Spaces:
Build error
Build error
| # chatterbox_dhivehi.py | |
| """ | |
| Dhivehi extension for ChatterboxTTS. | |
| Requires: chatterbox-tts 0.1.4 (not tested on any other version) | |
| Adds: | |
| - load_t3_with_vocab(state_dict, device, force_vocab_size): load T3 with a specific vocab size, | |
| resizing both the embedding and the projection head, and padding checkpoint weights if needed. | |
| - from_dhivehi(...): classmethod for building a ChatterboxTTS from a checkpoint directory, | |
| using load_t3_with_vocab under the hood (defaults to vocab=2000). | |
| - extend_dhivehi(): attach the above to ChatterboxTTS (idempotent). | |
| Usage in app.py: | |
| import chatterbox_dhivehi | |
| chatterbox_dhivehi.extend_dhivehi() | |
| self.model = ChatterboxTTS.from_dhivehi( | |
| ckpt_dir=Path(self.checkpoint), | |
| device="cuda" if torch.cuda.is_available() else "cpu", | |
| force_vocab_size=2000, | |
| ) | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional, Union | |
| import torch | |
| import torch.nn as nn | |
| from safetensors.torch import load_file | |
| # Core chatterbox imports | |
| from chatterbox.tts import ChatterboxTTS, Conditionals | |
| from chatterbox.models.t3 import T3 | |
| from chatterbox.models.s3gen import S3Gen | |
| from chatterbox.models.tokenizers import EnTokenizer | |
| from chatterbox.models.voice_encoder import VoiceEncoder | |
| # Helpers | |
| def _expand_or_trim_rows(t: torch.Tensor, new_rows: int, init_std: float = 0.02) -> torch.Tensor: | |
| """ | |
| Return a tensor with first dimension resized to `new_rows`. | |
| If expanding, newly added rows are randomly initialized N(0, init_std). | |
| """ | |
| old_rows = t.shape[0] | |
| if new_rows == old_rows: | |
| return t.clone() | |
| if new_rows < old_rows: | |
| return t[:new_rows].clone() | |
| # expand | |
| out = t.new_empty((new_rows,) + t.shape[1:]) | |
| out[:old_rows] = t | |
| out[old_rows:].normal_(mean=0.0, std=init_std) | |
| return out | |
| def _prepare_resized_state_dict(sd: dict, new_vocab: int, init_std: float = 0.02) -> dict: | |
| """ | |
| Create a modified copy of `sd` where text_emb/text_head weights (and bias) match `new_vocab`. | |
| """ | |
| sd = sd.copy() | |
| # text embedding: [vocab, dim] | |
| if "text_emb.weight" in sd: | |
| sd["text_emb.weight"] = _expand_or_trim_rows(sd["text_emb.weight"], new_vocab, init_std) | |
| # text projection head: Linear(out=vocab, in=dim) | |
| if "text_head.weight" in sd: | |
| sd["text_head.weight"] = _expand_or_trim_rows(sd["text_head.weight"], new_vocab, init_std) | |
| if "text_head.bias" in sd: | |
| bias = sd["text_head.bias"] | |
| if bias.ndim == 1: | |
| sd["text_head.bias"] = _expand_or_trim_rows(bias.unsqueeze(1), new_vocab, init_std).squeeze(1) | |
| return sd | |
| def _resize_model_vocab_layers(model: T3, new_vocab: int, dim: Optional[int] = None) -> None: | |
| """ | |
| Rebuild model.text_emb and model.text_head to match `new_vocab`. | |
| Embedding dim is inferred from existing layers if not provided. | |
| """ | |
| if dim is None: | |
| if hasattr(model, "text_emb") and isinstance(model.text_emb, nn.Embedding): | |
| dim = model.text_emb.embedding_dim | |
| elif hasattr(model, "text_head") and isinstance(model.text_head, nn.Linear): | |
| dim = model.text_head.in_features | |
| else: | |
| raise RuntimeError("Cannot infer text embedding dimension from T3 model.") | |
| model.text_emb = nn.Embedding(new_vocab, dim) | |
| model.text_head = nn.Linear(dim, new_vocab, bias=True) | |
| # Public api | |
| def load_t3_with_vocab( | |
| t3_state_dict: dict, | |
| device: str = "cpu", | |
| *, | |
| force_vocab_size: Optional[int] = None, | |
| init_std: float = 0.02, | |
| ) -> T3: | |
| """ | |
| Load a T3 model with a specified vocabulary size. | |
| - Removes a leading "t3." prefix on state_dict keys if present. | |
| - Resizes BOTH `text_emb` and `text_head` to `force_vocab_size` (or to the checkpoint vocab if not forced). | |
| - Pads checkpoint weights when the target vocab is larger than the checkpoint's. | |
| Args: | |
| t3_state_dict: state dict loaded from t3_cfg.safetensors (or similar). | |
| device: "cpu", "cuda", or "mps". | |
| force_vocab_size: desired vocab size (e.g., 2000 for Dhivehi-extended models). | |
| init_std: std for random init of padded rows. | |
| Returns: | |
| T3: model moved to `device` and set to eval(). | |
| """ | |
| logger = logging.getLogger(__name__) | |
| # Strip "t3." prefix if present | |
| if any(k.startswith("t3.") for k in t3_state_dict.keys()): | |
| t3_state_dict = {k[len("t3."):]: v for k, v in t3_state_dict.items()} | |
| # derive checkpoint vocab if available | |
| ckpt_vocab_size = None | |
| if "text_emb.weight" in t3_state_dict and t3_state_dict["text_emb.weight"].ndim == 2: | |
| ckpt_vocab_size = int(t3_state_dict["text_emb.weight"].shape[0]) | |
| elif "text_head.weight" in t3_state_dict and t3_state_dict["text_head.weight"].ndim == 2: | |
| ckpt_vocab_size = int(t3_state_dict["text_head.weight"].shape[0]) | |
| target_vocab = int(force_vocab_size) if force_vocab_size is not None else ckpt_vocab_size | |
| if target_vocab is None: | |
| raise RuntimeError("Could not determine vocab size. Provide force_vocab_size.") | |
| logger.info(f"Loading T3 with vocab={target_vocab} (ckpt_vocab={ckpt_vocab_size})") | |
| # Build a base model and resize layers to accept the incoming state dict | |
| t3 = T3() | |
| _resize_model_vocab_layers(t3, target_vocab) | |
| # Patch the checkpoint tensors to the target vocab | |
| patched_sd = _prepare_resized_state_dict(t3_state_dict, target_vocab, init_std) | |
| # Load (strict=False to tolerate benign extra/missing keys) | |
| t3.load_state_dict(patched_sd, strict=False) | |
| return t3.to(device).eval() | |
| def from_dhivehi( | |
| cls, | |
| *, | |
| ckpt_dir: Union[str, Path], | |
| device: str = "cpu", | |
| force_vocab_size: int = 2000, | |
| ): | |
| """ | |
| Construct a Dhivehi-extended ChatterboxTTS from a checkpoint directory. | |
| Expected files in `ckpt_dir`: | |
| - ve.safetensors | |
| - t3_cfg.safetensors | |
| - s3gen.safetensors | |
| - tokenizer.json | |
| - conds.pt (optional) | |
| """ | |
| ckpt_dir = Path(ckpt_dir) | |
| # Voice encoder | |
| ve = VoiceEncoder() | |
| ve.load_state_dict(load_file(ckpt_dir / "ve.safetensors")) | |
| ve.to(device).eval() | |
| # T3 with Dhivehi vocab extension | |
| t3_state = load_file(ckpt_dir / "t3_cfg.safetensors") | |
| t3 = load_t3_with_vocab(t3_state, device=device, force_vocab_size=force_vocab_size) | |
| # S3Gen | |
| s3gen = S3Gen() | |
| s3gen.load_state_dict(load_file(ckpt_dir / "s3gen.safetensors"), strict=False) | |
| s3gen.to(device).eval() | |
| # Tokenizer | |
| tokenizer = EnTokenizer(str(ckpt_dir / "tokenizer.json")) | |
| # Optional conditionals | |
| conds = None | |
| conds_path = ckpt_dir / "conds.pt" | |
| if conds_path.exists(): | |
| # Always safe-load to CPU first; .to(device) later | |
| conds = Conditionals.load(conds_path, map_location="cpu").to(device) | |
| return cls(t3, s3gen, ve, tokenizer, device, conds=conds) | |
| def extend_dhivehi(): | |
| """ | |
| Attach Dhivehi-specific helpers to ChatterboxTTS (idempotent). | |
| - ChatterboxTTS.load_t3_with_vocab (staticmethod) | |
| - ChatterboxTTS.from_dhivehi (classmethod) | |
| """ | |
| if getattr(ChatterboxTTS, "_dhivehi_extended", False): | |
| return | |
| ChatterboxTTS.load_t3_with_vocab = staticmethod(load_t3_with_vocab) | |
| ChatterboxTTS.from_dhivehi = classmethod(from_dhivehi) | |
| ChatterboxTTS._dhivehi_extended = True | |