MONOSTEP v1

MONOSTEP is a small (~16.6M parameter) experimental research model trained on GSM8K. Instead of predicting one token at a time, it predicts a fixed block of SLOTS = 4 tokens per forward pass through a set of sequential "slot" heads β€” a lightweight take on multi-token prediction. It is built on the GPT-2 byte-level tokenizer with a few added chat/control special tokens.

⚠️ This is a tiny, rough research artifact. Expect imperfect, playful answers β€” it is not suitable for production use.

Architecture

input_ids ──► Trunk (Transformer encoder, mean-pooled) ──► h_shared
                                                              β”‚
                 init_state ─► Slot 0 ─► Slot 1 ─► Slot 2 ─► Slot 3
                                  β”‚        β”‚        β”‚        β”‚
                                logits   logits   logits   logits   β†’  [B, SLOTS, V]
  • Trunk β€” token + learned positional embeddings, a norm_first Transformer encoder (GELU, dim_feedforward = 4Β·d_model), followed by masked mean-pooling over the sequence and a final LayerNorm. Produces a single shared vector h_shared summarizing the prefix.
  • Slots β€” SLOTS independent heads applied sequentially. Each slot takes [h_shared, h_prev], runs a small residual MLP + LayerNorm, and projects to the vocabulary. The hidden state is threaded slot-to-slot so the block is generated left-to-right within one forward pass.
  • Decoding β€” autoregressive in blocks of SLOTS tokens: emit up to 4 tokens, append the non-<empty> ones to the context, and repeat until <|endoftext|>.

Configuration

Field Value Field Value
d_model 64 slots 4
n_layers 4 max_len 512
n_heads 4 vocab_size 50262
Parameters ~16.6M Tokenizer GPT-2 + specials

Special tokens added to the GPT-2 tokenizer (order matters β€” ids must match the checkpoint): <pad> (50257), <empty> (50258), plus <system>, <user>, <assistant>. EOS is GPT-2's <|endoftext|> (50256). <empty> is the padding label used to fill out a slot block and is skipped during generation.

Prompt format

<system> You are a helpful math assistant.
<user> {question}
<assistant> 

Training

  • Data: openai/gsm8k (main), train split for training, test split for eval. Each answer is chunked into blocks of SLOTS tokens; the model learns to predict the next block from the running prefix.
  • Objective: mean cross-entropy across the 4 slots, ignoring the <empty> label.
  • Optimizer: AdamW, lr 3e-4, gradient clipping at 1.0.
  • Schedule: 10 epochs, batch size 16, seed 42.

Results (cross-entropy loss)

Epoch Train Eval
1 5.923 5.518
5 4.802 5.023
10 4.263 4.596

loss curve

Files

File Description
monostep_bundle.pt torch.save bundle: model_state_dict, optimizer_state_dict, config, train_history, eval_history, sample_output
config.json Training/architecture config
metrics.json Per-epoch train/eval loss histories
tokenizer/ Saved GPT-2 tokenizer (with the added special tokens)
monostep_gsm8k_loss.png Loss curve

Usage

The model uses a custom architecture, so you need the class definitions below (no transformers AutoModel support).

import torch, torch.nn as nn
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer

# --- tokenizer (GPT-2 + the same specials, in the same order) ---
tok = AutoTokenizer.from_pretrained("gpt2")
tok.add_special_tokens({
    "pad_token": "<pad>",
    "additional_special_tokens": ["<empty>", "<system>", "<user>", "<assistant>"],
})
EMPTY_ID = tok.convert_tokens_to_ids("<empty>")
EOS_ID = tok.eos_token_id

# --- architecture ---
class Trunk(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, n_heads, max_len):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        layer = nn.TransformerEncoderLayer(d_model, n_heads, 4 * d_model, dropout=0.1,
                                           activation="gelu", batch_first=True, norm_first=True)
        self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
        self.out_norm = nn.LayerNorm(d_model)
    def forward(self, ids, mask):
        b, t = ids.shape
        pos = torch.arange(t, device=ids.device).unsqueeze(0).expand(b, -1)
        x = self.encoder(self.tok_emb(ids) + self.pos_emb(pos), src_key_padding_mask=~mask)
        m = mask.unsqueeze(-1).to(x.dtype)
        return self.out_norm((x * m).sum(1) / m.sum(1).clamp_min(1.0))

class Slot(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.ff = nn.Sequential(nn.Linear(d_model * 2, d_model), nn.GELU(), nn.Linear(d_model, d_model))
        self.norm = nn.LayerNorm(d_model)
        self.to_vocab = nn.Linear(d_model, vocab_size)
    def forward(self, h_shared, h_prev):
        h = self.norm(h_prev + self.ff(torch.cat([h_shared, h_prev], -1)))
        return h, self.to_vocab(h)

class Monostep(nn.Module):
    def __init__(self, vocab_size, d_model=64, n_layers=4, n_heads=4, max_len=512, slots=4):
        super().__init__()
        self.trunk = Trunk(vocab_size, d_model, n_layers, n_heads, max_len)
        self.init_state = nn.Parameter(torch.zeros(d_model))
        self.slots = nn.ModuleList([Slot(d_model, vocab_size) for _ in range(slots)])
    def forward(self, ids, mask):
        h_shared = self.trunk(ids, mask)
        h, outs = self.init_state.unsqueeze(0).expand(ids.size(0), -1), []
        for slot in self.slots:
            h, logits = slot(h_shared, h)
            outs.append(logits)
        return torch.stack(outs, dim=1)  # [B, SLOTS, V]

# --- load ---
ckpt = torch.load(hf_hub_download("wop/Monostep-v1", "monostep_bundle.pt"),
                  map_location="cpu", weights_only=False)
cfg = ckpt["config"]
model = Monostep(cfg["vocab_size"], cfg["d_model"], cfg["n_layers"],
                 cfg["n_heads"], cfg["max_len"], cfg["slots"])
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

# --- block-autoregressive generation ---
@torch.no_grad()
def generate(question, max_new_tokens=128):
    prompt = f"<system> You are a helpful math assistant.\n<user> {question.strip()}\n<assistant> "
    ids = torch.tensor([tok(prompt, add_special_tokens=False).input_ids])
    out = []
    for _ in range(max_new_tokens // cfg["slots"]):
        logits = model(ids, torch.ones_like(ids, dtype=torch.bool))[0]  # [SLOTS, V]
        block = [t for t in logits.argmax(-1).tolist() if t != EMPTY_ID]
        if not block:
            break
        out += block
        ids = torch.cat([ids, torch.tensor([block])], dim=1)[:, -cfg["max_len"]:]
        if EOS_ID in block:
            break
    return tok.decode(out, skip_special_tokens=True)

print(generate("If a shop has 12 apples and sells 5, how many are left?"))

This checkpoint is also wired into the Cosmos T2-Accelerate chat demo as a selectable model (Monostep v1), which streams the 4-token blocks live.

Limitations

  • Tiny capacity (~16.6M params) and only 10 epochs β€” answers are frequently wrong.
  • Trained single-turn on GSM8K math word problems; out-of-domain or multi-turn prompts are out of distribution.
  • The slot block decode predicts 4 tokens from a single pooled summary, so intra-block coherence is weaker than standard token-by-token decoding.

License

Released under the MIT license. The model derives from the GPT-2 tokenizer (MIT) and was trained on GSM8K (MIT).

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train wop/Monostep-v1

Space using wop/Monostep-v1 1