Cyber-Blacat commited on
Commit
3d78f04
·
verified ·
1 Parent(s): 9179741

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -67
app.py CHANGED
@@ -3,6 +3,7 @@ from transformers import AutoModelForCTC, AutoFeatureExtractor, AutoTokenizer
3
  import torch
4
  import numpy as np
5
  import warnings
 
6
 
7
  warnings.filterwarnings("ignore")
8
 
@@ -12,6 +13,27 @@ model = None
12
  feature_extractor = None
13
  tokenizer = None
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def load_model_with_token(hf_token):
16
  global model, feature_extractor, tokenizer
17
 
@@ -21,6 +43,7 @@ def load_model_with_token(hf_token):
21
  try:
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
  print("🔄 Loading model components...")
 
24
 
25
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID, token=hf_token.strip())
26
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=hf_token.strip())
@@ -29,22 +52,15 @@ def load_model_with_token(hf_token):
29
 
30
  print(f"✅ Loaded: {type(feature_extractor)}, {type(tokenizer)}, {type(model)}")
31
 
32
- # 成功后:按钮显示"✅ 已加载",转录按钮启用
33
  return gr.update(interactive=False, value="✅ Model Loaded Successfully!"), gr.update(interactive=True)
34
  except Exception as e:
35
  print(f"Error loading model: {e}")
36
  import traceback
37
  traceback.print_exc()
38
-
39
- # 失败后:按钮显示错误信息,转录按钮禁用
40
  return gr.update(interactive=True, value=f"❌ Error: {str(e)}"), gr.update(interactive=False)
41
 
42
 
43
  def transcribe_audio(audio_input):
44
- """
45
- 注意:audio_input 不再是文件路径,而是 Gradio 直接传来的 numpy 数组
46
- 格式为 tuple: (sample_rate: int, waveform: np.ndarray)
47
- """
48
  global model, feature_extractor, tokenizer
49
 
50
  if audio_input is None:
@@ -54,27 +70,38 @@ def transcribe_audio(audio_input):
54
  return "❌ Please load the model first!"
55
 
56
  try:
57
- # 1. 解包 Gradio 传来的音频:采样率 + 波形
58
  sample_rate, waveform = audio_input
59
 
60
- # 如果是多通道(立体声),只取第一个声道
61
  if waveform.ndim == 2:
62
- waveform = waveform[:, 0] # 取单声道
63
 
64
- # 2. 统一转成 float32,归一化到 [-1, 1](如果还不是)
65
- # Gradio 默认给的是 int16,范围 [-32768, 32767],我们除以 32768 就变 float32
66
  if waveform.dtype == np.int16:
67
  waveform = waveform.astype(np.float32) / 32768.0
68
  elif waveform.dtype != np.float32:
69
  waveform = waveform.astype(np.float32)
70
 
71
- # 3. 如果采样率不是 16kHz,用 librosa 重采样(可选)
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  if sample_rate != 16000:
73
- import librosa
74
  waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
75
  sample_rate = 16000
76
 
77
- # 4. 用 LasrFeatureExtractor 处理音频
78
  inputs = feature_extractor(
79
  waveform,
80
  sampling_rate=sample_rate,
@@ -82,19 +109,45 @@ def transcribe_audio(audio_input):
82
  )
83
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
84
 
85
- # 5. model generation, decode with beam search.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  with torch.no_grad():
87
  outputs = model.generate(
88
  **inputs,
89
- max_length=inputs["input_values"].shape[1] // feature_extractor.stride[0] + 50,
90
- num_beams=8, # beam search
91
  temperature=1.0,
92
  )
93
 
94
- # 6. 解码
95
  transcription = tokenizer.batch_decode(outputs.tolist(), skip_special_tokens=True)[0]
96
 
97
- return transcription
 
 
 
 
 
98
 
99
  except Exception as e:
100
  import traceback
@@ -104,74 +157,55 @@ def transcribe_audio(audio_input):
104
 
105
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
106
  gr.Markdown("# 🏥 MedASR - Medical Speech Recognition")
107
- gr.Markdown("AI-powered medical dictation system. Upload or record audio to transcribe.")
108
 
109
  gr.Markdown("---")
110
 
111
  with gr.Row():
112
  with gr.Column(scale=1):
113
  gr.Markdown("## 🔑 Authentication")
114
- hf_token = gr.Textbox(
115
- label="HuggingFace Token",
116
- type="password",
117
- placeholder="hf_...",
118
- info="Required to access the gated MedASR model"
119
- )
120
 
121
- # 用按钮本身显示状态,不再需要单独的 status 文本框
122
- load_model_btn = gr.Button(
123
- "🔑 Load Model",
124
- variant="primary",
125
- size="lg"
126
- )
127
-
128
- with gr.Column(scale=2):
129
  gr.Markdown("## 📝 Tips")
130
  gr.Markdown("""
131
- - **Speak clearly** in English
132
- - **Short phrases** work best (3-10 seconds)
133
- - **Quiet environment** improves accuracy
134
- - Try medical terms: *patient, diagnosis, treatment, medication*
135
  """)
136
-
137
- gr.Markdown("---")
138
 
139
- with gr.Row():
140
- with gr.Column():
141
- gr.Markdown("## 🎙️ Audio Input")
142
- audio_input = gr.Audio(
143
- sources=["microphone", "upload"],
144
- type="numpy",
145
- label="Record or upload audio"
146
- )
147
- transcribe_btn = gr.Button(
148
- "🔄 Transcribe",
149
- variant="secondary",
150
- size="lg",
151
- interactive=False
152
- )
153
 
154
- with gr.Column():
155
- gr.Markdown("## 📄 Transcription Result")
156
- output_text = gr.Textbox(
157
- label="",
158
- lines=12,
159
- placeholder="Transcription will appear here...",
160
- show_label=False
161
- )
 
 
 
162
 
163
- # 事件绑定
164
  load_model_btn.click(
165
  fn=load_model_with_token,
166
  inputs=[hf_token],
167
- # 第一个返回值更新按钮文本和状态,第二个返回值更新转录按钮的交互状态
168
  outputs=[load_model_btn, transcribe_btn]
169
  )
170
 
171
  transcribe_btn.click(
172
- fn=transcribe_audio,
173
  inputs=[audio_input],
174
- outputs=[output_text]
 
 
 
 
 
175
  )
176
 
177
  if __name__ == "__main__":
 
3
  import torch
4
  import numpy as np
5
  import warnings
6
+ import librosa
7
 
8
  warnings.filterwarnings("ignore")
9
 
 
13
  feature_extractor = None
14
  tokenizer = None
15
 
16
+ def normalize_audio(audio):
17
+ """RMS归一化"""
18
+ rms = np.sqrt(np.mean(audio ** 2))
19
+ if rms > 0:
20
+ audio = audio / rms
21
+ audio = np.clip(audio, -1.0, 1.0)
22
+ return audio
23
+
24
+ def remove_silence(audio, sample_rate, threshold=0.01):
25
+ """去除静音段"""
26
+ energy = np.abs(audio)
27
+ above_threshold = energy > threshold
28
+ if not np.any(above_threshold):
29
+ return audio
30
+ start = np.where(above_threshold)[0][0]
31
+ end = np.where(above_threshold)[0][-1]
32
+ buffer = int(0.1 * sample_rate)
33
+ start = max(0, start - buffer)
34
+ end = min(len(audio), end + buffer)
35
+ return audio[start:end]
36
+
37
  def load_model_with_token(hf_token):
38
  global model, feature_extractor, tokenizer
39
 
 
43
  try:
44
  device = "cuda" if torch.cuda.is_available() else "cpu"
45
  print("🔄 Loading model components...")
46
+ print(f"📱 Device: {device}")
47
 
48
  feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID, token=hf_token.strip())
49
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=hf_token.strip())
 
52
 
53
  print(f"✅ Loaded: {type(feature_extractor)}, {type(tokenizer)}, {type(model)}")
54
 
 
55
  return gr.update(interactive=False, value="✅ Model Loaded Successfully!"), gr.update(interactive=True)
56
  except Exception as e:
57
  print(f"Error loading model: {e}")
58
  import traceback
59
  traceback.print_exc()
 
 
60
  return gr.update(interactive=True, value=f"❌ Error: {str(e)}"), gr.update(interactive=False)
61
 
62
 
63
  def transcribe_audio(audio_input):
 
 
 
 
64
  global model, feature_extractor, tokenizer
65
 
66
  if audio_input is None:
 
70
  return "❌ Please load the model first!"
71
 
72
  try:
73
+ # 1. 解包音频
74
  sample_rate, waveform = audio_input
75
 
76
+ # 2. 转单声道
77
  if waveform.ndim == 2:
78
+ waveform = waveform[:, 0]
79
 
80
+ # 3. 转换为 float32 并归一化
 
81
  if waveform.dtype == np.int16:
82
  waveform = waveform.astype(np.float32) / 32768.0
83
  elif waveform.dtype != np.float32:
84
  waveform = waveform.astype(np.float32)
85
 
86
+ # 4. RMS归一化
87
+ waveform = normalize_audio(waveform)
88
+
89
+ # 5. 去除静音
90
+ waveform = remove_silence(waveform, sample_rate)
91
+
92
+ # 6. 检查长度
93
+ duration = len(waveform) / sample_rate
94
+ if duration < 0.1:
95
+ return "⚠️ Audio is too short."
96
+ if duration > 60:
97
+ return "⚠️ Audio is too long."
98
+
99
+ # 7. 重采样
100
  if sample_rate != 16000:
 
101
  waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
102
  sample_rate = 16000
103
 
104
+ # 8. 特征提取
105
  inputs = feature_extractor(
106
  waveform,
107
  sampling_rate=sample_rate,
 
109
  )
110
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
111
 
112
+ # 【修复部分】
113
+ # 自动查找包含特征数据的键(可能是 'input_features' 或 'input_values')
114
+ # 过滤掉 'attention_mask',找到真正的输入 Tensor
115
+ input_tensor = None
116
+ for key, val in inputs.items():
117
+ if isinstance(val, torch.Tensor) and val.ndim > 1:
118
+ input_tensor = val
119
+ break
120
+
121
+ if input_tensor is None:
122
+ return "❌ Error: Could not extract audio features."
123
+
124
+ # 安全地获取 stride,如果不存在则默认为 4
125
+ stride = 4
126
+ if hasattr(feature_extractor, 'stride'):
127
+ s = feature_extractor.stride
128
+ stride = s[0] if isinstance(s, (list, tuple)) else s
129
+
130
+ # 动态计算 max_length
131
+ max_length = input_tensor.shape[1] // stride + 50
132
+
133
+ # 9. Beam search 解码
134
  with torch.no_grad():
135
  outputs = model.generate(
136
  **inputs,
137
+ max_length=max_length,
138
+ num_beams=8, # Beam search 提升准确率
139
  temperature=1.0,
140
  )
141
 
142
+ # 10. 解码
143
  transcription = tokenizer.batch_decode(outputs.tolist(), skip_special_tokens=True)[0]
144
 
145
+ # 11. 后处理
146
+ transcription = transcription.strip()
147
+ import re
148
+ transcription = re.sub(r'\s+', ' ', transcription)
149
+
150
+ return transcription if transcription else "⚠️ No speech detected."
151
 
152
  except Exception as e:
153
  import traceback
 
157
 
158
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
159
  gr.Markdown("# 🏥 MedASR - Medical Speech Recognition")
160
+ gr.Markdown("Optimized for medical dictation with Beam Search decoding.")
161
 
162
  gr.Markdown("---")
163
 
164
  with gr.Row():
165
  with gr.Column(scale=1):
166
  gr.Markdown("## 🔑 Authentication")
167
+ hf_token = gr.Textbox(label="HuggingFace Token", type="password", placeholder="hf_...")
168
+ load_model_btn = gr.Button("🔑 Load Model", variant="primary", size="lg")
 
 
 
 
169
 
 
 
 
 
 
 
 
 
170
  gr.Markdown("## 📝 Tips")
171
  gr.Markdown("""
172
+ - Speak **clearly and slowly**
173
+ - Use **medical terms**
174
+ - Short audio (3-10s) is best
175
+ - Quiet environment
176
  """)
 
 
177
 
178
+ with gr.Column(scale=2):
179
+ gr.Markdown("## 🎙️ Input & Result")
180
+ audio_input = gr.Audio(sources=["microphone", "upload"], type="numpy")
 
 
 
 
 
 
 
 
 
 
 
181
 
182
+ with gr.Row():
183
+ transcribe_btn = gr.Button("🔄 Transcribe", variant="secondary", size="lg", interactive=False)
184
+ clear_btn = gr.Button("🗑️ Clear", variant="stop", size="lg")
185
+
186
+ output_text = gr.Textbox(label="Result", lines=12, placeholder="...")
187
+ audio_info = gr.Textbox(label="Info", lines=2, interactive=False)
188
+
189
+ def transcribe_wrapper(audio_in):
190
+ res = transcribe_audio(audio_in)
191
+ info = f"Status: Success" if "❌" not in res and "⚠️" not in res else "Status: Check result"
192
+ return res, info
193
 
 
194
  load_model_btn.click(
195
  fn=load_model_with_token,
196
  inputs=[hf_token],
 
197
  outputs=[load_model_btn, transcribe_btn]
198
  )
199
 
200
  transcribe_btn.click(
201
+ fn=transcribe_wrapper,
202
  inputs=[audio_input],
203
+ outputs=[output_text, audio_info]
204
+ )
205
+
206
+ clear_btn.click(
207
+ fn=lambda: ("", "Ready"),
208
+ outputs=[output_text, audio_info]
209
  )
210
 
211
  if __name__ == "__main__":