| """ |
| Phase 3 β Inference Pipeline |
| """ |
|
|
| import os |
| import numpy as np |
| import cv2 |
| import torch |
| import torch.nn.functional as F |
| from typing import Optional, Dict, Any, Tuple |
|
|
| from model import UNet |
| from checkpoint_loader import load_model_for_inference |
| from clinical_metrics import run_clinical_pipeline, ClinicalResult |
|
|
|
|
| class Phase3Pipeline: |
|
|
| def __init__( |
| self, |
| repo_id: str = "Nj-1111/EyeeSEE", |
| epoch: Optional[int] = None, |
| mc_passes: int = 20, |
| uncertainty_threshold: float = 0.05, |
| device: Optional[torch.device] = None, |
| token: Optional[str] = None |
| ): |
| self.repo_id = repo_id |
| self.mc_passes = mc_passes |
| self.uncertainty_threshold = uncertainty_threshold |
|
|
| self.device = device or torch.device( |
| 'cuda' if torch.cuda.is_available() else 'cpu' |
| ) |
|
|
| self.token = ( |
| token or |
| os.getenv('HF_TOKEN_2') or |
| os.getenv('HF_TOKEN') |
| ) |
|
|
| |
| self.model = UNet( |
| in_channels=1, |
| n_classes=3, |
| base_filters=64, |
| dropout=0.1 |
| ) |
|
|
| load_model_for_inference( |
| model=self.model, |
| repo_id=repo_id, |
| epoch=epoch, |
| device=self.device, |
| token=self.token |
| ) |
|
|
| |
| |
| |
|
|
| def _preprocess(self, image: np.ndarray) -> torch.Tensor: |
|
|
| if image.ndim == 3: |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
|
| image = cv2.resize( |
| image, |
| (512, 512), |
| interpolation=cv2.INTER_AREA |
| ) |
|
|
| image = image.astype(np.float32) / 255.0 |
|
|
| return ( |
| torch.from_numpy(image) |
| .unsqueeze(0) |
| .unsqueeze(0) |
| .to(self.device) |
| ) |
|
|
| |
| |
| |
|
|
| def _clean_mask(self, binary_mask: np.ndarray) -> np.ndarray: |
|
|
| binary_mask = binary_mask.astype(np.uint8) |
|
|
| kernel = np.ones((5, 5), np.uint8) |
|
|
| |
| binary_mask = cv2.morphologyEx( |
| binary_mask, |
| cv2.MORPH_OPEN, |
| kernel |
| ) |
|
|
| |
| binary_mask = cv2.morphologyEx( |
| binary_mask, |
| cv2.MORPH_CLOSE, |
| kernel |
| ) |
|
|
| |
| n, labels, stats, _ = cv2.connectedComponentsWithStats(binary_mask) |
|
|
| if n <= 1: |
| return binary_mask |
|
|
| largest = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA]) |
|
|
| cleaned = np.zeros_like(binary_mask) |
| cleaned[labels == largest] = 1 |
|
|
| return cleaned |
|
|
| |
| |
| |
|
|
| def _mc_segment( |
| self, |
| tensor: torch.Tensor |
| ) -> Tuple[np.ndarray, np.ndarray, float]: |
|
|
|
|
| |
| self.model.eval() |
|
|
|
|
| |
| for m in self.model.modules(): |
| if isinstance(m, torch.nn.Dropout2d): |
| m.train() |
|
|
|
|
| all_probs = [] |
|
|
| with torch.no_grad(): |
| for _ in range(self.mc_passes): |
| logits = self.model(tensor) |
|
|
| probs = F.softmax(logits, dim=1) |
|
|
| all_probs.append(probs.cpu().numpy()) |
|
|
|
|
|
|
| |
| self.model.eval() |
|
|
| all_probs = np.stack(all_probs, axis=0) |
|
|
| mean_probs = all_probs.mean(axis=0)[0] |
| var_probs = all_probs.var(axis=0)[0] |
| |
|
|
| print( |
| "DISC mean/max:", |
| float(mean_probs[1].mean()), |
| float(mean_probs[1].max()) |
| ) |
|
|
| print( |
| "CUP mean/max:", |
| float(mean_probs[2].mean()), |
| float(mean_probs[2].max()) |
| ) |
|
|
|
|
| |
| disc_mask = (mean_probs[1] > 0.25).astype(np.uint8) |
| cup_mask = (mean_probs[2] > 0.25).astype(np.uint8) |
|
|
| print("Raw disc pixels:", int(disc_mask.sum())) |
| print("Raw cup pixels :", int(cup_mask.sum())) |
| |
|
|
| |
| disc_mask = self._clean_mask(disc_mask) |
| cup_mask = self._clean_mask(cup_mask) |
|
|
| print("Clean disc pixels:", int(disc_mask.sum())) |
| print("Clean cup pixels :", int(cup_mask.sum())) |
|
|
| return ( |
| disc_mask, |
| cup_mask, |
| float(var_probs[1:].mean()) |
| ) |
|
|
| |
| |
| |
|
|
| def _enforce_cup_in_disc( |
| self, |
| disc_mask: np.ndarray, |
| cup_mask: np.ndarray |
| ) -> Tuple[np.ndarray, np.ndarray, int]: |
|
|
| corrected_cup = np.logical_and( |
| cup_mask == 1, |
| disc_mask == 1 |
| ).astype(np.uint8) |
|
|
| violations = ( |
| int(cup_mask.sum()) - |
| int(corrected_cup.sum()) |
| ) |
|
|
| return disc_mask, corrected_cup, violations |
|
|
| |
| |
| |
|
|
| def run(self, image: np.ndarray) -> Dict[str, Any]: |
|
|
| tensor = self._preprocess(image) |
|
|
| disc_mask, cup_mask, uncertainty = self._mc_segment(tensor) |
|
|
| disc_mask, cup_mask, violations = self._enforce_cup_in_disc( |
| disc_mask, |
| cup_mask |
| ) |
|
|
| clinical = run_clinical_pipeline( |
| disc_mask=disc_mask, |
| cup_mask=cup_mask, |
| uncertainty=uncertainty, |
| uncertainty_threshold=self.uncertainty_threshold |
| ) |
|
|
| if violations > 0: |
| clinical.warnings.append( |
| f"Mask corrected: {violations} cup pixels clipped to disc boundary." |
| ) |
|
|
| report = { |
| 'vcdr': clinical.vcdr, |
| 'isnt': clinical.isnt.to_dict(), |
| 'risk_level': clinical.risk_level.value, |
| 'uncertainty': round(uncertainty, 6), |
| 'high_uncertainty': clinical.high_uncertainty, |
| 'disc_area_px': clinical.disc_area_px, |
| 'cup_area_px': clinical.cup_area_px, |
| 'disc_center': clinical.disc_center, |
| 'cup_center': clinical.cup_center, |
| 'sanity_passed': clinical.sanity_passed, |
| 'warnings': clinical.warnings, |
| } |
|
|
| return { |
| 'disc_mask': disc_mask, |
| 'cup_mask': cup_mask, |
| 'uncertainty': uncertainty, |
| 'clinical': clinical, |
| 'report': report, |
| } |
|
|
| def run_from_path(self, image_path: str) -> Dict[str, Any]: |
|
|
| image = cv2.imread(image_path) |
|
|
| if image is None: |
| raise FileNotFoundError(f"Cannot load image: {image_path}") |
|
|
| return self.run(image) |
|
|