"""HuggingFace-compatible inference model for MimeLens. Copied verbatim into each per-cell HF repo (`mjbommar/mimelens-001-*`). Lets users do: from transformers import AutoModel, AutoConfig config = AutoConfig.from_pretrained("mjbommar/mimelens-001-medium-bpe-16k-s1", trust_remote_code=True) model = AutoModel.from_pretrained("mjbommar/mimelens-001-medium-bpe-16k-s1", trust_remote_code=True) # → forward(input_ids, attention_mask) returns the mean-pooled body-token # embedding, shape (batch, hidden_size). The architecture is the small ModernBERT-style encoder from binary_embedding.models.encoder, vendored here to make each HF repo self-contained (no pip install binary_embedding required at inference time). Parameter naming is byte-compatible with the saved best.safetensors files so that AutoModel.from_pretrained() loads weights without prefix surgery. Pure torch; no scapy / sklearn / external deps. The mean-pool returned here is the same projection used throughout the paper; the cls_pool layer is known to receive no gradient under MLM-only training (see paper §3.4) and is kept only for state-dict compatibility — do not use it for downstream tasks. """ from __future__ import annotations from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput from .configuration_mimelens import MimeLensConfig # --------------------------------------------------------------------------- # Building blocks # --------------------------------------------------------------------------- class RMSNorm(nn.Module): """RMSNorm without bias. bf16-safe (norm computed in fp32).""" def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: variance = x.float().pow(2).mean(-1, keepdim=True) normed = x * torch.rsqrt(variance + self.eps).to(x.dtype) return normed * self.weight def _build_rope_cache(seq_len: int, head_dim: int, base: float, device: torch.device, dtype: torch.dtype): positions = torch.arange(seq_len, device=device, dtype=torch.float32) inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim)) freqs = torch.einsum("p,d->pd", positions, inv_freq) cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).to(dtype) sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).to(dtype) return cos, sin def _apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: """Apply rotary position embedding. x: (..., seq, head_dim).""" d = x.shape[-1] x1, x2 = x[..., : d // 2], x[..., d // 2 :] rotated = torch.cat([-x2, x1], dim=-1) return x * cos + rotated * sin class Attention(nn.Module): def __init__(self, hidden_size: int, num_heads: int, head_dim: int): super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.qkv = nn.Linear(hidden_size, 3 * num_heads * head_dim, bias=False) self.out = nn.Linear(num_heads * head_dim, hidden_size, bias=False) def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: B, S, _ = x.shape qkv = self.qkv(x).reshape(B, S, 3, self.num_heads, self.head_dim) q, k, v = qkv.unbind(dim=2) # each (B, S, H, D) q = _apply_rope(q.transpose(1, 2), cos, sin) # (B, H, S, D) k = _apply_rope(k.transpose(1, 2), cos, sin) v = v.transpose(1, 2) # attn_mask: (B, 1, 1, S) with -inf at pad positions, 0 at real positions out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) out = out.transpose(1, 2).contiguous().reshape(B, S, -1) return self.out(out) class FFN(nn.Module): """GeGLU FFN: gelu(w_gate(x)) * w_up(x) → w_down.""" def __init__(self, hidden_size: int, intermediate_size: int): super().__init__() self.w_gate = nn.Linear(hidden_size, intermediate_size, bias=False) self.w_up = nn.Linear(hidden_size, intermediate_size, bias=False) self.w_down = nn.Linear(intermediate_size, hidden_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_down(F.gelu(self.w_gate(x)) * self.w_up(x)) class Layer(nn.Module): def __init__(self, config: MimeLensConfig): super().__init__() self.norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attn = Attention(config.hidden_size, config.num_attention_heads, config.head_dim) self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.ffn = FFN(config.hidden_size, config.intermediate_size) def forward(self, x, cos, sin, attn_mask): x = x + self.attn(self.norm1(x), cos, sin, attn_mask) x = x + self.ffn(self.norm2(x)) return x # --------------------------------------------------------------------------- # Top-level model # --------------------------------------------------------------------------- class MimeLensModel(PreTrainedModel): """MimeLens encoder: bytes → mean-pooled embedding. Use `forward(input_ids, attention_mask)` and consume the `pooler_output` field (mean over body tokens, skipping the CLS / SEP / PAD positions). Last-hidden-state is also returned as `last_hidden_state` if you want to do your own pooling. """ config_class = MimeLensConfig base_model_prefix = "mimelens" def __init__(self, config: MimeLensConfig): super().__init__(config) self.config = config # Parameter naming MUST match best.safetensors: flat (no `encoder.` prefix). self.embed = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([Layer(config) for _ in range(config.num_hidden_layers)]) self.final_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # cls_pool is kept ONLY for safetensors compatibility — receives no # gradient under MLM-only training; use mean-pool instead. self.cls_pool = nn.Linear(config.hidden_size, config.cls_pool_dim, bias=False) # Lazily-built RoPE cache (one per device/dtype combination). self._rope_cache: Optional[tuple[torch.Tensor, torch.Tensor]] = None self._rope_cache_meta: Optional[tuple[torch.device, torch.dtype, int]] = None # No weight init here — we always load_state_dict from a pretrained # checkpoint via from_pretrained(). HF complains if we don't provide # an init_weights; provide the no-op version. self.post_init() def _init_weights(self, module): """No-op: we always load from a pretrained checkpoint.""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def _get_rope(self, seq_len: int, device: torch.device, dtype: torch.dtype): meta = (device, dtype, seq_len) if self._rope_cache_meta != meta: self._rope_cache = _build_rope_cache(seq_len, self.config.head_dim, self.config.rope_theta, device=device, dtype=dtype) self._rope_cache_meta = meta return self._rope_cache def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, output_hidden_states: bool = False, return_dict: bool = True, ): B, S = input_ids.shape x = self.embed(input_ids) # (B, S, H) # Build SDPA attention mask: (B, 1, 1, S) additive, -inf at pad. if attention_mask is None: attention_mask = torch.ones(B, S, device=input_ids.device, dtype=torch.long) # Convert to additive: real (=1) → 0, pad (=0) → -inf # SDPA expects mask broadcastable to (B, H, S, S) attn_mask = attention_mask.to(x.dtype) attn_mask = (1.0 - attn_mask).masked_fill((1.0 - attn_mask).bool(), torch.finfo(x.dtype).min) # shape: (B, 1, 1, S) — broadcasts over heads and queries attn_mask = attn_mask.view(B, 1, 1, S) cos, sin = self._get_rope(S, device=x.device, dtype=x.dtype) for layer in self.layers: x = layer(x, cos, sin, attn_mask) x = self.final_norm(x) last_hidden_state = x # Mean-pool over BODY tokens (skip CLS @ pos 0, SEP @ pos lens-1, PAD). # attention_mask is (B, S) of {0,1}. lens = attention_mask.sum(dim=1, keepdim=True) # (B, 1) positions = torch.arange(S, device=x.device).unsqueeze(0) # (1, S) body_mask = (positions >= 1) & (positions < (lens - 1)) # (B, S) bool body_mask_f = body_mask.to(x.dtype).unsqueeze(-1) # (B, S, 1) pooled = (x * body_mask_f).sum(dim=1) / body_mask_f.sum(dim=1).clamp(min=1) # shape: (B, H) if not return_dict: return (last_hidden_state, pooled) return BaseModelOutputWithPooling( last_hidden_state=last_hidden_state, pooler_output=pooled, ) # --------------------------------------------------------------------- # Helper utilities for users (not part of the standard HF surface) # --------------------------------------------------------------------- def encode_bytes( self, byte_window: bytes, tokenizer=None, seq_len: Optional[int] = None, ) -> torch.Tensor: """Convenience: encode one raw byte window into a mean-pooled embedding. For byte cells (`config.mimelens_vocab_pipeline == 'byte'`), tokenizer is ignored. For BPE cells, pass a BinaryTokenizer (from `mjbommar/binary-tokenizer-001-*`) or any tokenizer with `.encode(bytes) -> list[int]`. Returns a (1, hidden_size) tensor on the same device as the model. """ seq_len = seq_len or self.config.max_position_embeddings body = seq_len - 2 cls_id = self.config.cls_token_id sep_id = self.config.sep_token_id pad_id = self.config.pad_token_id if self.config.mimelens_vocab_pipeline == "byte": ids = [b + self.config.byte_offset for b in byte_window[:body]] else: if tokenizer is None: raise ValueError( f"BPE cell {self.config.mimelens_cell_id} requires a tokenizer; " f"load with e.g. `_native.BinaryTokenizer.from_file(...)` from " f"{self.config.mimelens_tokenizer_hub_id}" ) ids = list(tokenizer.encode(byte_window))[:body] out_ids = [cls_id, *ids, sep_id] attn = [1] * len(out_ids) + [0] * (seq_len - len(out_ids)) out_ids = out_ids + [pad_id] * (seq_len - len(out_ids)) device = next(self.parameters()).device input_ids = torch.tensor([out_ids], dtype=torch.long, device=device) attention_mask = torch.tensor([attn], dtype=torch.long, device=device) with torch.inference_mode(): return self(input_ids, attention_mask=attention_mask).pooler_output class MimeLensForSequenceClassification(PreTrainedModel): """MimeLens encoder + a 125-class libmagic-MIME classifier head. Lets users do, in one line: from transformers import pipeline clf = pipeline("text-classification", model="mjbommar/mimelens-001-medium-bpe-16k-s1", trust_remote_code=True) clf(open("some.bin", "rb").read(4096).decode("latin-1")) # → [{"label": "text/x-python", "score": 0.91}, ...] The classifier head is the same logistic-regression probe the paper reports on the magic-files corpus, re-fit on the full 4,096-file labelled set and baked into `model.safetensors` as `classifier.weight` and `classifier.bias`. Labels live in `config.id2label` / `config.label2id`. For embedding-only use, load via `AutoModel.from_pretrained(...)` instead, which returns mean-pooled embeddings and ignores the classifier head. """ config_class = MimeLensConfig base_model_prefix = "mimelens" def __init__(self, config: MimeLensConfig): super().__init__(config) self.config = config self.num_labels = getattr(config, "num_labels", 125) # The encoder body, identical to MimeLensModel — same parameter names so # the encoder weights load from the same safetensors keys. self.embed = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([Layer(config) for _ in range(config.num_hidden_layers)]) self.final_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.cls_pool = nn.Linear(config.hidden_size, config.cls_pool_dim, bias=False) # The 125-way classifier head. self.classifier = nn.Linear(config.hidden_size, self.num_labels) self._rope_cache: Optional[tuple[torch.Tensor, torch.Tensor]] = None self._rope_cache_meta: Optional[tuple[torch.device, torch.dtype, int]] = None self.post_init() def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def _get_rope(self, seq_len: int, device: torch.device, dtype: torch.dtype): meta = (device, dtype, seq_len) if self._rope_cache_meta != meta: self._rope_cache = _build_rope_cache(seq_len, self.config.head_dim, self.config.rope_theta, device=device, dtype=dtype) self._rope_cache_meta = meta return self._rope_cache def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, return_dict: bool = True, ): B, S = input_ids.shape x = self.embed(input_ids) if attention_mask is None: attention_mask = torch.ones(B, S, device=input_ids.device, dtype=torch.long) attn_mask = attention_mask.to(x.dtype) attn_mask = (1.0 - attn_mask).masked_fill((1.0 - attn_mask).bool(), torch.finfo(x.dtype).min) attn_mask = attn_mask.view(B, 1, 1, S) cos, sin = self._get_rope(S, device=x.device, dtype=x.dtype) for layer in self.layers: x = layer(x, cos, sin, attn_mask) x = self.final_norm(x) lens = attention_mask.sum(dim=1, keepdim=True) positions = torch.arange(S, device=x.device).unsqueeze(0) body_mask = (positions >= 1) & (positions < (lens - 1)) body_mask_f = body_mask.to(x.dtype).unsqueeze(-1) pooled = (x * body_mask_f).sum(dim=1) / body_mask_f.sum(dim=1).clamp(min=1) # Cast pooled to classifier dtype (bf16 encoder + fp32 classifier is common). logits = self.classifier(pooled.to(self.classifier.weight.dtype)) loss = None if labels is not None: loss = F.cross_entropy(logits, labels) if not return_dict: return (loss, logits) if loss is not None else (logits,) return SequenceClassifierOutput(loss=loss, logits=logits)