mlx-nucleus-image / README.md
treadon's picture
Upload README.md with huggingface_hub
314ec06 verified
|
Raw
History Blame
5.06 kB
metadata
library_name: mlx
tags:
  - mlx
  - text-to-image
  - image-generation
  - mixture-of-experts
  - dit
  - apple-silicon
  - nucleus-image
base_model: NucleusAI/Nucleus-Image
license: apache-2.0
pipeline_tag: text-to-image

MLX Nucleus-Image

An MLX port of NucleusAI/Nucleus-Image, a 17B parameter Mixture-of-Experts (MoE) DiT for text-to-image generation. Runs natively on Apple Silicon.

Sample Outputs (512x512, 30 steps, CFG 4.0, 4-bit)

Prompt Output
"A red apple on a white table"
"A golden retriever puppy playing in autumn leaves"
"A futuristic city skyline at sunset with flying cars"
"A steaming cup of coffee on a rainy windowsill"
"An astronaut riding a horse on the moon, digital art"

Architecture

Component Details
DiT 17B total params, ~2B active per token
Layers 32 (3 dense + 29 MoE)
Experts 64 routed + 1 shared per MoE layer
Routing Expert-choice (capacity-based)
Attention GQA: 16 query / 4 KV heads, head_dim=128
Text Encoder Qwen3-VL-8B-Instruct (PyTorch, hybrid)
VAE AutoencoderKLQwenImage, 16-ch latents
Scheduler Flow Matching Euler with dynamic sigma shift

Quick Start

pip install mlx torch transformers huggingface_hub pillow
import torch
import mlx.core as mx
from transformers import AutoProcessor, AutoModel
from nucleus_image.pipeline import NucleusImagePipeline

# Step 1: Encode text (PyTorch — runs once, then freed)
processor = AutoProcessor.from_pretrained(
    "NucleusAI/Nucleus-Image", subfolder="processor", trust_remote_code=True
)
text_model = AutoModel.from_pretrained(
    "NucleusAI/Nucleus-Image", subfolder="text_encoder",
    dtype=torch.bfloat16, trust_remote_code=True
)
text_model.eval()

PROMPT = "A red apple on a white table"
SYSTEM = "You are an image generation assistant."
messages = [
    {"role": "system", "content": SYSTEM},
    {"role": "user", "content": [{"type": "text", "text": PROMPT}]},
]
formatted = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[formatted], return_tensors="pt", padding=True)

with torch.no_grad():
    outputs = text_model(
        input_ids=inputs["input_ids"],
        attention_mask=inputs.get("attention_mask"),
        output_hidden_states=True, use_cache=False,
    )
    hidden = outputs.hidden_states[-8][0]  # 8th from last
text_emb = mx.array(hidden.cpu().float().numpy())

del text_model, processor  # Free ~16GB

# Step 2: Generate image (MLX)
pipe = NucleusImagePipeline.from_pretrained(quantize=4)  # 4-bit DiT

img = pipe.generate(
    text_embeddings=mx.expand_dims(text_emb, 0),
    height=512, width=512,
    num_inference_steps=30,
    guidance_scale=4.0,
    seed=42,
)
img.save("output.png")

Performance (M4 Pro 64GB)

Resolution Steps CFG Quantization Time
256x256 20 4.0 4-bit ~54s
512x512 20 4.0 4-bit ~70s
512x512 30 4.0 4-bit ~100s

How It Works

The port is a hybrid approach:

  • Text encoder stays in PyTorch (Qwen3-VL-8B, ~16GB). Loaded once to extract embeddings, then freed.
  • DiT (17B MoE) runs in MLX with optional 4-bit quantization for attention/modulation layers. Expert weights stay in bfloat16.
  • VAE decoder runs in MLX (254MB). Conv3d weights converted to Conv2d by extracting the last temporal kernel slice (CausalConv3d).

Key Conversion Details

PyTorch MLX Notes
CausalConv3d (5D weights) Conv2d (last temporal slice) Causal padding means only kernel[:,:,-1,:,:] matters
SwiGLU activation value * silu(gate) (dense), silu(gate) * up (experts) Different split conventions!
NucleusMoEEmbedRope (complex polar) cos/sin decomposition scale_rope=True: centered positions [-H/2..H/2]
Expert-choice MoE routing argsort + indicator matrix scatter Each expert picks top-C tokens
AdaLayerNormContinuous LayerNorm(affine=False) + adaptive scale/shift scale first, shift second
Timesteps(scale=1000) timestep_embedding(sigma * 1000) Pipeline normalizes t/1000 before DiT

Files

nucleus_image/
  dit.py        — 17B MoE DiT (517 lines)
  vae.py        — VAE decoder with Conv3d→Conv2d (189 lines)
  pipeline.py   — End-to-end pipeline (196 lines)
  scheduler.py  — Flow matching Euler scheduler (24 lines)

~960 lines total for the full port.

Acknowledgments