import math from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, List import torch from torch import nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_ramo import RamoConfig # --------------------------- # Norm + RoPE (ported from your training code) # --------------------------- class RMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (..., hidden) variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) return x * self.weight def rotate_half(x: torch.Tensor) -> torch.Tensor: x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): cos = cos.unsqueeze(unsqueeze_dim) sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def _get_mscale(scale, mscale=None): # same idea as HF yarn: mscale influences attention scaling if mscale is None: return 1.0 return float(mscale * math.log(scale) + 1.0) def _compute_default_rope_parameters(cfg: RamoConfig, device, seq_len: Optional[int] = None): base = float(cfg.rope_config.get("rope_theta", 10000.0)) dim = int((cfg.hidden_size // cfg.num_attention_heads) * float(cfg.rope_config.get("partial_rotary_factor", 1.0))) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) return inv_freq, 1.0 def _compute_dynamic_ntk_parameters(cfg: RamoConfig, device, seq_len: Optional[int] = None): # simplified dynamic-ntk (good enough for inference compatibility) base = float(cfg.rope_config.get("rope_theta", 10000.0)) dim = int((cfg.hidden_size // cfg.num_attention_heads) * float(cfg.rope_config.get("partial_rotary_factor", 1.0))) if seq_len is None: seq_len = cfg.max_position_embeddings factor = float(cfg.rope_config.get("factor", 1.0)) base = base * (factor ** (dim / (dim - 2))) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) return inv_freq, 1.0 def _compute_yarn_parameters(cfg: RamoConfig, device, seq_len: Optional[int] = None): rope_theta = float(cfg.rope_config.get("rope_theta", 10000.0)) factor = float(cfg.rope_config.get("factor", 1.0)) original_max_position_embeddings = int(cfg.original_max_position_embeddings or cfg.max_position_embeddings) dim = int((cfg.hidden_size // cfg.num_attention_heads) * float(cfg.rope_config.get("partial_rotary_factor", 1.0))) attention_factor = cfg.rope_config.get("attention_factor", None) beta_fast = float(cfg.rope_config.get("beta_fast", 32)) beta_slow = float(cfg.rope_config.get("beta_slow", 1)) mscale = cfg.rope_config.get("mscale", None) mscale_all_dim = cfg.rope_config.get("mscale_all_dim", None) if attention_factor is None: if mscale and mscale_all_dim: attention_factor = float(_get_mscale(factor, mscale) / _get_mscale(factor, mscale_all_dim)) else: attention_factor = _get_mscale(factor, mscale) def find_correction_dim(num_rotations, dim, base, max_position_embeddings): return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings): low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings)) high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings)) return max(low, 0), min(high, dim - 1) def linear_ramp_factor(minv, maxv, dimh): if minv == maxv: maxv += 0.001 linear_func = (torch.arange(dimh, dtype=torch.float32, device=device) - minv) / (maxv - minv) return torch.clamp(linear_func, 0, 1) pos_freqs = rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (factor * pos_freqs) low, high = find_correction_range(beta_fast, beta_slow, dim, rope_theta, original_max_position_embeddings) inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, dim // 2) inv_freq = inv_freq_interpolation * (1 - inv_freq_extrapolation_factor) + inv_freq_extrapolation * inv_freq_extrapolation_factor return inv_freq, float(attention_factor) ROPE_INIT_FUNCTIONS = { "default": _compute_default_rope_parameters, "dynamic": _compute_dynamic_ntk_parameters, "yarn": _compute_yarn_parameters, } class RotaryEmbedding(nn.Module): def __init__(self, config: RamoConfig): super().__init__() self.config = config self.rope_type = (config.rope_config or {}).get("rope_type", "default") self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(config, device="cpu") self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings def _dynamic_frequency_update(self, position_ids, device): seq_len = int(torch.max(position_ids).item()) + 1 if seq_len > self.max_seq_len_cached: inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device=device, seq_len=seq_len) self.register_buffer("inv_freq", inv_freq, persistent=False) self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) self.max_seq_len_cached = self.original_max_seq_len @torch.no_grad() def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if self.rope_type == "dynamic": self._dynamic_frequency_update(position_ids, device=x.device) inv_freq = self.inv_freq.to(device=x.device) inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() sin = emb.sin() cos = cos * float(self.attention_scaling) sin = sin * float(self.attention_scaling) return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) # --------------------------- # Cache adapter (HF tuple <-> update() API) # --------------------------- class KVCache: """ Stores per-layer (k,v) with shape: k: (bsz, n_kv_heads, seq, head_dim) v: (bsz, n_kv_heads, seq, head_dim) """ def __init__(self, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None): self.key_cache: List[Optional[torch.Tensor]] = [] self.value_cache: List[Optional[torch.Tensor]] = [] if past_key_values is not None: for (k, v) in past_key_values: self.key_cache.append(k) self.value_cache.append(v) def update(self, k: torch.Tensor, v: torch.Tensor, layer_idx: int) -> Tuple[torch.Tensor, torch.Tensor]: while len(self.key_cache) <= layer_idx: self.key_cache.append(None) self.value_cache.append(None) if self.key_cache[layer_idx] is None: self.key_cache[layer_idx] = k self.value_cache[layer_idx] = v else: self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], k], dim=2) self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], v], dim=2) return self.key_cache[layer_idx], self.value_cache[layer_idx] def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: out = [] for k, v in zip(self.key_cache, self.value_cache): if k is None or v is None: out.append((torch.empty(0), torch.empty(0))) else: out.append((k, v)) return tuple(out) @property def past_len(self) -> int: for k in self.key_cache: if k is not None: return int(k.shape[2]) return 0 # --------------------------- # Attention mask helpers # --------------------------- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int) -> torch.Tensor: # mask: (bsz, src_len) with 1 for tokens, 0 for pad bsz, src_len = mask.shape expanded = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) inverted = 1.0 - expanded return inverted.masked_fill(inverted.to(torch.bool), torch.finfo(dtype).min) def _make_causal_mask(input_shape, dtype, device, past_key_values_length=0): bsz, tgt_len = input_shape src_len = tgt_len + past_key_values_length mask = torch.full((tgt_len, src_len), torch.finfo(dtype).min, device=device) # allow attending to <= current position i = torch.arange(tgt_len, device=device).unsqueeze(1) j = torch.arange(src_len, device=device).unsqueeze(0) mask = mask.masked_fill(j <= (i + past_key_values_length), 0) return mask[None, None, :, :].expand(bsz, 1, tgt_len, src_len) def prepare_decoder_attention_mask(attention_mask, input_shape, dtype, device, past_key_values_length=0): bsz, tgt_len = input_shape combined = _make_causal_mask(input_shape, dtype, device, past_key_values_length=past_key_values_length) if attention_mask is not None: expanded = _expand_mask(attention_mask, dtype, tgt_len=tgt_len) combined = combined + expanded return combined def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: # x: (bsz, n_kv, seq, head_dim) if n_rep == 1: return x bsz, n_kv, seq, hd = x.shape x = x[:, :, None, :, :].expand(bsz, n_kv, n_rep, seq, hd) return x.reshape(bsz, n_kv * n_rep, seq, hd) # --------------------------- # MLP + MoE # --------------------------- class MLP(nn.Module): def __init__(self, config: RamoConfig, intermediate_size: Optional[int] = None): super().__init__() hidden = config.hidden_size inter = int(intermediate_size if intermediate_size is not None else config.intermediate_size) self.gate_proj = nn.Linear(hidden, inter, bias=False) self.up_proj = nn.Linear(hidden, inter, bias=False) self.down_proj = nn.Linear(inter, hidden, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class MoEGate(nn.Module): def __init__(self, config: RamoConfig): super().__init__() moe = config.moe_config self.top_k = int(moe["num_experts_per_tok"]) self.n_routed_experts = int(moe["n_routed_experts"]) self.routed_scaling_factor = float(moe.get("routed_scaling_factor", 1.0)) self.seq_aux = bool(moe.get("seq_aux", True)) self.norm_topk_prob = bool(moe.get("norm_topk_prob", False)) self.gating_dim = int(config.hidden_size) # IMPORTANT: parameter name must be "weight" to match checkpoint key: *.mlp.gate.weight self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim))) self.reset_parameters() def reset_parameters(self): nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: # x: (tokens, hidden) logits = F.linear(x, self.weight) # (tokens, n_experts) probs = F.softmax(logits, dim=-1, dtype=torch.float32) # stable topk_weight, topk_idx = torch.topk(probs, k=self.top_k, dim=-1) if self.norm_topk_prob: topk_weight = topk_weight / (topk_weight.sum(dim=-1, keepdim=True) + 1e-9) topk_weight = topk_weight * self.routed_scaling_factor aux_loss = None if self.training: # load-balancing loss (Switch-style) # importance: mean probability per expert importance = probs.mean(dim=0) # (n_experts,) # load: fraction of tokens dispatched (using top-1) top1 = topk_idx[:, 0] load = torch.bincount(top1, minlength=self.n_routed_experts).float() load = load / (load.sum() + 1e-9) aux_loss = (self.n_routed_experts * (importance * load).sum()) return topk_idx, topk_weight, aux_loss class MoE(nn.Module): def __init__(self, config: RamoConfig): super().__init__() moe = config.moe_config self.n_routed_experts = int(moe["n_routed_experts"]) self.n_shared_experts = int(moe.get("n_shared_experts", 0)) self.moe_intermediate = int(moe["intermediate_size"]) self.gate = MoEGate(config) self.experts = nn.ModuleList([MLP(config, intermediate_size=self.moe_intermediate) for _ in range(self.n_routed_experts)]) self.shared_experts = None if self.n_shared_experts and self.n_shared_experts > 0: # match checkpoint naming: *.mlp.shared_experts.* # common pattern: one shared MLP with intermediate expanded by n_shared_experts self.shared_experts = MLP(config, intermediate_size=self.moe_intermediate) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # x: (bsz, seq, hidden) bsz, seq, hidden = x.shape x_flat = x.reshape(bsz * seq, hidden) topk_idx, topk_weight, aux_loss = self.gate(x_flat) out = torch.zeros_like(x_flat) # shared experts contribution (dense) if self.shared_experts is not None: out = out + self.shared_experts(x_flat) # routed experts for expert_id in range(self.n_routed_experts): mask = (topk_idx == expert_id) # (tokens, top_k) if not mask.any(): continue tok_idx, which = mask.nonzero(as_tuple=True) x_e = x_flat.index_select(0, tok_idx) y_e = self.experts[expert_id](x_e) w = topk_weight[tok_idx, which].unsqueeze(-1).to(y_e.dtype) out.index_add_(0, tok_idx, y_e * w) return out.reshape(bsz, seq, hidden), aux_loss # --------------------------- # Decoder blocks # --------------------------- class Attention(nn.Module): def __init__(self, config: RamoConfig, layer_idx: int): super().__init__() assert config.num_attention_heads % config.num_key_value_heads == 0 if config.attention_implementation == "auto": self.use_sdpa_attention = hasattr(F, "scaled_dot_product_attention") else: self.use_sdpa_attention = (config.attention_implementation == "sdpa") self.use_qk_norm = bool(getattr(config, "use_qk_norm", True)) self.layer_idx = int(layer_idx) self.hidden_size = int(config.hidden_size) self.num_heads = int(config.num_attention_heads) self.head_size = self.hidden_size // self.num_heads self.num_key_value_heads = int(config.num_key_value_heads) self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.scale = self.head_size ** -0.5 self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_size, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_size, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_size, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_size, self.hidden_size, bias=False) self.dropout = nn.Dropout(p=float(getattr(config, "attention_dropout", 0.0))) if self.use_qk_norm: self.q_norm = RMSNorm(self.head_size) self.k_norm = RMSNorm(self.head_size) def forward( self, hidden_states: torch.Tensor, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[KVCache] = None, ) -> torch.Tensor: bsz, seq_len, _ = hidden_states.shape q = self.q_proj(hidden_states).reshape(bsz, seq_len, -1, self.head_size) k = self.k_proj(hidden_states).reshape(bsz, seq_len, -1, self.head_size) v = self.v_proj(hidden_states).reshape(bsz, seq_len, -1, self.head_size) if self.use_qk_norm: q = self.q_norm(q) k = self.k_norm(k) # to (bsz, heads, seq, head_dim) q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) v = v.permute(0, 2, 1, 3) if position_embeddings is not None: cos, sin = position_embeddings q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1) if past_key_values is not None: # cache stores k/v as (bsz, kv_heads, seq, head_dim) # our current k/v are already (bsz, kv_heads, seq, head_dim) (because k_proj uses kv_heads) k, v = past_key_values.update(k, v, self.layer_idx) # Expand kv to heads for GQA if self.num_key_value_heads != self.num_heads: k = repeat_kv(k, self.num_key_value_groups) v = repeat_kv(v, self.num_key_value_groups) if self.use_sdpa_attention: dropout_p = self.dropout.p if self.training else 0.0 attn_mask = attention_mask if attn_mask is not None: # SDPA can choke on broadcasted/expanded views -> materialize + match heads if attn_mask.dim() == 4 and attn_mask.size(1) == 1: attn_mask = attn_mask.expand(bsz, self.num_heads, attn_mask.size(2), attn_mask.size(3)) attn_mask = attn_mask.to(dtype=q.dtype).contiguous() try: attn = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p) attn = attn.transpose(1, 2) # (bsz, seq, heads, head_dim) except RuntimeError: # Fallback: eager attention (always works) scores = (self.scale * q) @ k.transpose(-1, -2) if attn_mask is not None: scores = scores + attn_mask probs = scores.softmax(-1) probs = self.dropout(probs) if self.training else probs attn = (probs @ v).permute(0, 2, 1, 3) else: scores = (self.scale * q) @ k.transpose(-1, -2) if attention_mask is not None: scores = scores + attention_mask probs = scores.softmax(-1) probs = self.dropout(probs) if self.training else probs attn = (probs @ v).permute(0, 2, 1, 3) attn = attn.reshape(bsz, seq_len, -1) return self.o_proj(attn) class DecoderLayer(nn.Module): def __init__(self, config: RamoConfig, layer_idx: int): super().__init__() self.attn_norm = RMSNorm(config.hidden_size) self.attn = Attention(config, layer_idx) self.mlp_norm = RMSNorm(config.hidden_size) use_moe = ( config.moe_config is not None and config.moe_config.get("intermediate_size", None) is not None and config.moe_config.get("num_experts_per_tok", None) is not None and config.moe_config.get("n_routed_experts", None) is not None and config.moe_config.get("n_shared_experts", None) is not None and layer_idx >= int(config.moe_config.get("n_dense_layer", 0)) ) self.mlp = MoE(config) if use_moe else MLP(config) def forward( self, hidden_states: torch.Tensor, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[KVCache] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: residual = hidden_states hidden_states = self.attn( hidden_states=self.attn_norm(hidden_states), position_embeddings=position_embeddings, attention_mask=attention_mask, past_key_values=past_key_values, ) hidden_states = hidden_states + residual if isinstance(self.mlp, MoE): mlp_out, aux_loss = self.mlp(self.mlp_norm(hidden_states)) else: mlp_out = self.mlp(self.mlp_norm(hidden_states)) aux_loss = None hidden_states = hidden_states + mlp_out return hidden_states, aux_loss class LlmModel(nn.Module): def __init__(self, config: RamoConfig): super().__init__() self.config = config self.rotary_emb = RotaryEmbedding(config=config) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([DecoderLayer(config, i) for i in range(config.num_hidden_layers)]) self.head_norm = RMSNorm(config.hidden_size) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[KVCache] = None, use_cache: bool = False, logits_to_keep: int = 0, ) -> Tuple[torch.Tensor, Optional[KVCache], Optional[torch.Tensor]]: # input_ids: (bsz, seq) bsz, seq_len = input_ids.shape device = input_ids.device dtype = self.embed_tokens.weight.dtype if attention_mask is None: attention_mask = torch.ones((bsz, seq_len), device=device, dtype=torch.long) if position_ids is None: # left-padding aware position ids position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) past_len = past_key_values.past_len if past_key_values is not None else 0 attn_mask = prepare_decoder_attention_mask( attention_mask=attention_mask, input_shape=(bsz, seq_len), dtype=dtype, device=device, past_key_values_length=past_len, ) hidden_states = self.embed_tokens(input_ids) pos_emb = self.rotary_emb(hidden_states, position_ids) aux_losses = [] for layer in self.layers: hidden_states, aux = layer( hidden_states=hidden_states, position_embeddings=pos_emb, attention_mask=attn_mask, past_key_values=past_key_values, ) if aux is not None: aux_losses.append(aux) hidden_states = self.head_norm(hidden_states) if logits_to_keep and logits_to_keep > 0: hidden_states = hidden_states[:, -logits_to_keep:, :] logits = self.lm_head(hidden_states) aux_loss = None if aux_losses: aux_loss = torch.stack(aux_losses).mean() return logits, past_key_values, aux_loss # --------------------------- # HF wrapper # --------------------------- class RamoForCausalLM(PreTrainedModel): config_class = RamoConfig base_model_prefix = "model" _no_split_modules = ["DecoderLayer", "Attention", "MoE", "MLP"] def __init__(self, config: RamoConfig): super().__init__(config) self.model = LlmModel(config) self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.model.lm_head def set_output_embeddings(self, new_embeddings): self.model.lm_head = new_embeddings def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, labels: Optional[torch.Tensor] = None, use_cache: bool = False, **kwargs ) -> CausalLMOutputWithPast: cache = KVCache(past_key_values) if (use_cache or past_key_values is not None) else None logits, cache, aux_loss = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=cache, use_cache=use_cache, ) legacy_cache = cache.to_legacy_cache() if cache is not None else None loss = None if labels is not None: # standard causal LM loss shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100) if aux_loss is not None: loss = loss + aux_loss * 0.0 # keep aux_loss available without changing loss unless you want return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=legacy_cache, ) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs): # Disable cache for maximum compatibility if attention_mask is None: attention_mask = input_ids.new_ones(input_ids.shape) return { "input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False, "past_key_values": None, "position_ids": None, }