|
|
|
|
|
|
|
|
| |
| |
|
|
| import os |
| import cv2 |
| import torch |
| import numpy as np |
| from ultralytics import YOLO |
| import wandb |
| import matplotlib.pyplot as plt |
| from datetime import datetime |
| from google.colab import userdata |
| |
| |
| wandb.login(key=userdata.get('WANDB')) |
|
|
| def setup_wandb(): |
| wandb.init(project="Object-detection", |
| name=f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}", |
| config={ |
| "model": "yolov8n", |
| "dataset": "coco128", |
| "img_size": 640, |
| "batch_size": 8 |
| }) |
|
|
| def load_model(): |
| model = YOLO("yolov8n.pt") |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| model.to(device) |
| return model |
|
|
| def train_model(model): |
| results = model.train( |
| data="coco128.yaml", |
| epochs=20, |
| imgsz=640, |
| batch=8, |
| device='0' if torch.cuda.is_available() else 'cpu', |
| patience=3, |
| save=True |
| ) |
| return model |
|
|
| def validate_model(model): |
| metrics = model.val() |
| wandb.log({ |
| "val/mAP50": metrics.box.map50, |
| "val/mAP50-95": metrics.box.map, |
| "val/precision": metrics.box.mp, |
| "val/recall": metrics.box.mr |
| }) |
| return metrics |
|
|
| def visualize_results(results, img_path): |
| img = cv2.imread(img_path) |
| if img is None: |
| raise ValueError(f"Failed to load image: {img_path}") |
| pred_img = results[0].plot() |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10)) |
| ax1.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) |
| ax1.axis('off') |
| ax2.imshow(cv2.cvtColor(pred_img, cv2.COLOR_BGR2RGB)) |
| ax2.axis('off') |
| plt.savefig("detection_results.jpg") |
| plt.close() |
| return "detection_results.jpg" |
|
|
| def test_image(model, img_path="test_image.jpg"): |
| if not os.path.exists(img_path): |
| raise FileNotFoundError(f"Image not found: {img_path}") |
| results = model(img_path) |
| output_path = visualize_results(results, img_path) |
| wandb.log({ |
| "test_results": wandb.Image(output_path), |
| "detections": results[0].boxes.cls.tolist(), |
| "confidences": results[0].boxes.conf.tolist() |
| }) |
| return results |
|
|
| def webcam_demo(model): |
| try: |
| from google.colab.patches import cv2_imshow |
| cap = cv2.VideoCapture(0) |
| if not cap.isOpened(): |
| print("Webcam not available - skipping demo") |
| return |
| print("Press 'q' to quit webcam demo") |
| while True: |
| ret, frame = cap.read() |
| if not ret: |
| break |
| results = model(frame) |
| annotated = results[0].plot() |
| cv2_imshow(annotated) |
| if cv2.waitKey(1) & 0xFF == ord('q'): |
| break |
| except Exception as e: |
| print(f"Webcam error: {e}") |
| finally: |
| cap.release() |
| cv2.destroyAllWindows() |
|
|
| def export_model(): |
| trained_weights = "runs/detect/train/weights/best.pt" |
| model = YOLO(trained_weights) |
| model.export(format="torchscript") |
| wandb.save("best.torchscript") |
|
|
| def main(): |
| setup_wandb() |
| model = load_model() |
| model = train_model(model) |
| validate_model(model) |
| test_image(model) |
| export_model() |
| wandb.finish() |
|
|
| if __name__ == "__main__": |
| main() |
| |
| |
| from google.colab import files |
|
|
| files.download("runs/detect/train/weights/best.torchscript") |
|
|
|
|