"""HuggingFace PreTrainedModel wrapper for the CoralBay encoder.""" import torch from transformers import PreTrainedModel from configuration_coralbay import CoralBayConfig from swin_unetr import SwinUNETREncoder class CoralBayModel(PreTrainedModel): """HuggingFace-compatible CoralBay model wrapping :class:`SwinUNETREncoder`. Args: config: A :class:`CoralBayConfig` instance. """ config_class = CoralBayConfig def __init__(self, config) -> None: super().__init__(config) self.encoder = SwinUNETREncoder( in_channels=config.in_channels, feature_size=config.feature_size, spatial_dims=config.spatial_dims, out_indices=config.out_indices, embeddings_type=config.embeddings_type, use_v2=config.use_v2, ) self.post_init() def forward(self, tensor: torch.Tensor) -> torch.Tensor | list[torch.Tensor]: """Forward pass — delegates to the underlying encoder. Args: tensor: Input tensor of shape ``(B, C, T, H, W)``. Returns: Aggregated feature vector or list of intermediate feature maps, depending on the encoder's ``out_indices`` setting. """ return self.encoder(tensor)