File size: 1,291 Bytes
17825b3 9fbba73 17825b3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | """HuggingFace PreTrainedModel wrapper for the CoralBay encoder."""
import torch
from transformers import PreTrainedModel
from configuration_coralbay import CoralBayConfig
from swin_unetr import SwinUNETREncoder
class CoralBayModel(PreTrainedModel):
"""HuggingFace-compatible CoralBay model wrapping :class:`SwinUNETREncoder`.
Args:
config: A :class:`CoralBayConfig` instance.
"""
config_class = CoralBayConfig
def __init__(self, config) -> None:
super().__init__(config)
self.encoder = SwinUNETREncoder(
in_channels=config.in_channels,
feature_size=config.feature_size,
spatial_dims=config.spatial_dims,
out_indices=config.out_indices,
embeddings_type=config.embeddings_type,
use_v2=config.use_v2,
)
self.post_init()
def forward(self, tensor: torch.Tensor) -> torch.Tensor | list[torch.Tensor]:
"""Forward pass — delegates to the underlying encoder.
Args:
tensor: Input tensor of shape ``(B, C, T, H, W)``.
Returns:
Aggregated feature vector or list of intermediate feature maps,
depending on the encoder's ``out_indices`` setting.
"""
return self.encoder(tensor)
|