Spaces:
Runtime error
Runtime error
| import argparse | |
| from pathlib import Path | |
| from typing import Optional | |
| import torch | |
| import utils | |
| from models import SynthesizerTrn | |
| from text.symbols import symbols | |
| OPSET_VERSION = 15 | |
| def main() -> None: | |
| torch.manual_seed(1234) | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--model-path", required=True, help="Path to model weights (.pth)" | |
| ) | |
| parser.add_argument( | |
| "--config-path", required=True, help="Path to model config (.json)" | |
| ) | |
| parser.add_argument("--output", required=True, help="Path to output model (.onnx)") | |
| args = parser.parse_args() | |
| args.model_path = Path(args.model_path) | |
| args.config_path = Path(args.config_path) | |
| args.output = Path(args.output) | |
| args.output.parent.mkdir(parents=True, exist_ok=True) | |
| hps = utils.get_hparams_from_file(args.config_path) | |
| if ( | |
| "use_mel_posterior_encoder" in hps.model.keys() | |
| and hps.model.use_mel_posterior_encoder == True | |
| ): | |
| print("Using mel posterior encoder for VITS2") | |
| posterior_channels = 80 # vits2 | |
| hps.data.use_mel_posterior_encoder = True | |
| else: | |
| print("Using lin posterior encoder for VITS1") | |
| posterior_channels = hps.data.filter_length // 2 + 1 | |
| hps.data.use_mel_posterior_encoder = False | |
| model_g = SynthesizerTrn( | |
| len(symbols), | |
| posterior_channels, | |
| hps.train.segment_size // hps.data.hop_length, | |
| n_speakers=hps.data.n_speakers, | |
| **hps.model, | |
| ) | |
| _ = model_g.eval() | |
| _ = utils.load_checkpoint(args.model_path, model_g, None) | |
| def infer_forward(text, text_lengths, scales, sid=None): | |
| noise_scale = scales[0] | |
| length_scale = scales[1] | |
| noise_scale_w = scales[2] | |
| audio = model_g.infer( | |
| text, | |
| text_lengths, | |
| noise_scale=noise_scale, | |
| length_scale=length_scale, | |
| noise_scale_w=noise_scale_w, | |
| sid=sid, | |
| )[0] | |
| return audio | |
| model_g.forward = infer_forward | |
| dummy_input_length = 50 | |
| sequences = torch.randint( | |
| low=0, high=len(symbols), size=(1, dummy_input_length), dtype=torch.long | |
| ) | |
| sequence_lengths = torch.LongTensor([sequences.size(1)]) | |
| sid: Optional[torch.LongTensor] = None | |
| if hps.data.n_speakers > 1: | |
| sid = torch.LongTensor([0]) | |
| # noise, length, noise_w | |
| scales = torch.FloatTensor([0.667, 1.0, 0.8]) | |
| dummy_input = (sequences, sequence_lengths, scales, sid) | |
| # Export | |
| torch.onnx.export( | |
| model=model_g, | |
| args=dummy_input, | |
| f=str(args.output), | |
| verbose=False, | |
| opset_version=OPSET_VERSION, | |
| input_names=["input", "input_lengths", "scales", "sid"], | |
| output_names=["output"], | |
| dynamic_axes={ | |
| "input": {0: "batch_size", 1: "phonemes"}, | |
| "input_lengths": {0: "batch_size"}, | |
| "output": {0: "batch_size", 1: "time1", 2: "time2"}, | |
| }, | |
| ) | |
| print(f"Exported model to {args.output}") | |
| if __name__ == "__main__": | |
| main() | |