#!/usr/bin/env python3 """ Quick demo: dense vs sparse attention on Qwen3-8B (Apple Silicon / MLX). Usage: python demo.py # three built-in sample prompts python demo.py --prompt "What is entropy?" # single custom prompt python demo.py --interactive # interactive chat REPL """ from __future__ import annotations import argparse import json import sys import time from pathlib import Path REPO_DIR = Path(__file__).resolve().parent SAMPLE_PROMPTS = [ "Explain how a transformer model processes a long input sequence, step by step.", "A train leaves city A at 9 AM traveling at 60 mph. Another train leaves city B " "(300 miles away) at 10 AM traveling at 80 mph toward city A. When do they meet?", "Write a Python function that finds the longest palindromic substring in a string.", ] def _config_value(config: dict, key: str, default: int) -> int: metadata = config.get("metadata", {}) return int(config.get(key, metadata.get(key, default))) def _check_deps() -> None: missing = [] for pkg in ["mlx", "mlx_lm", "safetensors", "numpy"]: try: __import__(pkg) except ImportError: missing.append(pkg) if missing: print(f"Missing packages: {', '.join(missing)}") print(f"Install with: pip install -r {REPO_DIR / 'requirements.txt'}") sys.exit(1) def _find_indexer() -> Path: path = REPO_DIR / "lightning_indexer_best_assembled.safetensors" if path.exists(): return path print("Error: No indexer checkpoint found.") print("Expected: lightning_indexer_best_assembled.safetensors") sys.exit(1) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description="Dense vs sparse attention demo on Qwen3-8B (MLX / Apple Silicon).", ) p.add_argument("--prompt", default="", help="Single prompt. Omit to run built-in samples.") p.add_argument("--interactive", action="store_true", help="Interactive chat REPL.") p.add_argument("--max-new-tokens", type=int, default=256) p.add_argument("--model", default="mlx-community/Qwen3-8B-4bit") p.add_argument("--top-k", type=int, default=0) return p.parse_args() def _generate(model, tokenizer, prompt: str, max_new_tokens: int) -> tuple[str, float, int]: import mlx.core as mx from mlx_lm.generate import generate_step input_ids = mx.array(tokenizer.encode(prompt)) tokens = [] t0 = time.perf_counter() for tok, _ in generate_step( input_ids, model, max_tokens=max_new_tokens, sampler=lambda x: mx.argmax(x, axis=-1), prefill_step_size=int(input_ids.shape[0]), ): t = int(tok.item()) if hasattr(tok, 'item') else int(tok) if t == tokenizer.eos_token_id: break tokens.append(t) elapsed = time.perf_counter() - t0 return tokenizer.decode(tokens), elapsed, len(tokens) def main() -> None: args = parse_args() print("Checking dependencies...") _check_deps() import mlx.core as mx from mlx_lm.utils import load as mlx_load from eval_sparse_generate import load_indexers, patch_sparse_generate config_path = REPO_DIR / "run_config.json" if not config_path.exists(): print(f"Error: run_config.json not found at {config_path}") sys.exit(1) rc = json.loads(config_path.read_text()) indexer_path = _find_indexer() dim = _config_value(rc, "hidden_size", 4096) proj_dim = int(rc.get("proj_dim", 69)) n_heads = int(rc.get("indexer_heads", 6)) rope_dim = int(rc.get("rope_dim", 64)) top_k = args.top_k or int(rc.get("top_k", 2048)) print(f"\nModel: {args.model}") print(f"Indexer: {indexer_path.name}") print(f"top_k: {top_k} ({len(rc.get('sparse_layers', []))} layers)") print("\nLoading model (first run downloads ~5 GB)...") model, tokenizer = mlx_load(args.model, tokenizer_config={"trust_remote_code": True}) mx.eval(model.parameters()); mx.synchronize() # Save the original (dense) block call before patching block_cls = model.model.layers[0].__class__ dense_call = block_cls.__call__ print(f"Loading indexers from {indexer_path.name} ...") indexers = load_indexers(str(indexer_path), dim, proj_dim, n_heads, rope_dim) mx.eval([idx.parameters() for idx in indexers.values()]); mx.synchronize() clear_fn = patch_sparse_generate(model, indexers, top_k) sparse_call = block_cls.__call__ # patched version def run_dense(prompt_text: str) -> tuple[str, float, int]: block_cls.__call__ = dense_call full = f"<|im_start|>user\n{prompt_text}<|im_end|>\n<|im_start|>assistant\n" return _generate(model, tokenizer, full, args.max_new_tokens) def run_sparse(prompt_text: str) -> tuple[str, float, int]: block_cls.__call__ = sparse_call clear_fn() full = f"<|im_start|>user\n{prompt_text}<|im_end|>\n<|im_start|>assistant\n" return _generate(model, tokenizer, full, args.max_new_tokens) def compare(prompt_text: str) -> None: print(f"\nPrompt: {prompt_text[:120]}{'...' if len(prompt_text) > 120 else ''}") dense_text, dense_t, dense_n = run_dense(prompt_text) print(f"\n[DENSE] {dense_n} tok / {dense_t:.1f}s ({dense_n/max(dense_t,0.1):.0f} tok/s)") print(dense_text[:400]) sparse_text, sparse_t, sparse_n = run_sparse(prompt_text) print(f"\n[SPARSE] {sparse_n} tok / {sparse_t:.1f}s ({sparse_n/max(sparse_t,0.1):.0f} tok/s)") print(sparse_text[:400]) match = dense_text.strip() == sparse_text.strip() print(f"\n Exact match: {'YES' if match else 'no (slight divergence, expected)'}") print("-" * 60) if args.interactive: print("\n" + "=" * 60) print("Interactive compare mode — /quit to exit") print("=" * 60) while True: try: user_text = input("\nYou> ").strip() except (EOFError, KeyboardInterrupt): print("\nBye!") break if not user_text or user_text in {"/quit", "quit", "exit"}: break compare(user_text) return prompts = [args.prompt] if args.prompt else SAMPLE_PROMPTS print(f"\nRunning {len(prompts)} prompt(s)...\n" + "=" * 60) for i, p in enumerate(prompts, 1): print(f"\n[{i}/{len(prompts)}]") compare(p) print("\nDone. Sparse outputs should closely match dense.") print("For GSM8K accuracy eval: python eval_sparse_generate.py --limit 100") if __name__ == "__main__": main()