File size: 4,350 Bytes
f9d5179 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | """
Drop this file next to phase3pipeline.py and run it once on any fundus image.
It prints the raw probability statistics and saves visual debug images
so you can see exactly what the model is predicting before any thresholding.
Usage:
python diagnose_masks.py path/to/fundus.jpg
"""
import sys
import os
import numpy as np
import cv2
import torch
import torch.nn.functional as F
# ββ adjust these imports to match your project layout ββββββββββββββββββ
from model import UNet
from checkpoint_loader import load_model_for_inference
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
REPO_ID = "Nj-1111/EyeeSEE"
TOKEN = os.getenv("HF_TOKEN_2") or os.getenv("HF_TOKEN")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
model = UNet(in_channels=1, n_classes=3, base_filters=64, dropout=0.1)
load_model_for_inference(
model=model, repo_id=REPO_ID, epoch=None,
device=DEVICE, token=TOKEN
)
model.eval()
return model
def preprocess(path):
img = cv2.imread(path)
if img is None:
raise FileNotFoundError(path)
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray = cv2.resize(gray, (512, 512), interpolation=cv2.INTER_AREA)
t = torch.from_numpy(gray.astype(np.float32) / 255.0)
return t.unsqueeze(0).unsqueeze(0).to(DEVICE)
def run(image_path):
model = load_model()
tensor = preprocess(image_path)
# single deterministic pass (no MC dropout) β we want the raw probs
with torch.no_grad():
logits = model(tensor)
probs = F.softmax(logits, dim=1).cpu().numpy()[0] # (3, H, W)
bg = probs[0] # background
disc = probs[1] # optic disc
cup = probs[2] # optic cup
print("=" * 56)
print(" RAW PROBABILITY DIAGNOSTICS")
print("=" * 56)
for name, ch in [("background", bg), ("disc ", disc), ("cup ", cup)]:
print(
f" {name} "
f"min={ch.min():.4f} "
f"max={ch.max():.4f} "
f"mean={ch.mean():.4f} "
f"median={np.median(ch):.4f}"
)
print()
# How many pixels survive each threshold?
print(" Disc pixels above threshold:")
for t in [0.10, 0.20, 0.25, 0.30, 0.35, 0.40, 0.50]:
print(f" > {t:.2f} β {(disc > t).sum():6d} px")
print()
print(" Cup pixels above threshold:")
for t in [0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.50, 0.55, 0.60]:
print(f" > {t:.2f} β {(cup > t).sum():6d} px")
print()
# Spatial overlap β at each threshold pair, how many cup pixels are
# INSIDE the disc mask? This tells you whether the predictions are
# spatially aligned.
print(" Spatial alignment (cup pixels inside disc) at disc>0.35:")
disc_bin = (disc > 0.35).astype(np.uint8)
for t in [0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.50, 0.55]:
cup_bin = (cup > t).astype(np.uint8)
inside = int((cup_bin & disc_bin).sum())
total = int(cup_bin.sum())
pct = (inside / total * 100) if total > 0 else 0.0
print(
f" cup>{t:.2f} total={total:5d} "
f"inside_disc={inside:5d} ({pct:.1f}%)"
)
print("=" * 56)
# Save visual probability maps (scaled 0-255)
cv2.imwrite("debug_bg_prob.png", (bg * 255).astype(np.uint8))
cv2.imwrite("debug_disc_prob.png", (disc * 255).astype(np.uint8))
cv2.imwrite("debug_cup_prob.png", (cup * 255).astype(np.uint8))
# Save side-by-side composite
disc_vis = cv2.applyColorMap((disc * 255).astype(np.uint8), cv2.COLORMAP_HOT)
cup_vis = cv2.applyColorMap((cup * 255).astype(np.uint8), cv2.COLORMAP_COOL)
composite = np.hstack([disc_vis, cup_vis])
cv2.imwrite("debug_composite.png", composite)
print()
print(" Saved: debug_disc_prob.png debug_cup_prob.png debug_composite.png")
print(" Open these images to see where the model is predicting disc and cup.")
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python diagnose_masks.py <path_to_fundus_image>")
sys.exit(1)
run(sys.argv[1])
|