DeepBeepMeep commited on
Commit ·
4dc6747
1
Parent(s): 20fdfdd
Added TeaCache support
Browse files- README.md +3 -2
- gradio_server.py +27 -13
- wan/image2video.py +2 -3
- wan/modules/model.py +73 -16
- wan/text2video.py +2 -3
README.md
CHANGED
|
@@ -20,7 +20,8 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
|
| 20 |
|
| 21 |
## 🔥 Latest News!!
|
| 22 |
|
| 23 |
-
* Mar 03, 2025: 👋 Wan2.1GP
|
|
|
|
| 24 |
- Support for all Wan including the Image to Video model
|
| 25 |
- Reduced memory consumption by 2, with possiblity to generate more than 10s of video at 720p with a RTX 4090 and 10s of video at 480p with less than 12GB of VRAM. Many thanks to REFLEx (https://github.com/thu-ml/RIFLEx) for their algorithm that allows generating nice looking video longer than 5s.
|
| 26 |
- The usual perks: web interface, multiple generations, loras support, sage attebtion, auto download of models, ...
|
|
@@ -162,7 +163,7 @@ You will find prebuilt Loras on https://civitai.com/ or you will be able to buil
|
|
| 162 |
|
| 163 |
### Profiles (for power users only)
|
| 164 |
You can choose between 5 profiles, but two are really relevant here :
|
| 165 |
-
- LowRAM_HighVRAM (3): loads entirely the model in VRAM,
|
| 166 |
- LowRAM_LowVRAM (4): load only the part of the models that is needed, low VRAM and low RAM requirement but slightly slower
|
| 167 |
|
| 168 |
|
|
|
|
| 20 |
|
| 21 |
## 🔥 Latest News!!
|
| 22 |
|
| 23 |
+
* Mar 03, 2025: 👋 Wan2.1GP v1.1: added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
|
| 24 |
+
* Mar 02, 2025: 👋 Wan2.1GP by DeepBeepMeep v1 brings:
|
| 25 |
- Support for all Wan including the Image to Video model
|
| 26 |
- Reduced memory consumption by 2, with possiblity to generate more than 10s of video at 720p with a RTX 4090 and 10s of video at 480p with less than 12GB of VRAM. Many thanks to REFLEx (https://github.com/thu-ml/RIFLEx) for their algorithm that allows generating nice looking video longer than 5s.
|
| 27 |
- The usual perks: web interface, multiple generations, loras support, sage attebtion, auto download of models, ...
|
|
|
|
| 163 |
|
| 164 |
### Profiles (for power users only)
|
| 165 |
You can choose between 5 profiles, but two are really relevant here :
|
| 166 |
+
- LowRAM_HighVRAM (3): loads entirely the model in VRAM, slightly faster, but less VRAM
|
| 167 |
- LowRAM_LowVRAM (4): load only the part of the models that is needed, low VRAM and low RAM requirement but slightly slower
|
| 168 |
|
| 169 |
|
gradio_server.py
CHANGED
|
@@ -19,7 +19,7 @@ from wan.modules.attention import get_attention_modes
|
|
| 19 |
import torch
|
| 20 |
import gc
|
| 21 |
import traceback
|
| 22 |
-
|
| 23 |
|
| 24 |
def _parse_args():
|
| 25 |
parser = argparse.ArgumentParser(
|
|
@@ -650,6 +650,7 @@ def generate_video(
|
|
| 650 |
embedded_guidance_scale,
|
| 651 |
repeat_generation,
|
| 652 |
tea_cache,
|
|
|
|
| 653 |
loras_choices,
|
| 654 |
loras_mult_choices,
|
| 655 |
image_to_continue,
|
|
@@ -783,12 +784,15 @@ def generate_video(
|
|
| 783 |
break
|
| 784 |
|
| 785 |
if trans.enable_teacache:
|
| 786 |
-
trans.
|
| 787 |
-
trans.
|
| 788 |
-
trans.
|
| 789 |
-
trans.
|
| 790 |
-
trans.
|
| 791 |
-
trans.
|
|
|
|
|
|
|
|
|
|
| 792 |
|
| 793 |
video_no += 1
|
| 794 |
status = f"Video {video_no}/{total_video}"
|
|
@@ -799,6 +803,7 @@ def generate_video(
|
|
| 799 |
|
| 800 |
gc.collect()
|
| 801 |
torch.cuda.empty_cache()
|
|
|
|
| 802 |
try:
|
| 803 |
if use_image2video:
|
| 804 |
samples = wan_model.generate(
|
|
@@ -858,6 +863,9 @@ def generate_video(
|
|
| 858 |
else:
|
| 859 |
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
| 860 |
|
|
|
|
|
|
|
|
|
|
| 861 |
|
| 862 |
if samples != None:
|
| 863 |
samples = samples.to("cpu")
|
|
@@ -874,7 +882,10 @@ def generate_video(
|
|
| 874 |
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
|
| 875 |
|
| 876 |
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
| 877 |
-
|
|
|
|
|
|
|
|
|
|
| 878 |
video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
|
| 879 |
cache_video(
|
| 880 |
tensor=sample[None],
|
|
@@ -1189,14 +1200,16 @@ def create_demo():
|
|
| 1189 |
flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
|
| 1190 |
tea_cache_setting = gr.Dropdown(
|
| 1191 |
choices=[
|
| 1192 |
-
("Disabled", 0),
|
| 1193 |
-
("
|
| 1194 |
-
("
|
|
|
|
| 1195 |
],
|
| 1196 |
value=default_tea_cache,
|
| 1197 |
-
visible=
|
| 1198 |
-
label="Tea Cache
|
| 1199 |
)
|
|
|
|
| 1200 |
|
| 1201 |
RIFLEx_setting = gr.Dropdown(
|
| 1202 |
choices=[
|
|
@@ -1241,6 +1254,7 @@ def create_demo():
|
|
| 1241 |
embedded_guidance_scale,
|
| 1242 |
repeat_generation,
|
| 1243 |
tea_cache_setting,
|
|
|
|
| 1244 |
loras_choices,
|
| 1245 |
loras_mult_choices,
|
| 1246 |
image_to_continue,
|
|
|
|
| 19 |
import torch
|
| 20 |
import gc
|
| 21 |
import traceback
|
| 22 |
+
import math
|
| 23 |
|
| 24 |
def _parse_args():
|
| 25 |
parser = argparse.ArgumentParser(
|
|
|
|
| 650 |
embedded_guidance_scale,
|
| 651 |
repeat_generation,
|
| 652 |
tea_cache,
|
| 653 |
+
tea_cache_start_step_perc,
|
| 654 |
loras_choices,
|
| 655 |
loras_mult_choices,
|
| 656 |
image_to_continue,
|
|
|
|
| 784 |
break
|
| 785 |
|
| 786 |
if trans.enable_teacache:
|
| 787 |
+
trans.teacache_counter = 0
|
| 788 |
+
trans.rel_l1_thresh = tea_cache
|
| 789 |
+
trans.teacache_start_step = max(math.ceil(tea_cache_start_step_perc*num_inference_steps/100),2)
|
| 790 |
+
trans.previous_residual_uncond = None
|
| 791 |
+
trans.previous_modulated_input_uncond = None
|
| 792 |
+
trans.previous_residual_cond = None
|
| 793 |
+
trans.previous_modulated_input_cond= None
|
| 794 |
+
|
| 795 |
+
trans.teacache_cache_device = "cuda" if profile==3 or profile==1 else "cpu"
|
| 796 |
|
| 797 |
video_no += 1
|
| 798 |
status = f"Video {video_no}/{total_video}"
|
|
|
|
| 803 |
|
| 804 |
gc.collect()
|
| 805 |
torch.cuda.empty_cache()
|
| 806 |
+
wan_model._interrupt = False
|
| 807 |
try:
|
| 808 |
if use_image2video:
|
| 809 |
samples = wan_model.generate(
|
|
|
|
| 863 |
else:
|
| 864 |
raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
| 865 |
|
| 866 |
+
if trans.enable_teacache:
|
| 867 |
+
trans.previous_residual_uncond = None
|
| 868 |
+
trans.previous_residual_cond = None
|
| 869 |
|
| 870 |
if samples != None:
|
| 871 |
samples = samples.to("cpu")
|
|
|
|
| 882 |
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
|
| 883 |
|
| 884 |
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%Hh%Mm%Ss")
|
| 885 |
+
if os.name == 'nt':
|
| 886 |
+
file_name = f"{time_flag}_seed{seed}_{prompt[:50].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ')
|
| 887 |
+
else:
|
| 888 |
+
file_name = f"{time_flag}_seed{seed}_{prompt[:100].replace('/','').strip()}.mp4".replace(':',' ').replace('\\',' ')
|
| 889 |
video_path = os.path.join(os.getcwd(), "gradio_outputs", file_name)
|
| 890 |
cache_video(
|
| 891 |
tensor=sample[None],
|
|
|
|
| 1200 |
flow_shift = gr.Slider(0.0, 25.0, value= default_flow_shift, step=0.1, label="Shift Scale")
|
| 1201 |
tea_cache_setting = gr.Dropdown(
|
| 1202 |
choices=[
|
| 1203 |
+
("Tea Cache Disabled", 0),
|
| 1204 |
+
("0.03 (around x1.6 speed up)", 0.03),
|
| 1205 |
+
("0.05 (around x2 speed up)", 0.05),
|
| 1206 |
+
("0.10 (around x3 speed up)", 0.1),
|
| 1207 |
],
|
| 1208 |
value=default_tea_cache,
|
| 1209 |
+
visible=True,
|
| 1210 |
+
label="Tea Cache Threshold to Skip Steps (the higher, the more steps are skipped but the lower the quality of the video (Tea Cache Consumes VRAM)"
|
| 1211 |
)
|
| 1212 |
+
tea_cache_start_step_perc = gr.Slider(2, 100, value=20, step=1, label="Tea Cache starting moment in percentage of generation (the later, the higher the quality but also the lower the speed gain)")
|
| 1213 |
|
| 1214 |
RIFLEx_setting = gr.Dropdown(
|
| 1215 |
choices=[
|
|
|
|
| 1254 |
embedded_guidance_scale,
|
| 1255 |
repeat_generation,
|
| 1256 |
tea_cache_setting,
|
| 1257 |
+
tea_cache_start_step_perc,
|
| 1258 |
loras_choices,
|
| 1259 |
loras_mult_choices,
|
| 1260 |
image_to_continue,
|
wan/image2video.py
CHANGED
|
@@ -316,7 +316,6 @@ class WanI2V:
|
|
| 316 |
if callback != None:
|
| 317 |
callback(-1, None)
|
| 318 |
|
| 319 |
-
self._interrupt = False
|
| 320 |
for i, t in enumerate(tqdm(timesteps)):
|
| 321 |
latent_model_input = [latent.to(self.device)]
|
| 322 |
timestep = [t]
|
|
@@ -324,13 +323,13 @@ class WanI2V:
|
|
| 324 |
timestep = torch.stack(timestep).to(self.device)
|
| 325 |
|
| 326 |
noise_pred_cond = self.model(
|
| 327 |
-
latent_model_input, t=timestep, **arg_c)[0]
|
| 328 |
if self._interrupt:
|
| 329 |
return None
|
| 330 |
if offload_model:
|
| 331 |
torch.cuda.empty_cache()
|
| 332 |
noise_pred_uncond = self.model(
|
| 333 |
-
latent_model_input, t=timestep, **arg_null)[0]
|
| 334 |
if self._interrupt:
|
| 335 |
return None
|
| 336 |
del latent_model_input
|
|
|
|
| 316 |
if callback != None:
|
| 317 |
callback(-1, None)
|
| 318 |
|
|
|
|
| 319 |
for i, t in enumerate(tqdm(timesteps)):
|
| 320 |
latent_model_input = [latent.to(self.device)]
|
| 321 |
timestep = [t]
|
|
|
|
| 323 |
timestep = torch.stack(timestep).to(self.device)
|
| 324 |
|
| 325 |
noise_pred_cond = self.model(
|
| 326 |
+
latent_model_input, t=timestep, current_step=i, is_uncond = False, **arg_c)[0]
|
| 327 |
if self._interrupt:
|
| 328 |
return None
|
| 329 |
if offload_model:
|
| 330 |
torch.cuda.empty_cache()
|
| 331 |
noise_pred_uncond = self.model(
|
| 332 |
+
latent_model_input, t=timestep, current_step=i, is_uncond = True, **arg_null)[0]
|
| 333 |
if self._interrupt:
|
| 334 |
return None
|
| 335 |
del latent_model_input
|
wan/modules/model.py
CHANGED
|
@@ -146,6 +146,11 @@ def rope_apply(x, grid_sizes, freqs):
|
|
| 146 |
output.append(x_i)
|
| 147 |
return torch.stack(output) #.float()
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
class WanRMSNorm(nn.Module):
|
| 151 |
|
|
@@ -662,6 +667,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 662 |
|
| 663 |
return freqs
|
| 664 |
|
|
|
|
| 665 |
def forward(
|
| 666 |
self,
|
| 667 |
x,
|
|
@@ -672,6 +678,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 672 |
y=None,
|
| 673 |
freqs = None,
|
| 674 |
pipeline = None,
|
|
|
|
|
|
|
| 675 |
):
|
| 676 |
r"""
|
| 677 |
Forward pass through the diffusion model
|
|
@@ -723,7 +731,6 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 723 |
e = self.time_embedding(
|
| 724 |
sinusoidal_embedding_1d(self.freq_dim, t))
|
| 725 |
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(torch.bfloat16)
|
| 726 |
-
# assert e.dtype == torch.float32 and e0.dtype == torch.float32
|
| 727 |
|
| 728 |
# context
|
| 729 |
context_lens = None
|
|
@@ -737,21 +744,71 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 737 |
if clip_fea is not None:
|
| 738 |
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
| 739 |
context = torch.concat([context_clip, context], dim=1)
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 755 |
|
| 756 |
# head
|
| 757 |
x = self.head(x, e)
|
|
|
|
| 146 |
output.append(x_i)
|
| 147 |
return torch.stack(output) #.float()
|
| 148 |
|
| 149 |
+
def relative_l1_distance(last_tensor, current_tensor):
|
| 150 |
+
l1_distance = torch.abs(last_tensor - current_tensor).mean()
|
| 151 |
+
norm = torch.abs(last_tensor).mean()
|
| 152 |
+
relative_l1_distance = l1_distance / norm
|
| 153 |
+
return relative_l1_distance.to(torch.float32)
|
| 154 |
|
| 155 |
class WanRMSNorm(nn.Module):
|
| 156 |
|
|
|
|
| 667 |
|
| 668 |
return freqs
|
| 669 |
|
| 670 |
+
|
| 671 |
def forward(
|
| 672 |
self,
|
| 673 |
x,
|
|
|
|
| 678 |
y=None,
|
| 679 |
freqs = None,
|
| 680 |
pipeline = None,
|
| 681 |
+
current_step = 0,
|
| 682 |
+
is_uncond=False
|
| 683 |
):
|
| 684 |
r"""
|
| 685 |
Forward pass through the diffusion model
|
|
|
|
| 731 |
e = self.time_embedding(
|
| 732 |
sinusoidal_embedding_1d(self.freq_dim, t))
|
| 733 |
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).to(torch.bfloat16)
|
|
|
|
| 734 |
|
| 735 |
# context
|
| 736 |
context_lens = None
|
|
|
|
| 744 |
if clip_fea is not None:
|
| 745 |
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
| 746 |
context = torch.concat([context_clip, context], dim=1)
|
| 747 |
+
# deepbeepmeep optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
|
| 748 |
+
should_calc = True
|
| 749 |
+
if self.enable_teacache and current_step >= self.teacache_start_step:
|
| 750 |
+
if current_step == self.teacache_start_step:
|
| 751 |
+
self.accumulated_rel_l1_distance_cond = 0
|
| 752 |
+
self.accumulated_rel_l1_distance_uncond = 0
|
| 753 |
+
self.teacache_skipped_cond_steps = 0
|
| 754 |
+
self.teacache_skipped_uncond_steps = 0
|
| 755 |
+
else:
|
| 756 |
+
prev_input = self.previous_modulated_input_uncond if is_uncond else self.previous_modulated_input_cond
|
| 757 |
+
acc_distance_attr = 'accumulated_rel_l1_distance_uncond' if is_uncond else 'accumulated_rel_l1_distance_cond'
|
| 758 |
+
|
| 759 |
+
temb_relative_l1 = relative_l1_distance(prev_input, e0)
|
| 760 |
+
setattr(self, acc_distance_attr, getattr(self, acc_distance_attr) + temb_relative_l1)
|
| 761 |
+
|
| 762 |
+
if getattr(self, acc_distance_attr) < self.rel_l1_thresh:
|
| 763 |
+
should_calc = False
|
| 764 |
+
self.teacache_counter += 1
|
| 765 |
+
else:
|
| 766 |
+
should_calc = True
|
| 767 |
+
setattr(self, acc_distance_attr, 0)
|
| 768 |
+
|
| 769 |
+
if is_uncond:
|
| 770 |
+
self.previous_modulated_input_uncond = e0.clone()
|
| 771 |
+
if should_calc:
|
| 772 |
+
self.previous_residual_uncond = None
|
| 773 |
+
else:
|
| 774 |
+
x += self.previous_residual_uncond
|
| 775 |
+
self.teacache_skipped_cond_steps += 1
|
| 776 |
+
# print(f"Skipped uncond:{self.teacache_skipped_cond_steps}/{current_step}" )
|
| 777 |
+
else:
|
| 778 |
+
self.previous_modulated_input_cond = e0.clone()
|
| 779 |
+
if should_calc:
|
| 780 |
+
self.previous_residual_cond = None
|
| 781 |
+
else:
|
| 782 |
+
x += self.previous_residual_cond
|
| 783 |
+
self.teacache_skipped_uncond_steps += 1
|
| 784 |
+
# print(f"Skipped uncond:{self.teacache_skipped_uncond_steps}/{current_step}" )
|
| 785 |
+
|
| 786 |
+
if should_calc:
|
| 787 |
+
if self.enable_teacache:
|
| 788 |
+
ori_hidden_states = x.clone()
|
| 789 |
+
# arguments
|
| 790 |
+
kwargs = dict(
|
| 791 |
+
e=e0,
|
| 792 |
+
seq_lens=seq_lens,
|
| 793 |
+
grid_sizes=grid_sizes,
|
| 794 |
+
freqs=freqs,
|
| 795 |
+
context=context,
|
| 796 |
+
context_lens=context_lens)
|
| 797 |
+
|
| 798 |
+
for block in self.blocks:
|
| 799 |
+
if pipeline._interrupt:
|
| 800 |
+
return [None]
|
| 801 |
+
|
| 802 |
+
x = block(x, **kwargs)
|
| 803 |
+
|
| 804 |
+
if self.enable_teacache:
|
| 805 |
+
residual = ori_hidden_states # just to have a readable code
|
| 806 |
+
torch.sub(x, ori_hidden_states, out=residual)
|
| 807 |
+
if is_uncond:
|
| 808 |
+
self.previous_residual_uncond = residual
|
| 809 |
+
else:
|
| 810 |
+
self.previous_residual_cond = residual
|
| 811 |
+
del residual, ori_hidden_states
|
| 812 |
|
| 813 |
# head
|
| 814 |
x = self.head(x, e)
|
wan/text2video.py
CHANGED
|
@@ -248,7 +248,6 @@ class WanT2V:
|
|
| 248 |
|
| 249 |
if callback != None:
|
| 250 |
callback(-1, None)
|
| 251 |
-
self._interrupt = False
|
| 252 |
for i, t in enumerate(tqdm(timesteps)):
|
| 253 |
latent_model_input = latents
|
| 254 |
timestep = [t]
|
|
@@ -257,11 +256,11 @@ class WanT2V:
|
|
| 257 |
|
| 258 |
# self.model.to(self.device)
|
| 259 |
noise_pred_cond = self.model(
|
| 260 |
-
latent_model_input, t=timestep, **arg_c)[0]
|
| 261 |
if self._interrupt:
|
| 262 |
return None
|
| 263 |
noise_pred_uncond = self.model(
|
| 264 |
-
latent_model_input, t=timestep, **arg_null)[0]
|
| 265 |
if self._interrupt:
|
| 266 |
return None
|
| 267 |
|
|
|
|
| 248 |
|
| 249 |
if callback != None:
|
| 250 |
callback(-1, None)
|
|
|
|
| 251 |
for i, t in enumerate(tqdm(timesteps)):
|
| 252 |
latent_model_input = latents
|
| 253 |
timestep = [t]
|
|
|
|
| 256 |
|
| 257 |
# self.model.to(self.device)
|
| 258 |
noise_pred_cond = self.model(
|
| 259 |
+
latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
|
| 260 |
if self._interrupt:
|
| 261 |
return None
|
| 262 |
noise_pred_uncond = self.model(
|
| 263 |
+
latent_model_input, t=timestep,current_step=i, is_uncond = True, **arg_null)[0]
|
| 264 |
if self._interrupt:
|
| 265 |
return None
|
| 266 |
|