| """HuggingFace PreTrainedModel wrapper for the CoralBay encoder.""" |
|
|
| import torch |
| from transformers import PreTrainedModel |
|
|
| from configuration_coralbay import CoralBayConfig |
| from swin_unetr import SwinUNETREncoder |
|
|
|
|
| class CoralBayModel(PreTrainedModel): |
| """HuggingFace-compatible CoralBay model wrapping :class:`SwinUNETREncoder`. |
| |
| Args: |
| config: A :class:`CoralBayConfig` instance. |
| """ |
|
|
| config_class = CoralBayConfig |
|
|
| def __init__(self, config) -> None: |
| super().__init__(config) |
|
|
| self.encoder = SwinUNETREncoder( |
| in_channels=config.in_channels, |
| feature_size=config.feature_size, |
| spatial_dims=config.spatial_dims, |
| out_indices=config.out_indices, |
| embeddings_type=config.embeddings_type, |
| use_v2=config.use_v2, |
| ) |
|
|
| self.post_init() |
|
|
| def forward(self, tensor: torch.Tensor) -> torch.Tensor | list[torch.Tensor]: |
| """Forward pass — delegates to the underlying encoder. |
| |
| Args: |
| tensor: Input tensor of shape ``(B, C, T, H, W)``. |
| |
| Returns: |
| Aggregated feature vector or list of intermediate feature maps, |
| depending on the encoder's ``out_indices`` setting. |
| """ |
| return self.encoder(tensor) |
|
|