Spaces:
Running
Running
| import gradio as gr | |
| from Models import VisionModel | |
| import huggingface_hub | |
| from PIL import Image | |
| import torch.amp.autocast_mode | |
| from pathlib import Path | |
| MODEL_REPO = "fancyfeast/joytag" | |
| def predict(image: Image.Image): | |
| with torch.amp.autocast_mode.autocast('cuda', enabled=True): | |
| preds = model(image) | |
| tag_preds = preds['tags'].sigmoid().cpu() | |
| return {top_tags[i]: tag_preds[i] for i in range(len(top_tags))} | |
| print("Downloading model...") | |
| path = huggingface_hub.snapshot_download(MODEL_REPO) | |
| print("Loading model...") | |
| model = VisionModel.load_model(path) | |
| model.eval() | |
| with open(Path(path) / 'top_tags.txt', 'r') as f: | |
| top_tags = [line.strip() for line in f.readlines() if line.strip()] | |
| print("Starting server...") | |
| gradio_app = gr.Interface( | |
| predict, | |
| inputs=gr.Image(label="Source", sources=['upload', 'webcam'], type='pil'), | |
| outputs=[gr.Label(label="Result", num_top_classes=5)], | |
| title="JoyTag", | |
| ) | |
| if __name__ == '__main__': | |
| gradio_app.launch() | |