DeepBeepMeep commited on
Commit
bd91f8e
·
1 Parent(s): d6835bd

Added Vac Contronet support

Browse files
gradio_server.py CHANGED
@@ -14,7 +14,7 @@ import gradio as gr
14
  import random
15
  import json
16
  import wan
17
- from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES
18
  from wan.utils.utils import cache_video
19
  from wan.modules.attention import get_attention_modes, get_supported_attention_modes
20
  import torch
@@ -55,6 +55,11 @@ def format_time(seconds):
55
  def pil_to_base64_uri(pil_image, format="png", quality=75):
56
  if pil_image is None:
57
  return None
 
 
 
 
 
58
  buffer = io.BytesIO()
59
  try:
60
  img_to_save = pil_image
@@ -93,10 +98,11 @@ def process_prompt_and_add_tasks(
93
  loras_choices,
94
  loras_mult_choices,
95
  image_prompt_type,
96
- image_to_continue,
97
- image_to_end,
98
- video_to_continue,
99
  max_frames,
 
100
  temporal_upsampling,
101
  spatial_upsampling,
102
  RIFLEx_setting,
@@ -127,9 +133,9 @@ def process_prompt_and_add_tasks(
127
  return
128
 
129
  file_model_needed = model_needed(image2video)
 
 
130
  if image2video:
131
- width, height = resolution.split("x")
132
- width, height = int(width), int(height)
133
 
134
  if "480p" in file_model_needed and not "Fun" in file_model_needed and width * height > 848*480:
135
  gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
@@ -143,74 +149,94 @@ def process_prompt_and_add_tasks(
143
  gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
144
  return
145
 
146
- if image2video:
147
- if image_to_continue == None or isinstance(image_to_continue, list) and len(image_to_continue) == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  return
149
  if image_prompt_type == 0:
150
- image_to_end = None
151
- if isinstance(image_to_continue, list):
152
- image_to_continue = [ convert_image(tup[0]) for tup in image_to_continue ]
153
  else:
154
- image_to_continue = [convert_image(image_to_continue)]
155
- if image_to_end != None:
156
- if isinstance(image_to_end , list):
157
- image_to_end = [ convert_image(tup[0]) for tup in image_to_end ]
158
  else:
159
- image_to_end = [convert_image(image_to_end) ]
160
- if len(image_to_continue) != len(image_to_end):
161
  gr.Info("The number of start and end images should be the same ")
162
  return
163
 
164
  if multi_images_gen_type == 0:
165
  new_prompts = []
166
- new_image_to_continue = []
167
- new_image_to_end = []
168
- for i in range(len(prompts) * len(image_to_continue) ):
169
  new_prompts.append( prompts[ i % len(prompts)] )
170
- new_image_to_continue.append(image_to_continue[i // len(prompts)] )
171
- if image_to_end != None:
172
- new_image_to_end.append(image_to_end[i // len(prompts)] )
173
  prompts = new_prompts
174
- image_to_continue = new_image_to_continue
175
- if image_to_end != None:
176
- image_to_end = new_image_to_end
177
  else:
178
- if len(prompts) >= len(image_to_continue):
179
- if len(prompts) % len(image_to_continue) !=0:
180
  raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
181
- rep = len(prompts) // len(image_to_continue)
182
- new_image_to_continue = []
183
- new_image_to_end = []
184
  for i, _ in enumerate(prompts):
185
- new_image_to_continue.append(image_to_continue[i//rep] )
186
- if image_to_end != None:
187
- new_image_to_end.append(image_to_end[i//rep] )
188
- image_to_continue = new_image_to_continue
189
- if image_to_end != None:
190
- image_to_end = new_image_to_end
191
  else:
192
- if len(image_to_continue) % len(prompts) !=0:
193
  raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
194
- rep = len(image_to_continue) // len(prompts)
195
  new_prompts = []
196
- for i, _ in enumerate(image_to_continue):
197
  new_prompts.append( prompts[ i//rep] )
198
  prompts = new_prompts
199
 
200
- # elif video_to_continue != None and len(video_to_continue) >0 :
201
- # input_image_or_video_path = video_to_continue
202
- # # pipeline.num_input_frames = max_frames
203
- # # pipeline.max_frames = max_frames
204
- # else:
205
- # return
206
- # else:
207
- # input_image_or_video_path = None
208
- if image_to_continue == None:
209
- image_to_continue = [None] * len(prompts)
210
- if image_to_end == None:
211
- image_to_end = [None] * len(prompts)
212
-
213
- for single_prompt, image_start, image_end in zip(prompts, image_to_continue, image_to_end) :
214
  kwargs = {
215
  "prompt" : single_prompt,
216
  "negative_prompt" : negative_prompt,
@@ -228,10 +254,11 @@ def process_prompt_and_add_tasks(
228
  "loras_choices" : loras_choices,
229
  "loras_mult_choices" : loras_mult_choices,
230
  "image_prompt_type" : image_prompt_type,
231
- "image_to_continue": image_start,
232
- "image_to_end" : image_end,
233
- "video_to_continue" : video_to_continue ,
234
  "max_frames" : max_frames,
 
235
  "temporal_upsampling" : temporal_upsampling,
236
  "spatial_upsampling" : spatial_upsampling,
237
  "RIFLEx_setting" : RIFLEx_setting,
@@ -262,8 +289,9 @@ def add_video_task(**kwargs):
262
  queue = gen["queue"]
263
  task_id += 1
264
  current_task_id = task_id
265
- start_image_data = kwargs["image_to_continue"]
266
- end_image_data = kwargs["image_to_end"]
 
267
 
268
  queue.append({
269
  "id": current_task_id,
@@ -275,7 +303,7 @@ def add_video_task(**kwargs):
275
  "prompt": kwargs["prompt"],
276
  "start_image_data": start_image_data,
277
  "end_image_data": end_image_data,
278
- "start_image_data_base64": pil_to_base64_uri(start_image_data, format="jpeg", quality=70),
279
  "end_image_data_base64": pil_to_base64_uri(end_image_data, format="jpeg", quality=70)
280
  })
281
  return update_queue_data(queue)
@@ -342,6 +370,7 @@ def get_queue_table(queue):
342
  full_prompt = item['prompt'].replace('"', '"')
343
  prompt_cell = f'<span title="{full_prompt}">{truncated_prompt}</span>'
344
  start_img_uri =item.get('start_image_data_base64')
 
345
  end_img_uri = item.get('end_image_data_base64')
346
  thumbnail_size = "50px"
347
  num_steps = item.get('steps')
@@ -694,6 +723,9 @@ attention_modes_installed = get_attention_modes()
694
  attention_modes_supported = get_supported_attention_modes()
695
  args = _parse_args()
696
  args.flow_reverse = True
 
 
 
697
  # torch.backends.cuda.matmul.allow_fp16_accumulation = True
698
  lock_ui_attention = False
699
  lock_ui_transformer = False
@@ -706,7 +738,7 @@ quantizeTransformer = args.quantize_transformer
706
  check_loras = args.check_loras ==1
707
  advanced = args.advanced
708
 
709
- transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors"]
710
  transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ]
711
  text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
712
 
@@ -750,7 +782,7 @@ def get_default_settings(filename, i2v):
750
  "prompts": get_default_prompt(i2v),
751
  "resolution": "832x480",
752
  "video_length": 81,
753
- "image_prompt_type" : 0,
754
  "num_inference_steps": 30,
755
  "seed": -1,
756
  "repeat_generation": 1,
@@ -1149,6 +1181,9 @@ def get_model_name(model_filename):
1149
  if "Fun" in model_filename:
1150
  model_name = "Fun InP image2video"
1151
  model_name += " 14B" if "14B" in model_filename else " 1.3B"
 
 
 
1152
  elif "image" in model_filename:
1153
  model_name = "Wan2.1 image2video"
1154
  model_name += " 720p" if "720p" in model_filename else " 480p"
@@ -1353,22 +1388,22 @@ def refresh_gallery(state, msg):
1353
  end_img_md = ""
1354
  prompt = task["prompt"]
1355
 
1356
- if task.get('image2video'):
1357
- start_img_uri = task.get('start_image_data_base64')
1358
- end_img_uri = task.get('end_image_data_base64')
1359
- thumbnail_size = "100px"
1360
- if start_img_uri:
1361
- start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
1362
- if end_img_uri:
1363
- end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
1364
 
1365
  label = f"Prompt of Video being Generated"
1366
 
1367
  html = "<STYLE> #PINFO, #PINFO th, #PINFO td {border: 1px solid #CCCCCC;background-color:#FFFFFF;}</STYLE><TABLE WIDTH=100% ID=PINFO ><TR><TD width=100%>" + prompt + "</TD>"
1368
  if start_img_md != "":
1369
  html += "<TD>" + start_img_md + "</TD>"
1370
- if end_img_md != "":
1371
- html += "<TD>" + end_img_md + "</TD>"
1372
 
1373
  html += "</TR></TABLE>"
1374
  html_output = gr.HTML(html, visible= True)
@@ -1419,24 +1454,26 @@ def expand_slist(slist, num_inference_steps ):
1419
  new_slist.append(slist[ int(pos)])
1420
  pos += inc
1421
  return new_slist
1422
-
1423
  def convert_image(image):
1424
- from PIL import ExifTags
1425
-
1426
- image = image.convert('RGB')
1427
- for orientation in ExifTags.TAGS.keys():
1428
- if ExifTags.TAGS[orientation]=='Orientation':
1429
- break
1430
- exif = image.getexif()
1431
- if not orientation in exif:
1432
- return image
1433
- if exif[orientation] == 3:
1434
- image=image.rotate(180, expand=True)
1435
- elif exif[orientation] == 6:
1436
- image=image.rotate(270, expand=True)
1437
- elif exif[orientation] == 8:
1438
- image=image.rotate(90, expand=True)
1439
- return image
 
 
 
1440
 
1441
  def generate_video(
1442
  task_id,
@@ -1457,10 +1494,11 @@ def generate_video(
1457
  loras_choices,
1458
  loras_mult_choices,
1459
  image_prompt_type,
1460
- image_to_continue,
1461
- image_to_end,
1462
- video_to_continue,
1463
  max_frames,
 
1464
  temporal_upsampling,
1465
  spatial_upsampling,
1466
  RIFLEx_setting,
@@ -1507,7 +1545,6 @@ def generate_video(
1507
  gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.")
1508
  return
1509
 
1510
-
1511
 
1512
  if not image2video:
1513
  width, height = resolution.split("x")
@@ -1586,7 +1623,7 @@ def generate_video(
1586
 
1587
  enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
1588
  # VAE Tiling
1589
- device_mem_capacity = torch.cuda.get_device_properties(0).total_memory / 1048576
1590
 
1591
  joint_pass = boost ==1 #and profile != 1 and profile != 3
1592
  # TeaCache
@@ -1615,6 +1652,17 @@ def generate_video(
1615
  else:
1616
  raise gr.Error("Teacache not supported for this model")
1617
 
 
 
 
 
 
 
 
 
 
 
 
1618
  import random
1619
  if seed == None or seed <0:
1620
  seed = random.randint(0, 999999999)
@@ -1673,8 +1721,8 @@ def generate_video(
1673
  if image2video:
1674
  samples = wan_model.generate(
1675
  prompt,
1676
- image_to_continue,
1677
- image_to_end if image_to_end != None else None,
1678
  frame_num=(video_length // 4)* 4 + 1,
1679
  max_area=MAX_AREA_CONFIGS[resolution],
1680
  shift=flow_shift,
@@ -1697,6 +1745,9 @@ def generate_video(
1697
  else:
1698
  samples = wan_model.generate(
1699
  prompt,
 
 
 
1700
  frame_num=(video_length // 4)* 4 + 1,
1701
  size=(width, height),
1702
  shift=flow_shift,
@@ -1745,7 +1796,7 @@ def generate_video(
1745
  new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames."
1746
  else:
1747
  new_error = gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
1748
- tb = traceback.format_exc().split('\n')[:-2]
1749
  print('\n'.join(tb))
1750
  raise gr.Error(new_error, print_exception= False)
1751
 
@@ -1799,7 +1850,7 @@ def generate_video(
1799
 
1800
  if exp > 0:
1801
  from rife.inference import temporal_interpolation
1802
- sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device="cuda")
1803
  fps = fps * 2**exp
1804
 
1805
  if len(spatial_upsampling) > 0:
@@ -1831,8 +1882,7 @@ def generate_video(
1831
  normalize=True,
1832
  value_range=(-1, 1))
1833
 
1834
-
1835
- configs = get_settings_dict(state, image2video, prompt, 0 if image_to_end == None else 1 , video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1836
  loras_mult_choices, tea_cache , tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
1837
 
1838
  metadata_choice = server_config.get("metadata_choice","metadata")
@@ -2294,7 +2344,7 @@ def switch_advanced(state, new_advanced, lset_name):
2294
  return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
2295
 
2296
 
2297
- def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
2298
  loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
2299
 
2300
  loras = state["loras"]
@@ -2330,18 +2380,22 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
2330
  ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - image2video"
2331
  ui_settings["image_prompt_type"] = image_prompt_type
2332
  else:
 
 
 
 
2333
  ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video"
2334
 
2335
  return ui_settings
2336
 
2337
- def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
2338
  loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
2339
 
2340
  if state.get("validate_success",0) != 1:
2341
  return
2342
 
2343
  image2video = state["image2video"]
2344
- ui_defaults = get_settings_dict(state, image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
2345
  loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
2346
 
2347
  defaults_filename = get_settings_file_name(image2video)
@@ -2379,6 +2433,25 @@ def download_loras():
2379
  writer.write(f"Loras downloaded on the {dt} at {time.time()} on the {time.time()}")
2380
  return
2381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2382
  def generate_video_tab(image2video=False):
2383
  filename = transformer_filename_i2v if image2video else transformer_filename_t2v
2384
  ui_defaults= get_default_settings(filename, image2video)
@@ -2387,6 +2460,7 @@ def generate_video_tab(image2video=False):
2387
 
2388
  state_dict["advanced"] = advanced
2389
  state_dict["loras_model"] = filename
 
2390
  state_dict["image2video"] = image2video
2391
  gen = dict()
2392
  gen["queue"] = []
@@ -2461,31 +2535,51 @@ def generate_video_tab(image2video=False):
2461
  save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
2462
  delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
2463
  cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
2464
- video_to_continue = gr.Video(label= "Video to continue", visible= image2video and False) #######
2465
- image_prompt_type= ui_defaults.get("image_prompt_type",0)
2466
- image_prompt_type_radio = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =image_prompt_type, label="Location", show_label= False, scale= 3, visible=image2video)
2467
-
2468
- if args.multiple_images:
2469
- image_to_continue = gr.Gallery(
2470
- label="Images as starting points for new videos", type ="pil", #file_types= "image",
2471
- columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image2video)
2472
- else:
2473
- image_to_continue = gr.Image(label= "Image as a starting point for a new video", type ="pil", visible=image2video)
2474
 
2475
- if args.multiple_images:
2476
- image_to_end = gr.Gallery(
2477
- label="Images as ending points for new videos", type ="pil", #file_types= "image",
2478
- columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image_prompt_type==1)
2479
- else:
2480
- image_to_end = gr.Image(label= "Last Image for a new video", type ="pil", visible=image_prompt_type==1)
 
 
 
 
 
 
 
 
 
 
2481
 
2482
- def switch_image_prompt_type_radio(image_prompt_type_radio):
2483
- if args.multiple_images:
2484
- return gr.Gallery(visible = (image_prompt_type_radio == 1) )
 
 
 
 
 
 
 
 
2485
  else:
2486
- return gr.Image(visible = (image_prompt_type_radio == 1) )
 
 
 
 
2487
 
2488
- image_prompt_type_radio.change(fn=switch_image_prompt_type_radio, inputs=[image_prompt_type_radio], outputs=[image_to_end])
 
 
 
 
 
 
 
 
2489
 
2490
 
2491
  advanced_prompt = advanced
@@ -2518,7 +2612,6 @@ def generate_video_tab(image2video=False):
2518
  wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3)
2519
  wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
2520
  wizard_variables_var = gr.Text(wizard_variables, visible = False)
2521
- state = gr.State(state_dict)
2522
  with gr.Row():
2523
  if image2video:
2524
  resolution = gr.Dropdown(
@@ -2555,8 +2648,6 @@ def generate_video_tab(image2video=False):
2555
  video_length = gr.Slider(5, 193, value=ui_defaults["video_length"], step=4, label="Number of frames (16 = 1s)")
2556
  with gr.Column():
2557
  num_inference_steps = gr.Slider(1, 100, value=ui_defaults["num_inference_steps"], step=1, label="Number of Inference Steps")
2558
- with gr.Row():
2559
- max_frames = gr.Slider(1, 100, value=9, step=1, label="Number of input frames to use for Video2World prediction", visible=image2video and False) #########
2560
  show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced)
2561
  with gr.Row(visible=advanced) as advanced_row:
2562
  with gr.Column():
@@ -2605,7 +2696,7 @@ def generate_video_tab(image2video=False):
2605
  tea_cache_start_step_perc = gr.Slider(0, 100, value=ui_defaults["tea_cache_start_step_perc"], step=1, label="Tea Cache starting moment in % of generation")
2606
 
2607
  with gr.Row():
2608
- gr.Markdown("<B>Upsampling</B>")
2609
  with gr.Row():
2610
  temporal_upsampling_choice = gr.Dropdown(
2611
  choices=[
@@ -2687,9 +2778,10 @@ def generate_video_tab(image2video=False):
2687
  show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
2688
  fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
2689
  with gr.Column():
2690
- gen_status = gr.Text(label="Status", interactive= False)
2691
- full_sync = gr.Text(label="Status", interactive= False, visible= False)
2692
- light_sync = gr.Text(label="Status", interactive= False, visible= False)
 
2693
  gen_progress_html = gr.HTML(
2694
  label="Status",
2695
  value="Idle",
@@ -2709,8 +2801,8 @@ def generate_video_tab(image2video=False):
2709
  abort_btn = gr.Button("Abort")
2710
 
2711
  queue_df = gr.DataFrame(
2712
- headers=["Qty","Prompt", "Length","Steps","Start", "End", "", "", ""],
2713
- datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
2714
  column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"],
2715
  interactive=False,
2716
  col_count=(9, "fixed"),
@@ -2792,7 +2884,7 @@ def generate_video_tab(image2video=False):
2792
  show_progress="hidden"
2793
  )
2794
  save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
2795
- save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
2796
  loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling_choice, spatial_upsampling_choice, RIFLEx_setting, slg_switch, slg_layers,
2797
  slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = [])
2798
  save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
@@ -2808,21 +2900,30 @@ def generate_video_tab(image2video=False):
2808
  refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
2809
  output.select(select_video, state, None )
2810
 
 
 
2811
  gen_status.change(refresh_gallery,
2812
  inputs = [state, gen_status],
2813
  outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn])
2814
 
2815
- full_sync.change(refresh_gallery,
 
 
 
2816
  inputs = [state, gen_status],
2817
  outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
2818
- ).then( fn=wait_tasks_done,
2819
  inputs= [state],
2820
  outputs =[gen_status],
2821
  ).then(finalize_generation,
2822
  inputs= [state],
2823
  outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
2824
  )
2825
- light_sync.change(refresh_gallery,
 
 
 
 
2826
  inputs = [state, gen_status],
2827
  outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
2828
  )
@@ -2848,10 +2949,11 @@ def generate_video_tab(image2video=False):
2848
  loras_choices,
2849
  loras_mult_choices,
2850
  image_prompt_type_radio,
2851
- image_to_continue,
2852
- image_to_end,
2853
- video_to_continue,
2854
  max_frames,
 
2855
  temporal_upsampling_choice,
2856
  spatial_upsampling_choice,
2857
  RIFLEx_setting,
@@ -2902,7 +3004,7 @@ def generate_video_tab(image2video=False):
2902
  )
2903
  return loras_column, loras_choices, presets_column, lset_name, header, light_sync, full_sync, state
2904
 
2905
- def generate_doxnload_tab(presets_column, loras_column, lset_name,loras_choices, state):
2906
  with gr.Row():
2907
  with gr.Row(scale =2):
2908
  gr.Markdown("<I>Wan2GP's Lora Festival ! Press the following button to download i2v <B>Remade</B> Loras collection (and bonuses Loras).")
@@ -2928,6 +3030,7 @@ def generate_configuration_tab():
2928
  ("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0),
2929
  ("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1),
2930
  ("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2),
 
2931
  ],
2932
  value= index,
2933
  label="Transformer model for Text to Video",
@@ -3108,16 +3211,17 @@ def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData):
3108
  t2v_light_sync = gr.Text()
3109
  i2v_full_sync = gr.Text()
3110
  t2v_full_sync = gr.Text()
3111
- if new_t2v or new_i2v:
3112
- last_tab_was_image2video =global_state.get("last_tab_was_image2video", None)
3113
- if last_tab_was_image2video == None or last_tab_was_image2video:
3114
- gen = i2v_state["gen"]
3115
- t2v_state["gen"] = gen
3116
- else:
3117
- gen = t2v_state["gen"]
3118
- i2v_state["gen"] = gen
3119
 
 
 
 
 
 
 
 
3120
 
 
 
3121
  if last_tab_was_image2video != None and new_t2v != new_i2v:
3122
  gen_location = gen.get("location", None)
3123
  if "in_progress" in gen and gen_location !=None and not (gen_location and new_i2v or not gen_location and new_t2v) :
@@ -3131,7 +3235,6 @@ def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData):
3131
  else:
3132
  t2v_light_sync = gr.Text(str(time.time()))
3133
 
3134
-
3135
  global_state["last_tab_was_image2video"] = new_i2v
3136
 
3137
  if(server_config.get("reload_model",2) == 1):
@@ -3433,7 +3536,7 @@ def create_demo():
3433
  }
3434
  """
3435
  with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
3436
- gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.4 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
3437
  gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
3438
 
3439
  with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
@@ -3454,7 +3557,7 @@ def create_demo():
3454
  i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync, i2v_state = generate_video_tab(True)
3455
  if not args.lock_config:
3456
  with gr.Tab("Downloads", id="downloads") as downloads_tab:
3457
- generate_doxnload_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state)
3458
  with gr.Tab("Configuration"):
3459
  generate_configuration_tab()
3460
  with gr.Tab("About"):
 
14
  import random
15
  import json
16
  import wan
17
+ from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS
18
  from wan.utils.utils import cache_video
19
  from wan.modules.attention import get_attention_modes, get_supported_attention_modes
20
  import torch
 
55
  def pil_to_base64_uri(pil_image, format="png", quality=75):
56
  if pil_image is None:
57
  return None
58
+
59
+ if isinstance(pil_image, str):
60
+ from wan.utils.utils import get_video_frame
61
+ pil_image = get_video_frame(pil_image, 0)
62
+
63
  buffer = io.BytesIO()
64
  try:
65
  img_to_save = pil_image
 
98
  loras_choices,
99
  loras_mult_choices,
100
  image_prompt_type,
101
+ image_source1,
102
+ image_source2,
103
+ image_source3,
104
  max_frames,
105
+ remove_background_image_ref,
106
  temporal_upsampling,
107
  spatial_upsampling,
108
  RIFLEx_setting,
 
133
  return
134
 
135
  file_model_needed = model_needed(image2video)
136
+ width, height = resolution.split("x")
137
+ width, height = int(width), int(height)
138
  if image2video:
 
 
139
 
140
  if "480p" in file_model_needed and not "Fun" in file_model_needed and width * height > 848*480:
141
  gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
 
149
  gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
150
  return
151
 
152
+ if not image2video:
153
+ if "Vace" in file_model_needed and "1.3B" in file_model_needed :
154
+ resolution_reformated = str(height) + "*" + str(width)
155
+ if not resolution_reformated in VACE_SIZE_CONFIGS:
156
+ res = VACE_SIZE_CONFIGS.keys().join(" and ")
157
+ gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
158
+ return
159
+
160
+ if not "I" in image_prompt_type:
161
+ image_source1 = None
162
+ if not "V" in image_prompt_type:
163
+ image_source2 = None
164
+ if not "M" in image_prompt_type:
165
+ image_source3 = None
166
+
167
+ if isinstance(image_source1, list):
168
+ image_source1 = [ convert_image(tup[0]) for tup in image_source1 ]
169
+
170
+ from wan.utils.utils import resize_and_remove_background
171
+ image_source1 = resize_and_remove_background(image_source1, width, height, remove_background_image_ref ==1)
172
+
173
+ image_source1 = [ image_source1 ] * len(prompts)
174
+ image_source2 = [ image_source2 ] * len(prompts)
175
+ image_source3 = [ image_source3 ] * len(prompts)
176
+
177
+ else:
178
+ if image_source1 == None or isinstance(image_source1, list) and len(image_source1) == 0:
179
  return
180
  if image_prompt_type == 0:
181
+ image_source2 = None
182
+ if isinstance(image_source1, list):
183
+ image_source1 = [ convert_image(tup[0]) for tup in image_source1 ]
184
  else:
185
+ image_source1 = [convert_image(image_source1)]
186
+ if image_source2 != None:
187
+ if isinstance(image_source2 , list):
188
+ image_source2 = [ convert_image(tup[0]) for tup in image_source2 ]
189
  else:
190
+ image_source2 = [convert_image(image_source2) ]
191
+ if len(image_source1) != len(image_source2):
192
  gr.Info("The number of start and end images should be the same ")
193
  return
194
 
195
  if multi_images_gen_type == 0:
196
  new_prompts = []
197
+ new_image_source1 = []
198
+ new_image_source2 = []
199
+ for i in range(len(prompts) * len(image_source1) ):
200
  new_prompts.append( prompts[ i % len(prompts)] )
201
+ new_image_source1.append(image_source1[i // len(prompts)] )
202
+ if image_source2 != None:
203
+ new_image_source2.append(image_source2[i // len(prompts)] )
204
  prompts = new_prompts
205
+ image_source1 = new_image_source1
206
+ if image_source2 != None:
207
+ image_source2 = new_image_source2
208
  else:
209
+ if len(prompts) >= len(image_source1):
210
+ if len(prompts) % len(image_source1) !=0:
211
  raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
212
+ rep = len(prompts) // len(image_source1)
213
+ new_image_source1 = []
214
+ new_image_source2 = []
215
  for i, _ in enumerate(prompts):
216
+ new_image_source1.append(image_source1[i//rep] )
217
+ if image_source2 != None:
218
+ new_image_source2.append(image_source2[i//rep] )
219
+ image_source1 = new_image_source1
220
+ if image_source2 != None:
221
+ image_source2 = new_image_source2
222
  else:
223
+ if len(image_source1) % len(prompts) !=0:
224
  raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
225
+ rep = len(image_source1) // len(prompts)
226
  new_prompts = []
227
+ for i, _ in enumerate(image_source1):
228
  new_prompts.append( prompts[ i//rep] )
229
  prompts = new_prompts
230
 
231
+
232
+ if image_source1 == None:
233
+ image_source1 = [None] * len(prompts)
234
+ if image_source2 == None:
235
+ image_source2 = [None] * len(prompts)
236
+ if image_source3 == None:
237
+ image_source3 = [None] * len(prompts)
238
+
239
+ for single_prompt, image_source1, image_source2, image_source3 in zip(prompts, image_source1, image_source2, image_source3) :
 
 
 
 
 
240
  kwargs = {
241
  "prompt" : single_prompt,
242
  "negative_prompt" : negative_prompt,
 
254
  "loras_choices" : loras_choices,
255
  "loras_mult_choices" : loras_mult_choices,
256
  "image_prompt_type" : image_prompt_type,
257
+ "image_source1": image_source1,
258
+ "image_source2" : image_source2,
259
+ "image_source3" : image_source3 ,
260
  "max_frames" : max_frames,
261
+ "remove_background_image_ref" : remove_background_image_ref,
262
  "temporal_upsampling" : temporal_upsampling,
263
  "spatial_upsampling" : spatial_upsampling,
264
  "RIFLEx_setting" : RIFLEx_setting,
 
289
  queue = gen["queue"]
290
  task_id += 1
291
  current_task_id = task_id
292
+ start_image_data = kwargs["image_source1"]
293
+ start_image_data = [start_image_data] if not isinstance(start_image_data, list) else start_image_data
294
+ end_image_data = kwargs["image_source2"]
295
 
296
  queue.append({
297
  "id": current_task_id,
 
303
  "prompt": kwargs["prompt"],
304
  "start_image_data": start_image_data,
305
  "end_image_data": end_image_data,
306
+ "start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data],
307
  "end_image_data_base64": pil_to_base64_uri(end_image_data, format="jpeg", quality=70)
308
  })
309
  return update_queue_data(queue)
 
370
  full_prompt = item['prompt'].replace('"', '&quot;')
371
  prompt_cell = f'<span title="{full_prompt}">{truncated_prompt}</span>'
372
  start_img_uri =item.get('start_image_data_base64')
373
+ start_img_uri = start_img_uri[0] if start_img_uri !=None else None
374
  end_img_uri = item.get('end_image_data_base64')
375
  thumbnail_size = "50px"
376
  num_steps = item.get('steps')
 
723
  attention_modes_supported = get_supported_attention_modes()
724
  args = _parse_args()
725
  args.flow_reverse = True
726
+ processing_device = args.gpu
727
+ if len(processing_device) == 0:
728
+ processing_device ="cuda"
729
  # torch.backends.cuda.matmul.allow_fp16_accumulation = True
730
  lock_ui_attention = False
731
  lock_ui_transformer = False
 
738
  check_loras = args.check_loras ==1
739
  advanced = args.advanced
740
 
741
+ transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors"]
742
  transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ]
743
  text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
744
 
 
782
  "prompts": get_default_prompt(i2v),
783
  "resolution": "832x480",
784
  "video_length": 81,
785
+ "image_prompt_type" : 0 if i2v else "",
786
  "num_inference_steps": 30,
787
  "seed": -1,
788
  "repeat_generation": 1,
 
1181
  if "Fun" in model_filename:
1182
  model_name = "Fun InP image2video"
1183
  model_name += " 14B" if "14B" in model_filename else " 1.3B"
1184
+ elif "Vace" in model_filename:
1185
+ model_name = "Vace ControlNet text2video"
1186
+ model_name += " 14B" if "14B" in model_filename else " 1.3B"
1187
  elif "image" in model_filename:
1188
  model_name = "Wan2.1 image2video"
1189
  model_name += " 720p" if "720p" in model_filename else " 480p"
 
1388
  end_img_md = ""
1389
  prompt = task["prompt"]
1390
 
1391
+ start_img_uri = task.get('start_image_data_base64')
1392
+ start_img_uri = start_img_uri[0] if start_img_uri !=None else None
1393
+ end_img_uri = task.get('end_image_data_base64')
1394
+ thumbnail_size = "100px"
1395
+ if start_img_uri:
1396
+ start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
1397
+ if end_img_uri:
1398
+ end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
1399
 
1400
  label = f"Prompt of Video being Generated"
1401
 
1402
  html = "<STYLE> #PINFO, #PINFO th, #PINFO td {border: 1px solid #CCCCCC;background-color:#FFFFFF;}</STYLE><TABLE WIDTH=100% ID=PINFO ><TR><TD width=100%>" + prompt + "</TD>"
1403
  if start_img_md != "":
1404
  html += "<TD>" + start_img_md + "</TD>"
1405
+ if end_img_md != "":
1406
+ html += "<TD>" + end_img_md + "</TD>"
1407
 
1408
  html += "</TR></TABLE>"
1409
  html_output = gr.HTML(html, visible= True)
 
1454
  new_slist.append(slist[ int(pos)])
1455
  pos += inc
1456
  return new_slist
 
1457
  def convert_image(image):
1458
+
1459
+ from PIL import ExifTags, ImageOps
1460
+ from typing import cast
1461
+
1462
+ return cast(Image, ImageOps.exif_transpose(image))
1463
+ # image = image.convert('RGB')
1464
+ # for orientation in ExifTags.TAGS.keys():
1465
+ # if ExifTags.TAGS[orientation]=='Orientation':
1466
+ # break
1467
+ # exif = image.getexif()
1468
+ # return image
1469
+ # if not orientation in exif:
1470
+ # if exif[orientation] == 3:
1471
+ # image=image.rotate(180, expand=True)
1472
+ # elif exif[orientation] == 6:
1473
+ # image=image.rotate(270, expand=True)
1474
+ # elif exif[orientation] == 8:
1475
+ # image=image.rotate(90, expand=True)
1476
+ # return image
1477
 
1478
  def generate_video(
1479
  task_id,
 
1494
  loras_choices,
1495
  loras_mult_choices,
1496
  image_prompt_type,
1497
+ image_source1,
1498
+ image_source2,
1499
+ image_source3,
1500
  max_frames,
1501
+ remove_background_image_ref,
1502
  temporal_upsampling,
1503
  spatial_upsampling,
1504
  RIFLEx_setting,
 
1545
  gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.")
1546
  return
1547
 
 
1548
 
1549
  if not image2video:
1550
  width, height = resolution.split("x")
 
1623
 
1624
  enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
1625
  # VAE Tiling
1626
+ device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
1627
 
1628
  joint_pass = boost ==1 #and profile != 1 and profile != 3
1629
  # TeaCache
 
1652
  else:
1653
  raise gr.Error("Teacache not supported for this model")
1654
 
1655
+ if "Vace" in model_filename:
1656
+ resolution_reformated = str(height) + "*" + str(width)
1657
+ src_video, src_mask, src_ref_images = wan_model.prepare_source([image_source2],
1658
+ [image_source3],
1659
+ [image_source1],
1660
+ video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu",
1661
+ trim_video=max_frames)
1662
+ else:
1663
+ src_video, src_mask, src_ref_images = None, None, None
1664
+
1665
+
1666
  import random
1667
  if seed == None or seed <0:
1668
  seed = random.randint(0, 999999999)
 
1721
  if image2video:
1722
  samples = wan_model.generate(
1723
  prompt,
1724
+ image_source1,
1725
+ image_source2 if image_source2 != None else None,
1726
  frame_num=(video_length // 4)* 4 + 1,
1727
  max_area=MAX_AREA_CONFIGS[resolution],
1728
  shift=flow_shift,
 
1745
  else:
1746
  samples = wan_model.generate(
1747
  prompt,
1748
+ input_frames = src_video,
1749
+ input_ref_images= src_ref_images,
1750
+ input_masks = src_mask,
1751
  frame_num=(video_length // 4)* 4 + 1,
1752
  size=(width, height),
1753
  shift=flow_shift,
 
1796
  new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames."
1797
  else:
1798
  new_error = gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
1799
+ tb = traceback.format_exc().split('\n')[:-1]
1800
  print('\n'.join(tb))
1801
  raise gr.Error(new_error, print_exception= False)
1802
 
 
1850
 
1851
  if exp > 0:
1852
  from rife.inference import temporal_interpolation
1853
+ sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=processing_device)
1854
  fps = fps * 2**exp
1855
 
1856
  if len(spatial_upsampling) > 0:
 
1882
  normalize=True,
1883
  value_range=(-1, 1))
1884
 
1885
+ configs = get_settings_dict(state, image2video, True, prompt, image_prompt_type, max_frames , remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
 
1886
  loras_mult_choices, tea_cache , tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
1887
 
1888
  metadata_choice = server_config.get("metadata_choice","metadata")
 
2344
  return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
2345
 
2346
 
2347
+ def get_settings_dict(state, i2v, image_metadata, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
2348
  loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
2349
 
2350
  loras = state["loras"]
 
2380
  ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - image2video"
2381
  ui_settings["image_prompt_type"] = image_prompt_type
2382
  else:
2383
+ if "Vace" in transformer_filename_t2v or not image_metadata:
2384
+ ui_settings["image_prompt_type"] = image_prompt_type
2385
+ ui_settings["max_frames"] = max_frames
2386
+ ui_settings["remove_background_image_ref"] = remove_background_image_ref
2387
  ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video"
2388
 
2389
  return ui_settings
2390
 
2391
+ def save_settings(state, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
2392
  loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
2393
 
2394
  if state.get("validate_success",0) != 1:
2395
  return
2396
 
2397
  image2video = state["image2video"]
2398
+ ui_defaults = get_settings_dict(state, image2video, False, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
2399
  loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
2400
 
2401
  defaults_filename = get_settings_file_name(image2video)
 
2433
  writer.write(f"Loras downloaded on the {dt} at {time.time()} on the {time.time()}")
2434
  return
2435
 
2436
+ def refresh_i2v_image_prompt_type_radio(state, image_prompt_type_radio):
2437
+ if args.multiple_images:
2438
+ return gr.Gallery(visible = (image_prompt_type_radio == 1) )
2439
+ else:
2440
+ return gr.Image(visible = (image_prompt_type_radio == 1) )
2441
+
2442
+ def refresh_t2v_image_prompt_type_radio(state, image_prompt_type_radio):
2443
+ vace_model = "Vace" in state["image_input_type_model"] and not state["image2video"]
2444
+ return gr.Column(visible= vace_model), gr.Radio(value= image_prompt_type_radio), gr.Gallery(visible = "I" in image_prompt_type_radio), gr.Video(visible= "V" in image_prompt_type_radio),gr.Video(visible= "M" in image_prompt_type_radio ), gr.Text(visible= "V" in image_prompt_type_radio) , gr.Checkbox(visible= "I" in image_prompt_type_radio)
2445
+
2446
+ def check_refresh_input_type(state):
2447
+ if not state["image2video"]:
2448
+ model_file_name = state["image_input_type_model"]
2449
+ model_file_needed= model_needed(False)
2450
+ if model_file_name != model_file_needed:
2451
+ state["image_input_type_model"] = model_file_needed
2452
+ return gr.Text(value= str(time.time()))
2453
+ return gr.Text()
2454
+
2455
  def generate_video_tab(image2video=False):
2456
  filename = transformer_filename_i2v if image2video else transformer_filename_t2v
2457
  ui_defaults= get_default_settings(filename, image2video)
 
2460
 
2461
  state_dict["advanced"] = advanced
2462
  state_dict["loras_model"] = filename
2463
+ state_dict["image_input_type_model"] = filename
2464
  state_dict["image2video"] = image2video
2465
  gen = dict()
2466
  gen["queue"] = []
 
2535
  save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
2536
  delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
2537
  cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
 
 
 
 
 
 
 
 
 
 
2538
 
2539
+ state = gr.State(state_dict)
2540
+ vace_model = "Vace" in filename and not image2video
2541
+ trigger_refresh_input_type = gr.Text(interactive= False, visible= False)
2542
+ with gr.Column(visible= image2video or vace_model) as image_prompt_column:
2543
+ if image2video:
2544
+ image_source3 = gr.Video(label= "Placeholder", visible= image2video and False)
2545
+
2546
+ image_prompt_type= ui_defaults.get("image_prompt_type",0)
2547
+ image_prompt_type_radio = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =image_prompt_type, label="Location", show_label= False, scale= 3)
2548
+
2549
+ if args.multiple_images:
2550
+ image_source1 = gr.Gallery(
2551
+ label="Images as starting points for new videos", type ="pil", #file_types= "image",
2552
+ columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True)
2553
+ else:
2554
+ image_source1 = gr.Image(label= "Image as a starting point for a new video", type ="pil")
2555
 
2556
+ if args.multiple_images:
2557
+ image_source2 = gr.Gallery(
2558
+ label="Images as ending points for new videos", type ="pil", #file_types= "image",
2559
+ columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image_prompt_type==1)
2560
+ else:
2561
+ image_source2 = gr.Image(label= "Last Image for a new video", type ="pil", visible=image_prompt_type==1)
2562
+
2563
+
2564
+ image_prompt_type_radio.change(fn=refresh_i2v_image_prompt_type_radio, inputs=[state, image_prompt_type_radio], outputs=[image_source2])
2565
+ max_frames = gr.Slider(1, 100,step=1, visible = False)
2566
+ remove_background_image_ref = gr.Text(visible = False)
2567
  else:
2568
+ image_prompt_type= ui_defaults.get("image_prompt_type","I")
2569
+ image_prompt_type_radio = gr.Radio( [("Use Images Ref", "I"),("a Video", "V"), ("Images + a Video", "IV"), ("Video + Video Mask", "VM"), ("Images + Video + Mask", "IVM")], value =image_prompt_type, label="Location", show_label= False, scale= 3, visible = vace_model)
2570
+ image_source1 = gr.Gallery(
2571
+ label="Reference Images of Faces and / or Object to be found in the Video", type ="pil",
2572
+ columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in image_prompt_type )
2573
 
2574
+ image_source2 = gr.Video(label= "Reference Video", visible= "V" in image_prompt_type )
2575
+ with gr.Row():
2576
+ max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Reference Video to use in Video (0 for as many as possible)", visible= "V" in image_prompt_type, scale = 2 )
2577
+ remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Images Ref. Background", visible= "I" in image_prompt_type, scale =1 )
2578
+
2579
+ image_source3 = gr.Video(label= "Video Mask (white pixels = Mask)", visible= "M" in image_prompt_type )
2580
+
2581
+
2582
+ gr.on(triggers=[image_prompt_type_radio.change, trigger_refresh_input_type.change], fn=refresh_t2v_image_prompt_type_radio, inputs=[state, image_prompt_type_radio], outputs=[image_prompt_column, image_prompt_type_radio, image_source1, image_source2, image_source3, max_frames, remove_background_image_ref])
2583
 
2584
 
2585
  advanced_prompt = advanced
 
2612
  wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3)
2613
  wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
2614
  wizard_variables_var = gr.Text(wizard_variables, visible = False)
 
2615
  with gr.Row():
2616
  if image2video:
2617
  resolution = gr.Dropdown(
 
2648
  video_length = gr.Slider(5, 193, value=ui_defaults["video_length"], step=4, label="Number of frames (16 = 1s)")
2649
  with gr.Column():
2650
  num_inference_steps = gr.Slider(1, 100, value=ui_defaults["num_inference_steps"], step=1, label="Number of Inference Steps")
 
 
2651
  show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced)
2652
  with gr.Row(visible=advanced) as advanced_row:
2653
  with gr.Column():
 
2696
  tea_cache_start_step_perc = gr.Slider(0, 100, value=ui_defaults["tea_cache_start_step_perc"], step=1, label="Tea Cache starting moment in % of generation")
2697
 
2698
  with gr.Row():
2699
+ gr.Markdown("<B>Upsampling - postprocessing that may improve fluidity and the size of the video</B>")
2700
  with gr.Row():
2701
  temporal_upsampling_choice = gr.Dropdown(
2702
  choices=[
 
2778
  show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
2779
  fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
2780
  with gr.Column():
2781
+ gen_status = gr.Text(interactive= False)
2782
+ full_sync = gr.Text(interactive= False, visible= False)
2783
+ light_sync = gr.Text(interactive= False, visible= False)
2784
+
2785
  gen_progress_html = gr.HTML(
2786
  label="Status",
2787
  value="Idle",
 
2801
  abort_btn = gr.Button("Abort")
2802
 
2803
  queue_df = gr.DataFrame(
2804
+ headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""],
2805
+ datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
2806
  column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"],
2807
  interactive=False,
2808
  col_count=(9, "fixed"),
 
2884
  show_progress="hidden"
2885
  )
2886
  save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
2887
+ save_settings, inputs = [state, prompt, image_prompt_type_radio, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
2888
  loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling_choice, spatial_upsampling_choice, RIFLEx_setting, slg_switch, slg_layers,
2889
  slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = [])
2890
  save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
 
2900
  refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
2901
  output.select(select_video, state, None )
2902
 
2903
+
2904
+
2905
  gen_status.change(refresh_gallery,
2906
  inputs = [state, gen_status],
2907
  outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn])
2908
 
2909
+ full_sync.change(fn= check_refresh_input_type,
2910
+ inputs= [state],
2911
+ outputs= [trigger_refresh_input_type]
2912
+ ).then(fn=refresh_gallery,
2913
  inputs = [state, gen_status],
2914
  outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
2915
+ ).then(fn=wait_tasks_done,
2916
  inputs= [state],
2917
  outputs =[gen_status],
2918
  ).then(finalize_generation,
2919
  inputs= [state],
2920
  outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
2921
  )
2922
+
2923
+ light_sync.change(fn= check_refresh_input_type,
2924
+ inputs= [state],
2925
+ outputs= [trigger_refresh_input_type]
2926
+ ).then(fn=refresh_gallery,
2927
  inputs = [state, gen_status],
2928
  outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
2929
  )
 
2949
  loras_choices,
2950
  loras_mult_choices,
2951
  image_prompt_type_radio,
2952
+ image_source1,
2953
+ image_source2,
2954
+ image_source3,
2955
  max_frames,
2956
+ remove_background_image_ref,
2957
  temporal_upsampling_choice,
2958
  spatial_upsampling_choice,
2959
  RIFLEx_setting,
 
3004
  )
3005
  return loras_column, loras_choices, presets_column, lset_name, header, light_sync, full_sync, state
3006
 
3007
+ def generate_download_tab(presets_column, loras_column, lset_name,loras_choices, state):
3008
  with gr.Row():
3009
  with gr.Row(scale =2):
3010
  gr.Markdown("<I>Wan2GP's Lora Festival ! Press the following button to download i2v <B>Remade</B> Loras collection (and bonuses Loras).")
 
3030
  ("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0),
3031
  ("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1),
3032
  ("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2),
3033
+ ("WAN 2.1 VACE 1.3B Text to Video / Control Net - text generation driven by reference images or videos", 3),
3034
  ],
3035
  value= index,
3036
  label="Transformer model for Text to Video",
 
3211
  t2v_light_sync = gr.Text()
3212
  i2v_full_sync = gr.Text()
3213
  t2v_full_sync = gr.Text()
 
 
 
 
 
 
 
 
3214
 
3215
+ last_tab_was_image2video =global_state.get("last_tab_was_image2video", None)
3216
+ if last_tab_was_image2video == None or last_tab_was_image2video:
3217
+ gen = i2v_state["gen"]
3218
+ t2v_state["gen"] = gen
3219
+ else:
3220
+ gen = t2v_state["gen"]
3221
+ i2v_state["gen"] = gen
3222
 
3223
+
3224
+ if new_t2v or new_i2v:
3225
  if last_tab_was_image2video != None and new_t2v != new_i2v:
3226
  gen_location = gen.get("location", None)
3227
  if "in_progress" in gen and gen_location !=None and not (gen_location and new_i2v or not gen_location and new_t2v) :
 
3235
  else:
3236
  t2v_light_sync = gr.Text(str(time.time()))
3237
 
 
3238
  global_state["last_tab_was_image2video"] = new_i2v
3239
 
3240
  if(server_config.get("reload_model",2) == 1):
 
3536
  }
3537
  """
3538
  with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
3539
+ gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
3540
  gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
3541
 
3542
  with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
 
3557
  i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync, i2v_state = generate_video_tab(True)
3558
  if not args.lock_config:
3559
  with gr.Tab("Downloads", id="downloads") as downloads_tab:
3560
+ generate_download_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state)
3561
  with gr.Tab("Configuration"):
3562
  generate_configuration_tab()
3563
  with gr.Tab("About"):
requirements.txt CHANGED
@@ -11,11 +11,15 @@ easydict
11
  ftfy
12
  dashscope
13
  imageio-ffmpeg
14
- # flash_attn
15
  gradio>=5.0.0
16
  numpy>=1.23.5,<2
17
  einops
18
  moviepy==1.0.3
19
  mmgp==3.3.4
20
  peft==0.14.0
21
- mutagen
 
 
 
 
 
11
  ftfy
12
  dashscope
13
  imageio-ffmpeg
14
+ # flash_attn
15
  gradio>=5.0.0
16
  numpy>=1.23.5,<2
17
  einops
18
  moviepy==1.0.3
19
  mmgp==3.3.4
20
  peft==0.14.0
21
+ mutagen
22
+ decord
23
+ onnxruntime-gpu
24
+ rembg[gpu]==2.0.65
25
+ # rembg==2.0.65
wan/configs/__init__.py CHANGED
@@ -40,3 +40,17 @@ SUPPORTED_SIZES = {
40
  'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
41
  't2i-14B': tuple(SIZE_CONFIGS.keys()),
42
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
41
  't2i-14B': tuple(SIZE_CONFIGS.keys()),
42
  }
43
+
44
+ VACE_SIZE_CONFIGS = {
45
+ '480*832': (480, 832),
46
+ '832*480': (832, 480),
47
+ }
48
+
49
+ VACE_MAX_AREA_CONFIGS = {
50
+ '480*832': 480 * 832,
51
+ '832*480': 832 * 480,
52
+ }
53
+
54
+ VACE_SUPPORTED_SIZES = {
55
+ 'vace-1.3B': ('480*832', '832*480'),
56
+ }
wan/modules/model.py CHANGED
@@ -377,6 +377,7 @@ class WanI2VCrossAttention(WanSelfAttention):
377
  return x
378
 
379
 
 
380
  WAN_CROSSATTENTION_CLASSES = {
381
  't2v_cross_attn': WanT2VCrossAttention,
382
  'i2v_cross_attn': WanI2VCrossAttention,
@@ -393,7 +394,9 @@ class WanAttentionBlock(nn.Module):
393
  window_size=(-1, -1),
394
  qk_norm=True,
395
  cross_attn_norm=False,
396
- eps=1e-6):
 
 
397
  super().__init__()
398
  self.dim = dim
399
  self.ffn_dim = ffn_dim
@@ -422,6 +425,7 @@ class WanAttentionBlock(nn.Module):
422
 
423
  # modulation
424
  self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
 
425
 
426
  def forward(
427
  self,
@@ -432,6 +436,8 @@ class WanAttentionBlock(nn.Module):
432
  freqs,
433
  context,
434
  context_lens,
 
 
435
  ):
436
  r"""
437
  Args:
@@ -480,10 +486,49 @@ class WanAttentionBlock(nn.Module):
480
  x.addcmul_(y, e[5])
481
 
482
 
483
-
484
- return x
485
-
486
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  class Head(nn.Module):
488
 
489
  def __init__(self, dim, out_dim, patch_size, eps=1e-6):
@@ -544,6 +589,8 @@ class WanModel(ModelMixin, ConfigMixin):
544
 
545
  @register_to_config
546
  def __init__(self,
 
 
547
  model_type='t2v',
548
  patch_size=(1, 2, 2),
549
  text_len=512,
@@ -628,12 +675,13 @@ class WanModel(ModelMixin, ConfigMixin):
628
  self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
629
 
630
  # blocks
631
- cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
632
- self.blocks = nn.ModuleList([
633
- WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
634
- window_size, qk_norm, cross_attn_norm, eps)
635
- for _ in range(num_layers)
636
- ])
 
637
 
638
  # head
639
  self.head = Head(dim, out_dim, patch_size, eps)
@@ -646,6 +694,33 @@ class WanModel(ModelMixin, ConfigMixin):
646
  # initialize weights
647
  self.init_weights()
648
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
 
650
  def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
651
  rescale_func = np.poly1d(self.coefficients)
@@ -688,6 +763,36 @@ class WanModel(ModelMixin, ConfigMixin):
688
  self.rel_l1_thresh = best_threshold
689
  print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
690
  return best_threshold
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
691
 
692
  def forward(
693
  self,
@@ -695,6 +800,8 @@ class WanModel(ModelMixin, ConfigMixin):
695
  t,
696
  context,
697
  seq_len,
 
 
698
  clip_fea=None,
699
  y=None,
700
  freqs = None,
@@ -829,13 +936,23 @@ class WanModel(ModelMixin, ConfigMixin):
829
  self.previous_residual_cond = None
830
  ori_hidden_states = x_list[0].clone()
831
  # arguments
 
832
  kwargs = dict(
833
- # e=e0,
834
  seq_lens=seq_lens,
835
  grid_sizes=grid_sizes,
836
  freqs=freqs,
837
- # context=context,
838
  context_lens=context_lens)
 
 
 
 
 
 
 
 
 
 
 
839
  for block_idx, block in enumerate(self.blocks):
840
  offload.shared_state["layer"] = block_idx
841
  if callback != None:
@@ -852,9 +969,10 @@ class WanModel(ModelMixin, ConfigMixin):
852
  x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs)
853
 
854
  else:
855
- for i, (x, context) in enumerate(zip(x_list, context_list)):
856
- x_list[i] = block(x, context = context, e= e0, **kwargs)
857
  del x
 
858
 
859
  if self.enable_teacache:
860
  if joint_pass:
 
377
  return x
378
 
379
 
380
+
381
  WAN_CROSSATTENTION_CLASSES = {
382
  't2v_cross_attn': WanT2VCrossAttention,
383
  'i2v_cross_attn': WanI2VCrossAttention,
 
394
  window_size=(-1, -1),
395
  qk_norm=True,
396
  cross_attn_norm=False,
397
+ eps=1e-6,
398
+ block_id=None
399
+ ):
400
  super().__init__()
401
  self.dim = dim
402
  self.ffn_dim = ffn_dim
 
425
 
426
  # modulation
427
  self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
428
+ self.block_id = block_id
429
 
430
  def forward(
431
  self,
 
436
  freqs,
437
  context,
438
  context_lens,
439
+ hints= None,
440
+ context_scale=1.0,
441
  ):
442
  r"""
443
  Args:
 
486
  x.addcmul_(y, e[5])
487
 
488
 
489
+ if self.block_id is not None and hints != None:
490
+ if context_scale == 1:
491
+ x.add_(hints[self.block_id])
492
+ else:
493
+ x.add_(hints[self.block_id], alpha =context_scale)
494
+ return x
495
+
496
+ class VaceWanAttentionBlock(WanAttentionBlock):
497
+ def __init__(
498
+ self,
499
+ cross_attn_type,
500
+ dim,
501
+ ffn_dim,
502
+ num_heads,
503
+ window_size=(-1, -1),
504
+ qk_norm=True,
505
+ cross_attn_norm=False,
506
+ eps=1e-6,
507
+ block_id=0
508
+ ):
509
+ super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
510
+ self.block_id = block_id
511
+ if block_id == 0:
512
+ self.before_proj = nn.Linear(self.dim, self.dim)
513
+ nn.init.zeros_(self.before_proj.weight)
514
+ nn.init.zeros_(self.before_proj.bias)
515
+ self.after_proj = nn.Linear(self.dim, self.dim)
516
+ nn.init.zeros_(self.after_proj.weight)
517
+ nn.init.zeros_(self.after_proj.bias)
518
+
519
+ def forward(self, c, x, **kwargs):
520
+ # behold dbm magic !
521
+ if self.block_id == 0:
522
+ c = self.before_proj(c) + x
523
+ all_c = []
524
+ else:
525
+ all_c = c
526
+ c = all_c.pop(-1)
527
+ c = super().forward(c, **kwargs)
528
+ c_skip = self.after_proj(c)
529
+ all_c += [c_skip, c]
530
+ return all_c
531
+
532
  class Head(nn.Module):
533
 
534
  def __init__(self, dim, out_dim, patch_size, eps=1e-6):
 
589
 
590
  @register_to_config
591
  def __init__(self,
592
+ vace_layers=None,
593
+ vace_in_dim=None,
594
  model_type='t2v',
595
  patch_size=(1, 2, 2),
596
  text_len=512,
 
675
  self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
676
 
677
  # blocks
678
+ if vace_layers == None:
679
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
680
+ self.blocks = nn.ModuleList([
681
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
682
+ window_size, qk_norm, cross_attn_norm, eps)
683
+ for _ in range(num_layers)
684
+ ])
685
 
686
  # head
687
  self.head = Head(dim, out_dim, patch_size, eps)
 
694
  # initialize weights
695
  self.init_weights()
696
 
697
+ if vace_layers != None:
698
+ self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
699
+ self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
700
+
701
+ assert 0 in self.vace_layers
702
+ self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
703
+
704
+ # blocks
705
+ self.blocks = nn.ModuleList([
706
+ WanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
707
+ self.cross_attn_norm, self.eps,
708
+ block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
709
+ for i in range(self.num_layers)
710
+ ])
711
+
712
+ # vace blocks
713
+ self.vace_blocks = nn.ModuleList([
714
+ VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
715
+ self.cross_attn_norm, self.eps, block_id=i)
716
+ for i in self.vace_layers
717
+ ])
718
+
719
+ # vace patch embeddings
720
+ self.vace_patch_embedding = nn.Conv3d(
721
+ self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
722
+ )
723
+
724
 
725
  def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
726
  rescale_func = np.poly1d(self.coefficients)
 
763
  self.rel_l1_thresh = best_threshold
764
  print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
765
  return best_threshold
766
+
767
+ def forward_vace(
768
+ self,
769
+ x,
770
+ vace_context,
771
+ seq_len,
772
+ context,
773
+ e,
774
+ kwargs
775
+ ):
776
+ # embeddings
777
+ c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
778
+ c = [u.flatten(2).transpose(1, 2) for u in c]
779
+ if (len(c) == 1 and seq_len == c[0].size(1)):
780
+ c = c[0]
781
+ else:
782
+ c = torch.cat([
783
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
784
+ dim=1) for u in c
785
+ ])
786
+
787
+ # arguments
788
+ new_kwargs = dict(x=x)
789
+ new_kwargs.update(kwargs)
790
+
791
+ for block in self.vace_blocks:
792
+ c = block(c, context= context, e= e, **new_kwargs)
793
+ hints = c[:-1]
794
+
795
+ return hints
796
 
797
  def forward(
798
  self,
 
800
  t,
801
  context,
802
  seq_len,
803
+ vace_context = None,
804
+ vace_context_scale=1.0,
805
  clip_fea=None,
806
  y=None,
807
  freqs = None,
 
936
  self.previous_residual_cond = None
937
  ori_hidden_states = x_list[0].clone()
938
  # arguments
939
+
940
  kwargs = dict(
 
941
  seq_lens=seq_lens,
942
  grid_sizes=grid_sizes,
943
  freqs=freqs,
 
944
  context_lens=context_lens)
945
+
946
+ if vace_context == None:
947
+ hints_list = [None ] *len(x_list)
948
+ else:
949
+ hints_list = []
950
+ for x, context in zip(x_list, context_list) :
951
+ hints_list.append( self.forward_vace(x, vace_context, seq_len, context= context, e= e0, kwargs= kwargs))
952
+ del x, context
953
+ kwargs['context_scale'] = vace_context_scale
954
+
955
+
956
  for block_idx, block in enumerate(self.blocks):
957
  offload.shared_state["layer"] = block_idx
958
  if callback != None:
 
969
  x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs)
970
 
971
  else:
972
+ for i, (x, context, hints) in enumerate(zip(x_list, context_list, hints_list)):
973
+ x_list[i] = block(x, context = context, hints= hints, e= e0, **kwargs)
974
  del x
975
+ del context, hints
976
 
977
  if self.enable_teacache:
978
  if joint_pass:
wan/text2video.py CHANGED
@@ -13,7 +13,9 @@ import torch
13
  import torch.cuda.amp as amp
14
  import torch.distributed as dist
15
  from tqdm import tqdm
16
-
 
 
17
  from .distributed.fsdp import shard_model
18
  from .modules.model import WanModel
19
  from .modules.t5 import T5EncoderModel
@@ -22,6 +24,7 @@ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
22
  get_sampling_sigmas, retrieve_timesteps)
23
  from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
24
  from wan.modules.posemb_layers import get_rotary_pos_embed
 
25
 
26
 
27
  def optimized_scale(positive_flat, negative_flat):
@@ -105,8 +108,6 @@ class WanT2V:
105
 
106
  self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False)
107
 
108
-
109
-
110
  self.model.eval().requires_grad_(False)
111
 
112
  if use_usp:
@@ -132,8 +133,148 @@ class WanT2V:
132
 
133
  self.sample_neg_prompt = config.sample_neg_prompt
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def generate(self,
136
  input_prompt,
 
 
 
 
137
  size=(1280, 720),
138
  frame_num=81,
139
  shift=5.0,
@@ -187,14 +328,6 @@ class WanT2V:
187
  - W: Frame width from size)
188
  """
189
  # preprocess
190
- F = frame_num
191
- target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
192
- size[1] // self.vae_stride[1],
193
- size[0] // self.vae_stride[2])
194
-
195
- seq_len = math.ceil((target_shape[2] * target_shape[3]) /
196
- (self.patch_size[1] * self.patch_size[2]) *
197
- target_shape[1] / self.sp_size) * self.sp_size
198
 
199
  if n_prompt == "":
200
  n_prompt = self.sample_neg_prompt
@@ -213,6 +346,29 @@ class WanT2V:
213
  context_null = self.text_encoder([n_prompt], torch.device('cpu'))
214
  context = [t.to(self.device) for t in context]
215
  context_null = [t.to(self.device) for t in context_null]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  noise = [
218
  torch.randn(
@@ -261,10 +417,12 @@ class WanT2V:
261
  arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
262
  arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
263
  arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
 
 
 
 
 
264
 
265
- # arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
266
- # arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
267
- # arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
268
  if self.model.enable_teacache:
269
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
270
  if callback != None:
@@ -281,7 +439,7 @@ class WanT2V:
281
  # self.model.to(self.device)
282
  if joint_pass:
283
  noise_pred_cond, noise_pred_uncond = self.model(
284
- latent_model_input, t=timestep,current_step=i, slg_layers=slg_layers_local, **arg_both)
285
  if self._interrupt:
286
  return None
287
  else:
@@ -329,7 +487,11 @@ class WanT2V:
329
  self.model.cpu()
330
  torch.cuda.empty_cache()
331
  if self.rank == 0:
332
- videos = self.vae.decode(x0, VAE_tile_size)
 
 
 
 
333
 
334
 
335
  del noise, latents
 
13
  import torch.cuda.amp as amp
14
  import torch.distributed as dist
15
  from tqdm import tqdm
16
+ from PIL import Image
17
+ import torchvision.transforms.functional as TF
18
+ import torch.nn.functional as F
19
  from .distributed.fsdp import shard_model
20
  from .modules.model import WanModel
21
  from .modules.t5 import T5EncoderModel
 
24
  get_sampling_sigmas, retrieve_timesteps)
25
  from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
26
  from wan.modules.posemb_layers import get_rotary_pos_embed
27
+ from .utils.vace_preprocessor import VaceVideoProcessor
28
 
29
 
30
  def optimized_scale(positive_flat, negative_flat):
 
108
 
109
  self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False)
110
 
 
 
111
  self.model.eval().requires_grad_(False)
112
 
113
  if use_usp:
 
133
 
134
  self.sample_neg_prompt = config.sample_neg_prompt
135
 
136
+ if "Vace" in model_filename:
137
+ self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
138
+ min_area=480*832,
139
+ max_area=480*832,
140
+ min_fps=config.sample_fps,
141
+ max_fps=config.sample_fps,
142
+ zero_start=True,
143
+ seq_len=32760,
144
+ keep_last=True)
145
+
146
+ def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
147
+ if ref_images is None:
148
+ ref_images = [None] * len(frames)
149
+ else:
150
+ assert len(frames) == len(ref_images)
151
+
152
+ if masks is None:
153
+ latents = self.vae.encode(frames, tile_size = tile_size)
154
+ else:
155
+ inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
156
+ reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
157
+ inactive = self.vae.encode(inactive, tile_size = tile_size)
158
+ reactive = self.vae.encode(reactive, tile_size = tile_size)
159
+ latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
160
+
161
+ cat_latents = []
162
+ for latent, refs in zip(latents, ref_images):
163
+ if refs is not None:
164
+ if masks is None:
165
+ ref_latent = self.vae.encode(refs, tile_size = tile_size)
166
+ else:
167
+ ref_latent = self.vae.encode(refs, tile_size = tile_size)
168
+ ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
169
+ assert all([x.shape[1] == 1 for x in ref_latent])
170
+ latent = torch.cat([*ref_latent, latent], dim=1)
171
+ cat_latents.append(latent)
172
+ return cat_latents
173
+
174
+ def vace_encode_masks(self, masks, ref_images=None):
175
+ if ref_images is None:
176
+ ref_images = [None] * len(masks)
177
+ else:
178
+ assert len(masks) == len(ref_images)
179
+
180
+ result_masks = []
181
+ for mask, refs in zip(masks, ref_images):
182
+ c, depth, height, width = mask.shape
183
+ new_depth = int((depth + 3) // self.vae_stride[0])
184
+ height = 2 * (int(height) // (self.vae_stride[1] * 2))
185
+ width = 2 * (int(width) // (self.vae_stride[2] * 2))
186
+
187
+ # reshape
188
+ mask = mask[0, :, :, :]
189
+ mask = mask.view(
190
+ depth, height, self.vae_stride[1], width, self.vae_stride[1]
191
+ ) # depth, height, 8, width, 8
192
+ mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
193
+ mask = mask.reshape(
194
+ self.vae_stride[1] * self.vae_stride[2], depth, height, width
195
+ ) # 8*8, depth, height, width
196
+
197
+ # interpolation
198
+ mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
199
+
200
+ if refs is not None:
201
+ length = len(refs)
202
+ mask_pad = torch.zeros_like(mask[:, :length, :, :])
203
+ mask = torch.cat((mask_pad, mask), dim=1)
204
+ result_masks.append(mask)
205
+ return result_masks
206
+
207
+ def vace_latent(self, z, m):
208
+ return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
209
+
210
+ def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, trim_video= 0):
211
+ image_sizes = []
212
+ for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
213
+ if sub_src_mask is not None and sub_src_video is not None:
214
+ src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video)
215
+ src_video[i] = src_video[i].to(device)
216
+ src_mask[i] = src_mask[i].to(device)
217
+ src_video_shape = src_video[i].shape
218
+ if src_video_shape[1] != num_frames:
219
+ src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
220
+ src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
221
+
222
+ src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
223
+ image_sizes.append(src_video[i].shape[2:])
224
+ elif sub_src_video is None:
225
+ src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
226
+ src_mask[i] = torch.ones_like(src_video[i], device=device)
227
+ image_sizes.append(image_size)
228
+ else:
229
+ src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video)
230
+ src_video[i] = src_video[i].to(device)
231
+ src_video_shape = src_video[i].shape
232
+ if src_video_shape[1] != num_frames:
233
+ src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
234
+ src_mask[i] = torch.ones_like(src_video[i], device=device)
235
+ image_sizes.append(src_video[i].shape[2:])
236
+
237
+ for i, ref_images in enumerate(src_ref_images):
238
+ if ref_images is not None:
239
+ image_size = image_sizes[i]
240
+ for j, ref_img in enumerate(ref_images):
241
+ if ref_img is not None:
242
+ ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
243
+ if ref_img.shape[-2:] != image_size:
244
+ canvas_height, canvas_width = image_size
245
+ ref_height, ref_width = ref_img.shape[-2:]
246
+ white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
247
+ scale = min(canvas_height / ref_height, canvas_width / ref_width)
248
+ new_height = int(ref_height * scale)
249
+ new_width = int(ref_width * scale)
250
+ resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
251
+ top = (canvas_height - new_height) // 2
252
+ left = (canvas_width - new_width) // 2
253
+ white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
254
+ ref_img = white_canvas
255
+ src_ref_images[i][j] = ref_img.to(device)
256
+ return src_video, src_mask, src_ref_images
257
+
258
+ def decode_latent(self, zs, ref_images=None, tile_size= 0 ):
259
+ if ref_images is None:
260
+ ref_images = [None] * len(zs)
261
+ else:
262
+ assert len(zs) == len(ref_images)
263
+
264
+ trimed_zs = []
265
+ for z, refs in zip(zs, ref_images):
266
+ if refs is not None:
267
+ z = z[:, len(refs):, :, :]
268
+ trimed_zs.append(z)
269
+
270
+ return self.vae.decode(trimed_zs, tile_size= tile_size)
271
+
272
  def generate(self,
273
  input_prompt,
274
+ input_frames= None,
275
+ input_masks = None,
276
+ input_ref_images = None,
277
+ context_scale=1.0,
278
  size=(1280, 720),
279
  frame_num=81,
280
  shift=5.0,
 
328
  - W: Frame width from size)
329
  """
330
  # preprocess
 
 
 
 
 
 
 
 
331
 
332
  if n_prompt == "":
333
  n_prompt = self.sample_neg_prompt
 
346
  context_null = self.text_encoder([n_prompt], torch.device('cpu'))
347
  context = [t.to(self.device) for t in context]
348
  context_null = [t.to(self.device) for t in context_null]
349
+
350
+ if input_frames != None:
351
+ # vace context encode
352
+ input_frames = [u.to(self.device) for u in input_frames]
353
+ input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
354
+ input_masks = [u.to(self.device) for u in input_masks]
355
+
356
+ z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size)
357
+ m0 = self.vace_encode_masks(input_masks, input_ref_images)
358
+ z = self.vace_latent(z0, m0)
359
+
360
+ target_shape = list(z0[0].shape)
361
+ target_shape[0] = int(target_shape[0] / 2)
362
+ else:
363
+ F = frame_num
364
+ target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
365
+ size[1] // self.vae_stride[1],
366
+ size[0] // self.vae_stride[2])
367
+
368
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) /
369
+ (self.patch_size[1] * self.patch_size[2]) *
370
+ target_shape[1] / self.sp_size) * self.sp_size
371
+
372
 
373
  noise = [
374
  torch.randn(
 
417
  arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
418
  arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
419
  arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
420
+ if input_frames != None:
421
+ vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale}
422
+ arg_c.update(vace_dict)
423
+ arg_null.update(vace_dict)
424
+ arg_both.update(vace_dict)
425
 
 
 
 
426
  if self.model.enable_teacache:
427
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
428
  if callback != None:
 
439
  # self.model.to(self.device)
440
  if joint_pass:
441
  noise_pred_cond, noise_pred_uncond = self.model(
442
+ latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
443
  if self._interrupt:
444
  return None
445
  else:
 
487
  self.model.cpu()
488
  torch.cuda.empty_cache()
489
  if self.rank == 0:
490
+
491
+ if input_frames == None:
492
+ videos = self.vae.decode(x0, VAE_tile_size)
493
+ else:
494
+ videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
495
 
496
 
497
  del noise, latents
wan/utils/utils.py CHANGED
@@ -3,21 +3,70 @@ import argparse
3
  import binascii
4
  import os
5
  import os.path as osp
 
 
6
 
7
  import imageio
8
  import torch
 
9
  import torchvision
10
  from PIL import Image
11
  import numpy as np
 
 
12
 
13
  __all__ = ['cache_video', 'cache_image', 'str2bool']
14
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def resize_lanczos(img, h, w):
16
  img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
17
  img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
18
  return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def rand_name(length=8, suffix=''):
22
  name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
23
  if suffix:
 
3
  import binascii
4
  import os
5
  import os.path as osp
6
+ import torchvision.transforms.functional as TF
7
+ import torch.nn.functional as F
8
 
9
  import imageio
10
  import torch
11
+ import decord
12
  import torchvision
13
  from PIL import Image
14
  import numpy as np
15
+ from rembg import remove, new_session
16
+
17
 
18
  __all__ = ['cache_video', 'cache_image', 'str2bool']
19
 
20
+
21
+
22
+ from PIL import Image
23
+
24
+ def get_video_frame(file_name, frame_no):
25
+ decord.bridge.set_bridge('torch')
26
+ reader = decord.VideoReader(file_name)
27
+
28
+ frame = reader.get_batch([frame_no]).squeeze(0)
29
+ img = Image.fromarray(frame.numpy().astype(np.uint8))
30
+ return img
31
+
32
  def resize_lanczos(img, h, w):
33
  img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
34
  img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
35
  return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
36
 
37
 
38
+ def remove_background(img, session=None):
39
+ if session ==None:
40
+ session = new_session()
41
+ img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
42
+ img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
43
+ return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
44
+
45
+
46
+
47
+
48
+ def resize_and_remove_background(img_list, canvas_width, canvas_height, rm_background ):
49
+ if rm_background:
50
+ session = new_session()
51
+
52
+ output_list =[]
53
+ for img in img_list:
54
+ width, height = img.size
55
+ white_canvas = np.full( (canvas_height, canvas_width, 3), 255, dtype= np.uint8 )
56
+ scale = min(canvas_height / height, canvas_width / width)
57
+ new_height = int(height * scale)
58
+ new_width = int(width * scale)
59
+ resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
60
+ if rm_background:
61
+ resized_image = remove(resized_image, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
62
+ top = (canvas_height - new_height) // 2
63
+ left = (canvas_width - new_width) // 2
64
+ white_canvas[top:top + new_height, left:left + new_width, :] = np.array(resized_image)
65
+ img = Image.fromarray(white_canvas)
66
+ output_list.append(img)
67
+ return output_list
68
+
69
+
70
  def rand_name(length=8, suffix=''):
71
  name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
72
  if suffix:
wan/utils/vace_preprocessor.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torchvision.transforms.functional as TF
8
+
9
+
10
+ class VaceImageProcessor(object):
11
+ def __init__(self, downsample=None, seq_len=None):
12
+ self.downsample = downsample
13
+ self.seq_len = seq_len
14
+
15
+ def _pillow_convert(self, image, cvt_type='RGB'):
16
+ if image.mode != cvt_type:
17
+ if image.mode == 'P':
18
+ image = image.convert(f'{cvt_type}A')
19
+ if image.mode == f'{cvt_type}A':
20
+ bg = Image.new(cvt_type,
21
+ size=(image.width, image.height),
22
+ color=(255, 255, 255))
23
+ bg.paste(image, (0, 0), mask=image)
24
+ image = bg
25
+ else:
26
+ image = image.convert(cvt_type)
27
+ return image
28
+
29
+ def _load_image(self, img_path):
30
+ if img_path is None or img_path == '':
31
+ return None
32
+ img = Image.open(img_path)
33
+ img = self._pillow_convert(img)
34
+ return img
35
+
36
+ def _resize_crop(self, img, oh, ow, normalize=True):
37
+ """
38
+ Resize, center crop, convert to tensor, and normalize.
39
+ """
40
+ # resize and crop
41
+ iw, ih = img.size
42
+ if iw != ow or ih != oh:
43
+ # resize
44
+ scale = max(ow / iw, oh / ih)
45
+ img = img.resize(
46
+ (round(scale * iw), round(scale * ih)),
47
+ resample=Image.Resampling.LANCZOS
48
+ )
49
+ assert img.width >= ow and img.height >= oh
50
+
51
+ # center crop
52
+ x1 = (img.width - ow) // 2
53
+ y1 = (img.height - oh) // 2
54
+ img = img.crop((x1, y1, x1 + ow, y1 + oh))
55
+
56
+ # normalize
57
+ if normalize:
58
+ img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1)
59
+ return img
60
+
61
+ def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs):
62
+ return self._resize_crop(img, oh, ow, normalize)
63
+
64
+ def load_image(self, data_key, **kwargs):
65
+ return self.load_image_batch(data_key, **kwargs)
66
+
67
+ def load_image_pair(self, data_key, data_key2, **kwargs):
68
+ return self.load_image_batch(data_key, data_key2, **kwargs)
69
+
70
+ def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs):
71
+ seq_len = self.seq_len if seq_len is None else seq_len
72
+ imgs = []
73
+ for data_key in data_key_batch:
74
+ img = self._load_image(data_key)
75
+ imgs.append(img)
76
+ w, h = imgs[0].size
77
+ dh, dw = self.downsample[1:]
78
+
79
+ # compute output size
80
+ scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw))))
81
+ oh = int(h * scale) // dh * dh
82
+ ow = int(w * scale) // dw * dw
83
+ assert (oh // dh) * (ow // dw) <= seq_len
84
+ imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs]
85
+ return *imgs, (oh, ow)
86
+
87
+
88
+ class VaceVideoProcessor(object):
89
+ def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs):
90
+ self.downsample = downsample
91
+ self.min_area = min_area
92
+ self.max_area = max_area
93
+ self.min_fps = min_fps
94
+ self.max_fps = max_fps
95
+ self.zero_start = zero_start
96
+ self.keep_last = keep_last
97
+ self.seq_len = seq_len
98
+ assert seq_len >= min_area / (self.downsample[1] * self.downsample[2])
99
+
100
+ @staticmethod
101
+ def resize_crop(video: torch.Tensor, oh: int, ow: int):
102
+ """
103
+ Resize, center crop and normalize for decord loaded video (torch.Tensor type)
104
+
105
+ Parameters:
106
+ video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C)
107
+ oh - target height (int)
108
+ ow - target width (int)
109
+
110
+ Returns:
111
+ The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W)
112
+
113
+ Raises:
114
+ """
115
+ # permute ([t, h, w, c] -> [t, c, h, w])
116
+ video = video.permute(0, 3, 1, 2)
117
+
118
+ # resize and crop
119
+ ih, iw = video.shape[2:]
120
+ if ih != oh or iw != ow:
121
+ # resize
122
+ scale = max(ow / iw, oh / ih)
123
+ video = F.interpolate(
124
+ video,
125
+ size=(round(scale * ih), round(scale * iw)),
126
+ mode='bicubic',
127
+ antialias=True
128
+ )
129
+ assert video.size(3) >= ow and video.size(2) >= oh
130
+
131
+ # center crop
132
+ x1 = (video.size(3) - ow) // 2
133
+ y1 = (video.size(2) - oh) // 2
134
+ video = video[:, :, y1:y1 + oh, x1:x1 + ow]
135
+
136
+ # permute ([t, c, h, w] -> [c, t, h, w]) and normalize
137
+ video = video.transpose(0, 1).float().div_(127.5).sub_(1.)
138
+ return video
139
+
140
+ def _video_preprocess(self, video, oh, ow):
141
+ return self.resize_crop(video, oh, ow)
142
+
143
+ def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng):
144
+ target_fps = min(fps, self.max_fps)
145
+ duration = frame_timestamps[-1].mean()
146
+ x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
147
+ h, w = y2 - y1, x2 - x1
148
+ ratio = h / w
149
+ df, dh, dw = self.downsample
150
+
151
+ # min/max area of the [latent video]
152
+ min_area_z = self.min_area / (dh * dw)
153
+ max_area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
154
+
155
+ # sample a frame number of the [latent video]
156
+ rand_area_z = np.square(np.power(2, rng.uniform(
157
+ np.log2(np.sqrt(min_area_z)),
158
+ np.log2(np.sqrt(max_area_z))
159
+ )))
160
+ of = min(
161
+ (int(duration * target_fps) - 1) // df + 1,
162
+ int(self.seq_len / rand_area_z)
163
+ )
164
+
165
+ # deduce target shape of the [latent video]
166
+ target_area_z = min(max_area_z, int(self.seq_len / of))
167
+ oh = round(np.sqrt(target_area_z * ratio))
168
+ ow = int(target_area_z / oh)
169
+ of = (of - 1) * df + 1
170
+ oh *= dh
171
+ ow *= dw
172
+
173
+ # sample frame ids
174
+ target_duration = of / target_fps
175
+ begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration)
176
+ timestamps = np.linspace(begin, begin + target_duration, of)
177
+ frame_ids = np.argmax(np.logical_and(
178
+ timestamps[:, None] >= frame_timestamps[None, :, 0],
179
+ timestamps[:, None] < frame_timestamps[None, :, 1]
180
+ ), axis=1).tolist()
181
+ return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
182
+
183
+ def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng, max_frames= 0):
184
+ import math
185
+ target_fps = self.max_fps
186
+ video_duration = frame_timestamps[-1][1]
187
+ video_frame_duration = 1 /fps
188
+ target_frame_duration = 1 / target_fps
189
+
190
+ cur_time = 0
191
+ target_time = 0
192
+ frame_no = 0
193
+ frame_ids =[]
194
+ for i in range(max_frames):
195
+ add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration )
196
+ frame_no += add_frames_count
197
+ frame_ids.append(frame_no)
198
+ cur_time += add_frames_count * video_frame_duration
199
+ target_time += target_frame_duration
200
+ if cur_time > video_duration:
201
+ break
202
+
203
+ x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
204
+ h, w = y2 - y1, x2 - x1
205
+ ratio = h / w
206
+ df, dh, dw = self.downsample
207
+ seq_len = self.seq_len
208
+ # min/max area of the [latent video]
209
+ min_area_z = self.min_area / (dh * dw)
210
+ # max_area_z = min(seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
211
+ max_area_z = min_area_z # workaround bug
212
+ # sample a frame number of the [latent video]
213
+ rand_area_z = np.square(np.power(2, rng.uniform(
214
+ np.log2(np.sqrt(min_area_z)),
215
+ np.log2(np.sqrt(max_area_z))
216
+ )))
217
+
218
+ seq_len = max_area_z * ((max_frames- 1) // df +1)
219
+
220
+ # of = min(
221
+ # (len(frame_ids) - 1) // df + 1,
222
+ # int(seq_len / rand_area_z)
223
+ # )
224
+ of = (len(frame_ids) - 1) // df + 1
225
+
226
+
227
+ # deduce target shape of the [latent video]
228
+ # target_area_z = min(max_area_z, int(seq_len / of))
229
+ target_area_z = max_area_z
230
+ oh = round(np.sqrt(target_area_z * ratio))
231
+ ow = int(target_area_z / oh)
232
+ of = (of - 1) * df + 1
233
+ oh *= dh
234
+ ow *= dw
235
+
236
+ return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
237
+
238
+ def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng, max_frames= 0):
239
+ if self.keep_last:
240
+ return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng, max_frames= max_frames)
241
+ else:
242
+ return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng, max_frames= max_frames)
243
+
244
+ def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
245
+ return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
246
+
247
+ def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
248
+ return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
249
+
250
+ def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, **kwargs):
251
+ rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
252
+ # read video
253
+ import decord
254
+ decord.bridge.set_bridge('torch')
255
+ readers = []
256
+ for data_k in data_key_batch:
257
+ reader = decord.VideoReader(data_k)
258
+ readers.append(reader)
259
+
260
+ fps = readers[0].get_avg_fps()
261
+ length = min([len(r) for r in readers])
262
+ frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
263
+ frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
264
+ # # frame_timestamps = frame_timestamps[ :max_frames]
265
+ # if trim_video > 0:
266
+ # frame_timestamps = frame_timestamps[ :trim_video]
267
+ max_frames = min(max_frames, trim_video) if trim_video > 0 else max_frames
268
+ h, w = readers[0].next().shape[:2]
269
+ frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng, max_frames=max_frames)
270
+
271
+ # preprocess video
272
+ videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
273
+ videos = [self._video_preprocess(video, oh, ow) for video in videos]
274
+ return *videos, frame_ids, (oh, ow), fps
275
+ # return videos if len(videos) > 1 else videos[0]
276
+
277
+
278
+ def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device):
279
+ for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
280
+ if sub_src_video is None and sub_src_mask is None:
281
+ src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
282
+ src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device)
283
+ for i, ref_images in enumerate(src_ref_images):
284
+ if ref_images is not None:
285
+ for j, ref_img in enumerate(ref_images):
286
+ if ref_img is not None and ref_img.shape[-2:] != image_size:
287
+ canvas_height, canvas_width = image_size
288
+ ref_height, ref_width = ref_img.shape[-2:]
289
+ white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
290
+ scale = min(canvas_height / ref_height, canvas_width / ref_width)
291
+ new_height = int(ref_height * scale)
292
+ new_width = int(ref_width * scale)
293
+ resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
294
+ top = (canvas_height - new_height) // 2
295
+ left = (canvas_width - new_width) // 2
296
+ white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
297
+ src_ref_images[i][j] = white_canvas
298
+ return src_video, src_mask, src_ref_images