Instructions to use treadon/mlx-nucleus-image with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use treadon/mlx-nucleus-image with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir mlx-nucleus-image treadon/mlx-nucleus-image
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
Upload nucleus_image/dit.py with huggingface_hub
Browse files- nucleus_image/dit.py +6 -3
nucleus_image/dit.py
CHANGED
|
@@ -466,7 +466,7 @@ class NucleusMoEDiT(nn.Module):
|
|
| 466 |
# proj_out: [64, 2048] → 64 = patch_size² * out_channels
|
| 467 |
self.proj_out = nn.Linear(hidden, in_channels, bias=False)
|
| 468 |
|
| 469 |
-
def __call__(self, hidden_states, timestep, txt_kv):
|
| 470 |
B = hidden_states.shape[0]
|
| 471 |
|
| 472 |
x = self.img_in(hidden_states)
|
|
@@ -481,8 +481,11 @@ class NucleusMoEDiT(nn.Module):
|
|
| 481 |
|
| 482 |
# Build RoPE: image patches are on a grid, text follows after
|
| 483 |
N_img = hidden_states.shape[1]
|
| 484 |
-
grid_h
|
| 485 |
-
|
|
|
|
|
|
|
|
|
|
| 486 |
img_cos, img_sin = compute_image_rope(
|
| 487 |
grid_h, grid_w, self._axes_dim,
|
| 488 |
self._pos_cos, self._pos_sin, self._neg_cos, self._neg_sin,
|
|
|
|
| 466 |
# proj_out: [64, 2048] → 64 = patch_size² * out_channels
|
| 467 |
self.proj_out = nn.Linear(hidden, in_channels, bias=False)
|
| 468 |
|
| 469 |
+
def __call__(self, hidden_states, timestep, txt_kv, grid_h=None, grid_w=None):
|
| 470 |
B = hidden_states.shape[0]
|
| 471 |
|
| 472 |
x = self.img_in(hidden_states)
|
|
|
|
| 481 |
|
| 482 |
# Build RoPE: image patches are on a grid, text follows after
|
| 483 |
N_img = hidden_states.shape[1]
|
| 484 |
+
if grid_h is None or grid_w is None:
|
| 485 |
+
# Fallback: assume square (works only for square images)
|
| 486 |
+
grid_h = int(N_img ** 0.5)
|
| 487 |
+
grid_w = N_img // grid_h
|
| 488 |
+
assert grid_h * grid_w == N_img, f"Grid {grid_h}x{grid_w} != N_img {N_img}"
|
| 489 |
img_cos, img_sin = compute_image_rope(
|
| 490 |
grid_h, grid_w, self._axes_dim,
|
| 491 |
self._pos_cos, self._pos_sin, self._neg_cos, self._neg_sin,
|