File size: 10,090 Bytes
09eaf7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
# -*- coding: utf-8 -*-
"""
step041_tts_higgs.py
HIGGS/Boson TTS — simple, stable, accent-aware (per-line synthesis).

Env (.env):
  BOSON_API_KEY=...
  BOSON_BASE_URL=https://hackathon.boson.ai/v1
  HIGGS_TTS_MODEL=higgs-audio-generation-Hackathon
Optional:
  HIGGS_TTS_SPEED=1.0     # speaking rate hint (server may clamp/ignore)
  HIGGS_TTS_PAD_MS=8      # tiny pad at start/end (ms)
  HIGGS_TTS_ALLOW_MISMATCH=0  # if 0 and text looks ASCII-English but lang != en, force 'en'

Public API (dispatcher-compatible):
  init_TTS()
  load_model()
  tts(text, output_path, speaker_wav=None, *, voice_type=None, target_language=None)

Notes:
  - Speak EXACTLY the provided `text` (pipeline passes line['translation']).
  - Unified language codes expected: zh-cn, zh-tw, en, ko, ja, es, fr.
"""

from __future__ import annotations
import os, base64, wave, time, random
from typing import Optional, Dict

import numpy as np
from dotenv import load_dotenv
from loguru import logger
from openai import OpenAI

# ------------------------------- Config ---------------------------------------

SR = 24000
SAMPLE_WIDTH = 2  # 16-bit PCM
NCHANNELS = 1

_client: Optional[OpenAI] = None
_model_name: Optional[str] = None

# env knobs
_HIGGS_SPEED = float(os.getenv("HIGGS_TTS_SPEED") or 1.0)
_PAD_MS      = int(os.getenv("HIGGS_TTS_PAD_MS") or 8)
_ALLOW_MISMATCH = bool(int(os.getenv("HIGGS_TTS_ALLOW_MISMATCH") or "0"))

# ------------------------ Unified language normalization -----------------------

# Accept labels OR codes -> return canonical code
_LANG_ALIASES: Dict[str, str] = {
    # Simplified Chinese
    "zh-cn": "zh-cn", "zh_cn": "zh-cn", "cn": "zh-cn",
    "chinese (中文)": "zh-cn", "chinese": "zh-cn", "中文": "zh-cn",
    "simplified chinese (简体中文)": "zh-cn", "simplified chinese": "zh-cn", "简体中文": "zh-cn",

    # Traditional Chinese
    "zh-tw": "zh-tw", "zh_tw": "zh-tw", "tw": "zh-tw",
    "traditional chinese (繁体中文)": "zh-tw", "traditional chinese": "zh-tw", "繁体中文": "zh-tw",

    # English
    "en": "en", "english": "en",

    # Korean
    "ko": "ko", "korean": "ko", "한국어": "ko",

    # Japanese
    "ja": "ja", "japanese": "ja", "日本語": "ja",

    # Spanish
    "es": "es", "spanish": "es", "español": "es",

    # French
    "fr": "fr", "french": "fr", "français": "fr",
}

_ALLOWED_LANGS = {"zh-cn", "zh-tw", "en", "ko", "ja", "es", "fr"}

# Accent defaults by language code
DEFAULT_REGION: Dict[str, str] = {
    "en": "US",
    "zh-cn": "China",
    "zh-tw": "Taiwan",
    "ja": "Japan",
    "ko": "Korea",
    "fr": "France",
    "es": "Spain",
}

# ---------------------------- Initialization ----------------------------------

def init_TTS():
    load_model()

def load_model():
    global _client, _model_name
    if _client is not None:
        return
    load_dotenv()
    api_key  = os.getenv("BOSON_API_KEY", "").strip()
    base_url = os.getenv("BOSON_BASE_URL", "https://hackathon.boson.ai/v1").strip()
    _model_name = os.getenv("HIGGS_TTS_MODEL", "higgs-audio-generation-Hackathon").strip()
    if not api_key:
        raise RuntimeError("BOSON_API_KEY is not set.")
    _client = OpenAI(api_key=api_key, base_url=base_url)
    logger.info(f"[HIGGS TTS] Client ready | base={base_url} | model={_model_name}")

# ------------------------------ Helpers ---------------------------------------

def _canon(s: Optional[str]) -> str:
    return "" if not s else str(s).strip().lower()

def _norm_lang(s: Optional[str]) -> str:
    key = _canon(s)
    code = _LANG_ALIASES.get(key, key)
    if code not in _ALLOWED_LANGS:
        # If empty, fallback to en; otherwise raise loudly to catch misconfig upstream
        if code == "":
            return "en"
        raise ValueError(f"[HIGGS TTS] Unsupported language: {s} -> {code}")
    return code

def _looks_ascii_english(text: str) -> bool:
    if not text:
        return False
    try:
        text.encode("ascii")
    except UnicodeEncodeError:
        return False
    # ASCII but not just punctuation/whitespace
    return any(c.isalpha() for c in text)

def _accent_from_voice_or_default(voice_type: Optional[str], lang_code: str) -> str:
    # Keep simple + deterministic; if you later encode region into voice_type, adapt here.
    return DEFAULT_REGION.get(lang_code, "US")

def _system_prompt(lang_code: str, region: str) -> str:
    # Keep the model on-task: speak-only, no paraphrase/translation/additions.
    # Use language code in prompt (server interprets code).
    return (
        f"Speak ONLY in {lang_code} with a native accent from {region}. "
        "Read the user's text verbatim; do NOT translate, paraphrase, or add words. "
        "Timing rules: treat commas as ~120ms pauses and sentence endings as ~220ms pauses. "
        "Do NOT read tags or metadata aloud. "
        "Keep natural prosody and native pronunciation. "
        "Maintain a consistent timbre, pitch, and speaking style across the entire utterance."
    )

def _b64_file(path: str) -> Optional[str]:
    if not path or not os.path.exists(path):
        return None
    with open(path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")

def _jittered_sleep(base: float, attempt: int):
    jitter = 0.2 + random.random() * 0.4
    time.sleep(base * (attempt + 1) * jitter)

# --------------------------- Streaming synthesis --------------------------------

def _stream_pcm16_to_wav(
    text: str,
    out_path: str,
    lang_code: str,
    region: str,
    ref_b64: Optional[str],
    max_retries: int = 3,
    backoff: float = 0.6,
):
    assert _client is not None and _model_name is not None

    os.makedirs(os.path.dirname(os.path.abspath(out_path)), exist_ok=True)
    sys_prompt = _system_prompt(lang_code, region)

    messages = [{"role": "system", "content": sys_prompt}]
    if ref_b64:
        messages.append({
            "role": "assistant",
            "content": [{"type": "input_audio", "input_audio": {"data": ref_b64, "format": "wav"}}],
        })
    messages.append({"role": "user", "content": text})

    with wave.open(out_path, "wb") as wf:
        wf.setnchannels(NCHANNELS)
        wf.setsampwidth(SAMPLE_WIDTH)
        wf.setframerate(SR)

        # Leading pad
        if _PAD_MS > 0:
            wf.writeframes(b"\x00\x00" * int(SR * _PAD_MS / 1000.0))

        for attempt in range(max_retries + 1):
            try:
                stream = _client.chat.completions.create(
                    model=_model_name,
                    messages=messages,
                    modalities=["text", "audio"],
                    audio={"format": "pcm16"},
                    stream=True,
                    extra_body={"language": lang_code, "speed": float(_HIGGS_SPEED)},
                )
                got_audio = False
                for chunk in stream:
                    delta = getattr(chunk.choices[0], "delta", None)
                    audio = getattr(delta, "audio", None)
                    if not audio:
                        continue
                    wf.writeframes(base64.b64decode(audio["data"]))
                    got_audio = True

                # trailing pad
                if _PAD_MS > 0:
                    wf.writeframes(b"\x00\x00" * int(SR * _PAD_MS / 1000.0))

                if not got_audio:
                    wf.writeframes(b"\x00\x00" * int(0.1 * SR))  # brief silence fallback
                    logger.warning("[HIGGS TTS] No audio chunks received; wrote brief silence.")
                break
            except Exception as e:
                msg = str(e)
                logger.warning(f"[HIGGS TTS] stream attempt {attempt + 1} failed: {msg}")
                if attempt >= max_retries:
                    raise
                is_rate = ("429" in msg) or ("rate limit" in msg.lower())
                _jittered_sleep(backoff * (2.0 if is_rate else 1.0), attempt)

# ------------------------------- Public API ------------------------------------

def tts(
    text: str,
    output_path: str,
    speaker_wav: Optional[str] = None,
    *,
    voice_type: Optional[str] = None,
    target_language: Optional[str] = None,
) -> None:
    """
    Perform per-line synthesis and write a mono 16-bit PCM WAV at SR=24k.
    `target_language` can be a UI label or a canonical code; it will be normalized to a code.
    """
    if os.path.exists(output_path) and os.path.getsize(output_path) > 1024:
        logger.info(f"[HIGGS TTS] Exists, skipping {output_path}")
        return

    load_model()

    # Normalize language to unified code
    lang_code = _norm_lang(target_language) if target_language else "en"

    # Safety: if text looks ASCII-English but non-English target, and mismatch not allowed -> force 'en'
    if not _ALLOW_MISMATCH and lang_code != "en" and _looks_ascii_english(text):
        logger.warning(f"[HIGGS TTS] ASCII-looking text with lang={lang_code}; forcing 'en'. "
                       f"Set HIGGS_TTS_ALLOW_MISMATCH=1 to disable.")
        lang_code = "en"

    region = _accent_from_voice_or_default(voice_type, lang_code)

    # Optional timbre reference
    ref_b64 = _b64_file(speaker_wav) if speaker_wav else None
    if ref_b64:
        logger.info(f"[HIGGS TTS] Using reference timbre: {speaker_wav}")

    # Empty text guard -> write a breath of silence
    text = (text or "").strip()
    if not text:
        os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
        with wave.open(output_path, "wb") as wf:
            wf.setnchannels(NCHANNELS)
            wf.setsampwidth(SAMPLE_WIDTH)
            wf.setframerate(SR)
            wf.writeframes(b"\x00\x00" * int(0.08 * SR))
        logger.warning("[HIGGS TTS] Empty input text; wrote brief silence.")
        return

    _stream_pcm16_to_wav(
        text=text,
        out_path=output_path,
        lang_code=lang_code,
        region=region,
        ref_b64=ref_b64,
        max_retries=3,
        backoff=0.6,
    )
    logger.info(f"[HIGGS TTS] Saved {output_path} | lang={lang_code}-{region} | speed={_HIGGS_SPEED} | pad_ms={_PAD_MS}")