""" Custom HuggingFace model for Open LM checkpoints. Open LM uses LayerNorm (not RMSNorm) and QK norm, which standard LlamaForCausalLM does not support. This module provides: - OpenLMConfig: LlamaConfig subclass with qk_norm flag - OpenLMForCausalLM: LlamaForCausalLM subclass with LayerNorm + QK norm Usage: model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=True) """ from typing import Callable, Optional import torch import torch.nn as nn from transformers import LlamaConfig, LlamaForCausalLM from transformers.models.llama.modeling_llama import ( ALL_ATTENTION_FUNCTIONS, LlamaAttention, LlamaRMSNorm, apply_rotary_pos_emb, eager_attention_forward, ) try: from typing import Unpack from transformers.utils.generic import TransformersKwargs except ImportError: pass from transformers.cache_utils import Cache class OpenLMConfig(LlamaConfig): model_type = "open_lm" def __init__(self, qk_norm: bool = True, **kwargs): super().__init__(**kwargs) self.qk_norm = qk_norm class OpenLMAttention(LlamaAttention): """LlamaAttention with QK norm applied before reshape (matching Open LM).""" def __init__(self, config: OpenLMConfig, layer_idx: int): super().__init__(config, layer_idx) if getattr(config, "qk_norm", False): self.q_norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, bias=False) self.k_norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, bias=False) else: self.q_norm = nn.Identity() self.k_norm = nn.Identity() def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) # QK norm applied to flat projected vectors BEFORE reshape (matches Open LM) query_states = self.q_norm(self.q_proj(hidden_states)).view(hidden_shape).transpose(1, 2) key_states = self.k_norm(self.k_proj(hidden_states)).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_values.update( key_states, value_states, self.layer_idx, cache_kwargs ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, query_states, key_states, value_states, attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, **kwargs, ) attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights class OpenLMForCausalLM(LlamaForCausalLM): """LlamaForCausalLM with LayerNorm (instead of RMSNorm) and QK norm support.""" config_class = OpenLMConfig def __init__(self, config: OpenLMConfig): super().__init__(config) # Replace all LlamaRMSNorm with nn.LayerNorm(bias=False) eps = config.rms_norm_eps hidden_size = config.hidden_size self.model.norm = nn.LayerNorm(hidden_size, eps=eps, bias=False) for layer in self.model.layers: layer.input_layernorm = nn.LayerNorm(hidden_size, eps=eps, bias=False) layer.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=eps, bias=False) # Replace attention module with QK norm version layer.self_attn = OpenLMAttention(config, layer.self_attn.layer_idx) # Re-run post_init to tie weights etc. self.post_init()