"""Custom audio tracks for WebRTC voice pipeline integration.""" import asyncio import logging import fractions from typing import Optional import numpy as np from av import AudioFrame from aiortc import MediaStreamTrack from aiortc.mediastreams import MediaStreamError logger = logging.getLogger(__name__) # WebRTC audio constants WEBRTC_SAMPLE_RATE = 48000 WHISPER_SAMPLE_RATE = 16000 TTS_SAMPLE_RATE = 24000 CHANNELS = 1 SAMPLES_PER_FRAME = 960 # 20ms at 48kHz class AudioInputTrack: """ Wrapper that receives audio from WebRTC and forwards to voice pipeline. Accumulates audio frames from WebRTC track and provides them via an async generator, resampled for Whisper (16kHz). """ def __init__(self, track: MediaStreamTrack): self._track = track self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=100) self._running = True self._task: Optional[asyncio.Task] = None async def start(self): """Start receiving audio frames from WebRTC track.""" self._task = asyncio.create_task(self._receive_loop()) logger.info("AudioInputTrack started") async def _receive_loop(self): """Receive audio frames from WebRTC track and queue them.""" try: while self._running: try: frame = await self._track.recv() # Convert AudioFrame to PCM bytes (resampled for Whisper) pcm_data = self._frame_to_pcm(frame) if pcm_data: try: self._queue.put_nowait(pcm_data) except asyncio.QueueFull: # Drop oldest frame if queue is full try: self._queue.get_nowait() self._queue.put_nowait(pcm_data) except asyncio.QueueEmpty: pass except MediaStreamError: logger.info("Media stream ended") break except asyncio.CancelledError: pass except Exception as e: logger.error(f"Audio receive error: {e}", exc_info=True) finally: self._running = False def _frame_to_pcm(self, frame: AudioFrame) -> bytes: """ Convert AudioFrame to PCM bytes suitable for Whisper. WebRTC: 48kHz -> Whisper: 16kHz (downsample by 3) """ try: # Get numpy array from frame audio_array = frame.to_ndarray() # Handle stereo -> mono if needed if len(audio_array.shape) > 1: if audio_array.shape[0] > 1: # Multiple channels - take mean audio_array = audio_array.mean(axis=0) else: # Single channel but 2D audio_array = audio_array[0] # Flatten audio_array = audio_array.flatten() # Normalize to [-1, 1] before float conversion (check dtype before casting) if audio_array.dtype == np.int16: audio_array = audio_array.astype(np.float32) / 32768.0 else: audio_array = audio_array.astype(np.float32) # If values are in int16 range, normalize if audio_array.size > 0 and (np.abs(audio_array).max() > 1.0): audio_array = np.clip(audio_array, -32768.0, 32767.0) / 32768.0 # Resample 48kHz -> 16kHz (simple decimation by 3) # For production, consider using scipy.signal.resample_poly for better quality resample_ratio = WHISPER_SAMPLE_RATE / WEBRTC_SAMPLE_RATE new_length = int(len(audio_array) * resample_ratio) if new_length > 0: # Simple linear interpolation resampling indices = np.linspace(0, len(audio_array) - 1, new_length) resampled = np.interp(indices, np.arange(len(audio_array)), audio_array) # Convert to 16-bit PCM pcm_data = (resampled * 32767).astype(np.int16).tobytes() return pcm_data return b"" except Exception as e: logger.error(f"Frame conversion error: {e}") return b"" async def audio_generator(self): """ Async generator yielding audio chunks for voice pipeline. Yields: bytes: PCM audio chunks (16-bit, 16kHz, mono) """ while self._running: try: chunk = await asyncio.wait_for(self._queue.get(), timeout=1.0) yield chunk except asyncio.TimeoutError: # No audio received, but keep waiting if still running continue except Exception as e: logger.error(f"Audio generator error: {e}") break # Drain remaining queue while not self._queue.empty(): try: chunk = self._queue.get_nowait() yield chunk except asyncio.QueueEmpty: break async def stop(self): """Stop receiving audio.""" self._running = False if self._task: self._task.cancel() try: await self._task except asyncio.CancelledError: pass logger.info("AudioInputTrack stopped") class AudioOutputTrack(MediaStreamTrack): """ MediaStreamTrack that sends audio from voice pipeline to WebRTC. Receives PCM audio from TTS (24kHz) and converts to WebRTC format (48kHz). """ kind = "audio" def __init__(self): super().__init__() self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=50) self._timestamp = 0 self._sample_rate = WEBRTC_SAMPLE_RATE self._samples_per_frame = SAMPLES_PER_FRAME self._silence_frame = self._create_silence_frame() def _create_silence_frame(self) -> bytes: """Create a silence frame for when no audio is available.""" return bytes(self._samples_per_frame * 2) # 16-bit = 2 bytes per sample async def add_audio(self, pcm_data: bytes, source_sample_rate: int = TTS_SAMPLE_RATE): """ Add PCM audio to be sent via WebRTC. Args: pcm_data: PCM 16-bit audio bytes source_sample_rate: Sample rate of input (TTS is typically 24kHz) """ if not pcm_data: return # Resample if needed (24kHz -> 48kHz for WebRTC) if source_sample_rate != self._sample_rate: pcm_data = self._resample(pcm_data, source_sample_rate, self._sample_rate) # Split into frame-sized chunks chunk_size = self._samples_per_frame * 2 # 2 bytes per sample for i in range(0, len(pcm_data), chunk_size): chunk = pcm_data[i:i + chunk_size] try: self._queue.put_nowait(chunk) except asyncio.QueueFull: # Drop oldest if queue is full try: self._queue.get_nowait() self._queue.put_nowait(chunk) except asyncio.QueueEmpty: pass def _resample(self, pcm_data: bytes, src_rate: int, dst_rate: int) -> bytes: """ Resample audio from source to destination sample rate. Args: pcm_data: PCM 16-bit audio bytes src_rate: Source sample rate dst_rate: Destination sample rate Returns: Resampled PCM bytes """ try: # Convert to numpy audio = np.frombuffer(pcm_data, dtype=np.int16).astype(np.float32) / 32767 # Calculate resampling ratio ratio = dst_rate / src_rate new_length = int(len(audio) * ratio) if new_length == 0: return b"" # Linear interpolation resampling indices = np.linspace(0, len(audio) - 1, new_length) resampled = np.interp(indices, np.arange(len(audio)), audio) # Convert back to 16-bit PCM return (resampled * 32767).astype(np.int16).tobytes() except Exception as e: logger.error(f"Resample error: {e}") return pcm_data async def recv(self) -> AudioFrame: """ Called by aiortc to get the next audio frame. Returns: AudioFrame ready to be sent via WebRTC """ try: # Try to get audio data, use silence if none available try: pcm_data = await asyncio.wait_for(self._queue.get(), timeout=0.02) except asyncio.TimeoutError: pcm_data = self._silence_frame # Ensure correct frame size if len(pcm_data) < self._samples_per_frame * 2: # Pad with silence pcm_data = pcm_data + bytes(self._samples_per_frame * 2 - len(pcm_data)) elif len(pcm_data) > self._samples_per_frame * 2: # Truncate pcm_data = pcm_data[:self._samples_per_frame * 2] # Convert to AudioFrame audio_array = np.frombuffer(pcm_data, dtype=np.int16) # Create AudioFrame frame = AudioFrame(format="s16", layout="mono", samples=self._samples_per_frame) frame.sample_rate = self._sample_rate frame.pts = self._timestamp frame.time_base = fractions.Fraction(1, self._sample_rate) # Copy data to frame frame.planes[0].update(audio_array.tobytes()) self._timestamp += self._samples_per_frame return frame except Exception as e: logger.error(f"Audio output recv error: {e}") raise MediaStreamError() def clear_queue(self): """Clear the audio queue (e.g., on interrupt).""" while not self._queue.empty(): try: self._queue.get_nowait() except asyncio.QueueEmpty: break