coralbay / modeling_coralbay.py
kaiko-ai-user's picture
Update modeling_coralbay.py
9fbba73 verified
"""HuggingFace PreTrainedModel wrapper for the CoralBay encoder."""
import torch
from transformers import PreTrainedModel
from configuration_coralbay import CoralBayConfig
from swin_unetr import SwinUNETREncoder
class CoralBayModel(PreTrainedModel):
"""HuggingFace-compatible CoralBay model wrapping :class:`SwinUNETREncoder`.
Args:
config: A :class:`CoralBayConfig` instance.
"""
config_class = CoralBayConfig
def __init__(self, config) -> None:
super().__init__(config)
self.encoder = SwinUNETREncoder(
in_channels=config.in_channels,
feature_size=config.feature_size,
spatial_dims=config.spatial_dims,
out_indices=config.out_indices,
embeddings_type=config.embeddings_type,
use_v2=config.use_v2,
)
self.post_init()
def forward(self, tensor: torch.Tensor) -> torch.Tensor | list[torch.Tensor]:
"""Forward pass — delegates to the underlying encoder.
Args:
tensor: Input tensor of shape ``(B, C, T, H, W)``.
Returns:
Aggregated feature vector or list of intermediate feature maps,
depending on the encoder's ``out_indices`` setting.
"""
return self.encoder(tensor)