--- language: - en library_name: transformers pipeline_tag: feature-extraction tags: - audio - speech - autoregressive - transformers - custom_code datasets: - LibriLight license: apache-2.0 pretty_name: AuriStream-1B (1-pred) --- # AuriStream-1B (1-pred) **AuriStream** is a biologically-inspired, GPT-style autoregressive Transformer trained to predict **cochlear tokens** - discrete codes produced by a companion "WavCoch" tokenizer over long speech contexts (through **transformation imitation**). This repository hosts the **1-prediction-head** AuriStream-1B variant trained on **LibriLight (~60k h)** for **~500k steps**. It uses a long context window of (~20 s, ~4096 tokens), learns rich, time-aligned representations (useful for linear probing), and can roll out future tokens to generate **speech continuations**. Inputs are **token IDs**; use it with a WavCoch quantizer for audio->tokens and with a vocoder-enabled WavCoch checkpoint for tokens->audio. --- ## Installation ```bash pip install -U torch torchaudio transformers ``` This model uses custom code; when loading from Hugging Face, pass `trust_remote_code=True`. --- ## Use Case 1) get hidden-state embeddings from a WAV ```python import torch, torchaudio from transformers import AutoModel device = "cuda" if torch.cuda.is_available() else "cpu" # 1) Load the WavCoch tokenizer (audio -> token IDs) quantizer = AutoModel.from_pretrained( "TuKoResearch/WavCochV8192", trust_remote_code=True ).to(device).eval() # 2) Load the AuriStream LM (tokens -> hidden states / next-token preds) lm = AutoModel.from_pretrained( "TuKoResearch/AuriStream1B_1Pred_librilight_500k", trust_remote_code=True ).to(device).eval() # 3) Read an audio file (mono, 16 kHz recommended) wav, sr = torchaudio.load("sample.wav") if wav.size(0) > 1: # stereo -> mono wav = wav.mean(dim=0, keepdim=True) if sr != 16_000: wav = torchaudio.transforms.Resample(sr, 16_000)(wav) sr = 16_000 # 4) Quantize to cochlear token IDs with torch.no_grad(): # quantizer.quantize expects (B, 1, T); returns LongTensor (B, L) token_ids = quantizer.quantize(wav.unsqueeze(0).to(device)) # (1, L) # 5) Forward pass with hidden states with torch.no_grad(): out = lm(token_ids, output_hidden_states=True) last_layer = out["hidden_states"][-1] # (1, T, D) clip_embedding = last_layer.mean(dim=1) # time mean-pool -> (1, D) print("Pooled embedding shape:", clip_embedding.shape) ``` **Notes** * `output_hidden_states=True` returns all layers; choose a layer or pool over time. * For word/phone segments, slice the time axis before pooling. --- ## Use Case 2) generate a speech continuation (token rollout) ```python import torch, torchaudio from transformers import AutoModel device = "cuda" if torch.cuda.is_available() else "cpu" # WavCoch tokenizer (audio->tokens, tokens->cochleagram) quantizer = AutoModel.from_pretrained( "TuKoResearch/WavCochV8192", trust_remote_code=True ).to(device).eval() # AuriStream LM (tokens->next tokens) lm = AutoModel.from_pretrained( "TuKoResearch/AuriStream1B_1Pred_librilight_500k", trust_remote_code=True ).to(device).eval() # Load & prep a short prompt (e.g., 3s of audio at 16 kHz) wav, sr = torchaudio.load("prompt.wav") if wav.size(0) > 1: wav = wav.mean(dim=0, keepdim=True) if sr != 16_000: wav = torchaudio.transforms.Resample(sr, 16_000)(wav) sr = 16_000 prompt_seconds = 3 wav = wav[:, : sr * prompt_seconds] # Quantize prompt to token IDs with torch.no_grad(): prompt_tokens = quantizer.quantize(wav.unsqueeze(0).to(device)) # (1, L) # Decide how many future tokens to generate tokens_per_sec = prompt_tokens.size(1) / float(prompt_seconds) rollout_seconds = 3 rollout_steps = int(round(tokens_per_sec * rollout_seconds)) # Roll out future tokens with torch.no_grad(): # returns (pred_tokens, pred_logits); temperature/top_k/top_p/seed optional pred_tokens, _ = lm.generate( prompt_tokens, rollout_steps, temp=0.7, top_k=50, top_p=0.95, seed=0 ) full_tokens = torch.cat([prompt_tokens, pred_tokens], dim=1) # (1, L+K) ``` --- ## Citation If you use this model, please cite: ```bibtex @misc{tuckute2025cochleartokens, title = {Representing Speech Through Autoregressive Prediction of Cochlear Tokens}, author = {Tuckute, Greta and Kotar, Klemen and Fedorenko, Evelina and Yamins, Daniel L. K.}, year = {2025}, eprint = {2508.11598}, archivePrefix = {arXiv}, url = {https://arxiv.org/abs/2508.11598} } ```