DeepBeepMeep commited on
Commit
4dc6747
·
1 Parent(s): 20fdfdd

Added TeaCache support

Browse files
Files changed (5) hide show
  1. README.md +3 -2
  2. gradio_server.py +27 -13
  3. wan/image2video.py +2 -3
  4. wan/modules/model.py +73 -16
  5. wan/text2video.py +2 -3
README.md CHANGED
@@ -20,7 +20,8 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
20
 
21
  ## 🔥 Latest News!!
22
 
23
- * Mar 03, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
 
24
  - Support for all Wan including the Image to Video model
25
  - Reduced memory consumption by 2, with possiblity to generate more than 10s of video at 720p with a RTX 4090 and 10s of video at 480p with less than 12GB of VRAM. Many thanks to REFLEx (https://github.com/thu-ml/RIFLEx) for their algorithm that allows generating nice looking video longer than 5s.
26
  - The usual perks: web interface, multiple generations, loras support, sage attebtion, auto download of models, ...
@@ -162,7 +163,7 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
162
 
163
  ### Profiles (for power users only)
164
  You can choose between 5 profiles, but two are really relevant here :
165
- - LowRAM_HighVRAM (3): loads entirely the model in VRAM, slighty faster, but less VRAM
166
  - LowRAM_LowVRAM (4): load only the part of the models that is needed, low VRAM and low RAM requirement but slightly slower
167
 
168
 
 
20
 
21
  ## 🔥 Latest News!!
22
 
23
+ * Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
24
+ * Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
25
  - Support for all Wan including the Image to Video model
26
  - Reduced memory consumption by 2, with possiblity to generate more than 10s of video at 720p with a RTX 4090 and 10s of video at 480p with less than 12GB of VRAM. Many thanks to REFLEx (https://github.com/thu-ml/RIFLEx) for their algorithm that allows generating nice looking video longer than 5s.
27
  - The usual perks: web interface, multiple generations, loras support, sage attebtion, auto download of models, ...
 
163
 
164
  ### Profiles (for power users only)
165
  You can choose between 5 profiles, but two are really relevant here :
166
+ - LowRAM_HighVRAM (3): loads entirely the model in VRAM, slightly faster, but less VRAM
167
  - LowRAM_LowVRAM (4): load only the part of the models that is needed, low VRAM and low RAM requirement but slightly slower
168
 
169
 
gradio_server.py CHANGED
@@ -19,7 +19,7 @@ from wan.modules.attention import get_attention_modes
19
  import torch
20
  import gc
21
  import traceback
22
-
23
 
24
  def _parse_args():
25
  parser = argparse.ArgumentParser(
@@ -650,6 +650,7 @@ def generate_video(
650
  embedded_guidance_scale,
651
  repeat_generation,
652
  tea_cache,
 
653
  loras_choices,
654
  loras_mult_choices,
655
  image_to_continue,
@@ -783,12 +784,15 @@ def generate_video(
783
  break
784
 
785
  if trans.enable_teacache:
786
- trans.num_steps = num_inference_steps
787
- trans.cnt = 0
788
- trans.rel_l1_thresh = tea_cache #0.15 # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
789
- trans.accumulated_rel_l1_distance = 0
790
- trans.previous_modulated_input = None
791
- trans.previous_residual = None
 
 
 
792
 
793
  video_no += 1
794
  status = f"Video {video_no}/{total_video}"
@@ -799,6 +803,7 @@ def generate_video(
799
 
800
  gc.collect()
801
  torch.cuda.empty_cache()
 
802
  try:
803
  if use_image2video:
804
  samples = wan_model.generate(
@@ -858,6 +863,9 @@ def generate_video(
858
  else:
859
  raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
860
 
 
 
 
861
 
862
  if samples != None:
863
  samples = samples.to("cpu")
@@ -874,7 +882,10 @@ def generate_video(
874
  # video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
875
 
876
  time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
877
- file_name = f"{time_flag}_seed{seed}_{prompt[:100].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ')
 
 
 
878
  video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
879
  cache_video(
880
  tensor=sample[None],
@@ -1189,14 +1200,16 @@ def create_demo():
1189
  flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
1190
  tea_cache_setting = gr.Dropdown(
1191
  choices=[
1192
- ("Disabled", 0),
1193
- ("Fast (x1.6 speed up)", 0.1),
1194
- ("Faster (x2.1 speed up)", 0.15),
 
1195
  ],
1196
  value=default_tea_cache,
1197
- visible=False,
1198
- label="Tea Cache acceleration (the faster the acceleration the higher the degradation of the quality of the video. Consumes VRAM)"
1199
  )
 
1200
 
1201
  RIFLEx_setting = gr.Dropdown(
1202
  choices=[
@@ -1241,6 +1254,7 @@ def create_demo():
1241
  embedded_guidance_scale,
1242
  repeat_generation,
1243
  tea_cache_setting,
 
1244
  loras_choices,
1245
  loras_mult_choices,
1246
  image_to_continue,
 
19
  import torch
20
  import gc
21
  import traceback
22
+ import math
23
 
24
  def _parse_args():
25
  parser = argparse.ArgumentParser(
 
650
  embedded_guidance_scale,
651
  repeat_generation,
652
  tea_cache,
653
+ tea_cache_start_step_perc,
654
  loras_choices,
655
  loras_mult_choices,
656
  image_to_continue,
 
784
  break
785
 
786
  if trans.enable_teacache:
787
+ trans.teacache_counter = 0
788
+ trans.rel_l1_thresh = tea_cache
789
+ trans.teacache_start_step = max(math.ceil(tea_cache_start_step_perc*num_inference_steps/100),2)
790
+ trans.previous_residual_uncond = None
791
+ trans.previous_modulated_input_uncond = None
792
+ trans.previous_residual_cond = None
793
+ trans.previous_modulated_input_cond= None
794
+
795
+ trans.teacache_cache_device = "cuda" if profile==3 or profile==1 else "cpu"
796
 
797
  video_no += 1
798
  status = f"Video {video_no}/{total_video}"
 
803
 
804
  gc.collect()
805
  torch.cuda.empty_cache()
806
+ wan_model._interrupt = False
807
  try:
808
  if use_image2video:
809
  samples = wan_model.generate(
 
863
  else:
864
  raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
865
 
866
+ if trans.enable_teacache:
867
+ trans.previous_residual_uncond = None
868
+ trans.previous_residual_cond = None
869
 
870
  if samples != None:
871
  samples = samples.to("cpu")
 
882
  # video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
883
 
884
  time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
885
+ if os.name == 'nt':
886
+ file_name = f"{time_flag}_seed{seed}_{prompt[:50].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ')
887
+ else:
888
+ file_name = f"{time_flag}_seed{seed}_{prompt[:100].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ')
889
  video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
890
  cache_video(
891
  tensor=sample[None],
 
1200
  flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
1201
  tea_cache_setting = gr.Dropdown(
1202
  choices=[
1203
+ ("Tea Cache Disabled", 0),
1204
+ ("0.03 (around x1.6 speed up)", 0.03),
1205
+ ("0.05 (around x2 speed up)", 0.05),
1206
+ ("0.10 (around x3 speed up)", 0.1),
1207
  ],
1208
  value=default_tea_cache,
1209
+ visible=True,
1210
+ label="Tea Cache Threshold to Skip Steps (the higher, the more steps are skipped but the lower the quality of the video (Tea Cache Consumes VRAM)"
1211
  )
1212
+ tea_cache_start_step_perc = gr.Slider(2, 100, value=20, step=1, label="Tea Cache starting moment in percentage of generation (the later, the higher the quality but also the lower the speed gain)")
1213
 
1214
  RIFLEx_setting = gr.Dropdown(
1215
  choices=[
 
1254
  embedded_guidance_scale,
1255
  repeat_generation,
1256
  tea_cache_setting,
1257
+ tea_cache_start_step_perc,
1258
  loras_choices,
1259
  loras_mult_choices,
1260
  image_to_continue,
wan/image2video.py CHANGED
@@ -316,7 +316,6 @@ class WanI2V:
316
  if callback != None:
317
  callback(-1, None)
318
 
319
- self._interrupt = False
320
  for i, t in enumerate(tqdm(timesteps)):
321
  latent_model_input = [latent.to(self.device)]
322
  timestep = [t]
@@ -324,13 +323,13 @@ class WanI2V:
324
  timestep = torch.stack(timestep).to(self.device)
325
 
326
  noise_pred_cond = self.model(
327
- latent_model_input, t=timestep, **arg_c)[0]
328
  if self._interrupt:
329
  return None
330
  if offload_model:
331
  torch.cuda.empty_cache()
332
  noise_pred_uncond = self.model(
333
- latent_model_input, t=timestep, **arg_null)[0]
334
  if self._interrupt:
335
  return None
336
  del latent_model_input
 
316
  if callback != None:
317
  callback(-1, None)
318
 
 
319
  for i, t in enumerate(tqdm(timesteps)):
320
  latent_model_input = [latent.to(self.device)]
321
  timestep = [t]
 
323
  timestep = torch.stack(timestep).to(self.device)
324
 
325
  noise_pred_cond = self.model(
326
+ latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
327
  if self._interrupt:
328
  return None
329
  if offload_model:
330
  torch.cuda.empty_cache()
331
  noise_pred_uncond = self.model(
332
+ latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
333
  if self._interrupt:
334
  return None
335
  del latent_model_input
wan/modules/model.py CHANGED
@@ -146,6 +146,11 @@ def rope_apply(x, grid_sizes, freqs):
146
  output.append(x_i)
147
  return torch.stack(output) #.float()
148
 
 
 
 
 
 
149
 
150
  class WanRMSNorm(nn.Module):
151
 
@@ -662,6 +667,7 @@ class WanModel(ModelMixin, ConfigMixin):
662
 
663
  return freqs
664
 
 
665
  def forward(
666
  self,
667
  x,
@@ -672,6 +678,8 @@ class WanModel(ModelMixin, ConfigMixin):
672
  y=None,
673
  freqs = None,
674
  pipeline = None,
 
 
675
  ):
676
  r"""
677
  Forward pass through the diffusion model
@@ -723,7 +731,6 @@ class WanModel(ModelMixin, ConfigMixin):
723
  e = self.time_embedding(
724
  sinusoidal_embedding_1d(self.freq_dim, t))
725
  e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(torch.bfloat16)
726
- # assert e.dtype == torch.float32 and e0.dtype == torch.float32
727
 
728
  # context
729
  context_lens = None
@@ -737,21 +744,71 @@ class WanModel(ModelMixin, ConfigMixin):
737
  if clip_fea is not None:
738
  context_clip = self.img_emb(clip_fea) # bs x 257 x dim
739
  context = torch.concat([context_clip, context], dim=1)
740
-
741
- # arguments
742
- kwargs = dict(
743
- e=e0,
744
- seq_lens=seq_lens,
745
- grid_sizes=grid_sizes,
746
- freqs=freqs,
747
- context=context,
748
- context_lens=context_lens)
749
-
750
- for block in self.blocks:
751
- if pipeline._interrupt:
752
- return [None]
753
-
754
- x = block(x, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755
 
756
  # head
757
  x = self.head(x, e)
 
146
  output.append(x_i)
147
  return torch.stack(output) #.float()
148
 
149
+ def relative_l1_distance(last_tensor, current_tensor):
150
+ l1_distance = torch.abs(last_tensor - current_tensor).mean()
151
+ norm = torch.abs(last_tensor).mean()
152
+ relative_l1_distance = l1_distance / norm
153
+ return relative_l1_distance.to(torch.float32)
154
 
155
  class WanRMSNorm(nn.Module):
156
 
 
667
 
668
  return freqs
669
 
670
+
671
  def forward(
672
  self,
673
  x,
 
678
  y=None,
679
  freqs = None,
680
  pipeline = None,
681
+ current_step = 0,
682
+ is_uncond=False
683
  ):
684
  r"""
685
  Forward pass through the diffusion model
 
731
  e = self.time_embedding(
732
  sinusoidal_embedding_1d(self.freq_dim, t))
733
  e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(torch.bfloat16)
 
734
 
735
  # context
736
  context_lens = None
 
744
  if clip_fea is not None:
745
  context_clip = self.img_emb(clip_fea) # bs x 257 x dim
746
  context = torch.concat([context_clip, context], dim=1)
747
+ # deepbeepmeep optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
748
+ should_calc = True
749
+ if self.enable_teacache and current_step >= self.teacache_start_step:
750
+ if current_step == self.teacache_start_step:
751
+ self.accumulated_rel_l1_distance_cond = 0
752
+ self.accumulated_rel_l1_distance_uncond = 0
753
+ self.teacache_skipped_cond_steps = 0
754
+ self.teacache_skipped_uncond_steps = 0
755
+ else:
756
+ prev_input = self.previous_modulated_input_uncond if is_uncond else self.previous_modulated_input_cond
757
+ acc_distance_attr = 'accumulated_rel_l1_distance_uncond' if is_uncond else 'accumulated_rel_l1_distance_cond'
758
+
759
+ temb_relative_l1 = relative_l1_distance(prev_input, e0)
760
+ setattr(self, acc_distance_attr, getattr(self, acc_distance_attr) + temb_relative_l1)
761
+
762
+ if getattr(self, acc_distance_attr) < self.rel_l1_thresh:
763
+ should_calc = False
764
+ self.teacache_counter += 1
765
+ else:
766
+ should_calc = True
767
+ setattr(self, acc_distance_attr, 0)
768
+
769
+ if is_uncond:
770
+ self.previous_modulated_input_uncond = e0.clone()
771
+ if should_calc:
772
+ self.previous_residual_uncond = None
773
+ else:
774
+ x += self.previous_residual_uncond
775
+ self.teacache_skipped_cond_steps += 1
776
+ # print(f"Skipped uncond:{self.teacache_skipped_cond_steps}/{current_step}" )
777
+ else:
778
+ self.previous_modulated_input_cond = e0.clone()
779
+ if should_calc:
780
+ self.previous_residual_cond = None
781
+ else:
782
+ x += self.previous_residual_cond
783
+ self.teacache_skipped_uncond_steps += 1
784
+ # print(f"Skipped uncond:{self.teacache_skipped_uncond_steps}/{current_step}" )
785
+
786
+ if should_calc:
787
+ if self.enable_teacache:
788
+ ori_hidden_states = x.clone()
789
+ # arguments
790
+ kwargs = dict(
791
+ e=e0,
792
+ seq_lens=seq_lens,
793
+ grid_sizes=grid_sizes,
794
+ freqs=freqs,
795
+ context=context,
796
+ context_lens=context_lens)
797
+
798
+ for block in self.blocks:
799
+ if pipeline._interrupt:
800
+ return [None]
801
+
802
+ x = block(x, **kwargs)
803
+
804
+ if self.enable_teacache:
805
+ residual = ori_hidden_states # just to have a readable code
806
+ torch.sub(x, ori_hidden_states, out=residual)
807
+ if is_uncond:
808
+ self.previous_residual_uncond = residual
809
+ else:
810
+ self.previous_residual_cond = residual
811
+ del residual, ori_hidden_states
812
 
813
  # head
814
  x = self.head(x, e)
wan/text2video.py CHANGED
@@ -248,7 +248,6 @@ class WanT2V:
248
 
249
  if callback != None:
250
  callback(-1, None)
251
- self._interrupt = False
252
  for i, t in enumerate(tqdm(timesteps)):
253
  latent_model_input = latents
254
  timestep = [t]
@@ -257,11 +256,11 @@ class WanT2V:
257
 
258
  # self.model.to(self.device)
259
  noise_pred_cond = self.model(
260
- latent_model_input, t=timestep, **arg_c)[0]
261
  if self._interrupt:
262
  return None
263
  noise_pred_uncond = self.model(
264
- latent_model_input, t=timestep, **arg_null)[0]
265
  if self._interrupt:
266
  return None
267
 
 
248
 
249
  if callback != None:
250
  callback(-1, None)
 
251
  for i, t in enumerate(tqdm(timesteps)):
252
  latent_model_input = latents
253
  timestep = [t]
 
256
 
257
  # self.model.to(self.device)
258
  noise_pred_cond = self.model(
259
+ latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
260
  if self._interrupt:
261
  return None
262
  noise_pred_uncond = self.model(
263
+ latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0]
264
  if self._interrupt:
265
  return None
266