RF-DETR_SAHI / app.py
Darius Morawiec
chore: Update device handling for model loading
dfcbf3d
import hashlib
import tempfile
from pathlib import Path
from typing import Any
import gradio as gr
import torch
from PIL import Image
from rfdetr.detr import (
RFDETR,
RFDETRBase,
RFDETRLarge,
RFDETRMedium,
RFDETRNano,
RFDETRSmall,
)
from rfdetr.util.coco_classes import COCO_CLASSES
from sahi import AutoDetectionModel
from sahi.predict import get_prediction, get_sliced_prediction
from kofi import SCRIPT
APP_DIR = Path(__file__).parent
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EXAMPLES_DIR = APP_DIR / "examples"
HEADER = """# [RF-DETR](https://github.com/roboflow/rf-detr) + [SAHI](https://github.com/obss/sahi) 🔥"""
IMAGE_PROCESSING_EXAMPLES = [
[
Image.open(EXAMPLES_DIR / "xingchen-yan-uDn6y3jii0Q-unsplash.jpg"),
"medium",
896,
0.6,
300,
380,
0.2,
0.2,
],
]
def load_model(checkpoint: str, resolution: int) -> RFDETR:
if checkpoint == "nano":
return RFDETRNano(resolution=resolution)
if checkpoint == "small":
return RFDETRSmall(resolution=resolution)
if checkpoint == "medium":
return RFDETRMedium(resolution=resolution)
if checkpoint == "base":
return RFDETRBase(resolution=resolution)
elif checkpoint == "large":
return RFDETRLarge(resolution=resolution)
raise TypeError("checkpoint must be a base or large")
def extend_model(
model: Any,
confidence_threshold: float,
category_mapping: dict,
):
model = AutoDetectionModel.from_pretrained(
model_type="roboflow",
model=model,
confidence_threshold=confidence_threshold,
category_mapping=category_mapping,
device=DEVICE,
)
return model
def run(
image_processing_input_image: Image.Image,
image_processing_checkpoint_dropdown: str,
image_processing_resolution_slider: int,
image_processing_confidence_slider: float,
slice_height: int,
slice_width: int,
overlap_height_ratio: float,
overlap_width_ratio: float,
):
image_processing_input_image = image_processing_input_image.convert("RGB")
image_hash = hashlib.md5(image_processing_input_image.tobytes()).hexdigest()
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = Path(temp_dir)
image_path = temp_dir / f"{image_hash}.jpg"
image_processing_input_image.save(str(image_path))
# Load model:
original_model = load_model(
checkpoint=image_processing_checkpoint_dropdown,
resolution=image_processing_resolution_slider,
)
# Extend model with SAHI:
model = extend_model(
model=original_model,
confidence_threshold=image_processing_confidence_slider,
category_mapping=COCO_CLASSES,
)
# Run original model prediction
prediction = get_prediction(
str(image_path),
model,
)
original_filename = f"{image_path.stem}_prediction"
prediction.export_visuals(
export_dir=str(temp_dir),
file_name=original_filename,
)
original_path = temp_dir / f"{original_filename}.png"
original_pil = Image.open(original_path)
# Run sliced model prediction
prediction_sliced = get_sliced_prediction(
str(image_path),
model,
slice_width=slice_width,
slice_height=slice_height,
overlap_width_ratio=overlap_width_ratio,
overlap_height_ratio=overlap_height_ratio,
postprocess_match_threshold=image_processing_confidence_slider,
)
scliced_filename = f"{image_path.stem}_sliced_prediction"
prediction_sliced.export_visuals(
export_dir=str(temp_dir),
file_name=scliced_filename,
)
sliced_path = temp_dir / f"{scliced_filename}.png"
sliced_pil = Image.open(sliced_path)
return original_pil, sliced_pil
with gr.Blocks(js=SCRIPT) as demo:
gr.Markdown(HEADER)
with gr.Row():
with gr.Column():
gr.Markdown("## Input")
image_processing_input_image = gr.Image(
label="Original Image",
image_mode="RGB",
type="pil",
height=600,
)
with gr.Column():
gr.Markdown("## Output")
image_processing_output_image = gr.ImageSlider(
label="Original vs Sliced Prediction",
image_mode="RGB",
type="pil",
height=600,
)
with gr.Row():
with gr.Column():
gr.Markdown("## RF-DETR Configuration")
image_processing_confidence_slider = gr.Slider(
label="Confidence",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.5,
)
image_processing_resolution_dropdown = gr.Dropdown(
label="Inference Resolution (dividable by 32 and 56)",
choices=[224, 448, 672, 896, 1008, 1120, 1344, 1568, 1792, 2016, 2240],
value=896,
)
image_processing_checkpoint_dropdown = gr.Dropdown(
label="Model Size",
choices=["nano", "small", "medium", "base", "large"],
value="base",
)
with gr.Column():
gr.Markdown("## SAHI Configuration")
slice_width = gr.Slider(
label="Slice Width",
minimum=100,
maximum=500,
step=1,
value=224,
)
slice_height = gr.Slider(
label="Slice Height",
minimum=100,
maximum=500,
step=1,
value=224,
)
overlap_width_ratio = gr.Slider(
label="Overlap Width Ratio",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.2,
)
overlap_height_ratio = gr.Slider(
label="Overlap Height Ratio",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.2,
)
with gr.Row():
with gr.Column():
image_processing_submit_button = gr.Button("Run")
with gr.Row():
with gr.Column():
gr.Markdown("## Examples")
gr.Examples(
fn=run,
examples=IMAGE_PROCESSING_EXAMPLES,
inputs=[
image_processing_input_image,
image_processing_checkpoint_dropdown,
image_processing_resolution_dropdown,
image_processing_confidence_slider,
slice_height,
slice_width,
overlap_height_ratio,
overlap_width_ratio,
],
outputs=[image_processing_output_image],
run_on_click=True,
cache_examples=False,
cache_mode="eager",
)
with gr.Row():
with gr.Column():
gr.HTML('<div id="kofi" style="text-align: center;"></div>')
image_processing_submit_button.click(
fn=run,
inputs=[
image_processing_input_image,
image_processing_checkpoint_dropdown,
image_processing_resolution_dropdown,
image_processing_confidence_slider,
slice_height,
slice_width,
overlap_height_ratio,
overlap_width_ratio,
],
outputs=[image_processing_output_image],
)
if __name__ == "__main__":
demo.launch(
debug=False,
)