# app.py for Gradio App on Hugging Face Spaces (with Sandbox Tab) import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer import random import time import json from typing import List, Dict, Tuple, Optional # --- Configuration --- MODEL_NAME = "Qwen/Qwen3-VL-4B-Instruct-FP8" # For Hugging Face Spaces, loading the model once here is common. # Performance on CPU will be slower. print("Loading model...") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16, # Use float16 for FP8 model if supported, or bfloat16 device_map="auto", # Automatically map to available devices (CPU or GPU) trust_remote_code=True ).eval() # Set to evaluation mode print("Model loaded successfully.") except Exception as e: print(f"Failed to load model: {e}") model = None tokenizer = None # --- Simulated Cost Database --- COSTS = { "ask_question": 10.0, "physical_exam": 25.0, "order_cbc": 50.0, "order_xray": 150.0, "administer_med": 30.0, "end_case": 0.0, "start_case": 0.0 # Cost for starting is typically 0 } # --- State Management Class --- class MedicalSimulatorState: def __init__(self): self.patient_profile: Optional[Dict] = None self.chat_history: List[Tuple[Optional[str], Optional[str]]] = [] # Gradio chat format: [(user_msg, bot_msg), ...] self.vitals: Dict[str, float] = {"HR": 72.0, "BP_Sys": 120.0, "BP_Dia": 80.0, "Temp": 98.6, "O2_Sat": 98.0} self.total_cost: float = 0.0 self.is_case_active: bool = False self.underlying_diagnosis: str = "" self.ordered_tests: Dict[str, str] = {} # e.g., {"cbc": "pending", "xray": "result..."} # Add more state variables as needed # --- Core AI Interaction Function (Medical Context) --- def get_ai_response_medical(user_input: str, history: List[Tuple[Optional[str], Optional[str]]], patient_profile: Dict, underlying_diagnosis: str) -> str: if not model or not tokenizer: return "Error: AI model is not loaded." # Construct a prompt for the AI based on history and patient profile history_str = "\n".join([f"{'User' if h[0] else 'System/AI'}: {h[0] or h[1]}" for h in history]) context = f"Patient Profile: Name: {patient_profile['name']}, Age: {patient_profile['age']}, Gender: {patient_profile['gender']}, Chief Complaint: {patient_profile['chief_complaint']}, History: {patient_profile['history']}, Medications: {patient_profile['medications']}, Allergies: {patient_profile['allergies']}, Social History: {patient_profile['social_history']}, Financial Status: {patient_profile['financial_status']}, Code Status: {patient_profile['code_status']}\nCurrent Chat History:\n{history_str}\nUser Action/Question: {user_input}\n" prompt = f"<|system|>You are an AI patient in a medical simulation. Role-play as the patient described in the profile. Be consistent with their history, demographics, and potential complaints. Respond to the user's input (which could be a question, exam instruction, or order). Your responses should simulate realistic patient dialogue and reactions, including potential anxiety or concerns. Do not reveal the secret diagnosis '{underlying_diagnosis}' directly, but your responses should be consistent with having that condition. Respond as if you are the patient speaking.<|user|>{context}<|assistant|>" try: inputs = tokenizer(prompt, return_tensors="pt") if inputs["input_ids"].shape[1] > 32768: # Check for model max length return "Error: Input prompt is too long for the model." # Generate response generate_ids = model.generate( inputs.input_ids, max_new_tokens=512, # Limit generated tokens do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=tokenizer.eos_token_id ) # Decode the generated response response_text = tokenizer.decode(generate_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) return response_text.strip() except Exception as e: print(f"Error during AI generation: {e}") return f"An error occurred while processing the AI response: {e}" # --- Core AI Interaction Function (General/Sandbox Context) --- def get_ai_response_general(user_input: str, history: List[Tuple[Optional[str], Optional[str]]]) -> str: if not model or not tokenizer: return "Error: AI model is not loaded." # Construct a prompt for the AI based on general chat history history_str = "\n".join([f"{'User' if h[0] else 'Assistant'}: {h[0] or h[1]}" for h in history]) prompt = f"<|system|>You are a helpful assistant. Answer the user's questions and follow their instructions.<|user|>{history_str}\n{user_input}<|assistant|>" try: inputs = tokenizer(prompt, return_tensors="pt") if inputs["input_ids"].shape[1] > 32768: # Check for model max length return "Error: Input prompt is too long for the model." # Generate response generate_ids = model.generate( inputs.input_ids, max_new_tokens=512, # Limit generated tokens do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=tokenizer.eos_token_id ) # Decode the generated response response_text = tokenizer.decode(generate_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) return response_text.strip() except Exception as e: print(f"Error during AI generation: {e}") return f"An error occurred while processing the AI response: {e}" # --- Tool Functions (Modify State) --- def start_case(case_type: str, state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str, str]: # --- Generate Patient Profile (Simplified Example) --- names = ["John Smith", "Emily Johnson", "Michael Brown", "Sarah Davis"] chief_complaints = { "General": ["Chest pain", "Shortness of breath", "Abdominal pain"], "Pediatric": ["Fever", "Cough", "Ear ache"], "Psychiatry": ["Feeling anxious", "Difficulty sleeping", "Low mood"], "Dual Diagnosis": ["Chest pain and feels anxious", "Abdominal pain after drinking"] } complaint_options = chief_complaints.get(case_type, chief_complaints["General"]) name = random.choice(names) age = random.randint(18, 80) if case_type != "Pediatric" else random.randint(0, 17) gender = random.choice(["Male", "Female"]) chief_complaint = random.choice(complaint_options) # Define underlying diagnosis based on complaint or case type diag_map = { "Chest pain": "Acute Myocardial Infarction", "Shortness of breath": "Pneumonia", "Abdominal pain": "Appendicitis", "Fever": "Viral Infection", "Cough": "Bronchitis", "Ear ache": "Otitis Media", "Feeling anxious": "Generalized Anxiety Disorder", "Difficulty sleeping": "Insomnia", "Low mood": "Major Depressive Disorder", "Chest pain and feels anxious": "Acute MI with Anxiety", "Abdominal pain after drinking": "Alcoholic Gastritis with Substance Use Disorder" } underlying_diagnosis = diag_map.get(chief_complaint, "Unknown Condition") patient = { "name": name, "age": age, "gender": gender, "chief_complaint": chief_complaint, "history": "Patient history relevant to complaint.", "medications": "Current medications.", "allergies": "Known allergies (e.g., Penicillin).", "social_history": "Social history details.", "financial_status": "Insurance status.", "code_status": "Full Code", "language": "English" } # Reset state state.patient_profile = patient state.chat_history = [("System", "New Case Started."), ("AI Patient", f"Hi, I'm {patient['name']}. I've been having {patient['chief_complaint']}.")] state.vitals = {"HR": 72.0, "BP_Sys": 120.0, "BP_Dia": 80.0, "Temp": 98.6, "O2_Sat": 98.0} state.total_cost = 0.0 state.is_case_active = True state.underlying_diagnosis = underlying_diagnosis state.ordered_tests = {} profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in patient.items()]) return state, state.chat_history, f"${state.total_cost:.2f}", profile_str def handle_chat(user_input: str, history: List[Tuple[Optional[str], Optional[str]]], state: MedicalSimulatorState) -> Tuple[List[Tuple[Optional[str], Optional[str]]], str]: if not state.is_case_active or not user_input.strip(): return history, f"${state.total_cost:.2f}" # Add user message to history history.append((user_input, None)) # Get AI response ai_response = get_ai_response_medical(user_input, history[:-1], state.patient_profile, state.underlying_diagnosis) # Pass history without the user's new message yet # Add AI response to history history[-1] = (user_input, ai_response) # Update the last entry with the AI's response return history, f"${state.total_cost:.2f}" # Cost doesn't change here, just return current def use_tool(tool_name: str, state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str]: if not state.is_case_active: return state, state.chat_history, f"${state.total_cost:.2f}" cost = COSTS.get(tool_name, 0.0) state.total_cost += cost if tool_name == "ask_question": ai_response = get_ai_response_medical("The user asks a general question to gather more history.", state.chat_history, state.patient_profile, state.underlying_diagnosis) state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]")) state.chat_history.append(("AI Patient", ai_response)) elif tool_name == "order_cbc": state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]")) state.chat_history.append(("Lab", "CBC Ordered. Result pending...")) # Simulate result appearing after a delay or another action # For now, add a simple result shortly after time.sleep(0.5) # Simulate processing delay state.chat_history.append(("Lab", "CBC Result: WBC slightly elevated, otherwise unremarkable.")) elif tool_name == "administer_med": med_name = "Medication X" # Simplified, could take input state.chat_history.append(("System", f"[Action: {tool_name} - {med_name}, Cost: ${cost:.2f}]")) # Check for allergies here in a real app state.chat_history.append(("AI Patient", f"Okay, I took the {med_name}.")) elif tool_name == "physical_exam": state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]")) state.chat_history.append(("System", "Physical Exam Performed. Findings noted.")) elif tool_name == "order_xray": state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]")) state.chat_history.append(("Imaging", "X-Ray Ordered. Result pending...")) # Placeholder for image result (could be a URL or base64 string) time.sleep(0.5) # Simulate processing delay state.chat_history.append(("Imaging", "Chest X-Ray Result: Normal lung fields, no acute findings. (Placeholder Image)")) # Add other tools as needed... return state, state.chat_history, f"${state.total_cost:.2f}" def end_case(state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str, str]: if not state.is_case_active: # Return current state if no case is active profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in (state.patient_profile or {}).items()]) return state, state.chat_history, f"${state.total_cost:.2f}", profile_str state.chat_history.append(("System", "Case Ended by User.")) state.is_case_active = False # In a full implementation, trigger the evaluation logic here # evaluation = run_evaluation(state) # Placeholder # state.chat_history.append(("System", f"Evaluation: {evaluation}")) # Add evaluation to chat or separate component profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in (state.patient_profile or {}).items()]) return state, state.chat_history, f"${state.total_cost:.2f}", profile_str # --- Sandbox Chat Handler --- def handle_sandbox_chat(user_input: str, history: List[Tuple[Optional[str], Optional[str]]]) -> List[Tuple[Optional[str], Optional[str]]]: if not user_input.strip(): return history # Add user message to history history.append((user_input, None)) # Get AI response (general context) ai_response = get_ai_response_general(user_input, history[:-1]) # Add AI response to history history[-1] = (user_input, ai_response) return history # --- Gradio Interface --- with gr.Blocks(title="Advanced Medical Simulator") as demo: gr.Markdown("# Advanced Medical Simulator") with gr.Tab("Medical Simulation"): # State component to hold the simulator state across interactions state = gr.State(lambda: MedicalSimulatorState()) with gr.Row(): with gr.Column(scale=2): # Chat Interface chatbot = gr.Chatbot(label="Patient Interaction", height=400, bubble_full_width=False) with gr.Row(): user_input = gr.Textbox(label="Your Action / Question", placeholder="Type your action or question here...", scale=4) submit_btn = gr.Button("Submit", scale=1) with gr.Column(scale=1): # Patient Chart / Info patient_chart = gr.Markdown(label="Patient Chart", value="Click 'Start New Case' to begin.") cost_display = gr.Textbox(label="Total Cost", value="$0.00", interactive=False) with gr.Row(): # Tool Panel with gr.Column(): gr.Markdown("### Tools") with gr.Row(): ask_btn = gr.Button("Ask Question ($10)") exam_btn = gr.Button("Physical Exam ($25)") with gr.Row(): cbc_btn = gr.Button("Order CBC ($50)") xray_btn = gr.Button("Order X-Ray ($150)") with gr.Row(): med_btn = gr.Button("Administer Med ($30)") end_btn = gr.Button("End Case", variant="stop") # Red button for ending with gr.Row(): # Case Controls start_case_btn = gr.Button("Start New Case (General)") case_type_dropdown = gr.Dropdown(["General", "Psychiatry", "Pediatric", "Dual Diagnosis"], label="Case Type", value="General") # Event Handling for Medical Simulation Tab start_case_btn.click( fn=start_case, inputs=[case_type_dropdown, state], outputs=[state, chatbot, cost_display, patient_chart] ) submit_btn.click( fn=handle_chat, inputs=[user_input, chatbot, state], outputs=[chatbot, cost_display] ).then( fn=lambda: "", # Clear the input textbox after submission inputs=[], outputs=[user_input] ) ask_btn.click(fn=lambda s: use_tool("ask_question", s), inputs=[state], outputs=[state, chatbot, cost_display]) exam_btn.click(fn=lambda s: use_tool("physical_exam", s), inputs=[state], outputs=[state, chatbot, cost_display]) cbc_btn.click(fn=lambda s: use_tool("order_cbc", s), inputs=[state], outputs=[state, chatbot, cost_display]) xray_btn.click(fn=lambda s: use_tool("order_xray", s), inputs=[state], outputs=[state, chatbot, cost_display]) med_btn.click(fn=lambda s: use_tool("administer_med", s), inputs=[state], outputs=[state, chatbot, cost_display]) end_btn.click(fn=end_case, inputs=[state], outputs=[state, chatbot, cost_display, patient_chart]) with gr.Tab("AI Sandbox"): sandbox_chatbot = gr.Chatbot(label="General AI Chat", height=500, bubble_full_width=False) with gr.Row(): sandbox_input = gr.Textbox(label="Message", placeholder="Ask anything...", scale=4) sandbox_submit = gr.Button("Send", scale=1) sandbox_submit.click( fn=handle_sandbox_chat, inputs=[sandbox_input, sandbox_chatbot], outputs=[sandbox_chatbot] ).then( fn=lambda: "", # Clear the input textbox after submission inputs=[], outputs=[sandbox_input] ) # Launch the app # For Hugging Face Spaces, Gradio handles the launch. demo.launch()