|
|
|
|
|
""" |
|
|
Real-time Strawberry Detection and Ripeness Classification using Webcam |
|
|
Optimized for WSL (Windows Subsystem for Linux) environments |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import argparse |
|
|
import time |
|
|
from pathlib import Path |
|
|
import sys |
|
|
import warnings |
|
|
|
|
|
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
class StrawberryPickerWebcam: |
|
|
def __init__(self, detector_path, classifier_path, device='cpu'): |
|
|
""" |
|
|
Initialize the strawberry picker system |
|
|
|
|
|
Args: |
|
|
detector_path: Path to YOLOv8 detection model |
|
|
classifier_path: Path to EfficientNet classification model |
|
|
device: Device to run inference on ('cpu' or 'cuda') |
|
|
""" |
|
|
print("π Initializing Strawberry Picker AI System...") |
|
|
|
|
|
self.device = device |
|
|
self.ripeness_classes = ['unripe', 'partially-ripe', 'ripe', 'overripe'] |
|
|
|
|
|
|
|
|
self.colors = { |
|
|
'unripe': (0, 255, 0), |
|
|
'partially-ripe': (0, 255, 255), |
|
|
'ripe': (0, 0, 255), |
|
|
'overripe': (128, 0, 128) |
|
|
} |
|
|
|
|
|
|
|
|
print("Loading detection model...") |
|
|
try: |
|
|
from ultralytics import YOLO |
|
|
self.detector = YOLO(detector_path) |
|
|
print("β
Detection model loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"β Error loading detection model: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
print("Loading classification model...") |
|
|
try: |
|
|
self.classifier = torch.load(classifier_path, map_location=device) |
|
|
self.classifier.eval() |
|
|
print("β
Classification model loaded successfully") |
|
|
except Exception as e: |
|
|
print(f"β Error loading classification model: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
self.transform = transforms.Compose([ |
|
|
transforms.Resize((128, 128)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
print("β
System initialized and ready!") |
|
|
|
|
|
def detect_and_classify(self, frame): |
|
|
""" |
|
|
Detect strawberries and classify their ripeness in a frame |
|
|
|
|
|
Args: |
|
|
frame: Input frame (BGR format) |
|
|
|
|
|
Returns: |
|
|
results: List of detection/classification results |
|
|
visualized_frame: Frame with visualizations |
|
|
""" |
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
detection_results = self.detector(frame_rgb) |
|
|
|
|
|
results = [] |
|
|
|
|
|
for result in detection_results: |
|
|
boxes = result.boxes.xyxy.cpu().numpy() |
|
|
confidences = result.boxes.conf.cpu().numpy() |
|
|
|
|
|
for box, conf in zip(boxes, confidences): |
|
|
if conf < 0.5: |
|
|
continue |
|
|
|
|
|
x1, y1, x2, y2 = map(int, box) |
|
|
|
|
|
|
|
|
x1 = max(0, x1) |
|
|
y1 = max(0, y1) |
|
|
x2 = min(frame.shape[1], x2) |
|
|
y2 = min(frame.shape[0], y2) |
|
|
|
|
|
|
|
|
crop = frame_rgb[y1:y2, x1:x2] |
|
|
|
|
|
if crop.size == 0: |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
crop_pil = Image.fromarray(crop) |
|
|
input_tensor = self.transform(crop_pil).unsqueeze(0).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = self.classifier(input_tensor) |
|
|
probabilities = torch.softmax(output, dim=1) |
|
|
predicted_class = torch.argmax(probabilities, dim=1).item() |
|
|
confidence = probabilities[0][predicted_class].item() |
|
|
|
|
|
ripeness = self.ripeness_classes[predicted_class] |
|
|
|
|
|
results.append({ |
|
|
'bbox': (x1, y1, x2, y2), |
|
|
'ripeness': ripeness, |
|
|
'confidence': confidence, |
|
|
'detection_confidence': float(conf) |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Warning: Error classifying crop: {e}") |
|
|
continue |
|
|
|
|
|
return results |
|
|
|
|
|
def visualize(self, frame, results): |
|
|
""" |
|
|
Draw bounding boxes and labels on frame |
|
|
|
|
|
Args: |
|
|
frame: Input frame |
|
|
results: Detection/classification results |
|
|
|
|
|
Returns: |
|
|
visualized_frame: Frame with drawings |
|
|
""" |
|
|
vis_frame = frame.copy() |
|
|
|
|
|
for result in results: |
|
|
x1, y1, x2, y2 = result['bbox'] |
|
|
ripeness = result['ripeness'] |
|
|
conf = result['confidence'] |
|
|
|
|
|
|
|
|
color = self.colors[ripeness] |
|
|
cv2.rectangle(vis_frame, (x1, y1), (x2, y2), color, 2) |
|
|
|
|
|
|
|
|
label = f"{ripeness} ({conf:.2f})" |
|
|
label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)[0] |
|
|
cv2.rectangle(vis_frame, (x1, y1 - label_size[1] - 10), |
|
|
(x1 + label_size[0], y1), color, -1) |
|
|
|
|
|
|
|
|
cv2.putText(vis_frame, label, (x1, y1 - 5), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) |
|
|
|
|
|
|
|
|
fps_text = f"FPS: {self.fps:.1f}" |
|
|
cv2.putText(vis_frame, fps_text, (10, 30), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) |
|
|
|
|
|
|
|
|
title = "Strawberry Picker AI - Press 'q' to quit" |
|
|
cv2.putText(vis_frame, title, (10, 60), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) |
|
|
|
|
|
return vis_frame |
|
|
|
|
|
def run_webcam(self, camera_index=0, width=640, height=480): |
|
|
""" |
|
|
Run real-time inference on webcam |
|
|
|
|
|
Args: |
|
|
camera_index: Camera index (0 for default webcam) |
|
|
width: Frame width |
|
|
height: Frame height |
|
|
""" |
|
|
print(f"\nπΉ Starting webcam (camera {camera_index})...") |
|
|
print("Press 'q' to quit, 's' to save screenshot") |
|
|
print("Make sure strawberries are well-lit and clearly visible\n") |
|
|
|
|
|
|
|
|
cap = cv2.VideoCapture(camera_index) |
|
|
|
|
|
if not cap.isOpened(): |
|
|
print(f"β Error: Could not open camera {camera_index}") |
|
|
print("\nTroubleshooting tips for WSL:") |
|
|
print("1. Install v4l2loopback: sudo apt-get install v4l2loopback-dkms") |
|
|
print("2. Load module: sudo modprobe v4l2loopback") |
|
|
print("3. Use IP webcam app on phone as alternative") |
|
|
print("4. Or use pre-recorded video file") |
|
|
return |
|
|
|
|
|
|
|
|
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width) |
|
|
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height) |
|
|
|
|
|
|
|
|
self.fps = 0 |
|
|
frame_count = 0 |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
screenshot_count = 0 |
|
|
|
|
|
try: |
|
|
while True: |
|
|
|
|
|
ret, frame = cap.read() |
|
|
|
|
|
if not ret: |
|
|
print("β Error: Could not read frame from camera") |
|
|
break |
|
|
|
|
|
|
|
|
results = self.detect_and_classify(frame) |
|
|
|
|
|
|
|
|
vis_frame = self.visualize(frame, results) |
|
|
|
|
|
|
|
|
frame_count += 1 |
|
|
if frame_count % 10 == 0: |
|
|
elapsed = time.time() - start_time |
|
|
self.fps = frame_count / elapsed |
|
|
|
|
|
|
|
|
cv2.imshow('Strawberry Picker AI', vis_frame) |
|
|
|
|
|
|
|
|
key = cv2.waitKey(1) & 0xFF |
|
|
|
|
|
if key == ord('q'): |
|
|
print("\nπ Quitting...") |
|
|
break |
|
|
elif key == ord('s'): |
|
|
|
|
|
screenshot_path = f"screenshot_{screenshot_count}.jpg" |
|
|
cv2.imwrite(screenshot_path, vis_frame) |
|
|
print(f"πΈ Screenshot saved: {screenshot_path}") |
|
|
screenshot_count += 1 |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\nπ Interrupted by user") |
|
|
|
|
|
finally: |
|
|
|
|
|
cap.release() |
|
|
cv2.destroyAllWindows() |
|
|
print("β
Webcam session ended") |
|
|
|
|
|
def run_video_file(self, video_path): |
|
|
""" |
|
|
Run inference on a video file |
|
|
|
|
|
Args: |
|
|
video_path: Path to video file |
|
|
""" |
|
|
print(f"\n㪠Processing video: {video_path}") |
|
|
|
|
|
cap = cv2.VideoCapture(video_path) |
|
|
|
|
|
if not cap.isOpened(): |
|
|
print(f"β Error: Could not open video file: {video_path}") |
|
|
return |
|
|
|
|
|
|
|
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
print(f"Video info: {width}x{height}, {fps} FPS, {total_frames} frames") |
|
|
|
|
|
|
|
|
output_path = f"output_{Path(video_path).name}" |
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
|
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
|
|
|
|
|
frame_count = 0 |
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
while True: |
|
|
ret, frame = cap.read() |
|
|
|
|
|
if not ret: |
|
|
break |
|
|
|
|
|
|
|
|
results = self.detect_and_classify(frame) |
|
|
vis_frame = self.visualize(frame, results) |
|
|
|
|
|
|
|
|
out.write(vis_frame) |
|
|
|
|
|
|
|
|
frame_count += 1 |
|
|
if frame_count % 30 == 0: |
|
|
progress = (frame_count / total_frames) * 100 |
|
|
elapsed = time.time() - start_time |
|
|
print(f"Progress: {progress:.1f}% ({frame_count}/{total_frames}) - " |
|
|
f"Time: {elapsed:.1f}s") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\nπ Interrupted by user") |
|
|
|
|
|
finally: |
|
|
cap.release() |
|
|
out.release() |
|
|
cv2.destroyAllWindows() |
|
|
print(f"β
Video processing complete. Output saved to: {output_path}") |
|
|
|
|
|
def main(): |
|
|
"""Main function with argument parsing""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description='Real-time Strawberry Detection and Ripeness Classification' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--detector', |
|
|
type=str, |
|
|
default='detection_model/best.pt', |
|
|
help='Path to YOLOv8 detection model' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--classifier', |
|
|
type=str, |
|
|
default='classification_model/best_enhanced_classifier.pth', |
|
|
help='Path to EfficientNet classification model' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--mode', |
|
|
type=str, |
|
|
choices=['webcam', 'video'], |
|
|
default='webcam', |
|
|
help='Mode: webcam or video file' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--input', |
|
|
type=str, |
|
|
help='Path to video file (if mode=video)' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--camera', |
|
|
type=int, |
|
|
default=0, |
|
|
help='Camera index (default: 0)' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--width', |
|
|
type=int, |
|
|
default=640, |
|
|
help='Camera frame width' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--height', |
|
|
type=int, |
|
|
default=480, |
|
|
help='Camera frame height' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--device', |
|
|
type=str, |
|
|
default='auto', |
|
|
choices=['auto', 'cpu', 'cuda'], |
|
|
help='Device to use for inference' |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.device == 'auto': |
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
else: |
|
|
device = args.device |
|
|
|
|
|
print(f"Using device: {device}") |
|
|
|
|
|
if device == 'cpu': |
|
|
print("β οΈ Running on CPU - this will be slower. Consider using GPU if available.") |
|
|
|
|
|
|
|
|
try: |
|
|
picker = StrawberryPickerWebcam( |
|
|
detector_path=args.detector, |
|
|
classifier_path=args.classifier, |
|
|
device=device |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"β Failed to initialize system: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if args.mode == 'webcam': |
|
|
picker.run_webcam( |
|
|
camera_index=args.camera, |
|
|
width=args.width, |
|
|
height=args.height |
|
|
) |
|
|
elif args.mode == 'video': |
|
|
if not args.input: |
|
|
print("β Error: --input required for video mode") |
|
|
sys.exit(1) |
|
|
picker.run_video_file(args.input) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
try: |
|
|
import torch |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
except ImportError as e: |
|
|
print(f"β Missing required library: {e}") |
|
|
print("Install with: pip install torch torchvision opencv-python pillow") |
|
|
sys.exit(1) |
|
|
|
|
|
main() |