--- license: apache-2.0 base_model: Qwen/Qwen3-8B datasets: - allenai/c4 tags: - sparse-attention - qwen3 - custom-code - indexer - experimental - prefill - efficiency - mlx - apple-silicon pipeline_tag: text-generation language: - en --- # Qwen3-8B All-Sparse Indexer > **Experimental research artifact** — a trained Dynamic Sparse Attention (DSA) indexer trained at 2K context length. This repository is intended as an exploratory learned sparse-attention index, not a finished production method. The inference code is written in MLX. A lightweight **sparse-attention indexer** trained to approximate dense attention behavior in [Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B). Conceptually, this is a **DeepSeek-style learned index** in the sense that a small auxiliary network predicts which key-value positions are worth keeping for attention. This is an independent research artifact and is not affiliated with DeepSeek. Early results suggest the approach can work in some settings, but more research is needed. **Runs natively on Apple Silicon (M1/M2/M3/M4) via [MLX](https://github.com/ml-explore/mlx).** No CUDA or bitsandbytes required. ## Model Description This repository contains a Dynamic Sparse Attention (DSA) indexer checkpoint and the MLX runtime code needed to patch it into a standard Qwen3-8B model. The indexer is a small auxiliary network (~72 M parameters total, ~2 M per layer) that runs alongside the frozen base model. For every layer and every query, it predicts a top-k set of key positions that may be useful for attention, allowing exploratory sparse prefill and fixed-budget decode experiments. Key properties: - **Base model**: `mlx-community/Qwen3-8B-4bit` (4-bit quantized, Apple Silicon) - **Framework**: MLX — runs on Apple M-series chips - **Indexer coverage**: All 36 transformer layers (full-model sparse attention) - **Sparse budget**: top_k = 2048 positions per query per layer - **Fixed-size decode cache**: KV cache stays at exactly 2048 entries forever, even for multi-thousand-token outputs - **Current evidence**: promising but preliminary; some quality and retrieval checks work, others degrade or fail depending on context length and top_k ## Benchmarks ### Needle-in-a-Haystack (Retrieval) A short phrase ("The secret code is 7392") was hidden at the midpoint of a context filled with diverse natural-language prose. The model was then asked to retrieve it. These are exploratory single-task measurements, not a comprehensive long-context benchmark. | Context | top_k | Pass | Retention | Prefill Time | Peak Memory | | ------- | ----- | ---- | --------- | ------------ | ----------- | | 2K | 2048 | ✅ | 100% | 9s | 5.4 GB | | 4K | 2048 | ✅ | 55% | 22s | 6.1 GB | | 8K | 2048 | ✅ | 27% | 55s | 7.5 GB | | 16K | 2048 | ❌ | 14% | 131s | 9.2 GB | | 16K | 4096 | ✅ | 28% | 180s | 9.2 GB | | 32K | 2048 | ❌ | 7% | 306s | 13.2 GB | > **Interpretation:** in these runs, retrieval became much less reliable when `top_k` > was too small relative to context length. A larger `top_k` may help on retrieval-style > tasks, but this should be treated as an experimental observation rather than a settled rule. ### Memory (KV Cache) | Context | Dense KV | Sparse KV (top_k=2048) | Savings | | ------- | -------- | ---------------------- | ------- | | 4K | 604 MB | 302 MB | 50% | | 8K | 1,208 MB | 302 MB | 75% | | 16K | 2,416 MB | 302 MB | 87.5% | | 32K | 4,832 MB | 302 MB | 93.7% | *Measured on Qwen3-8B 4-bit, Apple Silicon, MLX.* ### Quality | Benchmark | Dense | Sparse (top_k=2048) | | --------------------------- | ------ | ------------------- | | GSM8K accuracy (4-shot) | 95% | 92% | | PPL on C4 (seq_len=2048) | 13.526 | 13.533 (+0.058%) | | PPL on C4 (seq_len=8192) | 15.628 | 15.653 (+0.16%) | ## Training Details | Parameter | Value | | ------------------------ | -------------------------------------- | | Base model | `Qwen/Qwen3-8B` | | Quantization | 4-bit (MLX) | | Training dataset | `allenai/c4` (English split) | | Training tokens | 15 000 000 | | Validation tokens | 1 000 000 | | Sequence length | 2048 | | Sparse layers | All 36 (layers 0–35) | | top_k | 2048 | | Indexer heads | 6 | | Projection dim | 69 | | RoPE dim | 64 | | Parameters per layer | ~2 003 427 | | Total indexer parameters | ~72 123 372 | | Loss aggregation | per-layer | | Support loss weight | 0.1 | | LR schedule | warmup-cosine (5% warmup, min LR 1e-6) | ## Files | File | Description | | ---------------------------------------------- | ---------------------------------------------------------------- | | `lightning_indexer_best_assembled.safetensors` | Sparse indexer checkpoint — safetensors format | | `run_config.json` | Training and sparse-layer configuration | | `eval_sparse_generate.py` | Sparse patching + MLX runtime + GSM8K evaluation | | `demo.py` | One-command demo: loads model + indexer, runs sample prompts | | `requirements.txt` | Runtime dependencies (mlx, mlx-lm, safetensors, numpy, datasets) | | `ppl_results_assembled.json` | Dense vs. sparse perplexity evaluation summary | ## Quick Start ```bash # Requires Python 3.9+ and an Apple Silicon Mac (M1/M2/M3/M4) # Clone or download this repository, then enter the repo root # If cloning from the Hugging Face Hub, use your actual repo URL. pip install -r requirements.txt python demo.py ``` The demo automatically downloads `mlx-community/Qwen3-8B-4bit` (~5 GB), loads the sparse indexer, and runs three sample prompts showing dense vs. sparse output side by side. ### More options ```bash # Single custom prompt python demo.py --prompt "What causes the northern lights?" # Interactive chat REPL python demo.py --interactive # GSM8K accuracy eval — dense baseline python eval_sparse_generate.py --limit 100 # GSM8K accuracy eval — sparse (fixed-2K decode cache) python eval_sparse_generate.py --limit 100 \ --indexer-path lightning_indexer_best_assembled.safetensors \ --run-config run_config.json # Override top-k budget python eval_sparse_generate.py --top-k 1024 \ --indexer-path lightning_indexer_best_assembled.safetensors \ --run-config run_config.json --limit 50 ``` ### Programmatic usage ```python import json import mlx.core as mx from pathlib import Path from mlx_lm.utils import load as mlx_load from eval_sparse_generate import load_indexers, patch_sparse_generate # Load base model (auto-downloads from HF on first run) model, tokenizer = mlx_load( "mlx-community/Qwen3-8B-4bit", tokenizer_config={"trust_remote_code": True}, ) mx.eval(model.parameters()); mx.synchronize() # Load indexers rc = json.loads(Path("run_config.json").read_text()) dim = int(rc.get("hidden_size", rc.get("metadata", {}).get("hidden_size", 4096))) indexers = load_indexers( "lightning_indexer_best_assembled.safetensors", dim=dim, proj_dim=rc["proj_dim"], n_heads=rc["indexer_heads"], rope_dim=rc["rope_dim"], ) mx.eval([idx.parameters() for idx in indexers.values()]) # Patch — attention is now sparse for ALL steps clear_fn = patch_sparse_generate(model, indexers, top_k=2048) # Generate as normal from mlx_lm.generate import generate_step prompt = "<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n" input_ids = mx.array(tokenizer.encode(prompt)) tokens = [] for tok, _ in generate_step(input_ids, model, max_tokens=128, sampler=lambda x: mx.argmax(x, axis=-1), prefill_step_size=int(input_ids.shape[0])): t = int(tok.item()) if t == tokenizer.eos_token_id: break tokens.append(t) print(tokenizer.decode(tokens)) # Between prompts, clear indexer state clear_fn() ``` > **Note:** The indexer monkey-patches `block.__call__` and `model.make_cache` at runtime. > It does **not** use `from_pretrained()`. You need `eval_sparse_generate.py` alongside the checkpoint. ## Decode Cache Design After prefill, each layer's KV cache is pruned to exactly **top_k = 2048** entries. During decode, every new token is scored by the indexer against all top_k+1 candidates (top_k cached + 1 new), and the lowest-scoring entry is evicted. The cache stays at exactly 2048 entries regardless of how many tokens are generated. This means: - **O(top_k)** per decode step (not O(seq_len)) - **Constant memory** during generation — the KV cache never grows - Designed to keep decode memory bounded even for much longer generations, though broader validation is still needed ## Limitations - Requires Apple Silicon (M1/M2/M3/M4). Runs on Intel Mac CPU but will be slow. - Trained on English C4 data only. Other languages / strongly out-of-distribution domains not evaluated. - Fixed top_k = 2048 across all layers. Per-layer adaptive budgets may improve further. - Tested with `mlx-community/Qwen3-8B-4bit`. Other quantization levels not validated. - Training used sequences of length 2048. Very long contexts (> 32K) are extrapolated. - Quality and retrieval behavior are still preliminary and can vary materially with context length, prompt type, and `top_k`. - Some exploratory retrieval tests pass at 2K-8K with `top_k=2048` and at 16K with `top_k=4096`, but this should not be read as a general guarantee. - The current MLX runtime is a research implementation. Some measured regimes still show worse latency than dense baselines even when KV-cache scaling looks better. ## Citation If you reference this artifact directly, cite the published Hugging Face repository URL for this model card or your associated paper. This repository is a custom-code MLX runtime artifact rather than a standard `from_pretrained()` Transformers checkpoint, so cite the specific published repo you upload.