| import torch |
| import torch.nn.functional as F |
| import gradio as gr |
| from PIL import Image |
| from torchvision import transforms |
| from cnn import CNN |
|
|
| device = torch.device( |
| "cuda" |
| if torch.cuda.is_available() |
| else "mps" |
| if torch.backends.mps.is_available() |
| else "cpu" |
| ) |
|
|
| classes = [ |
| "airplane", |
| "automobile", |
| "bird", |
| "cat", |
| "deer", |
| "dog", |
| "frog", |
| "horse", |
| "ship", |
| "truck", |
| ] |
|
|
| model = CNN() |
| model.load_state_dict(torch.load("cnn/model.pt", map_location=device)) |
| model.to(device) |
| model.eval() |
|
|
| transform = transforms.Compose( |
| [ |
| transforms.Resize((32, 32)), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), |
| ] |
| ) |
|
|
|
|
| def predict(image): |
| if image is None: |
| return {} |
|
|
| image = Image.fromarray(image).convert("RGB") |
| image_tensor = transform(image) |
| image_tensor = (image_tensor).unsqueeze(0).to(device) |
|
|
| with torch.no_grad(): |
| outputs = model(image_tensor) |
| probabilities = F.softmax(outputs, dim=1)[0] |
|
|
| return {classes[i]: float(probabilities[i]) for i in range(len(classes))} |
|
|
|
|
| demo = gr.Interface( |
| fn=predict, |
| inputs=gr.Image(type="numpy"), |
| outputs=gr.Label(num_top_classes=10), |
| title="CNN Classifier", |
| description="Upload an image to classify it into one of 10 CIFAR-10 categories: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck", |
| examples=[ |
| ["examples/1.png"], |
| ["examples/2.png"], |
| ["examples/3.png"], |
| ["examples/4.png"], |
| ["examples/5.png"], |
| ["examples/6.png"], |
| ["examples/7.png"], |
| ], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(share=True, pwa=True) |
|
|