File size: 7,328 Bytes
d735744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71b9145
d735744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
# chatterbox_dhivehi.py
"""
Dhivehi extension for ChatterboxTTS. 

Requires: chatterbox-tts 0.1.4 (not tested on any other version)

Adds:
  - load_t3_with_vocab(state_dict, device, force_vocab_size): load T3 with a specific vocab size,
    resizing both the embedding and the projection head, and padding checkpoint weights if needed.
  - from_dhivehi(...): classmethod for building a ChatterboxTTS from a checkpoint directory,
    using load_t3_with_vocab under the hood (defaults to vocab=2000).
  - extend_dhivehi(): attach the above to ChatterboxTTS (idempotent).

Usage in app.py:
    import chatterbox_dhivehi
    chatterbox_dhivehi.extend_dhivehi()

    self.model = ChatterboxTTS.from_dhivehi(
        ckpt_dir=Path(self.checkpoint),
        device="cuda" if torch.cuda.is_available() else "cpu",
        force_vocab_size=2000,
    )
"""

from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional, Union

import torch
import torch.nn as nn
from safetensors.torch import load_file

# Core chatterbox imports
from chatterbox.tts import ChatterboxTTS, Conditionals
from chatterbox.models.t3 import T3
from chatterbox.models.s3gen import S3Gen
from chatterbox.models.tokenizers import EnTokenizer
from chatterbox.models.voice_encoder import VoiceEncoder


# Helpers

def _expand_or_trim_rows(t: torch.Tensor, new_rows: int, init_std: float = 0.02) -> torch.Tensor:
    """
    Return a tensor with first dimension resized to `new_rows`.
    If expanding, newly added rows are randomly initialized N(0, init_std).
    """
    old_rows = t.shape[0]
    if new_rows == old_rows:
        return t.clone()
    if new_rows < old_rows:
        return t[:new_rows].clone()
    # expand
    out = t.new_empty((new_rows,) + t.shape[1:])
    out[:old_rows] = t
    out[old_rows:].normal_(mean=0.0, std=init_std)
    return out


def _prepare_resized_state_dict(sd: dict, new_vocab: int, init_std: float = 0.02) -> dict:
    """
    Create a modified copy of `sd` where text_emb/text_head weights (and bias) match `new_vocab`.
    """
    sd = sd.copy()

    # text embedding: [vocab, dim]
    if "text_emb.weight" in sd:
        sd["text_emb.weight"] = _expand_or_trim_rows(sd["text_emb.weight"], new_vocab, init_std)

    # text projection head: Linear(out=vocab, in=dim)
    if "text_head.weight" in sd:
        sd["text_head.weight"] = _expand_or_trim_rows(sd["text_head.weight"], new_vocab, init_std)
    if "text_head.bias" in sd:
        bias = sd["text_head.bias"]
        if bias.ndim == 1:
            sd["text_head.bias"] = _expand_or_trim_rows(bias.unsqueeze(1), new_vocab, init_std).squeeze(1)

    return sd


def _resize_model_vocab_layers(model: T3, new_vocab: int, dim: Optional[int] = None) -> None:
    """
    Rebuild model.text_emb and model.text_head to match `new_vocab`.
    Embedding dim is inferred from existing layers if not provided.
    """
    if dim is None:
        if hasattr(model, "text_emb") and isinstance(model.text_emb, nn.Embedding):
            dim = model.text_emb.embedding_dim
        elif hasattr(model, "text_head") and isinstance(model.text_head, nn.Linear):
            dim = model.text_head.in_features
        else:
            raise RuntimeError("Cannot infer text embedding dimension from T3 model.")
    model.text_emb = nn.Embedding(new_vocab, dim)
    model.text_head = nn.Linear(dim, new_vocab, bias=True)


# Public api

def load_t3_with_vocab(
    t3_state_dict: dict,
    device: str = "cpu",
    *,
    force_vocab_size: Optional[int] = None,
    init_std: float = 0.02,
) -> T3:
    """
    Load a T3 model with a specified vocabulary size.

    - Removes a leading "t3." prefix on state_dict keys if present.
    - Resizes BOTH `text_emb` and `text_head` to `force_vocab_size` (or to the checkpoint vocab if not forced).
    - Pads checkpoint weights when the target vocab is larger than the checkpoint's.

    Args:
        t3_state_dict: state dict loaded from t3_cfg.safetensors (or similar).
        device: "cpu", "cuda", or "mps".
        force_vocab_size: desired vocab size (e.g., 2000 for Dhivehi-extended models).
        init_std: std for random init of padded rows.

    Returns:
        T3: model moved to `device` and set to eval().
    """
    logger = logging.getLogger(__name__)

    # Strip "t3." prefix if present
    if any(k.startswith("t3.") for k in t3_state_dict.keys()):
        t3_state_dict = {k[len("t3."):]: v for k, v in t3_state_dict.items()}

    # derive checkpoint vocab if available
    ckpt_vocab_size = None
    if "text_emb.weight" in t3_state_dict and t3_state_dict["text_emb.weight"].ndim == 2:
        ckpt_vocab_size = int(t3_state_dict["text_emb.weight"].shape[0])
    elif "text_head.weight" in t3_state_dict and t3_state_dict["text_head.weight"].ndim == 2:
        ckpt_vocab_size = int(t3_state_dict["text_head.weight"].shape[0])

    target_vocab = int(force_vocab_size) if force_vocab_size is not None else ckpt_vocab_size
    if target_vocab is None:
        raise RuntimeError("Could not determine vocab size. Provide force_vocab_size.")

    logger.info(f"Loading T3 with vocab={target_vocab} (ckpt_vocab={ckpt_vocab_size})")

    # Build a base model and resize layers to accept the incoming state dict
    t3 = T3()
    _resize_model_vocab_layers(t3, target_vocab)

    # Patch the checkpoint tensors to the target vocab
    patched_sd = _prepare_resized_state_dict(t3_state_dict, target_vocab, init_std)

    # Load (strict=False to tolerate benign extra/missing keys)
    t3.load_state_dict(patched_sd, strict=False)
    return t3.to(device).eval()


def from_dhivehi(
    cls,
    *,
    ckpt_dir: Union[str, Path],
    device: str = "cpu",
    force_vocab_size: int = 2000,
):
    """
    Construct a Dhivehi-extended ChatterboxTTS from a checkpoint directory.

    Expected files in `ckpt_dir`:
      - ve.safetensors
      - t3_cfg.safetensors
      - s3gen.safetensors
      - tokenizer.json
      - conds.pt (optional)
    """
    ckpt_dir = Path(ckpt_dir)

    # Voice encoder
    ve = VoiceEncoder()
    ve.load_state_dict(load_file(ckpt_dir / "ve.safetensors"))
    ve.to(device).eval()

    # T3 with Dhivehi vocab extension
    t3_state = load_file(ckpt_dir / "t3_cfg.safetensors")
    t3 = load_t3_with_vocab(t3_state, device=device, force_vocab_size=force_vocab_size)

    # S3Gen
    s3gen = S3Gen()
    s3gen.load_state_dict(load_file(ckpt_dir / "s3gen.safetensors"), strict=False)
    s3gen.to(device).eval()

    # Tokenizer
    tokenizer = EnTokenizer(str(ckpt_dir / "tokenizer.json"))

    # Optional conditionals
    conds = None
    conds_path = ckpt_dir / "conds.pt"
    if conds_path.exists():
        # Always safe-load to CPU first; .to(device) later
        conds = Conditionals.load(conds_path, map_location="cpu").to(device)

    return cls(t3, s3gen, ve, tokenizer, device, conds=conds)


def extend_dhivehi():
    """
    Attach Dhivehi-specific helpers to ChatterboxTTS (idempotent).
    - ChatterboxTTS.load_t3_with_vocab (staticmethod)
    - ChatterboxTTS.from_dhivehi (classmethod)
    """
    if getattr(ChatterboxTTS, "_dhivehi_extended", False):
        return
    ChatterboxTTS.load_t3_with_vocab = staticmethod(load_t3_with_vocab)
    ChatterboxTTS.from_dhivehi = classmethod(from_dhivehi)
    ChatterboxTTS._dhivehi_extended = True