"""Long-audio diarize+timestamps using hviske-style diar model + ReDimNet2 B6 for cross-chunk speaker linkage. Pipeline: 1. Slide a 28 s window with 2 s overlap over the input 2. For each window: run diar model → parse [{local_spk, start, end, text}] 3. For each parsed segment: extract waveform slice, run ReDimNet2 B6 → 192-D embedding 4. Cluster all (segment, embedding) globally with AHC (cosine distance, threshold τ) → map (chunk_id, local_spk) → global_speaker_id 5. Stitch into a single transcript over the full audio timeline Dependencies at inference time: - torch - transformers (for the diar model) - numpy, scipy - soundfile (audio I/O) - torch.hub auto-downloads ReDimNet2 (uses only torch) Usage: from diarize_long import diarize_long_audio segments = diarize_long_audio( audio="podcast.wav", diar_model_id="syvai/hviske-multilingual-diarize-ts", language="en", ) for s in segments: print(f"[{s['start']:6.2f} - {s['end']:6.2f}] SPK{s['speaker']:02d} {s['text']}") """ from __future__ import annotations import os, io, re, math from dataclasses import dataclass from typing import List, Optional, Union, Iterable, Dict, Tuple import numpy as np import torch import soundfile as sf from scipy.cluster.hierarchy import linkage, fcluster from scipy.spatial.distance import squareform # ── Defaults ──────────────────────────────────────────────────────────────── SR = 16000 CHUNK_S = 28.0 # window length (model trained on ≤ 30 s) OVERLAP_S = 2.0 # window overlap MIN_SEG_S = 0.20 # ignore segments shorter than this EMB_MIN_S = 0.50 # if a segment is shorter than this, expand symmetrically for embedding CLUSTER_TAU = 0.45 # cosine distance threshold for AHC (1 - cos_sim); 0.45 is a good default for podcast/panel audio (verified on a 126s YouTube clip — Bernie stays one ID across 5 chunks) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 # Parsing pattern matches the diar+ts model's segment template SEG_RE = re.compile(r"<\|spltoken(\d+)\|><\|t:(\d+\.\d+)\|>(.*?)<\|t:(\d+\.\d+)\|>", re.DOTALL) TOK_STRIP = re.compile(r"<\|[^|]+\|>") PROMPT_BY_LANG = { "da": "<|startofcontext|><|startoftranscript|><|emo:undefined|><|da|><|da|><|pnc|><|noitn|><|timestamp|><|diarize|>", "en": "<|startofcontext|><|startoftranscript|><|emo:undefined|><|en|><|en|><|pnc|><|noitn|><|timestamp|><|diarize|>", } # ── Data classes ──────────────────────────────────────────────────────────── @dataclass class Segment: start: float # seconds in full audio end: float local_spk: int # speaker index within the chunk (0..7) text: str chunk_id: int embedding: Optional[np.ndarray] = None # 192-D from ReDimNet2 speaker: int = -1 # filled in after clustering # ── Model loading ─────────────────────────────────────────────────────────── def load_diar_model(model_id: str, token: Optional[str] = None): from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq processor = AutoProcessor.from_pretrained(model_id, token=token) model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, token=token, dtype=DTYPE, low_cpu_mem_usage=True ).to(DEVICE).eval() return processor, model def load_redimnet(model_name: str = "b6", train_type: str = "lm"): """Download + load ReDimNet2 from torch.hub.""" model = torch.hub.load( "PalabraAI/redimnet2", "redimnet2", model_name=model_name, train_type=train_type, pretrained=True, source="github", trust_repo=True, ) model = model.to(DEVICE).eval() return model # ── Decoding ──────────────────────────────────────────────────────────────── def decode_chunk( processor, model, chunk_wav: np.ndarray, language: str, max_new_tokens: int = 400, ) -> str: """Return raw decoded text (with special tokens) for one ≤30 s chunk.""" prompt = PROMPT_BY_LANG[language] prompt_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt")["input_ids"].to(DEVICE) fe = processor.feature_extractor(chunk_wav.astype(np.float32), sampling_rate=SR, return_tensors="pt") with torch.inference_mode(): out = model.generate( input_features=fe["input_features"].to(DTYPE).to(DEVICE), attention_mask=torch.ones(fe["input_features"].shape[:2], device=DEVICE), decoder_input_ids=prompt_ids, max_new_tokens=max_new_tokens, do_sample=False, # repetition_penalty=1.2 lives in generation_config ) txt = processor.tokenizer.decode(out[0], skip_special_tokens=False) if "<|diarize|>" in txt: txt = txt.split("<|diarize|>", 1)[1] return txt.replace("<|endoftext|>", "") def parse_segments(text: str) -> List[Tuple[int, float, float, str]]: """Pull (spk, start_s, end_s, text) tuples out of decoded text.""" out = [] for m in SEG_RE.finditer(text): st = float(m.group(2)); ed = float(m.group(4)) if ed <= st: ed = st + 0.05 clean = TOK_STRIP.sub("", m.group(3)).strip() out.append((int(m.group(1)), st, ed, clean)) return out # ── Embedding extraction ──────────────────────────────────────────────────── def embed_segment( redimnet, audio: np.ndarray, start_s: float, end_s: float, ) -> np.ndarray: """Crop audio to [start, end] (with EMB_MIN_S padding if too short), embed.""" if end_s - start_s < EMB_MIN_S: mid = 0.5 * (start_s + end_s) half = max(0.5 * EMB_MIN_S, 0.5 * (end_s - start_s)) start_s = max(0.0, mid - half) end_s = min(len(audio) / SR, mid + half) s_idx = max(0, int(start_s * SR)) e_idx = min(len(audio), int(end_s * SR)) if e_idx - s_idx < int(SR * 0.05): return None wav = torch.from_numpy(audio[s_idx:e_idx].astype(np.float32)).unsqueeze(0).to(DEVICE) with torch.inference_mode(): emb = redimnet(wav) return emb.squeeze(0).float().cpu().numpy() # ── Clustering ────────────────────────────────────────────────────────────── def cluster_speakers( embs: np.ndarray, threshold: float = CLUSTER_TAU, max_speakers: int = 8 ) -> np.ndarray: """Cluster embeddings with AHC (average linkage on cosine distance). Cap clusters at max_speakers by reassigning small clusters to nearest centroid.""" if len(embs) <= 1: return np.zeros(len(embs), dtype=int) norms = np.linalg.norm(embs, axis=1, keepdims=True) + 1e-8 embs_n = embs / norms sim = embs_n @ embs_n.T sim = np.clip(sim, -1.0, 1.0) dist = 1.0 - sim np.fill_diagonal(dist, 0.0) cond = squareform(dist, checks=False) Z = linkage(cond, method="average") labels = fcluster(Z, t=threshold, criterion="distance") # If we have more than max_speakers clusters, keep the largest max_speakers # and reassign the rest to the nearest centroid. uniq, counts = np.unique(labels, return_counts=True) if len(uniq) > max_speakers: keep = uniq[np.argsort(-counts)[:max_speakers]] keep_set = set(int(x) for x in keep) cents = {} for k in keep: mask = labels == k c = embs_n[mask].mean(axis=0) c /= (np.linalg.norm(c) + 1e-8) cents[int(k)] = c for i in range(len(labels)): if int(labels[i]) not in keep_set: best_sim, best_k = -np.inf, int(keep[0]) for k, c in cents.items(): s = float(np.dot(embs_n[i], c)) if s > best_sim: best_sim, best_k = s, k labels[i] = best_k seen, remap = {}, [] for x in labels: if x not in seen: seen[x] = len(seen) remap.append(seen[x]) return np.asarray(remap, dtype=int) # ── Top-level pipeline ────────────────────────────────────────────────────── def _load_audio(audio: Union[str, np.ndarray, Tuple[np.ndarray, int]]) -> np.ndarray: if isinstance(audio, str): arr, sr = sf.read(audio) if arr.ndim > 1: arr = arr.mean(axis=1) if sr != SR: import librosa arr = librosa.resample(arr.astype(np.float32), orig_sr=sr, target_sr=SR) return arr.astype(np.float32) if isinstance(audio, tuple): arr, sr = audio if arr.ndim > 1: arr = arr.mean(axis=1) if sr != SR: import librosa arr = librosa.resample(arr.astype(np.float32), orig_sr=sr, target_sr=SR) return arr.astype(np.float32) arr = np.asarray(audio, dtype=np.float32) if arr.ndim > 1: arr = arr.mean(axis=1) return arr def diarize_long_audio( audio: Union[str, np.ndarray, Tuple[np.ndarray, int]], diar_model_id: str = "syvai/hviske-multilingual-diarize-ts", language: str = "en", redimnet_model_name: str = "b6", chunk_s: float = CHUNK_S, overlap_s: float = OVERLAP_S, cluster_threshold: float = CLUSTER_TAU, hf_token: Optional[str] = None, diar: Optional[Tuple] = None, redimnet=None, verbose: bool = False, ) -> List[Dict]: """Diarize long audio of any length. Returns list of segment dicts.""" if language not in PROMPT_BY_LANG: raise ValueError(f"language must be one of {list(PROMPT_BY_LANG)}") wav = _load_audio(audio) total_s = len(wav) / SR # Lazy model load (pass pre-loaded models if you want to reuse them across calls) if diar is None: diar = load_diar_model(diar_model_id, token=hf_token) processor, diar_model = diar if redimnet is None: redimnet = load_redimnet(redimnet_model_name) step = max(0.1, chunk_s - overlap_s) segs: List[Segment] = [] chunk_id = 0 t0 = 0.0 while t0 < total_s: t1 = min(t0 + chunk_s, total_s) # If the last chunk is too small, snap it back so we cover the tail if t1 - t0 < 2.0: t1 = total_s t0 = max(0.0, t1 - chunk_s) s_idx, e_idx = int(t0 * SR), int(t1 * SR) chunk_wav = wav[s_idx:e_idx] text = decode_chunk(processor, diar_model, chunk_wav, language=language) parsed = parse_segments(text) if verbose: print(f" chunk {chunk_id} [{t0:6.2f}-{t1:6.2f}] {len(parsed)} segs", flush=True) for (local_spk, st, ed, txt) in parsed: abs_start = t0 + st abs_end = min(t0 + ed, total_s) if abs_end - abs_start < MIN_SEG_S: continue seg = Segment(start=abs_start, end=abs_end, local_spk=local_spk, text=txt, chunk_id=chunk_id) seg.embedding = embed_segment(redimnet, wav, abs_start, abs_end) if seg.embedding is None: continue segs.append(seg) if t1 >= total_s: break t0 += step chunk_id += 1 if not segs: return [] # Dedupe overlapping segments from adjacent chunks (same speaker + > 50% IoU) segs = _dedupe_overlaps(segs) # Global cluster across all chunks embs = np.stack([s.embedding for s in segs], axis=0) labels = cluster_speakers(embs, threshold=cluster_threshold) for seg, lab in zip(segs, labels): seg.speaker = int(lab) segs.sort(key=lambda s: s.start) return [ {"start": s.start, "end": s.end, "speaker": s.speaker, "text": s.text, "chunk_id": s.chunk_id, "local_spk": s.local_spk} for s in segs ] def _dedupe_overlaps(segs: List[Segment]) -> List[Segment]: """If two segments from consecutive chunks overlap by > 50 % AND share the same local_spk pattern (same chunk-internal speaker emitting the same text), keep the one whose time is more central to its parent chunk.""" segs.sort(key=lambda s: (s.start, s.chunk_id)) keep = [] i = 0 while i < len(segs): cur = segs[i] j = i + 1 absorb_into = cur while j < len(segs) and segs[j].start < cur.end: other = segs[j] inter = max(0.0, min(cur.end, other.end) - max(cur.start, other.start)) union = max(cur.end, other.end) - min(cur.start, other.start) if union <= 0: j += 1; continue iou = inter / union # Same-speaker only if same chunk; otherwise text overlap heuristic same_text = (other.text[:20] == cur.text[:20]) if cur.text and other.text else False if iou > 0.5 and (other.chunk_id != cur.chunk_id) and same_text: # Drop `other` — duplicate from neighboring chunk segs[j] = None # type: ignore[assignment] j += 1 keep.append(absorb_into) # Advance past any j that we dropped i += 1 while i < len(segs) and segs[i] is None: i += 1 return [s for s in keep if s is not None] # ── CLI smoke test ────────────────────────────────────────────────────────── if __name__ == "__main__": import argparse, json, time ap = argparse.ArgumentParser() ap.add_argument("audio", help="Path to audio file") ap.add_argument("--ckpt", default="syvai/hviske-multilingual-diarize-ts") ap.add_argument("--lang", default="en", choices=["en", "da"]) ap.add_argument("--redimnet", default="b6") ap.add_argument("--tau", type=float, default=CLUSTER_TAU) ap.add_argument("--verbose", action="store_true") ap.add_argument("--token", default=os.environ.get("HF_TOKEN")) args = ap.parse_args() t0 = time.time() segs = diarize_long_audio( args.audio, diar_model_id=args.ckpt, language=args.lang, redimnet_model_name=args.redimnet, cluster_threshold=args.tau, hf_token=args.token, verbose=args.verbose, ) el = time.time() - t0 print(f"\n=== {len(segs)} segments in {el:.1f}s ===") for s in segs: print(f"[{s['start']:7.2f} - {s['end']:7.2f}] SPK{s['speaker']:02d} {s['text']}")