Spaces:
Sleeping
Sleeping
Add v1.0.0 model support with KPipeline implementation
Browse files- README.md +1 -0
- app.py +36 -16
- requirements.txt +5 -1
- tts_factory.py +22 -0
- tts_model_v1.py +168 -0
README.md
CHANGED
|
@@ -10,6 +10,7 @@ pinned: true
|
|
| 10 |
short_description: Accelerated Text-To-Speech on Kokoro-82M
|
| 11 |
models:
|
| 12 |
- hexgrad/kLegacy
|
|
|
|
| 13 |
---
|
| 14 |
|
| 15 |
# Kokoro TTS Demo Space
|
|
|
|
| 10 |
short_description: Accelerated Text-To-Speech on Kokoro-82M
|
| 11 |
models:
|
| 12 |
- hexgrad/kLegacy
|
| 13 |
+
- hexgrad/Kokoro-82M
|
| 14 |
---
|
| 15 |
|
| 16 |
# Kokoro TTS Demo Space
|
app.py
CHANGED
|
@@ -9,13 +9,13 @@ from lib import format_audio_output
|
|
| 9 |
from lib.ui_content import header_html, demo_text_info, styling
|
| 10 |
from lib.book_utils import get_available_books, get_book_info, get_chapter_text
|
| 11 |
from lib.text_utils import count_tokens
|
| 12 |
-
from
|
| 13 |
|
| 14 |
# Set HF_HOME for faster restarts with cached models/voices
|
| 15 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
| 16 |
|
| 17 |
-
#
|
| 18 |
-
model =
|
| 19 |
|
| 20 |
# Configure logging
|
| 21 |
logging.basicConfig(level=logging.DEBUG)
|
|
@@ -24,21 +24,24 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
| 24 |
logger = logging.getLogger(__name__)
|
| 25 |
logger.debug("Starting app initialization...")
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def initialize_model():
|
| 30 |
"""Initialize model and get voices"""
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
| 32 |
if not model.initialize():
|
| 33 |
raise gr.Error("Failed to initialize model")
|
| 34 |
-
|
| 35 |
-
voices = model.list_voices()
|
| 36 |
-
if not voices:
|
| 37 |
-
raise gr.Error("No voices found. Please check the voices directory.")
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
def update_progress(chunk_num, total_chunks, tokens_per_sec, rtf, progress_state, start_time, gpu_timeout, progress):
|
| 44 |
# Calculate time metrics
|
|
@@ -382,6 +385,14 @@ with gr.Blocks(title="Kokoro TTS Demo", css=styling) as demo:
|
|
| 382 |
)
|
| 383 |
|
| 384 |
with gr.Group():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
voice_dropdown = gr.Dropdown(
|
| 386 |
label="Voice(s)",
|
| 387 |
choices=[], # Start empty, will be populated after initialization
|
|
@@ -390,6 +401,15 @@ with gr.Blocks(title="Kokoro TTS Demo", css=styling) as demo:
|
|
| 390 |
multiselect=True
|
| 391 |
)
|
| 392 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
speed_slider = gr.Slider(
|
| 394 |
label="Speed",
|
| 395 |
minimum=0.5,
|
|
@@ -436,9 +456,9 @@ with gr.Blocks(title="Kokoro TTS Demo", css=styling) as demo:
|
|
| 436 |
with gr.Column():
|
| 437 |
gr.Markdown(demo_text_info)
|
| 438 |
|
| 439 |
-
# Initialize voices on load
|
| 440 |
demo.load(
|
| 441 |
-
fn=initialize_model,
|
| 442 |
outputs=[voice_dropdown]
|
| 443 |
)
|
| 444 |
|
|
|
|
| 9 |
from lib.ui_content import header_html, demo_text_info, styling
|
| 10 |
from lib.book_utils import get_available_books, get_book_info, get_chapter_text
|
| 11 |
from lib.text_utils import count_tokens
|
| 12 |
+
from tts_factory import TTSFactory
|
| 13 |
|
| 14 |
# Set HF_HOME for faster restarts with cached models/voices
|
| 15 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
| 16 |
|
| 17 |
+
# Initialize model variable
|
| 18 |
+
model = None
|
| 19 |
|
| 20 |
# Configure logging
|
| 21 |
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
| 24 |
logger = logging.getLogger(__name__)
|
| 25 |
logger.debug("Starting app initialization...")
|
| 26 |
|
| 27 |
+
def initialize_model(version="v0.19"):
|
|
|
|
|
|
|
| 28 |
"""Initialize model and get voices"""
|
| 29 |
+
global model
|
| 30 |
+
try:
|
| 31 |
+
# Create model instance using factory
|
| 32 |
+
model = TTSFactory.create_model(version)
|
| 33 |
if not model.initialize():
|
| 34 |
raise gr.Error("Failed to initialize model")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
voices = model.list_voices()
|
| 37 |
+
if not voices:
|
| 38 |
+
raise gr.Error("No voices found. Please check the voices directory.")
|
| 39 |
+
|
| 40 |
+
default_voice = 'af_sky' if 'af_sky' in voices else voices[0] if voices else None
|
| 41 |
+
|
| 42 |
+
return gr.update(choices=voices, value=default_voice)
|
| 43 |
+
except Exception as e:
|
| 44 |
+
raise gr.Error(f"Failed to initialize model: {str(e)}")
|
| 45 |
|
| 46 |
def update_progress(chunk_num, total_chunks, tokens_per_sec, rtf, progress_state, start_time, gpu_timeout, progress):
|
| 47 |
# Calculate time metrics
|
|
|
|
| 385 |
)
|
| 386 |
|
| 387 |
with gr.Group():
|
| 388 |
+
version_dropdown = gr.Dropdown(
|
| 389 |
+
label="Model Version",
|
| 390 |
+
choices=["v0.19", "v1.0.0"],
|
| 391 |
+
value="v0.19",
|
| 392 |
+
allow_custom_value=False,
|
| 393 |
+
multiselect=False
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
voice_dropdown = gr.Dropdown(
|
| 397 |
label="Voice(s)",
|
| 398 |
choices=[], # Start empty, will be populated after initialization
|
|
|
|
| 401 |
multiselect=True
|
| 402 |
)
|
| 403 |
|
| 404 |
+
def on_version_change(version):
|
| 405 |
+
return initialize_model(version)
|
| 406 |
+
|
| 407 |
+
version_dropdown.change(
|
| 408 |
+
fn=on_version_change,
|
| 409 |
+
inputs=[version_dropdown],
|
| 410 |
+
outputs=[voice_dropdown]
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
speed_slider = gr.Slider(
|
| 414 |
label="Speed",
|
| 415 |
minimum=0.5,
|
|
|
|
| 456 |
with gr.Column():
|
| 457 |
gr.Markdown(demo_text_info)
|
| 458 |
|
| 459 |
+
# Initialize voices on load with default version
|
| 460 |
demo.load(
|
| 461 |
+
fn=lambda: initialize_model("v0.19"),
|
| 462 |
outputs=[voice_dropdown]
|
| 463 |
)
|
| 464 |
|
requirements.txt
CHANGED
|
@@ -9,4 +9,8 @@ regex==2024.11.6
|
|
| 9 |
tiktoken==0.8.0
|
| 10 |
transformers==4.47.1
|
| 11 |
munch==4.0.0
|
| 12 |
-
matplotlib==3.4.3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
tiktoken==0.8.0
|
| 10 |
transformers==4.47.1
|
| 11 |
munch==4.0.0
|
| 12 |
+
matplotlib==3.4.3
|
| 13 |
+
|
| 14 |
+
# v1.0.0 dependencies
|
| 15 |
+
kokoro>=1.0.0
|
| 16 |
+
misaki[en]>=0.1.0
|
tts_factory.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tts_model import TTSModel
|
| 2 |
+
from tts_model_v1 import TTSModelV1
|
| 3 |
+
|
| 4 |
+
class TTSFactory:
|
| 5 |
+
"""Factory class to create appropriate TTS model version"""
|
| 6 |
+
|
| 7 |
+
@staticmethod
|
| 8 |
+
def create_model(version="v0.19"):
|
| 9 |
+
"""Create TTS model instance for specified version
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
version: Model version to use ("v0.19" or "v1.0.0")
|
| 13 |
+
|
| 14 |
+
Returns:
|
| 15 |
+
TTSModel or TTSModelV1 instance
|
| 16 |
+
"""
|
| 17 |
+
if version == "v0.19":
|
| 18 |
+
return TTSModel()
|
| 19 |
+
elif version == "v1.0.0":
|
| 20 |
+
return TTSModelV1()
|
| 21 |
+
else:
|
| 22 |
+
raise ValueError(f"Unsupported version: {version}")
|
tts_model_v1.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import time
|
| 5 |
+
from typing import Tuple, List
|
| 6 |
+
import soundfile as sf
|
| 7 |
+
from kokoro import KPipeline
|
| 8 |
+
import spaces
|
| 9 |
+
|
| 10 |
+
class TTSModelV1:
|
| 11 |
+
"""KPipeline-based TTS model for v1.0.0"""
|
| 12 |
+
|
| 13 |
+
def __init__(self):
|
| 14 |
+
self.pipeline = None
|
| 15 |
+
self.voices_dir = "voices"
|
| 16 |
+
self.model_repo = "hexgrad/Kokoro-82M"
|
| 17 |
+
|
| 18 |
+
def initialize(self) -> bool:
|
| 19 |
+
"""Initialize KPipeline and verify voices"""
|
| 20 |
+
try:
|
| 21 |
+
print("Initializing v1.0.0 model...")
|
| 22 |
+
|
| 23 |
+
# Initialize KPipeline with American English
|
| 24 |
+
self.pipeline = KPipeline(lang_code='a')
|
| 25 |
+
|
| 26 |
+
# Verify local voice files are available
|
| 27 |
+
voices_dir = os.path.join(self.voices_dir, "voices")
|
| 28 |
+
if not os.path.exists(voices_dir):
|
| 29 |
+
raise ValueError("Voice files not found")
|
| 30 |
+
|
| 31 |
+
# Verify voices were downloaded successfully
|
| 32 |
+
available_voices = self.list_voices()
|
| 33 |
+
if not available_voices:
|
| 34 |
+
print("Warning: No voices found after initialization")
|
| 35 |
+
else:
|
| 36 |
+
print(f"Found {len(available_voices)} voices")
|
| 37 |
+
|
| 38 |
+
print("Model initialization complete")
|
| 39 |
+
return True
|
| 40 |
+
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"Error initializing model: {str(e)}")
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
def list_voices(self) -> List[str]:
|
| 46 |
+
"""List available voices"""
|
| 47 |
+
voices = []
|
| 48 |
+
voices_subdir = os.path.join(self.voices_dir, "voices")
|
| 49 |
+
if os.path.exists(voices_subdir):
|
| 50 |
+
for file in os.listdir(voices_subdir):
|
| 51 |
+
if file.endswith(".pt"):
|
| 52 |
+
voice_name = file[:-3]
|
| 53 |
+
voices.append(voice_name)
|
| 54 |
+
return voices
|
| 55 |
+
|
| 56 |
+
@spaces.GPU(duration=None) # Duration will be set by the UI
|
| 57 |
+
def generate_speech(self, text: str, voice_names: list[str], speed: float = 1.0, gpu_timeout: int = 60, progress_callback=None, progress_state=None, progress=None) -> Tuple[np.ndarray, float]:
|
| 58 |
+
"""Generate speech from text using KPipeline
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
text: Input text to convert to speech
|
| 62 |
+
voice_names: List of voice names to use (will be mixed if multiple)
|
| 63 |
+
speed: Speech speed multiplier
|
| 64 |
+
progress_callback: Optional callback function
|
| 65 |
+
progress_state: Dictionary tracking generation progress metrics
|
| 66 |
+
progress: Progress callback from Gradio
|
| 67 |
+
"""
|
| 68 |
+
try:
|
| 69 |
+
start_time = time.time()
|
| 70 |
+
|
| 71 |
+
if not text or not voice_names:
|
| 72 |
+
raise ValueError("Text and voice name are required")
|
| 73 |
+
|
| 74 |
+
# Handle voice mixing
|
| 75 |
+
if isinstance(voice_names, list) and len(voice_names) > 1:
|
| 76 |
+
t_voices = []
|
| 77 |
+
for voice in voice_names:
|
| 78 |
+
try:
|
| 79 |
+
voice_path = os.path.join(self.voices_dir, "voices", f"{voice}.pt")
|
| 80 |
+
try:
|
| 81 |
+
voicepack = torch.load(voice_path, weights_only=True)
|
| 82 |
+
except Exception as e:
|
| 83 |
+
print(f"Warning: weights_only load failed, attempting full load: {str(e)}")
|
| 84 |
+
voicepack = torch.load(voice_path, weights_only=False)
|
| 85 |
+
t_voices.append(voicepack)
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print(f"Warning: Failed to load voice {voice}: {str(e)}")
|
| 88 |
+
|
| 89 |
+
# Combine voices by taking mean
|
| 90 |
+
voicepack = torch.mean(torch.stack(t_voices), dim=0)
|
| 91 |
+
voice_name = "_".join(voice_names)
|
| 92 |
+
# Save mixed voice temporarily
|
| 93 |
+
mixed_voice_path = os.path.join(self.voices_dir, "voices", f"{voice_name}.pt")
|
| 94 |
+
torch.save(voicepack, mixed_voice_path)
|
| 95 |
+
else:
|
| 96 |
+
voice_name = voice_names[0]
|
| 97 |
+
|
| 98 |
+
# Generate speech using KPipeline
|
| 99 |
+
generator = self.pipeline(
|
| 100 |
+
text,
|
| 101 |
+
voice=voice_name,
|
| 102 |
+
speed=speed,
|
| 103 |
+
split_pattern=r'\n+' # Default chunking pattern
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
# Process chunks and collect metrics
|
| 107 |
+
audio_chunks = []
|
| 108 |
+
chunk_times = []
|
| 109 |
+
chunk_sizes = []
|
| 110 |
+
total_tokens = 0
|
| 111 |
+
|
| 112 |
+
for i, (gs, ps, audio) in enumerate(generator):
|
| 113 |
+
chunk_start = time.time()
|
| 114 |
+
|
| 115 |
+
# Store chunk audio
|
| 116 |
+
audio_chunks.append(audio)
|
| 117 |
+
|
| 118 |
+
# Calculate metrics
|
| 119 |
+
chunk_time = time.time() - chunk_start
|
| 120 |
+
chunk_times.append(chunk_time)
|
| 121 |
+
chunk_sizes.append(len(gs)) # Use grapheme length as chunk size
|
| 122 |
+
|
| 123 |
+
# Update progress if callback provided
|
| 124 |
+
if progress_callback:
|
| 125 |
+
chunk_duration = len(audio) / 24000
|
| 126 |
+
rtf = chunk_time / chunk_duration
|
| 127 |
+
progress_callback(
|
| 128 |
+
i + 1,
|
| 129 |
+
-1, # Total chunks unknown with generator
|
| 130 |
+
len(gs) / chunk_time, # tokens/sec
|
| 131 |
+
rtf,
|
| 132 |
+
progress_state,
|
| 133 |
+
start_time,
|
| 134 |
+
gpu_timeout,
|
| 135 |
+
progress
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
print(f"Chunk {i+1} processed in {chunk_time:.2f}s")
|
| 139 |
+
print(f"Graphemes: {gs}")
|
| 140 |
+
print(f"Phonemes: {ps}")
|
| 141 |
+
|
| 142 |
+
# Concatenate audio chunks
|
| 143 |
+
audio = np.concatenate(audio_chunks)
|
| 144 |
+
|
| 145 |
+
# Cleanup temporary mixed voice if created
|
| 146 |
+
if len(voice_names) > 1:
|
| 147 |
+
try:
|
| 148 |
+
os.remove(mixed_voice_path)
|
| 149 |
+
except:
|
| 150 |
+
pass
|
| 151 |
+
|
| 152 |
+
# Return audio and metrics
|
| 153 |
+
return (
|
| 154 |
+
audio,
|
| 155 |
+
len(audio) / 24000,
|
| 156 |
+
{
|
| 157 |
+
"chunk_times": chunk_times,
|
| 158 |
+
"chunk_sizes": chunk_sizes,
|
| 159 |
+
"tokens_per_sec": [float(x) for x in progress_state["tokens_per_sec"]] if progress_state else [],
|
| 160 |
+
"rtf": [float(x) for x in progress_state["rtf"]] if progress_state else [],
|
| 161 |
+
"total_tokens": total_tokens,
|
| 162 |
+
"total_time": time.time() - start_time
|
| 163 |
+
}
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
print(f"Error generating speech: {str(e)}")
|
| 168 |
+
raise
|