Spaces:
Running
Running
| import gradio as gr | |
| from transformers import AutoModelForCTC, AutoFeatureExtractor, AutoTokenizer | |
| import torch | |
| import numpy as np | |
| import warnings | |
| import librosa | |
| warnings.filterwarnings("ignore") | |
| MODEL_ID = "google/medasr" | |
| model = None | |
| feature_extractor = None | |
| tokenizer = None | |
| def normalize_audio(audio): | |
| """RMS归一化""" | |
| rms = np.sqrt(np.mean(audio ** 2)) | |
| if rms > 0: | |
| audio = audio / rms | |
| audio = np.clip(audio, -1.0, 1.0) | |
| return audio | |
| def remove_silence(audio, sample_rate, threshold=0.01): | |
| """去除静音段""" | |
| energy = np.abs(audio) | |
| above_threshold = energy > threshold | |
| if not np.any(above_threshold): | |
| return audio | |
| start = np.where(above_threshold)[0][0] | |
| end = np.where(above_threshold)[0][-1] | |
| buffer = int(0.1 * sample_rate) | |
| start = max(0, start - buffer) | |
| end = min(len(audio), end + buffer) | |
| return audio[start:end] | |
| def load_model_with_token(hf_token): | |
| global model, feature_extractor, tokenizer | |
| if not hf_token or not hf_token.strip(): | |
| return gr.update(interactive=False, value="❌ Token cannot be empty!"), gr.update(interactive=False) | |
| try: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print("🔄 Loading model components...") | |
| print(f"📱 Device: {device}") | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID, token=hf_token.strip()) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=hf_token.strip()) | |
| model = AutoModelForCTC.from_pretrained(MODEL_ID, token=hf_token.strip()).to(device) | |
| model.eval() | |
| print(f"✅ Loaded: {type(feature_extractor)}, {type(tokenizer)}, {type(model)}") | |
| return gr.update(interactive=False, value="✅ Model Loaded Successfully!"), gr.update(interactive=True) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return gr.update(interactive=True, value=f"❌ Error: {str(e)}"), gr.update(interactive=False) | |
| def transcribe_audio(audio_input): | |
| global model, feature_extractor, tokenizer | |
| if audio_input is None: | |
| return "⚠️ Please upload or record audio." | |
| if model is None or feature_extractor is None: | |
| return "❌ Please load the model first!" | |
| try: | |
| # 1. 解包音频 | |
| sample_rate, waveform = audio_input | |
| # 2. 转单声道 | |
| if waveform.ndim == 2: | |
| waveform = waveform[:, 0] | |
| # 3. 转换为 float32 并归一化 | |
| if waveform.dtype == np.int16: | |
| waveform = waveform.astype(np.float32) / 32768.0 | |
| elif waveform.dtype != np.float32: | |
| waveform = waveform.astype(np.float32) | |
| # 4. RMS归一化 | |
| waveform = normalize_audio(waveform) | |
| # 5. 去除静音 | |
| waveform = remove_silence(waveform, sample_rate) | |
| # 6. 检查长度 | |
| duration = len(waveform) / sample_rate | |
| if duration < 0.1: | |
| return "⚠️ Audio is too short." | |
| if duration > 60: | |
| return "⚠️ Audio is too long." | |
| # 7. 重采样 | |
| if sample_rate != 16000: | |
| waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000) | |
| sample_rate = 16000 | |
| # 8. 特征提取 | |
| inputs = feature_extractor( | |
| waveform, | |
| sampling_rate=sample_rate, | |
| return_tensors="pt", | |
| ) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| # 【修复部分】 | |
| # 自动查找包含特征数据的键(可能是 'input_features' 或 'input_values') | |
| # 过滤掉 'attention_mask',找到真正的输入 Tensor | |
| input_tensor = None | |
| for key, val in inputs.items(): | |
| if isinstance(val, torch.Tensor) and val.ndim > 1: | |
| input_tensor = val | |
| break | |
| if input_tensor is None: | |
| return "❌ Error: Could not extract audio features." | |
| # 安全地获取 stride,如果不存在则默认为 4 | |
| stride = 4 | |
| if hasattr(feature_extractor, 'stride'): | |
| s = feature_extractor.stride | |
| stride = s[0] if isinstance(s, (list, tuple)) else s | |
| # 动态计算 max_length | |
| max_length = input_tensor.shape[1] // stride + 50 | |
| # 9. Beam search 解码 | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_length=max_length, | |
| num_beams=8, # Beam search 提升准确率 | |
| temperature=1.0, | |
| ) | |
| # 10. 解码 | |
| transcription = tokenizer.batch_decode(outputs.tolist(), skip_special_tokens=True)[0] | |
| # 11. 后处理 | |
| transcription = transcription.strip() | |
| import re | |
| transcription = re.sub(r'\s+', ' ', transcription) | |
| return transcription if transcription else "⚠️ No speech detected." | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return f"❌ Transcription error: {str(e)}" | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🏥 MedASR - Medical Speech Recognition") | |
| gr.Markdown("Optimized for medical dictation with Beam Search decoding.") | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## 🔑 Authentication") | |
| hf_token = gr.Textbox(label="HuggingFace Token", type="password", placeholder="hf_...") | |
| load_model_btn = gr.Button("🔑 Load Model", variant="primary", size="lg") | |
| gr.Markdown("## 📝 Tips") | |
| gr.Markdown(""" | |
| - Speak **clearly and slowly** | |
| - Use **medical terms** | |
| - Short audio (3-10s) is best | |
| - Quiet environment | |
| """) | |
| with gr.Column(scale=2): | |
| gr.Markdown("## 🎙️ Input & Result") | |
| audio_input = gr.Audio(sources=["microphone", "upload"], type="numpy") | |
| with gr.Row(): | |
| transcribe_btn = gr.Button("🔄 Transcribe", variant="secondary", size="lg", interactive=False) | |
| clear_btn = gr.Button("🗑️ Clear", variant="stop", size="lg") | |
| output_text = gr.Textbox(label="Result", lines=12, placeholder="...") | |
| audio_info = gr.Textbox(label="Info", lines=2, interactive=False) | |
| def transcribe_wrapper(audio_in): | |
| res = transcribe_audio(audio_in) | |
| info = f"Status: Success" if "❌" not in res and "⚠️" not in res else "Status: Check result" | |
| return res, info | |
| load_model_btn.click( | |
| fn=load_model_with_token, | |
| inputs=[hf_token], | |
| outputs=[load_model_btn, transcribe_btn] | |
| ) | |
| transcribe_btn.click( | |
| fn=transcribe_wrapper, | |
| inputs=[audio_input], | |
| outputs=[output_text, audio_info] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", "Ready"), | |
| outputs=[output_text, audio_info] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |