Kremon96 commited on
Commit
0cb8d70
·
verified ·
1 Parent(s): fb6f4bb

Upload mtl_tts (4).py

Browse files
Files changed (1) hide show
  1. mtl_tts (4).py +307 -0
mtl_tts (4).py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ import os
4
+
5
+ import librosa
6
+ import torch
7
+ import perth
8
+ import torch.nn.functional as F
9
+ from safetensors.torch import load_file as load_safetensors
10
+ from huggingface_hub import snapshot_download
11
+
12
+ from .models.t3 import T3
13
+ from .models.t3.modules.t3_config import T3Config
14
+ from .models.s3tokenizer import S3_SR, drop_invalid_tokens
15
+ from .models.s3gen import S3GEN_SR, S3Gen
16
+ from .models.tokenizers import MTLTokenizer
17
+ from .models.voice_encoder import VoiceEncoder
18
+ from .models.t3.modules.cond_enc import T3Cond
19
+
20
+
21
+ REPO_ID = "ResembleAI/chatterbox"
22
+
23
+ # Supported languages for the multilingual model
24
+ SUPPORTED_LANGUAGES = {
25
+ "ar": "Arabic",
26
+ "da": "Danish",
27
+ "de": "German",
28
+ "el": "Greek",
29
+ "en": "English",
30
+ "es": "Spanish",
31
+ "fi": "Finnish",
32
+ "fr": "French",
33
+ "he": "Hebrew",
34
+ "hi": "Hindi",
35
+ "it": "Italian",
36
+ "ja": "Japanese",
37
+ "ko": "Korean",
38
+ "ms": "Malay",
39
+ "nl": "Dutch",
40
+ "no": "Norwegian",
41
+ "pl": "Polish",
42
+ "pt": "Portuguese",
43
+ "ru": "Russian",
44
+ "sv": "Swedish",
45
+ "sw": "Swahili",
46
+ "tr": "Turkish",
47
+ "zh": "Chinese",
48
+ }
49
+
50
+
51
+ def punc_norm(text: str) -> str:
52
+ """
53
+ Quick cleanup func for punctuation from LLMs or
54
+ containing chars not seen often in the dataset
55
+ """
56
+ if len(text) == 0:
57
+ return "You need to add some text for me to talk."
58
+
59
+ # Capitalise first letter
60
+ if text[0].islower():
61
+ text = text[0].upper() + text[1:]
62
+
63
+ # Remove multiple space chars
64
+ text = " ".join(text.split())
65
+
66
+ # Replace uncommon/llm punc
67
+ punc_to_replace = [
68
+ ("...", ", "),
69
+ ("…", ", "),
70
+ (":", ","),
71
+ (" - ", ", "),
72
+ (";", ", "),
73
+ ("—", "-"),
74
+ ("–", "-"),
75
+ (" ,", ","),
76
+ ("“", "\""),
77
+ ("”", "\""),
78
+ ("‘", "'"),
79
+ ("’", "'"),
80
+ ]
81
+ for old_char_sequence, new_char in punc_to_replace:
82
+ text = text.replace(old_char_sequence, new_char)
83
+
84
+ # Add full stop if no ending punc
85
+ text = text.rstrip(" ")
86
+ sentence_enders = {".", "!", "?", "-", ",","、",",","。","?","!"}
87
+ if not any(text.endswith(p) for p in sentence_enders):
88
+ text += "."
89
+
90
+ return text
91
+
92
+
93
+ @dataclass
94
+ class Conditionals:
95
+ """
96
+ Conditionals for T3 and S3Gen
97
+ - T3 conditionals:
98
+ - speaker_emb
99
+ - clap_emb
100
+ - cond_prompt_speech_tokens
101
+ - cond_prompt_speech_emb
102
+ - emotion_adv
103
+ - S3Gen conditionals:
104
+ - prompt_token
105
+ - prompt_token_len
106
+ - prompt_feat
107
+ - prompt_feat_len
108
+ - embedding
109
+ """
110
+ t3: T3Cond
111
+ gen: dict
112
+
113
+ def to(self, device):
114
+ self.t3 = self.t3.to(device=device)
115
+ for k, v in self.gen.items():
116
+ if torch.is_tensor(v):
117
+ self.gen[k] = v.to(device=device)
118
+ return self
119
+
120
+ def save(self, fpath: Path):
121
+ arg_dict = dict(
122
+ t3=self.t3.__dict__,
123
+ gen=self.gen
124
+ )
125
+ torch.save(arg_dict, fpath)
126
+
127
+ @classmethod
128
+ def load(cls, fpath, map_location="cpu"):
129
+ kwargs = torch.load(fpath, map_location=map_location, weights_only=True)
130
+ return cls(T3Cond(**kwargs['t3']), kwargs['gen'])
131
+
132
+
133
+ class ChatterboxMultilingualTTS:
134
+ ENC_COND_LEN = 6 * S3_SR
135
+ DEC_COND_LEN = 10 * S3GEN_SR
136
+
137
+ def __init__(
138
+ self,
139
+ t3: T3,
140
+ s3gen: S3Gen,
141
+ ve: VoiceEncoder,
142
+ tokenizer: MTLTokenizer,
143
+ device: str,
144
+ conds: Conditionals = None,
145
+ ):
146
+ self.sr = S3GEN_SR # sample rate of synthesized audio
147
+ self.t3 = t3
148
+ self.s3gen = s3gen
149
+ self.ve = ve
150
+ self.tokenizer = tokenizer
151
+ self.device = device
152
+ self.conds = conds
153
+ self.watermarker = perth.PerthImplicitWatermarker()
154
+
155
+ @classmethod
156
+ def get_supported_languages(cls):
157
+ """Return dictionary of supported language codes and names."""
158
+ return SUPPORTED_LANGUAGES.copy()
159
+
160
+ @classmethod
161
+ def from_local(cls, ckpt_dir, device) -> 'ChatterboxMultilingualTTS':
162
+ ckpt_dir = Path(ckpt_dir)
163
+
164
+ # Determine map_location based on device
165
+ if device in ["cpu", "mps"]:
166
+ map_location = torch.device('cpu')
167
+ else:
168
+ map_location = None
169
+
170
+ ve = VoiceEncoder()
171
+ ve.load_state_dict(
172
+ torch.load(ckpt_dir / "ve.pt", map_location=map_location, weights_only=True)
173
+ )
174
+ ve.to(device).eval()
175
+
176
+ t3 = T3(T3Config.multilingual())
177
+ t3_state = load_safetensors(ckpt_dir / "t3_mtl23ls_v2.safetensors")
178
+ if "model" in t3_state.keys():
179
+ t3_state = t3_state["model"][0]
180
+ t3.load_state_dict(t3_state)
181
+ t3.to(device).eval()
182
+
183
+ s3gen = S3Gen()
184
+ s3gen.load_state_dict(
185
+ torch.load(ckpt_dir / "s3gen.pt", map_location=map_location, weights_only=True)
186
+ )
187
+ s3gen.to(device).eval()
188
+
189
+ tokenizer = MTLTokenizer(
190
+ str(ckpt_dir / "grapheme_mtl_merged_expanded_v1.json")
191
+ )
192
+
193
+ conds = None
194
+ if (builtin_voice := ckpt_dir / "conds.pt").exists():
195
+ conds = Conditionals.load(builtin_voice, map_location=map_location).to(device)
196
+
197
+ return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
198
+
199
+ @classmethod
200
+ def from_pretrained(cls, device: torch.device) -> 'ChatterboxMultilingualTTS':
201
+ ckpt_dir = Path(
202
+ snapshot_download(
203
+ repo_id=REPO_ID,
204
+ repo_type="model",
205
+ revision="main",
206
+ allow_patterns=["ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt", "grapheme_mtl_merged_expanded_v1.json", "conds.pt", "Cangjie5_TC.json"],
207
+ token=os.getenv("HF_TOKEN"),
208
+ )
209
+ )
210
+ return cls.from_local(ckpt_dir, device)
211
+
212
+ def prepare_conditionals(self, wav_fpath, exaggeration=0.5):
213
+ ## Load reference wav
214
+ s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR)
215
+
216
+ ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR)
217
+
218
+ s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN]
219
+ s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device)
220
+
221
+ # Speech cond prompt tokens
222
+ t3_cond_prompt_tokens = None
223
+ if plen := self.t3.hp.speech_cond_prompt_len:
224
+ s3_tokzr = self.s3gen.tokenizer
225
+ t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen)
226
+ t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device)
227
+
228
+ # Voice-encoder speaker embedding
229
+ ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR))
230
+ ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device)
231
+
232
+ t3_cond = T3Cond(
233
+ speaker_emb=ve_embed,
234
+ cond_prompt_speech_tokens=t3_cond_prompt_tokens,
235
+ emotion_adv=exaggeration * torch.ones(1, 1, 1),
236
+ ).to(device=self.device)
237
+ self.conds = Conditionals(t3_cond, s3gen_ref_dict)
238
+
239
+ def generate(
240
+ self,
241
+ text,
242
+ language_id,
243
+ audio_prompt_path=None,
244
+ exaggeration=0.5,
245
+ cfg_weight=0.5,
246
+ temperature=0.8,
247
+ repetition_penalty=2.0,
248
+ min_p=0.05,
249
+ top_p=1.0,
250
+ ):
251
+ # Validate language_id
252
+ if language_id and language_id.lower() not in SUPPORTED_LANGUAGES:
253
+ supported_langs = ", ".join(SUPPORTED_LANGUAGES.keys())
254
+ raise ValueError(
255
+ f"Unsupported language_id '{language_id}'. "
256
+ f"Supported languages: {supported_langs}"
257
+ )
258
+
259
+ if audio_prompt_path:
260
+ self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration)
261
+ else:
262
+ assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`"
263
+
264
+ # Update exaggeration if needed
265
+ if float(exaggeration) != float(self.conds.t3.emotion_adv[0, 0, 0].item()):
266
+ _cond: T3Cond = self.conds.t3
267
+ self.conds.t3 = T3Cond(
268
+ speaker_emb=_cond.speaker_emb,
269
+ cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens,
270
+ emotion_adv=exaggeration * torch.ones(1, 1, 1),
271
+ ).to(device=self.device)
272
+
273
+ # Norm and tokenize text
274
+ text = punc_norm(text)
275
+ text_tokens = self.tokenizer.text_to_tokens(text, language_id=language_id.lower() if language_id else None).to(self.device)
276
+ text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG
277
+
278
+ sot = self.t3.hp.start_text_token
279
+ eot = self.t3.hp.stop_text_token
280
+ text_tokens = F.pad(text_tokens, (1, 0), value=sot)
281
+ text_tokens = F.pad(text_tokens, (0, 1), value=eot)
282
+
283
+ with torch.inference_mode():
284
+ speech_tokens = self.t3.inference(
285
+ t3_cond=self.conds.t3,
286
+ text_tokens=text_tokens,
287
+ max_new_tokens=1000, # TODO: use the value in config
288
+ temperature=temperature,
289
+ cfg_weight=cfg_weight,
290
+ repetition_penalty=repetition_penalty,
291
+ min_p=min_p,
292
+ top_p=top_p,
293
+ )
294
+ # Extract only the conditional batch.
295
+ speech_tokens = speech_tokens[0]
296
+
297
+ # TODO: output becomes 1D
298
+ speech_tokens = drop_invalid_tokens(speech_tokens)
299
+ speech_tokens = speech_tokens.to(self.device)
300
+
301
+ wav, _ = self.s3gen.inference(
302
+ speech_tokens=speech_tokens,
303
+ ref_dict=self.conds.gen,
304
+ )
305
+ wav = wav.squeeze(0).detach().cpu().numpy()
306
+ watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr)
307
+ return torch.from_numpy(watermarked_wav).unsqueeze(0)