DeepBeepMeep commited on
Commit
941388c
·
1 Parent(s): 3d16998

Added CFG Zero *

Browse files
README.md CHANGED
@@ -19,6 +19,12 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
19
 
20
 
21
  ## 🔥 Latest News!!
 
 
 
 
 
 
22
  * Mar 19 2022: 👋 Wan2.1GP v3.1: Faster launch and RAM optimizations (should require less RAM to run)\
23
  You will need one more *pip install -r requirements.txt*
24
  * Mar 18 2022: 👋 Wan2.1GP v3.0:
 
19
 
20
 
21
  ## 🔥 Latest News!!
22
+ * Mar 19 2022: 👋 Wan2.1GP v3.2:
23
+ - Added Classifier-Free Guidance Zero Star. The video should match better the text prompt (especially with text2video) at no performance cost: many thanks to the **CFG Zero * Team:**\
24
+ Dont hesitate to give them a star if you appreciate the results: https://github.com/WeichenFan/CFG-Zero-star
25
+ - Added back support for Pytorch compilation with Loras. It seems it had been broken for some time
26
+ - Added possibility to keep a number of pregenerated videos in the Video Gallery (useful to compare outputs of different settings)
27
+ You will need one more *pip install -r requirements.txt*
28
  * Mar 19 2022: 👋 Wan2.1GP v3.1: Faster launch and RAM optimizations (should require less RAM to run)\
29
  You will need one more *pip install -r requirements.txt*
30
  * Mar 18 2022: 👋 Wan2.1GP v3.0:
gradio_server.py CHANGED
@@ -23,7 +23,7 @@ import asyncio
23
  from wan.utils import prompt_parser
24
  PROMPT_VARS_MAX = 10
25
 
26
- target_mmgp_version = "3.3.3"
27
  from importlib.metadata import version
28
  mmgp_version = version("mmgp")
29
  if mmgp_version != target_mmgp_version:
@@ -300,6 +300,7 @@ if not Path(server_config_filename).is_file():
300
  "metadata_type": "metadata",
301
  "default_ui": "t2v",
302
  "boost" : 1,
 
303
  "vae_config": 0,
304
  "profile" : profile_type.LowRAM_LowVRAM }
305
 
@@ -382,7 +383,6 @@ if len(args.vae_config) > 0:
382
 
383
  reload_needed = False
384
  default_ui = server_config.get("default_ui", "t2v")
385
- metadata = server_config.get("metadata_type", "metadata")
386
  save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
387
  use_image2video = default_ui != "t2v"
388
  if args.t2v:
@@ -741,7 +741,8 @@ def apply_changes( state,
741
  vae_config_choice,
742
  metadata_choice,
743
  default_ui_choice ="t2v",
744
- boost_choice = 1
 
745
  ):
746
  if args.lock_config:
747
  return
@@ -760,6 +761,7 @@ def apply_changes( state,
760
  "metadata_choice": metadata_choice,
761
  "default_ui" : default_ui_choice,
762
  "boost" : boost_choice,
 
763
  }
764
 
765
  if Path(server_config_filename).is_file():
@@ -792,7 +794,7 @@ def apply_changes( state,
792
  text_encoder_filename = server_config["text_encoder_filename"]
793
  vae_config = server_config["vae_config"]
794
  boost = server_config["boost"]
795
- if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice"] for change in changes ):
796
  pass
797
  else:
798
  reload_needed = True
@@ -849,13 +851,17 @@ def refresh_gallery(state, txt):
849
  if len(prompt) == 0:
850
  return file_list, gr.Text(visible= False, value="")
851
  else:
 
 
 
 
852
  prompts_max = state.get("prompts_max",0)
853
  prompt_no = state.get("prompt_no",0)
854
  if prompts_max >1 :
855
  label = f"Current Prompt ({prompt_no+1}/{prompts_max})"
856
  else:
857
  label = f"Current Prompt"
858
- return file_list, gr.Text(visible= True, value=prompt, label=label)
859
 
860
 
861
  def finalize_gallery(state):
@@ -863,6 +869,8 @@ def finalize_gallery(state):
863
  if "in_progress" in state:
864
  del state["in_progress"]
865
  choice = state.get("selected",0)
 
 
866
 
867
  state["extra_orders"] = 0
868
  time.sleep(0.2)
@@ -930,6 +938,7 @@ def generate_video(
930
  tea_cache_start_step_perc,
931
  loras_choices,
932
  loras_mult_choices,
 
933
  image_to_continue,
934
  image_to_end,
935
  video_to_continue,
@@ -938,7 +947,9 @@ def generate_video(
938
  slg_switch,
939
  slg_layers,
940
  slg_start,
941
- slg_end,
 
 
942
  state,
943
  image2video,
944
  progress=gr.Progress() #track_tqdm= True
@@ -1031,6 +1042,8 @@ def generate_video(
1031
  if len(prompts) ==0:
1032
  return
1033
  if image2video:
 
 
1034
  if image_to_continue is not None:
1035
  if isinstance(image_to_continue, list):
1036
  image_to_continue = [ tup[0] for tup in image_to_continue ]
@@ -1135,7 +1148,6 @@ def generate_video(
1135
  if "abort" in state:
1136
  del state["abort"]
1137
  state["in_progress"] = True
1138
- state["selected"] = 0
1139
 
1140
  enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
1141
  # VAE Tiling
@@ -1172,8 +1184,21 @@ def generate_video(
1172
  if seed == None or seed <0:
1173
  seed = random.randint(0, 999999999)
1174
 
1175
- file_list = []
 
 
 
 
 
 
 
 
 
 
 
 
1176
  state["file_list"] = file_list
 
1177
  global save_path
1178
  os.makedirs(save_path, exist_ok=True)
1179
  video_no = 0
@@ -1240,6 +1265,8 @@ def generate_video(
1240
  slg_layers = slg_layers,
1241
  slg_start = slg_start/100,
1242
  slg_end = slg_end/100,
 
 
1243
  )
1244
 
1245
  else:
@@ -1260,6 +1287,8 @@ def generate_video(
1260
  slg_layers = slg_layers,
1261
  slg_start = slg_start/100,
1262
  slg_end = slg_end/100,
 
 
1263
  )
1264
  except Exception as e:
1265
  gen_in_progress = False
@@ -1326,7 +1355,7 @@ def generate_video(
1326
  value_range=(-1, 1))
1327
 
1328
  configs = get_settings_dict(state, use_image2video, prompt, 0 if image_to_end == None else 1 , video_length, raw_resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1329
- loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end)
1330
 
1331
  metadata_choice = server_config.get("metadata_choice","metadata")
1332
  if metadata_choice == "json":
@@ -1642,7 +1671,7 @@ def switch_advanced(state, new_advanced, lset_name):
1642
 
1643
 
1644
  def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1645
- loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc):
1646
 
1647
  loras = state["loras"]
1648
  activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ]
@@ -1666,7 +1695,9 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
1666
  "slg_switch": slg_switch,
1667
  "slg_layers": slg_layers,
1668
  "slg_start_perc": slg_start_perc,
1669
- "slg_end_perc": slg_end_perc
 
 
1670
  }
1671
 
1672
  if i2v:
@@ -1678,13 +1709,13 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
1678
  return ui_settings
1679
 
1680
  def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1681
- loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc):
1682
 
1683
  if state.get("validate_success",0) != 1:
1684
  return
1685
 
1686
  ui_defaults = get_settings_dict(state, use_image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1687
- loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc)
1688
 
1689
  defaults_filename = get_settings_file_name(use_image2video)
1690
 
@@ -1955,7 +1986,7 @@ def generate_video_tab(image2video=False):
1955
  label="RIFLEx positional embedding to generate long video"
1956
  )
1957
  with gr.Row():
1958
- gr.Markdown("<B>Experimental: Skip Layer guidance,should improve video quality</B>")
1959
  with gr.Row():
1960
  slg_switch = gr.Dropdown(
1961
  choices=[
@@ -1979,6 +2010,23 @@ def generate_video_tab(image2video=False):
1979
  with gr.Row():
1980
  slg_start_perc = gr.Slider(0, 100, value=ui_defaults["slg_start_perc"], step=1, label="Denoising Steps % start")
1981
  slg_end_perc = gr.Slider(0, 100, value=ui_defaults["slg_end_perc"], step=1, label="Denoising Steps % end")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1982
  with gr.Row():
1983
  save_settings_btn = gr.Button("Set Settings as Default")
1984
  show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
@@ -1997,7 +2045,7 @@ def generate_video_tab(image2video=False):
1997
  save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
1998
  save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
1999
  loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
2000
- slg_start_perc, slg_end_perc ], outputs = [])
2001
  save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
2002
  confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
2003
  save_lset, inputs=[state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
@@ -2035,6 +2083,7 @@ def generate_video_tab(image2video=False):
2035
  tea_cache_start_step_perc,
2036
  loras_choices,
2037
  loras_mult_choices,
 
2038
  image_to_continue,
2039
  image_to_end,
2040
  video_to_continue,
@@ -2044,6 +2093,8 @@ def generate_video_tab(image2video=False):
2044
  slg_layers,
2045
  slg_start_perc,
2046
  slg_end_perc,
 
 
2047
  state,
2048
  gr.State(image2video)
2049
  ],
@@ -2175,9 +2226,24 @@ def generate_configuration_tab():
2175
  ("Add metadata to video", "metadata"),
2176
  ("Neither", "none")
2177
  ],
2178
- value=metadata,
2179
  label="Metadata Handling"
2180
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2181
  msg = gr.Markdown()
2182
  apply_btn = gr.Button("Apply Changes")
2183
  apply_btn.click(
@@ -2195,6 +2261,7 @@ def generate_configuration_tab():
2195
  metadata_choice,
2196
  default_ui_choice,
2197
  boost_choice,
 
2198
  ],
2199
  outputs= msg
2200
  )
@@ -2262,7 +2329,7 @@ def create_demo():
2262
  }
2263
  """
2264
  with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
2265
- gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.1 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
2266
  gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
2267
 
2268
  with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
 
23
  from wan.utils import prompt_parser
24
  PROMPT_VARS_MAX = 10
25
 
26
+ target_mmgp_version = "3.3.4"
27
  from importlib.metadata import version
28
  mmgp_version = version("mmgp")
29
  if mmgp_version != target_mmgp_version:
 
300
  "metadata_type": "metadata",
301
  "default_ui": "t2v",
302
  "boost" : 1,
303
+ "clear_file_list" : 0,
304
  "vae_config": 0,
305
  "profile" : profile_type.LowRAM_LowVRAM }
306
 
 
383
 
384
  reload_needed = False
385
  default_ui = server_config.get("default_ui", "t2v")
 
386
  save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
387
  use_image2video = default_ui != "t2v"
388
  if args.t2v:
 
741
  vae_config_choice,
742
  metadata_choice,
743
  default_ui_choice ="t2v",
744
+ boost_choice = 1,
745
+ clear_file_list = 0,
746
  ):
747
  if args.lock_config:
748
  return
 
761
  "metadata_choice": metadata_choice,
762
  "default_ui" : default_ui_choice,
763
  "boost" : boost_choice,
764
+ "clear_file_list" : clear_file_list
765
  }
766
 
767
  if Path(server_config_filename).is_file():
 
794
  text_encoder_filename = server_config["text_encoder_filename"]
795
  vae_config = server_config["vae_config"]
796
  boost = server_config["boost"]
797
+ if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ):
798
  pass
799
  else:
800
  reload_needed = True
 
851
  if len(prompt) == 0:
852
  return file_list, gr.Text(visible= False, value="")
853
  else:
854
+ choice = 0
855
+ if "in_progress" in state:
856
+ choice = state.get("selected",0)
857
+
858
  prompts_max = state.get("prompts_max",0)
859
  prompt_no = state.get("prompt_no",0)
860
  if prompts_max >1 :
861
  label = f"Current Prompt ({prompt_no+1}/{prompts_max})"
862
  else:
863
  label = f"Current Prompt"
864
+ return gr.Gallery(selected_index=choice, value = file_list), gr.Text(visible= True, value=prompt, label=label)
865
 
866
 
867
  def finalize_gallery(state):
 
869
  if "in_progress" in state:
870
  del state["in_progress"]
871
  choice = state.get("selected",0)
872
+ # file_list = state.get("file_list", [])
873
+
874
 
875
  state["extra_orders"] = 0
876
  time.sleep(0.2)
 
938
  tea_cache_start_step_perc,
939
  loras_choices,
940
  loras_mult_choices,
941
+ image_prompt_type,
942
  image_to_continue,
943
  image_to_end,
944
  video_to_continue,
 
947
  slg_switch,
948
  slg_layers,
949
  slg_start,
950
+ slg_end,
951
+ cfg_star_switch,
952
+ cfg_zero_step,
953
  state,
954
  image2video,
955
  progress=gr.Progress() #track_tqdm= True
 
1042
  if len(prompts) ==0:
1043
  return
1044
  if image2video:
1045
+ if image_prompt_type == 0:
1046
+ image_to_end = None
1047
  if image_to_continue is not None:
1048
  if isinstance(image_to_continue, list):
1049
  image_to_continue = [ tup[0] for tup in image_to_continue ]
 
1148
  if "abort" in state:
1149
  del state["abort"]
1150
  state["in_progress"] = True
 
1151
 
1152
  enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
1153
  # VAE Tiling
 
1184
  if seed == None or seed <0:
1185
  seed = random.randint(0, 999999999)
1186
 
1187
+ clear_file_list = server_config.get("clear_file_list", 0)
1188
+ file_list = state.get("file_list", [])
1189
+ if clear_file_list > 0:
1190
+ file_list_current_size = len(file_list)
1191
+ keep_file_from = max(file_list_current_size - clear_file_list, 0)
1192
+ files_removed = keep_file_from
1193
+ choice = state.get("selected",0)
1194
+ choice = max(choice- files_removed, 0)
1195
+ file_list = file_list[ keep_file_from: ]
1196
+ else:
1197
+ file_list = []
1198
+ choice = 0
1199
+ state["selected"] = choice
1200
  state["file_list"] = file_list
1201
+
1202
  global save_path
1203
  os.makedirs(save_path, exist_ok=True)
1204
  video_no = 0
 
1265
  slg_layers = slg_layers,
1266
  slg_start = slg_start/100,
1267
  slg_end = slg_end/100,
1268
+ cfg_star_switch = cfg_star_switch,
1269
+ cfg_zero_step = cfg_zero_step,
1270
  )
1271
 
1272
  else:
 
1287
  slg_layers = slg_layers,
1288
  slg_start = slg_start/100,
1289
  slg_end = slg_end/100,
1290
+ cfg_star_switch = cfg_star_switch,
1291
+ cfg_zero_step = cfg_zero_step,
1292
  )
1293
  except Exception as e:
1294
  gen_in_progress = False
 
1355
  value_range=(-1, 1))
1356
 
1357
  configs = get_settings_dict(state, use_image2video, prompt, 0 if image_to_end == None else 1 , video_length, raw_resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1358
+ loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
1359
 
1360
  metadata_choice = server_config.get("metadata_choice","metadata")
1361
  if metadata_choice == "json":
 
1671
 
1672
 
1673
  def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1674
+ loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
1675
 
1676
  loras = state["loras"]
1677
  activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ]
 
1695
  "slg_switch": slg_switch,
1696
  "slg_layers": slg_layers,
1697
  "slg_start_perc": slg_start_perc,
1698
+ "slg_end_perc": slg_end_perc,
1699
+ "cfg_star_switch": cfg_star_switch,
1700
+ "cfg_zero_step": cfg_zero_step
1701
  }
1702
 
1703
  if i2v:
 
1709
  return ui_settings
1710
 
1711
  def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1712
+ loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
1713
 
1714
  if state.get("validate_success",0) != 1:
1715
  return
1716
 
1717
  ui_defaults = get_settings_dict(state, use_image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1718
+ loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
1719
 
1720
  defaults_filename = get_settings_file_name(use_image2video)
1721
 
 
1986
  label="RIFLEx positional embedding to generate long video"
1987
  )
1988
  with gr.Row():
1989
+ gr.Markdown("<B>Experimental: Skip Layer Guidance, should improve video quality</B>")
1990
  with gr.Row():
1991
  slg_switch = gr.Dropdown(
1992
  choices=[
 
2010
  with gr.Row():
2011
  slg_start_perc = gr.Slider(0, 100, value=ui_defaults["slg_start_perc"], step=1, label="Denoising Steps % start")
2012
  slg_end_perc = gr.Slider(0, 100, value=ui_defaults["slg_end_perc"], step=1, label="Denoising Steps % end")
2013
+
2014
+ with gr.Row():
2015
+ gr.Markdown("<B>Experimental: Classifier-Free Guidance Zero Star, better adherence to Text Prompt")
2016
+ with gr.Row():
2017
+ cfg_star_switch = gr.Dropdown(
2018
+ choices=[
2019
+ ("OFF", 0),
2020
+ ("ON", 1),
2021
+ ],
2022
+ value=ui_defaults.get("cfg_star_switch",0),
2023
+ visible=True,
2024
+ scale = 1,
2025
+ label="CFG Star"
2026
+ )
2027
+ with gr.Row():
2028
+ cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)")
2029
+
2030
  with gr.Row():
2031
  save_settings_btn = gr.Button("Set Settings as Default")
2032
  show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
 
2045
  save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
2046
  save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
2047
  loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
2048
+ slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = [])
2049
  save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
2050
  confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
2051
  save_lset, inputs=[state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
 
2083
  tea_cache_start_step_perc,
2084
  loras_choices,
2085
  loras_mult_choices,
2086
+ image_prompt_type_radio,
2087
  image_to_continue,
2088
  image_to_end,
2089
  video_to_continue,
 
2093
  slg_layers,
2094
  slg_start_perc,
2095
  slg_end_perc,
2096
+ cfg_star_switch,
2097
+ cfg_zero_step,
2098
  state,
2099
  gr.State(image2video)
2100
  ],
 
2226
  ("Add metadata to video", "metadata"),
2227
  ("Neither", "none")
2228
  ],
2229
+ value=server_config.get("metadata_type", "metadata"),
2230
  label="Metadata Handling"
2231
  )
2232
+
2233
+ clear_file_list_choice = gr.Dropdown(
2234
+ choices=[
2235
+ ("None", 0),
2236
+ ("Keep the last video", 1),
2237
+ ("Keep the last 5 videos", 5),
2238
+ ("Keep the last 10 videos", 10),
2239
+ ("Keep the last 20 videos", 20),
2240
+ ("Keep the last 30 videos", 30),
2241
+ ],
2242
+ value=server_config.get("clear_file_list", 0),
2243
+ label="Keep Previously Generated Videos when starting a Generation Batch"
2244
+ )
2245
+
2246
+
2247
  msg = gr.Markdown()
2248
  apply_btn = gr.Button("Apply Changes")
2249
  apply_btn.click(
 
2261
  metadata_choice,
2262
  default_ui_choice,
2263
  boost_choice,
2264
+ clear_file_list_choice,
2265
  ],
2266
  outputs= msg
2267
  )
 
2329
  }
2330
  """
2331
  with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
2332
+ gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.2 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
2333
  gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
2334
 
2335
  with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
requirements.txt CHANGED
@@ -16,6 +16,6 @@ gradio>=5.0.0
16
  numpy>=1.23.5,<2
17
  einops
18
  moviepy==1.0.3
19
- mmgp==3.3.3
20
  peft==0.14.0
21
  mutagen
 
16
  numpy>=1.23.5,<2
17
  einops
18
  moviepy==1.0.3
19
+ mmgp==3.3.4
20
  peft==0.14.0
21
  mutagen
wan/image2video.py CHANGED
@@ -28,79 +28,19 @@ from wan.modules.posemb_layers import get_rotary_pos_embed
28
 
29
  from PIL import Image
30
 
31
- def lanczos(samples, width, height):
32
- images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
33
- images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
34
- images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
35
- result = torch.stack(images)
36
- return result.to(samples.device, samples.dtype)
37
-
38
- def bislerp(samples, width, height):
39
- def slerp(b1, b2, r):
40
- '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
41
-
42
- c = b1.shape[-1]
43
-
44
- #norms
45
- b1_norms = torch.norm(b1, dim=-1, keepdim=True)
46
- b2_norms = torch.norm(b2, dim=-1, keepdim=True)
47
-
48
- #normalize
49
- b1_normalized = b1 / b1_norms
50
- b2_normalized = b2 / b2_norms
51
-
52
- #zero when norms are zero
53
- b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
54
- b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
55
-
56
- #slerp
57
- dot = (b1_normalized*b2_normalized).sum(1)
58
- omega = torch.acos(dot)
59
- so = torch.sin(omega)
60
-
61
- #technically not mathematically correct, but more pleasing?
62
- res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
63
- res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
64
-
65
- #edge cases for same or polar opposites
66
- res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
67
- res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
68
- return res
69
-
70
-
71
- def common_upscale(samples, width, height, upscale_method, crop):
72
- orig_shape = tuple(samples.shape)
73
- if len(orig_shape) > 4:
74
- samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
75
- samples = samples.movedim(2, 1)
76
- samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
77
- if crop == "center":
78
- old_width = samples.shape[-1]
79
- old_height = samples.shape[-2]
80
- old_aspect = old_width / old_height
81
- new_aspect = width / height
82
- x = 0
83
- y = 0
84
- if old_aspect > new_aspect:
85
- x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
86
- elif old_aspect < new_aspect:
87
- y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
88
- s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
89
- else:
90
- s = samples
91
 
92
- if upscale_method == "bislerp":
93
- out = bislerp(s, width, height)
94
- elif upscale_method == "lanczos":
95
- out = lanczos(s, width, height)
96
- else:
97
- out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
98
 
99
- if len(orig_shape) == 4:
100
- return out
101
 
102
- out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width))
103
- return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width))
 
 
 
104
 
105
  class WanI2V:
106
 
@@ -227,6 +167,8 @@ class WanI2V:
227
  slg_layers = None,
228
  slg_start = 0.0,
229
  slg_end = 1.0,
 
 
230
  ):
231
  r"""
232
  Generates video frames from input image and text prompt using diffusion process.
@@ -375,7 +317,7 @@ class WanI2V:
375
 
376
  # sample videos
377
  latent = noise
378
-
379
  freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
380
 
381
  arg_c = {
@@ -456,8 +398,23 @@ class WanI2V:
456
  del latent_model_input
457
  if offload_model:
458
  torch.cuda.empty_cache()
459
- noise_pred = noise_pred_uncond + guide_scale * (
460
- noise_pred_cond - noise_pred_uncond)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  del noise_pred_uncond
462
 
463
  latent = latent.to(
 
28
 
29
  from PIL import Image
30
 
31
+ def optimized_scale(positive_flat, negative_flat):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # Calculate dot production
34
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
 
 
 
 
35
 
36
+ # Squared norm of uncondition
37
+ squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
38
 
39
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
40
+ st_star = dot_product / squared_norm
41
+
42
+ return st_star
43
+
44
 
45
  class WanI2V:
46
 
 
167
  slg_layers = None,
168
  slg_start = 0.0,
169
  slg_end = 1.0,
170
+ cfg_star_switch = True,
171
+ cfg_zero_step = 5,
172
  ):
173
  r"""
174
  Generates video frames from input image and text prompt using diffusion process.
 
317
 
318
  # sample videos
319
  latent = noise
320
+ batch_size = latent.shape[0]
321
  freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
322
 
323
  arg_c = {
 
398
  del latent_model_input
399
  if offload_model:
400
  torch.cuda.empty_cache()
401
+ # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
402
+ noise_pred_text = noise_pred_cond
403
+ if cfg_star_switch:
404
+ positive_flat = noise_pred_text.view(batch_size, -1)
405
+ negative_flat = noise_pred_uncond.view(batch_size, -1)
406
+
407
+ alpha = optimized_scale(positive_flat,negative_flat)
408
+ alpha = alpha.view(batch_size, 1, 1, 1)
409
+
410
+
411
+ if (i <= cfg_zero_step):
412
+ noise_pred = noise_pred_text*0.
413
+ else:
414
+ noise_pred = noise_pred_uncond * alpha + guide_scale * (noise_pred_text - noise_pred_uncond * alpha)
415
+ else:
416
+ noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
417
+
418
  del noise_pred_uncond
419
 
420
  latent = latent.to(
wan/modules/attention.py CHANGED
@@ -70,30 +70,31 @@ def sageattn_wrapper(
70
 
71
  return o
72
 
73
- # # try:
74
  # if True:
75
- # from sageattention import sageattn_qk_int8_pv_fp8_window_cuda
76
- # @torch.compiler.disable()
77
- # def sageattn_window_wrapper(
78
- # qkv_list,
79
- # attention_length,
80
- # window
81
- # ):
82
- # q,k, v = qkv_list
83
- # padding_length = q.shape[0] -attention_length
84
- # q = q[:attention_length, :, : ].unsqueeze(0)
85
- # k = k[:attention_length, :, : ].unsqueeze(0)
86
- # v = v[:attention_length, :, : ].unsqueeze(0)
87
- # o = sageattn_qk_int8_pv_fp8_window_cuda(q, k, v, tensor_layout="NHD", window = window).squeeze(0)
88
- # del q, k ,v
89
- # qkv_list.clear()
90
-
91
- # if padding_length > 0:
92
- # o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
93
-
94
- # return o
95
- # # except ImportError:
96
- # # sageattn = sageattn_qk_int8_pv_fp8_window_cuda
 
97
 
98
  @torch.compiler.disable()
99
  def sdpa_wrapper(
@@ -253,17 +254,19 @@ def pay_attention(
253
  # nb_latents = embed_sizes[0] * embed_sizes[1]* embed_sizes[2]
254
 
255
  # window = 0
256
- # start_window_step = int(max_steps * 0.4)
257
  # start_layer = 10
258
- # if (layer < start_layer ) or current_step <start_window_step:
 
259
  # window = 0
260
  # else:
261
- # coef = min((max_steps - current_step)/(max_steps-start_window_step),1)*max(min((25 - layer)/(25-start_layer),1),0) * 0.7 + 0.3
 
262
  # print(f"step: {current_step}, layer: {layer}, coef:{coef:0.1f}]")
263
  # window = math.ceil(coef* nb_latents)
264
 
265
  # invert_spaces = (layer + current_step) % 2 == 0 and window > 0
266
-
267
  # def flip(q):
268
  # q = q.reshape(*embed_sizes, *q.shape[-2:])
269
  # q = q.transpose(0,2)
 
70
 
71
  return o
72
 
73
+ # try:
74
  # if True:
75
+ # from .sage2_core import sageattn_qk_int8_pv_fp8_window_cuda
76
+ # @torch.compiler.disable()
77
+ # def sageattn_window_wrapper(
78
+ # qkv_list,
79
+ # attention_length,
80
+ # window
81
+ # ):
82
+ # q,k, v = qkv_list
83
+ # padding_length = q.shape[0] -attention_length
84
+ # q = q[:attention_length, :, : ].unsqueeze(0)
85
+ # k = k[:attention_length, :, : ].unsqueeze(0)
86
+ # v = v[:attention_length, :, : ].unsqueeze(0)
87
+ # qkvl_list = [q, k , v]
88
+ # del q, k ,v
89
+ # o = sageattn_qk_int8_pv_fp8_window_cuda(qkvl_list, tensor_layout="NHD", window = window).squeeze(0)
90
+ # qkv_list.clear()
91
+
92
+ # if padding_length > 0:
93
+ # o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
94
+
95
+ # return o
96
+ # except ImportError:
97
+ # sageattn = sageattn_qk_int8_pv_fp8_window_cuda
98
 
99
  @torch.compiler.disable()
100
  def sdpa_wrapper(
 
254
  # nb_latents = embed_sizes[0] * embed_sizes[1]* embed_sizes[2]
255
 
256
  # window = 0
257
+ # start_window_step = int(max_steps * 0.3)
258
  # start_layer = 10
259
+ # end_layer = 30
260
+ # if (layer < start_layer or layer > end_layer ) or current_step <start_window_step:
261
  # window = 0
262
  # else:
263
+ # # coef = min((max_steps - current_step)/(max_steps-start_window_step),1)*max(min((25 - layer)/(25-start_layer),1),0) * 0.7 + 0.3
264
+ # coef = 0.3
265
  # print(f"step: {current_step}, layer: {layer}, coef:{coef:0.1f}]")
266
  # window = math.ceil(coef* nb_latents)
267
 
268
  # invert_spaces = (layer + current_step) % 2 == 0 and window > 0
269
+ # invert_spaces = False
270
  # def flip(q):
271
  # q = q.reshape(*embed_sizes, *q.shape[-2:])
272
  # q = q.transpose(0,2)
wan/modules/model.py CHANGED
@@ -647,26 +647,6 @@ class WanModel(ModelMixin, ConfigMixin):
647
  self.init_weights()
648
 
649
 
650
- # self.freqs = torch.cat([
651
- # rope_params(1024, d - 4 * (d // 6)), #44
652
- # rope_params(1024, 2 * (d // 6)), #42
653
- # rope_params(1024, 2 * (d // 6)) #42
654
- # ],dim=1)
655
-
656
-
657
- def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None, device = "cuda"):
658
- dim = self.dim
659
- num_heads = self.num_heads
660
- d = dim // num_heads
661
- assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
662
-
663
-
664
- c1, s1 = rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ) if RIFLEx_k != None else rope_params(1024, dim= d - 4 * (d // 6)) #44
665
- c2, s2 = rope_params(1024, 2 * (d // 6)) #42
666
- c3, s3 = rope_params(1024, 2 * (d // 6)) #42
667
-
668
- return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device))
669
-
670
  def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
671
  rescale_func = np.poly1d(self.coefficients)
672
  e_list = []
 
647
  self.init_weights()
648
 
649
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
  def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
651
  rescale_func = np.poly1d(self.coefficients)
652
  e_list = []
wan/modules/sage2_core.py CHANGED
@@ -925,11 +925,11 @@ def sageattn_qk_int8_pv_fp8_window_cuda(
925
 
926
  if pv_accum_dtype == "fp32":
927
  if smooth_v:
928
- lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
929
  else:
930
- lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
931
  elif pv_accum_dtype == "fp32+fp32":
932
- lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
933
 
934
  o = o[..., :head_dim_og]
935
 
 
925
 
926
  if pv_accum_dtype == "fp32":
927
  if smooth_v:
928
+ lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
929
  else:
930
+ lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
931
  elif pv_accum_dtype == "fp32+fp32":
932
+ lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
933
 
934
  o = o[..., :head_dim_og]
935
 
wan/text2video.py CHANGED
@@ -24,6 +24,20 @@ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
24
  from wan.modules.posemb_layers import get_rotary_pos_embed
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  class WanT2V:
28
 
29
  def __init__(
@@ -136,6 +150,8 @@ class WanT2V:
136
  slg_layers = None,
137
  slg_start = 0.0,
138
  slg_end = 1.0,
 
 
139
  ):
140
  r"""
141
  Generates video frames from text prompt using diffusion process.
@@ -240,7 +256,7 @@ class WanT2V:
240
 
241
  # sample videos
242
  latents = noise
243
-
244
  freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
245
  arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
246
  arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
@@ -249,7 +265,6 @@ class WanT2V:
249
  # arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
250
  # arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
251
  # arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
252
-
253
  if self.model.enable_teacache:
254
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
255
  if callback != None:
@@ -280,8 +295,23 @@ class WanT2V:
280
  return None
281
 
282
  del latent_model_input
283
- noise_pred = noise_pred_uncond + guide_scale * (
284
- noise_pred_cond - noise_pred_uncond)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  del noise_pred_uncond
286
 
287
  temp_x0 = sample_scheduler.step(
 
24
  from wan.modules.posemb_layers import get_rotary_pos_embed
25
 
26
 
27
+ def optimized_scale(positive_flat, negative_flat):
28
+
29
+ # Calculate dot production
30
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
31
+
32
+ # Squared norm of uncondition
33
+ squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
34
+
35
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
36
+ st_star = dot_product / squared_norm
37
+
38
+ return st_star
39
+
40
+
41
  class WanT2V:
42
 
43
  def __init__(
 
150
  slg_layers = None,
151
  slg_start = 0.0,
152
  slg_end = 1.0,
153
+ cfg_star_switch = True,
154
+ cfg_zero_step = 5,
155
  ):
156
  r"""
157
  Generates video frames from text prompt using diffusion process.
 
256
 
257
  # sample videos
258
  latents = noise
259
+ batch_size =len(latents)
260
  freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
261
  arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
262
  arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
 
265
  # arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
266
  # arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
267
  # arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
 
268
  if self.model.enable_teacache:
269
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
270
  if callback != None:
 
295
  return None
296
 
297
  del latent_model_input
298
+
299
+ # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
300
+ noise_pred_text = noise_pred_cond
301
+ if cfg_star_switch:
302
+ positive_flat = noise_pred_text.view(batch_size, -1)
303
+ negative_flat = noise_pred_uncond.view(batch_size, -1)
304
+
305
+ alpha = optimized_scale(positive_flat,negative_flat)
306
+ alpha = alpha.view(batch_size, 1, 1, 1)
307
+
308
+
309
+ if (i <= cfg_zero_step):
310
+ noise_pred = noise_pred_text*0.
311
+ else:
312
+ noise_pred = noise_pred_uncond * alpha + guide_scale * (noise_pred_text - noise_pred_uncond * alpha)
313
+ else:
314
+ noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
315
  del noise_pred_uncond
316
 
317
  temp_x0 = sample_scheduler.step(