ashiq24's picture
Upload folder using huggingface_hub
f1fa947 verified
Raw
History Blame Contribute Delete
8.78 kB
"""
Filtered layer wrappers for DINOv2 models.
These functions monkey-patch DINOv2 embeddings to apply soft equivariance
filters to position embeddings and patch projections.
Supports both standard DINOv2 and DINOv2 with register tokens.
"""
import os
import sys
import importlib
import torch
from typing import Optional
from types import MethodType
import math
def _ensure_softeq():
"""Make the bundled softeq package importable."""
_dir = os.path.dirname(os.path.abspath(__file__))
if os.path.isdir(os.path.join(_dir, "softeq")):
if _dir not in sys.path:
sys.path.insert(0, _dir)
return
if importlib.util.find_spec("softeq") is not None:
return
from huggingface_hub import snapshot_download
parts = os.path.normpath(_dir).split(os.sep)
idx = next((i for i, p in enumerate(parts) if p == "transformers_modules"), None)
if idx is None or idx + 2 >= len(parts):
raise ImportError("Cannot locate the bundled softeq package.")
owner = parts[idx + 1]
repo = parts[idx + 2].replace("_hyphen_", "-").replace("_dot_", ".")
revision = parts[idx + 3] if idx + 3 < len(parts) else None
snap = snapshot_download(
f"{owner}/{repo}", revision=revision, allow_patterns=["softeq/**"],
)
sys.path.insert(0, snap)
_ensure_softeq()
FilteredConv2d = importlib.import_module("softeq.layers.fconv2d").FilteredConv2d
get_invariant_filter = importlib.import_module("softeq.equi_utils.filter_factory").get_invariant_filter
def custom_dinov2embeddings_forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
) -> torch.Tensor:
"""
Forward pass for standard DINOv2 embeddings with filtered patch projection (including positional embeddings projection).
Args:
pixel_values: Input images of shape (B, C, H, W).
bool_masked_pos: Optional patch mask used for masked-token replacement.
Returns:
Embedded token sequence with CLS token prepended and filtered positional encoding added.
"""
batch_size, num_channels, height, width = pixel_values.shape
target_dtype = self.patch_embeddings.projection.weight.dtype
embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
if bool_masked_pos is not None and self.use_mask_token:
embeddings = torch.where(
bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
)
# add the [CLS] token to the embedded patch tokens
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
self.position_embeddings.data[:, 1:, :] = self.filter_pos.smooth(
self.position_embeddings.data[:, 1:, :]
)
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
embeddings = self.dropout(embeddings)
return embeddings
def custom_dinov2_with_registers_embeddings_forward(
self,
pixel_values: torch.Tensor,
bool_masked_pos: Optional[torch.BoolTensor] = None,
) -> torch.Tensor:
"""
Forward pass for DINOv2 with register tokens and filtered patch projection (including positional embeddings projection).
Args:
pixel_values: Input images of shape (B, C, H, W).
bool_masked_pos: Optional patch mask used for masked-token replacement.
Returns:
Embedded token sequence with CLS, register tokens, and patch tokens.
"""
batch_size, num_channels, height, width = pixel_values.shape
target_dtype = self.patch_embeddings.projection.weight.dtype
embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
if bool_masked_pos is not None:
embeddings = torch.where(
bool_masked_pos.unsqueeze(-1), self.mask_token.to(embeddings.dtype).unsqueeze(0), embeddings
)
# add the [CLS] token to the embedded patch tokens
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
# add positional encoding to each token
# Apply filtering to positional embeddings - only to patch token positions
if hasattr(self, 'position_embeddings'):
# Only filter patch positions (skip CLS position)
self.position_embeddings.data[:, 1:, :] = self.filter_pos.smooth(
self.position_embeddings.data[:, 1:, :]
)
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
# add register tokens
embeddings = torch.cat(
(embeddings[:, :1], self.register_tokens.expand(embeddings.shape[0], -1, -1), embeddings[:, 1:]), dim=1
)
embeddings = self.dropout(embeddings)
return embeddings
def monkeypatch_dinov2embeddings(dinov2embeddings, filter_configs):
"""
Monkey patch DINOv2 embeddings to apply filters.
Supports both standard DINOv2 and DINOv2 with register tokens.
- Standard DINOv2: No register tokens
- DINOv2 with registers: Has register tokens (typically 4)
Register tokens should NOT be filtered as they are learned representations.
Args:
dinov2embeddings: DINOv2 embeddings module.
filter_configs: Dictionary with filter configuration, including:
- n_rotations: Number of discrete rotations.
- soft_thresholding: Softness for patch embedding filter.
- soft_thresholding_pos: Softness for position embedding filter.
- decomposition_method: 'svd' or 'schur'.
- group_type: 'rotation' or 'roto_reflection'.
- hard_mask: Use hard mask for smoothing.
- preserve_norm: Preserve weight norms after projection.
- joint_decomposition: Use joint decomposition for multi-generator groups.
"""
# Determine DINOv2 variant
is_register_variant = 'WithRegisters' in dinov2embeddings.__class__.__name__
group_type = filter_configs.get("group_type", "rotation")
# Filter patch embeddings (same for both variants)
original_conv = dinov2embeddings.patch_embeddings.projection
kernel_size = original_conv.kernel_size[0]
# assert square kernel
assert kernel_size == original_conv.kernel_size[1], "Kernel size is not square"
filter = get_invariant_filter(
group_type=group_type,
n_rotations=filter_configs["n_rotations"],
input_size=(1, kernel_size , kernel_size),
soft_threshold=filter_configs["soft_thresholding"],
decomposition_method=filter_configs["decomposition_method"],
debug=False,
hard_mask=filter_configs.get("hard_mask", False),
preserve_norm=filter_configs.get("preserve_norm", False),
joint_decomposition=filter_configs.get("joint_decomposition", True)
)
dinov2embeddings.patch_embeddings.projection = FilteredConv2d(dinov2embeddings.patch_embeddings.projection, filter)
# Handle positional embeddings
if hasattr(dinov2embeddings, 'position_embeddings') and dinov2embeddings.position_embeddings is not None:
# Filter positional embeddings for patch tokens only
num_patches = dinov2embeddings.patch_embeddings.num_patches
# Check if num_patches is a perfect square
if math.sqrt(num_patches).is_integer():
len_pos = int(math.sqrt(num_patches))
filter_pos = get_invariant_filter(
group_type=group_type,
n_rotations=filter_configs["n_rotations"],
input_size=(1, len_pos, len_pos),
soft_threshold=filter_configs["soft_thresholding_pos"],
decomposition_method=filter_configs["decomposition_method"],
debug=False,
hard_mask=filter_configs.get("hard_mask", False),
preserve_norm=filter_configs.get("preserve_norm", False),
joint_decomposition=filter_configs.get("joint_decomposition", True)
)
dinov2embeddings.filter_pos = filter_pos
else:
print(f"Warning: DINOv2 num_patches ({num_patches}) is not a perfect square. Skipping positional embedding filtering.")
else:
print("DINOv2 model does not have traditional positional embeddings.")
# Replace forward method with appropriate custom implementation
if is_register_variant:
dinov2embeddings.forward = MethodType(custom_dinov2_with_registers_embeddings_forward, dinov2embeddings)
else:
dinov2embeddings.forward = MethodType(custom_dinov2embeddings_forward, dinov2embeddings)
print("Applied standard DINOv2 forward method")