""" Bag inference for the full HT-Demucs FT 4-stem ONNX ensemble. Runs all 4 specialist sub-models and aggregates their outputs using the htdemucs_ft bag's one-hot weight matrix (drums-model -> drums stem only, bass-model -> bass stem only, etc). NO TORCH at inference. Just numpy + onnxruntime + soundfile. Usage: python bag_infer.py your-song.mp3 ./out/ # writes out/drums.wav, out/bass.wav, out/other.wav, out/vocals.wav Or as a library: import bag_infer stems = bag_infer.separate_all("song.mp3") # stems: dict[str, numpy.ndarray (2, samples)] """ from __future__ import annotations import argparse import sys import time from pathlib import Path import numpy as np import onnxruntime as ort import soundfile as sf SAMPLE_RATE = 44100 SEGMENT_S = 7.8 N_SAMPLES = int(SEGMENT_S * SAMPLE_RATE) # 343,980 N_CHANNELS = 2 SOURCES = ["drums", "bass", "other", "vocals"] HERE = Path(__file__).resolve().parent # The bag's weight matrix for htdemucs_ft is one-hot per stem: # drums specialist (bag.models[0]) -> contributes only to drums stem # bass specialist (bag.models[1]) -> contributes only to bass stem # other specialist (bag.models[2]) -> contributes only to other stem # vocals specialist (bag.models[3]) -> contributes only to vocals stem # So aggregation is trivial: pick row N from model N's output. DEFAULT_ONNX_FILES = { "drums": HERE / "htdemucs_ft_drums.onnx", "bass": HERE / "htdemucs_ft_bass.onnx", "other": HERE / "htdemucs_ft_other.onnx", "vocals": HERE / "htdemucs_ft_vocals.onnx", } def _make_transition_window(segment: int, overlap_frac: float = 0.25) -> np.ndarray: transition = int(segment * overlap_frac) window = np.ones(segment, dtype=np.float32) fade = np.linspace(0, 1, transition, dtype=np.float32) window[:transition] = fade window[-transition:] = fade[::-1] return window def _load_sessions(onnx_files: dict[str, Path], providers: list[str] | None = None, ) -> dict[str, ort.InferenceSession]: if providers is None: providers = ["CPUExecutionProvider"] sessions: dict[str, ort.InferenceSession] = {} for stem, path in onnx_files.items(): if not path.exists(): raise FileNotFoundError( f"Missing {stem} model at {path}. Download all 4 .onnx files " "into the same directory as this script.") sessions[stem] = ort.InferenceSession(str(path), providers=providers) return sessions def separate(mix: np.ndarray, sample_rate: int, onnx_files: dict[str, Path] | None = None, providers: list[str] | None = None, verbose: bool = True) -> dict[str, np.ndarray]: """Run full 4-stem chunked overlap-add separation. Args: mix: (channels, samples) float32 in [-1, 1], 44.1 kHz stereo. sample_rate: must equal 44100. onnx_files: optional dict overriding the default file locations. providers: onnxruntime EPs; defaults to CPU. verbose: print progress per chunk. Returns: dict of {stem_name: (channels, samples) float32}. """ if sample_rate != SAMPLE_RATE: raise ValueError(f"Bound to {SAMPLE_RATE} Hz; got {sample_rate}.") if mix.ndim != 2 or mix.shape[0] != N_CHANNELS: raise ValueError(f"Expected (2, samples) input, got {mix.shape}") sessions = _load_sessions(onnx_files or DEFAULT_ONNX_FILES, providers) if verbose: print(f" loaded {len(sessions)} ONNX sessions on " f"{list(sessions.values())[0].get_providers()[0]}") total_len = mix.shape[1] overlap = N_SAMPLES // 4 stride = N_SAMPLES - overlap n_chunks = max(1, (total_len + stride - 1) // stride) if verbose: print(f" input: {total_len:,} samples ({total_len / sample_rate:.1f}s)") print(f" chunks: {n_chunks}") window = _make_transition_window(N_SAMPLES) out = {stem: np.zeros((N_CHANNELS, total_len), dtype=np.float32) for stem in SOURCES} weight = np.zeros(total_len, dtype=np.float32) t0 = time.perf_counter() for i in range(n_chunks): start = i * stride end = min(start + N_SAMPLES, total_len) chunk = mix[:, start:end] if chunk.shape[1] < N_SAMPLES: chunk = np.pad(chunk, ((0, 0), (0, N_SAMPLES - chunk.shape[1])), mode="constant") x = chunk[np.newaxis, ...].astype(np.float32) chunk_len = end - start w = window[:chunk_len] # Run each specialist; take only its target stem row. for stem in SOURCES: stems = sessions[stem].run(["stems"], {"mix": x})[0][0] # (4, 2, N) target_row = SOURCES.index(stem) # 0/1/2/3 matches bag.models[idx] out[stem][:, start:end] += stems[target_row, :, :chunk_len] * w weight[start:end] += w if verbose: print(f" chunk {i+1}/{n_chunks}: " f"{time.perf_counter() - t0:.1f}s elapsed") weight = np.maximum(weight, 1e-8) for stem in SOURCES: out[stem] /= weight if verbose: rtf = (time.perf_counter() - t0) / (total_len / sample_rate) print(f" total: {time.perf_counter() - t0:.2f}s (RTF {rtf:.2f}, " f"4 sub-models × {n_chunks} chunks = " f"{4 * n_chunks} ONNX runs)") return out def separate_all(input_path: str, **kwargs) -> dict[str, np.ndarray]: """Convenience: load audio, run separation, return all 4 stems.""" audio, sr = sf.read(input_path, dtype="float32", always_2d=True) audio = audio.T if audio.shape[0] == 1: audio = np.tile(audio, (2, 1)) elif audio.shape[0] > 2: audio = audio[:2] return separate(audio, sr, **kwargs) def main() -> None: ap = argparse.ArgumentParser(description=__doc__) ap.add_argument("input", type=Path) ap.add_argument("out_dir", type=Path) ap.add_argument("--providers", type=str, default="cpu", choices=["cpu", "coreml", "cuda", "dml"]) args = ap.parse_args() providers_map = { "cpu": ["CPUExecutionProvider"], "coreml": ["CoreMLExecutionProvider", "CPUExecutionProvider"], "cuda": ["CUDAExecutionProvider", "CPUExecutionProvider"], "dml": ["DmlExecutionProvider", "CPUExecutionProvider"], } args.out_dir.mkdir(parents=True, exist_ok=True) print(f"Loading {args.input} ...") audio, sr = sf.read(str(args.input), dtype="float32", always_2d=True) audio = audio.T if audio.shape[0] == 1: audio = np.tile(audio, (2, 1)) elif audio.shape[0] > 2: audio = audio[:2] print(f" shape {audio.shape}, sr {sr}") stems = separate(audio, sr, providers=providers_map[args.providers]) for stem, audio_out in stems.items(): out_path = args.out_dir / f"{stem}.wav" sf.write(str(out_path), audio_out.T, sr) print(f" wrote {out_path}") if __name__ == "__main__": main()