Gemma 4 E2B — Text-Only ONNX INT4 (repack)

Text-only repackage of onnx-community/gemma-4-E2B-it-ONNX. Multimodal heads (vision, audio) and non-Q4 quantization variants are stripped; the text decoder + embed table at INT4 (MatMulNBits) are kept as-is.

Inspired by Mathieu Grenier — Gemma 4 text-only GPU 6Go optimisation. The article documents the same goal — running Gemma 4 text-only on a constrained GPU budget. This repo packages the artefacts so downstream users skip the trial-and-error.

Files

config.json                     # text-only (text_config promoted)
generation_config.json
tokenizer.json + tokenizer_config.json + chat_template.jinja
onnx/
  embed_tokens_q4.onnx (+ every .onnx_data* sidecar)
  decoder_model_merged_q4.onnx (+ every .onnx_data* sidecar)

Why a 2-graph layout

Gemma 4 uses per-layer inputs / per-layer projections (a Mixture-of-Depths-like routing). The community ONNX export splits this into:

  1. embed_tokens_q4input_idsinputs_embeds, per_layer_inputs
  2. decoder_model_merged_q4inputs_embeds, per_layer_inputs, attention_mask, position_ids, past_key_values.*logits, present.*

ORTModelForCausalLM.from_pretrained(...).generate() does not orchestrate this 2-graph pattern as of optimum==1.x / optimum==2.x dev. Use the custom loop below.

Usage (Python, custom ORT loop)

import onnxruntime as ort
import numpy as np
from transformers import AutoTokenizer
from huggingface_hub import snapshot_download

local = snapshot_download("tss-deposium/gemma-4-E2B-text-only-onnx-int4")
tok = AutoTokenizer.from_pretrained(local)

providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if ort.get_device() == "GPU" else ["CPUExecutionProvider"]
embed_session   = ort.InferenceSession(f"{local}/onnx/embed_tokens_q4.onnx", providers=providers)
decoder_session = ort.InferenceSession(f"{local}/onnx/decoder_model_merged_q4.onnx", providers=providers)

# --- Discover I/O signatures (Gemma 4 specifics: per_layer_inputs, KV cache layout) ---
DEC_INPUTS  = {x.name: x for x in decoder_session.get_inputs()}
DEC_OUTPUTS = {x.name: x for x in decoder_session.get_outputs()}
EMB_OUTPUTS = [x.name for x in embed_session.get_outputs()]
PAST_KEYS    = sorted([n for n in DEC_INPUTS if n.startswith("past_key_values")])
PRESENT_KEYS = sorted([n for n in DEC_OUTPUTS if n.startswith(("present.", "present_key_values."))])
HAS_PER_LAYER = "per_layer_inputs" in DEC_INPUTS
HAS_NUM_LOGITS_TO_KEEP = "num_logits_to_keep" in DEC_INPUTS
present_to_past = dict(zip(PRESENT_KEYS, PAST_KEYS))

def _ort_dtype(t):
    return {"tensor(float16)": np.float16, "tensor(float)": np.float32,
            "tensor(int64)": np.int64, "tensor(int32)": np.int32, "tensor(bool)": np.bool_}[t]

def _init_past_kv(batch=1):
    config_dims = {}
    for n in PAST_KEYS:
        for i, d in enumerate(DEC_INPUTS[n].shape):
            if isinstance(d, int) and d > 0: config_dims.setdefault(i, d)
    pkv = {}
    for n in PAST_KEYS:
        meta = DEC_INPUTS[n]
        shape = []
        for i, d in enumerate(meta.shape):
            if isinstance(d, int):       shape.append(d)
            elif i == 0:                  shape.append(batch)
            elif i == 2:                  shape.append(0)
            elif i in config_dims:        shape.append(config_dims[i])
            else: raise RuntimeError(f"unresolved symbolic dim {d} in {n}")
        pkv[n] = np.zeros(shape, dtype=_ort_dtype(meta.type))
    return pkv

def _run_embed(input_ids):
    outs = embed_session.run(None, {embed_session.get_inputs()[0].name: input_ids})
    out_map = dict(zip(EMB_OUTPUTS, outs))
    embeds = next(out_map[n] for n in EMB_OUTPUTS if "embed" in n.lower())
    per_layer = next((out_map[n] for n in EMB_OUTPUTS if "per_layer" in n.lower()), None)
    return embeds, per_layer

def _run_decoder(embeds, per_layer, mask, pos, past):
    feeds = {
        "inputs_embeds":  embeds.astype(_ort_dtype(DEC_INPUTS["inputs_embeds"].type)),
        "attention_mask": mask.astype(_ort_dtype(DEC_INPUTS["attention_mask"].type)),
        "position_ids":   pos.astype(_ort_dtype(DEC_INPUTS["position_ids"].type)),
    }
    if HAS_PER_LAYER:
        feeds["per_layer_inputs"] = per_layer.astype(_ort_dtype(DEC_INPUTS["per_layer_inputs"].type))
    if HAS_NUM_LOGITS_TO_KEEP:
        feeds["num_logits_to_keep"] = np.array(1, dtype=_ort_dtype(DEC_INPUTS["num_logits_to_keep"].type))
    feeds.update(past)
    outs = decoder_session.run(None, feeds)
    out_names = [o.name for o in decoder_session.get_outputs()]
    name_to_arr = dict(zip(out_names, outs))
    return name_to_arr.get("logits", outs[0]), {present_to_past[n]: name_to_arr[n] for n in PRESENT_KEYS}

def generate(prompt: str, max_new_tokens: int = 64) -> str:
    msgs = [{"role": "user", "content": prompt}]
    try:
        text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True, enable_thinking=False)
    except TypeError:
        text = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    ids = tok(text, return_tensors="np").input_ids.astype(np.int64)

    import json
    gen_cfg_path = f"{local}/generation_config.json"
    eos_raw = json.load(open(gen_cfg_path)).get("eos_token_id", tok.eos_token_id)
    eos_ids = set(eos_raw) if isinstance(eos_raw, list) else {eos_raw}

    past = _init_past_kv(batch=1)
    seq = ids.shape[1]
    mask = np.ones((1, seq), dtype=np.int64)
    pos = np.arange(seq, dtype=np.int64)[None, :]

    embeds, per_layer = _run_embed(ids)
    logits, past = _run_decoder(embeds, per_layer, mask, pos, past)
    next_id = int(np.argmax(logits[:, -1, :]))
    out = [next_id]
    cur = seq
    while len(out) < max_new_tokens and next_id not in eos_ids:
        cur += 1
        tok_arr = np.array([[next_id]], dtype=np.int64)
        mask = np.concatenate([mask, np.ones((1, 1), dtype=np.int64)], axis=1)
        pos  = np.array([[cur - 1]], dtype=np.int64)
        embeds, per_layer = _run_embed(tok_arr)
        logits, past = _run_decoder(embeds, per_layer, mask, pos, past)
        next_id = int(np.argmax(logits[:, -1, :]))
        out.append(next_id)
    return tok.decode(out, skip_special_tokens=True)

print(generate("Bonjour, explique en une phrase ce que tu es."))

Credits

Approach inspired by Mathieu Grenier — Gemma 4 text-only GPU 6Go optimisation. The article documents the same goal (text-only Gemma 4 on a constrained GPU budget). This repo packages the resulting artefacts so downstream users skip the trial-and-error.

Provenance

  • Source: onnx-community/gemma-4-E2B-it-ONNX
  • Repack pipeline: theseedship/deposium-turbov3/docs/gemma4_e4b_text_only_onnx_int4_export.ipynb
  • License: Gemma terms of use (inherited)
Downloads last month
19
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for sizzlebop/gemma-4-E2B-text-only-onnx-int4

Quantized
(204)
this model