Gemma 4 E2B — Text-Only ONNX INT4 (repack)
Text-only repackage of onnx-community/gemma-4-E2B-it-ONNX. Multimodal heads (vision, audio) and non-Q4 quantization variants are stripped; the text decoder + embed table at INT4 (MatMulNBits) are kept as-is.
Inspired by Mathieu Grenier — Gemma 4 text-only GPU 6Go optimisation. The article documents the same goal — running Gemma 4 text-only on a constrained GPU budget. This repo packages the artefacts so downstream users skip the trial-and-error.
Files
config.json # text-only (text_config promoted)
generation_config.json
tokenizer.json + tokenizer_config.json + chat_template.jinja
onnx/
embed_tokens_q4.onnx (+ every .onnx_data* sidecar)
decoder_model_merged_q4.onnx (+ every .onnx_data* sidecar)
Why a 2-graph layout
Gemma 4 uses per-layer inputs / per-layer projections (a Mixture-of-Depths-like routing). The community ONNX export splits this into:
embed_tokens_q4—input_ids→inputs_embeds,per_layer_inputsdecoder_model_merged_q4—inputs_embeds,per_layer_inputs,attention_mask,position_ids,past_key_values.*→logits,present.*
ORTModelForCausalLM.from_pretrained(...).generate() does not orchestrate this 2-graph pattern as of optimum==1.x / optimum==2.x dev. Use the custom loop below.
Usage (Python, custom ORT loop)
import onnxruntime as ort
import numpy as np
from transformers import AutoTokenizer
from huggingface_hub import snapshot_download
local = snapshot_download("tss-deposium/gemma-4-E2B-text-only-onnx-int4")
tok = AutoTokenizer.from_pretrained(local)
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if ort.get_device() == "GPU" else ["CPUExecutionProvider"]
embed_session = ort.InferenceSession(f"{local}/onnx/embed_tokens_q4.onnx", providers=providers)
decoder_session = ort.InferenceSession(f"{local}/onnx/decoder_model_merged_q4.onnx", providers=providers)
# --- Discover I/O signatures (Gemma 4 specifics: per_layer_inputs, KV cache layout) ---
DEC_INPUTS = {x.name: x for x in decoder_session.get_inputs()}
DEC_OUTPUTS = {x.name: x for x in decoder_session.get_outputs()}
EMB_OUTPUTS = [x.name for x in embed_session.get_outputs()]
PAST_KEYS = sorted([n for n in DEC_INPUTS if n.startswith("past_key_values")])
PRESENT_KEYS = sorted([n for n in DEC_OUTPUTS if n.startswith(("present.", "present_key_values."))])
HAS_PER_LAYER = "per_layer_inputs" in DEC_INPUTS
HAS_NUM_LOGITS_TO_KEEP = "num_logits_to_keep" in DEC_INPUTS
present_to_past = dict(zip(PRESENT_KEYS, PAST_KEYS))
def _ort_dtype(t):
return {"tensor(float16)": np.float16, "tensor(float)": np.float32,
"tensor(int64)": np.int64, "tensor(int32)": np.int32, "tensor(bool)": np.bool_}[t]
def _init_past_kv(batch=1):
config_dims = {}
for n in PAST_KEYS:
for i, d in enumerate(DEC_INPUTS[n].shape):
if isinstance(d, int) and d > 0: config_dims.setdefault(i, d)
pkv = {}
for n in PAST_KEYS:
meta = DEC_INPUTS[n]
shape = []
for i, d in enumerate(meta.shape):
if isinstance(d, int): shape.append(d)
elif i == 0: shape.append(batch)
elif i == 2: shape.append(0)
elif i in config_dims: shape.append(config_dims[i])
else: raise RuntimeError(f"unresolved symbolic dim {d} in {n}")
pkv[n] = np.zeros(shape, dtype=_ort_dtype(meta.type))
return pkv
def _run_embed(input_ids):
outs = embed_session.run(None, {embed_session.get_inputs()[0].name: input_ids})
out_map = dict(zip(EMB_OUTPUTS, outs))
embeds = next(out_map[n] for n in EMB_OUTPUTS if "embed" in n.lower())
per_layer = next((out_map[n] for n in EMB_OUTPUTS if "per_layer" in n.lower()), None)
return embeds, per_layer
def _run_decoder(embeds, per_layer, mask, pos, past):
feeds = {
"inputs_embeds": embeds.astype(_ort_dtype(DEC_INPUTS["inputs_embeds"].type)),
"attention_mask": mask.astype(_ort_dtype(DEC_INPUTS["attention_mask"].type)),
"position_ids": pos.astype(_ort_dtype(DEC_INPUTS["position_ids"].type)),
}
if HAS_PER_LAYER:
feeds["per_layer_inputs"] = per_layer.astype(_ort_dtype(DEC_INPUTS["per_layer_inputs"].type))
if HAS_NUM_LOGITS_TO_KEEP:
feeds["num_logits_to_keep"] = np.array(1, dtype=_ort_dtype(DEC_INPUTS["num_logits_to_keep"].type))
feeds.update(past)
outs = decoder_session.run(None, feeds)
out_names = [o.name for o in decoder_session.get_outputs()]
name_to_arr = dict(zip(out_names, outs))
return name_to_arr.get("logits", outs[0]), {present_to_past[n]: name_to_arr[n] for n in PRESENT_KEYS}
def generate(prompt: str, max_new_tokens: int = 64) -> str:
msgs = [{"role": "user", "content": prompt}]
try:
text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True, enable_thinking=False)
except TypeError:
text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
ids = tok(text, return_tensors="np").input_ids.astype(np.int64)
import json
gen_cfg_path = f"{local}/generation_config.json"
eos_raw = json.load(open(gen_cfg_path)).get("eos_token_id", tok.eos_token_id)
eos_ids = set(eos_raw) if isinstance(eos_raw, list) else {eos_raw}
past = _init_past_kv(batch=1)
seq = ids.shape[1]
mask = np.ones((1, seq), dtype=np.int64)
pos = np.arange(seq, dtype=np.int64)[None, :]
embeds, per_layer = _run_embed(ids)
logits, past = _run_decoder(embeds, per_layer, mask, pos, past)
next_id = int(np.argmax(logits[:, -1, :]))
out = [next_id]
cur = seq
while len(out) < max_new_tokens and next_id not in eos_ids:
cur += 1
tok_arr = np.array([[next_id]], dtype=np.int64)
mask = np.concatenate([mask, np.ones((1, 1), dtype=np.int64)], axis=1)
pos = np.array([[cur - 1]], dtype=np.int64)
embeds, per_layer = _run_embed(tok_arr)
logits, past = _run_decoder(embeds, per_layer, mask, pos, past)
next_id = int(np.argmax(logits[:, -1, :]))
out.append(next_id)
return tok.decode(out, skip_special_tokens=True)
print(generate("Bonjour, explique en une phrase ce que tu es."))
Credits
Approach inspired by Mathieu Grenier — Gemma 4 text-only GPU 6Go optimisation. The article documents the same goal (text-only Gemma 4 on a constrained GPU budget). This repo packages the resulting artefacts so downstream users skip the trial-and-error.
Provenance
- Source:
onnx-community/gemma-4-E2B-it-ONNX - Repack pipeline:
theseedship/deposium-turbov3/docs/gemma4_e4b_text_only_onnx_int4_export.ipynb - License: Gemma terms of use (inherited)
- Downloads last month
- 19