coralbay / swin_unetr.py
kaiko-ai-user's picture
Upload 11 files
17825b3 verified
"""Encoder based on Swin UNETR."""
from typing import List, Literal, Tuple
import torch
from monai.inferers.inferer import Inferer
from monai.networks.blocks import unetr_block
from monai.networks.nets import swin_unetr
from monai.utils import misc
from torch import nn
from huggingface_hub import PyTorchModelHubMixin
class SwinUNETREncoder(nn.Module, PyTorchModelHubMixin):
"""Swin transformer encoder based on UNETR [0].
- [0] UNETR: Transformers for 3D Medical Image Segmentation
https://arxiv.org/pdf/2103.10504
"""
def __init__(
self,
in_channels: int = 1,
feature_size: int = 48,
spatial_dims: int = 3,
out_indices: int | None = None,
inferer: Inferer | None = None,
embeddings_type: Literal["multiscale", "head"] = "multiscale",
use_v2: bool = True,
) -> None:
"""Build the UNETR encoder.
Args:
in_channels: Number of input channels.
feature_size: The dimension of network feature size.
spatial_dims: Number of spatial dimensions.
out_indices: Number of feature outputs. If None,
the aggregated feature vector is returned.
inferer: An optional MONAI `Inferer` for efficient
inference during evaluation.
embeddings_type: Whether to use aggregated or head embeddings:
- `multiscale`: multi-scale aggregated representation
- `head`: last-stage (head) pooled representation
use_v2: Whether to use SwinTransformerV2.
"""
super().__init__()
self._in_channels = in_channels
self._feature_size = feature_size
self._spatial_dims = spatial_dims
self._out_indices = out_indices
self._inferer = inferer
self._embeddings_type = embeddings_type
self._use_v2 = use_v2
self._window_size = misc.ensure_tuple_rep(7, spatial_dims)
self._patch_size = misc.ensure_tuple_rep(2, spatial_dims)
self.swinViT = swin_unetr.SwinTransformer(
in_chans=in_channels,
embed_dim=feature_size,
window_size=self._window_size,
patch_size=self._patch_size,
depths=(2, 2, 2, 2),
num_heads=(3, 6, 12, 24),
mlp_ratio=4.0,
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=torch.nn.LayerNorm,
spatial_dims=spatial_dims,
use_v2=use_v2,
)
self.encoder1 = unetr_block.UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=feature_size,
kernel_size=3,
stride=1,
norm_name="instance",
res_block=True,
)
self.encoder2 = unetr_block.UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=feature_size,
out_channels=feature_size,
kernel_size=3,
stride=1,
norm_name="instance",
res_block=True,
)
self.encoder3 = unetr_block.UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=2 * feature_size,
out_channels=2 * feature_size,
kernel_size=3,
stride=1,
norm_name="instance",
res_block=True,
)
self.encoder4 = unetr_block.UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=4 * feature_size,
out_channels=4 * feature_size,
kernel_size=3,
stride=1,
norm_name="instance",
res_block=True,
)
self.encoder10 = unetr_block.UnetrBasicBlock(
spatial_dims=spatial_dims,
in_channels=16 * feature_size,
out_channels=16 * feature_size,
kernel_size=3,
stride=1,
norm_name="instance",
res_block=True,
)
self._pool_op = (
nn.AdaptiveAvgPool3d(output_size=(1, 1, 1))
if spatial_dims == 3
else nn.AdaptiveAvgPool2d(output_size=(1, 1))
)
def _forward_features(self, tensor: torch.Tensor) -> List[torch.Tensor]:
"""Extracts feature maps from the Swin Transformer and encoder blocks.
Args:
tensor: Input tensor of shape (B, C, T, H, W).
Returns:
List of feature maps from encoder stages.
"""
hidden_states = self.swinViT(tensor)
enc0 = self.encoder1(tensor)
enc1 = self.encoder2(hidden_states[0])
enc2 = self.encoder3(hidden_states[1])
enc3 = self.encoder4(hidden_states[2])
dec4 = self.encoder10(hidden_states[4])
return [enc0, enc1, enc2, enc3, hidden_states[3], dec4]
def forward_features(self, tensor: torch.Tensor) -> List[torch.Tensor]:
"""Computes feature maps using either standard forward pass or inference mode.
If in inference mode (`self.training` is False) and an inference method
(`self._inferer`) is available, the `_inferer` is used to extract features.
Otherwise, `_forward_features` is called directly.
Args:
tensor: Input tensor of shape (B, C, T, H, W).
Returns:
List of feature maps from encoder stages.
"""
if not self.training and self._inferer:
return self._inferer(inputs=tensor, network=self._forward_features)
return self._forward_features(tensor)
def forward_encoders(self, features: List[torch.Tensor]) -> torch.Tensor:
"""Aggregates encoder features into a single feature vector.
Args:
features: List of feature maps from encoder stages.
Returns:
Aggregated feature vector (B, C').
"""
batch_size = features[0].shape[0]
reduced_features = []
for patch_features in features[:4] + features[5:]:
hidden_features = self._pool_op(patch_features)
hidden_features_reduced = hidden_features.view(batch_size, -1)
reduced_features.append(hidden_features_reduced)
return torch.cat(reduced_features, dim=1)
def forward_head(self, features: List[torch.Tensor]) -> torch.Tensor:
"""Casts last feature map into a single feature vector.
Args:
features: List of encoder feature maps.
Returns:
Aggregated feature vector (B, C').
"""
last_feature_map = features[-1]
pooled_features = self._pool_op(last_feature_map)
return torch.flatten(pooled_features, 1)
def forward_embeddings(self, tensor: torch.Tensor) -> torch.Tensor:
"""Computes the final aggregated feature vector.
Args:
tensor: Input tensor of shape (B, C, T, H, W).
Returns:
Aggregated feature vector of shape (B, C').
"""
embeddings_to_forward_methods = {
"multiscale": self.forward_encoders,
"head": self.forward_head,
}
forward_method = embeddings_to_forward_methods.get(self._embeddings_type)
if forward_method is None:
raise ValueError(f"Unknown embeddings_type: {self._embeddings_type}")
intermediates = self.forward_features(tensor)
return forward_method(intermediates)
def forward_intermediates(
self, tensor: torch.Tensor
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Computes encoder features and their embeddings.
Args:
tensor: Input tensor of shape (B, C, T, H, W).
Returns:
Aggregated feature vector and list of intermediate features.
"""
features = self.forward_features(tensor)
embeddings = self.forward_encoders(features)
return embeddings, features
def forward(self, tensor: torch.Tensor) -> torch.Tensor | List[torch.Tensor]:
"""Forward pass through the encoder.
If `self._out_indices` is None, it returns the aggregated feature vector.
Otherwise, it returns the intermediate feature maps up to the specified index.
Args:
tensor: Input tensor of shape (B, C, T, H, W).
Returns:
Aggregated feature vector or intermediate features.
"""
if self._out_indices is None:
return self.forward_embeddings(tensor)
intermediates = self.forward_features(tensor)
return intermediates[-1 * self._out_indices :]