Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import logging | |
| from functools import partial | |
| from typing import Literal | |
| import tempfile | |
| import torch | |
| from omegaconf import OmegaConf | |
| from vocos import Vocos | |
| from huggingface_hub import snapshot_download | |
| from .model.dvae import DVAE | |
| from .model.gpt import GPT_warpper | |
| from .utils.gpu_utils import select_device | |
| from .utils.infer_utils import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map, HomophonesReplacer | |
| from .utils.io_utils import get_latest_modified_file | |
| from .infer.api import refine_text, infer_code | |
| from .utils.download import check_all_assets, download_all_assets | |
| logging.basicConfig(level = logging.INFO) | |
| class Chat: | |
| def __init__(self, ): | |
| self.pretrain_models = {} | |
| self.normalizer = {} | |
| self.homophones_replacer = None | |
| self.logger = logging.getLogger(__name__) | |
| def check_model(self, level = logging.INFO, use_decoder = False): | |
| not_finish = False | |
| check_list = ['vocos', 'gpt', 'tokenizer'] | |
| if use_decoder: | |
| check_list.append('decoder') | |
| else: | |
| check_list.append('dvae') | |
| for module in check_list: | |
| if module not in self.pretrain_models: | |
| self.logger.log(logging.WARNING, f'{module} not initialized.') | |
| not_finish = True | |
| if not not_finish: | |
| self.logger.log(level, f'All initialized.') | |
| return not not_finish | |
| def load_models( | |
| self, | |
| source: Literal['huggingface', 'local', 'custom']='local', | |
| force_redownload=False, | |
| custom_path='<LOCAL_PATH>', | |
| **kwargs, | |
| ): | |
| if source == 'local': | |
| download_path = os.getcwd() | |
| if not check_all_assets(update=True): | |
| with tempfile.TemporaryDirectory() as tmp: | |
| download_all_assets(tmpdir=tmp) | |
| if not check_all_assets(update=False): | |
| logging.error("counld not satisfy all assets needed.") | |
| exit(1) | |
| elif source == 'huggingface': | |
| hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface")) | |
| try: | |
| download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots')) | |
| except: | |
| download_path = None | |
| if download_path is None or force_redownload: | |
| self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS') | |
| download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"]) | |
| else: | |
| self.logger.log(logging.INFO, f'Load from cache: {download_path}') | |
| elif source == 'custom': | |
| self.logger.log(logging.INFO, f'Load from local: {custom_path}') | |
| download_path = custom_path | |
| self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs) | |
| def _load( | |
| self, | |
| vocos_config_path: str = None, | |
| vocos_ckpt_path: str = None, | |
| dvae_config_path: str = None, | |
| dvae_ckpt_path: str = None, | |
| gpt_config_path: str = None, | |
| gpt_ckpt_path: str = None, | |
| decoder_config_path: str = None, | |
| decoder_ckpt_path: str = None, | |
| tokenizer_path: str = None, | |
| device: str = None, | |
| compile: bool = True, | |
| ): | |
| if not device: | |
| device = select_device(4096) | |
| self.logger.log(logging.INFO, f'use {device}') | |
| if vocos_config_path: | |
| vocos = Vocos.from_hparams(vocos_config_path).to( | |
| # vocos on mps will crash, use cpu fallback | |
| "cpu" if torch.backends.mps.is_available() else device | |
| ).eval() | |
| assert vocos_ckpt_path, 'vocos_ckpt_path should not be None' | |
| vocos.load_state_dict(torch.load(vocos_ckpt_path)) | |
| self.pretrain_models['vocos'] = vocos | |
| self.logger.log(logging.INFO, 'vocos loaded.') | |
| if dvae_config_path: | |
| cfg = OmegaConf.load(dvae_config_path) | |
| dvae = DVAE(**cfg).to(device).eval() | |
| assert dvae_ckpt_path, 'dvae_ckpt_path should not be None' | |
| dvae.load_state_dict(torch.load(dvae_ckpt_path)) | |
| self.pretrain_models['dvae'] = dvae | |
| self.logger.log(logging.INFO, 'dvae loaded.') | |
| if gpt_config_path: | |
| cfg = OmegaConf.load(gpt_config_path) | |
| gpt = GPT_warpper(**cfg).to(device).eval() | |
| assert gpt_ckpt_path, 'gpt_ckpt_path should not be None' | |
| gpt.load_state_dict(torch.load(gpt_ckpt_path)) | |
| if compile and 'cuda' in str(device): | |
| try: | |
| gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True) | |
| except RuntimeError as e: | |
| logging.warning(f'Compile failed,{e}. fallback to normal mode.') | |
| self.pretrain_models['gpt'] = gpt | |
| spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt') | |
| assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}' | |
| self.pretrain_models['spk_stat'] = torch.load(spk_stat_path).to(device) | |
| self.logger.log(logging.INFO, 'gpt loaded.') | |
| if decoder_config_path: | |
| cfg = OmegaConf.load(decoder_config_path) | |
| decoder = DVAE(**cfg).to(device).eval() | |
| assert decoder_ckpt_path, 'decoder_ckpt_path should not be None' | |
| decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu')) | |
| self.pretrain_models['decoder'] = decoder | |
| self.logger.log(logging.INFO, 'decoder loaded.') | |
| if tokenizer_path: | |
| tokenizer = torch.load(tokenizer_path, map_location='cpu') | |
| tokenizer.padding_side = 'left' | |
| self.pretrain_models['tokenizer'] = tokenizer | |
| self.logger.log(logging.INFO, 'tokenizer loaded.') | |
| self.check_model() | |
| def _infer( | |
| self, | |
| text, | |
| skip_refine_text=False, | |
| refine_text_only=False, | |
| params_refine_text={}, | |
| params_infer_code={'prompt':'[speed_5]'}, | |
| use_decoder=True, | |
| do_text_normalization=True, | |
| lang=None, | |
| stream=False, | |
| do_homophone_replacement=True | |
| ): | |
| assert self.check_model(use_decoder=use_decoder) | |
| if not isinstance(text, list): | |
| text = [text] | |
| if do_text_normalization: | |
| for i, t in enumerate(text): | |
| _lang = detect_language(t) if lang is None else lang | |
| if self.init_normalizer(_lang): | |
| text[i] = self.normalizer[_lang](t) | |
| if _lang == 'zh': | |
| text[i] = apply_half2full_map(text[i]) | |
| for i, t in enumerate(text): | |
| invalid_characters = count_invalid_characters(t) | |
| if len(invalid_characters): | |
| self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}') | |
| text[i] = apply_character_map(t) | |
| if do_homophone_replacement and self.init_homophones_replacer(): | |
| text[i] = self.homophones_replacer.replace(t) | |
| if t != text[i]: | |
| self.logger.log(logging.INFO, f'Homophones replace: {t} -> {text[i]}') | |
| if not skip_refine_text: | |
| text_tokens = refine_text( | |
| self.pretrain_models, | |
| text, | |
| **params_refine_text, | |
| )['ids'] | |
| text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens] | |
| text = self.pretrain_models['tokenizer'].batch_decode(text_tokens) | |
| if refine_text_only: | |
| yield text | |
| return | |
| text = [params_infer_code.get('prompt', '') + i for i in text] | |
| params_infer_code.pop('prompt', '') | |
| result_gen = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder, stream=stream) | |
| if use_decoder: | |
| field = 'hiddens' | |
| docoder_name = 'decoder' | |
| else: | |
| field = 'ids' | |
| docoder_name = 'dvae' | |
| vocos_decode = lambda spec: [self.pretrain_models['vocos'].decode( | |
| i.cpu() if torch.backends.mps.is_available() else i | |
| ).cpu().numpy() for i in spec] | |
| if stream: | |
| length = 0 | |
| for result in result_gen: | |
| chunk_data = result[field][0] | |
| assert len(result[field]) == 1 | |
| start_seek = length | |
| length = len(chunk_data) | |
| self.logger.debug(f'{start_seek=} total len: {length}, new len: {length - start_seek = }') | |
| chunk_data = chunk_data[start_seek:] | |
| if not len(chunk_data): | |
| continue | |
| self.logger.debug(f'new hidden {len(chunk_data)=}') | |
| mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1)) for i in [chunk_data]] | |
| wav = vocos_decode(mel_spec) | |
| self.logger.debug(f'yield wav chunk {len(wav[0])=} {len(wav[0][0])=}') | |
| yield wav | |
| return | |
| mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1)) for i in next(result_gen)[field]] | |
| yield vocos_decode(mel_spec) | |
| def infer( | |
| self, | |
| text, | |
| skip_refine_text=False, | |
| refine_text_only=False, | |
| params_refine_text={}, | |
| params_infer_code={'prompt':'[speed_5]'}, | |
| use_decoder=True, | |
| do_text_normalization=True, | |
| lang=None, | |
| stream=False, | |
| do_homophone_replacement=True, | |
| ): | |
| res_gen = self._infer( | |
| text, | |
| skip_refine_text, | |
| refine_text_only, | |
| params_refine_text, | |
| params_infer_code, | |
| use_decoder, | |
| do_text_normalization, | |
| lang, | |
| stream, | |
| do_homophone_replacement, | |
| ) | |
| if stream: | |
| return res_gen | |
| else: | |
| return next(res_gen) | |
| def sample_random_speaker(self, ): | |
| dim = self.pretrain_models['gpt'].gpt.layers[0].mlp.gate_proj.in_features | |
| std, mean = self.pretrain_models['spk_stat'].chunk(2) | |
| return torch.randn(dim, device=std.device) * std + mean | |
| def init_normalizer(self, lang) -> bool: | |
| if lang in self.normalizer: | |
| return True | |
| if lang == 'zh': | |
| try: | |
| from tn.chinese.normalizer import Normalizer | |
| self.normalizer[lang] = Normalizer().normalize | |
| return True | |
| except: | |
| self.logger.log( | |
| logging.WARNING, | |
| 'Package WeTextProcessing not found!', | |
| ) | |
| self.logger.log( | |
| logging.WARNING, | |
| 'Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing', | |
| ) | |
| else: | |
| try: | |
| from nemo_text_processing.text_normalization.normalize import Normalizer | |
| self.normalizer[lang] = partial(Normalizer(input_case='cased', lang=lang).normalize, verbose=False, punct_post_process=True) | |
| return True | |
| except: | |
| self.logger.log( | |
| logging.WARNING, | |
| 'Package nemo_text_processing not found!', | |
| ) | |
| self.logger.log( | |
| logging.WARNING, | |
| 'Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing', | |
| ) | |
| return False | |
| def init_homophones_replacer(self): | |
| if self.homophones_replacer: | |
| return True | |
| else: | |
| try: | |
| self.homophones_replacer = HomophonesReplacer(os.path.join(os.path.dirname(__file__), 'res', 'homophones_map.json')) | |
| self.logger.log(logging.INFO, 'homophones_replacer loaded.') | |
| return True | |
| except (IOError, json.JSONDecodeError) as e: | |
| self.logger.log(logging.WARNING, f'Error loading homophones map: {e}') | |
| except Exception as e: | |
| self.logger.log(logging.WARNING, f'Error loading homophones_replacer: {e}') | |
| return False | |