| """ |
| facs2dipl / dipl2norm β Multi-task Char-level Transformer |
| Model definition, encode/decode helpers, and greedy inference. |
| |
| Save this file to: /content/drive/MyDrive/facs2dipl_multitask/model_def.py |
| Then in any notebook run: %run /content/drive/MyDrive/facs2dipl_multitask/model_def.py |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
|
|
| |
| PAD, SOS, EOS, UNK = 0, 1, 2, 3 |
| |
| DIPL_IDX = 4 |
| NORM_IDX = 5 |
|
|
|
|
| |
|
|
| class PositionalEncoding(nn.Module): |
| def __init__(self, d_model, max_len, dropout): |
| super().__init__() |
| self.dropout = nn.Dropout(dropout) |
| pe = torch.zeros(max_len, d_model) |
| pos = torch.arange(max_len).unsqueeze(1) |
| div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) |
| pe[:, 0::2] = torch.sin(pos * div) |
| pe[:, 1::2] = torch.cos(pos * div) |
| self.register_buffer('pe', pe.unsqueeze(0)) |
|
|
| def forward(self, x): |
| return self.dropout(x + self.pe[:, :x.size(1)]) |
|
|
|
|
| class CharSeq2Seq(nn.Module): |
| def __init__(self, vocab_size, d_model, n_heads, n_enc, n_dec, d_ff, max_len, dropout): |
| super().__init__() |
| self.d_model = d_model |
| self.embed = nn.Embedding(vocab_size, d_model, padding_idx=PAD) |
| self.pos_enc = PositionalEncoding(d_model, max_len, dropout) |
| enc_layer = nn.TransformerEncoderLayer(d_model, n_heads, d_ff, dropout, batch_first=True, norm_first=True) |
| dec_layer = nn.TransformerDecoderLayer(d_model, n_heads, d_ff, dropout, batch_first=True, norm_first=True) |
| self.encoder = nn.TransformerEncoder(enc_layer, n_enc, norm=nn.LayerNorm(d_model)) |
| self.decoder = nn.TransformerDecoder(dec_layer, n_dec, norm=nn.LayerNorm(d_model)) |
| self.proj = nn.Linear(d_model, vocab_size) |
| self._init_weights() |
|
|
| def _init_weights(self): |
| for p in self.parameters(): |
| if p.dim() > 1: |
| nn.init.xavier_uniform_(p) |
|
|
| def encode(self, src, src_key_padding_mask): |
| x = self.pos_enc(self.embed(src) * math.sqrt(self.d_model)) |
| return self.encoder(x, src_key_padding_mask=src_key_padding_mask) |
|
|
| def decode(self, tgt, memory, tgt_mask, tgt_key_padding_mask, memory_key_padding_mask): |
| x = self.pos_enc(self.embed(tgt) * math.sqrt(self.d_model)) |
| return self.decoder( |
| x, memory, |
| tgt_mask=tgt_mask, |
| tgt_key_padding_mask=tgt_key_padding_mask, |
| memory_key_padding_mask=memory_key_padding_mask, |
| ) |
|
|
| def forward(self, src, tgt): |
| src_pad_mask = (src == PAD) |
| tgt_pad_mask = (tgt == PAD) |
| T = tgt.size(1) |
| tgt_mask = nn.Transformer.generate_square_subsequent_mask(T, device=src.device) |
| memory = self.encode(src, src_pad_mask) |
| out = self.decode(tgt, memory, tgt_mask, tgt_pad_mask, src_pad_mask) |
| return self.proj(out) |
|
|
|
|
| |
|
|
| def encode_text(text, task_idx, c2i, max_len): |
| """ |
| Encode source text with a task prefix token. |
| Layout: [task_token, char, char, ..., EOS, PAD, ...] |
| """ |
| ids = [task_idx] + [c2i.get(c, UNK) for c in text] |
| ids = ids[:max_len - 1] + [EOS] |
| ids += [PAD] * (max_len - len(ids)) |
| return ids |
|
|
|
|
| def encode_target(text, c2i, max_len): |
| """Encode target: [SOS, char, char, ..., EOS, PAD, ...]""" |
| ids = [SOS] + [c2i.get(c, UNK) for c in text] |
| ids = ids[:max_len - 1] + [EOS] |
| ids += [PAD] * (max_len - len(ids)) |
| return ids |
|
|
|
|
| def decode_ids(ids, i2c): |
| """Decode token ids to string, stopping at EOS, skipping special tokens.""" |
| chars = [] |
| skip = {PAD, SOS, EOS, UNK, DIPL_IDX, NORM_IDX} |
| for i in ids: |
| if i == EOS: |
| break |
| if i not in skip: |
| chars.append(i2c.get(i, '')) |
| return ''.join(chars) |
|
|
|
|
| @torch.no_grad() |
| def greedy_decode(model, src, max_len, device, i2c): |
| """ |
| Greedy decode a batch. Task is implicit in the src prefix token. |
| |
| Args: |
| model : CharSeq2Seq in eval mode |
| src : LongTensor (B, S) β first token is the task prefix |
| max_len : int |
| device : str or torch.device |
| i2c : dict[int, str] |
| |
| Returns: |
| list[str] of length B |
| """ |
| model.eval() |
| src = src.to(device) |
| src_pad_mask = (src == PAD) |
| memory = model.encode(src, src_pad_mask) |
|
|
| B = src.size(0) |
| ys = torch.full((B, 1), SOS, dtype=torch.long, device=device) |
| done = torch.zeros(B, dtype=torch.bool, device=device) |
|
|
| for _ in range(max_len - 1): |
| T = ys.size(1) |
| tgt_mask = nn.Transformer.generate_square_subsequent_mask(T, device=device) |
| tgt_pad = (ys == PAD) |
| out = model.decode(ys, memory, tgt_mask, tgt_pad, src_pad_mask) |
| next_tok = model.proj(out[:, -1]).argmax(-1) |
| next_tok = next_tok.masked_fill(done, PAD) |
| ys = torch.cat([ys, next_tok.unsqueeze(1)], dim=1) |
| done = done | (next_tok == EOS) |
| if done.all(): |
| break |
|
|
| return [decode_ids(row.tolist(), i2c) for row in ys] |
|
|