import math from typing import Any, Dict, Tuple import torch from timm.models.vision_transformer import DropPath, Mlp from torch import Tensor, nn from torch.nn import functional as F from .layer import ConvLayer2d, PosCNN class BaseModule(nn.Module): """Base class for all modules""" def __init__(self, *args, **kwargs): super(BaseModule, self).__init__() def forward(self, x: Any, *args, **kwargs) -> Any: raise NotImplementedError def __repr__(self): return "{}".format(self.__class__.__name__) class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0): super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = ( self.qkv(x) .reshape(B, N, 3, self.num_heads, C // self.num_heads) .permute(2, 0, 3, 1, 4) ) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class LayerScale(nn.Module): def __init__(self, dim, init_values=1e-5, inplace=False): super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x): return x.mul_(self.gamma) if self.inplace else x * self.gamma class Block(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, drop=0.0, attn_drop=0.0, init_values=None, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): super().__init__() self.norm1 = norm_layer(dim, elementwise_affine=True) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, ) self.ls1 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim, elementwise_affine=True) self.mlp = Mlp( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop, ) self.ls2 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x): x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x class MobileViTBlock(BaseModule): """ MVP Block - MobileViT with Positional Encoding """ def __init__( self, in_channels=128, transformer_dim=128, n_transformer_blocks=2, head_dim=64, attn_dropout=0.0, dropout=0.0, patch_h=2, patch_w=2, conv_ksize=3, dilation=1, no_fusion=True, ) -> None: conv_3x3_in = ConvLayer2d( in_channels=in_channels, out_channels=in_channels, kernel_size=conv_ksize, stride=1, use_norm=True, use_act=True, dilation=dilation, padding=1, ) conv_1x1_in = ConvLayer2d( in_channels=in_channels, out_channels=transformer_dim, kernel_size=1, stride=1, use_norm=False, use_act=False, ) conv_1x1_out = ConvLayer2d( in_channels=transformer_dim, out_channels=in_channels, kernel_size=1, stride=1, use_norm=True, use_act=True, ) conv_3x3_out = None if not no_fusion: conv_3x3_out = ConvLayer2d( in_channels=2 * in_channels, out_channels=in_channels, kernel_size=conv_ksize, stride=1, padding=1, use_norm=True, use_act=True, ) super().__init__() self.local_rep = nn.Sequential() self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in) self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in) """ Positional Encoding Generator """ self.pos_pe = PosCNN(in_chans=transformer_dim, embed_dim=transformer_dim) assert transformer_dim % head_dim == 0 num_heads = transformer_dim // head_dim global_rep = [ Block( dim=transformer_dim, num_heads=num_heads, mlp_ratio=4.0, qkv_bias=True, attn_drop=attn_dropout, drop=dropout, norm_layer=nn.LayerNorm, ) for _ in range(n_transformer_blocks) ] global_rep.append(nn.LayerNorm(transformer_dim)) self.global_rep = nn.Sequential(*global_rep) self.conv_proj = conv_1x1_out self.fusion = conv_3x3_out self.patch_h = patch_h self.patch_w = patch_w self.patch_area = self.patch_w * self.patch_h self.cnn_in_dim = in_channels self.cnn_out_dim = transformer_dim self.n_heads = num_heads self.dropout = dropout self.attn_dropout = attn_dropout self.dilation = dilation self.n_blocks = n_transformer_blocks self.conv_ksize = conv_ksize def unfolding(self, feature_map: Tensor) -> Tuple[Tensor, Dict]: patch_w, patch_h = self.patch_w, self.patch_h patch_area = int(patch_w * patch_h) batch_size, in_channels, orig_h, orig_w = feature_map.shape new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h) new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w) interpolate = False if new_w != orig_w or new_h != orig_h: # Note: Padding can be done, but then it needs to be handled in attention function. feature_map = F.interpolate( feature_map, size=(new_h, new_w), mode="bilinear", align_corners=False ) interpolate = True # number of patches along width and height num_patch_w = new_w // patch_w # n_w num_patch_h = new_h // patch_h # n_h num_patches = num_patch_h * num_patch_w # N # [B, C, H, W] --> [B * C * n_h, p_h, n_w, p_w] reshaped_fm = feature_map.reshape( batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w ) # [B * C * n_h, p_h, n_w, p_w] --> [B * C * n_h, n_w, p_h, p_w] transposed_fm = reshaped_fm.transpose(1, 2) # [B * C * n_h, n_w, p_h, p_w] --> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w reshaped_fm = transposed_fm.reshape( batch_size, in_channels, num_patches, patch_area ) # [B, C, N, P] --> [B, P, N, C] transposed_fm = reshaped_fm.transpose(1, 3) # [B, P, N, C] --> [BP, N, C] patches = transposed_fm.reshape(batch_size * patch_area, num_patches, -1) info_dict = { "orig_size": (orig_h, orig_w), "batch_size": batch_size, "interpolate": interpolate, "total_patches": num_patches, "num_patches_w": num_patch_w, "num_patches_h": num_patch_h, } return patches, info_dict def folding(self, patches: Tensor, info_dict: Dict) -> Tensor: n_dim = patches.dim() assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format( patches.shape ) # [BP, N, C] --> [B, P, N, C] patches = patches.contiguous().view( info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1 ) batch_size, pixels, num_patches, channels = patches.size() num_patch_h = info_dict["num_patches_h"] num_patch_w = info_dict["num_patches_w"] # [B, P, N, C] --> [B, C, N, P] patches = patches.transpose(1, 3) # [B, C, N, P] --> [B*C*n_h, n_w, p_h, p_w] feature_map = patches.reshape( batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w ) # [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] feature_map = feature_map.transpose(1, 2) # [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W] feature_map = feature_map.reshape( batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w ) if info_dict["interpolate"]: feature_map = F.interpolate( feature_map, size=info_dict["orig_size"], mode="bilinear", align_corners=False, ) return feature_map def forward(self, x: Tensor) -> Tensor: res = x fm = self.local_rep(x) # convert feature map to patches patches, info_dict = self.unfolding(fm) num_patch_h = info_dict["num_patches_h"] num_patch_w = info_dict["num_patches_w"] # learn global representations for j, transformer_layer in enumerate(self.global_rep): patches = transformer_layer(patches) if j == 0: patches = self.pos_pe(patches, num_patch_h, num_patch_w) # PEG here # [B x Patch x Patches x C] --> [B x C x Patches x Patch] fm = self.folding(patches=patches, info_dict=info_dict) fm = self.conv_proj(fm) if self.fusion is not None: fm = self.fusion(torch.cat((res, fm), dim=1)) return fm