taf-agent / python /taf_browser.py
karlexmarin's picture
feat: TAF Agent v0.1 — client-side transformer diagnostic
6ab0441
raw
history blame
39.9 kB
"""
TAF Browser — Pyodide-compatible TAF formulas + recipes.
Pure-Python deterministic computations of TAF (Thermodynamic Attention Framework)
formulas, plus 5 cross-section recipes for the most common viability questions.
Author: Carles Marin <transformerkmarin@gmail.com>
License: Apache-2.0
"""
from __future__ import annotations
import math
import json
# ════════════════════════════════════════════════════════════════════════════
# §26 — γ-Thermodynamics (OUR contribution)
# ════════════════════════════════════════════════════════════════════════════
def gamma_pade(theta: float, T_eval: int) -> float:
"""§26.1 — γ = (2θ - T√2)/(2θ + T√2)"""
z_sqrt2 = T_eval * math.sqrt(2)
return (2 * theta - z_sqrt2) / (2 * theta + z_sqrt2)
def gamma_decompose(gamma_pade_val, has_GQA=False, has_SWA=False, n_params=0.0) -> dict:
"""§26.10 — 5-axis decomposition (n=23 OLS, paper sesión 28)."""
delta_GQA = +0.11 if has_GQA else 0.0
delta_SWA = -0.21 if has_SWA else 0.0
delta_post_IH = -0.15 if n_params >= 4e8 else 0.0
return {
"pade_centroid": gamma_pade_val,
"delta_GQA": delta_GQA,
"delta_SWA": delta_SWA,
"delta_post_IH": delta_post_IH,
"gamma_corrected": gamma_pade_val + delta_GQA + delta_SWA + delta_post_IH,
}
def d_horizon(theta: float, gamma: float):
"""§26.2 — d_h = θ(1-γ)√2/(1+γ). None if γ outside (0,1)."""
if gamma <= 0 or gamma >= 1:
return None
return theta * (1 - gamma) * math.sqrt(2) / (1 + gamma)
def l_niah_c(d_horizon_val):
"""§26.5 — L_NIAH^c = 2·d_horizon."""
return None if d_horizon_val is None else 2 * d_horizon_val
def chi_susceptibility(gamma: float) -> float:
"""§26.16 — χ = 1/|γ-1|."""
return float('inf') if gamma == 1.0 else 1.0 / abs(gamma - 1.0)
def p_hallucinate(L: int, theta: float, gamma: float):
"""§26.9 — Horizon-overshoot probability."""
dh = d_horizon(theta, gamma)
if dh is None or L <= 0:
return None
chi = chi_susceptibility(gamma)
if chi == float('inf'):
return None
geom = max(0.0, 1.0 - (dh / L) ** (1 - gamma))
return geom * (math.sqrt(chi) / (1 + math.sqrt(chi)))
def theta_design(gamma_target: float, T_eval: int) -> float:
"""§26.3 — θ to land at γ_target at T_eval (Padé inverse)."""
if gamma_target >= 1 or gamma_target <= -1:
raise ValueError("gamma_target must be in (-1, 1)")
return T_eval * math.sqrt(2) * (1 + gamma_target) / (2 * (1 - gamma_target))
def alpha_opt(gamma_target: float, T_eval: int, theta_nominal: float) -> float:
"""§26.4 — α = θ_design / θ_nominal."""
return theta_design(gamma_target, T_eval) / theta_nominal
def df_window(gamma: float, N: int, f: float = 0.90):
"""§26.7 — KV compression window. None outside [0.65, 0.85] zone."""
if not (0.65 <= gamma <= 0.85):
return None
if gamma >= 1:
return int(f * N)
inner = (1 - f) + f * N ** (1 - gamma)
return int(math.ceil(inner ** (1 / (1 - gamma))))
def kv_soft_decay_regime(theta: float, gamma: float, T_train: int) -> str:
"""§26.8 — Soft decay régimen-bound. d_h ≳ T_train/2 ⇒ applies."""
dh = d_horizon(theta, gamma)
if dh is None:
return "use-hard-cutoff"
ratio = dh / max(1, T_train / 2)
if ratio >= 1.2:
return "applies"
if ratio >= 0.8:
return "borderline"
return "use-hard-cutoff"
# ════════════════════════════════════════════════════════════════════════════
# §17 — Pre-training viability formulas
# ════════════════════════════════════════════════════════════════════════════
def chinchilla_optimal_tokens(N_params: float, ratio: float = 20.0) -> float:
"""§17.30 — Chinchilla 20:1 token budget. D = ratio · N."""
return ratio * N_params
def chinchilla_optimal_N(D_tokens: float, ratio: float = 20.0) -> float:
"""§17.30 inverse — given D tokens, optimal N = D/20."""
return D_tokens / ratio
def training_flops(N_params: float, D_tokens: float) -> float:
"""§17.10 — C ≈ 6·N·D total training FLOPs."""
return 6 * N_params * D_tokens
def training_memory_16N(N_params: float) -> dict:
"""§17.20 — total memory ≈ 16·N bytes (model + grads + Adam moments)."""
bytes_total = 16 * N_params
return {
"bytes": bytes_total,
"GB": bytes_total / 1e9,
}
def emergent_threshold(N_params: float) -> str:
"""§17.60 — capability threshold heuristic (Wei 2022)."""
if N_params >= 1e11:
return "above 100B — strong reasoning capabilities expected"
if N_params >= 1e10:
return "above 10B — most emergent capabilities present"
if N_params >= 1e9:
return "above 1B — basic instruction-following, not strong reasoning"
if N_params >= 1e8:
return "above 100M — useful for narrow tasks, no emergence"
return "below 100M — domain-specific tasks only"
# ════════════════════════════════════════════════════════════════════════════
# §19 — Inference economics
# ════════════════════════════════════════════════════════════════════════════
def kv_cache_memory(n_layers, n_kv_heads, d_head, seq_len, bytes_per_element=2.0) -> dict:
"""§19.1 — bytes = 2·L·n_kv·d_h·seq·B."""
bytes_total = 2 * n_layers * n_kv_heads * d_head * seq_len * bytes_per_element
return {"bytes": bytes_total, "MB": bytes_total / 1e6, "GB": bytes_total / 1e9}
def model_weights_memory(N_params, bytes_per_element=2.0) -> dict:
"""Inference memory for model weights only (BF16=2, INT8=1, INT4=0.5)."""
return {"GB": N_params * bytes_per_element / 1e9}
def inference_decode_throughput(N_params, hbm_GB_per_s, bytes_per_element=2.0) -> float:
"""§19.7 — memory-bound decode: tokens/sec = HBM_BW / model_size."""
model_GB = N_params * bytes_per_element / 1e9
return hbm_GB_per_s / model_GB
# ════════════════════════════════════════════════════════════════════════════
# §20 — Hardware catalog (curated from vendor docs 2026)
# ════════════════════════════════════════════════════════════════════════════
GPU_CATALOG = {
# name: {bf16_TFLOPs, hbm_GB, hbm_GB_s, cloud_USD_per_h_spot, tdp_W}
"H100 SXM": {"flops": 989, "vram_GB": 80, "bw_GB_s": 3350, "usd_h": 2.5, "tdp": 700},
"H100 PCIe": {"flops": 756, "vram_GB": 80, "bw_GB_s": 2000, "usd_h": 2.0, "tdp": 350},
"H200": {"flops": 989, "vram_GB": 141, "bw_GB_s": 4800, "usd_h": 3.5, "tdp": 700},
"B200": {"flops": 2250, "vram_GB": 192, "bw_GB_s": 8000, "usd_h": 5.0, "tdp": 1000},
"A100 80GB": {"flops": 312, "vram_GB": 80, "bw_GB_s": 2000, "usd_h": 1.2, "tdp": 400},
"A100 40GB": {"flops": 312, "vram_GB": 40, "bw_GB_s": 1555, "usd_h": 1.0, "tdp": 400},
"L40S": {"flops": 362, "vram_GB": 48, "bw_GB_s": 864, "usd_h": 0.7, "tdp": 350},
"MI300X": {"flops": 1307, "vram_GB": 192, "bw_GB_s": 5300, "usd_h": 2.1, "tdp": 750},
"RTX 4090": {"flops": 165, "vram_GB": 24, "bw_GB_s": 1008, "usd_h": 0.4, "tdp": 450},
"RTX 5090": {"flops": 419, "vram_GB": 32, "bw_GB_s": 1792, "usd_h": 0.7, "tdp": 575},
"RTX 5060Ti":{"flops": 36, "vram_GB": 16, "bw_GB_s": 448, "usd_h": 0.0, "tdp": 180}, # local
}
def cost_per_training_run(N_params: float, D_tokens: float, gpu: str = "H100 SXM",
n_gpus: int = 8, mfu: float = 0.45) -> dict:
"""§20.11 — cost = (flops_total / (peak·MFU·n_gpus)) · USD/h."""
info = GPU_CATALOG.get(gpu)
if info is None:
return {"error": f"unknown gpu '{gpu}'", "available": list(GPU_CATALOG.keys())}
total_flops = training_flops(N_params, D_tokens) # absolute FLOPs
effective_flops_per_sec = info["flops"] * 1e12 * mfu * n_gpus
seconds = total_flops / effective_flops_per_sec
hours = seconds / 3600
usd = hours * info["usd_h"] * n_gpus
return {
"total_FLOPs": total_flops,
"hours": hours,
"days": hours / 24,
"USD": usd,
"gpu": gpu, "n_gpus": n_gpus, "mfu": mfu,
}
def cost_per_inference_token(model_GB: float, gpu: str, batch: int = 1) -> dict:
"""§19.9 / §20.12 — derived $/Mtok from memory-bound decode."""
info = GPU_CATALOG.get(gpu)
if info is None:
return {"error": f"unknown gpu '{gpu}'"}
tok_per_sec = info["bw_GB_s"] / model_GB * batch
sec_per_Mtok = 1e6 / tok_per_sec
h_per_Mtok = sec_per_Mtok / 3600
usd_per_Mtok = h_per_Mtok * info["usd_h"]
return {
"tok_per_sec": tok_per_sec,
"USD_per_Mtok": usd_per_Mtok,
"gpu": gpu, "batch": batch,
}
# ════════════════════════════════════════════════════════════════════════════
# §24 — Cost / ROI
# ════════════════════════════════════════════════════════════════════════════
API_PRICING = {
# USD per million tokens (input/output blended typical)
"GPT-4o": {"input": 2.5, "output": 10.0},
"GPT-4o-mini": {"input": 0.15, "output": 0.60},
"Claude-Opus-4": {"input": 15.0, "output": 75.0},
"Claude-Sonnet-4":{"input": 3.0, "output": 15.0},
"Claude-Haiku-4": {"input": 0.80, "output": 4.0},
"Gemini-1.5-Pro": {"input": 1.25, "output": 5.0},
"DeepSeek-V3": {"input": 0.27, "output": 1.10},
"Llama-3.3-70B (Together)": {"input": 0.88, "output": 0.88},
}
def break_even_volume(training_cost: float, self_inference_per_Mtok: float,
api_per_Mtok: float, blend_input_output: float = 0.5) -> dict:
"""§24.3 — monthly tokens at which custom training pays off."""
savings_per_Mtok = api_per_Mtok - self_inference_per_Mtok
if savings_per_Mtok <= 0:
return {"error": "self-host more expensive than API per token; never breaks even"}
Mtok_breakeven = training_cost / savings_per_Mtok
return {
"savings_per_Mtok": savings_per_Mtok,
"Mtok_breakeven": Mtok_breakeven,
"tokens_breakeven": Mtok_breakeven * 1e6,
}
# ════════════════════════════════════════════════════════════════════════════
# RECIPES
# ════════════════════════════════════════════════════════════════════════════
# ─────────────────────────────────────────────────────────────────────
# X-2 — Long Context Viability
# ─────────────────────────────────────────────────────────────────────
def run_recipe_x2(theta, T_train, T_eval, n_attention_heads, n_kv_heads,
d_head, n_layers, n_params, has_SWA=False,
bytes_per_element=2.0, **_unused):
"""X-2: will model M serve length L doing NIAH retrieval?"""
chain = []
g_pade = gamma_pade(theta, T_eval)
chain.append(_step(1, "§26.1", "γ_Padé", "γ = (2θ - T√2)/(2θ + T√2)",
{"theta": theta, "T_eval": T_eval}, g_pade,
_phase_label(g_pade)))
has_GQA = (n_kv_heads < n_attention_heads)
decomp = gamma_decompose(g_pade, has_GQA=has_GQA, has_SWA=has_SWA, n_params=n_params)
g_corr = decomp["gamma_corrected"]
chain.append(_step(2, "§26.10", "γ-decomposition", "γ + δ_GQA + δ_SWA + δ_post_IH",
{"has_GQA": has_GQA, "has_SWA": has_SWA, "n_params": n_params},
g_corr, breakdown=decomp))
dh = d_horizon(theta, g_corr)
chain.append(_step(3, "§26.2", "d_horizon", "d_h = θ(1-γ)√2/(1+γ)",
{"theta": theta, "gamma": g_corr}, dh,
"n/a — γ outside (0,1)" if dh is None else f"horizon at d={dh:.0f}"))
l_niah = l_niah_c(dh)
chain.append(_step(4, "§26.5", "L_NIAH^c", "L_NIAH^c = 2·d_horizon",
{"d_horizon": dh}, l_niah,
"n/a" if l_niah is None else f"NIAH 50% at L={l_niah:.0f}"))
p_hallu = p_hallucinate(T_eval, theta, g_corr)
chain.append(_step(5, "§26.9", "P_hallucinate", "max(0,1-(d_h/L)^(1-γ))·√χ/(1+√χ)",
{"L": T_eval, "theta": theta, "gamma": g_corr}, p_hallu,
"n/a (Phase B)" if p_hallu is None else f"{p_hallu*100:.1f}% predicted"))
kv = kv_cache_memory(n_layers, n_kv_heads, d_head, T_eval, bytes_per_element)
chain.append(_step(6, "§19.1", "KV cache memory", "2·L·n_kv·d_h·seq·B",
{"n_layers": n_layers, "n_kv_heads": n_kv_heads, "d_head": d_head,
"seq_len": T_eval, "bytes_per_element": bytes_per_element},
kv, f"{kv['GB']:.2f} GB per request"))
if g_corr <= 0 or g_corr >= 1:
verdict, reason = "NO", "Phase B / geometric collapse (γ_corrected outside (0,1))"
mit = (f"Apply NTK-aware extension. Required θ for γ=0.85: "
f"{theta_design(0.85, T_eval):,.0f}. α_opt = {alpha_opt(0.85, T_eval, theta):.2f} "
f"({'fine-tuning required' if alpha_opt(0.85, T_eval, theta) > 8 else 'zero-shot may work'}).")
elif dh is not None and T_eval < dh:
margin = (1 - T_eval / dh) * 100
verdict, reason = "YES", f"L={T_eval} inside d_horizon={dh:.0f} ({margin:.0f}% margin)."
mit = "None required."
elif dh is not None and T_eval < l_niah:
verdict, reason = "DEGRADED", f"L between d_horizon ({dh:.0f}) and L_NIAH^c ({l_niah:.0f})."
mit = "Consider context contraction OR NTK extension."
else:
verdict, reason = "NO", f"L={T_eval} exceeds NIAH ceiling {l_niah:.0f}."
mit = f"Apply NTK extension; need θ ≈ {theta_design(0.85, T_eval):,.0f} for γ=0.85."
return _wrap("X-2", "Long Context Viability", locals(), chain, verdict, reason, mit)
# ─────────────────────────────────────────────────────────────────────
# X-1 — Custom training vs API for a domain task
# ─────────────────────────────────────────────────────────────────────
def run_recipe_x1(N_params, D_tokens=None, gpu="H100 SXM", n_gpus=8, mfu=0.45,
api_model="GPT-4o", monthly_tokens_M=10.0, **_unused):
"""X-1: custom training (Chinchilla optimal) vs API."""
chain = []
# Step 1: Chinchilla optimal D
if D_tokens is None:
D_tokens = chinchilla_optimal_tokens(N_params)
chain.append(_step(1, "§17.30", "Chinchilla optimal D", "D = 20·N",
{"N_params": N_params}, D_tokens,
f"recommended D = {D_tokens:.2e} tokens"))
# Step 2: training FLOPs
flops = training_flops(N_params, D_tokens)
chain.append(_step(2, "§17.10", "Training FLOPs", "C = 6·N·D",
{"N": N_params, "D": D_tokens}, flops,
f"{flops:.2e} FLOPs total"))
# Step 3: training cost
cost = cost_per_training_run(N_params, D_tokens, gpu=gpu, n_gpus=n_gpus, mfu=mfu)
chain.append(_step(3, "§20.11", "Training cost",
"hours·USD/h·n_gpus = total $",
{"gpu": gpu, "n_gpus": n_gpus, "mfu": mfu}, cost,
f"${cost['USD']:,.0f} over {cost['days']:.1f} days"))
# Step 4: model_GB and decode throughput
model_GB = N_params * 2 / 1e9 # BF16
inf = cost_per_inference_token(model_GB, gpu, batch=1)
chain.append(_step(4, "§19.9 / §20.12", "Self-inference $/Mtok",
"BW / model_GB → tok/s → $/Mtok",
{"model_GB": model_GB, "gpu": gpu}, inf,
f"${inf['USD_per_Mtok']:.2f} per million tokens (single user)"))
# Step 5: API blended price
api = API_PRICING.get(api_model, {"input": 2.0, "output": 8.0})
api_blend = (api["input"] + api["output"]) / 2
chain.append(_step(5, "§24.X", f"{api_model} blended price",
"(input + output) / 2 USD/Mtok",
{"api_model": api_model}, api_blend,
f"${api_blend:.2f}/Mtok blended"))
# Step 6: break-even
be = break_even_volume(cost["USD"], inf["USD_per_Mtok"], api_blend)
chain.append(_step(6, "§24.3", "Break-even tokens", "training$ / (api - self) = Mtok",
{"training_cost": cost["USD"]}, be,
_be_interp(be, monthly_tokens_M)))
# Verdict
if "error" in be:
verdict, reason = "NO", be["error"]
mit = f"Stick with {api_model} API."
elif monthly_tokens_M >= be["Mtok_breakeven"]:
verdict = "YES (custom)"
months_to_payoff = be["Mtok_breakeven"] / monthly_tokens_M
reason = (f"At {monthly_tokens_M} M tokens/month, break-even in "
f"{months_to_payoff:.1f} months. Long-term custom is cheaper.")
mit = f"Train at {gpu}×{n_gpus}; serve self-hosted."
else:
months = be["Mtok_breakeven"] / monthly_tokens_M
verdict = "NO (API)"
reason = (f"At {monthly_tokens_M} M tokens/month, break-even in "
f"{months:.1f} months — too slow.")
mit = f"Use {api_model} API (cheaper for your volume)."
return _wrap("X-1", "Custom training vs API", locals(), chain, verdict, reason, mit)
def _be_interp(be, monthly):
if "error" in be:
return be["error"]
months = be["Mtok_breakeven"] / max(monthly, 0.001)
return f"break-even at {be['Mtok_breakeven']:.0f} Mtok ({months:.1f} months at {monthly} M/mo)"
# ─────────────────────────────────────────────────────────────────────
# X-3 — Pre-flight check on $5K training budget
# ─────────────────────────────────────────────────────────────────────
def run_recipe_x3(USD_budget=5000.0, gpu="H100 SXM", mfu=0.45, n_gpus=1, **_unused):
"""X-3: given $ budget, what model can I train?"""
chain = []
info = GPU_CATALOG[gpu]
# Step 1: GPU-hours we can afford
hours = USD_budget / (info["usd_h"] * n_gpus)
chain.append(_step(1, "§20.11", "Affordable GPU-hours", "USD / ($/h·n_gpus)",
{"USD": USD_budget, "gpu": gpu, "n_gpus": n_gpus}, hours,
f"{hours:.0f} GPU-hours total ({hours/24:.1f} days at full use)"))
# Step 2: max FLOPs
max_flops = info["flops"] * 1e12 * mfu * n_gpus * hours * 3600
chain.append(_step(2, "§17.10", "Max training FLOPs",
"peak·MFU·n_gpus·seconds",
{"peak_TFLOPs": info["flops"], "MFU": mfu}, max_flops,
f"{max_flops:.2e} effective FLOPs"))
# Step 3: Chinchilla-optimal N (with D=20N)
# 6·N·D = max_flops, D=20N → 120·N² = max_flops → N = sqrt(max_flops/120)
N_chinchilla = math.sqrt(max_flops / 120)
D_chinchilla = 20 * N_chinchilla
chain.append(_step(3, "§17.30", "Chinchilla-optimal N",
"N = √(C/120) at D=20N", {"max_FLOPs": max_flops},
N_chinchilla,
f"N ≈ {N_chinchilla:.2e} params with D = {D_chinchilla:.2e} tokens"))
# Step 4: emergence check
emerg = emergent_threshold(N_chinchilla)
chain.append(_step(4, "§17.60", "Emergence threshold", "Wei 2022 capability",
{"N": N_chinchilla}, emerg, emerg))
# Step 5: memory budget check
mem = training_memory_16N(N_chinchilla)
fits = mem["GB"] <= info["vram_GB"]
chain.append(_step(5, "§17.20", "16N training memory",
"model + grads + AdamW",
{"N": N_chinchilla}, mem,
f"{mem['GB']:.1f} GB needed; "
f"{'fits in ' if fits else 'EXCEEDS '}{info['vram_GB']} GB VRAM"))
# Verdict
if N_chinchilla < 1e8:
verdict, reason = "TINY-MODEL", f"Budget supports only ~{N_chinchilla:.0e} params"
mit = "Use LoRA fine-tuning of larger pretrained model instead."
elif not fits:
verdict, reason = "MEMORY-LIMITED", f"Chinchilla N ({N_chinchilla:.1e}) doesn't fit one {gpu}"
mit = f"Use ZeRO-3 across multiple GPUs (need ≥{math.ceil(mem['GB']/info['vram_GB'])}× {gpu}) OR train smaller N undertrained."
else:
verdict = "GO"
reason = (f"At ${USD_budget}, train {N_chinchilla:.1e}-param model on "
f"{D_chinchilla:.1e} tokens in ~{hours/24:.1f} days. "
f"Capability tier: {emerg.split('—')[0].strip()}.")
mit = "None — proceed with Chinchilla-optimal recipe."
return _wrap("X-3", "Budget pre-flight", locals(), chain, verdict, reason, mit)
# ─────────────────────────────────────────────────────────────────────
# X-5 — Hardware selection for serving
# ─────────────────────────────────────────────────────────────────────
def run_recipe_x5(N_params, T_eval=4096, n_layers=32, n_kv_heads=8, d_head=128,
bytes_per_weight=2.0, target_tokens_per_day=10_000_000.0,
concurrent_users=1, **_unused):
"""X-5: which GPU should I use to serve N-param model at L context?"""
chain = []
# Step 1: weights memory
w_mem = model_weights_memory(N_params, bytes_per_weight)
chain.append(_step(1, "§19.X", "Model weights memory",
"N · bytes_per_weight",
{"N": N_params, "bytes": bytes_per_weight}, w_mem,
f"{w_mem['GB']:.1f} GB for weights"))
# Step 2: KV cache per request
kv = kv_cache_memory(n_layers, n_kv_heads, d_head, T_eval, bytes_per_weight)
chain.append(_step(2, "§19.1", "KV cache (per request)",
"2·L·n_kv·d_h·seq·B",
{"n_layers": n_layers, "n_kv": n_kv_heads,
"d_head": d_head, "seq": T_eval}, kv,
f"{kv['GB']:.2f} GB per concurrent request"))
# Step 3: total memory needed
total_GB = w_mem["GB"] + kv["GB"] * concurrent_users
chain.append(_step(3, "§20.3", "Total GPU memory",
"weights + KV·n_concurrent", {}, {"GB": total_GB},
f"{total_GB:.1f} GB for {concurrent_users} concurrent users"))
# Step 4: scan GPU catalog
candidates = []
for name, info in GPU_CATALOG.items():
if info["vram_GB"] < total_GB:
continue
# Decode throughput estimate (memory-bound)
tok_per_s = info["bw_GB_s"] / w_mem["GB"]
tok_per_day = tok_per_s * 86400
capacity_users = tok_per_day / target_tokens_per_day
usd_per_day = info["usd_h"] * 24
usd_per_Mtok = (usd_per_day / (tok_per_day / 1e6)) if tok_per_day > 0 else float('inf')
candidates.append({
"gpu": name, "vram_GB": info["vram_GB"], "bw_GB_s": info["bw_GB_s"],
"tok_per_sec": tok_per_s, "tok_per_day": tok_per_day,
"USD_per_day": usd_per_day, "USD_per_Mtok": usd_per_Mtok,
"users_supported": capacity_users,
})
candidates.sort(key=lambda c: c["USD_per_Mtok"])
chain.append(_step(4, "§20", f"Eligible GPUs (≥{total_GB:.0f}GB)",
"filter + rank by $/Mtok",
{"min_VRAM": total_GB}, candidates[:5],
f"{len(candidates)} GPUs fit; cheapest: {candidates[0]['gpu'] if candidates else 'NONE'}"))
# Verdict
if not candidates:
verdict, reason = "NO", f"No single GPU has ≥{total_GB:.0f} GB VRAM."
mit = (f"Use tensor parallelism across multiple GPUs "
f"(e.g. 2× H100 = 160GB), or quantize to INT8 (halves memory).")
else:
best = candidates[0]
verdict = "YES"
reason = (f"Best GPU: {best['gpu']} at ${best['USD_per_Mtok']:.2f}/Mtok. "
f"Supports {best['users_supported']:.1f}× your daily target.")
mit = f"Provision {best['gpu']}, expected {best['tok_per_sec']:.0f} tok/s decode."
return _wrap("X-5", "Hardware selection for serving", locals(), chain, verdict, reason, mit)
# ─────────────────────────────────────────────────────────────────────
# X-19 — KV compression decision (ours vs literature)
# ─────────────────────────────────────────────────────────────────────
def run_recipe_x19(theta, T_train, T_eval, n_attention_heads, n_kv_heads,
d_head, n_layers, n_params, has_SWA=False, **_unused):
"""X-19: should I use γ-soft KV decay, hard D_f, or literature methods?"""
chain = []
# Step 1: γ_Padé
g_pade = gamma_pade(theta, T_eval)
chain.append(_step(1, "§26.1", "γ_Padé", "(2θ-T√2)/(2θ+T√2)",
{"theta": theta, "T_eval": T_eval}, g_pade, _phase_label(g_pade)))
# Step 2: γ-decomposition
has_GQA = n_kv_heads < n_attention_heads
decomp = gamma_decompose(g_pade, has_GQA, has_SWA, n_params)
g_corr = decomp["gamma_corrected"]
chain.append(_step(2, "§26.10", "γ-decomposition", "5-axis adjustment",
{"has_GQA": has_GQA, "has_SWA": has_SWA, "n_params": n_params},
g_corr))
# Step 3: §26.7 D_f window applicability
df = df_window(g_corr, T_eval, f=0.90)
df_zone_ok = df is not None
chain.append(_step(3, "§26.7", "D_f window (γ in [0.65, 0.85])",
"[(1-f)+fN^(1-γ)]^(1/(1-γ))",
{"gamma": g_corr, "N": T_eval, "f": 0.9}, df,
f"D_f = {df}" if df_zone_ok
else f"NOT applicable (γ={g_corr:.3f} outside [0.65, 0.85])"))
# Step 4: §26.8 soft decay régimen
regime = kv_soft_decay_regime(theta, g_corr, T_train)
dh = d_horizon(theta, g_corr)
dh_str = f"{dh:.0f}" if dh is not None else "n/a"
chain.append(_step(4, "§26.8", "Soft decay régimen", "d_h ≳ T_train/2",
{"theta": theta, "gamma": g_corr, "T_train": T_train}, regime,
f"d_horizon={dh_str}; regime: {regime}"))
# Step 5: KV cache memory baseline
kv = kv_cache_memory(n_layers, n_kv_heads, d_head, T_eval)
chain.append(_step(5, "§19.1", "Baseline KV memory", "2·L·n_kv·d_h·seq·B",
{"L": n_layers, "n_kv": n_kv_heads, "d_h": d_head, "seq": T_eval},
kv, f"{kv['GB']:.2f} GB without compression"))
# Verdict
if regime == "applies" and df_zone_ok:
verdict = "USE SOFT DECAY"
reason = (f"d_horizon ≳ T_train/2 AND γ in compression zone. "
f"Soft decay (1-d/d_h)^γ best (-21% PPL vs hard cutoff per F17).")
mit = "Implement as 4D attention_mask additive bias with eager attention."
elif df_zone_ok:
verdict = "USE D_f HARD CUTOFF"
reason = f"γ in [0.65, 0.85] zone but d_h < T_train/2. Hard truncation at D_f={df} works."
mit = "Set cache_max_len = D_f."
elif regime == "applies":
verdict = "USE SOFT DECAY (caveat)"
reason = "Régimen applies but γ outside D_f validity zone. Soft decay only."
mit = "Soft decay; do not use D_f window."
elif g_corr >= 1 or g_corr <= 0:
verdict = "USE LITERATURE METHODS"
reason = f"γ={g_corr:.3f} outside Phase A. Our formulas don't apply."
mit = "Use SnapKV / PyramidKV / FastGen (literature heuristics)."
else:
verdict = "USE HARD T_train CUTOFF"
reason = "Régimen not met AND γ outside zone. Cap context at T_train."
mit = f"Set seq_len ≤ {T_train}, no extension."
return _wrap("X-19", "KV compression decision", locals(), chain, verdict, reason, mit)
# ════════════════════════════════════════════════════════════════════════════
# Helpers
# ════════════════════════════════════════════════════════════════════════════
def _step(n, sec, name, formula, inputs, result, interpretation=None, breakdown=None):
s = {"step": n, "section": sec, "name": name, "formula": formula,
"inputs": inputs, "result": result}
if interpretation:
s["interpretation"] = interpretation
if breakdown:
s["breakdown"] = breakdown
return s
def _wrap(rid, rname, locals_dict, chain, verdict, reason, mitigation):
# Clean inputs (drop chain/internal vars)
inputs = {k: v for k, v in locals_dict.items()
if not k.startswith("_") and k not in
("chain", "verdict", "reason", "mit", "info", "be", "kv", "g_pade", "g_corr",
"decomp", "dh", "l_niah", "p_hallu", "cost", "model_GB", "inf", "api",
"api_blend", "fits", "mem", "emerg", "max_flops", "hours",
"N_chinchilla", "D_chinchilla", "candidates", "best", "tok_per_s",
"tok_per_day", "capacity_users", "usd_per_day", "usd_per_Mtok",
"total_GB", "w_mem", "df", "df_zone_ok", "regime", "has_GQA",
"margin", "months", "months_to_payoff", "name")}
return {"recipe_id": rid, "recipe_name": rname, "inputs": inputs,
"chain": chain, "verdict": verdict, "reason": reason,
"mitigation": mitigation}
def _phase_label(g):
if 0 < g < 1:
return "Phase A (long-range OK)"
if g >= 1:
return "Phase B / Hagedorn"
return "Phase B / catastrophic (negative γ — T too large for θ)"
# ════════════════════════════════════════════════════════════════════════════
# Recipe registry
# ════════════════════════════════════════════════════════════════════════════
RECIPES = {
"X-1": {
"name": "Custom Training vs API",
"description": "Should I train a custom model or use a frontier API for my domain task?",
"fn": run_recipe_x1,
"params": ["N_params", "D_tokens", "gpu", "n_gpus", "mfu",
"api_model", "monthly_tokens_M"],
"category": "build-vs-buy",
"uses_sections": ["§17", "§19", "§20", "§24"],
},
"X-2": {
"name": "Long Context Viability",
"description": "Will model M serve length L doing Needle-in-a-Haystack retrieval?",
"fn": run_recipe_x2,
"params": ["theta", "T_train", "T_eval", "n_attention_heads", "n_kv_heads",
"d_head", "n_layers", "n_params", "has_SWA"],
"category": "long-context",
"uses_sections": ["§26", "§19"],
},
"X-3": {
"name": "Budget Pre-flight",
"description": "Given $ budget, what model is feasible to train?",
"fn": run_recipe_x3,
"params": ["USD_budget", "gpu", "mfu", "n_gpus"],
"category": "training-budget",
"uses_sections": ["§17", "§20"],
},
"X-5": {
"name": "Hardware Selection",
"description": "Which GPU should I use to serve my model at target throughput?",
"fn": run_recipe_x5,
"params": ["N_params", "T_eval", "n_layers", "n_kv_heads", "d_head",
"bytes_per_weight", "target_tokens_per_day", "concurrent_users"],
"category": "serving",
"uses_sections": ["§19", "§20"],
},
"X-19": {
"name": "KV Compression Decision",
"description": "Should I use soft decay, D_f cutoff, or literature methods to compress KV?",
"fn": run_recipe_x19,
"params": ["theta", "T_train", "T_eval", "n_attention_heads", "n_kv_heads",
"d_head", "n_layers", "n_params", "has_SWA"],
"category": "kv-compression",
"uses_sections": ["§26", "§19"],
},
}
def list_recipes() -> str:
"""Return JSON of all recipes for UI dropdown."""
return json.dumps([
{"id": rid, "name": r["name"], "description": r["description"],
"category": r["category"], "params": r["params"],
"uses_sections": r["uses_sections"]}
for rid, r in RECIPES.items()
])
def run_recipe(recipe_id: str, **params) -> dict:
"""Dispatcher — execute recipe by id with given params."""
r = RECIPES.get(recipe_id)
if r is None:
return {"error": f"unknown recipe '{recipe_id}'",
"available": list(RECIPES.keys())}
return r["fn"](**params)
# ════════════════════════════════════════════════════════════════════════════
# Known model presets
# ════════════════════════════════════════════════════════════════════════════
PRESETS = {
"EleutherAI/pythia-2.8b": {
"theta": 10000, "T_train": 2048,
"n_attention_heads": 32, "n_kv_heads": 32,
"d_head": 80, "n_layers": 32, "n_params": 2.8e9, "has_SWA": False,
},
"EleutherAI/pythia-1b": {
"theta": 10000, "T_train": 2048,
"n_attention_heads": 8, "n_kv_heads": 8,
"d_head": 256, "n_layers": 16, "n_params": 1e9, "has_SWA": False,
},
"EleutherAI/pythia-1.4b": {
"theta": 10000, "T_train": 2048,
"n_attention_heads": 16, "n_kv_heads": 16,
"d_head": 128, "n_layers": 24, "n_params": 1.4e9, "has_SWA": False,
},
"meta-llama/Meta-Llama-3-8B": {
"theta": 500000, "T_train": 8192,
"n_attention_heads": 32, "n_kv_heads": 8,
"d_head": 128, "n_layers": 32, "n_params": 8e9, "has_SWA": False,
},
"meta-llama/Llama-3.2-1B": {
"theta": 500000, "T_train": 131072,
"n_attention_heads": 32, "n_kv_heads": 8,
"d_head": 64, "n_layers": 16, "n_params": 1.2e9, "has_SWA": False,
},
"meta-llama/Llama-3.3-70B-Instruct": {
"theta": 500000, "T_train": 131072,
"n_attention_heads": 64, "n_kv_heads": 8,
"d_head": 128, "n_layers": 80, "n_params": 70e9, "has_SWA": False,
},
"mistralai/Mistral-7B-v0.1": {
"theta": 10000, "T_train": 8192,
"n_attention_heads": 32, "n_kv_heads": 8,
"d_head": 128, "n_layers": 32, "n_params": 7e9, "has_SWA": True,
},
"Qwen/Qwen2.5-7B": {
"theta": 1000000, "T_train": 32768,
"n_attention_heads": 28, "n_kv_heads": 4,
"d_head": 128, "n_layers": 28, "n_params": 7.6e9, "has_SWA": False,
},
"Qwen/Qwen2.5-1.5B": {
"theta": 1000000, "T_train": 32768,
"n_attention_heads": 12, "n_kv_heads": 2,
"d_head": 128, "n_layers": 28, "n_params": 1.5e9, "has_SWA": False,
},
"google/gemma-2-9b-it": {
"theta": 10000, "T_train": 8192,
"n_attention_heads": 16, "n_kv_heads": 8,
"d_head": 256, "n_layers": 42, "n_params": 9e9, "has_SWA": True,
},
"microsoft/phi-3-mini-4k-instruct": {
"theta": 10000, "T_train": 4096,
"n_attention_heads": 32, "n_kv_heads": 32,
"d_head": 96, "n_layers": 32, "n_params": 3.8e9, "has_SWA": True,
},
}
def list_presets() -> str:
return json.dumps([
{"id": k, "label": k.split("/")[-1],
"theta": v["theta"], "T_train": v["T_train"]}
for k, v in PRESETS.items()
])
def get_preset(model_id: str) -> dict:
return PRESETS.get(model_id, {})
# Smoke test
if __name__ == "__main__":
print("─── X-2 Llama-3-8B @ 32K ───")
r = run_recipe("X-2", theta=500_000, T_train=8192, T_eval=32_000,
n_attention_heads=32, n_kv_heads=8, d_head=128,
n_layers=32, n_params=8e9, has_SWA=False)
print(f"Verdict: {r['verdict']}{r['reason']}\n")
print("─── X-1 Llama-3-8B vs GPT-4o (10M tok/mo) ───")
r = run_recipe("X-1", N_params=8e9, monthly_tokens_M=10.0, api_model="GPT-4o")
print(f"Verdict: {r['verdict']}{r['reason']}\n")
print("─── X-3 budget $5K ───")
r = run_recipe("X-3", USD_budget=5000.0, gpu="H100 SXM", n_gpus=1)
print(f"Verdict: {r['verdict']}{r['reason']}\n")
print("─── X-5 serve Llama-3-8B at 4K ───")
r = run_recipe("X-5", N_params=8e9, T_eval=4096, n_layers=32, n_kv_heads=8, d_head=128,
target_tokens_per_day=10e6, concurrent_users=1)
print(f"Verdict: {r['verdict']}{r['reason']}\n")
print("─── X-19 KV compression for Llama-3-8B ───")
r = run_recipe("X-19", theta=500_000, T_train=8192, T_eval=8192,
n_attention_heads=32, n_kv_heads=8, d_head=128,
n_layers=32, n_params=8e9)
print(f"Verdict: {r['verdict']}{r['reason']}\n")