File size: 3,404 Bytes
a78ad5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
import torch.nn.functional as F


class DiceFocalLoss(nn.Module):
    """Combined Dice + Focal loss for class-imbalanced binary segmentation.

    Focal loss down-weights easy negatives, forcing the network to focus on
    hard/uncertain pixels — the main failure mode when plateauing near 0.1.

    Args:
        alpha:        Focal weighting factor for positive class (0.25 typical).
        gamma:        Focal modulating exponent.  Higher = more focus on hard
                      pixels.  Tune in [0.5, 5.0] via Optuna.
        dice_weight:  Weight of the Dice component.
        focal_weight: Weight of the Focal component.
        smooth:       Laplace smoothing for Dice denominator.
    """

    def __init__(
        self,
        alpha: float = 0.25,
        gamma: float = 2.0,
        dice_weight: float = 0.5,
        focal_weight: float = 0.5,
        smooth: float = 1e-6,
    ) -> None:
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight
        self.smooth = smooth

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        # ---- Focal part ------------------------------------------------
        bce = F.binary_cross_entropy_with_logits(y_pred, y_true.float(), reduction="none")
        pt = torch.exp(-bce)
        focal = self.alpha * (1.0 - pt) ** self.gamma * bce
        focal_loss = focal.mean()

        # ---- Dice part -------------------------------------------------
        pred_sig = torch.sigmoid(y_pred)
        inter = (pred_sig * y_true).sum(dim=(2, 3))
        dice_loss = 1.0 - (2.0 * inter + self.smooth) / (
            pred_sig.sum(dim=(2, 3)) + y_true.sum(dim=(2, 3)) + self.smooth
        )
        dice_loss = dice_loss.mean()

        return self.dice_weight * dice_loss + self.focal_weight * focal_loss


class DiceLoss(nn.Module):

    def __init__(self, smooth: float = 1.0) -> None:
        super().__init__()
        self.smooth = smooth

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        assert y_pred.size() == y_true.size()
        y_pred = y_pred[:, 0].contiguous().view(-1)
        y_true = y_true[:, 0].contiguous().view(-1)
        intersection = (y_pred * y_true).sum()
        dsc = (2.0 * intersection + self.smooth) / (
            y_pred.sum() + y_true.sum() + self.smooth
        )
        return 1.0 - dsc


class BCEDiceLoss(nn.Module):

    def __init__(
        self,
        smooth: float = 1.0,
        bce_weight: float = 0.5,
        dice_weight: float = 0.5,
        label_smoothing: float = 0.0,
    ) -> None:
        super().__init__()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight
        self.label_smoothing = label_smoothing
        self.dice = DiceLoss(smooth=smooth)

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        if self.label_smoothing > 0.0:
            # Smooth labels towards 0.5: prevents overconfident BCE
            y_bce = y_true * (1.0 - self.label_smoothing) + self.label_smoothing * 0.5
        else:
            y_bce = y_true
        bce = F.binary_cross_entropy_with_logits(y_pred, y_bce)
        dice = self.dice(torch.sigmoid(y_pred), y_true)
        return self.bce_weight * bce + self.dice_weight * dice