Files changed (1) hide show
  1. app.py +536 -469
app.py CHANGED
@@ -1,491 +1,558 @@
1
- # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Liu Yue)
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import spaces
15
- import os
16
- import sys
17
- import argparse
18
  import gradio as gr
19
- import numpy as np
 
20
  import torch
21
  import torchaudio
22
- import random
 
 
23
  import librosa
24
- from funasr import AutoModel
25
- from funasr.utils.postprocess_utils import rich_transcription_postprocess
26
- ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
27
- sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
28
-
29
- from modelscope import snapshot_download, HubApi
30
- from huggingface_hub import snapshot_download as hf_snapshot_download
31
-
32
- hf_snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
33
- snapshot_download('iic/SenseVoiceSmall', local_dir='pretrained_models/SenseVoiceSmall')
34
- hf_snapshot_download('FunAudioLLM/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
35
- os.system(
36
- "cd pretrained_models/CosyVoice-ttsfrd/ && "
37
- "pip install ttsfrd_dependency-0.1-py3-none-any.whl && "
38
- "pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl && "
39
- "apt install -y unzip && "
40
- "rm -rf resource && "
41
- "unzip resource.zip -d ."
42
- )
43
-
44
- from cosyvoice.cli.cosyvoice import AutoModel as CosyVoiceAutoModel
45
- from cosyvoice.utils.file_utils import logging, load_wav
46
- from cosyvoice.utils.common import set_all_random_seed, instruct_list
47
-
48
- # -----------------------------
49
- # i18n (En: British spelling)
50
- # -----------------------------
51
- LANG_EN = "En"
52
- LANG_ZH = "Zh"
53
-
54
- MODE_ZERO_SHOT = "zero_shot"
55
- MODE_INSTRUCT = "instruct"
56
-
57
- UI_TEXT = {
58
- LANG_EN: {
59
- "lang_label": "Language",
60
- "md_links": (
61
- "### Repository [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) \n"
62
- "Pre-trained model [Fun-CosyVoice3-0.5B](https://huggingface.co/FunAudioLLM/Fun-CosyVoice3-0.5B-2512) \n"
63
- "[CosyVoice2-0.5B](https://www.modelscope.cn/models/iic/CosyVoice2-0.5B) \n"
64
- "[CosyVoice-300M](https://www.modelscope.cn/models/iic/CosyVoice-300M) \n"
65
- "[CosyVoice-300M-Instruct](https://www.modelscope.cn/models/iic/CosyVoice-300M-Instruct) \n"
66
- "[CosyVoice-300M-SFT](https://www.modelscope.cn/models/iic/CosyVoice-300M-SFT)"
67
- ),
68
- "md_hint": "#### Enter the text to synthesise, choose an inference mode, and follow the steps.",
69
- "tts_label": "Text to synthesise",
70
- "tts_default": "Her handwriting is very neat, which suggests she likes things tidy.",
71
- "mode_label": "Inference mode",
72
- "mode_zero_shot": "3s fast voice cloning",
73
- "mode_instruct": "Natural language control",
74
- "steps_label": "Steps",
75
- "steps_zero_shot": (
76
- "1. Choose a prompt audio file, or record prompt audio (≤ 30s). If both are provided, the uploaded file is used.\n"
77
- "2. Enter the prompt text.\n"
78
- "3. Click Generate audio."
79
- ),
80
- "steps_instruct": (
81
- "1. Choose a prompt audio file, or record prompt audio (≤ 30s). If both are provided, the uploaded file is used.\n"
82
- "2. Choose/enter the instruct text.\n"
83
- "3. Click Generate audio."
84
- ),
85
- "stream_label": "Streaming inference",
86
- "stream_no": "No",
87
- "dice": "🎲",
88
- "seed_label": "Random inference seed",
89
- "upload_label": "Choose prompt audio file (sample rate ≥ 16 kHz)",
90
- "record_label": "Record prompt audio",
91
- "prompt_text_label": "Prompt text",
92
- "prompt_text_ph": "Enter prompt text (auto recognition supported; you can edit the result)...",
93
- "instruct_label": "Choose instruct text",
94
- "generate_btn": "Generate audio",
95
- "output_label": "Synthesised audio",
96
- "warn_too_long": "Your input text is too long; please keep it within 200 characters.",
97
- "warn_instruct_empty": "You are using Natural language control; please enter instruct text.",
98
- "info_instruct_need_prompt": "You are using Natural language control; please provide prompt audio.",
99
- "warn_prompt_missing": "Prompt audio is empty. Did you forget to provide prompt audio?",
100
- "warn_prompt_sr_low": "Prompt audio sample rate {} is below {}.",
101
- "warn_prompt_too_long_10s": "Please keep the prompt audio within 10 seconds to avoid poor inference quality.",
102
- "warn_prompt_text_missing": "Prompt text is empty. Did you forget to enter prompt text?",
103
- "info_instruct_ignored": "You are using 3s fast voice cloning; instruct text will be ignored.",
104
- "warn_invalid_mode": "Invalid mode selection.",
105
- },
106
- LANG_ZH: {
107
- "lang_label": "语言",
108
- "md_links": (
109
- "### 代码库 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) \n"
110
- "预训练模型 [Fun-CosyVoice3-0.5B](https://huggingface.co/FunAudioLLM/Fun-CosyVoice3-0.5B-2512) \n"
111
- "[CosyVoice2-0.5B](https://www.modelscope.cn/models/iic/CosyVoice2-0.5B) \n"
112
- "[CosyVoice-300M](https://www.modelscope.cn/models/iic/CosyVoice-300M) \n"
113
- "[CosyVoice-300M-Instruct](https://www.modelscope.cn/models/iic/CosyVoice-300M-Instruct) \n"
114
- "[CosyVoice-300M-SFT](https://www.modelscope.cn/models/iic/CosyVoice-300M-SFT)"
115
- ),
116
- "md_hint": "#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作",
117
- "tts_label": "输入合成文本",
118
- "tts_default": "Her handwriting is [M][AY0][N][UW1][T]并且很整洁,说明她[h][ào]干净。",
119
- "mode_label": "选择推理模式",
120
- "mode_zero_shot": "3s极速复刻",
121
- "mode_instruct": "自然语言控制",
122
- "steps_label": "操作步骤",
123
- "steps_zero_shot": (
124
- "1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n"
125
- "2. 输入prompt文本\n"
126
- "3. 点击生成音频按钮"
127
- ),
128
- "steps_instruct": (
129
- "1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n"
130
- "2. 输入instruct文本\n"
131
- "3. 点击生成音频按钮"
132
- ),
133
- "stream_label": "是否流式推理",
134
- "stream_no": "否",
135
- "dice": "🎲",
136
- "seed_label": "随机推理种子",
137
- "upload_label": "选择prompt音频文件,注意采样率不低于16khz",
138
- "record_label": "录制prompt音频文件",
139
- "prompt_text_label": "prompt文本",
140
- "prompt_text_ph": "请输入prompt文本,支持自动识别,您可以自行修正识别结果...",
141
- "instruct_label": "选择instruct文本",
142
- "generate_btn": "生成音频",
143
- "output_label": "合成音频",
144
- "warn_too_long": "您输入的文字过长,请限制在200字以内",
145
- "warn_instruct_empty": "您正在使用自然语言控制模式, 请输入instruct文本",
146
- "info_instruct_need_prompt": "您正在使用自然语言控制模式, 请输入prompt音频",
147
- "warn_prompt_missing": "prompt音频为空,您是否忘记输入prompt音频?",
148
- "warn_prompt_sr_low": "prompt音频采样率{}低于{}",
149
- "warn_prompt_too_long_10s": "请限制输入音频在10s内,避免推理效果过低",
150
- "warn_prompt_text_missing": "prompt文本为空,您是否忘记输入prompt文本?",
151
- "info_instruct_ignored": "您正在使用3s极速复刻模式,instruct文本会被忽略!",
152
- "warn_invalid_mode": "无效的模式选择",
153
- },
154
- }
155
-
156
-
157
- def t(lang: str, key: str) -> str:
158
- lang = lang if lang in UI_TEXT else LANG_ZH
159
- return UI_TEXT[lang][key]
160
-
161
-
162
- def mode_choices(lang: str):
163
- return [
164
- (t(lang, "mode_zero_shot"), MODE_ZERO_SHOT),
165
- (t(lang, "mode_instruct"), MODE_INSTRUCT),
166
- ]
167
-
168
-
169
- def steps_for(lang: str, mode_value: str) -> str:
170
- if mode_value == MODE_INSTRUCT:
171
- return t(lang, "steps_instruct")
172
- return t(lang, "steps_zero_shot")
173
-
174
-
175
- # -----------------------------
176
- # Audio post-process
177
- # -----------------------------
178
  max_val = 0.8
179
  top_db = 60
180
  hop_length = 220
181
  win_length = 440
182
 
183
-
184
- def generate_seed():
185
- seed = random.randint(1, 100000000)
186
- return {"__type__": "update", "value": seed}
187
-
188
-
189
- def postprocess(wav):
190
- speech = load_wav(wav, target_sr=target_sr, min_sr=16000)
191
- speech, _ = librosa.effects.trim(
192
- speech, top_db=top_db, frame_length=win_length, hop_length=hop_length
193
- )
194
- if speech.abs().max() > max_val:
195
- speech = speech / speech.abs().max() * max_val
196
- speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1)
197
- torchaudio.save(wav, speech, target_sr)
198
- return wav
199
-
200
- @spaces.GPU
201
- def prompt_wav_recognition(prompt_wav):
202
- res = asr_model.generate(
203
- input=prompt_wav,
204
- language="auto", # "zn", "en", "yue", "ja", "ko", "nospeech"
205
- use_itn=True,
206
- )
207
- text = res[0]["text"].split("|>")[-1]
208
- return text
209
-
210
-
211
- @spaces.GPU
212
- def generate_audio(
213
- tts_text,
214
- mode_value,
215
- prompt_text,
216
- prompt_wav_upload,
217
- prompt_wav_record,
218
- instruct_text,
219
- seed,
220
- stream,
221
- ui_lang,
222
- ):
223
- stream = False
224
-
225
- if len(tts_text) > 200:
226
- gr.Warning(t(ui_lang, "warn_too_long"))
227
- return (target_sr, default_data)
228
-
229
- sft_dropdown, speed = "", 1.0
230
-
231
- if prompt_wav_upload is not None:
232
- prompt_wav = prompt_wav_upload
233
- elif prompt_wav_record is not None:
234
- prompt_wav = prompt_wav_record
235
- else:
236
- prompt_wav = None
237
-
238
- # instruct mode requirements
239
- if mode_value == MODE_INSTRUCT:
240
- if instruct_text == "":
241
- gr.Warning(t(ui_lang, "warn_instruct_empty"))
242
- return (target_sr, default_data)
243
- if prompt_wav is None:
244
- gr.Info(t(ui_lang, "info_instruct_need_prompt"))
245
- return (target_sr, default_data)
246
-
247
- # zero-shot requirements
248
- if mode_value == MODE_ZERO_SHOT:
249
- if prompt_wav is None:
250
- gr.Warning(t(ui_lang, "warn_prompt_missing"))
251
- return (target_sr, default_data)
252
-
253
- info = torchaudio.info(prompt_wav)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  if info.sample_rate < prompt_sr:
255
- gr.Warning(t(ui_lang, "warn_prompt_sr_low").format(info.sample_rate, prompt_sr))
256
- return (target_sr, default_data)
257
-
258
- if info.num_frames / info.sample_rate > 10:
259
- gr.Warning(t(ui_lang, "warn_prompt_too_long_10s"))
260
- return (target_sr, default_data)
261
-
262
- if prompt_text == "":
263
- gr.Warning(t(ui_lang, "warn_prompt_text_missing"))
264
- return (target_sr, default_data)
265
-
266
- if instruct_text != "":
267
- gr.Info(t(ui_lang, "info_instruct_ignored"))
268
-
269
- if mode_value == MODE_ZERO_SHOT:
270
- logging.info("get zero_shot inference request")
 
 
 
 
 
 
271
  set_all_random_seed(seed)
 
 
272
  speech_list = []
273
- for i in cosyvoice.inference_zero_shot(
274
- tts_text,
275
- "You are a helpful assistant.<|endofprompt|>" + prompt_text,
276
- postprocess(prompt_wav),
277
- stream=stream,
278
- speed=speed,
279
  ):
280
  speech_list.append(i["tts_speech"])
281
- return (target_sr, torch.concat(speech_list, dim=1).numpy().flatten())
282
-
283
- if mode_value == MODE_INSTRUCT:
284
- logging.info("get instruct inference request")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  set_all_random_seed(seed)
 
 
286
  speech_list = []
287
- for i in cosyvoice.inference_instruct2(
288
- tts_text,
289
- instruct_text,
290
- postprocess(prompt_wav),
291
- stream=stream,
292
- speed=speed,
293
  ):
294
  speech_list.append(i["tts_speech"])
295
- return (target_sr, torch.concat(speech_list, dim=1).numpy().flatten())
296
-
297
- gr.Warning(t(ui_lang, "warn_invalid_mode"))
298
- return (target_sr, default_data)
299
-
300
-
301
- def on_mode_change(mode_value, ui_lang):
302
- return steps_for(ui_lang, mode_value)
303
-
304
-
305
- def on_language_change(ui_lang, current_mode_value):
306
- lang = ui_lang if ui_lang in (LANG_EN, LANG_ZH) else LANG_ZH
307
- return (
308
- gr.update(value=UI_TEXT[lang]["md_links"]), # md_links
309
- gr.update(value=UI_TEXT[lang]["md_hint"]), # md_hint
310
- gr.update(label=t(lang, "lang_label")), # lang_radio label
311
- gr.update(choices=mode_choices(lang), label=t(lang, "mode_label")), # mode radio
312
- gr.update(value=steps_for(lang, current_mode_value), label=t(lang, "steps_label")), # steps box
313
- gr.update(
314
- choices=[(t(lang, "stream_no"), False)],
315
- label=t(lang, "stream_label"),
316
- value=False,
317
- ), # stream radio
318
- gr.update(value=t(lang, "dice")), # seed button text
319
- gr.update(label=t(lang, "seed_label")), # seed label
320
- gr.update(label=t(lang, "tts_label"), value=t(lang, "tts_default")), # tts textbox
321
- gr.update(label=t(lang, "upload_label")), # upload label
322
- gr.update(label=t(lang, "record_label")), # record label
323
- gr.update(label=t(lang, "prompt_text_label"), placeholder=t(lang, "prompt_text_ph")), # prompt text
324
- gr.update(label=t(lang, "instruct_label")), # instruct dropdown
325
- gr.update(value=t(lang, "generate_btn")), # generate button
326
- gr.update(label=t(lang, "output_label")), # output label
327
- )
328
-
329
-
330
- def main():
331
- with gr.Blocks() as demo:
332
- md_links = gr.Markdown(UI_TEXT[LANG_ZH]["md_links"])
333
- md_hint = gr.Markdown(UI_TEXT[LANG_ZH]["md_hint"])
334
-
335
- lang_radio = gr.Radio(
336
- choices=[LANG_EN, LANG_ZH],
337
- value=LANG_ZH,
338
- label=t(LANG_ZH, "lang_label"),
339
- )
340
-
341
- tts_text = gr.Textbox(
342
- label=t(LANG_ZH, "tts_label"),
343
- lines=1,
344
- value=t(LANG_ZH, "tts_default"),
345
- )
346
-
347
- with gr.Row():
348
- mode_radio = gr.Radio(
349
- choices=mode_choices(LANG_ZH),
350
- label=t(LANG_ZH, "mode_label"),
351
- value=MODE_ZERO_SHOT,
352
- )
353
- steps_box = gr.Textbox(
354
- label=t(LANG_ZH, "steps_label"),
355
- value=steps_for(LANG_ZH, MODE_ZERO_SHOT),
356
- lines=4,
357
- interactive=False,
358
- scale=0.5,
359
- )
360
- stream = gr.Radio(
361
- choices=[(t(LANG_ZH, "stream_no"), False)],
362
- label=t(LANG_ZH, "stream_label"),
363
- value=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  )
365
- with gr.Column(scale=0.25):
366
- seed_button = gr.Button(value=t(LANG_ZH, "dice"))
367
- seed = gr.Number(value=0, label=t(LANG_ZH, "seed_label"))
368
-
369
- with gr.Row():
370
- prompt_wav_upload = gr.Audio(
371
- sources="upload",
372
- type="filepath",
373
- label=t(LANG_ZH, "upload_label"),
374
- )
375
- prompt_wav_record = gr.Audio(
376
- sources="microphone",
377
- type="filepath",
378
- label=t(LANG_ZH, "record_label"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  )
380
-
381
- prompt_text = gr.Textbox(
382
- label=t(LANG_ZH, "prompt_text_label"),
383
- lines=1,
384
- placeholder=t(LANG_ZH, "prompt_text_ph"),
385
- value="",
386
- )
387
- instruct_text = gr.Dropdown(
388
- choices=instruct_list,
389
- label=t(LANG_ZH, "instruct_label"),
390
- value=instruct_list[0],
391
- )
392
-
393
- generate_button = gr.Button(t(LANG_ZH, "generate_btn"))
394
- audio_output = gr.Audio(
395
- label=t(LANG_ZH, "output_label"),
396
- autoplay=True,
397
- streaming=False,
398
- )
399
-
400
- seed_button.click(generate_seed, inputs=[], outputs=seed)
401
-
402
- generate_button.click(
403
- generate_audio,
404
- inputs=[
405
- tts_text,
406
- mode_radio,
407
- prompt_text,
408
- prompt_wav_upload,
409
- prompt_wav_record,
410
- instruct_text,
411
- seed,
412
- stream,
413
- lang_radio, # ui_lang
414
- ],
415
- outputs=[audio_output],
416
- )
417
-
418
- mode_radio.change(
419
- fn=on_mode_change,
420
- inputs=[mode_radio, lang_radio],
421
- outputs=[steps_box],
422
- )
423
-
424
- prompt_wav_upload.change(
425
- fn=prompt_wav_recognition,
426
- inputs=[prompt_wav_upload],
427
- outputs=[prompt_text],
428
- )
429
- prompt_wav_record.change(
430
- fn=prompt_wav_recognition,
431
- inputs=[prompt_wav_record],
432
- outputs=[prompt_text],
433
- )
434
-
435
- lang_radio.change(
436
- fn=on_language_change,
437
- inputs=[lang_radio, mode_radio],
438
- outputs=[
439
- md_links,
440
- md_hint,
441
- lang_radio,
442
- mode_radio,
443
- steps_box,
444
- stream,
445
- seed_button,
446
- seed,
447
- tts_text,
448
- prompt_wav_upload,
449
- prompt_wav_record,
450
- prompt_text,
451
- instruct_text,
452
- generate_button,
453
- audio_output,
454
- ],
455
- )
456
-
457
- demo.queue(default_concurrency_limit=4).launch()
458
-
459
 
460
  if __name__ == "__main__":
461
- cosyvoice = CosyVoiceAutoModel(
462
- model_dir="pretrained_models/Fun-CosyVoice3-0.5B",
463
- load_trt=False,
464
- fp16=False,
465
- )
466
- sft_spk = cosyvoice.list_available_spks()
467
-
468
- for stream in [False]:
469
- for i, j in enumerate(
470
- cosyvoice.inference_zero_shot(
471
- "收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。",
472
- "You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。",
473
- "zero_shot_prompt.wav",
474
- stream=stream,
475
- )
476
- ):
477
- continue
478
-
479
- prompt_sr = 16000
480
- target_sr = 24000
481
- default_data = np.zeros(target_sr)
482
-
483
- model_dir = "pretrained_models/SenseVoiceSmall"
484
- asr_model = AutoModel(
485
- model=model_dir,
486
- disable_update=True,
487
- log_level="DEBUG",
488
- device="cuda:0",
489
- )
490
-
491
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import sys
3
+ import os
4
  import torch
5
  import torchaudio
6
+ import torchaudio.transforms as T
7
+ import numpy as np
8
+ import tempfile
9
  import librosa
10
+ from pathlib import Path
11
+
12
+ print("=" * 60)
13
+ print("🎙️ Fun-CosyVoice3 TTS Initialization")
14
+ print("=" * 60)
15
+
16
+ # Step 1: Setup directories
17
+ print("\n📁 Step 1: Setting up directories...")
18
+ WORK_DIR = Path.cwd()
19
+ COSYVOICE_DIR = WORK_DIR / "CosyVoice"
20
+ MODEL_DIR = COSYVOICE_DIR / "pretrained_models" / "Fun-CosyVoice3-0.5B"
21
+
22
+ print(f"Working directory: {WORK_DIR}")
23
+ print(f"CosyVoice directory: {COSYVOICE_DIR}")
24
+ print(f"Model directory: {MODEL_DIR}")
25
+
26
+ # Step 2: Clone CosyVoice if needed
27
+ if not COSYVOICE_DIR.exists():
28
+ print("\n📥 Step 2: Cloning CosyVoice repository...")
29
+ import subprocess
30
+ try:
31
+ subprocess.run([
32
+ 'git', 'clone', '--recursive',
33
+ 'https://github.com/FunAudioLLM/CosyVoice.git',
34
+ str(COSYVOICE_DIR)
35
+ ], check=True)
36
+ print("✅ Repository cloned successfully")
37
+ except Exception as e:
38
+ print(f"❌ Failed to clone repository: {e}")
39
+ raise
40
+ else:
41
+ print("\n✅ Step 2: CosyVoice repository already exists")
42
+
43
+ # Step 3: Download models
44
+ if not MODEL_DIR.exists():
45
+ print("\n📥 Step 3: Downloading models (this may take a few minutes)...")
46
+ from huggingface_hub import snapshot_download
47
+ try:
48
+ print("Downloading Fun-CosyVoice3-0.5B-2512...")
49
+ snapshot_download(
50
+ 'FunAudioLLM/Fun-CosyVoice3-0.5B-2512',
51
+ local_dir=str(MODEL_DIR),
52
+ local_dir_use_symlinks=False
53
+ )
54
+ print(" Model downloaded successfully")
55
+ except Exception as e:
56
+ print(f" Failed to download model: {e}")
57
+ raise
58
+ else:
59
+ print("\n✅ Step 3: Models already exist")
60
+
61
+ # Step 4: Download ttsfrd (optional)
62
+ TTSFRD_DIR = COSYVOICE_DIR / "pretrained_models" / "CosyVoice-ttsfrd"
63
+ if not TTSFRD_DIR.exists():
64
+ print("\n📥 Step 4: Downloading ttsfrd...")
65
+ from huggingface_hub import snapshot_download
66
+ try:
67
+ snapshot_download(
68
+ 'FunAudioLLM/CosyVoice-ttsfrd',
69
+ local_dir=str(TTSFRD_DIR),
70
+ local_dir_use_symlinks=False
71
+ )
72
+ print(" ttsfrd downloaded successfully")
73
+ except Exception as e:
74
+ print(f"⚠️ Failed to download ttsfrd (will use WeText): {e}")
75
+ else:
76
+ print("\n✅ Step 4: ttsfrd already exists")
77
+
78
+ # Step 5: Add to Python path
79
+ print("\n🔧 Step 5: Configuring Python path...")
80
+ sys.path.insert(0, str(COSYVOICE_DIR))
81
+ sys.path.insert(0, str(COSYVOICE_DIR / "third_party" / "Matcha-TTS"))
82
+ print(f"Added to path: {COSYVOICE_DIR}")
83
+ print(f"Added to path: {COSYVOICE_DIR / 'third_party' / 'Matcha-TTS'}")
84
+
85
+ # Step 6: Import CosyVoice
86
+ print("\n📦 Step 6: Importing CosyVoice...")
87
+ try:
88
+ from cosyvoice.cli.cosyvoice import AutoModel as CosyVoiceAutoModel
89
+ from cosyvoice.utils.file_utils import load_wav
90
+ from cosyvoice.utils.common import set_all_random_seed
91
+ print("✅ CosyVoice imported successfully")
92
+ except Exception as e:
93
+ print(f"❌ Failed to import CosyVoice: {e}")
94
+ raise
95
+
96
+ print("\n" + "=" * 60)
97
+ print("✅ Initialization completed successfully!")
98
+ print("=" * 60 + "\n")
99
+
100
+ # Global variables
101
+ cosyvoice = None
102
+ target_sr = 24000
103
+ prompt_sr = 16000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  max_val = 0.8
105
  top_db = 60
106
  hop_length = 220
107
  win_length = 440
108
 
109
+ def load_model():
110
+ """Load the CosyVoice model"""
111
+ global cosyvoice
112
+ if cosyvoice is None:
113
+ print("🚀 Loading CosyVoice model...")
114
+ try:
115
+ cosyvoice = CosyVoiceAutoModel(
116
+ model_dir=str(MODEL_DIR),
117
+ load_trt=False,
118
+ fp16=False
119
+ )
120
+ print("✅ Model loaded successfully!")
121
+ except Exception as e:
122
+ print(f"❌ Error loading model: {e}")
123
+ import traceback
124
+ traceback.print_exc()
125
+ raise gr.Error(f"Failed to load model: {e}")
126
+ return cosyvoice
127
+
128
+ def postprocess(wav_path):
129
+ """Post-process audio - trim silence and normalize (from official code)"""
130
+ try:
131
+ speech = load_wav(wav_path, target_sr=target_sr, min_sr=16000)
132
+
133
+ # Trim silence from beginning and end
134
+ speech, _ = librosa.effects.trim(
135
+ speech, top_db=top_db,
136
+ frame_length=win_length,
137
+ hop_length=hop_length
138
+ )
139
+
140
+ # Normalize if too loud
141
+ if speech.abs().max() > max_val:
142
+ speech = speech / speech.abs().max() * max_val
143
+
144
+ # Add silence at the end
145
+ speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1)
146
+
147
+ # Save back
148
+ torchaudio.save(wav_path, speech, target_sr)
149
+ return wav_path
150
+ except Exception as e:
151
+ print(f"⚠️ Postprocess warning: {e}")
152
+ return wav_path
153
+
154
+ def process_audio(audio_input):
155
+ """
156
+ Convert audio input to proper format for CosyVoice
157
+ Handles: stereo->mono, different dtypes, resampling
158
+ """
159
+ if audio_input is None:
160
+ return None
161
+
162
+ try:
163
+ sr, audio_data = audio_input
164
+
165
+ print(f"📊 Input audio - shape: {audio_data.shape}, dtype: {audio_data.dtype}, sr: {sr}Hz")
166
+
167
+ # Step 1: Normalize data type to float32
168
+ if audio_data.dtype == np.int16:
169
+ audio_data = audio_data.astype(np.float32) / 32768.0
170
+ elif audio_data.dtype == np.int32:
171
+ audio_data = audio_data.astype(np.float32) / 2147483648.0
172
+ elif audio_data.dtype == np.float64:
173
+ audio_data = audio_data.astype(np.float32)
174
+ elif audio_data.dtype != np.float32:
175
+ audio_data = audio_data.astype(np.float32)
176
+
177
+ # Step 2: Convert stereo to mono if needed
178
+ if len(audio_data.shape) == 2:
179
+ print(f" Converting stereo ({audio_data.shape[1]} channels) to mono...")
180
+ if audio_data.shape[1] == 2:
181
+ audio_data = audio_data.mean(axis=1)
182
+ elif audio_data.shape[1] == 1:
183
+ audio_data = audio_data.squeeze()
184
+ else:
185
+ audio_data = audio_data[:, 0]
186
+
187
+ # Step 3: Ensure 1D array
188
+ audio_data = audio_data.flatten()
189
+
190
+ # Step 4: Check and adjust duration
191
+ duration = len(audio_data) / sr
192
+ print(f" Duration: {duration:.2f}s")
193
+
194
+ if duration < 1:
195
+ return None, "❌ Audio too short (minimum 1 second)"
196
+
197
+ if duration > 30:
198
+ print(f" ⚠️ Truncating audio from {duration:.2f}s to 30s")
199
+ audio_data = audio_data[:sr * 30]
200
+
201
+ # Step 5: Convert to torch tensor
202
+ audio_tensor = torch.from_numpy(audio_data).float()
203
+
204
+ # Step 6: Add channel dimension (1, samples)
205
+ if audio_tensor.dim() == 1:
206
+ audio_tensor = audio_tensor.unsqueeze(0)
207
+
208
+ print(f" Tensor shape: {audio_tensor.shape}")
209
+
210
+ # Step 7: Resample if needed
211
+ if sr != target_sr:
212
+ print(f" 🔄 Resampling from {sr}Hz to {target_sr}Hz...")
213
+ resampler = T.Resample(sr, target_sr)
214
+ audio_tensor = resampler(audio_tensor)
215
+ sr = target_sr
216
+
217
+ # Step 8: Save to temporary file
218
+ temp_path = tempfile.mktemp(suffix='.wav')
219
+ torchaudio.save(temp_path, audio_tensor, sr)
220
+
221
+ # Step 9: Post-process (trim silence, normalize)
222
+ temp_path = postprocess(temp_path)
223
+
224
+ print(f" ✅ Audio processed and saved: {os.path.basename(temp_path)}")
225
+ return temp_path
226
+
227
+ except Exception as e:
228
+ print(f"❌ Error processing audio: {e}")
229
+ import traceback
230
+ traceback.print_exc()
231
+ return None
232
+
233
+ def zero_shot_tts(tts_text, prompt_text, prompt_audio, seed, speed):
234
+ """Zero-shot TTS synthesis - following official code structure"""
235
+ try:
236
+ # Validation
237
+ if not tts_text or not tts_text.strip():
238
+ return None, "❌ Please provide text to synthesize"
239
+
240
+ if len(tts_text) > 200:
241
+ return None, "❌ Text too long, please keep within 200 characters"
242
+
243
+ if not prompt_audio:
244
+ return None, "❌ Please upload reference audio"
245
+
246
+ if not prompt_text or not prompt_text.strip():
247
+ return None, "❌ Please provide prompt text"
248
+
249
+ # Load model
250
+ model = load_model()
251
+
252
+ # Process audio
253
+ prompt_audio_path = process_audio(prompt_audio)
254
+ if prompt_audio_path is None:
255
+ return None, "❌ Failed to process audio"
256
+
257
+ # Check sample rate
258
+ info = torchaudio.info(prompt_audio_path)
259
  if info.sample_rate < prompt_sr:
260
+ return None, f"❌ Audio sample rate {info.sample_rate} is below {prompt_sr}Hz"
261
+
262
+ # Check duration
263
+ duration = info.num_frames / info.sample_rate
264
+ if duration > 10:
265
+ return None, "❌ Please keep prompt audio within 10 seconds"
266
+
267
+ # Clean inputs
268
+ tts_text = tts_text.strip()
269
+ prompt_text = prompt_text.strip()
270
+
271
+ # Build prompt following official format
272
+ # IMPORTANT: This is the official format from the code
273
+ full_prompt = f"You are a helpful assistant.<|endofprompt|>{prompt_text}"
274
+
275
+ print(f"\n🎵 Generating speech...")
276
+ print(f" TTS text: '{tts_text[:100]}{'...' if len(tts_text) > 100 else ''}'")
277
+ print(f" Prompt text: '{prompt_text[:50]}{'...' if len(prompt_text) > 50 else ''}'")
278
+ print(f" Full prompt: '{full_prompt[:80]}{'...' if len(full_prompt) > 80 else ''}'")
279
+ print(f" Seed: {seed}, Speed: {speed}")
280
+
281
+ # Set random seed
282
  set_all_random_seed(seed)
283
+
284
+ # Generate - following official code exactly
285
  speech_list = []
286
+ for i in model.inference_zero_shot(
287
+ tts_text, # Text to synthesize
288
+ full_prompt, # Prompt with special format
289
+ prompt_audio_path, # Processed prompt audio
290
+ stream=False,
291
+ speed=speed
292
  ):
293
  speech_list.append(i["tts_speech"])
294
+
295
+ # Concatenate all speech segments
296
+ output_speech = torch.concat(speech_list, dim=1)
297
+
298
+ # Clean up
299
+ if os.path.exists(prompt_audio_path):
300
+ os.remove(prompt_audio_path)
301
+
302
+ print(f" ✅ Generated audio shape: {output_speech.shape}")
303
+ print("✅ Speech generated successfully!\n")
304
+
305
+ # Return as numpy array for Gradio
306
+ return (target_sr, output_speech.numpy().flatten()), "✅ Success!"
307
+
308
+ except Exception as e:
309
+ print(f"❌ Error in zero_shot_tts: {e}")
310
+ import traceback
311
+ traceback.print_exc()
312
+
313
+ # Clean up on error
314
+ try:
315
+ if prompt_audio_path and os.path.exists(prompt_audio_path):
316
+ os.remove(prompt_audio_path)
317
+ except:
318
+ pass
319
+
320
+ return None, f"❌ Error: {str(e)}"
321
+
322
+ def instruct_tts(tts_text, instruct_text, prompt_audio, seed, speed):
323
+ """Instruction-based TTS - following official code structure"""
324
+ try:
325
+ # Validation
326
+ if not tts_text or not tts_text.strip():
327
+ return None, "❌ Please provide text to synthesize"
328
+
329
+ if len(tts_text) > 200:
330
+ return None, "❌ Text too long, please keep within 200 characters"
331
+
332
+ if not prompt_audio:
333
+ return None, "❌ Please upload reference audio"
334
+
335
+ if not instruct_text or not instruct_text.strip():
336
+ return None, "❌ Please provide instruction text"
337
+
338
+ # Load model
339
+ model = load_model()
340
+
341
+ # Process audio
342
+ prompt_audio_path = process_audio(prompt_audio)
343
+ if prompt_audio_path is None:
344
+ return None, "❌ Failed to process audio"
345
+
346
+ # Clean inputs
347
+ tts_text = tts_text.strip()
348
+ instruct_text = instruct_text.strip()
349
+
350
+ print(f"\n📝 Generating speech with instruction...")
351
+ print(f" TTS text: '{tts_text[:100]}{'...' if len(tts_text) > 100 else ''}'")
352
+ print(f" Instruction: '{instruct_text}'")
353
+ print(f" Seed: {seed}, Speed: {speed}")
354
+
355
+ # Set random seed
356
  set_all_random_seed(seed)
357
+
358
+ # Generate - following official code
359
  speech_list = []
360
+ for i in model.inference_instruct2(
361
+ tts_text, # Text to synthesize
362
+ instruct_text, # Instruction
363
+ prompt_audio_path, # Processed prompt audio
364
+ stream=False,
365
+ speed=speed
366
  ):
367
  speech_list.append(i["tts_speech"])
368
+
369
+ # Concatenate all speech segments
370
+ output_speech = torch.concat(speech_list, dim=1)
371
+
372
+ # Clean up
373
+ if os.path.exists(prompt_audio_path):
374
+ os.remove(prompt_audio_path)
375
+
376
+ print(f" ✅ Generated audio shape: {output_speech.shape}")
377
+ print("✅ Speech generated successfully!\n")
378
+
379
+ # Return as numpy array for Gradio
380
+ return (target_sr, output_speech.numpy().flatten()), "✅ Success!"
381
+
382
+ except Exception as e:
383
+ print(f"❌ Error: {e}")
384
+ import traceback
385
+ traceback.print_exc()
386
+
387
+ # Clean up on error
388
+ try:
389
+ if prompt_audio_path and os.path.exists(prompt_audio_path):
390
+ os.remove(prompt_audio_path)
391
+ except:
392
+ pass
393
+
394
+ return None, f" Error: {str(e)}"
395
+
396
+ # Instruction options (from official code)
397
+ instruct_options = [
398
+ "You are a helpful assistant. 请用广东话表达。<|endofprompt|>",
399
+ "You are a helpful assistant. 请用尽可能快地语速说一句话。<|endofprompt|>",
400
+ "You are a helpful assistant. 请用正常的语速说一句话。<|endofprompt|>",
401
+ "You are a helpful assistant. 请用慢一点的语速说一句话。<|endofprompt|>",
402
+ "You are a helpful assistant. Please speak in a professional tone.<|endofprompt|>",
403
+ "You are a helpful assistant. Please speak in a friendly tone.<|endofprompt|>",
404
+ ]
405
+
406
+ # Create Gradio interface
407
+ with gr.Blocks(title="Fun-CosyVoice3 TTS") as demo:
408
+ gr.Markdown("""
409
+ # 🎙️ Fun-CosyVoice3-0.5B Text-to-Speech
410
+
411
+ Advanced multilingual zero-shot TTS system supporting **9 languages** and **18+ Chinese dialects**.
412
+
413
+ Based on the official [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) implementation.
414
+ """)
415
+
416
+ with gr.Tabs():
417
+ # Tab 1: Zero-Shot TTS
418
+ with gr.Tab("🎯 Zero-Shot Voice Cloning (3s Fast Cloning)"):
419
+ gr.Markdown("""
420
+ ### Clone any voice with 3-10 seconds of reference audio
421
+
422
+ **Steps:**
423
+ 1. Upload or record reference audio (≤30s, ≥16kHz)
424
+ 2. Enter the **prompt text** (transcription of the reference audio)
425
+ 3. Enter the **text to synthesize** (what you want the voice to say)
426
+ 4. Click Generate
427
+ """)
428
+
429
+ with gr.Row():
430
+ with gr.Column():
431
+ zs_tts_text = gr.Textbox(
432
+ label="Text to synthesize (what will be spoken)",
433
+ placeholder="Enter the text you want to synthesize...",
434
+ lines=2,
435
+ value="Her handwriting is very neat, which suggests she likes things tidy."
436
+ )
437
+
438
+ zs_prompt_audio = gr.Audio(
439
+ label="Reference audio (upload or record)",
440
+ type="numpy"
441
+ )
442
+
443
+ zs_prompt_text = gr.Textbox(
444
+ label="Prompt text (transcription of reference audio)",
445
+ placeholder="Enter what is said in the reference audio...",
446
+ lines=2,
447
+ value=""
448
+ )
449
+
450
+ with gr.Row():
451
+ zs_seed = gr.Number(label="Random seed", value=0, precision=0)
452
+ zs_speed = gr.Slider(label="Speed", minimum=0.5, maximum=2.0, value=1.0, step=0.1)
453
+
454
+ zs_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
455
+
456
+ with gr.Column():
457
+ zs_output = gr.Audio(label="Generated speech")
458
+ zs_status = gr.Textbox(label="Status", interactive=False)
459
+
460
+ zs_btn.click(
461
+ fn=zero_shot_tts,
462
+ inputs=[zs_tts_text, zs_prompt_text, zs_prompt_audio, zs_seed, zs_speed],
463
+ outputs=[zs_output, zs_status]
464
  )
465
+
466
+ gr.Markdown("""
467
+ **Important:**
468
+ - **Text to synthesize**: The new text you want to hear in the cloned voice
469
+ - **Prompt text**: Transcription of what is said in your reference audio
470
+ - **Reference audio**: 3-10 seconds of clear speech
471
+
472
+ **Example:**
473
+ - Reference audio: Someone saying "Hello, how are you?"
474
+ - Prompt text: "Hello, how are you?"
475
+ - Text to synthesize: "This is a test of voice cloning"
476
+ - Result: "This is a test of voice cloning" in the cloned voice
477
+ """)
478
+
479
+ # Tab 2: Instruction-Based TTS
480
+ with gr.Tab("📝 Instruction-Based Control (Natural Language)"):
481
+ gr.Markdown("""
482
+ ### Control voice characteristics with natural language instructions
483
+
484
+ **Steps:**
485
+ 1. Upload or record reference audio
486
+ 2. Select or enter instruction (speed, dialect, emotion)
487
+ 3. Enter text to synthesize
488
+ 4. Click Generate
489
+ """)
490
+
491
+ with gr.Row():
492
+ with gr.Column():
493
+ inst_tts_text = gr.Textbox(
494
+ label="Text to synthesize",
495
+ placeholder="Enter your text...",
496
+ lines=2,
497
+ value="Welcome to the natural language control demo."
498
+ )
499
+
500
+ inst_prompt_audio = gr.Audio(
501
+ label="Reference audio",
502
+ type="numpy"
503
+ )
504
+
505
+ inst_text = gr.Dropdown(
506
+ label="Instruction",
507
+ choices=instruct_options,
508
+ value=instruct_options[0]
509
+ )
510
+
511
+ with gr.Row():
512
+ inst_seed = gr.Number(label="Random seed", value=0, precision=0)
513
+ inst_speed = gr.Slider(label="Speed", minimum=0.5, maximum=2.0, value=1.0, step=0.1)
514
+
515
+ inst_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
516
+
517
+ with gr.Column():
518
+ inst_output = gr.Audio(label="Generated speech")
519
+ inst_status = gr.Textbox(label="Status", interactive=False)
520
+
521
+ inst_btn.click(
522
+ fn=instruct_tts,
523
+ inputs=[inst_tts_text, inst_text, inst_prompt_audio, inst_seed, inst_speed],
524
+ outputs=[inst_output, inst_status]
525
  )
526
+
527
+ gr.Markdown("""
528
+ **Example instructions:**
529
+ - "请用广东话表达" (Speak in Cantonese)
530
+ - "请用尽可能快地语速说" (Speak as fast as possible)
531
+ - "Please speak in a professional tone"
532
+ """)
533
+
534
+ gr.Markdown("""
535
+ ---
536
+ ### 📋 Supported Languages & Dialects
537
+
538
+ **Languages:** Chinese, English, Japanese, Korean, German, Spanish, French, Italian, Russian
539
+
540
+ **Chinese Dialects:** Guangdong, Minnan, Sichuan, Dongbei, Shanxi, Shanghai, Tianjin, Shandong, and more
541
+
542
+ ### ⚡ Performance
543
+ - Model: Fun-CosyVoice3-0.5B (500M parameters)
544
+ - Sample Rate: 24kHz
545
+ - Latency: ~5-10s on CPU, ~2-3s on GPU
546
+
547
+ ### 📚 Resources
548
+ [Paper](https://arxiv.org/abs/2505.17589) • [GitHub](https://github.com/FunAudioLLM/CosyVoice) • [Model](https://huggingface.co/FunAudioLLM/Fun-CosyVoice3-0.5B-2512)
549
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
 
551
  if __name__ == "__main__":
552
+ print("\n🚀 Launching Gradio interface...")
553
+ demo.queue(max_size=10, default_concurrency_limit=2)
554
+ demo.launch(
555
+ server_name="0.0.0.0",
556
+ server_port=7860,
557
+ show_error=True
558
+ )