import argparse import json import wave from pathlib import Path import numpy as np import MNN.expr as expr import MNN.nn as nn def load_vocab(vocab_path: Path) -> dict[str, int]: with vocab_path.open(encoding="utf-8") as handle: return json.load(handle) def normalize_text(text: str, vocab: dict[str, int]) -> str: lowered = "".join(char.lower() for char in text) return "".join(char for char in lowered if char in vocab).strip() def tokenize(text: str, vocab: dict[str, int], add_blank: bool = True) -> np.ndarray: filtered = normalize_text(text, vocab) if not filtered: raise ValueError("Text becomes empty after tokenizer normalization.") token_ids = [vocab[char] for char in filtered] if add_blank: interspersed = [0] * (len(token_ids) * 2 + 1) interspersed[1::2] = token_ids token_ids = interspersed return np.asarray([token_ids], dtype=np.int32) def write_wav(path: Path, waveform: np.ndarray, sample_rate: int) -> None: clipped = np.clip(waveform, -1.0, 1.0) pcm = (clipped * 32767.0).astype(np.int16) with wave.open(str(path), "wb") as wav_file: wav_file.setnchannels(1) wav_file.setsampwidth(2) wav_file.setframerate(sample_rate) wav_file.writeframes(pcm.tobytes()) def build_placeholder(value: np.ndarray, reference_var) -> object: placeholder = expr.placeholder(list(value.shape), reference_var.data_format, reference_var.dtype) placeholder.write(value) return placeholder def maybe_add_control_input( model_inputs: list[object], model_input_names: list[str], graph_vars: dict[str, object], input_name: str, input_value: int | None, flag_name: str, ) -> None: if input_name not in graph_vars: if input_value is not None: raise ValueError( f"Model does not expose {input_name}. Re-export the model with {input_name} as a graph input." ) return if input_value is None: raise ValueError(f"Model expects {input_name}. Pass {flag_name}.") value = np.asarray([input_value], dtype=np.int32) model_inputs.append(build_placeholder(value, graph_vars[input_name])) model_input_names.append(input_name) def synthesize( model_path: Path, vocab_path: Path, text: str, speaker_id: int | None = None, style_id: int | None = None, ) -> np.ndarray: vocab = load_vocab(vocab_path) input_ids = tokenize(text, vocab) attention_mask = np.ones_like(input_ids, dtype=np.int32) graph_vars = expr.load_as_dict(str(model_path)) model_inputs = [ build_placeholder(input_ids, graph_vars["input_ids"]), build_placeholder(attention_mask, graph_vars["attention_mask"]), ] model_input_names = ["input_ids", "attention_mask"] maybe_add_control_input(model_inputs, model_input_names, graph_vars, "speaker_id", speaker_id, "--speaker-id") maybe_add_control_input(model_inputs, model_input_names, graph_vars, "emotion_id", style_id, "--style-id") module = nn.load_module_from_file( str(model_path), model_input_names, ["waveform"], dynamic=True, ) outputs = module.forward(model_inputs) if not outputs: raise RuntimeError("PyMNN returned no outputs.") waveform = outputs[0].read() return np.asarray(waveform).squeeze(0) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run VITS MNN inference with PyMNN.") parser.add_argument("--text", required=True, help="Input text to synthesize.") parser.add_argument("--model", default="vits_tamil_static.mnn", help="Path to the MNN model.") parser.add_argument("--vocab", default="vocab.json", help="Path to the tokenizer vocabulary JSON.") parser.add_argument("--output", default="audio_mnn.wav", help="Output WAV file path.") parser.add_argument("--sample-rate", type=int, default=24000, help="Output WAV sample rate.") parser.add_argument("--speaker-id", type=int, default=None, help="Speaker ID for models that expose speaker_id.") parser.add_argument( "--style-id", "--emotion-id", dest="style_id", type=int, default=None, help="Style or emotion ID for models that expose emotion_id.", ) return parser.parse_args() def main() -> None: args = parse_args() waveform = synthesize( Path(args.model), Path(args.vocab), args.text, speaker_id=args.speaker_id, style_id=args.style_id, ) write_wav(Path(args.output), waveform, args.sample_rate) print(f"wrote {args.output} with {waveform.shape[-1]} samples at {args.sample_rate} Hz") if __name__ == "__main__": main()