DNivalis commited on
Commit
a6808ea
·
verified ·
1 Parent(s): 07aa6ed

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +68 -9
README.md CHANGED
@@ -56,23 +56,82 @@ Fine-tuned on the **MedReadMe** dataset introduced by Jiang & Xu (2024).
56
  ## 🔧 Quick Start
57
 
58
  ```python
59
- from transformers import AutoTokenizer
60
- from modeling_jargon import CRFTokenClassificationModel
61
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  model_name = "DNivalis/med-jargon-crf"
63
  tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
64
  model = CRFTokenClassificationModel.from_pretrained(model_name)
65
  model.eval()
66
 
 
67
  text = "The patient presented with elevated CRP and intermittent AF."
68
  inputs = tokenizer(text, return_tensors="pt")
69
- with torch.no_grad():
70
- logits = model(**inputs)["logits"]
71
- tags = model.decode(logits, inputs["attention_mask"])[0]
72
 
73
- # Convert IDs → labels
74
- id2label = model.config.id2label
75
- spans = [(i, id2label[t]) for i, t in enumerate(tags) if t != 0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  ```
77
 
78
  ---
 
56
  ## 🔧 Quick Start
57
 
58
  ```python
59
+ from transformers import AutoTokenizer, AutoModel
60
+ from huggingface_hub import PyTorchModelHubMixin
61
+ from torchcrf import CRF
62
+ import torch
63
+ import torch.nn as nn
64
+
65
+ class CRFTokenClassificationModel(nn.Module, PyTorchModelHubMixin):
66
+ def __init__(self, config):
67
+ super().__init__()
68
+ # Load base transformer model
69
+ self.transformer = AutoModel.from_pretrained(config["pretrained_model_name"])
70
+
71
+ # Classification layers
72
+ self.dropout = nn.Dropout(config["hidden_dropout_prob"])
73
+ self.classifier = nn.Linear(config["hidden_size"], config["num_labels"])
74
+
75
+ # CRF layer for sequence labeling
76
+ self.crf = CRF(config["num_labels"], batch_first=True)
77
+
78
+ # Label mappings
79
+ self.id2label = {v: k for k, v in config["label_map"].items()}
80
+
81
+ def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
82
+ # Get transformer outputs
83
+ outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
84
+ sequence_output = self.dropout(outputs.last_hidden_state)
85
+ logits = self.classifier(sequence_output)
86
+
87
+ # Calculate loss if labels provided (training mode)
88
+ if labels is not None:
89
+ loss = -self.crf(logits, labels, mask=attention_mask.bool(), reduction='mean')
90
+ return {"loss": loss, "logits": logits}
91
+
92
+ # Return logits only (inference mode)
93
+ return {"logits": logits}
94
+
95
+ def decode(self, logits, mask):
96
+ # Use CRF to decode best sequence
97
+ return self.crf.decode(logits, mask.bool())
98
+
99
+
100
+ # 1. Load model and tokenizer
101
  model_name = "DNivalis/med-jargon-crf"
102
  tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
103
  model = CRFTokenClassificationModel.from_pretrained(model_name)
104
  model.eval()
105
 
106
+ # 2. Prepare input text
107
  text = "The patient presented with elevated CRP and intermittent AF."
108
  inputs = tokenizer(text, return_tensors="pt")
 
 
 
109
 
110
+ # 3. Run inference
111
+ with torch.no_grad():
112
+ outputs = model(**inputs)
113
+ logits = outputs["logits"]
114
+ # Decode best sequence using CRF
115
+ predicted_tags = model.decode(logits, inputs["attention_mask"])[0]
116
+
117
+ # 4. Extract spans from predictions
118
+ spans = [(i, model.id2label[tag_id]) for i, tag_id in enumerate(predicted_tags) if tag_id != 0]
119
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
120
+
121
+ # 5. Display results
122
+ print("Detected medical jargon:")
123
+ for token_idx, label in spans:
124
+ # Find continuous spans of the same entity
125
+ end_idx = token_idx + 1
126
+ while (end_idx < len(predicted_tags) and
127
+ predicted_tags[end_idx] == predicted_tags[token_idx]):
128
+ end_idx += 1
129
+
130
+ # Convert tokens back to text
131
+ detected_tokens = tokens[token_idx:end_idx]
132
+ detected_text = tokenizer.convert_tokens_to_string(detected_tokens)
133
+
134
+ print(f"{label}: '{detected_text.strip()}'")
135
  ```
136
 
137
  ---