EyeeSEE / visualise.py
Nj-1111's picture
Upload 9 files
f9b628d verified
Raw
History Blame
7.24 kB
"""
Phase 3 β€” Visualisation
"""
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.gridspec import GridSpec
from typing import Optional, Dict, Any
DISC_COLOUR_BGR = (0, 215, 255)
CUP_COLOUR_BGR = (0, 255, 128)
DISC_COLOUR_RGB = (255, 215, 0)
CUP_COLOUR_RGB = (0, 255, 128)
RISK_COLOURS = {
'Healthy': '#2ECC71',
'Glaucoma Suspect': '#F39C12',
'High Risk': '#E74C3C',
}
def draw_segmentation_overlay(
image: np.ndarray,
disc_mask: np.ndarray,
cup_mask: np.ndarray,
alpha: float = 0.35
) -> np.ndarray:
if image.ndim == 2:
overlay = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
else:
overlay = image.copy()
for mask, colour in [(disc_mask, DISC_COLOUR_BGR), (cup_mask, CUP_COLOUR_BGR)]:
fill = overlay.copy()
fill[mask == 1] = colour
cv2.addWeighted(fill, alpha, overlay, 1 - alpha, 0, overlay)
contours, _ = cv2.findContours(
mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
cv2.drawContours(overlay, contours, -1, colour, 2)
return overlay
def create_clinical_report_figure(
image: np.ndarray,
disc_mask: np.ndarray,
cup_mask: np.ndarray,
result: Dict[str, Any],
save_path: Optional[str] = None
) -> plt.Figure:
report = result['report']
if image.ndim == 2:
display_img = image
cmap_orig = 'gray'
else:
display_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
cmap_orig = None
overlay_rgb = cv2.cvtColor(
draw_segmentation_overlay(image, disc_mask, cup_mask),
cv2.COLOR_BGR2RGB
)
fig = plt.figure(figsize=(16, 10), facecolor='#1A1A2E')
gs = GridSpec(2, 3, figure=fig, hspace=0.40, wspace=0.30)
title_kw = dict(color='white', fontsize=11, fontweight='bold', pad=8)
# Panel 0 β€” original
ax0 = fig.add_subplot(gs[0, 0])
ax0.imshow(display_img, cmap=cmap_orig)
ax0.set_title('Original Image', **title_kw)
ax0.axis('off')
# Panel 1 β€” overlay
ax1 = fig.add_subplot(gs[0, 1])
ax1.imshow(overlay_rgb)
ax1.legend(
handles=[
mpatches.Patch(color=np.array(DISC_COLOUR_RGB)/255, label='Disc'),
mpatches.Patch(color=np.array(CUP_COLOUR_RGB)/255, label='Cup'),
],
loc='lower right', fontsize=8,
facecolor='#1A1A2E', labelcolor='white'
)
ax1.set_title('Segmentation Overlay', **title_kw)
ax1.axis('off')
# Panel 2 β€” ISNT bars
ax2 = fig.add_subplot(gs[0, 2])
isnt = report['isnt']
keys = ['Inferior', 'Superior', 'Nasal', 'Temporal']
values = [isnt['inferior'], isnt['superior'], isnt['nasal'], isnt['temporal']]
colour = '#2ECC71' if isnt['rule_satisfied'] else '#E74C3C'
ax2.bar(keys, values, color=colour, edgecolor='white', linewidth=0.5)
ax2.set_facecolor('#1A1A2E')
ax2.tick_params(colors='#CCCCCC', labelsize=8)
ax2.spines[:].set_color('#444')
rule_txt = 'βœ“ ISNT Satisfied' if isnt['rule_satisfied'] else 'βœ— ISNT Violated'
ax2.set_title(f'ISNT Rim Thickness ({rule_txt})', color=colour,
fontsize=10, fontweight='bold', pad=8)
# Panel 3 β€” vCDR gauge
ax3 = fig.add_subplot(gs[1, 0])
ax3.set_facecolor('#1A1A2E')
vcdr = report['vcdr']
theta = np.linspace(np.pi, 0, 300)
ax3.plot(np.cos(theta), np.sin(theta), lw=12, color='#444', solid_capstyle='round')
for lo, hi, col in [(0.0, 0.65, '#2ECC71'), (0.65, 0.80, '#F39C12'), (0.80, 1.0, '#E74C3C')]:
seg = np.linspace(np.pi*(1-lo), np.pi*(1-hi), 100)
ax3.plot(np.cos(seg), np.sin(seg), lw=12, color=col, solid_capstyle='butt')
angle = np.pi * (1 - vcdr)
ax3.annotate('', xy=(0.7*np.cos(angle), 0.7*np.sin(angle)), xytext=(0, 0),
arrowprops=dict(arrowstyle='->', color='white', lw=2.5))
ax3.text(0, -0.20, f'vCDR = {vcdr:.2f}', ha='center', fontsize=14,
fontweight='bold', color='white')
ax3.text(-0.9, -0.25, 'Healthy', fontsize=7, color='#2ECC71', ha='center')
ax3.text(0, -0.40, 'Suspect', fontsize=7, color='#F39C12', ha='center')
ax3.text(0.9, -0.25, 'High', fontsize=7, color='#E74C3C', ha='center')
ax3.set_xlim(-1.2, 1.2)
ax3.set_ylim(-0.6, 1.2)
ax3.set_title('Vertical CDR', **title_kw)
ax3.axis('off')
# Panel 4 β€” risk badge
ax4 = fig.add_subplot(gs[1, 1])
ax4.set_facecolor('#1A1A2E')
ax4.axis('off')
risk_str = report['risk_level']
risk_col = RISK_COLOURS.get(risk_str, '#888')
ax4.add_patch(plt.Rectangle((0.05, 0.55), 0.90, 0.35,
facecolor=risk_col, alpha=0.25,
transform=ax4.transAxes, clip_on=False))
ax4.text(0.50, 0.73, risk_str, ha='center', va='center',
fontsize=14, fontweight='bold', color=risk_col,
transform=ax4.transAxes)
unc_col = '#E74C3C' if report['high_uncertainty'] else '#2ECC71'
ax4.text(0.50, 0.46, f"Uncertainty: {report['uncertainty']:.4f}",
ha='center', fontsize=10, color=unc_col, transform=ax4.transAxes)
for i, w in enumerate(report['warnings'][:4]):
ax4.text(0.05, 0.36 - i*0.09, f'β€’ {w}',
fontsize=7.5, color='#FFD700', transform=ax4.transAxes)
if not report['warnings']:
ax4.text(0.50, 0.25, 'No clinical warnings', ha='center',
fontsize=9, color='#888', transform=ax4.transAxes)
ax4.set_title('Screening Result', **title_kw)
# Panel 5 β€” stats table
ax5 = fig.add_subplot(gs[1, 2])
ax5.set_facecolor('#1A1A2E')
ax5.axis('off')
ax5.set_xlim(0, 1)
ax5.set_ylim(0, 1)
rows = [
('Disc Area', f"{report['disc_area_px']:,} px"),
('Cup Area', f"{report['cup_area_px']:,} px"),
('Cup/Disc %', f"{report['cup_area_px']/max(report['disc_area_px'],1)*100:.1f} %"),
('Disc Centre', str(report['disc_center'])),
('Cup Centre', str(report['cup_center'])),
('Sanity', 'βœ“ Passed' if report['sanity_passed'] else 'βœ— Failed'),
('MC Passes', '20'),
]
for i, (label, value) in enumerate(rows):
y = 0.88 - i * 0.12
ax5.text(0.05, y, label, fontsize=9, color='#AAAAAA', transform=ax5.transAxes)
ax5.text(0.55, y, value, fontsize=9, color='white', transform=ax5.transAxes,
fontweight='bold')
# Use ax5.plot instead of axhline so transform=transAxes is valid
ax5.plot([0.02, 0.98], [y - 0.04, y - 0.04],
color='#333', linewidth=0.5, transform=ax5.transAxes)
ax5.set_title('Structural Statistics', **title_kw)
fig.suptitle('Glaucoma CDSS β€” Phase 3 Clinical Report',
color='white', fontsize=14, fontweight='bold', y=0.98)
fig.text(0.5, 0.01,
'RESEARCH PROTOTYPE β€” NOT A MEDICAL DEVICE. '
'Results must be validated by a qualified ophthalmologist.',
ha='center', fontsize=7.5, color='#666666', style='italic')
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches='tight',
facecolor=fig.get_facecolor())
return fig