lepanto1571's picture
Upload 17 files
a1e520f verified
raw
history blame
4.09 kB
"""Shared utilities for the PoseBusters MCP wrapper."""
# Standard library imports
import csv
import mimetypes
from pathlib import Path
from typing import Any, Dict, List
import json
# Third-party imports
from fastapi import UploadFile, HTTPException
# Constants
ALLOWED_MIME_TYPES = {
".sdf": ["chemical/x-mdl-sdfile", "application/octet-stream", "text/plain"],
".pdb": ["chemical/x-pdb", "application/octet-stream", "text/plain"]
}
def validate_mime_type(file: UploadFile, expected_ext: str) -> None:
"""Validate file MIME type and extension."""
if not file.filename:
raise HTTPException(status_code=400, detail=f"Missing filename for {expected_ext} file")
file_ext = Path(file.filename).suffix.lower()
if file_ext != expected_ext:
raise HTTPException(
status_code=400,
detail=f"Invalid file extension. Expected {expected_ext}, got {file_ext}"
)
content_type = file.content_type or mimetypes.guess_type(file.filename)[0] or "application/octet-stream"
if not content_type or content_type not in ALLOWED_MIME_TYPES[expected_ext]:
raise HTTPException(
status_code=400,
detail=f"Invalid MIME type for {expected_ext} file"
)
def format_results_to_mcp(csv_text: str, stderr_text: str = "", ligand_id: str = "unknown") -> Dict[str, Any]:
"""Format PoseBusters results to MCP standard output format."""
rows: List[List[str]] = []
if not csv_text.strip():
error_msg = stderr_text.strip().splitlines()[0] if stderr_text.strip() else "PoseBusters output is empty."
rows.append([ligand_id, "❌", "0/0", error_msg])
else:
reader = csv.DictReader(csv_text.strip().splitlines())
for row in reader:
ligand_id = row.get("molecule", "?")
failed_tests = []
passed_n = total_n = 0
for k, v in row.items():
if k in {"file", "molecule"}:
continue
if v == "True":
passed_n += 1
total_n += 1
elif v == "False":
failed_tests.append(k)
total_n += 1
rows.append([
ligand_id,
"✅" if not failed_tests else "❌",
f"{passed_n}/{total_n}",
", ".join(failed_tests) if failed_tests else "All tests passed"
])
# Create response
response = {
"object_id": "validation_results",
"data": {
"columns": ["ligand_id", "status", "passed/total", "details"],
"rows": rows
}
}
# Validate row lengths
expected_cols = len(response["data"]["columns"])
for row in response["data"]["rows"]:
if len(row) != expected_cols:
raise ValueError(f"Row length {len(row)} does not match columns length {expected_cols}")
return response
def get_mcp_context() -> Dict[str, Any]:
"""Return MCP context from the JSON file."""
try:
# Get the directory where the current file (utils.py) is located
current_dir = Path(__file__).parent
context_path = current_dir / 'mcp_context.json'
# If not found in current directory, try the app root (for Spaces)
if not context_path.exists():
context_path = current_dir / '..' / 'mcp_context.json'
context_path = context_path.resolve()
if not context_path.exists():
raise FileNotFoundError(f"MCP context file not found at {context_path}")
with open(context_path, 'r') as f:
return json.load(f)
except FileNotFoundError as e:
raise HTTPException(
status_code=500,
detail=str(e)
)
except json.JSONDecodeError:
raise HTTPException(
status_code=500,
detail="Invalid MCP context JSON format"
)