""" Phase 3 — Inference Pipeline """ import os import numpy as np import cv2 import torch import torch.nn.functional as F from typing import Optional, Dict, Any, Tuple from model import UNet from checkpoint_loader import load_model_for_inference from clinical_metrics import run_clinical_pipeline, ClinicalResult # ── Minimum plausible disc area in pixels (512×512 image). # Anything smaller is almost certainly not a real fundus image. # A disc typically covers ~3–5 % of the 512×512 canvas ≈ 7,000–13,000 px. # We gate at 1 % (≈ 2,600 px) to be conservative. _MIN_DISC_AREA_PX = 2_600 _MIN_CUP_AREA_PX = 100 # below this the cup reading is meaningless _MIN_CUP_DISC_RATIO = 0.01 # cup/disc area fraction — below = unreliable class Phase3Pipeline: def __init__( self, repo_id: str = "Nj-1111/EyeeSEE", epoch: Optional[int] = None, mc_passes: int = 20, uncertainty_threshold: float = 0.05, device: Optional[torch.device] = None, token: Optional[str] = None, debug: bool = False, # gate all debug I/O behind this flag ): self.repo_id = repo_id self.mc_passes = mc_passes self.uncertainty_threshold = uncertainty_threshold self.debug = debug self.device = device or torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) self.token = ( token or os.getenv('HF_TOKEN_2') or os.getenv('HF_TOKEN') ) # IMP: lower dropout improves stability self.model = UNet( in_channels=1, n_classes=3, base_filters=64, dropout=0.1 ) load_model_for_inference( model=self.model, repo_id=repo_id, epoch=epoch, device=self.device, token=self.token ) # ────────────────────────────────────────────────────────────────────── # preprocessing # ────────────────────────────────────────────────────────────────────── def _preprocess(self, image: np.ndarray) -> torch.Tensor: if image.ndim == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) image = cv2.resize( image, (512, 512), interpolation=cv2.INTER_AREA ) image = image.astype(np.float32) / 255.0 return ( torch.from_numpy(image) .unsqueeze(0) .unsqueeze(0) .to(self.device) ) # ────────────────────────────────────────────────────────────────────── # mask cleanup (separate functions — disc and cup have different scales) # ────────────────────────────────────────────────────────────────────── def _clean_disc(self, binary_mask: np.ndarray) -> np.ndarray: """ Morphological cleanup for the disc mask. Disc is large enough that a 5×5 kernel is safe. """ binary_mask = binary_mask.astype(np.uint8) if binary_mask.sum() == 0: return binary_mask kernel = np.ones((5, 5), np.uint8) binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel) binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel) n, labels, stats, _ = cv2.connectedComponentsWithStats(binary_mask) if n <= 1: return binary_mask largest = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA]) cleaned = np.zeros_like(binary_mask) cleaned[labels == largest] = 1 return cleaned def _clean_cup(self, binary_mask: np.ndarray) -> np.ndarray: """ Morphological cleanup for the cup mask. Uses an adaptive kernel: after anatomical clipping the remaining cup may be small — a 5×5 open would erase it. We drop to 3×3 for small remnants. """ binary_mask = binary_mask.astype(np.uint8) if binary_mask.sum() == 0: return binary_mask k = 3 if binary_mask.sum() < 3_000 else 5 kernel = np.ones((k, k), np.uint8) binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel) binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel) n, labels, stats, _ = cv2.connectedComponentsWithStats(binary_mask) if n <= 1: return binary_mask largest = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA]) cleaned = np.zeros_like(binary_mask) cleaned[labels == largest] = 1 return cleaned # ────────────────────────────────────────────────────────────────────── # anatomical enforcement # ────────────────────────────────────────────────────────────────────── def _enforce_cup_in_disc( self, disc_mask: np.ndarray, cup_mask: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray, int]: """ Hard-clip cup to disc boundary. Returns (disc_mask, corrected_cup, n_pixels_removed). Pure logical AND — no morphology here. """ corrected_cup = np.logical_and(cup_mask == 1, disc_mask == 1).astype(np.uint8) violations = int(cup_mask.sum()) - int(corrected_cup.sum()) return disc_mask, corrected_cup, violations # ────────────────────────────────────────────────────────────────────── # MC-Dropout segmentation # ────────────────────────────────────────────────────────────────────── def _mc_segment( self, tensor: torch.Tensor, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Run MC-Dropout inference and return: disc_mask (raw, thresholded, NOT yet cleaned) cup_mask (raw, thresholded, NOT yet cleaned) var_probs (3, H, W) — full variance map, used for ROI uncertainty later Masks are intentionally returned uncleaned so that the caller can enforce topology FIRST and clean AFTER (see run()). """ # BatchNorm stays frozen; only dropout layers go to train mode self.model.eval() for m in self.model.modules(): if isinstance(m, (torch.nn.Dropout, torch.nn.Dropout2d, torch.nn.Dropout3d)): m.train() all_probs = [] with torch.no_grad(): for _ in range(self.mc_passes): logits = self.model(tensor) probs = F.softmax(logits, dim=1) all_probs.append(probs.cpu().numpy()) self.model.eval() all_probs = np.stack(all_probs, axis=0) # (T, 1, 3, H, W) mean_probs = all_probs.mean(axis=0)[0] # (3, H, W) var_probs = all_probs.var(axis=0)[0] # (3, H, W) # ── Separate thresholds — critical fix ───────────────────────── # Cup probability maps are more diffuse than disc maps. # Using 0.25 for both causes the cup to spill far outside the disc; # the anatomical clipping then wipes it out entirely. # A higher cup threshold trims diffuse edges back inside the disc. disc_mask = (mean_probs[1] > 0.35).astype(np.uint8) cup_mask = (mean_probs[2] > 0.55).astype(np.uint8) if self.debug: cv2.imwrite("disc_prob_debug.png", (mean_probs[1] * 255).astype(np.uint8)) cv2.imwrite("cup_prob_debug.png", (mean_probs[2] * 255).astype(np.uint8)) print(f"[debug] raw disc px={disc_mask.sum()} raw cup px={cup_mask.sum()}") return disc_mask, cup_mask, var_probs # ────────────────────────────────────────────────────────────────────── # uncertainty over anatomical ROI # ────────────────────────────────────────────────────────────────────── @staticmethod def _roi_uncertainty( var_probs: np.ndarray, # (3, H, W) disc_mask: np.ndarray, # (H, W) cup_mask: np.ndarray, # (H, W) ) -> float: """ Mean MC-Dropout variance restricted to the disc+cup region. Averaging variance over the WHOLE image (including background) suppresses clinically important local uncertainty because the model is very confident about the large background class. Restricting to the anatomical ROI makes the score meaningful. """ roi = (disc_mask == 1) | (cup_mask == 1) if not roi.any(): # No anatomical region found — fall back to global (will be high) return float(var_probs[1:].mean()) # Channels 1 (disc) and 2 (cup) variance, sampled at ROI pixels return float(var_probs[1:, roi].mean()) # ────────────────────────────────────────────────────────────────────── # public API # ────────────────────────────────────────────────────────────────────── def run(self, image: np.ndarray) -> Dict[str, Any]: """ Full pipeline. Processing order (corrected): 1. threshold (separate disc / cup thresholds) 2. enforce topology ← BEFORE any morphology 3. clean disc + clean cup (separate kernels) 4. enforce topology again ← morphology can reintroduce violations 5. compute ROI uncertainty 6. minimum-cup sanity gate 7. clinical metrics """ tensor = self._preprocess(image) # Step 1+2: raw threshold → first topology enforcement # Cup is intentionally NOT cleaned yet so morphology doesn't expand # it past the disc before the AND clip. disc_mask, cup_mask, var_probs = self._mc_segment(tensor) disc_mask, cup_mask, violations = self._enforce_cup_in_disc(disc_mask, cup_mask) # Step 3: clean both masks (cup with adaptive small kernel) disc_mask = self._clean_disc(disc_mask) cup_mask = self._clean_cup(cup_mask) # Step 4: second enforcement — morphology can re-expand cup slightly # outside a cleaned (potentially slightly shrunk) disc disc_mask, cup_mask, extra_violations = self._enforce_cup_in_disc(disc_mask, cup_mask) violations += extra_violations # Step 5: uncertainty over anatomical ROI (not the whole image) uncertainty = self._roi_uncertainty(var_probs, disc_mask, cup_mask) if self.debug: print( f"[debug] post-clean disc px={disc_mask.sum()} " f"cup px={cup_mask.sum()} " f"violations={violations} " f"roi_uncertainty={uncertainty:.6f}" ) # ── Step 6: minimum-cup sanity gate ─────────────────────────── # A tiny cup remnant (e.g. a few dozen pixels surviving morphology) # produces a meaningless vCDR. Zero it out and warn instead. disc_area = int(disc_mask.sum()) cup_area = int(cup_mask.sum()) extra_warnings: list[str] = [] if disc_area < _MIN_DISC_AREA_PX: extra_warnings.append( f"Disc too small ({disc_area} px) — image may not be a fundus photo. " "Segmentation result is unreliable." ) if cup_area > 0 and cup_area < _MIN_CUP_AREA_PX: extra_warnings.append( f"Cup remnant too small ({cup_area} px) — suppressed. " "Cup segmentation is unreliable for this image." ) cup_mask = np.zeros_like(cup_mask) cup_area = 0 if cup_area > 0 and disc_area > 0: if (cup_area / disc_area) < _MIN_CUP_DISC_RATIO: extra_warnings.append( f"Cup/disc area ratio ({cup_area}/{disc_area} = " f"{cup_area/disc_area:.3f}) is below minimum — " "cup reading may be unreliable." ) # Step 7: clinical pipeline clinical = run_clinical_pipeline( disc_mask=disc_mask, cup_mask=cup_mask, uncertainty=uncertainty, uncertainty_threshold=self.uncertainty_threshold ) # Inject all accumulated warnings if violations > 0: clinical.warnings.append( f"Mask corrected: {violations} cup pixels clipped to disc boundary." ) clinical.warnings.extend(extra_warnings) report = { 'vcdr': clinical.vcdr, 'isnt': clinical.isnt.to_dict(), 'risk_level': clinical.risk_level.value, 'uncertainty': round(uncertainty, 6), 'high_uncertainty': clinical.high_uncertainty, 'disc_area_px': clinical.disc_area_px, 'cup_area_px': clinical.cup_area_px, 'disc_center': clinical.disc_center, 'cup_center': clinical.cup_center, 'sanity_passed': clinical.sanity_passed, 'warnings': clinical.warnings, } return { 'disc_mask': disc_mask, 'cup_mask': cup_mask, 'uncertainty': uncertainty, 'clinical': clinical, 'report': report, } def run_from_path(self, image_path: str) -> Dict[str, Any]: image = cv2.imread(image_path) if image is None: raise FileNotFoundError(f"Cannot load image: {image_path}") return self.run(image)