"""Nucleus-Image MLX Pipeline. MoE DiT + VAE in MLX, text encoder in PyTorch (hybrid). """ import json import time from pathlib import Path from typing import Optional import mlx.core as mx import mlx.nn as nn import numpy as np from PIL import Image from huggingface_hub import snapshot_download from .dit import NucleusMoEDiT from .vae import VAEDecoder from .scheduler import FlowMatchEulerScheduler def patchify(x, patch_size=2): """[B, H, W, C] → [B, (H/p)*(W/p), C*p*p] Matches diffusers _pack_latents: token layout is [C, ph, pw] (channels first). Input x is NHWC. We rearrange to [B, H/p, W/p, C, p, p] then flatten. """ B, H, W, C = x.shape p = patch_size x = x.reshape(B, H // p, p, W // p, p, C) # [B, H/p, p, W/p, p, C] → [B, H/p, W/p, C, p, p] x = x.transpose(0, 1, 3, 5, 2, 4) return x.reshape(B, (H // p) * (W // p), C * p * p) def unpatchify(x, h, w, patch_size=2): """[B, N, C*p*p] → [B, H, W, C] Inverse of patchify. Token layout is [C, ph, pw]. """ B, N, D = x.shape p = patch_size C = D // (p * p) hp, wp = h // p, w // p x = x.reshape(B, hp, wp, C, p, p) # [B, hp, wp, C, p, p] → [B, hp, p, wp, p, C] x = x.transpose(0, 1, 4, 2, 5, 3) return x.reshape(B, h, w, C) class NucleusImagePipeline: def __init__(self, dit, vae, scheduler, latents_mean, latents_std): self.dit = dit self.vae = vae self.scheduler = scheduler self.latents_mean = latents_mean self.latents_std = latents_std @staticmethod def from_pretrained(model_id="NucleusAI/Nucleus-Image", quantize=None): path = Path(snapshot_download(model_id)) # Config with open(path / "transformer" / "config.json") as f: dit_config = json.load(f) with open(path / "vae" / "config.json") as f: vae_config = json.load(f) # DiT print("Loading DiT...") dit = NucleusMoEDiT(dit_config) dit_weights = {} for f in sorted((path / "transformer").glob("*.safetensors")): dit_weights.update(mx.load(str(f))) dit.load_weights(list(dit_weights.items())) if quantize: print(f"Quantizing DiT to {quantize}-bit...") nn.quantize(dit, bits=quantize) # VAE print("Loading VAE...") vae = VAEDecoder() raw_vae = mx.load(str(path / "vae" / "diffusion_pytorch_model.safetensors")) vae_w = {} for k, v in raw_vae.items(): if k.startswith("encoder.") or k.startswith("quant_conv"): continue if k.startswith("latents_") or k in ("spatial_scale_factor", "temporal_scale_factor"): continue if k.startswith("bn."): continue if "weight" in k and v.ndim == 5: D = v.shape[2] # CausalConv3d: for T=1 input with padding=(2*p, 0), only the # LAST temporal slice of the kernel contributes v = v[:, :, -1, :, :] if D > 1 else v.squeeze(2) v = v.transpose(0, 2, 3, 1) elif "weight" in k and v.ndim == 4: v = v.transpose(0, 2, 3, 1) if "gamma" in k: v = v.squeeze() vae_w[k] = v vae.load_weights(list(vae_w.items())) # Latent stats latents_mean = mx.array(vae_config["latents_mean"]) latents_std = mx.array(vae_config["latents_std"]) scheduler = FlowMatchEulerScheduler() return NucleusImagePipeline(dit, vae, scheduler, latents_mean, latents_std) def generate(self, text_embeddings=None, neg_text_embeddings=None, height=1024, width=1024, num_inference_steps=50, guidance_scale=4.0, seed=None): t_start = time.time() latent_h = height // 8 # VAE is 8x latent_w = width // 8 if text_embeddings is None: text_embeddings = mx.zeros((1, 1, 4096)) text_bth = mx.expand_dims(text_embeddings, 0) if text_embeddings.ndim == 2 else text_embeddings do_cfg = guidance_scale > 1.0 if do_cfg and neg_text_embeddings is None: neg_text_embeddings = mx.zeros_like(text_bth) if seed is not None: mx.random.seed(seed) # Generate noise in latent space, then patchify latents = mx.random.normal((1, latent_h, latent_w, 16)) tokens = patchify(latents, patch_size=2) import numpy as np # Calculate dynamic shift based on image sequence length image_seq_len = tokens.shape[1] base_seq_len = 256 max_seq_len = 4096 base_shift = 0.5 max_shift = 1.15 # Linear interpolation of shift based on sequence length m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) # Apply shift: sigma_shifted = exp(mu) * sigma / (1 + (exp(mu) - 1) * sigma) shift = np.exp(mu) sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) self.scheduler.sigmas = mx.concatenate([mx.array(sigmas), mx.array([0.0])]) self.scheduler.timesteps = mx.array(sigmas) * 1000 for i, t in enumerate(self.scheduler.timesteps): # Normalize: divide by num_train_timesteps (1000) matching diffusers pipeline # Transformer receives sigma (0-1), Timesteps(scale=1000) handles the rest t_normalized = mx.array([t.item() / 1000.0]) pred = self.dit(tokens, t_normalized, text_bth) if do_cfg: neg_pred = self.dit(tokens, t_normalized, neg_text_embeddings) # CFG with norm rescaling comb = neg_pred + guidance_scale * (pred - neg_pred) cond_norm = mx.sqrt(mx.sum(pred * pred, axis=-1, keepdims=True) + 1e-8) noise_norm = mx.sqrt(mx.sum(comb * comb, axis=-1, keepdims=True) + 1e-8) pred = comb * (cond_norm / noise_norm) # Negate prediction (from diffusers pipeline line 597) pred = -pred tokens = self.scheduler.step(pred, i, tokens) mx.eval(tokens) denoise_time = time.time() - t_start # Unpatchify latents = unpatchify(tokens, latent_h, latent_w, patch_size=2) # Denormalize: latents * std + mean # diffusers computes: latents_std_inv = 1/config_std, then latents / std_inv = latents * config_std mean = self.latents_mean.reshape(1, 1, 1, -1) std = self.latents_std.reshape(1, 1, 1, -1) latents = latents * std + mean # VAE decode images = self.vae(latents) mx.eval(images) total_time = time.time() - t_start print(f" Denoise: {denoise_time:.1f}s | Decode: {total_time - denoise_time:.1f}s | Total: {total_time:.1f}s") images = mx.clip(images, -1, 1) images = ((images + 1) / 2 * 255).astype(mx.uint8) return Image.fromarray(np.array(images[0]))