""" Medieval Latin NER - Custom Span-NER Architecture ============================================================================= Core architecture for the Span-NER model utilizing a bi-encoder approach with a frozen BGE-M3 semantic label space and XLM-RoBERTa-Large text encoder. """ import re import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel # --------------------------------------------------------------------------- # 1. CONFIGURATION # --------------------------------------------------------------------------- class Config: TEXT_MODEL = "FacebookAI/xlm-roberta-large" TEXT_DIM = 1024 LABEL_MODEL = "BAAI/bge-m3" LABEL_DIM = 1024 MAX_SPAN_WIDTH = 80 WIDTH_EMB_DIM = 64 SPAN_HIDDEN = 512 ATTENTION_HEADS = 4 MAX_SEQ_LEN = 512 PREDICT_TEMP = 1.35 # --------------------------------------------------------------------------- # 2. LABEL DICTIONARY & PROMPTS # --------------------------------------------------------------------------- LABEL_DICT = { "PER": "individual person name without any titles or roles, strictly the given name or family name", "ACTOR": "full noun phrase referring to a person including their name plus noble title, profession, geographic origin, or social status", "TITLE": "social rank, noble title, ecclesiastical office, profession, or papal rank such as comes, abbas, episcopus", "REL": "word or phrase indicating family, kinship, marriage, or social relationship like filius, uxor, frater", "LOC": "geographical place, settlement, city, diocese, region, or named territory", "INS": "monastery, abbey, church, cell, or religious order functioning as a corporate and legal body", "NAT": "natural landscape feature such as a river, stream, forest, mountain, or valley", "EST": "short physical plot of land, estate, farm, meadows, woods, vineyards, or courtyards", "PROP": "detailed boundary description of a property, grange, estate, or island including past owners, movables, and immovables", "LEG": "legal clause declaring rights, conditions, penalties, permissions, or papal commands", "TRANS": "verb or phrase denoting a core transaction, confirmation, transfer, sale, gift, or donation", "TIM": "time period, duration, general dating formula, indiction, or papal/royal regnal year", "DAT": "specific calendar date, precise year of incarnation often starting with Anno or Datum, or named liturgical feast day", "MON": "money, currency, coin, or monetary value such as libra, solidus, denarius, uncia, or marca", "TAX": "customary toll, legal tax, tithe, exaction, lucrum camere, or tribute paid to an authority", "COM": "harvested crops, food, physical goods, salt, wine, wax, gold, wood, or animals traded or given", "NUM": "number written as a word or roman numeral, including fractions and quantities", "MEA": "unit of measurement for land, volume, or weight such as mansus, carratas, aratrum, or talentum", "RELIC": "holy relic, cross, altar, or sacred object of veneration within a church", } LABEL_KEYS = list(LABEL_DICT.keys()) LABEL_DESCS = list(LABEL_DICT.values()) LABEL2ID = {k: i for i, k in enumerate(LABEL_KEYS)} ID2LABEL = {i: k for k, i in LABEL2ID.items()} NUM_LABELS = len(LABEL_DICT) def char_tokenize(text): return [{"token": m.group(), "start": m.start(), "end": m.end()} for m in re.finditer(r'\w+|[^\w\s]', text)] # --------------------------------------------------------------------------- # 3. MODEL ARCHITECTURE # --------------------------------------------------------------------------- class SpanRepLayer(nn.Module): def __init__(self, hidden, max_span_width, width_emb_dim, num_heads=4): super().__init__() self.max_span_width = max_span_width self.num_heads = num_heads self.width_emb = nn.Embedding(max_span_width + 1, width_emb_dim) self.att_query = nn.Sequential( nn.Linear(hidden, hidden // 2), nn.GELU(), nn.Linear(hidden // 2, num_heads) ) self.span_dim = 2 * hidden + (num_heads * hidden) + width_emb_dim def forward(self, seq_out, spans): B, S, _ = spans.shape L = seq_out.size(1) H = seq_out.size(-1) h_start = seq_out[torch.arange(B).unsqueeze(1), spans[:,:,0]] h_end = seq_out[torch.arange(B).unsqueeze(1), spans[:,:,1]] width = spans[:,:,2].clamp(0, self.max_span_width) w_emb = self.width_emb(width) idx = torch.arange(L, device=seq_out.device).view(1, 1, L) mask = (idx >= spans[:,:,0:1]) & (idx <= spans[:,:,1:2]) att_logits = self.att_query(seq_out) att_logits = att_logits.unsqueeze(1).expand(B, S, L, self.num_heads) mask_expanded = mask.unsqueeze(-1).expand(-1, -1, -1, self.num_heads) att_logits = att_logits.masked_fill(~mask_expanded, float('-inf')) att_weights = F.softmax(att_logits, dim=2) h_pool = torch.einsum('bslm,blh->bsmh', att_weights, seq_out) h_pool = h_pool.reshape(B, S, self.num_heads * H) return torch.cat([h_start, h_end, h_pool, w_emb], dim=-1) class SpanNERModel(nn.Module): def __init__(self, cfg: Config): super().__init__() self.cfg = cfg self.text_enc = AutoModel.from_pretrained(cfg.TEXT_MODEL, add_pooling_layer=False) self.label_enc = AutoModel.from_pretrained(cfg.LABEL_MODEL) self.span_layer = SpanRepLayer(cfg.TEXT_DIM, cfg.MAX_SPAN_WIDTH, cfg.WIDTH_EMB_DIM, num_heads=cfg.ATTENTION_HEADS) self.label_proj = nn.Sequential( nn.Linear(cfg.LABEL_DIM, cfg.SPAN_HIDDEN), nn.GELU(), nn.LayerNorm(cfg.SPAN_HIDDEN), nn.Linear(cfg.SPAN_HIDDEN, cfg.SPAN_HIDDEN) ) self.span_proj = nn.Sequential( nn.Linear(self.span_layer.span_dim, cfg.SPAN_HIDDEN), nn.GELU(), nn.LayerNorm(cfg.SPAN_HIDDEN), nn.Dropout(0.2), nn.Linear(cfg.SPAN_HIDDEN, cfg.SPAN_HIDDEN) ) self.logit_scale = nn.Parameter(torch.tensor(1.0)) self._raw_label_embs = None @torch.no_grad() def _build_label_cache(self, label_tokenizer, device): enc = label_tokenizer( LABEL_DESCS, padding=True, truncation=True, max_length=128, return_tensors="pt" ).to(device) out = self.label_enc(**enc).last_hidden_state mask = enc["attention_mask"].unsqueeze(-1).float() pooled = F.normalize((out * mask).sum(1) / mask.sum(1), dim=-1, eps=1e-8) self._raw_label_embs = pooled.detach() def predict(self, text, label_tokenizer, text_tokenizer, threshold, flat_ner, device): self.eval() tokens_info = char_tokenize(text) all_tokens = [t["token"] for t in tokens_info] n_total = len(all_tokens) if n_total == 0: return [] stride = self.cfg.MAX_SEQ_LEN - 50 chunk_starts = range(0, max(1, n_total), stride) all_candidates = [] with torch.no_grad(): if self._raw_label_embs is None or self._raw_label_embs.device != device: self._build_label_cache(label_tokenizer, device) label_feat = self.label_proj(self._raw_label_embs.to(device)) for start_idx in chunk_starts: end_idx = min(start_idx + self.cfg.MAX_SEQ_LEN - 2, n_total) tokens = all_tokens[start_idx:end_idx] n = len(tokens) if n == 0: continue enc = text_tokenizer(tokens, is_split_into_words=True, max_length=self.cfg.MAX_SEQ_LEN, truncation=True, padding=False, return_tensors="pt").to(device) word_ids = enc.word_ids(batch_index=0) first_sw, last_sw = {}, {} for sw_idx, w_idx in enumerate(word_ids): if w_idx is not None: if w_idx not in first_sw: first_sw[w_idx] = sw_idx last_sw[w_idx] = sw_idx span_list, span_word_bounds = [], [] for ws in range(n): for we in range(ws, min(ws + self.cfg.MAX_SPAN_WIDTH, n)): if ws not in first_sw or we not in last_sw: continue span_list.append([first_sw[ws], last_sw[we], we - ws + 1]) span_word_bounds.append((start_idx + ws, start_idx + we)) if not span_list: continue seq_out = self.text_enc(input_ids=enc["input_ids"].to(device), attention_mask=enc["attention_mask"].to(device)).last_hidden_state chunk_logits = [] for i in range(0, len(span_list), 4096): chunk = torch.tensor(span_list[i:i+4096], dtype=torch.long).unsqueeze(0).to(device) sf = self.span_proj(self.span_layer(seq_out, chunk)) sf_norm = F.normalize(sf, p=2, dim=-1) lf_norm = F.normalize(label_feat, p=2, dim=-1) scale = self.logit_scale.exp().clamp(max=120.0) ch_logits = torch.einsum('bsd,ld->bsl', sf_norm, lf_norm) * scale chunk_logits.append(ch_logits.squeeze(0).cpu()) scores = torch.sigmoid(torch.cat(chunk_logits, dim=0) / self.cfg.PREDICT_TEMP) for si, (g_ws, g_we) in enumerate(span_word_bounds): for li in range(NUM_LABELS): score = scores[si, li].item() if score >= threshold: all_candidates.append((g_ws, g_we, ID2LABEL[li], score)) unique_cands = {} for ws, we, label, score in all_candidates: key = (ws, we, label) if key not in unique_cands or score > unique_cands[key]: unique_cands[key] = score final_candidates = [(ws, we, lbl, sc) for (ws, we, lbl), sc in unique_cands.items()] final_candidates.sort(key=lambda x: -x[3]) taken, result = set(), [] for ws, we, label, score in final_candidates: covered = set(range(ws, we + 1)) if flat_ner and covered & taken: continue if flat_ner: taken |= covered start_char = tokens_info[ws]["start"] end_char = tokens_info[we]["end"] text_span = text[start_char:end_char] result.append({ "label": label, "score": round(score, 4), "text": text_span, "start_char": start_char, "end_char": end_char, "start_word": ws, "end_word": we }) return result