File size: 5,709 Bytes
9e47315 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | """
facs2dipl / dipl2norm β Multi-task Char-level Transformer
Model definition, encode/decode helpers, and greedy inference.
Save this file to: /content/drive/MyDrive/facs2dipl_multitask/model_def.py
Then in any notebook run: %run /content/drive/MyDrive/facs2dipl_multitask/model_def.py
"""
import math
import torch
import torch.nn as nn
# ββ Special token indices (must match training) ββββββββββββββββββββββββββββ
PAD, SOS, EOS, UNK = 0, 1, 2, 3
# Task prefix token indices
DIPL_IDX = 4 # <DIPL> β facsβdipl task
NORM_IDX = 5 # <NORM> β diplβnorm task
# ββ Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
pe = torch.zeros(max_len, d_model)
pos = torch.arange(max_len).unsqueeze(1)
div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(pos * div)
pe[:, 1::2] = torch.cos(pos * div)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return self.dropout(x + self.pe[:, :x.size(1)])
class CharSeq2Seq(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, n_enc, n_dec, d_ff, max_len, dropout):
super().__init__()
self.d_model = d_model
self.embed = nn.Embedding(vocab_size, d_model, padding_idx=PAD)
self.pos_enc = PositionalEncoding(d_model, max_len, dropout)
enc_layer = nn.TransformerEncoderLayer(d_model, n_heads, d_ff, dropout, batch_first=True, norm_first=True)
dec_layer = nn.TransformerDecoderLayer(d_model, n_heads, d_ff, dropout, batch_first=True, norm_first=True)
self.encoder = nn.TransformerEncoder(enc_layer, n_enc, norm=nn.LayerNorm(d_model))
self.decoder = nn.TransformerDecoder(dec_layer, n_dec, norm=nn.LayerNorm(d_model))
self.proj = nn.Linear(d_model, vocab_size)
self._init_weights()
def _init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def encode(self, src, src_key_padding_mask):
x = self.pos_enc(self.embed(src) * math.sqrt(self.d_model))
return self.encoder(x, src_key_padding_mask=src_key_padding_mask)
def decode(self, tgt, memory, tgt_mask, tgt_key_padding_mask, memory_key_padding_mask):
x = self.pos_enc(self.embed(tgt) * math.sqrt(self.d_model))
return self.decoder(
x, memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
)
def forward(self, src, tgt):
src_pad_mask = (src == PAD)
tgt_pad_mask = (tgt == PAD)
T = tgt.size(1)
tgt_mask = nn.Transformer.generate_square_subsequent_mask(T, device=src.device)
memory = self.encode(src, src_pad_mask)
out = self.decode(tgt, memory, tgt_mask, tgt_pad_mask, src_pad_mask)
return self.proj(out)
# ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def encode_text(text, task_idx, c2i, max_len):
"""
Encode source text with a task prefix token.
Layout: [task_token, char, char, ..., EOS, PAD, ...]
"""
ids = [task_idx] + [c2i.get(c, UNK) for c in text]
ids = ids[:max_len - 1] + [EOS]
ids += [PAD] * (max_len - len(ids))
return ids
def encode_target(text, c2i, max_len):
"""Encode target: [SOS, char, char, ..., EOS, PAD, ...]"""
ids = [SOS] + [c2i.get(c, UNK) for c in text]
ids = ids[:max_len - 1] + [EOS]
ids += [PAD] * (max_len - len(ids))
return ids
def decode_ids(ids, i2c):
"""Decode token ids to string, stopping at EOS, skipping special tokens."""
chars = []
skip = {PAD, SOS, EOS, UNK, DIPL_IDX, NORM_IDX}
for i in ids:
if i == EOS:
break
if i not in skip:
chars.append(i2c.get(i, ''))
return ''.join(chars)
@torch.no_grad()
def greedy_decode(model, src, max_len, device, i2c):
"""
Greedy decode a batch. Task is implicit in the src prefix token.
Args:
model : CharSeq2Seq in eval mode
src : LongTensor (B, S) β first token is the task prefix
max_len : int
device : str or torch.device
i2c : dict[int, str]
Returns:
list[str] of length B
"""
model.eval()
src = src.to(device)
src_pad_mask = (src == PAD)
memory = model.encode(src, src_pad_mask)
B = src.size(0)
ys = torch.full((B, 1), SOS, dtype=torch.long, device=device)
done = torch.zeros(B, dtype=torch.bool, device=device)
for _ in range(max_len - 1):
T = ys.size(1)
tgt_mask = nn.Transformer.generate_square_subsequent_mask(T, device=device)
tgt_pad = (ys == PAD)
out = model.decode(ys, memory, tgt_mask, tgt_pad, src_pad_mask)
next_tok = model.proj(out[:, -1]).argmax(-1)
next_tok = next_tok.masked_fill(done, PAD)
ys = torch.cat([ys, next_tok.unsqueeze(1)], dim=1)
done = done | (next_tok == EOS)
if done.all():
break
return [decode_ids(row.tolist(), i2c) for row in ys]
|