Spaces:
Sleeping
Sleeping
| # app.py for Gradio App on Hugging Face Spaces (with Sandbox Tab) | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import random | |
| import time | |
| import json | |
| from typing import List, Dict, Tuple, Optional | |
| # --- Configuration --- | |
| MODEL_NAME = "Qwen/Qwen3-VL-4B-Instruct-FP8" | |
| # For Hugging Face Spaces, loading the model once here is common. | |
| # Performance on CPU will be slower. | |
| print("Loading model...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float16, # Use float16 for FP8 model if supported, or bfloat16 | |
| device_map="auto", # Automatically map to available devices (CPU or GPU) | |
| trust_remote_code=True | |
| ).eval() # Set to evaluation mode | |
| print("Model loaded successfully.") | |
| except Exception as e: | |
| print(f"Failed to load model: {e}") | |
| model = None | |
| tokenizer = None | |
| # --- Simulated Cost Database --- | |
| COSTS = { | |
| "ask_question": 10.0, | |
| "physical_exam": 25.0, | |
| "order_cbc": 50.0, | |
| "order_xray": 150.0, | |
| "administer_med": 30.0, | |
| "end_case": 0.0, | |
| "start_case": 0.0 # Cost for starting is typically 0 | |
| } | |
| # --- State Management Class --- | |
| class MedicalSimulatorState: | |
| def __init__(self): | |
| self.patient_profile: Optional[Dict] = None | |
| self.chat_history: List[Tuple[Optional[str], Optional[str]]] = [] # Gradio chat format: [(user_msg, bot_msg), ...] | |
| self.vitals: Dict[str, float] = {"HR": 72.0, "BP_Sys": 120.0, "BP_Dia": 80.0, "Temp": 98.6, "O2_Sat": 98.0} | |
| self.total_cost: float = 0.0 | |
| self.is_case_active: bool = False | |
| self.underlying_diagnosis: str = "" | |
| self.ordered_tests: Dict[str, str] = {} # e.g., {"cbc": "pending", "xray": "result..."} | |
| # Add more state variables as needed | |
| # --- Core AI Interaction Function (Medical Context) --- | |
| def get_ai_response_medical(user_input: str, history: List[Tuple[Optional[str], Optional[str]]], patient_profile: Dict, underlying_diagnosis: str) -> str: | |
| if not model or not tokenizer: | |
| return "Error: AI model is not loaded." | |
| # Construct a prompt for the AI based on history and patient profile | |
| history_str = "\n".join([f"{'User' if h[0] else 'System/AI'}: {h[0] or h[1]}" for h in history]) | |
| context = f"Patient Profile: Name: {patient_profile['name']}, Age: {patient_profile['age']}, Gender: {patient_profile['gender']}, Chief Complaint: {patient_profile['chief_complaint']}, History: {patient_profile['history']}, Medications: {patient_profile['medications']}, Allergies: {patient_profile['allergies']}, Social History: {patient_profile['social_history']}, Financial Status: {patient_profile['financial_status']}, Code Status: {patient_profile['code_status']}\nCurrent Chat History:\n{history_str}\nUser Action/Question: {user_input}\n" | |
| prompt = f"<|system|>You are an AI patient in a medical simulation. Role-play as the patient described in the profile. Be consistent with their history, demographics, and potential complaints. Respond to the user's input (which could be a question, exam instruction, or order). Your responses should simulate realistic patient dialogue and reactions, including potential anxiety or concerns. Do not reveal the secret diagnosis '{underlying_diagnosis}' directly, but your responses should be consistent with having that condition. Respond as if you are the patient speaking.<|user|>{context}<|assistant|>" | |
| try: | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| if inputs["input_ids"].shape[1] > 32768: # Check for model max length | |
| return "Error: Input prompt is too long for the model." | |
| # Generate response | |
| generate_ids = model.generate( | |
| inputs.input_ids, | |
| max_new_tokens=512, # Limit generated tokens | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode the generated response | |
| response_text = tokenizer.decode(generate_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| return response_text.strip() | |
| except Exception as e: | |
| print(f"Error during AI generation: {e}") | |
| return f"An error occurred while processing the AI response: {e}" | |
| # --- Core AI Interaction Function (General/Sandbox Context) --- | |
| def get_ai_response_general(user_input: str, history: List[Tuple[Optional[str], Optional[str]]]) -> str: | |
| if not model or not tokenizer: | |
| return "Error: AI model is not loaded." | |
| # Construct a prompt for the AI based on general chat history | |
| history_str = "\n".join([f"{'User' if h[0] else 'Assistant'}: {h[0] or h[1]}" for h in history]) | |
| prompt = f"<|system|>You are a helpful assistant. Answer the user's questions and follow their instructions.<|user|>{history_str}\n{user_input}<|assistant|>" | |
| try: | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| if inputs["input_ids"].shape[1] > 32768: # Check for model max length | |
| return "Error: Input prompt is too long for the model." | |
| # Generate response | |
| generate_ids = model.generate( | |
| inputs.input_ids, | |
| max_new_tokens=512, # Limit generated tokens | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode the generated response | |
| response_text = tokenizer.decode(generate_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| return response_text.strip() | |
| except Exception as e: | |
| print(f"Error during AI generation: {e}") | |
| return f"An error occurred while processing the AI response: {e}" | |
| # --- Tool Functions (Modify State) --- | |
| def start_case(case_type: str, state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str, str]: | |
| # --- Generate Patient Profile (Simplified Example) --- | |
| names = ["John Smith", "Emily Johnson", "Michael Brown", "Sarah Davis"] | |
| chief_complaints = { | |
| "General": ["Chest pain", "Shortness of breath", "Abdominal pain"], | |
| "Pediatric": ["Fever", "Cough", "Ear ache"], | |
| "Psychiatry": ["Feeling anxious", "Difficulty sleeping", "Low mood"], | |
| "Dual Diagnosis": ["Chest pain and feels anxious", "Abdominal pain after drinking"] | |
| } | |
| complaint_options = chief_complaints.get(case_type, chief_complaints["General"]) | |
| name = random.choice(names) | |
| age = random.randint(18, 80) if case_type != "Pediatric" else random.randint(0, 17) | |
| gender = random.choice(["Male", "Female"]) | |
| chief_complaint = random.choice(complaint_options) | |
| # Define underlying diagnosis based on complaint or case type | |
| diag_map = { | |
| "Chest pain": "Acute Myocardial Infarction", | |
| "Shortness of breath": "Pneumonia", | |
| "Abdominal pain": "Appendicitis", | |
| "Fever": "Viral Infection", | |
| "Cough": "Bronchitis", | |
| "Ear ache": "Otitis Media", | |
| "Feeling anxious": "Generalized Anxiety Disorder", | |
| "Difficulty sleeping": "Insomnia", | |
| "Low mood": "Major Depressive Disorder", | |
| "Chest pain and feels anxious": "Acute MI with Anxiety", | |
| "Abdominal pain after drinking": "Alcoholic Gastritis with Substance Use Disorder" | |
| } | |
| underlying_diagnosis = diag_map.get(chief_complaint, "Unknown Condition") | |
| patient = { | |
| "name": name, | |
| "age": age, | |
| "gender": gender, | |
| "chief_complaint": chief_complaint, | |
| "history": "Patient history relevant to complaint.", | |
| "medications": "Current medications.", | |
| "allergies": "Known allergies (e.g., Penicillin).", | |
| "social_history": "Social history details.", | |
| "financial_status": "Insurance status.", | |
| "code_status": "Full Code", | |
| "language": "English" | |
| } | |
| # Reset state | |
| state.patient_profile = patient | |
| state.chat_history = [("System", "New Case Started."), ("AI Patient", f"Hi, I'm {patient['name']}. I've been having {patient['chief_complaint']}.")] | |
| state.vitals = {"HR": 72.0, "BP_Sys": 120.0, "BP_Dia": 80.0, "Temp": 98.6, "O2_Sat": 98.0} | |
| state.total_cost = 0.0 | |
| state.is_case_active = True | |
| state.underlying_diagnosis = underlying_diagnosis | |
| state.ordered_tests = {} | |
| profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in patient.items()]) | |
| return state, state.chat_history, f"${state.total_cost:.2f}", profile_str | |
| def handle_chat(user_input: str, history: List[Tuple[Optional[str], Optional[str]]], state: MedicalSimulatorState) -> Tuple[List[Tuple[Optional[str], Optional[str]]], str]: | |
| if not state.is_case_active or not user_input.strip(): | |
| return history, f"${state.total_cost:.2f}" | |
| # Add user message to history | |
| history.append((user_input, None)) | |
| # Get AI response | |
| ai_response = get_ai_response_medical(user_input, history[:-1], state.patient_profile, state.underlying_diagnosis) # Pass history without the user's new message yet | |
| # Add AI response to history | |
| history[-1] = (user_input, ai_response) # Update the last entry with the AI's response | |
| return history, f"${state.total_cost:.2f}" # Cost doesn't change here, just return current | |
| def use_tool(tool_name: str, state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str]: | |
| if not state.is_case_active: | |
| return state, state.chat_history, f"${state.total_cost:.2f}" | |
| cost = COSTS.get(tool_name, 0.0) | |
| state.total_cost += cost | |
| if tool_name == "ask_question": | |
| ai_response = get_ai_response_medical("The user asks a general question to gather more history.", state.chat_history, state.patient_profile, state.underlying_diagnosis) | |
| state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]")) | |
| state.chat_history.append(("AI Patient", ai_response)) | |
| elif tool_name == "order_cbc": | |
| state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]")) | |
| state.chat_history.append(("Lab", "CBC Ordered. Result pending...")) | |
| # Simulate result appearing after a delay or another action | |
| # For now, add a simple result shortly after | |
| time.sleep(0.5) # Simulate processing delay | |
| state.chat_history.append(("Lab", "CBC Result: WBC slightly elevated, otherwise unremarkable.")) | |
| elif tool_name == "administer_med": | |
| med_name = "Medication X" # Simplified, could take input | |
| state.chat_history.append(("System", f"[Action: {tool_name} - {med_name}, Cost: ${cost:.2f}]")) | |
| # Check for allergies here in a real app | |
| state.chat_history.append(("AI Patient", f"Okay, I took the {med_name}.")) | |
| elif tool_name == "physical_exam": | |
| state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]")) | |
| state.chat_history.append(("System", "Physical Exam Performed. Findings noted.")) | |
| elif tool_name == "order_xray": | |
| state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]")) | |
| state.chat_history.append(("Imaging", "X-Ray Ordered. Result pending...")) | |
| # Placeholder for image result (could be a URL or base64 string) | |
| time.sleep(0.5) # Simulate processing delay | |
| state.chat_history.append(("Imaging", "Chest X-Ray Result: Normal lung fields, no acute findings. (Placeholder Image)")) | |
| # Add other tools as needed... | |
| return state, state.chat_history, f"${state.total_cost:.2f}" | |
| def end_case(state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str, str]: | |
| if not state.is_case_active: | |
| # Return current state if no case is active | |
| profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in (state.patient_profile or {}).items()]) | |
| return state, state.chat_history, f"${state.total_cost:.2f}", profile_str | |
| state.chat_history.append(("System", "Case Ended by User.")) | |
| state.is_case_active = False | |
| # In a full implementation, trigger the evaluation logic here | |
| # evaluation = run_evaluation(state) # Placeholder | |
| # state.chat_history.append(("System", f"Evaluation: {evaluation}")) # Add evaluation to chat or separate component | |
| profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in (state.patient_profile or {}).items()]) | |
| return state, state.chat_history, f"${state.total_cost:.2f}", profile_str | |
| # --- Sandbox Chat Handler --- | |
| def handle_sandbox_chat(user_input: str, history: List[Tuple[Optional[str], Optional[str]]]) -> List[Tuple[Optional[str], Optional[str]]]: | |
| if not user_input.strip(): | |
| return history | |
| # Add user message to history | |
| history.append((user_input, None)) | |
| # Get AI response (general context) | |
| ai_response = get_ai_response_general(user_input, history[:-1]) | |
| # Add AI response to history | |
| history[-1] = (user_input, ai_response) | |
| return history | |
| # --- Gradio Interface --- | |
| with gr.Blocks(title="Advanced Medical Simulator") as demo: | |
| gr.Markdown("# Advanced Medical Simulator") | |
| with gr.Tab("Medical Simulation"): | |
| # State component to hold the simulator state across interactions | |
| state = gr.State(lambda: MedicalSimulatorState()) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Chat Interface | |
| chatbot = gr.Chatbot(label="Patient Interaction", height=400, bubble_full_width=False) | |
| with gr.Row(): | |
| user_input = gr.Textbox(label="Your Action / Question", placeholder="Type your action or question here...", scale=4) | |
| submit_btn = gr.Button("Submit", scale=1) | |
| with gr.Column(scale=1): | |
| # Patient Chart / Info | |
| patient_chart = gr.Markdown(label="Patient Chart", value="Click 'Start New Case' to begin.") | |
| cost_display = gr.Textbox(label="Total Cost", value="$0.00", interactive=False) | |
| with gr.Row(): | |
| # Tool Panel | |
| with gr.Column(): | |
| gr.Markdown("### Tools") | |
| with gr.Row(): | |
| ask_btn = gr.Button("Ask Question ($10)") | |
| exam_btn = gr.Button("Physical Exam ($25)") | |
| with gr.Row(): | |
| cbc_btn = gr.Button("Order CBC ($50)") | |
| xray_btn = gr.Button("Order X-Ray ($150)") | |
| with gr.Row(): | |
| med_btn = gr.Button("Administer Med ($30)") | |
| end_btn = gr.Button("End Case", variant="stop") # Red button for ending | |
| with gr.Row(): | |
| # Case Controls | |
| start_case_btn = gr.Button("Start New Case (General)") | |
| case_type_dropdown = gr.Dropdown(["General", "Psychiatry", "Pediatric", "Dual Diagnosis"], label="Case Type", value="General") | |
| # Event Handling for Medical Simulation Tab | |
| start_case_btn.click( | |
| fn=start_case, | |
| inputs=[case_type_dropdown, state], | |
| outputs=[state, chatbot, cost_display, patient_chart] | |
| ) | |
| submit_btn.click( | |
| fn=handle_chat, | |
| inputs=[user_input, chatbot, state], | |
| outputs=[chatbot, cost_display] | |
| ).then( | |
| fn=lambda: "", # Clear the input textbox after submission | |
| inputs=[], | |
| outputs=[user_input] | |
| ) | |
| ask_btn.click(fn=lambda s: use_tool("ask_question", s), inputs=[state], outputs=[state, chatbot, cost_display]) | |
| exam_btn.click(fn=lambda s: use_tool("physical_exam", s), inputs=[state], outputs=[state, chatbot, cost_display]) | |
| cbc_btn.click(fn=lambda s: use_tool("order_cbc", s), inputs=[state], outputs=[state, chatbot, cost_display]) | |
| xray_btn.click(fn=lambda s: use_tool("order_xray", s), inputs=[state], outputs=[state, chatbot, cost_display]) | |
| med_btn.click(fn=lambda s: use_tool("administer_med", s), inputs=[state], outputs=[state, chatbot, cost_display]) | |
| end_btn.click(fn=end_case, inputs=[state], outputs=[state, chatbot, cost_display, patient_chart]) | |
| with gr.Tab("AI Sandbox"): | |
| sandbox_chatbot = gr.Chatbot(label="General AI Chat", height=500, bubble_full_width=False) | |
| with gr.Row(): | |
| sandbox_input = gr.Textbox(label="Message", placeholder="Ask anything...", scale=4) | |
| sandbox_submit = gr.Button("Send", scale=1) | |
| sandbox_submit.click( | |
| fn=handle_sandbox_chat, | |
| inputs=[sandbox_input, sandbox_chatbot], | |
| outputs=[sandbox_chatbot] | |
| ).then( | |
| fn=lambda: "", # Clear the input textbox after submission | |
| inputs=[], | |
| outputs=[sandbox_input] | |
| ) | |
| # Launch the app | |
| # For Hugging Face Spaces, Gradio handles the launch. | |
| demo.launch() |