alakxender commited on
Commit
d81be79
·
1 Parent(s): 91ab2cd
Files changed (5) hide show
  1. app.py +0 -2
  2. cbox_dv.py +0 -451
  3. cbox_test.py +0 -79
  4. chatterbox_dhivehi.py +0 -210
  5. requirements.txt +0 -1
app.py CHANGED
@@ -16,7 +16,6 @@ from csm1b_dv import (
16
  change_model_and_update_ui
17
  )
18
  from dia_1_6B_dv import get_dia_1_6B_tab
19
- from cbox_dv import get_cbox_dv
20
 
21
  with gr.Blocks(
22
  title="Dhivehi (Thaana) Text-to-Speech",
@@ -38,7 +37,6 @@ with gr.Blocks(
38
  ) as app:
39
  get_csm1b_tab()
40
  get_dia_1_6B_tab()
41
- get_cbox_dv()
42
 
43
  if __name__ == "__main__":
44
  app.launch(share=False)
 
16
  change_model_and_update_ui
17
  )
18
  from dia_1_6B_dv import get_dia_1_6B_tab
 
19
 
20
  with gr.Blocks(
21
  title="Dhivehi (Thaana) Text-to-Speech",
 
37
  ) as app:
38
  get_csm1b_tab()
39
  get_dia_1_6B_tab()
 
40
 
41
  if __name__ == "__main__":
42
  app.launch(share=False)
cbox_dv.py DELETED
@@ -1,451 +0,0 @@
1
- from pathlib import Path
2
- import os
3
- try:
4
- from huggingface_hub import snapshot_download
5
- _target = Path.home() / ".chatterbox-tts-dhivehi"
6
- if not (_target.exists() and any(_target.rglob("*"))):
7
- snapshot_download(
8
- repo_id="alakxender/chatterbox-tts-dhivehi",
9
- local_dir=str(_target),
10
- local_dir_use_symlinks=False,
11
- resume_download=True
12
- )
13
- except Exception as _e:
14
- pass
15
-
16
- from chatterbox.tts import ChatterboxTTS
17
- import torchaudio
18
- from pathlib import Path
19
- import torch
20
- import random
21
- import numpy as np
22
- import gradio as gr
23
- import tempfile
24
- import os
25
- import chatterbox_dhivehi
26
- import warnings
27
-
28
- warnings.filterwarnings("ignore")
29
-
30
- chatterbox_dhivehi.extend_dhivehi()
31
-
32
- class TTSApp:
33
- def __init__(self, checkpoint=f"{_target}/kn_cbox"):
34
- self.checkpoint = checkpoint
35
- self.model = None
36
- self.load_model()
37
-
38
- def load_model(self):
39
- """Load the TTS model"""
40
- try:
41
- print(f"Loading model with checkpoint: {self.checkpoint}")
42
- self.model = ChatterboxTTS.from_dhivehi(
43
- ckpt_dir=Path(self.checkpoint),
44
- device="cuda" if torch.cuda.is_available() else "cpu"
45
- )
46
- print("Model loaded successfully!")
47
- except Exception as e:
48
- print(f"Error loading model: {e}")
49
- raise e
50
-
51
- def set_seed(self, seed: int):
52
- """Set random seed for reproducibility"""
53
- torch.manual_seed(seed)
54
- if torch.cuda.is_available():
55
- torch.cuda.manual_seed(seed)
56
- torch.cuda.manual_seed_all(seed)
57
- random.seed(seed)
58
- np.random.seed(seed)
59
-
60
- def generate_speech(self,
61
- text,
62
- reference_audio,
63
- exaggeration=0.5,
64
- temperature=0.1,
65
- cfg_weight=0.5,
66
- seed=42):
67
- """Generate speech from text using voice cloning"""
68
-
69
- # Clean the input text
70
- text = self.clean_text(text)
71
-
72
- if not text:
73
- return None, "Please enter some text to generate speech."
74
-
75
- if self.model is None:
76
- return None, "Model not loaded. Please check your model paths."
77
-
78
- try:
79
- # Set seed for reproducibility
80
- self.set_seed(seed)
81
-
82
- # Handle reference audio - make it optional
83
- audio_prompt_path = reference_audio
84
-
85
- print(f"Generating audio for: {text[:50]}...")
86
- if audio_prompt_path:
87
- print(f"Using reference audio: {audio_prompt_path}")
88
- else:
89
- print("Generating without reference audio")
90
-
91
- # Generate audio - handle optional reference audio
92
- if audio_prompt_path:
93
- audio = self.model.generate(
94
- text=text,
95
- audio_prompt_path=audio_prompt_path,
96
- exaggeration=exaggeration,
97
- temperature=temperature,
98
- cfg_weight=cfg_weight,
99
- )
100
- else:
101
- # Try without reference audio
102
- try:
103
- audio = self.model.generate(
104
- text=text,
105
- exaggeration=exaggeration,
106
- temperature=temperature,
107
- cfg_weight=cfg_weight,
108
- )
109
- except TypeError:
110
- # If the model requires audio_prompt_path, try with empty string
111
- audio = self.model.generate(
112
- text=text,
113
- audio_prompt_path="",
114
- exaggeration=exaggeration,
115
- temperature=temperature,
116
- cfg_weight=cfg_weight,
117
- )
118
-
119
- # Save to temporary file
120
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
121
- output_path = tmp_file.name
122
-
123
- torchaudio.save(output_path, audio, 24000)
124
-
125
- return output_path, f"Successfully generated speech! Audio length: {audio.shape[1]/24000:.2f} seconds"
126
-
127
- except Exception as e:
128
- error_msg = f"Error generating speech: {str(e)}"
129
- print(error_msg)
130
- return None, error_msg
131
-
132
- def clean_text(self, text):
133
- """Clean text by removing newlines at start/end, double spaces, and extra whitespace"""
134
- import re
135
-
136
- # Remove newlines at start and end
137
- text = text.strip('\n\r')
138
-
139
- # Replace multiple spaces with single space
140
- text = re.sub(r'\s+', ' ', text)
141
-
142
- # Strip leading and trailing spaces
143
- text = text.strip()
144
-
145
- return text
146
-
147
- def split_sentences(self, text):
148
- """Split text into sentences based on periods, ensuring each sentence is at least 150 characters"""
149
- # Clean the input text first
150
- text = self.clean_text(text)
151
-
152
- # First, split by periods normally
153
- initial_sentences = []
154
- current_sentence = ""
155
-
156
- for char in text:
157
- current_sentence += char
158
- if char == '.':
159
- # Add sentence if it's not empty after stripping spaces from both sides
160
- stripped_sentence = current_sentence.strip()
161
- if stripped_sentence:
162
- initial_sentences.append(stripped_sentence)
163
- current_sentence = ""
164
-
165
- # Add remaining text if any (without period), stripped of spaces from both sides
166
- stripped_remaining = current_sentence.strip()
167
- if stripped_remaining:
168
- initial_sentences.append(stripped_remaining)
169
-
170
- # If we only have one sentence, return it
171
- if len(initial_sentences) <= 1:
172
- return initial_sentences
173
-
174
- # Now combine sentences until each is at least 150 characters
175
- final_sentences = []
176
- combined_sentence = ""
177
-
178
- for sentence in initial_sentences:
179
- if combined_sentence:
180
- combined_sentence += " " + sentence
181
- else:
182
- combined_sentence = sentence
183
-
184
- # If combined sentence is >= 150 chars, add it to final list
185
- if len(combined_sentence) >= 150:
186
- final_sentences.append(combined_sentence.strip())
187
- combined_sentence = ""
188
-
189
- # Add any remaining combined sentence (even if < 150 chars)
190
- if combined_sentence.strip():
191
- final_sentences.append(combined_sentence.strip())
192
-
193
- return final_sentences
194
-
195
- def generate_speech_multi_sentence(self,
196
- text,
197
- reference_audio,
198
- exaggeration=0.5,
199
- temperature=0.1,
200
- cfg_weight=0.5,
201
- seed=42):
202
- """Generate speech from text with multi-sentence support and progress tracking"""
203
-
204
- # Clean the input text
205
- text = self.clean_text(text)
206
-
207
- if not text:
208
- yield None, "Please enter some text to generate speech."
209
- return
210
-
211
- if self.model is None:
212
- yield None, "Model not loaded. Please check your model paths."
213
- return
214
-
215
- # Split text into sentences
216
- sentences = self.split_sentences(text)
217
-
218
- # If only one sentence or no periods, use regular method
219
- if len(sentences) <= 1:
220
- yield None, "🎵 Generating single sentence..."
221
- result_audio, result_status = self.generate_speech(text, reference_audio, exaggeration, temperature, cfg_weight, seed)
222
- yield result_audio, result_status
223
- return
224
-
225
- try:
226
- # Set seed for reproducibility
227
- self.set_seed(seed)
228
-
229
- # Handle reference audio - make it optional
230
- audio_prompt_path = reference_audio
231
-
232
- yield None, f"🚀 Starting generation for {len(sentences)} sentences..."
233
- print(f"Processing {len(sentences)} sentences...")
234
-
235
- all_audio_segments = []
236
- total_duration = 0
237
-
238
- for i, sentence in enumerate(sentences):
239
- # Calculate progress percentage
240
- progress_percent = int((i / len(sentences)) * 90) # Reserve last 10% for combining
241
- yield None, f"🎵 Generating sentence {i+1}/{len(sentences)} ({progress_percent}%): {sentence[:50]}..."
242
-
243
- print(f"Generating audio for sentence {i+1}/{len(sentences)}: {sentence[:50]}...")
244
-
245
- # Generate audio for this sentence
246
- try:
247
- if audio_prompt_path:
248
- audio = self.model.generate(
249
- text=sentence,
250
- audio_prompt_path=audio_prompt_path,
251
- exaggeration=exaggeration,
252
- temperature=temperature,
253
- cfg_weight=cfg_weight,
254
- )
255
- else:
256
- # Try without reference audio
257
- try:
258
- audio = self.model.generate(
259
- text=sentence,
260
- exaggeration=exaggeration,
261
- temperature=temperature,
262
- cfg_weight=cfg_weight,
263
- )
264
- except TypeError:
265
- # If the model requires audio_prompt_path, try with empty string
266
- audio = self.model.generate(
267
- text=sentence,
268
- audio_prompt_path="",
269
- exaggeration=exaggeration,
270
- temperature=temperature,
271
- cfg_weight=cfg_weight,
272
- )
273
- except Exception as model_error:
274
- # If the model fails due to missing reference audio, try with default behavior
275
- if "reference_voice.wav not found" in str(model_error) or "No reference audio provided" in str(model_error):
276
- print("Attempting generation without reference audio...")
277
- # Try different approaches for models that don't support None reference audio
278
- try:
279
- # Some models might accept an empty string
280
- audio = self.model.generate(
281
- text=sentence,
282
- audio_prompt_path="",
283
- exaggeration=exaggeration,
284
- temperature=temperature,
285
- cfg_weight=cfg_weight,
286
- )
287
- except:
288
- # If that fails, try without the audio_prompt_path parameter entirely
289
- audio = self.model.generate(
290
- text=sentence,
291
- exaggeration=exaggeration,
292
- temperature=temperature,
293
- cfg_weight=cfg_weight,
294
- )
295
- else:
296
- raise model_error
297
-
298
- all_audio_segments.append(audio)
299
- total_duration += audio.shape[1] / 24000
300
-
301
- # Concatenate all audio segments
302
- yield None, "🔧 Combining audio segments (95%)..."
303
- print("Combining audio segments...")
304
- combined_audio = torch.cat(all_audio_segments, dim=1)
305
-
306
- # Save to temporary file
307
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
308
- output_path = tmp_file.name
309
-
310
- torchaudio.save(output_path, combined_audio, 24000)
311
- print("Multi-sentence processing complete!")
312
-
313
- yield output_path, f"✅ Successfully generated speech from {len(sentences)} sentences! Total audio length: {total_duration:.2f} seconds"
314
-
315
- except Exception as e:
316
- error_msg = f"❌ Error generating multi-sentence speech: {str(e)}"
317
- print(error_msg)
318
- yield None, error_msg
319
-
320
- def get_cbox_dv():
321
- """Create the Gradio interface"""
322
-
323
- # Initialize the TTS app
324
- #tts_app = TTSApp()
325
-
326
- # Sample texts in Dhivehi
327
- sample_texts = [
328
- "ކާޑު ނުލައި ފައިސާ ދެއްކޭ ނެޝަނަލް ކިއުއާރް ކޯޑް އެމްއެމްއޭ އިން ތައާރަފްކުރަނީ",
329
- """ފުޓްބޯޅަ ސްކޫލްގެ ބިމާއި ގުދަންބަރި ބިމުގައި އިމާރާތް ކުރުމުގެ މަސައްކަތް ހުއްޓާލަން އަންގައިފި...
330
- Construction work on football school land and warehouse land has been ordered to stop""",
331
- "ސިވިލް ސާވިސްގެ ހިދުމަތުގެ މުއްދަތު ގުނުމުގައި ކުންފުނިތަކާއި އިދާރާތަކަށް ހިދުމަތްކުރި މުއްދަތު ހިމަނަނީ",
332
- """އެ ރަށުގެ ބިން ހިއްކުމާއި ބަނދަރުގެ ނެރު ބަދަލުކުރުމާއި ގޮނޑުދޮށް ހިމާޔަތް ކުރުމުގެ މަސައްކަތް އެމްޓީސީސީއާ މިނިސްޓްރީން ހަވާލުކުރީ މިދިޔަ މަހު ރައީސް އެ ރަށަށް ކުރެއްވި ދަތުރުފުޅުގައި.
333
- The ministry handed over the land reclamation, replacement of the port canal and beach protection to MTCC during the President's visit to the village last month"""
334
- ]
335
-
336
- with gr.Tab("🎤 ChatterboxTTS"):
337
- gr.Markdown("# 🎤 ChatterboxTTS - Dhivehi Text-to-Speech with Voice Cloning")
338
- gr.Markdown("Generate natural-sounding Dhivehi speech with voice cloning capabilities.")
339
-
340
- # Row 1: Text input and Reference audio
341
- with gr.Row():
342
- text_input = gr.Textbox(
343
- label="Text to Convert",
344
- placeholder="Enter Dhivehi text here...",
345
- lines=6,
346
- value=sample_texts[0],
347
- rtl=True,
348
- elem_classes=["textbox1"]
349
- )
350
- reference_audio = gr.Audio(
351
- label="Reference Voice Audio (optional - for voice cloning)",
352
- type="filepath",
353
- sources=["upload", "microphone"],
354
- )
355
-
356
- # Row 2: Example buttons
357
- gr.Markdown("**Quick Examples:**")
358
- with gr.Row():
359
- sample_btn1 = gr.Button("Sample 1", size="sm")
360
- sample_btn2 = gr.Button("Sample 2", size="sm")
361
- sample_btn3 = gr.Button("Sample 3", size="sm")
362
- sample_btn4 = gr.Button("Sample 4", size="sm")
363
-
364
- # Row 3: Advanced settings
365
- with gr.Accordion("Advanced Settings", open=False):
366
- with gr.Row():
367
- exaggeration = gr.Slider(
368
- minimum=0.0,
369
- maximum=2.0,
370
- value=0.5,
371
- step=0.1,
372
- label="Exaggeration",
373
- info="Controls expressiveness"
374
- )
375
- temperature = gr.Slider(
376
- minimum=0.01,
377
- maximum=1.0,
378
- value=0.35,
379
- step=0.01,
380
- label="Temperature",
381
- info="Controls randomness"
382
- )
383
- cfg_weight = gr.Slider(
384
- minimum=0.0,
385
- maximum=2.0,
386
- value=0.3,
387
- step=0.1,
388
- label="CFG Weight",
389
- info="Classifier-free guidance weight"
390
- )
391
- seed = gr.Slider(
392
- minimum=0,
393
- maximum=9999,
394
- value=42,
395
- step=1,
396
- label="Seed",
397
- info="For reproducible results"
398
- )
399
-
400
- # Row 4: Generate button
401
- generate_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
402
-
403
- # Row 5: Output section
404
- with gr.Row():
405
- with gr.Column():
406
- output_audio = gr.Audio(label="Generated Speech", type="filepath")
407
- status_message = gr.Textbox(label="Status", interactive=False)
408
-
409
- # Event handlers
410
- def set_sample_text(sample_idx):
411
- return sample_texts[sample_idx]
412
-
413
- sample_btn1.click(lambda: set_sample_text(0), outputs=[text_input])
414
- sample_btn2.click(lambda: set_sample_text(1), outputs=[text_input])
415
- sample_btn3.click(lambda: set_sample_text(2), outputs=[text_input])
416
- sample_btn4.click(lambda: set_sample_text(3), outputs=[text_input])
417
-
418
- def generate_with_progress(text, reference_audio, exaggeration, temperature, cfg_weight, seed):
419
- """Generate speech with streaming progress updates"""
420
- # Use the streaming generator from the TTS app
421
- #for result_audio, result_status in tts_app.generate_speech_multi_sentence(
422
- # text, reference_audio, exaggeration, temperature, cfg_weight, seed
423
- #):
424
- # yield result_audio, result_status
425
-
426
- generate_btn.click(
427
- fn=generate_with_progress,
428
- inputs=[text_input, reference_audio, exaggeration, temperature, cfg_weight, seed],
429
- outputs=[output_audio, status_message]
430
- )
431
-
432
- # Instructions
433
- with gr.Accordion("Tips", open=False):
434
- gr.Markdown("""
435
- ### General Use (TTS and Voice Agents):
436
- - The default settings (exaggeration=0.5, cfg=0.5) work well for most prompts.
437
- - If the reference speaker has a fast speaking style, lowering cfg to around 0.3 can improve pacing.
438
-
439
- ### Expressive or Dramatic Speech:
440
- - Try lower cfg values (e.g. ~0.3) and increase exaggeration to around 0.7 or higher.
441
- - Higher exaggeration tends to speed up speech; reducing cfg helps compensate with slower, more deliberate pacing.
442
-
443
- ### Language Transfer Notes:
444
- - Ensure that the reference clip matches the specified language tag. Otherwise, language transfer outputs may inherit the accent of the reference clip's language.
445
- - To mitigate this, set the CFG weight to 0.
446
-
447
- ### Additional Tips:
448
- - For best voice cloning results, use clear audio with minimal background noise
449
- - The reference audio should be 3-10 seconds long
450
- - Use the same seed value for reproducible results
451
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbox_test.py DELETED
@@ -1,79 +0,0 @@
1
- from pathlib import Path
2
- import os
3
- try:
4
- from huggingface_hub import snapshot_download
5
- _target = Path.home() / ".chatterbox-tts-dhivehi"
6
- if not (_target.exists() and any(_target.rglob("*"))):
7
- snapshot_download(
8
- repo_id="alakxender/chatterbox-tts-dhivehi",
9
- local_dir=str(_target),
10
- local_dir_use_symlinks=False,
11
- resume_download=True
12
- )
13
- except Exception as _e:
14
- pass
15
-
16
- from chatterbox.tts import ChatterboxTTS
17
- import chatterbox_dhivehi
18
- import torchaudio
19
- import torch
20
- import numpy as np
21
- import random
22
- # ---- User settings (edit these) ----
23
- CKPT_DIR = f"{_target}/kn_cbox" # path to your finetuned checkpoint dir
24
- REF_WAV = f"{_target}/samples/reference_audio.wav" # optional 3–10s clean reference; "" to disable
25
- #REF_WAV = ""
26
- TEXT = "މި ރިޕޯޓާ ގުޅޭ ގޮތުން އެނިމަލް ވެލްފެއާ މިނިސްޓްރީން އަދި ވާހަކައެއް ނުދައްކާ" # sample Dhivehi text
27
- TEXT = f"{TEXT}, The Animal Welfare Ministry has not yet commented on the report"
28
- EXAGGERATION = 0.4
29
- TEMPERATURE = 0.3
30
- CFG_WEIGHT = 0.7
31
- SEED = 42
32
- SAMPLE_RATE = 24000
33
- OUT_PATH = "out.wav"
34
- # ------------------------------------
35
-
36
- # Extend Dhivehi support from local file
37
- chatterbox_dhivehi.extend_dhivehi()
38
-
39
- # Seed for reproducibility
40
- torch.manual_seed(SEED)
41
- if torch.cuda.is_available():
42
- torch.cuda.manual_seed(SEED)
43
- torch.cuda.manual_seed_all(SEED)
44
- random.seed(SEED)
45
- np.random.seed(SEED)
46
-
47
- # Load model
48
- device = "cuda" if torch.cuda.is_available() else "cpu"
49
- print(f"Loading ChatterboxTTS from: {CKPT_DIR} on {device}")
50
- model = ChatterboxTTS.from_dhivehi(ckpt_dir=Path(CKPT_DIR), device=device)
51
- print("Model loaded.")
52
-
53
- # Generate (reference audio optional)
54
- print(f"Generating audio... ref={'yes' if REF_WAV else 'no'}")
55
- gen_kwargs = dict(
56
- text=TEXT,
57
- exaggeration=EXAGGERATION,
58
- temperature=TEMPERATURE,
59
- cfg_weight=CFG_WEIGHT,
60
- )
61
-
62
- try:
63
- if REF_WAV:
64
- gen_kwargs["audio_prompt_path"] = REF_WAV
65
- audio = model.generate(**gen_kwargs)
66
- else:
67
- # Try without reference first; if backend requires audio_prompt_path, fall back to ""
68
- try:
69
- audio = model.generate(**gen_kwargs)
70
- except TypeError:
71
- gen_kwargs["audio_prompt_path"] = ""
72
- audio = model.generate(**gen_kwargs)
73
- except Exception as e:
74
- raise RuntimeError(f"Generation failed: {e}")
75
-
76
- # Save
77
- torchaudio.save(OUT_PATH, audio, SAMPLE_RATE)
78
- dur = audio.shape[1] / SAMPLE_RATE
79
- print(f"Saved {OUT_PATH} ({dur:.2f}s)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
chatterbox_dhivehi.py DELETED
@@ -1,210 +0,0 @@
1
- # chatterbox_dhivehi.py
2
- """
3
- Dhivehi extension for ChatterboxTTS.
4
-
5
- Requires: chatterbox-tts 0.1.4 (not tested on any other version)
6
-
7
- Adds:
8
- - load_t3_with_vocab(state_dict, device, force_vocab_size): load T3 with a specific vocab size,
9
- resizing both the embedding and the projection head, and padding checkpoint weights if needed.
10
- - from_dhivehi(...): classmethod for building a ChatterboxTTS from a checkpoint directory,
11
- using load_t3_with_vocab under the hood (defaults to vocab=2000).
12
- - extend_dhivehi(): attach the above to ChatterboxTTS (idempotent).
13
-
14
- Usage in app.py:
15
- import chatterbox_dhivehi
16
- chatterbox_dhivehi.extend_dhivehi()
17
-
18
- self.model = ChatterboxTTS.from_dhivehi(
19
- ckpt_dir=Path(self.checkpoint),
20
- device="cuda" if torch.cuda.is_available() else "cpu",
21
- force_vocab_size=2000,
22
- )
23
- """
24
-
25
- from __future__ import annotations
26
- import logging
27
- from pathlib import Path
28
- from typing import Optional, Union
29
-
30
- import torch
31
- import torch.nn as nn
32
- from safetensors.torch import load_file
33
-
34
- # Core chatterbox imports
35
- from chatterbox.tts import ChatterboxTTS, Conditionals
36
- from chatterbox.models.t3 import T3
37
- from chatterbox.models.s3gen import S3Gen
38
- from chatterbox.models.tokenizers import EnTokenizer
39
- from chatterbox.models.voice_encoder import VoiceEncoder
40
-
41
-
42
- # Helpers
43
-
44
- def _expand_or_trim_rows(t: torch.Tensor, new_rows: int, init_std: float = 0.02) -> torch.Tensor:
45
- """
46
- Return a tensor with first dimension resized to `new_rows`.
47
- If expanding, newly added rows are randomly initialized N(0, init_std).
48
- """
49
- old_rows = t.shape[0]
50
- if new_rows == old_rows:
51
- return t.clone()
52
- if new_rows < old_rows:
53
- return t[:new_rows].clone()
54
- # expand
55
- out = t.new_empty((new_rows,) + t.shape[1:])
56
- out[:old_rows] = t
57
- out[old_rows:].normal_(mean=0.0, std=init_std)
58
- return out
59
-
60
-
61
- def _prepare_resized_state_dict(sd: dict, new_vocab: int, init_std: float = 0.02) -> dict:
62
- """
63
- Create a modified copy of `sd` where text_emb/text_head weights (and bias) match `new_vocab`.
64
- """
65
- sd = sd.copy()
66
-
67
- # text embedding: [vocab, dim]
68
- if "text_emb.weight" in sd:
69
- sd["text_emb.weight"] = _expand_or_trim_rows(sd["text_emb.weight"], new_vocab, init_std)
70
-
71
- # text projection head: Linear(out=vocab, in=dim)
72
- if "text_head.weight" in sd:
73
- sd["text_head.weight"] = _expand_or_trim_rows(sd["text_head.weight"], new_vocab, init_std)
74
- if "text_head.bias" in sd:
75
- bias = sd["text_head.bias"]
76
- if bias.ndim == 1:
77
- sd["text_head.bias"] = _expand_or_trim_rows(bias.unsqueeze(1), new_vocab, init_std).squeeze(1)
78
-
79
- return sd
80
-
81
-
82
- def _resize_model_vocab_layers(model: T3, new_vocab: int, dim: Optional[int] = None) -> None:
83
- """
84
- Rebuild model.text_emb and model.text_head to match `new_vocab`.
85
- Embedding dim is inferred from existing layers if not provided.
86
- """
87
- if dim is None:
88
- if hasattr(model, "text_emb") and isinstance(model.text_emb, nn.Embedding):
89
- dim = model.text_emb.embedding_dim
90
- elif hasattr(model, "text_head") and isinstance(model.text_head, nn.Linear):
91
- dim = model.text_head.in_features
92
- else:
93
- raise RuntimeError("Cannot infer text embedding dimension from T3 model.")
94
- model.text_emb = nn.Embedding(new_vocab, dim)
95
- model.text_head = nn.Linear(dim, new_vocab, bias=True)
96
-
97
-
98
- # Public api
99
-
100
- def load_t3_with_vocab(
101
- t3_state_dict: dict,
102
- device: str = "cpu",
103
- *,
104
- force_vocab_size: Optional[int] = None,
105
- init_std: float = 0.02,
106
- ) -> T3:
107
- """
108
- Load a T3 model with a specified vocabulary size.
109
-
110
- - Removes a leading "t3." prefix on state_dict keys if present.
111
- - Resizes BOTH `text_emb` and `text_head` to `force_vocab_size` (or to the checkpoint vocab if not forced).
112
- - Pads checkpoint weights when the target vocab is larger than the checkpoint's.
113
-
114
- Args:
115
- t3_state_dict: state dict loaded from t3_cfg.safetensors (or similar).
116
- device: "cpu", "cuda", or "mps".
117
- force_vocab_size: desired vocab size (e.g., 2000 for Dhivehi-extended models).
118
- init_std: std for random init of padded rows.
119
-
120
- Returns:
121
- T3: model moved to `device` and set to eval().
122
- """
123
- logger = logging.getLogger(__name__)
124
-
125
- # Strip "t3." prefix if present
126
- if any(k.startswith("t3.") for k in t3_state_dict.keys()):
127
- t3_state_dict = {k[len("t3."):]: v for k, v in t3_state_dict.items()}
128
-
129
- # derive checkpoint vocab if available
130
- ckpt_vocab_size = None
131
- if "text_emb.weight" in t3_state_dict and t3_state_dict["text_emb.weight"].ndim == 2:
132
- ckpt_vocab_size = int(t3_state_dict["text_emb.weight"].shape[0])
133
- elif "text_head.weight" in t3_state_dict and t3_state_dict["text_head.weight"].ndim == 2:
134
- ckpt_vocab_size = int(t3_state_dict["text_head.weight"].shape[0])
135
-
136
- target_vocab = int(force_vocab_size) if force_vocab_size is not None else ckpt_vocab_size
137
- if target_vocab is None:
138
- raise RuntimeError("Could not determine vocab size. Provide force_vocab_size.")
139
-
140
- logger.info(f"Loading T3 with vocab={target_vocab} (ckpt_vocab={ckpt_vocab_size})")
141
-
142
- # Build a base model and resize layers to accept the incoming state dict
143
- t3 = T3()
144
- _resize_model_vocab_layers(t3, target_vocab)
145
-
146
- # Patch the checkpoint tensors to the target vocab
147
- patched_sd = _prepare_resized_state_dict(t3_state_dict, target_vocab, init_std)
148
-
149
- # Load (strict=False to tolerate benign extra/missing keys)
150
- t3.load_state_dict(patched_sd, strict=False)
151
- return t3.to(device).eval()
152
-
153
-
154
- def from_dhivehi(
155
- cls,
156
- *,
157
- ckpt_dir: Union[str, Path],
158
- device: str = "cpu",
159
- force_vocab_size: int = 1199,
160
- ):
161
- """
162
- Construct a Dhivehi-extended ChatterboxTTS from a checkpoint directory.
163
-
164
- Expected files in `ckpt_dir`:
165
- - ve.safetensors
166
- - t3_cfg.safetensors
167
- - s3gen.safetensors
168
- - tokenizer.json
169
- - conds.pt (optional)
170
- """
171
- ckpt_dir = Path(ckpt_dir)
172
-
173
- # Voice encoder
174
- ve = VoiceEncoder()
175
- ve.load_state_dict(load_file(ckpt_dir / "ve.safetensors"))
176
- ve.to(device).eval()
177
-
178
- # T3 with Dhivehi vocab extension
179
- t3_state = load_file(ckpt_dir / "t3_cfg.safetensors")
180
- t3 = load_t3_with_vocab(t3_state, device=device, force_vocab_size=force_vocab_size)
181
-
182
- # S3Gen
183
- s3gen = S3Gen()
184
- s3gen.load_state_dict(load_file(ckpt_dir / "s3gen.safetensors"), strict=False)
185
- s3gen.to(device).eval()
186
-
187
- # Tokenizer
188
- tokenizer = EnTokenizer(str(ckpt_dir / "tokenizer.json"))
189
-
190
- # Optional conditionals
191
- conds = None
192
- conds_path = ckpt_dir / "conds.pt"
193
- if conds_path.exists():
194
- # Always safe-load to CPU first; .to(device) later
195
- conds = Conditionals.load(conds_path, map_location="cpu").to(device)
196
-
197
- return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
198
-
199
-
200
- def extend_dhivehi():
201
- """
202
- Attach Dhivehi-specific helpers to ChatterboxTTS (idempotent).
203
- - ChatterboxTTS.load_t3_with_vocab (staticmethod)
204
- - ChatterboxTTS.from_dhivehi (classmethod)
205
- """
206
- if getattr(ChatterboxTTS, "_dhivehi_extended", False):
207
- return
208
- ChatterboxTTS.load_t3_with_vocab = staticmethod(load_t3_with_vocab)
209
- ChatterboxTTS.from_dhivehi = classmethod(from_dhivehi)
210
- ChatterboxTTS._dhivehi_extended = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- chatterbox-tts==0.1.4
2
  transformers==4.53.0
3
  librosa==0.11.0
4
  accelerate==1.8.1
 
 
1
  transformers==4.53.0
2
  librosa==0.11.0
3
  accelerate==1.8.1