Jgjjg / app.py
Jpete20001's picture
Update app.py
244da05 verified
# app.py for Gradio App on Hugging Face Spaces (with Sandbox Tab)
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import random
import time
import json
from typing import List, Dict, Tuple, Optional
# --- Configuration ---
MODEL_NAME = "Qwen/Qwen3-VL-4B-Instruct-FP8"
# For Hugging Face Spaces, loading the model once here is common.
# Performance on CPU will be slower.
print("Loading model...")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16, # Use float16 for FP8 model if supported, or bfloat16
device_map="auto", # Automatically map to available devices (CPU or GPU)
trust_remote_code=True
).eval() # Set to evaluation mode
print("Model loaded successfully.")
except Exception as e:
print(f"Failed to load model: {e}")
model = None
tokenizer = None
# --- Simulated Cost Database ---
COSTS = {
"ask_question": 10.0,
"physical_exam": 25.0,
"order_cbc": 50.0,
"order_xray": 150.0,
"administer_med": 30.0,
"end_case": 0.0,
"start_case": 0.0 # Cost for starting is typically 0
}
# --- State Management Class ---
class MedicalSimulatorState:
def __init__(self):
self.patient_profile: Optional[Dict] = None
self.chat_history: List[Tuple[Optional[str], Optional[str]]] = [] # Gradio chat format: [(user_msg, bot_msg), ...]
self.vitals: Dict[str, float] = {"HR": 72.0, "BP_Sys": 120.0, "BP_Dia": 80.0, "Temp": 98.6, "O2_Sat": 98.0}
self.total_cost: float = 0.0
self.is_case_active: bool = False
self.underlying_diagnosis: str = ""
self.ordered_tests: Dict[str, str] = {} # e.g., {"cbc": "pending", "xray": "result..."}
# Add more state variables as needed
# --- Core AI Interaction Function (Medical Context) ---
def get_ai_response_medical(user_input: str, history: List[Tuple[Optional[str], Optional[str]]], patient_profile: Dict, underlying_diagnosis: str) -> str:
if not model or not tokenizer:
return "Error: AI model is not loaded."
# Construct a prompt for the AI based on history and patient profile
history_str = "\n".join([f"{'User' if h[0] else 'System/AI'}: {h[0] or h[1]}" for h in history])
context = f"Patient Profile: Name: {patient_profile['name']}, Age: {patient_profile['age']}, Gender: {patient_profile['gender']}, Chief Complaint: {patient_profile['chief_complaint']}, History: {patient_profile['history']}, Medications: {patient_profile['medications']}, Allergies: {patient_profile['allergies']}, Social History: {patient_profile['social_history']}, Financial Status: {patient_profile['financial_status']}, Code Status: {patient_profile['code_status']}\nCurrent Chat History:\n{history_str}\nUser Action/Question: {user_input}\n"
prompt = f"<|system|>You are an AI patient in a medical simulation. Role-play as the patient described in the profile. Be consistent with their history, demographics, and potential complaints. Respond to the user's input (which could be a question, exam instruction, or order). Your responses should simulate realistic patient dialogue and reactions, including potential anxiety or concerns. Do not reveal the secret diagnosis '{underlying_diagnosis}' directly, but your responses should be consistent with having that condition. Respond as if you are the patient speaking.<|user|>{context}<|assistant|>"
try:
inputs = tokenizer(prompt, return_tensors="pt")
if inputs["input_ids"].shape[1] > 32768: # Check for model max length
return "Error: Input prompt is too long for the model."
# Generate response
generate_ids = model.generate(
inputs.input_ids,
max_new_tokens=512, # Limit generated tokens
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
# Decode the generated response
response_text = tokenizer.decode(generate_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return response_text.strip()
except Exception as e:
print(f"Error during AI generation: {e}")
return f"An error occurred while processing the AI response: {e}"
# --- Core AI Interaction Function (General/Sandbox Context) ---
def get_ai_response_general(user_input: str, history: List[Tuple[Optional[str], Optional[str]]]) -> str:
if not model or not tokenizer:
return "Error: AI model is not loaded."
# Construct a prompt for the AI based on general chat history
history_str = "\n".join([f"{'User' if h[0] else 'Assistant'}: {h[0] or h[1]}" for h in history])
prompt = f"<|system|>You are a helpful assistant. Answer the user's questions and follow their instructions.<|user|>{history_str}\n{user_input}<|assistant|>"
try:
inputs = tokenizer(prompt, return_tensors="pt")
if inputs["input_ids"].shape[1] > 32768: # Check for model max length
return "Error: Input prompt is too long for the model."
# Generate response
generate_ids = model.generate(
inputs.input_ids,
max_new_tokens=512, # Limit generated tokens
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
# Decode the generated response
response_text = tokenizer.decode(generate_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return response_text.strip()
except Exception as e:
print(f"Error during AI generation: {e}")
return f"An error occurred while processing the AI response: {e}"
# --- Tool Functions (Modify State) ---
def start_case(case_type: str, state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str, str]:
# --- Generate Patient Profile (Simplified Example) ---
names = ["John Smith", "Emily Johnson", "Michael Brown", "Sarah Davis"]
chief_complaints = {
"General": ["Chest pain", "Shortness of breath", "Abdominal pain"],
"Pediatric": ["Fever", "Cough", "Ear ache"],
"Psychiatry": ["Feeling anxious", "Difficulty sleeping", "Low mood"],
"Dual Diagnosis": ["Chest pain and feels anxious", "Abdominal pain after drinking"]
}
complaint_options = chief_complaints.get(case_type, chief_complaints["General"])
name = random.choice(names)
age = random.randint(18, 80) if case_type != "Pediatric" else random.randint(0, 17)
gender = random.choice(["Male", "Female"])
chief_complaint = random.choice(complaint_options)
# Define underlying diagnosis based on complaint or case type
diag_map = {
"Chest pain": "Acute Myocardial Infarction",
"Shortness of breath": "Pneumonia",
"Abdominal pain": "Appendicitis",
"Fever": "Viral Infection",
"Cough": "Bronchitis",
"Ear ache": "Otitis Media",
"Feeling anxious": "Generalized Anxiety Disorder",
"Difficulty sleeping": "Insomnia",
"Low mood": "Major Depressive Disorder",
"Chest pain and feels anxious": "Acute MI with Anxiety",
"Abdominal pain after drinking": "Alcoholic Gastritis with Substance Use Disorder"
}
underlying_diagnosis = diag_map.get(chief_complaint, "Unknown Condition")
patient = {
"name": name,
"age": age,
"gender": gender,
"chief_complaint": chief_complaint,
"history": "Patient history relevant to complaint.",
"medications": "Current medications.",
"allergies": "Known allergies (e.g., Penicillin).",
"social_history": "Social history details.",
"financial_status": "Insurance status.",
"code_status": "Full Code",
"language": "English"
}
# Reset state
state.patient_profile = patient
state.chat_history = [("System", "New Case Started."), ("AI Patient", f"Hi, I'm {patient['name']}. I've been having {patient['chief_complaint']}.")]
state.vitals = {"HR": 72.0, "BP_Sys": 120.0, "BP_Dia": 80.0, "Temp": 98.6, "O2_Sat": 98.0}
state.total_cost = 0.0
state.is_case_active = True
state.underlying_diagnosis = underlying_diagnosis
state.ordered_tests = {}
profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in patient.items()])
return state, state.chat_history, f"${state.total_cost:.2f}", profile_str
def handle_chat(user_input: str, history: List[Tuple[Optional[str], Optional[str]]], state: MedicalSimulatorState) -> Tuple[List[Tuple[Optional[str], Optional[str]]], str]:
if not state.is_case_active or not user_input.strip():
return history, f"${state.total_cost:.2f}"
# Add user message to history
history.append((user_input, None))
# Get AI response
ai_response = get_ai_response_medical(user_input, history[:-1], state.patient_profile, state.underlying_diagnosis) # Pass history without the user's new message yet
# Add AI response to history
history[-1] = (user_input, ai_response) # Update the last entry with the AI's response
return history, f"${state.total_cost:.2f}" # Cost doesn't change here, just return current
def use_tool(tool_name: str, state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str]:
if not state.is_case_active:
return state, state.chat_history, f"${state.total_cost:.2f}"
cost = COSTS.get(tool_name, 0.0)
state.total_cost += cost
if tool_name == "ask_question":
ai_response = get_ai_response_medical("The user asks a general question to gather more history.", state.chat_history, state.patient_profile, state.underlying_diagnosis)
state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]"))
state.chat_history.append(("AI Patient", ai_response))
elif tool_name == "order_cbc":
state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]"))
state.chat_history.append(("Lab", "CBC Ordered. Result pending..."))
# Simulate result appearing after a delay or another action
# For now, add a simple result shortly after
time.sleep(0.5) # Simulate processing delay
state.chat_history.append(("Lab", "CBC Result: WBC slightly elevated, otherwise unremarkable."))
elif tool_name == "administer_med":
med_name = "Medication X" # Simplified, could take input
state.chat_history.append(("System", f"[Action: {tool_name} - {med_name}, Cost: ${cost:.2f}]"))
# Check for allergies here in a real app
state.chat_history.append(("AI Patient", f"Okay, I took the {med_name}."))
elif tool_name == "physical_exam":
state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]"))
state.chat_history.append(("System", "Physical Exam Performed. Findings noted."))
elif tool_name == "order_xray":
state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]"))
state.chat_history.append(("Imaging", "X-Ray Ordered. Result pending..."))
# Placeholder for image result (could be a URL or base64 string)
time.sleep(0.5) # Simulate processing delay
state.chat_history.append(("Imaging", "Chest X-Ray Result: Normal lung fields, no acute findings. (Placeholder Image)"))
# Add other tools as needed...
return state, state.chat_history, f"${state.total_cost:.2f}"
def end_case(state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str, str]:
if not state.is_case_active:
# Return current state if no case is active
profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in (state.patient_profile or {}).items()])
return state, state.chat_history, f"${state.total_cost:.2f}", profile_str
state.chat_history.append(("System", "Case Ended by User."))
state.is_case_active = False
# In a full implementation, trigger the evaluation logic here
# evaluation = run_evaluation(state) # Placeholder
# state.chat_history.append(("System", f"Evaluation: {evaluation}")) # Add evaluation to chat or separate component
profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in (state.patient_profile or {}).items()])
return state, state.chat_history, f"${state.total_cost:.2f}", profile_str
# --- Sandbox Chat Handler ---
def handle_sandbox_chat(user_input: str, history: List[Tuple[Optional[str], Optional[str]]]) -> List[Tuple[Optional[str], Optional[str]]]:
if not user_input.strip():
return history
# Add user message to history
history.append((user_input, None))
# Get AI response (general context)
ai_response = get_ai_response_general(user_input, history[:-1])
# Add AI response to history
history[-1] = (user_input, ai_response)
return history
# --- Gradio Interface ---
with gr.Blocks(title="Advanced Medical Simulator") as demo:
gr.Markdown("# Advanced Medical Simulator")
with gr.Tab("Medical Simulation"):
# State component to hold the simulator state across interactions
state = gr.State(lambda: MedicalSimulatorState())
with gr.Row():
with gr.Column(scale=2):
# Chat Interface
chatbot = gr.Chatbot(label="Patient Interaction", height=400, bubble_full_width=False)
with gr.Row():
user_input = gr.Textbox(label="Your Action / Question", placeholder="Type your action or question here...", scale=4)
submit_btn = gr.Button("Submit", scale=1)
with gr.Column(scale=1):
# Patient Chart / Info
patient_chart = gr.Markdown(label="Patient Chart", value="Click 'Start New Case' to begin.")
cost_display = gr.Textbox(label="Total Cost", value="$0.00", interactive=False)
with gr.Row():
# Tool Panel
with gr.Column():
gr.Markdown("### Tools")
with gr.Row():
ask_btn = gr.Button("Ask Question ($10)")
exam_btn = gr.Button("Physical Exam ($25)")
with gr.Row():
cbc_btn = gr.Button("Order CBC ($50)")
xray_btn = gr.Button("Order X-Ray ($150)")
with gr.Row():
med_btn = gr.Button("Administer Med ($30)")
end_btn = gr.Button("End Case", variant="stop") # Red button for ending
with gr.Row():
# Case Controls
start_case_btn = gr.Button("Start New Case (General)")
case_type_dropdown = gr.Dropdown(["General", "Psychiatry", "Pediatric", "Dual Diagnosis"], label="Case Type", value="General")
# Event Handling for Medical Simulation Tab
start_case_btn.click(
fn=start_case,
inputs=[case_type_dropdown, state],
outputs=[state, chatbot, cost_display, patient_chart]
)
submit_btn.click(
fn=handle_chat,
inputs=[user_input, chatbot, state],
outputs=[chatbot, cost_display]
).then(
fn=lambda: "", # Clear the input textbox after submission
inputs=[],
outputs=[user_input]
)
ask_btn.click(fn=lambda s: use_tool("ask_question", s), inputs=[state], outputs=[state, chatbot, cost_display])
exam_btn.click(fn=lambda s: use_tool("physical_exam", s), inputs=[state], outputs=[state, chatbot, cost_display])
cbc_btn.click(fn=lambda s: use_tool("order_cbc", s), inputs=[state], outputs=[state, chatbot, cost_display])
xray_btn.click(fn=lambda s: use_tool("order_xray", s), inputs=[state], outputs=[state, chatbot, cost_display])
med_btn.click(fn=lambda s: use_tool("administer_med", s), inputs=[state], outputs=[state, chatbot, cost_display])
end_btn.click(fn=end_case, inputs=[state], outputs=[state, chatbot, cost_display, patient_chart])
with gr.Tab("AI Sandbox"):
sandbox_chatbot = gr.Chatbot(label="General AI Chat", height=500, bubble_full_width=False)
with gr.Row():
sandbox_input = gr.Textbox(label="Message", placeholder="Ask anything...", scale=4)
sandbox_submit = gr.Button("Send", scale=1)
sandbox_submit.click(
fn=handle_sandbox_chat,
inputs=[sandbox_input, sandbox_chatbot],
outputs=[sandbox_chatbot]
).then(
fn=lambda: "", # Clear the input textbox after submission
inputs=[],
outputs=[sandbox_input]
)
# Launch the app
# For Hugging Face Spaces, Gradio handles the launch.
demo.launch()