Instructions to use treadon/mlx-nucleus-image with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use treadon/mlx-nucleus-image with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir mlx-nucleus-image treadon/mlx-nucleus-image
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
File size: 7,108 Bytes
2922472 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | """Nucleus-Image MLX Pipeline.
MoE DiT + VAE in MLX, text encoder in PyTorch (hybrid).
"""
import json
import time
from pathlib import Path
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download
from .dit import NucleusMoEDiT
from .vae import VAEDecoder
from .scheduler import FlowMatchEulerScheduler
def patchify(x, patch_size=2):
"""[B, H, W, C] → [B, (H/p)*(W/p), C*p*p]
Matches diffusers _pack_latents: token layout is [C, ph, pw] (channels first).
Input x is NHWC. We rearrange to [B, H/p, W/p, C, p, p] then flatten.
"""
B, H, W, C = x.shape
p = patch_size
x = x.reshape(B, H // p, p, W // p, p, C)
# [B, H/p, p, W/p, p, C] → [B, H/p, W/p, C, p, p]
x = x.transpose(0, 1, 3, 5, 2, 4)
return x.reshape(B, (H // p) * (W // p), C * p * p)
def unpatchify(x, h, w, patch_size=2):
"""[B, N, C*p*p] → [B, H, W, C]
Inverse of patchify. Token layout is [C, ph, pw].
"""
B, N, D = x.shape
p = patch_size
C = D // (p * p)
hp, wp = h // p, w // p
x = x.reshape(B, hp, wp, C, p, p)
# [B, hp, wp, C, p, p] → [B, hp, p, wp, p, C]
x = x.transpose(0, 1, 4, 2, 5, 3)
return x.reshape(B, h, w, C)
class NucleusImagePipeline:
def __init__(self, dit, vae, scheduler, latents_mean, latents_std):
self.dit = dit
self.vae = vae
self.scheduler = scheduler
self.latents_mean = latents_mean
self.latents_std = latents_std
@staticmethod
def from_pretrained(model_id="NucleusAI/Nucleus-Image", quantize=None):
path = Path(snapshot_download(model_id))
# Config
with open(path / "transformer" / "config.json") as f:
dit_config = json.load(f)
with open(path / "vae" / "config.json") as f:
vae_config = json.load(f)
# DiT
print("Loading DiT...")
dit = NucleusMoEDiT(dit_config)
dit_weights = {}
for f in sorted((path / "transformer").glob("*.safetensors")):
dit_weights.update(mx.load(str(f)))
dit.load_weights(list(dit_weights.items()))
if quantize:
print(f"Quantizing DiT to {quantize}-bit...")
nn.quantize(dit, bits=quantize)
# VAE
print("Loading VAE...")
vae = VAEDecoder()
raw_vae = mx.load(str(path / "vae" / "diffusion_pytorch_model.safetensors"))
vae_w = {}
for k, v in raw_vae.items():
if k.startswith("encoder.") or k.startswith("quant_conv"): continue
if k.startswith("latents_") or k in ("spatial_scale_factor", "temporal_scale_factor"): continue
if k.startswith("bn."): continue
if "weight" in k and v.ndim == 5:
D = v.shape[2]
# CausalConv3d: for T=1 input with padding=(2*p, 0), only the
# LAST temporal slice of the kernel contributes
v = v[:, :, -1, :, :] if D > 1 else v.squeeze(2)
v = v.transpose(0, 2, 3, 1)
elif "weight" in k and v.ndim == 4:
v = v.transpose(0, 2, 3, 1)
if "gamma" in k:
v = v.squeeze()
vae_w[k] = v
vae.load_weights(list(vae_w.items()))
# Latent stats
latents_mean = mx.array(vae_config["latents_mean"])
latents_std = mx.array(vae_config["latents_std"])
scheduler = FlowMatchEulerScheduler()
return NucleusImagePipeline(dit, vae, scheduler, latents_mean, latents_std)
def generate(self, text_embeddings=None, neg_text_embeddings=None,
height=1024, width=1024, num_inference_steps=50,
guidance_scale=4.0, seed=None):
t_start = time.time()
latent_h = height // 8 # VAE is 8x
latent_w = width // 8
if text_embeddings is None:
text_embeddings = mx.zeros((1, 1, 4096))
text_bth = mx.expand_dims(text_embeddings, 0) if text_embeddings.ndim == 2 else text_embeddings
do_cfg = guidance_scale > 1.0
if do_cfg and neg_text_embeddings is None:
neg_text_embeddings = mx.zeros_like(text_bth)
if seed is not None:
mx.random.seed(seed)
# Generate noise in latent space, then patchify
latents = mx.random.normal((1, latent_h, latent_w, 16))
tokens = patchify(latents, patch_size=2)
import numpy as np
# Calculate dynamic shift based on image sequence length
image_seq_len = tokens.shape[1]
base_seq_len = 256
max_seq_len = 4096
base_shift = 0.5
max_shift = 1.15
# Linear interpolation of shift based on sequence length
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps)
# Apply shift: sigma_shifted = exp(mu) * sigma / (1 + (exp(mu) - 1) * sigma)
shift = np.exp(mu)
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.scheduler.sigmas = mx.concatenate([mx.array(sigmas), mx.array([0.0])])
self.scheduler.timesteps = mx.array(sigmas) * 1000
for i, t in enumerate(self.scheduler.timesteps):
# Normalize: divide by num_train_timesteps (1000) matching diffusers pipeline
# Transformer receives sigma (0-1), Timesteps(scale=1000) handles the rest
t_normalized = mx.array([t.item() / 1000.0])
pred = self.dit(tokens, t_normalized, text_bth)
if do_cfg:
neg_pred = self.dit(tokens, t_normalized, neg_text_embeddings)
# CFG with norm rescaling
comb = neg_pred + guidance_scale * (pred - neg_pred)
cond_norm = mx.sqrt(mx.sum(pred * pred, axis=-1, keepdims=True) + 1e-8)
noise_norm = mx.sqrt(mx.sum(comb * comb, axis=-1, keepdims=True) + 1e-8)
pred = comb * (cond_norm / noise_norm)
# Negate prediction (from diffusers pipeline line 597)
pred = -pred
tokens = self.scheduler.step(pred, i, tokens)
mx.eval(tokens)
denoise_time = time.time() - t_start
# Unpatchify
latents = unpatchify(tokens, latent_h, latent_w, patch_size=2)
# Denormalize: latents * std + mean
# diffusers computes: latents_std_inv = 1/config_std, then latents / std_inv = latents * config_std
mean = self.latents_mean.reshape(1, 1, 1, -1)
std = self.latents_std.reshape(1, 1, 1, -1)
latents = latents * std + mean
# VAE decode
images = self.vae(latents)
mx.eval(images)
total_time = time.time() - t_start
print(f" Denoise: {denoise_time:.1f}s | Decode: {total_time - denoise_time:.1f}s | Total: {total_time:.1f}s")
images = mx.clip(images, -1, 1)
images = ((images + 1) / 2 * 255).astype(mx.uint8)
return Image.fromarray(np.array(images[0]))
|