#!/usr/bin/env python3 """ AETHER-Micro Heterogeneous MoE 大5 + 小15 + Shared2 구조 """ import torch import torch.nn as nn import torch.nn.functional as F from .configuration_aether_micro import AETHERMicroConfig from .router import WuXingRouter class AETHERMicroExpert(nn.Module): """ Single Expert: Linear → GELU → Linear """ def __init__(self, hidden_size: int, intermediate_size: int): super().__init__() self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) self.act_fn = nn.GELU() def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x))) class AETHERMicroMoE(nn.Module): """ Heterogeneous MoE Structure: - 大 Experts: 5 × 2048 dim (high capacity) - 小 Experts: 15 × 1024 dim (low capacity) - Shared Experts: 2 × 1536 dim (always active) - Top-K routing: k=2 """ def __init__(self, config: AETHERMicroConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.top_k = config.top_k # Heterogeneous experts if config.enable_hetero_moe: # 大 experts (5개) self.大_experts = nn.ModuleList([ AETHERMicroExpert( self.hidden_size, config.大_intermediate_size ) for _ in range(config.num_大_experts) ]) # 小 experts (15개) self.小_experts = nn.ModuleList([ AETHERMicroExpert( self.hidden_size, config.小_intermediate_size ) for _ in range(config.num_小_experts) ]) # Shared experts (2개) self.shared_experts = nn.ModuleList([ AETHERMicroExpert( self.hidden_size, config.shared_intermediate_size ) for _ in range(config.num_shared_experts) ]) # Note: self.experts removed to avoid shared tensor error in save_pretrained # Use 大_experts + 小_experts directly in forward self.num_experts = config.num_大_experts + config.num_小_experts else: # Uniform experts (baseline) self.experts = nn.ModuleList([ AETHERMicroExpert( self.hidden_size, config.小_intermediate_size ) for _ in range(config.num_小_experts) ]) self.num_experts = config.num_小_experts self.shared_experts = None # Router if config.enable_wuxing: self.router = WuXingRouter(config) else: # Basic router self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Args: hidden_states: (batch_size, seq_length, hidden_size) Returns: moe_output: (batch_size, seq_length, hidden_size) """ batch_size, seq_length, hidden_size = hidden_states.shape # Flatten for expert processing hidden_states_flat = hidden_states.view(-1, hidden_size) # (batch*seq, hidden) # Routing if hasattr(self, 'router'): # Wu-Xing Router router_logits = self.router(hidden_states) # (batch, seq, num_experts) router_logits = router_logits.view(-1, self.num_experts) # (batch*seq, num_experts) else: # Basic router router_logits = self.gate(hidden_states_flat) # (batch*seq, num_experts) # Top-K selection routing_weights = F.softmax(router_logits, dim=-1) routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # Renormalize # Expert execution final_hidden_states = torch.zeros( (batch_size * seq_length, hidden_size), dtype=hidden_states.dtype, device=hidden_states.device ) # Process each expert for expert_idx in range(self.num_experts): # Find tokens routed to this expert expert_mask = (selected_experts == expert_idx).any(dim=-1) token_indices = expert_mask.nonzero(as_tuple=True)[0] if token_indices.numel() == 0: continue # Get expert output (Heterogeneous MoE) expert_input = hidden_states_flat[token_indices] if expert_idx < self.config.num_大_experts: # 大 expert (0~4) expert_output = self.大_experts[expert_idx](expert_input) else: # 小 expert (5~19) expert_output = self.小_experts[expert_idx - self.config.num_大_experts](expert_input) # Get routing weights for this expert expert_weights = torch.zeros(batch_size * seq_length, device=hidden_states.device) for k in range(self.top_k): mask = (selected_experts[:, k] == expert_idx) expert_weights[mask] = routing_weights[mask, k] # Accumulate weighted output final_hidden_states[token_indices] += expert_weights[token_indices].unsqueeze(-1) * expert_output # Shared experts (always active) if self.shared_experts is not None: shared_output = torch.zeros_like(hidden_states_flat) for shared_expert in self.shared_experts: shared_output += shared_expert(hidden_states_flat) shared_output /= len(self.shared_experts) final_hidden_states += shared_output # Reshape back moe_output = final_hidden_states.view(batch_size, seq_length, hidden_size) return moe_output