Create span_ner_model.py
Browse files- span_ner_model.py +12 -13
span_ner_model.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
|
| 2 |
"""
|
| 3 |
Medieval Latin NER - Custom Span-NER Architecture
|
| 4 |
=============================================================================
|
|
@@ -20,14 +19,14 @@ class Config:
|
|
| 20 |
TEXT_DIM = 1024
|
| 21 |
LABEL_MODEL = "BAAI/bge-m3"
|
| 22 |
LABEL_DIM = 1024
|
| 23 |
-
|
| 24 |
MAX_SPAN_WIDTH = 80
|
| 25 |
WIDTH_EMB_DIM = 64
|
| 26 |
SPAN_HIDDEN = 512
|
| 27 |
ATTENTION_HEADS = 4
|
| 28 |
-
|
| 29 |
MAX_SEQ_LEN = 512
|
| 30 |
-
PREDICT_TEMP = 1.35
|
| 31 |
|
| 32 |
# ---------------------------------------------------------------------------
|
| 33 |
# 2. LABEL DICTIONARY & PROMPTS
|
|
@@ -92,16 +91,16 @@ class SpanRepLayer(nn.Module):
|
|
| 92 |
w_emb = self.width_emb(width)
|
| 93 |
|
| 94 |
idx = torch.arange(L, device=seq_out.device).view(1, 1, L)
|
| 95 |
-
mask = (idx >= spans[:,:,0:1]) & (idx <= spans[:,:,1:2])
|
| 96 |
|
| 97 |
-
att_logits = self.att_query(seq_out)
|
| 98 |
att_logits = att_logits.unsqueeze(1).expand(B, S, L, self.num_heads)
|
| 99 |
|
| 100 |
mask_expanded = mask.unsqueeze(-1).expand(-1, -1, -1, self.num_heads)
|
| 101 |
att_logits = att_logits.masked_fill(~mask_expanded, float('-inf'))
|
| 102 |
-
att_weights = F.softmax(att_logits, dim=2)
|
| 103 |
|
| 104 |
-
h_pool = torch.einsum('bslm,blh->bsmh', att_weights, seq_out)
|
| 105 |
h_pool = h_pool.reshape(B, S, self.num_heads * H)
|
| 106 |
|
| 107 |
return torch.cat([h_start, h_end, h_pool, w_emb], dim=-1)
|
|
@@ -112,7 +111,7 @@ class SpanNERModel(nn.Module):
|
|
| 112 |
self.cfg = cfg
|
| 113 |
self.text_enc = AutoModel.from_pretrained(cfg.TEXT_MODEL, add_pooling_layer=False)
|
| 114 |
self.label_enc = AutoModel.from_pretrained(cfg.LABEL_MODEL)
|
| 115 |
-
|
| 116 |
self.span_layer = SpanRepLayer(cfg.TEXT_DIM, cfg.MAX_SPAN_WIDTH, cfg.WIDTH_EMB_DIM, num_heads=cfg.ATTENTION_HEADS)
|
| 117 |
|
| 118 |
self.label_proj = nn.Sequential(
|
|
@@ -129,7 +128,7 @@ class SpanNERModel(nn.Module):
|
|
| 129 |
nn.Linear(cfg.SPAN_HIDDEN, cfg.SPAN_HIDDEN)
|
| 130 |
)
|
| 131 |
|
| 132 |
-
self.logit_scale = nn.Parameter(torch.tensor(1.0))
|
| 133 |
self._raw_label_embs = None
|
| 134 |
|
| 135 |
@torch.no_grad()
|
|
@@ -220,11 +219,11 @@ class SpanNERModel(nn.Module):
|
|
| 220 |
covered = set(range(ws, we + 1))
|
| 221 |
if flat_ner and covered & taken: continue
|
| 222 |
if flat_ner: taken |= covered
|
| 223 |
-
|
| 224 |
start_char = tokens_info[ws]["start"]
|
| 225 |
end_char = tokens_info[we]["end"]
|
| 226 |
text_span = text[start_char:end_char]
|
| 227 |
-
|
| 228 |
result.append({
|
| 229 |
"label": label,
|
| 230 |
"score": round(score, 4),
|
|
@@ -235,4 +234,4 @@ class SpanNERModel(nn.Module):
|
|
| 235 |
"end_word": we
|
| 236 |
})
|
| 237 |
|
| 238 |
-
return result
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Medieval Latin NER - Custom Span-NER Architecture
|
| 3 |
=============================================================================
|
|
|
|
| 19 |
TEXT_DIM = 1024
|
| 20 |
LABEL_MODEL = "BAAI/bge-m3"
|
| 21 |
LABEL_DIM = 1024
|
| 22 |
+
|
| 23 |
MAX_SPAN_WIDTH = 80
|
| 24 |
WIDTH_EMB_DIM = 64
|
| 25 |
SPAN_HIDDEN = 512
|
| 26 |
ATTENTION_HEADS = 4
|
| 27 |
+
|
| 28 |
MAX_SEQ_LEN = 512
|
| 29 |
+
PREDICT_TEMP = 1.35
|
| 30 |
|
| 31 |
# ---------------------------------------------------------------------------
|
| 32 |
# 2. LABEL DICTIONARY & PROMPTS
|
|
|
|
| 91 |
w_emb = self.width_emb(width)
|
| 92 |
|
| 93 |
idx = torch.arange(L, device=seq_out.device).view(1, 1, L)
|
| 94 |
+
mask = (idx >= spans[:,:,0:1]) & (idx <= spans[:,:,1:2])
|
| 95 |
|
| 96 |
+
att_logits = self.att_query(seq_out)
|
| 97 |
att_logits = att_logits.unsqueeze(1).expand(B, S, L, self.num_heads)
|
| 98 |
|
| 99 |
mask_expanded = mask.unsqueeze(-1).expand(-1, -1, -1, self.num_heads)
|
| 100 |
att_logits = att_logits.masked_fill(~mask_expanded, float('-inf'))
|
| 101 |
+
att_weights = F.softmax(att_logits, dim=2)
|
| 102 |
|
| 103 |
+
h_pool = torch.einsum('bslm,blh->bsmh', att_weights, seq_out)
|
| 104 |
h_pool = h_pool.reshape(B, S, self.num_heads * H)
|
| 105 |
|
| 106 |
return torch.cat([h_start, h_end, h_pool, w_emb], dim=-1)
|
|
|
|
| 111 |
self.cfg = cfg
|
| 112 |
self.text_enc = AutoModel.from_pretrained(cfg.TEXT_MODEL, add_pooling_layer=False)
|
| 113 |
self.label_enc = AutoModel.from_pretrained(cfg.LABEL_MODEL)
|
| 114 |
+
|
| 115 |
self.span_layer = SpanRepLayer(cfg.TEXT_DIM, cfg.MAX_SPAN_WIDTH, cfg.WIDTH_EMB_DIM, num_heads=cfg.ATTENTION_HEADS)
|
| 116 |
|
| 117 |
self.label_proj = nn.Sequential(
|
|
|
|
| 128 |
nn.Linear(cfg.SPAN_HIDDEN, cfg.SPAN_HIDDEN)
|
| 129 |
)
|
| 130 |
|
| 131 |
+
self.logit_scale = nn.Parameter(torch.tensor(1.0))
|
| 132 |
self._raw_label_embs = None
|
| 133 |
|
| 134 |
@torch.no_grad()
|
|
|
|
| 219 |
covered = set(range(ws, we + 1))
|
| 220 |
if flat_ner and covered & taken: continue
|
| 221 |
if flat_ner: taken |= covered
|
| 222 |
+
|
| 223 |
start_char = tokens_info[ws]["start"]
|
| 224 |
end_char = tokens_info[we]["end"]
|
| 225 |
text_span = text[start_char:end_char]
|
| 226 |
+
|
| 227 |
result.append({
|
| 228 |
"label": label,
|
| 229 |
"score": round(score, 4),
|
|
|
|
| 234 |
"end_word": we
|
| 235 |
})
|
| 236 |
|
| 237 |
+
return result
|