""" Streaming AMD (Answering Machine Detection) Classifier using Whisper. Processes PCM audio chunks in real-time and outputs classification once confident. Uses Whisper encoder (speech understanding) — critical for distinguishing human-recorded voicemail greetings from live human speech. Architecture: WhisperForAudioClassification with accumulating buffer. - Accepts 8kHz or 16kHz PCM audio chunks - Maintains internal buffer (up to 10s rolling window) - Runs inference every N ms, outputs (label, confidence) when threshold met - Class-specific thresholds for optimal early detection Usage: from streaming_amd import StreamingAMDClassifier classifier = StreamingAMDClassifier("AbijahKaj/whisper-telephony-amd") # Process chunks as they arrive from telephony stream for pcm_chunk in audio_stream: result = classifier.process_chunk(pcm_chunk) if result is not None: label, confidence, elapsed_s = result print(f"Detected: {label} ({confidence:.2f}) after {elapsed_s:.1f}s") break """ import numpy as np import torch from typing import Optional, Tuple, List, Dict from dataclasses import dataclass, field from transformers import AutoFeatureExtractor, WhisperForAudioClassification @dataclass class AMDConfig: """Configuration for streaming AMD classifier.""" model_id: str = "AbijahKaj/whisper-telephony-amd" device: str = "cpu" # Audio input_sample_rate: int = 8000 # Telephony standard model_sample_rate: int = 16000 # Whisper expects 16kHz # Streaming chunk_duration_ms: int = 160 # Telephony frame (160ms) min_audio_ms: int = 800 # Min audio before first inference inference_interval_ms: int = 500 # Run inference every 500ms max_audio_ms: int = 10000 # Max 10s buffer # Confidence thresholds (per-class) thresholds: Dict[str, float] = field(default_factory=lambda: { "human": 0.80, "voicemail": 0.75, "ivr": 0.70, # IVR has distinctive patterns "answering_machine": 0.75, }) min_consecutive: int = 2 # Require N consecutive same-class predictions global_threshold: float = 0.90 # Any class above this → immediate decision @dataclass class StreamingState: """Internal state for streaming inference.""" audio_buffer: List[np.ndarray] = field(default_factory=list) total_samples: int = 0 inference_count: int = 0 prediction_history: List[Tuple[str, float]] = field(default_factory=list) consecutive_counts: Dict[str, int] = field(default_factory=lambda: { "human": 0, "voicemail": 0, "ivr": 0, "answering_machine": 0 }) elapsed_samples: int = 0 class StreamingAMDClassifier: """Real-time streaming AMD classifier using Whisper encoder.""" def __init__(self, config: Optional[AMDConfig] = None, model_id: Optional[str] = None): if config is None: config = AMDConfig() if model_id: config.model_id = model_id self.config = config self.state = StreamingState() print(f"Loading AMD model: {config.model_id}") self.feature_extractor = AutoFeatureExtractor.from_pretrained(config.model_id) self.model = WhisperForAudioClassification.from_pretrained(config.model_id) self.model.to(config.device) self.model.eval() self._resample_ratio = config.model_sample_rate / config.input_sample_rate self._min_samples = int(config.min_audio_ms / 1000 * config.input_sample_rate) self._interval_samples = int(config.inference_interval_ms / 1000 * config.input_sample_rate) self._max_samples = int(config.max_audio_ms / 1000 * config.input_sample_rate) self._since_inference = 0 print(f"Ready. Device={config.device}, input={config.input_sample_rate}Hz") def reset(self): self.state = StreamingState() self._since_inference = 0 def _resample(self, audio: np.ndarray) -> np.ndarray: if self.config.input_sample_rate == self.config.model_sample_rate: return audio n = len(audio) out_n = int(n * self._resample_ratio) return np.interp(np.linspace(0, n-1, out_n), np.arange(n), audio).astype(np.float32) @torch.no_grad() def _infer(self, audio: np.ndarray) -> Tuple[str, float, np.ndarray]: audio_16k = self._resample(audio) inputs = self.feature_extractor( audio_16k, sampling_rate=self.config.model_sample_rate, return_tensors="pt", padding="max_length", max_length=self.config.max_audio_ms // 1000 * self.config.model_sample_rate, truncation=True, ) inputs = {k: v.to(self.config.device) for k, v in inputs.items()} logits = self.model(**inputs).logits probs = torch.softmax(logits, dim=-1)[0].cpu().numpy() idx = int(np.argmax(probs)) label = self.model.config.id2label[str(idx)] return label, float(probs[idx]), probs def _confident(self, label: str, conf: float) -> bool: if conf >= self.config.global_threshold: return True threshold = self.config.thresholds.get(label, 0.80) if conf >= threshold and self.state.consecutive_counts[label] >= self.config.min_consecutive - 1: return True return False def process_chunk(self, pcm: np.ndarray, sample_rate: Optional[int] = None) -> Optional[Tuple[str, float, float]]: """ Process a PCM audio chunk. Args: pcm: Audio samples (int16 or float32) sample_rate: Override sample rate Returns: None if not yet confident, or (label, confidence, elapsed_seconds) """ if pcm.dtype == np.int16: pcm = pcm.astype(np.float32) / 32768.0 if sample_rate and sample_rate != self.config.input_sample_rate: self.config.input_sample_rate = sample_rate self._resample_ratio = self.config.model_sample_rate / sample_rate self._min_samples = int(self.config.min_audio_ms / 1000 * sample_rate) self._interval_samples = int(self.config.inference_interval_ms / 1000 * sample_rate) self._max_samples = int(self.config.max_audio_ms / 1000 * sample_rate) self.state.audio_buffer.append(pcm) self.state.total_samples += len(pcm) self.state.elapsed_samples += len(pcm) self._since_inference += len(pcm) if self.state.total_samples < self._min_samples: return None if self._since_inference < self._interval_samples: return None self._since_inference = 0 full = np.concatenate(self.state.audio_buffer) if len(full) > self._max_samples: full = full[-self._max_samples:] label, conf, probs = self._infer(full) self.state.inference_count += 1 self.state.prediction_history.append((label, conf)) for cls in self.state.consecutive_counts: self.state.consecutive_counts[cls] = self.state.consecutive_counts[cls] + 1 if cls == label else 0 if self._confident(label, conf) or self.state.total_samples >= self._max_samples: return (label, conf, self.state.elapsed_samples / self.config.input_sample_rate) return None def get_current(self) -> Optional[Tuple[str, float]]: return self.state.prediction_history[-1] if self.state.prediction_history else None def elapsed_ms(self) -> float: return self.state.elapsed_samples / self.config.input_sample_rate * 1000 def simulate_call(audio: np.ndarray, sr: int = 8000, model_id: str = "AbijahKaj/whisper-telephony-amd", chunk_ms: int = 160) -> dict: """Simulate streaming AMD on a complete audio array.""" config = AMDConfig(model_id=model_id, input_sample_rate=sr, chunk_duration_ms=chunk_ms) clf = StreamingAMDClassifier(config=config) chunk_n = int(chunk_ms / 1000 * sr) for i in range(0, len(audio), chunk_n): chunk = audio[i:i + chunk_n] if len(chunk) == 0: break result = clf.process_chunk(chunk) if result: label, conf, elapsed = result return {"label": label, "confidence": conf, "elapsed_ms": elapsed * 1000, "inferences": clf.state.inference_count, "history": clf.state.prediction_history} cur = clf.get_current() if cur: return {"label": cur[0], "confidence": cur[1], "elapsed_ms": clf.elapsed_ms(), "inferences": clf.state.inference_count, "note": "max audio reached"} return {"label": "unknown", "confidence": 0.0, "elapsed_ms": 0.0}