"""Encoder based on Swin UNETR.""" from typing import List, Literal, Tuple import torch from monai.inferers.inferer import Inferer from monai.networks.blocks import unetr_block from monai.networks.nets import swin_unetr from monai.utils import misc from torch import nn from huggingface_hub import PyTorchModelHubMixin class SwinUNETREncoder(nn.Module, PyTorchModelHubMixin): """Swin transformer encoder based on UNETR [0]. - [0] UNETR: Transformers for 3D Medical Image Segmentation https://arxiv.org/pdf/2103.10504 """ def __init__( self, in_channels: int = 1, feature_size: int = 48, spatial_dims: int = 3, out_indices: int | None = None, inferer: Inferer | None = None, embeddings_type: Literal["multiscale", "head"] = "multiscale", use_v2: bool = True, ) -> None: """Build the UNETR encoder. Args: in_channels: Number of input channels. feature_size: The dimension of network feature size. spatial_dims: Number of spatial dimensions. out_indices: Number of feature outputs. If None, the aggregated feature vector is returned. inferer: An optional MONAI `Inferer` for efficient inference during evaluation. embeddings_type: Whether to use aggregated or head embeddings: - `multiscale`: multi-scale aggregated representation - `head`: last-stage (head) pooled representation use_v2: Whether to use SwinTransformerV2. """ super().__init__() self._in_channels = in_channels self._feature_size = feature_size self._spatial_dims = spatial_dims self._out_indices = out_indices self._inferer = inferer self._embeddings_type = embeddings_type self._use_v2 = use_v2 self._window_size = misc.ensure_tuple_rep(7, spatial_dims) self._patch_size = misc.ensure_tuple_rep(2, spatial_dims) self.swinViT = swin_unetr.SwinTransformer( in_chans=in_channels, embed_dim=feature_size, window_size=self._window_size, patch_size=self._patch_size, depths=(2, 2, 2, 2), num_heads=(3, 6, 12, 24), mlp_ratio=4.0, qkv_bias=True, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=torch.nn.LayerNorm, spatial_dims=spatial_dims, use_v2=use_v2, ) self.encoder1 = unetr_block.UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=in_channels, out_channels=feature_size, kernel_size=3, stride=1, norm_name="instance", res_block=True, ) self.encoder2 = unetr_block.UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=feature_size, out_channels=feature_size, kernel_size=3, stride=1, norm_name="instance", res_block=True, ) self.encoder3 = unetr_block.UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=2 * feature_size, out_channels=2 * feature_size, kernel_size=3, stride=1, norm_name="instance", res_block=True, ) self.encoder4 = unetr_block.UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=4 * feature_size, out_channels=4 * feature_size, kernel_size=3, stride=1, norm_name="instance", res_block=True, ) self.encoder10 = unetr_block.UnetrBasicBlock( spatial_dims=spatial_dims, in_channels=16 * feature_size, out_channels=16 * feature_size, kernel_size=3, stride=1, norm_name="instance", res_block=True, ) self._pool_op = ( nn.AdaptiveAvgPool3d(output_size=(1, 1, 1)) if spatial_dims == 3 else nn.AdaptiveAvgPool2d(output_size=(1, 1)) ) def _forward_features(self, tensor: torch.Tensor) -> List[torch.Tensor]: """Extracts feature maps from the Swin Transformer and encoder blocks. Args: tensor: Input tensor of shape (B, C, T, H, W). Returns: List of feature maps from encoder stages. """ hidden_states = self.swinViT(tensor) enc0 = self.encoder1(tensor) enc1 = self.encoder2(hidden_states[0]) enc2 = self.encoder3(hidden_states[1]) enc3 = self.encoder4(hidden_states[2]) dec4 = self.encoder10(hidden_states[4]) return [enc0, enc1, enc2, enc3, hidden_states[3], dec4] def forward_features(self, tensor: torch.Tensor) -> List[torch.Tensor]: """Computes feature maps using either standard forward pass or inference mode. If in inference mode (`self.training` is False) and an inference method (`self._inferer`) is available, the `_inferer` is used to extract features. Otherwise, `_forward_features` is called directly. Args: tensor: Input tensor of shape (B, C, T, H, W). Returns: List of feature maps from encoder stages. """ if not self.training and self._inferer: return self._inferer(inputs=tensor, network=self._forward_features) return self._forward_features(tensor) def forward_encoders(self, features: List[torch.Tensor]) -> torch.Tensor: """Aggregates encoder features into a single feature vector. Args: features: List of feature maps from encoder stages. Returns: Aggregated feature vector (B, C'). """ batch_size = features[0].shape[0] reduced_features = [] for patch_features in features[:4] + features[5:]: hidden_features = self._pool_op(patch_features) hidden_features_reduced = hidden_features.view(batch_size, -1) reduced_features.append(hidden_features_reduced) return torch.cat(reduced_features, dim=1) def forward_head(self, features: List[torch.Tensor]) -> torch.Tensor: """Casts last feature map into a single feature vector. Args: features: List of encoder feature maps. Returns: Aggregated feature vector (B, C'). """ last_feature_map = features[-1] pooled_features = self._pool_op(last_feature_map) return torch.flatten(pooled_features, 1) def forward_embeddings(self, tensor: torch.Tensor) -> torch.Tensor: """Computes the final aggregated feature vector. Args: tensor: Input tensor of shape (B, C, T, H, W). Returns: Aggregated feature vector of shape (B, C'). """ embeddings_to_forward_methods = { "multiscale": self.forward_encoders, "head": self.forward_head, } forward_method = embeddings_to_forward_methods.get(self._embeddings_type) if forward_method is None: raise ValueError(f"Unknown embeddings_type: {self._embeddings_type}") intermediates = self.forward_features(tensor) return forward_method(intermediates) def forward_intermediates( self, tensor: torch.Tensor ) -> Tuple[torch.Tensor, List[torch.Tensor]]: """Computes encoder features and their embeddings. Args: tensor: Input tensor of shape (B, C, T, H, W). Returns: Aggregated feature vector and list of intermediate features. """ features = self.forward_features(tensor) embeddings = self.forward_encoders(features) return embeddings, features def forward(self, tensor: torch.Tensor) -> torch.Tensor | List[torch.Tensor]: """Forward pass through the encoder. If `self._out_indices` is None, it returns the aggregated feature vector. Otherwise, it returns the intermediate feature maps up to the specified index. Args: tensor: Input tensor of shape (B, C, T, H, W). Returns: Aggregated feature vector or intermediate features. """ if self._out_indices is None: return self.forward_embeddings(tensor) intermediates = self.forward_features(tensor) return intermediates[-1 * self._out_indices :]