Spaces:
Runtime error
Runtime error
| import os | |
| import whisper | |
| import evaluate | |
| from evaluate.utils import launch_gradio_widget | |
| import gradio as gr | |
| import torch | |
| import pandas as pd | |
| import random | |
| import classify | |
| from whisper.model import Whisper | |
| from whisper.tokenizer import get_tokenizer | |
| from transformers import pipeline, WhisperTokenizer | |
| # pull in emotion detection | |
| # --- Add element for specification | |
| # pull in text classification | |
| # --- Add custom labels | |
| # --- Associate labels with radio elements | |
| # add logic to initiate mock notificaiton when detected | |
| # pull in misophonia-specific model | |
| model_cache = {} | |
| # static classes for now, but it would be best ot have the user select from multiple, and to enter their own | |
| class_options = { | |
| "misophonia": ["chewing", "breathing", "mouthsounds", "popping", "sneezing", "yawning", "smacking", "sniffling", "panting"] | |
| } | |
| pipe = pipeline("automatic-speech-recognition", model="openai/whisper-large") | |
| model = whisper.load_model("large") | |
| tokenizer = get_tokenizer("large") | |
| def slider_logic(slider): | |
| threshold = 0 | |
| if slider == 1: | |
| threshold = .45 | |
| elif slider == 2: | |
| threshold = .35 | |
| elif slider == 3: | |
| threshold = .25 | |
| elif slider == 4: | |
| threshold = .15 | |
| elif slider == 5: | |
| threshold = .5 | |
| else: | |
| threshold = [] | |
| return threshold | |
| # Create a Gradio interface with audio file and text inputs | |
| def classify_toxicity(audio_file, selected_sounds, slider): | |
| # Transcribe the audio file using Whisper ASR | |
| # transcribed_text = pipe(audio_file)["text"] | |
| threshold = slider_logic(slider) | |
| # MODEL LINE model = whisper.load_model("large") | |
| # model = model_cache[model_name] | |
| # class_names = classify_anxiety.split(",") | |
| classify_anxiety = "misophonia" | |
| class_names_list = class_options.get(classify_anxiety, []) | |
| class_str = "" | |
| for elm in class_names_list: | |
| class_str += elm + "," | |
| #class_names = class_names_temp.split(",") | |
| class_names = class_str.split(",") | |
| print("class names ", class_names, "classify_anxiety ", classify_anxiety) | |
| # TOKENIZER LINE tokenizer = get_tokenizer("large") | |
| # tokenizer= WhisperTokenizer.from_pretrained("openai/whisper-large") | |
| internal_lm_average_logprobs = classify.calculate_internal_lm_average_logprobs( | |
| model=model, | |
| class_names=class_names, | |
| # class_names=classify_anxiety, | |
| tokenizer=tokenizer, | |
| ) | |
| audio_features = classify.calculate_audio_features(audio_file, model) | |
| average_logprobs = classify.calculate_average_logprobs( | |
| model=model, | |
| audio_features=audio_features, | |
| class_names=class_names, | |
| tokenizer=tokenizer, | |
| ) | |
| average_logprobs -= internal_lm_average_logprobs | |
| scores = average_logprobs.softmax(-1).tolist() | |
| class_score_dict = {class_name: score for class_name, score in zip(class_names, scores)} | |
| matching_label_score = {} | |
| # Iterate through the selected sounds | |
| for selected_class_name in selected_sounds: | |
| if selected_class_name in class_score_dict: | |
| score = class_score_dict[selected_class_name] | |
| matching_label_score[selected_class_name] = score | |
| print("matching label score type is ", type(matching_label_score)) | |
| highest_score = max(matching_label_score.values()) | |
| highest_float = float(highest_score) | |
| print("highest float ", highest_float) | |
| print("threshold", threshold) | |
| if highest_score is not None and highest_float > threshold: | |
| affirm = "Threshold Exceeded, initiate intervention" | |
| else: | |
| affirm = " " | |
| # miso_label_dict = {label: score for label, score in classify_anxiety[0].items()} | |
| return class_score_dict, affirm | |
| with gr.Blocks() as iface: | |
| with gr.Column(): | |
| miso_sounds = gr.CheckboxGroup(["chewing", "breathing", "mouthsounds", "popping", "sneezing", "yawning", "smacking", "sniffling", "panting"]) | |
| sense_slider = gr.Slider(minimum=1, maximum=5, step=1.0, label="How readily do you want the tool to intervene? 1 = in extreme cases and 5 = at every opportunity") | |
| with gr.Column(): | |
| aud_input = gr.Audio(source="upload", type="filepath", label="Upload Audio File") | |
| submit_btn = gr.Button(label="Run") | |
| with gr.Column(): | |
| # out_val = gr.Textbox() | |
| out_text = gr.Textbox(label="Intervention") | |
| out_class = gr.Label() | |
| submit_btn.click(fn=classify_toxicity, inputs=[aud_input, miso_sounds, sense_slider], outputs=[out_class, out_text]) | |
| iface.launch() |