ERCDiDip commited on
Commit
dfef0e5
·
verified ·
1 Parent(s): 176f183

Create span_ner_model.py

Browse files
Files changed (1) hide show
  1. span_ner_model.py +12 -13
span_ner_model.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  """
3
  Medieval Latin NER - Custom Span-NER Architecture
4
  =============================================================================
@@ -20,14 +19,14 @@ class Config:
20
  TEXT_DIM = 1024
21
  LABEL_MODEL = "BAAI/bge-m3"
22
  LABEL_DIM = 1024
23
-
24
  MAX_SPAN_WIDTH = 80
25
  WIDTH_EMB_DIM = 64
26
  SPAN_HIDDEN = 512
27
  ATTENTION_HEADS = 4
28
-
29
  MAX_SEQ_LEN = 512
30
- PREDICT_TEMP = 1.35
31
 
32
  # ---------------------------------------------------------------------------
33
  # 2. LABEL DICTIONARY & PROMPTS
@@ -92,16 +91,16 @@ class SpanRepLayer(nn.Module):
92
  w_emb = self.width_emb(width)
93
 
94
  idx = torch.arange(L, device=seq_out.device).view(1, 1, L)
95
- mask = (idx >= spans[:,:,0:1]) & (idx <= spans[:,:,1:2])
96
 
97
- att_logits = self.att_query(seq_out)
98
  att_logits = att_logits.unsqueeze(1).expand(B, S, L, self.num_heads)
99
 
100
  mask_expanded = mask.unsqueeze(-1).expand(-1, -1, -1, self.num_heads)
101
  att_logits = att_logits.masked_fill(~mask_expanded, float('-inf'))
102
- att_weights = F.softmax(att_logits, dim=2)
103
 
104
- h_pool = torch.einsum('bslm,blh->bsmh', att_weights, seq_out)
105
  h_pool = h_pool.reshape(B, S, self.num_heads * H)
106
 
107
  return torch.cat([h_start, h_end, h_pool, w_emb], dim=-1)
@@ -112,7 +111,7 @@ class SpanNERModel(nn.Module):
112
  self.cfg = cfg
113
  self.text_enc = AutoModel.from_pretrained(cfg.TEXT_MODEL, add_pooling_layer=False)
114
  self.label_enc = AutoModel.from_pretrained(cfg.LABEL_MODEL)
115
-
116
  self.span_layer = SpanRepLayer(cfg.TEXT_DIM, cfg.MAX_SPAN_WIDTH, cfg.WIDTH_EMB_DIM, num_heads=cfg.ATTENTION_HEADS)
117
 
118
  self.label_proj = nn.Sequential(
@@ -129,7 +128,7 @@ class SpanNERModel(nn.Module):
129
  nn.Linear(cfg.SPAN_HIDDEN, cfg.SPAN_HIDDEN)
130
  )
131
 
132
- self.logit_scale = nn.Parameter(torch.tensor(1.0))
133
  self._raw_label_embs = None
134
 
135
  @torch.no_grad()
@@ -220,11 +219,11 @@ class SpanNERModel(nn.Module):
220
  covered = set(range(ws, we + 1))
221
  if flat_ner and covered & taken: continue
222
  if flat_ner: taken |= covered
223
-
224
  start_char = tokens_info[ws]["start"]
225
  end_char = tokens_info[we]["end"]
226
  text_span = text[start_char:end_char]
227
-
228
  result.append({
229
  "label": label,
230
  "score": round(score, 4),
@@ -235,4 +234,4 @@ class SpanNERModel(nn.Module):
235
  "end_word": we
236
  })
237
 
238
- return result
 
 
1
  """
2
  Medieval Latin NER - Custom Span-NER Architecture
3
  =============================================================================
 
19
  TEXT_DIM = 1024
20
  LABEL_MODEL = "BAAI/bge-m3"
21
  LABEL_DIM = 1024
22
+
23
  MAX_SPAN_WIDTH = 80
24
  WIDTH_EMB_DIM = 64
25
  SPAN_HIDDEN = 512
26
  ATTENTION_HEADS = 4
27
+
28
  MAX_SEQ_LEN = 512
29
+ PREDICT_TEMP = 1.35
30
 
31
  # ---------------------------------------------------------------------------
32
  # 2. LABEL DICTIONARY & PROMPTS
 
91
  w_emb = self.width_emb(width)
92
 
93
  idx = torch.arange(L, device=seq_out.device).view(1, 1, L)
94
+ mask = (idx >= spans[:,:,0:1]) & (idx <= spans[:,:,1:2])
95
 
96
+ att_logits = self.att_query(seq_out)
97
  att_logits = att_logits.unsqueeze(1).expand(B, S, L, self.num_heads)
98
 
99
  mask_expanded = mask.unsqueeze(-1).expand(-1, -1, -1, self.num_heads)
100
  att_logits = att_logits.masked_fill(~mask_expanded, float('-inf'))
101
+ att_weights = F.softmax(att_logits, dim=2)
102
 
103
+ h_pool = torch.einsum('bslm,blh->bsmh', att_weights, seq_out)
104
  h_pool = h_pool.reshape(B, S, self.num_heads * H)
105
 
106
  return torch.cat([h_start, h_end, h_pool, w_emb], dim=-1)
 
111
  self.cfg = cfg
112
  self.text_enc = AutoModel.from_pretrained(cfg.TEXT_MODEL, add_pooling_layer=False)
113
  self.label_enc = AutoModel.from_pretrained(cfg.LABEL_MODEL)
114
+
115
  self.span_layer = SpanRepLayer(cfg.TEXT_DIM, cfg.MAX_SPAN_WIDTH, cfg.WIDTH_EMB_DIM, num_heads=cfg.ATTENTION_HEADS)
116
 
117
  self.label_proj = nn.Sequential(
 
128
  nn.Linear(cfg.SPAN_HIDDEN, cfg.SPAN_HIDDEN)
129
  )
130
 
131
+ self.logit_scale = nn.Parameter(torch.tensor(1.0))
132
  self._raw_label_embs = None
133
 
134
  @torch.no_grad()
 
219
  covered = set(range(ws, we + 1))
220
  if flat_ner and covered & taken: continue
221
  if flat_ner: taken |= covered
222
+
223
  start_char = tokens_info[ws]["start"]
224
  end_char = tokens_info[we]["end"]
225
  text_span = text[start_char:end_char]
226
+
227
  result.append({
228
  "label": label,
229
  "score": round(score, 4),
 
234
  "end_word": we
235
  })
236
 
237
+ return result