Instructions to use treadon/mlx-nucleus-image with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use treadon/mlx-nucleus-image with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir mlx-nucleus-image treadon/mlx-nucleus-image
- Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- LM Studio
| """Nucleus-Image MoE DiT in MLX. | |
| All dimensions verified against actual safetensors weight shapes. | |
| """ | |
| import math | |
| import mlx.core as mx | |
| import mlx.nn as nn | |
| # ── RoPE ── | |
| def build_rope_freqs(axes_dim, theta, max_len=4096): | |
| """Precompute RoPE cos/sin for positive AND negative positions. | |
| Matches NucleusMoEEmbedRope: pos_freqs for [0..4095], neg_freqs for [-4096..-1]. | |
| """ | |
| pos_cos, pos_sin = [], [] | |
| neg_cos, neg_sin = [], [] | |
| for dim in axes_dim: | |
| freqs = 1.0 / (theta ** (mx.arange(0, dim, 2).astype(mx.float32) / dim)) | |
| # Positive positions: 0, 1, ..., max_len-1 | |
| t_pos = mx.arange(max_len).astype(mx.float32) | |
| angles_pos = mx.outer(t_pos, freqs) | |
| pos_cos.append(mx.cos(angles_pos)) | |
| pos_sin.append(mx.sin(angles_pos)) | |
| # Negative positions: -max_len, ..., -2, -1 | |
| t_neg = (mx.arange(max_len).astype(mx.float32)[::-1] * -1 - 1) | |
| angles_neg = mx.outer(t_neg, freqs) | |
| neg_cos.append(mx.cos(angles_neg)) | |
| neg_sin.append(mx.sin(angles_neg)) | |
| return pos_cos, pos_sin, neg_cos, neg_sin | |
| def compute_image_rope(height, width, axes_dim, pos_cos, pos_sin, neg_cos, neg_sin, scale_rope=True): | |
| """Compute RoPE frequencies for image patches. | |
| With scale_rope=True (default for Nucleus), uses centered positions: | |
| height: [-ceil(H/2), ..., -1, 0, 1, ..., floor(H/2)-1] | |
| width: [-ceil(W/2), ..., -1, 0, 1, ..., floor(W/2)-1] | |
| Returns (img_cos, img_sin) each of shape [H*W, D//2] | |
| """ | |
| frame = 1 | |
| # Frame axis: position 0 | |
| f_cos = mx.broadcast_to(pos_cos[0][:frame], (frame, 1, 1, axes_dim[0] // 2)) | |
| f_sin = mx.broadcast_to(pos_sin[0][:frame], (frame, 1, 1, axes_dim[0] // 2)) | |
| f_cos = mx.broadcast_to(f_cos, (frame, height, width, axes_dim[0] // 2)) | |
| f_sin = mx.broadcast_to(f_sin, (frame, height, width, axes_dim[0] // 2)) | |
| if scale_rope: | |
| # Centered: negative positions + positive positions | |
| # Height: neg[-N_neg:] + pos[:N_pos] where N_neg = H - H//2, N_pos = H//2 | |
| n_neg_h = height - height // 2 | |
| n_pos_h = height // 2 | |
| h_cos = mx.concatenate([neg_cos[1][-n_neg_h:], pos_cos[1][:n_pos_h]], axis=0) | |
| h_sin = mx.concatenate([neg_sin[1][-n_neg_h:], pos_sin[1][:n_pos_h]], axis=0) | |
| # Width | |
| n_neg_w = width - width // 2 | |
| n_pos_w = width // 2 | |
| w_cos = mx.concatenate([neg_cos[2][-n_neg_w:], pos_cos[2][:n_pos_w]], axis=0) | |
| w_sin = mx.concatenate([neg_sin[2][-n_neg_w:], pos_sin[2][:n_pos_w]], axis=0) | |
| else: | |
| h_cos = pos_cos[1][:height] | |
| h_sin = pos_sin[1][:height] | |
| w_cos = pos_cos[2][:width] | |
| w_sin = pos_sin[2][:width] | |
| h_cos = mx.broadcast_to(h_cos.reshape(1, height, 1, -1), (frame, height, width, axes_dim[1] // 2)) | |
| h_sin = mx.broadcast_to(h_sin.reshape(1, height, 1, -1), (frame, height, width, axes_dim[1] // 2)) | |
| w_cos = mx.broadcast_to(w_cos.reshape(1, 1, width, -1), (frame, height, width, axes_dim[2] // 2)) | |
| w_sin = mx.broadcast_to(w_sin.reshape(1, 1, width, -1), (frame, height, width, axes_dim[2] // 2)) | |
| img_cos = mx.concatenate([f_cos, h_cos, w_cos], axis=-1).reshape(frame * height * width, -1) | |
| img_sin = mx.concatenate([f_sin, h_sin, w_sin], axis=-1).reshape(frame * height * width, -1) | |
| return img_cos, img_sin | |
| def compute_text_rope(max_txt_len, max_vid_index, axes_dim, pos_cos, pos_sin): | |
| """Compute RoPE frequencies for text tokens. | |
| Text positions start after max_vid_index (= max(H/2, W/2) with scale_rope). | |
| """ | |
| start = max_vid_index | |
| cos_parts = [] | |
| sin_parts = [] | |
| for i, dim in enumerate(axes_dim): | |
| cos_parts.append(pos_cos[i][start:start + max_txt_len]) | |
| sin_parts.append(pos_sin[i][start:start + max_txt_len]) | |
| txt_cos = mx.concatenate(cos_parts, axis=-1) | |
| txt_sin = mx.concatenate(sin_parts, axis=-1) | |
| return txt_cos, txt_sin | |
| def apply_rotary_emb_complex(x, freqs_cos, freqs_sin): | |
| """Apply rotary embeddings using complex multiplication. | |
| x: [B, S, H, D] | |
| freqs_cos, freqs_sin: [S, D//2] (real and imaginary parts of complex exponentials) | |
| Complex multiplication: (a+bi)(c+di) = (ac-bd) + (ad+bc)i | |
| """ | |
| # Reshape x into pairs: [B, S, H, D//2, 2] | |
| x_pairs = x.reshape(*x.shape[:-1], -1, 2) | |
| x_real = x_pairs[..., 0] # [B, S, H, D//2] | |
| x_imag = x_pairs[..., 1] | |
| # Broadcast freqs: [1, S, 1, D//2] | |
| cos_ = freqs_cos[None, :, None, :].astype(x.dtype) | |
| sin_ = freqs_sin[None, :, None, :].astype(x.dtype) | |
| # Complex multiply | |
| out_real = x_real * cos_ - x_imag * sin_ | |
| out_imag = x_real * sin_ + x_imag * cos_ | |
| # Interleave back: [B, S, H, D//2, 2] → [B, S, H, D] | |
| out = mx.stack([out_real, out_imag], axis=-1) | |
| return out.reshape(x.shape) | |
| def timestep_embedding(t, dim: int): | |
| half = dim // 2 | |
| freqs = mx.exp(-math.log(10000) * mx.arange(0, half).astype(mx.float32) / half) | |
| args = t.astype(mx.float32)[:, None] * freqs[None, :] | |
| return mx.concatenate([mx.cos(args), mx.sin(args)], axis=-1) | |
| # ── TimestepEmbedder: linear_1 [8192, 2048], linear_2 [2048, 8192] ── | |
| class TimestepEmbedder(nn.Module): | |
| def __init__(self, in_dim: int, expand_dim: int): | |
| super().__init__() | |
| self.linear_1 = nn.Linear(in_dim, expand_dim) | |
| self.linear_2 = nn.Linear(expand_dim, in_dim) | |
| def __call__(self, t): | |
| return self.linear_2(nn.silu(self.linear_1(t))) | |
| # ── Gated projection (shared between dense FFN and shared expert) ── | |
| class GatedProj(nn.Module): | |
| """net.0.proj: [hidden*2, in_dim]. SwiGLU: first_half * silu(second_half).""" | |
| def __init__(self, in_dim: int, hidden_dim: int): | |
| super().__init__() | |
| self.proj = nn.Linear(in_dim, hidden_dim * 2, bias=False) | |
| def __call__(self, x): | |
| hidden, gate = mx.split(self.proj(x), 2, axis=-1) | |
| return hidden * nn.silu(gate) | |
| # ── Dense FFN (layers 0-2) ── | |
| # net.0.proj: [10752, 2048] → gated, hidden=5376 | |
| # net.2: [2048, 5376] | |
| class DenseFFN(nn.Module): | |
| def __init__(self, in_dim: int, hidden_dim: int): | |
| super().__init__() | |
| self.net = [GatedProj(in_dim, hidden_dim), None, nn.Linear(hidden_dim, in_dim, bias=False)] | |
| def __call__(self, x): | |
| return self.net[2](self.net[0](x)) | |
| # ── MoE FFN (layers 3-31) ── | |
| class SharedExpert(nn.Module): | |
| """shared_expert.net.0.proj: [2688, 2048], net.2: [2048, 1344]""" | |
| def __init__(self, in_dim: int, hidden_dim: int): | |
| super().__init__() | |
| self.net = [GatedProj(in_dim, hidden_dim), None, nn.Linear(hidden_dim, in_dim, bias=False)] | |
| def __call__(self, x): | |
| return self.net[2](self.net[0](x)) | |
| class Experts(nn.Module): | |
| """Packed expert weights. | |
| gate_up_proj: [64, in_dim, hidden*2] (weight shape from safetensors) | |
| down_proj: [64, hidden, in_dim] | |
| """ | |
| def __init__(self, in_dim: int, hidden_dim: int, num_experts: int): | |
| super().__init__() | |
| self.gate_up_proj = mx.zeros((num_experts, in_dim, hidden_dim * 2)) | |
| self.down_proj = mx.zeros((num_experts, hidden_dim, in_dim)) | |
| class MoEFFN(nn.Module): | |
| """gate: [64, 4096] — router input is concat(token, timestep) | |
| experts: packed, shared_expert: standard FFN | |
| Uses expert-choice routing: each expert picks its top-C tokens | |
| (capacity-based), matching the diffusers NucleusMoELayer. | |
| """ | |
| def __init__(self, in_dim: int, expert_hidden_dim: int, num_experts: int = 64, | |
| route_scale: float = 2.5, capacity_factor: float = 2.0): | |
| super().__init__() | |
| self.num_experts = num_experts | |
| self.route_scale = route_scale | |
| self.capacity_factor = capacity_factor | |
| # Router input is 2*in_dim (token + timestep concat) | |
| self.gate = nn.Linear(in_dim * 2, num_experts, bias=False) | |
| self.experts = Experts(in_dim, expert_hidden_dim, num_experts) | |
| self.shared_expert = SharedExpert(in_dim, expert_hidden_dim) | |
| def __call__(self, x, timestep_emb=None, unmodulated_x=None): | |
| """Expert-choice routing: each expert picks top-C tokens. | |
| C = ceil(capacity_factor * S / num_experts) | |
| Per-token gating normalization ensures tokens selected by | |
| multiple experts have their weights sum to route_scale. | |
| """ | |
| B, S, D = x.shape | |
| # Decoupled routing: router sees unmodulated tokens + timestep | |
| if unmodulated_x is None: | |
| unmodulated_x = x | |
| if timestep_emb is not None: | |
| t_expanded = mx.broadcast_to(timestep_emb[:, None, :], (B, S, D)) | |
| router_input = mx.concatenate([t_expanded, unmodulated_x], axis=-1) | |
| else: | |
| router_input = unmodulated_x | |
| logits = self.gate(router_input) # [B, S, E] | |
| scores = mx.softmax(logits.astype(mx.float32), axis=-1).astype(x.dtype) | |
| # Expert-choice: transpose to [B, E, S], each expert scores all tokens | |
| affinity = mx.transpose(scores, (0, 2, 1)) # [B, E, S] | |
| capacity = max(1, math.ceil(self.capacity_factor * S / self.num_experts)) | |
| # B=1 for inference — squeeze batch dim for simpler indexing | |
| aff = affinity[0] # [E, S] | |
| x_flat = x[0] # [S, D] | |
| # Each expert picks top-C tokens (argsort descending) | |
| sorted_idx = mx.argsort(-aff, axis=-1)[:, :capacity] # [E, C] | |
| top_scores = mx.take_along_axis(aff, sorted_idx, axis=-1) # [E, C] | |
| # Per-token normalization (tokens picked by multiple experts | |
| # have their scores normalized to sum to route_scale) | |
| flat_idx = sorted_idx.reshape(-1) # [E*C] | |
| flat_scores = top_scores.reshape(-1) # [E*C] | |
| # Indicator matrix for scatter operations: [E*C, S] | |
| indicator = (flat_idx[:, None] == mx.arange(S)[None, :]).astype(x.dtype) | |
| # Per-token score sums | |
| token_sums = indicator.T @ flat_scores # [S] | |
| # Normalized gating weights | |
| flat_norm = flat_scores / (token_sums[flat_idx] + 1e-12) * self.route_scale | |
| # Gather selected tokens for all experts | |
| gathered = x_flat[flat_idx] # [E*C, D] | |
| # Process through each expert's weights | |
| gu = self.experts.gate_up_proj # [E, D, hidden*2] | |
| down = self.experts.down_proj # [E, hidden, D] | |
| expert_outputs = [] | |
| for e in range(self.num_experts): | |
| start = e * capacity | |
| end = start + capacity | |
| tokens_e = gathered[start:end] # [C, D] | |
| h = tokens_e @ gu[e] # [C, hidden*2] | |
| gh, uh = mx.split(h, 2, axis=-1) # SwiGLUExperts: first=gate, second=up | |
| h = nn.silu(gh) * uh | |
| h = h @ down[e] # [C, D] | |
| expert_outputs.append(h) | |
| all_expert_out = mx.concatenate(expert_outputs, axis=0) # [E*C, D] | |
| all_expert_out = all_expert_out * flat_norm[:, None] # weighted | |
| # Scatter back: indicator.T is [S, E*C], matmul gives [S, D] | |
| scattered = indicator.T @ all_expert_out | |
| # Shared expert (always runs on all tokens) | |
| shared_out = self.shared_expert(x)[0] # [S, D] | |
| return (shared_out + scattered).reshape(1, S, D) | |
| # ── Joint Attention ── | |
| class JointAttention(nn.Module): | |
| def __init__(self, hidden_dim: int, num_heads: int, num_kv_heads: int, head_dim: int): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.num_kv_heads = num_kv_heads | |
| self.head_dim = head_dim | |
| self.to_q = nn.Linear(hidden_dim, num_heads * head_dim, bias=False) | |
| self.to_k = nn.Linear(hidden_dim, num_kv_heads * head_dim, bias=False) | |
| self.to_v = nn.Linear(hidden_dim, num_kv_heads * head_dim, bias=False) | |
| self.to_out = [nn.Linear(num_heads * head_dim, hidden_dim, bias=False)] | |
| self.add_k_proj = nn.Linear(hidden_dim, num_kv_heads * head_dim, bias=False) | |
| self.add_v_proj = nn.Linear(hidden_dim, num_kv_heads * head_dim, bias=False) | |
| self.norm_q = nn.RMSNorm(head_dim) | |
| self.norm_k = nn.RMSNorm(head_dim) | |
| self.norm_added_q = nn.RMSNorm(head_dim) | |
| self.norm_added_k = nn.RMSNorm(head_dim) | |
| def __call__(self, img_x, txt_kv, img_rope=None, txt_rope=None): | |
| B, S_img, _ = img_x.shape | |
| S_txt = txt_kv.shape[1] | |
| q = self.to_q(img_x).reshape(B, S_img, self.num_heads, self.head_dim) | |
| k = self.to_k(img_x).reshape(B, S_img, self.num_kv_heads, self.head_dim) | |
| v = self.to_v(img_x).reshape(B, S_img, self.num_kv_heads, self.head_dim) | |
| q = self.norm_q(q) | |
| k = self.norm_k(k) | |
| txt_k = self.add_k_proj(txt_kv).reshape(B, S_txt, self.num_kv_heads, self.head_dim) | |
| txt_v = self.add_v_proj(txt_kv).reshape(B, S_txt, self.num_kv_heads, self.head_dim) | |
| txt_k = self.norm_added_k(txt_k) | |
| # Apply RoPE | |
| if img_rope is not None: | |
| img_cos, img_sin = img_rope | |
| q = apply_rotary_emb_complex(q, img_cos, img_sin) | |
| k = apply_rotary_emb_complex(k, img_cos, img_sin) | |
| if txt_rope is not None: | |
| txt_cos, txt_sin = txt_rope | |
| txt_k = apply_rotary_emb_complex(txt_k, txt_cos, txt_sin) | |
| k = mx.concatenate([k, txt_k], axis=1) | |
| v = mx.concatenate([v, txt_v], axis=1) | |
| if self.num_kv_heads < self.num_heads: | |
| r = self.num_heads // self.num_kv_heads | |
| k = mx.repeat(k, r, axis=2) | |
| v = mx.repeat(v, r, axis=2) | |
| q = q.transpose(0, 2, 1, 3) | |
| k = k.transpose(0, 2, 1, 3) | |
| v = v.transpose(0, 2, 1, 3) | |
| out = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0 / math.sqrt(self.head_dim)) | |
| out = out.transpose(0, 2, 1, 3).reshape(B, S_img, -1) | |
| return self.to_out[0](out) | |
| # ── Transformer Block ── | |
| # img_mod.1: [8192, 2048] → 4 modulations (shift_attn, scale_attn, shift_mlp, scale_mlp) | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, hidden_dim: int, num_heads: int, num_kv_heads: int, | |
| head_dim: int, dense_hidden: int, is_moe: bool, | |
| num_experts: int, expert_hidden: int, route_scale: float, | |
| text_dim: int, capacity_factor: float = 2.0): | |
| super().__init__() | |
| self.is_moe = is_moe | |
| self.attn = JointAttention(hidden_dim, num_heads, num_kv_heads, head_dim) | |
| if is_moe: | |
| self.is_moe = True | |
| self.img_mlp = MoEFFN(hidden_dim, expert_hidden, num_experts, route_scale, capacity_factor) | |
| else: | |
| self.is_moe = False | |
| self.img_mlp = DenseFFN(hidden_dim, dense_hidden) | |
| # Pre-norms: LayerNorm without affine | |
| self.pre_attn_norm = nn.LayerNorm(hidden_dim, affine=False) | |
| self.pre_mlp_norm = nn.LayerNorm(hidden_dim, affine=False) | |
| # 4 modulations: scale1, gate1, scale2, gate2 | |
| self.img_mod = [nn.SiLU(), nn.Linear(hidden_dim, hidden_dim * 4)] | |
| self.encoder_proj = nn.Linear(text_dim, hidden_dim) | |
| def __call__(self, img_x, txt_kv, c, img_rope=None, txt_rope=None): | |
| # 4 modulations: scale1, gate1, scale2, gate2 | |
| mod = self.img_mod[1](self.img_mod[0](c))[:, None, :] # [B, 1, 4*H] | |
| scale1, gate1, scale2, gate2 = mx.split(mod, 4, axis=-1) | |
| # Clamp gates | |
| gate1 = mx.clip(gate1, -2.0, 2.0) | |
| gate2 = mx.clip(gate2, -2.0, 2.0) | |
| txt_projected = self.encoder_proj(txt_kv) | |
| # Attention: LayerNorm → scale → attn → tanh gate | |
| img_normed = self.pre_attn_norm(img_x) | |
| img_modulated = img_normed * (1 + scale1) | |
| attn_out = self.attn(img_modulated, txt_projected, img_rope=img_rope, txt_rope=txt_rope) | |
| img_x = img_x + mx.tanh(gate1) * attn_out | |
| # FFN: LayerNorm → scale → mlp → tanh gate | |
| img_normed2 = self.pre_mlp_norm(img_x) | |
| img_modulated2 = img_normed2 * (1 + scale2) | |
| if self.is_moe: | |
| mlp_out = self.img_mlp(img_modulated2, timestep_emb=c, unmodulated_x=img_normed2) | |
| else: | |
| mlp_out = self.img_mlp(img_modulated2) | |
| img_x = img_x + mx.tanh(gate2) * mlp_out | |
| return img_x | |
| # ── Full Model ── | |
| class NucleusMoEDiT(nn.Module): | |
| def __init__(self, config: dict): | |
| super().__init__() | |
| num_heads = config.get("num_attention_heads", 16) | |
| head_dim = config.get("attention_head_dim", 128) | |
| hidden = num_heads * head_dim # 2048 | |
| num_kv_heads = config.get("num_key_value_heads", 4) | |
| num_layers = config.get("num_layers", 32) | |
| num_experts = config.get("num_experts", 64) | |
| expert_hidden = config.get("moe_intermediate_dim", 1344) | |
| route_scale = config.get("route_scale", 2.5) | |
| text_dim = config.get("joint_attention_dim", 4096) | |
| in_channels = config.get("in_channels", 64) | |
| out_channels = config.get("out_channels", 16) | |
| axes_dims = config.get("axes_dims_rope", [16, 56, 56]) | |
| # Dense FFN hidden: from weight [10752, 2048] → 10752/2 = 5376 | |
| dense_hidden = 5376 | |
| self.hidden_dim = hidden | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.img_in = nn.Linear(in_channels, hidden) | |
| # Timestep: linear_1 [8192, 2048], linear_2 [2048, 8192] | |
| self.time_text_embed = { | |
| "timestep_embedder": TimestepEmbedder(hidden, 8192), | |
| "norm": nn.RMSNorm(hidden), | |
| } | |
| self.txt_norm = nn.RMSNorm(text_dim) | |
| self._axes_dim = axes_dims | |
| self._scale_rope = True # Nucleus uses centered positions | |
| self._pos_cos, self._pos_sin, self._neg_cos, self._neg_sin = build_rope_freqs(axes_dims, 10000) | |
| capacity_factors = config.get("capacity_factors", [0.0] * 3 + [2.0] * 29) | |
| dense_layers = {0, 1, 2} | |
| self.transformer_blocks = [ | |
| TransformerBlock( | |
| hidden, num_heads, num_kv_heads, head_dim, dense_hidden, | |
| is_moe=(i not in dense_layers), | |
| num_experts=num_experts, expert_hidden=expert_hidden, | |
| route_scale=route_scale, text_dim=text_dim, | |
| capacity_factor=capacity_factors[i], | |
| ) | |
| for i in range(num_layers) | |
| ] | |
| # norm_out: AdaLayerNormContinuous(elementwise_affine=False, eps=1e-6) | |
| # Applies LayerNorm before adaptive modulation | |
| self.norm_out = { | |
| "norm": nn.LayerNorm(hidden, affine=False, eps=1e-6), | |
| "linear": nn.Linear(hidden, hidden * 2), | |
| } | |
| # proj_out: [64, 2048] → 64 = patch_size² * out_channels | |
| self.proj_out = nn.Linear(hidden, in_channels, bias=False) | |
| def __call__(self, hidden_states, timestep, txt_kv, grid_h=None, grid_w=None): | |
| B = hidden_states.shape[0] | |
| x = self.img_in(hidden_states) | |
| # Timesteps(scale=1000): multiply by 1000 then sinusoidal embed | |
| # Input timestep is already sigma*1000 from the scheduler | |
| t_emb = timestep_embedding(timestep * 1000, self.hidden_dim).astype(x.dtype) | |
| c = self.time_text_embed["timestep_embedder"](t_emb) | |
| c = self.time_text_embed["norm"](c) | |
| txt_kv = self.txt_norm(txt_kv) | |
| # Build RoPE: image patches are on a grid, text follows after | |
| N_img = hidden_states.shape[1] | |
| if grid_h is None or grid_w is None: | |
| # Fallback: assume square (works only for square images) | |
| grid_h = int(N_img ** 0.5) | |
| grid_w = N_img // grid_h | |
| assert grid_h * grid_w == N_img, f"Grid {grid_h}x{grid_w} != N_img {N_img}" | |
| img_cos, img_sin = compute_image_rope( | |
| grid_h, grid_w, self._axes_dim, | |
| self._pos_cos, self._pos_sin, self._neg_cos, self._neg_sin, | |
| scale_rope=self._scale_rope, | |
| ) | |
| T_txt = txt_kv.shape[1] | |
| # With scale_rope, text starts at max(H/2, W/2) | |
| max_vid_idx = max(grid_h // 2, grid_w // 2) if self._scale_rope else max(grid_h, grid_w) | |
| txt_cos, txt_sin = compute_text_rope(T_txt, max_vid_idx, self._axes_dim, self._pos_cos, self._pos_sin) | |
| img_rope = (img_cos, img_sin) | |
| txt_rope = (txt_cos, txt_sin) | |
| for block in self.transformer_blocks: | |
| x = block(x, txt_kv, c, img_rope=img_rope, txt_rope=txt_rope) | |
| # AdaLayerNormContinuous: norm first, then adaptive modulation | |
| mod = self.norm_out["linear"](nn.silu(c)) | |
| scale, shift = mx.split(mod, 2, axis=-1) # scale first, shift second | |
| x = self.norm_out["norm"](x) * (1 + scale[:, None, :]) + shift[:, None, :] | |
| x = self.proj_out(x) | |
| return x | |