treadon's picture
Upload nucleus_image/vae.py with huggingface_hub
914f9d0 verified
Raw
History Blame
6.35 kB
"""Qwen-Image VAE Decoder in MLX.
The original uses Conv3d (video-capable). For image generation T=1,
so all Conv3d reduce to Conv2d. We squeeze the temporal dimension
and use standard Conv2d.
Architecture (from config):
z_dim: 16 (latent channels)
base_dim: 96
dim_mult: [1, 2, 4, 4] → channels [96, 192, 384, 384]
num_res_blocks: 2
8x spatial upscale
Weight naming (from safetensors):
decoder.conv_in, decoder.mid_block.{resnets,attentions}, decoder.up_blocks,
decoder.conv_out, decoder.norm_out
Uses 'gamma' for norm weights (not 'weight')
Conv weights are 5D: [out, in, D, H, W] → squeeze D for Conv2d
"""
import mlx.core as mx
import mlx.nn as nn
class RMSNorm2D(nn.Module):
"""RMS normalization with spatial gamma. Matches 'gamma' weight naming."""
def __init__(self, channels: int):
super().__init__()
self.gamma = mx.ones((channels,))
def __call__(self, x):
# x: [B, H, W, C]
rms = mx.sqrt(mx.mean(x * x, axis=-1, keepdims=True) + 1e-6)
return (x / rms) * self.gamma
class ResnetBlock(nn.Module):
"""Residual block. Matches: norm1.gamma, conv1, norm2.gamma, conv2."""
def __init__(self, in_ch: int, out_ch: int):
super().__init__()
self.norm1 = RMSNorm2D(in_ch)
self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
self.norm2 = RMSNorm2D(out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
if in_ch != out_ch:
self.conv_shortcut = nn.Conv2d(in_ch, out_ch, kernel_size=1)
else:
self.conv_shortcut = None
def __call__(self, x):
h = nn.silu(self.norm1(x))
h = self.conv1(h)
h = nn.silu(self.norm2(h))
h = self.conv2(h)
if self.conv_shortcut is not None:
x = self.conv_shortcut(x)
return x + h
class AttentionBlock(nn.Module):
"""Self-attention. Matches: norm.gamma, to_qkv.{weight,bias}, proj.{weight,bias}.
Note: weights are stored as Conv2d 1x1 [out, in, 1, 1] but we use Linear."""
def __init__(self, channels: int):
super().__init__()
self.norm = RMSNorm2D(channels)
self.to_qkv = nn.Conv2d(channels, channels * 3, kernel_size=1)
self.proj = nn.Conv2d(channels, channels, kernel_size=1)
def __call__(self, x):
B, H, W, C = x.shape
residual = x
x = self.norm(x)
qkv = self.to_qkv(x) # [B, H, W, C*3]
qkv = qkv.reshape(B, H * W, C * 3)
q, k, v = mx.split(qkv, 3, axis=-1)
scale = C ** -0.5
attn = (q @ k.transpose(0, 2, 1)) * scale
attn = mx.softmax(attn, axis=-1)
out = attn @ v
out = out.reshape(B, H, W, C)
out = self.proj(out)
return out + residual
class PixelShufflePlaceholder(nn.Module):
"""Placeholder for resample[0] (pixel unshuffle). Not used at inference."""
pass
class Upsample(nn.Module):
"""2x spatial upsample. Matches: upsamplers.0.resample.{0,1} + [time_conv]."""
def __init__(self, in_ch: int, out_ch: int, has_time_conv: bool = True):
super().__init__()
self.resample = [PixelShufflePlaceholder(), nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)]
if has_time_conv:
self.time_conv = nn.Conv2d(in_ch, in_ch * 2, kernel_size=1)
def __call__(self, x):
x = mx.repeat(x, 2, axis=1)
x = mx.repeat(x, 2, axis=2)
return self.resample[1](x)
class UpBlock(nn.Module):
"""Matches: resnets.{0,1,2}, [upsamplers.0]."""
def __init__(self, in_ch: int, out_ch: int, num_res_blocks: int = 3, upsample_out_ch: int = None, has_time_conv: bool = True):
super().__init__()
self.resnets = [ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res_blocks)]
if upsample_out_ch is not None:
self.upsamplers = [Upsample(out_ch, upsample_out_ch, has_time_conv=has_time_conv)]
else:
self.upsamplers = None
def __call__(self, x):
for resnet in self.resnets:
x = resnet(x)
if self.upsamplers is not None:
x = self.upsamplers[0](x)
return x
class MidBlock(nn.Module):
"""Matches: resnets.{0,1}, attentions.0."""
def __init__(self, channels: int):
super().__init__()
self.resnets = [ResnetBlock(channels, channels), ResnetBlock(channels, channels)]
self.attentions = [AttentionBlock(channels)]
def __call__(self, x):
x = self.resnets[0](x)
x = self.attentions[0](x)
x = self.resnets[1](x)
return x
class Decoder(nn.Module):
"""Full decoder. Matches diffusers naming."""
def __init__(self):
super().__init__()
# Config: base_dim=96, dim_mult=[1,2,4,4] → [96, 192, 384, 384]
# Decoder goes reversed: [384, 384, 192, 96]
self.conv_in = nn.Conv2d(16, 384, kernel_size=3, padding=1)
self.mid_block = MidBlock(384)
# From weight shapes:
# block 0: all 384→384, upsample 384→192
# block 1: first resnet 192→384 (has conv_shortcut), rest 384→384, upsample 384→192
# block 2: all 192→192, upsample 192→96
# block 3: first resnet 96→96, rest 96→96, no upsample
self.up_blocks = [
UpBlock(384, 384, upsample_out_ch=192), # 384→384, upsample→192
UpBlock(192, 384, upsample_out_ch=192), # 192→384 (shortcut), upsample→192
UpBlock(192, 192, upsample_out_ch=96, has_time_conv=False), # 192→192, upsample→96 (no time_conv)
UpBlock(96, 96, upsample_out_ch=None), # 96→96, no upsample
]
self.norm_out = RMSNorm2D(96)
self.conv_out = nn.Conv2d(96, 3, kernel_size=3, padding=1)
def __call__(self, z):
x = self.conv_in(z)
x = self.mid_block(x)
for block in self.up_blocks:
x = block(x)
x = nn.silu(self.norm_out(x))
x = self.conv_out(x)
return x
class VAEDecoder(nn.Module):
"""Top-level: post_quant_conv + decoder."""
def __init__(self):
super().__init__()
self.post_quant_conv = nn.Conv2d(16, 16, kernel_size=1)
self.decoder = Decoder()
def __call__(self, z):
z = self.post_quant_conv(z)
return self.decoder(z)