"""Long-audio diarization via vLLM + ReDimNet2 B6 cross-chunk linkage. Sends ≤28s windows of audio to a running vLLM diarize endpoint in parallel (exploiting vLLM batching), then re-clusters speakers globally using ReDimNet2 embeddings of each parsed segment. Usage: python diarize_long_vllm.py podcast.wav \\ --vllm http://127.0.0.1:8000 \\ --model syvai/cohere-transcribe-diarize \\ --tau 0.45 """ import argparse, asyncio, io, json, os, re, sys, time from typing import List, Dict import numpy as np import soundfile as sf import torch from scipy.cluster.hierarchy import linkage, fcluster from scipy.spatial.distance import squareform import aiohttp SR = 16000 CHUNK_S = 28.0 OVERLAP_S = 2.0 MIN_SEG_S = 0.20 EMB_MIN_S = 0.50 EMB_MAX_S = 3.0 # cap each embedding crop — beyond ~2-3s a single speaker, extra audio adds little CLUSTER_TAU = 0.45 PROMPT_BY_LANG = { "en": "<|startofcontext|><|startoftranscript|><|emo:undefined|><|en|><|en|><|pnc|><|noitn|><|timestamp|><|diarize|>", "da": "<|startofcontext|><|startoftranscript|><|emo:undefined|><|da|><|da|><|pnc|><|noitn|><|timestamp|><|diarize|>", } def load_audio(path): wav, sr = sf.read(path) if wav.ndim > 1: wav = wav.mean(axis=1) if sr != SR: import librosa wav = librosa.resample(wav.astype(np.float32), orig_sr=sr, target_sr=SR) return wav.astype(np.float32) async def transcribe_chunk(session, url, model, language, chunk_wav, idx, prompt): buf = io.BytesIO() sf.write(buf, chunk_wav, SR, format="WAV", subtype="PCM_16") data = aiohttp.FormData() data.add_field("file", buf.getvalue(), filename=f"chunk{idx:04d}.wav", content_type="audio/wav") data.add_field("model", model) data.add_field("language", language) data.add_field("response_format", "diarized_json") data.add_field("prompt", prompt) data.add_field("max_completion_tokens", "400") async with session.post(f"{url}/v1/audio/transcriptions", data=data) as resp: body = await resp.json() return idx, body async def fetch_all_chunks(audio, vllm_url, model, language, chunk_s, overlap_s, concurrency=8): step = max(0.1, chunk_s - overlap_s) total_s = len(audio) / SR chunks = [] t0 = 0.0 cid = 0 while t0 < total_s: t1 = min(t0 + chunk_s, total_s) if t1 - t0 < 2.0: t1 = total_s; t0 = max(0.0, t1 - chunk_s) chunks.append((cid, t0, t1, audio[int(t0*SR):int(t1*SR)])) if t1 >= total_s: break t0 += step cid += 1 prompt = PROMPT_BY_LANG[language] print(f" {len(chunks)} chunks, concurrency={concurrency}", flush=True) timeout = aiohttp.ClientTimeout(total=600) conn = aiohttp.TCPConnector(limit=concurrency) sem = asyncio.Semaphore(concurrency) async def bounded(session, *args): async with sem: return await transcribe_chunk(session, *args) async with aiohttp.ClientSession(timeout=timeout, connector=conn) as session: tasks = [bounded(session, vllm_url, model, language, w, c, prompt) for (c, _, _, w) in chunks] results = await asyncio.gather(*tasks) by_id = {idx: body for idx, body in results} return chunks, by_id def load_redimnet(model_name="b6", compile_=False): # Enable TF32 — free conv speedup on Ampere. Skip cudnn.benchmark when shapes vary. torch.set_float32_matmul_precision("high") m = torch.hub.load("PalabraAI/redimnet2", "redimnet2", model_name=model_name, train_type="lm", pretrained=True, source="github", trust_repo=True ).cuda().eval() if compile_: m = torch.compile(m, mode="reduce-overhead", dynamic=False, fullgraph=False) return m @torch.inference_mode() def embed_segments(redimnet, audio, segs, batch_size=32, fixed_len_s=0.0): """Embed all segments via ReDimNet2 in batched forward passes. Pads to per-batch max length so we don't waste compute on the longest tail. If fixed_len_s > 0, every crop is pad/trim'd to that many seconds — gives a single static shape so torch.compile reuses one CUDA graph for all batches.""" crops, keep = [], [] for s in segs: st, ed = s["abs_start"], s["abs_end"] if ed - st < EMB_MIN_S: mid = 0.5*(st+ed); half = 0.5*EMB_MIN_S st = max(0.0, mid-half); ed = min(len(audio)/SR, mid+half) a = audio[max(0,int(st*SR)):min(len(audio), int(ed*SR))] if len(a) < SR // 4: continue if len(a) > int(SR * EMB_MAX_S): a = a[:int(SR * EMB_MAX_S)] crops.append(a.astype(np.float32)); keep.append(s) if not crops: return np.zeros((0, 192), dtype=np.float32), [] embs_out = [None] * len(crops) if fixed_len_s > 0: # All crops padded/trimmed to a single fixed length: same shape every batch. L = int(fixed_len_s * SR) for start in range(0, len(crops), batch_size): idxs = list(range(start, min(start+batch_size, len(crops)))) batch = np.zeros((len(idxs), L), dtype=np.float32) for bi, i in enumerate(idxs): n = min(len(crops[i]), L) batch[bi, :n] = crops[i][:n] x = torch.from_numpy(batch).cuda() e = redimnet(x).float().cpu().numpy() for bi, i in enumerate(idxs): embs_out[i] = e[bi] else: # Per-batch dynamic length: sort desc so the longest is first → minimum pad waste. order = sorted(range(len(crops)), key=lambda i: -len(crops[i])) for start in range(0, len(order), batch_size): idxs = order[start:start+batch_size] L = max(len(crops[i]) for i in idxs) L = max(L, SR // 2) batch = np.zeros((len(idxs), L), dtype=np.float32) for bi, i in enumerate(idxs): n = min(len(crops[i]), L) batch[bi, :n] = crops[i][:n] x = torch.from_numpy(batch).cuda() e = redimnet(x).float().cpu().numpy() for bi, i in enumerate(idxs): embs_out[i] = e[bi] return np.stack(embs_out, axis=0), keep def cluster(embs, tau=CLUSTER_TAU, max_speakers=8): if len(embs) <= 1: return np.zeros(len(embs), dtype=int) e = embs / (np.linalg.norm(embs, axis=1, keepdims=True) + 1e-8) sim = np.clip(e @ e.T, -1, 1); dist = 1 - sim; np.fill_diagonal(dist, 0) Z = linkage(squareform(dist, checks=False), method="average") labels = fcluster(Z, t=tau, criterion="distance") uniq, counts = np.unique(labels, return_counts=True) if len(uniq) > max_speakers: keep = uniq[np.argsort(-counts)[:max_speakers]] cents = {int(k): (e[labels==k].mean(0) / (np.linalg.norm(e[labels==k].mean(0))+1e-8)) for k in keep} keep_set = set(int(x) for x in keep) for i in range(len(labels)): if int(labels[i]) not in keep_set: labels[i] = max(cents, key=lambda k: float(np.dot(e[i], cents[k]))) seen, remap = {}, [] for x in labels: if x not in seen: seen[x] = len(seen) remap.append(seen[x]) return np.asarray(remap, dtype=int) async def main(): ap = argparse.ArgumentParser() ap.add_argument("audio") ap.add_argument("--vllm", default="http://127.0.0.1:8000") ap.add_argument("--model", default="syvai/cohere-transcribe-diarize") ap.add_argument("--language", default="en", choices=list(PROMPT_BY_LANG)) ap.add_argument("--chunk-s", type=float, default=CHUNK_S) ap.add_argument("--overlap-s", type=float, default=OVERLAP_S) ap.add_argument("--tau", type=float, default=CLUSTER_TAU) ap.add_argument("--concurrency", type=int, default=32) ap.add_argument("--embed-batch", type=int, default=32) ap.add_argument("--compile", action="store_true", help="torch.compile(redim) — adds first-call overhead but is faster after") ap.add_argument("--fixed-len", type=float, default=0.0, help="pad/trim every crop to this many seconds before embedding (makes torch.compile reuse one CUDA graph)") ap.add_argument("--redimnet", default="b6") args = ap.parse_args() print(f"loading audio {args.audio}", flush=True) audio = load_audio(args.audio) dur = len(audio) / SR print(f" {dur:.1f}s", flush=True) print("decoding chunks via vLLM…", flush=True) t0 = time.time() chunks, results = await fetch_all_chunks( audio, args.vllm, args.model, args.language, args.chunk_s, args.overlap_s, args.concurrency, ) decode_s = time.time() - t0 print(f" decode: {decode_s:.2f}s ({dur/decode_s:.1f}× RTF)", flush=True) # Translate each chunk's segments into absolute time + chunk id all_segs = [] for (cid, t0c, t1c, _) in chunks: body = results.get(cid, {}) for seg in body.get("segments", []): abs_start = t0c + float(seg["start"]) abs_end = min(t0c + float(seg["end"]), len(audio)/SR) if abs_end - abs_start < MIN_SEG_S: continue all_segs.append({ "chunk_id": cid, "abs_start": abs_start, "abs_end": abs_end, "local_speaker": seg["speaker"], "text": seg["text"], }) print(f" parsed {len(all_segs)} segments") if not all_segs: print("(empty)") return print("embedding via ReDimNet2 b6…", flush=True) t1 = time.time() redim = load_redimnet(args.redimnet, compile_=args.compile) embs, segs = embed_segments(redim, audio, all_segs, batch_size=args.embed_batch, fixed_len_s=args.fixed_len) emb_s = time.time() - t1 print(f" embed: {emb_s:.2f}s for {len(segs)} segs ({len(segs)/emb_s:.1f} seg/s)", flush=True) labels = cluster(embs, tau=args.tau) for s, lab in zip(segs, labels): s["speaker"] = f"SPEAKER_{int(lab):02d}" segs.sort(key=lambda s: s["abs_start"]) print(f"\n=== {len(segs)} segments, {len(set(s['speaker'] for s in segs))} speakers, RTF total = {dur/(decode_s+emb_s):.1f}× ===") for s in segs: print(f"[{s['abs_start']:7.2f} – {s['abs_end']:7.2f}] {s['speaker']} {s['text'][:80]}") if __name__ == "__main__": asyncio.run(main())