Update app.py
Browse files
app.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
|
@@ -42,7 +44,7 @@ MODEL_ID = "Ali7880/multihead-content-moderator" # Change this!
|
|
| 42 |
|
| 43 |
# Download model files
|
| 44 |
checkpoint_path = hf_hub_download(MODEL_ID, "multihead_model.pt")
|
| 45 |
-
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 46 |
|
| 47 |
# Create and load model
|
| 48 |
model = MultiHeadContentModerator(
|
|
@@ -88,8 +90,8 @@ def moderate_image(image):
|
|
| 88 |
|
| 89 |
if flags:
|
| 90 |
verdict = "β UNSAFE - " + ", ".join(flags)
|
|
|
|
| 91 |
verdict = f"β
SAFE (Normal: {nsfw_results.get('normal', 0):.0%}, Safe: {violence_results.get('safe', 0):.0%})"
|
| 92 |
-
verdict = f"β
SAFE (NSFW: {nsfw_results.get('safe', 0):.0%}, Violence: {violence_results.get('safe', 0):.0%})"
|
| 93 |
|
| 94 |
return nsfw_results, violence_results, verdict
|
| 95 |
|
|
@@ -110,4 +112,4 @@ demo = gr.Interface(
|
|
| 110 |
)
|
| 111 |
|
| 112 |
if __name__ == "__main__":
|
| 113 |
-
demo.launch()
|
|
|
|
| 1 |
+
# Create a Gradio demo for the Multi-Head model
|
| 2 |
+
|
| 3 |
import gradio as gr
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
|
|
|
| 44 |
|
| 45 |
# Download model files
|
| 46 |
checkpoint_path = hf_hub_download(MODEL_ID, "multihead_model.pt")
|
| 47 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
|
| 48 |
|
| 49 |
# Create and load model
|
| 50 |
model = MultiHeadContentModerator(
|
|
|
|
| 90 |
|
| 91 |
if flags:
|
| 92 |
verdict = "β UNSAFE - " + ", ".join(flags)
|
| 93 |
+
else:
|
| 94 |
verdict = f"β
SAFE (Normal: {nsfw_results.get('normal', 0):.0%}, Safe: {violence_results.get('safe', 0):.0%})"
|
|
|
|
| 95 |
|
| 96 |
return nsfw_results, violence_results, verdict
|
| 97 |
|
|
|
|
| 112 |
)
|
| 113 |
|
| 114 |
if __name__ == "__main__":
|
| 115 |
+
demo.launch()
|