| | import gradio as gr |
| | import torch |
| | import torch.nn.functional as F |
| | import torchvision.transforms as transforms |
| | from PIL import Image |
| | import os |
| | from ResNet_for_CC import CC_model |
| |
|
| | |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | model_path = "CC_net.pt" |
| | model = CC_model(num_classes1=14) |
| |
|
| | |
| | state_dict = torch.load(model_path, map_location=device) |
| | model.load_state_dict(state_dict, strict=False) |
| | model.to(device) |
| | model.eval() |
| |
|
| | |
| | class_labels = [ |
| | "T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", |
| | "Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", |
| | "Vest", "Underwear" |
| | ] |
| |
|
| | |
| | default_images = { |
| | "Shawl": "shawlOG.webp", |
| | "Jacket": "jacket.jpg", |
| | "Sweater": "sweater.webp", |
| | "Vest": "dress.jpg" |
| | } |
| |
|
| | |
| | default_images_gallery = [(path, label) for label, path in default_images.items()] |
| |
|
| | |
| | def preprocess_image(image): |
| | """Applies necessary transformations to the input image.""" |
| | transform = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | ]) |
| | return transform(image).unsqueeze(0).to(device) |
| |
|
| | |
| | def classify_image(selected_default, uploaded_image): |
| | """Processes either a default or uploaded image and returns the predicted clothing category.""" |
| | try: |
| | |
| | if uploaded_image is not None: |
| | image = Image.fromarray(uploaded_image) |
| | else: |
| | image_path = default_images[selected_default] |
| | image = Image.open(image_path) |
| |
|
| | image = preprocess_image(image) |
| | |
| | with torch.no_grad(): |
| | output = model(image) |
| | if isinstance(output, tuple): |
| | output = output[1] |
| |
|
| | probabilities = F.softmax(output, dim=1) |
| | predicted_class = torch.argmax(probabilities, dim=1).item() |
| |
|
| | if 0 <= predicted_class < len(class_labels): |
| | predicted_label = class_labels[predicted_class] |
| | confidence = probabilities[0][predicted_class].item() * 100 |
| | return f"Predicted Class: {predicted_label} (Confidence: {confidence:.2f}%)" |
| | else: |
| | return "[ERROR] Model returned an invalid class index." |
| | |
| | except Exception as e: |
| | return f"Error in classification: {e}" |
| |
|
| | |
| | with gr.Blocks() as interface: |
| | gr.Markdown("# Clothing1M Image Classifier") |
| | gr.Markdown("Upload a clothing image or select from the predefined images below.") |
| |
|
| | |
| | gallery = gr.Gallery( |
| | value=default_images_gallery, |
| | label="Default Images", |
| | elem_id="default_gallery" |
| | ) |
| |
|
| | |
| | default_selector = gr.Dropdown( |
| | choices=list(default_images.keys()), |
| | label="Select a Default Image", |
| | value="Shawl" |
| | ) |
| |
|
| | |
| | image_upload = gr.Image(type="numpy", label="Or Upload Your Own Image") |
| |
|
| | |
| | output_text = gr.Textbox(label="Classification Result") |
| |
|
| | |
| | classify_button = gr.Button("Classify Image") |
| |
|
| | |
| | classify_button.click( |
| | fn=classify_image, |
| | inputs=[default_selector, image_upload], |
| | outputs=output_text |
| | ) |
| |
|
| | |
| | if __name__ == "__main__": |
| | print("[INFO] Launching Gradio interface...") |
| | interface.launch() |
| |
|