File size: 3,048 Bytes
f60a5bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Audio processing components for the TTS system."""

import torch
import logging
from abc import ABC, abstractmethod
from typing import Tuple, Optional
from nemo.collections.tts.models import AudioCodecModel
from transformers import AutoTokenizer
from .config import Config, AudioConfig
from .extractors import AudioCodeExtractor, TextExtractor

from nemo.utils.nemo_logging import Logger

nemo_logger = Logger()
nemo_logger.remove_stream_handlers()

logger = logging.getLogger(__name__)


class AudioProcessor(ABC):
    """Abstract base class for audio processing strategies."""
    
    @abstractmethod
    def decode_audio(self, audio_codes: torch.Tensor, length: torch.Tensor) -> torch.Tensor:
        pass


class NemoAudioProcessor(AudioProcessor):
    """NeMo-based audio processing implementation."""
    
    def __init__(self, config: AudioConfig):
        self.config = config
        self.device = config.device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self._model = None
    
    @property
    def model(self):
        if self._model is None:
            logger.info(f"Loading NeMo codec model: {self.config.nemo_model_name}")
            self._model = AudioCodecModel.from_pretrained(self.config.nemo_model_name).eval()
            self._model.to(self.device)
        return self._model
    
    def decode_audio(self, audio_codes: torch.Tensor, length: torch.Tensor) -> torch.Tensor:
        audio_codes, length = audio_codes.to(self.device), length.to(self.device)
        with torch.inference_mode():
            reconstructed_audio, _ = self.model.decode(tokens=audio_codes, tokens_len=length)
            return reconstructed_audio.cpu().detach().numpy().squeeze()


class NemoAudioPlayer:
    """Orchestrates audio generation from token sequences."""
    
    def __init__(self, config: Config, text_tokenizer_name: Optional[str] = None):
        self.config = config
        self.tokens = config.tokens
        self.audio_processor = NemoAudioProcessor(config.audio)
        self.code_extractor = AudioCodeExtractor(config.tokens)
        
        self.text_extractor = None
        if text_tokenizer_name:
            tokenizer = AutoTokenizer.from_pretrained(text_tokenizer_name)
            self.text_extractor = TextExtractor(config.tokens, tokenizer)
    
    def get_waveform(self, out_ids: torch.Tensor) -> Tuple[torch.Tensor, Optional[str]]:
        """Generate waveform from model output tokens."""
        try:
            out_ids = out_ids.flatten()
            self.code_extractor.validate_output(out_ids)
            audio_codes, length = self.code_extractor.extract_audio_codes(out_ids)
            
            output_audio = self.audio_processor.decode_audio(audio_codes, length)
            
            text = None
            if self.text_extractor:
                text = self.text_extractor.extract_text(out_ids)
            
            return output_audio, text
            
        except Exception as e:
            logger.error(f"Error generating waveform: {e}")
            raise