"""The EDEN encoder-decoder Transformer (training/inference reference model).""" from __future__ import annotations import math import torch import torch.nn as nn import torch.nn.functional as F from .config import TrainConfig from .constants import * class PositionalEncoding(nn.Module): def __init__(self, d_model: int, max_len: int, dropout: float): super().__init__() self.dropout = nn.Dropout(dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe.unsqueeze(0), persistent=False) def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.pe[:, : x.size(1), :].to(dtype=x.dtype) return self.dropout(x) class EdenTransformer(nn.Module): def __init__(self, cfg: TrainConfig): super().__init__() self.cfg = cfg self.pad_id = PAD_ID self.bos_id = BOS_ID self.eos_id = EOS_ID self.scale = math.sqrt(cfg.d_model) self.embedding = nn.Embedding(cfg.vocab_size, cfg.d_model, padding_idx=PAD_ID) self.pos = PositionalEncoding(cfg.d_model, cfg.max_len + 4, cfg.dropout) self.transformer = nn.Transformer( d_model=cfg.d_model, nhead=cfg.n_heads, num_encoder_layers=cfg.n_layers, num_decoder_layers=cfg.n_layers, dim_feedforward=cfg.dim_feedforward, dropout=cfg.dropout, activation="gelu", batch_first=True, norm_first=True, ) self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) self.lm_head.weight = self.embedding.weight self._reset_parameters() def _reset_parameters(self) -> None: for name, param in self.named_parameters(): if param.dim() > 1 and "embedding" not in name: nn.init.xavier_uniform_(param) nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02) with torch.no_grad(): self.embedding.weight[PAD_ID].zero_() def parameter_count(self) -> int: return sum(p.numel() for p in self.parameters() if p.requires_grad) def encode(self, src: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: src_padding = src.eq(PAD_ID) src_emb = self.pos(self.embedding(src) * self.scale) memory = self.transformer.encoder(src_emb, src_key_padding_mask=src_padding) return memory, src_padding def decode( self, tgt: torch.Tensor, memory: torch.Tensor, src_padding: torch.Tensor, ) -> torch.Tensor: tgt_padding = tgt.eq(PAD_ID) tgt_emb = self.pos(self.embedding(tgt) * self.scale) tgt_len = tgt.size(1) causal_mask = torch.triu( torch.ones(tgt_len, tgt_len, dtype=torch.bool, device=tgt.device), diagonal=1, ) hidden = self.transformer.decoder( tgt_emb, memory, tgt_mask=causal_mask, tgt_key_padding_mask=tgt_padding, memory_key_padding_mask=src_padding, ) return hidden def forward(self, src: torch.Tensor, tgt_in: torch.Tensor) -> torch.Tensor: memory, src_padding = self.encode(src) hidden = self.decode(tgt_in, memory, src_padding) return self.lm_head(hidden)