Spaces:
Running
Running
| import hashlib | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Any | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from rfdetr.detr import ( | |
| RFDETR, | |
| RFDETRBase, | |
| RFDETRLarge, | |
| RFDETRMedium, | |
| RFDETRNano, | |
| RFDETRSmall, | |
| ) | |
| from rfdetr.util.coco_classes import COCO_CLASSES | |
| from sahi import AutoDetectionModel | |
| from sahi.predict import get_prediction, get_sliced_prediction | |
| from kofi import SCRIPT | |
| APP_DIR = Path(__file__).parent | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| EXAMPLES_DIR = APP_DIR / "examples" | |
| HEADER = """# [RF-DETR](https://github.com/roboflow/rf-detr) + [SAHI](https://github.com/obss/sahi) 🔥""" | |
| IMAGE_PROCESSING_EXAMPLES = [ | |
| [ | |
| Image.open(EXAMPLES_DIR / "xingchen-yan-uDn6y3jii0Q-unsplash.jpg"), | |
| "medium", | |
| 896, | |
| 0.6, | |
| 300, | |
| 380, | |
| 0.2, | |
| 0.2, | |
| ], | |
| ] | |
| def load_model(checkpoint: str, resolution: int) -> RFDETR: | |
| if checkpoint == "nano": | |
| return RFDETRNano(resolution=resolution) | |
| if checkpoint == "small": | |
| return RFDETRSmall(resolution=resolution) | |
| if checkpoint == "medium": | |
| return RFDETRMedium(resolution=resolution) | |
| if checkpoint == "base": | |
| return RFDETRBase(resolution=resolution) | |
| elif checkpoint == "large": | |
| return RFDETRLarge(resolution=resolution) | |
| raise TypeError("checkpoint must be a base or large") | |
| def extend_model( | |
| model: Any, | |
| confidence_threshold: float, | |
| category_mapping: dict, | |
| ): | |
| model = AutoDetectionModel.from_pretrained( | |
| model_type="roboflow", | |
| model=model, | |
| confidence_threshold=confidence_threshold, | |
| category_mapping=category_mapping, | |
| device=DEVICE, | |
| ) | |
| return model | |
| def run( | |
| image_processing_input_image: Image.Image, | |
| image_processing_checkpoint_dropdown: str, | |
| image_processing_resolution_slider: int, | |
| image_processing_confidence_slider: float, | |
| slice_height: int, | |
| slice_width: int, | |
| overlap_height_ratio: float, | |
| overlap_width_ratio: float, | |
| ): | |
| image_processing_input_image = image_processing_input_image.convert("RGB") | |
| image_hash = hashlib.md5(image_processing_input_image.tobytes()).hexdigest() | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| temp_dir = Path(temp_dir) | |
| image_path = temp_dir / f"{image_hash}.jpg" | |
| image_processing_input_image.save(str(image_path)) | |
| # Load model: | |
| original_model = load_model( | |
| checkpoint=image_processing_checkpoint_dropdown, | |
| resolution=image_processing_resolution_slider, | |
| ) | |
| # Extend model with SAHI: | |
| model = extend_model( | |
| model=original_model, | |
| confidence_threshold=image_processing_confidence_slider, | |
| category_mapping=COCO_CLASSES, | |
| ) | |
| # Run original model prediction | |
| prediction = get_prediction( | |
| str(image_path), | |
| model, | |
| ) | |
| original_filename = f"{image_path.stem}_prediction" | |
| prediction.export_visuals( | |
| export_dir=str(temp_dir), | |
| file_name=original_filename, | |
| ) | |
| original_path = temp_dir / f"{original_filename}.png" | |
| original_pil = Image.open(original_path) | |
| # Run sliced model prediction | |
| prediction_sliced = get_sliced_prediction( | |
| str(image_path), | |
| model, | |
| slice_width=slice_width, | |
| slice_height=slice_height, | |
| overlap_width_ratio=overlap_width_ratio, | |
| overlap_height_ratio=overlap_height_ratio, | |
| postprocess_match_threshold=image_processing_confidence_slider, | |
| ) | |
| scliced_filename = f"{image_path.stem}_sliced_prediction" | |
| prediction_sliced.export_visuals( | |
| export_dir=str(temp_dir), | |
| file_name=scliced_filename, | |
| ) | |
| sliced_path = temp_dir / f"{scliced_filename}.png" | |
| sliced_pil = Image.open(sliced_path) | |
| return original_pil, sliced_pil | |
| with gr.Blocks(js=SCRIPT) as demo: | |
| gr.Markdown(HEADER) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Input") | |
| image_processing_input_image = gr.Image( | |
| label="Original Image", | |
| image_mode="RGB", | |
| type="pil", | |
| height=600, | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("## Output") | |
| image_processing_output_image = gr.ImageSlider( | |
| label="Original vs Sliced Prediction", | |
| image_mode="RGB", | |
| type="pil", | |
| height=600, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## RF-DETR Configuration") | |
| image_processing_confidence_slider = gr.Slider( | |
| label="Confidence", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.5, | |
| ) | |
| image_processing_resolution_dropdown = gr.Dropdown( | |
| label="Inference Resolution (dividable by 32 and 56)", | |
| choices=[224, 448, 672, 896, 1008, 1120, 1344, 1568, 1792, 2016, 2240], | |
| value=896, | |
| ) | |
| image_processing_checkpoint_dropdown = gr.Dropdown( | |
| label="Model Size", | |
| choices=["nano", "small", "medium", "base", "large"], | |
| value="base", | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("## SAHI Configuration") | |
| slice_width = gr.Slider( | |
| label="Slice Width", | |
| minimum=100, | |
| maximum=500, | |
| step=1, | |
| value=224, | |
| ) | |
| slice_height = gr.Slider( | |
| label="Slice Height", | |
| minimum=100, | |
| maximum=500, | |
| step=1, | |
| value=224, | |
| ) | |
| overlap_width_ratio = gr.Slider( | |
| label="Overlap Width Ratio", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.2, | |
| ) | |
| overlap_height_ratio = gr.Slider( | |
| label="Overlap Height Ratio", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.2, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_processing_submit_button = gr.Button("Run") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Examples") | |
| gr.Examples( | |
| fn=run, | |
| examples=IMAGE_PROCESSING_EXAMPLES, | |
| inputs=[ | |
| image_processing_input_image, | |
| image_processing_checkpoint_dropdown, | |
| image_processing_resolution_dropdown, | |
| image_processing_confidence_slider, | |
| slice_height, | |
| slice_width, | |
| overlap_height_ratio, | |
| overlap_width_ratio, | |
| ], | |
| outputs=[image_processing_output_image], | |
| run_on_click=True, | |
| cache_examples=False, | |
| cache_mode="eager", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML('<div id="kofi" style="text-align: center;"></div>') | |
| image_processing_submit_button.click( | |
| fn=run, | |
| inputs=[ | |
| image_processing_input_image, | |
| image_processing_checkpoint_dropdown, | |
| image_processing_resolution_dropdown, | |
| image_processing_confidence_slider, | |
| slice_height, | |
| slice_width, | |
| overlap_height_ratio, | |
| overlap_width_ratio, | |
| ], | |
| outputs=[image_processing_output_image], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| debug=False, | |
| ) | |