File size: 7,069 Bytes
a1e520f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
# Standard library imports
import os
import io
import asyncio
from pathlib import Path
from typing import Any, Dict, Optional
from contextlib import contextmanager
# Third-party imports
import gradio as gr
from fastapi import FastAPI, File, Form, UploadFile
import uvicorn
import json
# Local imports
from utils import validate_mime_type, format_results_to_mcp
from main import predict, app as fastapi_app
# Check if we're running in Hugging Face Spaces
IS_HF_SPACE = os.getenv("SPACE_ID") is not None
# Store the original working directory
ORIGINAL_CWD = os.getcwd()
@contextmanager
def stable_cwd():
"""Context manager to maintain a stable working directory."""
current_dir = os.getcwd()
try:
os.chdir(ORIGINAL_CWD)
yield
finally:
os.chdir(current_dir)
# Get MCP context with proper path handling
def get_mcp_context_wrapper() -> Dict[str, Any]:
"""Wrapper around get_mcp_context that handles Spaces paths."""
try:
with stable_cwd():
# In Spaces, we know the file is in the root directory
if IS_HF_SPACE:
context_path = Path(ORIGINAL_CWD) / 'mcp_context.json'
if not context_path.exists():
raise FileNotFoundError(f"MCP context file not found at {context_path}")
with open(context_path, 'r') as f:
return json.load(f)
# Otherwise use the utility function that searches multiple locations
from utils import get_mcp_context
return get_mcp_context()
except Exception as e:
return {
"error": f"Failed to load MCP context: {str(e)}",
"app": {
"id": "lepanto1571.wrapper.posebusters",
"name": "PoseBusters MCP wrapper",
"version": "0.5"
}
}
async def process_files(ligand_file, protein_file, crystal_file=None):
"""Process uploaded files using the robust predict function from main.py."""
try:
with stable_cwd():
# Handle file uploads in a way that works in both environments
async def create_upload_file(file_obj) -> Optional[UploadFile]:
if file_obj is None:
return None
try:
# Get the filename, handling both string paths and file-like objects
if hasattr(file_obj, 'orig_name'):
filename = file_obj.orig_name
elif hasattr(file_obj, 'name'):
filename = file_obj.name
else:
filename = str(file_obj)
# Create a UploadFile object that works in both environments
if hasattr(file_obj, 'read'):
# File-like object (e.g. in Spaces or already open file)
content = await file_obj.read() if asyncio.iscoroutinefunction(file_obj.read) else file_obj.read()
if hasattr(file_obj, 'seek'):
file_obj.seek(0) # Reset file pointer for future reads
file_like = io.BytesIO(content if isinstance(content, bytes) else content.encode('utf-8'))
return UploadFile(
file=file_like,
filename=Path(filename).name
)
elif hasattr(file_obj, 'name') and os.path.exists(file_obj.name):
# Local file path that exists
return UploadFile(
file=open(file_obj.name, "rb"),
filename=Path(filename).name
)
else:
raise ValueError(f"Unable to read file: {filename}")
except Exception as e:
raise ValueError(f"Failed to process file {getattr(file_obj, 'name', str(file_obj))}: {str(e)}")
# Convert files to UploadFile objects
ligand_input = await create_upload_file(ligand_file)
protein_input = await create_upload_file(protein_file)
crystal_input = await create_upload_file(crystal_file)
if not ligand_input or not protein_input:
raise ValueError("Both ligand and protein files are required")
# Use the action based on whether crystal file is provided
action = "redocking_validation" if crystal_input else "validate_pose"
# Call the robust predict function
return await predict(
action=action,
ligand_input=ligand_input,
protein_input=protein_input,
crystal_input=crystal_input
)
except Exception as e:
return {
"object_id": "validation_results",
"data": {
"columns": ["ligand_id", "status", "passed/total", "details"],
"rows": [[
Path(getattr(ligand_file, 'orig_name',
getattr(ligand_file, 'name', 'unknown'))).stem,
"❌",
"0/0",
str(e)
]]
}
}
theme = gr.themes.Base()
# Main validation interface
validation_ui = gr.Interface(
fn=process_files,
inputs=[
gr.File(label="Ligand (.sdf)", type="filepath", file_count="single"),
gr.File(label="Protein (.pdb)", type="filepath", file_count="single"),
gr.File(label="Crystal Ligand (.sdf, optional)", type="filepath", file_count="single")
],
outputs="json",
title="PoseBusters Validation",
description="Upload files to validate ligand–protein structures using PoseBusters.",
theme=theme,
cache_examples=False # Disable caching to prevent file handle issues
)
# MCP Context viewer
context_ui = gr.Interface(
fn=get_mcp_context_wrapper,
inputs=None,
outputs=gr.JSON(),
title="MCP Context",
description="View the MCP context metadata that defines this tool's capabilities.",
theme=theme,
cache_examples=False # Disable caching to prevent file handle issues
)
# Combined interface
demo = gr.TabbedInterface(
interface_list=[validation_ui, context_ui],
tab_names=["💊 Validation", "🔍 MCP Context"],
title="PoseBusters MCP Wrapper",
theme=theme
)
def main():
"""Run the server based on environment."""
if IS_HF_SPACE:
# In Hugging Face Spaces, just run the Gradio app
demo.launch(server_name="0.0.0.0")
else:
# For local/Docker, mount Gradio into FastAPI and run the server
app = gr.mount_gradio_app(fastapi_app, demo, path="/")
uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
main()
|