import argparse import shutil import subprocess from pathlib import Path import onnx import torch from transformers import AutoModel, AutoTokenizer class VitsExportWrapper(torch.nn.Module): def __init__(self, model: torch.nn.Module): super().__init__() self.model = model.eval() def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, speaker_id: torch.Tensor, emotion_id: torch.Tensor, ) -> torch.Tensor: outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, speaker_id=speaker_id.to(torch.long), emotion_id=emotion_id.to(torch.long), return_dict=True, ) return outputs.waveform def inspect_onnx(onnx_path: Path) -> None: model = onnx.load(str(onnx_path)) print("onnx inputs:") for value in model.graph.input: tensor_type = value.type.tensor_type dims = [] for dim in tensor_type.shape.dim: if dim.dim_value: dims.append(dim.dim_value) elif dim.dim_param: dims.append(dim.dim_param) else: dims.append("?") print(f" {value.name}: shape={dims}, elem_type={tensor_type.elem_type}") def inspect_mnn(mnn_path: Path) -> None: import MNN.expr as expr graph_vars = expr.load_as_dict(str(mnn_path)) for name in ("input_ids", "attention_mask", "speaker_id", "emotion_id"): if name not in graph_vars: print(f"mnn input missing: {name}") continue var = graph_vars[name] print(f"mnn input {name}: shape={var.shape}, dtype={var.dtype}, format={var.data_format}") def export_onnx(args: argparse.Namespace) -> None: model_dir = Path(args.model_dir) tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True) model = AutoModel.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True) wrapper = VitsExportWrapper(model) tokenized = tokenizer(text=args.text, return_tensors="pt") input_ids = tokenized["input_ids"].to(torch.long) attention_mask = tokenized.get("attention_mask", torch.ones_like(input_ids)).to(torch.long) speaker_id = torch.tensor([args.speaker_id], dtype=torch.long) emotion_id = torch.tensor([args.style_id], dtype=torch.long) torch.onnx.export( wrapper, (input_ids, attention_mask, speaker_id, emotion_id), str(args.onnx_output), input_names=["input_ids", "attention_mask", "speaker_id", "emotion_id"], output_names=["waveform"], dynamic_axes={ "input_ids": {1: "text_length"}, "attention_mask": {1: "text_length"}, "waveform": {1: "audio_length"}, }, opset_version=args.opset, do_constant_folding=True, dynamo=False, ) print(f"wrote {args.onnx_output}") inspect_onnx(Path(args.onnx_output)) def convert_to_mnn(args: argparse.Namespace) -> None: if not args.mnn_output: return converter = shutil.which("MNNConvert") if converter is None: raise FileNotFoundError("MNNConvert not found in PATH") command = [ converter, "-f", "ONNX", "--modelFile", str(args.onnx_output), "--MNNModel", str(args.mnn_output), "--bizCode", "MNN", ] subprocess.run(command, check=True) print(f"wrote {args.mnn_output}") inspect_mnn(Path(args.mnn_output)) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Re-export local VITS weights with speaker/style control inputs.") parser.add_argument("--model-dir", default=".", help="Directory containing config, tokenizer, custom code, and weights.") parser.add_argument("--text", default="வணக்கம்", help="Sample text used to trace the export graph.") parser.add_argument("--speaker-id", type=int, default=18, help="Sample speaker ID used during tracing.") parser.add_argument("--style-id", type=int, default=0, help="Sample style or emotion ID used during tracing.") parser.add_argument("--onnx-output", default="vits_tamil_with_controls.onnx", help="Path for exported ONNX.") parser.add_argument("--mnn-output", default="vits_tamil_with_controls.mnn", help="Optional output path for converted MNN.") parser.add_argument("--opset", type=int, default=17, help="ONNX opset version.") return parser.parse_args() def main() -> None: args = parse_args() export_onnx(args) convert_to_mnn(args) if __name__ == "__main__": main()