old-icelandic-facs2dipl2norm / model_def_multitask.py
NKCZ's picture
Initial commit
9e47315 verified
Raw
History Blame Contribute Delete
5.71 kB
"""
facs2dipl / dipl2norm β€” Multi-task Char-level Transformer
Model definition, encode/decode helpers, and greedy inference.
Save this file to: /content/drive/MyDrive/facs2dipl_multitask/model_def.py
Then in any notebook run: %run /content/drive/MyDrive/facs2dipl_multitask/model_def.py
"""
import math
import torch
import torch.nn as nn
# ── Special token indices (must match training) ────────────────────────────
PAD, SOS, EOS, UNK = 0, 1, 2, 3
# Task prefix token indices
DIPL_IDX = 4 # <DIPL> — facs→dipl task
NORM_IDX = 5 # <NORM> — dipl→norm task
# ── Model ──────────────────────────────────────────────────────────────────
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model)
pos = torch.arange(max_len).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return self.dropout(x + self.pe[:, :x.size(1)])
class CharSeq2Seq(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, n_enc, n_dec, d_ff, max_len, dropout):
super().__init__()
self.d_model = d_model
self.embed = nn.Embedding(vocab_size, d_model, padding_idx=PAD)
self.pos_enc = PositionalEncoding(d_model, max_len, dropout)
enc_layer = nn.TransformerEncoderLayer(d_model, n_heads, d_ff, dropout, batch_first=True, norm_first=True)
dec_layer = nn.TransformerDecoderLayer(d_model, n_heads, d_ff, dropout, batch_first=True, norm_first=True)
self.encoder = nn.TransformerEncoder(enc_layer, n_enc, norm=nn.LayerNorm(d_model))
self.decoder = nn.TransformerDecoder(dec_layer, n_dec, norm=nn.LayerNorm(d_model))
self.proj = nn.Linear(d_model, vocab_size)
self._init_weights()
def _init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def encode(self, src, src_key_padding_mask):
x = self.pos_enc(self.embed(src) * math.sqrt(self.d_model))
return self.encoder(x, src_key_padding_mask=src_key_padding_mask)
def decode(self, tgt, memory, tgt_mask, tgt_key_padding_mask, memory_key_padding_mask):
x = self.pos_enc(self.embed(tgt) * math.sqrt(self.d_model))
return self.decoder(
x, memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
)
def forward(self, src, tgt):
src_pad_mask = (src == PAD)
tgt_pad_mask = (tgt == PAD)
T = tgt.size(1)
tgt_mask = nn.Transformer.generate_square_subsequent_mask(T, device=src.device)
memory = self.encode(src, src_pad_mask)
out = self.decode(tgt, memory, tgt_mask, tgt_pad_mask, src_pad_mask)
return self.proj(out)
# ── Helpers ────────────────────────────────────────────────────────────────
def encode_text(text, task_idx, c2i, max_len):
"""
Encode source text with a task prefix token.
Layout: [task_token, char, char, ..., EOS, PAD, ...]
"""
ids = [task_idx] + [c2i.get(c, UNK) for c in text]
ids = ids[:max_len - 1] + [EOS]
ids += [PAD] * (max_len - len(ids))
return ids
def encode_target(text, c2i, max_len):
"""Encode target: [SOS, char, char, ..., EOS, PAD, ...]"""
ids = [SOS] + [c2i.get(c, UNK) for c in text]
ids = ids[:max_len - 1] + [EOS]
ids += [PAD] * (max_len - len(ids))
return ids
def decode_ids(ids, i2c):
"""Decode token ids to string, stopping at EOS, skipping special tokens."""
chars = []
skip = {PAD, SOS, EOS, UNK, DIPL_IDX, NORM_IDX}
for i in ids:
if i == EOS:
break
if i not in skip:
chars.append(i2c.get(i, ''))
return ''.join(chars)
@torch.no_grad()
def greedy_decode(model, src, max_len, device, i2c):
"""
Greedy decode a batch. Task is implicit in the src prefix token.
Args:
model : CharSeq2Seq in eval mode
src : LongTensor (B, S) β€” first token is the task prefix
max_len : int
device : str or torch.device
i2c : dict[int, str]
Returns:
list[str] of length B
"""
model.eval()
src = src.to(device)
src_pad_mask = (src == PAD)
memory = model.encode(src, src_pad_mask)
B = src.size(0)
ys = torch.full((B, 1), SOS, dtype=torch.long, device=device)
done = torch.zeros(B, dtype=torch.bool, device=device)
for _ in range(max_len - 1):
T = ys.size(1)
tgt_mask = nn.Transformer.generate_square_subsequent_mask(T, device=device)
tgt_pad = (ys == PAD)
out = model.decode(ys, memory, tgt_mask, tgt_pad, src_pad_mask)
next_tok = model.proj(out[:, -1]).argmax(-1)
next_tok = next_tok.masked_fill(done, PAD)
ys = torch.cat([ys, next_tok.unsqueeze(1)], dim=1)
done = done | (next_tok == EOS)
if done.all():
break
return [decode_ids(row.tolist(), i2c) for row in ys]