lepanto1571's picture
Upload 17 files
a1e520f verified
raw
history blame
7.07 kB
# Standard library imports
import os
import io
import asyncio
from pathlib import Path
from typing import Any, Dict, Optional
from contextlib import contextmanager
# Third-party imports
import gradio as gr
from fastapi import FastAPI, File, Form, UploadFile
import uvicorn
import json
# Local imports
from utils import validate_mime_type, format_results_to_mcp
from main import predict, app as fastapi_app
# Check if we're running in Hugging Face Spaces
IS_HF_SPACE = os.getenv("SPACE_ID") is not None
# Store the original working directory
ORIGINAL_CWD = os.getcwd()
@contextmanager
def stable_cwd():
"""Context manager to maintain a stable working directory."""
current_dir = os.getcwd()
try:
os.chdir(ORIGINAL_CWD)
yield
finally:
os.chdir(current_dir)
# Get MCP context with proper path handling
def get_mcp_context_wrapper() -> Dict[str, Any]:
"""Wrapper around get_mcp_context that handles Spaces paths."""
try:
with stable_cwd():
# In Spaces, we know the file is in the root directory
if IS_HF_SPACE:
context_path = Path(ORIGINAL_CWD) / 'mcp_context.json'
if not context_path.exists():
raise FileNotFoundError(f"MCP context file not found at {context_path}")
with open(context_path, 'r') as f:
return json.load(f)
# Otherwise use the utility function that searches multiple locations
from utils import get_mcp_context
return get_mcp_context()
except Exception as e:
return {
"error": f"Failed to load MCP context: {str(e)}",
"app": {
"id": "lepanto1571.wrapper.posebusters",
"name": "PoseBusters MCP wrapper",
"version": "0.5"
}
}
async def process_files(ligand_file, protein_file, crystal_file=None):
"""Process uploaded files using the robust predict function from main.py."""
try:
with stable_cwd():
# Handle file uploads in a way that works in both environments
async def create_upload_file(file_obj) -> Optional[UploadFile]:
if file_obj is None:
return None
try:
# Get the filename, handling both string paths and file-like objects
if hasattr(file_obj, 'orig_name'):
filename = file_obj.orig_name
elif hasattr(file_obj, 'name'):
filename = file_obj.name
else:
filename = str(file_obj)
# Create a UploadFile object that works in both environments
if hasattr(file_obj, 'read'):
# File-like object (e.g. in Spaces or already open file)
content = await file_obj.read() if asyncio.iscoroutinefunction(file_obj.read) else file_obj.read()
if hasattr(file_obj, 'seek'):
file_obj.seek(0) # Reset file pointer for future reads
file_like = io.BytesIO(content if isinstance(content, bytes) else content.encode('utf-8'))
return UploadFile(
file=file_like,
filename=Path(filename).name
)
elif hasattr(file_obj, 'name') and os.path.exists(file_obj.name):
# Local file path that exists
return UploadFile(
file=open(file_obj.name, "rb"),
filename=Path(filename).name
)
else:
raise ValueError(f"Unable to read file: {filename}")
except Exception as e:
raise ValueError(f"Failed to process file {getattr(file_obj, 'name', str(file_obj))}: {str(e)}")
# Convert files to UploadFile objects
ligand_input = await create_upload_file(ligand_file)
protein_input = await create_upload_file(protein_file)
crystal_input = await create_upload_file(crystal_file)
if not ligand_input or not protein_input:
raise ValueError("Both ligand and protein files are required")
# Use the action based on whether crystal file is provided
action = "redocking_validation" if crystal_input else "validate_pose"
# Call the robust predict function
return await predict(
action=action,
ligand_input=ligand_input,
protein_input=protein_input,
crystal_input=crystal_input
)
except Exception as e:
return {
"object_id": "validation_results",
"data": {
"columns": ["ligand_id", "status", "passed/total", "details"],
"rows": [[
Path(getattr(ligand_file, 'orig_name',
getattr(ligand_file, 'name', 'unknown'))).stem,
"❌",
"0/0",
str(e)
]]
}
}
theme = gr.themes.Base()
# Main validation interface
validation_ui = gr.Interface(
fn=process_files,
inputs=[
gr.File(label="Ligand (.sdf)", type="filepath", file_count="single"),
gr.File(label="Protein (.pdb)", type="filepath", file_count="single"),
gr.File(label="Crystal Ligand (.sdf, optional)", type="filepath", file_count="single")
],
outputs="json",
title="PoseBusters Validation",
description="Upload files to validate ligand–protein structures using PoseBusters.",
theme=theme,
cache_examples=False # Disable caching to prevent file handle issues
)
# MCP Context viewer
context_ui = gr.Interface(
fn=get_mcp_context_wrapper,
inputs=None,
outputs=gr.JSON(),
title="MCP Context",
description="View the MCP context metadata that defines this tool's capabilities.",
theme=theme,
cache_examples=False # Disable caching to prevent file handle issues
)
# Combined interface
demo = gr.TabbedInterface(
interface_list=[validation_ui, context_ui],
tab_names=["💊 Validation", "🔍 MCP Context"],
title="PoseBusters MCP Wrapper",
theme=theme
)
def main():
"""Run the server based on environment."""
if IS_HF_SPACE:
# In Hugging Face Spaces, just run the Gradio app
demo.launch(server_name="0.0.0.0")
else:
# For local/Docker, mount Gradio into FastAPI and run the server
app = gr.mount_gradio_app(fastapi_app, demo, path="/")
uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
main()