Jimmy commited on
Commit
b533acb
·
1 Parent(s): 92f2b6e

Add skip layer guidance

Browse files
Files changed (3) hide show
  1. i2v_inference.py +13 -0
  2. wan/image2video.py +38 -18
  3. wan/modules/model.py +5 -2
i2v_inference.py CHANGED
@@ -413,6 +413,9 @@ def parse_args():
413
  parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.")
414
  parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]")
415
  parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.")
 
 
 
416
 
417
  # LoRA usage
418
  parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.")
@@ -540,6 +543,12 @@ def main():
540
  except:
541
  raise ValueError(f"Invalid resolution: '{resolution_str}'")
542
 
 
 
 
 
 
 
543
  # Additional checks (from your original code).
544
  if "480p" in args.transformer_file:
545
  # Then we cannot exceed certain area for 480p model
@@ -628,6 +637,10 @@ def main():
628
  callback=None, # or define your own callback if you want
629
  enable_RIFLEx=enable_riflex,
630
  VAE_tile_size=VAE_tile_size,
 
 
 
 
631
  )
632
  except Exception as e:
633
  offloadobj.unload_all()
 
413
  parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.")
414
  parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]")
415
  parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.")
416
+ parser.add_argument("--slg-layers", type=str, default=None, help="Which layers to use for skip layer guidance")
417
+ parser.add_argument("--slg-start", type=float, default=0.0, help="Percentage in to start SLG")
418
+ parser.add_argument("--slg-end", type=float, default=1.0, help="Percentage in to end SLG")
419
 
420
  # LoRA usage
421
  parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.")
 
543
  except:
544
  raise ValueError(f"Invalid resolution: '{resolution_str}'")
545
 
546
+ # Parse slg_layers from comma-separated string to a Python list of ints (or None if not provided)
547
+ if args.slg_layers:
548
+ slg_list = [int(x) for x in args.slg_layers.split(",")]
549
+ else:
550
+ slg_list = None
551
+
552
  # Additional checks (from your original code).
553
  if "480p" in args.transformer_file:
554
  # Then we cannot exceed certain area for 480p model
 
637
  callback=None, # or define your own callback if you want
638
  enable_RIFLEx=enable_riflex,
639
  VAE_tile_size=VAE_tile_size,
640
+ joint_pass=slg_list is None, # set if you want a small speed improvement without SLG
641
+ slg_layers=slg_list,
642
+ slg_start=args.slg_start,
643
+ slg_end=args.slg_end,
644
  )
645
  except Exception as e:
646
  offloadobj.unload_all()
wan/image2video.py CHANGED
@@ -132,22 +132,25 @@ class WanI2V:
132
  self.sample_neg_prompt = config.sample_neg_prompt
133
 
134
  def generate(self,
135
- input_prompt,
136
- img,
137
- max_area=720 * 1280,
138
- frame_num=81,
139
- shift=5.0,
140
- sample_solver='unipc',
141
- sampling_steps=40,
142
- guide_scale=5.0,
143
- n_prompt="",
144
- seed=-1,
145
- offload_model=True,
146
- callback = None,
147
- enable_RIFLEx = False,
148
- VAE_tile_size= 0,
149
- joint_pass = False,
150
- ):
 
 
 
151
  r"""
152
  Generates video frames from input image and text prompt using diffusion process.
153
 
@@ -331,25 +334,42 @@ class WanI2V:
331
  callback(-1, None)
332
 
333
  for i, t in enumerate(tqdm(timesteps)):
 
 
 
 
334
  offload.set_step_no_for_lora(i)
335
  latent_model_input = [latent.to(self.device)]
336
  timestep = [t]
337
 
338
  timestep = torch.stack(timestep).to(self.device)
339
  if joint_pass:
 
 
340
  noise_pred_cond, noise_pred_uncond = self.model(
341
  latent_model_input, t=timestep, current_step=i, **arg_both)
342
  if self._interrupt:
343
  return None
344
  else:
345
  noise_pred_cond = self.model(
346
- latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
 
 
 
 
 
347
  if self._interrupt:
348
  return None
349
  if offload_model:
350
  torch.cuda.empty_cache()
351
  noise_pred_uncond = self.model(
352
- latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
 
 
 
 
 
 
353
  if self._interrupt:
354
  return None
355
  del latent_model_input
 
132
  self.sample_neg_prompt = config.sample_neg_prompt
133
 
134
  def generate(self,
135
+ input_prompt,
136
+ img,
137
+ max_area=720 * 1280,
138
+ frame_num=81,
139
+ shift=5.0,
140
+ sample_solver='unipc',
141
+ sampling_steps=40,
142
+ guide_scale=5.0,
143
+ n_prompt="",
144
+ seed=-1,
145
+ offload_model=True,
146
+ callback = None,
147
+ enable_RIFLEx = False,
148
+ VAE_tile_size= 0,
149
+ joint_pass = False,
150
+ slg_layers = None,
151
+ slg_start = 0.0,
152
+ slg_end = 1.0,
153
+ ):
154
  r"""
155
  Generates video frames from input image and text prompt using diffusion process.
156
 
 
334
  callback(-1, None)
335
 
336
  for i, t in enumerate(tqdm(timesteps)):
337
+ slg_layers_local = None
338
+ if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
339
+ slg_layers_local = slg_layers
340
+
341
  offload.set_step_no_for_lora(i)
342
  latent_model_input = [latent.to(self.device)]
343
  timestep = [t]
344
 
345
  timestep = torch.stack(timestep).to(self.device)
346
  if joint_pass:
347
+ if slg_layers is not None:
348
+ raise ValueError('Can not use SLG and joint-pass')
349
  noise_pred_cond, noise_pred_uncond = self.model(
350
  latent_model_input, t=timestep, current_step=i, **arg_both)
351
  if self._interrupt:
352
  return None
353
  else:
354
  noise_pred_cond = self.model(
355
+ latent_model_input,
356
+ t=timestep,
357
+ current_step=i,
358
+ is_uncond=False,
359
+ **arg_c,
360
+ )[0]
361
  if self._interrupt:
362
  return None
363
  if offload_model:
364
  torch.cuda.empty_cache()
365
  noise_pred_uncond = self.model(
366
+ latent_model_input,
367
+ t=timestep,
368
+ current_step=i,
369
+ is_uncond=True,
370
+ slg_layers=slg_layers_local,
371
+ **arg_null,
372
+ )[0]
373
  if self._interrupt:
374
  return None
375
  del latent_model_input
wan/modules/model.py CHANGED
@@ -716,7 +716,8 @@ class WanModel(ModelMixin, ConfigMixin):
716
  pipeline = None,
717
  current_step = 0,
718
  context2 = None,
719
- is_uncond=False
 
720
  ):
721
  r"""
722
  Forward pass through the diffusion model
@@ -843,7 +844,9 @@ class WanModel(ModelMixin, ConfigMixin):
843
  # context=context,
844
  context_lens=context_lens)
845
 
846
- for block in self.blocks:
 
 
847
  if pipeline._interrupt:
848
  if joint_pass:
849
  return None, None
 
716
  pipeline = None,
717
  current_step = 0,
718
  context2 = None,
719
+ is_uncond=False,
720
+ slg_layers=None,
721
  ):
722
  r"""
723
  Forward pass through the diffusion model
 
844
  # context=context,
845
  context_lens=context_lens)
846
 
847
+ for block_idx, block in enumerate(self.blocks):
848
+ if slg_layers is not None and block_idx in slg_layers and is_uncond:
849
+ continue
850
  if pipeline._interrupt:
851
  if joint_pass:
852
  return None, None