Tophness2022 commited on
Commit
b0bdc4c
·
2 Parent(s): 4fba9c9941388c

Merge remote-tracking branch 'upstream/main' into queues

Browse files
README.md CHANGED
@@ -19,6 +19,12 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
19
 
20
 
21
  ## 🔥 Latest News!!
 
 
 
 
 
 
22
  * Mar 19 2022: 👋 Wan2.1GP v3.1: Faster launch and RAM optimizations (should require less RAM to run)\
23
  You will need one more *pip install -r requirements.txt*
24
  * Mar 18 2022: 👋 Wan2.1GP v3.0:
 
19
 
20
 
21
  ## 🔥 Latest News!!
22
+ * Mar 19 2022: 👋 Wan2.1GP v3.2:
23
+ - Added Classifier-Free Guidance Zero Star. The video should match better the text prompt (especially with text2video) at no performance cost: many thanks to the **CFG Zero * Team:**\
24
+ Dont hesitate to give them a star if you appreciate the results: https://github.com/WeichenFan/CFG-Zero-star
25
+ - Added back support for Pytorch compilation with Loras. It seems it had been broken for some time
26
+ - Added possibility to keep a number of pregenerated videos in the Video Gallery (useful to compare outputs of different settings)
27
+ You will need one more *pip install -r requirements.txt*
28
  * Mar 19 2022: 👋 Wan2.1GP v3.1: Faster launch and RAM optimizations (should require less RAM to run)\
29
  You will need one more *pip install -r requirements.txt*
30
  * Mar 18 2022: 👋 Wan2.1GP v3.0:
gradio_server.py CHANGED
@@ -24,7 +24,7 @@ import asyncio
24
  from wan.utils import prompt_parser
25
  PROMPT_VARS_MAX = 10
26
 
27
- target_mmgp_version = "3.3.3"
28
  from importlib.metadata import version
29
  mmgp_version = version("mmgp")
30
  if mmgp_version != target_mmgp_version:
@@ -98,6 +98,7 @@ def process_prompt_and_add_tasks(
98
  tea_cache_start_step_perc,
99
  loras_choices,
100
  loras_mult_choices,
 
101
  image_to_continue,
102
  image_to_end,
103
  video_to_continue,
@@ -107,6 +108,8 @@ def process_prompt_and_add_tasks(
107
  slg_layers,
108
  slg_start,
109
  slg_end,
 
 
110
  state_arg,
111
  image2video
112
  ):
@@ -138,6 +141,7 @@ def process_prompt_and_add_tasks(
138
  tea_cache_start_step_perc,
139
  loras_choices,
140
  loras_mult_choices,
 
141
  image_to_continue,
142
  image_to_end,
143
  video_to_continue,
@@ -147,6 +151,8 @@ def process_prompt_and_add_tasks(
147
  slg_layers,
148
  slg_start,
149
  slg_end,
 
 
150
  state_arg,
151
  image2video
152
  )
@@ -380,7 +386,6 @@ def _parse_args():
380
  default="",
381
  help="Server name"
382
  )
383
-
384
  parser.add_argument(
385
  "--gpu",
386
  type=str,
@@ -482,7 +487,6 @@ def get_lora_dir(i2v):
482
 
483
  attention_modes_installed = get_attention_modes()
484
  attention_modes_supported = get_supported_attention_modes()
485
-
486
  args = _parse_args()
487
  args.flow_reverse = True
488
 
@@ -513,6 +517,7 @@ if not Path(server_config_filename).is_file():
513
  "metadata_type": "metadata",
514
  "default_ui": "t2v",
515
  "boost" : 1,
 
516
  "vae_config": 0,
517
  "profile" : profile_type.LowRAM_LowVRAM,
518
  "reload_model": 2 }
@@ -596,7 +601,6 @@ if len(args.vae_config) > 0:
596
 
597
  reload_needed = False
598
  default_ui = server_config.get("default_ui", "t2v")
599
- metadata = server_config.get("metadata_type", "metadata")
600
  save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
601
  use_image2video = default_ui != "t2v"
602
  if args.t2v:
@@ -956,6 +960,7 @@ def apply_changes( state,
956
  metadata_choice,
957
  default_ui_choice ="t2v",
958
  boost_choice = 1,
 
959
  reload_choice = 1
960
  ):
961
  if args.lock_config:
@@ -975,6 +980,7 @@ def apply_changes( state,
975
  "metadata_choice": metadata_choice,
976
  "default_ui" : default_ui_choice,
977
  "boost" : boost_choice,
 
978
  "reload_model" : reload_choice,
979
  }
980
 
@@ -1008,7 +1014,7 @@ def apply_changes( state,
1008
  text_encoder_filename = server_config["text_encoder_filename"]
1009
  vae_config = server_config["vae_config"]
1010
  boost = server_config["boost"]
1011
- if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice"] for change in changes ):
1012
  pass
1013
  else:
1014
  reload_needed = True
@@ -1059,6 +1065,10 @@ def finalize_gallery(state):
1059
  if "in_progress" in state:
1060
  del state["in_progress"]
1061
  choice = state.get("selected",0)
 
 
 
 
1062
  time.sleep(0.2)
1063
  global gen_in_progress
1064
  gen_in_progress = False
@@ -1096,6 +1106,7 @@ def generate_video(
1096
  tea_cache_start_step_perc,
1097
  loras_choices,
1098
  loras_mult_choices,
 
1099
  image_to_continue,
1100
  image_to_end,
1101
  video_to_continue,
@@ -1104,7 +1115,9 @@ def generate_video(
1104
  slg_switch,
1105
  slg_layers,
1106
  slg_start,
1107
- slg_end,
 
 
1108
  state,
1109
  image2video,
1110
  progress=gr.Progress() #track_tqdm= True
@@ -1243,7 +1256,6 @@ def generate_video(
1243
  if "abort" in state:
1244
  del state["abort"]
1245
  state["in_progress"] = True
1246
- state["selected"] = 0
1247
 
1248
  enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
1249
  # VAE Tiling
@@ -1281,17 +1293,25 @@ def generate_video(
1281
  seed = random.randint(0, 999999999)
1282
 
1283
  global file_list
1284
- state["file_list"] = file_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1285
  global save_path
1286
  os.makedirs(save_path, exist_ok=True)
1287
- abort = False
1288
- if trans.enable_teacache:
1289
- trans.teacache_counter = 0
1290
- trans.num_steps = num_inference_steps
1291
- trans.teacache_skipped_steps = 0
1292
- trans.previous_residual_uncond = None
1293
- trans.previous_residual_cond = None
1294
  video_no = 0
 
1295
  repeats = f"{video_no}/{repeat_generation}"
1296
  callback = build_callback(task_id, state, trans, num_inference_steps, repeats)
1297
  offload.shared_state["callback"] = callback
@@ -1301,14 +1321,22 @@ def generate_video(
1301
  for i in range(repeat_generation):
1302
  try:
1303
  with tracker_lock:
 
1304
  progress_tracker[task_id] = {
1305
  'current_step': 0,
1306
  'total_steps': num_inference_steps,
1307
- 'start_time': time.time(),
1308
- 'last_update': time.time(),
1309
  'repeats': f"{video_no}/{repeat_generation}",
1310
  'status': "Encoding Prompt"
1311
  }
 
 
 
 
 
 
 
1312
  video_no += 1
1313
  if image2video:
1314
  samples = wan_model.generate(
@@ -1330,6 +1358,8 @@ def generate_video(
1330
  slg_layers = slg_layers,
1331
  slg_start = slg_start/100,
1332
  slg_end = slg_end/100,
 
 
1333
  )
1334
  else:
1335
  samples = wan_model.generate(
@@ -1349,13 +1379,15 @@ def generate_video(
1349
  slg_layers = slg_layers,
1350
  slg_start = slg_start/100,
1351
  slg_end = slg_end/100,
 
 
1352
  )
1353
  except Exception as e:
1354
  gen_in_progress = False
1355
  if temp_filename!= None and os.path.isfile(temp_filename):
1356
  os.remove(temp_filename)
1357
- if(offload.last_offload_obj): offload.last_offload_obj.unload_all()
1358
- if(trans): offload.unload_loras_from_model(trans)
1359
  # if compile:
1360
  # cache_size = torch._dynamo.config.cache_size_limit
1361
  # torch.compiler.reset()
@@ -1399,6 +1431,7 @@ def generate_video(
1399
  end_time = time.time()
1400
  abort = True
1401
  state["prompt"] = ""
 
1402
  else:
1403
  sample = samples.cpu()
1404
  # video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
@@ -1416,9 +1449,9 @@ def generate_video(
1416
  nrow=1,
1417
  normalize=True,
1418
  value_range=(-1, 1))
1419
-
1420
  configs = get_settings_dict(state, use_image2video, prompt, 0 if image_to_end == None else 1 , video_length, raw_resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1421
- loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end)
1422
 
1423
  metadata_choice = server_config.get("metadata_choice","metadata")
1424
  if metadata_choice == "json":
@@ -1432,7 +1465,15 @@ def generate_video(
1432
 
1433
  print(f"New video saved to Path: "+video_path)
1434
  file_list.append(video_path)
 
 
 
 
 
 
1435
  seed += 1
 
 
1436
  last_model_type = image2video
1437
 
1438
  if temp_filename!= None and os.path.isfile(temp_filename):
@@ -1725,8 +1766,9 @@ def switch_advanced(state, new_advanced, lset_name):
1725
  else:
1726
  return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
1727
 
 
1728
  def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1729
- loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc):
1730
 
1731
  loras = state["loras"]
1732
  activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ]
@@ -1750,7 +1792,9 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
1750
  "slg_switch": slg_switch,
1751
  "slg_layers": slg_layers,
1752
  "slg_start_perc": slg_start_perc,
1753
- "slg_end_perc": slg_end_perc
 
 
1754
  }
1755
 
1756
  if i2v:
@@ -1758,21 +1802,22 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
1758
  ui_settings["image_prompt_type"] = image_prompt_type
1759
  else:
1760
  ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video"
 
1761
  return ui_settings
1762
 
1763
  def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1764
- loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc):
1765
 
1766
  if state.get("validate_success",0) != 1:
1767
  return
1768
 
1769
  ui_defaults = get_settings_dict(state, use_image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1770
- loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc)
1771
 
1772
  defaults_filename = get_settings_file_name(use_image2video)
1773
 
1774
  with open(defaults_filename, "w", encoding="utf-8") as f:
1775
- json.dump(ui_defaults , f, indent=4)
1776
 
1777
  gr.Info("New Default Settings saved")
1778
 
@@ -1907,6 +1952,7 @@ def generate_video_tab(image2video=False):
1907
  return gr.Gallery(visible = (image_prompt_type_radio == 1) )
1908
  else:
1909
  return gr.Image(visible = (image_prompt_type_radio == 1) )
 
1910
  image_prompt_type_radio.change(fn=switch_image_prompt_type_radio, inputs=[image_prompt_type_radio], outputs=[image_to_end])
1911
 
1912
 
@@ -2037,7 +2083,7 @@ def generate_video_tab(image2video=False):
2037
  label="RIFLEx positional embedding to generate long video"
2038
  )
2039
  with gr.Row():
2040
- gr.Markdown("<B>Experimental: Skip Layer guidance,should improve video quality</B>")
2041
  with gr.Row():
2042
  slg_switch = gr.Dropdown(
2043
  choices=[
@@ -2061,6 +2107,23 @@ def generate_video_tab(image2video=False):
2061
  with gr.Row():
2062
  slg_start_perc = gr.Slider(0, 100, value=ui_defaults["slg_start_perc"], step=1, label="Denoising Steps % start")
2063
  slg_end_perc = gr.Slider(0, 100, value=ui_defaults["slg_end_perc"], step=1, label="Denoising Steps % end")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2064
  with gr.Row():
2065
  save_settings_btn = gr.Button("Set Settings as Default")
2066
  show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
@@ -2103,7 +2166,7 @@ def generate_video_tab(image2video=False):
2103
  save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
2104
  save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
2105
  loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
2106
- slg_start_perc, slg_end_perc ], outputs = [])
2107
  save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
2108
  confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
2109
  save_lset, inputs=[state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
@@ -2133,6 +2196,7 @@ def generate_video_tab(image2video=False):
2133
  tea_cache_start_step_perc,
2134
  loras_choices,
2135
  loras_mult_choices,
 
2136
  image_to_continue,
2137
  image_to_end,
2138
  video_to_continue,
@@ -2142,6 +2206,8 @@ def generate_video_tab(image2video=False):
2142
  slg_layers,
2143
  slg_start_perc,
2144
  slg_end_perc,
 
 
2145
  state,
2146
  gr.State(image2video)
2147
  ]
@@ -2276,7 +2342,7 @@ def generate_configuration_tab():
2276
  ("Add metadata to video", "metadata"),
2277
  ("Neither", "none")
2278
  ],
2279
- value=metadata,
2280
  label="Metadata Handling"
2281
  )
2282
  reload_choice = gr.Dropdown(
@@ -2287,6 +2353,21 @@ def generate_configuration_tab():
2287
  value=server_config.get("reload_model",2),
2288
  label="Reload model"
2289
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2290
  msg = gr.Markdown()
2291
  apply_btn = gr.Button("Apply Changes")
2292
  apply_btn.click(
@@ -2304,6 +2385,7 @@ def generate_configuration_tab():
2304
  metadata_choice,
2305
  default_ui_choice,
2306
  boost_choice,
 
2307
  reload_choice,
2308
  ],
2309
  outputs= msg
@@ -2322,6 +2404,9 @@ def generate_about_tab():
2322
  def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
2323
  global lora_model_filename, use_image2video
2324
 
 
 
 
2325
  new_t2v = evt.index == 0
2326
  new_i2v = evt.index == 1
2327
  use_image2video = new_i2v
@@ -2341,9 +2426,6 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
2341
  wan_model, offloadobj, trans = load_models(use_image2video)
2342
  del trans
2343
 
2344
- t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
2345
- i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
2346
-
2347
  if new_t2v:
2348
  lora_model_filename = t2v_state["loras_model"]
2349
  if ("1.3B" in transformer_filename_t2v and not "1.3B" in lora_model_filename or "14B" in transformer_filename_t2v and not "14B" in lora_model_filename):
@@ -2470,7 +2552,7 @@ def create_demo():
2470
  }
2471
  """
2472
  with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
2473
- gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.1 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
2474
  gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
2475
 
2476
  with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
 
24
  from wan.utils import prompt_parser
25
  PROMPT_VARS_MAX = 10
26
 
27
+ target_mmgp_version = "3.3.4"
28
  from importlib.metadata import version
29
  mmgp_version = version("mmgp")
30
  if mmgp_version != target_mmgp_version:
 
98
  tea_cache_start_step_perc,
99
  loras_choices,
100
  loras_mult_choices,
101
+ image_prompt_type,
102
  image_to_continue,
103
  image_to_end,
104
  video_to_continue,
 
108
  slg_layers,
109
  slg_start,
110
  slg_end,
111
+ cfg_star_switch,
112
+ cfg_zero_step,
113
  state_arg,
114
  image2video
115
  ):
 
141
  tea_cache_start_step_perc,
142
  loras_choices,
143
  loras_mult_choices,
144
+ image_prompt_type,
145
  image_to_continue,
146
  image_to_end,
147
  video_to_continue,
 
151
  slg_layers,
152
  slg_start,
153
  slg_end,
154
+ cfg_star_switch,
155
+ cfg_zero_step,
156
  state_arg,
157
  image2video
158
  )
 
386
  default="",
387
  help="Server name"
388
  )
 
389
  parser.add_argument(
390
  "--gpu",
391
  type=str,
 
487
 
488
  attention_modes_installed = get_attention_modes()
489
  attention_modes_supported = get_supported_attention_modes()
 
490
  args = _parse_args()
491
  args.flow_reverse = True
492
 
 
517
  "metadata_type": "metadata",
518
  "default_ui": "t2v",
519
  "boost" : 1,
520
+ "clear_file_list" : 0,
521
  "vae_config": 0,
522
  "profile" : profile_type.LowRAM_LowVRAM,
523
  "reload_model": 2 }
 
601
 
602
  reload_needed = False
603
  default_ui = server_config.get("default_ui", "t2v")
 
604
  save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
605
  use_image2video = default_ui != "t2v"
606
  if args.t2v:
 
960
  metadata_choice,
961
  default_ui_choice ="t2v",
962
  boost_choice = 1,
963
+ clear_file_list = 0,
964
  reload_choice = 1
965
  ):
966
  if args.lock_config:
 
980
  "metadata_choice": metadata_choice,
981
  "default_ui" : default_ui_choice,
982
  "boost" : boost_choice,
983
+ "clear_file_list" : clear_file_list
984
  "reload_model" : reload_choice,
985
  }
986
 
 
1014
  text_encoder_filename = server_config["text_encoder_filename"]
1015
  vae_config = server_config["vae_config"]
1016
  boost = server_config["boost"]
1017
+ if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ):
1018
  pass
1019
  else:
1020
  reload_needed = True
 
1065
  if "in_progress" in state:
1066
  del state["in_progress"]
1067
  choice = state.get("selected",0)
1068
+ # file_list = state.get("file_list", [])
1069
+
1070
+
1071
+ state["extra_orders"] = 0
1072
  time.sleep(0.2)
1073
  global gen_in_progress
1074
  gen_in_progress = False
 
1106
  tea_cache_start_step_perc,
1107
  loras_choices,
1108
  loras_mult_choices,
1109
+ image_prompt_type,
1110
  image_to_continue,
1111
  image_to_end,
1112
  video_to_continue,
 
1115
  slg_switch,
1116
  slg_layers,
1117
  slg_start,
1118
+ slg_end,
1119
+ cfg_star_switch,
1120
+ cfg_zero_step,
1121
  state,
1122
  image2video,
1123
  progress=gr.Progress() #track_tqdm= True
 
1256
  if "abort" in state:
1257
  del state["abort"]
1258
  state["in_progress"] = True
 
1259
 
1260
  enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
1261
  # VAE Tiling
 
1293
  seed = random.randint(0, 999999999)
1294
 
1295
  global file_list
1296
+ clear_file_list = server_config.get("clear_file_list", 0)
1297
+ file_list = state.get("file_list", [])
1298
+ if clear_file_list > 0:
1299
+ file_list_current_size = len(file_list)
1300
+ keep_file_from = max(file_list_current_size - clear_file_list, 0)
1301
+ files_removed = keep_file_from
1302
+ choice = state.get("selected",0)
1303
+ choice = max(choice- files_removed, 0)
1304
+ file_list = file_list[ keep_file_from: ]
1305
+ else:
1306
+ file_list = []
1307
+ choice = 0
1308
+ state["selected"] = choice
1309
+ state["file_list"] = file_list
1310
+
1311
  global save_path
1312
  os.makedirs(save_path, exist_ok=True)
 
 
 
 
 
 
 
1313
  video_no = 0
1314
+ abort = False
1315
  repeats = f"{video_no}/{repeat_generation}"
1316
  callback = build_callback(task_id, state, trans, num_inference_steps, repeats)
1317
  offload.shared_state["callback"] = callback
 
1321
  for i in range(repeat_generation):
1322
  try:
1323
  with tracker_lock:
1324
+ start_time = time.time()
1325
  progress_tracker[task_id] = {
1326
  'current_step': 0,
1327
  'total_steps': num_inference_steps,
1328
+ 'start_time': start_time,
1329
+ 'last_update': start_time,
1330
  'repeats': f"{video_no}/{repeat_generation}",
1331
  'status': "Encoding Prompt"
1332
  }
1333
+ if trans.enable_teacache:
1334
+ trans.teacache_counter = 0
1335
+ trans.num_steps = num_inference_steps
1336
+ trans.teacache_skipped_steps = 0
1337
+ trans.previous_residual_uncond = None
1338
+ trans.previous_residual_cond = None
1339
+
1340
  video_no += 1
1341
  if image2video:
1342
  samples = wan_model.generate(
 
1358
  slg_layers = slg_layers,
1359
  slg_start = slg_start/100,
1360
  slg_end = slg_end/100,
1361
+ cfg_star_switch = cfg_star_switch,
1362
+ cfg_zero_step = cfg_zero_step,
1363
  )
1364
  else:
1365
  samples = wan_model.generate(
 
1379
  slg_layers = slg_layers,
1380
  slg_start = slg_start/100,
1381
  slg_end = slg_end/100,
1382
+ cfg_star_switch = cfg_star_switch,
1383
+ cfg_zero_step = cfg_zero_step,
1384
  )
1385
  except Exception as e:
1386
  gen_in_progress = False
1387
  if temp_filename!= None and os.path.isfile(temp_filename):
1388
  os.remove(temp_filename)
1389
+ offload.last_offload_obj.unload_all()
1390
+ offload.unload_loras_from_model(trans)
1391
  # if compile:
1392
  # cache_size = torch._dynamo.config.cache_size_limit
1393
  # torch.compiler.reset()
 
1431
  end_time = time.time()
1432
  abort = True
1433
  state["prompt"] = ""
1434
+ yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
1435
  else:
1436
  sample = samples.cpu()
1437
  # video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
 
1449
  nrow=1,
1450
  normalize=True,
1451
  value_range=(-1, 1))
1452
+
1453
  configs = get_settings_dict(state, use_image2video, prompt, 0 if image_to_end == None else 1 , video_length, raw_resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1454
+ loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
1455
 
1456
  metadata_choice = server_config.get("metadata_choice","metadata")
1457
  if metadata_choice == "json":
 
1465
 
1466
  print(f"New video saved to Path: "+video_path)
1467
  file_list.append(video_path)
1468
+ if video_no < total_video:
1469
+ yield status
1470
+ else:
1471
+ end_time = time.time()
1472
+ state["prompt"] = ""
1473
+ yield f"Total Generation Time: {end_time-start_time:.1f}s"
1474
  seed += 1
1475
+ repeat_no += 1
1476
+
1477
  last_model_type = image2video
1478
 
1479
  if temp_filename!= None and os.path.isfile(temp_filename):
 
1766
  else:
1767
  return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
1768
 
1769
+
1770
  def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1771
+ loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
1772
 
1773
  loras = state["loras"]
1774
  activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ]
 
1792
  "slg_switch": slg_switch,
1793
  "slg_layers": slg_layers,
1794
  "slg_start_perc": slg_start_perc,
1795
+ "slg_end_perc": slg_end_perc,
1796
+ "cfg_star_switch": cfg_star_switch,
1797
+ "cfg_zero_step": cfg_zero_step
1798
  }
1799
 
1800
  if i2v:
 
1802
  ui_settings["image_prompt_type"] = image_prompt_type
1803
  else:
1804
  ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video"
1805
+
1806
  return ui_settings
1807
 
1808
  def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1809
+ loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
1810
 
1811
  if state.get("validate_success",0) != 1:
1812
  return
1813
 
1814
  ui_defaults = get_settings_dict(state, use_image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1815
+ loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
1816
 
1817
  defaults_filename = get_settings_file_name(use_image2video)
1818
 
1819
  with open(defaults_filename, "w", encoding="utf-8") as f:
1820
+ json.dump(ui_defaults, f, indent=4)
1821
 
1822
  gr.Info("New Default Settings saved")
1823
 
 
1952
  return gr.Gallery(visible = (image_prompt_type_radio == 1) )
1953
  else:
1954
  return gr.Image(visible = (image_prompt_type_radio == 1) )
1955
+
1956
  image_prompt_type_radio.change(fn=switch_image_prompt_type_radio, inputs=[image_prompt_type_radio], outputs=[image_to_end])
1957
 
1958
 
 
2083
  label="RIFLEx positional embedding to generate long video"
2084
  )
2085
  with gr.Row():
2086
+ gr.Markdown("<B>Experimental: Skip Layer Guidance, should improve video quality</B>")
2087
  with gr.Row():
2088
  slg_switch = gr.Dropdown(
2089
  choices=[
 
2107
  with gr.Row():
2108
  slg_start_perc = gr.Slider(0, 100, value=ui_defaults["slg_start_perc"], step=1, label="Denoising Steps % start")
2109
  slg_end_perc = gr.Slider(0, 100, value=ui_defaults["slg_end_perc"], step=1, label="Denoising Steps % end")
2110
+
2111
+ with gr.Row():
2112
+ gr.Markdown("<B>Experimental: Classifier-Free Guidance Zero Star, better adherence to Text Prompt")
2113
+ with gr.Row():
2114
+ cfg_star_switch = gr.Dropdown(
2115
+ choices=[
2116
+ ("OFF", 0),
2117
+ ("ON", 1),
2118
+ ],
2119
+ value=ui_defaults.get("cfg_star_switch",0),
2120
+ visible=True,
2121
+ scale = 1,
2122
+ label="CFG Star"
2123
+ )
2124
+ with gr.Row():
2125
+ cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)")
2126
+
2127
  with gr.Row():
2128
  save_settings_btn = gr.Button("Set Settings as Default")
2129
  show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
 
2166
  save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
2167
  save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
2168
  loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
2169
+ slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = [])
2170
  save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
2171
  confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
2172
  save_lset, inputs=[state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
 
2196
  tea_cache_start_step_perc,
2197
  loras_choices,
2198
  loras_mult_choices,
2199
+ image_prompt_type_radio,
2200
  image_to_continue,
2201
  image_to_end,
2202
  video_to_continue,
 
2206
  slg_layers,
2207
  slg_start_perc,
2208
  slg_end_perc,
2209
+ cfg_star_switch,
2210
+ cfg_zero_step,
2211
  state,
2212
  gr.State(image2video)
2213
  ]
 
2342
  ("Add metadata to video", "metadata"),
2343
  ("Neither", "none")
2344
  ],
2345
+ value=server_config.get("metadata_type", "metadata"),
2346
  label="Metadata Handling"
2347
  )
2348
  reload_choice = gr.Dropdown(
 
2353
  value=server_config.get("reload_model",2),
2354
  label="Reload model"
2355
  )
2356
+
2357
+ clear_file_list_choice = gr.Dropdown(
2358
+ choices=[
2359
+ ("None", 0),
2360
+ ("Keep the last video", 1),
2361
+ ("Keep the last 5 videos", 5),
2362
+ ("Keep the last 10 videos", 10),
2363
+ ("Keep the last 20 videos", 20),
2364
+ ("Keep the last 30 videos", 30),
2365
+ ],
2366
+ value=server_config.get("clear_file_list", 0),
2367
+ label="Keep Previously Generated Videos when starting a Generation Batch"
2368
+ )
2369
+
2370
+
2371
  msg = gr.Markdown()
2372
  apply_btn = gr.Button("Apply Changes")
2373
  apply_btn.click(
 
2385
  metadata_choice,
2386
  default_ui_choice,
2387
  boost_choice,
2388
+ clear_file_list_choice,
2389
  reload_choice,
2390
  ],
2391
  outputs= msg
 
2404
  def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
2405
  global lora_model_filename, use_image2video
2406
 
2407
+ t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
2408
+ i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
2409
+
2410
  new_t2v = evt.index == 0
2411
  new_i2v = evt.index == 1
2412
  use_image2video = new_i2v
 
2426
  wan_model, offloadobj, trans = load_models(use_image2video)
2427
  del trans
2428
 
 
 
 
2429
  if new_t2v:
2430
  lora_model_filename = t2v_state["loras_model"]
2431
  if ("1.3B" in transformer_filename_t2v and not "1.3B" in lora_model_filename or "14B" in transformer_filename_t2v and not "14B" in lora_model_filename):
 
2552
  }
2553
  """
2554
  with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
2555
+ gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.2 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
2556
  gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
2557
 
2558
  with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
requirements.txt CHANGED
@@ -16,6 +16,6 @@ gradio>=5.0.0
16
  numpy>=1.23.5,<2
17
  einops
18
  moviepy==1.0.3
19
- mmgp==3.3.3
20
  peft==0.14.0
21
  mutagen
 
16
  numpy>=1.23.5,<2
17
  einops
18
  moviepy==1.0.3
19
+ mmgp==3.3.4
20
  peft==0.14.0
21
  mutagen
wan/image2video.py CHANGED
@@ -28,79 +28,19 @@ from wan.modules.posemb_layers import get_rotary_pos_embed
28
 
29
  from PIL import Image
30
 
31
- def lanczos(samples, width, height):
32
- images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
33
- images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
34
- images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
35
- result = torch.stack(images)
36
- return result.to(samples.device, samples.dtype)
37
-
38
- def bislerp(samples, width, height):
39
- def slerp(b1, b2, r):
40
- '''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
41
-
42
- c = b1.shape[-1]
43
-
44
- #norms
45
- b1_norms = torch.norm(b1, dim=-1, keepdim=True)
46
- b2_norms = torch.norm(b2, dim=-1, keepdim=True)
47
-
48
- #normalize
49
- b1_normalized = b1 / b1_norms
50
- b2_normalized = b2 / b2_norms
51
-
52
- #zero when norms are zero
53
- b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
54
- b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
55
-
56
- #slerp
57
- dot = (b1_normalized*b2_normalized).sum(1)
58
- omega = torch.acos(dot)
59
- so = torch.sin(omega)
60
-
61
- #technically not mathematically correct, but more pleasing?
62
- res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
63
- res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
64
-
65
- #edge cases for same or polar opposites
66
- res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
67
- res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
68
- return res
69
-
70
-
71
- def common_upscale(samples, width, height, upscale_method, crop):
72
- orig_shape = tuple(samples.shape)
73
- if len(orig_shape) > 4:
74
- samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
75
- samples = samples.movedim(2, 1)
76
- samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
77
- if crop == "center":
78
- old_width = samples.shape[-1]
79
- old_height = samples.shape[-2]
80
- old_aspect = old_width / old_height
81
- new_aspect = width / height
82
- x = 0
83
- y = 0
84
- if old_aspect > new_aspect:
85
- x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
86
- elif old_aspect < new_aspect:
87
- y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
88
- s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
89
- else:
90
- s = samples
91
 
92
- if upscale_method == "bislerp":
93
- out = bislerp(s, width, height)
94
- elif upscale_method == "lanczos":
95
- out = lanczos(s, width, height)
96
- else:
97
- out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
98
 
99
- if len(orig_shape) == 4:
100
- return out
101
 
102
- out = out.reshape((orig_shape[0], -1, orig_shape[1]) + (height, width))
103
- return out.movedim(2, 1).reshape(orig_shape[:-2] + (height, width))
 
 
 
104
 
105
  class WanI2V:
106
 
@@ -227,6 +167,8 @@ class WanI2V:
227
  slg_layers = None,
228
  slg_start = 0.0,
229
  slg_end = 1.0,
 
 
230
  ):
231
  r"""
232
  Generates video frames from input image and text prompt using diffusion process.
@@ -375,7 +317,7 @@ class WanI2V:
375
 
376
  # sample videos
377
  latent = noise
378
-
379
  freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
380
 
381
  arg_c = {
@@ -456,8 +398,23 @@ class WanI2V:
456
  del latent_model_input
457
  if offload_model:
458
  torch.cuda.empty_cache()
459
- noise_pred = noise_pred_uncond + guide_scale * (
460
- noise_pred_cond - noise_pred_uncond)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  del noise_pred_uncond
462
 
463
  latent = latent.to(
 
28
 
29
  from PIL import Image
30
 
31
+ def optimized_scale(positive_flat, negative_flat):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # Calculate dot production
34
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
 
 
 
 
35
 
36
+ # Squared norm of uncondition
37
+ squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
38
 
39
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
40
+ st_star = dot_product / squared_norm
41
+
42
+ return st_star
43
+
44
 
45
  class WanI2V:
46
 
 
167
  slg_layers = None,
168
  slg_start = 0.0,
169
  slg_end = 1.0,
170
+ cfg_star_switch = True,
171
+ cfg_zero_step = 5,
172
  ):
173
  r"""
174
  Generates video frames from input image and text prompt using diffusion process.
 
317
 
318
  # sample videos
319
  latent = noise
320
+ batch_size = latent.shape[0]
321
  freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
322
 
323
  arg_c = {
 
398
  del latent_model_input
399
  if offload_model:
400
  torch.cuda.empty_cache()
401
+ # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
402
+ noise_pred_text = noise_pred_cond
403
+ if cfg_star_switch:
404
+ positive_flat = noise_pred_text.view(batch_size, -1)
405
+ negative_flat = noise_pred_uncond.view(batch_size, -1)
406
+
407
+ alpha = optimized_scale(positive_flat,negative_flat)
408
+ alpha = alpha.view(batch_size, 1, 1, 1)
409
+
410
+
411
+ if (i <= cfg_zero_step):
412
+ noise_pred = noise_pred_text*0.
413
+ else:
414
+ noise_pred = noise_pred_uncond * alpha + guide_scale * (noise_pred_text - noise_pred_uncond * alpha)
415
+ else:
416
+ noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
417
+
418
  del noise_pred_uncond
419
 
420
  latent = latent.to(
wan/modules/attention.py CHANGED
@@ -70,30 +70,31 @@ def sageattn_wrapper(
70
 
71
  return o
72
 
73
- # # try:
74
  # if True:
75
- # from sageattention import sageattn_qk_int8_pv_fp8_window_cuda
76
- # @torch.compiler.disable()
77
- # def sageattn_window_wrapper(
78
- # qkv_list,
79
- # attention_length,
80
- # window
81
- # ):
82
- # q,k, v = qkv_list
83
- # padding_length = q.shape[0] -attention_length
84
- # q = q[:attention_length, :, : ].unsqueeze(0)
85
- # k = k[:attention_length, :, : ].unsqueeze(0)
86
- # v = v[:attention_length, :, : ].unsqueeze(0)
87
- # o = sageattn_qk_int8_pv_fp8_window_cuda(q, k, v, tensor_layout="NHD", window = window).squeeze(0)
88
- # del q, k ,v
89
- # qkv_list.clear()
90
-
91
- # if padding_length > 0:
92
- # o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
93
-
94
- # return o
95
- # # except ImportError:
96
- # # sageattn = sageattn_qk_int8_pv_fp8_window_cuda
 
97
 
98
  @torch.compiler.disable()
99
  def sdpa_wrapper(
@@ -253,17 +254,19 @@ def pay_attention(
253
  # nb_latents = embed_sizes[0] * embed_sizes[1]* embed_sizes[2]
254
 
255
  # window = 0
256
- # start_window_step = int(max_steps * 0.4)
257
  # start_layer = 10
258
- # if (layer < start_layer ) or current_step <start_window_step:
 
259
  # window = 0
260
  # else:
261
- # coef = min((max_steps - current_step)/(max_steps-start_window_step),1)*max(min((25 - layer)/(25-start_layer),1),0) * 0.7 + 0.3
 
262
  # print(f"step: {current_step}, layer: {layer}, coef:{coef:0.1f}]")
263
  # window = math.ceil(coef* nb_latents)
264
 
265
  # invert_spaces = (layer + current_step) % 2 == 0 and window > 0
266
-
267
  # def flip(q):
268
  # q = q.reshape(*embed_sizes, *q.shape[-2:])
269
  # q = q.transpose(0,2)
 
70
 
71
  return o
72
 
73
+ # try:
74
  # if True:
75
+ # from .sage2_core import sageattn_qk_int8_pv_fp8_window_cuda
76
+ # @torch.compiler.disable()
77
+ # def sageattn_window_wrapper(
78
+ # qkv_list,
79
+ # attention_length,
80
+ # window
81
+ # ):
82
+ # q,k, v = qkv_list
83
+ # padding_length = q.shape[0] -attention_length
84
+ # q = q[:attention_length, :, : ].unsqueeze(0)
85
+ # k = k[:attention_length, :, : ].unsqueeze(0)
86
+ # v = v[:attention_length, :, : ].unsqueeze(0)
87
+ # qkvl_list = [q, k , v]
88
+ # del q, k ,v
89
+ # o = sageattn_qk_int8_pv_fp8_window_cuda(qkvl_list, tensor_layout="NHD", window = window).squeeze(0)
90
+ # qkv_list.clear()
91
+
92
+ # if padding_length > 0:
93
+ # o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
94
+
95
+ # return o
96
+ # except ImportError:
97
+ # sageattn = sageattn_qk_int8_pv_fp8_window_cuda
98
 
99
  @torch.compiler.disable()
100
  def sdpa_wrapper(
 
254
  # nb_latents = embed_sizes[0] * embed_sizes[1]* embed_sizes[2]
255
 
256
  # window = 0
257
+ # start_window_step = int(max_steps * 0.3)
258
  # start_layer = 10
259
+ # end_layer = 30
260
+ # if (layer < start_layer or layer > end_layer ) or current_step <start_window_step:
261
  # window = 0
262
  # else:
263
+ # # coef = min((max_steps - current_step)/(max_steps-start_window_step),1)*max(min((25 - layer)/(25-start_layer),1),0) * 0.7 + 0.3
264
+ # coef = 0.3
265
  # print(f"step: {current_step}, layer: {layer}, coef:{coef:0.1f}]")
266
  # window = math.ceil(coef* nb_latents)
267
 
268
  # invert_spaces = (layer + current_step) % 2 == 0 and window > 0
269
+ # invert_spaces = False
270
  # def flip(q):
271
  # q = q.reshape(*embed_sizes, *q.shape[-2:])
272
  # q = q.transpose(0,2)
wan/modules/model.py CHANGED
@@ -647,26 +647,6 @@ class WanModel(ModelMixin, ConfigMixin):
647
  self.init_weights()
648
 
649
 
650
- # self.freqs = torch.cat([
651
- # rope_params(1024, d - 4 * (d // 6)), #44
652
- # rope_params(1024, 2 * (d // 6)), #42
653
- # rope_params(1024, 2 * (d // 6)) #42
654
- # ],dim=1)
655
-
656
-
657
- def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None, device = "cuda"):
658
- dim = self.dim
659
- num_heads = self.num_heads
660
- d = dim // num_heads
661
- assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
662
-
663
-
664
- c1, s1 = rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ) if RIFLEx_k != None else rope_params(1024, dim= d - 4 * (d // 6)) #44
665
- c2, s2 = rope_params(1024, 2 * (d // 6)) #42
666
- c3, s3 = rope_params(1024, 2 * (d // 6)) #42
667
-
668
- return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device))
669
-
670
  def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
671
  rescale_func = np.poly1d(self.coefficients)
672
  e_list = []
 
647
  self.init_weights()
648
 
649
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
650
  def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
651
  rescale_func = np.poly1d(self.coefficients)
652
  e_list = []
wan/modules/sage2_core.py CHANGED
@@ -925,11 +925,11 @@ def sageattn_qk_int8_pv_fp8_window_cuda(
925
 
926
  if pv_accum_dtype == "fp32":
927
  if smooth_v:
928
- lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
929
  else:
930
- lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
931
  elif pv_accum_dtype == "fp32+fp32":
932
- lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window = window)
933
 
934
  o = o[..., :head_dim_og]
935
 
 
925
 
926
  if pv_accum_dtype == "fp32":
927
  if smooth_v:
928
+ lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
929
  else:
930
+ lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
931
  elif pv_accum_dtype == "fp32+fp32":
932
+ lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
933
 
934
  o = o[..., :head_dim_og]
935
 
wan/text2video.py CHANGED
@@ -24,6 +24,20 @@ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
24
  from wan.modules.posemb_layers import get_rotary_pos_embed
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  class WanT2V:
28
 
29
  def __init__(
@@ -136,6 +150,8 @@ class WanT2V:
136
  slg_layers = None,
137
  slg_start = 0.0,
138
  slg_end = 1.0,
 
 
139
  ):
140
  r"""
141
  Generates video frames from text prompt using diffusion process.
@@ -240,7 +256,7 @@ class WanT2V:
240
 
241
  # sample videos
242
  latents = noise
243
-
244
  freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
245
  arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
246
  arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
@@ -249,7 +265,6 @@ class WanT2V:
249
  # arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
250
  # arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
251
  # arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
252
-
253
  if self.model.enable_teacache:
254
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
255
  if callback != None:
@@ -280,8 +295,23 @@ class WanT2V:
280
  return None
281
 
282
  del latent_model_input
283
- noise_pred = noise_pred_uncond + guide_scale * (
284
- noise_pred_cond - noise_pred_uncond)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  del noise_pred_uncond
286
 
287
  temp_x0 = sample_scheduler.step(
 
24
  from wan.modules.posemb_layers import get_rotary_pos_embed
25
 
26
 
27
+ def optimized_scale(positive_flat, negative_flat):
28
+
29
+ # Calculate dot production
30
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
31
+
32
+ # Squared norm of uncondition
33
+ squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
34
+
35
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
36
+ st_star = dot_product / squared_norm
37
+
38
+ return st_star
39
+
40
+
41
  class WanT2V:
42
 
43
  def __init__(
 
150
  slg_layers = None,
151
  slg_start = 0.0,
152
  slg_end = 1.0,
153
+ cfg_star_switch = True,
154
+ cfg_zero_step = 5,
155
  ):
156
  r"""
157
  Generates video frames from text prompt using diffusion process.
 
256
 
257
  # sample videos
258
  latents = noise
259
+ batch_size =len(latents)
260
  freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
261
  arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
262
  arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
 
265
  # arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
266
  # arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
267
  # arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
 
268
  if self.model.enable_teacache:
269
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
270
  if callback != None:
 
295
  return None
296
 
297
  del latent_model_input
298
+
299
+ # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
300
+ noise_pred_text = noise_pred_cond
301
+ if cfg_star_switch:
302
+ positive_flat = noise_pred_text.view(batch_size, -1)
303
+ negative_flat = noise_pred_uncond.view(batch_size, -1)
304
+
305
+ alpha = optimized_scale(positive_flat,negative_flat)
306
+ alpha = alpha.view(batch_size, 1, 1, 1)
307
+
308
+
309
+ if (i <= cfg_zero_step):
310
+ noise_pred = noise_pred_text*0.
311
+ else:
312
+ noise_pred = noise_pred_uncond * alpha + guide_scale * (noise_pred_text - noise_pred_uncond * alpha)
313
+ else:
314
+ noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
315
  del noise_pred_uncond
316
 
317
  temp_x0 = sample_scheduler.step(