Llama-3.2-1B-DMS-8x

Description

Llama-3.2-1B-DMS-8x is a derivative of meta-llama/Llama-3.2-1B that integrates Dynamic Memory Sparsification (DMS) for KV-cache compression research and serving experiments.

DMS learns per-KV-head token eviction decisions that interpolate between a recent-token sliding window and full attention. The checkpoint stores those decisions in the borrowed-neuron convention used by the DMS paper and can be served by DMS-aware runtimes to reduce KV-cache memory at inference time.

This checkpoint is part of Shisa AI's FastDMS work. Standard Transformers can load the model as a normal LlamaForCausalLM, but compact KV-cache behavior requires a runtime that consumes the packaged DMS metadata and eviction signal.

This model is for research and development.

DMS Metadata

The runtime parameters are included both in config.json and in dms_metadata.json:

Parameter Value
DMS window 256 tokens
DMS alpha scale 100.0
DMS alpha offset 5.0
Target CR 8x
Base model meta-llama/Llama-3.2-1B

The full retained training log is included as training_log.json.

Training

The checkpoint was trained with the two-phase DMS procedure:

  1. Borrowed-neuron zeroing for 2000 steps.
  2. DMS retrofitting with logit distillation and compression loss through CR2 to CR8 checkpoints.

Training used WikiText-2 text chunks, context length 4096, DMS window 256, learning rate 3e-5, and compression weight 100.

The retained run took 1219.7s (20.3 minutes) of DMS training wall time on a single RTX PRO 6000 Blackwell GPU.

Evaluation

On the local WikiText-2 512 x 2 strict-mask gate, the final checkpoint measured:

Metric Value
Strict-mask PPL 7.2171
PPL delta vs base -1.31%
KLD vs base 0.021957 nats/token
Eviction rate 0.5767
Effective compression 2.36x

In Shisa AI's compact-DMS serving experiments, this checkpoint is the canonical Llama-3.2-1B corrected-mask v5 source used for the FastDMS/nano-vLLM serving rows.

Quick Start

Dense Transformers loading:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_id = "shisa-ai/Llama-3.2-1B-DMS-8x"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

DMS-aware serving:

See FastDMS for the runtime path that uses dms_metadata.json and the checkpoint's learned eviction signal for compact KV-cache inference.

License and Terms

This checkpoint is a derivative of Meta Llama 3.2 and is released under the Llama 3.2 license terms inherited from the base model.

References

Downloads last month
17
Safetensors
Model size
1B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for shisa-ai/Llama-3.2-1B-DMS-8x

Finetuned
(914)
this model
Quantizations
2 models

Dataset used to train shisa-ai/Llama-3.2-1B-DMS-8x

Paper for shisa-ai/Llama-3.2-1B-DMS-8x