vits-tts-mnn / reexport_vits_with_controls.py
developerabu's picture
Upload 7 files
6d774ce verified
Raw
History Blame Contribute Delete
4.69 kB
import argparse
import shutil
import subprocess
from pathlib import Path
import onnx
import torch
from transformers import AutoModel, AutoTokenizer
class VitsExportWrapper(torch.nn.Module):
def __init__(self, model: torch.nn.Module):
super().__init__()
self.model = model.eval()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
speaker_id: torch.Tensor,
emotion_id: torch.Tensor,
) -> torch.Tensor:
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
speaker_id=speaker_id.to(torch.long),
emotion_id=emotion_id.to(torch.long),
return_dict=True,
)
return outputs.waveform
def inspect_onnx(onnx_path: Path) -> None:
model = onnx.load(str(onnx_path))
print("onnx inputs:")
for value in model.graph.input:
tensor_type = value.type.tensor_type
dims = []
for dim in tensor_type.shape.dim:
if dim.dim_value:
dims.append(dim.dim_value)
elif dim.dim_param:
dims.append(dim.dim_param)
else:
dims.append("?")
print(f" {value.name}: shape={dims}, elem_type={tensor_type.elem_type}")
def inspect_mnn(mnn_path: Path) -> None:
import MNN.expr as expr
graph_vars = expr.load_as_dict(str(mnn_path))
for name in ("input_ids", "attention_mask", "speaker_id", "emotion_id"):
if name not in graph_vars:
print(f"mnn input missing: {name}")
continue
var = graph_vars[name]
print(f"mnn input {name}: shape={var.shape}, dtype={var.dtype}, format={var.data_format}")
def export_onnx(args: argparse.Namespace) -> None:
model_dir = Path(args.model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True)
model = AutoModel.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True)
wrapper = VitsExportWrapper(model)
tokenized = tokenizer(text=args.text, return_tensors="pt")
input_ids = tokenized["input_ids"].to(torch.long)
attention_mask = tokenized.get("attention_mask", torch.ones_like(input_ids)).to(torch.long)
speaker_id = torch.tensor([args.speaker_id], dtype=torch.long)
emotion_id = torch.tensor([args.style_id], dtype=torch.long)
torch.onnx.export(
wrapper,
(input_ids, attention_mask, speaker_id, emotion_id),
str(args.onnx_output),
input_names=["input_ids", "attention_mask", "speaker_id", "emotion_id"],
output_names=["waveform"],
dynamic_axes={
"input_ids": {1: "text_length"},
"attention_mask": {1: "text_length"},
"waveform": {1: "audio_length"},
},
opset_version=args.opset,
do_constant_folding=True,
dynamo=False,
)
print(f"wrote {args.onnx_output}")
inspect_onnx(Path(args.onnx_output))
def convert_to_mnn(args: argparse.Namespace) -> None:
if not args.mnn_output:
return
converter = shutil.which("MNNConvert")
if converter is None:
raise FileNotFoundError("MNNConvert not found in PATH")
command = [
converter,
"-f",
"ONNX",
"--modelFile",
str(args.onnx_output),
"--MNNModel",
str(args.mnn_output),
"--bizCode",
"MNN",
]
subprocess.run(command, check=True)
print(f"wrote {args.mnn_output}")
inspect_mnn(Path(args.mnn_output))
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Re-export local VITS weights with speaker/style control inputs.")
parser.add_argument("--model-dir", default=".", help="Directory containing config, tokenizer, custom code, and weights.")
parser.add_argument("--text", default="வணக்கம்", help="Sample text used to trace the export graph.")
parser.add_argument("--speaker-id", type=int, default=18, help="Sample speaker ID used during tracing.")
parser.add_argument("--style-id", type=int, default=0, help="Sample style or emotion ID used during tracing.")
parser.add_argument("--onnx-output", default="vits_tamil_with_controls.onnx", help="Path for exported ONNX.")
parser.add_argument("--mnn-output", default="vits_tamil_with_controls.mnn", help="Optional output path for converted MNN.")
parser.add_argument("--opset", type=int, default=17, help="ONNX opset version.")
return parser.parse_args()
def main() -> None:
args = parse_args()
export_onnx(args)
convert_to_mnn(args)
if __name__ == "__main__":
main()