treadon commited on
Commit
45e2c49
·
verified ·
1 Parent(s): be2ece9

Upload nucleus_image/pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. nucleus_image/pipeline.py +6 -2
nucleus_image/pipeline.py CHANGED
@@ -119,6 +119,10 @@ class NucleusImagePipeline:
119
  latents = mx.random.normal((1, latent_h, latent_w, 16))
120
  tokens = patchify(latents, patch_size=2)
121
 
 
 
 
 
122
  # Sigma schedule: raw linspace, no shift
123
  # (scheduler config: use_dynamic_shifting=False, shift=1.0)
124
  sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps)
@@ -131,10 +135,10 @@ class NucleusImagePipeline:
131
  # Transformer receives sigma (0-1), Timesteps(scale=1000) handles the rest
132
  t_normalized = mx.array([t.item() / 1000.0])
133
 
134
- pred = self.dit(tokens, t_normalized, text_bth)
135
 
136
  if do_cfg:
137
- neg_pred = self.dit(tokens, t_normalized, neg_text_embeddings)
138
  # CFG with norm rescaling
139
  comb = neg_pred + guidance_scale * (pred - neg_pred)
140
  cond_norm = mx.sqrt(mx.sum(pred * pred, axis=-1, keepdims=True) + 1e-8)
 
119
  latents = mx.random.normal((1, latent_h, latent_w, 16))
120
  tokens = patchify(latents, patch_size=2)
121
 
122
+ # Grid dimensions for RoPE (patch_size=2)
123
+ grid_h = latent_h // 2
124
+ grid_w = latent_w // 2
125
+
126
  # Sigma schedule: raw linspace, no shift
127
  # (scheduler config: use_dynamic_shifting=False, shift=1.0)
128
  sigmas = np.linspace(1.0, 1.0 / num_inference_steps, num_inference_steps)
 
135
  # Transformer receives sigma (0-1), Timesteps(scale=1000) handles the rest
136
  t_normalized = mx.array([t.item() / 1000.0])
137
 
138
+ pred = self.dit(tokens, t_normalized, text_bth, grid_h=grid_h, grid_w=grid_w)
139
 
140
  if do_cfg:
141
+ neg_pred = self.dit(tokens, t_normalized, neg_text_embeddings, grid_h=grid_h, grid_w=grid_w)
142
  # CFG with norm rescaling
143
  comb = neg_pred + guidance_scale * (pred - neg_pred)
144
  cond_norm = mx.sqrt(mx.sum(pred * pred, axis=-1, keepdims=True) + 1e-8)