""" DDPMCDPipeline for change detection. pipeline.py is in the repo — use custom_pipeline="pipeline" (relative path). Usage:: from diffusers import DiffusionPipeline pipe = DiffusionPipeline.from_pretrained( "BiliSakura/ddpm-cd", custom_pipeline="pipeline", trust_remote_code=True, cd_head_subfolder="levir-50-100", ) change_map = pipe(image_A, image_B, timesteps=[50, 100]) """ import json import math import os from inspect import isfunction import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from diffusers import DDPMScheduler from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.models.modeling_utils import ModelMixin # ModelMixin subclasses nn.Module from diffusers.pipelines.pipeline_utils import DiffusionPipeline from tqdm.auto import tqdm # =========================================================================== # UNet (SR3-style) - all components inlined # =========================================================================== def _exists(x): return x is not None def _default(val, d): if _exists(val): return val return d() if isfunction(d) else d class PositionalEncoding(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, noise_level): count = self.dim // 2 step = torch.arange(count, dtype=noise_level.dtype, device=noise_level.device) / count encoding = noise_level.unsqueeze(1) * torch.exp(-math.log(1e4) * step.unsqueeze(0)) return torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1) class FeatureWiseAffine(nn.Module): def __init__(self, in_channels, out_channels, use_affine_level=False): super().__init__() self.use_affine_level = use_affine_level self.noise_func = nn.Sequential(nn.Linear(in_channels, out_channels * (1 + self.use_affine_level))) def forward(self, x, noise_embed): batch = x.shape[0] if self.use_affine_level: gamma, beta = self.noise_func(noise_embed).view(batch, -1, 1, 1).chunk(2, dim=1) x = (1 + gamma) * x + beta else: x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1) return x class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) class Upsample(nn.Module): def __init__(self, dim): super().__init__() self.up = nn.Upsample(scale_factor=2, mode="nearest") self.conv = nn.Conv2d(dim, dim, 3, padding=1) def forward(self, x): return self.conv(self.up(x)) class Downsample(nn.Module): def __init__(self, dim): super().__init__() self.conv = nn.Conv2d(dim, dim, 3, 2, 1) def forward(self, x): return self.conv(x) class Block(nn.Module): def __init__(self, dim, dim_out, groups=32, dropout=0): super().__init__() self.block = nn.Sequential( nn.GroupNorm(groups, dim), Swish(), nn.Dropout(dropout) if dropout != 0 else nn.Identity(), nn.Conv2d(dim, dim_out, 3, padding=1), ) def forward(self, x): return self.block(x) class ResnetBlock(nn.Module): def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32): super().__init__() self.noise_func = FeatureWiseAffine(noise_level_emb_dim, dim_out, use_affine_level) self.block1 = Block(dim, dim_out, groups=norm_groups) self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout) self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() def forward(self, x, time_emb): h = self.block1(x) h = self.noise_func(h, time_emb) h = self.block2(h) return h + self.res_conv(x) class SelfAttention(nn.Module): def __init__(self, in_channel, n_head=1, norm_groups=32): super().__init__() self.n_head = n_head self.norm = nn.GroupNorm(norm_groups, in_channel) self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False) self.out = nn.Conv2d(in_channel, in_channel, 1) def forward(self, input): batch, channel, height, width = input.shape n_head, head_dim = self.n_head, channel // self.n_head norm = self.norm(input) qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width) query, key, value = qkv.chunk(3, dim=2) attn = torch.einsum("bnchw, bncyx -> bnhwyx", query, key).contiguous() / math.sqrt(channel) attn = torch.softmax(attn.view(batch, n_head, height, width, -1), -1) attn = attn.view(batch, n_head, height, width, height, width) out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous() return self.out(out.view(batch, channel, height, width)) + input class ResnetBlocWithAttn(nn.Module): def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False): super().__init__() self.with_attn = with_attn self.res_block = ResnetBlock(dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout) self.attn = SelfAttention(dim_out, norm_groups=norm_groups) if with_attn else None def forward(self, x, time_emb): x = self.res_block(x, time_emb) if self.with_attn: x = self.attn(x) return x class UNet(ModelMixin, ConfigMixin): """SR3-style UNet with noise-level conditioning. Supports feat_need=True for intermediate features.""" @register_to_config def __init__( self, in_channel=6, out_channel=3, inner_channel=32, norm_groups=32, channel_mults=(1, 2, 4, 8, 8), attn_res=(8,), res_blocks=3, dropout=0, with_noise_level_emb=True, image_size=128, ): super().__init__() noise_level_channel = inner_channel if with_noise_level_emb else None self.noise_level_mlp = ( nn.Sequential( PositionalEncoding(inner_channel), nn.Linear(inner_channel, inner_channel * 4), Swish(), nn.Linear(inner_channel * 4, inner_channel), ) if with_noise_level_emb else None ) num_mults = len(channel_mults) pre_channel, feat_channels, now_res = inner_channel, [inner_channel], image_size self.init_conv = nn.Conv2d(in_channel, inner_channel, 3, padding=1) downs = [] for ind in range(num_mults): use_attn = now_res in attn_res channel_mult = inner_channel * channel_mults[ind] for _ in range(res_blocks): downs.append( ResnetBlocWithAttn( pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn, ) ) feat_channels.append(channel_mult) pre_channel = channel_mult if ind < num_mults - 1: downs.append(Downsample(pre_channel)) feat_channels.append(pre_channel) now_res = now_res // 2 self.downs = nn.ModuleList(downs) self.mid = nn.ModuleList([ ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=True), ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=False), ]) ups = [] for ind in reversed(range(num_mults)): use_attn = now_res in attn_res channel_mult = inner_channel * channel_mults[ind] for _ in range(res_blocks + 1): ups.append( ResnetBlocWithAttn( pre_channel + feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn, ) ) pre_channel = channel_mult if ind > 0: ups.append(Upsample(pre_channel)) now_res = now_res * 2 self.ups = nn.ModuleList(ups) self.final_conv = Block(pre_channel, _default(out_channel, lambda: in_channel), groups=norm_groups) def forward(self, x, time, feat_need=False): t = self.noise_level_mlp(time) if _exists(self.noise_level_mlp) else None x = self.init_conv(x) feats = [x] for layer in self.downs: x = layer(x, t) if isinstance(layer, ResnetBlocWithAttn) else layer(x) feats.append(x) fe = feats.copy() if feat_need else None for layer in self.mid: x = layer(x, t) if isinstance(layer, ResnetBlocWithAttn) else layer(x) fd = [] if feat_need else None for layer in self.ups: if isinstance(layer, ResnetBlocWithAttn): x = layer(torch.cat((x, feats.pop()), dim=1), t) if feat_need: fd.append(x) else: x = layer(x) x = self.final_conv(x) return (fe, list(reversed(fd))) if feat_need else x # =========================================================================== # Change detection head # =========================================================================== class ChannelSELayer(nn.Module): def __init__(self, num_channels, reduction_ratio=2): super().__init__() reduced = num_channels // reduction_ratio self.fc1 = nn.Linear(num_channels, reduced, bias=True) self.fc2 = nn.Linear(reduced, num_channels, bias=True) self.relu, self.sigmoid = nn.ReLU(), nn.Sigmoid() def forward(self, x): b, c, _, _ = x.size() s = x.view(b, c, -1).mean(dim=2) s = self.sigmoid(self.fc2(self.relu(self.fc1(s)))).view(b, c, 1, 1) return x * s class SpatialSELayer(nn.Module): def __init__(self, num_channels): super().__init__() self.conv = nn.Conv2d(num_channels, 1, 1) self.sigmoid = nn.Sigmoid() def forward(self, x, weights=None): b, c, h, w = x.size() out = F.conv2d(x, weights.view(1, c, 1, 1)) if weights is not None else self.conv(x) return x * self.sigmoid(out).view(b, 1, h, w) class ChannelSpatialSELayer(nn.Module): def __init__(self, num_channels, reduction_ratio=2): super().__init__() self.cSE = ChannelSELayer(num_channels, reduction_ratio) self.sSE = SpatialSELayer(num_channels) def forward(self, x): return self.cSE(x) + self.sSE(x) def _get_in_channels(feat_scales, inner_channel, channel_multiplier): m, cm = inner_channel, channel_multiplier r = 0 for s in feat_scales: if s < 3: r += m * cm[0] elif s < 6: r += m * cm[1] elif s < 9: r += m * cm[2] elif s < 12: r += m * cm[3] elif s < 15: r += m * cm[4] else: raise ValueError("feat_scales 0<=s<=14") return r class AttentionBlock(nn.Module): def __init__(self, dim, dim_out): super().__init__() self.block = nn.Sequential( nn.Conv2d(dim, dim_out, 3, padding=1), nn.ReLU(), ChannelSpatialSELayer(dim_out, 2), ) def forward(self, x): return self.block(x) class CDBlock(nn.Module): def __init__(self, dim, dim_out, time_steps): super().__init__() if len(time_steps) > 1: self.block = nn.Sequential( nn.Conv2d(dim * len(time_steps), dim, 1), nn.ReLU(), nn.Conv2d(dim, dim_out, 3, padding=1), nn.ReLU(), ) else: self.block = nn.Sequential(nn.Conv2d(dim, dim_out, 3, padding=1), nn.ReLU()) def forward(self, x): return self.block(x) class cd_head_v2(nn.Module): """Change detection head (version 2).""" def __init__(self, feat_scales, out_channels=2, inner_channel=None, channel_multiplier=None, img_size=256, time_steps=None): super().__init__() self.feat_scales = sorted(list(feat_scales), reverse=True) self.in_channels = _get_in_channels(self.feat_scales, inner_channel, channel_multiplier) self.img_size, self.time_steps = img_size, time_steps self.decoder = nn.ModuleList() for i in range(len(self.feat_scales)): dim = _get_in_channels([self.feat_scales[i]], inner_channel, channel_multiplier) self.decoder.append(CDBlock(dim, dim, time_steps)) if i < len(self.feat_scales) - 1: dim_out = _get_in_channels([self.feat_scales[i + 1]], inner_channel, channel_multiplier) self.decoder.append(AttentionBlock(dim, dim_out)) self.clfr_stg1 = nn.Conv2d(dim_out, 64, 3, padding=1) self.clfr_stg2 = nn.Conv2d(64, out_channels, 3, padding=1) self.relu = nn.ReLU() def forward(self, feats_A, feats_B): lvl, x = 0, None for layer in self.decoder: if isinstance(layer, CDBlock): f_A = feats_A[0][self.feat_scales[lvl]] f_B = feats_B[0][self.feat_scales[lvl]] if len(self.time_steps) > 1: for i in range(1, len(self.time_steps)): f_A = torch.cat((f_A, feats_A[i][self.feat_scales[lvl]]), dim=1) f_B = torch.cat((f_B, feats_B[i][self.feat_scales[lvl]]), dim=1) diff = torch.abs(layer(f_A) - layer(f_B)) if lvl > 0: diff = diff + x lvl += 1 else: diff = layer(diff) x = F.interpolate(diff, scale_factor=2, mode="bilinear") return self.clfr_stg2(self.relu(self.clfr_stg1(x))) # =========================================================================== # Diffusion utilities # =========================================================================== def _precompute_alpha_tables(scheduler): ac = scheduler.alphas_cumprod.numpy() return np.sqrt(np.append(1.0, ac)) def _q_sample(x_start, continuous_sqrt_alpha_cumprod, noise=None): if noise is None: noise = torch.randn_like(x_start) return continuous_sqrt_alpha_cumprod * x_start + (1 - continuous_sqrt_alpha_cumprod ** 2).sqrt() * noise @torch.no_grad() def _extract_features(model, x, t, sqrt_alphas): b = x.shape[0] lvl = torch.FloatTensor( np.random.uniform(sqrt_alphas[t - 1], sqrt_alphas[t], size=b) ).to(x.device).view(b, -1) noise = torch.randn_like(x) x_noisy = _q_sample(x, lvl.view(-1, 1, 1, 1), noise) return model(x_noisy, lvl, feat_need=True) # =========================================================================== # Pipeline # =========================================================================== class DDPMCDPipeline(DiffusionPipeline): """DDPM-based change detection. Load with trust_remote_code=True. For consolidated ddpm-cd repo with multiple cd_head variants, pass cd_head_subfolder (e.g. 'levir-50-100', 'whu-50-100-400', 'cdd-50-100', etc.) when loading.""" def __init__(self, unet, scheduler, cd_head=None, cd_head_subfolder=None): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) self.cd_head = cd_head self._cd_head_subfolder = cd_head_subfolder # Infer base path from unet config (dirname of unet subfolder = model root) unet_path = getattr(getattr(unet, "config", None), "_name_or_path", None) self._cd_head_base_path = os.path.dirname(unet_path) if unet_path else None def _load_cd_head_if_needed(self): """Lazy-load cd_head from disk when first needed (path inferred from unet).""" if self.cd_head is not None: return base = self._cd_head_base_path if base is None: cfg = getattr(self.unet, "config", None) base = os.path.dirname(getattr(cfg, "_name_or_path", "")) if cfg else None if not base or not os.path.isdir(base): return # no cd_head (e.g. pretrained-only model) subfolder = self._cd_head_subfolder if subfolder: cd_dir = os.path.join(base, "cd_head", subfolder) else: cd_dir = os.path.join(base, "cd_head") if not os.path.isfile(os.path.join(cd_dir, "config.json")): # Consolidated repo: cd_head_subfolder is required subdirs = sorted([d for d in os.listdir(cd_dir) if os.path.isdir(os.path.join(cd_dir, d))]) raise RuntimeError( "DDPMCDPipeline requires cd_head_subfolder when loading from consolidated ddpm-cd repo. " f"Available: {subdirs}. Example: from_pretrained(..., cd_head_subfolder='levir-50-100')" ) if not os.path.isdir(cd_dir): return # no cd_head (e.g. pretrained-only model) with open(os.path.join(cd_dir, "config.json")) as f: cfg = json.load(f) ch = cd_head_v2(**cfg) for name in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin"): p = os.path.join(cd_dir, name) if os.path.exists(p): if p.endswith(".safetensors"): from safetensors.torch import load_file ch.load_state_dict(load_file(p, device="cpu")) else: try: s = torch.load(p, map_location="cpu", weights_only=True) except TypeError: s = torch.load(p, map_location="cpu") ch.load_state_dict(s.state_dict() if hasattr(s, "state_dict") else s) break self.cd_head = ch def load_cd_head(self, pretrained_model_name_or_path=None, subfolder=None): """Manually load cd_head from the given path (or infer from unet). subfolder: e.g. 'levir-50-100', 'whu-50-100-400' for consolidated ddpm-cd repo.""" if pretrained_model_name_or_path: self._cd_head_base_path = pretrained_model_name_or_path if subfolder is not None: self._cd_head_subfolder = subfolder self._load_cd_head_if_needed() @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): cd_head_subfolder = kwargs.pop("cd_head_subfolder", None) pipe = super().from_pretrained(pretrained_model_name_or_path, **kwargs) pipe._cd_head_base_path = pretrained_model_name_or_path if os.path.isdir(pretrained_model_name_or_path) else None pipe._cd_head_subfolder = cd_head_subfolder pipe._load_cd_head_if_needed() return pipe @torch.no_grad() def __call__(self, image_A, image_B, timesteps=None, feat_type="dec"): self._load_cd_head_if_needed() if self.cd_head is None: raise RuntimeError("DDPMCDPipeline requires cd_head. Could not load from disk.") timesteps = timesteps or [50, 100] sqrt_a = _precompute_alpha_tables(self.scheduler) feats_A, feats_B = [], [] for t in timesteps: fe_A, fd_A = _extract_features(self.unet, image_A, t, sqrt_a) fe_B, fd_B = _extract_features(self.unet, image_B, t, sqrt_a) feats_A.append(fd_A if feat_type == "dec" else fe_A) feats_B.append(fd_B if feat_type == "dec" else fe_B) return self.cd_head(feats_A, feats_B) @torch.no_grad() def generate(self, batch_size=1, in_channels=3, image_size=256, num_inference_steps=None, generator=None): device = next(self.unet.parameters()).device steps = num_inference_steps or self.scheduler.config.num_train_timesteps sqrt_a = _precompute_alpha_tables(self.scheduler) image = torch.randn((batch_size, in_channels, image_size, image_size), device=device, generator=generator) self.scheduler.set_timesteps(steps) for t in tqdm(self.scheduler.timesteps, desc="Sampling"): idx = min(int(t) + 1, len(sqrt_a) - 1) lvl = torch.FloatTensor([sqrt_a[idx]]).repeat(batch_size, 1).to(device) noise_pred = self.unet(image, lvl) image = self.scheduler.step(noise_pred, t, image).prev_sample return image