MDLM-OWT — flash-attn-free reimplementation

A drop-in reimplementation of kuleshov-group/mdlm-owt (Sahoo et al., Simple and Effective Masked Diffusion Language Models, NeurIPS 2024) that removes the hard dependency on flash-attn. The weights are the original weights, byte-for-byte; only the modeling code changed, so the model runs anywhere — Windows, CPU, and recent CUDA GPUs (including Blackwell / sm_120) where a flash-attn build is impractical.

This port was produced as the reference substrate for a cross-layer-transcoder interpretability study, where a deterministic, dependency-light, hookable forward pass matters. It is published in case it is useful to others blocked by the flash-attn requirement.

What changed (and what did not)

Three changes to the computation, all in the modeling code:

  1. Attention — torch.nn.functional.scaled_dot_product_attention (non-causal, full bidirectional) replaces flash_attn_varlen_qkvpacked_func. Equivalent here: one sequence per batch element, no causal mask, default softmax scale 1/sqrt(head_dim).
  2. Rotary embeddings — applied with an explicit rotate_half instead of flash_attn.layers.rotary. This reproduces flash-attn's non-interleaved convention exactly (the modeling file carries a doctest that proves the equivalence).
  3. Precision — the transformer stack runs in the model's native dtype (fp32 by default) rather than under a bfloat16 autocast. fp32 is preferable for interpretability and for use as a numerical oracle.

Plus one compatibility fix: the constructor now calls self.post_init(), which the upstream code omits. transformers ≥ 5 requires it to finalize from_pretrained (without it, loading raises AttributeError: ... 'all_tied_weights_keys'). MDLM unties its embeddings, so this registers no tied weights — it only initializes HF bookkeeping.

Unchanged: every weight (the original model.safetensors, SHA256 47149e73f7552f39ea9776dbe74d925d25237bcf2ed2e2ec03cdff9d51c82aa4), all parameter names, the architecture, and the forward signature. The original weights load with no remapping.

Numerical note: because the upstream model runs its block stack under a bf16 autocast while this port runs fp32, logits differ from the upstream path at roughly bf16 tolerance. This port computes the higher-precision fp32 result of the same math.

Usage

trust_remote_code=True is required (the modeling code ships with the repo). The tokenizer is plain GPT-2; the model's vocabulary is GPT-2's 50257 tokens plus a [MASK] token at id 50257 (vocab size 50258).

import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForMaskedLM.from_pretrained(
    "TheQweaker/mdlm-owt-noflash", trust_remote_code=True).eval()

MASK = 50257
ids = tokenizer("The capital of France is Paris.")["input_ids"]
pos = ids.index(tokenizer(" Paris")["input_ids"][0])
ids[pos] = MASK  # mask the word " Paris"

x = torch.tensor([ids])
with torch.no_grad():
    logits = model(input_ids=x)        # raw [1, L, 50258] tensor (config sets return_dict=False)
logits[..., MASK] = float("-inf")      # never predict [MASK]
print(tokenizer.decode([logits[0, pos].argmax().item()]))   # -> " Paris"

timesteps may be omitted: the checkpoint was trained with time_conditioning=False, so the time input is ignored internally. (Pass return_dict=True if you want a MaskedLMOutput instead of a raw tensor.)

Generation

The original checkpoint ships no sampler. sample.py in this repo provides a faithful masked-diffusion ancestral sampler (SUBS parameterization, linear schedule):

from sample import generate   # from this repo
gen = torch.Generator().manual_seed(0)
ids = generate(model, seq_len=64, num_steps=128, top_k=50, generator=gen)
print(tokenizer.decode(ids.tolist()))

Note that a 169M-parameter masked-diffusion model under simple ancestral sampling tends to produce locally grammatical but repetitive text; sample quality is sensitive to step count, temperature, and top-k.

Model details (unchanged from upstream)

  • Architecture: 12-layer DiT-style transformer, d_model 768, 12 heads, adaLN conditioning, untied embeddings, rotary positions. ≈169.6M total parameters (≈92M non-embedding; the input and output embeddings are untied, ≈77M together). Context length 1024.
  • Training: OpenWebText, 1M steps (Sahoo et al., 2024).

Requirements and environment

No flash-attn. The runtime dependencies are just torch (any recent version with scaled_dot_product_attention), transformers, and safetensors; the tokenizer is plain gpt2 (tokenizers). trust_remote_code=True is required.

This port was developed and validated against the following exact stack:

Component Version
Python 3.14.0
OS Windows 11
PyTorch 2.10.0+cu130 (CUDA 13.0)
transformers 5.1.0
safetensors 0.7.0
tokenizers 0.22.2
GPU NVIDIA RTX 5060 Ti (16 GB, Blackwell / sm_120)

The self.post_init() compatibility fix (see What changed) is what makes loading work on transformers ≥ 5; it is also present in older transformers, so the model remains loadable there. The model runs on CPU and on any CUDA GPU — the point of the port is that it does not require a flash-attn build.

Provenance and credit

All weights and the architecture are the work of Sahoo et al. This repository contributes only the flash-attn-free modeling code. Please cite the original work and consult the MDLM GitHub repository and paper.

@inproceedings{sahoo2024simple,
    title={Simple and Effective Masked Diffusion Language Models},
    author={Subham Sekhar Sahoo and Marianne Arriola and Aaron Gokaslan and Edgar Mariano Marroquin and Alexander M Rush and Yair Schiff and Justin T Chiu and Volodymyr Kuleshov},
    booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
    year={2024},
    url={https://openreview.net/forum?id=L4uaAR4ArM}
}

License

Apache-2.0, matching the upstream checkpoint.

Downloads last month
12
Safetensors
Model size
0.2B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for TheQweaker/mdlm-owt-noflash

Finetuned
(4)
this model

Dataset used to train TheQweaker/mdlm-owt-noflash

Paper for TheQweaker/mdlm-owt-noflash