MedASR / app.py
Cyber-Blacat's picture
Update app.py
3d78f04 verified
import gradio as gr
from transformers import AutoModelForCTC, AutoFeatureExtractor, AutoTokenizer
import torch
import numpy as np
import warnings
import librosa
warnings.filterwarnings("ignore")
MODEL_ID = "google/medasr"
model = None
feature_extractor = None
tokenizer = None
def normalize_audio(audio):
"""RMS归一化"""
rms = np.sqrt(np.mean(audio ** 2))
if rms > 0:
audio = audio / rms
audio = np.clip(audio, -1.0, 1.0)
return audio
def remove_silence(audio, sample_rate, threshold=0.01):
"""去除静音段"""
energy = np.abs(audio)
above_threshold = energy > threshold
if not np.any(above_threshold):
return audio
start = np.where(above_threshold)[0][0]
end = np.where(above_threshold)[0][-1]
buffer = int(0.1 * sample_rate)
start = max(0, start - buffer)
end = min(len(audio), end + buffer)
return audio[start:end]
def load_model_with_token(hf_token):
global model, feature_extractor, tokenizer
if not hf_token or not hf_token.strip():
return gr.update(interactive=False, value="❌ Token cannot be empty!"), gr.update(interactive=False)
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("🔄 Loading model components...")
print(f"📱 Device: {device}")
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID, token=hf_token.strip())
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=hf_token.strip())
model = AutoModelForCTC.from_pretrained(MODEL_ID, token=hf_token.strip()).to(device)
model.eval()
print(f"✅ Loaded: {type(feature_extractor)}, {type(tokenizer)}, {type(model)}")
return gr.update(interactive=False, value="✅ Model Loaded Successfully!"), gr.update(interactive=True)
except Exception as e:
print(f"Error loading model: {e}")
import traceback
traceback.print_exc()
return gr.update(interactive=True, value=f"❌ Error: {str(e)}"), gr.update(interactive=False)
def transcribe_audio(audio_input):
global model, feature_extractor, tokenizer
if audio_input is None:
return "⚠️ Please upload or record audio."
if model is None or feature_extractor is None:
return "❌ Please load the model first!"
try:
# 1. 解包音频
sample_rate, waveform = audio_input
# 2. 转单声道
if waveform.ndim == 2:
waveform = waveform[:, 0]
# 3. 转换为 float32 并归一化
if waveform.dtype == np.int16:
waveform = waveform.astype(np.float32) / 32768.0
elif waveform.dtype != np.float32:
waveform = waveform.astype(np.float32)
# 4. RMS归一化
waveform = normalize_audio(waveform)
# 5. 去除静音
waveform = remove_silence(waveform, sample_rate)
# 6. 检查长度
duration = len(waveform) / sample_rate
if duration < 0.1:
return "⚠️ Audio is too short."
if duration > 60:
return "⚠️ Audio is too long."
# 7. 重采样
if sample_rate != 16000:
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
sample_rate = 16000
# 8. 特征提取
inputs = feature_extractor(
waveform,
sampling_rate=sample_rate,
return_tensors="pt",
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# 【修复部分】
# 自动查找包含特征数据的键(可能是 'input_features' 或 'input_values')
# 过滤掉 'attention_mask',找到真正的输入 Tensor
input_tensor = None
for key, val in inputs.items():
if isinstance(val, torch.Tensor) and val.ndim > 1:
input_tensor = val
break
if input_tensor is None:
return "❌ Error: Could not extract audio features."
# 安全地获取 stride,如果不存在则默认为 4
stride = 4
if hasattr(feature_extractor, 'stride'):
s = feature_extractor.stride
stride = s[0] if isinstance(s, (list, tuple)) else s
# 动态计算 max_length
max_length = input_tensor.shape[1] // stride + 50
# 9. Beam search 解码
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
num_beams=8, # Beam search 提升准确率
temperature=1.0,
)
# 10. 解码
transcription = tokenizer.batch_decode(outputs.tolist(), skip_special_tokens=True)[0]
# 11. 后处理
transcription = transcription.strip()
import re
transcription = re.sub(r'\s+', ' ', transcription)
return transcription if transcription else "⚠️ No speech detected."
except Exception as e:
import traceback
traceback.print_exc()
return f"❌ Transcription error: {str(e)}"
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🏥 MedASR - Medical Speech Recognition")
gr.Markdown("Optimized for medical dictation with Beam Search decoding.")
gr.Markdown("---")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## 🔑 Authentication")
hf_token = gr.Textbox(label="HuggingFace Token", type="password", placeholder="hf_...")
load_model_btn = gr.Button("🔑 Load Model", variant="primary", size="lg")
gr.Markdown("## 📝 Tips")
gr.Markdown("""
- Speak **clearly and slowly**
- Use **medical terms**
- Short audio (3-10s) is best
- Quiet environment
""")
with gr.Column(scale=2):
gr.Markdown("## 🎙️ Input & Result")
audio_input = gr.Audio(sources=["microphone", "upload"], type="numpy")
with gr.Row():
transcribe_btn = gr.Button("🔄 Transcribe", variant="secondary", size="lg", interactive=False)
clear_btn = gr.Button("🗑️ Clear", variant="stop", size="lg")
output_text = gr.Textbox(label="Result", lines=12, placeholder="...")
audio_info = gr.Textbox(label="Info", lines=2, interactive=False)
def transcribe_wrapper(audio_in):
res = transcribe_audio(audio_in)
info = f"Status: Success" if "❌" not in res and "⚠️" not in res else "Status: Check result"
return res, info
load_model_btn.click(
fn=load_model_with_token,
inputs=[hf_token],
outputs=[load_model_btn, transcribe_btn]
)
transcribe_btn.click(
fn=transcribe_wrapper,
inputs=[audio_input],
outputs=[output_text, audio_info]
)
clear_btn.click(
fn=lambda: ("", "Ready"),
outputs=[output_text, audio_info]
)
if __name__ == "__main__":
demo.launch()