Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| from transformers import AutoConfig | |
| from parler_tts import ParlerTTSDecoderConfig, ParlerTTSForCausalLM, ParlerTTSForConditionalGeneration | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("save_directory", type=str, help="Directory where to save the model and the decoder.") | |
| parser.add_argument("--text_model", type=str, help="Repository id or path to the text encoder.") | |
| parser.add_argument("--audio_model", type=str, help="Repository id or path to the audio encoder.") | |
| args = parser.parse_args() | |
| text_model = args.text_model | |
| encodec_version = args.audio_model | |
| t5 = AutoConfig.from_pretrained(text_model) | |
| encodec = AutoConfig.from_pretrained(encodec_version) | |
| encodec_vocab_size = encodec.codebook_size | |
| num_codebooks = encodec.num_codebooks | |
| print("num_codebooks", num_codebooks) | |
| decoder_config = ParlerTTSDecoderConfig( | |
| vocab_size=encodec_vocab_size + 64, # + 64 instead of +1 to have a multiple of 64 | |
| max_position_embeddings=4096, # 30 s = 2580 | |
| num_hidden_layers=24, | |
| ffn_dim=4096, | |
| num_attention_heads=16, | |
| layerdrop=0.0, | |
| use_cache=True, | |
| activation_function="gelu", | |
| hidden_size=1024, | |
| dropout=0.1, | |
| attention_dropout=0.0, | |
| activation_dropout=0.0, | |
| pad_token_id=encodec_vocab_size, | |
| eos_token_id=encodec_vocab_size, | |
| bos_token_id=encodec_vocab_size + 1, | |
| num_codebooks=num_codebooks, | |
| ) | |
| decoder = ParlerTTSForCausalLM(decoder_config) | |
| decoder.save_pretrained(os.path.join(args.save_directory, "decoder")) | |
| model = ParlerTTSForConditionalGeneration.from_sub_models_pretrained( | |
| text_encoder_pretrained_model_name_or_path=text_model, | |
| audio_encoder_pretrained_model_name_or_path=encodec_version, | |
| decoder_pretrained_model_name_or_path=os.path.join(args.save_directory, "decoder"), | |
| vocab_size=t5.vocab_size, | |
| ) | |
| # set the appropriate bos/pad token ids | |
| model.generation_config.decoder_start_token_id = encodec_vocab_size + 1 | |
| model.generation_config.pad_token_id = encodec_vocab_size | |
| model.generation_config.eos_token_id = encodec_vocab_size | |
| # set other default generation config params | |
| model.generation_config.max_length = int(30 * model.audio_encoder.config.frame_rate) | |
| model.generation_config.do_sample = True # True | |
| model.config.pad_token_id = encodec_vocab_size | |
| model.config.decoder_start_token_id = encodec_vocab_size + 1 | |
| model.save_pretrained(os.path.join(args.save_directory, "parler-tts-untrained-600M/")) | |