#!/usr/bin/env python3 """ AETHER-Micro Latent Thought Loop (LTL) Dynamic K-step latent reasoning with confidence gating """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple from .configuration_aether_micro import AETHERMicroConfig class AETHERMicroLatentThought(nn.Module): """ Latent Thought Loop with Dynamic K and Confidence Gate Features: - Dynamic K selection (K=0: direct, K=1: shallow, K=2: deep) - Confidence Head with positive bias initialization - Thought Counter Embedding for K-step tracking - Efficient branching: K=0 direct path, K=1/K=2 conditional execution Args: config: AETHERMicroConfig with LTL settings Architecture: 1. Latent Projection: hidden → latent_dim 2. Latent Layers: 2 layers of self-attention + FFN 3. Output Projection: latent_dim → hidden 4. Confidence Head: hidden → 1 (sigmoid with bias=1.4) 5. K Predictor: hidden → 3 (K=0/1/2 logits) """ def __init__(self, config: AETHERMicroConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.latent_dim = config.latent_dim self.num_latents = config.num_latents self.max_k = config.max_k # Latent tokens (learnable) self.latent_tokens = nn.Parameter(torch.randn(1, self.num_latents, self.latent_dim)) # Projections self.down_proj = nn.Linear(self.hidden_size, self.latent_dim, bias=False) self.up_proj = nn.Linear(self.latent_dim, self.hidden_size, bias=False) # Latent layers (simplified Transformer) self.latent_layers = nn.ModuleList([ LatentLayer(self.latent_dim, num_heads=8) for _ in range(2) ]) # Confidence Head (sigmoid with positive bias) self.confidence_head = nn.Linear(self.hidden_size, 1, bias=True) with torch.no_grad(): # Initialize bias to 1.4 → sigmoid(1.4) ≈ 0.80 (initial confidence) self.confidence_head.bias.fill_(1.4) # K Predictor (3-way classification: K=0/1/2) self.k_predictor = nn.Linear(self.hidden_size, self.max_k + 1, bias=False) # Thought Counter Embedding (K=0/1/2) self.thought_counter = nn.Embedding(self.max_k + 1, self.latent_dim) def forward( self, hidden_states: torch.Tensor, deterministic: bool = False ) -> Tuple[torch.Tensor, dict]: """ Args: hidden_states: (batch_size, seq_length, hidden_size) deterministic: If True, use argmax for K selection (inference mode) Returns: output: (batch_size, seq_length, hidden_size) metrics: dict with k0_ratio, k1_ratio, k2_ratio, avg_k, avg_confidence """ batch_size, seq_length, _ = hidden_states.shape device = hidden_states.device # 1. Predict K for each token k_logits = self.k_predictor(hidden_states) # (B, S, max_k+1) if deterministic: # Inference: use argmax k_values = torch.argmax(k_logits, dim=-1) # (B, S) else: # Training: use Gumbel-Softmax for differentiability k_probs = F.gumbel_softmax(k_logits, tau=1.0, hard=True, dim=-1) # (B, S, max_k+1) k_values = torch.argmax(k_probs, dim=-1) # (B, S) # 2. Predict confidence confidence = torch.sigmoid(self.confidence_head(hidden_states)) # (B, S, 1) # 3. Initialize output output = torch.zeros_like(hidden_states) # 4. Process tokens by K value for k in range(self.max_k + 1): # Mask for tokens with K=k mask = (k_values == k) # (B, S) if not mask.any(): continue if k == 0: # K=0: Direct path (no latent reasoning) output[mask] = hidden_states[mask] else: # K=1/2: Latent reasoning # Extract tokens that need latent processing indices = torch.nonzero(mask, as_tuple=False) # (N, 2) where N = num tokens with K=k if indices.numel() == 0: continue # Gather hidden states h = hidden_states[indices[:, 0], indices[:, 1], :] # (N, hidden_size) # Down-project to latent space h_latent = self.down_proj(h) # (N, latent_dim) # Broadcast latent tokens latent = self.latent_tokens.expand(h.size(0), -1, -1) # (N, num_latents, latent_dim) # Add thought counter embedding counter_emb = self.thought_counter(torch.tensor([k], device=device)) # (1, latent_dim) latent = latent + counter_emb.unsqueeze(1) # (N, num_latents, latent_dim) # Concatenate with down-projected hidden states latent_input = torch.cat([h_latent.unsqueeze(1), latent], dim=1) # (N, 1+num_latents, latent_dim) # Process through latent layers (K times) for _ in range(k): for layer in self.latent_layers: latent_input = layer(latent_input) # Extract processed hidden state (first token) h_processed = latent_input[:, 0, :] # (N, latent_dim) # Up-project back to hidden space h_out = self.up_proj(h_processed) # (N, hidden_size) # Gate with confidence h_gated = confidence[indices[:, 0], indices[:, 1], :] * h_out # (N, hidden_size) # Residual connection h_final = hidden_states[indices[:, 0], indices[:, 1], :] + h_gated # Assign to output output[indices[:, 0], indices[:, 1], :] = h_final # 5. Compute metrics k0_ratio = (k_values == 0).float().mean().item() k1_ratio = (k_values == 1).float().mean().item() k2_ratio = (k_values == 2).float().mean().item() avg_k = k_values.float().mean().item() avg_confidence = confidence.mean().item() metrics = { 'k0_ratio': k0_ratio, 'k1_ratio': k1_ratio, 'k2_ratio': k2_ratio, 'avg_k': avg_k, 'avg_confidence': avg_confidence } return output, metrics class LatentLayer(nn.Module): """ Simplified Transformer layer for latent space Architecture: - LayerNorm → Self-Attention → Residual - LayerNorm → FFN → Residual """ def __init__(self, latent_dim: int, num_heads: int = 8): super().__init__() self.latent_dim = latent_dim self.num_heads = num_heads # Self-Attention self.norm1 = nn.LayerNorm(latent_dim) self.attn = nn.MultiheadAttention( latent_dim, num_heads, dropout=0.0, batch_first=True ) # FFN self.norm2 = nn.LayerNorm(latent_dim) self.ffn = nn.Sequential( nn.Linear(latent_dim, latent_dim * 4), nn.GELU(), nn.Linear(latent_dim * 4, latent_dim) ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (batch_size, seq_length, latent_dim) Returns: x: (batch_size, seq_length, latent_dim) """ # Self-Attention x_norm = self.norm1(x) attn_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False) x = x + attn_out # FFN x = x + self.ffn(self.norm2(x)) return x # ======================================== # Integration with AETHERMicroDecoderLayer # ======================================== def integrate_ltl_into_decoder_layer(decoder_layer, config: AETHERMicroConfig): """ Integrate LTL into existing AETHERMicroDecoderLayer Usage: from .latent_thought import integrate_ltl_into_decoder_layer # After creating decoder layer decoder_layer = AETHERMicroDecoderLayer(config) if config.enable_latent_thought: integrate_ltl_into_decoder_layer(decoder_layer, config) """ if not config.enable_latent_thought: return # Add LTL module decoder_layer.latent_thought = AETHERMicroLatentThought(config) # Wrap original forward method original_forward = decoder_layer.forward def forward_with_ltl( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, deterministic: bool = False, ) -> Tuple[torch.Tensor, Optional[dict]]: """ Forward pass with LTL Returns: hidden_states: (batch_size, seq_length, hidden_size) ltl_metrics: dict or None """ # Original forward (Attention + MoE) hidden_states = original_forward( hidden_states, attention_mask=attention_mask, position_ids=position_ids ) # LTL processing if hasattr(decoder_layer, 'latent_thought'): hidden_states, ltl_metrics = decoder_layer.latent_thought( hidden_states, deterministic=deterministic ) return hidden_states, ltl_metrics else: return hidden_states, None decoder_layer.forward = forward_with_ltl