Nj-1111 commited on
Commit
58fb621
Β·
verified Β·
1 Parent(s): da27670

Upload 2 files

Browse files
Files changed (2) hide show
  1. clinical_metrics_v2.py +382 -0
  2. phase3pipeline_v2.py +351 -0
clinical_metrics_v2.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 3 β€” Clinical Logic & Verification Engine
3
+ ===============================================
4
+ Implements deterministic clinical metric extraction from segmentation masks.
5
+ No ML here β€” pure geometry and ophthalmology math.
6
+
7
+ Mask convention (from Phase 2 training):
8
+ 0 = Background
9
+ 1 = Optic Disc
10
+ 2 = Optic Cup
11
+ """
12
+
13
+ import numpy as np
14
+ import cv2
15
+ from dataclasses import dataclass, field
16
+ from typing import Tuple, Dict, Optional
17
+ from enum import Enum
18
+
19
+ # ── ISNT tolerance margin ──────────────────────────────────────────────────
20
+ # Exact I > S > N > T fails on tiny numerical noise.
21
+ # A margin of 0.2 rim-pixels absorbs rounding without masking real violations.
22
+ _ISNT_MARGIN = 0.2
23
+
24
+
25
+ class RiskLevel(Enum):
26
+ HEALTHY = "Healthy"
27
+ SUSPECT = "Glaucoma Suspect"
28
+ HIGH = "High Risk"
29
+
30
+
31
+ class SanityError(Exception):
32
+ """Raised when mask geometry violates anatomical constraints."""
33
+ pass
34
+
35
+
36
+ @dataclass
37
+ class ISNTResult:
38
+ inferior: float = 0.0
39
+ superior: float = 0.0
40
+ nasal: float = 0.0
41
+ temporal: float = 0.0
42
+ rule_satisfied: bool = False
43
+
44
+ def to_dict(self) -> Dict[str, float]:
45
+ return {
46
+ 'inferior': round(self.inferior, 4),
47
+ 'superior': round(self.superior, 4),
48
+ 'nasal': round(self.nasal, 4),
49
+ 'temporal': round(self.temporal, 4),
50
+ 'rule_satisfied': self.rule_satisfied,
51
+ }
52
+
53
+
54
+ @dataclass
55
+ class ClinicalResult:
56
+ vcdr: float = 0.0
57
+ isnt: ISNTResult = field(default_factory=ISNTResult)
58
+ disc_area_px: int = 0
59
+ cup_area_px: int = 0
60
+ disc_center: Tuple[int, int] = (0, 0)
61
+ cup_center: Tuple[int, int] = (0, 0)
62
+ uncertainty: float = 0.0
63
+ high_uncertainty: bool = False
64
+ risk_level: RiskLevel = RiskLevel.HEALTHY
65
+ sanity_passed: bool = False
66
+ warnings: list = field(default_factory=list)
67
+
68
+ def to_dict(self) -> dict:
69
+ return {
70
+ 'vcdr': round(self.vcdr, 4),
71
+ 'isnt': self.isnt.to_dict(),
72
+ 'disc_area_px': self.disc_area_px,
73
+ 'cup_area_px': self.cup_area_px,
74
+ 'disc_center': self.disc_center,
75
+ 'cup_center': self.cup_center,
76
+ 'uncertainty': round(self.uncertainty, 6),
77
+ 'high_uncertainty': self.high_uncertainty,
78
+ 'risk_level': self.risk_level.value,
79
+ 'sanity_passed': self.sanity_passed,
80
+ 'warnings': self.warnings,
81
+ }
82
+
83
+
84
+ # ─────────────────────────────────────────────────────────────────────────────
85
+ # Step 1 β€” Sanity checks
86
+ # ─────────────────────────────────────────────────────────────────────────────
87
+
88
+ def run_sanity_checks(disc_mask: np.ndarray, cup_mask: np.ndarray) -> None:
89
+ """
90
+ Enforce anatomical constraints. Raises SanityError on hard violations.
91
+
92
+ Checks:
93
+ 1. Masks are binary (0/1 values only).
94
+ 2. Disc region is non-empty.
95
+ 3. Cup is 100 % contained inside the disc (hard anatomical law).
96
+ 4. Disconnected regions β€” warn only; upstream _clean_* handles them.
97
+ """
98
+ # 1. Binary check
99
+ for name, mask in [('disc', disc_mask), ('cup', cup_mask)]:
100
+ unique_vals = np.unique(mask)
101
+ if not set(unique_vals).issubset({0, 1}):
102
+ raise SanityError(
103
+ f"{name} mask contains non-binary values: {unique_vals}"
104
+ )
105
+
106
+ # 2. Non-empty disc
107
+ if int(disc_mask.sum()) == 0:
108
+ raise SanityError("Optic disc mask is empty β€” segmentation failure.")
109
+
110
+ # 3. Cup βŠ‚ Disc
111
+ cup_outside_disc = np.logical_and(cup_mask == 1, disc_mask == 0)
112
+ if cup_outside_disc.any():
113
+ n_violation = int(cup_outside_disc.sum())
114
+ raise SanityError(
115
+ f"Cup extends outside disc boundary ({n_violation} pixels). "
116
+ "Anatomically impossible β€” reject segmentation."
117
+ )
118
+
119
+ # 4. Single connected component β€” warn only, do not reject.
120
+ # Small segmentation gaps from noisy boundaries are handled upstream
121
+ # by _clean_disc / _clean_cup.
122
+ for name, mask in [('disc', disc_mask), ('cup', cup_mask)]:
123
+ if mask.sum() == 0:
124
+ continue
125
+ n_labels, _ = cv2.connectedComponents(mask.astype(np.uint8))
126
+ if n_labels > 2:
127
+ pass # upstream cleanup handles; avoid hard rejection here
128
+
129
+
130
+ # ──────────────────────────��──────────────────────────────────────────────────
131
+ # Step 2 β€” vCDR
132
+ # ─────────────────────────────────────────────────────────────────────────────
133
+
134
+ def calculate_vcdr(
135
+ disc_mask: np.ndarray,
136
+ cup_mask: np.ndarray,
137
+ ) -> Tuple[float, dict]:
138
+ """
139
+ Calculate vertical Cup-to-Disc Ratio using vertical extrema.
140
+
141
+ Clinical basis:
142
+ Horizontal disc expansion is less indicative of early glaucoma.
143
+ vCDR is the primary screening metric.
144
+
145
+ Returns:
146
+ vcdr (float): ratio in [0, 1]
147
+ details (dict): raw pixel measurements for transparency
148
+ """
149
+ disc_rows = np.where(disc_mask.any(axis=1))[0]
150
+ cup_rows = np.where(cup_mask.any(axis=1))[0]
151
+
152
+ if disc_rows.size == 0:
153
+ return 0.0, {}
154
+
155
+ disc_v_diam = int(disc_rows.max() - disc_rows.min() + 1)
156
+ cup_v_diam = int(cup_rows.max() - cup_rows.min() + 1) if cup_rows.size > 0 else 0
157
+
158
+ vcdr = cup_v_diam / disc_v_diam if disc_v_diam > 0 else 0.0
159
+
160
+ details = {
161
+ 'disc_top_px': int(disc_rows.min()),
162
+ 'disc_bottom_px': int(disc_rows.max()),
163
+ 'disc_v_diam_px': disc_v_diam,
164
+ 'cup_top_px': int(cup_rows.min()) if cup_rows.size > 0 else None,
165
+ 'cup_bottom_px': int(cup_rows.max()) if cup_rows.size > 0 else None,
166
+ 'cup_v_diam_px': cup_v_diam,
167
+ }
168
+ return round(vcdr, 4), details
169
+
170
+
171
+ # ─────────────────────────────────────────────────────────────────────────────
172
+ # Step 3 β€” ISNT Rule
173
+ # ─────────────────────────────────────────────────────────────────────────────
174
+
175
+ def _disc_centroid(disc_mask: np.ndarray) -> Tuple[int, int]:
176
+ M = cv2.moments(disc_mask.astype(np.uint8))
177
+ if M['m00'] == 0:
178
+ h, w = disc_mask.shape
179
+ return h // 2, w // 2
180
+ cy = int(M['m01'] / M['m00'])
181
+ cx = int(M['m10'] / M['m00'])
182
+ return cy, cx
183
+
184
+
185
+ def _rim_thickness_in_quadrant(
186
+ quadrant_mask: np.ndarray,
187
+ disc_mask: np.ndarray,
188
+ cup_mask: np.ndarray,
189
+ ) -> float:
190
+ """
191
+ Mean Euclidean distance from cup boundary to disc boundary,
192
+ measured inside a specific quadrant.
193
+
194
+ Uses a distance transform on the disc interior so that each pixel
195
+ inside the disc gets its distance to the disc edge.
196
+ We then sample only rim pixels (disc=1, cup=0) in this quadrant.
197
+ """
198
+ rim = np.logical_and(disc_mask == 1, cup_mask == 0).astype(np.uint8)
199
+ rim_in_quad = np.logical_and(rim, quadrant_mask).astype(np.uint8)
200
+
201
+ if rim_in_quad.sum() == 0:
202
+ return 0.0
203
+
204
+ dist_to_disc_edge = cv2.distanceTransform(
205
+ disc_mask.astype(np.uint8), cv2.DIST_L2, 5
206
+ )
207
+
208
+ thicknesses = dist_to_disc_edge[rim_in_quad == 1]
209
+ return float(np.mean(thicknesses))
210
+
211
+
212
+ def calculate_isnt(
213
+ disc_mask: np.ndarray,
214
+ cup_mask: np.ndarray,
215
+ ) -> ISNTResult:
216
+ """
217
+ Calculate neuro-retinal rim thickness in four ISNT quadrants.
218
+
219
+ Quadrant definition (image coordinates):
220
+ Superior β€” top half (rows < cy)
221
+ Inferior β€” bottom half (rows >= cy)
222
+ Nasal β€” right half (cols >= cx) [standard right eye convention]
223
+ Temporal β€” left half (cols < cx)
224
+
225
+ ISNT rule (with tolerance margin):
226
+ Inferior > Superior - margin
227
+ Superior > Nasal - margin
228
+ Nasal > Temporal - margin
229
+
230
+ Using a strict exact ordering falsely triggers violations when
231
+ quadrant thicknesses differ by sub-pixel amounts due to rounding.
232
+ A margin of _ISNT_MARGIN (0.2 px) absorbs noise without masking
233
+ real rim thinning (which presents as differences of several pixels).
234
+ """
235
+ h, w = disc_mask.shape
236
+ cy, cx = _disc_centroid(disc_mask)
237
+
238
+ superior_q = np.zeros((h, w), dtype=bool)
239
+ inferior_q = np.zeros((h, w), dtype=bool)
240
+ nasal_q = np.zeros((h, w), dtype=bool)
241
+ temporal_q = np.zeros((h, w), dtype=bool)
242
+
243
+ superior_q[:cy, :] = True
244
+ inferior_q[cy:, :] = True
245
+ nasal_q[:, cx:] = True
246
+ temporal_q[:, :cx] = True
247
+
248
+ I = _rim_thickness_in_quadrant(inferior_q, disc_mask, cup_mask)
249
+ S = _rim_thickness_in_quadrant(superior_q, disc_mask, cup_mask)
250
+ N = _rim_thickness_in_quadrant(nasal_q, disc_mask, cup_mask)
251
+ T = _rim_thickness_in_quadrant(temporal_q, disc_mask, cup_mask)
252
+
253
+ # Tolerated ISNT check β€” absorbs sub-pixel numerical noise
254
+ rule_ok = (
255
+ I > S - _ISNT_MARGIN and
256
+ S > N - _ISNT_MARGIN and
257
+ N > T - _ISNT_MARGIN
258
+ )
259
+
260
+ return ISNTResult(
261
+ inferior=I, superior=S, nasal=N, temporal=T,
262
+ rule_satisfied=rule_ok,
263
+ )
264
+
265
+
266
+ # ─────────────────────────────────────────────────────────────────────────────
267
+ # Step 4 β€” Risk classification
268
+ # ─────────────────────────────────────────────────────────────────────────────
269
+
270
+ def classify_risk(
271
+ vcdr: float,
272
+ isnt: ISNTResult,
273
+ uncertainty: float,
274
+ uncertainty_threshold: float = 0.05,
275
+ ) -> Tuple[RiskLevel, list]:
276
+ """
277
+ Rule-based risk stratification.
278
+
279
+ Thresholds from clinical literature:
280
+ vCDR < 0.65 β†’ Healthy
281
+ vCDR 0.65–0.80 β†’ Suspect
282
+ vCDR > 0.80 β†’ High Risk
283
+ ISNT violation (any risk) β†’ escalate to at least Suspect
284
+ High uncertainty β†’ override to Suspect regardless of vCDR
285
+ """
286
+ warnings = []
287
+
288
+ if uncertainty > uncertainty_threshold:
289
+ warnings.append(
290
+ f"High model uncertainty ({uncertainty:.4f}) β€” result may be unreliable."
291
+ )
292
+ return RiskLevel.SUSPECT, warnings
293
+
294
+ if vcdr > 0.80:
295
+ risk = RiskLevel.HIGH
296
+ warnings.append(
297
+ f"vCDR {vcdr:.2f} exceeds 0.80 β€” urgent referral recommended."
298
+ )
299
+ elif vcdr > 0.65:
300
+ risk = RiskLevel.SUSPECT
301
+ warnings.append(f"vCDR {vcdr:.2f} in borderline range (0.65–0.80).")
302
+ else:
303
+ risk = RiskLevel.HEALTHY
304
+
305
+ if not isnt.rule_satisfied:
306
+ warnings.append("ISNT rule violated β€” neuro-retinal rim thinning detected.")
307
+ if risk == RiskLevel.HEALTHY:
308
+ risk = RiskLevel.SUSPECT
309
+
310
+ return risk, warnings
311
+
312
+
313
+ # ─────────────────────────────────────────────────────────────────────────────
314
+ # Main pipeline entry point
315
+ # ─────────────────────────────────────────────────────────────────────────────
316
+
317
+ def run_clinical_pipeline(
318
+ disc_mask: np.ndarray,
319
+ cup_mask: np.ndarray,
320
+ uncertainty: float = 0.0,
321
+ uncertainty_threshold: float = 0.05,
322
+ ) -> ClinicalResult:
323
+ """
324
+ Execute complete Phase 3 clinical pipeline on binary masks.
325
+
326
+ Args:
327
+ disc_mask: uint8 binary array (1 = disc, 0 = background)
328
+ cup_mask: uint8 binary array (1 = cup, 0 = background)
329
+ uncertainty: scalar from Phase 2 MC-Dropout (ROI-restricted)
330
+ uncertainty_threshold: flag above this value as high uncertainty
331
+
332
+ Returns:
333
+ ClinicalResult dataclass
334
+ """
335
+ result = ClinicalResult()
336
+ result.uncertainty = float(uncertainty)
337
+ result.high_uncertainty = uncertainty > uncertainty_threshold
338
+
339
+ disc_mask = (disc_mask > 0).astype(np.uint8)
340
+ cup_mask = (cup_mask > 0).astype(np.uint8)
341
+
342
+ # ── Sanity checks ──────────────────────────────────────────────────
343
+ try:
344
+ run_sanity_checks(disc_mask, cup_mask)
345
+ result.sanity_passed = True
346
+ except SanityError as e:
347
+ result.warnings.append(f"SANITY FAIL: {e}")
348
+ result.sanity_passed = False
349
+ # Only abort if there is no disc at all β€” otherwise continue
350
+ # computing metrics on whatever geometry we have.
351
+ if disc_mask.sum() == 0:
352
+ result.risk_level = RiskLevel.SUSPECT
353
+ return result
354
+
355
+ # ── Structural measurements ────────────────────────────────────────
356
+ result.disc_area_px = int(disc_mask.sum())
357
+ result.cup_area_px = int(cup_mask.sum())
358
+
359
+ dy, dx = _disc_centroid(disc_mask)
360
+ result.disc_center = (int(dx), int(dy))
361
+
362
+ if cup_mask.sum() > 0:
363
+ M = cv2.moments(cup_mask)
364
+ if M['m00'] > 0:
365
+ result.cup_center = (
366
+ int(M['m10'] / M['m00']),
367
+ int(M['m01'] / M['m00']),
368
+ )
369
+
370
+ # ── vCDR ──────────────────────────────────────────────────────────
371
+ result.vcdr, _ = calculate_vcdr(disc_mask, cup_mask)
372
+
373
+ # ── ISNT (with tolerance margin) ──────────────────────────────────
374
+ result.isnt = calculate_isnt(disc_mask, cup_mask)
375
+
376
+ # ── Risk classification ─────���─────────────────────────────────────
377
+ result.risk_level, warnings = classify_risk(
378
+ result.vcdr, result.isnt, uncertainty, uncertainty_threshold
379
+ )
380
+ result.warnings.extend(warnings)
381
+
382
+ return result
phase3pipeline_v2.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 3 β€” Inference Pipeline
3
+ """
4
+
5
+ import os
6
+ import numpy as np
7
+ import cv2
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from typing import Optional, Dict, Any, Tuple
11
+
12
+ from model import UNet
13
+ from checkpoint_loader import load_model_for_inference
14
+ from clinical_metrics import run_clinical_pipeline, ClinicalResult
15
+
16
+ # ── Minimum plausible disc area in pixels (512Γ—512 image).
17
+ # Anything smaller is almost certainly not a real fundus image.
18
+ # A disc typically covers ~3–5 % of the 512Γ—512 canvas β‰ˆ 7,000–13,000 px.
19
+ # We gate at 1 % (β‰ˆ 2,600 px) to be conservative.
20
+ _MIN_DISC_AREA_PX = 2_600
21
+ _MIN_CUP_AREA_PX = 100 # below this the cup reading is meaningless
22
+ _MIN_CUP_DISC_RATIO = 0.01 # cup/disc area fraction β€” below = unreliable
23
+
24
+
25
+ class Phase3Pipeline:
26
+
27
+ def __init__(
28
+ self,
29
+ repo_id: str = "Nj-1111/EyeeSEE",
30
+ epoch: Optional[int] = None,
31
+ mc_passes: int = 20,
32
+ uncertainty_threshold: float = 0.05,
33
+ device: Optional[torch.device] = None,
34
+ token: Optional[str] = None,
35
+ debug: bool = False, # gate all debug I/O behind this flag
36
+ ):
37
+ self.repo_id = repo_id
38
+ self.mc_passes = mc_passes
39
+ self.uncertainty_threshold = uncertainty_threshold
40
+ self.debug = debug
41
+
42
+ self.device = device or torch.device(
43
+ 'cuda' if torch.cuda.is_available() else 'cpu'
44
+ )
45
+
46
+ self.token = (
47
+ token or
48
+ os.getenv('HF_TOKEN_2') or
49
+ os.getenv('HF_TOKEN')
50
+ )
51
+
52
+ # IMP: lower dropout improves stability
53
+ self.model = UNet(
54
+ in_channels=1,
55
+ n_classes=3,
56
+ base_filters=64,
57
+ dropout=0.1
58
+ )
59
+
60
+ load_model_for_inference(
61
+ model=self.model,
62
+ repo_id=repo_id,
63
+ epoch=epoch,
64
+ device=self.device,
65
+ token=self.token
66
+ )
67
+
68
+ # ──────────────────────────────────────────────────────────────────────
69
+ # preprocessing
70
+ # ──────────────────────────────────────────────────────────────────────
71
+
72
+ def _preprocess(self, image: np.ndarray) -> torch.Tensor:
73
+
74
+ if image.ndim == 3:
75
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
76
+
77
+ image = cv2.resize(
78
+ image,
79
+ (512, 512),
80
+ interpolation=cv2.INTER_AREA
81
+ )
82
+
83
+ image = image.astype(np.float32) / 255.0
84
+
85
+ return (
86
+ torch.from_numpy(image)
87
+ .unsqueeze(0)
88
+ .unsqueeze(0)
89
+ .to(self.device)
90
+ )
91
+
92
+ # ──────────────────────────────────────────────────────────────────────
93
+ # mask cleanup (separate functions β€” disc and cup have different scales)
94
+ # ──────────────────────────────────────────────────────────────────────
95
+
96
+ def _clean_disc(self, binary_mask: np.ndarray) -> np.ndarray:
97
+ """
98
+ Morphological cleanup for the disc mask.
99
+ Disc is large enough that a 5Γ—5 kernel is safe.
100
+ """
101
+ binary_mask = binary_mask.astype(np.uint8)
102
+ if binary_mask.sum() == 0:
103
+ return binary_mask
104
+
105
+ kernel = np.ones((5, 5), np.uint8)
106
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
107
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
108
+
109
+ n, labels, stats, _ = cv2.connectedComponentsWithStats(binary_mask)
110
+ if n <= 1:
111
+ return binary_mask
112
+ largest = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
113
+ cleaned = np.zeros_like(binary_mask)
114
+ cleaned[labels == largest] = 1
115
+ return cleaned
116
+
117
+ def _clean_cup(self, binary_mask: np.ndarray) -> np.ndarray:
118
+ """
119
+ Morphological cleanup for the cup mask.
120
+ Uses an adaptive kernel: after anatomical clipping the remaining cup
121
+ may be small β€” a 5Γ—5 open would erase it. We drop to 3Γ—3 for
122
+ small remnants.
123
+ """
124
+ binary_mask = binary_mask.astype(np.uint8)
125
+ if binary_mask.sum() == 0:
126
+ return binary_mask
127
+
128
+ k = 3 if binary_mask.sum() < 3_000 else 5
129
+ kernel = np.ones((k, k), np.uint8)
130
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, kernel)
131
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
132
+
133
+ n, labels, stats, _ = cv2.connectedComponentsWithStats(binary_mask)
134
+ if n <= 1:
135
+ return binary_mask
136
+ largest = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
137
+ cleaned = np.zeros_like(binary_mask)
138
+ cleaned[labels == largest] = 1
139
+ return cleaned
140
+
141
+ # ──────────────────────────────────────────────────────────────────────
142
+ # anatomical enforcement
143
+ # ──────────────────────────────────────────────────────────────────────
144
+
145
+ def _enforce_cup_in_disc(
146
+ self,
147
+ disc_mask: np.ndarray,
148
+ cup_mask: np.ndarray,
149
+ ) -> Tuple[np.ndarray, np.ndarray, int]:
150
+ """
151
+ Hard-clip cup to disc boundary.
152
+ Returns (disc_mask, corrected_cup, n_pixels_removed).
153
+ Pure logical AND β€” no morphology here.
154
+ """
155
+ corrected_cup = np.logical_and(cup_mask == 1, disc_mask == 1).astype(np.uint8)
156
+ violations = int(cup_mask.sum()) - int(corrected_cup.sum())
157
+ return disc_mask, corrected_cup, violations
158
+
159
+ # ──────────────────────────────────────────────────────────────────────
160
+ # MC-Dropout segmentation
161
+ # ──────────────────────────────────────────────────────────────────────
162
+
163
+ def _mc_segment(
164
+ self,
165
+ tensor: torch.Tensor,
166
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
167
+ """
168
+ Run MC-Dropout inference and return:
169
+ disc_mask (raw, thresholded, NOT yet cleaned)
170
+ cup_mask (raw, thresholded, NOT yet cleaned)
171
+ var_probs (3, H, W) β€” full variance map, used for ROI uncertainty later
172
+
173
+ Masks are intentionally returned uncleaned so that the caller can
174
+ enforce topology FIRST and clean AFTER (see run()).
175
+ """
176
+ # BatchNorm stays frozen; only dropout layers go to train mode
177
+ self.model.eval()
178
+ for m in self.model.modules():
179
+ if isinstance(m, (torch.nn.Dropout, torch.nn.Dropout2d, torch.nn.Dropout3d)):
180
+ m.train()
181
+
182
+ all_probs = []
183
+ with torch.no_grad():
184
+ for _ in range(self.mc_passes):
185
+ logits = self.model(tensor)
186
+ probs = F.softmax(logits, dim=1)
187
+ all_probs.append(probs.cpu().numpy())
188
+
189
+ self.model.eval()
190
+
191
+ all_probs = np.stack(all_probs, axis=0) # (T, 1, 3, H, W)
192
+ mean_probs = all_probs.mean(axis=0)[0] # (3, H, W)
193
+ var_probs = all_probs.var(axis=0)[0] # (3, H, W)
194
+
195
+ # ── Separate thresholds β€” critical fix ─────────────────────────
196
+ # Cup probability maps are more diffuse than disc maps.
197
+ # Using 0.25 for both causes the cup to spill far outside the disc;
198
+ # the anatomical clipping then wipes it out entirely.
199
+ # A higher cup threshold trims diffuse edges back inside the disc.
200
+ disc_mask = (mean_probs[1] > 0.35).astype(np.uint8)
201
+ cup_mask = (mean_probs[2] > 0.55).astype(np.uint8)
202
+
203
+ if self.debug:
204
+ cv2.imwrite("disc_prob_debug.png", (mean_probs[1] * 255).astype(np.uint8))
205
+ cv2.imwrite("cup_prob_debug.png", (mean_probs[2] * 255).astype(np.uint8))
206
+ print(f"[debug] raw disc px={disc_mask.sum()} raw cup px={cup_mask.sum()}")
207
+
208
+ return disc_mask, cup_mask, var_probs
209
+
210
+ # ──────────────────────────────────────────────────────────────────────
211
+ # uncertainty over anatomical ROI
212
+ # ──────────────────────────────────────────────────────────────────────
213
+
214
+ @staticmethod
215
+ def _roi_uncertainty(
216
+ var_probs: np.ndarray, # (3, H, W)
217
+ disc_mask: np.ndarray, # (H, W)
218
+ cup_mask: np.ndarray, # (H, W)
219
+ ) -> float:
220
+ """
221
+ Mean MC-Dropout variance restricted to the disc+cup region.
222
+
223
+ Averaging variance over the WHOLE image (including background)
224
+ suppresses clinically important local uncertainty because the model
225
+ is very confident about the large background class.
226
+ Restricting to the anatomical ROI makes the score meaningful.
227
+ """
228
+ roi = (disc_mask == 1) | (cup_mask == 1)
229
+ if not roi.any():
230
+ # No anatomical region found β€” fall back to global (will be high)
231
+ return float(var_probs[1:].mean())
232
+ # Channels 1 (disc) and 2 (cup) variance, sampled at ROI pixels
233
+ return float(var_probs[1:, roi].mean())
234
+
235
+ # ──────────────────────────────────────────────────────────────────────
236
+ # public API
237
+ # ──────────────────────────────────────────────────────────────────────
238
+
239
+ def run(self, image: np.ndarray) -> Dict[str, Any]:
240
+ """
241
+ Full pipeline. Processing order (corrected):
242
+
243
+ 1. threshold (separate disc / cup thresholds)
244
+ 2. enforce topology ← BEFORE any morphology
245
+ 3. clean disc + clean cup (separate kernels)
246
+ 4. enforce topology again ← morphology can reintroduce violations
247
+ 5. compute ROI uncertainty
248
+ 6. minimum-cup sanity gate
249
+ 7. clinical metrics
250
+ """
251
+ tensor = self._preprocess(image)
252
+
253
+ # Step 1+2: raw threshold β†’ first topology enforcement
254
+ # Cup is intentionally NOT cleaned yet so morphology doesn't expand
255
+ # it past the disc before the AND clip.
256
+ disc_mask, cup_mask, var_probs = self._mc_segment(tensor)
257
+
258
+ disc_mask, cup_mask, violations = self._enforce_cup_in_disc(disc_mask, cup_mask)
259
+
260
+ # Step 3: clean both masks (cup with adaptive small kernel)
261
+ disc_mask = self._clean_disc(disc_mask)
262
+ cup_mask = self._clean_cup(cup_mask)
263
+
264
+ # Step 4: second enforcement β€” morphology can re-expand cup slightly
265
+ # outside a cleaned (potentially slightly shrunk) disc
266
+ disc_mask, cup_mask, extra_violations = self._enforce_cup_in_disc(disc_mask, cup_mask)
267
+ violations += extra_violations
268
+
269
+ # Step 5: uncertainty over anatomical ROI (not the whole image)
270
+ uncertainty = self._roi_uncertainty(var_probs, disc_mask, cup_mask)
271
+
272
+ if self.debug:
273
+ print(
274
+ f"[debug] post-clean disc px={disc_mask.sum()} "
275
+ f"cup px={cup_mask.sum()} "
276
+ f"violations={violations} "
277
+ f"roi_uncertainty={uncertainty:.6f}"
278
+ )
279
+
280
+ # ── Step 6: minimum-cup sanity gate ───────────────────────────
281
+ # A tiny cup remnant (e.g. a few dozen pixels surviving morphology)
282
+ # produces a meaningless vCDR. Zero it out and warn instead.
283
+ disc_area = int(disc_mask.sum())
284
+ cup_area = int(cup_mask.sum())
285
+ extra_warnings: list[str] = []
286
+
287
+ if disc_area < _MIN_DISC_AREA_PX:
288
+ extra_warnings.append(
289
+ f"Disc too small ({disc_area} px) β€” image may not be a fundus photo. "
290
+ "Segmentation result is unreliable."
291
+ )
292
+
293
+ if cup_area > 0 and cup_area < _MIN_CUP_AREA_PX:
294
+ extra_warnings.append(
295
+ f"Cup remnant too small ({cup_area} px) β€” suppressed. "
296
+ "Cup segmentation is unreliable for this image."
297
+ )
298
+ cup_mask = np.zeros_like(cup_mask)
299
+ cup_area = 0
300
+
301
+ if cup_area > 0 and disc_area > 0:
302
+ if (cup_area / disc_area) < _MIN_CUP_DISC_RATIO:
303
+ extra_warnings.append(
304
+ f"Cup/disc area ratio ({cup_area}/{disc_area} = "
305
+ f"{cup_area/disc_area:.3f}) is below minimum β€” "
306
+ "cup reading may be unreliable."
307
+ )
308
+
309
+ # Step 7: clinical pipeline
310
+ clinical = run_clinical_pipeline(
311
+ disc_mask=disc_mask,
312
+ cup_mask=cup_mask,
313
+ uncertainty=uncertainty,
314
+ uncertainty_threshold=self.uncertainty_threshold
315
+ )
316
+
317
+ # Inject all accumulated warnings
318
+ if violations > 0:
319
+ clinical.warnings.append(
320
+ f"Mask corrected: {violations} cup pixels clipped to disc boundary."
321
+ )
322
+ clinical.warnings.extend(extra_warnings)
323
+
324
+ report = {
325
+ 'vcdr': clinical.vcdr,
326
+ 'isnt': clinical.isnt.to_dict(),
327
+ 'risk_level': clinical.risk_level.value,
328
+ 'uncertainty': round(uncertainty, 6),
329
+ 'high_uncertainty': clinical.high_uncertainty,
330
+ 'disc_area_px': clinical.disc_area_px,
331
+ 'cup_area_px': clinical.cup_area_px,
332
+ 'disc_center': clinical.disc_center,
333
+ 'cup_center': clinical.cup_center,
334
+ 'sanity_passed': clinical.sanity_passed,
335
+ 'warnings': clinical.warnings,
336
+ }
337
+
338
+ return {
339
+ 'disc_mask': disc_mask,
340
+ 'cup_mask': cup_mask,
341
+ 'uncertainty': uncertainty,
342
+ 'clinical': clinical,
343
+ 'report': report,
344
+ }
345
+
346
+ def run_from_path(self, image_path: str) -> Dict[str, Any]:
347
+
348
+ image = cv2.imread(image_path)
349
+ if image is None:
350
+ raise FileNotFoundError(f"Cannot load image: {image_path}")
351
+ return self.run(image)