# Standard library imports import os import subprocess import tempfile import traceback from pathlib import Path from typing import Dict, Any, Optional # Third-party imports import gradio as gr from fastapi import FastAPI, File, Form, UploadFile, HTTPException from fastapi.responses import JSONResponse from jsonschema import validate, ValidationError from schema import MCP_CONTEXT_SCHEMA, MCP_PREDICT_RESPONSE_SCHEMA # Local imports from utils import validate_mime_type, format_results_to_mcp, get_mcp_context, ALLOWED_MIME_TYPES app = FastAPI(title="PoseBusters MCP API") @app.post("/mcp/predict") async def predict( action: str = Form(...), ligand_input: UploadFile = File(...), protein_input: UploadFile = File(...), crystal_input: Optional[UploadFile] = File(None), ) -> Dict[str, Any]: """ MCP-compliant prediction endpoint. Validates file types and runs PoseBusters validation. Response is validated against MCP_PREDICT_RESPONSE_SCHEMA. """ try: # Validate MIME types validate_mime_type(ligand_input, ".sdf") validate_mime_type(protein_input, ".pdb") if crystal_input: validate_mime_type(crystal_input, ".sdf") async def save_upload(upload: UploadFile, suffix: str) -> str: """Safely save uploaded file with secure naming.""" with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: content = await upload.read() tmp.write(content) return tmp.name # Create a temporary directory for isolated execution with tempfile.TemporaryDirectory() as tmpdir: try: # Save files in the temporary directory ligand_path = await save_upload(ligand_input, ".sdf") protein_path = await save_upload(protein_input, ".pdb") crystal_path = await save_upload(crystal_input, ".sdf") if crystal_input else None # Move to temporary directory os.chdir(tmpdir) if action == "validate_pose": cmd = ["bust", ligand_path, "-p", protein_path, "--outfmt", "csv"] elif action == "redocking_validation": if not crystal_path: raise HTTPException(status_code=400, detail="Missing crystal ligand file") cmd = ["bust", ligand_path, "-l", crystal_path, "-p", protein_path, "--outfmt", "csv"] else: raise HTTPException(status_code=400, detail=f"Unknown action: {action}") # Use list form of subprocess.run to avoid shell injection # Add timeout of 5 minutes result = subprocess.run( cmd, capture_output=True, text=True, check=False, # Don't raise on non-zero exit timeout=300 # 5 minutes timeout ) response = format_results_to_mcp( result.stdout, result.stderr, Path(ligand_input.filename).stem ) try: validate(instance=response, schema=MCP_PREDICT_RESPONSE_SCHEMA) return response except ValidationError as e: return JSONResponse( status_code=500, content={ "object_id": "validation_results", "data": { "columns": ["ligand_id", "status", "passed/total", "details"], "rows": [[ Path(ligand_input.filename).stem, "❌", "0/0", f"Schema validation error: {str(e)}" ]] } } ) finally: # Clean up temporary files for path in [ligand_path, protein_path, crystal_path]: if path and os.path.exists(path): try: os.unlink(path) except OSError: pass except subprocess.TimeoutExpired as e: return JSONResponse( status_code=500, content={ "object_id": "validation_results", "data": { "columns": ["ligand_id", "status", "passed/total", "details"], "rows": [[ Path(ligand_input.filename).stem, "❌", "0/0", f"Process timed out after {e.timeout} seconds" ]] } } ) except HTTPException: raise except Exception as e: return JSONResponse( status_code=500, content={ "object_id": "validation_results", "data": { "columns": ["ligand_id", "status", "passed/total", "details"], "rows": [[ Path(ligand_input.filename).stem, "❌", "0/0", str(e) ]] }, "error": str(e), "traceback": traceback.format_exc() } ) @app.get("/mcp/context") def context() -> Dict[str, Any]: """ Return MCP context with empty initial validation results. Response is validated against MCP_CONTEXT_SCHEMA. """ try: context = get_mcp_context() validate(instance=context, schema=MCP_CONTEXT_SCHEMA) return context except ValidationError as e: return JSONResponse( status_code=500, content={"error": f"Schema validation error: {str(e)}"} ) def main(): """Run the FastAPI server.""" import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860) if __name__ == "__main__": main()