/** * LFM2-Audio Model Runner for ONNX Runtime Web * * Runs audio model inference using ONNX models: * 1. decoder.onnx - LFM2 backbone (shared with text) * 2. audio_encoder.onnx - Conformer encoder for ASR (mel → embeddings) * 3. audio_embedding.onnx - Audio code embeddings for TTS * 4. audio_detokenizer.onnx - Audio codes → STFT features * 5. vocoder_depthformer.onnx - Autoregressive codebook prediction * * Supports ASR mode for the webapp (transcription). */ import * as ort from 'onnxruntime-web'; import { AutoTokenizer, env } from '@huggingface/transformers'; import { loadMelConfig, computeMelSpectrogram, loadAudioFile } from './audio-processor.js'; // Cache configuration const CACHE_NAME = 'onnx-models-v1'; const IDB_NAME = 'onnx-model-cache'; const IDB_STORE = 'models'; // IndexedDB helpers for fallback caching let idbPromise = null; function openIDB() { if (idbPromise) return idbPromise; idbPromise = new Promise((resolve, reject) => { const request = indexedDB.open(IDB_NAME, 1); request.onerror = () => reject(request.error); request.onsuccess = () => resolve(request.result); request.onupgradeneeded = (event) => { const db = event.target.result; if (!db.objectStoreNames.contains(IDB_STORE)) { db.createObjectStore(IDB_STORE); } }; }); return idbPromise; } async function idbGet(key) { try { const db = await openIDB(); return new Promise((resolve, reject) => { const tx = db.transaction(IDB_STORE, 'readonly'); const store = tx.objectStore(IDB_STORE); const request = store.get(key); request.onerror = () => reject(request.error); request.onsuccess = () => resolve(request.result); }); } catch (e) { return null; } } async function idbSet(key, value) { try { const db = await openIDB(); return new Promise((resolve, reject) => { const tx = db.transaction(IDB_STORE, 'readwrite'); const store = tx.objectStore(IDB_STORE); const request = store.put(value, key); request.onerror = () => reject(request.error); request.onsuccess = () => resolve(); }); } catch (e) { // Ignore cache write failures } } // Special tokens for audio model const SPECIAL_TOKENS = { AUDIO_START: 128, // <|audio_start|> TEXT_START: 129, // <|text_start|> TEXT_END: 130, // <|text_end|> MIXED_START: 131, // <|mixed_start|> MIXED_END: 132, // <|mixed_end|> IM_END: 7, // <|im_end|> }; // Audio codebook constants const NUM_CODEBOOKS = 8; const CODEBOOK_VOCAB = 2049; const END_OF_AUDIO_TOKEN = 2048; // Default system prompts (matching Python lfm2-audio-infer) const DEFAULT_SYSTEM_PROMPT_ASR = 'Perform ASR.'; const DEFAULT_SYSTEM_PROMPT_TTS = 'Perform TTS. Use the UK female voice.'; const DEFAULT_SYSTEM_PROMPT_INTERLEAVED = 'Respond with interleaved text and audio.'; // Max tokens defaults (matching liquid-audio) // Each audio frame = 80ms (6x upsampling in detokenizer, 320 hop, 24kHz) // 1024 frames ≈ 82 seconds of audio const DEFAULT_MAX_TOKENS_AUDIO = 1024; // TTS and interleaved modes const DEFAULT_MAX_TOKENS_TEXT = 100; // ASR mode // Timestamped logging helper let _logStartTime = null; function log(...args) { if (_logStartTime === null) { _logStartTime = performance.now(); } const elapsed = ((performance.now() - _logStartTime) / 1000).toFixed(2); console.log(`[${elapsed}s]`, ...args); } function logReset() { _logStartTime = performance.now(); } /** * Fetch with caching support using Cache API or IndexedDB fallback */ async function fetchWithCache(url, options = {}) { if (!url.startsWith('http://') && !url.startsWith('https://')) { return fetch(url, options); } const fileName = url.split('/').pop(); // Try Cache API first if (typeof caches !== 'undefined') { try { const cache = await caches.open(CACHE_NAME); const cached = await cache.match(url); if (cached) { console.log(`[Cache HIT] ${fileName}`); return cached; } console.log(`[Cache MISS] Fetching ${fileName}...`); const response = await fetch(url, options); if (response.ok) { cache.put(url, response.clone()); } return response; } catch (e) { // Fall through to IndexedDB } } // Try IndexedDB fallback if (typeof indexedDB !== 'undefined') { try { const cached = await idbGet(url); if (cached) { console.log(`[IDB Cache HIT] ${fileName}`); return new Response(cached.data, { status: 200, headers: { 'Content-Type': cached.contentType || 'application/octet-stream' }, }); } console.log(`[IDB Cache MISS] Fetching ${fileName}...`); const response = await fetch(url, options); if (response.ok) { const clone = response.clone(); const data = await clone.arrayBuffer(); const contentType = response.headers.get('Content-Type') || 'application/octet-stream'; await idbSet(url, { data, contentType }); } return response; } catch (e) { console.warn('IndexedDB cache failed:', e); } } // Direct fetch as last resort console.log(`[No Cache] Fetching ${fileName}...`); return fetch(url, options); } /** * Clear the model cache (both Cache API and IndexedDB) */ export async function clearModelCache() { let deleted = false; // Clear Cache API if (typeof caches !== 'undefined') { try { deleted = await caches.delete(CACHE_NAME); } catch (e) { // Ignore } } // Clear IndexedDB if (typeof indexedDB !== 'undefined') { try { await new Promise((resolve, reject) => { const request = indexedDB.deleteDatabase(IDB_NAME); request.onerror = () => reject(request.error); request.onsuccess = () => resolve(); }); idbPromise = null; // Reset the cached promise deleted = true; } catch (e) { // Ignore } } console.log(deleted ? 'Model cache cleared' : 'No cache to clear'); return deleted; } /** * Get cache storage usage info */ export async function getCacheInfo() { if ('storage' in navigator && 'estimate' in navigator.storage) { const estimate = await navigator.storage.estimate(); return { used: estimate.usage || 0, available: estimate.quota || 0, }; } return null; } /** * Load tokenizer from model path */ async function loadTokenizerFromPath(modelPath) { const isRemote = modelPath.startsWith('http://') || modelPath.startsWith('https://'); console.log(`Loading tokenizer from ${isRemote ? 'remote' : 'local'}: ${modelPath}`); const fetchOptions = isRemote ? { mode: 'cors', credentials: 'omit' } : {}; const [tokenizerResponse, configResponse] = await Promise.all([ fetchWithCache(`${modelPath}/tokenizer.json`, fetchOptions), fetchWithCache(`${modelPath}/tokenizer_config.json`, fetchOptions), ]); if (!tokenizerResponse.ok) { throw new Error(`Failed to fetch tokenizer.json: ${tokenizerResponse.status}`); } if (!configResponse.ok) { throw new Error(`Failed to fetch tokenizer_config.json: ${configResponse.status}`); } const tokenizerJSON = await tokenizerResponse.text(); const configJSON = await configResponse.text(); // Parse tokenizer.json to extract special token IDs const tokenizerData = JSON.parse(tokenizerJSON); const specialTokens = {}; if (tokenizerData.added_tokens) { for (const token of tokenizerData.added_tokens) { specialTokens[token.content] = token.id; } console.log('Found special tokens:', Object.keys(specialTokens).length); } // Create tokenizer using transformers.js const fakeModelId = `tokenizer-${Date.now()}`; const fileCache = { 'tokenizer.json': tokenizerJSON, 'tokenizer_config.json': configJSON, }; const originalFetch = globalThis.fetch; globalThis.fetch = async (input, init) => { const url = typeof input === 'string' ? input : input.url; if (url.includes(fakeModelId)) { for (const [filename, content] of Object.entries(fileCache)) { if (url.includes(filename)) { return new Response(content, { status: 200, headers: { 'Content-Type': 'application/json' }, }); } } return new Response('Not found', { status: 404 }); } return originalFetch(input, init); }; const originalAllowLocal = env.allowLocalModels; env.allowLocalModels = false; try { const tokenizer = await AutoTokenizer.from_pretrained(fakeModelId); console.log('Tokenizer created successfully'); return { tokenizer, specialTokens }; } finally { globalThis.fetch = originalFetch; env.allowLocalModels = originalAllowLocal; } } export class AudioModel { constructor() { this.tokenizer = null; this.decoderSession = null; this.audioEncoderSession = null; this.audioEmbeddingSession = null; this.audioEmbeddingWeight = null; // Direct lookup (faster than ONNX) this.audioDetokenizerSession = null; this.vocoderSession = null; this.config = null; this.embedTokensWeight = null; // Model config this.hiddenSize = 2048; this.numLayers = 16; this.numKVHeads = 8; this.headDim = 64; this.convL = 3; this.layerTypes = []; this.vocabSize = 65536; // === Stateful cache for multi-turn conversation === this.cache = null; this.cacheSeqLen = 0; } /** * Reset conversation state (KV cache). * Call this to start a new conversation. */ reset() { this.cache = null; this.cacheSeqLen = 0; log('Conversation state reset'); } /** * Load the audio model from a directory * @param {string} modelPath - Path to model directory * @param {object} options - Loading options */ async load(modelPath, options = {}) { const { progressCallback, device = 'webgpu', quantization = null } = options; const report = (status, progress = 0, file = '') => { if (progressCallback) { progressCallback({ status, progress, file }); } }; const executionProviders = device === 'webgpu' ? ['webgpu', 'wasm'] : ['wasm']; try { // Load mel config for audio processing await loadMelConfig(modelPath); // Load tokenizer report('loading', 0, 'tokenizer'); const { tokenizer } = await loadTokenizerFromPath(modelPath); this.tokenizer = tokenizer; // Load config report('loading', 5, 'config'); const configResponse = await fetch(`${modelPath}/config.json`, { mode: 'cors', credentials: 'omit', }); this.config = await configResponse.json(); // Extract model dimensions from config const lfmConfig = this.config.lfm || {}; this.hiddenSize = lfmConfig.hidden_size || 2048; this.numLayers = lfmConfig.num_hidden_layers || 16; this.numKVHeads = lfmConfig.num_key_value_heads || 8; this.headDim = Math.floor(this.hiddenSize / (lfmConfig.num_attention_heads || 32)); this.convL = lfmConfig.conv_L_cache || 3; this.layerTypes = lfmConfig.layer_types || []; this.vocabSize = lfmConfig.vocab_size || 65536; console.log('Model config:', { hiddenSize: this.hiddenSize, numLayers: this.numLayers, numKVHeads: this.numKVHeads, headDim: this.headDim, }); // Parse quantization config const quantConfig = typeof quantization === 'object' ? quantization : { decoder: quantization, audioEncoder: quantization, audioEmbedding: quantization, audioDetokenizer: quantization, vocoder: quantization, }; // Helper to load ONNX model with external data const loadOnnxWithExternalData = async (name, progress, quantSuffix = null, extraOptions = {}) => { const suffix = quantSuffix ? `_${quantSuffix}` : ''; const fileName = `${name}${suffix}`; report('loading', progress, `${fileName}.onnx`); const onnxPath = `${modelPath}/onnx/${fileName}.onnx`; const fetchOptions = { mode: 'cors', credentials: 'omit' }; console.log(`Loading ${fileName}...`); const sessionOptions = { executionProviders, ...extraOptions }; const onnxResponse = await fetchWithCache(onnxPath, fetchOptions); if (!onnxResponse.ok) { throw new Error(`Failed to fetch ${fileName}.onnx: ${onnxResponse.status}`); } const onnxBuffer = await onnxResponse.arrayBuffer(); console.log(`Loaded ${fileName}.onnx: ${(onnxBuffer.byteLength / 1024 / 1024).toFixed(1)} MB`); // Load external data files sessionOptions.externalData = []; // Try single .onnx_data file const singleDataPath = `${modelPath}/onnx/${fileName}.onnx_data`; try { const dataResponse = await fetchWithCache(singleDataPath, fetchOptions); const contentType = dataResponse.headers.get('content-type') || ''; if (dataResponse.ok && !contentType.includes('text/html')) { const dataBuffer = await dataResponse.arrayBuffer(); if (dataBuffer.byteLength > 1000) { // Sanity check console.log(`Loaded ${fileName}.onnx_data: ${(dataBuffer.byteLength / 1024 / 1024).toFixed(1)} MB`); sessionOptions.externalData.push({ path: `${fileName}.onnx_data`, data: new Uint8Array(dataBuffer), }); } } } catch (e) { // File doesn't exist } // Try numbered files - stop on first 404 for (let i = 1; ; i++) { const numberedDataPath = `${modelPath}/onnx/${fileName}.onnx_data_${i}`; const dataResponse = await fetch(numberedDataPath, fetchOptions); if (dataResponse.status === 404 || !dataResponse.ok) break; const contentType = dataResponse.headers.get('content-type') || ''; if (contentType.includes('text/html')) break; const dataBuffer = await dataResponse.arrayBuffer(); if (dataBuffer.byteLength < 1000) break; console.log(`Loaded ${fileName}.onnx_data_${i}: ${(dataBuffer.byteLength / 1024 / 1024).toFixed(1)} MB`); sessionOptions.externalData.push({ path: `${fileName}.onnx_data_${i}`, data: new Uint8Array(dataBuffer), }); } if (sessionOptions.externalData.length === 0) { delete sessionOptions.externalData; } const session = await ort.InferenceSession.create(new Uint8Array(onnxBuffer), sessionOptions); console.log(`Session created for ${fileName}`); return session; }; // Load decoder // On WebGPU: keep KV cache outputs on GPU to avoid GPU→CPU→GPU roundtrips between steps const decoderOpts = device === 'webgpu' ? (() => { const loc = {}; for (let i = 0; i < this.layerTypes.length; i++) { if (this.layerTypes[i] === 'conv') { loc[`present_conv.${i}`] = 'gpu-buffer'; } else { loc[`present.${i}.key`] = 'gpu-buffer'; loc[`present.${i}.value`] = 'gpu-buffer'; } } return { preferredOutputLocation: loc }; })() : {}; this.decoderSession = await loadOnnxWithExternalData('decoder', 10, quantConfig.decoder, decoderOpts); // Load embed_tokens weight for text embedding lookup report('loading', 30, 'embed_tokens'); this.embedTokensWeight = await this.loadEmbedTokensWeight(modelPath); // Load audio encoder (for ASR) this.audioEncoderSession = await loadOnnxWithExternalData('audio_encoder', 50, quantConfig.audioEncoder); // Load audio embedding (for TTS) - try binary first, fallback to ONNX report('loading', 65, 'audio_embedding'); this.audioEmbeddingWeight = await this.loadAudioEmbeddingWeight(modelPath); if (!this.audioEmbeddingWeight) { // Fallback to ONNX model this.audioEmbeddingSession = await loadOnnxWithExternalData('audio_embedding', 70, quantConfig.audioEmbedding); } else { console.log('Using direct audio embedding lookup (binary)'); } // Load audio detokenizer (for TTS output) try { this.audioDetokenizerSession = await loadOnnxWithExternalData('audio_detokenizer', 85, quantConfig.audioDetokenizer); } catch (e) { console.warn('Audio detokenizer not available:', e); } // Load vocoder (for TTS) // On WebGPU: keep KV cache on GPU to avoid GPU→CPU→GPU roundtrips between steps try { const vocoderOpts = device === 'webgpu' ? { preferredOutputLocation: { new_keys: 'gpu-buffer', new_values: 'gpu-buffer', depth_slices: 'gpu-buffer' } } : {}; this.vocoderSession = await loadOnnxWithExternalData('vocoder_depthformer', 95, quantConfig.vocoder, vocoderOpts); } catch (e) { console.warn('Vocoder not available:', e); } report('done', 100, ''); return true; } catch (error) { let errorMessage = error; if (typeof error === 'number') { errorMessage = `ONNX Runtime error code: ${error}`; } else if (error instanceof Error) { errorMessage = error.message; } console.error('Failed to load audio model:', errorMessage); throw new Error(errorMessage); } } /** * Get audio embeddings for given token indices and sum across codebooks * * Uses direct array indexing if binary weight available (fast), * falls back to ONNX session otherwise. * * @param {number[]} audioTokens - Array of 8 token indices (one per codebook) * @returns {Float32Array} Summed embedding [hiddenSize] */ async getAudioEmbedding(audioTokens) { const NUM_CODEBOOKS = 8; const hiddenSize = this.hiddenSize; const summedEmbeds = new Float32Array(hiddenSize); if (this.audioEmbeddingWeight) { // Direct lookup (much faster - no ONNX call) const weight = this.audioEmbeddingWeight.weight; for (let cb = 0; cb < NUM_CODEBOOKS; cb++) { const tokenIdx = audioTokens[cb]; const offset = tokenIdx * hiddenSize; for (let h = 0; h < hiddenSize; h++) { summedEmbeds[h] += weight[offset + h]; } } } else { // Fallback to ONNX session const audioTokensTensor = new ort.Tensor('int64', new BigInt64Array(audioTokens.map(BigInt)), [1, NUM_CODEBOOKS]); const result = await this.audioEmbeddingSession.run({ audio_codes: audioTokensTensor }); const embeddings = result.audio_embeds.data; for (let cb = 0; cb < NUM_CODEBOOKS; cb++) { for (let h = 0; h < hiddenSize; h++) { summedEmbeds[h] += embeddings[cb * hiddenSize + h]; } } } return summedEmbeds; } /** * Load audio_embedding.weight from raw binary file for direct lookup * * This eliminates ONNX model calls (352 per generation → 0). * Falls back to ONNX session if binary not available. */ async loadAudioEmbeddingWeight(modelPath) { const fetchOptions = { mode: 'cors', credentials: 'omit' }; try { // Load metadata const metaResponse = await fetchWithCache(`${modelPath}/onnx/audio_embedding.json`, fetchOptions); if (!metaResponse.ok) { console.log('audio_embedding.json not found, will use ONNX model'); return null; } const meta = await metaResponse.json(); console.log('audio_embedding metadata:', meta); // Load binary weight const binResponse = await fetchWithCache(`${modelPath}/onnx/audio_embedding.bin`, fetchOptions); if (!binResponse.ok) { console.log('audio_embedding.bin not found, will use ONNX model'); return null; } const buffer = await binResponse.arrayBuffer(); const weight = new Float32Array(buffer); if (weight.length !== meta.vocab_size * meta.hidden_size) { console.error('audio_embedding size mismatch:', weight.length, 'expected:', meta.vocab_size * meta.hidden_size); return null; } console.log(`Loaded audio_embedding: [${meta.vocab_size}, ${meta.hidden_size}] (${(buffer.byteLength / 1e6).toFixed(1)} MB)`); return { weight, vocabSize: meta.vocab_size, hiddenSize: meta.hidden_size }; } catch (e) { console.log('Failed to load audio_embedding.bin:', e); return null; } } /** * Load embed_tokens.weight from raw binary file for text embedding lookup * * The Python export saves embed_tokens.weight as: * - embed_tokens.bin: raw float32 binary (vocab_size * hidden_size * 4 bytes) * - embed_tokens.json: metadata (vocab_size, hidden_size) */ async loadEmbedTokensWeight(modelPath) { const fetchOptions = { mode: 'cors', credentials: 'omit' }; // Load metadata const metaResponse = await fetchWithCache(`${modelPath}/onnx/embed_tokens.json`, fetchOptions); if (!metaResponse.ok) { console.warn('embed_tokens.json not found, TTS will be unavailable'); return null; } const meta = await metaResponse.json(); console.log('embed_tokens metadata:', meta); // Load binary weight const binResponse = await fetchWithCache(`${modelPath}/onnx/embed_tokens.bin`, fetchOptions); if (!binResponse.ok) { console.warn('embed_tokens.bin not found, TTS will be unavailable'); return null; } const buffer = await binResponse.arrayBuffer(); const weight = new Float32Array(buffer); if (weight.length !== meta.vocab_size * meta.hidden_size) { console.error('embed_tokens size mismatch:', weight.length, 'expected:', meta.vocab_size * meta.hidden_size); return null; } console.log(`Loaded embed_tokens: [${meta.vocab_size}, ${meta.hidden_size}] (${(buffer.byteLength / 1e6).toFixed(1)} MB)`); return { weight, vocabSize: meta.vocab_size, hiddenSize: meta.hidden_size }; } /** * Get text embeddings for token IDs using pre-loaded weight * @param {number[]} tokenIds - Array of token IDs * @returns {ort.Tensor} - Embeddings tensor [1, seq_len, hidden_size] */ getTextEmbeddings(tokenIds) { if (!this.embedTokensWeight) { throw new Error('embed_tokens weight not loaded'); } const { weight, hiddenSize } = this.embedTokensWeight; const seqLen = tokenIds.length; const embeddings = new Float32Array(seqLen * hiddenSize); for (let i = 0; i < seqLen; i++) { const tokenId = tokenIds[i]; const srcOffset = tokenId * hiddenSize; const dstOffset = i * hiddenSize; embeddings.set(weight.subarray(srcOffset, srcOffset + hiddenSize), dstOffset); } return new ort.Tensor('float32', embeddings, [1, seqLen, hiddenSize]); } /** * Initialize KV cache for generation */ initializeCache() { const cache = {}; for (let idx = 0; idx < this.layerTypes.length; idx++) { const layerType = this.layerTypes[idx]; if (layerType === 'conv') { cache[`past_conv.${idx}`] = new ort.Tensor( 'float32', new Float32Array(1 * this.hiddenSize * this.convL), [1, this.hiddenSize, this.convL] ); } else { cache[`past_key_values.${idx}.key`] = new ort.Tensor( 'float32', new Float32Array(0), [1, this.numKVHeads, 0, this.headDim] ); cache[`past_key_values.${idx}.value`] = new ort.Tensor( 'float32', new Float32Array(0), [1, this.numKVHeads, 0, this.headDim] ); } } return cache; } /** * Update cache from decoder outputs */ updateCache(cache, outputs) { for (const name of Object.keys(outputs)) { if (name.startsWith('present_conv.')) { const cacheName = name.replace('present_conv', 'past_conv'); if (cacheName in cache) { cache[cacheName] = outputs[name]; } } else if (name.startsWith('present.')) { const cacheName = name.replace('present.', 'past_key_values.'); if (cacheName in cache) { cache[cacheName] = outputs[name]; } } } } /** * Run decoder with embeddings */ async runDecoder(embeds, attentionMask, cache) { const feeds = { inputs_embeds: embeds, attention_mask: attentionMask, ...cache, }; const outputs = await this.decoderSession.run(feeds); return { logits: outputs.logits, hiddenStates: outputs.hidden_states, outputs, }; } /** * Sample next token */ sampleToken(logits, temperature = 0.7) { if (temperature === 0) { let maxIdx = 0; let maxVal = logits[0]; for (let i = 1; i < logits.length; i++) { if (logits[i] > maxVal) { maxVal = logits[i]; maxIdx = i; } } return maxIdx; } // Temperature sampling const scaledLogits = new Float32Array(logits.length); let maxLogit = -Infinity; for (let i = 0; i < logits.length; i++) { scaledLogits[i] = logits[i] / temperature; maxLogit = Math.max(maxLogit, scaledLogits[i]); } let sumExp = 0; for (let i = 0; i < scaledLogits.length; i++) { scaledLogits[i] = Math.exp(scaledLogits[i] - maxLogit); sumExp += scaledLogits[i]; } const probs = new Float32Array(scaledLogits.length); for (let i = 0; i < probs.length; i++) { probs[i] = scaledLogits[i] / sumExp; } // Sample from distribution const r = Math.random(); let cumsum = 0; for (let i = 0; i < probs.length; i++) { cumsum += probs[i]; if (r < cumsum) return i; } return probs.length - 1; } /** * Transcribe audio to text (ASR mode) * @param {Float32Array} audioData - Audio samples in [-1, 1] * @param {number} sampleRate - Audio sample rate * @param {object} options - Generation options */ async transcribe(audioData, sampleRate, options = {}) { const { maxNewTokens = DEFAULT_MAX_TOKENS_TEXT, temperature = 0, systemPrompt = DEFAULT_SYSTEM_PROMPT_ASR, onToken, } = options; if (!this.audioEncoderSession) { throw new Error('Audio encoder not loaded'); } if (!this.embedTokensWeight) { throw new Error('embed_tokens not loaded - required for ASR'); } logReset(); log('=== ASR Transcription ==='); log('Audio samples:', audioData.length, 'Sample rate:', sampleRate); // 1. Compute mel spectrogram const { melFeatures, numFrames } = computeMelSpectrogram(audioData, sampleRate); log('Mel spectrogram frames:', numFrames); // 2. Run audio encoder const melTensor = new ort.Tensor( 'float32', melFeatures, [1, numFrames, 128] // [batch, time, n_mels] ); const melLengths = new ort.Tensor( 'int64', new BigInt64Array([BigInt(numFrames)]), [1] ); const encoderOutputs = await this.audioEncoderSession.run({ mel_spectrogram: melTensor, mel_lengths: melLengths, }); const audioEmbeds = encoderOutputs.audio_embeddings; log('Audio embeddings shape:', audioEmbeds.dims); // 3. Build prompt: prefix + audio + suffix const prefixText = `<|startoftext|><|im_start|>system\n${systemPrompt}<|im_end|>\n<|im_start|>user\n`; const suffixText = '<|im_end|>\n<|im_start|>assistant\n'; // Use add_special_tokens: false to match Python behavior (prompt already has special tokens) const prefixIds = Array.from(this.tokenizer.encode(prefixText, { add_special_tokens: false })); const suffixIds = Array.from(this.tokenizer.encode(suffixText, { add_special_tokens: false })); log('Prefix tokens:', prefixIds.length, 'Suffix tokens:', suffixIds.length); // Get text embeddings const prefixEmbeds = this.getTextEmbeddings(prefixIds); const suffixEmbeds = this.getTextEmbeddings(suffixIds); // 4. Concatenate embeddings: prefix + audio + suffix const prefixLen = prefixIds.length; const audioLen = audioEmbeds.dims[1]; const suffixLen = suffixIds.length; const totalLen = prefixLen + audioLen + suffixLen; const { hiddenSize } = this.embedTokensWeight; const allEmbeds = new Float32Array(totalLen * hiddenSize); // Copy prefix embeddings allEmbeds.set(prefixEmbeds.data, 0); // Copy audio embeddings allEmbeds.set(new Float32Array(audioEmbeds.data.buffer, audioEmbeds.data.byteOffset, audioLen * hiddenSize), prefixLen * hiddenSize); // Copy suffix embeddings allEmbeds.set(suffixEmbeds.data, (prefixLen + audioLen) * hiddenSize); const inputEmbeds = new ort.Tensor('float32', allEmbeds, [1, totalLen, hiddenSize]); const attentionMask = new ort.Tensor('int64', new BigInt64Array(totalLen).fill(1n), [1, totalLen]); // 5. Initialize cache and run prefill const cache = this.initializeCache(); let { logits, hiddenStates, outputs } = await this.runDecoder(inputEmbeds, attentionMask, cache); this.updateCache(cache, outputs); // 6. Generate tokens const generatedTokens = []; let currentLen = totalLen; for (let i = 0; i < maxNewTokens; i++) { // Get logits for last position - shape is [1, seq_len, vocab_size] const logitsData = logits.data; const seqLen = logits.dims[1]; const lastLogits = new Float32Array(this.vocabSize); const offset = (seqLen - 1) * this.vocabSize; for (let j = 0; j < this.vocabSize; j++) { lastLogits[j] = logitsData[offset + j]; } const nextToken = this.sampleToken(lastLogits, temperature); // Check for stop tokens if (nextToken === this.tokenizer.eos_token_id || nextToken === SPECIAL_TOKENS.IM_END) { log('Stop token reached'); break; } generatedTokens.push(nextToken); if (onToken) { const text = this.tokenizer.decode(generatedTokens); onToken(text, nextToken); } // Get embedding for next token const nextEmbeds = this.getTextEmbeddings([nextToken]); currentLen++; const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]); // Run decoder with single token ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, cache)); this.updateCache(cache, outputs); } const result = this.tokenizer.decode(generatedTokens); log(`Generated ${generatedTokens.length} tokens: "${result}"`); return result; } /** * Generate response from messages * @param {Array} messages - Chat messages * @param {object} options - Generation options */ async generate(messages, options = {}) { const { maxNewTokens = 256, onToken, audioData = null, sampleRate = null } = options; // If audio data provided, do ASR if (audioData && sampleRate) { return this.transcribe(audioData, sampleRate, { maxNewTokens, onToken, }); } // Text-only generation (simplified) const prompt = this.tokenizer.apply_chat_template(messages, { add_generation_prompt: true, tokenize: false, }); const inputIds = this.tokenizer.encode(prompt); console.log('Input tokens:', inputIds.length); // Initialize cache const cache = this.initializeCache(); const generatedTokens = []; // Note: Full implementation needs proper text embedding support // This is a placeholder that shows the model is loaded return '[Text generation requires full embedding support - model loaded successfully]'; } /** * Initialize reusable vocoder tensors to reduce allocation overhead */ _initVocoderCache() { if (this._vocoderCache) return; const numLayers = 6; const numKvHeads = 8; const headDim = 32; // Pre-allocate data arrays const stepIdxData = new BigInt64Array(1); const prevTokenData = new BigInt64Array(1); const seqlensKData = new Int32Array(1); const totalSeqLenData = new Int32Array(1); // Pre-allocate tensors that can be reused this._vocoderCache = { hiddenTensor: null, // Created per-call since hiddenState changes stepIdxData, prevTokenData, seqlensKData, totalSeqLenData, // Pre-create reusable tensors (ONNX Runtime reads from the data array) stepIdxTensor: new ort.Tensor('int64', stepIdxData, []), prevTokenTensor: new ort.Tensor('int64', prevTokenData, [1]), seqlensKTensor: new ort.Tensor('int32', seqlensKData, [1]), totalSeqLenTensor: new ort.Tensor('int32', totalSeqLenData, []), emptyKeysData: new Float32Array(0), emptyValuesData: new Float32Array(0), emptyDepthSlicesData: new Float32Array(8 * 1024), // zeros for step 0 // Reusable sampling arrays scaledLogits: new Float32Array(2049), // codebook vocab size indices: new Uint16Array(2049), // Use typed array for faster reset probs: new Float32Array(64), // top-k size }; // Initialize indices for (let i = 0; i < 2049; i++) { this._vocoderCache.indices[i] = i; } } /** * Sample audio codes using vocoder depthformer * Optimized to reduce tensor creation overhead * @param {Float32Array} hiddenState - [hidden_size] hidden state * @param {number} temperature - Sampling temperature * @param {number} topK - Top-k sampling * @returns {number[]} - 8 codebook values */ async sampleAudioCodes(hiddenState, temperature = 0.8, topK = 64) { if (!this.vocoderSession) { throw new Error('Vocoder not loaded'); } // Initialize cache on first call this._initVocoderCache(); const cache = this._vocoderCache; const numCodebooks = 8; const numLayers = 6; const numKvHeads = 8; const headDim = 32; const codes = []; let prevToken = 0; // Create hidden state tensor (must be new since data changes) const hiddenTensor = new ort.Tensor('float32', hiddenState, [1, this.hiddenSize]); // Initialize empty KV cache let pastKeys = new ort.Tensor( 'float32', cache.emptyKeysData, [numLayers, 1, numKvHeads, 0, headDim] ); let pastValues = new ort.Tensor( 'float32', cache.emptyValuesData, [numLayers, 1, numKvHeads, 0, headDim] ); // Reuse step_idx and prev_token tensors by updating their data cache.stepIdxData[0] = 0n; cache.prevTokenData[0] = 0n; // depth_slices_in: zeros at step 0 (model ignores it), then fed back from output let depthSlicesIn = new ort.Tensor('float32', cache.emptyDepthSlicesData, [1, 8, 1024]); for (let i = 0; i < numCodebooks; i++) { // Update mutable tensor data (tensor objects reuse the underlying data arrays) cache.stepIdxData[0] = BigInt(i); cache.prevTokenData[0] = BigInt(prevToken); cache.seqlensKData[0] = i; cache.totalSeqLenData[0] = i + 1; const feeds = { hidden_states: hiddenTensor, depth_slices_in: depthSlicesIn, step_idx: cache.stepIdxTensor, prev_token: cache.prevTokenTensor, past_keys: pastKeys, past_values: pastValues, seqlens_k: cache.seqlensKTensor, total_seq_len: cache.totalSeqLenTensor, }; const outputs = await this.vocoderSession.run(feeds); depthSlicesIn = outputs.depth_slices; const logits = outputs.logits.data; const vocabSize = logits.length; // Sample with temperature and top-k (reusing cached arrays) let token; if (temperature <= 0) { // Greedy token = 0; let maxVal = logits[0]; for (let j = 1; j < vocabSize; j++) { if (logits[j] > maxVal) { maxVal = logits[j]; token = j; } } } else { // Top-k sampling with reused arrays const scaledLogits = cache.scaledLogits; const indices = cache.indices; const probs = cache.probs; // Scale logits by temperature and find top-k in single pass // Use partial selection sort (O(k*n) which is fast for small k) for (let j = 0; j < vocabSize; j++) { scaledLogits[j] = logits[j] / temperature; indices[j] = j; } // Partial sort to get top-k for (let j = 0; j < topK; j++) { let maxIdx = j; for (let k = j + 1; k < vocabSize; k++) { if (scaledLogits[indices[k]] > scaledLogits[indices[maxIdx]]) { maxIdx = k; } } // Swap const tmp = indices[j]; indices[j] = indices[maxIdx]; indices[maxIdx] = tmp; } // Softmax over top-k const maxLogit = scaledLogits[indices[0]]; let sumExp = 0; for (let j = 0; j < topK; j++) { probs[j] = Math.exp(scaledLogits[indices[j]] - maxLogit); sumExp += probs[j]; } for (let j = 0; j < topK; j++) { probs[j] /= sumExp; } // Sample const r = Math.random(); let cumsum = 0; token = indices[topK - 1]; // Default to last for (let j = 0; j < topK; j++) { cumsum += probs[j]; if (r < cumsum) { token = indices[j]; break; } } } codes.push(token); prevToken = token; // Update KV cache pastKeys = outputs.new_keys; pastValues = outputs.new_values; } return codes; } /** * Generate speech from text (TTS mode) * @param {string} text - Text to convert to speech * @param {object} options - Generation options * @returns {object} - { audioCodes: number[][], textTokens: number[] } */ async generateSpeech(text, options = {}) { const { maxNewTokens = DEFAULT_MAX_TOKENS_AUDIO, textTemperature = 0.7, audioTemperature = 0.8, audioTopK = 64, systemPrompt = DEFAULT_SYSTEM_PROMPT_TTS, onToken, onAudioFrame, } = options; logReset(); log('=== TTS Generation ==='); log('Text:', text); if (!this.embedTokensWeight) { throw new Error('embed_tokens not loaded - required for TTS'); } if (!this.vocoderSession) { throw new Error('Vocoder not loaded - required for TTS'); } // Build TTS prompt const prompt = `<|startoftext|><|im_start|>system\n${systemPrompt}<|im_end|>\n<|im_start|>user\n${text}<|im_end|>\n<|im_start|>assistant\n`; // Use add_special_tokens: false to match Python behavior (prompt already has special tokens) const inputIds = Array.from(this.tokenizer.encode(prompt, { add_special_tokens: false })); log('TTS prompt tokens:', inputIds.length); // Get embeddings and run prefill const inputEmbeds = this.getTextEmbeddings(inputIds); const cache = this.initializeCache(); const attentionMask = new ort.Tensor('int64', new BigInt64Array(inputIds.length).fill(1n), [1, inputIds.length]); let { logits, hiddenStates, outputs } = await this.runDecoder(inputEmbeds, attentionMask, cache); this.updateCache(cache, outputs); // Phase 1: Generate text until <|audio_start|> token const textTokens = []; let currentLen = inputIds.length; let inAudioMode = false; let tokensGenerated = 0; while (tokensGenerated < maxNewTokens && !inAudioMode) { const logitsData = logits.data; const seqLen = logits.dims[1]; // Get logits for last position - shape is [1, seq_len, vocab_size] const lastLogits = new Float32Array(this.vocabSize); const offset = (seqLen - 1) * this.vocabSize; for (let i = 0; i < this.vocabSize; i++) { lastLogits[i] = logitsData[offset + i]; } const nextToken = this.sampleToken(lastLogits, textTemperature); tokensGenerated++; if (nextToken === this.tokenizer.eos_token_id) { console.warn('Model produced EOS before audio, TTS may not work'); break; } if (nextToken === SPECIAL_TOKENS.AUDIO_START) { log('Model entered audio mode'); inAudioMode = true; // Feed audio_start token to get hidden states for first audio frame const nextEmbeds = this.getTextEmbeddings([SPECIAL_TOKENS.AUDIO_START]); currentLen++; const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]); ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, cache)); this.updateCache(cache, outputs); break; } textTokens.push(nextToken); // Continue text generation const nextEmbeds = this.getTextEmbeddings([nextToken]); currentLen++; const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]); ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, cache)); this.updateCache(cache, outputs); } if (!inAudioMode) { console.warn('Model did not enter audio mode, forcing audio generation'); // Force audio start token const nextEmbeds = this.getTextEmbeddings([SPECIAL_TOKENS.AUDIO_START]); currentLen++; const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]); ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, cache)); this.updateCache(cache, outputs); tokensGenerated++; } // Phase 2: Generate audio frames using depthformer const audioCodes = []; const startTime = performance.now(); while (tokensGenerated < maxNewTokens) { // Get hidden state for last position const hiddenData = hiddenStates.data; const seqLen = hiddenStates.dims[1]; const lastHidden = hiddenData.slice((seqLen - 1) * this.hiddenSize, seqLen * this.hiddenSize); // Sample audio codes const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK); // Check for end-of-audio // Only check first codebook (matching liquid-audio TTS behavior) if (frameCodes[0] >= END_OF_AUDIO_TOKEN) { log(`End of audio at frame ${audioCodes.length}`); break; } // Log progress periodically if (audioCodes.length % 50 === 0) { log(`Generated ${audioCodes.length} audio frames`); } audioCodes.push(frameCodes); tokensGenerated++; if (onAudioFrame) { onAudioFrame(frameCodes, audioCodes.length); } // Feed back audio codes to continue generation // Audio embedding expects tokens in range [0, 16392) where: // token = codebook_idx * 2049 + code_value const clampedCodes = frameCodes.map(c => Math.min(c, 2047)); const audioTokens = clampedCodes.map((code, idx) => idx * CODEBOOK_VOCAB + code); // Get summed embeddings for all 8 codebooks const summedEmbeds = await this.getAudioEmbedding(audioTokens); const nextEmbeds = new ort.Tensor('float32', summedEmbeds, [1, 1, this.hiddenSize]); currentLen++; const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]); ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, cache)); this.updateCache(cache, outputs); } const elapsed = (performance.now() - startTime) / 1000; const framesPerSec = audioCodes.length / elapsed; log(`Generated ${audioCodes.length} audio frames in ${elapsed.toFixed(2)}s (${framesPerSec.toFixed(1)} frames/s)`); const textOutput = textTokens.length > 0 ? this.tokenizer.decode(textTokens) : ''; return { audioCodes, textTokens, textOutput }; } /** * Generate interleaved response (mixed text/audio) with stateful KV cache. * * The cache is preserved between calls for multi-turn conversation. * Call reset() to start a new conversation. * * @param {Float32Array} audioData - Input audio samples * @param {number} sampleRate - Audio sample rate * @param {string} textPrompt - Optional text prompt (unused, for API compatibility) * @param {object} options - Generation options * @returns {object} - { text: string, audioCodes: number[][] } */ async generateInterleaved(audioData, sampleRate, textPrompt = '', options = {}) { const { maxNewTokens = DEFAULT_MAX_TOKENS_AUDIO, textTemperature = 1.0, audioTemperature = 1.0, audioTopK = 4, systemPrompt = DEFAULT_SYSTEM_PROMPT_INTERLEAVED, onToken, onAudioFrame, } = options; // Counter-based mode switching (matching liquid-audio) const INTERLEAVED_N_TEXT = 6; const INTERLEAVED_N_AUDIO = 12; logReset(); log('=== Interleaved Generation ==='); log('Cache state:', this.cache ? `exists (seq_len=${this.cacheSeqLen})` : 'null (new conversation)'); log('Audio samples:', audioData.length, 'Sample rate:', sampleRate); if (!this.audioEncoderSession) { throw new Error('Audio encoder not loaded - required for interleaved mode'); } if (!this.embedTokensWeight) { throw new Error('embed_tokens not loaded - required for interleaved mode'); } if (!this.vocoderSession) { throw new Error('Vocoder not loaded - required for interleaved mode'); } // Timing accumulators let timeAudioEncode = 0; let timePrefill = 0; let timeTextDecode = 0; let timeAudioDecode = 0; let timeVocoder = 0; let timeAudioEmbed = 0; // 1. Compute mel spectrogram and encode audio let tStep = performance.now(); const { melFeatures, numFrames } = computeMelSpectrogram(audioData, sampleRate); const timeMel = performance.now() - tStep; const melTensor = new ort.Tensor('float32', melFeatures, [1, numFrames, 128]); const melLengths = new ort.Tensor('int64', new BigInt64Array([BigInt(numFrames)]), [1]); tStep = performance.now(); const encoderOutputs = await this.audioEncoderSession.run({ mel_spectrogram: melTensor, mel_lengths: melLengths, }); timeAudioEncode = performance.now() - tStep; const audioEmbeds = encoderOutputs.audio_embeddings; log(`Mel: ${timeMel.toFixed(0)}ms, AudioEnc: ${timeAudioEncode.toFixed(0)}ms, frames: ${numFrames}`); const { hiddenSize } = this.embedTokensWeight; // 2. Build prompt based on whether this is first turn or continuation let inputEmbeds; let newSeqLen; if (this.cache === null) { // === First turn: full prompt with system message === log('First turn - initializing conversation'); this.cache = this.initializeCache(); this.cacheSeqLen = 0; const prefixText = `<|startoftext|><|im_start|>system\n${systemPrompt}<|im_end|>\n<|im_start|>user\n`; const suffixText = '<|im_end|>\n<|im_start|>assistant\n'; const prefixIds = Array.from(this.tokenizer.encode(prefixText, { add_special_tokens: false })); const suffixIds = Array.from(this.tokenizer.encode(suffixText, { add_special_tokens: false })); const prefixEmbeds = this.getTextEmbeddings(prefixIds); const suffixEmbeds = this.getTextEmbeddings(suffixIds); const prefixLen = prefixIds.length; const audioLen = audioEmbeds.dims[1]; const suffixLen = suffixIds.length; newSeqLen = prefixLen + audioLen + suffixLen; const allEmbeds = new Float32Array(newSeqLen * hiddenSize); allEmbeds.set(prefixEmbeds.data, 0); allEmbeds.set( new Float32Array(audioEmbeds.data.buffer, audioEmbeds.data.byteOffset, audioLen * hiddenSize), prefixLen * hiddenSize ); allEmbeds.set(suffixEmbeds.data, (prefixLen + audioLen) * hiddenSize); inputEmbeds = new ort.Tensor('float32', allEmbeds, [1, newSeqLen, hiddenSize]); } else { // === Continuation: user turn only === log(`Continuing conversation (cache seq_len=${this.cacheSeqLen})`); const userPrefixText = '<|im_start|>user\n'; const suffixText = '<|im_end|>\n<|im_start|>assistant\n'; const userPrefixIds = Array.from(this.tokenizer.encode(userPrefixText, { add_special_tokens: false })); const suffixIds = Array.from(this.tokenizer.encode(suffixText, { add_special_tokens: false })); const userPrefixEmbeds = this.getTextEmbeddings(userPrefixIds); const suffixEmbeds = this.getTextEmbeddings(suffixIds); const userPrefixLen = userPrefixIds.length; const audioLen = audioEmbeds.dims[1]; const suffixLen = suffixIds.length; newSeqLen = userPrefixLen + audioLen + suffixLen; const allEmbeds = new Float32Array(newSeqLen * hiddenSize); allEmbeds.set(userPrefixEmbeds.data, 0); allEmbeds.set( new Float32Array(audioEmbeds.data.buffer, audioEmbeds.data.byteOffset, audioLen * hiddenSize), userPrefixLen * hiddenSize ); allEmbeds.set(suffixEmbeds.data, (userPrefixLen + audioLen) * hiddenSize); inputEmbeds = new ort.Tensor('float32', allEmbeds, [1, newSeqLen, hiddenSize]); } // 3. Run prefill with attention mask covering full sequence const totalLen = this.cacheSeqLen + newSeqLen; const attentionMask = new ort.Tensor('int64', new BigInt64Array(totalLen).fill(1n), [1, totalLen]); tStep = performance.now(); let { logits, hiddenStates, outputs } = await this.runDecoder(inputEmbeds, attentionMask, this.cache); timePrefill = performance.now() - tStep; this.updateCache(this.cache, outputs); this.cacheSeqLen = totalLen; log(`Prefill: ${timePrefill.toFixed(0)}ms, new tokens: ${newSeqLen}, total: ${totalLen}`); // 4. Generate with counter-based mode switching const textTokens = []; const audioCodes = []; let currentLen = totalLen; let inAudioMode = false; let modalityLeft = INTERLEAVED_N_TEXT; let textDone = false; const startTime = performance.now(); for (let step = 0; step < maxNewTokens; step++) { modalityLeft--; if (inAudioMode) { // Generate audio frame using depthformer const hiddenData = hiddenStates.data; const seqLen = hiddenStates.dims[1]; const lastHidden = hiddenData.slice((seqLen - 1) * hiddenSize, seqLen * hiddenSize); tStep = performance.now(); const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK); timeVocoder += performance.now() - tStep; // Switch back to text after N audio frames (if text not done) if (modalityLeft <= 0 && !textDone) { inAudioMode = false; modalityLeft = INTERLEAVED_N_TEXT; } // Check for end of audio - first codebook == 2048 (matching liquid-audio) if (frameCodes[0] === END_OF_AUDIO_TOKEN) { log(`End of audio at step ${step}`); // Set all codes to 2048 (matching liquid-audio) for (let i = 0; i < NUM_CODEBOOKS; i++) { frameCodes[i] = END_OF_AUDIO_TOKEN; } inAudioMode = false; // Don't save this frame, but still feed it back } else { // Save valid frame (clamped to 0-2047) const clampedFrame = frameCodes.map(c => Math.min(c, 2047)); audioCodes.push(clampedFrame); if (onAudioFrame) { onAudioFrame(clampedFrame, audioCodes.length); } if (audioCodes.length % 50 === 0) { log(`Generated ${audioCodes.length} audio frames`); } } // Get embeddings for next step (always feed back, even for 2048 frames) tStep = performance.now(); const feedCodes = frameCodes.map(c => c === END_OF_AUDIO_TOKEN ? END_OF_AUDIO_TOKEN : Math.min(c, 2047)); const audioTokens = feedCodes.map((code, idx) => idx * CODEBOOK_VOCAB + code); // Get summed embeddings for all 8 codebooks const summedEmbeds = await this.getAudioEmbedding(audioTokens); timeAudioEmbed += performance.now() - tStep; const nextEmbeds = new ort.Tensor('float32', summedEmbeds, [1, 1, hiddenSize]); currentLen++; const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]); tStep = performance.now(); ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache)); timeAudioDecode += performance.now() - tStep; this.updateCache(this.cache, outputs); } else { // Generate text token const logitsData = logits.data; const seqLen = logits.dims[1]; // Get logits for last position - shape is [1, seq_len, vocab_size] const lastLogits = new Float32Array(this.vocabSize); const offset = (seqLen - 1) * this.vocabSize; for (let i = 0; i < this.vocabSize; i++) { lastLogits[i] = logitsData[offset + i]; } const token = this.sampleToken(lastLogits, textTemperature); // Check for end of turn if (token === this.tokenizer.eos_token_id || token === SPECIAL_TOKENS.IM_END) { log(`End of turn at step ${step}`); break; } // Check for <|text_end|> token (130) if (token === SPECIAL_TOKENS.TEXT_END) { log(`Text end at step ${step}`); textDone = true; } // Switch to audio after N text tokens OR text_end if (modalityLeft <= 0 || textDone) { inAudioMode = true; modalityLeft = INTERLEAVED_N_AUDIO; } textTokens.push(token); if (onToken) { const decodedText = this.tokenizer.decode(textTokens, { skip_special_tokens: true }); onToken(decodedText, token); } // Get embedding for next step const nextEmbeds = this.getTextEmbeddings([token]); currentLen++; const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]); tStep = performance.now(); ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache)); timeTextDecode += performance.now() - tStep; this.updateCache(this.cache, outputs); } } // 5. Feed <|im_end|> token to close assistant turn in cache const imEndEmbeds = this.getTextEmbeddings([SPECIAL_TOKENS.IM_END]); currentLen++; const finalMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]); ({ outputs } = await this.runDecoder(imEndEmbeds, finalMask, this.cache)); this.updateCache(this.cache, outputs); this.cacheSeqLen = currentLen; // Decode with skip_special_tokens to clean up special tokens like <|text_end|> const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true }); // Print timing summary log(`=== Summary ===`); log(` Mel: ${timeMel.toFixed(0)}ms, AudioEnc: ${timeAudioEncode.toFixed(0)}ms, Prefill: ${timePrefill.toFixed(0)}ms`); log(` TextDec: ${timeTextDecode.toFixed(0)}ms (${textTokens.length} tok), AudioDec: ${timeAudioDecode.toFixed(0)}ms`); log(` Vocoder: ${timeVocoder.toFixed(0)}ms, AudioEmbed: ${timeAudioEmbed.toFixed(0)}ms`); log(`Output: ${textTokens.length} text tokens, ${audioCodes.length} audio frames`); log(`Text: "${text}"`); log(`Cache seq_len: ${this.cacheSeqLen}`); return { text, audioCodes }; } /** * Generate interleaved response from text-only input (continuation turn). * Uses the stateful KV cache from previous turns. Produces both text AND audio. * * @param {string} userText - User's text message * @param {object} options - Generation options * @returns {object} - { text: string, audioCodes: number[][] } */ async generateInterleavedFromText(userText, options = {}) { const { maxNewTokens = DEFAULT_MAX_TOKENS_AUDIO, textTemperature = 1.0, audioTemperature = 1.0, audioTopK = 4, systemPrompt = DEFAULT_SYSTEM_PROMPT_INTERLEAVED, onToken, onAudioFrame, } = options; // Counter-based mode switching (matching liquid-audio) const INTERLEAVED_N_TEXT = 6; const INTERLEAVED_N_AUDIO = 12; logReset(); log('=== Text-Only Interleaved Generation ==='); log(`Cache state: ${this.cache ? `exists (seq_len=${this.cacheSeqLen})` : 'null (new conversation)'}`); log(`User text: ${userText}`); if (!this.embedTokensWeight) { throw new Error('embed_tokens not loaded - required for text generation'); } if (!this.vocoderSession) { throw new Error('Vocoder not loaded - required for interleaved mode'); } // Timing accumulators let timePrefill = 0; let timeTextDecode = 0; let timeAudioDecode = 0; let timeVocoder = 0; let timeAudioEmbed = 0; let tStep; const { hiddenSize } = this.embedTokensWeight; // Build prompt based on whether this is first turn or continuation let inputEmbeds; let newSeqLen; if (this.cache === null) { // === First turn: full prompt with system message === log('First turn - initializing conversation'); this.cache = this.initializeCache(); this.cacheSeqLen = 0; const prefixText = `<|startoftext|><|im_start|>system\n${systemPrompt}<|im_end|>\n<|im_start|>user\n${userText}<|im_end|>\n<|im_start|>assistant\n`; const prefixIds = Array.from(this.tokenizer.encode(prefixText, { add_special_tokens: false })); const prefixEmbeds = this.getTextEmbeddings(prefixIds); newSeqLen = prefixIds.length; inputEmbeds = new ort.Tensor('float32', prefixEmbeds.data, [1, newSeqLen, hiddenSize]); } else { // === Continuation: user turn only === log(`Continuing conversation (cache seq_len=${this.cacheSeqLen})`); const userTurnText = `<|im_start|>user\n${userText}<|im_end|>\n<|im_start|>assistant\n`; const userTurnIds = Array.from(this.tokenizer.encode(userTurnText, { add_special_tokens: false })); const userTurnEmbeds = this.getTextEmbeddings(userTurnIds); newSeqLen = userTurnIds.length; inputEmbeds = new ort.Tensor('float32', userTurnEmbeds.data, [1, newSeqLen, hiddenSize]); } // Run prefill with attention mask covering full sequence const totalLen = this.cacheSeqLen + newSeqLen; const attentionMask = new ort.Tensor('int64', new BigInt64Array(totalLen).fill(1n), [1, totalLen]); tStep = performance.now(); let { logits, hiddenStates, outputs } = await this.runDecoder(inputEmbeds, attentionMask, this.cache); timePrefill = performance.now() - tStep; this.updateCache(this.cache, outputs); this.cacheSeqLen = totalLen; log(`Prefill: ${timePrefill.toFixed(0)}ms, new tokens: ${newSeqLen}, total: ${totalLen}`); // Generate with counter-based mode switching const textTokens = []; const audioCodes = []; let currentLen = totalLen; let inAudioMode = false; let modalityLeft = INTERLEAVED_N_TEXT; let textDone = false; for (let step = 0; step < maxNewTokens; step++) { modalityLeft--; if (inAudioMode) { // Generate audio frame using depthformer const hiddenData = hiddenStates.data; const seqLen = hiddenStates.dims[1]; const lastHidden = hiddenData.slice((seqLen - 1) * hiddenSize, seqLen * hiddenSize); tStep = performance.now(); const frameCodes = await this.sampleAudioCodes(lastHidden, audioTemperature, audioTopK); timeVocoder += performance.now() - tStep; // Switch back to text after N audio frames (if text not done) if (modalityLeft <= 0 && !textDone) { inAudioMode = false; modalityLeft = INTERLEAVED_N_TEXT; } // Check for end of audio if (frameCodes[0] === END_OF_AUDIO_TOKEN) { log(`End of audio at step ${step}`); for (let i = 0; i < NUM_CODEBOOKS; i++) { frameCodes[i] = END_OF_AUDIO_TOKEN; } inAudioMode = false; } else { const clampedFrame = frameCodes.map(c => Math.min(c, 2047)); audioCodes.push(clampedFrame); if (onAudioFrame) { onAudioFrame(clampedFrame, audioCodes.length); } if (audioCodes.length % 50 === 0) { log(`Generated ${audioCodes.length} audio frames`); } } // Get embeddings for next step tStep = performance.now(); const feedCodes = frameCodes.map(c => c === END_OF_AUDIO_TOKEN ? END_OF_AUDIO_TOKEN : Math.min(c, 2047)); const audioTokens = feedCodes.map((code, idx) => idx * CODEBOOK_VOCAB + code); const summedEmbeds = await this.getAudioEmbedding(audioTokens); timeAudioEmbed += performance.now() - tStep; const nextEmbeds = new ort.Tensor('float32', summedEmbeds, [1, 1, hiddenSize]); currentLen++; const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]); tStep = performance.now(); ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache)); timeAudioDecode += performance.now() - tStep; this.updateCache(this.cache, outputs); } else { // Generate text token const logitsData = logits.data; const seqLen = logits.dims[1]; const lastLogits = new Float32Array(this.vocabSize); const offset = (seqLen - 1) * this.vocabSize; for (let i = 0; i < this.vocabSize; i++) { lastLogits[i] = logitsData[offset + i]; } const token = this.sampleToken(lastLogits, textTemperature); // Check for end of turn if (token === this.tokenizer.eos_token_id || token === SPECIAL_TOKENS.IM_END) { log(`End of turn at step ${step}`); break; } // Check for <|text_end|> token if (token === SPECIAL_TOKENS.TEXT_END) { log(`Text end at step ${step}`); textDone = true; } // Switch to audio after N text tokens OR text_end if (modalityLeft <= 0 || textDone) { inAudioMode = true; modalityLeft = INTERLEAVED_N_AUDIO; } textTokens.push(token); if (onToken) { const decodedText = this.tokenizer.decode(textTokens, { skip_special_tokens: true }); onToken(decodedText, token); } // Get embedding for next step const nextEmbeds = this.getTextEmbeddings([token]); currentLen++; const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]); tStep = performance.now(); ({ logits, hiddenStates, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache)); timeTextDecode += performance.now() - tStep; this.updateCache(this.cache, outputs); } } // Feed <|im_end|> token to close assistant turn in cache const imEndEmbeds = this.getTextEmbeddings([SPECIAL_TOKENS.IM_END]); currentLen++; const finalMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]); ({ outputs } = await this.runDecoder(imEndEmbeds, finalMask, this.cache)); this.updateCache(this.cache, outputs); this.cacheSeqLen = currentLen; const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true }); log(`=== Summary ===`); log(` Prefill: ${timePrefill.toFixed(0)}ms`); log(` TextDec: ${timeTextDecode.toFixed(0)}ms (${textTokens.length} tok), AudioDec: ${timeAudioDecode.toFixed(0)}ms`); log(` Vocoder: ${timeVocoder.toFixed(0)}ms, AudioEmbed: ${timeAudioEmbed.toFixed(0)}ms`); log(`Output: ${textTokens.length} text tokens, ${audioCodes.length} audio frames`); log(`Text: "${text}"`); log(`Cache seq_len: ${this.cacheSeqLen}`); return { text, audioCodes }; } /** * Generate text-only response (for follow-up turns without audio). * Uses the stateful KV cache from previous interleaved turns. * * @param {string} userText - User's text input * @param {object} options - Generation options * @returns {object} - { text: string } */ async generateTextOnly(userText, options = {}) { const { maxNewTokens = 256, temperature = 0.7, systemPrompt = 'You are a helpful assistant.', onToken, } = options; logReset(); log('=== Text-Only Generation ==='); log('Cache state:', this.cache ? `exists (seq_len=${this.cacheSeqLen})` : 'null (new conversation)'); log('User text:', userText); if (!this.embedTokensWeight) { throw new Error('embed_tokens not loaded'); } const { hiddenSize } = this.embedTokensWeight; // Build prompt based on whether we have existing cache let inputEmbeds; let newSeqLen; if (this.cache === null) { // First turn: include system message log('First turn - initializing conversation'); this.cache = this.initializeCache(); this.cacheSeqLen = 0; const promptText = `<|startoftext|><|im_start|>system\n${systemPrompt}<|im_end|>\n<|im_start|>user\n${userText}<|im_end|>\n<|im_start|>assistant\n`; const promptIds = Array.from(this.tokenizer.encode(promptText, { add_special_tokens: false })); inputEmbeds = this.getTextEmbeddings(promptIds); newSeqLen = promptIds.length; } else { // Continuation: just user turn log(`Continuing conversation (cache seq_len=${this.cacheSeqLen})`); const turnText = `<|im_start|>user\n${userText}<|im_end|>\n<|im_start|>assistant\n`; const turnIds = Array.from(this.tokenizer.encode(turnText, { add_special_tokens: false })); inputEmbeds = this.getTextEmbeddings(turnIds); newSeqLen = turnIds.length; } // Run prefill const totalLen = this.cacheSeqLen + newSeqLen; const attentionMask = new ort.Tensor('int64', new BigInt64Array(totalLen).fill(1n), [1, totalLen]); let { logits, outputs } = await this.runDecoder(inputEmbeds, attentionMask, this.cache); this.updateCache(this.cache, outputs); this.cacheSeqLen = totalLen; // Generate tokens const textTokens = []; let currentLen = totalLen; for (let i = 0; i < maxNewTokens; i++) { const logitsData = logits.data; const seqLen = logits.dims[1]; const lastLogits = new Float32Array(this.vocabSize); const offset = (seqLen - 1) * this.vocabSize; for (let j = 0; j < this.vocabSize; j++) { lastLogits[j] = logitsData[offset + j]; } const nextToken = this.sampleToken(lastLogits, temperature); // Check for stop tokens if (nextToken === this.tokenizer.eos_token_id || nextToken === SPECIAL_TOKENS.IM_END) { log('Stop token reached'); break; } textTokens.push(nextToken); if (onToken) { const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true }); onToken(text, nextToken); } // Get embedding for next token const nextEmbeds = this.getTextEmbeddings([nextToken]); currentLen++; const nextMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]); ({ logits, outputs } = await this.runDecoder(nextEmbeds, nextMask, this.cache)); this.updateCache(this.cache, outputs); } // Feed <|im_end|> to close turn const imEndEmbeds = this.getTextEmbeddings([SPECIAL_TOKENS.IM_END]); currentLen++; const finalMask = new ort.Tensor('int64', new BigInt64Array(currentLen).fill(1n), [1, currentLen]); ({ outputs } = await this.runDecoder(imEndEmbeds, finalMask, this.cache)); this.updateCache(this.cache, outputs); this.cacheSeqLen = currentLen; const text = this.tokenizer.decode(textTokens, { skip_special_tokens: true }); log(`Generated ${textTokens.length} tokens: "${text}"`); log(`Cache seq_len: ${this.cacheSeqLen}`); return { text }; } /** * Decode audio codes to waveform using audio detokenizer + ISTFT * @param {number[][]} audioCodes - Array of [8] codebook values per frame * @returns {Float32Array} - Audio waveform samples in [-1, 1] */ async decodeAudioCodes(audioCodes) { if (!this.audioDetokenizerSession) { throw new Error('Audio detokenizer not loaded'); } if (audioCodes.length < 2) { console.warn('Not enough audio codes to decode'); return new Float32Array(0); } const decodeStart = performance.now(); log(`Decoding ${audioCodes.length} audio frames...`); // ISTFT parameters (fixed for this model) const nFft = 1280; const hopLength = 320; const winLength = 1280; const nFftBins = nFft / 2 + 1; // Stack codes: [T, 8] -> [8, T] and add batch -> [1, 8, T] const T = audioCodes.length; const codesTransposed = new BigInt64Array(8 * T); for (let t = 0; t < T; t++) { for (let cb = 0; cb < 8; cb++) { codesTransposed[cb * T + t] = BigInt(Math.min(audioCodes[t][cb], 2047)); } } // Run detokenizer: [1, 8, T] -> [1, T, 1282] const codesTensor = new ort.Tensor('int64', codesTransposed, [1, 8, T]); const detokStart = performance.now(); const detokOutputs = await this.audioDetokenizerSession.run({ audio_codes: codesTensor }); const stftFeatures = detokOutputs.stft_features; log(`Detokenizer: ${(performance.now() - detokStart).toFixed(0)}ms, STFT frames: ${stftFeatures.dims[1]}`); // Get raw data - shape is [1, T, 1282], we need to skip batch dimension const stftData = stftFeatures.data; const actualT = stftFeatures.dims[1]; // Convert to complex STFT: [log_magnitude | angle] -> complex const complexStft = new Array(nFftBins); for (let f = 0; f < nFftBins; f++) { complexStft[f] = new Array(actualT); for (let t = 0; t < actualT; t++) { const logMag = stftData[t * 1282 + f]; const angle = stftData[t * 1282 + nFftBins + f]; const mag = Math.exp(logMag); // Store as [real, imag] complexStft[f][t] = [mag * Math.cos(angle), mag * Math.sin(angle)]; } } // ISTFT with 'same' padding const istftStart = performance.now(); const waveform = this.istftSamePadding(complexStft, nFft, hopLength, winLength, actualT); log(`ISTFT: ${(performance.now() - istftStart).toFixed(0)}ms`); // Find max/min without spread operator (avoid stack overflow on large arrays) let waveMax = -Infinity, waveMin = Infinity; for (let i = 0; i < waveform.length; i++) { if (waveform[i] > waveMax) waveMax = waveform[i]; if (waveform[i] < waveMin) waveMin = waveform[i]; } log('ISTFT output - length:', waveform.length, 'max:', waveMax.toFixed(4), 'min:', waveMin.toFixed(4)); // Check for invalid values if (isNaN(waveMax) || isNaN(waveMin) || !isFinite(waveMax) || !isFinite(waveMin)) { console.error('ISTFT produced invalid values (NaN/Inf)'); return new Float32Array(0); } // Normalize to [-1, 1] let maxVal = Math.max(Math.abs(waveMax), Math.abs(waveMin)); if (maxVal > 0) { for (let i = 0; i < waveform.length; i++) { waveform[i] = (waveform[i] / maxVal) * 0.9; } } else { console.warn('ISTFT produced all-zero waveform'); } log(`Decoded audio: ${waveform.length} samples (${(waveform.length / 24000).toFixed(2)}s)`); return waveform; } /** * ISTFT with 'same' padding matching liquid_audio. * Uses Bluestein FFT for O(N log N) IRFFT on any size. * * Matches Python: np.fft.irfft(spec, n_fft, axis=0, norm="backward") */ istftSamePadding(complexStft, nFft, hopLength, winLength, T) { const N = complexStft.length; // nFftBins = nFft/2 + 1 = 641 const pad = Math.floor((winLength - hopLength) / 2); // Generate Hann window const window = new Float32Array(winLength); for (let i = 0; i < winLength; i++) { window[i] = 0.5 * (1 - Math.cos(2 * Math.PI * i / (winLength - 1))); } // Initialize Bluestein FFT for size nFft (cached for reuse) if (!this._bluesteinCache || this._bluesteinCache.n !== nFft) { this._bluesteinCache = this._initBluestein(nFft); } const bluestein = this._bluesteinCache; // Pre-allocate buffers for IFFT const fullRe = new Float32Array(nFft); const fullIm = new Float32Array(nFft); // Process all frames const ifftFrames = new Array(T); for (let t = 0; t < T; t++) { // Build full spectrum from one-sided (conjugate symmetry) fullRe.fill(0); fullIm.fill(0); // Copy positive frequencies for (let k = 0; k < N; k++) { fullRe[k] = complexStft[k][t][0]; fullIm[k] = complexStft[k][t][1]; } // Mirror negative frequencies (conjugate symmetry for real signal) for (let k = 1; k < N - 1; k++) { fullRe[nFft - k] = fullRe[k]; fullIm[nFft - k] = -fullIm[k]; } // IFFT using Bluestein: IFFT(X) = conj(FFT(conj(X))) / N // Conjugate input for (let i = 0; i < nFft; i++) fullIm[i] = -fullIm[i]; // FFT this._bluesteinFFT(fullRe, fullIm, bluestein); // Conjugate and scale for (let i = 0; i < nFft; i++) { fullRe[i] /= nFft; fullIm[i] = -fullIm[i] / nFft; } // Apply window (take first winLength samples) const windowedFrame = new Float32Array(winLength); for (let n = 0; n < winLength; n++) { windowedFrame[n] = fullRe[n] * window[n]; } ifftFrames[t] = windowedFrame; // Debug first frame if (t === 0) { let maxVal = 0; let hasNaN = false; for (let n = 0; n < winLength; n++) { if (isNaN(windowedFrame[n]) || !isFinite(windowedFrame[n])) { hasNaN = true; break; } const absVal = Math.abs(windowedFrame[n] / (window[n] + 1e-10)); if (absVal > maxVal) maxVal = absVal; } if (hasNaN) { console.error('IRFFT frame 0 contains NaN/Inf values!'); } } } // Overlap-add const outputSize = (T - 1) * hopLength + winLength; const audio = new Float32Array(outputSize); const windowEnvelope = new Float32Array(outputSize); const windowSq = new Float32Array(winLength); for (let i = 0; i < winLength; i++) { windowSq[i] = window[i] * window[i]; } for (let t = 0; t < T; t++) { const start = t * hopLength; for (let n = 0; n < winLength; n++) { audio[start + n] += ifftFrames[t][n]; windowEnvelope[start + n] += windowSq[n]; } } // Normalize and trim padding const trimmedLength = outputSize - 2 * pad; const trimmed = new Float32Array(trimmedLength); for (let i = 0; i < trimmedLength; i++) { const srcIdx = i + pad; if (windowEnvelope[srcIdx] > 1e-8) { trimmed[i] = audio[srcIdx] / windowEnvelope[srcIdx]; } else { trimmed[i] = audio[srcIdx]; } } return trimmed; } /** * Initialize Bluestein FFT for size n (any size, not just power of 2) */ _initBluestein(n) { // Bluestein's algorithm: converts any-size FFT to power-of-2 FFT via convolution // FFT size for convolution: next power of 2 >= 2n - 1 let m = 1; while (m < 2 * n - 1) m <<= 1; // Chirp sequence: W_n^(k^2/2) = exp(-πi * k² / n) const chirpRe = new Float32Array(n); const chirpIm = new Float32Array(n); for (let k = 0; k < n; k++) { const angle = Math.PI * k * k / n; chirpRe[k] = Math.cos(angle); chirpIm[k] = -Math.sin(angle); // exp(-i*angle) } // Precompute FFT of chirp filter (b sequence) // b[k] = conj(chirp[k]) for k in [0, n-1] // b[m-k] = conj(chirp[k]) for k in [1, n-1] // conj(chirp[k]) = chirpRe[k] - i*chirpIm[k] const bRe = new Float32Array(m); const bIm = new Float32Array(m); bRe[0] = chirpRe[0]; bIm[0] = -chirpIm[0]; // conjugate for (let k = 1; k < n; k++) { bRe[k] = chirpRe[k]; bIm[k] = -chirpIm[k]; // conjugate bRe[m - k] = chirpRe[k]; bIm[m - k] = -chirpIm[k]; // conjugate } // FFT of b (in-place) this._fftRadix2InPlace(bRe, bIm, m, false); // Precompute twiddle factors for radix-2 FFT of size m const twiddleRe = new Float32Array(m / 2); const twiddleIm = new Float32Array(m / 2); for (let i = 0; i < m / 2; i++) { const angle = -2 * Math.PI * i / m; twiddleRe[i] = Math.cos(angle); twiddleIm[i] = Math.sin(angle); } return { n, m, chirpRe, chirpIm, bRe, bIm, twiddleRe, twiddleIm }; } /** * Bluestein FFT for any size */ _bluesteinFFT(re, im, cache) { const { n, m, chirpRe, chirpIm, bRe, bIm, twiddleRe, twiddleIm } = cache; // a[k] = x[k] * chirp[k] for k in [0, n-1], zero-padded to m // chirp[k] = chirpRe[k] + i*chirpIm[k] // (re + i*im) * (chirpRe + i*chirpIm) = (re*chirpRe - im*chirpIm) + i*(im*chirpRe + re*chirpIm) const aRe = new Float32Array(m); const aIm = new Float32Array(m); for (let k = 0; k < n; k++) { aRe[k] = re[k] * chirpRe[k] - im[k] * chirpIm[k]; aIm[k] = im[k] * chirpRe[k] + re[k] * chirpIm[k]; } // FFT of a this._fftRadix2(aRe, aIm, twiddleRe, twiddleIm); // Pointwise multiply: a = a * b for (let k = 0; k < m; k++) { const tmpRe = aRe[k] * bRe[k] - aIm[k] * bIm[k]; const tmpIm = aRe[k] * bIm[k] + aIm[k] * bRe[k]; aRe[k] = tmpRe; aIm[k] = tmpIm; } // IFFT of a (using FFT with conjugate trick) for (let k = 0; k < m; k++) aIm[k] = -aIm[k]; this._fftRadix2(aRe, aIm, twiddleRe, twiddleIm); for (let k = 0; k < m; k++) { aRe[k] /= m; aIm[k] = -aIm[k] / m; } // X[k] = chirp[k] * y[k] // Same multiplication as for a: (aRe + i*aIm) * (chirpRe + i*chirpIm) for (let k = 0; k < n; k++) { re[k] = aRe[k] * chirpRe[k] - aIm[k] * chirpIm[k]; im[k] = aIm[k] * chirpRe[k] + aRe[k] * chirpIm[k]; } } /** * In-place radix-2 FFT (Cooley-Tukey) with precomputed twiddles */ _fftRadix2(re, im, twiddleRe, twiddleIm) { const n = re.length; // Bit-reversal permutation for (let i = 0, j = 0; i < n; i++) { if (i < j) { let tmp = re[i]; re[i] = re[j]; re[j] = tmp; tmp = im[i]; im[i] = im[j]; im[j] = tmp; } let k = n >> 1; while (k > 0 && k <= j) { j -= k; k >>= 1; } j += k; } // Cooley-Tukey butterflies for (let len = 2; len <= n; len <<= 1) { const halfLen = len >> 1; const step = n / len; for (let i = 0; i < n; i += len) { for (let j = 0; j < halfLen; j++) { const twIdx = j * step; const wRe = twiddleRe[twIdx]; const wIm = twiddleIm[twIdx]; const u = i + j; const v = u + halfLen; const tRe = wRe * re[v] - wIm * im[v]; const tIm = wRe * im[v] + wIm * re[v]; re[v] = re[u] - tRe; im[v] = im[u] - tIm; re[u] += tRe; im[u] += tIm; } } } } /** * In-place radix-2 FFT without precomputed twiddles (for initialization) */ _fftRadix2InPlace(re, im, n, inverse = false) { // Bit-reversal for (let i = 0, j = 0; i < n; i++) { if (i < j) { let tmp = re[i]; re[i] = re[j]; re[j] = tmp; tmp = im[i]; im[i] = im[j]; im[j] = tmp; } let k = n >> 1; while (k > 0 && k <= j) { j -= k; k >>= 1; } j += k; } // Butterflies const sign = inverse ? 1 : -1; for (let len = 2; len <= n; len <<= 1) { const halfLen = len >> 1; const angle = sign * 2 * Math.PI / len; const wRe = Math.cos(angle); const wIm = Math.sin(angle); for (let i = 0; i < n; i += len) { let curRe = 1, curIm = 0; for (let j = 0; j < halfLen; j++) { const u = i + j; const v = u + halfLen; const tRe = curRe * re[v] - curIm * im[v]; const tIm = curRe * im[v] + curIm * re[v]; re[v] = re[u] - tRe; im[v] = im[u] - tIm; re[u] += tRe; im[u] += tIm; const newRe = curRe * wRe - curIm * wIm; curIm = curRe * wIm + curIm * wRe; curRe = newRe; } } } if (inverse) { for (let i = 0; i < n; i++) { re[i] /= n; im[i] /= n; } } } /** * Free resources */ dispose() { this.tokenizer = null; this.decoderSession = null; this.audioEncoderSession = null; this.audioEmbeddingSession = null; this.audioEmbeddingWeight = null; this.audioDetokenizerSession = null; this.vocoderSession = null; this.embedTokensWeight = null; } } // Re-export audio utilities export { loadAudioFile }; export default AudioModel;