""" facs2dipl / dipl2norm — Multi-task Char-level Transformer Model definition, encode/decode helpers, and greedy inference. Save this file to: /content/drive/MyDrive/facs2dipl_multitask/model_def.py Then in any notebook run: %run /content/drive/MyDrive/facs2dipl_multitask/model_def.py """ import math import torch import torch.nn as nn # ── Special token indices (must match training) ──────────────────────────── PAD, SOS, EOS, UNK = 0, 1, 2, 3 # Task prefix token indices DIPL_IDX = 4 # — facs→dipl task NORM_IDX = 5 # — dipl→norm task # ── Model ────────────────────────────────────────────────────────────────── class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len, dropout): super().__init__() self.dropout = nn.Dropout(dropout) pe = torch.zeros(max_len, d_model) pos = torch.arange(max_len).unsqueeze(1) div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(pos * div) pe[:, 1::2] = torch.cos(pos * div) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): return self.dropout(x + self.pe[:, :x.size(1)]) class CharSeq2Seq(nn.Module): def __init__(self, vocab_size, d_model, n_heads, n_enc, n_dec, d_ff, max_len, dropout): super().__init__() self.d_model = d_model self.embed = nn.Embedding(vocab_size, d_model, padding_idx=PAD) self.pos_enc = PositionalEncoding(d_model, max_len, dropout) enc_layer = nn.TransformerEncoderLayer(d_model, n_heads, d_ff, dropout, batch_first=True, norm_first=True) dec_layer = nn.TransformerDecoderLayer(d_model, n_heads, d_ff, dropout, batch_first=True, norm_first=True) self.encoder = nn.TransformerEncoder(enc_layer, n_enc, norm=nn.LayerNorm(d_model)) self.decoder = nn.TransformerDecoder(dec_layer, n_dec, norm=nn.LayerNorm(d_model)) self.proj = nn.Linear(d_model, vocab_size) self._init_weights() def _init_weights(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def encode(self, src, src_key_padding_mask): x = self.pos_enc(self.embed(src) * math.sqrt(self.d_model)) return self.encoder(x, src_key_padding_mask=src_key_padding_mask) def decode(self, tgt, memory, tgt_mask, tgt_key_padding_mask, memory_key_padding_mask): x = self.pos_enc(self.embed(tgt) * math.sqrt(self.d_model)) return self.decoder( x, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, ) def forward(self, src, tgt): src_pad_mask = (src == PAD) tgt_pad_mask = (tgt == PAD) T = tgt.size(1) tgt_mask = nn.Transformer.generate_square_subsequent_mask(T, device=src.device) memory = self.encode(src, src_pad_mask) out = self.decode(tgt, memory, tgt_mask, tgt_pad_mask, src_pad_mask) return self.proj(out) # ── Helpers ──────────────────────────────────────────────────────────────── def encode_text(text, task_idx, c2i, max_len): """ Encode source text with a task prefix token. Layout: [task_token, char, char, ..., EOS, PAD, ...] """ ids = [task_idx] + [c2i.get(c, UNK) for c in text] ids = ids[:max_len - 1] + [EOS] ids += [PAD] * (max_len - len(ids)) return ids def encode_target(text, c2i, max_len): """Encode target: [SOS, char, char, ..., EOS, PAD, ...]""" ids = [SOS] + [c2i.get(c, UNK) for c in text] ids = ids[:max_len - 1] + [EOS] ids += [PAD] * (max_len - len(ids)) return ids def decode_ids(ids, i2c): """Decode token ids to string, stopping at EOS, skipping special tokens.""" chars = [] skip = {PAD, SOS, EOS, UNK, DIPL_IDX, NORM_IDX} for i in ids: if i == EOS: break if i not in skip: chars.append(i2c.get(i, '')) return ''.join(chars) @torch.no_grad() def greedy_decode(model, src, max_len, device, i2c): """ Greedy decode a batch. Task is implicit in the src prefix token. Args: model : CharSeq2Seq in eval mode src : LongTensor (B, S) — first token is the task prefix max_len : int device : str or torch.device i2c : dict[int, str] Returns: list[str] of length B """ model.eval() src = src.to(device) src_pad_mask = (src == PAD) memory = model.encode(src, src_pad_mask) B = src.size(0) ys = torch.full((B, 1), SOS, dtype=torch.long, device=device) done = torch.zeros(B, dtype=torch.bool, device=device) for _ in range(max_len - 1): T = ys.size(1) tgt_mask = nn.Transformer.generate_square_subsequent_mask(T, device=device) tgt_pad = (ys == PAD) out = model.decode(ys, memory, tgt_mask, tgt_pad, src_pad_mask) next_tok = model.proj(out[:, -1]).argmax(-1) next_tok = next_tok.masked_fill(done, PAD) ys = torch.cat([ys, next_tok.unsqueeze(1)], dim=1) done = done | (next_tok == EOS) if done.all(): break return [decode_ids(row.tolist(), i2c) for row in ys]