lepanto1571's picture
Upload 17 files
a1e520f verified
raw
history blame
6.43 kB
# Standard library imports
import os
import subprocess
import tempfile
import traceback
from pathlib import Path
from typing import Dict, Any, Optional
# Third-party imports
import gradio as gr
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from jsonschema import validate, ValidationError
from schema import MCP_CONTEXT_SCHEMA, MCP_PREDICT_RESPONSE_SCHEMA
# Local imports
from utils import validate_mime_type, format_results_to_mcp, get_mcp_context, ALLOWED_MIME_TYPES
app = FastAPI(title="PoseBusters MCP API")
@app.post("/mcp/predict")
async def predict(
action: str = Form(...),
ligand_input: UploadFile = File(...),
protein_input: UploadFile = File(...),
crystal_input: Optional[UploadFile] = File(None),
) -> Dict[str, Any]:
"""
MCP-compliant prediction endpoint.
Validates file types and runs PoseBusters validation.
Response is validated against MCP_PREDICT_RESPONSE_SCHEMA.
"""
try:
# Validate MIME types
validate_mime_type(ligand_input, ".sdf")
validate_mime_type(protein_input, ".pdb")
if crystal_input:
validate_mime_type(crystal_input, ".sdf")
async def save_upload(upload: UploadFile, suffix: str) -> str:
"""Safely save uploaded file with secure naming."""
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
content = await upload.read()
tmp.write(content)
return tmp.name
# Create a temporary directory for isolated execution
with tempfile.TemporaryDirectory() as tmpdir:
try:
# Save files in the temporary directory
ligand_path = await save_upload(ligand_input, ".sdf")
protein_path = await save_upload(protein_input, ".pdb")
crystal_path = await save_upload(crystal_input, ".sdf") if crystal_input else None
# Move to temporary directory
os.chdir(tmpdir)
if action == "validate_pose":
cmd = ["bust", ligand_path, "-p", protein_path, "--outfmt", "csv"]
elif action == "redocking_validation":
if not crystal_path:
raise HTTPException(status_code=400, detail="Missing crystal ligand file")
cmd = ["bust", ligand_path, "-l", crystal_path, "-p", protein_path, "--outfmt", "csv"]
else:
raise HTTPException(status_code=400, detail=f"Unknown action: {action}")
# Use list form of subprocess.run to avoid shell injection
# Add timeout of 5 minutes
result = subprocess.run(
cmd,
capture_output=True,
text=True,
check=False, # Don't raise on non-zero exit
timeout=300 # 5 minutes timeout
)
response = format_results_to_mcp(
result.stdout,
result.stderr,
Path(ligand_input.filename).stem
)
try:
validate(instance=response, schema=MCP_PREDICT_RESPONSE_SCHEMA)
return response
except ValidationError as e:
return JSONResponse(
status_code=500,
content={
"object_id": "validation_results",
"data": {
"columns": ["ligand_id", "status", "passed/total", "details"],
"rows": [[
Path(ligand_input.filename).stem,
"❌",
"0/0",
f"Schema validation error: {str(e)}"
]]
}
}
)
finally:
# Clean up temporary files
for path in [ligand_path, protein_path, crystal_path]:
if path and os.path.exists(path):
try:
os.unlink(path)
except OSError:
pass
except subprocess.TimeoutExpired as e:
return JSONResponse(
status_code=500,
content={
"object_id": "validation_results",
"data": {
"columns": ["ligand_id", "status", "passed/total", "details"],
"rows": [[
Path(ligand_input.filename).stem,
"❌",
"0/0",
f"Process timed out after {e.timeout} seconds"
]]
}
}
)
except HTTPException:
raise
except Exception as e:
return JSONResponse(
status_code=500,
content={
"object_id": "validation_results",
"data": {
"columns": ["ligand_id", "status", "passed/total", "details"],
"rows": [[
Path(ligand_input.filename).stem,
"❌",
"0/0",
str(e)
]]
},
"error": str(e),
"traceback": traceback.format_exc()
}
)
@app.get("/mcp/context")
def context() -> Dict[str, Any]:
"""
Return MCP context with empty initial validation results.
Response is validated against MCP_CONTEXT_SCHEMA.
"""
try:
context = get_mcp_context()
validate(instance=context, schema=MCP_CONTEXT_SCHEMA)
return context
except ValidationError as e:
return JSONResponse(
status_code=500,
content={"error": f"Schema validation error: {str(e)}"}
)
def main():
"""Run the FastAPI server."""
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
main()