--- license: apache-2.0 language: - en base_model: - FacebookAI/roberta-large tags: - token-classification - named-entity-recognition - medical-nlp - crf - biomedical - jargon-identification --- # Medical Jargon Identifier with CRF A PyTorch model that performs **fine-grained medical jargon identification** using a **RoBERTa-large** backbone enhanced by a **Conditional Random Field (CRF)** layer. Fine-tuned on the **MedReadMe** dataset introduced by Jiang & Xu (2024). --- ## 🧠 Overview * **Architecture**: RoBERTa-large → Linear classifier → CRF * **Task**: Token-level classification into **7 medical jargon categories** + BIO tagging * **Input**: Raw English text (sentences or paragraphs) * **Output**: Word-level spans labeled with jargon type and boundaries --- ## 🎯 Supported Jargon Categories | Label (BIO) | Meaning | | ------------------------------------ | -------------------------------------------- | | `medical-jargon-google-easy` | Easily Google-able medical terms | | `medical-jargon-google-hard` | Complex, hard-to-Google medical terms | | `medical-name-entity` | Named diseases, drugs, procedures | | `general-complex` | Complex general vocabulary | | `abbr-medical` | Medical abbreviations (e.g., ECG, CBC) | | `abbr-general` | General abbreviations | | `general-medical-multisense` | Words with both lay and medical meanings | --- ## 📁 Files & Format * `pytorch_model.bin` – model weights * `config.json` – hyper-parameters & label map * `tokenizer.json`, `vocab.json`, `merges.txt` – RoBERTa tokenizer assets * `modeling_jargon.py` – custom `CRFTokenClassificationModel` class * `requirements.txt` – runtime dependencies --- ## 🔧 Quick Start ```python from transformers import AutoTokenizer, AutoModel from huggingface_hub import PyTorchModelHubMixin from torchcrf import CRF import torch import torch.nn as nn class CRFTokenClassificationModel(nn.Module, PyTorchModelHubMixin): def __init__(self, config): super().__init__() # Load base transformer model self.transformer = AutoModel.from_pretrained(config["pretrained_model_name"]) # Classification layers self.dropout = nn.Dropout(config["hidden_dropout_prob"]) self.classifier = nn.Linear(config["hidden_size"], config["num_labels"]) # CRF layer for sequence labeling self.crf = CRF(config["num_labels"], batch_first=True) # Label mappings self.id2label = {v: k for k, v in config["label_map"].items()} def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): # Get transformer outputs outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask) sequence_output = self.dropout(outputs.last_hidden_state) logits = self.classifier(sequence_output) # Calculate loss if labels provided (training mode) if labels is not None: loss = -self.crf(logits, labels, mask=attention_mask.bool(), reduction='mean') return {"loss": loss, "logits": logits} # Return logits only (inference mode) return {"logits": logits} def decode(self, logits, mask): # Use CRF to decode best sequence return self.crf.decode(logits, mask.bool()) # 1. Load model and tokenizer model_name = "DNivalis/med-jargon-crf" tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True) model = CRFTokenClassificationModel.from_pretrained(model_name) model.eval() # 2. Prepare input text text = "The patient presented with elevated CRP and intermittent AF." inputs = tokenizer(text, return_tensors="pt") # 3. Run inference with torch.no_grad(): outputs = model(**inputs) logits = outputs["logits"] # Decode best sequence using CRF predicted_tags = model.decode(logits, inputs["attention_mask"])[0] # 4. Extract spans from predictions spans = [(i, model.id2label[tag_id]) for i, tag_id in enumerate(predicted_tags) if tag_id != 0] tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # 5. Display results print("Detected medical jargon:") for token_idx, label in spans: # Find continuous spans of the same entity end_idx = token_idx + 1 while (end_idx < len(predicted_tags) and predicted_tags[end_idx] == predicted_tags[token_idx]): end_idx += 1 # Convert tokens back to text detected_tokens = tokens[token_idx:end_idx] detected_text = tokenizer.convert_tokens_to_string(detected_tokens) print(f"{label}: '{detected_text.strip()}'") ``` --- ## 🏥 Supported Tasks * **Medical jargon detection** – binary, 3-class, or 7-category granularity * **Named-entity recognition** – extract spans of medical interest * **Readability analysis** – density of jargon per sentence or document * **Downstream QA & summarization** – filter or simplify complex terms --- ## 🌍 Language English only. --- ## 📚 Training Data Fine-tuned on **MedReadMe**: 4,520 sentences with fine-grained jargon span annotations, including the novel *Google-Easy* and *Google-Hard* categories . --- ## 📖 Citation If you use this model or the underlying dataset, please cite: ```bibtex @article{jiang2024medreadmesystematicstudyfinegrained, title={MedReadMe: A Systematic Study for Fine-grained Sentence Readability in Medical Domain}, author={Chao Jiang and Wei Xu}, year={2024}, eprint={2405.02144}, archivePrefix={arXiv}, primaryClass={cs.CL}, url={https://arxiv.org/abs/2405.02144} } ``` --- ## 📝 License & Usage Licensed under **Apache 2.0**. * ✅ Allowed: research, commercial use, derivative works * Include license notice and attribution in any distribution --- ## ⚠️ Important Notes * Model outputs are **not medical advice**; use for research/educational purposes only. * Performance may vary on text that differs substantially from the MedReadMe training domain. * Consider additional post-processing for production systems (e.g., confidence filtering). --- ## ☎️ Contact For questions, issues, or licensing inquiries, open an issue on the [model repository](https://huggingface.co/DNivalis/med-jargon-crf).