DeepBeepMeep commited on
Commit
cb1518b
·
1 Parent(s): 90fc871

Queue adaptations

Browse files
Files changed (4) hide show
  1. gradio_server.py +1015 -550
  2. wan/image2video.py +45 -18
  3. wan/modules/model.py +3 -4
  4. wan/text2video.py +2 -2
gradio_server.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import time
 
3
  import threading
4
  import argparse
5
  from mmgp import offload, safetensors2, profile_type
@@ -33,15 +34,12 @@ mmgp_version = version("mmgp")
33
  if mmgp_version != target_mmgp_version:
34
  print(f"Incorrect version of mmgp ({mmgp_version}), version {target_mmgp_version} is needed. Please upgrade with the command 'pip install -r requirements.txt'")
35
  exit()
36
- queue = []
37
  lock = threading.Lock()
38
  current_task_id = None
39
  task_id = 0
40
- progress_tracker = {}
41
- tracker_lock = threading.Lock()
42
- file_list = []
43
  last_model_type = None
44
- last_status_string = ""
45
 
46
  def format_time(seconds):
47
  if seconds < 60:
@@ -77,37 +75,6 @@ def pil_to_base64_uri(pil_image, format="png", quality=75):
77
  print(f"Error converting PIL to base64: {e}")
78
  return None
79
 
80
- def runner():
81
- global current_task_id
82
- while True:
83
- with lock:
84
- for item in queue:
85
- task_id_runner = item['id']
86
- with tracker_lock:
87
- progress = progress_tracker.get(task_id_runner, {})
88
-
89
- if item['state'] == "Processing":
90
- current_step = progress.get('current_step', 0)
91
- total_steps = progress.get('total_steps', 0)
92
- elapsed = time.time() - progress.get('start_time', time.time())
93
- status = progress.get('status', "")
94
- repeats = progress.get("repeats", "0/0")
95
- item.update({
96
- 'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
97
- 'steps': f"{current_step}/{total_steps}",
98
- 'time': format_time(elapsed),
99
- 'repeats': f"{repeats}",
100
- 'status': f"{status}"
101
- })
102
- if not any(item['state'] == "Processing" for item in queue):
103
- for item in queue:
104
- if item['state'] == "Queued":
105
- item['status'] = "Processing"
106
- item['state'] = "Processing"
107
- current_task_id = item['id']
108
- threading.Thread(target=process_task, args=(item,)).start()
109
- break
110
- time.sleep(1)
111
 
112
  def process_prompt_and_add_tasks(
113
  prompt,
@@ -137,161 +104,290 @@ def process_prompt_and_add_tasks(
137
  slg_end,
138
  cfg_star_switch,
139
  cfg_zero_step,
140
- state_arg,
141
  image2video
142
  ):
143
-
144
- if state_arg.get("validate_success",0) != 1:
145
- print("Validation failed, not adding tasks.")
146
  return
 
 
147
  if len(prompt) ==0:
148
  return
149
  prompt, errors = prompt_parser.process_template(prompt)
150
  if len(errors) > 0:
151
- print("Error processing prompt template: " + errors)
152
  return
153
  prompts = prompt.replace("\r", "").split("\n")
154
  prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
155
  if len(prompts) ==0:
156
  return
157
 
158
- for single_prompt in prompts:
159
- task_params = (
160
- single_prompt,
161
- negative_prompt,
162
- resolution,
163
- video_length,
164
- seed,
165
- num_inference_steps,
166
- guidance_scale,
167
- flow_shift,
168
- embedded_guidance_scale,
169
- repeat_generation,
170
- multi_images_gen_type,
171
- tea_cache,
172
- tea_cache_start_step_perc,
173
- loras_choices,
174
- loras_mult_choices,
175
- image_prompt_type,
176
- image_to_continue,
177
- image_to_end,
178
- video_to_continue,
179
- max_frames,
180
- RIFLEx_setting,
181
- slg_switch,
182
- slg_layers,
183
- slg_start,
184
- slg_end,
185
- cfg_star_switch,
186
- cfg_zero_step,
187
- state_arg,
188
- image2video
189
- )
190
- add_video_task(*task_params)
191
- return update_queue_data()
192
 
193
- def process_task(task):
194
- try:
195
- task_id, *params = task['params']
196
- generate_video(task_id, *params)
197
- finally:
198
- with lock:
199
- queue[:] = [item for item in queue if item['id'] != task['id']]
200
- with tracker_lock:
201
- if task['id'] in progress_tracker:
202
- del progress_tracker[task['id']]
203
-
204
- def add_video_task(*params):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  global task_id
206
- with lock:
207
- task_id += 1
208
- current_task_id = task_id
209
- start_image_data = params[16] if len(params) > 16 else None
210
- end_image_data = params[17] if len(params) > 17 else None
211
-
212
- queue.append({
213
- "id": current_task_id,
214
- "params": (current_task_id,) + params,
215
- "state": "Queued",
216
- "status": "Queued",
217
- "repeats": "0/0",
218
- "progress": "0.0%",
219
- "steps": f"0/{params[5]}",
220
- "time": "--",
221
- "prompt": params[0],
222
- "start_image_data": start_image_data,
223
- "end_image_data": end_image_data
224
- })
225
- return update_queue_data()
226
-
227
- def move_up(selected_indices):
 
 
228
  if not selected_indices or len(selected_indices) == 0:
229
- return update_queue_data()
230
  idx = selected_indices[0]
231
  if isinstance(idx, list):
232
  idx = idx[0]
233
  idx = int(idx)
234
  with lock:
235
  if idx > 0:
 
236
  queue[idx], queue[idx-1] = queue[idx-1], queue[idx]
237
- return update_queue_data()
238
 
239
- def move_down(selected_indices):
240
  if not selected_indices or len(selected_indices) == 0:
241
- return update_queue_data()
242
  idx = selected_indices[0]
243
  if isinstance(idx, list):
244
  idx = idx[0]
245
  idx = int(idx)
246
  with lock:
 
247
  if idx < len(queue)-1:
248
  queue[idx], queue[idx+1] = queue[idx+1], queue[idx]
249
- return update_queue_data()
250
 
251
- def remove_task(selected_indices):
252
  if not selected_indices or len(selected_indices) == 0:
253
- return update_queue_data()
254
  idx = selected_indices[0]
255
  if isinstance(idx, list):
256
  idx = idx[0]
257
- idx = int(idx)
258
  with lock:
259
  if idx < len(queue):
260
  if idx == 0:
261
  wan_model._interrupt = True
262
  del queue[idx]
263
- return update_queue_data()
264
 
265
- def update_queue_data():
266
- with lock:
267
- data = []
268
- for item in queue:
269
- truncated_prompt = (item['prompt'][:97] + '...') if len(item['prompt']) > 100 else item['prompt']
270
- full_prompt = item['prompt'].replace('"', '&quot;')
271
- prompt_cell = f'<span title="{full_prompt}">{truncated_prompt}</span>'
272
- start_img_uri = pil_to_base64_uri(item.get('start_image_data'), format="jpeg", quality=70)
273
- end_img_uri = pil_to_base64_uri(item.get('end_image_data'), format="jpeg", quality=70)
274
- thumbnail_size = "50px"
275
- start_img_md = ""
276
- end_img_md = ""
277
- if start_img_uri:
278
- start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
279
- if end_img_uri:
280
- end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
281
- data.append([
282
- item.get('status', "Starting"),
283
- item.get('repeats', "0/0"),
284
- item.get('progress', "0.0%"),
285
- item.get('steps', ''),
286
- item.get('time', '--'),
287
- prompt_cell,
288
- start_img_md,
289
- end_img_md,
290
- "↑",
291
- "",
292
- ""
293
- ])
294
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
  def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True):
297
  bar_class = "progress-bar-custom idle" if is_idle else "progress-bar-custom"
@@ -306,35 +402,35 @@ def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True):
306
  """
307
  return html
308
 
309
- def refresh_progress():
310
- global current_task_id, progress_tracker, last_status_string
311
- task_id_to_check = current_task_id
312
- is_idle = True
313
- status_string = "Starting..."
314
- progress_percent = 0.0
315
- html_content = ""
316
-
317
- with tracker_lock:
318
- with lock:
319
- processing_or_queued = any(item['state'] in ["Processing", "Queued"] for item in queue)
320
- if task_id_to_check is not None:
321
- progress_data = progress_tracker.get(task_id_to_check)
322
- if progress_data:
323
- is_idle = False
324
- current_step = progress_data.get('current_step', 0)
325
- total_steps = progress_data.get('total_steps', 0)
326
- status = progress_data.get('status', "Starting...")
327
- repeats = progress_data.get("repeats", "0/0")
328
-
329
- if total_steps > 0:
330
- progress_float = min(1.0, max(0.0, float(current_step) / float(total_steps)))
331
- progress_percent = progress_float * 100
332
- status_string = f"{status} [{repeats}] - {progress_percent:.1f}% complete ({current_step}/{total_steps} steps)"
333
- else:
334
- progress_percent = 0.0
335
- status_string = f"{status} [{repeats}] - Initializing..."
336
- html_content = create_html_progress_bar(progress_percent, status_string, is_idle)
337
- return gr.update(value=html_content)
338
 
339
  def update_generation_status(html_content):
340
  if(html_content):
@@ -736,7 +832,6 @@ if args.i2v_1_3B:
736
 
737
  only_allow_edit_in_advanced = False
738
  lora_preselected_preset = args.lora_preset
739
- lora_preselected_preset_for_i2v = use_image2video
740
  # if args.fast : #or args.fastest
741
  # transformer_filename_t2v = transformer_choices_t2v[2]
742
  # attention_mode="sage2" if "sage2" in attention_modes_supported else "sage"
@@ -749,7 +844,6 @@ if args.compile: #args.fastest or
749
  lock_ui_compile = True
750
 
751
  model_filename = ""
752
- lora_model_filename = ""
753
  #attention_mode="sage"
754
  #attention_mode="sage2"
755
  #attention_mode="flash"
@@ -758,15 +852,12 @@ lora_model_filename = ""
758
  # compile = "transformer"
759
 
760
  def preprocess_loras(sd):
761
- if not use_image2video:
762
- return sd
763
-
764
- new_sd = {}
765
  first = next(iter(sd), None)
766
  if first == None:
767
  return sd
768
- if not first.startswith("lora_unet_"):
769
  return sd
 
770
  print("Converting Lora Safetensors format to Lora Diffusers format")
771
  alphas = {}
772
  repl_list = ["cross_attn", "self_attn", "ffn"]
@@ -845,14 +936,14 @@ download_models(transformer_filename_i2v if use_image2video else transformer_fil
845
  def sanitize_file_name(file_name, rep =""):
846
  return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep)
847
 
848
- def extract_preset(lset_name, loras):
849
  loras_choices = []
850
  loras_choices_files = []
851
  loras_mult_choices = ""
852
  prompt =""
853
  full_prompt =""
854
  lset_name = sanitize_file_name(lset_name)
855
- lora_dir = get_lora_dir(use_image2video)
856
  if not lset_name.endswith(".lset"):
857
  lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" )
858
  else:
@@ -923,7 +1014,7 @@ def setup_loras(i2v, transformer, lora_dir, lora_preselected_preset, split_line
923
  if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
924
  raise Exception(f"Unknown preset '{lora_preselected_preset}'")
925
  default_lora_preset = lora_preselected_preset
926
- default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, _ , error = extract_preset(default_lora_preset, loras)
927
  if len(error) > 0:
928
  print(error[:200])
929
  return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
@@ -1010,7 +1101,7 @@ def load_models(i2v):
1010
  # kwargs["partialPinning"] = True
1011
  elif profile == 3:
1012
  kwargs["budgets"] = { "*" : "70%" }
1013
- offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", **kwargs)
1014
  if len(args.gpu) > 0:
1015
  torch.set_default_device(args.gpu)
1016
 
@@ -1087,7 +1178,7 @@ def apply_changes( state,
1087
  if gen_in_progress:
1088
  yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
1089
  return
1090
- global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
1091
  server_config = {"attention_mode" : attention_choice,
1092
  "transformer_filename": transformer_choices_t2v[transformer_t2v_choice],
1093
  "transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice],
@@ -1152,44 +1243,152 @@ def save_video(final_frames, output_path, fps=24):
1152
  final_frames = (final_frames * 255).astype(np.uint8)
1153
  ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
1154
 
1155
- def build_callback(taskid, state, pipe, num_inference_steps, repeats):
1156
- start_time = time.time()
1157
- def update_progress(step_idx, _):
1158
- with tracker_lock:
1159
- step_idx += 1
1160
- if state.get("abort", False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1161
  # pipe._interrupt = True
1162
- phase = "Aborting"
1163
  elif step_idx == num_inference_steps:
1164
- phase = "VAE Decoding"
1165
  else:
1166
- phase = "Denoising"
1167
- elapsed = time.time() - start_time
1168
- progress_tracker[taskid] = {
1169
- 'current_step': step_idx,
1170
- 'total_steps': num_inference_steps,
1171
- 'start_time': start_time,
1172
- 'last_update': time.time(),
1173
- 'repeats': repeats,
1174
- 'status': phase
1175
- }
1176
- return update_progress
1177
-
1178
- def refresh_gallery(state):
1179
- return gr.update(value=state.get("file_list", []))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1180
 
1181
  def refresh_gallery_on_trigger(state):
1182
- if(state.get("update_gallery", False)):
1183
- state['update_gallery'] = False
1184
- return gr.update(value=state.get("file_list", []))
 
 
1185
 
1186
  def select_video(state , event_data: gr.EventData):
1187
  data= event_data._data
 
 
1188
  if data!=None:
1189
  choice = data.get("index",0)
1190
- file_list = state.get("file_list", [])
1191
- state["last_selected"] = (choice + 1) >= len(file_list)
1192
- state["selected"] = choice
1193
  return
1194
 
1195
  def expand_slist(slist, num_inference_steps ):
@@ -1221,6 +1420,7 @@ def convert_image(image):
1221
 
1222
  def generate_video(
1223
  task_id,
 
1224
  prompt,
1225
  negative_prompt,
1226
  resolution,
@@ -1254,22 +1454,29 @@ def generate_video(
1254
  ):
1255
 
1256
  global wan_model, offloadobj, reload_needed, last_model_type
 
 
 
 
 
1257
  file_model_needed = model_needed(image2video)
1258
- with lock:
1259
- queue_not_empty = len(queue) > 0
1260
- if(last_model_type != image2video and (queue_not_empty or server_config.get("reload_model",1) == 2) and (file_model_needed != model_filename or reload_needed)):
 
 
1261
  del wan_model
1262
  if offloadobj is not None:
1263
  offloadobj.release()
1264
  del offloadobj
1265
  gc.collect()
1266
- print(f"Loading model {get_model_name(file_model_needed)}...")
1267
  wan_model, offloadobj, trans = load_models(image2video)
1268
- print(f"Model loaded")
1269
  reload_needed= False
1270
 
1271
  if wan_model == None:
1272
- raise gr.Error("Unable to generate a Video while a new configuration is being applied.")
1273
  if attention_mode == "auto":
1274
  attn = get_auto_attention()
1275
  elif attention_mode in attention_modes_supported:
@@ -1278,26 +1485,15 @@ def generate_video(
1278
  gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.")
1279
  return
1280
 
1281
- raw_resolution = resolution
1282
- width, height = resolution.split("x")
1283
- width, height = int(width), int(height)
 
 
1284
 
1285
  if slg_switch == 0:
1286
  slg_layers = None
1287
- if image2video:
1288
- if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480:
1289
- gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
1290
- return
1291
 
1292
- resolution = str(width) + "*" + str(height)
1293
- if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
1294
- gr.Info(f"Resolution {resolution} not supported by image 2 video")
1295
- return
1296
-
1297
- if "1.3B" in model_filename and width * height > 848*480:
1298
- gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
1299
- return
1300
-
1301
  offload.shared_state["_attention"] = attn
1302
 
1303
  # VAE Tiling
@@ -1321,16 +1517,7 @@ def generate_video(
1321
 
1322
  trans = wan_model.model
1323
 
1324
- global gen_in_progress
1325
- gen_in_progress = True
1326
  temp_filename = None
1327
- if image2video:
1328
- if video_to_continue != None and len(video_to_continue) >0 :
1329
- input_image_or_video_path = video_to_continue
1330
- # pipeline.num_input_frames = max_frames
1331
- # pipeline.max_frames = max_frames
1332
- else:
1333
- input_image_or_video_path = None
1334
 
1335
  loras = state["loras"]
1336
  if len(loras) > 0:
@@ -1374,10 +1561,6 @@ def generate_video(
1374
  raise gr.Error("Error while loading Loras: " + ", ".join(error_files))
1375
  seed = None if seed == -1 else seed
1376
  # negative_prompt = "" # not applicable in the inference
1377
-
1378
- if "abort" in state:
1379
- del state["abort"]
1380
- state["in_progress"] = True
1381
 
1382
  enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
1383
  # VAE Tiling
@@ -1414,49 +1597,53 @@ def generate_video(
1414
  if seed == None or seed <0:
1415
  seed = random.randint(0, 999999999)
1416
 
1417
- global file_list
1418
- clear_file_list = server_config.get("clear_file_list", 0)
1419
- file_list = state.get("file_list", [])
1420
- if clear_file_list > 0:
1421
- file_list_current_size = len(file_list)
1422
- keep_file_from = max(file_list_current_size - clear_file_list, 0)
1423
- files_removed = keep_file_from
1424
- choice = state.get("selected",0)
1425
- choice = max(choice- files_removed, 0)
1426
- file_list = file_list[ keep_file_from: ]
1427
- else:
1428
- file_list = []
1429
- choice = 0
1430
- state["selected"] = choice
1431
- state["file_list"] = file_list
1432
-
1433
-
1434
  global save_path
1435
  os.makedirs(save_path, exist_ok=True)
1436
  video_no = 0
1437
  abort = False
1438
- repeats = f"{video_no}/{repeat_generation}"
1439
- callback = build_callback(task_id, state, trans, num_inference_steps, repeats)
1440
- offload.shared_state["callback"] = callback
1441
  gc.collect()
1442
  torch.cuda.empty_cache()
1443
  wan_model._interrupt = False
1444
- for i in range(repeat_generation):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1445
  try:
1446
- with tracker_lock:
1447
- start_time = time.time()
1448
- progress_tracker[task_id] = {
1449
- 'current_step': 0,
1450
- 'total_steps': num_inference_steps,
1451
- 'start_time': start_time,
1452
- 'last_update': start_time,
1453
- 'repeats': f"{video_no}/{repeat_generation}",
1454
- 'status': "Encoding Prompt"
1455
- }
1456
  if trans.enable_teacache:
1457
  trans.teacache_counter = 0
1458
  trans.num_steps = num_inference_steps
1459
- trans.teacache_skipped_steps = 0
1460
  trans.previous_residual_uncond = None
1461
  trans.previous_residual_cond = None
1462
 
@@ -1464,8 +1651,8 @@ def generate_video(
1464
  if image2video:
1465
  samples = wan_model.generate(
1466
  prompt,
1467
- convert_image(image_to_continue),
1468
- convert_image(image_to_end) if image_to_end != None else None,
1469
  frame_num=(video_length // 4)* 4 + 1,
1470
  max_area=MAX_AREA_CONFIGS[resolution],
1471
  shift=flow_shift,
@@ -1483,7 +1670,7 @@ def generate_video(
1483
  slg_end = slg_end/100,
1484
  cfg_star_switch = cfg_star_switch,
1485
  cfg_zero_step = cfg_zero_step,
1486
- add_frames_for_end_image = not "Fun" in transformer_filename_i2v
1487
  )
1488
  else:
1489
  samples = wan_model.generate(
@@ -1507,7 +1694,6 @@ def generate_video(
1507
  cfg_zero_step = cfg_zero_step,
1508
  )
1509
  except Exception as e:
1510
- gen_in_progress = False
1511
  if temp_filename!= None and os.path.isfile(temp_filename):
1512
  os.remove(temp_filename)
1513
  offload.last_offload_obj.unload_all()
@@ -1530,15 +1716,23 @@ def generate_video(
1530
  if any( keyword in frame.name for keyword in keyword_list):
1531
  VRAM_crash = True
1532
  break
 
 
 
1533
  state["prompt"] = ""
1534
  if VRAM_crash:
1535
- raise gr.Error("The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames.")
1536
  else:
1537
- raise gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
 
 
 
 
1538
  finally:
1539
- with tracker_lock:
1540
- if task_id in progress_tracker:
1541
- del progress_tracker[task_id]
 
1542
 
1543
  if trans.enable_teacache:
1544
  print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
@@ -1555,7 +1749,7 @@ def generate_video(
1555
  end_time = time.time()
1556
  abort = True
1557
  state["prompt"] = ""
1558
- print(f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s")
1559
  else:
1560
  sample = samples.cpu()
1561
  # video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
@@ -1574,7 +1768,7 @@ def generate_video(
1574
  normalize=True,
1575
  value_range=(-1, 1))
1576
 
1577
- configs = get_settings_dict(state, use_image2video, prompt, 0 if image_to_end == None else 1 , video_length, raw_resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1578
  loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
1579
 
1580
  metadata_choice = server_config.get("metadata_choice","metadata")
@@ -1596,9 +1790,159 @@ def generate_video(
1596
 
1597
  if temp_filename!= None and os.path.isfile(temp_filename):
1598
  os.remove(temp_filename)
1599
- gen_in_progress = False
1600
  offload.unload_loras_from_model(trans)
1601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1602
 
1603
  def get_new_preset_msg(advanced = True):
1604
  if advanced:
@@ -1650,7 +1994,7 @@ def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_
1650
 
1651
 
1652
  lset_name_filename = lset_name + ".lset"
1653
- full_lset_name_filename = os.path.join(get_lora_dir(use_image2video), lset_name_filename)
1654
 
1655
  with open(full_lset_name_filename, "w", encoding="utf-8") as writer:
1656
  writer.write(json.dumps(lset, indent=4))
@@ -1667,7 +2011,7 @@ def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_
1667
 
1668
  def delete_lset(state, lset_name):
1669
  loras_presets = state["loras_presets"]
1670
- lset_name_filename = os.path.join( get_lora_dir(use_image2video), sanitize_file_name(lset_name) + ".lset" )
1671
  if len(lset_name) > 0 and lset_name != get_new_preset_msg(True) and lset_name != get_new_preset_msg(False):
1672
  if not os.path.isfile(lset_name_filename):
1673
  raise gr.Error(f"Preset '{lset_name}' not found ")
@@ -1688,8 +2032,8 @@ def delete_lset(state, lset_name):
1688
  def refresh_lora_list(state, lset_name, loras_choices):
1689
  loras_names = state["loras_names"]
1690
  prev_lora_names_selected = [ loras_names[int(i)] for i in loras_choices]
1691
-
1692
- loras, loras_names, loras_presets, _, _, _, _ = setup_loras(use_image2video, None, get_lora_dir(use_image2video), lora_preselected_preset, None)
1693
  state["loras"] = loras
1694
  state["loras_names"] = loras_names
1695
  state["loras_presets"] = loras_presets
@@ -1729,7 +2073,7 @@ def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_m
1729
  gr.Info("Please choose a preset in the list or create one")
1730
  else:
1731
  loras = state["loras"]
1732
- loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(lset_name, loras)
1733
  if len(error) > 0:
1734
  gr.Info(error)
1735
  else:
@@ -1930,10 +2274,11 @@ def save_settings(state, prompt, image_prompt_type, video_length, resolution, nu
1930
  if state.get("validate_success",0) != 1:
1931
  return
1932
 
1933
- ui_defaults = get_settings_dict(state, use_image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
 
1934
  loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
1935
 
1936
- defaults_filename = get_settings_file_name(use_image2video)
1937
 
1938
  with open(defaults_filename, "w", encoding="utf-8") as f:
1939
  json.dump(ui_defaults, f, indent=4)
@@ -1976,7 +2321,12 @@ def generate_video_tab(image2video=False):
1976
 
1977
  state_dict["advanced"] = advanced
1978
  state_dict["loras_model"] = filename
1979
- preset_to_load = lora_preselected_preset if lora_preselected_preset_for_i2v == image2video else ""
 
 
 
 
 
1980
 
1981
  loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset = setup_loras(image2video, None, get_lora_dir(image2video), preset_to_load, None)
1982
 
@@ -1989,7 +2339,7 @@ def generate_video_tab(image2video=False):
1989
  launch_loras = []
1990
  launch_multis_str = ""
1991
 
1992
- if len(default_lora_preset) > 0 and image2video == lora_preselected_preset_for_i2v:
1993
  launch_preset = default_lora_preset
1994
  launch_prompt = default_lora_preset_prompt
1995
  launch_loras = default_loras_choices
@@ -2014,15 +2364,6 @@ def generate_video_tab(image2video=False):
2014
 
2015
 
2016
  header = gr.Markdown(generate_header(model_filename, compile, attention_mode))
2017
- with gr.Row(visible= image2video):
2018
- with gr.Row(scale =2):
2019
- gr.Markdown("<I>Wan2GP's Lora Festival ! Press the following button to download i2v <B>Remade</B> Loras collection (and bonuses Loras).")
2020
- with gr.Row(scale =1):
2021
- download_loras_btn = gr.Button("---> Let the Lora's Festival Start !", scale =1)
2022
- with gr.Row(scale =1):
2023
- gr.Markdown("")
2024
- with gr.Row(visible= image2video) as download_status_row:
2025
- download_status = gr.Markdown()
2026
  with gr.Row():
2027
  with gr.Column():
2028
  with gr.Column(visible=False, elem_id="image-modal-container") as modal_container:
@@ -2250,89 +2591,112 @@ def generate_video_tab(image2video=False):
2250
  cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)")
2251
 
2252
  with gr.Row():
2253
- save_settings_btn = gr.Button("Set Settings as Default")
2254
  show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
2255
  fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
2256
  with gr.Column():
 
 
 
2257
  gen_progress_html = gr.HTML(
2258
  label="Status",
2259
  value="Idle",
2260
- elem_id="generation_progress_bar_container"
2261
  )
2262
  output = gr.Gallery(
2263
  label="Generated videos", show_label=False, elem_id="gallery"
2264
  , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
2265
  generate_btn = gr.Button("Generate")
2266
- queue_df = gr.DataFrame(
2267
- headers=["Status", "Completed", "Progress", "Steps", "Time", "Prompt", "Start", "End", "", "", ""],
2268
- datatype=["str", "str", "str", "str", "str", "markdown", "markdown", "markdown", "str", "str", "str"],
2269
- interactive=False,
2270
- col_count=(11, "fixed"),
2271
- wrap=True,
2272
- value=update_queue_data,
2273
- every=1,
2274
- elem_id="queue_df"
2275
- )
2276
- def handle_selection(evt: gr.SelectData):
2277
- if evt.index is None:
2278
- return gr.update(), gr.update(), gr.update(visible=False)
2279
- row_index, col_index = evt.index
2280
- cell_value = None
2281
- if col_index in [8, 9, 10]:
2282
- if col_index == 8: cell_value = "↑"
2283
- elif col_index == 9: cell_value = "↓"
2284
- elif col_index == 10: cell_value = ""
2285
- if col_index == 8:
2286
- new_df_data = move_up([row_index])
2287
- return new_df_data, gr.update(), gr.update(visible=False)
2288
- elif col_index == 9:
2289
- new_df_data = move_down([row_index])
2290
- return new_df_data, gr.update(), gr.update(visible=False)
2291
- elif col_index == 10:
2292
- new_df_data = remove_task([row_index])
2293
- return new_df_data, gr.update(), gr.update(visible=False)
2294
- start_img_col_idx = 6
2295
- end_img_col_idx = 7
2296
- image_data_to_show = None
2297
- if col_index == start_img_col_idx:
2298
- with lock:
2299
- if row_index < len(queue):
2300
- image_data_to_show = queue[row_index].get('start_image_data')
2301
- elif col_index == end_img_col_idx:
2302
- with lock:
2303
- if row_index < len(queue):
2304
- image_data_to_show = queue[row_index].get('end_image_data')
2305
-
2306
- if image_data_to_show:
2307
- return gr.update(), gr.update(value=image_data_to_show), gr.update(visible=True)
2308
- else:
2309
- return gr.update(), gr.update(), gr.update(visible=False)
2310
- selected_indices = gr.State([])
2311
- queue_df.select(
2312
- fn=handle_selection,
2313
- inputs=None,
2314
- outputs=[queue_df, modal_image_display, modal_container],
2315
- )
2316
- gallery_update_trigger.change(
2317
- fn=refresh_gallery_on_trigger,
2318
- inputs=[state],
2319
- outputs=[output]
2320
- )
2321
- queue_df.change(
2322
- fn=refresh_gallery,
2323
- inputs=[state],
2324
- outputs=[gallery_update_trigger]
2325
- ).then(
2326
- fn=refresh_progress,
2327
- inputs=None,
2328
- outputs=[progress_update_trigger]
2329
- )
2330
- progress_update_trigger.change(
2331
- fn=update_generation_status,
2332
- inputs=[progress_update_trigger],
2333
- outputs=[gen_progress_html],
2334
- show_progress="hidden"
2335
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2336
  save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
2337
  save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
2338
  loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
@@ -2348,53 +2712,114 @@ def generate_video_tab(image2video=False):
2348
  )
2349
  refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
2350
  refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
2351
- download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
2352
  output.select(select_video, state, None )
2353
 
2354
- generate_btn.click(
2355
- fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2356
  ).then(
2357
  fn=process_prompt_and_add_tasks,
2358
- inputs=[
2359
- prompt,
2360
- negative_prompt,
2361
- resolution,
2362
- video_length,
2363
- seed,
2364
- num_inference_steps,
2365
- guidance_scale,
2366
- flow_shift,
2367
- embedded_guidance_scale,
2368
- repeat_generation,
2369
- multi_images_gen_type,
2370
- tea_cache_setting,
2371
- tea_cache_start_step_perc,
2372
- loras_choices,
2373
- loras_mult_choices,
2374
- image_prompt_type_radio,
2375
- image_to_continue,
2376
- image_to_end,
2377
- video_to_continue,
2378
- max_frames,
2379
- RIFLEx_setting,
2380
- slg_switch,
2381
- slg_layers,
2382
- slg_start_perc,
2383
- slg_end_perc,
2384
- cfg_star_switch,
2385
- cfg_zero_step,
2386
- state,
2387
- gr.State(image2video)
2388
- ],
2389
  outputs=queue_df
 
 
 
2390
  )
 
 
2391
  close_modal_button.click(
2392
  lambda: gr.update(visible=False),
2393
  inputs=[],
2394
  outputs=[modal_container]
2395
  )
2396
- return loras_column, loras_choices, presets_column, lset_name, header, state
 
 
 
 
 
 
 
 
 
 
 
 
 
2397
 
 
2398
  def generate_configuration_tab():
2399
  state_dict = {}
2400
  state = gr.State(state_dict)
@@ -2411,7 +2836,7 @@ def generate_configuration_tab():
2411
  value= index,
2412
  label="Transformer model for Text to Video",
2413
  interactive= not lock_ui_transformer,
2414
- visible=True #not use_image2video
2415
  )
2416
  index = transformer_choices_i2v.index(transformer_filename_i2v)
2417
  index = 0 if index ==0 else index
@@ -2428,7 +2853,7 @@ def generate_configuration_tab():
2428
  value= index,
2429
  label="Transformer model for Image to Video",
2430
  interactive= not lock_ui_transformer,
2431
- visible = True # use_image2video,
2432
  )
2433
  index = text_encoder_choices.index(text_encoder_filename)
2434
  index = 0 if index ==0 else index
@@ -2524,7 +2949,7 @@ def generate_configuration_tab():
2524
  reload_choice = gr.Dropdown(
2525
  choices=[
2526
  ("When changing tabs", 1),
2527
- ("When pressing generate", 2),
2528
  ],
2529
  value=server_config.get("reload_model",2),
2530
  label="Reload model"
@@ -2577,19 +3002,46 @@ def generate_about_tab():
2577
  gr.Markdown("- <B>Remade_AI</B> : for creating their awesome Loras collection")
2578
 
2579
 
2580
- def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
2581
- global lora_model_filename, use_image2video
2582
-
2583
  t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
2584
  i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
2585
 
2586
  new_t2v = evt.index == 0
2587
  new_i2v = evt.index == 1
2588
- use_image2video = new_i2v
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2589
 
2590
  if(server_config.get("reload_model",2) == 1):
2591
- with lock:
2592
- queue_empty = len(queue) == 0
 
2593
  if queue_empty:
2594
  global wan_model, offloadobj
2595
  if wan_model is not None:
@@ -2599,7 +3051,7 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
2599
  wan_model = None
2600
  gc.collect()
2601
  torch.cuda.empty_cache()
2602
- wan_model, offloadobj, trans = load_models(use_image2video)
2603
  del trans
2604
 
2605
  if new_t2v or new_i2v:
@@ -2625,11 +3077,15 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
2625
  gr.Column(visible= visible),
2626
  gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible),
2627
  t2v_header,
 
 
2628
  gr.Column(),
2629
  gr.Dropdown(),
2630
  gr.Column(),
2631
  gr.Dropdown(),
2632
- i2v_header,
 
 
2633
  ]
2634
  else:
2635
  return [
@@ -2637,16 +3093,21 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
2637
  gr.Dropdown(),
2638
  gr.Column(),
2639
  gr.Dropdown(),
2640
- t2v_header,
 
 
 
2641
  gr.Column(visible= visible),
2642
  gr.Dropdown(choices=new_loras_choices, visible=visible, value=[]),
2643
  gr.Column(visible= visible),
2644
  gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible),
2645
  i2v_header,
 
 
2646
  ]
2647
 
2648
- return [gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), t2v_header,
2649
- gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), i2v_header]
2650
 
2651
 
2652
  def create_demo():
@@ -2706,112 +3167,112 @@ def create_demo():
2706
  overflow: hidden;
2707
  text-overflow: ellipsis;
2708
  }
2709
- #queue_df td:nth-child(-n+5) {
2710
- cursor: default !important;
2711
- pointer-events: none;
2712
- }
2713
- #queue_df td:nth-child(6) {
2714
- cursor: default !important;
2715
- }
2716
- #queue_df th {
2717
- pointer-events: none;
2718
- text-align: center;
2719
- vertical-align: middle;
2720
- }
2721
- #queue_df table {
2722
- width: 100%;
2723
- overflow: hidden !important;
2724
- }
2725
- #queue_df::-webkit-scrollbar {
2726
- display: none !important;
2727
- }
2728
- #queue_df {
2729
- scrollbar-width: none !important;
2730
- -ms-overflow-style: none !important;
2731
- }
2732
- #queue_df th:nth-child(1),
2733
- #queue_df td:nth-child(1) {
2734
- width: 90px;
2735
- text-align: center;
2736
- vertical-align: middle;
2737
- }
2738
- #queue_df th:nth-child(1) {
2739
- font-size: 0.8em;
2740
- }
2741
- #queue_df th:nth-child(2),
2742
- #queue_df td:nth-child(2) {
2743
- width: 85px;
2744
- text-align: center;
2745
- vertical-align: middle;
2746
- }
2747
- #queue_df th:nth-child(2) {
2748
- font-size: 0.5em;
2749
- }
2750
- #queue_df th:nth-child(3),
2751
- #queue_df td:nth-child(3) {
2752
- width: 75px;
2753
- text-align: center;
2754
- vertical-align: middle;
2755
- }
2756
- #queue_df th:nth-child(3) {
2757
- font-size: 0.6em;
2758
- }
2759
- #queue_df th:nth-child(4),
2760
- #queue_df td:nth-child(4) {
2761
- width: 65px;
2762
- text-align: center;
2763
- white-space: nowrap;
2764
- }
2765
- #queue_df th:nth-child(4) {
2766
- font-size: 0.9em;
2767
- }
2768
- #queue_df th:nth-child(5),
2769
- #queue_df td:nth-child(5) {
2770
- width: 60px;
2771
- text-align: center;
2772
- white-space: nowrap;
2773
- }
2774
- #queue_df th:nth-child(6),
2775
- #queue_df td:nth-child(6) {
2776
- width: auto;
2777
- text-align: center;
2778
- white-space: normal;
2779
- }
2780
- #queue_df th:nth-child(6) {
2781
- font-size: 0.8em;
2782
- }
2783
- #queue_df th:nth-child(7), #queue_df td:nth-child(7),
2784
- #queue_df th:nth-child(8), #queue_df td:nth-child(8) {
2785
- width: 60px;
2786
- text-align: center;
2787
- vertical-align: middle;
2788
- }
2789
- #queue_df td:nth-child(7) img,
2790
- #queue_df td:nth-child(8) img {
2791
- max-width: 50px;
2792
- max-height: 50px;
2793
- object-fit: contain;
2794
- display: block;
2795
- margin: auto;
2796
- cursor: pointer;
2797
- }
2798
- #queue_df th:nth-child(9), #queue_df td:nth-child(9),
2799
- #queue_df th:nth-child(10), #queue_df td:nth-child(10),
2800
- #queue_df th:nth-child(11), #queue_df td:nth-child(11) {
2801
- width: 20px;
2802
- padding: 2px !important;
2803
- cursor: pointer;
2804
- text-align: center;
2805
- font-weight: bold;
2806
- vertical-align: middle;
2807
- }
2808
- #queue_df td:nth-child(7):hover,
2809
- #queue_df td:nth-child(8):hover,
2810
- #queue_df td:nth-child(9):hover,
2811
- #queue_df td:nth-child(10):hover,
2812
- #queue_df td:nth-child(11):hover {
2813
- background-color: #e0e0e0;
2814
- }
2815
  #image-modal-container {
2816
  position: fixed;
2817
  top: 0;
@@ -2893,8 +3354,8 @@ def create_demo():
2893
  pointer-events: none;
2894
  }
2895
  """
2896
- with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
2897
- gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.3 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
2898
  gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
2899
 
2900
  with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
@@ -2904,30 +3365,34 @@ def create_demo():
2904
  gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
2905
  gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
2906
  gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
2907
-
 
 
2908
 
2909
  with gr.Tabs(selected="i2v" if use_image2video else "t2v") as main_tabs:
2910
  with gr.Tab("Text To Video", id="t2v") as t2v_tab:
2911
- t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_state = generate_video_tab()
2912
  with gr.Tab("Image To Video", id="i2v") as i2v_tab:
2913
- i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_state = generate_video_tab(True)
2914
  if not args.lock_config:
 
 
2915
  with gr.Tab("Configuration"):
2916
  generate_configuration_tab()
2917
  with gr.Tab("About"):
2918
  generate_about_tab()
2919
  main_tabs.select(
2920
  fn=on_tab_select,
2921
- inputs=[t2v_state, i2v_state],
2922
  outputs=[
2923
- t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header,
2924
- i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header
2925
  ]
2926
  )
2927
  return demo
2928
 
2929
  if __name__ == "__main__":
2930
- threading.Thread(target=runner, daemon=True).start()
2931
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
2932
  server_port = int(args.server_port)
2933
  if os.name == "nt":
 
1
  import os
2
  import time
3
+ import sys
4
  import threading
5
  import argparse
6
  from mmgp import offload, safetensors2, profile_type
 
34
  if mmgp_version != target_mmgp_version:
35
  print(f"Incorrect version of mmgp ({mmgp_version}), version {target_mmgp_version} is needed. Please upgrade with the command 'pip install -r requirements.txt'")
36
  exit()
 
37
  lock = threading.Lock()
38
  current_task_id = None
39
  task_id = 0
40
+ # progress_tracker = {}
41
+ # tracker_lock = threading.Lock()
 
42
  last_model_type = None
 
43
 
44
  def format_time(seconds):
45
  if seconds < 60:
 
75
  print(f"Error converting PIL to base64: {e}")
76
  return None
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  def process_prompt_and_add_tasks(
80
  prompt,
 
104
  slg_end,
105
  cfg_star_switch,
106
  cfg_zero_step,
107
+ state,
108
  image2video
109
  ):
110
+
111
+ if state.get("validate_success",0) != 1:
112
+ gr.Info("Validation failed, not adding tasks.")
113
  return
114
+
115
+ state["validate_success"] = 0
116
  if len(prompt) ==0:
117
  return
118
  prompt, errors = prompt_parser.process_template(prompt)
119
  if len(errors) > 0:
120
+ gr.Info("Error processing prompt template: " + errors)
121
  return
122
  prompts = prompt.replace("\r", "").split("\n")
123
  prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
124
  if len(prompts) ==0:
125
  return
126
 
127
+ file_model_needed = model_needed(image2video)
128
+ if image2video:
129
+ width, height = resolution.split("x")
130
+ width, height = int(width), int(height)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ if "480p" in file_model_needed and not "Fun" in file_model_needed and width * height > 848*480:
133
+ gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
134
+ return
135
+ resolution = str(width) + "*" + str(height)
136
+ if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
137
+ gr.Info(f"Resolution {resolution} not supported by image 2 video")
138
+ return
139
+
140
+ if "1.3B" in file_model_needed and width * height > 848*480:
141
+ gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
142
+ return
143
+
144
+ if image2video:
145
+ if image_to_continue == None or isinstance(image_to_continue, list) and len(image_to_continue) == 0:
146
+ return
147
+ if image_prompt_type == 0:
148
+ image_to_end = None
149
+ if isinstance(image_to_continue, list):
150
+ image_to_continue = [ convert_image(tup[0]) for tup in image_to_continue ]
151
+ else:
152
+ image_to_continue = [convert_image(image_to_continue)]
153
+ if image_to_end != None:
154
+ if isinstance(image_to_end , list):
155
+ image_to_end = [ convert_image(tup[0]) for tup in image_to_end ]
156
+ else:
157
+ image_to_end = [convert_image(image_to_end) ]
158
+ if len(image_to_continue) != len(image_to_end):
159
+ gr.Info("The number of start and end images should be the same ")
160
+ return
161
+
162
+ if multi_images_gen_type == 0:
163
+ new_prompts = []
164
+ new_image_to_continue = []
165
+ new_image_to_end = []
166
+ for i in range(len(prompts) * len(image_to_continue) ):
167
+ new_prompts.append( prompts[ i % len(prompts)] )
168
+ new_image_to_continue.append(image_to_continue[i // len(prompts)] )
169
+ if image_to_end != None:
170
+ new_image_to_end.append(image_to_end[i // len(prompts)] )
171
+ prompts = new_prompts
172
+ image_to_continue = new_image_to_continue
173
+ if image_to_end != None:
174
+ image_to_end = new_image_to_end
175
+ else:
176
+ if len(prompts) >= len(image_to_continue):
177
+ if len(prompts) % len(image_to_continue) !=0:
178
+ raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
179
+ rep = len(prompts) // len(image_to_continue)
180
+ new_image_to_continue = []
181
+ new_image_to_end = []
182
+ for i, _ in enumerate(prompts):
183
+ new_image_to_continue.append(image_to_continue[i//rep] )
184
+ if image_to_end != None:
185
+ new_image_to_end.append(image_to_end[i//rep] )
186
+ image_to_continue = new_image_to_continue
187
+ if image_to_end != None:
188
+ image_to_end = new_image_to_end
189
+ else:
190
+ if len(image_to_continue) % len(prompts) !=0:
191
+ raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
192
+ rep = len(image_to_continue) // len(prompts)
193
+ new_prompts = []
194
+ for i, _ in enumerate(image_to_continue):
195
+ new_prompts.append( prompts[ i//rep] )
196
+ prompts = new_prompts
197
+
198
+ # elif video_to_continue != None and len(video_to_continue) >0 :
199
+ # input_image_or_video_path = video_to_continue
200
+ # # pipeline.num_input_frames = max_frames
201
+ # # pipeline.max_frames = max_frames
202
+ # else:
203
+ # return
204
+ # else:
205
+ # input_image_or_video_path = None
206
+ if image_to_continue == None:
207
+ image_to_continue = [None] * len(prompts)
208
+ if image_to_end == None:
209
+ image_to_end = [None] * len(prompts)
210
+
211
+ for single_prompt, image_start, image_end in zip(prompts, image_to_continue, image_to_end) :
212
+ kwargs = {
213
+ "prompt" : single_prompt,
214
+ "negative_prompt" : negative_prompt,
215
+ "resolution" : resolution,
216
+ "video_length" : video_length,
217
+ "seed" : seed,
218
+ "num_inference_steps" : num_inference_steps,
219
+ "guidance_scale" : guidance_scale,
220
+ "flow_shift" : flow_shift,
221
+ "embedded_guidance_scale" : embedded_guidance_scale,
222
+ "repeat_generation" : repeat_generation,
223
+ "multi_images_gen_type" : multi_images_gen_type,
224
+ "tea_cache" : tea_cache,
225
+ "tea_cache_start_step_perc" : tea_cache_start_step_perc,
226
+ "loras_choices" : loras_choices,
227
+ "loras_mult_choices" : loras_mult_choices,
228
+ "image_prompt_type" : image_prompt_type,
229
+ "image_to_continue": image_start,
230
+ "image_to_end" : image_end,
231
+ "video_to_continue" : video_to_continue ,
232
+ "max_frames" : max_frames,
233
+ "RIFLEx_setting" : RIFLEx_setting,
234
+ "slg_switch" : slg_switch,
235
+ "slg_layers" : slg_layers,
236
+ "slg_start" : slg_start,
237
+ "slg_end" : slg_end,
238
+ "cfg_star_switch" : cfg_star_switch,
239
+ "cfg_zero_step" : cfg_zero_step,
240
+ "state" : state,
241
+ "image2video" : image2video
242
+ }
243
+ add_video_task(**kwargs)
244
+
245
+ gen = get_gen_info(state)
246
+ gen["prompts_max"] = len(prompts) + gen.get("prompts_max",0)
247
+ state["validate_success"] = 1
248
+ queue= gen.get("queue", [])
249
+ return update_queue_data(queue)
250
+
251
+
252
+
253
+
254
+ def add_video_task(**kwargs):
255
  global task_id
256
+ state = kwargs["state"]
257
+ gen = get_gen_info(state)
258
+ queue = gen["queue"]
259
+ task_id += 1
260
+ current_task_id = task_id
261
+ start_image_data = kwargs["image_to_continue"]
262
+ end_image_data = kwargs["image_to_end"]
263
+
264
+ queue.append({
265
+ "id": current_task_id,
266
+ "image2video": kwargs["image2video"],
267
+ "params": kwargs.copy(),
268
+ "repeats": kwargs["repeat_generation"],
269
+ "length": kwargs["video_length"],
270
+ "steps": kwargs["num_inference_steps"],
271
+ "prompt": kwargs["prompt"],
272
+ "start_image_data": start_image_data,
273
+ "end_image_data": end_image_data,
274
+ "start_image_data_base64": pil_to_base64_uri(start_image_data, format="jpeg", quality=70),
275
+ "end_image_data_base64": pil_to_base64_uri(end_image_data, format="jpeg", quality=70)
276
+ })
277
+ return update_queue_data(queue)
278
+
279
+ def move_up(queue, selected_indices):
280
  if not selected_indices or len(selected_indices) == 0:
281
+ return update_queue_data(queue)
282
  idx = selected_indices[0]
283
  if isinstance(idx, list):
284
  idx = idx[0]
285
  idx = int(idx)
286
  with lock:
287
  if idx > 0:
288
+ idx += 1
289
  queue[idx], queue[idx-1] = queue[idx-1], queue[idx]
290
+ return update_queue_data(queue)
291
 
292
+ def move_down(queue, selected_indices):
293
  if not selected_indices or len(selected_indices) == 0:
294
+ return update_queue_data(queue)
295
  idx = selected_indices[0]
296
  if isinstance(idx, list):
297
  idx = idx[0]
298
  idx = int(idx)
299
  with lock:
300
+ idx += 1
301
  if idx < len(queue)-1:
302
  queue[idx], queue[idx+1] = queue[idx+1], queue[idx]
303
+ return update_queue_data(queue)
304
 
305
+ def remove_task(queue, selected_indices):
306
  if not selected_indices or len(selected_indices) == 0:
307
+ return update_queue_data(queue)
308
  idx = selected_indices[0]
309
  if isinstance(idx, list):
310
  idx = idx[0]
311
+ idx = int(idx) + 1
312
  with lock:
313
  if idx < len(queue):
314
  if idx == 0:
315
  wan_model._interrupt = True
316
  del queue[idx]
317
+ return update_queue_data(queue)
318
 
319
+
320
+
321
+ def get_queue_table(queue):
322
+ data = []
323
+ if len(queue) == 1:
324
+ return data
325
+
326
+ # def td(l, content, width =None):
327
+ # if width !=None:
328
+ # l.append("<TD WIDTH="+ str(width) + "px>" + content + "</TD>")
329
+ # else:
330
+ # l.append("<TD>" + content + "</TD>")
331
+
332
+ # data.append("<STYLE> .TB, .TB th, .TB td {border: 1px solid #CCCCCC};></STYLE><TABLE CLASS=TB><TR BGCOLOR=#F2F2F2><TD Style='Bold'>Qty</TD><TD>Prompt</TD><TD>Steps</TD><TD></TD><TD><TD></TD><TD></TD><TD></TD></TR>")
333
+
334
+ for i, item in enumerate(queue):
335
+ if i==0:
336
+ continue
337
+ truncated_prompt = (item['prompt'][:97] + '...') if len(item['prompt']) > 100 else item['prompt']
338
+ full_prompt = item['prompt'].replace('"', '&quot;')
339
+ prompt_cell = f'<span title="{full_prompt}">{truncated_prompt}</span>'
340
+ start_img_uri =item.get('start_image_data_base64')
341
+ end_img_uri = item.get('end_image_data_base64')
342
+ thumbnail_size = "50px"
343
+ num_steps = item.get('steps')
344
+ length = item.get('length')
345
+ start_img_md = ""
346
+ end_img_md = ""
347
+ if start_img_uri:
348
+ start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
349
+ if end_img_uri:
350
+ end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
351
+ # if i % 2 == 1:
352
+ # data.append("<TR>")
353
+ # else:
354
+ # data.append("<TR BGCOLOR=#F2F2F2>")
355
+
356
+ # td(data,str(item.get('repeats', "1")) )
357
+ # td(data, prompt_cell, "100%")
358
+ # td(data, num_steps, "100%")
359
+ # td(data, start_img_md)
360
+ # td(data, end_img_md)
361
+ # td(data, "↑")
362
+ # td(data, "↓")
363
+ # td(data, "✖")
364
+ # data.append("</TR>")
365
+ # data.append("</TABLE>")
366
+ # return ''.join(data)
367
+
368
+ data.append([item.get('repeats', "1"),
369
+ prompt_cell,
370
+ length,
371
+ num_steps,
372
+ start_img_md,
373
+ end_img_md,
374
+ "↑",
375
+ "↓",
376
+ "✖"
377
+ ])
378
+ return data
379
+ def update_queue_data(queue):
380
+
381
+ data = get_queue_table(queue)
382
+
383
+ # if len(data) == 0:
384
+ # return gr.HTML(visible=False)
385
+ # else:
386
+ # return gr.HTML(value=data, visible= True)
387
+ if len(data) == 0:
388
+ return gr.DataFrame(visible=False)
389
+ else:
390
+ return gr.DataFrame(value=data, visible= True)
391
 
392
  def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True):
393
  bar_class = "progress-bar-custom idle" if is_idle else "progress-bar-custom"
 
402
  """
403
  return html
404
 
405
+ # def refresh_progress():
406
+ # global current_task_id, progress_tracker, last_status_string
407
+ # task_id_to_check = current_task_id
408
+ # is_idle = True
409
+ # status_string = "Starting..."
410
+ # progress_percent = 0.0
411
+ # html_content = ""
412
+
413
+ # with tracker_lock:
414
+ # with lock:
415
+ # processing_or_queued = any(item['state'] in ["Processing", "Queued"] for item in queue)
416
+ # if task_id_to_check is not None:
417
+ # progress_data = progress_tracker.get(task_id_to_check)
418
+ # if progress_data:
419
+ # is_idle = False
420
+ # current_step = progress_data.get('current_step', 0)
421
+ # total_steps = progress_data.get('total_steps', 0)
422
+ # status = progress_data.get('status', "Starting...")
423
+ # repeats = progress_data.get("repeats", 1)
424
+
425
+ # if total_steps > 0:
426
+ # progress_float = min(1.0, max(0.0, float(current_step) / float(total_steps)))
427
+ # progress_percent = progress_float * 100
428
+ # status_string = f"{status} [{repeats}] - {progress_percent:.1f}% complete ({current_step}/{total_steps} steps)"
429
+ # else:
430
+ # progress_percent = 0.0
431
+ # status_string = f"{status} [{repeats}] - Initializing..."
432
+ # html_content = create_html_progress_bar(progress_percent, status_string, is_idle)
433
+ # return gr.update(value=html_content)
434
 
435
  def update_generation_status(html_content):
436
  if(html_content):
 
832
 
833
  only_allow_edit_in_advanced = False
834
  lora_preselected_preset = args.lora_preset
 
835
  # if args.fast : #or args.fastest
836
  # transformer_filename_t2v = transformer_choices_t2v[2]
837
  # attention_mode="sage2" if "sage2" in attention_modes_supported else "sage"
 
844
  lock_ui_compile = True
845
 
846
  model_filename = ""
 
847
  #attention_mode="sage"
848
  #attention_mode="sage2"
849
  #attention_mode="flash"
 
852
  # compile = "transformer"
853
 
854
  def preprocess_loras(sd):
 
 
 
 
855
  first = next(iter(sd), None)
856
  if first == None:
857
  return sd
858
+ if not first.startswith("lora_unet_"):
859
  return sd
860
+ new_sd = {}
861
  print("Converting Lora Safetensors format to Lora Diffusers format")
862
  alphas = {}
863
  repl_list = ["cross_attn", "self_attn", "ffn"]
 
936
  def sanitize_file_name(file_name, rep =""):
937
  return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep)
938
 
939
+ def extract_preset(image2video, lset_name, loras):
940
  loras_choices = []
941
  loras_choices_files = []
942
  loras_mult_choices = ""
943
  prompt =""
944
  full_prompt =""
945
  lset_name = sanitize_file_name(lset_name)
946
+ lora_dir = get_lora_dir(image2video)
947
  if not lset_name.endswith(".lset"):
948
  lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" )
949
  else:
 
1014
  if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
1015
  raise Exception(f"Unknown preset '{lora_preselected_preset}'")
1016
  default_lora_preset = lora_preselected_preset
1017
+ default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, _ , error = extract_preset(i2v, default_lora_preset, loras)
1018
  if len(error) > 0:
1019
  print(error[:200])
1020
  return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
 
1101
  # kwargs["partialPinning"] = True
1102
  elif profile == 3:
1103
  kwargs["budgets"] = { "*" : "70%" }
1104
+ offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, **kwargs)
1105
  if len(args.gpu) > 0:
1106
  torch.set_default_device(args.gpu)
1107
 
 
1178
  if gen_in_progress:
1179
  yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
1180
  return
1181
+ global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
1182
  server_config = {"attention_mode" : attention_choice,
1183
  "transformer_filename": transformer_choices_t2v[transformer_t2v_choice],
1184
  "transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice],
 
1243
  final_frames = (final_frames * 255).astype(np.uint8)
1244
  ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
1245
 
1246
+
1247
+ def get_gen_info(state):
1248
+ cache = state.get("gen", None)
1249
+ if cache == None:
1250
+ cache = dict()
1251
+ state["gen"] = cache
1252
+ return cache
1253
+
1254
+ def build_callback(state, pipe, progress, status, num_inference_steps):
1255
+ def callback(step_idx, force_refresh, read_state = False):
1256
+ gen = get_gen_info(state)
1257
+ refresh_id = gen.get("refresh", -1)
1258
+ if force_refresh or step_idx >= 0:
1259
+ pass
1260
+ else:
1261
+ refresh_id = gen.get("refresh", -1)
1262
+ if refresh_id < 0:
1263
+ return
1264
+ UI_refresh = state.get("refresh", 0)
1265
+ if UI_refresh >= refresh_id:
1266
+ return
1267
+
1268
+ status = gen["progress_status"]
1269
+ state["refresh"] = refresh_id
1270
+ if read_state:
1271
+ phase, step_idx = gen["progress_phase"]
1272
+ else:
1273
+ step_idx += 1
1274
+ if gen.get("abort", False):
1275
  # pipe._interrupt = True
1276
+ phase = " - Aborting"
1277
  elif step_idx == num_inference_steps:
1278
+ phase = " - VAE Decoding"
1279
  else:
1280
+ phase = " - Denoising"
1281
+ gen["progress_phase"] = (phase, step_idx)
1282
+ status_msg = status + phase
1283
+ if step_idx >= 0:
1284
+ progress_args = [(step_idx , num_inference_steps) , status_msg , num_inference_steps]
1285
+ else:
1286
+ progress_args = [0, status_msg]
1287
+
1288
+ progress(*progress_args)
1289
+ gen["progress_args"] = progress_args
1290
+
1291
+ return callback
1292
+ def abort_generation(state):
1293
+ gen = get_gen_info(state)
1294
+ if "in_progress" in gen:
1295
+
1296
+ gen["abort"] = True
1297
+ gen["extra_orders"] = 0
1298
+ wan_model._interrupt= True
1299
+ msg = "Processing Request to abort Current Generation"
1300
+ gr.Info(msg)
1301
+ return msg, gr.Button(interactive= False)
1302
+ else:
1303
+ return "", gr.Button(interactive= True)
1304
+
1305
+ def is_gen_location(state):
1306
+ gen = get_gen_info(state)
1307
+
1308
+ gen_location = gen.get("location",None)
1309
+ if gen_location == None:
1310
+ return None
1311
+ return state["image2video"] == gen_location
1312
+
1313
+
1314
+ def refresh_gallery(state, msg):
1315
+ gen = get_gen_info(state)
1316
+
1317
+ if is_gen_location(state):
1318
+ gen["last_msg"] = msg
1319
+ file_list = gen.get("file_list", None)
1320
+ choice = gen.get("selected",0)
1321
+ in_progress = "in_progress" in gen
1322
+ if in_progress:
1323
+ if gen.get("last_selected", True):
1324
+ choice = max(len(file_list) - 1,0)
1325
+
1326
+ queue = gen.get("queue", [])
1327
+ abort_interactive = not gen.get("abort", False)
1328
+ if not in_progress or len(queue) == 0:
1329
+ return gr.Gallery(selected_index=choice, value = file_list), gr.HTML("", visible= False), gr.Button(visible=True), gr.Button(visible=False), gr.Row(visible=False), update_queue_data(queue), gr.Button(interactive= abort_interactive)
1330
+ else:
1331
+ task = queue[0]
1332
+ start_img_md = ""
1333
+ end_img_md = ""
1334
+ prompt = task["prompt"]
1335
+
1336
+ if task.get('image2video'):
1337
+ start_img_uri = task.get('start_image_data_base64')
1338
+ end_img_uri = task.get('end_image_data_base64')
1339
+ thumbnail_size = "100px"
1340
+ if start_img_uri:
1341
+ start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
1342
+ if end_img_uri:
1343
+ end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
1344
+
1345
+ label = f"Prompt of Video being Generated"
1346
+
1347
+ html = "<STYLE> #PINFO, #PINFO th, #PINFO td {border: 1px solid #CCCCCC;background-color:#FFFFFF;}</STYLE><TABLE WIDTH=100% ID=PINFO ><TR><TD width=100%>" + prompt + "</TD>"
1348
+ if start_img_md != "":
1349
+ html += "<TD>" + start_img_md + "</TD>"
1350
+ if end_img_md != "":
1351
+ html += "<TD>" + end_img_md + "</TD>"
1352
+
1353
+ html += "</TR></TABLE>"
1354
+ html_output = gr.HTML(html, visible= True)
1355
+ return gr.Gallery(selected_index=choice, value = file_list), html_output, gr.Button(visible=False), gr.Button(visible=True), gr.Row(visible=True), update_queue_data(queue), gr.Button(interactive= abort_interactive)
1356
+
1357
+
1358
+
1359
+ def finalize_generation(state):
1360
+ gen = get_gen_info(state)
1361
+ choice = gen.get("selected",0)
1362
+ if "in_progress" in gen:
1363
+ del gen["in_progress"]
1364
+ if gen.get("last_selected", True):
1365
+ file_list = gen.get("file_list", [])
1366
+ choice = len(file_list) - 1
1367
+
1368
+
1369
+ gen["extra_orders"] = 0
1370
+ time.sleep(0.2)
1371
+ global gen_in_progress
1372
+ gen_in_progress = False
1373
+ return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Button(visible= False), gr.Column(visible= False), gr.HTML(visible= False, value="")
1374
+
1375
 
1376
  def refresh_gallery_on_trigger(state):
1377
+ gen = get_gen_info(state)
1378
+
1379
+ if(gen.get("update_gallery", False)):
1380
+ gen['update_gallery'] = False
1381
+ return gr.update(value=gen.get("file_list", []))
1382
 
1383
  def select_video(state , event_data: gr.EventData):
1384
  data= event_data._data
1385
+ gen = get_gen_info(state)
1386
+
1387
  if data!=None:
1388
  choice = data.get("index",0)
1389
+ file_list = gen.get("file_list", [])
1390
+ gen["last_selected"] = (choice + 1) >= len(file_list)
1391
+ gen["selected"] = choice
1392
  return
1393
 
1394
  def expand_slist(slist, num_inference_steps ):
 
1420
 
1421
  def generate_video(
1422
  task_id,
1423
+ progress,
1424
  prompt,
1425
  negative_prompt,
1426
  resolution,
 
1454
  ):
1455
 
1456
  global wan_model, offloadobj, reload_needed, last_model_type
1457
+ gen = get_gen_info(state)
1458
+
1459
+ file_list = gen["file_list"]
1460
+ prompt_no = gen["prompt_no"]
1461
+
1462
  file_model_needed = model_needed(image2video)
1463
+ # queue = gen.get("queue", [])
1464
+ # with lock:
1465
+ # queue_not_empty = len(queue) > 0
1466
+ # if(last_model_type != image2video and (queue_not_empty or server_config.get("reload_model",1) == 2) and (file_model_needed != model_filename or reload_needed)):
1467
+ if file_model_needed != model_filename or reload_needed:
1468
  del wan_model
1469
  if offloadobj is not None:
1470
  offloadobj.release()
1471
  del offloadobj
1472
  gc.collect()
1473
+ yield f"Loading model {get_model_name(file_model_needed)}..."
1474
  wan_model, offloadobj, trans = load_models(image2video)
1475
+ yield f"Model loaded"
1476
  reload_needed= False
1477
 
1478
  if wan_model == None:
1479
+ gr.Info("Unable to generate a Video while a new configuration is being applied.")
1480
  if attention_mode == "auto":
1481
  attn = get_auto_attention()
1482
  elif attention_mode in attention_modes_supported:
 
1485
  gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.")
1486
  return
1487
 
1488
+
1489
+
1490
+ if not image2video:
1491
+ width, height = resolution.split("x")
1492
+ width, height = int(width), int(height)
1493
 
1494
  if slg_switch == 0:
1495
  slg_layers = None
 
 
 
 
1496
 
 
 
 
 
 
 
 
 
 
1497
  offload.shared_state["_attention"] = attn
1498
 
1499
  # VAE Tiling
 
1517
 
1518
  trans = wan_model.model
1519
 
 
 
1520
  temp_filename = None
 
 
 
 
 
 
 
1521
 
1522
  loras = state["loras"]
1523
  if len(loras) > 0:
 
1561
  raise gr.Error("Error while loading Loras: " + ", ".join(error_files))
1562
  seed = None if seed == -1 else seed
1563
  # negative_prompt = "" # not applicable in the inference
 
 
 
 
1564
 
1565
  enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
1566
  # VAE Tiling
 
1597
  if seed == None or seed <0:
1598
  seed = random.randint(0, 999999999)
1599
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1600
  global save_path
1601
  os.makedirs(save_path, exist_ok=True)
1602
  video_no = 0
1603
  abort = False
 
 
 
1604
  gc.collect()
1605
  torch.cuda.empty_cache()
1606
  wan_model._interrupt = False
1607
+ gen["abort"] = False
1608
+ gen["prompt"] = prompt
1609
+ repeat_no = 0
1610
+ extra_generation = 0
1611
+ while True:
1612
+ extra_generation += gen.get("extra_orders",0)
1613
+ gen["extra_orders"] = 0
1614
+ total_generation = repeat_generation + extra_generation
1615
+ gen["total_generation"] = total_generation
1616
+ if abort or repeat_no >= total_generation:
1617
+ break
1618
+ repeat_no +=1
1619
+ gen["repeat_no"] = repeat_no
1620
+ prompts_max = gen["prompts_max"]
1621
+ status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation)
1622
+
1623
+ yield status
1624
+
1625
+ gen["progress_status"] = status
1626
+ gen["progress_phase"] = (" - Encoding Prompt", -1 )
1627
+ callback = build_callback(state, trans, progress, status, num_inference_steps)
1628
+ progress_args = [0, status + " - Encoding Prompt"]
1629
+ progress(*progress_args )
1630
+ gen["progress_args"] = progress_args
1631
+
1632
  try:
1633
+ start_time = time.time()
1634
+ # with tracker_lock:
1635
+ # progress_tracker[task_id] = {
1636
+ # 'current_step': 0,
1637
+ # 'total_steps': num_inference_steps,
1638
+ # 'start_time': start_time,
1639
+ # 'last_update': start_time,
1640
+ # 'repeats': repeat_generation, # f"{video_no}/{repeat_generation}",
1641
+ # 'status': "Encoding Prompt"
1642
+ # }
1643
  if trans.enable_teacache:
1644
  trans.teacache_counter = 0
1645
  trans.num_steps = num_inference_steps
1646
+ trans.teacache_skipped_steps = 0
1647
  trans.previous_residual_uncond = None
1648
  trans.previous_residual_cond = None
1649
 
 
1651
  if image2video:
1652
  samples = wan_model.generate(
1653
  prompt,
1654
+ image_to_continue,
1655
+ image_to_end if image_to_end != None else None,
1656
  frame_num=(video_length // 4)* 4 + 1,
1657
  max_area=MAX_AREA_CONFIGS[resolution],
1658
  shift=flow_shift,
 
1670
  slg_end = slg_end/100,
1671
  cfg_star_switch = cfg_star_switch,
1672
  cfg_zero_step = cfg_zero_step,
1673
+ add_frames_for_end_image = not "Fun" in transformer_filename_i2v,
1674
  )
1675
  else:
1676
  samples = wan_model.generate(
 
1694
  cfg_zero_step = cfg_zero_step,
1695
  )
1696
  except Exception as e:
 
1697
  if temp_filename!= None and os.path.isfile(temp_filename):
1698
  os.remove(temp_filename)
1699
  offload.last_offload_obj.unload_all()
 
1716
  if any( keyword in frame.name for keyword in keyword_list):
1717
  VRAM_crash = True
1718
  break
1719
+
1720
+ _ , exc_value, exc_traceback = sys.exc_info()
1721
+
1722
  state["prompt"] = ""
1723
  if VRAM_crash:
1724
+ new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames."
1725
  else:
1726
+ new_error = gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
1727
+ tb = traceback.format_exc().split('\n')[:-2]
1728
+ print('\n'.join(tb))
1729
+ raise gr.Error(new_error, print_exception= False)
1730
+
1731
  finally:
1732
+ pass
1733
+ # with tracker_lock:
1734
+ # if task_id in progress_tracker:
1735
+ # del progress_tracker[task_id]
1736
 
1737
  if trans.enable_teacache:
1738
  print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
 
1749
  end_time = time.time()
1750
  abort = True
1751
  state["prompt"] = ""
1752
+ # yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
1753
  else:
1754
  sample = samples.cpu()
1755
  # video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
 
1768
  normalize=True,
1769
  value_range=(-1, 1))
1770
 
1771
+ configs = get_settings_dict(state, image2video, prompt, 0 if image_to_end == None else 1 , video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
1772
  loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
1773
 
1774
  metadata_choice = server_config.get("metadata_choice","metadata")
 
1790
 
1791
  if temp_filename!= None and os.path.isfile(temp_filename):
1792
  os.remove(temp_filename)
 
1793
  offload.unload_loras_from_model(trans)
1794
 
1795
+ def prepare_generate_video(state):
1796
+ if state.get("validate_success",0) != 1:
1797
+ return gr.Button(visible= True), gr.Button(visible= False), gr.Column(visible= False)
1798
+ else:
1799
+ return gr.Button(visible= False), gr.Button(visible= True), gr.Column(visible= True)
1800
+
1801
+
1802
+ def wait_tasks_done(state, progress=gr.Progress()):
1803
+
1804
+ gen = get_gen_info(state)
1805
+ gen_location = is_gen_location(state)
1806
+
1807
+ last_msg = gen.get("last_msg", "")
1808
+ if len(last_msg) > 0:
1809
+ yield last_msg
1810
+
1811
+ if gen_location == None or gen_location:
1812
+ return gr.Text()
1813
+
1814
+
1815
+ while True:
1816
+
1817
+ msg = gen.get("last_msg", "")
1818
+ if len(msg) > 0 and last_msg != msg:
1819
+ yield msg
1820
+ last_msg = msg
1821
+ progress_args = gen.get("progress_args", None)
1822
+ if progress_args != None:
1823
+ progress(*progress_args)
1824
+
1825
+ in_progress= gen.get("in_progress", False)
1826
+ if not in_progress:
1827
+ break
1828
+ time.sleep(0.5)
1829
+
1830
+
1831
+
1832
+ def process_tasks(state, progress=gr.Progress()):
1833
+ gen = get_gen_info(state)
1834
+ queue = gen.get("queue", [])
1835
+
1836
+ if len(queue) == 0:
1837
+ return
1838
+ gen = get_gen_info(state)
1839
+ gen["location"] = state["image2video"]
1840
+ clear_file_list = server_config.get("clear_file_list", 0)
1841
+ file_list = gen.get("file_list", [])
1842
+ if clear_file_list > 0:
1843
+ file_list_current_size = len(file_list)
1844
+ keep_file_from = max(file_list_current_size - clear_file_list, 0)
1845
+ files_removed = keep_file_from
1846
+ choice = gen.get("selected",0)
1847
+ choice = max(choice- files_removed, 0)
1848
+ file_list = file_list[ keep_file_from: ]
1849
+ else:
1850
+ file_list = []
1851
+ choice = 0
1852
+ gen["selected"] = choice
1853
+ gen["file_list"] = file_list
1854
+
1855
+ start_time = time.time()
1856
+
1857
+ global gen_in_progress
1858
+ gen_in_progress = True
1859
+ gen["in_progress"] = True
1860
+
1861
+ prompt_no = 0
1862
+ while len(queue) > 0:
1863
+ prompt_no += 1
1864
+ gen["prompt_no"] = prompt_no
1865
+ task = queue[0]
1866
+ task_id = task["id"]
1867
+ params = task['params']
1868
+ iterator = iter(generate_video(task_id, progress, **params))
1869
+ while True:
1870
+ try:
1871
+ ok = False
1872
+ status = next(iterator, "#")
1873
+ if status == "#":
1874
+ break
1875
+ ok = True
1876
+ except Exception as e:
1877
+ _ , exc_value, exc_traceback = sys.exc_info()
1878
+ raise exc_value.with_traceback(exc_traceback)
1879
+ finally:
1880
+ if not ok:
1881
+ queue.clear()
1882
+ yield status
1883
+
1884
+ queue[:] = [item for item in queue if item['id'] != task['id']]
1885
+
1886
+ gen["prompts_max"] = 0
1887
+ gen["prompt"] = ""
1888
+ end_time = time.time()
1889
+ if gen.get("abort"):
1890
+ yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
1891
+ else:
1892
+ yield f"Total Generation Time: {end_time-start_time:.1f}s"
1893
+
1894
+
1895
+ def get_generation_status(prompt_no, prompts_max, repeat_no, repeat_max):
1896
+ if prompts_max == 1:
1897
+ if repeat_max == 1:
1898
+ return "Video"
1899
+ else:
1900
+ return f"Sample {repeat_no}/{repeat_max}"
1901
+ else:
1902
+ if repeat_max == 1:
1903
+ return f"Prompt {prompt_no}/{prompts_max}"
1904
+ else:
1905
+ return f"Prompt {prompt_no}/{prompts_max}, Sample {repeat_no}/{repeat_max}"
1906
+
1907
+
1908
+ refresh_id = 0
1909
+
1910
+ def get_new_refresh_id():
1911
+ global refresh_id
1912
+ refresh_id += 1
1913
+ return refresh_id
1914
+
1915
+ def update_status(state):
1916
+ gen = get_gen_info(state)
1917
+ prompt_no = gen["prompt_no"]
1918
+ prompts_max = gen.get("prompts_max",0)
1919
+ total_generation = gen["total_generation"]
1920
+ repeat_no = gen["repeat_no"]
1921
+ status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation)
1922
+ gen["progress_status"] = status
1923
+ gen["refresh"] = get_new_refresh_id()
1924
+
1925
+
1926
+ def one_more_sample(state):
1927
+ gen = get_gen_info(state)
1928
+ extra_orders = gen.get("extra_orders", 0)
1929
+ extra_orders += 1
1930
+ gen["extra_orders"] = extra_orders
1931
+ in_progress = gen.get("in_progress", False)
1932
+ if not in_progress :
1933
+ return state
1934
+ prompt_no = gen["prompt_no"]
1935
+ prompts_max = gen.get("prompts_max",0)
1936
+ total_generation = gen["total_generation"] + extra_orders
1937
+ repeat_no = gen["repeat_no"]
1938
+ status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation)
1939
+
1940
+
1941
+ gen["progress_status"] = status
1942
+ gen["refresh"] = get_new_refresh_id()
1943
+ gr.Info(f"An extra sample generation is planned for a total of {total_generation} videos for this prompt")
1944
+
1945
+ return state
1946
 
1947
  def get_new_preset_msg(advanced = True):
1948
  if advanced:
 
1994
 
1995
 
1996
  lset_name_filename = lset_name + ".lset"
1997
+ full_lset_name_filename = os.path.join(get_lora_dir(state["image2video"]), lset_name_filename)
1998
 
1999
  with open(full_lset_name_filename, "w", encoding="utf-8") as writer:
2000
  writer.write(json.dumps(lset, indent=4))
 
2011
 
2012
  def delete_lset(state, lset_name):
2013
  loras_presets = state["loras_presets"]
2014
+ lset_name_filename = os.path.join( get_lora_dir(state["image2video"]), sanitize_file_name(lset_name) + ".lset" )
2015
  if len(lset_name) > 0 and lset_name != get_new_preset_msg(True) and lset_name != get_new_preset_msg(False):
2016
  if not os.path.isfile(lset_name_filename):
2017
  raise gr.Error(f"Preset '{lset_name}' not found ")
 
2032
  def refresh_lora_list(state, lset_name, loras_choices):
2033
  loras_names = state["loras_names"]
2034
  prev_lora_names_selected = [ loras_names[int(i)] for i in loras_choices]
2035
+ image2video= state["image2video"]
2036
+ loras, loras_names, loras_presets, _, _, _, _ = setup_loras(image2video, None, get_lora_dir(image2video), lora_preselected_preset, None)
2037
  state["loras"] = loras
2038
  state["loras_names"] = loras_names
2039
  state["loras_presets"] = loras_presets
 
2073
  gr.Info("Please choose a preset in the list or create one")
2074
  else:
2075
  loras = state["loras"]
2076
+ loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(state["image2video"], lset_name, loras)
2077
  if len(error) > 0:
2078
  gr.Info(error)
2079
  else:
 
2274
  if state.get("validate_success",0) != 1:
2275
  return
2276
 
2277
+ image2video = state["image2video"]
2278
+ ui_defaults = get_settings_dict(state, image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
2279
  loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
2280
 
2281
+ defaults_filename = get_settings_file_name(image2video)
2282
 
2283
  with open(defaults_filename, "w", encoding="utf-8") as f:
2284
  json.dump(ui_defaults, f, indent=4)
 
2321
 
2322
  state_dict["advanced"] = advanced
2323
  state_dict["loras_model"] = filename
2324
+ state_dict["image2video"] = image2video
2325
+ gen = dict()
2326
+ gen["queue"] = []
2327
+ state_dict["gen"] = gen
2328
+
2329
+ preset_to_load = lora_preselected_preset if use_image2video == image2video else ""
2330
 
2331
  loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset = setup_loras(image2video, None, get_lora_dir(image2video), preset_to_load, None)
2332
 
 
2339
  launch_loras = []
2340
  launch_multis_str = ""
2341
 
2342
+ if len(default_lora_preset) > 0 and image2video == use_image2video:
2343
  launch_preset = default_lora_preset
2344
  launch_prompt = default_lora_preset_prompt
2345
  launch_loras = default_loras_choices
 
2364
 
2365
 
2366
  header = gr.Markdown(generate_header(model_filename, compile, attention_mode))
 
 
 
 
 
 
 
 
 
2367
  with gr.Row():
2368
  with gr.Column():
2369
  with gr.Column(visible=False, elem_id="image-modal-container") as modal_container:
 
2591
  cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)")
2592
 
2593
  with gr.Row():
2594
+ save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config)
2595
  show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
2596
  fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
2597
  with gr.Column():
2598
+ gen_status = gr.Text(label="Status", interactive= False)
2599
+ full_sync = gr.Text(label="Status", interactive= False, visible= False)
2600
+ light_sync = gr.Text(label="Status", interactive= False, visible= False)
2601
  gen_progress_html = gr.HTML(
2602
  label="Status",
2603
  value="Idle",
2604
+ elem_id="generation_progress_bar_container", visible= False
2605
  )
2606
  output = gr.Gallery(
2607
  label="Generated videos", show_label=False, elem_id="gallery"
2608
  , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
2609
  generate_btn = gr.Button("Generate")
2610
+ add_to_queue_btn = gr.Button("Add New Prompt To Queue", visible = False)
2611
+
2612
+ with gr.Column(visible= False) as current_gen_column:
2613
+ with gr.Row():
2614
+ gen_info = gr.HTML(visible=False, min_height=1)
2615
+ with gr.Row():
2616
+ onemore_btn = gr.Button("One More Sample Please !")
2617
+ abort_btn = gr.Button("Abort")
2618
+
2619
+ queue_df = gr.DataFrame(
2620
+ headers=["Qty","Prompt", "Length","Steps","Start", "End", "", "", ""],
2621
+ datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
2622
+ interactive=False,
2623
+ col_count=(9, "fixed"),
2624
+ wrap=True,
2625
+ value=[],
2626
+ visible= False,
2627
+ # every=1,
2628
+ elem_id="queue_df"
2629
+ )
2630
+ # queue_df = gr.HTML("",
2631
+ # visible= False,
2632
+ # elem_id="queue_df"
2633
+ # )
2634
+
2635
+ def handle_selection(state, evt: gr.SelectData):
2636
+ gen = get_gen_info(state)
2637
+ queue = gen.get("queue", [])
2638
+
2639
+ if evt.index is None:
2640
+ return gr.update(), gr.update(), gr.update(visible=False)
2641
+ row_index, col_index = evt.index
2642
+ cell_value = None
2643
+ if col_index in [6, 7, 8]:
2644
+ if col_index == 6: cell_value = "↑"
2645
+ elif col_index == 7: cell_value = "↓"
2646
+ elif col_index == 8: cell_value = "✖"
2647
+ if col_index == 6:
2648
+ new_df_data = move_up(queue, [row_index])
2649
+ return new_df_data, gr.update(), gr.update(visible=False)
2650
+ elif col_index == 7:
2651
+ new_df_data = move_down(queue, [row_index])
2652
+ return new_df_data, gr.update(), gr.update(visible=False)
2653
+ elif col_index == 8:
2654
+ new_df_data = remove_task(queue, [row_index])
2655
+ gen["prompts_max"] = gen.get("prompts_max",0) - 1
2656
+ update_status(state)
2657
+ return new_df_data, gr.update(), gr.update(visible=False)
2658
+ start_img_col_idx = 4
2659
+ end_img_col_idx = 5
2660
+ image_data_to_show = None
2661
+ if col_index == start_img_col_idx:
2662
+ with lock:
2663
+ if row_index < len(queue):
2664
+ image_data_to_show = queue[row_index].get('start_image_data')
2665
+ elif col_index == end_img_col_idx:
2666
+ with lock:
2667
+ if row_index < len(queue):
2668
+ image_data_to_show = queue[row_index].get('end_image_data')
2669
+
2670
+ if image_data_to_show:
2671
+ return gr.update(), gr.update(value=image_data_to_show), gr.update(visible=True)
2672
+ else:
2673
+ return gr.update(), gr.update(), gr.update(visible=False)
2674
+ selected_indices = gr.State([])
2675
+ queue_df.select(
2676
+ fn=handle_selection,
2677
+ inputs=state,
2678
+ outputs=[queue_df, modal_image_display, modal_container],
2679
+ )
2680
+ # gallery_update_trigger.change(
2681
+ # fn=refresh_gallery_on_trigger,
2682
+ # inputs=[state],
2683
+ # outputs=[output]
2684
+ # )
2685
+ # queue_df.change(
2686
+ # fn=refresh_gallery,
2687
+ # inputs=[state],
2688
+ # outputs=[gallery_update_trigger]
2689
+ # ).then(
2690
+ # fn=refresh_progress,
2691
+ # inputs=None,
2692
+ # outputs=[progress_update_trigger]
2693
+ # )
2694
+ progress_update_trigger.change(
2695
+ fn=update_generation_status,
2696
+ inputs=[progress_update_trigger],
2697
+ outputs=[gen_progress_html],
2698
+ show_progress="hidden"
2699
+ )
2700
  save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
2701
  save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
2702
  loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
 
2712
  )
2713
  refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
2714
  refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
 
2715
  output.select(select_video, state, None )
2716
 
2717
+ gen_status.change(refresh_gallery,
2718
+ inputs = [state, gen_status],
2719
+ outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn])
2720
+
2721
+ full_sync.change(refresh_gallery,
2722
+ inputs = [state, gen_status],
2723
+ outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
2724
+ ).then( fn=wait_tasks_done,
2725
+ inputs= [state],
2726
+ outputs =[gen_status],
2727
+ ).then(finalize_generation,
2728
+ inputs= [state],
2729
+ outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
2730
+ )
2731
+ light_sync.change(refresh_gallery,
2732
+ inputs = [state, gen_status],
2733
+ outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
2734
+ )
2735
+
2736
+ abort_btn.click(abort_generation, [state], [gen_status, abort_btn] ) #.then(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info, queue_df] )
2737
+ onemore_btn.click(fn=one_more_sample,inputs=[state], outputs= [state])
2738
+
2739
+
2740
+ gen_inputs=[
2741
+ prompt,
2742
+ negative_prompt,
2743
+ resolution,
2744
+ video_length,
2745
+ seed,
2746
+ num_inference_steps,
2747
+ guidance_scale,
2748
+ flow_shift,
2749
+ embedded_guidance_scale,
2750
+ repeat_generation,
2751
+ multi_images_gen_type,
2752
+ tea_cache_setting,
2753
+ tea_cache_start_step_perc,
2754
+ loras_choices,
2755
+ loras_mult_choices,
2756
+ image_prompt_type_radio,
2757
+ image_to_continue,
2758
+ image_to_end,
2759
+ video_to_continue,
2760
+ max_frames,
2761
+ RIFLEx_setting,
2762
+ slg_switch,
2763
+ slg_layers,
2764
+ slg_start_perc,
2765
+ slg_end_perc,
2766
+ cfg_star_switch,
2767
+ cfg_zero_step,
2768
+ state,
2769
+ gr.State(image2video)
2770
+ ]
2771
+
2772
+ generate_btn.click(fn=validate_wizard_prompt,
2773
+ inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
2774
+ outputs= [prompt]
2775
+ ).then(fn=process_prompt_and_add_tasks,
2776
+ inputs = gen_inputs,
2777
+ outputs= queue_df
2778
+ ).then(fn=prepare_generate_video,
2779
+ inputs= [state],
2780
+ outputs= [generate_btn, add_to_queue_btn, current_gen_column],
2781
+ ).then(fn=process_tasks,
2782
+ inputs= [state],
2783
+ outputs= [gen_status],
2784
+ ).then(finalize_generation,
2785
+ inputs= [state],
2786
+ outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
2787
+ )
2788
+
2789
+ add_to_queue_btn.click(fn=validate_wizard_prompt,
2790
+ inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
2791
+ outputs= [prompt]
2792
  ).then(
2793
  fn=process_prompt_and_add_tasks,
2794
+ inputs = gen_inputs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2795
  outputs=queue_df
2796
+ ).then(
2797
+ fn=update_status,
2798
+ inputs = [state],
2799
  )
2800
+
2801
+
2802
  close_modal_button.click(
2803
  lambda: gr.update(visible=False),
2804
  inputs=[],
2805
  outputs=[modal_container]
2806
  )
2807
+ return loras_column, loras_choices, presets_column, lset_name, header, light_sync, full_sync, state
2808
+
2809
+ def generate_doxnload_tab(presets_column, loras_column, lset_name,loras_choices, state):
2810
+ with gr.Row():
2811
+ with gr.Row(scale =2):
2812
+ gr.Markdown("<I>Wan2GP's Lora Festival ! Press the following button to download i2v <B>Remade</B> Loras collection (and bonuses Loras).")
2813
+ with gr.Row(scale =1):
2814
+ download_loras_btn = gr.Button("---> Let the Lora's Festival Start !", scale =1)
2815
+ with gr.Row(scale =1):
2816
+ gr.Markdown("")
2817
+ with gr.Row() as download_status_row:
2818
+ download_status = gr.Markdown()
2819
+
2820
+ download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
2821
 
2822
+
2823
  def generate_configuration_tab():
2824
  state_dict = {}
2825
  state = gr.State(state_dict)
 
2836
  value= index,
2837
  label="Transformer model for Text to Video",
2838
  interactive= not lock_ui_transformer,
2839
+ visible=True
2840
  )
2841
  index = transformer_choices_i2v.index(transformer_filename_i2v)
2842
  index = 0 if index ==0 else index
 
2853
  value= index,
2854
  label="Transformer model for Image to Video",
2855
  interactive= not lock_ui_transformer,
2856
+ visible = True,
2857
  )
2858
  index = text_encoder_choices.index(text_encoder_filename)
2859
  index = 0 if index ==0 else index
 
2949
  reload_choice = gr.Dropdown(
2950
  choices=[
2951
  ("When changing tabs", 1),
2952
+ ("When pressing Generate", 2),
2953
  ],
2954
  value=server_config.get("reload_model",2),
2955
  label="Reload model"
 
3002
  gr.Markdown("- <B>Remade_AI</B> : for creating their awesome Loras collection")
3003
 
3004
 
3005
+ def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData):
 
 
3006
  t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
3007
  i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
3008
 
3009
  new_t2v = evt.index == 0
3010
  new_i2v = evt.index == 1
3011
+ i2v_light_sync = gr.Text()
3012
+ t2v_light_sync = gr.Text()
3013
+ i2v_full_sync = gr.Text()
3014
+ t2v_full_sync = gr.Text()
3015
+ if new_t2v or new_i2v:
3016
+ last_tab_was_image2video =global_state.get("last_tab_was_image2video", None)
3017
+ if last_tab_was_image2video == None or last_tab_was_image2video:
3018
+ gen = i2v_state["gen"]
3019
+ t2v_state["gen"] = gen
3020
+ else:
3021
+ gen = t2v_state["gen"]
3022
+ i2v_state["gen"] = gen
3023
+
3024
+
3025
+ if last_tab_was_image2video != None and new_t2v != new_i2v:
3026
+ gen_location = gen.get("location", None)
3027
+ if "in_progress" in gen and gen_location !=None and not (gen_location and new_i2v or not gen_location and new_t2v) :
3028
+ if new_i2v:
3029
+ i2v_full_sync = gr.Text(str(time.time()))
3030
+ else:
3031
+ t2v_full_sync = gr.Text(str(time.time()))
3032
+ else:
3033
+ if new_i2v:
3034
+ i2v_light_sync = gr.Text(str(time.time()))
3035
+ else:
3036
+ t2v_light_sync = gr.Text(str(time.time()))
3037
+
3038
+
3039
+ global_state["last_tab_was_image2video"] = new_i2v
3040
 
3041
  if(server_config.get("reload_model",2) == 1):
3042
+ queue = gen.get("queue", [])
3043
+
3044
+ queue_empty = len(queue) == 0
3045
  if queue_empty:
3046
  global wan_model, offloadobj
3047
  if wan_model is not None:
 
3051
  wan_model = None
3052
  gc.collect()
3053
  torch.cuda.empty_cache()
3054
+ wan_model, offloadobj, trans = load_models(new_i2v)
3055
  del trans
3056
 
3057
  if new_t2v or new_i2v:
 
3077
  gr.Column(visible= visible),
3078
  gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible),
3079
  t2v_header,
3080
+ t2v_light_sync,
3081
+ t2v_full_sync,
3082
  gr.Column(),
3083
  gr.Dropdown(),
3084
  gr.Column(),
3085
  gr.Dropdown(),
3086
+ gr.Markdown(),
3087
+ gr.Text(),
3088
+ gr.Text(),
3089
  ]
3090
  else:
3091
  return [
 
3093
  gr.Dropdown(),
3094
  gr.Column(),
3095
  gr.Dropdown(),
3096
+ gr.Markdown(),
3097
+ gr.Text(),
3098
+ gr.Text(),
3099
+ gr.Text(),
3100
  gr.Column(visible= visible),
3101
  gr.Dropdown(choices=new_loras_choices, visible=visible, value=[]),
3102
  gr.Column(visible= visible),
3103
  gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible),
3104
  i2v_header,
3105
+ i2v_light_sync,
3106
+ i2v_full_sync,
3107
  ]
3108
 
3109
+ return [gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), t2v_header, t2v_light_sync, t2v_full_sync,
3110
+ gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), i2v_header, i2v_light_sync, i2v_full_sync]
3111
 
3112
 
3113
  def create_demo():
 
3167
  overflow: hidden;
3168
  text-overflow: ellipsis;
3169
  }
3170
+ # #queue_df td:nth-child(-n+5) {
3171
+ # cursor: default !important;
3172
+ # pointer-events: none;
3173
+ # }
3174
+ # #queue_df td:nth-child(6) {
3175
+ # cursor: default !important;
3176
+ # }
3177
+ # #queue_df th {
3178
+ # pointer-events: none;
3179
+ # text-align: center;
3180
+ # vertical-align: middle;
3181
+ # }
3182
+ # #queue_df table {
3183
+ # width: 100%;
3184
+ # overflow: hidden !important;
3185
+ # }
3186
+ # #queue_df::-webkit-scrollbar {
3187
+ # display: none !important;
3188
+ # }
3189
+ # #queue_df {
3190
+ # scrollbar-width: none !important;
3191
+ # -ms-overflow-style: none !important;
3192
+ # }
3193
+ # #queue_df th:nth-child(1),
3194
+ # #queue_df td:nth-child(1) {
3195
+ # width: 90px;
3196
+ # text-align: center;
3197
+ # vertical-align: middle;
3198
+ # }
3199
+ # #queue_df th:nth-child(1) {
3200
+ # font-size: 0.8em;
3201
+ # }
3202
+ # #queue_df th:nth-child(2),
3203
+ # #queue_df td:nth-child(2) {
3204
+ # width: 85px;
3205
+ # text-align: center;
3206
+ # vertical-align: middle;
3207
+ # }
3208
+ # #queue_df th:nth-child(2) {
3209
+ # font-size: 0.5em;
3210
+ # }
3211
+ # #queue_df th:nth-child(3),
3212
+ # #queue_df td:nth-child(3) {
3213
+ # width: 75px;
3214
+ # text-align: center;
3215
+ # vertical-align: middle;
3216
+ # }
3217
+ # #queue_df th:nth-child(3) {
3218
+ # font-size: 0.6em;
3219
+ # }
3220
+ # #queue_df th:nth-child(4),
3221
+ # #queue_df td:nth-child(4) {
3222
+ # width: 65px;
3223
+ # text-align: center;
3224
+ # white-space: nowrap;
3225
+ # }
3226
+ # #queue_df th:nth-child(4) {
3227
+ # font-size: 0.9em;
3228
+ # }
3229
+ # #queue_df th:nth-child(5),
3230
+ # #queue_df td:nth-child(5) {
3231
+ # width: 60px;
3232
+ # text-align: center;
3233
+ # white-space: nowrap;
3234
+ # }
3235
+ # #queue_df th:nth-child(6),
3236
+ # #queue_df td:nth-child(6) {
3237
+ # width: auto;
3238
+ # text-align: center;
3239
+ # white-space: normal;
3240
+ # }
3241
+ # #queue_df th:nth-child(6) {
3242
+ # font-size: 0.8em;
3243
+ # }
3244
+ # #queue_df th:nth-child(7), #queue_df td:nth-child(7),
3245
+ # #queue_df th:nth-child(8), #queue_df td:nth-child(8) {
3246
+ # width: 60px;
3247
+ # text-align: center;
3248
+ # vertical-align: middle;
3249
+ # }
3250
+ # #queue_df td:nth-child(7) img,
3251
+ # #queue_df td:nth-child(8) img {
3252
+ # max-width: 50px;
3253
+ # max-height: 50px;
3254
+ # object-fit: contain;
3255
+ # display: block;
3256
+ # margin: auto;
3257
+ # cursor: pointer;
3258
+ # }
3259
+ # #queue_df th:nth-child(9), #queue_df td:nth-child(9),
3260
+ # #queue_df th:nth-child(10), #queue_df td:nth-child(10),
3261
+ # #queue_df th:nth-child(11), #queue_df td:nth-child(11) {
3262
+ # width: 20px;
3263
+ # padding: 2px !important;
3264
+ # cursor: pointer;
3265
+ # text-align: center;
3266
+ # font-weight: bold;
3267
+ # vertical-align: middle;
3268
+ # }
3269
+ # #queue_df td:nth-child(7):hover,
3270
+ # #queue_df td:nth-child(8):hover,
3271
+ # #queue_df td:nth-child(9):hover,
3272
+ # #queue_df td:nth-child(10):hover,
3273
+ # #queue_df td:nth-child(11):hover {
3274
+ # background-color: #e0e0e0;
3275
+ # }
3276
  #image-modal-container {
3277
  position: fixed;
3278
  top: 0;
 
3354
  pointer-events: none;
3355
  }
3356
  """
3357
+ with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
3358
+ gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.4 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
3359
  gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
3360
 
3361
  with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
 
3365
  gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
3366
  gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
3367
  gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
3368
+ global_dict = {}
3369
+ global_dict["last_tab_was_image2video"] = use_image2video
3370
+ global_state = gr.State(global_dict)
3371
 
3372
  with gr.Tabs(selected="i2v" if use_image2video else "t2v") as main_tabs:
3373
  with gr.Tab("Text To Video", id="t2v") as t2v_tab:
3374
+ t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_light_sync, t2v_full_sync, t2v_state = generate_video_tab(False)
3375
  with gr.Tab("Image To Video", id="i2v") as i2v_tab:
3376
+ i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync, i2v_state = generate_video_tab(True)
3377
  if not args.lock_config:
3378
+ with gr.Tab("Downloads", id="downloads") as downloads_tab:
3379
+ generate_doxnload_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state)
3380
  with gr.Tab("Configuration"):
3381
  generate_configuration_tab()
3382
  with gr.Tab("About"):
3383
  generate_about_tab()
3384
  main_tabs.select(
3385
  fn=on_tab_select,
3386
+ inputs=[global_state, t2v_state, i2v_state],
3387
  outputs=[
3388
+ t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_light_sync, t2v_full_sync,
3389
+ i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync
3390
  ]
3391
  )
3392
  return demo
3393
 
3394
  if __name__ == "__main__":
3395
+ # threading.Thread(target=runner, daemon=True).start()
3396
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
3397
  server_port = int(args.server_port)
3398
  if os.name == "nt":
wan/image2video.py CHANGED
@@ -40,7 +40,12 @@ def optimized_scale(positive_flat, negative_flat):
40
  st_star = dot_product / squared_norm
41
 
42
  return st_star
43
-
 
 
 
 
 
44
 
45
  class WanI2V:
46
 
@@ -90,7 +95,6 @@ class WanI2V:
90
 
91
  self.num_train_timesteps = config.num_train_timesteps
92
  self.param_dtype = config.param_dtype
93
-
94
  shard_fn = partial(shard_model, device_id=device_id)
95
  self.text_encoder = T5EncoderModel(
96
  text_len=config.text_len,
@@ -208,16 +212,16 @@ class WanI2V:
208
  - H: Frame height (from max_area)
209
  - W: Frame width from max_area)
210
  """
211
- img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
212
  lat_frames = int((frame_num - 1) // self.vae_stride[0] + 1)
213
  any_end_frame = img2 !=None
214
  if any_end_frame:
215
  any_end_frame = True
216
- img2 = TF.to_tensor(img2).sub_(0.5).div_(0.5).to(self.device)
217
  if add_frames_for_end_image:
218
  frame_num +=1
219
  lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
220
-
221
  h, w = img.shape[1:]
222
  aspect_ratio = h / w
223
  lat_h = round(
@@ -229,6 +233,15 @@ class WanI2V:
229
  h = lat_h * self.vae_stride[1]
230
  w = lat_w * self.vae_stride[2]
231
 
 
 
 
 
 
 
 
 
 
232
  max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
233
  max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
234
 
@@ -273,21 +286,32 @@ class WanI2V:
273
 
274
  from mmgp import offload
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  offload.last_offload_obj.unload_all()
277
  if any_end_frame:
278
- img_interpolated = torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16)
279
- img2_interpolated = torch.nn.functional.interpolate(img2[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16)
280
  mean2 = 0
281
  enc= torch.concat([
282
  img_interpolated,
283
- torch.full( (3, frame_num-2, h, w), mean2, device="cpu", dtype= torch.bfloat16),
284
- img2_interpolated,
285
  ], dim=1).to(self.device)
286
  else:
287
  enc= torch.concat([
288
- torch.nn.functional.interpolate(
289
- img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16),
290
- torch.zeros(3, frame_num-1, h, w, device="cpu", dtype= torch.bfloat16)
291
  ], dim=1).to(self.device)
292
 
293
  lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
@@ -333,7 +357,8 @@ class WanI2V:
333
  'seq_len': max_seq_len,
334
  'y': [y],
335
  'freqs' : freqs,
336
- 'pipeline' : self
 
337
  }
338
 
339
  arg_null = {
@@ -342,7 +367,8 @@ class WanI2V:
342
  'seq_len': max_seq_len,
343
  'y': [y],
344
  'freqs' : freqs,
345
- 'pipeline' : self
 
346
  }
347
 
348
  arg_both= {
@@ -352,7 +378,8 @@ class WanI2V:
352
  'seq_len': max_seq_len,
353
  'y': [y],
354
  'freqs' : freqs,
355
- 'pipeline' : self
 
356
  }
357
 
358
  if offload_model:
@@ -363,7 +390,7 @@ class WanI2V:
363
 
364
  # self.model.to(self.device)
365
  if callback != None:
366
- callback(-1, None)
367
 
368
  for i, t in enumerate(tqdm(timesteps)):
369
  offload.set_step_no_for_lora(self.model, i)
@@ -437,7 +464,7 @@ class WanI2V:
437
  del timestep
438
 
439
  if callback is not None:
440
- callback(i, latent)
441
 
442
 
443
  x0 = [latent.to(self.device, dtype=torch.bfloat16)]
@@ -451,7 +478,7 @@ class WanI2V:
451
  video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
452
 
453
  if any_end_frame and add_frames_for_end_image:
454
- # video[:, -1:] = img2_interpolated
455
  video = video[:, :-1]
456
 
457
  else:
 
40
  st_star = dot_product / squared_norm
41
 
42
  return st_star
43
+
44
+ def resize_lanczos(img, h, w):
45
+ img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
46
+ img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
47
+ return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
48
+
49
 
50
  class WanI2V:
51
 
 
95
 
96
  self.num_train_timesteps = config.num_train_timesteps
97
  self.param_dtype = config.param_dtype
 
98
  shard_fn = partial(shard_model, device_id=device_id)
99
  self.text_encoder = T5EncoderModel(
100
  text_len=config.text_len,
 
212
  - H: Frame height (from max_area)
213
  - W: Frame width from max_area)
214
  """
215
+ img = TF.to_tensor(img)
216
  lat_frames = int((frame_num - 1) // self.vae_stride[0] + 1)
217
  any_end_frame = img2 !=None
218
  if any_end_frame:
219
  any_end_frame = True
220
+ img2 = TF.to_tensor(img2)
221
  if add_frames_for_end_image:
222
  frame_num +=1
223
  lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
224
+
225
  h, w = img.shape[1:]
226
  aspect_ratio = h / w
227
  lat_h = round(
 
233
  h = lat_h * self.vae_stride[1]
234
  w = lat_w * self.vae_stride[2]
235
 
236
+ clip_image_size = self.clip.model.image_size
237
+ img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device)
238
+ img = resize_lanczos(img, clip_image_size, clip_image_size)
239
+ img = img.sub_(0.5).div_(0.5).to(self.device)
240
+ if img2!= None:
241
+ img_interpolated2 = resize_lanczos(img2, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device)
242
+ img2 = resize_lanczos(img2, clip_image_size, clip_image_size)
243
+ img2 = img2.sub_(0.5).div_(0.5).to(self.device)
244
+
245
  max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
246
  max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
247
 
 
286
 
287
  from mmgp import offload
288
 
289
+
290
+ # img_interpolated.save('aaa.png')
291
+
292
+ # img_interpolated = torch.from_numpy(np.array(img_interpolated).astype(np.float32) / 255.0).movedim(-1, 0)
293
+
294
+ # img_interpolated = torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode='lanczos')
295
+ # img_interpolated = img_interpolated.squeeze(0).transpose(0,2).transpose(1,0)
296
+ # img_interpolated = img_interpolated.clamp(-1, 1)
297
+ # img_interpolated = (img_interpolated + 1)/2
298
+ # img_interpolated = (img_interpolated*255).type(torch.uint8)
299
+ # img_interpolated = img_interpolated.cpu().numpy()
300
+ # xxx = Image.fromarray(img_interpolated, 'RGB')
301
+ # xxx.save('my.png')
302
+
303
  offload.last_offload_obj.unload_all()
304
  if any_end_frame:
 
 
305
  mean2 = 0
306
  enc= torch.concat([
307
  img_interpolated,
308
+ torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= torch.bfloat16),
309
+ img_interpolated2,
310
  ], dim=1).to(self.device)
311
  else:
312
  enc= torch.concat([
313
+ img_interpolated,
314
+ torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= torch.bfloat16)
 
315
  ], dim=1).to(self.device)
316
 
317
  lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
 
357
  'seq_len': max_seq_len,
358
  'y': [y],
359
  'freqs' : freqs,
360
+ 'pipeline' : self,
361
+ 'callback' : callback
362
  }
363
 
364
  arg_null = {
 
367
  'seq_len': max_seq_len,
368
  'y': [y],
369
  'freqs' : freqs,
370
+ 'pipeline' : self,
371
+ 'callback' : callback
372
  }
373
 
374
  arg_both= {
 
378
  'seq_len': max_seq_len,
379
  'y': [y],
380
  'freqs' : freqs,
381
+ 'pipeline' : self,
382
+ 'callback' : callback
383
  }
384
 
385
  if offload_model:
 
390
 
391
  # self.model.to(self.device)
392
  if callback != None:
393
+ callback(-1, True)
394
 
395
  for i, t in enumerate(tqdm(timesteps)):
396
  offload.set_step_no_for_lora(self.model, i)
 
464
  del timestep
465
 
466
  if callback is not None:
467
+ callback(i, False)
468
 
469
 
470
  x0 = [latent.to(self.device, dtype=torch.bfloat16)]
 
478
  video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
479
 
480
  if any_end_frame and add_frames_for_end_image:
481
+ # video[:, -1:] = img_interpolated2
482
  video = video[:, :-1]
483
 
484
  else:
wan/modules/model.py CHANGED
@@ -704,6 +704,7 @@ class WanModel(ModelMixin, ConfigMixin):
704
  is_uncond=False,
705
  max_steps = 0,
706
  slg_layers=None,
 
707
  ):
708
  r"""
709
  Forward pass through the diffusion model
@@ -835,12 +836,10 @@ class WanModel(ModelMixin, ConfigMixin):
835
  freqs=freqs,
836
  # context=context,
837
  context_lens=context_lens)
838
-
839
  for block_idx, block in enumerate(self.blocks):
840
  offload.shared_state["layer"] = block_idx
841
- if "refresh" in offload.shared_state:
842
- del offload.shared_state["refresh"]
843
- offload.shared_state["callback"](-1, -1, True)
844
  if pipeline._interrupt:
845
  if joint_pass:
846
  return None, None
 
704
  is_uncond=False,
705
  max_steps = 0,
706
  slg_layers=None,
707
+ callback = None,
708
  ):
709
  r"""
710
  Forward pass through the diffusion model
 
836
  freqs=freqs,
837
  # context=context,
838
  context_lens=context_lens)
 
839
  for block_idx, block in enumerate(self.blocks):
840
  offload.shared_state["layer"] = block_idx
841
+ if callback != None:
842
+ callback(-1, False, True)
 
843
  if pipeline._interrupt:
844
  if joint_pass:
845
  return None, None
wan/text2video.py CHANGED
@@ -268,7 +268,7 @@ class WanT2V:
268
  if self.model.enable_teacache:
269
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
270
  if callback != None:
271
- callback(-1, None)
272
  for i, t in enumerate(tqdm(timesteps)):
273
  latent_model_input = latents
274
  slg_layers_local = None
@@ -322,7 +322,7 @@ class WanT2V:
322
  del temp_x0
323
 
324
  if callback is not None:
325
- callback(i, latents)
326
 
327
  x0 = latents
328
  if offload_model:
 
268
  if self.model.enable_teacache:
269
  self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
270
  if callback != None:
271
+ callback(-1, True)
272
  for i, t in enumerate(tqdm(timesteps)):
273
  latent_model_input = latents
274
  slg_layers_local = None
 
322
  del temp_x0
323
 
324
  if callback is not None:
325
+ callback(i, False)
326
 
327
  x0 = latents
328
  if offload_model: