Ali7880 commited on
Commit
f83caf8
·
verified ·
1 Parent(s): 0c32818

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import AutoModelForImageClassification, AutoImageProcessor
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ # ========================================
8
+ # MODEL DEFINITION
9
+ # ========================================
10
+ class MultiHeadContentModerator(nn.Module):
11
+ def __init__(self, base_model_name="Falconsai/nsfw_image_detection", num_violence_labels=2):
12
+ super().__init__()
13
+ original_model = AutoModelForImageClassification.from_pretrained(base_model_name)
14
+ hidden_size = original_model.config.hidden_size
15
+
16
+ self.vit = original_model.vit
17
+ self.nsfw_classifier = original_model.classifier
18
+ self.violence_classifier = nn.Linear(hidden_size, num_violence_labels)
19
+
20
+ # Falconsai uses: {0: 'normal', 1: 'nsfw'}
21
+ self.nsfw_id2label = {0: 'normal', 1: 'nsfw'}
22
+ self.violence_id2label = {0: 'safe', 1: 'violence'}
23
+
24
+ def forward(self, pixel_values, task='both'):
25
+ outputs = self.vit(pixel_values=pixel_values)
26
+ pooled_output = outputs.last_hidden_state[:, 0]
27
+
28
+ if task == 'both':
29
+ return {
30
+ 'nsfw': self.nsfw_classifier(pooled_output),
31
+ 'violence': self.violence_classifier(pooled_output)
32
+ }
33
+ elif task == 'nsfw':
34
+ return self.nsfw_classifier(pooled_output)
35
+ elif task == 'violence':
36
+ return self.violence_classifier(pooled_output)
37
+
38
+ # ========================================
39
+ # LOAD MODEL
40
+ # ========================================
41
+ MODEL_ID = "Ali7880/multihead-content-moderator" # Change this!
42
+
43
+ # Download model files
44
+ checkpoint_path = hf_hub_download(MODEL_ID, "multihead_model.pt")
45
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
46
+
47
+ # Create and load model
48
+ model = MultiHeadContentModerator(
49
+ base_model_name=checkpoint['base_model'],
50
+ num_violence_labels=checkpoint['num_violence_labels']
51
+ )
52
+ model.load_state_dict(checkpoint['model_state_dict'])
53
+ model.violence_id2label = checkpoint['violence_id2label']
54
+ model.nsfw_id2label = checkpoint['nsfw_id2label']
55
+ model.eval()
56
+
57
+ processor = AutoImageProcessor.from_pretrained(MODEL_ID)
58
+
59
+ # ========================================
60
+ # INFERENCE FUNCTION
61
+ # ========================================
62
+ def moderate_image(image):
63
+ if image is None:
64
+ return None, None, "Please upload an image"
65
+
66
+ # Preprocess
67
+ inputs = processor(images=image, return_tensors="pt")
68
+
69
+ # Predict
70
+ with torch.no_grad():
71
+ outputs = model(inputs['pixel_values'], task='both')
72
+ nsfw_probs = torch.softmax(outputs['nsfw'], dim=-1).numpy()[0]
73
+ violence_probs = torch.softmax(outputs['violence'], dim=-1).numpy()[0]
74
+
75
+ # Format results
76
+ nsfw_results = {model.nsfw_id2label[i]: float(p) for i, p in enumerate(nsfw_probs)}
77
+ violence_results = {model.violence_id2label[i]: float(p) for i, p in enumerate(violence_probs)}
78
+
79
+ # Falconsai: {0: 'normal', 1: 'nsfw'}, Violence: {0: 'safe', 1: 'violence'}
80
+ is_nsfw = nsfw_probs.argmax() == 1 # 1 = nsfw
81
+ is_violent = violence_probs.argmax() == 1 # 1 = violence
82
+
83
+ flags = []
84
+ if is_nsfw:
85
+ flags.append(f"NSFW ({nsfw_results.get('nsfw', 0):.0%})")
86
+ if is_violent:
87
+ flags.append(f"Violence ({violence_results.get('violence', 0):.0%})")
88
+
89
+ if flags:
90
+ verdict = "❌ UNSAFE - " + ", ".join(flags)
91
+ verdict = f"✅ SAFE (Normal: {nsfw_results.get('normal', 0):.0%}, Safe: {violence_results.get('safe', 0):.0%})"
92
+ verdict = f"✅ SAFE (NSFW: {nsfw_results.get('safe', 0):.0%}, Violence: {violence_results.get('safe', 0):.0%})"
93
+
94
+ return nsfw_results, violence_results, verdict
95
+
96
+ # ========================================
97
+ # GRADIO INTERFACE
98
+ # ========================================
99
+ demo = gr.Interface(
100
+ fn=moderate_image,
101
+ inputs=gr.Image(type="pil", label="Upload Image"),
102
+ outputs=[
103
+ gr.Label(label="NSFW Detection", num_top_classes=2),
104
+ gr.Label(label="Violence Detection", num_top_classes=2),
105
+ gr.Textbox(label="Overall Verdict")
106
+ ],
107
+ title="🛡️ Multi-Head Content Moderator",
108
+ description="Upload an image to check for NSFW and Violence content simultaneously.",
109
+ theme="default"
110
+ )
111
+
112
+ if __name__ == "__main__":
113
+ demo.launch()