import math import torch import torch.nn as nn import torch.nn.functional as F from torch import einsum from einops import rearrange, repeat from inspect import isfunction import numpy as np from abc import abstractmethod # --- Utilities --- def exists(val): return val is not None def default(val, d): if exists(val): return val return d() if isfunction(d) else d def checkpoint(func, inputs, params, flag): if flag and any(x.requires_grad for x in inputs if isinstance(x, torch.Tensor)): return torch.utils.checkpoint.checkpoint(func, *inputs) return func(*inputs) def timestep_embedding(timesteps, dim, max_period=10000): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device) args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def zero_module(module): for p in module.parameters(): p.detach().zero_() return module def normalization(channels): return nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True) # --- Attention Blocks --- class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) return x * F.gelu(gate) class FeedForward(nn.Module): def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) def forward(self, x): return self.net(x) class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) self.scale = dim_head ** -0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) def forward(self, x, context=None, mask=None): h = self.heads q = self.to_q(x) context = default(context, x) k = self.to_k(context) v = self.to_v(context) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) sim = einsum('b i d, b j d -> b i j', q, k) * self.scale if exists(mask): # Optimized for HF: ensure mask is on same device mask = rearrange(mask.to(sim.device), 'b j -> b 1 1 j') sim.masked_fill_(~mask, -torch.finfo(sim.dtype).max) attn = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', attn, v) return self.to_out(rearrange(out, '(b h) n d -> b n (h d)', h=h)) # --- Transformer Blocks --- class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout) self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) self.norm1, self.norm2, self.norm3 = nn.LayerNorm(dim), nn.LayerNorm(dim), nn.LayerNorm(dim) self.checkpoint = checkpoint def forward(self, x, context=None): return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) def _forward(self, x, context=None): x = self.attn1(self.norm1(x)) + x x = self.attn2(self.norm2(x), context=context) + x x = self.ff(self.norm3(x)) + x return x class SpatialTransformer(nn.Module): def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0., context_dim=None): super().__init__() inner_dim = n_heads * d_head self.norm = normalization(in_channels) self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1) self.transformer_blocks = nn.ModuleList([ BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) for _ in range(depth) ]) self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1)) def forward(self, x, context=None): b, c, h, w = x.shape x_in = x x = rearrange(self.proj_in(self.norm(x)), 'b c h w -> b (h w) c') for block in self.transformer_blocks: x = block(x, context=context) return self.proj_out(rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)) + x_in # --- UNet Components --- class TimestepBlock(nn.Module): @abstractmethod def forward(self, x, emb): pass class TimestepEmbedSequential(nn.Sequential, TimestepBlock): def forward(self, x, emb, context=None): for layer in self: if isinstance(layer, TimestepBlock): x = layer(x, emb) elif isinstance(layer, SpatialTransformer): x = layer(x, context) else: x = layer(x) return x class ResBlock(TimestepBlock): def __init__(self, channels, emb_channels, dropout, out_channels=None, use_conv=False, use_scale_shift_norm=False): super().__init__() self.out_channels = out_channels or channels self.in_layers = nn.Sequential(normalization(channels), nn.SiLU(), nn.Conv2d(channels, self.out_channels, 3, padding=1)) self.emb_layers = nn.Sequential(nn.SiLU(), nn.Linear(emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels)) self.out_layers = nn.Sequential(normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), zero_module(nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1))) self.skip_connection = nn.Identity() if self.out_channels == channels else (nn.Conv2d(channels, self.out_channels, 3, padding=1) if use_conv else nn.Conv2d(channels, self.out_channels, 1)) self.use_scale_shift_norm = use_scale_shift_norm def forward(self, x, emb): h = self.in_layers(x) emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] scale, shift = torch.chunk(emb_out, 2, dim=1) h = out_norm(h) * (1 + scale) + shift h = out_rest(h) else: h = self.out_layers(h + emb_out) return self.skip_connection(x) + h class Upsample(nn.Module): def __init__(self, channels, use_conv, out_channels=None): super().__init__() self.conv = nn.Conv2d(channels, out_channels or channels, 3, padding=1) if use_conv else nn.Identity() def forward(self, x): x = F.interpolate(x, scale_factor=2, mode="nearest") return self.conv(x) class Downsample(nn.Module): def __init__(self, channels, use_conv, out_channels=None): super().__init__() self.op = nn.Conv2d(channels, out_channels or channels, 3, stride=2, padding=1) if use_conv else nn.AvgPool2d(2) def forward(self, x): return self.op(x) # --- Main Model --- class UNetModel(nn.Module): def __init__(self, image_size, in_channels, model_channels, out_channels, num_res_blocks, attention_resolutions, dropout=0, channel_mult=(1, 2, 4, 8), conv_resample=True, context_dim=768, text_encoder=None, args=None): super().__init__() self.args = args self.model_channels = model_channels self.text_encoder = text_encoder time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential(nn.Linear(model_channels, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)) self.label_emb = nn.Embedding(1000, time_embed_dim) # Safety fallback self.style_lin = nn.Linear(1280, time_embed_dim) self.text_lin = nn.Linear(768, 320) if context_dim == 320 else nn.Identity() self.input_blocks = nn.ModuleList([TimestepEmbedSequential(nn.Conv2d(in_channels, model_channels, 3, padding=1))]) input_block_chans = [model_channels] ch, ds = model_channels, 1 for level, mult in enumerate(channel_mult): for _ in range(num_res_blocks): layers = [ResBlock(ch, time_embed_dim, dropout, out_channels=model_channels * mult)] ch = model_channels * mult if ds in attention_resolutions: layers.append(SpatialTransformer(ch, 8, ch // 8, depth=1, context_dim=context_dim)) self.input_blocks.append(TimestepEmbedSequential(*layers)) input_block_chans.append(ch) if level != len(channel_mult) - 1: self.input_blocks.append(TimestepEmbedSequential(Downsample(ch, conv_resample))) input_block_chans.append(ch) ds *= 2 self.middle_block = TimestepEmbedSequential( ResBlock(ch, time_embed_dim, dropout), SpatialTransformer(ch, 8, ch // 8, depth=1, context_dim=context_dim), ResBlock(ch, time_embed_dim, dropout) ) self.output_blocks = nn.ModuleList([]) for level, mult in list(enumerate(channel_mult))[::-1]: for i in range(num_res_blocks + 1): ich = input_block_chans.pop() layers = [ResBlock(ch + ich, time_embed_dim, dropout, out_channels=model_channels * mult)] ch = model_channels * mult if ds in attention_resolutions: layers.append(SpatialTransformer(ch, 8, ch // 8, depth=1, context_dim=context_dim)) if level and i == num_res_blocks: layers.append(Upsample(ch, conv_resample)) ds //= 2 self.output_blocks.append(TimestepEmbedSequential(*layers)) self.out = nn.Sequential(normalization(ch), nn.SiLU(), zero_module(nn.Conv2d(model_channels, out_channels, 3, padding=1))) def forward(self, x, timesteps, context=None, style_extractor=None, **kwargs): # 1. Time Embedding t_emb = timestep_embedding(timesteps, self.model_channels) emb = self.time_embed(t_emb) # 2. Style Embedding (Fixed Reshaping Bug) if style_extractor is not None: # If batch has style reference, average it if len(style_extractor.shape) == 3: style_vec = torch.mean(style_extractor, dim=1) else: style_vec = style_extractor emb = emb + self.style_lin(style_vec.to(x.device)) # 3. Text Context (Canine Support) if context is not None and self.text_encoder is not None: # Optimized: check if context is already encoded to save time if not isinstance(context, torch.Tensor): with torch.no_grad(): context = self.text_encoder(**context).last_hidden_state context = self.text_lin(context) # 4. UNet Pass h = x.type(torch.float32) hs = [] for module in self.input_blocks: h = module(h, emb, context) hs.append(h) h = self.middle_block(h, emb, context) for module in self.output_blocks: h = torch.cat([h, hs.pop()], dim=1) h = module(h, emb, context) return self.out(h)