LUS / model.py
huzie111's picture
Upload model.py
8222e7f verified
Raw
History Blame Contribute Delete
2.9 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
class EfficientNetEncoder(nn.Module):
def __init__(self, model_name='efficientnet_b3', pretrained=False):
super().__init__()
self.backbone = timm.create_model(model_name, pretrained=pretrained, features_only=True)
self.feature_channels = self.backbone.feature_info.channels()
def forward(self, x):
return self.backbone(x)
class SegFormerDecoder(nn.Module):
def __init__(self, encoder_channels, num_classes=1, decoder_dim=256):
super().__init__()
self.proj = nn.ModuleList()
for in_ch in encoder_channels:
self.proj.append(
nn.Sequential(
nn.Conv2d(in_ch, decoder_dim, kernel_size=1),
nn.BatchNorm2d(decoder_dim),
nn.ReLU(inplace=True)
)
)
self.fusion = nn.Sequential(
nn.Conv2d(decoder_dim * len(encoder_channels), decoder_dim, kernel_size=1),
nn.BatchNorm2d(decoder_dim),
nn.ReLU(inplace=True),
nn.Conv2d(decoder_dim, decoder_dim, kernel_size=3, padding=1),
nn.BatchNorm2d(decoder_dim),
nn.ReLU(inplace=True),
nn.Conv2d(decoder_dim, num_classes, kernel_size=1)
)
def forward(self, features):
target_size = features[0].shape[2]
projected = []
for i, feat in enumerate(features):
proj = self.proj[i](feat)
if proj.shape[2] != target_size:
proj = F.interpolate(proj, size=(target_size, target_size),
mode='bilinear', align_corners=False)
projected.append(proj)
concat = torch.cat(projected, dim=1)
out = self.fusion(concat)
out = F.interpolate(out, size=(224, 224), mode='bilinear', align_corners=False)
return out
class LungUltrasoundModel(nn.Module):
def __init__(self, num_classes=3, num_seg_classes=1):
super().__init__()
self.encoder = EfficientNetEncoder('efficientnet_b3', pretrained=False)
encoder_channels = self.encoder.feature_channels
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(encoder_channels[-1], 512),
nn.ReLU(inplace=True),
nn.BatchNorm1d(512),
nn.Dropout(0.5),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
self.seg_decoder = SegFormerDecoder(encoder_channels, num_seg_classes, 256)
def forward(self, x):
features = self.encoder(x)
class_out = self.classifier(features[-1])
seg_out = self.seg_decoder(features)
return class_out, seg_out