| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import timm |
|
|
| class EfficientNetEncoder(nn.Module): |
| def __init__(self, model_name='efficientnet_b3', pretrained=False): |
| super().__init__() |
| self.backbone = timm.create_model(model_name, pretrained=pretrained, features_only=True) |
| self.feature_channels = self.backbone.feature_info.channels() |
| |
| def forward(self, x): |
| return self.backbone(x) |
|
|
|
|
| class SegFormerDecoder(nn.Module): |
| def __init__(self, encoder_channels, num_classes=1, decoder_dim=256): |
| super().__init__() |
| self.proj = nn.ModuleList() |
| for in_ch in encoder_channels: |
| self.proj.append( |
| nn.Sequential( |
| nn.Conv2d(in_ch, decoder_dim, kernel_size=1), |
| nn.BatchNorm2d(decoder_dim), |
| nn.ReLU(inplace=True) |
| ) |
| ) |
| self.fusion = nn.Sequential( |
| nn.Conv2d(decoder_dim * len(encoder_channels), decoder_dim, kernel_size=1), |
| nn.BatchNorm2d(decoder_dim), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(decoder_dim, decoder_dim, kernel_size=3, padding=1), |
| nn.BatchNorm2d(decoder_dim), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(decoder_dim, num_classes, kernel_size=1) |
| ) |
| |
| def forward(self, features): |
| target_size = features[0].shape[2] |
| projected = [] |
| for i, feat in enumerate(features): |
| proj = self.proj[i](feat) |
| if proj.shape[2] != target_size: |
| proj = F.interpolate(proj, size=(target_size, target_size), |
| mode='bilinear', align_corners=False) |
| projected.append(proj) |
| concat = torch.cat(projected, dim=1) |
| out = self.fusion(concat) |
| out = F.interpolate(out, size=(224, 224), mode='bilinear', align_corners=False) |
| return out |
|
|
|
|
| class LungUltrasoundModel(nn.Module): |
| def __init__(self, num_classes=3, num_seg_classes=1): |
| super().__init__() |
| self.encoder = EfficientNetEncoder('efficientnet_b3', pretrained=False) |
| encoder_channels = self.encoder.feature_channels |
| |
| self.classifier = nn.Sequential( |
| nn.AdaptiveAvgPool2d(1), |
| nn.Flatten(), |
| nn.Linear(encoder_channels[-1], 512), |
| nn.ReLU(inplace=True), |
| nn.BatchNorm1d(512), |
| nn.Dropout(0.5), |
| nn.Linear(512, 256), |
| nn.ReLU(inplace=True), |
| nn.Dropout(0.3), |
| nn.Linear(256, num_classes) |
| ) |
| |
| self.seg_decoder = SegFormerDecoder(encoder_channels, num_seg_classes, 256) |
| |
| def forward(self, x): |
| features = self.encoder(x) |
| class_out = self.classifier(features[-1]) |
| seg_out = self.seg_decoder(features) |
| return class_out, seg_out |