Instructions to use treadon/mlx-nucleus-image with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use treadon/mlx-nucleus-image with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir mlx-nucleus-image treadon/mlx-nucleus-image
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
Upload nucleus_image/pipeline.py with huggingface_hub
Browse files
nucleus_image/pipeline.py
CHANGED
|
@@ -119,6 +119,10 @@ class NucleusImagePipeline:
|
|
| 119 |
latents = mx.random.normal((1, latent_h, latent_w, 16))
|
| 120 |
tokens = patchify(latents, patch_size=2)
|
| 121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
# Sigma schedule: raw linspace, no shift
|
| 123 |
# (scheduler config: use_dynamic_shifting=False, shift=1.0)
|
| 124 |
sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps)
|
|
@@ -131,10 +135,10 @@ class NucleusImagePipeline:
|
|
| 131 |
# Transformer receives sigma (0-1), Timesteps(scale=1000) handles the rest
|
| 132 |
t_normalized = mx.array([t.item() / 1000.0])
|
| 133 |
|
| 134 |
-
pred = self.dit(tokens, t_normalized, text_bth)
|
| 135 |
|
| 136 |
if do_cfg:
|
| 137 |
-
neg_pred = self.dit(tokens, t_normalized, neg_text_embeddings)
|
| 138 |
# CFG with norm rescaling
|
| 139 |
comb = neg_pred + guidance_scale * (pred - neg_pred)
|
| 140 |
cond_norm = mx.sqrt(mx.sum(pred * pred, axis=-1, keepdims=True) + 1e-8)
|
|
|
|
| 119 |
latents = mx.random.normal((1, latent_h, latent_w, 16))
|
| 120 |
tokens = patchify(latents, patch_size=2)
|
| 121 |
|
| 122 |
+
# Grid dimensions for RoPE (patch_size=2)
|
| 123 |
+
grid_h = latent_h // 2
|
| 124 |
+
grid_w = latent_w // 2
|
| 125 |
+
|
| 126 |
# Sigma schedule: raw linspace, no shift
|
| 127 |
# (scheduler config: use_dynamic_shifting=False, shift=1.0)
|
| 128 |
sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps)
|
|
|
|
| 135 |
# Transformer receives sigma (0-1), Timesteps(scale=1000) handles the rest
|
| 136 |
t_normalized = mx.array([t.item() / 1000.0])
|
| 137 |
|
| 138 |
+
pred = self.dit(tokens, t_normalized, text_bth, grid_h=grid_h, grid_w=grid_w)
|
| 139 |
|
| 140 |
if do_cfg:
|
| 141 |
+
neg_pred = self.dit(tokens, t_normalized, neg_text_embeddings, grid_h=grid_h, grid_w=grid_w)
|
| 142 |
# CFG with norm rescaling
|
| 143 |
comb = neg_pred + guidance_scale * (pred - neg_pred)
|
| 144 |
cond_norm = mx.sqrt(mx.sum(pred * pred, axis=-1, keepdims=True) + 1e-8)
|