#!/usr/bin/env python3 """ AETHER-Micro Attention Layer Multi-Head Attention with RoPE + GQA + Flash Attention """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional from .configuration_aether_micro import AETHERMicroConfig from .embeddings import AETHERMicroRotaryEmbedding from .utils import repeat_kv, apply_rotary_pos_emb class AETHERMicroAttention(nn.Module): """ Multi-Head Attention with RoPE and GQA Features: - Rotary Position Embedding (RoPE) - Grouped Query Attention (GQA): num_key_value_heads < num_heads - Flash Attention compatibility (F.scaled_dot_product_attention) """ def __init__(self, config: AETHERMicroConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.head_dim = self.hidden_size // self.num_heads # GQA: num_key_value_heads <= num_heads if self.num_heads % self.num_key_value_heads != 0: raise ValueError( f"num_heads ({self.num_heads}) must be divisible by " f"num_key_value_heads ({self.num_key_value_heads})" ) self.num_key_value_groups = self.num_heads // self.num_key_value_heads # Q, K, V projections self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) # RoPE self.rotary_emb = AETHERMicroRotaryEmbedding( self.head_dim, max_position_embeddings=config.max_position_embeddings, base=config.rope_theta ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ) -> torch.Tensor: """ Args: hidden_states: (batch_size, seq_length, hidden_size) attention_mask: (batch_size, 1, seq_length, seq_length) or None position_ids: (batch_size, seq_length) or None Returns: attn_output: (batch_size, seq_length, hidden_size) """ batch_size, seq_length, _ = hidden_states.shape # Q, K, V projections query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) # Reshape for multi-head attention # (batch, seq, num_heads, head_dim) query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1, 2) # RoPE kv_seq_len = key_states.shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) # Position IDs (default: 0, 1, 2, ..., seq_length-1) if position_ids is None: position_ids = torch.arange( seq_length, dtype=torch.long, device=hidden_states.device ) position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) # Repeat K/V for GQA (if num_key_value_groups > 1) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) # Flash Attention (scaled_dot_product_attention) attn_output = F.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=None, # Causal LM: always None for FlashAttn optimal path dropout_p=self.config.attention_dropout if self.training else 0.0, is_causal=True # Always causal for pretraining ) # Reshape back: (batch, num_heads, seq, head_dim) → (batch, seq, hidden) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size) # Output projection attn_output = self.o_proj(attn_output) return attn_output