Jimmy commited on
Commit ·
b533acb
1
Parent(s): 92f2b6e
Add skip layer guidance
Browse files- i2v_inference.py +13 -0
- wan/image2video.py +38 -18
- wan/modules/model.py +5 -2
i2v_inference.py
CHANGED
|
@@ -413,6 +413,9 @@ def parse_args():
|
|
| 413 |
parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.")
|
| 414 |
parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]")
|
| 415 |
parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.")
|
|
|
|
|
|
|
|
|
|
| 416 |
|
| 417 |
# LoRA usage
|
| 418 |
parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.")
|
|
@@ -540,6 +543,12 @@ def main():
|
|
| 540 |
except:
|
| 541 |
raise ValueError(f"Invalid resolution: '{resolution_str}'")
|
| 542 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
# Additional checks (from your original code).
|
| 544 |
if "480p" in args.transformer_file:
|
| 545 |
# Then we cannot exceed certain area for 480p model
|
|
@@ -628,6 +637,10 @@ def main():
|
|
| 628 |
callback=None, # or define your own callback if you want
|
| 629 |
enable_RIFLEx=enable_riflex,
|
| 630 |
VAE_tile_size=VAE_tile_size,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
)
|
| 632 |
except Exception as e:
|
| 633 |
offloadobj.unload_all()
|
|
|
|
| 413 |
parser.add_argument("--teacache", type=float, default=0.25, help="TeaCache multiplier, e.g. 0.5, 2.0, etc.")
|
| 414 |
parser.add_argument("--teacache-start", type=float, default=0.1, help="Teacache start step percentage [0..100]")
|
| 415 |
parser.add_argument("--seed", type=int, default=-1, help="Random seed. -1 means random each time.")
|
| 416 |
+
parser.add_argument("--slg-layers", type=str, default=None, help="Which layers to use for skip layer guidance")
|
| 417 |
+
parser.add_argument("--slg-start", type=float, default=0.0, help="Percentage in to start SLG")
|
| 418 |
+
parser.add_argument("--slg-end", type=float, default=1.0, help="Percentage in to end SLG")
|
| 419 |
|
| 420 |
# LoRA usage
|
| 421 |
parser.add_argument("--loras-choices", type=str, default="", help="Comma-separated list of chosen LoRA indices or preset names to load. Usually you only use the preset.")
|
|
|
|
| 543 |
except:
|
| 544 |
raise ValueError(f"Invalid resolution: '{resolution_str}'")
|
| 545 |
|
| 546 |
+
# Parse slg_layers from comma-separated string to a Python list of ints (or None if not provided)
|
| 547 |
+
if args.slg_layers:
|
| 548 |
+
slg_list = [int(x) for x in args.slg_layers.split(",")]
|
| 549 |
+
else:
|
| 550 |
+
slg_list = None
|
| 551 |
+
|
| 552 |
# Additional checks (from your original code).
|
| 553 |
if "480p" in args.transformer_file:
|
| 554 |
# Then we cannot exceed certain area for 480p model
|
|
|
|
| 637 |
callback=None, # or define your own callback if you want
|
| 638 |
enable_RIFLEx=enable_riflex,
|
| 639 |
VAE_tile_size=VAE_tile_size,
|
| 640 |
+
joint_pass=slg_list is None, # set if you want a small speed improvement without SLG
|
| 641 |
+
slg_layers=slg_list,
|
| 642 |
+
slg_start=args.slg_start,
|
| 643 |
+
slg_end=args.slg_end,
|
| 644 |
)
|
| 645 |
except Exception as e:
|
| 646 |
offloadobj.unload_all()
|
wan/image2video.py
CHANGED
|
@@ -132,22 +132,25 @@ class WanI2V:
|
|
| 132 |
self.sample_neg_prompt = config.sample_neg_prompt
|
| 133 |
|
| 134 |
def generate(self,
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
| 151 |
r"""
|
| 152 |
Generates video frames from input image and text prompt using diffusion process.
|
| 153 |
|
|
@@ -331,25 +334,42 @@ class WanI2V:
|
|
| 331 |
callback(-1, None)
|
| 332 |
|
| 333 |
for i, t in enumerate(tqdm(timesteps)):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
offload.set_step_no_for_lora(i)
|
| 335 |
latent_model_input = [latent.to(self.device)]
|
| 336 |
timestep = [t]
|
| 337 |
|
| 338 |
timestep = torch.stack(timestep).to(self.device)
|
| 339 |
if joint_pass:
|
|
|
|
|
|
|
| 340 |
noise_pred_cond, noise_pred_uncond = self.model(
|
| 341 |
latent_model_input, t=timestep, current_step=i, **arg_both)
|
| 342 |
if self._interrupt:
|
| 343 |
return None
|
| 344 |
else:
|
| 345 |
noise_pred_cond = self.model(
|
| 346 |
-
latent_model_input,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
if self._interrupt:
|
| 348 |
return None
|
| 349 |
if offload_model:
|
| 350 |
torch.cuda.empty_cache()
|
| 351 |
noise_pred_uncond = self.model(
|
| 352 |
-
latent_model_input,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
if self._interrupt:
|
| 354 |
return None
|
| 355 |
del latent_model_input
|
|
|
|
| 132 |
self.sample_neg_prompt = config.sample_neg_prompt
|
| 133 |
|
| 134 |
def generate(self,
|
| 135 |
+
input_prompt,
|
| 136 |
+
img,
|
| 137 |
+
max_area=720 * 1280,
|
| 138 |
+
frame_num=81,
|
| 139 |
+
shift=5.0,
|
| 140 |
+
sample_solver='unipc',
|
| 141 |
+
sampling_steps=40,
|
| 142 |
+
guide_scale=5.0,
|
| 143 |
+
n_prompt="",
|
| 144 |
+
seed=-1,
|
| 145 |
+
offload_model=True,
|
| 146 |
+
callback = None,
|
| 147 |
+
enable_RIFLEx = False,
|
| 148 |
+
VAE_tile_size= 0,
|
| 149 |
+
joint_pass = False,
|
| 150 |
+
slg_layers = None,
|
| 151 |
+
slg_start = 0.0,
|
| 152 |
+
slg_end = 1.0,
|
| 153 |
+
):
|
| 154 |
r"""
|
| 155 |
Generates video frames from input image and text prompt using diffusion process.
|
| 156 |
|
|
|
|
| 334 |
callback(-1, None)
|
| 335 |
|
| 336 |
for i, t in enumerate(tqdm(timesteps)):
|
| 337 |
+
slg_layers_local = None
|
| 338 |
+
if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
|
| 339 |
+
slg_layers_local = slg_layers
|
| 340 |
+
|
| 341 |
offload.set_step_no_for_lora(i)
|
| 342 |
latent_model_input = [latent.to(self.device)]
|
| 343 |
timestep = [t]
|
| 344 |
|
| 345 |
timestep = torch.stack(timestep).to(self.device)
|
| 346 |
if joint_pass:
|
| 347 |
+
if slg_layers is not None:
|
| 348 |
+
raise ValueError('Can not use SLG and joint-pass')
|
| 349 |
noise_pred_cond, noise_pred_uncond = self.model(
|
| 350 |
latent_model_input, t=timestep, current_step=i, **arg_both)
|
| 351 |
if self._interrupt:
|
| 352 |
return None
|
| 353 |
else:
|
| 354 |
noise_pred_cond = self.model(
|
| 355 |
+
latent_model_input,
|
| 356 |
+
t=timestep,
|
| 357 |
+
current_step=i,
|
| 358 |
+
is_uncond=False,
|
| 359 |
+
**arg_c,
|
| 360 |
+
)[0]
|
| 361 |
if self._interrupt:
|
| 362 |
return None
|
| 363 |
if offload_model:
|
| 364 |
torch.cuda.empty_cache()
|
| 365 |
noise_pred_uncond = self.model(
|
| 366 |
+
latent_model_input,
|
| 367 |
+
t=timestep,
|
| 368 |
+
current_step=i,
|
| 369 |
+
is_uncond=True,
|
| 370 |
+
slg_layers=slg_layers_local,
|
| 371 |
+
**arg_null,
|
| 372 |
+
)[0]
|
| 373 |
if self._interrupt:
|
| 374 |
return None
|
| 375 |
del latent_model_input
|
wan/modules/model.py
CHANGED
|
@@ -716,7 +716,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 716 |
pipeline = None,
|
| 717 |
current_step = 0,
|
| 718 |
context2 = None,
|
| 719 |
-
is_uncond=False
|
|
|
|
| 720 |
):
|
| 721 |
r"""
|
| 722 |
Forward pass through the diffusion model
|
|
@@ -843,7 +844,9 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 843 |
# context=context,
|
| 844 |
context_lens=context_lens)
|
| 845 |
|
| 846 |
-
for block in self.blocks:
|
|
|
|
|
|
|
| 847 |
if pipeline._interrupt:
|
| 848 |
if joint_pass:
|
| 849 |
return None, None
|
|
|
|
| 716 |
pipeline = None,
|
| 717 |
current_step = 0,
|
| 718 |
context2 = None,
|
| 719 |
+
is_uncond=False,
|
| 720 |
+
slg_layers=None,
|
| 721 |
):
|
| 722 |
r"""
|
| 723 |
Forward pass through the diffusion model
|
|
|
|
| 844 |
# context=context,
|
| 845 |
context_lens=context_lens)
|
| 846 |
|
| 847 |
+
for block_idx, block in enumerate(self.blocks):
|
| 848 |
+
if slg_layers is not None and block_idx in slg_layers and is_uncond:
|
| 849 |
+
continue
|
| 850 |
if pipeline._interrupt:
|
| 851 |
if joint_pass:
|
| 852 |
return None, None
|