Spaces:
Runtime error
Runtime error
| """Audio processing components for the TTS system.""" | |
| import torch | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from typing import Tuple, Optional | |
| from nemo.collections.tts.models import AudioCodecModel | |
| from transformers import AutoTokenizer | |
| from .config import Config, AudioConfig | |
| from .extractors import AudioCodeExtractor, TextExtractor | |
| from nemo.utils.nemo_logging import Logger | |
| nemo_logger = Logger() | |
| nemo_logger.remove_stream_handlers() | |
| logger = logging.getLogger(__name__) | |
| class AudioProcessor(ABC): | |
| """Abstract base class for audio processing strategies.""" | |
| def decode_audio(self, audio_codes: torch.Tensor, length: torch.Tensor) -> torch.Tensor: | |
| pass | |
| class NemoAudioProcessor(AudioProcessor): | |
| """NeMo-based audio processing implementation.""" | |
| def __init__(self, config: AudioConfig): | |
| self.config = config | |
| self.device = config.device or ('cuda' if torch.cuda.is_available() else 'cpu') | |
| self._model = None | |
| def model(self): | |
| if self._model is None: | |
| logger.info(f"Loading NeMo codec model: {self.config.nemo_model_name}") | |
| self._model = AudioCodecModel.from_pretrained(self.config.nemo_model_name).eval() | |
| self._model.to(self.device) | |
| return self._model | |
| def decode_audio(self, audio_codes: torch.Tensor, length: torch.Tensor) -> torch.Tensor: | |
| audio_codes, length = audio_codes.to(self.device), length.to(self.device) | |
| with torch.inference_mode(): | |
| reconstructed_audio, _ = self.model.decode(tokens=audio_codes, tokens_len=length) | |
| return reconstructed_audio.cpu().detach().numpy().squeeze() | |
| class NemoAudioPlayer: | |
| """Orchestrates audio generation from token sequences.""" | |
| def __init__(self, config: Config, text_tokenizer_name: Optional[str] = None): | |
| self.config = config | |
| self.tokens = config.tokens | |
| self.audio_processor = NemoAudioProcessor(config.audio) | |
| self.code_extractor = AudioCodeExtractor(config.tokens) | |
| self.text_extractor = None | |
| if text_tokenizer_name: | |
| tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_name) | |
| self.text_extractor = TextExtractor(config.tokens, tokenizer) | |
| def get_waveform(self, out_ids: torch.Tensor) -> Tuple[torch.Tensor, Optional[str]]: | |
| """Generate waveform from model output tokens.""" | |
| try: | |
| out_ids = out_ids.flatten() | |
| self.code_extractor.validate_output(out_ids) | |
| audio_codes, length = self.code_extractor.extract_audio_codes(out_ids) | |
| output_audio = self.audio_processor.decode_audio(audio_codes, length) | |
| text = None | |
| if self.text_extractor: | |
| text = self.text_extractor.extract_text(out_ids) | |
| return output_audio, text | |
| except Exception as e: | |
| logger.error(f"Error generating waveform: {e}") | |
| raise |