DNivalis commited on
Commit
e002b8e
·
verified ·
1 Parent(s): a6808ea

Update modeling_jargon.py

Browse files
Files changed (1) hide show
  1. modeling_jargon.py +17 -6
modeling_jargon.py CHANGED
@@ -1,27 +1,38 @@
 
1
  from huggingface_hub import PyTorchModelHubMixin
2
- from transformers import PreTrainedModel, AutoConfig, AutoModel
3
  from torchcrf import CRF
4
  import torch.nn as nn
5
- import torch
6
 
7
  class CRFTokenClassificationModel(nn.Module, PyTorchModelHubMixin):
8
  def __init__(self, config):
9
  super().__init__()
10
- self.transformer = AutoModel.from_config(config)
11
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
12
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
13
- self.crf = CRF(config.num_labels, batch_first=True)
 
 
 
 
 
 
 
 
14
 
15
  def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
 
16
  outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
17
  sequence_output = self.dropout(outputs.last_hidden_state)
18
  logits = self.classifier(sequence_output)
19
 
 
20
  if labels is not None:
21
  loss = -self.crf(logits, labels, mask=attention_mask.bool(), reduction='mean')
22
  return {"loss": loss, "logits": logits}
23
 
 
24
  return {"logits": logits}
25
 
26
  def decode(self, logits, mask):
 
27
  return self.crf.decode(logits, mask.bool())
 
1
+ from transformers import AutoModel
2
  from huggingface_hub import PyTorchModelHubMixin
 
3
  from torchcrf import CRF
4
  import torch.nn as nn
 
5
 
6
  class CRFTokenClassificationModel(nn.Module, PyTorchModelHubMixin):
7
  def __init__(self, config):
8
  super().__init__()
9
+ # Load base transformer model
10
+ self.transformer = AutoModel.from_pretrained(config["pretrained_model_name"])
11
+
12
+ # Classification layers
13
+ self.dropout = nn.Dropout(config["hidden_dropout_prob"])
14
+ self.classifier = nn.Linear(config["hidden_size"], config["num_labels"])
15
+
16
+ # CRF layer for sequence labeling
17
+ self.crf = CRF(config["num_labels"], batch_first=True)
18
+
19
+ # Label mappings
20
+ self.id2label = {v: k for k, v in config["label_map"].items()}
21
 
22
  def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
23
+ # Get transformer outputs
24
  outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
25
  sequence_output = self.dropout(outputs.last_hidden_state)
26
  logits = self.classifier(sequence_output)
27
 
28
+ # Calculate loss if labels provided (training mode)
29
  if labels is not None:
30
  loss = -self.crf(logits, labels, mask=attention_mask.bool(), reduction='mean')
31
  return {"loss": loss, "logits": logits}
32
 
33
+ # Return logits only (inference mode)
34
  return {"logits": logits}
35
 
36
  def decode(self, logits, mask):
37
+ # Use CRF to decode best sequence
38
  return self.crf.decode(logits, mask.bool())