import torch import torch.nn as nn import torch.nn.functional as F class DoubleConv(nn.Module): def __init__(self, in_ch: int, out_ch: int): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) def forward(self, x): return self.conv(x) class Down(nn.Module): def __init__(self, in_ch: int, out_ch: int): super().__init__() self.pool_conv = nn.Sequential( nn.MaxPool2d(2), DoubleConv(in_ch, out_ch), ) def forward(self, x): return self.pool_conv(x) class Up(nn.Module): def __init__(self, in_ch: int, out_ch: int): super().__init__() self.up = nn.ConvTranspose2d(in_ch, in_ch // 2, 2, stride=2) self.conv = DoubleConv(in_ch, out_ch) def forward(self, x1, x2): x1 = self.up(x1) dy = x2.size(2) - x1.size(2) dx = x2.size(3) - x1.size(3) x1 = F.pad(x1, [dx // 2, dx - dx // 2, dy // 2, dy - dy // 2]) return self.conv(torch.cat([x2, x1], dim=1)) class UNet(nn.Module): def __init__(self, in_channels: int = 1, n_classes: int = 3, base_filters: int = 64, dropout: float = 0.1): super().__init__() f = base_filters self.inc = DoubleConv(in_channels, f) self.down1 = Down(f, f * 2) self.down2 = Down(f * 2, f * 4) self.down3 = Down(f * 4, f * 8) self.down4 = Down(f * 8, f * 16) self.drop = nn.Dropout2d(dropout) self.up1 = Up(f * 16, f * 8) self.up2 = Up(f * 8, f * 4) self.up3 = Up(f * 4, f * 2) self.up4 = Up(f * 2, f) self.outc = nn.Conv2d(f, n_classes, 1) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x): x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.drop(self.down4(x4)) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) return self.outc(x) class DiceLoss(nn.Module): def __init__(self, smooth: float = 1.0): super().__init__() self.smooth = smooth def forward(self, pred, target): pred = F.softmax(pred, dim=1) oh = F.one_hot(target, pred.size(1)).permute(0, 3, 1, 2).float() inter = (pred * oh).sum(dim=(2, 3)) union = pred.sum(dim=(2, 3)) + oh.sum(dim=(2, 3)) return 1.0 - ((2.0 * inter + self.smooth) / (union + self.smooth)).mean() class CombinedLoss(nn.Module): """ 0.5 * Dice + 0.5 * weighted CrossEntropy Weights: background=0.1 disc=1.5 cup=3.0 """ def __init__(self): super().__init__() self.dice = DiceLoss() self._w = torch.tensor([0.2, 1.5, 6.0]) self.ce = nn.CrossEntropyLoss(weight=self._w) def to(self, device): super().to(device) self._w = self._w.to(device) self.ce = nn.CrossEntropyLoss(weight=self._w) return self def forward(self, pred, target): return 0.5 * self.dice(pred, target) + 0.5 * self.ce(pred, target) def calculate_dice(pred, target) -> float: import numpy as np p = pred.astype(bool) t = target.astype(bool) return float((2.0 * (p & t).sum() + 1e-5) / (p.sum() + t.sum() + 1e-5))