#!/usr/bin/env python3 """ AETHER-Micro Wu-Xing Router 5-Agent Router with Magic Square initialization """ import torch import torch.nn as nn import torch.nn.functional as F from .configuration_aether_micro import AETHERMicroConfig class WuXingRouter(nn.Module): """ 5-Agent Wu-Xing Router (金/水/木/火/土) Features: - 5-Agent activation with Magic Square initialization - 1:4 mapping: each Agent controls 4 experts - Optional Wu-Xing Interaction bias """ def __init__(self, config: AETHERMicroConfig): super().__init__() self.config = config self.num_agents = 5 if config.enable_hetero_moe: self.num_experts = config.num_大_experts + config.num_小_experts else: self.num_experts = config.num_小_experts if hasattr(config, 'num_小_experts') else 20 self.experts_per_agent = self.num_experts // self.num_agents # Agent gate: hidden_size → 5 agents self.agent_gate = nn.Linear(config.hidden_size, self.num_agents, bias=False) # Magic Square initialization (if enabled) if config.enable_magic_init: self._init_magic_square() # Wu-Xing Interaction (disabled for now - will be integrated in future version) # TODO: Integrate Wu-Xing inter-layer propagation self.wuxing = None def _init_magic_square(self): """ Initialize agent_gate with 5×5 Pandiagonal Magic Square Magic Square (sum = 65): 17 24 1 8 15 23 5 7 14 16 4 6 13 20 22 10 12 19 21 3 11 18 25 2 9 Normalized to [-1, 1]: (value - 13) / 12 """ magic_square = torch.tensor([ [17, 24, 1, 8, 15], [23, 5, 7, 14, 16], [ 4, 6, 13, 20, 22], [10, 12, 19, 21, 3], [11, 18, 25, 2, 9] ], dtype=torch.float32) # Normalize to [-1, 1] magic_square = (magic_square - 13.0) / 12.0 # Assign to agent_gate weight # Shape: (num_agents=5, hidden_size) # We use the magic square to initialize the first 5×5 block with torch.no_grad(): # If hidden_size > 5, repeat the pattern for i in range(0, self.config.hidden_size, 5): end_idx = min(i + 5, self.config.hidden_size) self.agent_gate.weight[:, i:end_idx] = magic_square[:, :end_idx-i] def forward(self, hidden_states: torch.Tensor, return_router_stats: bool = False): """ Route hidden_states to experts via 5-Agent activation Args: hidden_states: (batch_size, seq_length, hidden_size) return_router_stats: If True, return (expert_probs, router_stats) for monitoring Returns: expert_probs: (batch_size, seq_length, num_experts) router_stats (optional): dict with utilization metrics """ batch_size, seq_length, _ = hidden_states.shape # Agent activation: (batch, seq, hidden) → (batch, seq, 5) agent_logits = self.agent_gate(hidden_states) # Softmax: normalize to probabilities agent_probs = F.softmax(agent_logits, dim=-1) # (batch, seq, 5) # 1:4 mapping: Agent i controls experts [i*4, i*4+1, i*4+2, i*4+3] # Repeat each agent probability 4 times expert_probs = agent_probs.repeat_interleave(self.experts_per_agent, dim=-1) # Shape: (batch, seq, num_experts) if return_router_stats: # Mean utilization per agent across batch and sequence # Shape: (5,) - average probability each agent receives mean_agent_util = agent_probs.mean(dim=[0, 1]) # (5,) # Entropy: H = -sum(p * log(p)), max=log(5)≈1.609 (uniform), min=0 (collapse) entropy = -(agent_probs * torch.log(agent_probs + 1e-9)).sum(dim=-1).mean() # Load imbalance: max_util / (1/num_agents) - 1.0 = 0 이면 완전 균형 imbalance = (mean_agent_util.max() / (1.0 / self.num_agents)) - 1.0 agent_names = ["金", "水", "木", "火", "土"] router_stats = { "router/entropy": entropy.item(), "router/entropy_ratio": (entropy / torch.log(torch.tensor(float(self.num_agents)))).item(), "router/imbalance": imbalance.item(), **{f"router/agent_{agent_names[i]}_util": mean_agent_util[i].item() for i in range(self.num_agents)} } return expert_probs, router_stats return expert_probs