File size: 3,404 Bytes
a78ad5d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 | import torch
import torch.nn as nn
import torch.nn.functional as F
class DiceFocalLoss(nn.Module):
"""Combined Dice + Focal loss for class-imbalanced binary segmentation.
Focal loss down-weights easy negatives, forcing the network to focus on
hard/uncertain pixels — the main failure mode when plateauing near 0.1.
Args:
alpha: Focal weighting factor for positive class (0.25 typical).
gamma: Focal modulating exponent. Higher = more focus on hard
pixels. Tune in [0.5, 5.0] via Optuna.
dice_weight: Weight of the Dice component.
focal_weight: Weight of the Focal component.
smooth: Laplace smoothing for Dice denominator.
"""
def __init__(
self,
alpha: float = 0.25,
gamma: float = 2.0,
dice_weight: float = 0.5,
focal_weight: float = 0.5,
smooth: float = 1e-6,
) -> None:
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.dice_weight = dice_weight
self.focal_weight = focal_weight
self.smooth = smooth
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
# ---- Focal part ------------------------------------------------
bce = F.binary_cross_entropy_with_logits(y_pred, y_true.float(), reduction="none")
pt = torch.exp(-bce)
focal = self.alpha * (1.0 - pt) ** self.gamma * bce
focal_loss = focal.mean()
# ---- Dice part -------------------------------------------------
pred_sig = torch.sigmoid(y_pred)
inter = (pred_sig * y_true).sum(dim=(2, 3))
dice_loss = 1.0 - (2.0 * inter + self.smooth) / (
pred_sig.sum(dim=(2, 3)) + y_true.sum(dim=(2, 3)) + self.smooth
)
dice_loss = dice_loss.mean()
return self.dice_weight * dice_loss + self.focal_weight * focal_loss
class DiceLoss(nn.Module):
def __init__(self, smooth: float = 1.0) -> None:
super().__init__()
self.smooth = smooth
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
assert y_pred.size() == y_true.size()
y_pred = y_pred[:, 0].contiguous().view(-1)
y_true = y_true[:, 0].contiguous().view(-1)
intersection = (y_pred * y_true).sum()
dsc = (2.0 * intersection + self.smooth) / (
y_pred.sum() + y_true.sum() + self.smooth
)
return 1.0 - dsc
class BCEDiceLoss(nn.Module):
def __init__(
self,
smooth: float = 1.0,
bce_weight: float = 0.5,
dice_weight: float = 0.5,
label_smoothing: float = 0.0,
) -> None:
super().__init__()
self.bce_weight = bce_weight
self.dice_weight = dice_weight
self.label_smoothing = label_smoothing
self.dice = DiceLoss(smooth=smooth)
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
if self.label_smoothing > 0.0:
# Smooth labels towards 0.5: prevents overconfident BCE
y_bce = y_true * (1.0 - self.label_smoothing) + self.label_smoothing * 0.5
else:
y_bce = y_true
bce = F.binary_cross_entropy_with_logits(y_pred, y_bce)
dice = self.dice(torch.sigmoid(y_pred), y_true)
return self.bce_weight * bce + self.dice_weight * dice |