Spaces:
Build error
Build error
File size: 2,465 Bytes
dda3dc2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | {
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "ca5c8ec7",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"# Install dependencies\n",
"!pip install transformers torch datasets\n",
"\n",
"from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n",
"from datasets import load_dataset\n",
"import re\n",
"\n",
"# Load pretrained clinical BART model\n",
"MODEL_NAME = \"dmacres/bart-large-mimiciii-v2\"\n",
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)\n",
"\n",
"# Load discharge summaries\n",
"notes = load_dataset(\n",
" \"ntphuc149/MIMIC-III-Clinical-Database\",\n",
" \"NOTEEVENTS\",\n",
" split=\"train\"\n",
")\n",
"\n",
"notes = notes.filter(lambda x: x[\"CATEGORY\"] == \"Discharge summary\")\n",
"\n",
"# Clean text\n",
"def clean_text(text):\n",
" text = text.lower()\n",
" text = re.sub(r\"\\[\\*\\*.*?\\*\\*\\]\", \"\", text)\n",
" text = re.sub(r\"\\n+\", \" \", text)\n",
" return text.strip()\n",
"\n",
"sample_note = clean_text(notes[0][\"TEXT\"])\n",
"\n",
"# GenAI explanation function\n",
"def generate_explanation(note, risk_score):\n",
" prompt = f\"\"\"\n",
"Discharge summary:\n",
"{note}\n",
"\n",
"Predicted readmission risk: {risk_score:.2f}\n",
"\n",
"Explain the key clinical reasons for readmission risk.\n",
"\"\"\"\n",
"\n",
" inputs = tokenizer(\n",
" prompt,\n",
" return_tensors=\"pt\",\n",
" truncation=True,\n",
" max_length=1024\n",
" )\n",
"\n",
" outputs = model.generate(\n",
" **inputs,\n",
" max_length=200,\n",
" num_beams=4,\n",
" early_stopping=True\n",
" )\n",
"\n",
" return tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
"\n",
"# Demo\n",
"risk_score = 0.72 # from your classifier\n",
"explanation = generate_explanation(sample_note, risk_score)\n",
"\n",
"print(\"Predicted Risk:\", risk_score)\n",
"print(\"\\nGenerated Explanation:\\n\")\n",
"print(explanation)\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|