{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"pygments_lexer":"ipython3","nbconvert_exporter":"python","version":"3.6.4","file_extension":".py","codemirror_mode":{"name":"ipython","version":3},"name":"python","mimetype":"text/x-python"},"kaggle":{"accelerator":"none","dataSources":[],"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":false}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"# This Python 3 environment comes with many helpful analytics libraries installed\n# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n# For example, here's several helpful packages to load\n\nimport numpy as np # linear algebra\nimport pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n\n# Input data files are available in the read-only \"../input/\" directory\n# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n\nimport os\nfor dirname, _, filenames in os.walk('/kaggle/input'):\n for filename in filenames:\n print(os.path.join(dirname, filename))\n\n# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Grad Cam Heatmap Generator","metadata":{}},{"cell_type":"code","source":"import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom PIL import Image\nimport cv2\nimport os\nimport pandas as pd\nfrom scipy.ndimage import gaussian_filter\n\n# -------------------------------------------\n# Configuration\n# -------------------------------------------\nOUTPUT_DIR = \"/kaggle/working\"\nGRADCAM_OUTPUT_DIR = os.path.join(OUTPUT_DIR, 'gradcam_plus_plus_results')\nos.makedirs(GRADCAM_OUTPUT_DIR, exist_ok=True)\n\nprint(\"=\"*80)\nprint(\"HIERARCHICAL TRI-HEAD GRAD-CAM++ CONFIGURATION\")\nprint(\"=\"*80)\nprint(f\"Output directory: {GRADCAM_OUTPUT_DIR}\")\nprint(\"=\"*80)\n\n# -------------------------------------------\n# Class Mappings\n# -------------------------------------------\nDISEASE_CLASS_MAPPING = {\n 0: \"Breast_cancer\",\n 1: \"annrbc-anemia_processed\",\n 2: \"colon_processed\",\n 3: \"leukemia_processed\",\n 4: \"lung_processed\",\n 5: \"oral-cancer_processed\",\n 6: \"ovarian-cancer_processed\",\n 7: \"sickle-cell-new_processed\",\n 8: \"thalassemia_processed\",\n}\n\nSEVERITY_CLASS_MAPPING = {\n 0: \"Normal\",\n 1: \"Abnormal\",\n}\n\n# -------------------------------------------\n# Grad-CAM++ Implementation for ViT\n# -------------------------------------------\n# -------------------------------------------\n# Grad-CAM++ Implementation for ViT (FIXED for tuple outputs)\n# -------------------------------------------\nclass GradCAM:\n \"\"\"\n Standard Grad-CAM implementation for Vision Transformers\n Simplified - no second-order gradients, just straightforward CAM\n \"\"\"\n def __init__(self, model, target_layer):\n \"\"\"\n Args:\n model: Your Phase3 hierarchical model\n target_layer: The layer to hook (typically last transformer block)\n \"\"\"\n self.model = model\n self.target_layer = target_layer\n self.gradients = None\n self.activations = None\n \n # Register hooks\n self.handlers = []\n self._register_hooks()\n \n def _register_hooks(self):\n \"\"\"Register forward and backward hooks on target layer\"\"\"\n def forward_hook(module, input, output):\n # Handle tuple output (DINOv2 returns tuple)\n if isinstance(output, tuple):\n self.activations = output[0].detach()\n print(f\" šŸŖ Forward hook: Captured from tuple, shape {output[0].shape}\")\n else:\n self.activations = output.detach()\n print(f\" šŸŖ Forward hook: Captured tensor, shape {output.shape}\")\n \n def backward_hook(module, grad_input, grad_output):\n # Handle tuple output in gradients\n if isinstance(grad_output, tuple):\n grad = grad_output[0]\n if grad is not None:\n self.gradients = grad.detach()\n print(f\" šŸŖ Backward hook: Captured from tuple, shape {grad.shape}\")\n else:\n if grad_output is not None:\n self.gradients = grad_output.detach()\n print(f\" šŸŖ Backward hook: Captured tensor, shape {grad_output.shape}\")\n \n # Register hooks\n self.handlers.append(\n self.target_layer.register_forward_hook(forward_hook)\n )\n self.handlers.append(\n self.target_layer.register_full_backward_hook(backward_hook)\n )\n \n def remove_hooks(self):\n \"\"\"Remove all hooks\"\"\"\n for handle in self.handlers:\n handle.remove()\n \n def generate_cam(self, class_idx, logits):\n \"\"\"\n Generate standard Grad-CAM heatmap\n \n Args:\n class_idx: Target class index\n logits: Model output logits\n \n Returns:\n cam: Grad-CAM heatmap (H, W)\n \"\"\"\n # Zero gradients\n self.model.zero_grad()\n \n # Backward pass\n one_hot = torch.zeros_like(logits)\n one_hot[0, class_idx] = 1\n logits.backward(gradient=one_hot, retain_graph=True)\n \n # ========================================\n # šŸ” GRADIENT FLOW DEBUGGING\n # ========================================\n print(f\"\\n šŸ” GRADIENT FLOW CHECK:\")\n print(f\" {'='*60}\")\n \n # Check if gradients were captured\n if self.gradients is None:\n print(f\" āŒ CRITICAL: No gradients captured!\")\n return np.zeros((14, 14))\n else:\n print(f\" āœ… Gradients captured: {self.gradients.shape}\")\n print(f\" Min: {self.gradients.min().item():.6f}, Max: {self.gradients.max().item():.6f}\")\n print(f\" Mean: {self.gradients.mean().item():.6f}, Std: {self.gradients.std().item():.6f}\")\n \n # Check if activations were captured\n if self.activations is None:\n print(f\" āŒ CRITICAL: No activations captured!\")\n return np.zeros((14, 14))\n else:\n print(f\" āœ… Activations captured: {self.activations.shape}\")\n print(f\" Min: {self.activations.min().item():.6f}, Max: {self.activations.max().item():.6f}\")\n \n print(f\" {'='*60}\\n\")\n # ========================================\n \n # Standard Grad-CAM computation\n # gradients: [B, N, D]\n # activations: [B, N, D]\n \n # Step 1: Global average pooling on gradients to get weights\n # Take mean across spatial dimension (tokens) for each channel\n weights = self.gradients.mean(dim=1, keepdim=True) # [B, 1, D]\n \n print(f\" šŸ“Š Weights (channel importance):\")\n print(f\" Shape: {weights.shape}\")\n print(f\" Min: {weights.min().item():.6f}, Max: {weights.max().item():.6f}\")\n \n # Step 2: Weighted combination of activation maps\n # weights: [B, 1, D]\n # activations: [B, N, D]\n # Result: [B, N] - one value per token\n cam = (weights * self.activations).sum(dim=2) # Sum across channels (D)\n \n print(f\" šŸ“Š CAM before ReLU:\")\n print(f\" Shape: {cam.shape}\")\n print(f\" Min: {cam.min().item():.6f}, Max: {cam.max().item():.6f}\")\n \n # Step 3: Apply ReLU (only keep positive contributions)\n #cam = F.relu(cam)\n \n print(f\" šŸ“Š CAM after ReLU:\")\n print(f\" Min: {cam.min().item():.6f}, Max: {cam.max().item():.6f}\")\n \n # Step 4: Remove batch dimension\n cam = cam[0] # [N]\n \n # Step 5: Remove CLS token (first token in ViT)\n if cam.shape[0] > 1:\n cam = cam[1:]\n print(f\" šŸŽÆ Removed CLS token, remaining tokens: {cam.shape[0]}\")\n \n # Step 6: Reshape to spatial grid\n grid_size = int(np.sqrt(cam.shape[0]))\n print(f\" šŸ“ Grid size: {grid_size}x{grid_size}\")\n \n cam = cam.reshape(grid_size, grid_size)\n \n # Step 7: Normalize to [0, 1]\n cam_min = cam.min()\n cam_max = cam.max()\n \n if cam_max > cam_min:\n cam = (cam - cam_min) / (cam_max - cam_min)\n print(f\" āœ… Normalized CAM: min={cam.min().item():.4f}, max={cam.max().item():.4f}\")\n else:\n print(f\" āš ļø WARNING: No variation in CAM (all same value)\")\n cam = torch.zeros_like(cam)\n \n print(f\" šŸ“Š Final CAM mean: {cam.mean().item():.6f}\\n\")\n \n return cam.cpu().numpy()\n\n\nclass GradCAMVisionTransformer:\n \"\"\"\n Wrapper to apply standard Grad-CAM to Vision Transformer models\n \"\"\"\n def __init__(self, model, device):\n self.model = model\n self.device = device\n self.model.eval()\n \n # Find the last transformer block\n self.target_layer = self._find_target_layer()\n print(f\" šŸŽÆ Target layer for Grad-CAM: {self.target_layer}\")\n \n def _find_target_layer(self):\n \"\"\"\n Find the last transformer block in DINOv2/Phikon model\n \"\"\"\n print(\"\\nšŸ” Searching for target layer in DINOv2 architecture...\")\n \n try:\n if hasattr(self.model, 'backbone'):\n vit_model = self.model.backbone.vit\n print(f\"āœ… Found backbone.vit: {type(vit_model).__name__}\")\n else:\n raise AttributeError(\"No backbone found\")\n \n if hasattr(vit_model, 'encoder') and hasattr(vit_model.encoder, 'layer'):\n num_layers = len(vit_model.encoder.layer)\n last_layer = vit_model.encoder.layer[-1]\n print(f\"āœ… Found encoder with {num_layers} layers\")\n print(f\"āœ… Target layer: encoder.layer[-1] (layer {num_layers-1})\")\n return last_layer\n else:\n raise AttributeError(\"No encoder.layer found\")\n \n except AttributeError as e:\n print(f\"āŒ Error: {e}\")\n raise ValueError(\"Could not find DINOv2 encoder layers\")\n \n def generate_heatmap(self, image_tensor, target_class_idx, head_type='disease'):\n \"\"\"\n Generate Grad-CAM heatmap for a specific head\n \n Args:\n image_tensor: Input image tensor [1, 3, H, W]\n target_class_idx: Target class index\n head_type: 'disease', 'severity', or 'stage'\n \n Returns:\n heatmap: Grad-CAM heatmap resized to input image size\n \"\"\"\n # Create Grad-CAM instance\n gradcam = GradCAM(self.model, self.target_layer)\n \n # Forward pass with gradients enabled\n image_tensor = image_tensor.to(self.device)\n image_tensor.requires_grad = True\n \n # Get logits based on head type\n disease_logits, severity_logits, stage_logits, _ = self.model([image_tensor], enable_gradients=True)\n \n if head_type == 'disease':\n logits = disease_logits\n elif head_type == 'severity':\n disease_pred_idx = disease_logits.argmax(dim=1).item()\n disease_name = DISEASE_CLASS_MAPPING.get(disease_pred_idx, f\"Unknown_{disease_pred_idx}\")\n logits = severity_logits[disease_name]\n elif head_type == 'stage':\n logits = stage_logits\n else:\n raise ValueError(f\"Unknown head_type: {head_type}\")\n \n # Generate CAM\n cam = gradcam.generate_cam(target_class_idx, logits)\n \n # Clean up hooks\n gradcam.remove_hooks()\n \n # Resize to match input image size\n H, W = image_tensor.shape[2], image_tensor.shape[3]\n cam_resized = cv2.resize(cam, (W, H), interpolation=cv2.INTER_CUBIC)\n \n # Smooth the heatmap\n cam_smooth = gaussian_filter(cam_resized, sigma=2)\n \n return cam_smooth\n# -------------------------------------------\n# Helper Functions\n# -------------------------------------------\ndef create_gradcam_overlay(image_array, heatmap, alpha=0.5, colormap='jet'):\n \"\"\"\n Create a visual overlay of Grad-CAM++ heatmap on original image\n Uses red-yellow colormap like traditional Grad-CAM\n \n Args:\n image_array: Original image as numpy array (H, W, 3)\n heatmap: Grad-CAM++ heatmap (H, W), values in [0, 1]\n alpha: Transparency of heatmap overlay\n colormap: Matplotlib colormap name\n \n Returns:\n Overlayed image as numpy array (H, W, 3) in range [0, 1]\n \"\"\"\n # Normalize image to [0, 1]\n img_normalized = image_array.astype(np.float32) / 255.0\n \n # Resize heatmap to match image size if needed\n target_h, target_w = img_normalized.shape[:2]\n if heatmap.shape != (target_h, target_w):\n print(f\" šŸ“ Resizing heatmap from {heatmap.shape} to ({target_h}, {target_w})\")\n heatmap = cv2.resize(heatmap, (target_w, target_h), interpolation=cv2.INTER_CUBIC)\n \n # Ensure heatmap is in [0, 1]\n heatmap = np.clip(heatmap, 0, 1)\n \n # Apply colormap - Fixed for newer matplotlib\n import matplotlib\n cmap = matplotlib.colormaps.get_cmap(colormap)\n heatmap_colored = cmap(heatmap)[:, :, :3] # Remove alpha channel\n \n # Blend with original image\n overlay = img_normalized * (1 - alpha) + heatmap_colored * alpha\n overlay = np.clip(overlay, 0, 1)\n \n return overlay\n\ndef load_and_preprocess_image(img_path, target_size=224):\n \"\"\"Load image and preprocess for both visualization and model input\"\"\"\n try:\n if img_path.lower().endswith(('.svs', '.ndpi')):\n slide = openslide.OpenSlide(img_path)\n img_pil = slide.get_thumbnail((target_size, target_size))\n slide.close()\n elif img_path.lower().endswith('.tif'):\n try:\n slide = openslide.OpenSlide(img_path)\n img_pil = slide.get_thumbnail((target_size, target_size))\n slide.close()\n except:\n img_pil = Image.open(img_path).convert('RGB')\n img_pil = img_pil.resize((target_size, target_size), Image.BILINEAR)\n else:\n img_pil = Image.open(img_path).convert('RGB')\n img_pil = img_pil.resize((target_size, target_size), Image.BILINEAR)\n \n img_array = np.array(img_pil)\n img_tensor = test_transform(img_pil).unsqueeze(0)\n \n return img_pil, img_array, img_tensor\n \n except Exception as e:\n print(f\"Error loading image {img_path}: {e}\")\n return None, None, None\n\n\ndef extract_stage_label(img_path):\n \"\"\"\n Extract stage label using strict hierarchy rules.\n \n Valid structures:\n - .../abnormal/test/image.png -> no stage -> return -1\n - .../abnormal//test/image.png -> stage exists -> return \n \"\"\"\n path_parts = img_path.split(os.sep)\n \n try:\n # Find 'test' folder\n test_idx = path_parts.index('test')\n \n # Folder immediately above 'test'\n candidate = path_parts[test_idx - 1]\n \n # If abnormal is directly above test → no stage\n if candidate.lower() == 'abnormal':\n return -1\n \n # Otherwise, this folder is the stage\n return candidate\n \n except (ValueError, IndexError):\n # 'test' not found or malformed path\n return -1\n\n\n# -------------------------------------------\n# Main Tri-Head Grad-CAM++ Analysis\n# -------------------------------------------\ndef run_tri_head_gradcam_plus_plus_analysis(model, device, collected_images):\n \"\"\"\n Run Grad-CAM++ analysis on disease head (Level 1), severity head (Level 2), and stage head (Level 3)\n Stage head is only analyzed when ground truth is abnormal and stage label is not -1\n Returns comprehensive dictionaries with all analysis results\n \"\"\"\n print(\"\\n\" + \"=\"*80)\n print(\"šŸ” STARTING TRI-HEAD GRAD-CAM++ ANALYSIS\")\n print(\"=\"*80)\n \n if not collected_images:\n print(\"āŒ No images provided\")\n return []\n \n print(f\"āœ… Processing {len(collected_images)} images\\n\")\n \n # Initialize Grad-CAM++ wrapper\n gradcam_wrapper = GradCAMVisionTransformer(model, device)\n\n \n all_results = []\n \n for idx, img_path in enumerate(collected_images):\n print(f\"\\n[{idx+1}/{len(collected_images)}] Processing: {os.path.basename(img_path)}\")\n \n try:\n # Load image\n img_pil, img_array, img_tensor = load_and_preprocess_image(img_path, target_size=224)\n \n if img_tensor is None:\n print(f\" āŒ Failed to load image\")\n continue\n \n # Extract metadata\n path_parts = img_path.split(os.sep)\n dataset_name = \"Unknown\"\n true_label = \"unknown\"\n stage_label = extract_stage_label(img_path)\n \n for part in path_parts:\n if part in [\"ovarian-cancer_processed\", \"oral-cancer_processed\", \n \"Breast_cancer\", \"colon_processed\", \"lung_processed\",\n \"annrbc-anemia_processed\", \"leukemia_processed\",\n \"sickle-cell-new_processed\", \"thalassemia_processed\"]:\n dataset_name = part\n if part in [\"normal\", \"abnormal\"]:\n true_label = part\n break\n \n # ===== STEP 1: Get predictions from all heads =====\n with torch.no_grad():\n img_tensor_device = img_tensor.to(device)\n disease_logits, severity_logits_dict, stage_logits, _ = model([img_tensor_device])\n \n # Disease prediction (Level 1)\n disease_pred_idx = disease_logits.argmax(dim=1).item()\n disease_probs = F.softmax(disease_logits, dim=1)\n disease_confidence = disease_probs[0, disease_pred_idx].item()\n disease_name = DISEASE_CLASS_MAPPING.get(disease_pred_idx, f\"Unknown_{disease_pred_idx}\")\n disease_all_probs = disease_probs[0].cpu().numpy()\n \n print(f\" šŸ“Š Level 1 (Disease): {disease_name}\")\n print(f\" Index: {disease_pred_idx}, Confidence: {disease_confidence:.4f}\")\n \n # Severity prediction (Level 2)\n severity_logits = severity_logits_dict[disease_name]\n severity_pred_idx = severity_logits.argmax(dim=1).item()\n severity_probs = F.softmax(severity_logits, dim=1)\n severity_confidence = severity_probs[0, severity_pred_idx].item()\n severity_label_text = SEVERITY_CLASS_MAPPING.get(severity_pred_idx, f\"Unknown_{severity_pred_idx}\")\n severity_all_probs = severity_probs[0].cpu().numpy()\n \n print(f\" šŸ“Š Level 2 (Severity): {severity_label_text}\")\n print(f\" Index: {severity_pred_idx}, Confidence: {severity_confidence:.4f}\")\n \n # Stage prediction (Level 3) - if available\n stage_pred_idx = None\n stage_confidence = None\n stage_all_probs = None\n \n if stage_logits is not None:\n stage_pred_idx = stage_logits.argmax(dim=1).item()\n stage_probs = F.softmax(stage_logits, dim=1)\n stage_confidence = stage_probs[0, stage_pred_idx].item()\n stage_all_probs = stage_probs[0].cpu().numpy()\n \n print(f\" šŸ“Š Level 3 (Stage): Stage {stage_pred_idx}\")\n print(f\" Confidence: {stage_confidence:.4f}\")\n print(f\" Ground Truth Stage: {stage_label}\")\n \n # ===== STEP 2: Generate Grad-CAM++ for disease head =====\n print(f\"\\n šŸ”„ Generating Grad-CAM++ for Disease Head...\")\n disease_heatmap = gradcam_wrapper.generate_heatmap(\n img_tensor.clone(),\n disease_pred_idx,\n head_type='disease'\n )\n \n # Create overlay\n disease_overlay = create_gradcam_overlay(img_array, disease_heatmap, alpha=0.5)\n \n print(f\" āœ… Disease heatmap generated\")\n print(f\" Min: {disease_heatmap.min():.4f}, Max: {disease_heatmap.max():.4f}\")\n \n # ===== STEP 3: Generate Grad-CAM++ for severity head =====\n print(f\" šŸ”„ Generating Grad-CAM++ for Severity Head...\")\n severity_heatmap = gradcam_wrapper.generate_heatmap(\n img_tensor.clone(),\n severity_pred_idx,\n head_type='severity'\n )\n \n # Create overlay\n severity_overlay = create_gradcam_overlay(img_array, severity_heatmap, alpha=0.5)\n \n print(f\" āœ… Severity heatmap generated\")\n print(f\" Min: {severity_heatmap.min():.4f}, Max: {severity_heatmap.max():.4f}\")\n \n # ===== STEP 4: Generate Grad-CAM++ for stage head (conditional) =====\n stage_heatmap = None\n stage_overlay = None\n include_stage_analysis = False\n \n # Check conditions: abnormal ground truth AND stage label != -1\n if true_label == \"abnormal\" and stage_label != -1 and stage_logits is not None:\n include_stage_analysis = True\n print(f\" šŸ”„ Generating Grad-CAM++ for Stage Head (GT: abnormal, Stage: {stage_label})...\")\n \n stage_heatmap = gradcam_wrapper.generate_heatmap(\n img_tensor.clone(),\n stage_pred_idx,\n head_type='stage'\n )\n \n # Create overlay\n stage_overlay = create_gradcam_overlay(img_array, stage_heatmap, alpha=0.5)\n \n print(f\" āœ… Stage heatmap generated\")\n print(f\" Min: {stage_heatmap.min():.4f}, Max: {stage_heatmap.max():.4f}\")\n else:\n reason = []\n if true_label != \"abnormal\":\n reason.append(f\"true_label='{true_label}'\")\n if stage_label == -1:\n reason.append(\"stage_label=-1\")\n if stage_logits is None:\n reason.append(\"stage_logits=None\")\n print(f\" ā­ļø Skipping Stage Head Analysis ({', '.join(reason)})\")\n \n # ===== STEP 5: Create Union Heatmap =====\n if include_stage_analysis:\n # Average of all three heatmaps\n union_heatmap = (disease_heatmap + severity_heatmap + stage_heatmap) / 3.0\n print(f\" šŸ“Š Union Heatmap: Average of 3 heads (Disease + Severity + Stage)\")\n else:\n # Average of two heatmaps\n union_heatmap = (disease_heatmap + severity_heatmap) / 2.0\n print(f\" šŸ“Š Union Heatmap: Average of 2 heads (Disease + Severity)\")\n \n union_overlay = create_gradcam_overlay(img_array, union_heatmap, alpha=0.5)\n \n # ===== STEP 6: Calculate statistics =====\n disease_mean_activation = float(disease_heatmap.mean())\n disease_max_activation = float(disease_heatmap.max())\n \n severity_mean_activation = float(severity_heatmap.mean())\n severity_max_activation = float(severity_heatmap.max())\n \n stage_mean_activation = None\n stage_max_activation = None\n if stage_heatmap is not None:\n stage_mean_activation = float(stage_heatmap.mean())\n stage_max_activation = float(stage_heatmap.max())\n \n union_mean_activation = float(union_heatmap.mean())\n union_max_activation = float(union_heatmap.max())\n \n # ===== STEP 7: Compile comprehensive results dictionary =====\n result_dict = {\n # ===== Image Information =====\n 'filename': os.path.basename(img_path),\n 'full_path': img_path,\n 'dataset_name': dataset_name,\n 'true_label': true_label,\n 'stage_label': stage_label,\n 'include_stage_analysis': include_stage_analysis,\n \n # ===== Original Image =====\n 'image': img_array,\n \n # ===== Level 1: Disease Head Results =====\n 'level1_disease': {\n 'predicted_class': disease_name,\n 'predicted_idx': disease_pred_idx,\n 'confidence': disease_confidence,\n 'all_probabilities': disease_all_probs,\n 'heatmap_raw': disease_heatmap,\n 'heatmap_overlay': disease_overlay,\n 'activation_stats': {\n 'mean': disease_mean_activation,\n 'max': disease_max_activation,\n }\n },\n \n # ===== Level 2: Severity Head Results =====\n 'level2_severity': {\n 'predicted_class': severity_label_text,\n 'predicted_idx': severity_pred_idx,\n 'confidence': severity_confidence,\n 'all_probabilities': severity_all_probs,\n 'heatmap_raw': severity_heatmap,\n 'heatmap_overlay': severity_overlay,\n 'activation_stats': {\n 'mean': severity_mean_activation,\n 'max': severity_max_activation,\n }\n },\n \n # ===== Level 3: Stage Head Results (conditional) =====\n 'level3_stage': {\n 'predicted_idx': stage_pred_idx,\n 'confidence': stage_confidence,\n 'all_probabilities': stage_all_probs,\n 'heatmap_raw': stage_heatmap,\n 'heatmap_overlay': stage_overlay,\n 'activation_stats': {\n 'mean': stage_mean_activation,\n 'max': stage_max_activation,\n } if stage_heatmap is not None else None\n },\n \n # ===== Union Results =====\n 'union': {\n 'heatmap_raw': union_heatmap,\n 'heatmap_overlay': union_overlay,\n 'num_heads_averaged': 3 if include_stage_analysis else 2,\n 'activation_stats': {\n 'mean': union_mean_activation,\n 'max': union_max_activation,\n }\n },\n \n # ===== Legacy Fields =====\n 'disease_heatmap': disease_heatmap,\n 'severity_heatmap': severity_heatmap,\n 'stage_heatmap': stage_heatmap,\n 'disease_pred': disease_name,\n 'disease_idx': disease_pred_idx,\n 'disease_conf': disease_confidence,\n 'severity_pred': severity_label_text,\n 'severity_idx': severity_pred_idx,\n 'severity_conf': severity_confidence,\n 'stage_pred_idx': stage_pred_idx,\n 'stage_conf': stage_confidence,\n }\n \n all_results.append(result_dict)\n \n print(f\" āœ… Completed tri-head Grad-CAM++ analysis\")\n \n except Exception as e:\n print(f\" āŒ Error: {e}\")\n import traceback\n traceback.print_exc()\n continue\n \n print(\"\\n\" + \"=\"*80)\n print(\"āœ… TRI-HEAD GRAD-CAM++ ANALYSIS COMPLETE\")\n print(f\"šŸ“¦ Generated {len(all_results)} comprehensive result dictionaries\")\n print(\"=\"*80)\n \n return all_results\n\n\n# -------------------------------------------\n# Visualization Function\n# -------------------------------------------\ndef display_tri_head_gradcam_grid(results):\n \"\"\"\n Display grid: each row = one image with 5 columns (or 4 if no stage analysis)\n [Original | Disease Grad-CAM++ | Severity Grad-CAM++ | Stage Grad-CAM++ (if available) | Union]\n \"\"\"\n if not results:\n print(\"No results to display\")\n return\n \n num_images = len(results)\n max_cols = 5\n \n # Create figure\n fig, axes = plt.subplots(num_images, max_cols, figsize=(35, 7 * num_images))\n \n # Handle single image case\n if num_images == 1:\n axes = axes.reshape(1, -1)\n \n cmap = plt.cm.jet\n \n for i, result in enumerate(results):\n has_stage = result['include_stage_analysis']\n \n # Column 1: Original Image\n axes[i, 0].imshow(result['image'])\n title_text = (\n f\"Original Image {i+1}\\n\"\n f\"Dataset: {result['dataset_name']}\\n\"\n f\"True Label: {result['true_label']}\\n\"\n f\"Stage GT: {result['stage_label']}\\n\"\n f\"File: {result['filename']}\"\n )\n axes[i, 0].set_title(title_text, fontsize=10, fontweight='bold', pad=10)\n axes[i, 0].axis('off')\n \n # Column 2: Disease Head Grad-CAM++ (Level 1)\n axes[i, 1].imshow(result['image'])\n \n disease_heatmap = result['level1_disease']['heatmap_raw']\n \n im1 = axes[i, 1].imshow(\n disease_heatmap,\n cmap=cmap,\n alpha=0.5,\n vmin=0,\n vmax=1\n )\n \n cbar1 = plt.colorbar(im1, ax=axes[i, 1], fraction=0.046, pad=0.04)\n cbar1.set_label('Activation', rotation=270, labelpad=15)\n \n disease_title = (\n f\"Level 1: Disease Head\\n\"\n f\"Predicted: {result['level1_disease']['predicted_class']}\\n\"\n f\"Confidence: {result['level1_disease']['confidence']:.4f}\\n\"\n f\"Mean Act: {result['level1_disease']['activation_stats']['mean']:.4f}\"\n )\n axes[i, 1].set_title(disease_title, fontsize=10, fontweight='bold', pad=10)\n axes[i, 1].axis('off')\n \n # Column 3: Severity Head Grad-CAM++ (Level 2)\n axes[i, 2].imshow(result['image'])\n \n severity_heatmap = result['level2_severity']['heatmap_raw']\n \n im2 = axes[i, 2].imshow(\n severity_heatmap,\n cmap=cmap,\n alpha=0.5,\n vmin=0,\n vmax=1\n )\n \n cbar2 = plt.colorbar(im2, ax=axes[i, 2], fraction=0.046, pad=0.04)\n cbar2.set_label('Activation', rotation=270, labelpad=15)\n \n severity_title = (\n f\"Level 2: Severity Head\\n\"\n f\"Predicted: {result['level2_severity']['predicted_class']}\\n\"\n f\"Confidence: {result['level2_severity']['confidence']:.4f}\\n\"\n f\"Mean Act: {result['level2_severity']['activation_stats']['mean']:.4f}\"\n )\n axes[i, 2].set_title(severity_title, fontsize=10, fontweight='bold', pad=10)\n axes[i, 2].axis('off')\n \n # Column 4: Stage Head Grad-CAM++ (Level 3) - Conditional\n if has_stage:\n axes[i, 3].imshow(result['image'])\n \n stage_heatmap = result['level3_stage']['heatmap_raw']\n \n im3 = axes[i, 3].imshow(\n stage_heatmap,\n cmap=cmap,\n alpha=0.5,\n vmin=0,\n vmax=1\n )\n \n cbar3 = plt.colorbar(im3, ax=axes[i, 3], fraction=0.046, pad=0.04)\n cbar3.set_label('Activation', rotation=270, labelpad=15)\n \n stage_title = (\n f\"Level 3: Stage Head\\n\"\n f\"Predicted: Stage {result['level3_stage']['predicted_idx']}\\n\"\n f\"Confidence: {result['level3_stage']['confidence']:.4f}\\n\"\n f\"Mean Act: {result['level3_stage']['activation_stats']['mean']:.4f}\"\n )\n axes[i, 3].set_title(stage_title, fontsize=10, fontweight='bold', pad=10)\n axes[i, 3].axis('off')\n else:\n # Display placeholder text\n axes[i, 3].text(\n 0.5, 0.5,\n \"Stage Analysis\\nNot Applicable\\n\\n\" +\n (f\"Reason: GT={result['true_label']}\\n\" if result['true_label'] != 'abnormal' else \"\") +\n (f\"Stage={result['stage_label']}\" if result['stage_label'] == -1 else \"\"),\n ha='center', va='center',\n fontsize=12, color='gray',\n transform=axes[i, 3].transAxes\n )\n axes[i, 3].axis('off')\n \n # Column 5: Union Grad-CAM++\n axes[i, 4].imshow(result['image'])\n \n union_heatmap = result['union']['heatmap_raw']\n \n im4 = axes[i, 4].imshow(\n union_heatmap,\n cmap=cmap,\n alpha=0.5,\n vmin=0,\n vmax=1\n )\n \n cbar4 = plt.colorbar(im4, ax=axes[i, 4], fraction=0.046, pad=0.04)\n cbar4.set_label('Activation', rotation=270, labelpad=15)\n \n union_title = (\n f\"Union: Combined Grad-CAM++\\n\"\n f\"Averaged {result['union']['num_heads_averaged']} Heads\\n\"\n f\"Disease: {result['level1_disease']['predicted_class']}\\n\"\n f\"Severity: {result['level2_severity']['predicted_class']}\"\n )\n if has_stage:\n union_title += f\"\\nStage: {result['level3_stage']['predicted_idx']}\"\n union_title += f\"\\nMean Act: {result['union']['activation_stats']['mean']:.4f}\"\n \n axes[i, 4].set_title(union_title, fontsize=10, fontweight='bold', pad=10)\n axes[i, 4].axis('off')\n \n # Print statistics\n print(f\"\\nšŸ“Š Image {i+1} ({result['filename']}) Statistics:\")\n print(f\" Disease Head: {result['level1_disease']['predicted_class']} \"\n f\"({result['level1_disease']['confidence']:.4f})\")\n print(f\" Mean Activation: {result['level1_disease']['activation_stats']['mean']:.4f}, \"\n f\"Max: {result['level1_disease']['activation_stats']['max']:.4f}\")\n \n print(f\" Severity Head: {result['level2_severity']['predicted_class']} \"\n f\"({result['level2_severity']['confidence']:.4f})\")\n print(f\" Mean Activation: {result['level2_severity']['activation_stats']['mean']:.4f}, \"\n f\"Max: {result['level2_severity']['activation_stats']['max']:.4f}\")\n \n if has_stage:\n print(f\" Stage Head: Stage {result['level3_stage']['predicted_idx']} \"\n f\"({result['level3_stage']['confidence']:.4f})\")\n print(f\" Mean Activation: {result['level3_stage']['activation_stats']['mean']:.4f}, \"\n f\"Max: {result['level3_stage']['activation_stats']['max']:.4f}\")\n else:\n print(f\" Stage Head: Not analyzed (GT: {result['true_label']}, Stage: {result['stage_label']})\")\n \n print(f\" Union Heatmap ({result['union']['num_heads_averaged']} heads):\")\n print(f\" Mean Activation: {result['union']['activation_stats']['mean']:.4f}, \"\n f\"Max: {result['union']['activation_stats']['max']:.4f}\")\n \n plt.suptitle(\n 'Hierarchical Model - Tri-Head Grad-CAM++ Analysis with Union\\n'\n 'Level 1: Disease | Level 2: Severity | Level 3: Stage (Conditional) | Union: Combined Analysis\\n'\n 'Red = High Activation | Blue = Low Activation',\n fontsize=16,\n fontweight='bold',\n y=0.998\n )\n \n plt.tight_layout()\n \n # Save\n grid_save_path = os.path.join(GRADCAM_OUTPUT_DIR, 'tri_head_union_gradcam_plus_plus_analysis.png')\n plt.savefig(grid_save_path, dpi=150, bbox_inches='tight')\n print(f\"\\nāœ… Grid saved to: {grid_save_path}\")\n \n plt.show()\n\n\n# -------------------------------------------\n# Execute Analysis\n# -------------------------------------------\nprint(\"\\n\" + \"=\"*80)\nprint(\"CHECKING FOR COLLECTED IMAGES\")\nprint(\"=\"*80)\n\ntry:\n if 'collected_images' in locals() or 'collected_images' in globals():\n print(f\"āœ… Found collected_images with {len(collected_images)} images\\n\")\n \n # Run tri-head Grad-CAM++ analysis\n gradcam_results = run_tri_head_gradcam_plus_plus_analysis(\n model,\n device,\n collected_images\n )\n \n # Display results\n if gradcam_results:\n display_tri_head_gradcam_grid(gradcam_results)\n \n # Save summary\n results_summary = []\n for r in gradcam_results:\n summary_row = {\n 'filename': r['filename'],\n 'dataset': r['dataset_name'],\n 'true_label': r['true_label'],\n 'stage_gt': r['stage_label'],\n 'disease_predicted': r['level1_disease']['predicted_class'],\n 'disease_confidence': r['level1_disease']['confidence'],\n 'disease_mean_activation': r['level1_disease']['activation_stats']['mean'],\n 'severity_predicted': r['level2_severity']['predicted_class'],\n 'severity_confidence': r['level2_severity']['confidence'],\n 'severity_mean_activation': r['level2_severity']['activation_stats']['mean'],\n }\n \n if r['include_stage_analysis']:\n summary_row.update({\n 'stage_predicted': r['level3_stage']['predicted_idx'],\n 'stage_confidence': r['level3_stage']['confidence'],\n 'stage_mean_activation': r['level3_stage']['activation_stats']['mean'],\n })\n else:\n summary_row.update({\n 'stage_predicted': 'N/A',\n 'stage_confidence': 'N/A',\n 'stage_mean_activation': 'N/A',\n })\n \n summary_row['union_heads_averaged'] = r['union']['num_heads_averaged']\n summary_row['union_mean_activation'] = r['union']['activation_stats']['mean']\n \n results_summary.append(summary_row)\n \n summary_df = pd.DataFrame(results_summary)\n summary_path = os.path.join(GRADCAM_OUTPUT_DIR, 'tri_head_union_gradcam_plus_plus_summary.csv')\n summary_df.to_csv(summary_path, index=False)\n print(f\"\\nāœ… Summary saved to: {summary_path}\")\n \n print(\"\\n\" + \"=\"*80)\n print(\"TRI-HEAD GRAD-CAM++ ANALYSIS SUMMARY\")\n print(\"=\"*80)\n print(summary_df.to_string(index=False))\n print(\"=\"*80)\n \n # Print structure of results for reference\n print(\"\\n\" + \"=\"*80)\n print(\"šŸ“¦ RESULTS STRUCTURE\")\n print(\"=\"*80)\n print(\"Each result dictionary contains:\")\n print(\" - filename, full_path, dataset_name, true_label, stage_label\")\n print(\" - include_stage_analysis: boolean flag\")\n print(\" - image: original image array\")\n print(\" - level1_disease: {\")\n print(\" predicted_class, predicted_idx, confidence, all_probabilities\")\n print(\" heatmap_raw, heatmap_overlay, activation_stats\")\n print(\" }\")\n print(\" - level2_severity: {\")\n print(\" predicted_class, predicted_idx, confidence, all_probabilities\")\n print(\" heatmap_raw, heatmap_overlay, activation_stats\")\n print(\" }\")\n print(\" - level3_stage: {\")\n print(\" predicted_idx, confidence, all_probabilities\")\n print(\" heatmap_raw (None if not analyzed), heatmap_overlay (None if not analyzed)\")\n print(\" activation_stats (None if not analyzed)\")\n print(\" }\")\n print(\" - union: {\")\n print(\" heatmap_raw (average of 2 or 3 heads)\")\n print(\" heatmap_overlay, num_heads_averaged, activation_stats\")\n print(\" }\")\n print(\"=\"*80)\n \n print(f\"\\nāœ… gradcam_results variable contains {len(gradcam_results)} dictionaries\")\n print(\" Use gradcam_results in the next cell for further analysis!\")\n \n else:\n print(\"\\nāŒ No results generated\")\n else:\n print(\"āŒ collected_images not found!\")\n \nexcept Exception as e:\n print(f\"āŒ Error: {e}\")\n import traceback\n traceback.print_exc()\n\n\n\n","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Attention Heatmap Generator","metadata":{}},{"cell_type":"code","source":"import torch\nimport torch.nn.functional as F\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom PIL import Image\nimport cv2\nfrom scipy.ndimage import zoom, gaussian_filter\nimport os\n\n# -------------------------------------------\n# Configuration\n# -------------------------------------------\nOUTPUT_DIR=\"/kaggle/working\"\nATTENTION_OUTPUT_DIR = os.path.join(OUTPUT_DIR, 'attention_results')\nos.makedirs(ATTENTION_OUTPUT_DIR, exist_ok=True)\n\nprint(\"=\"*80)\nprint(\"ATTENTION VISUALIZATION CONFIGURATION\")\nprint(\"=\"*80)\nprint(f\"Output directory: {ATTENTION_OUTPUT_DIR}\")\nprint(\"=\"*80)\n\n# -------------------------------------------\n# Attention Extraction Wrapper\n# -------------------------------------------\nclass AttentionExtractor(nn.Module):\n \"\"\"\n Wrapper to extract attention weights from the hierarchical model\n \"\"\"\n def __init__(self, phase3_model):\n super().__init__()\n self.phase3_model = phase3_model\n self.attention_weights = None\n self.tile_features = None\n \n def forward(self, tiles):\n \"\"\"\n Extract attention weights and tile features from the model\n \"\"\"\n # Get model outputs including attention weights\n disease_logits, severity_logits, stage_logits, attention_weights = self.phase3_model(tiles)\n \n # Store attention weights for visualization\n self.attention_weights = attention_weights\n \n return disease_logits, severity_logits, stage_logits, attention_weights\n\n# Create attention extractor\nattention_extractor = AttentionExtractor(model).to(device)\nattention_extractor.eval()\n\nprint(\"\\nāœ… Attention extractor created successfully\\n\")\n\n# -------------------------------------------\n# Helper Functions\n# -------------------------------------------\ndef extract_attention_map(model, preprocessed_image, device):\n \"\"\"\n Extract attention weights from the model for a single image\n Since each image is a single tile, we get attention for that single representation\n \n Args:\n model: AttentionExtractor model\n preprocessed_image: Preprocessed tensor (single image, not tiled)\n device: torch device\n \n Returns:\n attention_weights: numpy array of attention weight (single value for single tile)\n disease_logits, severity_logits, stage_logits: model outputs\n \"\"\"\n model.eval()\n \n try:\n # Each preprocessed image is already a single tensor of shape (C, H, W)\n # We need to add batch dimension and wrap in list\n if preprocessed_image.dim() == 3:\n # Single image: (C, H, W) -> (1, C, H, W)\n image_batch = preprocessed_image.unsqueeze(0)\n else:\n # Already has batch dimension\n image_batch = preprocessed_image\n \n # Wrap in list as model expects list of tile batches\n # Since we have single image as single tile, this is [1 tile batch]\n tiles_list = [image_batch.to(device)]\n \n with torch.no_grad():\n disease_logits, severity_logits, stage_logits, attention_weights = model(tiles_list)\n \n # Convert attention weights to numpy\n # For single tile, this will be shape (1,) or (1, 1)\n attention_np = attention_weights.squeeze().cpu().numpy()\n \n # Ensure it's at least 1D\n if attention_np.ndim == 0:\n attention_np = np.array([attention_np.item()])\n \n print(f\" Extracted attention weights: shape={attention_np.shape}, value={attention_np}\")\n \n return attention_np, disease_logits, severity_logits, stage_logits\n \n except Exception as e:\n print(f\" Error extracting attention: {e}\")\n import traceback\n traceback.print_exc()\n return None, None, None, None\n\ndef create_uniform_attention_heatmap(attention_weight, image_shape):\n \"\"\"\n Create a uniform attention heatmap for a single tile (entire image)\n Since the whole image is one tile, the attention is uniform across it\n \n Args:\n attention_weight: single attention weight value\n image_shape: tuple (height, width) of image\n \n Returns:\n heatmap: 2D array with uniform attention value\n \"\"\"\n # Since we have single tile = whole image, create uniform heatmap\n # with the attention weight value\n heatmap = np.full(image_shape, attention_weight, dtype=np.float32)\n \n # Normalize to [0, 1] for visualization\n if heatmap.max() > 0:\n heatmap = heatmap / heatmap.max()\n \n return heatmap\n\ndef extract_patch_level_attention(model_backbone, preprocessed_image, device, patch_size=16):\n \"\"\"\n Extract patch-level attention from ViT backbone\n ViT processes image as patches, we can visualize their importance\n \n Args:\n model_backbone: ViT backbone model\n preprocessed_image: Preprocessed tensor\n device: torch device\n patch_size: ViT patch size (default 16 for most ViTs)\n \n Returns:\n patch_attention_map: 2D heatmap showing patch-level importance\n \"\"\"\n try:\n if preprocessed_image.dim() == 3:\n image_batch = preprocessed_image.unsqueeze(0).to(device)\n else:\n image_batch = preprocessed_image.to(device)\n \n with torch.no_grad():\n # Get ViT outputs - last_hidden_state contains all patch embeddings\n outputs = model_backbone.vit(pixel_values=image_batch)\n # Shape: (batch, num_patches + 1, embed_dim)\n # First token is CLS token, rest are patch tokens\n \n hidden_states = outputs.last_hidden_state\n \n # Get patch tokens (exclude CLS token at index 0)\n patch_tokens = hidden_states[:, 1:, :] # (1, num_patches, embed_dim)\n \n # Compute importance as L2 norm of each patch embedding\n patch_importance = torch.norm(patch_tokens, p=2, dim=2).squeeze().cpu().numpy()\n \n # Calculate grid dimensions\n # For 224x224 image with patch_size=16: 14x14 = 196 patches\n num_patches = len(patch_importance)\n grid_size = int(np.sqrt(num_patches))\n \n # Reshape to 2D grid\n attention_grid = patch_importance.reshape(grid_size, grid_size)\n \n print(f\" Extracted patch-level attention: {grid_size}x{grid_size} patches\")\n \n return attention_grid\n \n except Exception as e:\n print(f\" Error extracting patch attention: {e}\")\n import traceback\n traceback.print_exc()\n return None\n\ndef create_patch_attention_heatmap(patch_attention_grid, target_shape):\n \"\"\"\n Upsample patch-level attention to image dimensions\n \n Args:\n patch_attention_grid: 2D grid of patch attention values\n target_shape: tuple (height, width) for output\n \n Returns:\n heatmap: upsampled attention heatmap\n \"\"\"\n # Calculate zoom factors\n zoom_factors = (target_shape[0] / patch_attention_grid.shape[0],\n target_shape[1] / patch_attention_grid.shape[1])\n \n # Upsample using bilinear interpolation\n heatmap = zoom(patch_attention_grid, zoom_factors, order=1)\n \n # Normalize to [0, 1]\n heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)\n \n # Apply smoothing for better visualization\n heatmap = gaussian_filter(heatmap, sigma=5)\n \n return heatmap\n\ndef load_image_for_attention(img_path, target_size=768):\n \"\"\"Load and resize image for attention visualization\"\"\"\n try:\n if img_path.lower().endswith(('.svs', '.ndpi')):\n slide = openslide.OpenSlide(img_path)\n img_pil = slide.get_thumbnail((target_size, target_size))\n slide.close()\n elif img_path.lower().endswith('.tif'):\n try:\n slide = openslide.OpenSlide(img_path)\n img_pil = slide.get_thumbnail((target_size, target_size))\n slide.close()\n except:\n img_pil = Image.open(img_path).convert('RGB')\n img_pil = img_pil.resize((target_size, target_size), Image.BILINEAR)\n else:\n img_pil = Image.open(img_path).convert('RGB')\n img_pil = img_pil.resize((target_size, target_size), Image.BILINEAR)\n \n img_array = np.array(img_pil)\n return img_pil, img_array\n \n except Exception as e:\n print(f\"Error loading image {img_path}: {e}\")\n return None, None\n\n# -------------------------------------------\n# Main Attention Extraction Function\n# -------------------------------------------\ndef run_attention_analysis(attention_model, device, collected_images, processed_images,\n main_class_mapping, stage_class_mapping):\n \"\"\"\n Extract and visualize attention weights from the hierarchical model\n Uses patch-level attention from ViT backbone since images are single tiles\n \n Args:\n attention_model: AttentionExtractor model\n device: torch device\n collected_images: list of image paths\n processed_images: list of preprocessed tensors (single images)\n main_class_mapping: dictionary mapping class indices to names\n stage_class_mapping: dictionary mapping stage indices to names\n \n Returns:\n list of results dictionaries\n \"\"\"\n print(\"\\n\" + \"=\"*80)\n print(\"šŸŽÆ STARTING ATTENTION WEIGHT EXTRACTION AND VISUALIZATION\")\n print(\"=\"*80)\n print(\"ā„¹ļø Note: Each image is treated as a single tile\")\n print(\"ā„¹ļø Using patch-level attention from ViT backbone for visualization\")\n \n if not collected_images or not processed_images:\n print(\"āŒ No images or preprocessed data provided\")\n return []\n \n print(f\"āœ… Processing {len(collected_images)} images\\n\")\n \n all_results = []\n \n for idx, (img_path, preprocessed_image) in enumerate(zip(collected_images, processed_images)):\n print(f\"\\n[{idx+1}/{len(collected_images)}] Processing: {os.path.basename(img_path)}\")\n print(f\" Path: {img_path}\")\n print(f\" Image shape: {preprocessed_image.shape}\")\n \n try:\n # Load original image for visualization\n img_pil, img_array = load_image_for_attention(img_path, target_size=768)\n \n if img_array is None:\n print(f\" āŒ Failed to load image\")\n continue\n \n # Extract MIL-level attention weights (single value for single tile)\n attention_weights, disease_logits, severity_logits, stage_logits = extract_attention_map(\n attention_model, \n preprocessed_image, \n device\n )\n \n if attention_weights is None:\n print(f\" āŒ Failed to extract attention\")\n continue\n \n # Extract patch-level attention from ViT backbone\n patch_attention = extract_patch_level_attention(\n attention_model.phase3_model.backbone,\n preprocessed_image,\n device\n )\n \n if patch_attention is not None:\n # Create heatmap from patch attention\n attention_heatmap = create_patch_attention_heatmap(\n patch_attention,\n img_array.shape[:2]\n )\n print(f\" āœ… Created patch-level attention heatmap\")\n else:\n # Fallback: uniform heatmap with MIL attention weight\n attention_heatmap = create_uniform_attention_heatmap(\n attention_weights[0],\n img_array.shape[:2]\n )\n print(f\" ā„¹ļø Using uniform attention heatmap\")\n \n # Get predictions\n with torch.no_grad():\n disease_probs = F.softmax(disease_logits, dim=1)\n disease_pred_idx = torch.argmax(disease_probs, dim=1).item()\n disease_confidence = disease_probs[0, disease_pred_idx].item()\n \n predicted_class_name = main_class_mapping.get(\n disease_pred_idx, \n f\"Unknown_Class_{disease_pred_idx}\"\n )\n \n # Get severity prediction\n if \"_normal\" in predicted_class_name:\n predicted_disease = predicted_class_name.replace(\"_normal\", \"\")\n severity_label = \"Normal\"\n elif \"_abnormal\" in predicted_class_name:\n predicted_disease = predicted_class_name.replace(\"_abnormal\", \"\")\n severity_label = \"Abnormal\"\n else:\n predicted_disease = predicted_class_name\n severity_label = \"Unknown\"\n \n print(f\" šŸ“Š Prediction: {predicted_class_name}\")\n print(f\" Confidence: {disease_confidence:.4f}\")\n print(f\" MIL Attention Weight: {attention_weights[0]:.4f}\")\n \n # Extract dataset and true label from path\n path_parts = img_path.split(os.sep)\n dataset_name = \"Unknown\"\n true_label = \"unknown\"\n \n for part in path_parts:\n if \"processed\" in part or \"cancer\" in part.lower():\n dataset_name = part\n if part in [\"normal\", \"abnormal\"]:\n true_label = part\n break\n \n # Store results\n result = {\n 'image': img_array,\n 'attention_heatmap': attention_heatmap,\n 'mil_attention_weight': attention_weights[0],\n 'true_label': true_label,\n 'dataset_name': dataset_name,\n 'predicted_class': predicted_class_name,\n 'predicted_disease': predicted_disease,\n 'severity': severity_label,\n 'class_idx': disease_pred_idx,\n 'confidence': disease_confidence,\n 'filename': os.path.basename(img_path),\n 'full_path': img_path\n }\n \n all_results.append(result)\n \n # Save individual attention heatmap (raw, for OpenCV processing)\n heatmap_filename = f\"attention_heatmap_{idx+1}_{os.path.splitext(os.path.basename(img_path))[0]}.npy\"\n heatmap_path = os.path.join(ATTENTION_OUTPUT_DIR, heatmap_filename)\n np.save(heatmap_path, attention_heatmap)\n \n print(f\" āœ… Attention heatmap saved to: {heatmap_filename}\")\n print(f\" āœ… Completed analysis\")\n \n except Exception as e:\n print(f\" āŒ Error: {e}\")\n import traceback\n traceback.print_exc()\n continue\n \n print(\"\\n\" + \"=\"*80)\n print(\"āœ… ATTENTION EXTRACTION COMPLETE\")\n print(f\"šŸ“ Results saved to: {ATTENTION_OUTPUT_DIR}\")\n print(\"=\"*80)\n \n return all_results\n\n# -------------------------------------------\n# Visualization Function\n# -------------------------------------------\ndef display_attention_grid(results):\n \"\"\"\n Display grid with original images and attention heatmap overlays\n \"\"\"\n if not results:\n print(\"No results to display\")\n return\n \n num_images = len(results)\n \n # Create figure: 3 columns (original, heatmap, overlay)\n fig, axes = plt.subplots(num_images, 3, figsize=(18, 6 * num_images))\n \n # Handle single image case\n if num_images == 1:\n axes = axes.reshape(1, -1)\n \n # Use 'jet' colormap for attention (blue to red)\n cmap = plt.cm.jet\n \n for i, result in enumerate(results):\n # Column 1: Original Image\n axes[i, 0].imshow(result['image'])\n title_text = (\n f\"Original Image {i+1}\\n\"\n f\"Dataset: {result['dataset_name']}\\n\"\n f\"True Label: {result['true_label']}\\n\"\n f\"File: {result['filename'][:30]}...\"\n )\n axes[i, 0].set_title(title_text, fontsize=10, fontweight='bold', pad=10)\n axes[i, 0].axis('off')\n \n # Column 2: Attention Heatmap\n im = axes[i, 1].imshow(result['attention_heatmap'], cmap=cmap)\n cbar = plt.colorbar(im, ax=axes[i, 1], fraction=0.046, pad=0.04)\n cbar.set_label('Attention Weight', rotation=270, labelpad=15)\n \n heatmap_title = (\n f\"Attention Heatmap {i+1}\\n\"\n f\"Patch-Level Importance\\n\"\n f\"MIL Weight: {result['mil_attention_weight']:.4f}\"\n )\n axes[i, 1].set_title(heatmap_title, fontsize=10, fontweight='bold', pad=10)\n axes[i, 1].axis('off')\n \n # Column 3: Overlay\n axes[i, 2].imshow(result['image'])\n axes[i, 2].imshow(result['attention_heatmap'], cmap=cmap, alpha=0.5)\n \n overlay_title = (\n f\"Overlay {i+1}\\n\"\n f\"Predicted: {result['predicted_class']}\\n\"\n f\"Confidence: {result['confidence']:.4f}\"\n )\n axes[i, 2].set_title(overlay_title, fontsize=10, fontweight='bold', pad=10)\n axes[i, 2].axis('off')\n \n # Print statistics\n high_attention = np.sum(result['attention_heatmap'] > 0.7) / result['attention_heatmap'].size * 100\n medium_attention = np.sum((result['attention_heatmap'] > 0.4) & \n (result['attention_heatmap'] <= 0.7)) / result['attention_heatmap'].size * 100\n low_attention = np.sum(result['attention_heatmap'] <= 0.4) / result['attention_heatmap'].size * 100\n \n print(f\"\\nšŸ“Š Image {i+1} ({result['filename']}) Attention Statistics:\")\n print(f\" Predicted: {result['predicted_class']}\")\n print(f\" Confidence: {result['confidence']:.4f}\")\n print(f\" MIL Attention Weight: {result['mil_attention_weight']:.4f}\")\n print(f\" High attention regions (>0.7): {high_attention:.1f}%\")\n print(f\" Medium attention regions (0.4-0.7): {medium_attention:.1f}%\")\n print(f\" Low attention regions (<0.4): {low_attention:.1f}%\")\n \n plt.suptitle(\n 'Hierarchical Model - Patch-Level Attention Visualization\\n'\n 'Warmer colors (red/yellow) indicate higher attention | Cooler colors (blue) indicate lower attention',\n fontsize=16,\n fontweight='bold',\n y=0.998\n )\n \n plt.tight_layout()\n \n # Save grid\n grid_save_path = os.path.join(ATTENTION_OUTPUT_DIR, 'attention_visualization_grid.png')\n plt.savefig(grid_save_path, dpi=150, bbox_inches='tight')\n print(f\"\\nāœ… Grid visualization saved to: {grid_save_path}\")\n \n plt.show()\n\n# -------------------------------------------\n# Execute Attention Analysis\n# -------------------------------------------\nprint(\"\\n\" + \"=\"*80)\nprint(\"CHECKING FOR COLLECTED AND PROCESSED IMAGES\")\nprint(\"=\"*80)\n\ntry:\n # Check if required variables exist\n if 'collected_images' in locals() or 'collected_images' in globals():\n if 'processed_images' in locals() or 'processed_images' in globals():\n print(f\"āœ… Found collected_images: {len(collected_images)} images\")\n print(f\"āœ… Found processed_images: {len(processed_images)} tensors\\n\")\n \n # Run attention analysis\n attention_results = run_attention_analysis(\n attention_extractor,\n device,\n collected_images,\n processed_images,\n DISEASE_CLASS_MAPPING,\n STAGE_CLASS_MAPPING\n )\n \n # Display results\n if attention_results:\n display_attention_grid(attention_results)\n \n # Save results summary\n results_summary = []\n for r in attention_results:\n high_attn = np.sum(r['attention_heatmap'] > 0.7) / r['attention_heatmap'].size * 100\n \n results_summary.append({\n 'filename': r['filename'],\n 'dataset': r['dataset_name'],\n 'true_label': r['true_label'],\n 'predicted_class': r['predicted_class'],\n 'confidence': r['confidence'],\n 'mil_attention_weight': r['mil_attention_weight'],\n 'high_attention_area_%': high_attn,\n 'max_attention': r['attention_heatmap'].max(),\n 'mean_attention': r['attention_heatmap'].mean()\n })\n \n summary_df = pd.DataFrame(results_summary)\n summary_path = os.path.join(ATTENTION_OUTPUT_DIR, 'attention_summary.csv')\n summary_df.to_csv(summary_path, index=False)\n print(f\"\\nāœ… Summary saved to: {summary_path}\")\n \n print(\"\\n\" + \"=\"*80)\n print(\"ATTENTION ANALYSIS SUMMARY\")\n print(\"=\"*80)\n print(summary_df.to_string(index=False))\n print(\"=\"*80)\n \n # Save attention results for next cell (OpenCV feature extraction)\n print(\"\\n\" + \"=\"*80)\n print(\"ATTENTION RESULTS READY FOR OPENCV PROCESSING\")\n print(\"=\"*80)\n print(f\"āœ… Variable 'attention_results' contains {len(attention_results)} results\")\n print(\"āœ… Each result includes:\")\n print(\" - Original image\")\n print(\" - Patch-level attention heatmap (smoothed)\")\n print(\" - MIL attention weight\")\n print(\" - Predictions and metadata\")\n print(\"\\nšŸ’” Use 'attention_results' in the next cell for OpenCV feature extraction\")\n print(\"=\"*80)\n else:\n print(\"\\nāŒ No results generated\")\n else:\n print(\"āŒ processed_images not found!\")\n print(\"Please run the preprocessing cell first\")\n else:\n print(\"āŒ collected_images not found!\")\n print(\"Please run the image collection cell first\")\n \nexcept Exception as e:\n print(f\"āŒ Error: {e}\")\n import traceback\n traceback.print_exc()","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Heatmap Feature Extractor","metadata":{}},{"cell_type":"code","source":"import numpy as np\nimport cv2\nfrom sklearn.cluster import DBSCAN\nimport matplotlib.pyplot as plt\nimport os\nfrom scipy.ndimage import maximum_filter\nfrom skimage.feature import graycomatrix, graycoprops\nfrom openai import OpenAI\nimport time\nfrom scipy.stats import spearmanr\n\n# ================================================================\n# HEATMAP FEATURE EXTRACTOR CLASS\n# Adapted for Hierarchical Model Pipeline\n# ================================================================\nclass HeatmapFeatureExtractor:\n \n def __init__(self, attention_result):\n \"\"\"\n attention_result is one entry from attention_results\n \"\"\"\n self.heatmap = attention_result['attention_heatmap']\n self.original_image = attention_result['image']\n self.prediction_info = {\n 'predicted_class': attention_result['predicted_class'],\n 'predicted_disease': attention_result['predicted_disease'],\n 'severity': attention_result['severity'],\n 'confidence': attention_result['confidence'],\n 'class_idx': attention_result['class_idx'],\n 'mil_attention': attention_result['mil_attention_weight']\n }\n self.true_label = attention_result['true_label']\n self.dataset_name = attention_result['dataset_name']\n self.filename = attention_result['filename']\n \n # ---------------------------------------------------------\n # METHOD 1: Brightest Region Analysis\n # ---------------------------------------------------------\n def get_brightest_region(self):\n \"\"\"\n IMPROVED: Comprehensive analysis of high-attention regions\n \"\"\"\n heatmap = self.heatmap.astype(float)\n H, W = heatmap.shape\n \n heatmap_norm = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)\n \n # 1. PRIMARY HOTSPOT\n brightest_idx = np.unravel_index(np.argmax(heatmap_norm), heatmap_norm.shape)\n y_bright, x_bright = brightest_idx\n intensity_bright = heatmap_norm[y_bright, x_bright]\n \n position_bright = self._get_anatomical_position(y_bright, x_bright, H, W)\n \n primary_hotspot = {\n \"pixel\": (int(y_bright), int(x_bright)),\n \"position\": position_bright,\n \"intensity\": float(intensity_bright)\n }\n \n # 2. SECONDARY HOTSPOTS\n secondary_hotspots = self._find_secondary_hotspots(heatmap_norm, H, W, threshold=0.6)\n \n # 3. ATTENTION PATTERN\n attention_pattern = self._determine_attention_pattern(heatmap_norm, H, W)\n \n # 4. SPATIAL COVERAGE\n spatial_coverage = self._calculate_spatial_coverage(heatmap_norm, H, W)\n \n # 5. HOTSPOT COUNT\n hotspot_count = 1 + len(secondary_hotspots)\n \n return {\n \"primary_hotspot\": primary_hotspot,\n \"secondary_hotspots\": secondary_hotspots,\n \"attention_pattern\": attention_pattern,\n \"spatial_coverage\": spatial_coverage,\n \"hotspot_count\": hotspot_count\n }\n\n def _get_anatomical_position(self, y, x, H, W):\n \"\"\"Convert pixel coordinates to descriptive position\"\"\"\n y_rel = y / H\n x_rel = x / W\n \n center_y, center_x = H / 2, W / 2\n dist_from_center = np.sqrt((y - center_y)**2 + (x - center_x)**2)\n max_dist = np.sqrt((H/2)**2 + (W/2)**2)\n dist_ratio = dist_from_center / max_dist\n \n center_threshold_inner = 0.35\n center_threshold_outer = 0.65\n periphery_threshold = 0.75\n \n if y_rel < center_threshold_inner:\n vert = \"upper\"\n elif y_rel > center_threshold_outer:\n vert = \"lower\"\n else:\n vert = \"mid\"\n \n if x_rel < center_threshold_inner:\n horiz = \"left\"\n elif x_rel > center_threshold_outer:\n horiz = \"right\"\n else:\n horiz = \"center\"\n \n if horiz == \"center\" and vert == \"mid\":\n position = \"center\"\n elif horiz == \"center\":\n position = f\"{vert}-center\"\n elif vert == \"mid\":\n position = f\"{horiz}-center\"\n else:\n position = f\"{vert}-{horiz}\"\n \n if dist_ratio > periphery_threshold:\n position = f\"{position} (periphery)\"\n elif dist_ratio < 0.3:\n position = f\"{position} (core)\"\n \n return position\n \n def _find_secondary_hotspots(self, heatmap_norm, H, W, threshold=0.6, min_distance=20):\n \"\"\"Find additional significant attention regions\"\"\"\n secondary = []\n \n neighborhood_size = max(10, min(H, W) // 20)\n local_max = maximum_filter(heatmap_norm, size=neighborhood_size)\n \n peaks = (heatmap_norm == local_max) & (heatmap_norm > threshold * heatmap_norm.max())\n peak_coords = np.argwhere(peaks)\n peak_intensities = heatmap_norm[peaks]\n sorted_indices = np.argsort(peak_intensities)[::-1]\n \n primary_y, primary_x = np.unravel_index(np.argmax(heatmap_norm), heatmap_norm.shape)\n \n for idx in sorted_indices[:5]:\n y, x = peak_coords[idx]\n \n if np.sqrt((y - primary_y)**2 + (x - primary_x)**2) < min_distance:\n continue\n \n too_close = False\n for existing in secondary:\n ey, ex = existing[\"pixel\"]\n if np.sqrt((y - ey)**2 + (x - ex)**2) < min_distance:\n too_close = True\n break\n \n if too_close:\n continue\n \n position = self._get_anatomical_position(y, x, H, W)\n intensity = float(heatmap_norm[y, x])\n \n secondary.append({\n \"pixel\": (int(y), int(x)),\n \"position\": position,\n \"intensity\": intensity\n })\n \n return secondary\n\n def _determine_attention_pattern(self, heatmap_norm, H, W):\n \"\"\"Determine overall attention distribution pattern\"\"\"\n center_y, center_x = H // 2, W // 2\n \n Y, X = np.ogrid[:H, :W]\n dist_from_center = np.sqrt((Y - center_y)**2 + (X - center_x)**2)\n max_dist = np.sqrt((H/2)**2 + (W/2)**2)\n \n core_mask = dist_from_center < (max_dist * 0.3)\n mid_mask = (dist_from_center >= max_dist * 0.3) & (dist_from_center < max_dist * 0.7)\n periphery_mask = dist_from_center >= (max_dist * 0.7)\n \n core_attention = np.mean(heatmap_norm[core_mask])\n mid_attention = np.mean(heatmap_norm[mid_mask])\n periphery_attention = np.mean(heatmap_norm[periphery_mask])\n \n high_attention_pixels = np.sum(heatmap_norm > 0.7) / heatmap_norm.size\n \n if core_attention > 0.7 and core_attention > mid_attention * 1.5:\n return \"centralized (focused on center)\"\n elif periphery_attention > 0.7 and periphery_attention > core_attention * 1.5:\n return \"peripheral (focused on edges)\"\n elif mid_attention > core_attention and mid_attention > periphery_attention:\n return \"ring-like (donut pattern)\"\n elif high_attention_pixels > 0.5:\n return \"diffuse (spread across image)\"\n elif high_attention_pixels > 0.1 and high_attention_pixels < 0.3:\n return \"focal (single concentrated region)\"\n else:\n return \"scattered (multiple regions)\"\n\n def _calculate_spatial_coverage(self, heatmap_norm, H, W):\n \"\"\"Calculate percentage of attention in each spatial region\"\"\"\n center_y, center_x = H // 2, W // 2\n \n Y, X = np.ogrid[:H, :W]\n dist_from_center = np.sqrt((Y - center_y)**2 + (X - center_x)**2)\n max_dist = np.sqrt((H/2)**2 + (W/2)**2)\n \n core_mask = dist_from_center < (max_dist * 0.3)\n mid_mask = (dist_from_center >= max_dist * 0.3) & (dist_from_center < max_dist * 0.7)\n periphery_mask = dist_from_center >= (max_dist * 0.7)\n \n total_attention = np.sum(heatmap_norm)\n core_sum = np.sum(heatmap_norm[core_mask])\n mid_sum = np.sum(heatmap_norm[mid_mask])\n periphery_sum = np.sum(heatmap_norm[periphery_mask])\n \n return {\n \"center_attention\": float(core_sum / total_attention * 100) if total_attention > 0 else 0,\n \"mid_region_attention\": float(mid_sum / total_attention * 100) if total_attention > 0 else 0,\n \"periphery_attention\": float(periphery_sum / total_attention * 100) if total_attention > 0 else 0\n }\n\n # ---------------------------------------------------------\n # METHOD 2: Scatter Analysis\n # ---------------------------------------------------------\n def get_activation_scatter(self, threshold_ratio=0.6):\n \"\"\"Determine if heatmap is focused or scattered\"\"\"\n heatmap = self.heatmap.astype(float)\n H, W = heatmap.shape\n\n heatmap_norm = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-6)\n thresh = threshold_ratio * heatmap_norm.max()\n high_pixels = np.argwhere(heatmap_norm >= thresh)\n\n if len(high_pixels) == 0:\n return {\n \"scatter_level\": \"low\",\n \"num_clusters\": 0,\n \"clusters_sizes\": []\n }\n\n clustering = DBSCAN(eps=8, min_samples=20).fit(high_pixels)\n labels = clustering.labels_\n unique_labels = [lb for lb in np.unique(labels) if lb != -1]\n\n cluster_sizes = []\n for lb in unique_labels:\n cluster_sizes.append(int(np.sum(labels == lb)))\n\n num_clusters = len(unique_labels)\n\n if num_clusters == 1:\n scatter = \"low\"\n elif 2 <= num_clusters <= 3:\n scatter = \"medium\"\n else:\n scatter = \"high\"\n\n return {\n \"scatter_level\": scatter,\n \"num_clusters\": num_clusters,\n \"clusters_sizes\": cluster_sizes\n }\n \n # ---------------------------------------------------------\n # METHOD 3: Dominant Color Analysis\n # ---------------------------------------------------------\n def get_dominant_focus_color(self, threshold_ratio=0.6, k_clusters=5):\n \"\"\"Detect dominant color in attention-focused regions\"\"\"\n heatmap = self.heatmap.astype(float)\n orig = self.original_image.copy()\n \n H, W = heatmap.shape\n heatmap_norm = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-6)\n mask = (heatmap_norm >= threshold_ratio).astype(np.uint8) * 255\n \n if np.sum(mask) == 0:\n return {\n \"dominant_color_rgb\": None,\n \"dominant_color_hsv\": None,\n \"dominant_color_name\": \"none\",\n \"color_confidence\": 0.0\n }\n \n kernel = np.ones((3, 3), np.uint8)\n mask_clean = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2)\n contours, _ = cv2.findContours(mask_clean, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n\n if len(contours) == 0:\n return {\n \"dominant_color_rgb\": None,\n \"dominant_color_hsv\": None,\n \"dominant_color_name\": \"none\",\n \"color_confidence\": 0.0\n }\n \n activation_mask = np.zeros_like(mask_clean)\n cv2.drawContours(activation_mask, contours, -1, 255, -1)\n focus_pixels = orig[activation_mask == 255]\n \n if len(focus_pixels) < 10:\n return {\n \"dominant_color_rgb\": None,\n \"dominant_color_hsv\": None,\n \"dominant_color_name\": \"none\",\n \"color_confidence\": 0.0\n }\n\n Z = np.float32(focus_pixels)\n criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 20, 1.0)\n K = k_clusters\n \n _, labels, centers = cv2.kmeans(\n Z, K, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS\n )\n \n counts = np.bincount(labels.flatten())\n sorted_indices = np.argsort(counts)[::-1]\n \n dominant_index = None\n dominant_color = None\n \n for idx in sorted_indices:\n candidate_color = centers[idx].astype(int)\n hsv_color = cv2.cvtColor(\n np.uint8([[candidate_color]]), \n cv2.COLOR_RGB2HSV\n )[0][0]\n \n if hsv_color[1] > 30:\n dominant_index = idx\n dominant_color = candidate_color\n break\n \n if dominant_index is None:\n dominant_index = sorted_indices[0]\n dominant_color = centers[dominant_index].astype(int)\n \n hsv_color = cv2.cvtColor(\n np.uint8([[dominant_color]]), \n cv2.COLOR_RGB2HSV\n )[0][0]\n \n color_confidence = counts[dominant_index] / len(labels) * 100\n dominant_name = self._map_color_to_name(\n dominant_color.tolist(), \n hsv_color.tolist()\n )\n \n return {\n \"dominant_color_rgb\": dominant_color.tolist(),\n \"dominant_color_hsv\": hsv_color.tolist(),\n \"dominant_color_name\": dominant_name,\n \"color_confidence\": float(color_confidence)\n }\n\n def _map_color_to_name(self, rgb, hsv=None):\n \"\"\"Enhanced color naming using HSV color space\"\"\"\n r, g, b = rgb\n \n if hsv is None:\n hsv_array = cv2.cvtColor(np.uint8([[rgb]]), cv2.COLOR_RGB2HSV)[0][0]\n h, s, v = hsv_array.tolist()\n else:\n h, s, v = hsv\n \n if s < 30:\n if v > 200:\n return \"white / very light\"\n elif v > 150:\n return \"light gray / pale\"\n elif v > 80:\n return \"gray\"\n else:\n return \"dark gray / black\"\n \n if v < 60:\n return \"very dark / black\"\n \n if 130 <= h <= 160:\n if s > 100:\n return \"purple / violet\"\n else:\n return \"light purple / lavender\"\n \n if 160 <= h <= 180 or h <= 10:\n if v > 180 and s < 100:\n return \"pink / light red\"\n elif s > 150:\n return \"magenta / bright pink\"\n else:\n return \"pink / rose\"\n \n if h <= 10:\n if v < 150:\n return \"dark red / maroon\"\n else:\n return \"red / crimson\"\n \n if 10 <= h < 25:\n if v < 130:\n return \"brown / dark tan\"\n else:\n return \"orange / tan\"\n \n if 25 <= h < 40:\n if s < 80:\n return \"beige / cream\"\n else:\n return \"yellow / golden\"\n \n if 40 <= h < 80:\n if v > 180:\n return \"light green / pale green\"\n elif v > 120:\n return \"green\"\n else:\n return \"dark green\"\n \n if 80 <= h < 100:\n return \"cyan / turquoise\"\n \n if 100 <= h < 130:\n if s > 150:\n return \"blue / deep blue\"\n elif v > 180:\n return \"light blue / sky blue\"\n else:\n return \"blue\"\n \n max_channel = max(r, g, b)\n if max_channel == r:\n return \"reddish tones\"\n elif max_channel == g:\n return \"greenish tones\"\n elif max_channel == b:\n return \"bluish tones\"\n else:\n return \"mixed color region\"\n\n # ---------------------------------------------------------\n # METHOD 4: Texture Analysis\n # ---------------------------------------------------------\n def get_texture_analysis(self, threshold_ratio=0.6):\n \"\"\"\n Analyze texture patterns in high-attention regions using GLCM\n \n Returns generic image-based descriptions without medical assumptions\n \"\"\"\n heatmap = self.heatmap.astype(float)\n orig = self.original_image.copy()\n \n H, W = heatmap.shape\n heatmap_norm = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-6)\n mask = (heatmap_norm >= threshold_ratio).astype(np.uint8) * 255\n \n if np.sum(mask) < 100:\n return {\n \"texture_classification\": \"insufficient data\",\n \"texture_description\": \"Not enough attention data to analyze texture\",\n \"texture_scores\": {\n \"uniformity\": 0,\n \"organization\": 0,\n \"complexity\": 0,\n \"smoothness\": 0\n },\n \"glcm_features\": {\n \"contrast\": 0.0,\n \"correlation\": 0.0,\n \"energy\": 0.0,\n \"homogeneity\": 0.0\n }\n }\n \n # Clean mask\n kernel = np.ones((3, 3), np.uint8)\n mask_clean = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=2)\n \n # Extract focused region from original image\n contours, _ = cv2.findContours(mask_clean, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)\n \n if len(contours) == 0:\n return {\n \"texture_classification\": \"no region detected\",\n \"texture_description\": \"No focused region detected\",\n \"texture_scores\": {\n \"uniformity\": 0,\n \"organization\": 0,\n \"complexity\": 0,\n \"smoothness\": 0\n },\n \"glcm_features\": {\n \"contrast\": 0.0,\n \"correlation\": 0.0,\n \"energy\": 0.0,\n \"homogeneity\": 0.0\n }\n }\n \n # Create activation mask\n activation_mask = np.zeros_like(mask_clean)\n cv2.drawContours(activation_mask, contours, -1, 255, -1)\n \n # Get bounding box of the region\n y_coords, x_coords = np.where(activation_mask == 255)\n if len(y_coords) == 0:\n return {\n \"texture_classification\": \"invalid region\",\n \"texture_description\": \"Invalid region for texture analysis\",\n \"texture_scores\": {\n \"uniformity\": 0,\n \"organization\": 0,\n \"complexity\": 0,\n \"smoothness\": 0\n },\n \"glcm_features\": {\n \"contrast\": 0.0,\n \"correlation\": 0.0,\n \"energy\": 0.0,\n \"homogeneity\": 0.0\n }\n }\n \n y_min, y_max = y_coords.min(), y_coords.max()\n x_min, x_max = x_coords.min(), x_coords.max()\n \n # Extract region\n region_rgb = orig[y_min:y_max+1, x_min:x_max+1]\n region_mask = activation_mask[y_min:y_max+1, x_min:x_max+1]\n \n # Convert to grayscale for texture analysis\n region_gray = cv2.cvtColor(region_rgb, cv2.COLOR_RGB2GRAY)\n \n # Apply mask to focus only on high-attention pixels\n region_gray_masked = region_gray.copy()\n region_gray_masked[region_mask == 0] = 0\n \n # Quantize to reduce GLCM computation (64 levels)\n region_quantized = (region_gray_masked / 4).astype(np.uint8)\n \n # Compute GLCM\n # distances: [1] means immediate neighbors\n # angles: [0, Ļ€/4, Ļ€/2, 3Ļ€/4] for rotation invariance\n distances = [1]\n angles = [0, np.pi/4, np.pi/2, 3*np.pi/4]\n \n try:\n glcm = graycomatrix(\n region_quantized,\n distances=distances,\n angles=angles,\n levels=64,\n symmetric=True,\n normed=True\n )\n \n # Extract features (averaged across all angles)\n contrast = float(graycoprops(glcm, 'contrast')[0].mean())\n correlation = float(graycoprops(glcm, 'correlation')[0].mean())\n energy = float(graycoprops(glcm, 'energy')[0].mean())\n homogeneity = float(graycoprops(glcm, 'homogeneity')[0].mean())\n \n except Exception as e:\n print(f\" Warning: GLCM computation failed: {e}\")\n return {\n \"texture_classification\": \"computation error\",\n \"texture_description\": \"Error computing texture features\",\n \"texture_scores\": {\n \"uniformity\": 0,\n \"organization\": 0,\n \"complexity\": 0,\n \"smoothness\": 0\n },\n \"glcm_features\": {\n \"contrast\": 0.0,\n \"correlation\": 0.0,\n \"energy\": 0.0,\n \"homogeneity\": 0.0\n }\n }\n \n # Convert to 0-100 scores\n uniformity_score = int(energy * 100)\n organization_score = int(max(0, min(100, (correlation + 1) * 50))) # Scale -1,1 to 0,100\n complexity_score = int((1 - energy) * 100)\n smoothness_score = int(homogeneity * 100)\n \n # Classify texture based on GLCM features\n classification, description = self._classify_texture(\n contrast, correlation, energy, homogeneity\n )\n \n return {\n \"texture_classification\": classification,\n \"texture_description\": description,\n \"texture_scores\": {\n \"uniformity\": uniformity_score,\n \"organization\": organization_score,\n \"complexity\": complexity_score,\n \"smoothness\": smoothness_score\n },\n \"glcm_features\": {\n \"contrast\": round(contrast, 2),\n \"correlation\": round(correlation, 3),\n \"energy\": round(energy, 3),\n \"homogeneity\": round(homogeneity, 3)\n }\n }\n \n def _classify_texture(self, contrast, correlation, energy, homogeneity):\n \"\"\"\n Classify texture based on GLCM features\n Returns (classification, description) tuple\n \"\"\"\n # Rule-based classification\n if contrast < 100 and homogeneity > 0.8:\n classification = \"uniform and smooth\"\n description = \"Model focused on a region with smooth, uniform texture showing consistent patterns with minimal variation\"\n \n elif correlation > 0.7 and energy > 0.3:\n classification = \"structured and regular\"\n description = \"Model focused on a region with organized, structured patterns exhibiting regular, repeating elements\"\n \n elif contrast > 400 and correlation < 0.4:\n classification = \"irregular and chaotic\"\n description = \"Model focused on a region with irregular, chaotic texture displaying highly variable patterns with no clear organization\"\n \n elif homogeneity < 0.5:\n classification = \"rough and coarse\"\n description = \"Model focused on a region with rough, coarse texture showing sharp intensity changes and abrupt transitions\"\n \n elif energy < 0.2:\n classification = \"complex and varied\"\n description = \"Model focused on a region with complex, varied texture containing multiple different patterns and high visual diversity\"\n \n else:\n classification = \"moderate texture\"\n description = \"Model focused on a region with moderate texture complexity showing intermediate characteristics\"\n \n return classification, description","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Final Output Grid","metadata":{}},{"cell_type":"code","source":"import warnings\nwarnings.filterwarnings(\"ignore\")\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.utils.data import Dataset, DataLoader\nimport torchvision.transforms as transforms\nfrom PIL import Image\nimport os\nimport numpy as np\nimport pandas as pd\nfrom transformers import Dinov2Model\nimport openslide\nfrom tqdm import tqdm\nimport logging\nimport matplotlib\nimport matplotlib.pyplot as plt\nimport matplotlib.patches as mpatches\nfrom matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec\nfrom openai import OpenAI\nimport time\n\nlogging.basicConfig(level=logging.INFO)\n\nMODEL_PATH = \"/kaggle/input/models/ulimaank/updated-diagnostic-model-jan-18/other/default/1/phase3_mil_best.pth\"\nOUTPUT_DIR = \"/kaggle/working\"\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nprint(f\"Device: {device}, GPUs: {torch.cuda.device_count()}\")\n\nDISEASE_NAMES = [\n 'Breast_cancer', 'annrbc-anemia_processed', 'colon_processed',\n 'leukemia_processed', 'lung_processed', 'oral-cancer_processed',\n 'ovarian-cancer_processed', 'sickle-cell-new_processed', 'thalassemia_processed'\n]\n\nSTAGE_NAMES = {\n 0: 'Breast_cancer - ductal_carcinoma',\n 1: 'Breast_cancer - lobular_carcinoma',\n 2: 'Breast_cancer - mucinous_carcinoma',\n 3: 'Breast_cancer - papillary_carcinoma',\n 4: 'leukemia_processed - Early',\n 5: 'leukemia_processed - Pre',\n 6: 'leukemia_processed - Pro',\n 7: 'lung_processed - lung_aca',\n 8: 'lung_processed - lung_scc',\n 9: 'ovarian-cancer_processed - CC',\n 10: 'ovarian-cancer_processed - EC',\n 11: 'ovarian-cancer_processed - HGSC',\n 12: 'ovarian-cancer_processed - LGSC',\n 13: 'ovarian-cancer_processed - MC'\n}\n\nDISEASE_CLASS_MAPPING = {\n 0: \"Breast_cancer\",\n 1: \"annrbc-anemia_processed\",\n 2: \"colon_processed\",\n 3: \"leukemia_processed\",\n 4: \"lung_processed\",\n 5: \"oral-cancer_processed\",\n 6: \"ovarian-cancer_processed\",\n 7: \"sickle-cell-new_processed\",\n 8: \"thalassemia_processed\",\n}\n\nSTAGE_CLASS_MAPPING = STAGE_NAMES\n\nTARGET_SIZE = 256\nstandardize_transform = transforms.Resize((TARGET_SIZE, TARGET_SIZE))\n\n\n# ================================================================\n# DATA COLLECTION\n# ================================================================\n\ndef collect_images_from_folder(folder_path):\n images = []\n valid_extensions = ('.svs', '.tif', '.ndpi', '.png', '.jpg', '.jpeg', '.tiff')\n for root, dirs, files in os.walk(folder_path):\n for f in files:\n if f.lower().endswith(valid_extensions):\n images.append(os.path.join(root, f))\n return images\n\n\n# ================================================================\n# DATASET\n# ================================================================\n\nclass SimpleSlideDataset(Dataset):\n def __init__(self, image_paths, tile_size=224, max_tiles=1000):\n self.image_paths = image_paths\n self.tile_size = tile_size\n self.max_tiles = max_tiles\n\n def __len__(self):\n return len(self.image_paths)\n\n def __getitem__(self, idx):\n slide_path = self.image_paths[idx]\n tiles = []\n try:\n if slide_path.lower().endswith(('.svs', '.ndpi')):\n slide = openslide.OpenSlide(slide_path)\n width, height = slide.dimensions\n for y in range(0, height, self.tile_size):\n for x in range(0, width, self.tile_size):\n if len(tiles) >= self.max_tiles:\n break\n tile = slide.read_region((x, y), 0, (self.tile_size, self.tile_size)).convert('RGB')\n tiles.append(standardize_transform(tile))\n if len(tiles) >= self.max_tiles:\n break\n slide.close()\n elif slide_path.lower().endswith('.tif'):\n try:\n slide = openslide.OpenSlide(slide_path)\n width, height = slide.dimensions\n for y in range(0, height, self.tile_size):\n for x in range(0, width, self.tile_size):\n if len(tiles) >= self.max_tiles:\n break\n tile = slide.read_region((x, y), 0, (self.tile_size, self.tile_size)).convert('RGB')\n tiles.append(standardize_transform(tile))\n if len(tiles) >= self.max_tiles:\n break\n slide.close()\n except openslide.OpenSlideError:\n tiles = [standardize_transform(Image.open(slide_path).convert('RGB'))]\n except Exception:\n tiles = [standardize_transform(Image.open(slide_path).convert('RGB'))]\n else:\n tiles = [standardize_transform(Image.open(slide_path).convert('RGB'))]\n\n if not tiles:\n raise ValueError(\"No tiles extracted\")\n return tiles, slide_path\n except Exception as e:\n logging.error(f\"Error processing slide {slide_path}: {e}\")\n return [], slide_path\n\n\ntest_transform = transforms.Compose([\n transforms.Resize(224),\n transforms.ToTensor(),\n transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n])\n\n\ndef simple_collate(batch):\n valid_batch = [item for item in batch if item[0]]\n if not valid_batch:\n return [], []\n tiles_list, paths = zip(*valid_batch)\n processed_tiles = [torch.stack([test_transform(tile) for tile in tiles]) for tiles in tiles_list]\n return processed_tiles, list(paths)\n\n\n# ================================================================\n# MODEL ARCHITECTURE\n# ================================================================\n\nclass ViTBackbone(nn.Module):\n def __init__(self):\n super().__init__()\n self.vit = Dinov2Model.from_pretrained(\"owkin/phikon-v2\")\n\n def forward(self, x):\n return self.vit(pixel_values=x).last_hidden_state[:, 0]\n\n\nclass ClassificationHead(nn.Module):\n def __init__(self, in_dim=1024, num_classes=2, hidden_dim=512):\n super().__init__()\n self.classifier = nn.Sequential(\n nn.Linear(in_dim, hidden_dim),\n nn.ReLU(),\n nn.Dropout(0.3),\n nn.Linear(hidden_dim, num_classes)\n )\n\n def forward(self, x):\n return self.classifier(x)\n\n\nclass HierarchicalMILAggregator(nn.Module):\n def __init__(self, embed_dim=1024, num_heads=8, num_layers=2,\n num_diseases=6, num_stage_classes=0, disease_names=None):\n super().__init__()\n self.pre_norm = nn.LayerNorm(embed_dim)\n encoder_layer = nn.TransformerEncoderLayer(\n d_model=embed_dim, nhead=num_heads, batch_first=True, dropout=0.1)\n self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)\n self.attention = nn.Sequential(nn.Linear(embed_dim, 256), nn.Tanh(), nn.Linear(256, 1))\n self.disease_head = ClassificationHead(embed_dim, num_diseases)\n self.severity_heads = nn.ModuleDict()\n for name in disease_names:\n self.severity_heads[name] = ClassificationHead(embed_dim, 2)\n self.stage_head = ClassificationHead(embed_dim, num_stage_classes) if num_stage_classes > 0 else None\n self.disease_name_to_idx = {n: i for i, n in enumerate(disease_names)}\n self.idx_to_disease_name = {i: n for n, i in self.disease_name_to_idx.items()}\n self.disease_names = disease_names\n\n def forward(self, tile_features):\n normalized = self.pre_norm(tile_features)\n aggregated = self.transformer(normalized)\n attn_scores = self.attention(aggregated)\n attn_weights = torch.softmax(attn_scores.squeeze(-1), dim=1)\n weighted = torch.sum(aggregated * attn_weights.unsqueeze(-1), dim=1)\n disease_logits = self.disease_head(weighted)\n severity_logits = {n: self.severity_heads[n](weighted) for n in self.disease_names}\n stage_logits = self.stage_head(weighted) if self.stage_head is not None else None\n return disease_logits, severity_logits, stage_logits, attn_weights\n\n\nclass Phase3Model(nn.Module):\n def __init__(self, backbone, num_diseases=6, num_stage_classes=0, disease_names=None):\n super().__init__()\n self.backbone = backbone\n for param in self.backbone.parameters():\n param.requires_grad = False\n self.aggregator = HierarchicalMILAggregator(\n num_diseases=num_diseases,\n num_stage_classes=num_stage_classes,\n disease_names=disease_names\n )\n\n def forward(self, tiles, enable_gradients=False):\n all_features = []\n for batch_tiles in tiles:\n if batch_tiles.numel() == 0:\n continue\n batch_tiles = batch_tiles.to(next(self.backbone.parameters()).device)\n if enable_gradients:\n batch_features = self.backbone(batch_tiles)\n else:\n with torch.no_grad():\n batch_features = self.backbone(batch_tiles)\n all_features.append(batch_features)\n if not all_features:\n raise ValueError(\"No valid tile features could be extracted.\")\n all_features = torch.stack(all_features)\n return self.aggregator(all_features)\n\n\n# ================================================================\n# PREDICTION\n# ================================================================\n\ndef predict_image(model, tiles, disease_names, stage_names):\n model.eval()\n try:\n with torch.no_grad():\n disease_logits, severity_logits, stage_logits, _ = model(tiles)\n disease_probs = F.softmax(disease_logits, dim=1)\n disease_pred_idx = torch.argmax(disease_probs, dim=1).item()\n disease_confidence = disease_probs[0, disease_pred_idx].item()\n predicted_disease = disease_names[disease_pred_idx]\n\n severity_probs = F.softmax(severity_logits[predicted_disease], dim=1)\n severity_pred = torch.argmax(severity_probs, dim=1).item()\n severity_confidence = severity_probs[0, severity_pred].item()\n severity_label = \"Normal\" if severity_pred == 0 else \"Abnormal\"\n\n stage_label = \"N/A\"\n stage_confidence = 0.0\n if severity_pred == 1 and stage_logits is not None:\n stage_probs = F.softmax(stage_logits, dim=1)\n stage_pred_idx = torch.argmax(stage_probs, dim=1).item()\n stage_confidence = stage_probs[0, stage_pred_idx].item()\n stage_label = stage_names.get(stage_pred_idx, f\"Stage_{stage_pred_idx}\")\n\n return {\n 'disease': predicted_disease,\n 'disease_confidence': disease_confidence,\n 'severity': severity_label,\n 'severity_confidence': severity_confidence,\n 'stage': stage_label,\n 'stage_confidence': stage_confidence\n }\n except Exception as e:\n logging.error(f\"Error during prediction: {e}\")\n return None\n\n\n# ================================================================\n# GPT EXPLANATION GENERATOR\n# ================================================================\n\ndef generate_comprehensive_explanation(comprehensive_data):\n \"\"\"\n Calls GPT-4o-mini to convert technical XAI metrics into a human-friendly\n explanation. Falls back to a template string if the API call fails.\n Only references Attention and GradCAM heatmaps.\n \"\"\"\n try:\n client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))\n\n prompt = f\"\"\"You are an AI explainability assistant helping users understand how a hierarchical medical image classification model made its decision. Convert the following technical analysis into a clear, accessible explanation.\n\nHIERARCHICAL MODEL PREDICTION:\n- Region: {comprehensive_data['predicted_disease']} ({comprehensive_data['gradcam_disease_conf']:.1%} confidence)\n- Status Level: {comprehensive_data['gradcam_severity']} ({comprehensive_data['gradcam_severity_conf']:.1%} confidence)\n- Stage Level: {comprehensive_data['predicted_stage']} ({comprehensive_data['stage_confidence']:.1%} confidence)\n\nGRADCAM ANALYSIS (Gradient-weighted Class Activation Mapping):\n- Note: Bright/warm regions in GradCAM indicate areas that most strongly influenced the model's prediction\n\nSPATIAL ATTENTION PATTERN AND VISUAL CHARACTERISTICS (from Attention Heatmap):\n- Primary Focus: {comprehensive_data['primary_position']} (intensity: {comprehensive_data['primary_intensity']:.2f})\n- Attention Hotspots: {comprehensive_data['hotspot_count']}\n- Spatial Distribution: Center {comprehensive_data['center_attention']:.1f}%, Mid-region {comprehensive_data['mid_attention']:.1f}%, Periphery {comprehensive_data['periphery_attention']:.1f}%\n- Clustering: {comprehensive_data['scatter_level']} scatter level with {comprehensive_data['num_clusters']} clusters\n- Dominant Color: {comprehensive_data['dominant_color']} ({comprehensive_data['color_confidence']:.1f}% confidence)\n- Texture Pattern: {comprehensive_data['texture_classification']}\n- Texture Scores: Uniformity {comprehensive_data['uniformity']}/100, Organization {comprehensive_data['organization']}/100, Complexity {comprehensive_data['complexity']}/100, Smoothness {comprehensive_data['smoothness']}/100\n\nCRITICAL INSTRUCTIONS:\n1. Write in clear, accessible language for someone without medical or technical expertise\n2. Ground ALL statements in the provided data - do NOT add medical interpretations or diagnoses\n3. Explain how the two explainability methods (Attention Heatmap and GradCAM) show WHERE the model focused\n4. Describe WHAT visual patterns were detected, not WHY medically\n5. Keep it concise but informative (under 100 words)\n6. Structure with clear sections\n7. Make it conversational but professional\n8. Visual Characteristics and Spatial Attention Pattern were taken from Attention Heatmap\n\nGenerate a comprehensive explanation covering: what the model decided, where it looked, what the Attention and GradCAM methods revealed, what visual characteristics were important, and how confident we can be in the decision.\n\nFormat as natural paragraphs, not bullet points.\"\"\"\n\n response = client.chat.completions.create(\n model=\"gpt-4o-mini\",\n messages=[\n {\"role\": \"system\", \"content\": \"You are an expert at explaining complex AI model decisions in simple, clear language. You help users understand model behavior without making medical claims.\"},\n {\"role\": \"user\", \"content\": prompt}\n ],\n temperature=0.7,\n max_tokens=700\n )\n return response.choices[0].message.content.strip()\n\n except Exception as e:\n logging.warning(f\"OpenAI API call failed: {e}. Using fallback template.\")\n return (\n f\"MODEL DECISION SUMMARY\\n\\n\"\n f\"The model classified this as '{comprehensive_data['predicted_disease']}' \"\n f\"with severity '{comprehensive_data['gradcam_severity']}' \"\n f\"and stage '{comprehensive_data['predicted_stage']}' \"\n f\"({comprehensive_data['stage_confidence']:.1%} confidence).\\n\\n\"\n f\"ATTENTION ANALYSIS\\n\\n\"\n f\"Primary focus: {comprehensive_data['primary_position']} region. \"\n f\"Attention shows {comprehensive_data['scatter_level']} scatter across \"\n f\"{comprehensive_data['num_clusters']} clusters. \"\n f\"Distribution - Center: {comprehensive_data['center_attention']:.1f}%, \"\n f\"Mid: {comprehensive_data['mid_attention']:.1f}%, \"\n f\"Periphery: {comprehensive_data['periphery_attention']:.1f}%.\\n\\n\"\n f\"GRADCAM ANALYSIS\\n\\n\"\n f\"GradCAM confidence: {comprehensive_data['gradcam_disease_conf']:.1%} (disease), \"\n f\"{comprehensive_data['gradcam_severity_conf']:.1%} (severity).\\n\\n\"\n f\"VISUAL PATTERNS\\n\\n\"\n f\"Dominant color: {comprehensive_data['dominant_color']}. \"\n f\"Texture: {comprehensive_data['texture_classification']} \"\n f\"(uniformity {comprehensive_data['uniformity']}/100, \"\n f\"smoothness {comprehensive_data['smoothness']}/100).\"\n )\n\n\n# ================================================================\n# RENDER EXPLANATION TEXT -> RGB NUMPY ARRAY FOR ROW 3\n# ================================================================\n\ndef _render_explanation_to_image(explanation_text, figsize=(16, 4)):\n \"\"\"\n Renders a plain-text explanation string into an (H, W, 3) uint8 numpy\n array that fills the entire Row 3 panel in display_prediction().\n \"\"\"\n fig, ax = plt.subplots(figsize=figsize, facecolor='#0F0F2A')\n fig.subplots_adjust(left=0, right=1, top=1, bottom=0)\n ax.set_facecolor('#0F0F2A')\n ax.set_xlim(0, 1)\n ax.set_ylim(0, 1)\n ax.axis('off')\n\n ax.add_patch(mpatches.FancyBboxPatch(\n (0.0, 0.0), 1.0, 1.0,\n boxstyle=\"round,pad=0.01\",\n linewidth=2,\n edgecolor='#3498DB',\n facecolor='#16213E',\n transform=ax.transAxes,\n clip_on=False\n ))\n\n ax.text(\n 0.5, 0.93,\n 'Textual Explanation',\n ha='center', va='top',\n fontsize=11,\n fontweight='bold',\n color='#3498DB',\n transform=ax.transAxes\n )\n\n ax.add_line(plt.Line2D(\n [0.02, 0.98], [0.855, 0.855],\n transform=ax.transAxes,\n color='#3498DB',\n linewidth=1.0\n ))\n\n ax.text(\n 0.02, 0.83,\n explanation_text,\n va='top', ha='left',\n fontsize=10,\n color='#E0E0F0',\n family='monospace',\n wrap=True,\n transform=ax.transAxes\n )\n\n fig.canvas.draw()\n buf = fig.canvas.buffer_rgba()\n img_array = np.frombuffer(buf, dtype=np.uint8).reshape(\n fig.canvas.get_width_height()[::-1] + (4,)\n )\n plt.close(fig)\n return img_array[:, :, :3]\n\n\n# ================================================================\n# ATTENTION OVERLAY HELPERS\n# ================================================================\n\ndef _build_attention_overlay(img_array, heatmap_raw):\n \"\"\"\n Takes img_array (H, W, 3) uint8 and heatmap_raw (H, W) float,\n returns an RGB overlay as np.ndarray in [0, 1].\n \"\"\"\n import cv2\n img_norm = img_array.astype(np.float32) / 255.0\n hm = heatmap_raw.astype(np.float32)\n\n h, w = img_norm.shape[:2]\n if hm.shape != (h, w):\n hm = cv2.resize(hm, (w, h), interpolation=cv2.INTER_CUBIC)\n\n hm_min, hm_max = hm.min(), hm.max()\n if hm_max > hm_min:\n hm = (hm - hm_min) / (hm_max - hm_min)\n\n cmap = matplotlib.colormaps.get_cmap('jet')\n hm_colored = cmap(hm)[:, :, :3]\n overlay = img_norm * 0.5 + hm_colored * 0.5\n return np.clip(overlay, 0, 1)\n\n\ndef _preprocess_images_for_attention(image_paths):\n processed = []\n for p in image_paths:\n try:\n img_pil = Image.open(p).convert('RGB').resize((224, 224), Image.BILINEAR)\n tensor = test_transform(img_pil)\n processed.append(tensor)\n except Exception as e:\n logging.warning(f\"Could not preprocess {p} for attention: {e}\")\n processed.append(torch.zeros(3, 224, 224))\n return processed\n\n\n# ================================================================\n# DISPLAY FUNCTION - 3-ROW LAYOUT\n# Row 1 : Original Image | Diagnostic Report\n# Row 2 : Attention Heatmap | GradCAM Heatmap\n# Row 3 : GPT-Generated Human-Friendly Text Explanation\n# ================================================================\n\ndef display_prediction(image_path, prediction,\n heatmap_images=None,\n heatmap_titles=None,\n explanation_image=None):\n severity = prediction['severity']\n accent = '#E74C3C' if severity == 'Abnormal' else '#2ECC71'\n bg_color = '#1A1A2E'\n panel_color = '#16213E'\n border_dim = '#2A2A4A'\n\n # Exactly 2 heatmaps: Attention + GradCAM\n if heatmap_titles is None:\n heatmap_titles = ['Attention Heatmap', 'GradCAM Heatmap']\n\n fig = plt.figure(figsize=(16, 14), facecolor=bg_color)\n outer_gs = GridSpec(3, 1, figure=fig,\n height_ratios=[5, 4, 3],\n hspace=0.08)\n\n # ===== ROW 1: Original Image | Diagnostic Report =====\n row1_gs = GridSpecFromSubplotSpec(1, 2,\n subplot_spec=outer_gs[0],\n width_ratios=[1, 1.2],\n wspace=0.05)\n\n ax_img = fig.add_subplot(row1_gs[0])\n ax_img.set_facecolor(bg_color)\n try:\n img = Image.open(image_path).convert('RGB')\n ax_img.imshow(img)\n except Exception:\n ax_img.text(0.5, 0.5, 'WSI / Slide\\n(preview unavailable)',\n ha='center', va='center', color='white', fontsize=13,\n transform=ax_img.transAxes)\n for spine in ax_img.spines.values():\n spine.set_edgecolor(accent)\n spine.set_linewidth(3)\n ax_img.set_xticks([])\n ax_img.set_yticks([])\n ax_img.set_title(os.path.basename(image_path), color='white',\n fontsize=11, pad=8, fontweight='bold')\n\n ax_info = fig.add_subplot(row1_gs[1])\n ax_info.set_facecolor(bg_color)\n ax_info.set_xlim(0, 1)\n ax_info.set_ylim(0, 1)\n ax_info.axis('off')\n ax_info.text(0.5, 0.96, 'Diagnostic Report',\n ha='center', va='top', fontsize=15, fontweight='bold',\n color='white', transform=ax_info.transAxes)\n divider = plt.Line2D([0.05, 0.95], [0.89, 0.89],\n transform=ax_info.transAxes,\n color=accent, linewidth=1.5)\n ax_info.add_line(divider)\n\n def draw_card(ax, y, label, value, confidence, color):\n ax.add_patch(mpatches.FancyBboxPatch(\n (0.04, y - 0.11), 0.92, 0.14,\n boxstyle=\"round,pad=0.01\",\n linewidth=1.5, edgecolor=color,\n facecolor=panel_color, transform=ax.transAxes, clip_on=False\n ))\n ax.text(0.10, y - 0.01, label.upper(),\n ha='left', va='center', fontsize=8, color='#A0A0C0',\n fontweight='bold', transform=ax.transAxes)\n ax.text(0.10, y - 0.05, value,\n ha='left', va='center', fontsize=13, color='white',\n fontweight='bold', transform=ax.transAxes)\n if confidence > 0:\n bar_y = y - 0.09\n ax.add_patch(mpatches.FancyBboxPatch(\n (0.08, bar_y), 0.60, 0.015,\n boxstyle=\"round,pad=0.001\", linewidth=0,\n facecolor='#0F3460', transform=ax.transAxes, clip_on=False\n ))\n ax.add_patch(mpatches.FancyBboxPatch(\n (0.08, bar_y), 0.60 * confidence, 0.015,\n boxstyle=\"round,pad=0.001\", linewidth=0,\n facecolor=color, transform=ax.transAxes, clip_on=False\n ))\n ax.text(0.72, bar_y + 0.007, f'{confidence:.1%}',\n ha='left', va='center', fontsize=9, color=color,\n fontweight='bold', transform=ax.transAxes)\n\n draw_card(ax_info, 0.78,\n 'Region',\n prediction['disease'].replace('_processed', '').replace('_', ' ').title(),\n prediction['disease_confidence'], '#3498DB')\n draw_card(ax_info, 0.57, 'Status', severity,\n prediction['severity_confidence'], accent)\n\n stage_val = prediction['stage']\n stage_conf = prediction['stage_confidence']\n if stage_val == 'N/A':\n stage_display = 'N/A (Normal)'\n stage_conf = 0\n else:\n stage_display = (stage_val.split(' - ')[-1].replace('_', ' ').title()\n if ' - ' in stage_val else stage_val)\n draw_card(ax_info, 0.36, 'Stage / Subtype', stage_display, stage_conf, '#F39C12')\n\n # ===== ROW 2: Attention Heatmap | GradCAM Heatmap (2 columns) =====\n row2_gs = GridSpecFromSubplotSpec(1, 2,\n subplot_spec=outer_gs[1],\n wspace=0.06)\n for col_idx in range(2):\n ax_hm = fig.add_subplot(row2_gs[col_idx])\n ax_hm.set_facecolor(panel_color)\n\n if heatmap_images and col_idx < len(heatmap_images) and heatmap_images[col_idx] is not None:\n hm = heatmap_images[col_idx]\n ax_hm.imshow(hm if isinstance(hm, np.ndarray) else np.array(hm))\n else:\n ax_hm.set_xlim(0, 1)\n ax_hm.set_ylim(0, 1)\n ax_hm.add_patch(mpatches.FancyBboxPatch(\n (0.05, 0.05), 0.90, 0.90,\n boxstyle=\"round,pad=0.02\",\n linewidth=1.5, linestyle='--',\n edgecolor='#4A4A6A', facecolor='#0F0F2A',\n transform=ax_hm.transAxes, clip_on=False\n ))\n ax_hm.text(0.5, 0.5, '[ Heatmap\\nPlaceholder ]',\n ha='center', va='center',\n color='#4A4A6A', fontsize=9, fontstyle='italic',\n transform=ax_hm.transAxes)\n\n title = heatmap_titles[col_idx] if col_idx < len(heatmap_titles) else f'Heatmap {col_idx+1}'\n ax_hm.set_title(title, color='#A0A0C0', fontsize=9, fontweight='bold', pad=5)\n for spine in ax_hm.spines.values():\n spine.set_edgecolor(border_dim)\n spine.set_linewidth(1.2)\n ax_hm.set_xticks([])\n ax_hm.set_yticks([])\n\n # ===== ROW 3: GPT Explanation =====\n ax_text = fig.add_subplot(outer_gs[2])\n ax_text.set_facecolor(panel_color)\n\n if explanation_image is not None:\n exp_img = explanation_image if isinstance(explanation_image, np.ndarray) \\\n else np.array(explanation_image)\n ax_text.imshow(exp_img, aspect='auto')\n ax_text.set_xticks([])\n ax_text.set_yticks([])\n else:\n ax_text.set_xlim(0, 1)\n ax_text.set_ylim(0, 1)\n ax_text.axis('off')\n ax_text.add_patch(mpatches.FancyBboxPatch(\n (0.01, 0.05), 0.98, 0.90,\n boxstyle=\"round,pad=0.02\",\n linewidth=1.5, linestyle='--',\n edgecolor='#4A4A6A', facecolor='#0F0F2A',\n transform=ax_text.transAxes, clip_on=False\n ))\n ax_text.text(0.5, 0.80,\n 'Human-Friendly Text Explanation',\n ha='center', va='center',\n color='#A0A0C0', fontsize=11, fontweight='bold',\n transform=ax_text.transAxes)\n ax_text.text(0.5, 0.42,\n '[ Textual Explanation Placeholder ]\\n\\n'\n 'The model focused on [extracted features]\\n'\n 'because [rule-based reasoning] ...',\n ha='center', va='center',\n color='#3A3A5A', fontsize=10, fontstyle='italic',\n transform=ax_text.transAxes)\n\n for spine in ax_text.spines.values():\n spine.set_edgecolor(border_dim)\n spine.set_linewidth(1.2)\n\n plt.tight_layout(pad=1.2)\n plt.show()\n print()\n\n\n# ================================================================\n# MODEL LOADER\n# ================================================================\n\ndef load_model():\n if not os.path.exists(MODEL_PATH):\n print(f\"ERROR: Model not found at {MODEL_PATH}\")\n return None, None\n\n checkpoint = torch.load(MODEL_PATH, map_location=device)\n state_dict = checkpoint['model_state_dict']\n\n if list(state_dict.keys())[0].startswith('module.'):\n state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}\n\n severity_head_names = set()\n for k in state_dict.keys():\n if k.startswith('aggregator.severity_heads.'):\n parts = k.split('.')\n if len(parts) > 2:\n severity_head_names.add(parts[2])\n disease_names = sorted(list(severity_head_names))\n\n num_diseases = state_dict['aggregator.disease_head.classifier.3.weight'].shape[0]\n num_stage_classes = (\n state_dict['aggregator.stage_head.classifier.3.weight'].shape[0]\n if 'aggregator.stage_head.classifier.3.weight' in state_dict else 0\n )\n\n print(f\"\\nModel config -> diseases: {num_diseases} | stages: {num_stage_classes}\")\n print(f\"Classes: {', '.join(disease_names)}\\n\")\n\n backbone = ViTBackbone()\n model = Phase3Model(backbone, num_diseases=num_diseases,\n num_stage_classes=num_stage_classes,\n disease_names=disease_names).to(device)\n model.load_state_dict(state_dict, strict=True)\n model.eval()\n print(\"Model loaded successfully!\\n\")\n return model, disease_names\n\n\n# ================================================================\n# MAIN INFERENCE PIPELINE\n# ================================================================\n\ndef run_inference():\n print(\"\\n\" + \"=\" * 70)\n print(\" HIERARCHICAL MIL MODEL - PATHOLOGY INFERENCE\")\n print(\"=\" * 70)\n print(\"\\nOptions:\")\n print(\" 1. Single image (provide full path to one image file)\")\n print(\" 2. Folder (provide path to a folder; all images processed)\")\n\n choice = input(\"\\nSelect option (1 / 2): \").strip()\n\n if choice == '1':\n image_path = input(\"Enter image path: \").strip()\n if not os.path.isfile(image_path):\n print(f\"ERROR: File not found -> {image_path}\")\n return\n all_images = [image_path]\n elif choice == '2':\n folder_path = input(\"Enter folder path: \").strip()\n if not os.path.isdir(folder_path):\n print(f\"ERROR: Folder not found -> {folder_path}\")\n return\n all_images = collect_images_from_folder(folder_path)\n if not all_images:\n print(\"No valid images found in the folder.\")\n return\n print(f\"Found {len(all_images)} image(s).\")\n else:\n print(\"Invalid option.\")\n return\n\n # ----------------------------------------------------------\n # Load model\n # ----------------------------------------------------------\n print(\"\\nLoading model ...\")\n model, disease_names = load_model()\n if model is None:\n return\n\n # ----------------------------------------------------------\n # STEP 1: Run Attention and GradCAM analyses only\n # ----------------------------------------------------------\n print(\"\\n\" + \"=\" * 70)\n print(\"Running Attention Analysis ...\")\n print(\"=\" * 70)\n processed_images = _preprocess_images_for_attention(all_images)\n attn_results = run_attention_analysis(\n attention_extractor,\n device,\n all_images,\n processed_images,\n DISEASE_CLASS_MAPPING,\n STAGE_CLASS_MAPPING\n )\n\n print(\"\\n\" + \"=\" * 70)\n print(\"Running GradCAM Analysis ...\")\n print(\"=\" * 70)\n gradcam_results = run_tri_head_gradcam_plus_plus_analysis(model, device, all_images)\n\n # ----------------------------------------------------------\n # STEP 2: Feature extraction from attention heatmaps\n # ----------------------------------------------------------\n print(\"\\n\" + \"=\" * 70)\n print(\"Running Feature Extraction ...\")\n print(\"=\" * 70)\n\n explanations_list = []\n\n for i, attention_result in enumerate(attn_results):\n print(f\" [{i+1}/{len(attn_results)}] Extracting features: {attention_result['filename']}\")\n\n extractor = HeatmapFeatureExtractor(attention_result)\n\n bright = extractor.get_brightest_region()\n scatter = extractor.get_activation_scatter()\n dom_color = extractor.get_dominant_focus_color()\n texture = extractor.get_texture_analysis()\n\n explanations_list.append({\n \"brightest\": bright,\n \"scatter\": scatter,\n \"dominant_color\": dom_color,\n \"texture\": texture,\n })\n\n print(f\" āœ… Position={bright['primary_hotspot']['position']}, \"\n f\"Scatter={scatter['scatter_level']}, \"\n f\"Color={dom_color['dominant_color_name']}, \"\n f\"Texture={texture['texture_classification']}\")\n\n # ----------------------------------------------------------\n # STEP 3: Standard inference loop + GPT explanation + display\n # ----------------------------------------------------------\n dataset = SimpleSlideDataset(all_images)\n dataloader = DataLoader(dataset, batch_size=1, shuffle=False,\n collate_fn=simple_collate, num_workers=2, pin_memory=True)\n\n results = []\n\n for batch_idx, batch in enumerate(tqdm(dataloader, desc=\"Running inference\")):\n tiles, paths = batch\n if not tiles or not paths:\n continue\n\n slide_path = paths[0]\n\n try:\n img_idx = all_images.index(slide_path)\n except ValueError:\n img_idx = batch_idx\n\n prediction = predict_image(model, tiles, disease_names, STAGE_NAMES)\n if prediction is None:\n print(f\"Failed to process: {slide_path}\")\n continue\n\n # ---- Build the 2 overlay images: Attention + GradCAM ----\n attn_overlay = None\n if attn_results and img_idx < len(attn_results):\n ar = attn_results[img_idx]\n attn_overlay = _build_attention_overlay(ar['image'], ar['attention_heatmap'])\n\n gradcam_overlay = None\n if gradcam_results and img_idx < len(gradcam_results):\n gradcam_overlay = gradcam_results[img_idx]['union']['heatmap_overlay']\n\n # ---- Build comprehensive_data dict for GPT ----\n exp = explanations_list[img_idx]\n bright = exp['brightest']\n scatter_res = exp['scatter']\n dom = exp['dominant_color']\n texture = exp['texture']\n\n # GradCAM-sourced confidence and severity values\n gradcam_disease_conf = gradcam_results[img_idx]['level1_disease']['confidence'] \\\n if gradcam_results else 0.0\n gradcam_severity_conf = gradcam_results[img_idx]['level2_severity']['confidence'] \\\n if gradcam_results else 0.0\n gradcam_severity = gradcam_results[img_idx]['level2_severity']['predicted_class'] \\\n if gradcam_results else 'N/A'\n\n comprehensive_data = {\n 'predicted_disease': prediction['disease'],\n 'gradcam_severity': gradcam_severity,\n 'predicted_stage': prediction['stage'],\n 'stage_confidence': prediction['stage_confidence'],\n 'gradcam_disease_conf': gradcam_disease_conf,\n 'gradcam_severity_conf': gradcam_severity_conf,\n 'primary_position': bright['primary_hotspot']['position'],\n 'primary_intensity': bright['primary_hotspot']['intensity'],\n 'hotspot_count': bright['hotspot_count'],\n 'center_attention': bright['spatial_coverage']['center_attention'],\n 'mid_attention': bright['spatial_coverage']['mid_region_attention'],\n 'periphery_attention': bright['spatial_coverage']['periphery_attention'],\n 'scatter_level': scatter_res['scatter_level'],\n 'num_clusters': scatter_res['num_clusters'],\n 'dominant_color': dom['dominant_color_name'],\n 'color_confidence': dom['color_confidence'],\n 'texture_classification': texture['texture_classification'],\n 'uniformity': texture['texture_scores']['uniformity'],\n 'organization': texture['texture_scores']['organization'],\n 'complexity': texture['texture_scores']['complexity'],\n 'smoothness': texture['texture_scores']['smoothness'],\n }\n\n # ---- Generate GPT explanation -> render to image for Row 3 ----\n print(f\"\\n Generating GPT explanation for image {img_idx + 1} ...\")\n explanation_text = generate_comprehensive_explanation(comprehensive_data)\n explanation_image = _render_explanation_to_image(explanation_text)\n time.sleep(0.5)\n\n # ---- Display: Original | Report | Attention | GradCAM | Explanation ----\n display_prediction(\n image_path = slide_path,\n prediction = prediction,\n heatmap_images = [attn_overlay, gradcam_overlay],\n heatmap_titles = ['Attention Heatmap', 'GradCAM Heatmap'],\n explanation_image = explanation_image\n )\n\n results.append({\n 'image_path': slide_path,\n 'image_name': os.path.basename(slide_path),\n 'predicted_disease': prediction['disease'],\n 'disease_confidence': prediction['disease_confidence'],\n 'predicted_severity': prediction['severity'],\n 'severity_confidence': prediction['severity_confidence'],\n 'predicted_stage': prediction['stage'],\n 'stage_confidence': prediction['stage_confidence']\n })\n\n if results:\n df = pd.DataFrame(results)\n out_path = os.path.join(OUTPUT_DIR, \"inference_results.csv\")\n df.to_csv(out_path, index=False)\n print(f\"\\nResults saved -> {out_path}\")\n print(f\"Total processed: {len(results)}\")\n\n print(\"\\n\" + \"=\" * 70)\n print(\"DONE\")\n print(\"=\" * 70)\n\n\nif __name__ == \"__main__\":\n run_inference()","metadata":{"trusted":true},"outputs":[],"execution_count":null}]}