#!/usr/bin/env python3 """Greedy TDT/RNNT inference with CoreML parakeet-tdt-1.1b. Usage: python infer.py audio.wav """ import json, sys from pathlib import Path import coremltools as ct import numpy as np import soundfile as sf REPO_DIR = Path(__file__).parent SAMPLE_RATE = 16_000 def load_audio(path, max_samples): data, sr = sf.read(path, dtype="float32", always_2d=False) if sr != SAMPLE_RATE: raise ValueError(f"Expected {SAMPLE_RATE} Hz.") if data.ndim > 1: data = data[:, 0] actual = min(len(data), max_samples) data = np.pad(data, (0, max(0, max_samples - len(data))))[:max_samples] return data.reshape(1, -1).astype(np.float32), actual def transcribe(audio_path, compute_units="ALL"): meta = json.loads((REPO_DIR / "metadata.json").read_text()) vocab = json.loads((REPO_DIR / "vocab.json").read_text()) blank = meta["blank_id"] n = meta["max_audio_samples"] bins = meta.get("duration_bins", [1]) comps = meta["components"]["decoder"]["inputs"] d_layers = comps["h_in"][0] d_hidden = comps["h_in"][2] cu_map = {"ALL": ct.ComputeUnit.ALL, "CPU_ONLY": ct.ComputeUnit.CPU_ONLY, "CPU_AND_NE": ct.ComputeUnit.CPU_AND_NE} cu = cu_map.get(compute_units.upper(), ct.ComputeUnit.ALL) mel_enc = ct.models.MLModel(str(REPO_DIR / "parakeet_mel_encoder.mlpackage"), compute_units=cu) dec_model = ct.models.MLModel(str(REPO_DIR / "parakeet_decoder.mlpackage"), compute_units=ct.ComputeUnit.CPU_ONLY) jd_model = ct.models.MLModel(str(REPO_DIR / "parakeet_joint_decision_single_step.mlpackage"), compute_units=cu) audio, actual = load_audio(audio_path, n) length = np.array([actual], dtype=np.int32) enc_out = mel_enc.predict({"audio_signal": audio, "audio_length": length}) encoder = enc_out["encoder"] enc_len = int(enc_out["encoder_length"][0]) h = np.zeros((d_layers, 1, d_hidden), dtype=np.float32) c = np.zeros((d_layers, 1, d_hidden), dtype=np.float32) prev = np.array([[blank]], dtype=np.int32) tlen = np.array([1], dtype=np.int32) dec_out = dec_model.predict({"targets": prev, "target_length": tlen, "h_in": h, "c_in": c}) dec_state, h, c = dec_out["decoder"], dec_out["h_out"], dec_out["c_out"] tokens, t = [], 0 while t < enc_len: jd = jd_model.predict({"encoder_step": encoder[:,:,t:t+1], "decoder_step": dec_state[:,:,:1]}) tok = int(jd["token_id"].flat[0]) dur = int(jd["duration"].flat[0]) adv = bins[min(dur, len(bins)-1)] if bins else 1 if tok != blank: tokens.append(tok) dec_out = dec_model.predict({"targets": np.array([[tok]], dtype=np.int32), "target_length": tlen, "h_in": h, "c_in": c}) dec_state, h, c = dec_out["decoder"], dec_out["h_out"], dec_out["c_out"] t += max(1, adv) return "".join(vocab[i] for i in tokens if i < len(vocab)).replace("▁", " ").strip() if __name__ == "__main__": args = sys.argv[1:] if not args: print("Usage: python infer.py [--compute-units ALL|CPU_ONLY|CPU_AND_NE]") sys.exit(1) cu = "ALL" if "--compute-units" in args: cu = args[args.index("--compute-units") + 1] print(transcribe(args[0], cu))