Spaces:
Sleeping
Sleeping
File size: 16,940 Bytes
244da05 d9cba7a 244da05 d9cba7a 244da05 d9cba7a 244da05 d9cba7a 244da05 d9cba7a 244da05 d9cba7a 244da05 d9cba7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 |
# app.py for Gradio App on Hugging Face Spaces (with Sandbox Tab)
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import random
import time
import json
from typing import List, Dict, Tuple, Optional
# --- Configuration ---
MODEL_NAME = "Qwen/Qwen3-VL-4B-Instruct-FP8"
# For Hugging Face Spaces, loading the model once here is common.
# Performance on CPU will be slower.
print("Loading model...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16, # Use float16 for FP8 model if supported, or bfloat16
device_map="auto", # Automatically map to available devices (CPU or GPU)
trust_remote_code=True
).eval() # Set to evaluation mode
print("Model loaded successfully.")
except Exception as e:
print(f"Failed to load model: {e}")
model = None
tokenizer = None
# --- Simulated Cost Database ---
COSTS = {
"ask_question": 10.0,
"physical_exam": 25.0,
"order_cbc": 50.0,
"order_xray": 150.0,
"administer_med": 30.0,
"end_case": 0.0,
"start_case": 0.0 # Cost for starting is typically 0
}
# --- State Management Class ---
class MedicalSimulatorState:
def __init__(self):
self.patient_profile: Optional[Dict] = None
self.chat_history: List[Tuple[Optional[str], Optional[str]]] = [] # Gradio chat format: [(user_msg, bot_msg), ...]
self.vitals: Dict[str, float] = {"HR": 72.0, "BP_Sys": 120.0, "BP_Dia": 80.0, "Temp": 98.6, "O2_Sat": 98.0}
self.total_cost: float = 0.0
self.is_case_active: bool = False
self.underlying_diagnosis: str = ""
self.ordered_tests: Dict[str, str] = {} # e.g., {"cbc": "pending", "xray": "result..."}
# Add more state variables as needed
# --- Core AI Interaction Function (Medical Context) ---
def get_ai_response_medical(user_input: str, history: List[Tuple[Optional[str], Optional[str]]], patient_profile: Dict, underlying_diagnosis: str) -> str:
if not model or not tokenizer:
return "Error: AI model is not loaded."
# Construct a prompt for the AI based on history and patient profile
history_str = "\n".join([f"{'User' if h[0] else 'System/AI'}: {h[0] or h[1]}" for h in history])
context = f"Patient Profile: Name: {patient_profile['name']}, Age: {patient_profile['age']}, Gender: {patient_profile['gender']}, Chief Complaint: {patient_profile['chief_complaint']}, History: {patient_profile['history']}, Medications: {patient_profile['medications']}, Allergies: {patient_profile['allergies']}, Social History: {patient_profile['social_history']}, Financial Status: {patient_profile['financial_status']}, Code Status: {patient_profile['code_status']}\nCurrent Chat History:\n{history_str}\nUser Action/Question: {user_input}\n"
prompt = f"<|system|>You are an AI patient in a medical simulation. Role-play as the patient described in the profile. Be consistent with their history, demographics, and potential complaints. Respond to the user's input (which could be a question, exam instruction, or order). Your responses should simulate realistic patient dialogue and reactions, including potential anxiety or concerns. Do not reveal the secret diagnosis '{underlying_diagnosis}' directly, but your responses should be consistent with having that condition. Respond as if you are the patient speaking.<|user|>{context}<|assistant|>"
try:
inputs = tokenizer(prompt, return_tensors="pt")
if inputs["input_ids"].shape[1] > 32768: # Check for model max length
return "Error: Input prompt is too long for the model."
# Generate response
generate_ids = model.generate(
inputs.input_ids,
max_new_tokens=512, # Limit generated tokens
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
# Decode the generated response
response_text = tokenizer.decode(generate_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return response_text.strip()
except Exception as e:
print(f"Error during AI generation: {e}")
return f"An error occurred while processing the AI response: {e}"
# --- Core AI Interaction Function (General/Sandbox Context) ---
def get_ai_response_general(user_input: str, history: List[Tuple[Optional[str], Optional[str]]]) -> str:
if not model or not tokenizer:
return "Error: AI model is not loaded."
# Construct a prompt for the AI based on general chat history
history_str = "\n".join([f"{'User' if h[0] else 'Assistant'}: {h[0] or h[1]}" for h in history])
prompt = f"<|system|>You are a helpful assistant. Answer the user's questions and follow their instructions.<|user|>{history_str}\n{user_input}<|assistant|>"
try:
inputs = tokenizer(prompt, return_tensors="pt")
if inputs["input_ids"].shape[1] > 32768: # Check for model max length
return "Error: Input prompt is too long for the model."
# Generate response
generate_ids = model.generate(
inputs.input_ids,
max_new_tokens=512, # Limit generated tokens
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
# Decode the generated response
response_text = tokenizer.decode(generate_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return response_text.strip()
except Exception as e:
print(f"Error during AI generation: {e}")
return f"An error occurred while processing the AI response: {e}"
# --- Tool Functions (Modify State) ---
def start_case(case_type: str, state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str, str]:
# --- Generate Patient Profile (Simplified Example) ---
names = ["John Smith", "Emily Johnson", "Michael Brown", "Sarah Davis"]
chief_complaints = {
"General": ["Chest pain", "Shortness of breath", "Abdominal pain"],
"Pediatric": ["Fever", "Cough", "Ear ache"],
"Psychiatry": ["Feeling anxious", "Difficulty sleeping", "Low mood"],
"Dual Diagnosis": ["Chest pain and feels anxious", "Abdominal pain after drinking"]
}
complaint_options = chief_complaints.get(case_type, chief_complaints["General"])
name = random.choice(names)
age = random.randint(18, 80) if case_type != "Pediatric" else random.randint(0, 17)
gender = random.choice(["Male", "Female"])
chief_complaint = random.choice(complaint_options)
# Define underlying diagnosis based on complaint or case type
diag_map = {
"Chest pain": "Acute Myocardial Infarction",
"Shortness of breath": "Pneumonia",
"Abdominal pain": "Appendicitis",
"Fever": "Viral Infection",
"Cough": "Bronchitis",
"Ear ache": "Otitis Media",
"Feeling anxious": "Generalized Anxiety Disorder",
"Difficulty sleeping": "Insomnia",
"Low mood": "Major Depressive Disorder",
"Chest pain and feels anxious": "Acute MI with Anxiety",
"Abdominal pain after drinking": "Alcoholic Gastritis with Substance Use Disorder"
}
underlying_diagnosis = diag_map.get(chief_complaint, "Unknown Condition")
patient = {
"name": name,
"age": age,
"gender": gender,
"chief_complaint": chief_complaint,
"history": "Patient history relevant to complaint.",
"medications": "Current medications.",
"allergies": "Known allergies (e.g., Penicillin).",
"social_history": "Social history details.",
"financial_status": "Insurance status.",
"code_status": "Full Code",
"language": "English"
}
# Reset state
state.patient_profile = patient
state.chat_history = [("System", "New Case Started."), ("AI Patient", f"Hi, I'm {patient['name']}. I've been having {patient['chief_complaint']}.")]
state.vitals = {"HR": 72.0, "BP_Sys": 120.0, "BP_Dia": 80.0, "Temp": 98.6, "O2_Sat": 98.0}
state.total_cost = 0.0
state.is_case_active = True
state.underlying_diagnosis = underlying_diagnosis
state.ordered_tests = {}
profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in patient.items()])
return state, state.chat_history, f"${state.total_cost:.2f}", profile_str
def handle_chat(user_input: str, history: List[Tuple[Optional[str], Optional[str]]], state: MedicalSimulatorState) -> Tuple[List[Tuple[Optional[str], Optional[str]]], str]:
if not state.is_case_active or not user_input.strip():
return history, f"${state.total_cost:.2f}"
# Add user message to history
history.append((user_input, None))
# Get AI response
ai_response = get_ai_response_medical(user_input, history[:-1], state.patient_profile, state.underlying_diagnosis) # Pass history without the user's new message yet
# Add AI response to history
history[-1] = (user_input, ai_response) # Update the last entry with the AI's response
return history, f"${state.total_cost:.2f}" # Cost doesn't change here, just return current
def use_tool(tool_name: str, state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str]:
if not state.is_case_active:
return state, state.chat_history, f"${state.total_cost:.2f}"
cost = COSTS.get(tool_name, 0.0)
state.total_cost += cost
if tool_name == "ask_question":
ai_response = get_ai_response_medical("The user asks a general question to gather more history.", state.chat_history, state.patient_profile, state.underlying_diagnosis)
state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]"))
state.chat_history.append(("AI Patient", ai_response))
elif tool_name == "order_cbc":
state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]"))
state.chat_history.append(("Lab", "CBC Ordered. Result pending..."))
# Simulate result appearing after a delay or another action
# For now, add a simple result shortly after
time.sleep(0.5) # Simulate processing delay
state.chat_history.append(("Lab", "CBC Result: WBC slightly elevated, otherwise unremarkable."))
elif tool_name == "administer_med":
med_name = "Medication X" # Simplified, could take input
state.chat_history.append(("System", f"[Action: {tool_name} - {med_name}, Cost: ${cost:.2f}]"))
# Check for allergies here in a real app
state.chat_history.append(("AI Patient", f"Okay, I took the {med_name}."))
elif tool_name == "physical_exam":
state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]"))
state.chat_history.append(("System", "Physical Exam Performed. Findings noted."))
elif tool_name == "order_xray":
state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]"))
state.chat_history.append(("Imaging", "X-Ray Ordered. Result pending..."))
# Placeholder for image result (could be a URL or base64 string)
time.sleep(0.5) # Simulate processing delay
state.chat_history.append(("Imaging", "Chest X-Ray Result: Normal lung fields, no acute findings. (Placeholder Image)"))
# Add other tools as needed...
return state, state.chat_history, f"${state.total_cost:.2f}"
def end_case(state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str, str]:
if not state.is_case_active:
# Return current state if no case is active
profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in (state.patient_profile or {}).items()])
return state, state.chat_history, f"${state.total_cost:.2f}", profile_str
state.chat_history.append(("System", "Case Ended by User."))
state.is_case_active = False
# In a full implementation, trigger the evaluation logic here
# evaluation = run_evaluation(state) # Placeholder
# state.chat_history.append(("System", f"Evaluation: {evaluation}")) # Add evaluation to chat or separate component
profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in (state.patient_profile or {}).items()])
return state, state.chat_history, f"${state.total_cost:.2f}", profile_str
# --- Sandbox Chat Handler ---
def handle_sandbox_chat(user_input: str, history: List[Tuple[Optional[str], Optional[str]]]) -> List[Tuple[Optional[str], Optional[str]]]:
if not user_input.strip():
return history
# Add user message to history
history.append((user_input, None))
# Get AI response (general context)
ai_response = get_ai_response_general(user_input, history[:-1])
# Add AI response to history
history[-1] = (user_input, ai_response)
return history
# --- Gradio Interface ---
with gr.Blocks(title="Advanced Medical Simulator") as demo:
gr.Markdown("# Advanced Medical Simulator")
with gr.Tab("Medical Simulation"):
# State component to hold the simulator state across interactions
state = gr.State(lambda: MedicalSimulatorState())
with gr.Row():
with gr.Column(scale=2):
# Chat Interface
chatbot = gr.Chatbot(label="Patient Interaction", height=400, bubble_full_width=False)
with gr.Row():
user_input = gr.Textbox(label="Your Action / Question", placeholder="Type your action or question here...", scale=4)
submit_btn = gr.Button("Submit", scale=1)
with gr.Column(scale=1):
# Patient Chart / Info
patient_chart = gr.Markdown(label="Patient Chart", value="Click 'Start New Case' to begin.")
cost_display = gr.Textbox(label="Total Cost", value="$0.00", interactive=False)
with gr.Row():
# Tool Panel
with gr.Column():
gr.Markdown("### Tools")
with gr.Row():
ask_btn = gr.Button("Ask Question ($10)")
exam_btn = gr.Button("Physical Exam ($25)")
with gr.Row():
cbc_btn = gr.Button("Order CBC ($50)")
xray_btn = gr.Button("Order X-Ray ($150)")
with gr.Row():
med_btn = gr.Button("Administer Med ($30)")
end_btn = gr.Button("End Case", variant="stop") # Red button for ending
with gr.Row():
# Case Controls
start_case_btn = gr.Button("Start New Case (General)")
case_type_dropdown = gr.Dropdown(["General", "Psychiatry", "Pediatric", "Dual Diagnosis"], label="Case Type", value="General")
# Event Handling for Medical Simulation Tab
start_case_btn.click(
fn=start_case,
inputs=[case_type_dropdown, state],
outputs=[state, chatbot, cost_display, patient_chart]
)
submit_btn.click(
fn=handle_chat,
inputs=[user_input, chatbot, state],
outputs=[chatbot, cost_display]
).then(
fn=lambda: "", # Clear the input textbox after submission
inputs=[],
outputs=[user_input]
)
ask_btn.click(fn=lambda s: use_tool("ask_question", s), inputs=[state], outputs=[state, chatbot, cost_display])
exam_btn.click(fn=lambda s: use_tool("physical_exam", s), inputs=[state], outputs=[state, chatbot, cost_display])
cbc_btn.click(fn=lambda s: use_tool("order_cbc", s), inputs=[state], outputs=[state, chatbot, cost_display])
xray_btn.click(fn=lambda s: use_tool("order_xray", s), inputs=[state], outputs=[state, chatbot, cost_display])
med_btn.click(fn=lambda s: use_tool("administer_med", s), inputs=[state], outputs=[state, chatbot, cost_display])
end_btn.click(fn=end_case, inputs=[state], outputs=[state, chatbot, cost_display, patient_chart])
with gr.Tab("AI Sandbox"):
sandbox_chatbot = gr.Chatbot(label="General AI Chat", height=500, bubble_full_width=False)
with gr.Row():
sandbox_input = gr.Textbox(label="Message", placeholder="Ask anything...", scale=4)
sandbox_submit = gr.Button("Send", scale=1)
sandbox_submit.click(
fn=handle_sandbox_chat,
inputs=[sandbox_input, sandbox_chatbot],
outputs=[sandbox_chatbot]
).then(
fn=lambda: "", # Clear the input textbox after submission
inputs=[],
outputs=[sandbox_input]
)
# Launch the app
# For Hugging Face Spaces, Gradio handles the launch.
demo.launch() |