vits-tts-mnn / run_mnn_inference.py
developerabu's picture
Upload 7 files
6d774ce verified
Raw
History Blame Contribute Delete
4.77 kB
import argparse
import json
import wave
from pathlib import Path
import numpy as np
import MNN.expr as expr
import MNN.nn as nn
def load_vocab(vocab_path: Path) -> dict[str, int]:
with vocab_path.open(encoding="utf-8") as handle:
return json.load(handle)
def normalize_text(text: str, vocab: dict[str, int]) -> str:
lowered = "".join(char.lower() for char in text)
return "".join(char for char in lowered if char in vocab).strip()
def tokenize(text: str, vocab: dict[str, int], add_blank: bool = True) -> np.ndarray:
filtered = normalize_text(text, vocab)
if not filtered:
raise ValueError("Text becomes empty after tokenizer normalization.")
token_ids = [vocab[char] for char in filtered]
if add_blank:
interspersed = [0] * (len(token_ids) * 2 + 1)
interspersed[1::2] = token_ids
token_ids = interspersed
return np.asarray([token_ids], dtype=np.int32)
def write_wav(path: Path, waveform: np.ndarray, sample_rate: int) -> None:
clipped = np.clip(waveform, -1.0, 1.0)
pcm = (clipped * 32767.0).astype(np.int16)
with wave.open(str(path), "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(pcm.tobytes())
def build_placeholder(value: np.ndarray, reference_var) -> object:
placeholder = expr.placeholder(list(value.shape), reference_var.data_format, reference_var.dtype)
placeholder.write(value)
return placeholder
def maybe_add_control_input(
model_inputs: list[object],
model_input_names: list[str],
graph_vars: dict[str, object],
input_name: str,
input_value: int | None,
flag_name: str,
) -> None:
if input_name not in graph_vars:
if input_value is not None:
raise ValueError(
f"Model does not expose {input_name}. Re-export the model with {input_name} as a graph input."
)
return
if input_value is None:
raise ValueError(f"Model expects {input_name}. Pass {flag_name}.")
value = np.asarray([input_value], dtype=np.int32)
model_inputs.append(build_placeholder(value, graph_vars[input_name]))
model_input_names.append(input_name)
def synthesize(
model_path: Path,
vocab_path: Path,
text: str,
speaker_id: int | None = None,
style_id: int | None = None,
) -> np.ndarray:
vocab = load_vocab(vocab_path)
input_ids = tokenize(text, vocab)
attention_mask = np.ones_like(input_ids, dtype=np.int32)
graph_vars = expr.load_as_dict(str(model_path))
model_inputs = [
build_placeholder(input_ids, graph_vars["input_ids"]),
build_placeholder(attention_mask, graph_vars["attention_mask"]),
]
model_input_names = ["input_ids", "attention_mask"]
maybe_add_control_input(model_inputs, model_input_names, graph_vars, "speaker_id", speaker_id, "--speaker-id")
maybe_add_control_input(model_inputs, model_input_names, graph_vars, "emotion_id", style_id, "--style-id")
module = nn.load_module_from_file(
str(model_path),
model_input_names,
["waveform"],
dynamic=True,
)
outputs = module.forward(model_inputs)
if not outputs:
raise RuntimeError("PyMNN returned no outputs.")
waveform = outputs[0].read()
return np.asarray(waveform).squeeze(0)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run VITS MNN inference with PyMNN.")
parser.add_argument("--text", required=True, help="Input text to synthesize.")
parser.add_argument("--model", default="vits_tamil_static.mnn", help="Path to the MNN model.")
parser.add_argument("--vocab", default="vocab.json", help="Path to the tokenizer vocabulary JSON.")
parser.add_argument("--output", default="audio_mnn.wav", help="Output WAV file path.")
parser.add_argument("--sample-rate", type=int, default=24000, help="Output WAV sample rate.")
parser.add_argument("--speaker-id", type=int, default=None, help="Speaker ID for models that expose speaker_id.")
parser.add_argument(
"--style-id",
"--emotion-id",
dest="style_id",
type=int,
default=None,
help="Style or emotion ID for models that expose emotion_id.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
waveform = synthesize(
Path(args.model),
Path(args.vocab),
args.text,
speaker_id=args.speaker_id,
style_id=args.style_id,
)
write_wav(Path(args.output), waveform, args.sample_rate)
print(f"wrote {args.output} with {waveform.shape[-1]} samples at {args.sample_rate} Hz")
if __name__ == "__main__":
main()