import argparse from pathlib import Path from typing import Optional import torch import utils from models import SynthesizerTrn from text.symbols import symbols OPSET_VERSION = 15 def main() -> None: torch.manual_seed(1234) parser = argparse.ArgumentParser() parser.add_argument( "--model-path", required=True, help="Path to model weights (.pth)" ) parser.add_argument( "--config-path", required=True, help="Path to model config (.json)" ) parser.add_argument("--output", required=True, help="Path to output model (.onnx)") args = parser.parse_args() args.model_path = Path(args.model_path) args.config_path = Path(args.config_path) args.output = Path(args.output) args.output.parent.mkdir(parents=True, exist_ok=True) hps = utils.get_hparams_from_file(args.config_path) if ( "use_mel_posterior_encoder" in hps.model.keys() and hps.model.use_mel_posterior_encoder == True ): print("Using mel posterior encoder for VITS2") posterior_channels = 80 # vits2 hps.data.use_mel_posterior_encoder = True else: print("Using lin posterior encoder for VITS1") posterior_channels = hps.data.filter_length // 2 + 1 hps.data.use_mel_posterior_encoder = False model_g = SynthesizerTrn( len(symbols), posterior_channels, hps.train.segment_size // hps.data.hop_length, n_speakers=hps.data.n_speakers, **hps.model, ) _ = model_g.eval() _ = utils.load_checkpoint(args.model_path, model_g, None) def infer_forward(text, text_lengths, scales, sid=None): noise_scale = scales[0] length_scale = scales[1] noise_scale_w = scales[2] audio = model_g.infer( text, text_lengths, noise_scale=noise_scale, length_scale=length_scale, noise_scale_w=noise_scale_w, sid=sid, )[0] return audio model_g.forward = infer_forward dummy_input_length = 50 sequences = torch.randint( low=0, high=len(symbols), size=(1, dummy_input_length), dtype=torch.long ) sequence_lengths = torch.LongTensor([sequences.size(1)]) sid: Optional[torch.LongTensor] = None if hps.data.n_speakers > 1: sid = torch.LongTensor([0]) # noise, length, noise_w scales = torch.FloatTensor([0.667, 1.0, 0.8]) dummy_input = (sequences, sequence_lengths, scales, sid) # Export torch.onnx.export( model=model_g, args=dummy_input, f=str(args.output), verbose=False, opset_version=OPSET_VERSION, input_names=["input", "input_lengths", "scales", "sid"], output_names=["output"], dynamic_axes={ "input": {0: "batch_size", 1: "phonemes"}, "input_lengths": {0: "batch_size"}, "output": {0: "batch_size", 1: "time1", 2: "time2"}, }, ) print(f"Exported model to {args.output}") if __name__ == "__main__": main()