Gapeleon's picture
add modules
f60a5bb
"""Audio processing components for the TTS system."""
import torch
import logging
from abc import ABC, abstractmethod
from typing import Tuple, Optional
from nemo.collections.tts.models import AudioCodecModel
from transformers import AutoTokenizer
from .config import Config, AudioConfig
from .extractors import AudioCodeExtractor, TextExtractor
from nemo.utils.nemo_logging import Logger
nemo_logger = Logger()
nemo_logger.remove_stream_handlers()
logger = logging.getLogger(__name__)
class AudioProcessor(ABC):
"""Abstract base class for audio processing strategies."""
@abstractmethod
def decode_audio(self, audio_codes: torch.Tensor, length: torch.Tensor) -> torch.Tensor:
pass
class NemoAudioProcessor(AudioProcessor):
"""NeMo-based audio processing implementation."""
def __init__(self, config: AudioConfig):
self.config = config
self.device = config.device or ('cuda' if torch.cuda.is_available() else 'cpu')
self._model = None
@property
def model(self):
if self._model is None:
logger.info(f"Loading NeMo codec model: {self.config.nemo_model_name}")
self._model = AudioCodecModel.from_pretrained(self.config.nemo_model_name).eval()
self._model.to(self.device)
return self._model
def decode_audio(self, audio_codes: torch.Tensor, length: torch.Tensor) -> torch.Tensor:
audio_codes, length = audio_codes.to(self.device), length.to(self.device)
with torch.inference_mode():
reconstructed_audio, _ = self.model.decode(tokens=audio_codes, tokens_len=length)
return reconstructed_audio.cpu().detach().numpy().squeeze()
class NemoAudioPlayer:
"""Orchestrates audio generation from token sequences."""
def __init__(self, config: Config, text_tokenizer_name: Optional[str] = None):
self.config = config
self.tokens = config.tokens
self.audio_processor = NemoAudioProcessor(config.audio)
self.code_extractor = AudioCodeExtractor(config.tokens)
self.text_extractor = None
if text_tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_name)
self.text_extractor = TextExtractor(config.tokens, tokenizer)
def get_waveform(self, out_ids: torch.Tensor) -> Tuple[torch.Tensor, Optional[str]]:
"""Generate waveform from model output tokens."""
try:
out_ids = out_ids.flatten()
self.code_extractor.validate_output(out_ids)
audio_codes, length = self.code_extractor.extract_audio_codes(out_ids)
output_audio = self.audio_processor.decode_audio(audio_codes, length)
text = None
if self.text_extractor:
text = self.text_extractor.extract_text(out_ids)
return output_audio, text
except Exception as e:
logger.error(f"Error generating waveform: {e}")
raise