Gapeleon commited on
Commit
f60a5bb
·
1 Parent(s): 8c00aff

add modules

Browse files
kanitts/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """NanoCodec TTS System - A modular text-to-speech system."""
2
+
3
+ from .config import Config, AudioConfig, ModelConfig
4
+ from .tokens import TokenRegistry
5
+ from .audio import NemoAudioPlayer, AudioProcessor, NemoAudioProcessor
6
+ from .models import KaniModel, InputProcessor, ModelInference
7
+ from .extractors import AudioCodeExtractor, TextExtractor
8
+ from .factory import TTSFactory
9
+
10
+ __all__ = [
11
+ 'Config',
12
+ 'AudioConfig',
13
+ 'ModelConfig',
14
+ 'TokenRegistry',
15
+ 'NemoAudioPlayer',
16
+ 'AudioProcessor',
17
+ 'NemoAudioProcessor',
18
+ 'KaniModel',
19
+ 'InputProcessor',
20
+ 'ModelInference',
21
+ 'AudioCodeExtractor',
22
+ 'TextExtractor',
23
+ 'TTSFactory',
24
+ ]
kanitts/audio.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio processing components for the TTS system."""
2
+
3
+ import torch
4
+ import logging
5
+ from abc import ABC, abstractmethod
6
+ from typing import Tuple, Optional
7
+ from nemo.collections.tts.models import AudioCodecModel
8
+ from transformers import AutoTokenizer
9
+ from .config import Config, AudioConfig
10
+ from .extractors import AudioCodeExtractor, TextExtractor
11
+
12
+ from nemo.utils.nemo_logging import Logger
13
+
14
+ nemo_logger = Logger()
15
+ nemo_logger.remove_stream_handlers()
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class AudioProcessor(ABC):
21
+ """Abstract base class for audio processing strategies."""
22
+
23
+ @abstractmethod
24
+ def decode_audio(self, audio_codes: torch.Tensor, length: torch.Tensor) -> torch.Tensor:
25
+ pass
26
+
27
+
28
+ class NemoAudioProcessor(AudioProcessor):
29
+ """NeMo-based audio processing implementation."""
30
+
31
+ def __init__(self, config: AudioConfig):
32
+ self.config = config
33
+ self.device = config.device or ('cuda' if torch.cuda.is_available() else 'cpu')
34
+ self._model = None
35
+
36
+ @property
37
+ def model(self):
38
+ if self._model is None:
39
+ logger.info(f"Loading NeMo codec model: {self.config.nemo_model_name}")
40
+ self._model = AudioCodecModel.from_pretrained(self.config.nemo_model_name).eval()
41
+ self._model.to(self.device)
42
+ return self._model
43
+
44
+ def decode_audio(self, audio_codes: torch.Tensor, length: torch.Tensor) -> torch.Tensor:
45
+ audio_codes, length = audio_codes.to(self.device), length.to(self.device)
46
+ with torch.inference_mode():
47
+ reconstructed_audio, _ = self.model.decode(tokens=audio_codes, tokens_len=length)
48
+ return reconstructed_audio.cpu().detach().numpy().squeeze()
49
+
50
+
51
+ class NemoAudioPlayer:
52
+ """Orchestrates audio generation from token sequences."""
53
+
54
+ def __init__(self, config: Config, text_tokenizer_name: Optional[str] = None):
55
+ self.config = config
56
+ self.tokens = config.tokens
57
+ self.audio_processor = NemoAudioProcessor(config.audio)
58
+ self.code_extractor = AudioCodeExtractor(config.tokens)
59
+
60
+ self.text_extractor = None
61
+ if text_tokenizer_name:
62
+ tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_name)
63
+ self.text_extractor = TextExtractor(config.tokens, tokenizer)
64
+
65
+ def get_waveform(self, out_ids: torch.Tensor) -> Tuple[torch.Tensor, Optional[str]]:
66
+ """Generate waveform from model output tokens."""
67
+ try:
68
+ out_ids = out_ids.flatten()
69
+ self.code_extractor.validate_output(out_ids)
70
+ audio_codes, length = self.code_extractor.extract_audio_codes(out_ids)
71
+
72
+ output_audio = self.audio_processor.decode_audio(audio_codes, length)
73
+
74
+ text = None
75
+ if self.text_extractor:
76
+ text = self.text_extractor.extract_text(out_ids)
77
+
78
+ return output_audio, text
79
+
80
+ except Exception as e:
81
+ logger.error(f"Error generating waveform: {e}")
82
+ raise
kanitts/config.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration classes for the TTS system."""
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+ from .tokens import TokenRegistry
6
+
7
+
8
+ @dataclass
9
+ class AudioConfig:
10
+ """Configuration for audio processing."""
11
+ nemo_model_name: str = "nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps"
12
+ sample_rate: int = 22050
13
+ device: Optional[str] = None
14
+
15
+
16
+ @dataclass
17
+ class ModelConfig:
18
+ """Configuration for language model."""
19
+ model_name: str = 'nineninesix/kani-tts-450m-0.1-pt'
20
+ device_map: str = "auto"
21
+ torch_dtype: str = "bfloat16"
22
+ max_new_tokens: int = 1200
23
+ temperature: float = 0.6
24
+ top_p: float = 0.95
25
+ repetition_penalty: float = 1.1
26
+
27
+
28
+ @dataclass
29
+ class Config:
30
+ """Main configuration container."""
31
+ model: ModelConfig
32
+ audio: AudioConfig
33
+ tokens: TokenRegistry
34
+
35
+ @classmethod
36
+ def default(cls) -> 'Config':
37
+ return cls(
38
+ model=ModelConfig(),
39
+ audio=AudioConfig(),
40
+ tokens=TokenRegistry()
41
+ )
kanitts/extractors.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Extractors for processing audio and text from token sequences."""
2
+
3
+ import torch
4
+ from typing import Tuple, Optional
5
+ from .tokens import TokenRegistry
6
+
7
+
8
+ class AudioCodeExtractor:
9
+ """Handles extraction and validation of audio codes from token sequences."""
10
+
11
+ def __init__(self, token_registry: TokenRegistry):
12
+ self.tokens = token_registry
13
+
14
+ def validate_output(self, out_ids: torch.Tensor) -> None:
15
+ """Validate that required speech tokens are present."""
16
+ start_present = self.tokens.start_of_speech in out_ids
17
+ end_present = self.tokens.end_of_speech in out_ids
18
+
19
+ if not (start_present and end_present):
20
+ raise ValueError('Special speech tokens not found in output')
21
+
22
+ def extract_audio_codes(self, out_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
23
+ """Extract and process audio codes from token sequence."""
24
+ try:
25
+ start_idx = (out_ids == self.tokens.start_of_speech).nonzero(as_tuple=True)[0].item()
26
+ end_idx = (out_ids == self.tokens.end_of_speech).nonzero(as_tuple=True)[0].item()
27
+ except IndexError:
28
+ raise ValueError('Speech tokens not found in sequence')
29
+
30
+ if start_idx >= end_idx:
31
+ raise ValueError('Invalid audio codes sequence - start token after end token')
32
+
33
+ audio_codes = out_ids[start_idx + 1:end_idx]
34
+
35
+ if len(audio_codes) % 4 != 0:
36
+ raise ValueError('Audio codes length must be multiple of 4')
37
+
38
+ audio_codes = audio_codes.reshape(-1, 4)
39
+ audio_codes = audio_codes - torch.tensor([self.tokens.codebook_size * i for i in range(4)])
40
+ audio_codes = audio_codes - self.tokens.audio_tokens_start
41
+
42
+ if (audio_codes < 0).sum().item() > 0:
43
+ raise ValueError('Invalid audio tokens detected')
44
+
45
+ audio_codes = audio_codes.T.unsqueeze(0)
46
+ length = torch.tensor([audio_codes.shape[-1]])
47
+
48
+ return audio_codes, length
49
+
50
+
51
+ class TextExtractor:
52
+ """Handles text extraction from token sequences."""
53
+
54
+ def __init__(self, token_registry: TokenRegistry, tokenizer):
55
+ self.tokens = token_registry
56
+ self.tokenizer = tokenizer
57
+
58
+ def extract_text(self, out_ids: torch.Tensor) -> Optional[str]:
59
+ """Extract text from token sequence."""
60
+ try:
61
+ start_idx = (out_ids == self.tokens.start_of_text).nonzero(as_tuple=True)[0].item()
62
+ end_idx = (out_ids == self.tokens.end_of_text).nonzero(as_tuple=True)[0].item()
63
+ text_tokens = out_ids[start_idx:end_idx + 1]
64
+ return self.tokenizer.decode(text_tokens, skip_special_tokens=True)
65
+ except (IndexError, AttributeError):
66
+ return None
kanitts/factory.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Factory for creating TTS system components."""
2
+
3
+ from typing import Optional, Tuple
4
+ from .config import Config
5
+ from .audio import NemoAudioPlayer
6
+ from .models import KaniModel
7
+
8
+
9
+ class TTSFactory:
10
+ """Factory for creating TTS system components."""
11
+
12
+ @staticmethod
13
+ def create_system(config: Optional[Config] = None) -> Tuple[KaniModel, NemoAudioPlayer]:
14
+ """Create a complete TTS system."""
15
+ if config is None:
16
+ config = Config.default()
17
+
18
+ player = NemoAudioPlayer(config)
19
+ model = KaniModel(config, player)
20
+
21
+ return model, player
kanitts/models.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model inference components for the TTS system."""
2
+
3
+ import torch
4
+ import logging
5
+ from typing import Tuple
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ from .config import Config, ModelConfig
8
+ from .tokens import TokenRegistry
9
+ from .audio import NemoAudioPlayer
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class InputProcessor:
15
+ """Handles input text processing and tokenization."""
16
+
17
+ def __init__(self, tokenizer, token_registry: TokenRegistry):
18
+ self.tokenizer = tokenizer
19
+ self.tokens = token_registry
20
+
21
+ def prepare_input(self, text: str) -> Tuple[torch.Tensor, torch.Tensor]:
22
+ """Prepare input text for model inference."""
23
+ input_ids = self.tokenizer(text, return_tensors="pt").input_ids
24
+
25
+ start_token = torch.tensor([[self.tokens.start_of_human]], dtype=torch.int64)
26
+ end_tokens = torch.tensor([[self.tokens.end_of_text, self.tokens.end_of_human]], dtype=torch.int64)
27
+
28
+ modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
29
+ attention_mask = torch.ones(1, modified_input_ids.shape[1], dtype=torch.int64)
30
+
31
+ return modified_input_ids, attention_mask
32
+
33
+
34
+ class ModelInference:
35
+ """Handles model inference operations."""
36
+
37
+ def __init__(self, model, config: ModelConfig, token_registry: TokenRegistry):
38
+ self.model = model
39
+ self.config = config
40
+ self.tokens = token_registry
41
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
42
+
43
+ def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
44
+ """Generate tokens from input."""
45
+ input_ids = input_ids.to(self.device)
46
+ attention_mask = attention_mask.to(self.device)
47
+
48
+ with torch.no_grad():
49
+ generated_ids = self.model.generate(
50
+ input_ids=input_ids,
51
+ attention_mask=attention_mask,
52
+ max_new_tokens=self.config.max_new_tokens,
53
+ do_sample=True,
54
+ temperature=self.config.temperature,
55
+ top_p=self.config.top_p,
56
+ repetition_penalty=self.config.repetition_penalty,
57
+ num_return_sequences=1,
58
+ eos_token_id=self.tokens.end_of_speech,
59
+ )
60
+ return generated_ids.to('cpu')
61
+
62
+
63
+ class KaniModel:
64
+ """Main text-to-speech model orchestrator."""
65
+
66
+ def __init__(self, config: Config, player: NemoAudioPlayer):
67
+ self.config = config
68
+ self.player = player
69
+
70
+ logger.info(f"Loading model: {config.model.model_name}")
71
+ torch_dtype = getattr(torch, config.model.torch_dtype)
72
+ self.model = AutoModelForCausalLM.from_pretrained(
73
+ config.model.model_name,
74
+ torch_dtype=torch_dtype,
75
+ device_map=config.model.device_map,
76
+ )
77
+
78
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model.model_name)
79
+ self.input_processor = InputProcessor(self.tokenizer, config.tokens)
80
+ self.inference = ModelInference(self.model, config.model, config.tokens)
81
+
82
+ def run_model(self, text: str) -> Tuple[torch.Tensor, str]:
83
+ """Generate audio from input text."""
84
+ try:
85
+ logger.info(f"Processing text: {text[:50]}...")
86
+ input_ids, attention_mask = self.input_processor.prepare_input(text)
87
+ model_output = self.inference.generate(input_ids, attention_mask)
88
+ audio, _ = self.player.get_waveform(model_output)
89
+ return audio, text
90
+ except Exception as e:
91
+ logger.error(f"Error in model execution: {e}")
92
+ raise
kanitts/tokens.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Token registry for managing special tokens in the TTS system."""
2
+
3
+
4
+ class TokenRegistry:
5
+ """Centralized token management for audio codec operations."""
6
+
7
+ def __init__(self, tokenizer_length: int = 64400):
8
+ self.tokenizer_length = tokenizer_length
9
+ self.start_of_text = 1
10
+ self.end_of_text = 2
11
+ self.start_of_speech = tokenizer_length + 1
12
+ self.end_of_speech = tokenizer_length + 2
13
+ self.start_of_human = tokenizer_length + 3
14
+ self.end_of_human = tokenizer_length + 4
15
+ self.start_of_ai = tokenizer_length + 5
16
+ self.end_of_ai = tokenizer_length + 6
17
+ self.pad_token = tokenizer_length + 7
18
+ self.audio_tokens_start = tokenizer_length + 10
19
+ self.codebook_size = 4032