Instructions to use treadon/mlx-nucleus-image with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use treadon/mlx-nucleus-image with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir mlx-nucleus-image treadon/mlx-nucleus-image
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
| """Nucleus-Image MLX Pipeline. | |
| MoE DiT + VAE in MLX, text encoder in PyTorch (hybrid). | |
| """ | |
| import json | |
| import time | |
| from pathlib import Path | |
| from typing import Optional | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| import numpy as np | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download | |
| from .dit import NucleusMoEDiT | |
| from .vae import VAEDecoder | |
| from .scheduler import FlowMatchEulerScheduler | |
| def patchify(x, patch_size=2): | |
| """[B, H, W, C] → [B, (H/p)*(W/p), C*p*p] | |
| Matches diffusers _pack_latents: token layout is [C, ph, pw] (channels first). | |
| Input x is NHWC. We rearrange to [B, H/p, W/p, C, p, p] then flatten. | |
| """ | |
| B, H, W, C = x.shape | |
| p = patch_size | |
| x = x.reshape(B, H // p, p, W // p, p, C) | |
| # [B, H/p, p, W/p, p, C] → [B, H/p, W/p, C, p, p] | |
| x = x.transpose(0, 1, 3, 5, 2, 4) | |
| return x.reshape(B, (H // p) * (W // p), C * p * p) | |
| def unpatchify(x, h, w, patch_size=2): | |
| """[B, N, C*p*p] → [B, H, W, C] | |
| Inverse of patchify. Token layout is [C, ph, pw]. | |
| """ | |
| B, N, D = x.shape | |
| p = patch_size | |
| C = D // (p * p) | |
| hp, wp = h // p, w // p | |
| x = x.reshape(B, hp, wp, C, p, p) | |
| # [B, hp, wp, C, p, p] → [B, hp, p, wp, p, C] | |
| x = x.transpose(0, 1, 4, 2, 5, 3) | |
| return x.reshape(B, h, w, C) | |
| class NucleusImagePipeline: | |
| def __init__(self, dit, vae, scheduler, latents_mean, latents_std): | |
| self.dit = dit | |
| self.vae = vae | |
| self.scheduler = scheduler | |
| self.latents_mean = latents_mean | |
| self.latents_std = latents_std | |
| def from_pretrained(model_id="NucleusAI/Nucleus-Image", quantize=None): | |
| path = Path(snapshot_download(model_id)) | |
| # Config | |
| with open(path / "transformer" / "config.json") as f: | |
| dit_config = json.load(f) | |
| with open(path / "vae" / "config.json") as f: | |
| vae_config = json.load(f) | |
| # DiT | |
| print("Loading DiT...") | |
| dit = NucleusMoEDiT(dit_config) | |
| dit_weights = {} | |
| for f in sorted((path / "transformer").glob("*.safetensors")): | |
| dit_weights.update(mx.load(str(f))) | |
| dit.load_weights(list(dit_weights.items())) | |
| if quantize: | |
| print(f"Quantizing DiT to {quantize}-bit...") | |
| nn.quantize(dit, bits=quantize) | |
| # VAE | |
| print("Loading VAE...") | |
| vae = VAEDecoder() | |
| raw_vae = mx.load(str(path / "vae" / "diffusion_pytorch_model.safetensors")) | |
| vae_w = {} | |
| for k, v in raw_vae.items(): | |
| if k.startswith("encoder.") or k.startswith("quant_conv"): continue | |
| if k.startswith("latents_") or k in ("spatial_scale_factor", "temporal_scale_factor"): continue | |
| if k.startswith("bn."): continue | |
| if "weight" in k and v.ndim == 5: | |
| D = v.shape[2] | |
| # CausalConv3d: for T=1 input with padding=(2*p, 0), only the | |
| # LAST temporal slice of the kernel contributes | |
| v = v[:, :, -1, :, :] if D > 1 else v.squeeze(2) | |
| v = v.transpose(0, 2, 3, 1) | |
| elif "weight" in k and v.ndim == 4: | |
| v = v.transpose(0, 2, 3, 1) | |
| if "gamma" in k: | |
| v = v.squeeze() | |
| vae_w[k] = v | |
| vae.load_weights(list(vae_w.items())) | |
| # Latent stats | |
| latents_mean = mx.array(vae_config["latents_mean"]) | |
| latents_std = mx.array(vae_config["latents_std"]) | |
| scheduler = FlowMatchEulerScheduler() | |
| return NucleusImagePipeline(dit, vae, scheduler, latents_mean, latents_std) | |
| def generate(self, text_embeddings=None, neg_text_embeddings=None, | |
| height=1024, width=1024, num_inference_steps=50, | |
| guidance_scale=4.0, seed=None): | |
| t_start = time.time() | |
| latent_h = height // 8 # VAE is 8x | |
| latent_w = width // 8 | |
| if text_embeddings is None: | |
| text_embeddings = mx.zeros((1, 1, 4096)) | |
| text_bth = mx.expand_dims(text_embeddings, 0) if text_embeddings.ndim == 2 else text_embeddings | |
| do_cfg = guidance_scale > 1.0 | |
| if do_cfg and neg_text_embeddings is None: | |
| neg_text_embeddings = mx.zeros_like(text_bth) | |
| if seed is not None: | |
| mx.random.seed(seed) | |
| # Generate noise in latent space, then patchify | |
| latents = mx.random.normal((1, latent_h, latent_w, 16)) | |
| tokens = patchify(latents, patch_size=2) | |
| import numpy as np | |
| # Calculate dynamic shift based on image sequence length | |
| image_seq_len = tokens.shape[1] | |
| base_seq_len = 256 | |
| max_seq_len = 4096 | |
| base_shift = 0.5 | |
| max_shift = 1.15 | |
| # Linear interpolation of shift based on sequence length | |
| m = (max_shift - base_shift) / (max_seq_len - base_seq_len) | |
| b = base_shift - m * base_seq_len | |
| mu = image_seq_len * m + b | |
| sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps) | |
| # Apply shift: sigma_shifted = exp(mu) * sigma / (1 + (exp(mu) - 1) * sigma) | |
| shift = np.exp(mu) | |
| sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) | |
| self.scheduler.sigmas = mx.concatenate([mx.array(sigmas), mx.array([0.0])]) | |
| self.scheduler.timesteps = mx.array(sigmas) * 1000 | |
| for i, t in enumerate(self.scheduler.timesteps): | |
| # Normalize: divide by num_train_timesteps (1000) matching diffusers pipeline | |
| # Transformer receives sigma (0-1), Timesteps(scale=1000) handles the rest | |
| t_normalized = mx.array([t.item() / 1000.0]) | |
| pred = self.dit(tokens, t_normalized, text_bth) | |
| if do_cfg: | |
| neg_pred = self.dit(tokens, t_normalized, neg_text_embeddings) | |
| # CFG with norm rescaling | |
| comb = neg_pred + guidance_scale * (pred - neg_pred) | |
| cond_norm = mx.sqrt(mx.sum(pred * pred, axis=-1, keepdims=True) + 1e-8) | |
| noise_norm = mx.sqrt(mx.sum(comb * comb, axis=-1, keepdims=True) + 1e-8) | |
| pred = comb * (cond_norm / noise_norm) | |
| # Negate prediction (from diffusers pipeline line 597) | |
| pred = -pred | |
| tokens = self.scheduler.step(pred, i, tokens) | |
| mx.eval(tokens) | |
| denoise_time = time.time() - t_start | |
| # Unpatchify | |
| latents = unpatchify(tokens, latent_h, latent_w, patch_size=2) | |
| # Denormalize: latents * std + mean | |
| # diffusers computes: latents_std_inv = 1/config_std, then latents / std_inv = latents * config_std | |
| mean = self.latents_mean.reshape(1, 1, 1, -1) | |
| std = self.latents_std.reshape(1, 1, 1, -1) | |
| latents = latents * std + mean | |
| # VAE decode | |
| images = self.vae(latents) | |
| mx.eval(images) | |
| total_time = time.time() - t_start | |
| print(f" Denoise: {denoise_time:.1f}s | Decode: {total_time - denoise_time:.1f}s | Total: {total_time:.1f}s") | |
| images = mx.clip(images, -1, 1) | |
| images = ((images + 1) / 2 * 255).astype(mx.uint8) | |
| return Image.fromarray(np.array(images[0])) | |