File size: 4,343 Bytes
f83caf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import torch
import torch.nn as nn
from transformers import AutoModelForImageClassification, AutoImageProcessor
from huggingface_hub import hf_hub_download

# ========================================
# MODEL DEFINITION
# ========================================
class MultiHeadContentModerator(nn.Module):
    def __init__(self, base_model_name="Falconsai/nsfw_image_detection", num_violence_labels=2):
        super().__init__()
        original_model = AutoModelForImageClassification.from_pretrained(base_model_name)
        hidden_size = original_model.config.hidden_size
        
        self.vit = original_model.vit
        self.nsfw_classifier = original_model.classifier
        self.violence_classifier = nn.Linear(hidden_size, num_violence_labels)
        
        # Falconsai uses: {0: 'normal', 1: 'nsfw'}
        self.nsfw_id2label = {0: 'normal', 1: 'nsfw'}
        self.violence_id2label = {0: 'safe', 1: 'violence'}
        
    def forward(self, pixel_values, task='both'):
        outputs = self.vit(pixel_values=pixel_values)
        pooled_output = outputs.last_hidden_state[:, 0]
        
        if task == 'both':
            return {
                'nsfw': self.nsfw_classifier(pooled_output),
                'violence': self.violence_classifier(pooled_output)
            }
        elif task == 'nsfw':
            return self.nsfw_classifier(pooled_output)
        elif task == 'violence':
            return self.violence_classifier(pooled_output)

# ========================================
# LOAD MODEL
# ========================================
MODEL_ID = "Ali7880/multihead-content-moderator"  # Change this!

# Download model files
checkpoint_path = hf_hub_download(MODEL_ID, "multihead_model.pt")
checkpoint = torch.load(checkpoint_path, map_location='cpu')

# Create and load model
model = MultiHeadContentModerator(
    base_model_name=checkpoint['base_model'],
    num_violence_labels=checkpoint['num_violence_labels']
)
model.load_state_dict(checkpoint['model_state_dict'])
model.violence_id2label = checkpoint['violence_id2label']
model.nsfw_id2label = checkpoint['nsfw_id2label']
model.eval()

processor = AutoImageProcessor.from_pretrained(MODEL_ID)

# ========================================
# INFERENCE FUNCTION
# ========================================
def moderate_image(image):
    if image is None:
        return None, None, "Please upload an image"
    
    # Preprocess
    inputs = processor(images=image, return_tensors="pt")
    
    # Predict
    with torch.no_grad():
        outputs = model(inputs['pixel_values'], task='both')
        nsfw_probs = torch.softmax(outputs['nsfw'], dim=-1).numpy()[0]
        violence_probs = torch.softmax(outputs['violence'], dim=-1).numpy()[0]
    
    # Format results
    nsfw_results = {model.nsfw_id2label[i]: float(p) for i, p in enumerate(nsfw_probs)}
    violence_results = {model.violence_id2label[i]: float(p) for i, p in enumerate(violence_probs)}
    
    # Falconsai: {0: 'normal', 1: 'nsfw'}, Violence: {0: 'safe', 1: 'violence'}
    is_nsfw = nsfw_probs.argmax() == 1  # 1 = nsfw
    is_violent = violence_probs.argmax() == 1  # 1 = violence
    
    flags = []
    if is_nsfw:
        flags.append(f"NSFW ({nsfw_results.get('nsfw', 0):.0%})")
    if is_violent:
        flags.append(f"Violence ({violence_results.get('violence', 0):.0%})")
    
    if flags:
        verdict = "❌ UNSAFE - " + ", ".join(flags)
        verdict = f"βœ… SAFE (Normal: {nsfw_results.get('normal', 0):.0%}, Safe: {violence_results.get('safe', 0):.0%})"
        verdict = f"βœ… SAFE (NSFW: {nsfw_results.get('safe', 0):.0%}, Violence: {violence_results.get('safe', 0):.0%})"
    
    return nsfw_results, violence_results, verdict

# ========================================
# GRADIO INTERFACE
# ========================================
demo = gr.Interface(
    fn=moderate_image,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=[
        gr.Label(label="NSFW Detection", num_top_classes=2),
        gr.Label(label="Violence Detection", num_top_classes=2),
        gr.Textbox(label="Overall Verdict")
    ],
    title="πŸ›‘οΈ Multi-Head Content Moderator",
    description="Upload an image to check for NSFW and Violence content simultaneously.",
    theme="default"
)

if __name__ == "__main__":
    demo.launch()