Diggz10 commited on
Commit
09dbb5c
·
verified ·
1 Parent(s): bcdc929

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -88
app.py CHANGED
@@ -1,14 +1,12 @@
1
- # app.py — Voice Clarity Booster (MetricGAN+) for Hugging Face Spaces
2
- # Fixes:
3
- # - Robust mono conversion (handles [T], [T,C], [C,T]) to prevent 50-byte WAVs.
4
- # - Output autoplay, NaN/Inf sanitization, tiny-output fallback.
5
 
6
- import io
7
  import os
 
8
  import tempfile
9
  from typing import Tuple, Optional
10
 
11
- # ---- Quiet noisy deprecation warnings (optional) ----
12
  import warnings
13
  warnings.filterwarnings(
14
  "ignore",
@@ -26,32 +24,47 @@ import soundfile as sf
26
  import torch
27
  import torchaudio
28
 
29
- # ---- SpeechBrain import: prefer new API, fall back if older version ----
30
  try:
31
- # SpeechBrain >= 1.0
32
  from speechbrain.inference import SpectralMaskEnhancement
33
- except Exception: # pragma: no cover
34
- # Older SpeechBrain (<1.0)
35
  from speechbrain.pretrained import SpectralMaskEnhancement # type: ignore
36
 
 
 
 
 
 
 
37
 
38
  # -----------------------------
39
- # Model: SpeechBrain MetricGAN+
40
  # -----------------------------
41
- _ENHANCER: Optional[SpectralMaskEnhancement] = None
42
  _DEVICE = "cpu"
 
 
43
 
44
 
45
- def _get_enhancer() -> SpectralMaskEnhancement:
46
- """Lazily load the enhancer and cache it."""
47
- global _ENHANCER
48
- if _ENHANCER is None:
49
- _ENHANCER = SpectralMaskEnhancement.from_hparams(
50
  source="speechbrain/metricgan-plus-voicebank",
51
  savedir="pretrained/metricgan_plus_voicebank",
52
  run_opts={"device": _DEVICE},
53
  )
54
- return _ENHANCER
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  # -----------------------------
@@ -60,40 +73,28 @@ def _get_enhancer() -> SpectralMaskEnhancement:
60
  def _to_mono(wav: np.ndarray) -> np.ndarray:
61
  """
62
  Ensure mono [T] float32 robustly.
63
-
64
- Accepts:
65
- - [T] (mono)
66
- - [T, C] (samples, channels)
67
- - [C, T] (channels, samples)
68
- - Any 2D shape where a dimension <= 8 is 'channels'
69
  """
70
  wav = np.asarray(wav, dtype=np.float32)
71
-
72
  if wav.ndim == 1:
73
  return wav
74
-
75
  if wav.ndim == 2:
76
- T, U = wav.shape
77
-
78
- # If one dimension is 1, just squeeze
79
- if 1 in (T, U):
80
  return wav.reshape(-1).astype(np.float32)
81
-
82
- # Heuristic: if the last dim is small (<= 8), treat it as channels -> [T, C]
83
- if U <= 8:
84
- return wav.mean(axis=1).astype(np.float32) # average across channel axis
85
-
86
- # If the first dim is small (<= 8), treat it as channels -> [C, T]
87
- if T <= 8:
88
  return wav.mean(axis=0).astype(np.float32)
89
-
90
- # Fallback: assume [T, C]
91
  return wav.mean(axis=1).astype(np.float32)
92
-
93
- # Higher dims: flatten channels, keep time last if possible
94
  return wav.reshape(-1).astype(np.float32)
95
 
96
 
 
 
 
 
97
  def _resample_torch(wav: torch.Tensor, sr_in: int, sr_out: int) -> torch.Tensor:
98
  if sr_in == sr_out:
99
  return wav
@@ -107,7 +108,6 @@ def _highpass(wav: torch.Tensor, sr: int, cutoff_hz: float) -> torch.Tensor:
107
 
108
 
109
  def _presence_boost(wav: torch.Tensor, sr: int, gain_db: float) -> torch.Tensor:
110
- """Simple presence EQ around ~4.5 kHz."""
111
  if abs(gain_db) < 1e-6:
112
  return wav
113
  center = 4500.0
@@ -116,74 +116,119 @@ def _presence_boost(wav: torch.Tensor, sr: int, gain_db: float) -> torch.Tensor:
116
 
117
 
118
  def _limit_peak(wav: torch.Tensor, target_dbfs: float = -1.0) -> torch.Tensor:
119
- """Peak-normalize to target dBFS and hard-limit to [-1, 1]."""
120
  target_amp = 10.0 ** (target_dbfs / 20.0)
121
  peak = torch.max(torch.abs(wav)).item()
122
  if peak > 0:
123
- scale = min(1.0, target_amp / peak)
124
- wav = wav * scale
125
  return torch.clamp(wav, -1.0, 1.0)
126
 
127
 
128
- def _sanitize(mono: np.ndarray) -> np.ndarray:
129
- """Replace NaN/Inf with 0 to keep encoders happy."""
130
- return np.nan_to_num(mono, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
 
133
  def _enhance_numpy_audio(
134
  audio: Tuple[int, np.ndarray],
135
- presence_db: float = 3.0,
136
- lowcut_hz: float = 75.0,
 
 
137
  out_sr: Optional[int] = None,
138
  ) -> Tuple[int, np.ndarray]:
139
  """
140
- Core pipeline used by the Gradio UI.
141
  Input: (sr, np.float32 [T] or [T,C])
142
  Returns: (sr_out, np.float32 [T])
143
  """
144
  sr_in, wav_np = audio
145
- wav_mono = _to_mono(wav_np)
146
 
147
- # Guard: empty input
148
- if wav_mono.size < 16:
149
- # Return a short silent buffer at original SR to avoid empty files
150
  return sr_in, np.zeros(1600 if sr_in else 1600, dtype=np.float32)
151
 
152
- wav_t = torch.from_numpy(wav_mono).unsqueeze(0) # [1, T]
153
-
154
- # MetricGAN+ expects 16 kHz mono
155
- enh = _get_enhancer()
156
- wav_16k = _resample_torch(wav_t, sr_in, 16000)
157
 
158
- # Enhance via file path API for broad codec compatibility
159
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_in:
160
  sf.write(tmp_in.name, wav_16k.squeeze(0).numpy(), 16000, subtype="PCM_16")
161
  tmp_in.flush()
162
- clean = enh.enhance_file(tmp_in.name) # torch.Tensor [1, T]
163
- try:
164
- os.remove(tmp_in.name)
165
- except Exception:
166
- pass
167
-
168
- # Optional polish: high-pass & presence EQ + peak limit
169
- clean = _highpass(clean, 16000, lowcut_hz)
170
- clean = _presence_boost(clean, 16000, presence_db)
171
- clean = _limit_peak(clean, target_dbfs=-1.0)
172
 
173
- # Resample to requested output rate (or original)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  sr_out = sr_in if (out_sr is None or out_sr <= 0) else int(out_sr)
175
- clean_out = _resample_torch(clean, 16000, sr_out).squeeze(0).numpy().astype(np.float32)
 
176
 
177
- # Sanitize
178
- clean_out = _sanitize(clean_out)
 
 
 
179
 
180
- # Tiny-output fallback: if somehow too short, return processed original instead
181
- if clean_out.size < 160: # ~10 ms @16k
182
- fallback = _sanitize(wav_16k.squeeze(0).numpy())
183
- fallback = _resample_torch(torch.from_numpy(fallback).unsqueeze(0), 16000, sr_out).squeeze(0).numpy().astype(np.float32)
184
- return sr_out, fallback
185
 
186
- return sr_out, clean_out
187
 
188
 
189
  # -----------------------------
@@ -191,6 +236,8 @@ def _enhance_numpy_audio(
191
  # -----------------------------
192
  def gradio_enhance(
193
  audio: Tuple[int, np.ndarray],
 
 
194
  presence_db: float,
195
  lowcut_hz: float,
196
  output_sr: str,
@@ -201,25 +248,39 @@ def gradio_enhance(
201
  if output_sr in {"44100", "48000"}:
202
  out_sr = int(output_sr)
203
  sr_out, enhanced = _enhance_numpy_audio(
204
- audio, presence_db=float(presence_db), lowcut_hz=float(lowcut_hz), out_sr=out_sr
 
 
 
 
 
205
  )
206
  return (sr_out, enhanced)
207
 
208
 
209
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
210
- gr.Markdown("## Voice Clarity Booster (MetricGAN+)")
211
  with gr.Row():
212
  with gr.Column():
213
  in_audio = gr.Audio(
214
  sources=["upload", "microphone"],
215
  type="numpy",
216
- label="Input (noisy speech)",
 
 
 
 
 
 
 
 
 
217
  )
218
  presence = gr.Slider(
219
- minimum=-12, maximum=12, value=3, step=0.5, label="Presence Boost (dB)"
220
  )
221
  lowcut = gr.Slider(
222
- minimum=0, maximum=200, value=75, step=5, label="Low-Cut (Hz)"
223
  )
224
  out_sr = gr.Radio(
225
  choices=["Original", "44100", "48000"],
@@ -230,7 +291,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
230
  with gr.Column():
231
  out_audio = gr.Audio(type="numpy", label="Enhanced", autoplay=True)
232
 
233
- btn.click(gradio_enhance, inputs=[in_audio, presence, lowcut, out_sr], outputs=[out_audio])
 
 
 
 
234
 
235
- # IMPORTANT for Hugging Face Spaces: call launch() unguarded so the app starts.
236
  demo.launch()
 
1
+ # app.py — Voice Clarity Booster with mode switch + dry/wet mix
2
+ # Modes: MetricGAN+ (denoise) | SepFormer (dereverb+denoise) | Bypass (EQ only)
 
 
3
 
 
4
  import os
5
+ import io
6
  import tempfile
7
  from typing import Tuple, Optional
8
 
9
+ # --- Quiet noisy deprecation warnings (optional) ---
10
  import warnings
11
  warnings.filterwarnings(
12
  "ignore",
 
24
  import torch
25
  import torchaudio
26
 
27
+ # Prefer new SpeechBrain API; fall back for older versions
28
  try:
 
29
  from speechbrain.inference import SpectralMaskEnhancement
30
+ except Exception: # < 1.0
 
31
  from speechbrain.pretrained import SpectralMaskEnhancement # type: ignore
32
 
33
+ try:
34
+ # SepFormer enhancement model (WHAMR) via separation interface
35
+ from speechbrain.inference import SepformerSeparation
36
+ except Exception:
37
+ from speechbrain.pretrained import SepformerSeparation # type: ignore
38
+
39
 
40
  # -----------------------------
41
+ # Cached models
42
  # -----------------------------
 
43
  _DEVICE = "cpu"
44
+ _ENHANCER_METRICGAN: Optional[SpectralMaskEnhancement] = None
45
+ _ENHANCER_SEPFORMER: Optional[SepformerSeparation] = None
46
 
47
 
48
+ def _get_metricgan() -> SpectralMaskEnhancement:
49
+ global _ENHANCER_METRICGAN
50
+ if _ENHANCER_METRICGAN is None:
51
+ _ENHANCER_METRICGAN = SpectralMaskEnhancement.from_hparams(
 
52
  source="speechbrain/metricgan-plus-voicebank",
53
  savedir="pretrained/metricgan_plus_voicebank",
54
  run_opts={"device": _DEVICE},
55
  )
56
+ return _ENHANCER_METRICGAN
57
+
58
+
59
+ def _get_sepformer() -> SepformerSeparation:
60
+ global _ENHANCER_SEPFORMER
61
+ if _ENHANCER_SEPFORMER is None:
62
+ _ENHANCER_SEPFORMER = SepformerSeparation.from_hparams(
63
+ source="speechbrain/sepformer-whamr-enhancement",
64
+ savedir="pretrained/sepformer_whamr_enh",
65
+ run_opts={"device": _DEVICE},
66
+ )
67
+ return _ENHANCER_SEPFORMER
68
 
69
 
70
  # -----------------------------
 
73
  def _to_mono(wav: np.ndarray) -> np.ndarray:
74
  """
75
  Ensure mono [T] float32 robustly.
76
+ Accepts [T], [T,C], [C,T]; picks the 'channels' axis if <=8.
 
 
 
 
 
77
  """
78
  wav = np.asarray(wav, dtype=np.float32)
 
79
  if wav.ndim == 1:
80
  return wav
 
81
  if wav.ndim == 2:
82
+ t, u = wav.shape
83
+ if 1 in (t, u):
 
 
84
  return wav.reshape(-1).astype(np.float32)
85
+ if u <= 8: # [T, C]
86
+ return wav.mean(axis=1).astype(np.float32)
87
+ if t <= 8: # [C, T]
 
 
 
 
88
  return wav.mean(axis=0).astype(np.float32)
 
 
89
  return wav.mean(axis=1).astype(np.float32)
90
+ # higher dims: fall back
 
91
  return wav.reshape(-1).astype(np.float32)
92
 
93
 
94
+ def _sanitize(mono: np.ndarray) -> np.ndarray:
95
+ return np.nan_to_num(mono, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
96
+
97
+
98
  def _resample_torch(wav: torch.Tensor, sr_in: int, sr_out: int) -> torch.Tensor:
99
  if sr_in == sr_out:
100
  return wav
 
108
 
109
 
110
  def _presence_boost(wav: torch.Tensor, sr: int, gain_db: float) -> torch.Tensor:
 
111
  if abs(gain_db) < 1e-6:
112
  return wav
113
  center = 4500.0
 
116
 
117
 
118
  def _limit_peak(wav: torch.Tensor, target_dbfs: float = -1.0) -> torch.Tensor:
 
119
  target_amp = 10.0 ** (target_dbfs / 20.0)
120
  peak = torch.max(torch.abs(wav)).item()
121
  if peak > 0:
122
+ wav = wav * min(1.0, target_amp / peak)
 
123
  return torch.clamp(wav, -1.0, 1.0)
124
 
125
 
126
+ def _align_lengths(a: np.ndarray, b: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
127
+ """Pad/crop to same length so we can mix dry/wet safely."""
128
+ n = min(len(a), len(b))
129
+ return a[:n], b[:n]
130
+
131
+
132
+ # -----------------------------
133
+ # Core pipeline
134
+ # -----------------------------
135
+ def _run_metricgan(clean_16k_path: str) -> torch.Tensor:
136
+ enh = _get_metricgan()
137
+ with torch.no_grad():
138
+ out = enh.enhance_file(clean_16k_path) # [1, T] float32 -1..1
139
+ return out
140
+
141
+
142
+ def _run_sepformer(clean_16k_path: str) -> torch.Tensor:
143
+ sep = _get_sepformer()
144
+ with torch.no_grad():
145
+ # Some SB versions return [n_src, T]; others [1, T]
146
+ out = sep.separate_file(path=clean_16k_path)
147
+ # Normalize shape to [1, T]
148
+ if isinstance(out, torch.Tensor):
149
+ if out.dim() == 1:
150
+ out = out.unsqueeze(0)
151
+ elif out.dim() == 2 and out.shape[0] > 1:
152
+ out = out[:1, :] # pick primary enhanced speech
153
+ return out
154
+ # If older API returns numpy or list, convert:
155
+ if hasattr(out, "numpy"):
156
+ t = torch.from_numpy(out)
157
+ if t.dim() == 1:
158
+ t = t.unsqueeze(0)
159
+ elif t.dim() == 2 and t.shape[0] > 1:
160
+ t = t[:1, :]
161
+ return t
162
+ if isinstance(out, (list, tuple)):
163
+ t = torch.tensor(out[0] if isinstance(out[0], (np.ndarray, list)) else out, dtype=torch.float32)
164
+ if t.dim() == 1:
165
+ t = t.unsqueeze(0)
166
+ return t
167
+ raise RuntimeError("Unexpected SepFormer output type")
168
 
169
 
170
  def _enhance_numpy_audio(
171
  audio: Tuple[int, np.ndarray],
172
+ mode: str = "MetricGAN+ (denoise)",
173
+ dry_wet: float = 1.0, # 0..1 (1=fully processed)
174
+ presence_db: float = 0.0, # default 0 for safer tone
175
+ lowcut_hz: float = 0.0, # default 0 (off)
176
  out_sr: Optional[int] = None,
177
  ) -> Tuple[int, np.ndarray]:
178
  """
 
179
  Input: (sr, np.float32 [T] or [T,C])
180
  Returns: (sr_out, np.float32 [T])
181
  """
182
  sr_in, wav_np = audio
183
+ wav_mono = _sanitize(_to_mono(wav_np))
184
 
185
+ # Guard: tiny input
186
+ if wav_mono.size < 32:
 
187
  return sr_in, np.zeros(1600 if sr_in else 1600, dtype=np.float32)
188
 
189
+ dry_t = torch.from_numpy(wav_mono).unsqueeze(0) # [1, T @ sr_in]
190
+ # Prepare 16k mono file for models
191
+ wav_16k = _resample_torch(dry_t, sr_in, 16000)
 
 
192
 
 
193
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_in:
194
  sf.write(tmp_in.name, wav_16k.squeeze(0).numpy(), 16000, subtype="PCM_16")
195
  tmp_in.flush()
196
+ path_16k = tmp_in.name
 
 
 
 
 
 
 
 
 
197
 
198
+ try:
199
+ if mode.startswith("MetricGAN"):
200
+ proc = _run_metricgan(path_16k) # [1, T@16k]
201
+ elif mode.startswith("SepFormer"):
202
+ proc = _run_sepformer(path_16k) # [1, T@16k]
203
+ else: # Bypass (EQ only)
204
+ proc = wav_16k
205
+ finally:
206
+ try:
207
+ os.remove(path_16k)
208
+ except Exception:
209
+ pass
210
+
211
+ # Subtle polish (applied to processed only)
212
+ proc = _highpass(proc, 16000, lowcut_hz)
213
+ proc = _presence_boost(proc, 16000, presence_db)
214
+ proc = _limit_peak(proc, target_dbfs=-1.0)
215
+
216
+ # Resample both to output rate for mixing & export
217
  sr_out = sr_in if (out_sr is None or out_sr <= 0) else int(out_sr)
218
+ proc_out = _resample_torch(proc, 16000, sr_out).squeeze(0).numpy().astype(np.float32)
219
+ dry_out = _resample_torch(dry_t, sr_in, sr_out).squeeze(0).numpy().astype(np.float32)
220
 
221
+ # Align and mix
222
+ proc_out, dry_out = _align_lengths(proc_out, dry_out)
223
+ dry_wet = float(np.clip(dry_wet, 0.0, 1.0))
224
+ mixed = (1.0 - (1.0 - dry_wet)) * proc_out + (1.0 - dry_wet) * dry_out # equivalent to dry*(1-dw) + proc*dw
225
+ mixed = _sanitize(mixed)
226
 
227
+ # Safety: if somehow too tiny, fall back to dry
228
+ if mixed.size < 160:
229
+ return sr_out, dry_out
 
 
230
 
231
+ return sr_out, mixed
232
 
233
 
234
  # -----------------------------
 
236
  # -----------------------------
237
  def gradio_enhance(
238
  audio: Tuple[int, np.ndarray],
239
+ mode: str,
240
+ dry_wet_pct: float,
241
  presence_db: float,
242
  lowcut_hz: float,
243
  output_sr: str,
 
248
  if output_sr in {"44100", "48000"}:
249
  out_sr = int(output_sr)
250
  sr_out, enhanced = _enhance_numpy_audio(
251
+ audio,
252
+ mode=mode,
253
+ dry_wet=dry_wet_pct / 100.0,
254
+ presence_db=float(presence_db),
255
+ lowcut_hz=float(lowcut_hz),
256
+ out_sr=out_sr,
257
  )
258
  return (sr_out, enhanced)
259
 
260
 
261
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
262
+ gr.Markdown("## Voice Clarity Booster")
263
  with gr.Row():
264
  with gr.Column():
265
  in_audio = gr.Audio(
266
  sources=["upload", "microphone"],
267
  type="numpy",
268
+ label="Input",
269
+ )
270
+ mode = gr.Radio(
271
+ choices=["MetricGAN+ (denoise)", "SepFormer (dereverb+denoise)", "Bypass (EQ only)"],
272
+ value="MetricGAN+ (denoise)",
273
+ label="Mode",
274
+ )
275
+ dry_wet = gr.Slider(
276
+ minimum=0, maximum=100, value=85, step=1,
277
+ label="Dry/Wet Mix (%) — lower to reduce artifacts"
278
  )
279
  presence = gr.Slider(
280
+ minimum=-12, maximum=12, value=0, step=0.5, label="Presence Boost (dB)"
281
  )
282
  lowcut = gr.Slider(
283
+ minimum=0, maximum=200, value=0, step=5, label="Low-Cut (Hz)"
284
  )
285
  out_sr = gr.Radio(
286
  choices=["Original", "44100", "48000"],
 
291
  with gr.Column():
292
  out_audio = gr.Audio(type="numpy", label="Enhanced", autoplay=True)
293
 
294
+ btn.click(
295
+ gradio_enhance,
296
+ inputs=[in_audio, mode, dry_wet, presence, lowcut, out_sr],
297
+ outputs=[out_audio],
298
+ )
299
 
300
+ # Start server (Hugging Face Spaces expects this unguarded)
301
  demo.launch()