ddpm-cd / pipeline.py
BiliSakura's picture
Add files using upload-large-folder tool
01176da verified
"""
DDPMCDPipeline for change detection.
pipeline.py is in the repo — use custom_pipeline="pipeline" (relative path).
Usage::
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(
"BiliSakura/ddpm-cd",
custom_pipeline="pipeline",
trust_remote_code=True,
cd_head_subfolder="levir-50-100",
)
change_map = pipe(image_A, image_B, timesteps=[50, 100])
"""
import json
import math
import os
from inspect import isfunction
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import DDPMScheduler
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_utils import ModelMixin # ModelMixin subclasses nn.Module
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from tqdm.auto import tqdm
# ===========================================================================
# UNet (SR3-style) - all components inlined
# ===========================================================================
def _exists(x):
return x is not None
def _default(val, d):
if _exists(val):
return val
return d() if isfunction(d) else d
class PositionalEncoding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, noise_level):
count = self.dim // 2
step = torch.arange(count, dtype=noise_level.dtype, device=noise_level.device) / count
encoding = noise_level.unsqueeze(1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
return torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1)
class FeatureWiseAffine(nn.Module):
def __init__(self, in_channels, out_channels, use_affine_level=False):
super().__init__()
self.use_affine_level = use_affine_level
self.noise_func = nn.Sequential(nn.Linear(in_channels, out_channels * (1 + self.use_affine_level)))
def forward(self, x, noise_embed):
batch = x.shape[0]
if self.use_affine_level:
gamma, beta = self.noise_func(noise_embed).view(batch, -1, 1, 1).chunk(2, dim=1)
x = (1 + gamma) * x + beta
else:
x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1)
return x
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class Upsample(nn.Module):
def __init__(self, dim):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode="nearest")
self.conv = nn.Conv2d(dim, dim, 3, padding=1)
def forward(self, x):
return self.conv(self.up(x))
class Downsample(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv2d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Block(nn.Module):
def __init__(self, dim, dim_out, groups=32, dropout=0):
super().__init__()
self.block = nn.Sequential(
nn.GroupNorm(groups, dim),
Swish(),
nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
nn.Conv2d(dim, dim_out, 3, padding=1),
)
def forward(self, x):
return self.block(x)
class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32):
super().__init__()
self.noise_func = FeatureWiseAffine(noise_level_emb_dim, dim_out, use_affine_level)
self.block1 = Block(dim, dim_out, groups=norm_groups)
self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb):
h = self.block1(x)
h = self.noise_func(h, time_emb)
h = self.block2(h)
return h + self.res_conv(x)
class SelfAttention(nn.Module):
def __init__(self, in_channel, n_head=1, norm_groups=32):
super().__init__()
self.n_head = n_head
self.norm = nn.GroupNorm(norm_groups, in_channel)
self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
self.out = nn.Conv2d(in_channel, in_channel, 1)
def forward(self, input):
batch, channel, height, width = input.shape
n_head, head_dim = self.n_head, channel // self.n_head
norm = self.norm(input)
qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
query, key, value = qkv.chunk(3, dim=2)
attn = torch.einsum("bnchw, bncyx -> bnhwyx", query, key).contiguous() / math.sqrt(channel)
attn = torch.softmax(attn.view(batch, n_head, height, width, -1), -1)
attn = attn.view(batch, n_head, height, width, height, width)
out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
return self.out(out.view(batch, channel, height, width)) + input
class ResnetBlocWithAttn(nn.Module):
def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False):
super().__init__()
self.with_attn = with_attn
self.res_block = ResnetBlock(dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout)
self.attn = SelfAttention(dim_out, norm_groups=norm_groups) if with_attn else None
def forward(self, x, time_emb):
x = self.res_block(x, time_emb)
if self.with_attn:
x = self.attn(x)
return x
class UNet(ModelMixin, ConfigMixin):
"""SR3-style UNet with noise-level conditioning. Supports feat_need=True for intermediate features."""
@register_to_config
def __init__(
self,
in_channel=6,
out_channel=3,
inner_channel=32,
norm_groups=32,
channel_mults=(1, 2, 4, 8, 8),
attn_res=(8,),
res_blocks=3,
dropout=0,
with_noise_level_emb=True,
image_size=128,
):
super().__init__()
noise_level_channel = inner_channel if with_noise_level_emb else None
self.noise_level_mlp = (
nn.Sequential(
PositionalEncoding(inner_channel),
nn.Linear(inner_channel, inner_channel * 4),
Swish(),
nn.Linear(inner_channel * 4, inner_channel),
)
if with_noise_level_emb
else None
)
num_mults = len(channel_mults)
pre_channel, feat_channels, now_res = inner_channel, [inner_channel], image_size
self.init_conv = nn.Conv2d(in_channel, inner_channel, 3, padding=1)
downs = []
for ind in range(num_mults):
use_attn = now_res in attn_res
channel_mult = inner_channel * channel_mults[ind]
for _ in range(res_blocks):
downs.append(
ResnetBlocWithAttn(
pre_channel, channel_mult,
noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
dropout=dropout, with_attn=use_attn,
)
)
feat_channels.append(channel_mult)
pre_channel = channel_mult
if ind < num_mults - 1:
downs.append(Downsample(pre_channel))
feat_channels.append(pre_channel)
now_res = now_res // 2
self.downs = nn.ModuleList(downs)
self.mid = nn.ModuleList([
ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
norm_groups=norm_groups, dropout=dropout, with_attn=True),
ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel,
norm_groups=norm_groups, dropout=dropout, with_attn=False),
])
ups = []
for ind in reversed(range(num_mults)):
use_attn = now_res in attn_res
channel_mult = inner_channel * channel_mults[ind]
for _ in range(res_blocks + 1):
ups.append(
ResnetBlocWithAttn(
pre_channel + feat_channels.pop(), channel_mult,
noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups,
dropout=dropout, with_attn=use_attn,
)
)
pre_channel = channel_mult
if ind > 0:
ups.append(Upsample(pre_channel))
now_res = now_res * 2
self.ups = nn.ModuleList(ups)
self.final_conv = Block(pre_channel, _default(out_channel, lambda: in_channel), groups=norm_groups)
def forward(self, x, time, feat_need=False):
t = self.noise_level_mlp(time) if _exists(self.noise_level_mlp) else None
x = self.init_conv(x)
feats = [x]
for layer in self.downs:
x = layer(x, t) if isinstance(layer, ResnetBlocWithAttn) else layer(x)
feats.append(x)
fe = feats.copy() if feat_need else None
for layer in self.mid:
x = layer(x, t) if isinstance(layer, ResnetBlocWithAttn) else layer(x)
fd = [] if feat_need else None
for layer in self.ups:
if isinstance(layer, ResnetBlocWithAttn):
x = layer(torch.cat((x, feats.pop()), dim=1), t)
if feat_need:
fd.append(x)
else:
x = layer(x)
x = self.final_conv(x)
return (fe, list(reversed(fd))) if feat_need else x
# ===========================================================================
# Change detection head
# ===========================================================================
class ChannelSELayer(nn.Module):
def __init__(self, num_channels, reduction_ratio=2):
super().__init__()
reduced = num_channels // reduction_ratio
self.fc1 = nn.Linear(num_channels, reduced, bias=True)
self.fc2 = nn.Linear(reduced, num_channels, bias=True)
self.relu, self.sigmoid = nn.ReLU(), nn.Sigmoid()
def forward(self, x):
b, c, _, _ = x.size()
s = x.view(b, c, -1).mean(dim=2)
s = self.sigmoid(self.fc2(self.relu(self.fc1(s)))).view(b, c, 1, 1)
return x * s
class SpatialSELayer(nn.Module):
def __init__(self, num_channels):
super().__init__()
self.conv = nn.Conv2d(num_channels, 1, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x, weights=None):
b, c, h, w = x.size()
out = F.conv2d(x, weights.view(1, c, 1, 1)) if weights is not None else self.conv(x)
return x * self.sigmoid(out).view(b, 1, h, w)
class ChannelSpatialSELayer(nn.Module):
def __init__(self, num_channels, reduction_ratio=2):
super().__init__()
self.cSE = ChannelSELayer(num_channels, reduction_ratio)
self.sSE = SpatialSELayer(num_channels)
def forward(self, x):
return self.cSE(x) + self.sSE(x)
def _get_in_channels(feat_scales, inner_channel, channel_multiplier):
m, cm = inner_channel, channel_multiplier
r = 0
for s in feat_scales:
if s < 3: r += m * cm[0]
elif s < 6: r += m * cm[1]
elif s < 9: r += m * cm[2]
elif s < 12: r += m * cm[3]
elif s < 15: r += m * cm[4]
else: raise ValueError("feat_scales 0<=s<=14")
return r
class AttentionBlock(nn.Module):
def __init__(self, dim, dim_out):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(dim, dim_out, 3, padding=1),
nn.ReLU(),
ChannelSpatialSELayer(dim_out, 2),
)
def forward(self, x):
return self.block(x)
class CDBlock(nn.Module):
def __init__(self, dim, dim_out, time_steps):
super().__init__()
if len(time_steps) > 1:
self.block = nn.Sequential(
nn.Conv2d(dim * len(time_steps), dim, 1), nn.ReLU(),
nn.Conv2d(dim, dim_out, 3, padding=1), nn.ReLU(),
)
else:
self.block = nn.Sequential(nn.Conv2d(dim, dim_out, 3, padding=1), nn.ReLU())
def forward(self, x):
return self.block(x)
class cd_head_v2(nn.Module):
"""Change detection head (version 2)."""
def __init__(self, feat_scales, out_channels=2, inner_channel=None, channel_multiplier=None, img_size=256, time_steps=None):
super().__init__()
self.feat_scales = sorted(list(feat_scales), reverse=True)
self.in_channels = _get_in_channels(self.feat_scales, inner_channel, channel_multiplier)
self.img_size, self.time_steps = img_size, time_steps
self.decoder = nn.ModuleList()
for i in range(len(self.feat_scales)):
dim = _get_in_channels([self.feat_scales[i]], inner_channel, channel_multiplier)
self.decoder.append(CDBlock(dim, dim, time_steps))
if i < len(self.feat_scales) - 1:
dim_out = _get_in_channels([self.feat_scales[i + 1]], inner_channel, channel_multiplier)
self.decoder.append(AttentionBlock(dim, dim_out))
self.clfr_stg1 = nn.Conv2d(dim_out, 64, 3, padding=1)
self.clfr_stg2 = nn.Conv2d(64, out_channels, 3, padding=1)
self.relu = nn.ReLU()
def forward(self, feats_A, feats_B):
lvl, x = 0, None
for layer in self.decoder:
if isinstance(layer, CDBlock):
f_A = feats_A[0][self.feat_scales[lvl]]
f_B = feats_B[0][self.feat_scales[lvl]]
if len(self.time_steps) > 1:
for i in range(1, len(self.time_steps)):
f_A = torch.cat((f_A, feats_A[i][self.feat_scales[lvl]]), dim=1)
f_B = torch.cat((f_B, feats_B[i][self.feat_scales[lvl]]), dim=1)
diff = torch.abs(layer(f_A) - layer(f_B))
if lvl > 0:
diff = diff + x
lvl += 1
else:
diff = layer(diff)
x = F.interpolate(diff, scale_factor=2, mode="bilinear")
return self.clfr_stg2(self.relu(self.clfr_stg1(x)))
# ===========================================================================
# Diffusion utilities
# ===========================================================================
def _precompute_alpha_tables(scheduler):
ac = scheduler.alphas_cumprod.numpy()
return np.sqrt(np.append(1.0, ac))
def _q_sample(x_start, continuous_sqrt_alpha_cumprod, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
return continuous_sqrt_alpha_cumprod * x_start + (1 - continuous_sqrt_alpha_cumprod ** 2).sqrt() * noise
@torch.no_grad()
def _extract_features(model, x, t, sqrt_alphas):
b = x.shape[0]
lvl = torch.FloatTensor(
np.random.uniform(sqrt_alphas[t - 1], sqrt_alphas[t], size=b)
).to(x.device).view(b, -1)
noise = torch.randn_like(x)
x_noisy = _q_sample(x, lvl.view(-1, 1, 1, 1), noise)
return model(x_noisy, lvl, feat_need=True)
# ===========================================================================
# Pipeline
# ===========================================================================
class DDPMCDPipeline(DiffusionPipeline):
"""DDPM-based change detection. Load with trust_remote_code=True.
For consolidated ddpm-cd repo with multiple cd_head variants, pass cd_head_subfolder
(e.g. 'levir-50-100', 'whu-50-100-400', 'cdd-50-100', etc.) when loading."""
def __init__(self, unet, scheduler, cd_head=None, cd_head_subfolder=None):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
self.cd_head = cd_head
self._cd_head_subfolder = cd_head_subfolder
# Infer base path from unet config (dirname of unet subfolder = model root)
unet_path = getattr(getattr(unet, "config", None), "_name_or_path", None)
self._cd_head_base_path = os.path.dirname(unet_path) if unet_path else None
def _load_cd_head_if_needed(self):
"""Lazy-load cd_head from disk when first needed (path inferred from unet)."""
if self.cd_head is not None:
return
base = self._cd_head_base_path
if base is None:
cfg = getattr(self.unet, "config", None)
base = os.path.dirname(getattr(cfg, "_name_or_path", "")) if cfg else None
if not base or not os.path.isdir(base):
return # no cd_head (e.g. pretrained-only model)
subfolder = self._cd_head_subfolder
if subfolder:
cd_dir = os.path.join(base, "cd_head", subfolder)
else:
cd_dir = os.path.join(base, "cd_head")
if not os.path.isfile(os.path.join(cd_dir, "config.json")):
# Consolidated repo: cd_head_subfolder is required
subdirs = sorted([d for d in os.listdir(cd_dir) if os.path.isdir(os.path.join(cd_dir, d))])
raise RuntimeError(
"DDPMCDPipeline requires cd_head_subfolder when loading from consolidated ddpm-cd repo. "
f"Available: {subdirs}. Example: from_pretrained(..., cd_head_subfolder='levir-50-100')"
)
if not os.path.isdir(cd_dir):
return # no cd_head (e.g. pretrained-only model)
with open(os.path.join(cd_dir, "config.json")) as f:
cfg = json.load(f)
ch = cd_head_v2(**cfg)
for name in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin"):
p = os.path.join(cd_dir, name)
if os.path.exists(p):
if p.endswith(".safetensors"):
from safetensors.torch import load_file
ch.load_state_dict(load_file(p, device="cpu"))
else:
try:
s = torch.load(p, map_location="cpu", weights_only=True)
except TypeError:
s = torch.load(p, map_location="cpu")
ch.load_state_dict(s.state_dict() if hasattr(s, "state_dict") else s)
break
self.cd_head = ch
def load_cd_head(self, pretrained_model_name_or_path=None, subfolder=None):
"""Manually load cd_head from the given path (or infer from unet).
subfolder: e.g. 'levir-50-100', 'whu-50-100-400' for consolidated ddpm-cd repo."""
if pretrained_model_name_or_path:
self._cd_head_base_path = pretrained_model_name_or_path
if subfolder is not None:
self._cd_head_subfolder = subfolder
self._load_cd_head_if_needed()
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
cd_head_subfolder = kwargs.pop("cd_head_subfolder", None)
pipe = super().from_pretrained(pretrained_model_name_or_path, **kwargs)
pipe._cd_head_base_path = pretrained_model_name_or_path if os.path.isdir(pretrained_model_name_or_path) else None
pipe._cd_head_subfolder = cd_head_subfolder
pipe._load_cd_head_if_needed()
return pipe
@torch.no_grad()
def __call__(self, image_A, image_B, timesteps=None, feat_type="dec"):
self._load_cd_head_if_needed()
if self.cd_head is None:
raise RuntimeError("DDPMCDPipeline requires cd_head. Could not load from disk.")
timesteps = timesteps or [50, 100]
sqrt_a = _precompute_alpha_tables(self.scheduler)
feats_A, feats_B = [], []
for t in timesteps:
fe_A, fd_A = _extract_features(self.unet, image_A, t, sqrt_a)
fe_B, fd_B = _extract_features(self.unet, image_B, t, sqrt_a)
feats_A.append(fd_A if feat_type == "dec" else fe_A)
feats_B.append(fd_B if feat_type == "dec" else fe_B)
return self.cd_head(feats_A, feats_B)
@torch.no_grad()
def generate(self, batch_size=1, in_channels=3, image_size=256, num_inference_steps=None, generator=None):
device = next(self.unet.parameters()).device
steps = num_inference_steps or self.scheduler.config.num_train_timesteps
sqrt_a = _precompute_alpha_tables(self.scheduler)
image = torch.randn((batch_size, in_channels, image_size, image_size), device=device, generator=generator)
self.scheduler.set_timesteps(steps)
for t in tqdm(self.scheduler.timesteps, desc="Sampling"):
idx = min(int(t) + 1, len(sqrt_a) - 1)
lvl = torch.FloatTensor([sqrt_a[idx]]).repeat(batch_size, 1).to(device)
noise_pred = self.unet(image, lvl)
image = self.scheduler.step(noise_pred, t, image).prev_sample
return image