Instructions to use TheQweaker/mdlm-owt-noflash with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use TheQweaker/mdlm-owt-noflash with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("fill-mask", model="TheQweaker/mdlm-owt-noflash", trust_remote_code=True)# Load model directly from transformers import AutoModelForMaskedLM model = AutoModelForMaskedLM.from_pretrained("TheQweaker/mdlm-owt-noflash", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
MDLM-OWT — flash-attn-free reimplementation
A drop-in reimplementation of kuleshov-group/mdlm-owt
(Sahoo et al., Simple and Effective Masked Diffusion Language Models, NeurIPS 2024)
that removes the hard dependency on flash-attn. The weights are the original
weights, byte-for-byte; only the modeling code changed, so the model runs anywhere —
Windows, CPU, and recent CUDA GPUs (including Blackwell / sm_120) where a flash-attn
build is impractical.
This port was produced as the reference substrate for a cross-layer-transcoder
interpretability study, where a deterministic, dependency-light, hookable forward pass
matters. It is published in case it is useful to others blocked by the flash-attn
requirement.
What changed (and what did not)
Three changes to the computation, all in the modeling code:
- Attention —
torch.nn.functional.scaled_dot_product_attention(non-causal, full bidirectional) replacesflash_attn_varlen_qkvpacked_func. Equivalent here: one sequence per batch element, no causal mask, default softmax scale1/sqrt(head_dim). - Rotary embeddings — applied with an explicit
rotate_halfinstead offlash_attn.layers.rotary. This reproduces flash-attn's non-interleaved convention exactly (the modeling file carries a doctest that proves the equivalence). - Precision — the transformer stack runs in the model's native dtype (fp32 by
default) rather than under a
bfloat16autocast. fp32 is preferable for interpretability and for use as a numerical oracle.
Plus one compatibility fix: the constructor now calls self.post_init(), which the
upstream code omits. transformers ≥ 5 requires it to finalize from_pretrained (without
it, loading raises AttributeError: ... 'all_tied_weights_keys'). MDLM unties its
embeddings, so this registers no tied weights — it only initializes HF bookkeeping.
Unchanged: every weight (the original model.safetensors, SHA256
47149e73f7552f39ea9776dbe74d925d25237bcf2ed2e2ec03cdff9d51c82aa4), all parameter
names, the architecture, and the forward signature. The original weights load with no
remapping.
Numerical note: because the upstream model runs its block stack under a bf16 autocast while this port runs fp32, logits differ from the upstream path at roughly bf16 tolerance. This port computes the higher-precision fp32 result of the same math.
Usage
trust_remote_code=True is required (the modeling code ships with the repo). The
tokenizer is plain GPT-2; the model's vocabulary is GPT-2's 50257 tokens plus a [MASK]
token at id 50257 (vocab size 50258).
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForMaskedLM.from_pretrained(
"TheQweaker/mdlm-owt-noflash", trust_remote_code=True).eval()
MASK = 50257
ids = tokenizer("The capital of France is Paris.")["input_ids"]
pos = ids.index(tokenizer(" Paris")["input_ids"][0])
ids[pos] = MASK # mask the word " Paris"
x = torch.tensor([ids])
with torch.no_grad():
logits = model(input_ids=x) # raw [1, L, 50258] tensor (config sets return_dict=False)
logits[..., MASK] = float("-inf") # never predict [MASK]
print(tokenizer.decode([logits[0, pos].argmax().item()])) # -> " Paris"
timesteps may be omitted: the checkpoint was trained with time_conditioning=False, so
the time input is ignored internally. (Pass return_dict=True if you want a
MaskedLMOutput instead of a raw tensor.)
Generation
The original checkpoint ships no sampler. sample.py in this repo provides a faithful
masked-diffusion ancestral sampler (SUBS parameterization, linear schedule):
from sample import generate # from this repo
gen = torch.Generator().manual_seed(0)
ids = generate(model, seq_len=64, num_steps=128, top_k=50, generator=gen)
print(tokenizer.decode(ids.tolist()))
Note that a 169M-parameter masked-diffusion model under simple ancestral sampling tends to produce locally grammatical but repetitive text; sample quality is sensitive to step count, temperature, and top-k.
Model details (unchanged from upstream)
- Architecture: 12-layer DiT-style transformer,
d_model768, 12 heads, adaLN conditioning, untied embeddings, rotary positions. ≈169.6M total parameters (≈92M non-embedding; the input and output embeddings are untied, ≈77M together). Context length 1024. - Training: OpenWebText, 1M steps (Sahoo et al., 2024).
Requirements and environment
No flash-attn. The runtime dependencies are just torch (any recent version with
scaled_dot_product_attention), transformers, and safetensors; the tokenizer is
plain gpt2 (tokenizers). trust_remote_code=True is required.
This port was developed and validated against the following exact stack:
| Component | Version |
|---|---|
| Python | 3.14.0 |
| OS | Windows 11 |
| PyTorch | 2.10.0+cu130 (CUDA 13.0) |
| transformers | 5.1.0 |
| safetensors | 0.7.0 |
| tokenizers | 0.22.2 |
| GPU | NVIDIA RTX 5060 Ti (16 GB, Blackwell / sm_120) |
The self.post_init() compatibility fix (see What changed) is what makes loading work
on transformers ≥ 5; it is also present in older transformers, so the model remains
loadable there. The model runs on CPU and on any CUDA GPU — the point of the port is that
it does not require a flash-attn build.
Provenance and credit
All weights and the architecture are the work of Sahoo et al. This repository contributes
only the flash-attn-free modeling code. Please cite the original work and consult the
MDLM GitHub repository and
paper.
@inproceedings{sahoo2024simple,
title={Simple and Effective Masked Diffusion Language Models},
author={Subham Sekhar Sahoo and Marianne Arriola and Aaron Gokaslan and Edgar Mariano Marroquin and Alexander M Rush and Yair Schiff and Justin T Chiu and Volodymyr Kuleshov},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=L4uaAR4ArM}
}
License
Apache-2.0, matching the upstream checkpoint.
- Downloads last month
- 12
Model tree for TheQweaker/mdlm-owt-noflash
Base model
kuleshov-group/mdlm-owt