File size: 2,465 Bytes
dda3dc2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca5c8ec7",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "# Install dependencies\n",
    "!pip install transformers torch datasets\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n",
    "from datasets import load_dataset\n",
    "import re\n",
    "\n",
    "# Load pretrained clinical BART model\n",
    "MODEL_NAME = \"dmacres/bart-large-mimiciii-v2\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    "model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)\n",
    "\n",
    "# Load discharge summaries\n",
    "notes = load_dataset(\n",
    "    \"ntphuc149/MIMIC-III-Clinical-Database\",\n",
    "    \"NOTEEVENTS\",\n",
    "    split=\"train\"\n",
    ")\n",
    "\n",
    "notes = notes.filter(lambda x: x[\"CATEGORY\"] == \"Discharge summary\")\n",
    "\n",
    "# Clean text\n",
    "def clean_text(text):\n",
    "    text = text.lower()\n",
    "    text = re.sub(r\"\\[\\*\\*.*?\\*\\*\\]\", \"\", text)\n",
    "    text = re.sub(r\"\\n+\", \" \", text)\n",
    "    return text.strip()\n",
    "\n",
    "sample_note = clean_text(notes[0][\"TEXT\"])\n",
    "\n",
    "# GenAI explanation function\n",
    "def generate_explanation(note, risk_score):\n",
    "    prompt = f\"\"\"\n",
    "Discharge summary:\n",
    "{note}\n",
    "\n",
    "Predicted readmission risk: {risk_score:.2f}\n",
    "\n",
    "Explain the key clinical reasons for readmission risk.\n",
    "\"\"\"\n",
    "\n",
    "    inputs = tokenizer(\n",
    "        prompt,\n",
    "        return_tensors=\"pt\",\n",
    "        truncation=True,\n",
    "        max_length=1024\n",
    "    )\n",
    "\n",
    "    outputs = model.generate(\n",
    "        **inputs,\n",
    "        max_length=200,\n",
    "        num_beams=4,\n",
    "        early_stopping=True\n",
    "    )\n",
    "\n",
    "    return tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "\n",
    "# Demo\n",
    "risk_score = 0.72  # from your classifier\n",
    "explanation = generate_explanation(sample_note, risk_score)\n",
    "\n",
    "print(\"Predicted Risk:\", risk_score)\n",
    "print(\"\\nGenerated Explanation:\\n\")\n",
    "print(explanation)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}