""" Bird Species Classifier MCP Server on Modal Updated to support: 1. classify_from_base64() - for IDE/Cursor clients and Gradio 2. classify_from_url() - for fallback/public images """ import modal from fastmcp import FastMCP import base64 import json import httpx from io import BytesIO from PIL import Image import torch import os # ============================================================================ # MODAL APP CONFIGURATION # ============================================================================ app = modal.App("bird-classifier-mcp") image = modal.Image.debian_slim(python_version="3.12").pip_install( "transformers==4.46.0", "torch==2.5.1", "pillow==10.4.0", "fastmcp>=2.13.0", "pydantic>=2.10.0,<3.0.0", "fastapi==0.115.14", "httpx>=0.28.0", ) API_KEY_SECRET = modal.Secret.from_name("bird-classifier-api-key") # ============================================================================ # MCP SERVER DEFINITION # ============================================================================ def make_mcp_server(): """Create FastMCP server with bird classification tools.""" from transformers import pipeline mcp = FastMCP("Bird Species Classifier") print("🔄 Loading bird classifier model...") classifier = pipeline( "image-classification", model="prithivMLmods/Bird-Species-Classifier-526", device=0 ) print("✅ Model loaded!") def preprocess_image(image: Image.Image, max_size: int = 800) -> Image.Image: """Resize and convert to RGB.""" if image.mode != 'RGB': image = image.convert('RGB') if max(image.size) > max_size: ratio = max_size / max(image.size) new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio)) image = image.resize(new_size, Image.Resampling.LANCZOS) return image # ======================================================================== # TOOL 1: classify_from_base64 (PRIMARY - for IDE/Gradio) # ======================================================================== @mcp.tool() async def classify_from_base64(image_data: str) -> str: """ Classify a bird species from base64-encoded image data. This is the primary tool for IDE clients and Gradio apps. Accepts raw base64 or data URL format. Args: image_data: Base64-encoded image string (PNG/JPG) Can be raw base64 or "data:image/png;base64,..." Returns: JSON string with species name and confidence score Format: {"species": "Common Name", "confidence": 0.95} """ try: # Handle data URL format if image_data.startswith("data:"): image_data = image_data.split(",")[1] # Decode base64 print(f"[STATUS]: Decoding base64 image ({len(image_data)} chars)...") image_bytes = base64.b64decode(image_data) image = Image.open(BytesIO(image_bytes)) image = preprocess_image(image) # Classify print(f"[STATUS]: Classifying image...") results = classifier(image, top_k=1) top_result = results[0] return json.dumps({ "species": top_result['label'], "confidence": round(top_result['score'], 4), "source": "base64" }) except Exception as e: return json.dumps({ "error": str(e), "species": None, "confidence": 0.0 }) # ======================================================================== # TOOL 2: classify_from_url (FALLBACK) # ======================================================================== @mcp.tool() async def classify_from_url(image_url: str) -> str: """ Download image from URL and classify bird species. Fallback tool for clients that have URL access. Args: image_url: URL to the image (https://example.com/bird.jpg) Returns: JSON string with species name and confidence score """ try: print(f"[STATUS]: Downloading from URL...") response = httpx.get(image_url, follow_redirects=True, timeout=15) response.raise_for_status() image = Image.open(BytesIO(response.content)) image = preprocess_image(image) results = classifier(image, top_k=1) top_result = results[0] return json.dumps({ "species": top_result['label'], "confidence": round(top_result['score'], 4), "source": "url" }) except Exception as e: return json.dumps({ "error": str(e), "species": None, "confidence": 0.0 }) return mcp # ============================================================================ # WEB ENDPOINT WITH AUTHENTICATION # ============================================================================ @app.function( image=image, #gpu="L40S", gpu="T4", secrets=[API_KEY_SECRET], timeout=300, min_containers=0, max_containers=5, scaledown_window=60, ) @modal.asgi_app() def web(): """ASGI web endpoint for MCP server with API key auth.""" from fastapi import FastAPI, Request, HTTPException from fastapi.responses import JSONResponse print("[STATUS]: Starting MCP server...") mcp = make_mcp_server() mcp_app = mcp.http_app(transport="streamable-http", stateless_http=True) from fastapi.middleware.cors import CORSMiddleware fastapi_app = FastAPI( title="Bird Classifier MCP Server", description="MCP server for bird species classification", lifespan=mcp_app.lifespan ) @fastapi_app.middleware("http") async def verify_api_key(request: Request, call_next): """Verify API key on every request""" api_key = request.headers.get("X-API-Key") expected_key = os.environ.get("API_KEY") if not api_key or api_key != expected_key: return JSONResponse( status_code=401, content={"error": "Invalid or missing API key"} ) return await call_next(request) fastapi_app.mount("/", mcp_app) print("[STATUS]: MCP server is ready!") return fastapi_app # ============================================================================ # TEST FUNCTION # ============================================================================ @app.function(image=image, secrets=[API_KEY_SECRET]) async def test_classifier(): """Test MCP server""" from fastmcp import Client from fastmcp.client.transports import StreamableHttpTransport print("\n"+"="*70) print("[STATUS]: Testing Bird Classifier MCP server...") print("="*70+"\n") server_url = f"{web.get_web_url()}/mcp/" transport = StreamableHttpTransport( url=server_url, headers={"X-API-Key": os.environ.get("API_KEY")} ) client = Client(transport) try: async with client: # List tools print("\nAvailable Tools:") tools = await client.list_tools() for tool in tools: print(f" - {tool.name}") # Test classify_from_url print("\n"+"="*70) print("[TEST 1]: classify_from_url") print("="*70) test_url = "https://images.unsplash.com/photo-1444464666168-49d633b86797?w=400" result = await client.call_tool( "classify_from_url", arguments={"image_url": test_url} ) if result.content: result_text = result.content[0].text data = json.loads(result_text) print(f"[RESULT]: {data.get('species')} ({data.get('confidence'):.1%})") except Exception as e: print(f"[ERROR]: {e}") print("\n"+"="*70) print("[STATUS]: Test complete!") print("="*70+"\n")