BirdScopeAI / modal_bird_classifier.py
facemelter's picture
Initial commit to hf space for hackathon
ff0e97f verified
"""
Bird Species Classifier MCP Server on Modal
Updated to support:
1. classify_from_base64() - for IDE/Cursor clients and Gradio
2. classify_from_url() - for fallback/public images
"""
import modal
from fastmcp import FastMCP
import base64
import json
import httpx
from io import BytesIO
from PIL import Image
import torch
import os
# ============================================================================
# MODAL APP CONFIGURATION
# ============================================================================
app = modal.App("bird-classifier-mcp")
image = modal.Image.debian_slim(python_version="3.12").pip_install(
"transformers==4.46.0",
"torch==2.5.1",
"pillow==10.4.0",
"fastmcp>=2.13.0",
"pydantic>=2.10.0,<3.0.0",
"fastapi==0.115.14",
"httpx>=0.28.0",
)
API_KEY_SECRET = modal.Secret.from_name("bird-classifier-api-key")
# ============================================================================
# MCP SERVER DEFINITION
# ============================================================================
def make_mcp_server():
"""Create FastMCP server with bird classification tools."""
from transformers import pipeline
mcp = FastMCP("Bird Species Classifier")
print("🔄 Loading bird classifier model...")
classifier = pipeline(
"image-classification",
model="prithivMLmods/Bird-Species-Classifier-526",
device=0
)
print("✅ Model loaded!")
def preprocess_image(image: Image.Image, max_size: int = 800) -> Image.Image:
"""Resize and convert to RGB."""
if image.mode != 'RGB':
image = image.convert('RGB')
if max(image.size) > max_size:
ratio = max_size / max(image.size)
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio))
image = image.resize(new_size, Image.Resampling.LANCZOS)
return image
# ========================================================================
# TOOL 1: classify_from_base64 (PRIMARY - for IDE/Gradio)
# ========================================================================
@mcp.tool()
async def classify_from_base64(image_data: str) -> str:
"""
Classify a bird species from base64-encoded image data.
This is the primary tool for IDE clients and Gradio apps.
Accepts raw base64 or data URL format.
Args:
image_data: Base64-encoded image string (PNG/JPG)
Can be raw base64 or "data:image/png;base64,..."
Returns:
JSON string with species name and confidence score
Format: {"species": "Common Name", "confidence": 0.95}
"""
try:
# Handle data URL format
if image_data.startswith("data:"):
image_data = image_data.split(",")[1]
# Decode base64
print(f"[STATUS]: Decoding base64 image ({len(image_data)} chars)...")
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes))
image = preprocess_image(image)
# Classify
print(f"[STATUS]: Classifying image...")
results = classifier(image, top_k=1)
top_result = results[0]
return json.dumps({
"species": top_result['label'],
"confidence": round(top_result['score'], 4),
"source": "base64"
})
except Exception as e:
return json.dumps({
"error": str(e),
"species": None,
"confidence": 0.0
})
# ========================================================================
# TOOL 2: classify_from_url (FALLBACK)
# ========================================================================
@mcp.tool()
async def classify_from_url(image_url: str) -> str:
"""
Download image from URL and classify bird species.
Fallback tool for clients that have URL access.
Args:
image_url: URL to the image (https://example.com/bird.jpg)
Returns:
JSON string with species name and confidence score
"""
try:
print(f"[STATUS]: Downloading from URL...")
response = httpx.get(image_url, follow_redirects=True, timeout=15)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
image = preprocess_image(image)
results = classifier(image, top_k=1)
top_result = results[0]
return json.dumps({
"species": top_result['label'],
"confidence": round(top_result['score'], 4),
"source": "url"
})
except Exception as e:
return json.dumps({
"error": str(e),
"species": None,
"confidence": 0.0
})
return mcp
# ============================================================================
# WEB ENDPOINT WITH AUTHENTICATION
# ============================================================================
@app.function(
image=image,
#gpu="L40S",
gpu="T4",
secrets=[API_KEY_SECRET],
timeout=300,
min_containers=0,
max_containers=5,
scaledown_window=60,
)
@modal.asgi_app()
def web():
"""ASGI web endpoint for MCP server with API key auth."""
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
print("[STATUS]: Starting MCP server...")
mcp = make_mcp_server()
mcp_app = mcp.http_app(transport="streamable-http", stateless_http=True)
from fastapi.middleware.cors import CORSMiddleware
fastapi_app = FastAPI(
title="Bird Classifier MCP Server",
description="MCP server for bird species classification",
lifespan=mcp_app.lifespan
)
@fastapi_app.middleware("http")
async def verify_api_key(request: Request, call_next):
"""Verify API key on every request"""
api_key = request.headers.get("X-API-Key")
expected_key = os.environ.get("API_KEY")
if not api_key or api_key != expected_key:
return JSONResponse(
status_code=401,
content={"error": "Invalid or missing API key"}
)
return await call_next(request)
fastapi_app.mount("/", mcp_app)
print("[STATUS]: MCP server is ready!")
return fastapi_app
# ============================================================================
# TEST FUNCTION
# ============================================================================
@app.function(image=image, secrets=[API_KEY_SECRET])
async def test_classifier():
"""Test MCP server"""
from fastmcp import Client
from fastmcp.client.transports import StreamableHttpTransport
print("\n"+"="*70)
print("[STATUS]: Testing Bird Classifier MCP server...")
print("="*70+"\n")
server_url = f"{web.get_web_url()}/mcp/"
transport = StreamableHttpTransport(
url=server_url,
headers={"X-API-Key": os.environ.get("API_KEY")}
)
client = Client(transport)
try:
async with client:
# List tools
print("\nAvailable Tools:")
tools = await client.list_tools()
for tool in tools:
print(f" - {tool.name}")
# Test classify_from_url
print("\n"+"="*70)
print("[TEST 1]: classify_from_url")
print("="*70)
test_url = "https://images.unsplash.com/photo-1444464666168-49d633b86797?w=400"
result = await client.call_tool(
"classify_from_url",
arguments={"image_url": test_url}
)
if result.content:
result_text = result.content[0].text
data = json.loads(result_text)
print(f"[RESULT]: {data.get('species')} ({data.get('confidence'):.1%})")
except Exception as e:
print(f"[ERROR]: {e}")
print("\n"+"="*70)
print("[STATUS]: Test complete!")
print("="*70+"\n")