chatterbox-tts-dhivehi / chatterbox_dhivehi.py
alakxender's picture
t
71b9145
# chatterbox_dhivehi.py
"""
Dhivehi extension for ChatterboxTTS.
Requires: chatterbox-tts 0.1.4 (not tested on any other version)
Adds:
- load_t3_with_vocab(state_dict, device, force_vocab_size): load T3 with a specific vocab size,
resizing both the embedding and the projection head, and padding checkpoint weights if needed.
- from_dhivehi(...): classmethod for building a ChatterboxTTS from a checkpoint directory,
using load_t3_with_vocab under the hood (defaults to vocab=2000).
- extend_dhivehi(): attach the above to ChatterboxTTS (idempotent).
Usage in app.py:
import chatterbox_dhivehi
chatterbox_dhivehi.extend_dhivehi()
self.model = ChatterboxTTS.from_dhivehi(
ckpt_dir=Path(self.checkpoint),
device="cuda" if torch.cuda.is_available() else "cpu",
force_vocab_size=2000,
)
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional, Union
import torch
import torch.nn as nn
from safetensors.torch import load_file
# Core chatterbox imports
from chatterbox.tts import ChatterboxTTS, Conditionals
from chatterbox.models.t3 import T3
from chatterbox.models.s3gen import S3Gen
from chatterbox.models.tokenizers import EnTokenizer
from chatterbox.models.voice_encoder import VoiceEncoder
# Helpers
def _expand_or_trim_rows(t: torch.Tensor, new_rows: int, init_std: float = 0.02) -> torch.Tensor:
"""
Return a tensor with first dimension resized to `new_rows`.
If expanding, newly added rows are randomly initialized N(0, init_std).
"""
old_rows = t.shape[0]
if new_rows == old_rows:
return t.clone()
if new_rows < old_rows:
return t[:new_rows].clone()
# expand
out = t.new_empty((new_rows,) + t.shape[1:])
out[:old_rows] = t
out[old_rows:].normal_(mean=0.0, std=init_std)
return out
def _prepare_resized_state_dict(sd: dict, new_vocab: int, init_std: float = 0.02) -> dict:
"""
Create a modified copy of `sd` where text_emb/text_head weights (and bias) match `new_vocab`.
"""
sd = sd.copy()
# text embedding: [vocab, dim]
if "text_emb.weight" in sd:
sd["text_emb.weight"] = _expand_or_trim_rows(sd["text_emb.weight"], new_vocab, init_std)
# text projection head: Linear(out=vocab, in=dim)
if "text_head.weight" in sd:
sd["text_head.weight"] = _expand_or_trim_rows(sd["text_head.weight"], new_vocab, init_std)
if "text_head.bias" in sd:
bias = sd["text_head.bias"]
if bias.ndim == 1:
sd["text_head.bias"] = _expand_or_trim_rows(bias.unsqueeze(1), new_vocab, init_std).squeeze(1)
return sd
def _resize_model_vocab_layers(model: T3, new_vocab: int, dim: Optional[int] = None) -> None:
"""
Rebuild model.text_emb and model.text_head to match `new_vocab`.
Embedding dim is inferred from existing layers if not provided.
"""
if dim is None:
if hasattr(model, "text_emb") and isinstance(model.text_emb, nn.Embedding):
dim = model.text_emb.embedding_dim
elif hasattr(model, "text_head") and isinstance(model.text_head, nn.Linear):
dim = model.text_head.in_features
else:
raise RuntimeError("Cannot infer text embedding dimension from T3 model.")
model.text_emb = nn.Embedding(new_vocab, dim)
model.text_head = nn.Linear(dim, new_vocab, bias=True)
# Public api
def load_t3_with_vocab(
t3_state_dict: dict,
device: str = "cpu",
*,
force_vocab_size: Optional[int] = None,
init_std: float = 0.02,
) -> T3:
"""
Load a T3 model with a specified vocabulary size.
- Removes a leading "t3." prefix on state_dict keys if present.
- Resizes BOTH `text_emb` and `text_head` to `force_vocab_size` (or to the checkpoint vocab if not forced).
- Pads checkpoint weights when the target vocab is larger than the checkpoint's.
Args:
t3_state_dict: state dict loaded from t3_cfg.safetensors (or similar).
device: "cpu", "cuda", or "mps".
force_vocab_size: desired vocab size (e.g., 2000 for Dhivehi-extended models).
init_std: std for random init of padded rows.
Returns:
T3: model moved to `device` and set to eval().
"""
logger = logging.getLogger(__name__)
# Strip "t3." prefix if present
if any(k.startswith("t3.") for k in t3_state_dict.keys()):
t3_state_dict = {k[len("t3."):]: v for k, v in t3_state_dict.items()}
# derive checkpoint vocab if available
ckpt_vocab_size = None
if "text_emb.weight" in t3_state_dict and t3_state_dict["text_emb.weight"].ndim == 2:
ckpt_vocab_size = int(t3_state_dict["text_emb.weight"].shape[0])
elif "text_head.weight" in t3_state_dict and t3_state_dict["text_head.weight"].ndim == 2:
ckpt_vocab_size = int(t3_state_dict["text_head.weight"].shape[0])
target_vocab = int(force_vocab_size) if force_vocab_size is not None else ckpt_vocab_size
if target_vocab is None:
raise RuntimeError("Could not determine vocab size. Provide force_vocab_size.")
logger.info(f"Loading T3 with vocab={target_vocab} (ckpt_vocab={ckpt_vocab_size})")
# Build a base model and resize layers to accept the incoming state dict
t3 = T3()
_resize_model_vocab_layers(t3, target_vocab)
# Patch the checkpoint tensors to the target vocab
patched_sd = _prepare_resized_state_dict(t3_state_dict, target_vocab, init_std)
# Load (strict=False to tolerate benign extra/missing keys)
t3.load_state_dict(patched_sd, strict=False)
return t3.to(device).eval()
def from_dhivehi(
cls,
*,
ckpt_dir: Union[str, Path],
device: str = "cpu",
force_vocab_size: int = 2000,
):
"""
Construct a Dhivehi-extended ChatterboxTTS from a checkpoint directory.
Expected files in `ckpt_dir`:
- ve.safetensors
- t3_cfg.safetensors
- s3gen.safetensors
- tokenizer.json
- conds.pt (optional)
"""
ckpt_dir = Path(ckpt_dir)
# Voice encoder
ve = VoiceEncoder()
ve.load_state_dict(load_file(ckpt_dir / "ve.safetensors"))
ve.to(device).eval()
# T3 with Dhivehi vocab extension
t3_state = load_file(ckpt_dir / "t3_cfg.safetensors")
t3 = load_t3_with_vocab(t3_state, device=device, force_vocab_size=force_vocab_size)
# S3Gen
s3gen = S3Gen()
s3gen.load_state_dict(load_file(ckpt_dir / "s3gen.safetensors"), strict=False)
s3gen.to(device).eval()
# Tokenizer
tokenizer = EnTokenizer(str(ckpt_dir / "tokenizer.json"))
# Optional conditionals
conds = None
conds_path = ckpt_dir / "conds.pt"
if conds_path.exists():
# Always safe-load to CPU first; .to(device) later
conds = Conditionals.load(conds_path, map_location="cpu").to(device)
return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
def extend_dhivehi():
"""
Attach Dhivehi-specific helpers to ChatterboxTTS (idempotent).
- ChatterboxTTS.load_t3_with_vocab (staticmethod)
- ChatterboxTTS.from_dhivehi (classmethod)
"""
if getattr(ChatterboxTTS, "_dhivehi_extended", False):
return
ChatterboxTTS.load_t3_with_vocab = staticmethod(load_t3_with_vocab)
ChatterboxTTS.from_dhivehi = classmethod(from_dhivehi)
ChatterboxTTS._dhivehi_extended = True