treadon commited on
Commit
be2ece9
·
verified ·
1 Parent(s): 9b67aa8

Upload nucleus_image/dit.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. nucleus_image/dit.py +6 -3
nucleus_image/dit.py CHANGED
@@ -466,7 +466,7 @@ class NucleusMoEDiT(nn.Module):
466
  # proj_out: [64, 2048] → 64 = patch_size² * out_channels
467
  self.proj_out = nn.Linear(hidden, in_channels, bias=False)
468
 
469
- def __call__(self, hidden_states, timestep, txt_kv):
470
  B = hidden_states.shape[0]
471
 
472
  x = self.img_in(hidden_states)
@@ -481,8 +481,11 @@ class NucleusMoEDiT(nn.Module):
481
 
482
  # Build RoPE: image patches are on a grid, text follows after
483
  N_img = hidden_states.shape[1]
484
- grid_h = int(N_img ** 0.5)
485
- grid_w = N_img // grid_h
 
 
 
486
  img_cos, img_sin = compute_image_rope(
487
  grid_h, grid_w, self._axes_dim,
488
  self._pos_cos, self._pos_sin, self._neg_cos, self._neg_sin,
 
466
  # proj_out: [64, 2048] → 64 = patch_size² * out_channels
467
  self.proj_out = nn.Linear(hidden, in_channels, bias=False)
468
 
469
+ def __call__(self, hidden_states, timestep, txt_kv, grid_h=None, grid_w=None):
470
  B = hidden_states.shape[0]
471
 
472
  x = self.img_in(hidden_states)
 
481
 
482
  # Build RoPE: image patches are on a grid, text follows after
483
  N_img = hidden_states.shape[1]
484
+ if grid_h is None or grid_w is None:
485
+ # Fallback: assume square (works only for square images)
486
+ grid_h = int(N_img ** 0.5)
487
+ grid_w = N_img // grid_h
488
+ assert grid_h * grid_w == N_img, f"Grid {grid_h}x{grid_w} != N_img {N_img}"
489
  img_cos, img_sin = compute_image_rope(
490
  grid_h, grid_w, self._axes_dim,
491
  self._pos_cos, self._pos_sin, self._neg_cos, self._neg_sin,