Tophness2022 commited on
Commit
134cb56
·
1 Parent(s): 12652e0

add queue saving/loading/clearing/autosaving/autoloading, fix empty prompt logic

Browse files
Files changed (1) hide show
  1. wgp.py +581 -12
wgp.py CHANGED
@@ -28,6 +28,12 @@ from wan.utils import prompt_parser
28
  import base64
29
  import io
30
  from PIL import Image
 
 
 
 
 
 
31
  PROMPT_VARS_MAX = 10
32
 
33
  target_mmgp_version = "3.3.4"
@@ -98,10 +104,14 @@ def process_prompt_and_add_tasks(state, model_choice):
98
  inputs["state"] = state
99
  inputs.pop("lset_name")
100
  if inputs == None:
101
- return
 
102
  prompt = inputs["prompt"]
103
  if len(prompt) ==0:
104
- return
 
 
 
105
  prompt, errors = prompt_parser.process_template(prompt)
106
  if len(errors) > 0:
107
  gr.Info("Error processing prompt template: " + errors)
@@ -111,7 +121,10 @@ def process_prompt_and_add_tasks(state, model_choice):
111
  prompts = prompt.replace("\r", "").split("\n")
112
  prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
113
  if len(prompts) ==0:
114
- return
 
 
 
115
 
116
  resolution = inputs["resolution"]
117
  width, height = resolution.split("x")
@@ -250,9 +263,6 @@ def process_prompt_and_add_tasks(state, model_choice):
250
  queue= gen.get("queue", [])
251
  return update_queue_data(queue)
252
 
253
-
254
-
255
-
256
  def add_video_task(**inputs):
257
  global task_id
258
  state = inputs["state"]
@@ -327,6 +337,444 @@ def remove_task(queue, selected_indices):
327
  del queue[idx]
328
  return update_queue_data(queue)
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
 
332
  def get_queue_table(queue):
@@ -390,7 +838,7 @@ def get_queue_table(queue):
390
  ])
391
  return data
392
  def update_queue_data(queue):
393
-
394
  data = get_queue_table(queue)
395
 
396
  # if len(data) == 0:
@@ -1993,6 +2441,7 @@ def process_tasks(state, progress=gr.Progress()):
1993
  yield status
1994
 
1995
  queue[:] = [item for item in queue if item['id'] != task['id']]
 
1996
 
1997
  gen["prompts_max"] = 0
1998
  gen["prompt"] = ""
@@ -2716,7 +3165,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
2716
  wizard_variables = "\n".join(variables)
2717
  for _ in range( PROMPT_VARS_MAX - len(prompt_vars)):
2718
  prompt_vars.append(gr.Textbox(visible= False, min_width=80, show_label= False))
2719
-
2720
  with gr.Column(not advanced_prompt) as prompt_column_wizard:
2721
  wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3)
2722
  wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
@@ -2902,7 +3351,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
2902
  queue_df = gr.DataFrame(
2903
  headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""],
2904
  datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
2905
- column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"],
2906
  interactive=False,
2907
  col_count=(9, "fixed"),
2908
  wrap=True,
@@ -2911,6 +3360,72 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
2911
  visible= False,
2912
  elem_id="queue_df"
2913
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2914
 
2915
  extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
2916
  prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row] # show_advanced presets_column,
@@ -3014,7 +3529,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
3014
  outputs=[modal_container]
3015
  )
3016
 
3017
- return loras_choices, lset_name, state
 
 
 
 
 
 
 
 
 
3018
 
3019
  def generate_download_tab(lset_name,loras_choices, state):
3020
  with gr.Row():
@@ -3479,8 +4003,15 @@ def create_demo():
3479
  with gr.Row():
3480
  header = gr.Markdown(generate_header(transformer_filename, compile, attention_mode), visible= True)
3481
  with gr.Row():
3482
-
3483
- loras_choices, lset_name, state = generate_video_tab(model_choice = model_choice, header = header)
 
 
 
 
 
 
 
3484
  with gr.Tab("Informations"):
3485
  generate_info_tab()
3486
  if not args.lock_config:
@@ -3491,9 +4022,47 @@ def create_demo():
3491
  with gr.Tab("About"):
3492
  generate_about_tab()
3493
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3494
  return demo
3495
 
3496
  if __name__ == "__main__":
 
3497
  # threading.Thread(target=runner, daemon=True).start()
3498
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
3499
  server_port = int(args.server_port)
 
28
  import base64
29
  import io
30
  from PIL import Image
31
+ import zipfile
32
+ import tempfile
33
+ import shutil
34
+ import atexit
35
+ global_queue_ref = []
36
+ AUTOSAVE_FILENAME = "queue.zip"
37
  PROMPT_VARS_MAX = 10
38
 
39
  target_mmgp_version = "3.3.4"
 
104
  inputs["state"] = state
105
  inputs.pop("lset_name")
106
  if inputs == None:
107
+ gr.Warning("Internal state error: Could not retrieve inputs for the model.")
108
+ return update_queue_data(queue)
109
  prompt = inputs["prompt"]
110
  if len(prompt) ==0:
111
+ gr.Info("Prompt cannot be empty.")
112
+ gen = get_gen_info(state)
113
+ queue = gen.get("queue", [])
114
+ return get_queue_table(queue)
115
  prompt, errors = prompt_parser.process_template(prompt)
116
  if len(errors) > 0:
117
  gr.Info("Error processing prompt template: " + errors)
 
121
  prompts = prompt.replace("\r", "").split("\n")
122
  prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
123
  if len(prompts) ==0:
124
+ gr.Info("Prompt cannot be empty.")
125
+ gen = get_gen_info(state)
126
+ queue = gen.get("queue", [])
127
+ return get_queue_table(queue)
128
 
129
  resolution = inputs["resolution"]
130
  width, height = resolution.split("x")
 
263
  queue= gen.get("queue", [])
264
  return update_queue_data(queue)
265
 
 
 
 
266
  def add_video_task(**inputs):
267
  global task_id
268
  state = inputs["state"]
 
337
  del queue[idx]
338
  return update_queue_data(queue)
339
 
340
+ def update_global_queue_ref(queue):
341
+ global global_queue_ref
342
+ with lock:
343
+ global_queue_ref = queue[:]
344
+
345
+ def save_queue_action(state):
346
+ gen = get_gen_info(state)
347
+ queue = gen.get("queue", [])
348
+
349
+ if not queue or len(queue) <=1 : # Check if queue is empty or only has the placeholder
350
+ gr.Info("Queue is empty. Nothing to save.")
351
+ return None # Return None if nothing to save
352
+
353
+ # Use an in-memory buffer for the zip file
354
+ zip_buffer = io.BytesIO()
355
+
356
+ # Still use a temporary directory *only* for storing images before zipping
357
+ with tempfile.TemporaryDirectory() as tmpdir:
358
+ queue_manifest = []
359
+ image_paths_in_zip = {} # Tracks image PIL object ID -> filename in zip
360
+
361
+ for task_index, task in enumerate(queue):
362
+ # Skip the placeholder item if it exists
363
+ if task is None or not isinstance(task, dict) or task_index == 0: continue
364
+
365
+ params_copy = task.get('params', {}).copy()
366
+ task_id_s = task.get('id', f"task_{task_index}") # Use a different var name
367
+
368
+ image_keys = ["image_start", "image_end", "image_refs"]
369
+ for key in image_keys:
370
+ images_pil = params_copy.get(key)
371
+ if images_pil is None:
372
+ continue
373
+
374
+ # Ensure images_pil is always a list for processing
375
+ is_originally_list = isinstance(images_pil, list)
376
+ if not is_originally_list:
377
+ images_pil = [images_pil]
378
+
379
+ image_filenames_for_json = []
380
+ for img_index, pil_image in enumerate(images_pil):
381
+ # Ensure it's actually a PIL Image object before proceeding
382
+ if not isinstance(pil_image, Image.Image):
383
+ print(f"Warning: Expected PIL Image for key '{key}' in task {task_id_s}, got {type(pil_image)}. Skipping image.")
384
+ continue
385
+
386
+ # Use object ID to check if this specific image instance is already saved
387
+ img_id = id(pil_image)
388
+ if img_id in image_paths_in_zip:
389
+ # If already saved, just add its filename to the list
390
+ image_filenames_for_json.append(image_paths_in_zip[img_id])
391
+ continue # Move to the next image in the list
392
+
393
+ # Image not saved yet, create filename and save path
394
+ img_filename_in_zip = f"task{task_id_s}_{key}_{img_index}.png"
395
+ img_save_path = os.path.join(tmpdir, img_filename_in_zip)
396
+
397
+ try:
398
+ # Save the image to the temporary directory
399
+ pil_image.save(img_save_path, "PNG")
400
+ image_filenames_for_json.append(img_filename_in_zip)
401
+ # Store the mapping from image ID to its filename in the zip
402
+ image_paths_in_zip[img_id] = img_filename_in_zip
403
+ except Exception as e:
404
+ print(f"Error saving image {img_filename_in_zip} for task {task_id_s}: {e}")
405
+ # Optionally decide if you want to continue or fail here
406
+
407
+ # Update the params_copy with the list of filenames (or single filename)
408
+ if image_filenames_for_json:
409
+ params_copy[key] = image_filenames_for_json if is_originally_list else image_filenames_for_json[0]
410
+ else:
411
+ # If no images were successfully processed for this key, remove it
412
+ params_copy.pop(key, None)
413
+
414
+
415
+ # Clean up parameters before adding to manifest
416
+ params_copy.pop('state', None)
417
+ params_copy.pop('start_image_data_base64', None) # Don't need base64 in saved queue
418
+ params_copy.pop('end_image_data_base64', None)
419
+ # Also remove the actual PIL data if it somehow remained
420
+ params_copy.pop('start_image_data', None)
421
+ params_copy.pop('end_image_data', None)
422
+
423
+ manifest_entry = {
424
+ "id": task.get('id'),
425
+ "params": params_copy,
426
+ # Keep other necessary top-level task info if needed, like repeats etc.
427
+ # Example: "repeats": task.get('repeats', 1)
428
+ }
429
+ queue_manifest.append(manifest_entry)
430
+
431
+ # --- Create queue.json content ---
432
+ manifest_path = os.path.join(tmpdir, "queue.json")
433
+ try:
434
+ with open(manifest_path, 'w', encoding='utf-8') as f:
435
+ # Dump only the relevant manifest data
436
+ json.dump(queue_manifest, f, indent=4)
437
+ except Exception as e:
438
+ print(f"Error writing queue.json: {e}")
439
+ gr.Warning("Failed to create queue manifest.")
440
+ return None # Return None on failure
441
+
442
+ # --- Create the zip file in memory ---
443
+ try:
444
+ with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf:
445
+ # Add queue.json
446
+ zf.write(manifest_path, arcname="queue.json")
447
+
448
+ # Add all unique images that were saved to the temp dir
449
+ for saved_img_rel_path in image_paths_in_zip.values():
450
+ saved_img_abs_path = os.path.join(tmpdir, saved_img_rel_path)
451
+ if os.path.exists(saved_img_abs_path):
452
+ zf.write(saved_img_abs_path, arcname=saved_img_rel_path)
453
+ else:
454
+ # This shouldn't happen if saving was successful, but good to check
455
+ print(f"Warning: Image file {saved_img_rel_path} not found during zipping.")
456
+
457
+ # --- Prepare for return ---
458
+ # Move buffer position to the beginning
459
+ zip_buffer.seek(0)
460
+ # Read the binary content
461
+ zip_binary_content = zip_buffer.getvalue()
462
+ # Encode as base64 string
463
+ zip_base64 = base64.b64encode(zip_binary_content).decode('utf-8')
464
+ print(f"Queue successfully prepared as base64 string ({len(zip_base64)} chars).")
465
+ return zip_base64
466
+
467
+ except Exception as e:
468
+ print(f"Error creating zip file in memory: {e}")
469
+ gr.Warning("Failed to create zip data for download.")
470
+ return None # Return None on failure
471
+ finally:
472
+ zip_buffer.close()
473
+
474
+ def load_queue_action(filepath, state):
475
+ global task_id
476
+ gen = get_gen_info(state)
477
+ original_queue = gen.get("queue", []) # Store original queue for error case
478
+
479
+ if not filepath or not hasattr(filepath, 'name') or not Path(filepath.name).is_file():
480
+ print("[load_queue_action] Warning: No valid file selected or file not found.")
481
+ # Return the current state of the DataFrame
482
+ return update_queue_data(original_queue)
483
+
484
+ newly_loaded_queue = []
485
+ max_id_in_file = 0
486
+ error_message = ""
487
+ local_queue_copy_for_global_ref = None
488
+
489
+ try:
490
+ print(f"[load_queue_action] Attempting to load queue from: {filepath.name}")
491
+ with tempfile.TemporaryDirectory() as tmpdir:
492
+ with zipfile.ZipFile(filepath.name, 'r') as zf:
493
+ if "queue.json" not in zf.namelist(): raise ValueError("queue.json not found in zip file")
494
+ print(f"[load_queue_action] Extracting {filepath.name} to {tmpdir}")
495
+ zf.extractall(tmpdir)
496
+ print(f"[load_queue_action] Extraction complete.")
497
+
498
+ manifest_path = os.path.join(tmpdir, "queue.json")
499
+ print(f"[load_queue_action] Reading manifest: {manifest_path}")
500
+ with open(manifest_path, 'r', encoding='utf-8') as f:
501
+ loaded_manifest = json.load(f)
502
+ print(f"[load_queue_action] Manifest loaded. Processing {len(loaded_manifest)} tasks.")
503
+
504
+ for task_index, task_data in enumerate(loaded_manifest):
505
+ # (Keep the existing task processing logic here...)
506
+ if task_data is None or not isinstance(task_data, dict):
507
+ print(f"[load_queue_action] Skipping invalid task data at index {task_index}")
508
+ continue
509
+
510
+ params = task_data.get('params', {})
511
+ task_id_loaded = task_data.get('id', 0)
512
+ max_id_in_file = max(max_id_in_file, task_id_loaded)
513
+ loaded_pil_images = {}
514
+ image_keys = ["image_start", "image_end", "image_refs"]
515
+ params['state'] = state # Add state back temporarily for consistency if needed by internal logic, but it's removed before saving
516
+
517
+ for key in image_keys:
518
+ image_filenames = params.get(key)
519
+ if image_filenames is None: continue
520
+ is_list = isinstance(image_filenames, list)
521
+ if not is_list: image_filenames = [image_filenames]
522
+ loaded_pils = []
523
+ for img_filename_in_zip in image_filenames:
524
+ if not isinstance(img_filename_in_zip, str): continue
525
+ img_load_path = os.path.join(tmpdir, img_filename_in_zip)
526
+ if not os.path.exists(img_load_path):
527
+ print(f"[load_queue_action] Image file not found during load: {img_load_path}")
528
+ continue
529
+ try:
530
+ pil_image = Image.open(img_load_path)
531
+ # Ensure the image data is loaded into memory before the temp dir is cleaned up
532
+ pil_image.load()
533
+ # Convert image right after loading
534
+ converted_image = convert_image(pil_image)
535
+ loaded_pils.append(converted_image)
536
+ pil_image.close() # Close the file handle
537
+ except Exception as img_e:
538
+ print(f"[load_queue_action] Error loading image {img_filename_in_zip}: {img_e}")
539
+ if loaded_pils:
540
+ params[key] = loaded_pils if is_list else loaded_pils[0]
541
+ loaded_pil_images[key] = params[key] # Store loaded PILs for preview generation
542
+ else: params.pop(key, None)
543
+
544
+ # Generate preview base64 strings
545
+ primary_preview_pil, secondary_preview_pil = None, None
546
+ start_prev_pil_list = loaded_pil_images.get("image_start")
547
+ end_prev_pil_list = loaded_pil_images.get("image_end")
548
+ ref_prev_pil_list = loaded_pil_images.get("image_refs")
549
+
550
+ # Extract first image for preview if available
551
+ if start_prev_pil_list:
552
+ primary_preview_pil = start_prev_pil_list[0] if isinstance(start_prev_pil_list, list) and start_prev_pil_list else start_prev_pil_list if not isinstance(start_prev_pil_list, list) else None
553
+ if end_prev_pil_list:
554
+ secondary_preview_pil = end_prev_pil_list[0] if isinstance(end_prev_pil_list, list) and end_prev_pil_list else end_prev_pil_list if not isinstance(end_prev_pil_list, list) else None
555
+ elif ref_prev_pil_list and isinstance(ref_prev_pil_list, list) and ref_prev_pil_list:
556
+ primary_preview_pil = ref_prev_pil_list[0]
557
+
558
+ # Generate base64 only if PIL image exists
559
+ start_b64 = [pil_to_base64_uri(primary_preview_pil, format="jpeg", quality=70)] if primary_preview_pil else None
560
+ end_b64 = [pil_to_base64_uri(secondary_preview_pil, format="jpeg", quality=70)] if secondary_preview_pil else None
561
+
562
+ # Get top-level image data (PIL objects) for runtime task
563
+ top_level_start_image = loaded_pil_images.get("image_start")
564
+ top_level_end_image = loaded_pil_images.get("image_end")
565
+
566
+ # Construct the runtime task dictionary
567
+ runtime_task = {
568
+ "id": task_id_loaded,
569
+ "params": params.copy(), # Use a copy of params
570
+ # Extract necessary params for top level if they exist
571
+ "repeats": params.get('repeat_generation', 1),
572
+ "length": params.get('video_length'),
573
+ "steps": params.get('num_inference_steps'),
574
+ "prompt": params.get('prompt'),
575
+ # Store the actual loaded PIL image data here
576
+ "start_image_data": top_level_start_image,
577
+ "end_image_data": top_level_end_image,
578
+ # Store base64 previews generated above
579
+ "start_image_data_base64": start_b64,
580
+ "end_image_data_base64": end_b64,
581
+ }
582
+ newly_loaded_queue.append(runtime_task)
583
+ print(f"[load_queue_action] Processed task {task_index+1}/{len(loaded_manifest)}, ID: {task_id_loaded}")
584
+
585
+ # --- State Update ---
586
+ with lock:
587
+ print("[load_queue_action] Acquiring lock to update state...")
588
+ gen["queue"] = newly_loaded_queue[:] # Replace the queue in the state
589
+ local_queue_copy_for_global_ref = gen["queue"][:] # Copy for global ref update
590
+ current_max_id_in_new_queue = max([t['id'] for t in newly_loaded_queue if 'id' in t] + [0]) # Safer max ID calculation
591
+
592
+ # Update global task ID only if the loaded max ID is higher
593
+ if current_max_id_in_new_queue > task_id:
594
+ print(f"[load_queue_action] Updating global task_id from {task_id} to {current_max_id_in_new_queue + 1}")
595
+ task_id = current_max_id_in_new_queue + 1 # Ensure next ID is unique
596
+ else:
597
+ print(f"[load_queue_action] Global task_id ({task_id}) is >= max in file ({current_max_id_in_new_queue}). Not changing task_id.")
598
+
599
+ gen["prompts_max"] = len(newly_loaded_queue)
600
+ print("[load_queue_action] State update complete. Releasing lock.")
601
+
602
+ # --- Global Reference Update ---
603
+ if local_queue_copy_for_global_ref is not None:
604
+ print("[load_queue_action] Updating global queue reference...")
605
+ update_global_queue_ref(local_queue_copy_for_global_ref)
606
+ else:
607
+ # This case should ideally not be reached if state update happens
608
+ print("[load_queue_action] Warning: Skipping global ref update as local copy is None.")
609
+
610
+ print(f"[load_queue_action] Queue load successful. Returning DataFrame update for {len(newly_loaded_queue)} tasks.")
611
+ # *** Return the DataFrame update object ***
612
+ return update_queue_data(newly_loaded_queue)
613
+
614
+ except (ValueError, zipfile.BadZipFile, FileNotFoundError, Exception) as e:
615
+ error_message = f"Error during queue load: {e}"
616
+ print(f"[load_queue_action] Caught error: {error_message}")
617
+ traceback.print_exc()
618
+ # Optionally show a Gradio warning/error to the user
619
+ gr.Warning(f"Failed to load queue: {error_message[:200]}") # Show truncated error
620
+
621
+ # *** Return the DataFrame update for the original queue ***
622
+ print("[load_queue_action] Load failed. Returning DataFrame update for original queue.")
623
+ return update_queue_data(original_queue)
624
+ finally:
625
+ # Clean up the uploaded file object if it exists and has a path
626
+ if filepath and hasattr(filepath, 'name') and filepath.name and os.path.exists(filepath.name):
627
+ try:
628
+ # Gradio often uses temp files, attempting removal is good practice
629
+ # os.remove(filepath.name)
630
+ # print(f"[load_queue_action] Cleaned up temporary upload file: {filepath.name}")
631
+ pass # Let Gradio manage its temp files unless specifically needed
632
+ except OSError as e:
633
+ # Ignore errors like "file not found" if already cleaned up
634
+ print(f"[load_queue_action] Info: Could not remove temp file {filepath.name}: {e}")
635
+ pass
636
+
637
+ def clear_queue_action(state):
638
+ gen = get_gen_info(state)
639
+ queue = gen.get("queue", [])
640
+ if not queue:
641
+ gr.Info("Queue is already empty.")
642
+ return update_queue_data([])
643
+
644
+ with lock:
645
+ queue.clear()
646
+ gen["prompts_max"] = 0
647
+
648
+ gr.Info("Queue cleared.")
649
+ return update_queue_data([])
650
+
651
+ def autosave_queue():
652
+ global global_queue_ref
653
+ if not global_queue_ref:
654
+ print("Autosave: Queue is empty, nothing to save.")
655
+ return
656
+
657
+ print(f"Autosaving queue ({len(global_queue_ref)} items) to {AUTOSAVE_FILENAME}...")
658
+ temp_state_for_save = {"gen": {"queue": global_queue_ref}}
659
+ zip_file_path = None
660
+ try:
661
+
662
+ def _save_queue_to_file(queue_to_save, output_filename):
663
+ if not queue_to_save: return None
664
+ with tempfile.TemporaryDirectory() as tmpdir:
665
+ queue_manifest = []
666
+ image_paths_in_zip = {}
667
+ for task_index, task in enumerate(queue_to_save):
668
+ if task is None or not isinstance(task, dict): continue
669
+ params_copy = task.get('params', {}).copy()
670
+ task_id_s = task.get('id', f"task_{task_index}")
671
+ image_keys = ["image_start", "image_end", "image_refs"]
672
+ for key in image_keys:
673
+ images_pil = params_copy.get(key)
674
+ if images_pil is None: continue
675
+ is_list = isinstance(images_pil, list)
676
+ if not is_list: images_pil = [images_pil]
677
+ image_filenames_for_json = []
678
+ for img_index, pil_image in enumerate(images_pil):
679
+ if not isinstance(pil_image, Image.Image): continue
680
+ img_id = id(pil_image)
681
+ if img_id in image_paths_in_zip:
682
+ image_filenames_for_json.append(image_paths_in_zip[img_id])
683
+ continue
684
+ img_filename_in_zip = f"task{task_id_s}_{key}_{img_index}.png"
685
+ img_save_path = os.path.join(tmpdir, img_filename_in_zip)
686
+ try:
687
+ pil_image.save(img_save_path, "PNG")
688
+ image_filenames_for_json.append(img_filename_in_zip)
689
+ image_paths_in_zip[img_id] = img_filename_in_zip
690
+ except Exception as e:
691
+ print(f"Autosave error saving image {img_filename_in_zip}: {e}")
692
+ if image_filenames_for_json:
693
+ params_copy[key] = image_filenames_for_json if is_list else image_filenames_for_json[0]
694
+ else:
695
+ params_copy.pop(key, None)
696
+ params_copy.pop('state', None)
697
+ params_copy.pop('start_image_data_base64', None)
698
+ params_copy.pop('end_image_data_base64', None)
699
+ manifest_entry = {
700
+ "id": task.get('id'), "params": params_copy,
701
+ }
702
+ queue_manifest.append(manifest_entry)
703
+ manifest_path = os.path.join(tmpdir, "queue.json")
704
+ with open(manifest_path, 'w', encoding='utf-8') as f: json.dump(queue_manifest, f, indent=4)
705
+ with zipfile.ZipFile(output_filename, 'w', zipfile.ZIP_DEFLATED) as zf:
706
+ zf.write(manifest_path, arcname="queue.json")
707
+ for saved_img_rel_path in image_paths_in_zip.values():
708
+ saved_img_abs_path = os.path.join(tmpdir, saved_img_rel_path)
709
+ if os.path.exists(saved_img_abs_path):
710
+ zf.write(saved_img_abs_path, arcname=saved_img_rel_path)
711
+ return output_filename
712
+ return None # Should not happen if queue has items
713
+
714
+ saved_path = _save_queue_to_file(global_queue_ref, AUTOSAVE_FILENAME)
715
+
716
+ if saved_path:
717
+ print(f"Queue autosaved successfully to {saved_path}")
718
+ else:
719
+ print("Autosave failed.")
720
+ except Exception as e:
721
+ print(f"Error during autosave: {e}")
722
+ traceback.print_exc()
723
+
724
+
725
+ def autoload_queue(state):
726
+ global task_id
727
+ # Initial check using the original state
728
+ try:
729
+ gen = get_gen_info(state) # Make sure initial state is a dict
730
+ original_queue = gen.get("queue", [])
731
+ except AttributeError:
732
+ print("[autoload_queue] Error: Initial state is not a dictionary. Cannot autoload.")
733
+ # Return default values indicating no load occurred and the state is unchanged
734
+ return gr.update(visible=False), False, state # Return an empty DF update
735
+
736
+ loaded_flag = False
737
+ dataframe_update = update_queue_data(original_queue) # Default update is the original queue
738
+
739
+ if not original_queue and Path(AUTOSAVE_FILENAME).is_file():
740
+ print(f"Autoloading queue from {AUTOSAVE_FILENAME}...")
741
+ class MockFile:
742
+ def __init__(self, name):
743
+ self.name = name
744
+ mock_filepath = MockFile(AUTOSAVE_FILENAME)
745
+
746
+ # Call load_queue_action, it modifies 'state' internally and returns a DataFrame update
747
+ dataframe_update = load_queue_action(mock_filepath, state)
748
+
749
+ # Now check the 'state' dictionary which should have been modified by load_queue_action
750
+ gen = get_gen_info(state) # Use the (potentially) modified state dictionary
751
+ loaded_queue_after_action = gen.get("queue", [])
752
+
753
+ if loaded_queue_after_action: # Check if the queue in the state is now populated
754
+ print(f"Autoload successful. Loaded {len(loaded_queue_after_action)} tasks into state.")
755
+ loaded_flag = True
756
+ # Global ref update was already done inside load_queue_action if successful
757
+ else:
758
+ print("Autoload attempted but queue in state remains empty (file might be empty or invalid).")
759
+ # Ensure state reflects empty queue if load failed but file existed
760
+ with lock:
761
+ gen["queue"] = []
762
+ gen["prompts_max"] = 0
763
+ update_global_queue_ref([])
764
+ dataframe_update = update_queue_data([]) # Ensure UI shows empty queue
765
+
766
+ else: # Handle cases where autoload shouldn't happen
767
+ if original_queue:
768
+ print("Autoload skipped: Queue is not empty.")
769
+ update_global_queue_ref(original_queue) # Ensure global ref matches current state
770
+ dataframe_update = update_queue_data(original_queue) # UI should show current queue
771
+ else:
772
+ print(f"Autoload skipped: {AUTOSAVE_FILENAME} not found.")
773
+ update_global_queue_ref([]) # Ensure global ref is empty
774
+ dataframe_update = update_queue_data([]) # UI should show empty queue
775
+
776
+ # Return the DataFrame update needed for the UI, the flag, and the final state dictionary
777
+ return dataframe_update, loaded_flag, state
778
 
779
 
780
  def get_queue_table(queue):
 
838
  ])
839
  return data
840
  def update_queue_data(queue):
841
+ update_global_queue_ref(queue)
842
  data = get_queue_table(queue)
843
 
844
  # if len(data) == 0:
 
2441
  yield status
2442
 
2443
  queue[:] = [item for item in queue if item['id'] != task['id']]
2444
+ update_global_queue_ref(queue)
2445
 
2446
  gen["prompts_max"] = 0
2447
  gen["prompt"] = ""
 
3165
  wizard_variables = "\n".join(variables)
3166
  for _ in range( PROMPT_VARS_MAX - len(prompt_vars)):
3167
  prompt_vars.append(gr.Textbox(visible= False, min_width=80, show_label= False))
3168
+
3169
  with gr.Column(not advanced_prompt) as prompt_column_wizard:
3170
  wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3)
3171
  wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
 
3351
  queue_df = gr.DataFrame(
3352
  headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""],
3353
  datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
3354
+ column_widths= ["5%", None, "7%", "7%", "10%", "10%", "3%", "3%", "3%"],
3355
  interactive=False,
3356
  col_count=(9, "fixed"),
3357
  wrap=True,
 
3360
  visible= False,
3361
  elem_id="queue_df"
3362
  )
3363
+ with gr.Row():
3364
+ queue_zip_base64_output = gr.Text(visible=False)
3365
+ save_queue_btn = gr.DownloadButton("Save Queue", size="sm")
3366
+ load_queue_btn = gr.UploadButton("Load Queue", file_types=[".zip"], size="sm")
3367
+ clear_queue_btn = gr.Button("Clear Queue", size="sm", variant="stop")
3368
+ trigger_zip_download_js = """
3369
+ (base64String) => {
3370
+ if (!base64String) {
3371
+ console.log("No base64 zip data received, skipping download.");
3372
+ return;
3373
+ }
3374
+ try {
3375
+ const byteCharacters = atob(base64String);
3376
+ const byteNumbers = new Array(byteCharacters.length);
3377
+ for (let i = 0; i < byteCharacters.length; i++) {
3378
+ byteNumbers[i] = byteCharacters.charCodeAt(i);
3379
+ }
3380
+ const byteArray = new Uint8Array(byteNumbers);
3381
+ const blob = new Blob([byteArray], { type: 'application/zip' });
3382
+
3383
+ const url = URL.createObjectURL(blob);
3384
+ const a = document.createElement('a');
3385
+ a.style.display = 'none';
3386
+ a.href = url;
3387
+ a.download = 'queue.zip';
3388
+ document.body.appendChild(a);
3389
+ a.click();
3390
+
3391
+ window.URL.revokeObjectURL(url);
3392
+ document.body.removeChild(a);
3393
+ console.log("Zip download triggered.");
3394
+ } catch (e) {
3395
+ console.error("Error processing base64 data or triggering download:", e);
3396
+ }
3397
+ }
3398
+ """
3399
+ save_queue_btn.click(
3400
+ fn=save_queue_action,
3401
+ inputs=[state],
3402
+ outputs=[queue_zip_base64_output]
3403
+ ).then(
3404
+ fn=None,
3405
+ inputs=[queue_zip_base64_output],
3406
+ outputs=None,
3407
+ js=trigger_zip_download_js
3408
+ )
3409
+
3410
+ load_queue_btn.upload(
3411
+ fn=load_queue_action,
3412
+ inputs=[load_queue_btn, state],
3413
+ outputs=[queue_df]
3414
+ ).then(
3415
+ fn=lambda s: gr.update(visible=bool(get_gen_info(s).get("queue",[]))),
3416
+ inputs=[state],
3417
+ outputs=[current_gen_column]
3418
+ )
3419
+
3420
+ clear_queue_btn.click(
3421
+ fn=clear_queue_action,
3422
+ inputs=[state],
3423
+ outputs=[queue_df]
3424
+ ).then(
3425
+ fn=lambda: gr.update(visible=False),
3426
+ inputs=None,
3427
+ outputs=[current_gen_column]
3428
+ )
3429
 
3430
  extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
3431
  prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row] # show_advanced presets_column,
 
3529
  outputs=[modal_container]
3530
  )
3531
 
3532
+ return (
3533
+ loras_choices, lset_name, state, queue_df, current_gen_column,
3534
+ gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
3535
+ gen_info,
3536
+ prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var,
3537
+ prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars,
3538
+ advanced_row, image_prompt_column, video_prompt_column,
3539
+ *prompt_vars
3540
+ )
3541
+
3542
 
3543
  def generate_download_tab(lset_name,loras_choices, state):
3544
  with gr.Row():
 
4003
  with gr.Row():
4004
  header = gr.Markdown(generate_header(transformer_filename, compile, attention_mode), visible= True)
4005
  with gr.Row():
4006
+ (
4007
+ loras_choices, lset_name, state, queue_df, current_gen_column,
4008
+ gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
4009
+ gen_info,
4010
+ prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var,
4011
+ prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars,
4012
+ advanced_row, image_prompt_column, video_prompt_column,
4013
+ *prompt_vars_outputs
4014
+ ) = generate_video_tab(model_choice=model_choice, header=header)
4015
  with gr.Tab("Informations"):
4016
  generate_info_tab()
4017
  if not args.lock_config:
 
4022
  with gr.Tab("About"):
4023
  generate_about_tab()
4024
 
4025
+ should_start_flag = gr.State(False)
4026
+ def run_autoload_and_prepare_ui(current_state):
4027
+ df_update, loaded_flag, modified_state = autoload_queue(current_state)
4028
+ should_start_processing = loaded_flag
4029
+ return df_update, gr.update(visible=loaded_flag), should_start_processing, modified_state
4030
+
4031
+ def start_processing_if_needed(should_start, current_state):
4032
+ if not isinstance(current_state, dict) or 'gen' not in current_state:
4033
+ yield "Error: Invalid state received before processing."
4034
+ return
4035
+ if should_start:
4036
+ yield from process_tasks(current_state)
4037
+ else:
4038
+ yield "Autoload complete. Processing not started."
4039
+
4040
+ def finalize_generation_with_state(current_state):
4041
+ if not isinstance(current_state, dict) or 'gen' not in current_state:
4042
+ return gr.update(), gr.update(interactive=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=""), current_state
4043
+ gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update = finalize_generation(current_state)
4044
+ return gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update, current_state
4045
+
4046
+ demo.load(
4047
+ fn=run_autoload_and_prepare_ui,
4048
+ inputs=[state],
4049
+ outputs=[queue_df, current_gen_column, should_start_flag, state]
4050
+ ).then(
4051
+ fn=start_processing_if_needed,
4052
+ inputs=[should_start_flag, state],
4053
+ outputs=[gen_status],
4054
+ trigger_mode="once"
4055
+ ).then(
4056
+ fn=finalize_generation_with_state,
4057
+ inputs=[state],
4058
+ outputs=[output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info, state],
4059
+ trigger_mode="always_last"
4060
+ )
4061
+
4062
  return demo
4063
 
4064
  if __name__ == "__main__":
4065
+ atexit.register(autosave_queue)
4066
  # threading.Thread(target=runner, daemon=True).start()
4067
  os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
4068
  server_port = int(args.server_port)