File size: 16,940 Bytes
244da05
d9cba7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244da05
 
d9cba7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244da05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9cba7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244da05
d9cba7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244da05
d9cba7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244da05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9cba7a
 
 
 
 
244da05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9cba7a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
# app.py for Gradio App on Hugging Face Spaces (with Sandbox Tab)
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import random
import time
import json
from typing import List, Dict, Tuple, Optional

# --- Configuration ---
MODEL_NAME = "Qwen/Qwen3-VL-4B-Instruct-FP8"
# For Hugging Face Spaces, loading the model once here is common.
# Performance on CPU will be slower.
print("Loading model...")
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16, # Use float16 for FP8 model if supported, or bfloat16
        device_map="auto", # Automatically map to available devices (CPU or GPU)
        trust_remote_code=True
    ).eval() # Set to evaluation mode
    print("Model loaded successfully.")
except Exception as e:
    print(f"Failed to load model: {e}")
    model = None
    tokenizer = None

# --- Simulated Cost Database ---
COSTS = {
    "ask_question": 10.0,
    "physical_exam": 25.0,
    "order_cbc": 50.0,
    "order_xray": 150.0,
    "administer_med": 30.0,
    "end_case": 0.0,
    "start_case": 0.0 # Cost for starting is typically 0
}

# --- State Management Class ---
class MedicalSimulatorState:
    def __init__(self):
        self.patient_profile: Optional[Dict] = None
        self.chat_history: List[Tuple[Optional[str], Optional[str]]] = [] # Gradio chat format: [(user_msg, bot_msg), ...]
        self.vitals: Dict[str, float] = {"HR": 72.0, "BP_Sys": 120.0, "BP_Dia": 80.0, "Temp": 98.6, "O2_Sat": 98.0}
        self.total_cost: float = 0.0
        self.is_case_active: bool = False
        self.underlying_diagnosis: str = ""
        self.ordered_tests: Dict[str, str] = {} # e.g., {"cbc": "pending", "xray": "result..."}
        # Add more state variables as needed

# --- Core AI Interaction Function (Medical Context) ---
def get_ai_response_medical(user_input: str, history: List[Tuple[Optional[str], Optional[str]]], patient_profile: Dict, underlying_diagnosis: str) -> str:
    if not model or not tokenizer:
        return "Error: AI model is not loaded."

    # Construct a prompt for the AI based on history and patient profile
    history_str = "\n".join([f"{'User' if h[0] else 'System/AI'}: {h[0] or h[1]}" for h in history])
    context = f"Patient Profile: Name: {patient_profile['name']}, Age: {patient_profile['age']}, Gender: {patient_profile['gender']}, Chief Complaint: {patient_profile['chief_complaint']}, History: {patient_profile['history']}, Medications: {patient_profile['medications']}, Allergies: {patient_profile['allergies']}, Social History: {patient_profile['social_history']}, Financial Status: {patient_profile['financial_status']}, Code Status: {patient_profile['code_status']}\nCurrent Chat History:\n{history_str}\nUser Action/Question: {user_input}\n"
    prompt = f"<|system|>You are an AI patient in a medical simulation. Role-play as the patient described in the profile. Be consistent with their history, demographics, and potential complaints. Respond to the user's input (which could be a question, exam instruction, or order). Your responses should simulate realistic patient dialogue and reactions, including potential anxiety or concerns. Do not reveal the secret diagnosis '{underlying_diagnosis}' directly, but your responses should be consistent with having that condition. Respond as if you are the patient speaking.<|user|>{context}<|assistant|>"

    try:
        inputs = tokenizer(prompt, return_tensors="pt")
        if inputs["input_ids"].shape[1] > 32768: # Check for model max length
             return "Error: Input prompt is too long for the model."

        # Generate response
        generate_ids = model.generate(
            inputs.input_ids,
            max_new_tokens=512, # Limit generated tokens
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id
        )

        # Decode the generated response
        response_text = tokenizer.decode(generate_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
        return response_text.strip()

    except Exception as e:
        print(f"Error during AI generation: {e}")
        return f"An error occurred while processing the AI response: {e}"

# --- Core AI Interaction Function (General/Sandbox Context) ---
def get_ai_response_general(user_input: str, history: List[Tuple[Optional[str], Optional[str]]]) -> str:
    if not model or not tokenizer:
        return "Error: AI model is not loaded."

    # Construct a prompt for the AI based on general chat history
    history_str = "\n".join([f"{'User' if h[0] else 'Assistant'}: {h[0] or h[1]}" for h in history])
    prompt = f"<|system|>You are a helpful assistant. Answer the user's questions and follow their instructions.<|user|>{history_str}\n{user_input}<|assistant|>"

    try:
        inputs = tokenizer(prompt, return_tensors="pt")
        if inputs["input_ids"].shape[1] > 32768: # Check for model max length
             return "Error: Input prompt is too long for the model."

        # Generate response
        generate_ids = model.generate(
            inputs.input_ids,
            max_new_tokens=512, # Limit generated tokens
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id
        )

        # Decode the generated response
        response_text = tokenizer.decode(generate_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
        return response_text.strip()

    except Exception as e:
        print(f"Error during AI generation: {e}")
        return f"An error occurred while processing the AI response: {e}"


# --- Tool Functions (Modify State) ---
def start_case(case_type: str, state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str, str]:
    # --- Generate Patient Profile (Simplified Example) ---
    names = ["John Smith", "Emily Johnson", "Michael Brown", "Sarah Davis"]
    chief_complaints = {
        "General": ["Chest pain", "Shortness of breath", "Abdominal pain"],
        "Pediatric": ["Fever", "Cough", "Ear ache"],
        "Psychiatry": ["Feeling anxious", "Difficulty sleeping", "Low mood"],
        "Dual Diagnosis": ["Chest pain and feels anxious", "Abdominal pain after drinking"]
    }
    complaint_options = chief_complaints.get(case_type, chief_complaints["General"])
    name = random.choice(names)
    age = random.randint(18, 80) if case_type != "Pediatric" else random.randint(0, 17)
    gender = random.choice(["Male", "Female"])
    chief_complaint = random.choice(complaint_options)
    # Define underlying diagnosis based on complaint or case type
    diag_map = {
        "Chest pain": "Acute Myocardial Infarction",
        "Shortness of breath": "Pneumonia",
        "Abdominal pain": "Appendicitis",
        "Fever": "Viral Infection",
        "Cough": "Bronchitis",
        "Ear ache": "Otitis Media",
        "Feeling anxious": "Generalized Anxiety Disorder",
        "Difficulty sleeping": "Insomnia",
        "Low mood": "Major Depressive Disorder",
        "Chest pain and feels anxious": "Acute MI with Anxiety",
        "Abdominal pain after drinking": "Alcoholic Gastritis with Substance Use Disorder"
    }
    underlying_diagnosis = diag_map.get(chief_complaint, "Unknown Condition")

    patient = {
        "name": name,
        "age": age,
        "gender": gender,
        "chief_complaint": chief_complaint,
        "history": "Patient history relevant to complaint.",
        "medications": "Current medications.",
        "allergies": "Known allergies (e.g., Penicillin).",
        "social_history": "Social history details.",
        "financial_status": "Insurance status.",
        "code_status": "Full Code",
        "language": "English"
    }

    # Reset state
    state.patient_profile = patient
    state.chat_history = [("System", "New Case Started."), ("AI Patient", f"Hi, I'm {patient['name']}. I've been having {patient['chief_complaint']}.")]
    state.vitals = {"HR": 72.0, "BP_Sys": 120.0, "BP_Dia": 80.0, "Temp": 98.6, "O2_Sat": 98.0}
    state.total_cost = 0.0
    state.is_case_active = True
    state.underlying_diagnosis = underlying_diagnosis
    state.ordered_tests = {}

    profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in patient.items()])
    return state, state.chat_history, f"${state.total_cost:.2f}", profile_str


def handle_chat(user_input: str, history: List[Tuple[Optional[str], Optional[str]]], state: MedicalSimulatorState) -> Tuple[List[Tuple[Optional[str], Optional[str]]], str]:
    if not state.is_case_active or not user_input.strip():
        return history, f"${state.total_cost:.2f}"

    # Add user message to history
    history.append((user_input, None))

    # Get AI response
    ai_response = get_ai_response_medical(user_input, history[:-1], state.patient_profile, state.underlying_diagnosis) # Pass history without the user's new message yet

    # Add AI response to history
    history[-1] = (user_input, ai_response) # Update the last entry with the AI's response

    return history, f"${state.total_cost:.2f}" # Cost doesn't change here, just return current


def use_tool(tool_name: str, state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str]:
    if not state.is_case_active:
        return state, state.chat_history, f"${state.total_cost:.2f}"

    cost = COSTS.get(tool_name, 0.0)
    state.total_cost += cost

    if tool_name == "ask_question":
        ai_response = get_ai_response_medical("The user asks a general question to gather more history.", state.chat_history, state.patient_profile, state.underlying_diagnosis)
        state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]"))
        state.chat_history.append(("AI Patient", ai_response))

    elif tool_name == "order_cbc":
        state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]"))
        state.chat_history.append(("Lab", "CBC Ordered. Result pending..."))
        # Simulate result appearing after a delay or another action
        # For now, add a simple result shortly after
        time.sleep(0.5) # Simulate processing delay
        state.chat_history.append(("Lab", "CBC Result: WBC slightly elevated, otherwise unremarkable."))

    elif tool_name == "administer_med":
        med_name = "Medication X" # Simplified, could take input
        state.chat_history.append(("System", f"[Action: {tool_name} - {med_name}, Cost: ${cost:.2f}]"))
        # Check for allergies here in a real app
        state.chat_history.append(("AI Patient", f"Okay, I took the {med_name}."))

    elif tool_name == "physical_exam":
         state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]"))
         state.chat_history.append(("System", "Physical Exam Performed. Findings noted."))

    elif tool_name == "order_xray":
         state.chat_history.append(("System", f"[Action: {tool_name}, Cost: ${cost:.2f}]"))
         state.chat_history.append(("Imaging", "X-Ray Ordered. Result pending..."))
         # Placeholder for image result (could be a URL or base64 string)
         time.sleep(0.5) # Simulate processing delay
         state.chat_history.append(("Imaging", "Chest X-Ray Result: Normal lung fields, no acute findings. (Placeholder Image)"))

    # Add other tools as needed...

    return state, state.chat_history, f"${state.total_cost:.2f}"


def end_case(state: MedicalSimulatorState) -> Tuple[MedicalSimulatorState, List[Tuple[Optional[str], Optional[str]]], str, str]:
    if not state.is_case_active:
        # Return current state if no case is active
        profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in (state.patient_profile or {}).items()])
        return state, state.chat_history, f"${state.total_cost:.2f}", profile_str

    state.chat_history.append(("System", "Case Ended by User."))
    state.is_case_active = False

    # In a full implementation, trigger the evaluation logic here
    # evaluation = run_evaluation(state) # Placeholder
    # state.chat_history.append(("System", f"Evaluation: {evaluation}")) # Add evaluation to chat or separate component

    profile_str = "\n".join([f"{k.replace('_', ' ').title()}: {v}" for k, v in (state.patient_profile or {}).items()])
    return state, state.chat_history, f"${state.total_cost:.2f}", profile_str

# --- Sandbox Chat Handler ---
def handle_sandbox_chat(user_input: str, history: List[Tuple[Optional[str], Optional[str]]]) -> List[Tuple[Optional[str], Optional[str]]]:
    if not user_input.strip():
        return history

    # Add user message to history
    history.append((user_input, None))

    # Get AI response (general context)
    ai_response = get_ai_response_general(user_input, history[:-1])

    # Add AI response to history
    history[-1] = (user_input, ai_response)

    return history


# --- Gradio Interface ---
with gr.Blocks(title="Advanced Medical Simulator") as demo:
    gr.Markdown("# Advanced Medical Simulator")

    with gr.Tab("Medical Simulation"):
        # State component to hold the simulator state across interactions
        state = gr.State(lambda: MedicalSimulatorState())

        with gr.Row():
            with gr.Column(scale=2):
                # Chat Interface
                chatbot = gr.Chatbot(label="Patient Interaction", height=400, bubble_full_width=False)
                with gr.Row():
                    user_input = gr.Textbox(label="Your Action / Question", placeholder="Type your action or question here...", scale=4)
                    submit_btn = gr.Button("Submit", scale=1)

            with gr.Column(scale=1):
                # Patient Chart / Info
                patient_chart = gr.Markdown(label="Patient Chart", value="Click 'Start New Case' to begin.")
                cost_display = gr.Textbox(label="Total Cost", value="$0.00", interactive=False)

        with gr.Row():
            # Tool Panel
            with gr.Column():
                gr.Markdown("### Tools")
                with gr.Row():
                    ask_btn = gr.Button("Ask Question ($10)")
                    exam_btn = gr.Button("Physical Exam ($25)")
                with gr.Row():
                    cbc_btn = gr.Button("Order CBC ($50)")
                    xray_btn = gr.Button("Order X-Ray ($150)")
                with gr.Row():
                     med_btn = gr.Button("Administer Med ($30)")
                     end_btn = gr.Button("End Case", variant="stop") # Red button for ending

        with gr.Row():
            # Case Controls
            start_case_btn = gr.Button("Start New Case (General)")
            case_type_dropdown = gr.Dropdown(["General", "Psychiatry", "Pediatric", "Dual Diagnosis"], label="Case Type", value="General")

        # Event Handling for Medical Simulation Tab
        start_case_btn.click(
            fn=start_case,
            inputs=[case_type_dropdown, state],
            outputs=[state, chatbot, cost_display, patient_chart]
        )

        submit_btn.click(
            fn=handle_chat,
            inputs=[user_input, chatbot, state],
            outputs=[chatbot, cost_display]
        ).then(
             fn=lambda: "", # Clear the input textbox after submission
             inputs=[],
             outputs=[user_input]
        )

        ask_btn.click(fn=lambda s: use_tool("ask_question", s), inputs=[state], outputs=[state, chatbot, cost_display])
        exam_btn.click(fn=lambda s: use_tool("physical_exam", s), inputs=[state], outputs=[state, chatbot, cost_display])
        cbc_btn.click(fn=lambda s: use_tool("order_cbc", s), inputs=[state], outputs=[state, chatbot, cost_display])
        xray_btn.click(fn=lambda s: use_tool("order_xray", s), inputs=[state], outputs=[state, chatbot, cost_display])
        med_btn.click(fn=lambda s: use_tool("administer_med", s), inputs=[state], outputs=[state, chatbot, cost_display])
        end_btn.click(fn=end_case, inputs=[state], outputs=[state, chatbot, cost_display, patient_chart])

    with gr.Tab("AI Sandbox"):
        sandbox_chatbot = gr.Chatbot(label="General AI Chat", height=500, bubble_full_width=False)
        with gr.Row():
            sandbox_input = gr.Textbox(label="Message", placeholder="Ask anything...", scale=4)
            sandbox_submit = gr.Button("Send", scale=1)

        sandbox_submit.click(
            fn=handle_sandbox_chat,
            inputs=[sandbox_input, sandbox_chatbot],
            outputs=[sandbox_chatbot]
        ).then(
             fn=lambda: "", # Clear the input textbox after submission
             inputs=[],
             outputs=[sandbox_input]
        )

# Launch the app
# For Hugging Face Spaces, Gradio handles the launch.
demo.launch()