Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFont | |
| from transformers import AutoImageProcessor | |
| from transformers import AutoModelForObjectDetection | |
| # Note: Can load from Hugging Face or can load from local. | |
| # You will have to replace {mrdbourke} for your own username if the model is on your Hugging Face account. | |
| model_save_path = "mrdbourke/detr_finetuned_trashify_box_detector_with_data_aug" | |
| # Load the model and preprocessor | |
| image_processor = AutoImageProcessor.from_pretrained(model_save_path) | |
| model = AutoModelForObjectDetection.from_pretrained(model_save_path) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| # Get the id2label dictionary from the model | |
| id2label = model.config.id2label | |
| # Set up a colour dictionary for plotting boxes with different colours | |
| color_dict = { | |
| "bin": "green", | |
| "trash": "blue", | |
| "hand": "purple", | |
| "trash_arm": "yellow", | |
| "not_trash": "red", | |
| "not_bin": "red", | |
| "not_hand": "red", | |
| } | |
| # Create helper functions for seeing if items from one list are in another | |
| def any_in_list(list_a, list_b): | |
| "Returns True if any item from list_a is in list_b, otherwise False." | |
| return any(item in list_b for item in list_a) | |
| def all_in_list(list_a, list_b): | |
| "Returns True if all items from list_a are in list_b, otherwise False." | |
| return all(item in list_b for item in list_a) | |
| def filter_highest_scoring_box_per_class(boxes, labels, scores): | |
| """ | |
| Perform NMS (Non-max Supression) to only keep the top scoring box per class. | |
| Args: | |
| boxes: tensor of shape (N, 4) | |
| labels: tensor of shape (N,) | |
| scores: tensor of shape (N,) | |
| Returns: | |
| boxes: tensor of shape (N, 4) filtered for max scoring item per class | |
| labels: tensor of shape (N,) filtered for max scoring item per class | |
| scores: tensor of shape (N,) filtered for max scoring item per class | |
| """ | |
| # Start with a blank keep mask (e.g. all False and then update the boxes to keep with True) | |
| keep_mask = torch.zeros(len(boxes), dtype=torch.bool) | |
| # For each unique class | |
| for class_id in labels.unique(): | |
| # Get the indicies for the target class | |
| class_mask = labels == class_id | |
| # If any of the labels match the current class_id | |
| if class_mask.any(): | |
| # Find the index of highest scoring box for this specific class | |
| class_scores = scores[class_mask] | |
| highest_score_idx = class_scores.argmax() | |
| # Convert back to the original index | |
| original_idx = torch.where(class_mask)[0][highest_score_idx] | |
| # Update the index in the keep mask to keep the highest scoring box | |
| keep_mask[original_idx] = True | |
| return boxes[keep_mask], labels[keep_mask], scores[keep_mask] | |
| def create_return_string(list_of_predicted_labels, target_items=["trash", "bin", "hand"]): | |
| # Setup blank string to print out | |
| return_string = "" | |
| # If no items detected or trash, bin, hand not in list, return notification | |
| if (len(list_of_predicted_labels) == 0) or not (any_in_list(list_a=target_items, list_b=list_of_predicted_labels)): | |
| return_string = f"No trash, bin or hand detected at confidence threshold {conf_threshold}. Try another image or lowering the confidence threshold." | |
| return return_string | |
| # If there are some missing, print the ones which are missing | |
| elif not all_in_list(list_a=target_items, list_b=list_of_predicted_labels): | |
| missing_items = [] | |
| for item in target_items: | |
| if item not in list_of_predicted_labels: | |
| missing_items.append(item) | |
| return_string = f"Detected the following items: {list_of_predicted_labels} (total: {len(list_of_predicted_labels)}). But missing the following in order to get +1: {missing_items}. If this is an error, try another image or altering the confidence threshold. Otherwise, the model may need to be updated with better data." | |
| # If all 3 trash, bin, hand occur = + 1 | |
| if all_in_list(list_a=target_items, list_b=list_of_predicted_labels): | |
| return_string = f"+1! Found the following items: {list_of_predicted_labels} (total: {len(list_of_predicted_labels)}), thank you for cleaning up the area!" | |
| print(return_string) | |
| return return_string | |
| def predict_on_image(image, conf_threshold): | |
| with torch.no_grad(): | |
| inputs = image_processor(images=[image], return_tensors="pt") | |
| outputs = model(**inputs.to(device)) | |
| target_sizes = torch.tensor([[image.size[1], image.size[0]]]) # height, width | |
| results = image_processor.post_process_object_detection(outputs, | |
| threshold=conf_threshold, | |
| target_sizes=target_sizes)[0] | |
| # Return all items in results to CPU | |
| for key, value in results.items(): | |
| try: | |
| results[key] = value.item().cpu() # can't get scalar as .item() so add try/except block | |
| except: | |
| results[key] = value.cpu() | |
| # Can return results as plotted on a PIL image (then display the image) | |
| draw = ImageDraw.Draw(image) | |
| # Create a copy of the image to draw on it for NMS | |
| image_nms = image.copy() | |
| draw_nms = ImageDraw.Draw(image_nms) | |
| # Get a font from ImageFont | |
| font = ImageFont.load_default(size=20) | |
| # Get class names as text for print out | |
| class_name_text_labels = [] | |
| # TK - update this for NMS | |
| class_name_text_labels_nms = [] | |
| # Get original boxes, scores, labels | |
| original_boxes = results["boxes"] | |
| original_labels = results["labels"] | |
| original_scores = results["scores"] | |
| # Filter boxes and only keep 1x of each label with highest score | |
| filtered_boxes, filtered_labels, filtered_scores = filter_highest_scoring_box_per_class(boxes=original_boxes, | |
| labels=original_labels, | |
| scores=original_scores) | |
| # TODO: turn this into a function so it's cleaner? | |
| for box, label, score in zip(original_boxes, original_labels, original_scores): | |
| # Create coordinates | |
| x, y, x2, y2 = tuple(box.tolist()) | |
| # Get label_name | |
| label_name = id2label[label.item()] | |
| targ_color = color_dict[label_name] | |
| class_name_text_labels.append(label_name) | |
| # Draw the rectangle | |
| draw.rectangle(xy=(x, y, x2, y2), | |
| outline=targ_color, | |
| width=3) | |
| # Create a text string to display | |
| text_string_to_show = f"{label_name} ({round(score.item(), 3)})" | |
| # Draw the text on the image | |
| draw.text(xy=(x, y), | |
| text=text_string_to_show, | |
| fill="white", | |
| font=font) | |
| # TODO: turn this into a function so it's cleaner? | |
| for box, label, score in zip(filtered_boxes, filtered_labels, filtered_scores): | |
| # Create coordinates | |
| x, y, x2, y2 = tuple(box.tolist()) | |
| # Get label_name | |
| label_name = id2label[label.item()] | |
| targ_color = color_dict[label_name] | |
| class_name_text_labels_nms.append(label_name) | |
| # Draw the rectangle | |
| draw_nms.rectangle(xy=(x, y, x2, y2), | |
| outline=targ_color, | |
| width=3) | |
| # Create a text string to display | |
| text_string_to_show = f"{label_name} ({round(score.item(), 3)})" | |
| # Draw the text on the image | |
| draw_nms.text(xy=(x, y), | |
| text=text_string_to_show, | |
| fill="white", | |
| font=font) | |
| # Remove the draw each time | |
| del draw | |
| del draw_nms | |
| # Create the return string | |
| return_string = create_return_string(list_of_predicted_labels=class_name_text_labels) | |
| return_string_nms = create_return_string(list_of_predicted_labels=class_name_text_labels_nms) | |
| return image, return_string, image_nms, return_string_nms | |
| # Create the interface | |
| demo = gr.Interface( | |
| fn=predict_on_image, | |
| inputs=[ | |
| gr.Image(type="pil", label="Target Image"), | |
| gr.Slider(minimum=0, maximum=1, value=0.25, label="Confidence Threshold") | |
| ], | |
| outputs=[ | |
| gr.Image(type="pil", label="Image Output (no filtering)"), | |
| gr.Text(label="Text Output (no filtering)"), | |
| gr.Image(type="pil", label="Image Output (with max score per class box filtering)"), | |
| gr.Text(label="Text Output (with max score per class box filtering)") | |
| ], | |
| title="🚮 Trashify Object Detection Demo V3", | |
| description="""Help clean up your local area! Upload an image and get +1 if there is all of the following items detected: trash, bin, hand. | |
| The model in V3 is [same model](https://huggingface.co/mrdbourke/detr_finetuned_trashify_box_detector_with_data_aug) as in [V2](https://huggingface.co/spaces/mrdbourke/trashify_demo_v2) (trained with data augmentation) but has an additional post-processing step (NMS or [Non Maximum Suppression](https://paperswithcode.com/method/non-maximum-suppression)) to filter classes for only the highest scoring box of each class. | |
| """, | |
| # Examples come in the form of a list of lists, where each inner list contains elements to prefill the `inputs` parameter with | |
| examples=[ | |
| ["examples/trashify_example_1.jpeg", 0.25], | |
| ["examples/trashify_example_2.jpeg", 0.25], | |
| ["examples/trashify_example_3.jpeg", 0.25] | |
| ], | |
| cache_examples=True | |
| ) | |
| # Launch the demo | |
| demo.launch() | |