--- license: apache-2.0 base_model: nvidia/LocateAnything-3B tags: - onnx - onnxruntime-web - webgpu - int4 - quantization - object-detection - visual-grounding library_name: onnxruntime pipeline_tag: object-detection --- # LocateAnything-3B — ONNX WebGPU (INT4 + 4-bit embeddings) In-browser ([onnxruntime-web](https://onnxruntime.ai/docs/tutorials/web/) / WebGPU) build of [`nvidia/LocateAnything-3B`](https://huggingface.co/nvidia/LocateAnything-3B), a visual-grounding / open-vocabulary detector. The language tower is weight-only **INT4** and the **embedding table is true group-wise 4-bit**. ## Why this repo exists The naive "4-bit ONNX" of this model was ~3GB because the model has **tied word embeddings** (`vocab 152681 × hidden 2048 = 1.25GB` in fp32). ORT's `MatMulNBits` INT4 quantizer compresses the tied `lm_head` MatMul but **leaves the input-embedding `Gather` at full fp32** — so 1.25GB of fp32 embeddings stayed in the package. This build fixes that with a **custom quantized embedding gather**: 1. The language graph was surgically rewired to consume `inputs_embeds` directly (the fp32 embedding `Gather` and its 1.25GB initializer are removed). It still takes `input_ids` (used only by the SDLM block-mask `==` comparisons, not a Gather) and `visual_features` (spliced at the image token). 2. The embedding table ships as a **group-wise symmetric INT4 blob** (`(q-8)·scale`, block size 32): `embed_tokens_int4_packed.bin` (uint8 nibble-packed) + `embed_tokens_int4_scales.bin` (fp16). 3. The browser does the **gather + dequant in JS** to build `inputs_embeds`, then runs the INT4 language graph. ## Files (browser-facing) | File | Size | Notes | |------|------|-------| | `onnx/vision_mlp.onnx` (+`.data`) | ~1.73 GB | MoonViT + projector, fp32 (see note below) | | `onnx/language_tail_int4.onnx` (+`.data`) | ~1.69 GB | Qwen2 language tower + tied lm_head, weight-only INT4 (block 128) | | `onnx/embed_tokens_int4_packed.bin` | ~156 MB | INT4 embedding table, uint8 nibble-packed `[152681, 1024]` | | `onnx/embed_tokens_int4_scales.bin` | ~19.5 MB | fp16 group scales `[152681, 64]` | | `onnx/embed_tokens_int4_meta.json` | — | layout / dequant scheme | | `web_config.json` | — | runtime wiring, token ids, tail size | **Total browser payload ≈ 3.6 GB.** The big win here is the language side: the embedding table dropped from **1.25 GB fp32 → 176 MB INT4** and the language tail from **2.9 GB → 1.69 GB**. > **Vision precision note.** The vision tower's linears export as ONNX `Gemm` / dynamic `MatMul`, > which ORT's `MatMulNBits` INT4 quantizer cannot compress, so it ships fp32. fp16 conversion is > blocked by the explicit `.float()` Cast islands in the ONNX-friendly MoonViT RoPE patch (post-hoc > fp16 conversion produces type clashes; native fp16 export hits a torch/MPS `expand_as`+float64 > limitation). A native mixed-precision vision re-export (Conv in fp32, rest fp16) is the planned > follow-up to cut this to ~0.9 GB. ## Embedding gather / dequant (JS reference) ``` row = packed[token_id] // uint8[hidden/2] low = row & 0x0F ; high = (row >> 4) // two nibbles per byte (low = even idx, high = odd idx) q = interleave(low, high) // uint4[hidden], values 0..15 emb = (q - 8) * scales[token_id][j/32] // fp32[hidden]; one scale per 32-wide group ``` The language graph then splices `visual_features` over the image-token positions and applies the SDLM block mask from `input_ids`. ## Validation Validated against the fp32 PyTorch model on the sample image (slow / autoregressive mode): - **Next-token argmax matches PyTorch exactly** (token `151672` = `` start). - INT4 embedding gather error vs fp32 embeddings: `max_abs 0.017`, `mean_rel ≈ 10%` (per element), contributing only `~0.98` to the final logits (argmax-preserving). - The dominant INT4 *weight* error (`~12.6` max logit delta) is unchanged from the baseline INT4 build. - fp16 vision vs fp32 vision: see `validation_report.json`. > Generation mode: use **`slow` (autoregressive, greedy)**. The earlier prefill tail positions used by > the `fast`/MTP path diverge under INT4 and are not relied upon here. ## Intended use Open-vocabulary detection / visual grounding: given an image and a category prompt (`Locate all the instances that matches the following description: .`), the model emits `labelx1 y1 x2 y2` with coordinates normalized to `0–1000`. ## KV-cache graph for in-browser use `onnx/language_tail_kv_int4.onnx` (+`.data`, ~1.65 GB) is the **KV-cache** version of the language tail used by the live demo. It takes `inputs_embeds` (+ `input_ids` for the plain-causal mask, `position_ids`, and 36×2 `past_key/value` GQA tensors `[1,2,seq,128]`) and returns `logits` for the last position plus `present_key/value`. Prefill passes length-0 past; decode passes the growing cache. This makes autoregressive decoding ~13× faster than the cache-less graph. Validated: prefill (empty past) → cached decode reproduces `label` detections; next-token argmax matches the fp32 torch model. See `kv_validation_report.json`. **Live in-browser demo (WebGPU):** https://huggingface.co/spaces/Reza2kn/LocateAnything-3B-WebGPU Source model & license: [`nvidia/LocateAnything-3B`](https://huggingface.co/nvidia/LocateAnything-3B).