|
|
|
|
|
import os
|
|
|
import subprocess
|
|
|
import tempfile
|
|
|
import traceback
|
|
|
from pathlib import Path
|
|
|
from typing import Dict, Any, Optional
|
|
|
|
|
|
|
|
|
import gradio as gr
|
|
|
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
|
|
|
from fastapi.responses import JSONResponse
|
|
|
from jsonschema import validate, ValidationError
|
|
|
from schema import MCP_CONTEXT_SCHEMA, MCP_PREDICT_RESPONSE_SCHEMA
|
|
|
|
|
|
|
|
|
from utils import validate_mime_type, format_results_to_mcp, get_mcp_context, ALLOWED_MIME_TYPES
|
|
|
|
|
|
app = FastAPI(title="PoseBusters MCP API")
|
|
|
|
|
|
@app.post("/mcp/predict")
|
|
|
async def predict(
|
|
|
action: str = Form(...),
|
|
|
ligand_input: UploadFile = File(...),
|
|
|
protein_input: UploadFile = File(...),
|
|
|
crystal_input: Optional[UploadFile] = File(None),
|
|
|
) -> Dict[str, Any]:
|
|
|
"""
|
|
|
MCP-compliant prediction endpoint.
|
|
|
Validates file types and runs PoseBusters validation.
|
|
|
Response is validated against MCP_PREDICT_RESPONSE_SCHEMA.
|
|
|
"""
|
|
|
try:
|
|
|
|
|
|
validate_mime_type(ligand_input, ".sdf")
|
|
|
validate_mime_type(protein_input, ".pdb")
|
|
|
if crystal_input:
|
|
|
validate_mime_type(crystal_input, ".sdf")
|
|
|
|
|
|
async def save_upload(upload: UploadFile, suffix: str) -> str:
|
|
|
"""Safely save uploaded file with secure naming."""
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
|
|
content = await upload.read()
|
|
|
tmp.write(content)
|
|
|
return tmp.name
|
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
|
try:
|
|
|
|
|
|
ligand_path = await save_upload(ligand_input, ".sdf")
|
|
|
protein_path = await save_upload(protein_input, ".pdb")
|
|
|
crystal_path = await save_upload(crystal_input, ".sdf") if crystal_input else None
|
|
|
|
|
|
|
|
|
os.chdir(tmpdir)
|
|
|
|
|
|
if action == "validate_pose":
|
|
|
cmd = ["bust", ligand_path, "-p", protein_path, "--outfmt", "csv"]
|
|
|
elif action == "redocking_validation":
|
|
|
if not crystal_path:
|
|
|
raise HTTPException(status_code=400, detail="Missing crystal ligand file")
|
|
|
cmd = ["bust", ligand_path, "-l", crystal_path, "-p", protein_path, "--outfmt", "csv"]
|
|
|
else:
|
|
|
raise HTTPException(status_code=400, detail=f"Unknown action: {action}")
|
|
|
|
|
|
|
|
|
|
|
|
result = subprocess.run(
|
|
|
cmd,
|
|
|
capture_output=True,
|
|
|
text=True,
|
|
|
check=False,
|
|
|
timeout=300
|
|
|
)
|
|
|
|
|
|
response = format_results_to_mcp(
|
|
|
result.stdout,
|
|
|
result.stderr,
|
|
|
Path(ligand_input.filename).stem
|
|
|
)
|
|
|
try:
|
|
|
validate(instance=response, schema=MCP_PREDICT_RESPONSE_SCHEMA)
|
|
|
return response
|
|
|
except ValidationError as e:
|
|
|
return JSONResponse(
|
|
|
status_code=500,
|
|
|
content={
|
|
|
"object_id": "validation_results",
|
|
|
"data": {
|
|
|
"columns": ["ligand_id", "status", "passed/total", "details"],
|
|
|
"rows": [[
|
|
|
Path(ligand_input.filename).stem,
|
|
|
"❌",
|
|
|
"0/0",
|
|
|
f"Schema validation error: {str(e)}"
|
|
|
]]
|
|
|
}
|
|
|
}
|
|
|
)
|
|
|
|
|
|
finally:
|
|
|
|
|
|
for path in [ligand_path, protein_path, crystal_path]:
|
|
|
if path and os.path.exists(path):
|
|
|
try:
|
|
|
os.unlink(path)
|
|
|
except OSError:
|
|
|
pass
|
|
|
|
|
|
except subprocess.TimeoutExpired as e:
|
|
|
return JSONResponse(
|
|
|
status_code=500,
|
|
|
content={
|
|
|
"object_id": "validation_results",
|
|
|
"data": {
|
|
|
"columns": ["ligand_id", "status", "passed/total", "details"],
|
|
|
"rows": [[
|
|
|
Path(ligand_input.filename).stem,
|
|
|
"❌",
|
|
|
"0/0",
|
|
|
f"Process timed out after {e.timeout} seconds"
|
|
|
]]
|
|
|
}
|
|
|
}
|
|
|
)
|
|
|
except HTTPException:
|
|
|
raise
|
|
|
except Exception as e:
|
|
|
return JSONResponse(
|
|
|
status_code=500,
|
|
|
content={
|
|
|
"object_id": "validation_results",
|
|
|
"data": {
|
|
|
"columns": ["ligand_id", "status", "passed/total", "details"],
|
|
|
"rows": [[
|
|
|
Path(ligand_input.filename).stem,
|
|
|
"❌",
|
|
|
"0/0",
|
|
|
str(e)
|
|
|
]]
|
|
|
},
|
|
|
"error": str(e),
|
|
|
"traceback": traceback.format_exc()
|
|
|
}
|
|
|
)
|
|
|
|
|
|
@app.get("/mcp/context")
|
|
|
def context() -> Dict[str, Any]:
|
|
|
"""
|
|
|
Return MCP context with empty initial validation results.
|
|
|
Response is validated against MCP_CONTEXT_SCHEMA.
|
|
|
"""
|
|
|
try:
|
|
|
context = get_mcp_context()
|
|
|
validate(instance=context, schema=MCP_CONTEXT_SCHEMA)
|
|
|
return context
|
|
|
except ValidationError as e:
|
|
|
return JSONResponse(
|
|
|
status_code=500,
|
|
|
content={"error": f"Schema validation error: {str(e)}"}
|
|
|
)
|
|
|
|
|
|
def main():
|
|
|
"""Run the FastAPI server."""
|
|
|
import uvicorn
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |