mimelens-001-tiny-byte-s2 / modeling_mimelens.py
mjbommar's picture
mimelens-001 cell: tiny/byte/s2
85ca285 verified
Raw
History Blame Contribute Delete
16.5 kB
"""HuggingFace-compatible inference model for MimeLens.
Copied verbatim into each per-cell HF repo (`mjbommar/mimelens-001-*`). Lets
users do:
from transformers import AutoModel, AutoConfig
config = AutoConfig.from_pretrained("mjbommar/mimelens-001-medium-bpe-16k-s1",
trust_remote_code=True)
model = AutoModel.from_pretrained("mjbommar/mimelens-001-medium-bpe-16k-s1",
trust_remote_code=True)
# → forward(input_ids, attention_mask) returns the mean-pooled body-token
# embedding, shape (batch, hidden_size).
The architecture is the small ModernBERT-style encoder from
binary_embedding.models.encoder, vendored here to make each HF repo
self-contained (no pip install binary_embedding required at inference time).
Parameter naming is byte-compatible with the saved best.safetensors files so
that AutoModel.from_pretrained() loads weights without prefix surgery.
Pure torch; no scapy / sklearn / external deps. The mean-pool returned here
is the same projection used throughout the paper; the cls_pool layer is
known to receive no gradient under MLM-only training (see paper §3.4) and is
kept only for state-dict compatibility — do not use it for downstream tasks.
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
from .configuration_mimelens import MimeLensConfig
# ---------------------------------------------------------------------------
# Building blocks
# ---------------------------------------------------------------------------
class RMSNorm(nn.Module):
"""RMSNorm without bias. bf16-safe (norm computed in fp32)."""
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
variance = x.float().pow(2).mean(-1, keepdim=True)
normed = x * torch.rsqrt(variance + self.eps).to(x.dtype)
return normed * self.weight
def _build_rope_cache(seq_len: int, head_dim: int, base: float,
device: torch.device, dtype: torch.dtype):
positions = torch.arange(seq_len, device=device, dtype=torch.float32)
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device,
dtype=torch.float32) / head_dim))
freqs = torch.einsum("p,d->pd", positions, inv_freq)
cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).to(dtype)
sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).to(dtype)
return cos, sin
def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
"""Apply rotary position embedding. x: (..., seq, head_dim)."""
d = x.shape[-1]
x1, x2 = x[..., : d // 2], x[..., d // 2 :]
rotated = torch.cat([-x2, x1], dim=-1)
return x * cos + rotated * sin
class Attention(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, head_dim: int):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=False)
self.out = nn.Linear(num_heads * head_dim, hidden_size, bias=False)
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
attn_mask: torch.Tensor) -> torch.Tensor:
B, S, _ = x.shape
qkv = self.qkv(x).reshape(B, S, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2) # each (B, S, H, D)
q = _apply_rope(q.transpose(1, 2), cos, sin) # (B, H, S, D)
k = _apply_rope(k.transpose(1, 2), cos, sin)
v = v.transpose(1, 2)
# attn_mask: (B, 1, 1, S) with -inf at pad positions, 0 at real positions
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
out = out.transpose(1, 2).contiguous().reshape(B, S, -1)
return self.out(out)
class FFN(nn.Module):
"""GeGLU FFN: gelu(w_gate(x)) * w_up(x) → w_down."""
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.w_gate = nn.Linear(hidden_size, intermediate_size, bias=False)
self.w_up = nn.Linear(hidden_size, intermediate_size, bias=False)
self.w_down = nn.Linear(intermediate_size, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_down(F.gelu(self.w_gate(x)) * self.w_up(x))
class Layer(nn.Module):
def __init__(self, config: MimeLensConfig):
super().__init__()
self.norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.attn = Attention(config.hidden_size, config.num_attention_heads,
config.head_dim)
self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.ffn = FFN(config.hidden_size, config.intermediate_size)
def forward(self, x, cos, sin, attn_mask):
x = x + self.attn(self.norm1(x), cos, sin, attn_mask)
x = x + self.ffn(self.norm2(x))
return x
# ---------------------------------------------------------------------------
# Top-level model
# ---------------------------------------------------------------------------
class MimeLensModel(PreTrainedModel):
"""MimeLens encoder: bytes → mean-pooled embedding.
Use `forward(input_ids, attention_mask)` and consume the `pooler_output`
field (mean over body tokens, skipping the CLS / SEP / PAD positions).
Last-hidden-state is also returned as `last_hidden_state` if you want to
do your own pooling.
"""
config_class = MimeLensConfig
base_model_prefix = "mimelens"
def __init__(self, config: MimeLensConfig):
super().__init__(config)
self.config = config
# Parameter naming MUST match best.safetensors: flat (no `encoder.` prefix).
self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([Layer(config) for _ in range(config.num_hidden_layers)])
self.final_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# cls_pool is kept ONLY for safetensors compatibility — receives no
# gradient under MLM-only training; use mean-pool instead.
self.cls_pool = nn.Linear(config.hidden_size, config.cls_pool_dim, bias=False)
# Lazily-built RoPE cache (one per device/dtype combination).
self._rope_cache: Optional[tuple[torch.Tensor, torch.Tensor]] = None
self._rope_cache_meta: Optional[tuple[torch.device, torch.dtype, int]] = None
# No weight init here — we always load_state_dict from a pretrained
# checkpoint via from_pretrained(). HF complains if we don't provide
# an init_weights; provide the no-op version.
self.post_init()
def _init_weights(self, module):
"""No-op: we always load from a pretrained checkpoint."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _get_rope(self, seq_len: int, device: torch.device, dtype: torch.dtype):
meta = (device, dtype, seq_len)
if self._rope_cache_meta != meta:
self._rope_cache = _build_rope_cache(seq_len, self.config.head_dim,
self.config.rope_theta,
device=device, dtype=dtype)
self._rope_cache_meta = meta
return self._rope_cache
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
output_hidden_states: bool = False,
return_dict: bool = True,
):
B, S = input_ids.shape
x = self.embed(input_ids) # (B, S, H)
# Build SDPA attention mask: (B, 1, 1, S) additive, -inf at pad.
if attention_mask is None:
attention_mask = torch.ones(B, S, device=input_ids.device, dtype=torch.long)
# Convert to additive: real (=1) → 0, pad (=0) → -inf
# SDPA expects mask broadcastable to (B, H, S, S)
attn_mask = attention_mask.to(x.dtype)
attn_mask = (1.0 - attn_mask).masked_fill((1.0 - attn_mask).bool(),
torch.finfo(x.dtype).min)
# shape: (B, 1, 1, S) — broadcasts over heads and queries
attn_mask = attn_mask.view(B, 1, 1, S)
cos, sin = self._get_rope(S, device=x.device, dtype=x.dtype)
for layer in self.layers:
x = layer(x, cos, sin, attn_mask)
x = self.final_norm(x)
last_hidden_state = x
# Mean-pool over BODY tokens (skip CLS @ pos 0, SEP @ pos lens-1, PAD).
# attention_mask is (B, S) of {0,1}.
lens = attention_mask.sum(dim=1, keepdim=True) # (B, 1)
positions = torch.arange(S, device=x.device).unsqueeze(0) # (1, S)
body_mask = (positions >= 1) & (positions < (lens - 1)) # (B, S) bool
body_mask_f = body_mask.to(x.dtype).unsqueeze(-1) # (B, S, 1)
pooled = (x * body_mask_f).sum(dim=1) / body_mask_f.sum(dim=1).clamp(min=1)
# shape: (B, H)
if not return_dict:
return (last_hidden_state, pooled)
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled,
)
# ---------------------------------------------------------------------
# Helper utilities for users (not part of the standard HF surface)
# ---------------------------------------------------------------------
def encode_bytes(
self,
byte_window: bytes,
tokenizer=None,
seq_len: Optional[int] = None,
) -> torch.Tensor:
"""Convenience: encode one raw byte window into a mean-pooled embedding.
For byte cells (`config.mimelens_vocab_pipeline == 'byte'`), tokenizer
is ignored. For BPE cells, pass a BinaryTokenizer (from
`mjbommar/binary-tokenizer-001-*`) or any tokenizer with `.encode(bytes)
-> list[int]`.
Returns a (1, hidden_size) tensor on the same device as the model.
"""
seq_len = seq_len or self.config.max_position_embeddings
body = seq_len - 2
cls_id = self.config.cls_token_id
sep_id = self.config.sep_token_id
pad_id = self.config.pad_token_id
if self.config.mimelens_vocab_pipeline == "byte":
ids = [b + self.config.byte_offset for b in byte_window[:body]]
else:
if tokenizer is None:
raise ValueError(
f"BPE cell {self.config.mimelens_cell_id} requires a tokenizer; "
f"load with e.g. `_native.BinaryTokenizer.from_file(...)` from "
f"{self.config.mimelens_tokenizer_hub_id}"
)
ids = list(tokenizer.encode(byte_window))[:body]
out_ids = [cls_id, *ids, sep_id]
attn = [1] * len(out_ids) + [0] * (seq_len - len(out_ids))
out_ids = out_ids + [pad_id] * (seq_len - len(out_ids))
device = next(self.parameters()).device
input_ids = torch.tensor([out_ids], dtype=torch.long, device=device)
attention_mask = torch.tensor([attn], dtype=torch.long, device=device)
with torch.inference_mode():
return self(input_ids, attention_mask=attention_mask).pooler_output
class MimeLensForSequenceClassification(PreTrainedModel):
"""MimeLens encoder + a 125-class libmagic-MIME classifier head.
Lets users do, in one line:
from transformers import pipeline
clf = pipeline("text-classification",
model="mjbommar/mimelens-001-medium-bpe-16k-s1",
trust_remote_code=True)
clf(open("some.bin", "rb").read(4096).decode("latin-1"))
# → [{"label": "text/x-python", "score": 0.91}, ...]
The classifier head is the same logistic-regression probe the paper
reports on the magic-files corpus, re-fit on the full 4,096-file
labelled set and baked into `model.safetensors` as `classifier.weight`
and `classifier.bias`. Labels live in `config.id2label` / `config.label2id`.
For embedding-only use, load via `AutoModel.from_pretrained(...)` instead,
which returns mean-pooled embeddings and ignores the classifier head.
"""
config_class = MimeLensConfig
base_model_prefix = "mimelens"
def __init__(self, config: MimeLensConfig):
super().__init__(config)
self.config = config
self.num_labels = getattr(config, "num_labels", 125)
# The encoder body, identical to MimeLensModel — same parameter names so
# the encoder weights load from the same safetensors keys.
self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([Layer(config) for _ in range(config.num_hidden_layers)])
self.final_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.cls_pool = nn.Linear(config.hidden_size, config.cls_pool_dim, bias=False)
# The 125-way classifier head.
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
self._rope_cache: Optional[tuple[torch.Tensor, torch.Tensor]] = None
self._rope_cache_meta: Optional[tuple[torch.device, torch.dtype, int]] = None
self.post_init()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _get_rope(self, seq_len: int, device: torch.device, dtype: torch.dtype):
meta = (device, dtype, seq_len)
if self._rope_cache_meta != meta:
self._rope_cache = _build_rope_cache(seq_len, self.config.head_dim,
self.config.rope_theta,
device=device, dtype=dtype)
self._rope_cache_meta = meta
return self._rope_cache
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: bool = True,
):
B, S = input_ids.shape
x = self.embed(input_ids)
if attention_mask is None:
attention_mask = torch.ones(B, S, device=input_ids.device, dtype=torch.long)
attn_mask = attention_mask.to(x.dtype)
attn_mask = (1.0 - attn_mask).masked_fill((1.0 - attn_mask).bool(),
torch.finfo(x.dtype).min)
attn_mask = attn_mask.view(B, 1, 1, S)
cos, sin = self._get_rope(S, device=x.device, dtype=x.dtype)
for layer in self.layers:
x = layer(x, cos, sin, attn_mask)
x = self.final_norm(x)
lens = attention_mask.sum(dim=1, keepdim=True)
positions = torch.arange(S, device=x.device).unsqueeze(0)
body_mask = (positions >= 1) & (positions < (lens - 1))
body_mask_f = body_mask.to(x.dtype).unsqueeze(-1)
pooled = (x * body_mask_f).sum(dim=1) / body_mask_f.sum(dim=1).clamp(min=1)
# Cast pooled to classifier dtype (bf16 encoder + fp32 classifier is common).
logits = self.classifier(pooled.to(self.classifier.weight.dtype))
loss = None
if labels is not None:
loss = F.cross_entropy(logits, labels)
if not return_dict:
return (loss, logits) if loss is not None else (logits,)
return SequenceClassifierOutput(loss=loss, logits=logits)