Spaces:
Build error
Build error
File size: 7,328 Bytes
d735744 71b9145 d735744 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
# chatterbox_dhivehi.py
"""
Dhivehi extension for ChatterboxTTS.
Requires: chatterbox-tts 0.1.4 (not tested on any other version)
Adds:
- load_t3_with_vocab(state_dict, device, force_vocab_size): load T3 with a specific vocab size,
resizing both the embedding and the projection head, and padding checkpoint weights if needed.
- from_dhivehi(...): classmethod for building a ChatterboxTTS from a checkpoint directory,
using load_t3_with_vocab under the hood (defaults to vocab=2000).
- extend_dhivehi(): attach the above to ChatterboxTTS (idempotent).
Usage in app.py:
import chatterbox_dhivehi
chatterbox_dhivehi.extend_dhivehi()
self.model = ChatterboxTTS.from_dhivehi(
ckpt_dir=Path(self.checkpoint),
device="cuda" if torch.cuda.is_available() else "cpu",
force_vocab_size=2000,
)
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional, Union
import torch
import torch.nn as nn
from safetensors.torch import load_file
# Core chatterbox imports
from chatterbox.tts import ChatterboxTTS, Conditionals
from chatterbox.models.t3 import T3
from chatterbox.models.s3gen import S3Gen
from chatterbox.models.tokenizers import EnTokenizer
from chatterbox.models.voice_encoder import VoiceEncoder
# Helpers
def _expand_or_trim_rows(t: torch.Tensor, new_rows: int, init_std: float = 0.02) -> torch.Tensor:
"""
Return a tensor with first dimension resized to `new_rows`.
If expanding, newly added rows are randomly initialized N(0, init_std).
"""
old_rows = t.shape[0]
if new_rows == old_rows:
return t.clone()
if new_rows < old_rows:
return t[:new_rows].clone()
# expand
out = t.new_empty((new_rows,) + t.shape[1:])
out[:old_rows] = t
out[old_rows:].normal_(mean=0.0, std=init_std)
return out
def _prepare_resized_state_dict(sd: dict, new_vocab: int, init_std: float = 0.02) -> dict:
"""
Create a modified copy of `sd` where text_emb/text_head weights (and bias) match `new_vocab`.
"""
sd = sd.copy()
# text embedding: [vocab, dim]
if "text_emb.weight" in sd:
sd["text_emb.weight"] = _expand_or_trim_rows(sd["text_emb.weight"], new_vocab, init_std)
# text projection head: Linear(out=vocab, in=dim)
if "text_head.weight" in sd:
sd["text_head.weight"] = _expand_or_trim_rows(sd["text_head.weight"], new_vocab, init_std)
if "text_head.bias" in sd:
bias = sd["text_head.bias"]
if bias.ndim == 1:
sd["text_head.bias"] = _expand_or_trim_rows(bias.unsqueeze(1), new_vocab, init_std).squeeze(1)
return sd
def _resize_model_vocab_layers(model: T3, new_vocab: int, dim: Optional[int] = None) -> None:
"""
Rebuild model.text_emb and model.text_head to match `new_vocab`.
Embedding dim is inferred from existing layers if not provided.
"""
if dim is None:
if hasattr(model, "text_emb") and isinstance(model.text_emb, nn.Embedding):
dim = model.text_emb.embedding_dim
elif hasattr(model, "text_head") and isinstance(model.text_head, nn.Linear):
dim = model.text_head.in_features
else:
raise RuntimeError("Cannot infer text embedding dimension from T3 model.")
model.text_emb = nn.Embedding(new_vocab, dim)
model.text_head = nn.Linear(dim, new_vocab, bias=True)
# Public api
def load_t3_with_vocab(
t3_state_dict: dict,
device: str = "cpu",
*,
force_vocab_size: Optional[int] = None,
init_std: float = 0.02,
) -> T3:
"""
Load a T3 model with a specified vocabulary size.
- Removes a leading "t3." prefix on state_dict keys if present.
- Resizes BOTH `text_emb` and `text_head` to `force_vocab_size` (or to the checkpoint vocab if not forced).
- Pads checkpoint weights when the target vocab is larger than the checkpoint's.
Args:
t3_state_dict: state dict loaded from t3_cfg.safetensors (or similar).
device: "cpu", "cuda", or "mps".
force_vocab_size: desired vocab size (e.g., 2000 for Dhivehi-extended models).
init_std: std for random init of padded rows.
Returns:
T3: model moved to `device` and set to eval().
"""
logger = logging.getLogger(__name__)
# Strip "t3." prefix if present
if any(k.startswith("t3.") for k in t3_state_dict.keys()):
t3_state_dict = {k[len("t3."):]: v for k, v in t3_state_dict.items()}
# derive checkpoint vocab if available
ckpt_vocab_size = None
if "text_emb.weight" in t3_state_dict and t3_state_dict["text_emb.weight"].ndim == 2:
ckpt_vocab_size = int(t3_state_dict["text_emb.weight"].shape[0])
elif "text_head.weight" in t3_state_dict and t3_state_dict["text_head.weight"].ndim == 2:
ckpt_vocab_size = int(t3_state_dict["text_head.weight"].shape[0])
target_vocab = int(force_vocab_size) if force_vocab_size is not None else ckpt_vocab_size
if target_vocab is None:
raise RuntimeError("Could not determine vocab size. Provide force_vocab_size.")
logger.info(f"Loading T3 with vocab={target_vocab} (ckpt_vocab={ckpt_vocab_size})")
# Build a base model and resize layers to accept the incoming state dict
t3 = T3()
_resize_model_vocab_layers(t3, target_vocab)
# Patch the checkpoint tensors to the target vocab
patched_sd = _prepare_resized_state_dict(t3_state_dict, target_vocab, init_std)
# Load (strict=False to tolerate benign extra/missing keys)
t3.load_state_dict(patched_sd, strict=False)
return t3.to(device).eval()
def from_dhivehi(
cls,
*,
ckpt_dir: Union[str, Path],
device: str = "cpu",
force_vocab_size: int = 2000,
):
"""
Construct a Dhivehi-extended ChatterboxTTS from a checkpoint directory.
Expected files in `ckpt_dir`:
- ve.safetensors
- t3_cfg.safetensors
- s3gen.safetensors
- tokenizer.json
- conds.pt (optional)
"""
ckpt_dir = Path(ckpt_dir)
# Voice encoder
ve = VoiceEncoder()
ve.load_state_dict(load_file(ckpt_dir / "ve.safetensors"))
ve.to(device).eval()
# T3 with Dhivehi vocab extension
t3_state = load_file(ckpt_dir / "t3_cfg.safetensors")
t3 = load_t3_with_vocab(t3_state, device=device, force_vocab_size=force_vocab_size)
# S3Gen
s3gen = S3Gen()
s3gen.load_state_dict(load_file(ckpt_dir / "s3gen.safetensors"), strict=False)
s3gen.to(device).eval()
# Tokenizer
tokenizer = EnTokenizer(str(ckpt_dir / "tokenizer.json"))
# Optional conditionals
conds = None
conds_path = ckpt_dir / "conds.pt"
if conds_path.exists():
# Always safe-load to CPU first; .to(device) later
conds = Conditionals.load(conds_path, map_location="cpu").to(device)
return cls(t3, s3gen, ve, tokenizer, device, conds=conds)
def extend_dhivehi():
"""
Attach Dhivehi-specific helpers to ChatterboxTTS (idempotent).
- ChatterboxTTS.load_t3_with_vocab (staticmethod)
- ChatterboxTTS.from_dhivehi (classmethod)
"""
if getattr(ChatterboxTTS, "_dhivehi_extended", False):
return
ChatterboxTTS.load_t3_with_vocab = staticmethod(load_t3_with_vocab)
ChatterboxTTS.from_dhivehi = classmethod(from_dhivehi)
ChatterboxTTS._dhivehi_extended = True
|