kevinwang676 commited on
Commit
ee6d11a
·
verified ·
1 Parent(s): fdd3aa3

Create app_demo.py

Browse files
Files changed (1) hide show
  1. app_demo.py +125 -0
app_demo.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import sys
16
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
17
+ sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
18
+
19
+ #os.system("apt-get install sox libsox-dev")
20
+ os.system("pip install pysox")
21
+ os.system("pip install torch==2.1.0+cu121 torchvision==0.16.0+cu121 torchaudio==2.1.0 torchtext==0.16.0+cpu torchdata==0.7.0 --index-url https://download.pytorch.org/whl/cu121")
22
+
23
+
24
+ import argparse
25
+ import gradio as gr
26
+ import numpy as np
27
+ import torch
28
+ import torchaudio
29
+ import random
30
+ import librosa
31
+
32
+ import logging
33
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
34
+
35
+ import spaces
36
+
37
+ logging.basicConfig(level=logging.DEBUG,
38
+ format='%(asctime)s %(levelname)s %(message)s')
39
+
40
+ def generate_seed():
41
+ seed = random.randint(1, 100000000)
42
+ return {
43
+ "__type__": "update",
44
+ "value": seed
45
+ }
46
+
47
+ def set_all_random_seed(seed):
48
+ random.seed(seed)
49
+ np.random.seed(seed)
50
+ torch.manual_seed(seed)
51
+ torch.cuda.manual_seed_all(seed)
52
+
53
+ max_val = 0.8
54
+ def postprocess(speech, top_db=60, hop_length=220, win_length=440):
55
+ speech, _ = librosa.effects.trim(
56
+ speech, top_db=top_db,
57
+ frame_length=win_length,
58
+ hop_length=hop_length
59
+ )
60
+ if speech.abs().max() > max_val:
61
+ speech = speech / speech.abs().max() * max_val
62
+ speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1)
63
+ return speech
64
+
65
+ inference_mode_list = ['预训练音色', '3s极速复刻', '跨语种复刻', '自然语言控制']
66
+ instruct_dict = {'预训练音色': '1. 选择预训练音色\n2.点击生成音频按钮',
67
+ '3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3.点击生成音频按钮',
68
+ '跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,若同时提供,优先选择prompt音频文件\n2.点击生成音频按钮',
69
+ '自然语言控制': '1. 输入instruct文本\n2.点击生成音频按钮'}
70
+ def change_instruction(mode_checkbox_group):
71
+ return instruct_dict[mode_checkbox_group]
72
+
73
+ @spaces.GPU
74
+ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed):
75
+
76
+ return "jay_short.wav"
77
+
78
+ def main():
79
+ with gr.Blocks() as demo:
80
+ gr.Markdown("### 代码库 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) 预训练模型 [CosyVoice-300M](https://www.modelscope.cn/models/iic/CosyVoice-300M) [CosyVoice-300M-Instruct](https://www.modelscope.cn/models/iic/CosyVoice-300M-Instruct) [CosyVoice-300M-SFT](https://www.modelscope.cn/models/iic/CosyVoice-300M-SFT)")
81
+ gr.Markdown("#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作")
82
+
83
+ tts_text = gr.Textbox(label="输入合成文本", lines=1, value="我是通义实验室语音团队全新推出的生成式语音大模型,提供舒适自然的语音合成能力。")
84
+
85
+ with gr.Row():
86
+ mode_checkbox_group = gr.Radio(choices=inference_mode_list, label='选择推理模式', value=inference_mode_list[0])
87
+ instruction_text = gr.Text(label="操作步骤", value=instruct_dict[inference_mode_list[0]], scale=0.5)
88
+ sft_dropdown = gr.Dropdown(choices=sft_spk, label='选择预训练音色', value=sft_spk[0], scale=0.25)
89
+ with gr.Column(scale=0.25):
90
+ seed_button = gr.Button(value="\U0001F3B2")
91
+ seed = gr.Number(value=0, label="随机推理种子")
92
+
93
+ with gr.Row():
94
+ prompt_wav_upload = gr.Audio(sources='upload', type='filepath', label='选择prompt音频文件,注意采样率不低于16khz')
95
+ prompt_wav_record = gr.Audio(sources='microphone', type='filepath', label='录制prompt音频文件')
96
+ prompt_text = gr.Textbox(label="输入prompt文本", lines=1, placeholder="请输入prompt文本,需与prompt音频内容一致,暂时不支持自动识别...", value='')
97
+ instruct_text = gr.Textbox(label="输入instruct文本", lines=1, placeholder="请输入instruct文本.", value='')
98
+
99
+ generate_button = gr.Button("生成音频")
100
+
101
+ audio_output = gr.Audio(label="合成音频")
102
+
103
+ seed_button.click(generate_seed, inputs=[], outputs=seed)
104
+ generate_button.click(generate_audio,
105
+ inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed],
106
+ outputs=[audio_output])
107
+ mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
108
+ demo.queue()
109
+ demo.launch(share=False, show_error=True)
110
+
111
+ if __name__ == '__main__':
112
+ parser = argparse.ArgumentParser()
113
+ parser.add_argument('--port',
114
+ type=int,
115
+ default=8000)
116
+ parser.add_argument('--model_dir',
117
+ type=str,
118
+ default='iic/CosyVoice-300M',
119
+ help='local path or modelscope repo id')
120
+ args = parser.parse_args()
121
+ cosyvoice = CosyVoice(args.model_dir)
122
+ sft_spk = cosyvoice.list_avaliable_spks()
123
+ prompt_sr, target_sr = 16000, 22050
124
+ default_data = np.zeros(target_sr)
125
+ main()