| """ |
| Soft-Equivariant ViT for Semantic Segmentation. |
| |
| Architecture (state-dict compatible with training checkpoints): |
| * self.vit – ViTModel backbone (no classification head) |
| * self.classifier – nn.Conv2d(hidden_size, num_labels, kernel_size=1) segmentation head |
| * Logits are bilinearly upsampled to input resolution. |
| |
| Filter buffers are stored in model.safetensors alongside learnable weights. |
| """ |
|
|
| import math |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel, ViTConfig, ViTModel |
| from transformers.modeling_outputs import SemanticSegmenterOutput |
|
|
| from .configuration_softeq import SoftEqConfig |
| from .filtered_layers_vit import monkeypatch_vitembeddings, monkeypatch_vitoutput |
|
|
|
|
| class FilteredViTSeg(PreTrainedModel): |
| """ |
| Soft-Equivariant Vision Transformer for semantic segmentation. |
| |
| Uses a linear probing style head: the ViT patch tokens are reshaped to a |
| spatial grid, classified by a 1x1 Conv, then bilinearly upsampled to the |
| input resolution. |
| """ |
|
|
| config_class = SoftEqConfig |
|
|
| def __init__(self, config: SoftEqConfig): |
| super().__init__(config) |
|
|
| print(f"Loading ViT segmentation backbone config from: {config.pretrained_model}") |
| vit_config = ViTConfig.from_pretrained(config.pretrained_model) |
|
|
| |
| self.vit = ViTModel(vit_config, add_pooling_layer=False) |
| self.ignore_index = config.ignore_index |
|
|
| hidden_size = vit_config.hidden_size |
| self.classifier = nn.Conv2d(hidden_size, config.num_labels, kernel_size=1, bias=True) |
| nn.init.normal_(self.classifier.weight, std=0.02) |
| nn.init.zeros_(self.classifier.bias) |
|
|
| self.loss_fct = nn.CrossEntropyLoss(ignore_index=self.ignore_index) |
|
|
| filter_config = config._make_filter_config() |
| with torch.device('cpu'): |
| if config.filter_patch_embeddings: |
| print("Applying soft-equivariant filter to patch embeddings...") |
| monkeypatch_vitembeddings(self.vit.embeddings, filter_config) |
|
|
| if config.filter_attention_output: |
| num_patches = self.vit.embeddings.patch_embeddings.num_patches |
| for layer_idx in range(vit_config.num_hidden_layers): |
| if layer_idx in (config.attention_output_filter_list or []): |
| layer = self.vit.encoder.layer[layer_idx] |
| monkeypatch_vitoutput( |
| layer.attention.output, filter_config, num_patches |
| ) |
|
|
| if config.freeze_patch_embeddings: |
| proj = self.vit.embeddings.patch_embeddings.projection |
| if hasattr(proj, "weight"): |
| proj.weight.requires_grad = False |
| if config.freeze_position_embeddings: |
| if hasattr(self.vit.embeddings, "position_embeddings"): |
| self.vit.embeddings.position_embeddings.requires_grad = False |
|
|
| self.post_init() |
|
|
| def forward( |
| self, |
| pixel_values: torch.FloatTensor, |
| labels: Optional[torch.LongTensor] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ): |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| batch_size, _, input_height, input_width = pixel_values.shape |
|
|
| outputs = self.vit( |
| pixel_values=pixel_values, |
| output_hidden_states=output_hidden_states, |
| output_attentions=output_attentions, |
| return_dict=True, |
| ) |
|
|
| |
| sequence_output = outputs.last_hidden_state |
| patch_tokens = sequence_output[:, 1:, :] |
| num_patches = patch_tokens.shape[1] |
| grid_size = int(math.sqrt(num_patches)) |
| feature_map = patch_tokens.permute(0, 2, 1).reshape(batch_size, -1, grid_size, grid_size) |
|
|
| logits = self.classifier(feature_map) |
| target_size = labels.shape[-2:] if labels is not None else (input_height, input_width) |
| logits = F.interpolate(logits, size=target_size, mode="bilinear", align_corners=False) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_fct(logits, labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SemanticSegmenterOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states if output_hidden_states else None, |
| attentions=outputs.attentions if output_attentions else None, |
| ) |
|
|