{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "ca5c8ec7", "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "# Install dependencies\n", "!pip install transformers torch datasets\n", "\n", "from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n", "from datasets import load_dataset\n", "import re\n", "\n", "# Load pretrained clinical BART model\n", "MODEL_NAME = \"dmacres/bart-large-mimiciii-v2\"\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n", "model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)\n", "\n", "# Load discharge summaries\n", "notes = load_dataset(\n", " \"ntphuc149/MIMIC-III-Clinical-Database\",\n", " \"NOTEEVENTS\",\n", " split=\"train\"\n", ")\n", "\n", "notes = notes.filter(lambda x: x[\"CATEGORY\"] == \"Discharge summary\")\n", "\n", "# Clean text\n", "def clean_text(text):\n", " text = text.lower()\n", " text = re.sub(r\"\\[\\*\\*.*?\\*\\*\\]\", \"\", text)\n", " text = re.sub(r\"\\n+\", \" \", text)\n", " return text.strip()\n", "\n", "sample_note = clean_text(notes[0][\"TEXT\"])\n", "\n", "# GenAI explanation function\n", "def generate_explanation(note, risk_score):\n", " prompt = f\"\"\"\n", "Discharge summary:\n", "{note}\n", "\n", "Predicted readmission risk: {risk_score:.2f}\n", "\n", "Explain the key clinical reasons for readmission risk.\n", "\"\"\"\n", "\n", " inputs = tokenizer(\n", " prompt,\n", " return_tensors=\"pt\",\n", " truncation=True,\n", " max_length=1024\n", " )\n", "\n", " outputs = model.generate(\n", " **inputs,\n", " max_length=200,\n", " num_beams=4,\n", " early_stopping=True\n", " )\n", "\n", " return tokenizer.decode(outputs[0], skip_special_tokens=True)\n", "\n", "# Demo\n", "risk_score = 0.72 # from your classifier\n", "explanation = generate_explanation(sample_note, risk_score)\n", "\n", "print(\"Predicted Risk:\", risk_score)\n", "print(\"\\nGenerated Explanation:\\n\")\n", "print(explanation)\n" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }