|
|
""" |
|
|
Bird Species Classifier MCP Server on Modal |
|
|
Updated to support: |
|
|
1. classify_from_base64() - for IDE/Cursor clients and Gradio |
|
|
2. classify_from_url() - for fallback/public images |
|
|
""" |
|
|
|
|
|
import modal |
|
|
from fastmcp import FastMCP |
|
|
import base64 |
|
|
import json |
|
|
import httpx |
|
|
from io import BytesIO |
|
|
from PIL import Image |
|
|
import torch |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = modal.App("bird-classifier-mcp") |
|
|
|
|
|
image = modal.Image.debian_slim(python_version="3.12").pip_install( |
|
|
"transformers==4.46.0", |
|
|
"torch==2.5.1", |
|
|
"pillow==10.4.0", |
|
|
"fastmcp>=2.13.0", |
|
|
"pydantic>=2.10.0,<3.0.0", |
|
|
"fastapi==0.115.14", |
|
|
"httpx>=0.28.0", |
|
|
) |
|
|
|
|
|
API_KEY_SECRET = modal.Secret.from_name("bird-classifier-api-key") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_mcp_server(): |
|
|
"""Create FastMCP server with bird classification tools.""" |
|
|
from transformers import pipeline |
|
|
|
|
|
mcp = FastMCP("Bird Species Classifier") |
|
|
|
|
|
print("🔄 Loading bird classifier model...") |
|
|
classifier = pipeline( |
|
|
"image-classification", |
|
|
model="prithivMLmods/Bird-Species-Classifier-526", |
|
|
device=0 |
|
|
) |
|
|
print("✅ Model loaded!") |
|
|
|
|
|
def preprocess_image(image: Image.Image, max_size: int = 800) -> Image.Image: |
|
|
"""Resize and convert to RGB.""" |
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
|
|
|
if max(image.size) > max_size: |
|
|
ratio = max_size / max(image.size) |
|
|
new_size = (int(image.size[0] * ratio), int(image.size[1] * ratio)) |
|
|
image = image.resize(new_size, Image.Resampling.LANCZOS) |
|
|
|
|
|
return image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@mcp.tool() |
|
|
async def classify_from_base64(image_data: str) -> str: |
|
|
""" |
|
|
Classify a bird species from base64-encoded image data. |
|
|
|
|
|
This is the primary tool for IDE clients and Gradio apps. |
|
|
Accepts raw base64 or data URL format. |
|
|
|
|
|
Args: |
|
|
image_data: Base64-encoded image string (PNG/JPG) |
|
|
Can be raw base64 or "data:image/png;base64,..." |
|
|
|
|
|
Returns: |
|
|
JSON string with species name and confidence score |
|
|
Format: {"species": "Common Name", "confidence": 0.95} |
|
|
""" |
|
|
try: |
|
|
|
|
|
if image_data.startswith("data:"): |
|
|
image_data = image_data.split(",")[1] |
|
|
|
|
|
|
|
|
print(f"[STATUS]: Decoding base64 image ({len(image_data)} chars)...") |
|
|
image_bytes = base64.b64decode(image_data) |
|
|
image = Image.open(BytesIO(image_bytes)) |
|
|
image = preprocess_image(image) |
|
|
|
|
|
|
|
|
print(f"[STATUS]: Classifying image...") |
|
|
results = classifier(image, top_k=1) |
|
|
top_result = results[0] |
|
|
|
|
|
return json.dumps({ |
|
|
"species": top_result['label'], |
|
|
"confidence": round(top_result['score'], 4), |
|
|
"source": "base64" |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
return json.dumps({ |
|
|
"error": str(e), |
|
|
"species": None, |
|
|
"confidence": 0.0 |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@mcp.tool() |
|
|
async def classify_from_url(image_url: str) -> str: |
|
|
""" |
|
|
Download image from URL and classify bird species. |
|
|
|
|
|
Fallback tool for clients that have URL access. |
|
|
|
|
|
Args: |
|
|
image_url: URL to the image (https://example.com/bird.jpg) |
|
|
|
|
|
Returns: |
|
|
JSON string with species name and confidence score |
|
|
""" |
|
|
try: |
|
|
print(f"[STATUS]: Downloading from URL...") |
|
|
response = httpx.get(image_url, follow_redirects=True, timeout=15) |
|
|
response.raise_for_status() |
|
|
|
|
|
image = Image.open(BytesIO(response.content)) |
|
|
image = preprocess_image(image) |
|
|
|
|
|
results = classifier(image, top_k=1) |
|
|
top_result = results[0] |
|
|
|
|
|
return json.dumps({ |
|
|
"species": top_result['label'], |
|
|
"confidence": round(top_result['score'], 4), |
|
|
"source": "url" |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
return json.dumps({ |
|
|
"error": str(e), |
|
|
"species": None, |
|
|
"confidence": 0.0 |
|
|
}) |
|
|
|
|
|
return mcp |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.function( |
|
|
image=image, |
|
|
|
|
|
gpu="T4", |
|
|
secrets=[API_KEY_SECRET], |
|
|
timeout=300, |
|
|
min_containers=0, |
|
|
max_containers=5, |
|
|
scaledown_window=60, |
|
|
) |
|
|
@modal.asgi_app() |
|
|
def web(): |
|
|
"""ASGI web endpoint for MCP server with API key auth.""" |
|
|
from fastapi import FastAPI, Request, HTTPException |
|
|
from fastapi.responses import JSONResponse |
|
|
|
|
|
print("[STATUS]: Starting MCP server...") |
|
|
|
|
|
mcp = make_mcp_server() |
|
|
mcp_app = mcp.http_app(transport="streamable-http", stateless_http=True) |
|
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
|
|
fastapi_app = FastAPI( |
|
|
title="Bird Classifier MCP Server", |
|
|
description="MCP server for bird species classification", |
|
|
lifespan=mcp_app.lifespan |
|
|
) |
|
|
|
|
|
@fastapi_app.middleware("http") |
|
|
async def verify_api_key(request: Request, call_next): |
|
|
"""Verify API key on every request""" |
|
|
api_key = request.headers.get("X-API-Key") |
|
|
expected_key = os.environ.get("API_KEY") |
|
|
|
|
|
if not api_key or api_key != expected_key: |
|
|
return JSONResponse( |
|
|
status_code=401, |
|
|
content={"error": "Invalid or missing API key"} |
|
|
) |
|
|
|
|
|
return await call_next(request) |
|
|
|
|
|
fastapi_app.mount("/", mcp_app) |
|
|
|
|
|
print("[STATUS]: MCP server is ready!") |
|
|
return fastapi_app |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.function(image=image, secrets=[API_KEY_SECRET]) |
|
|
async def test_classifier(): |
|
|
"""Test MCP server""" |
|
|
from fastmcp import Client |
|
|
from fastmcp.client.transports import StreamableHttpTransport |
|
|
|
|
|
print("\n"+"="*70) |
|
|
print("[STATUS]: Testing Bird Classifier MCP server...") |
|
|
print("="*70+"\n") |
|
|
|
|
|
server_url = f"{web.get_web_url()}/mcp/" |
|
|
|
|
|
transport = StreamableHttpTransport( |
|
|
url=server_url, |
|
|
headers={"X-API-Key": os.environ.get("API_KEY")} |
|
|
) |
|
|
|
|
|
client = Client(transport) |
|
|
|
|
|
try: |
|
|
async with client: |
|
|
|
|
|
print("\nAvailable Tools:") |
|
|
tools = await client.list_tools() |
|
|
for tool in tools: |
|
|
print(f" - {tool.name}") |
|
|
|
|
|
|
|
|
print("\n"+"="*70) |
|
|
print("[TEST 1]: classify_from_url") |
|
|
print("="*70) |
|
|
|
|
|
test_url = "https://images.unsplash.com/photo-1444464666168-49d633b86797?w=400" |
|
|
result = await client.call_tool( |
|
|
"classify_from_url", |
|
|
arguments={"image_url": test_url} |
|
|
) |
|
|
|
|
|
if result.content: |
|
|
result_text = result.content[0].text |
|
|
data = json.loads(result_text) |
|
|
print(f"[RESULT]: {data.get('species')} ({data.get('confidence'):.1%})") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[ERROR]: {e}") |
|
|
|
|
|
print("\n"+"="*70) |
|
|
print("[STATUS]: Test complete!") |
|
|
print("="*70+"\n") |
|
|
|