#!/usr/bin/env python3 """ AETHER-Micro Multi-Token Prediction Loss Reference: "Better & Faster Large Language Models via Multi-token Prediction" (Meta, 2024) MTP predicts next N tokens simultaneously instead of just the next token. This improves both perplexity and reasoning capabilities. """ import torch import torch.nn as nn from typing import Tuple, Dict from .configuration_aether_micro import AETHERMicroConfig class MTPLoss(nn.Module): """ Multi-Token Prediction Loss (Block 5) Predicts next N tokens simultaneously: - n=1: Standard NTP (next token prediction) - n=2: Predict next 2 tokens - n=3: Predict next 3 tokens - n=4: Predict next 4 tokens Loss = weighted sum of n-step prediction losses Weight decreases with distance: 1.0, 0.5, 0.33, 0.25 """ def __init__(self, config: AETHERMicroConfig): super().__init__() self.hidden_size = config.hidden_size self.vocab_size = config.vocab_size self.num_predictions = getattr(config, 'mtp_num_predictions', 4) # Separate prediction heads for each future token self.prediction_heads = nn.ModuleList([ nn.Linear(self.hidden_size, self.vocab_size, bias=False) for _ in range(self.num_predictions) ]) def forward( self, hidden_states: torch.Tensor, # (batch, seq, hidden) labels: torch.LongTensor # (batch, seq) ) -> Tuple[torch.Tensor, Dict]: """ Compute MTP loss Args: hidden_states: (batch_size, seq_length, hidden_size) labels: (batch_size, seq_length) Returns: loss: Weighted sum of n-step prediction losses metrics: Per-step losses """ batch_size, seq_len, _ = hidden_states.shape total_loss = 0.0 metrics = {} loss_fct = nn.CrossEntropyLoss() for i in range(self.num_predictions): # Predict token at position +i+1 logits = self.prediction_heads[i](hidden_states) # Shift labels by i+1 if seq_len > i + 1: # logits: predict token at position t # labels: target token at position t+i+1 shifted_logits = logits[:, :-(i+1), :].contiguous() shifted_labels = labels[:, i+1:].contiguous() # Compute cross-entropy loss = loss_fct( shifted_logits.reshape(-1, self.vocab_size), shifted_labels.reshape(-1) ) # Weighted sum (decay by distance) weight = 1.0 / (i + 1) total_loss += weight * loss metrics[f'mtp_loss_{i+1}'] = loss.item() # Normalize by sum of weights # weights: 1.0, 0.5, 0.33, 0.25 → sum ≈ 2.08 weight_sum = sum(1.0 / (i + 1) for i in range(self.num_predictions)) total_loss = total_loss / weight_sum return total_loss, metrics def create_mtp_heads(config: AETHERMicroConfig) -> MTPLoss: """ Factory function to create MTP heads Usage: mtp_heads = create_mtp_heads(config) loss, metrics = mtp_heads(hidden_states, labels) """ return MTPLoss(config)