mlx-nucleus-image / README.md
treadon's picture
Upload README.md with huggingface_hub
90dc3da verified
|
Raw
History Blame
6.18 kB
metadata
library_name: mlx
tags:
  - mlx
  - text-to-image
  - image-generation
  - mixture-of-experts
  - dit
  - apple-silicon
  - nucleus-image
base_model: NucleusAI/Nucleus-Image
license: apache-2.0
pipeline_tag: text-to-image

MLX Nucleus-Image

An MLX port of NucleusAI/Nucleus-Image, a 17B parameter Mixture-of-Experts (MoE) DiT for text-to-image generation. Runs natively on Apple Silicon.

Sample Outputs (512x512, 30 steps, CFG 4.0, 4-bit)

"A red apple on a white table" "A golden retriever puppy playing in autumn leaves" "A futuristic city skyline at sunset"
"A steaming cup of coffee on a rainy windowsill" "An astronaut riding a horse on the moon"

Quick Start

1. Clone the repo

git clone https://huggingface.co/treadon/mlx-nucleus-image
cd mlx-nucleus-image

2. Install dependencies

pip install mlx torch transformers huggingface_hub pillow

3. Generate an image

python generate.py --prompt "A red apple on a white table" --seed 42

The first run downloads ~34GB of weights (cached for subsequent runs).

More options

python generate.py \
  --prompt "A futuristic city skyline at sunset" \
  --height 512 --width 512 \
  --steps 30 --cfg 4.0 \
  --seed 42 --output city.png \
  --quantize 4
Flag Default Description
--prompt required Text prompt
--height 512 Image height
--width 512 Image width
--steps 50 Denoising steps
--cfg 4.0 Guidance scale
--seed random Random seed
--output output.png Output file
--quantize 4 Quantization bits (4, 8, or None)

Architecture

Component Details
DiT 17B total params, ~2B active per token
Layers 32 (3 dense + 29 MoE)
Experts 64 routed + 1 shared per MoE layer
Routing Expert-choice (capacity-based)
Attention GQA: 16 query / 4 KV heads, head_dim=128
Text Encoder Qwen3-VL-8B-Instruct (PyTorch, hybrid)
VAE AutoencoderKLQwenImage, 16-ch latents
Scheduler Flow Matching Euler

Performance (M4 Pro 64GB)

Resolution Steps CFG Quantization Time
256x256 20 4.0 4-bit ~54s
512x512 20 4.0 4-bit ~70s
512x512 30 4.0 4-bit ~100s

How It Works

The port is a hybrid approach:

  • Text encoder stays in PyTorch (Qwen3-VL-8B, ~16GB). Loaded once to extract embeddings, then freed.
  • DiT (17B MoE) runs in MLX with optional 4-bit quantization for attention/modulation layers. Expert weights stay in bfloat16.
  • VAE decoder runs in MLX (~50MB converted). Original CausalConv3d weights converted to Conv2d.

Key Conversion Details

PyTorch MLX Notes
CausalConv3d (5D weights) Conv2d (last temporal slice) Causal padding means only kernel[:,:,-1,:,:] matters
SwiGLU activation value * silu(gate) (dense), silu(gate) * up (experts) Different split conventions!
NucleusMoEEmbedRope (complex polar) cos/sin decomposition scale_rope=True: centered positions [-H/2..H/2]
Expert-choice MoE routing argsort + indicator matrix scatter Each expert picks top-C tokens
AdaLayerNormContinuous LayerNorm(affine=False) + adaptive scale/shift scale first, shift second
Timesteps(scale=1000) timestep_embedding(sigma * 1000) Pipeline normalizes t/1000 before DiT

Python API

import torch
import mlx.core as mx
from transformers import AutoProcessor, AutoModel
from nucleus_image import NucleusImagePipeline

# Encode text (PyTorch — loaded once, then freed)
processor = AutoProcessor.from_pretrained("NucleusAI/Nucleus-Image", subfolder="processor", trust_remote_code=True)
text_model = AutoModel.from_pretrained("NucleusAI/Nucleus-Image", subfolder="text_encoder", dtype=torch.bfloat16, trust_remote_code=True).eval()

SYSTEM = "You are an image generation assistant."
def encode(prompt):
    messages = [{"role": "system", "content": SYSTEM}, {"role": "user", "content": [{"type": "text", "text": prompt}]}]
    formatted = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[formatted], return_tensors="pt", padding=True)
    with torch.no_grad():
        out = text_model(input_ids=inputs["input_ids"], attention_mask=inputs.get("attention_mask"),
                        output_hidden_states=True, use_cache=False)
    return mx.array(out.hidden_states[-8][0].cpu().float().numpy())

text_emb = encode("A red apple on a white table")
neg_emb = encode("")  # empty string for proper CFG
del text_model, processor  # free ~16GB

# Generate (MLX)
pipe = NucleusImagePipeline.from_pretrained(quantize=4)
img = pipe.generate(
    text_embeddings=mx.expand_dims(text_emb, 0),
    neg_text_embeddings=mx.expand_dims(neg_emb, 0),
    height=512, width=512, num_inference_steps=30, guidance_scale=4.0, seed=42,
)
img.save("output.png")

Files

generate.py              — CLI entry point
samples/                 — Pre-generated sample images
nucleus_image/
  dit.py                 — 17B MoE DiT
  vae.py                 — VAE decoder (Conv3d→Conv2d)
  pipeline.py            — End-to-end pipeline
  scheduler.py           — Flow matching Euler scheduler
dit/                     — Pre-converted DiT weights (safetensors)
vae/                     — Pre-converted VAE weights (safetensors)

Acknowledgments