med-jargon-crf / README.md
DNivalis's picture
Update README.md
a6808ea verified
|
Raw
History Blame
6.48 kB
---
license: apache-2.0
language:
- en
base_model:
- FacebookAI/roberta-large
tags:
- token-classification
- named-entity-recognition
- medical-nlp
- crf
- biomedical
- jargon-identification
---
# Medical Jargon Identifier with CRF
A PyTorch model that performs **fine-grained medical jargon identification** using a **RoBERTa-large** backbone enhanced by a **Conditional Random Field (CRF)** layer.
Fine-tuned on the **MedReadMe** dataset introduced by Jiang & Xu (2024).
---
## 🧠 Overview
* **Architecture**: RoBERTa-large β†’ Linear classifier β†’ CRF
* **Task**: Token-level classification into **7 medical jargon categories** + BIO tagging
* **Input**: Raw English text (sentences or paragraphs)
* **Output**: Word-level spans labeled with jargon type and boundaries
---
## 🎯 Supported Jargon Categories
| Label (BIO) | Meaning |
| ------------------------------------ | -------------------------------------------- |
| `medical-jargon-google-easy` | Easily Google-able medical terms |
| `medical-jargon-google-hard` | Complex, hard-to-Google medical terms |
| `medical-name-entity` | Named diseases, drugs, procedures |
| `general-complex` | Complex general vocabulary |
| `abbr-medical` | Medical abbreviations (e.g., ECG, CBC) |
| `abbr-general` | General abbreviations |
| `general-medical-multisense` | Words with both lay and medical meanings |
---
## πŸ“ Files & Format
* `pytorch_model.bin` – model weights
* `config.json` – hyper-parameters & label map
* `tokenizer.json`, `vocab.json`, `merges.txt` – RoBERTa tokenizer assets
* `modeling_jargon.py` – custom `CRFTokenClassificationModel` class
* `requirements.txt` – runtime dependencies
---
## πŸ”§ Quick Start
```python
from transformers import AutoTokenizer, AutoModel
from huggingface_hub import PyTorchModelHubMixin
from torchcrf import CRF
import torch
import torch.nn as nn
class CRFTokenClassificationModel(nn.Module, PyTorchModelHubMixin):
def __init__(self, config):
super().__init__()
# Load base transformer model
self.transformer = AutoModel.from_pretrained(config["pretrained_model_name"])
# Classification layers
self.dropout = nn.Dropout(config["hidden_dropout_prob"])
self.classifier = nn.Linear(config["hidden_size"], config["num_labels"])
# CRF layer for sequence labeling
self.crf = CRF(config["num_labels"], batch_first=True)
# Label mappings
self.id2label = {v: k for k, v in config["label_map"].items()}
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
# Get transformer outputs
outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = self.dropout(outputs.last_hidden_state)
logits = self.classifier(sequence_output)
# Calculate loss if labels provided (training mode)
if labels is not None:
loss = -self.crf(logits, labels, mask=attention_mask.bool(), reduction='mean')
return {"loss": loss, "logits": logits}
# Return logits only (inference mode)
return {"logits": logits}
def decode(self, logits, mask):
# Use CRF to decode best sequence
return self.crf.decode(logits, mask.bool())
# 1. Load model and tokenizer
model_name = "DNivalis/med-jargon-crf"
tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
model = CRFTokenClassificationModel.from_pretrained(model_name)
model.eval()
# 2. Prepare input text
text = "The patient presented with elevated CRP and intermittent AF."
inputs = tokenizer(text, return_tensors="pt")
# 3. Run inference
with torch.no_grad():
outputs = model(**inputs)
logits = outputs["logits"]
# Decode best sequence using CRF
predicted_tags = model.decode(logits, inputs["attention_mask"])[0]
# 4. Extract spans from predictions
spans = [(i, model.id2label[tag_id]) for i, tag_id in enumerate(predicted_tags) if tag_id != 0]
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
# 5. Display results
print("Detected medical jargon:")
for token_idx, label in spans:
# Find continuous spans of the same entity
end_idx = token_idx + 1
while (end_idx < len(predicted_tags) and
predicted_tags[end_idx] == predicted_tags[token_idx]):
end_idx += 1
# Convert tokens back to text
detected_tokens = tokens[token_idx:end_idx]
detected_text = tokenizer.convert_tokens_to_string(detected_tokens)
print(f"{label}: '{detected_text.strip()}'")
```
---
## πŸ₯ Supported Tasks
* **Medical jargon detection** – binary, 3-class, or 7-category granularity
* **Named-entity recognition** – extract spans of medical interest
* **Readability analysis** – density of jargon per sentence or document
* **Downstream QA & summarization** – filter or simplify complex terms
---
## 🌍 Language
English only.
---
## πŸ“š Training Data
Fine-tuned on **MedReadMe**: 4,520 sentences with fine-grained jargon span annotations, including the novel *Google-Easy* and *Google-Hard* categories .
---
## πŸ“– Citation
If you use this model or the underlying dataset, please cite:
```bibtex
@article{jiang2024medreadmesystematicstudyfinegrained,
title={MedReadMe: A Systematic Study for Fine-grained Sentence Readability in Medical Domain},
author={Chao Jiang and Wei Xu},
year={2024},
eprint={2405.02144},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2405.02144}
}
```
---
## πŸ“ License & Usage
Licensed under **Apache 2.0**.
* βœ… Allowed: research, commercial use, derivative works
* Include license notice and attribution in any distribution
---
## ⚠️ Important Notes
* Model outputs are **not medical advice**; use for research/educational purposes only.
* Performance may vary on text that differs substantially from the MedReadMe training domain.
* Consider additional post-processing for production systems (e.g., confidence filtering).
---
## ☎️ Contact
For questions, issues, or licensing inquiries, open an issue on the [model repository](https://huggingface.co/DNivalis/med-jargon-crf).