|
|
|
|
|
import os
|
|
|
import io
|
|
|
import asyncio
|
|
|
from pathlib import Path
|
|
|
from typing import Any, Dict, Optional
|
|
|
from contextlib import contextmanager
|
|
|
|
|
|
|
|
|
import gradio as gr
|
|
|
from fastapi import FastAPI, File, Form, UploadFile
|
|
|
import uvicorn
|
|
|
import json
|
|
|
|
|
|
|
|
|
from utils import validate_mime_type, format_results_to_mcp
|
|
|
from main import predict, app as fastapi_app
|
|
|
|
|
|
|
|
|
IS_HF_SPACE = os.getenv("SPACE_ID") is not None
|
|
|
|
|
|
|
|
|
ORIGINAL_CWD = os.getcwd()
|
|
|
|
|
|
@contextmanager
|
|
|
def stable_cwd():
|
|
|
"""Context manager to maintain a stable working directory."""
|
|
|
current_dir = os.getcwd()
|
|
|
try:
|
|
|
os.chdir(ORIGINAL_CWD)
|
|
|
yield
|
|
|
finally:
|
|
|
os.chdir(current_dir)
|
|
|
|
|
|
|
|
|
def get_mcp_context_wrapper() -> Dict[str, Any]:
|
|
|
"""Wrapper around get_mcp_context that handles Spaces paths."""
|
|
|
try:
|
|
|
with stable_cwd():
|
|
|
|
|
|
if IS_HF_SPACE:
|
|
|
context_path = Path(ORIGINAL_CWD) / 'mcp_context.json'
|
|
|
if not context_path.exists():
|
|
|
raise FileNotFoundError(f"MCP context file not found at {context_path}")
|
|
|
with open(context_path, 'r') as f:
|
|
|
return json.load(f)
|
|
|
|
|
|
from utils import get_mcp_context
|
|
|
return get_mcp_context()
|
|
|
except Exception as e:
|
|
|
return {
|
|
|
"error": f"Failed to load MCP context: {str(e)}",
|
|
|
"app": {
|
|
|
"id": "lepanto1571.wrapper.posebusters",
|
|
|
"name": "PoseBusters MCP wrapper",
|
|
|
"version": "0.5"
|
|
|
}
|
|
|
}
|
|
|
|
|
|
async def process_files(ligand_file, protein_file, crystal_file=None):
|
|
|
"""Process uploaded files using the robust predict function from main.py."""
|
|
|
try:
|
|
|
with stable_cwd():
|
|
|
|
|
|
async def create_upload_file(file_obj) -> Optional[UploadFile]:
|
|
|
if file_obj is None:
|
|
|
return None
|
|
|
|
|
|
try:
|
|
|
|
|
|
if hasattr(file_obj, 'orig_name'):
|
|
|
filename = file_obj.orig_name
|
|
|
elif hasattr(file_obj, 'name'):
|
|
|
filename = file_obj.name
|
|
|
else:
|
|
|
filename = str(file_obj)
|
|
|
|
|
|
|
|
|
if hasattr(file_obj, 'read'):
|
|
|
|
|
|
content = await file_obj.read() if asyncio.iscoroutinefunction(file_obj.read) else file_obj.read()
|
|
|
if hasattr(file_obj, 'seek'):
|
|
|
file_obj.seek(0)
|
|
|
file_like = io.BytesIO(content if isinstance(content, bytes) else content.encode('utf-8'))
|
|
|
return UploadFile(
|
|
|
file=file_like,
|
|
|
filename=Path(filename).name
|
|
|
)
|
|
|
elif hasattr(file_obj, 'name') and os.path.exists(file_obj.name):
|
|
|
|
|
|
return UploadFile(
|
|
|
file=open(file_obj.name, "rb"),
|
|
|
filename=Path(filename).name
|
|
|
)
|
|
|
else:
|
|
|
raise ValueError(f"Unable to read file: {filename}")
|
|
|
|
|
|
except Exception as e:
|
|
|
raise ValueError(f"Failed to process file {getattr(file_obj, 'name', str(file_obj))}: {str(e)}")
|
|
|
|
|
|
|
|
|
ligand_input = await create_upload_file(ligand_file)
|
|
|
protein_input = await create_upload_file(protein_file)
|
|
|
crystal_input = await create_upload_file(crystal_file)
|
|
|
|
|
|
if not ligand_input or not protein_input:
|
|
|
raise ValueError("Both ligand and protein files are required")
|
|
|
|
|
|
|
|
|
action = "redocking_validation" if crystal_input else "validate_pose"
|
|
|
|
|
|
|
|
|
return await predict(
|
|
|
action=action,
|
|
|
ligand_input=ligand_input,
|
|
|
protein_input=protein_input,
|
|
|
crystal_input=crystal_input
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
return {
|
|
|
"object_id": "validation_results",
|
|
|
"data": {
|
|
|
"columns": ["ligand_id", "status", "passed/total", "details"],
|
|
|
"rows": [[
|
|
|
Path(getattr(ligand_file, 'orig_name',
|
|
|
getattr(ligand_file, 'name', 'unknown'))).stem,
|
|
|
"❌",
|
|
|
"0/0",
|
|
|
str(e)
|
|
|
]]
|
|
|
}
|
|
|
}
|
|
|
|
|
|
theme = gr.themes.Base()
|
|
|
|
|
|
|
|
|
validation_ui = gr.Interface(
|
|
|
fn=process_files,
|
|
|
inputs=[
|
|
|
gr.File(label="Ligand (.sdf)", type="filepath", file_count="single"),
|
|
|
gr.File(label="Protein (.pdb)", type="filepath", file_count="single"),
|
|
|
gr.File(label="Crystal Ligand (.sdf, optional)", type="filepath", file_count="single")
|
|
|
],
|
|
|
outputs="json",
|
|
|
title="PoseBusters Validation",
|
|
|
description="Upload files to validate ligand–protein structures using PoseBusters.",
|
|
|
theme=theme,
|
|
|
cache_examples=False
|
|
|
)
|
|
|
|
|
|
|
|
|
context_ui = gr.Interface(
|
|
|
fn=get_mcp_context_wrapper,
|
|
|
inputs=None,
|
|
|
outputs=gr.JSON(),
|
|
|
title="MCP Context",
|
|
|
description="View the MCP context metadata that defines this tool's capabilities.",
|
|
|
theme=theme,
|
|
|
cache_examples=False
|
|
|
)
|
|
|
|
|
|
|
|
|
demo = gr.TabbedInterface(
|
|
|
interface_list=[validation_ui, context_ui],
|
|
|
tab_names=["💊 Validation", "🔍 MCP Context"],
|
|
|
title="PoseBusters MCP Wrapper",
|
|
|
theme=theme
|
|
|
)
|
|
|
|
|
|
def main():
|
|
|
"""Run the server based on environment."""
|
|
|
if IS_HF_SPACE:
|
|
|
|
|
|
demo.launch(server_name="0.0.0.0")
|
|
|
else:
|
|
|
|
|
|
app = gr.mount_gradio_app(fastapi_app, demo, path="/")
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|
|
|
|