AdnanSadi's picture
Update app.py
49c815e verified
# -*- coding: utf-8 -*-
"""
@author: adnan-sadi
"""
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import set_seed
import gradio as gr
import torch
import numpy as np
import os
seed = 17
set_seed(seed)
token = os.getenv("hf_token")
login(token)
# Disease labels
labels = ['URTI', 'HIV (initial infection)', 'Pneumonia', 'Chronic rhinosinusitis', 'Viral pharyngitis', 'Anemia',
'Atrial fibrillation', 'Allergic sinusitis', 'Laryngospasm', 'Cluster headache', 'Anaphylaxis',
'Spontaneous pneumothorax', 'Acute pulmonary edema', 'Tuberculosis', 'Myasthenia gravis', 'Panic attack',
'Scombroid food poisoning', 'Epiglottitis', 'Inguinal hernia', 'Boerhaave', 'Pancreatic neoplasm', 'Bronchitis',
'SLE', 'Acute laryngitis', 'Unstable angina', 'Bronchiectasis', 'Possible NSTEMI / STEMI', 'Chagas',
'Localized edema', 'Sarcoidosis', 'Spontaneous rib fracture', 'GERD', 'Bronchospasm / acute asthma exacerbation',
'Acute COPD exacerbation / infection', 'Guillain-Barré syndrome', 'Influenza', 'Pulmonary embolism',
'Stable angina', 'Pericarditis', 'Acute rhinosinusitis', 'Whooping cough', 'Myocarditis', 'Acute dystonic reactions',
'Pulmonary neoplasm', 'Acute otitis media', 'PSVT', 'Croup', 'Ebola', 'Bronchiolitis']
label2id = {label:idx for idx, label in enumerate(labels)}
id2label = {idx:label for label, idx in label2id.items()}
# Function for getting model and tokenizer
def get_model_and_tokenizer(model_name):
model_dict = {
"bert-base" : "AdnanSadi/Bert_DDXPlus_1",
"distilbert-base" : "AdnanSadi/DistilBert_DDXPlus_2",
"roberta-base": "AdnanSadi/Roberta_DDXPlus_1",
"bds-bert-base": "AdnanSadi/BioDisSumBert_DDXPlus_1",
"bert-sp-mtd": "AdnanSadi/Bert_DDXPlus_2",
"distilbert-sp-mtd": "AdnanSadi/DistilBert_DDXPlus_3",
"roberta-sp-mtd" : "AdnanSadi/Roberta_DDXPlus_2",
"bds-bert-sp-mtd": "AdnanSadi/BioDisSumBert_DDXPlus_2",
"Choose": "AdnanSadi/Roberta_DDXPlus_2"
}
model_path = model_dict[model_name]
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast= True)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
return tokenizer, model
# Diagnosis generation function
def get_diagnosis(age, sex, medhist, symptoms, model_name, threshold):
# get tokenizer and model
tokenizer, model = get_model_and_tokenizer(model_name)
# assemble input text
text = f"""The following is a list of medical history and symptoms described by a patient."""
text += f"""\nSex: {sex}, Age: {age}"""
text += f"""\nMedical History:\n{medhist}"""
text += f"""\nSymptoms:\n{symptoms}"""
encoding = tokenizer(text, truncation=True, return_tensors="pt")
# output logits
outputs = model(**encoding)
logits = outputs.logits
# convert logits in probabilities
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
predictions = np.zeros(probs.shape)
predictions[np.where(probs >= threshold)] = 1
# get probability values for the predictions
predicted_probs = probs[predictions == 1]
# turn predicted id's into actual label names
predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
output_text = f""""""
for i, pl in enumerate(predicted_labels):
output_text += f"""{i+1}. {pl} (Conf. Score: {round(predicted_probs[i].item(), 4)})\n"""
#output_text += f"""{i+1}.{pl}\n"""
return output_text
# defining the examples for the demo
sample_1 = ["Male", "90", f"""- I work in a daycare.
- I smoke cigarettes.
- I have had a cold in the last 2 weeks.
- I have not traveled anywhere in the last 4 weeks.""",
f"""- I feel pain.
- The pain is sensitive.
- I feel pain in my tonsil(R).
- I feel pain in my thyroid cartilage.
- I feel pain in my palate.
- I feel pain in my pharynx.
- I feel pain in my under the jaw.
- On a scale of 0-10, the pain intensity is 4.
- The pain does not radiate to anywhere.
- On a scale of 0-10, the pain's location precision is 4.
- On a scale of 0-10, the pace at which the pain appear is 2.
- I have a cough.
- I have noticed a change in the tone of my voice.""", "bds-bert-sp-mtd"]
sample_2 = ["Female", "16", f"""- I feel anxious.
- I regularly drink coffee or tea.
- I consume energy drinks regularly.
- I regularly take stimulant drugs.
- I have not traveled anywhere in the last 4 weeks.""",
f"""- I feel pain.
- The pain is burning.
- I feel pain in my back of head.
- I feel pain in my forehead.
- I feel pain in my temple(R).
- On a scale of 0-10, the pain intensity is 6.
- The pain does not radiate to anywhere.
- On a scale of 0-10, the pain's location precision is 2.
- On a scale of 0-10, the pace at which the pain appear is 6.
- I am experiencing shortness of breath or difficulty breathing in a significant way.
- I feel lightheaded, dizzy, and about to faint.
- I feel palpitations.""", "bert-base"]
sample_3 = ["Male", "57", f"""- Some members of my family have been diagnosed with myasthenia gravis.
- I have not traveled anywhere in the last 4 weeks.""",
f"""- I have pain or weakness in my jaw.
- I have difficulty articulating words/speaking.
- I have a feeling of discomfort/blockage when swallowing.
- I am experiencing shortness of breath or difficulty breathing in a significant way.
- I feel weakness in both arms and/or both legs.""", "roberta-sp-mtd"]
# creating the demo
model_names = ["Choose", "bert-base", "distilbert-base", "roberta-base", "bds-bert-base",
"bert-sp-mtd", "distilbert-sp-mtd", "roberta-sp-mtd", "bds-bert-sp-mtd"]
demo = gr.Blocks()
with demo:
gr.Markdown("""
# Differential Diagnosis Tool
This demo contains the models described in paper [Automatic Differential Diagnosis using Transformer-Based Multi-Label Sequence Classification](https://doi.org/10.48550/arXiv.2408.15827).
The models were trained to provide a differential diagnosis based on the medical history and symptoms described by a patient.
Please fill out the following form with relevant information, including the age, sex, medical history, and symptoms of the patient. For best results, please provide the symptoms and medical history information as a list.
#### For reference, please look over some of the examples provided at the bottom.
### Acknowledgments: This project was funded by North South University CTRG.
""")
with gr.Row():
with gr.Column():
with gr.Row():
age = gr.components.Number(label="Age", interactive= True)
gender = gr.components.Dropdown(["Choose", "Male", "Female"], label="Gender", value="Choose", interactive=True)
medhist = gr.components.Textbox(label="Medical History", info="a list of patient's medical history", lines=5, interactive=True)
symptoms = gr.components.Textbox(label="Symptoms", info="a list of patient's symptoms", lines=5, interactive=True)
model_name = gr.components.Dropdown(model_names, label="Model", info="Defaults to Roberta",
value="Choose", interactive=True)
threshold_value = gr.components.Slider(0, 1, value=0.5, label="Model Confidence Threshold",
info="Choose between 0 and 1", interactive=True)
with gr.Row():
clear_btn = gr.Button("Clear")
submit_btn = gr.Button("Submit", variant="primary")
output_box = gr.Textbox(label="Differential Diagnosis Based on Patient Report:", lines=5, interactive=False)
gr.Markdown("## Patient Examples")
gr.Examples(
[sample_1, sample_2, sample_3],
[gender, age, medhist, symptoms, model_name],
)
clear_btn.click(lambda: [0, "Choose", None,None,"Choose",0.2, None],
outputs=[age, gender, medhist, symptoms, model_name, threshold_value, output_box])
submit_btn.click(fn = get_diagnosis, inputs=[age, gender, medhist, symptoms, model_name, threshold_value],
outputs=output_box)
demo.launch(share=False, debug=False)