| | import gradio as gr |
| | import requests |
| | import random |
| | from src.classification_model import ClassificationModel |
| | from src.util.extract import extract_image_urls |
| |
|
| | |
| | |
| | |
| |
|
| | print('start...') |
| | clf = ClassificationModel() |
| | model_names = clf.get_model_names() |
| | output_labels = [] |
| | output_images = [] |
| | max_input_image = 10 |
| |
|
| | def predict(models, img_url, img_files): |
| | print(f'model choosen: {models}') |
| | model_predictions = {} |
| |
|
| | |
| | for label in output_labels: |
| | model_predictions[label] = gr.Label(label=f'# {name}', visible=False) |
| | |
| | for img in output_images: |
| | model_predictions[img] = gr.Image(visible=False) |
| | |
| | sources = extract_image_urls(img_url) + (img_files or []) |
| | for i, source in enumerate(sources): |
| | print(f'{i} type: {type(source)} --> {source}') |
| | if i >= max_input_image: break |
| |
|
| | for j, m in enumerate(models): |
| | results = clf.classify(m, source) |
| | print(f'{m} --> {results}') |
| |
|
| | idx = j + (len(model_names)*i) |
| | label_value = {raw.class_name: raw.confidence for raw in results} |
| | model_predictions[output_labels[idx]] = gr.Label(label=f'# {m}, 3 seconds', value=label_value, visible=True) |
| | model_predictions[output_images[i]] = gr.Image(visible=True, value=source, label=f'image {i}') |
| | |
| | return model_predictions |
| |
|
| | with gr.Blocks() as demo: |
| | gr.Markdown("# Image Classification Benchmark") |
| | gr.Markdown("You can input at maximum 10 images at once (urls or files)") |
| | |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | model = gr.Dropdown(choices=model_names, multiselect=True, label='Choose the model') |
| | img_urls = gr.Textbox(label='Image Urls (separated with comma)') |
| | img_files = gr.File(label='Upload Files',file_count='multiple', file_types=['image']) |
| | apply = gr.Button("Classify", variant='primary') |
| | with gr.Column(scale=1): |
| | for i in range(max_input_image): |
| | output_images.append(gr.Image(interactive=False, visible= (i==0))) |
| | for name in clf.get_model_names(): |
| | output_labels.append(gr.Label(label=f'# {name}', visible= (i==0))) |
| |
|
| | apply.click(fn=predict, |
| | inputs=[model, img_urls, img_files], |
| | outputs=output_images+output_labels) |
| |
|
| |
|
| | |
| | demo.queue().launch() |