Instructions to use developerabu/vits-tts-mnn with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use developerabu/vits-tts-mnn with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-to-speech", model="developerabu/vits-tts-mnn", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("developerabu/vits-tts-mnn", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import argparse | |
| import shutil | |
| import subprocess | |
| from pathlib import Path | |
| import onnx | |
| import torch | |
| from transformers import AutoModel, AutoTokenizer | |
| class VitsExportWrapper(torch.nn.Module): | |
| def __init__(self, model: torch.nn.Module): | |
| super().__init__() | |
| self.model = model.eval() | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| speaker_id: torch.Tensor, | |
| emotion_id: torch.Tensor, | |
| ) -> torch.Tensor: | |
| outputs = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| speaker_id=speaker_id.to(torch.long), | |
| emotion_id=emotion_id.to(torch.long), | |
| return_dict=True, | |
| ) | |
| return outputs.waveform | |
| def inspect_onnx(onnx_path: Path) -> None: | |
| model = onnx.load(str(onnx_path)) | |
| print("onnx inputs:") | |
| for value in model.graph.input: | |
| tensor_type = value.type.tensor_type | |
| dims = [] | |
| for dim in tensor_type.shape.dim: | |
| if dim.dim_value: | |
| dims.append(dim.dim_value) | |
| elif dim.dim_param: | |
| dims.append(dim.dim_param) | |
| else: | |
| dims.append("?") | |
| print(f" {value.name}: shape={dims}, elem_type={tensor_type.elem_type}") | |
| def inspect_mnn(mnn_path: Path) -> None: | |
| import MNN.expr as expr | |
| graph_vars = expr.load_as_dict(str(mnn_path)) | |
| for name in ("input_ids", "attention_mask", "speaker_id", "emotion_id"): | |
| if name not in graph_vars: | |
| print(f"mnn input missing: {name}") | |
| continue | |
| var = graph_vars[name] | |
| print(f"mnn input {name}: shape={var.shape}, dtype={var.dtype}, format={var.data_format}") | |
| def export_onnx(args: argparse.Namespace) -> None: | |
| model_dir = Path(args.model_dir) | |
| tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True) | |
| model = AutoModel.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True) | |
| wrapper = VitsExportWrapper(model) | |
| tokenized = tokenizer(text=args.text, return_tensors="pt") | |
| input_ids = tokenized["input_ids"].to(torch.long) | |
| attention_mask = tokenized.get("attention_mask", torch.ones_like(input_ids)).to(torch.long) | |
| speaker_id = torch.tensor([args.speaker_id], dtype=torch.long) | |
| emotion_id = torch.tensor([args.style_id], dtype=torch.long) | |
| torch.onnx.export( | |
| wrapper, | |
| (input_ids, attention_mask, speaker_id, emotion_id), | |
| str(args.onnx_output), | |
| input_names=["input_ids", "attention_mask", "speaker_id", "emotion_id"], | |
| output_names=["waveform"], | |
| dynamic_axes={ | |
| "input_ids": {1: "text_length"}, | |
| "attention_mask": {1: "text_length"}, | |
| "waveform": {1: "audio_length"}, | |
| }, | |
| opset_version=args.opset, | |
| do_constant_folding=True, | |
| dynamo=False, | |
| ) | |
| print(f"wrote {args.onnx_output}") | |
| inspect_onnx(Path(args.onnx_output)) | |
| def convert_to_mnn(args: argparse.Namespace) -> None: | |
| if not args.mnn_output: | |
| return | |
| converter = shutil.which("MNNConvert") | |
| if converter is None: | |
| raise FileNotFoundError("MNNConvert not found in PATH") | |
| command = [ | |
| converter, | |
| "-f", | |
| "ONNX", | |
| "--modelFile", | |
| str(args.onnx_output), | |
| "--MNNModel", | |
| str(args.mnn_output), | |
| "--bizCode", | |
| "MNN", | |
| ] | |
| subprocess.run(command, check=True) | |
| print(f"wrote {args.mnn_output}") | |
| inspect_mnn(Path(args.mnn_output)) | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Re-export local VITS weights with speaker/style control inputs.") | |
| parser.add_argument("--model-dir", default=".", help="Directory containing config, tokenizer, custom code, and weights.") | |
| parser.add_argument("--text", default="வணக்கம்", help="Sample text used to trace the export graph.") | |
| parser.add_argument("--speaker-id", type=int, default=18, help="Sample speaker ID used during tracing.") | |
| parser.add_argument("--style-id", type=int, default=0, help="Sample style or emotion ID used during tracing.") | |
| parser.add_argument("--onnx-output", default="vits_tamil_with_controls.onnx", help="Path for exported ONNX.") | |
| parser.add_argument("--mnn-output", default="vits_tamil_with_controls.mnn", help="Optional output path for converted MNN.") | |
| parser.add_argument("--opset", type=int, default=17, help="ONNX opset version.") | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| export_onnx(args) | |
| convert_to_mnn(args) | |
| if __name__ == "__main__": | |
| main() |