#!/usr/bin/env python3 """ MLX sparse generative evaluation — sparse attention on EVERY token (prefill + decode). Self-contained: requires only mlx, mlx-lm, safetensors, numpy. Runs natively on Apple Silicon (M1/M2/M3/M4). Decode uses a FIXED-SIZE KV cache: after prefill, exactly top_k entries are kept per layer, and each decode step scores all top_k+1 candidates and evicts the lowest- scoring one. The cache never grows beyond top_k regardless of output length. Usage (dense baseline): python eval_sparse_generate.py --model mlx-community/Qwen3-8B-4bit \\ --tasks gsm8k --limit 50 Usage (sparse, fixed-2K decode cache): python eval_sparse_generate.py --model mlx-community/Qwen3-8B-4bit \\ --indexer-path lightning_indexer_best_assembled.safetensors \\ --run-config run_config.json --top-k 2048 --tasks gsm8k --limit 50 """ from __future__ import annotations import argparse import json import math import re import time from pathlib import Path from typing import Callable import numpy as np import mlx.core as mx import mlx.nn as nn from mlx_lm.generate import generate_step from mlx_lm.models.base import create_causal_mask, scaled_dot_product_attention from mlx_lm.models.cache import _BaseCache, KVCache from mlx_lm.utils import load as mlx_load from safetensors import safe_open # ── Indexer ─────────────────────────────────────────────────────────────────── def _config_value(config: dict, key: str, default: int) -> int: metadata = config.get("metadata", {}) return int(config.get(key, metadata.get(key, default))) def _layer_norm_last(x, eps: float = 1e-5): m = mx.mean(x, axis=-1, keepdims=True) c = x - m return c * mx.rsqrt(mx.mean(c * c, axis=-1, keepdims=True) + eps) class Indexer(nn.Module): """Lightweight sparse-attention indexer: predicts which KV positions each query needs.""" def __init__(self, dim: int, proj_dim: int, n_heads: int, rope_dim: int) -> None: super().__init__() self.proj_dim = proj_dim self.n_heads = n_heads self.rope_dim = min(rope_dim, proj_dim) self.q_proj = nn.Linear(dim, n_heads * proj_dim, bias=True) self.k_proj = nn.Linear(dim, proj_dim, bias=True) self.weight_proj = nn.Linear(dim, n_heads, bias=False) self.softmax_scale = 1.0 / math.sqrt(float(proj_dim)) self.weight_scale = 1.0 / math.sqrt(float(n_heads)) def encode(self, x, rope_fn=None): """Encode hidden states into query, key, and weight tensors. Args: x: [batch, seq_len, dim] rope_fn: optional function to apply RoPE to queries and keys Returns: q: [batch, seq_len, n_heads, proj_dim] k: [batch, seq_len, proj_dim] w: [batch, seq_len, n_heads] """ b, t, _ = x.shape q = self.q_proj(x).reshape(b, t, self.n_heads, self.proj_dim) k = _layer_norm_last(self.k_proj(x)) if rope_fn is not None and self.rope_dim > 0: q_pe = rope_fn(q[:, :, :, :self.rope_dim].transpose(0, 2, 1, 3)).transpose(0, 2, 1, 3) q = mx.concatenate([q_pe, q[:, :, :, self.rope_dim:]], axis=-1) k_pe = rope_fn(k[:, None, :, :self.rope_dim]).squeeze(1) k = mx.concatenate([k_pe, k[:, :, self.rope_dim:]], axis=-1) w = self.weight_proj(x.astype(mx.float32)) * self.weight_scale return q, k, w def load_indexers(path: str, dim: int, proj_dim: int, n_heads: int, rope_dim: int) -> dict[int, Indexer]: """Load per-layer indexers from a safetensors checkpoint.""" f = safe_open(str(path), framework="numpy") keys = list(f.keys()) layer_ids = sorted(set(int(k.split('.')[1]) for k in keys if k.startswith('layers.'))) out: dict[int, Indexer] = {} for li in layer_ids: idx = Indexer(dim, proj_dim, n_heads, rope_dim) idx.q_proj.weight = mx.array(f.get_tensor(f"layers.{li}.q_proj.weight")) idx.q_proj.bias = mx.array(f.get_tensor(f"layers.{li}.q_proj.bias")) idx.k_proj.weight = mx.array(f.get_tensor(f"layers.{li}.k_proj.weight")) idx.k_proj.bias = mx.array(f.get_tensor(f"layers.{li}.k_proj.bias")) if f"layers.{li}.weight_proj.weight" in keys: idx.weight_proj.weight = mx.array(f.get_tensor(f"layers.{li}.weight_proj.weight")) out[li] = idx return out # ── Sparse KV cache ─────────────────────────────────────────────────────────── class SparseKVCache(_BaseCache): """KV cache that stays fixed at top_k entries. Never grows beyond prefill budget.""" def __init__(self) -> None: self.keys = self.values = None self.offset = 0 @property def state(self): return [] if self.keys is None else (self.keys, self.values) @state.setter def state(self, v): if not v: self.keys = self.values = None self.offset = 0 return self.keys, self.values = v def __len__(self) -> int: return 0 if self.keys is None else int(self.keys.shape[2]) def update_and_fetch(self, keys, values): """Grow cache unconditionally (used by dense path / non-patched layers).""" prev = self.offset if self.keys is None: self.keys, self.values = keys, values else: self.keys = mx.concatenate([self.keys, keys], axis=2) self.values = mx.concatenate([self.values, values], axis=2) self.offset += keys.shape[2] return self.keys, self.values def replace(self, keys, values, offset: int): self.keys, self.values, self.offset = keys, values, int(offset) return keys, values def make_mask(self, N: int, window_size=None, return_array: bool = False): if N == 1 and window_size is None and not return_array: return None return create_causal_mask(N, offset=self.offset, window_size=window_size) # ── Sparse attention patch ──────────────────────────────────────────────────── def patch_sparse_generate( model, indexers: dict[int, Indexer], top_k: int, ) -> Callable[[], None]: """Monkey-patch model with sparse attention for ALL steps (prefill + decode). The decode cache is fixed at exactly top_k entries per layer: after each decode step, the indexer scores all top_k+1 candidates and evicts the lowest. Args: model: mlx-lm model (already loaded, parameters eval'd) indexers: dict mapping layer_idx → Indexer top_k: number of KV positions to keep per layer Returns: clear_buffers(): call this between prompts to reset indexer state """ layers = model.model.layers block_cls = layers[0].__class__ original_call = block_cls.__call__ had_make_cache = hasattr(model, "make_cache") original_make_cache = getattr(model, "make_cache", None) sparse_set = {id(layers[li]): (li, layers[li].self_attn, indexers[li]) for li in indexers} key_buffers: dict[int, mx.array] = {} # layer_idx → [1, top_k, proj_dim] # ── QKV projection ──────────────────────────────────────────────────────── def _project_qkv(attn, x, offset=None): b, t, _ = x.shape q = attn.q_proj(x).reshape(b, t, attn.n_heads, -1) k = attn.k_proj(x).reshape(b, t, attn.n_kv_heads, -1) v = attn.v_proj(x).reshape(b, t, attn.n_kv_heads, -1) if hasattr(attn, 'q_norm'): q = attn.q_norm(q) if hasattr(attn, 'k_norm'): k = attn.k_norm(k) q, k, v = q.transpose(0, 2, 1, 3), k.transpose(0, 2, 1, 3), v.transpose(0, 2, 1, 3) rope = (lambda a: attn.rope(a, offset=offset)) if offset is not None else attn.rope return rope(q), rope(k), v # ── Top-k selection ─────────────────────────────────────────────────────── def _select_topk(scores: np.ndarray, keep: int, force_last: bool = False) -> np.ndarray: n = scores.shape[0] if keep >= n: return np.arange(n, dtype=np.int32) selected = np.argpartition(-scores, keep - 1)[:keep].astype(np.int32) if force_last and (n - 1) not in selected: selected[int(np.argmin(scores[selected]))] = n - 1 return np.sort(np.unique(selected)) def _select_topk_decode(scores: np.ndarray, keep: int, local_window: int = 128) -> np.ndarray: """Top-k for decode: always protect the most recent local_window tokens from eviction. Recently-generated tokens sit at the HIGH end of the index array (new tokens are appended last and sort to the end). Without this window, the indexer — which was trained on prefill patterns — underscores recent decode tokens, causing them to be evicted and leading to repetition loops at 4K+ generated tokens. """ n = scores.shape[0] if keep >= n: return np.arange(n, dtype=np.int32) # Always keep the last `local_window` entries (most recent tokens in the cache) n_protect = min(local_window, keep) protected = set(range(n - n_protect, n)) n_free = keep - len(protected) if n_free > 0: masked = scores.copy() for i in protected: masked[i] = -np.inf extra = np.argpartition(-masked, n_free - 1)[:n_free].astype(np.int32) selected = np.array(sorted(set(extra.tolist()) | protected), dtype=np.int32) else: selected = np.array(sorted(protected)[:keep], dtype=np.int32) return np.sort(np.unique(selected)) # ── Sparse prefill ──────────────────────────────────────────────────────── def _sparse_prefill(attn, indexer: Indexer, x, layer_idx: int): """Chunked sparse prefill. Phase 1: build all per-chunk score graphs lazily in MLX (no GPU sync per chunk). Phase 2: single mx.eval for all chunks — 1 GPU→CPU sync per layer instead of N. Phase 3: mask-building + sparse attention using the materialized indices. """ b, t, _ = x.shape q, k, v = _project_qkv(attn, x) q_enc_mx, k_enc_mx, w_enc_mx = indexer.encode(x, rope_fn=attn.rope) q_enc_0 = q_enc_mx[0] # [t, n_heads, proj_dim] — stays on GPU k_enc_0 = k_enc_mx[0] # [t, proj_dim] w_enc_0 = w_enc_mx[0] # [t, n_heads] scale = float(indexer.softmax_scale) n_heads = int(q_enc_0.shape[1]) proj_dim = int(q_enc_0.shape[2]) chunk = 256 meta = [] # (qs, qe_end, keep) lazy_topk = [] # MLX lazy arrays or None (trivial case) # ── Phase 1: build lazy scoring graphs (no eval yet) ────────────────── for qs in range(0, t, chunk): qe_end = min(t, qs + chunk) q_len = qe_end - qs n_keys = qe_end keep = min(top_k, n_keys) meta.append((qs, qe_end, keep)) if keep < n_keys: q_c = q_enc_0[qs:qe_end] # [q_len, n_heads, proj_dim] w_c = w_enc_0[qs:qe_end] # [q_len, n_heads] k_p = k_enc_0[:n_keys] # [n_keys, proj_dim] raw = mx.maximum( 0., mx.matmul(q_c.reshape(q_len * n_heads, proj_dim), k_p.T) ) # [q_len*n_heads, n_keys] scores = ( raw.reshape(q_len, n_heads, n_keys) * w_c[:, :, None] ).sum(1) * scale # [q_len, n_keys] q_pos = mx.arange(qs, qe_end)[:, None] k_pos = mx.arange(n_keys)[None, :] scores = mx.where(k_pos <= q_pos, scores, mx.array(-1e9, dtype=scores.dtype)) lazy_topk.append(mx.argsort(-scores, axis=1)[:, :keep]) # lazy else: lazy_topk.append(None) # ── Phase 2: single GPU sync for all chunks ──────────────────────────── nonnull = [m for m in lazy_topk if m is not None] if nonnull: mx.eval(*nonnull) sel_mats = [] for lz, (qs, qe_end, keep) in zip(lazy_topk, meta): q_len = qe_end - qs if lz is not None: sel_mats.append(np.array(lz, dtype=np.int32)) else: sel_mats.append(np.broadcast_to(np.arange(keep, dtype=np.int32), (q_len, keep)).copy()) # ── Phase 3: mask-building + sparse attention ────────────────────────── outputs = [] for sel_mat, (qs, qe_end, keep) in zip(sel_mats, meta): q_len = qe_end - qs # Vectorized causal filtering (no Python loop over rows) q_pos = np.arange(qs, qe_end)[:, None] # [q_len, 1] valid = (sel_mat >= 0) & (sel_mat <= q_pos) # [q_len, keep] ti = np.unique(sel_mat[valid]) tp = mx.array(ti) qs_ = q[:, :, qs:qe_end, :] ks_ = mx.take(k, tp, axis=2) vs_ = mx.take(v, tp, axis=2) # Build boolean mask: mn[r, j] = True iff ti[j] is selected for query r r_idx, c_idx = np.nonzero(valid) ti_idx = np.searchsorted(ti, sel_mat[r_idx, c_idx]) mn = np.zeros((q_len, len(ti)), bool) if r_idx.size: mn[r_idx, ti_idx] = True out = scaled_dot_product_attention( qs_, ks_, vs_, cache=None, scale=attn.scale, mask=mx.array(mn)[None, None] ) outputs.append(out) out = mx.concatenate(outputs, 2).transpose(0, 2, 1, 3).reshape(b, t, -1) # ── Cache selection: score from last query in MLX ───────────────────── last_q = q_enc_0[-1] # [n_heads, proj_dim] last_w = w_enc_0[-1] # [n_heads] raw = mx.maximum(0., mx.matmul(last_q, k_enc_0.T)) # [n_heads, t] ls_mx = (raw * last_w[:, None]).sum(0) * scale # [t] sel = _select_topk(np.array(ls_mx, np.float32), min(top_k, t), force_last=True) sm = mx.array(sel) kb = mx.stop_gradient(mx.take(k_enc_mx, sm, axis=1)) mx.eval(kb) key_buffers[layer_idx] = kb return attn.o_proj(out), mx.take(k, sm, 2), mx.take(v, sm, 2) # ── Sparse decode (fixed-2K cache) ──────────────────────────────────────── def _sparse_decode(attn, indexer: Indexer, x, cache: SparseKVCache, layer_idx: int): """Decode with fixed-size KV cache: scores top_k+1 positions, evicts 1 per step.""" b, _, _ = x.shape off = cache.offset q, k_new, v_new = _project_qkv(attn, x, offset=off) rope_fn = lambda a: attn.rope(a, offset=off) q_enc, k_enc_new, w_enc = indexer.encode(x, rope_fn=rope_fn) # q_enc: [b, 1, n_heads, proj_dim] # k_enc_new: [b, 1, proj_dim] # w_enc: [b, 1, n_heads] k_hist = key_buffers[layer_idx] # [b, top_k, proj_dim] full_k_enc = mx.concatenate([k_hist, k_enc_new], axis=1) # [b, top_k+1, proj_dim] full_k = mx.concatenate([cache.keys, k_new], axis=2) # [b, n_kv_heads, top_k+1, head_dim] full_v = mx.concatenate([cache.values, v_new], axis=2) # Score all top_k+1 positions for the current query scores = mx.maximum(0., mx.einsum("bqhd,bkd->bqhk", q_enc, full_k_enc)) # [b, 1, n_heads, top_k+1] w_last = w_enc[:, -1:, :] # [b, 1, n_heads] scores = (scores * w_last[:, :, :, None]).sum(2) * indexer.softmax_scale # [b, 1, top_k+1] scores_np = np.array(scores[0, 0], np.float32) # [top_k+1] n_full = int(full_k_enc.shape[1]) keep = min(top_k, n_full) sel = _select_topk_decode(scores_np, keep) # keeps newest + local window of recent tokens sm = mx.array(sel) k_sel = mx.take(full_k, sm, axis=2) v_sel = mx.take(full_v, sm, axis=2) out = scaled_dot_product_attention(q, k_sel, v_sel, cache=None, scale=attn.scale, mask=None) cache.replace(k_sel, v_sel, offset=off + 1) # Update key buffer to match the pruned set kb = mx.stop_gradient(mx.take(full_k_enc, sm, axis=1)) mx.eval(kb) key_buffers[layer_idx] = kb return attn.o_proj(out.transpose(0, 2, 1, 3).reshape(b, 1, -1)) # ── Patched block forward ───────────────────────────────────────────────── def patched_call(self, x, mask=None, cache=None, **kw): lid = id(self) if lid not in sparse_set: return original_call(self, x, mask=mask, cache=cache, **kw) layer_idx, attn, indexer = sparse_set[lid] is_decode = cache is not None and cache.offset > 0 r = self.input_layernorm(x) if is_decode: h = _sparse_decode(attn, indexer, r, cache, layer_idx) else: h, rk, rv = _sparse_prefill(attn, indexer, r, layer_idx) if cache is not None: cache.replace(rk, rv, offset=r.shape[1]) x = x + h return x + self.mlp(self.post_attention_layernorm(x)) def make_cache(): caches = [KVCache() for _ in range(len(layers))] for li in indexers: caches[li] = SparseKVCache() return caches def clear_buffers(): key_buffers.clear() def restore(): block_cls.__call__ = original_call if had_make_cache: model.make_cache = original_make_cache elif hasattr(model, "make_cache"): delattr(model, "make_cache") key_buffers.clear() block_cls.__call__ = patched_call model.make_cache = make_cache print(f"Patched {len(sparse_set)} MLX layers (top_k={top_k}), fixed-size decode cache") return clear_buffers # ── GSM8K evaluation ────────────────────────────────────────────────────────── GSM8K_4SHOT_RAW = [ ( "Weng earns $12 an hour for babysitting. Yesterday, she just did 50 minutes of babysitting. " "How much did she earn?", "Weng earns 12/60 = $0.2 per minute. Working 50 minutes, she earned 0.2 x 50 = $10. " "The answer is 10.", ), ( "Betty is saving money for a new wallet which costs $100. Betty has only half of the money " "she needs. Her parents decided to give her $15 for that purpose, and her grandparents twice " "as much as her parents. How much more money does Betty need to buy the wallet?", "In the beginning, Betty has only 100/2 = $50. Betty's grandparents gave her 15*2 = $30. " "Betty has 50+15+30 = $95. Betty needs 100-95 = $5 more. The answer is 5.", ), ( "Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she " "read twice as many pages as yesterday. If she wants to read half of the remaining pages " "tomorrow, how many pages should she read tomorrow?", "Maila read 12 x 2 = 24 pages today. So she was able to read a total of 12+24 = 36 pages " "since yesterday. There are 120-36 = 84 pages left to be read. Since she wants to read half " "of the remaining pages, she should read 84/2 = 42 pages tomorrow. The answer is 42.", ), ( "James writes a 3-page letter to 2 different friends twice a week. How many pages does he " "write a year?", "He writes each friend 3*2=6 pages a week. So he writes 6*2=12 pages a week. " "That means he writes 12*52=624 pages a year. The answer is 624.", ), ] def _build_gsm8k_prompt(question: str, tokenizer) -> str: """Build GSM8K prompt using chat template (matches PyTorch eval that achieved 92%). Without the chat template, Qwen3 continues generating new Q&A pairs after the answer, causing extract_number to pick a number from the spurious continuation rather than the actual answer. """ if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None: messages = [] for q, a in GSM8K_4SHOT_RAW: messages.append({"role": "user", "content": q}) messages.append({"role": "assistant", "content": a}) messages.append({"role": "user", "content": question}) return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Fallback: raw 4-shot (not recommended for Qwen3 — model never generates EOS) shots = "\n\n".join( f"Question: {q}\nAnswer: {a}" for q, a in GSM8K_4SHOT_RAW ) return shots + f"\n\nQuestion: {question}\nAnswer:" def extract_number(text: str): text = text.replace(",", "") nums = re.findall(r"-?\d+(?:\.\d+)?", text) return nums[-1] if nums else None def eval_gsm8k(model, tokenizer, limit: int, max_new_tokens: int, clear_fn=None): from datasets import load_dataset ds = load_dataset("openai/gsm8k", "main", split="test", trust_remote_code=True) examples = list(ds)[:limit] correct, total = 0, 0 t0 = time.time() for ex in examples: if clear_fn: clear_fn() prompt = _build_gsm8k_prompt(ex['question'], tokenizer) input_ids = mx.array(tokenizer.encode(prompt)) tokens = [] for token, _ in generate_step( input_ids, model, max_tokens=max_new_tokens, sampler=lambda x: mx.argmax(x, axis=-1), prefill_step_size=max(1, int(input_ids.shape[0])), ): t = int(token.item()) if hasattr(token, 'item') else int(token) if t == tokenizer.eos_token_id: break tokens.append(t) generated = tokenizer.decode(tokens) pred = extract_number(generated) gold = extract_number(ex["answer"]) match = pred is not None and gold is not None and str(pred) == str(gold) correct += int(match) total += 1 print(f" [{total}/{limit}] gen={len(tokens)} pred={pred} gold={gold} " f"{'OK' if match else 'X'} ({time.time() - t0:.0f}s)", flush=True) elapsed = time.time() - t0 acc = correct / total if total else 0.0 print(f"\n GSM8K: {correct}/{total} = {acc:.3f} ({elapsed:.1f}s)") return acc, total, elapsed # ── CLI ─────────────────────────────────────────────────────────────────────── def main(): p = argparse.ArgumentParser( description="Sparse attention evaluation for Qwen3-8B on Apple Silicon.", formatter_class=argparse.RawDescriptionHelpFormatter, epilog="""\ Examples: # Dense baseline (50 GSM8K problems) python eval_sparse_generate.py --limit 50 # Sparse with fixed-2K decode cache python eval_sparse_generate.py --limit 50 \\ --indexer-path lightning_indexer_best_assembled.safetensors \\ --run-config run_config.json # Different model / top-k python eval_sparse_generate.py --model mlx-community/Qwen3-8B-4bit \\ --top-k 1024 --indexer-path lightning_indexer_best_assembled.safetensors \\ --run-config run_config.json --limit 100 """, ) p.add_argument("--model", default="mlx-community/Qwen3-8B-4bit") p.add_argument("--indexer-path", default="") p.add_argument("--run-config", default="run_config.json") p.add_argument("--top-k", type=int, default=0) p.add_argument("--tasks", default="gsm8k") p.add_argument("--limit", type=int, default=10) p.add_argument("--max-new-tokens", type=int, default=4096) p.add_argument("--out", default="") p.add_argument("--seed", type=int, default=42) args = p.parse_args() np.random.seed(args.seed) print(f"Loading {args.model} ...") model, tokenizer = mlx_load(args.model, tokenizer_config={"trust_remote_code": True}) mx.eval(model.parameters()) mx.synchronize() clear_fn = None sparse = bool(args.indexer_path) mode = "sparse" if sparse else "dense" if sparse: rc_path = Path(args.run_config) if not rc_path.exists(): raise FileNotFoundError(f"run_config not found: {rc_path}") rc = json.loads(rc_path.read_text()) dim = _config_value(rc, "hidden_size", 4096) proj_dim = int(rc.get("proj_dim", 69)) n_heads = int(rc.get("indexer_heads", 6)) rope_dim = int(rc.get("rope_dim", 64)) tk = args.top_k if args.top_k > 0 else int(rc.get("top_k", 2048)) print(f"Loading indexers from {args.indexer_path} ...") indexers = load_indexers(args.indexer_path, dim, proj_dim, n_heads, rope_dim) mx.eval([idx.parameters() for idx in indexers.values()]) mx.synchronize() clear_fn = patch_sparse_generate(model, indexers, tk) print(f"Mode: SPARSE (top_k={tk}, {len(indexers)} layers)") else: print("Mode: DENSE") results = { "model": args.model, "mode": mode, "limit": args.limit, "max_new_tokens": args.max_new_tokens, } for task in args.tasks.split(","): task = task.strip() if task == "gsm8k": acc, n, el = eval_gsm8k(model, tokenizer, args.limit, args.max_new_tokens, clear_fn) results["gsm8k"] = {"acc": acc, "n": n, "elapsed_s": round(el, 1)} print(f"\n── SUMMARY ({mode}) ──") for t in args.tasks.split(","): t = t.strip() if t in results and isinstance(results[t], dict): print(f" {t:20s} {results[t]['acc']:.4f} (n={results[t]['n']})") if args.out: Path(args.out).write_text(json.dumps(results, indent=2)) print(f"Saved to {args.out}") if __name__ == "__main__": main()