# Create a Gradio demo for the Multi-Head model import gradio as gr import torch import torch.nn as nn from transformers import AutoModelForImageClassification, AutoImageProcessor from huggingface_hub import hf_hub_download # ======================================== # MODEL DEFINITION # ======================================== class MultiHeadContentModerator(nn.Module): def __init__(self, base_model_name="Falconsai/nsfw_image_detection", num_violence_labels=2): super().__init__() original_model = AutoModelForImageClassification.from_pretrained(base_model_name) hidden_size = original_model.config.hidden_size self.vit = original_model.vit self.nsfw_classifier = original_model.classifier self.violence_classifier = nn.Linear(hidden_size, num_violence_labels) # Falconsai uses: {0: 'normal', 1: 'nsfw'} self.nsfw_id2label = {0: 'normal', 1: 'nsfw'} self.violence_id2label = {0: 'safe', 1: 'violence'} def forward(self, pixel_values, task='both'): outputs = self.vit(pixel_values=pixel_values) pooled_output = outputs.last_hidden_state[:, 0] if task == 'both': return { 'nsfw': self.nsfw_classifier(pooled_output), 'violence': self.violence_classifier(pooled_output) } elif task == 'nsfw': return self.nsfw_classifier(pooled_output) elif task == 'violence': return self.violence_classifier(pooled_output) # ======================================== # LOAD MODEL # ======================================== MODEL_ID = "Ali7880/multihead-content-moderator" # Change this! # Download model files checkpoint_path = hf_hub_download(MODEL_ID, "multihead_model.pt") checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) # Create and load model model = MultiHeadContentModerator( base_model_name=checkpoint['base_model'], num_violence_labels=checkpoint['num_violence_labels'] ) model.load_state_dict(checkpoint['model_state_dict']) model.violence_id2label = checkpoint['violence_id2label'] model.nsfw_id2label = checkpoint['nsfw_id2label'] model.eval() processor = AutoImageProcessor.from_pretrained(MODEL_ID) # ======================================== # INFERENCE FUNCTION # ======================================== def moderate_image(image): if image is None: return None, None, "Please upload an image" # Preprocess inputs = processor(images=image, return_tensors="pt") # Predict with torch.no_grad(): outputs = model(inputs['pixel_values'], task='both') nsfw_probs = torch.softmax(outputs['nsfw'], dim=-1).numpy()[0] violence_probs = torch.softmax(outputs['violence'], dim=-1).numpy()[0] # Format results nsfw_results = {model.nsfw_id2label[i]: float(p) for i, p in enumerate(nsfw_probs)} violence_results = {model.violence_id2label[i]: float(p) for i, p in enumerate(violence_probs)} # Falconsai: {0: 'normal', 1: 'nsfw'}, Violence: {0: 'safe', 1: 'violence'} is_nsfw = nsfw_probs.argmax() == 1 # 1 = nsfw is_violent = violence_probs.argmax() == 1 # 1 = violence flags = [] if is_nsfw: flags.append(f"NSFW ({nsfw_results.get('nsfw', 0):.0%})") if is_violent: flags.append(f"Violence ({violence_results.get('violence', 0):.0%})") if flags: verdict = "❌ UNSAFE - " + ", ".join(flags) else: verdict = f"✅ SAFE (Normal: {nsfw_results.get('normal', 0):.0%}, Safe: {violence_results.get('safe', 0):.0%})" return nsfw_results, violence_results, verdict # ======================================== # GRADIO INTERFACE # ======================================== demo = gr.Interface( fn=moderate_image, inputs=gr.Image(type="pil", label="Upload Image"), outputs=[ gr.Label(label="NSFW Detection", num_top_classes=2), gr.Label(label="Violence Detection", num_top_classes=2), gr.Textbox(label="Overall Verdict") ], title="🛡️ Multi-Head Content Moderator", description="Upload an image to check for NSFW and Violence content simultaneously." ) if __name__ == "__main__": demo.launch()