File size: 1,291 Bytes
17825b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9fbba73
17825b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""HuggingFace PreTrainedModel wrapper for the CoralBay encoder."""

import torch
from transformers import PreTrainedModel

from configuration_coralbay import CoralBayConfig
from swin_unetr import SwinUNETREncoder


class CoralBayModel(PreTrainedModel):
    """HuggingFace-compatible CoralBay model wrapping :class:`SwinUNETREncoder`.

    Args:
        config: A :class:`CoralBayConfig` instance.
    """

    config_class = CoralBayConfig

    def __init__(self, config) -> None:
        super().__init__(config)

        self.encoder = SwinUNETREncoder(
            in_channels=config.in_channels,
            feature_size=config.feature_size,
            spatial_dims=config.spatial_dims,
            out_indices=config.out_indices,
            embeddings_type=config.embeddings_type,
            use_v2=config.use_v2,
        )

        self.post_init()

    def forward(self, tensor: torch.Tensor) -> torch.Tensor | list[torch.Tensor]:
        """Forward pass — delegates to the underlying encoder.

        Args:
            tensor: Input tensor of shape ``(B, C, T, H, W)``.

        Returns:
            Aggregated feature vector or list of intermediate feature maps,
            depending on the encoder's ``out_indices`` setting.
        """
        return self.encoder(tensor)