| | import gradio as gr |
| | import torch |
| | import soundfile as sf |
| | import os |
| | import numpy as np |
| |
|
| | import os |
| | import soundfile as sf |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset, DataLoader |
| | from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2ForSequenceClassification |
| | from collections import Counter |
| |
|
| | device = torch.device("cpu") |
| | processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") |
| | model = Wav2Vec2ForSequenceClassification.from_pretrained("facebook/wav2vec2-base-960h", num_labels=2).to(device) |
| | model_path = "dysarthria_classifier12.pth" |
| | |
| | model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | title = "Upload an mp3 file for Psuedobulbar Palsy (PP) detection! (Thai Language)" |
| | description = """ |
| | The model was trained on Thai audio recordings with the following sentences so please use these sentences: \n |
| | ชาวไร่ตัดต้นสนทำท่อนซุง\n |
| | ปูม้าวิ่งไปมาบนใบไม้ (เน้นใช้ริมฝีปาก)\n |
| | อีกาคอยคาบงูคาบไก่ (เน้นใช้เพดานปาก)\n |
| | เพียงแค่ฝนตกลงที่หน้าต่างในบางครา\n |
| | “อาาาาาาาาาาา”\n |
| | “อีีีีีีีีี”\n |
| | “อาาาา” (ดังขึ้นเรื่อยๆ)\n |
| | “อาา อาาา อาาาาา”\n |
| | |
| | """ |
| |
|
| | |
| |
|
| |
|
| |
|
| |
|
| | def predict(file_upload,microphone): |
| | max_length = 100000 |
| | file_path =file_upload |
| | warn_output = "" |
| | if (microphone is not None) and (file_upload is not None): |
| | warn_output = ( |
| | "WARNING: You've uploaded an audio file and used the microphone. " |
| | "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n\n" |
| | ) |
| |
|
| | elif (microphone is None) and (file_upload is None): |
| | return "ERROR: You have to either use the microphone or upload an audio file" |
| | if(file_upload is not None): |
| | file_path = file_upload |
| | if(microphone is not None): |
| | file_path = microphone |
| | model.eval() |
| | with torch.no_grad(): |
| | wav_data, _ = sf.read(file_path) |
| | inputs = processor(wav_data, sampling_rate=16000, return_tensors="pt", padding=True) |
| |
|
| | input_values = inputs.input_values.squeeze(0) |
| | if max_length - input_values.shape[-1] > 0: |
| | input_values = torch.cat([input_values, torch.zeros((max_length - input_values.shape[-1],))], dim=-1) |
| | else: |
| | input_values = input_values[:max_length] |
| | input_values = input_values.unsqueeze(0).to(device) |
| | inputs = {"input_values": input_values} |
| |
|
| | logits = model(**inputs).logits |
| | logits = logits.squeeze() |
| | predicted_class_id = torch.argmax(logits, dim=-1).item() |
| |
|
| | return warn_output + "You probably have PP" if predicted_class_id == 1 else warn_output + "You probably don't have PP" |
| | gr.Interface( |
| | fn=predict, |
| | inputs=[ |
| | gr.inputs.Audio(source="upload", type="filepath", optional=True), |
| | gr.inputs.Audio(source="microphone", type="filepath", optional=True), |
| | ], |
| | outputs="text", |
| | title=title, |
| | description=description, |
| | ).launch() |
| |
|
| |
|
| | |
| | |