Commit ·
134cb56
1
Parent(s): 12652e0
add queue saving/loading/clearing/autosaving/autoloading, fix empty prompt logic
Browse files
wgp.py
CHANGED
|
@@ -28,6 +28,12 @@ from wan.utils import prompt_parser
|
|
| 28 |
import base64
|
| 29 |
import io
|
| 30 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
PROMPT_VARS_MAX = 10
|
| 32 |
|
| 33 |
target_mmgp_version = "3.3.4"
|
|
@@ -98,10 +104,14 @@ def process_prompt_and_add_tasks(state, model_choice):
|
|
| 98 |
inputs["state"] = state
|
| 99 |
inputs.pop("lset_name")
|
| 100 |
if inputs == None:
|
| 101 |
-
|
|
|
|
| 102 |
prompt = inputs["prompt"]
|
| 103 |
if len(prompt) ==0:
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
| 105 |
prompt, errors = prompt_parser.process_template(prompt)
|
| 106 |
if len(errors) > 0:
|
| 107 |
gr.Info("Error processing prompt template: " + errors)
|
|
@@ -111,7 +121,10 @@ def process_prompt_and_add_tasks(state, model_choice):
|
|
| 111 |
prompts = prompt.replace("\r", "").split("\n")
|
| 112 |
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
|
| 113 |
if len(prompts) ==0:
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
| 115 |
|
| 116 |
resolution = inputs["resolution"]
|
| 117 |
width, height = resolution.split("x")
|
|
@@ -250,9 +263,6 @@ def process_prompt_and_add_tasks(state, model_choice):
|
|
| 250 |
queue= gen.get("queue", [])
|
| 251 |
return update_queue_data(queue)
|
| 252 |
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
def add_video_task(**inputs):
|
| 257 |
global task_id
|
| 258 |
state = inputs["state"]
|
|
@@ -327,6 +337,444 @@ def remove_task(queue, selected_indices):
|
|
| 327 |
del queue[idx]
|
| 328 |
return update_queue_data(queue)
|
| 329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
|
| 332 |
def get_queue_table(queue):
|
|
@@ -390,7 +838,7 @@ def get_queue_table(queue):
|
|
| 390 |
])
|
| 391 |
return data
|
| 392 |
def update_queue_data(queue):
|
| 393 |
-
|
| 394 |
data = get_queue_table(queue)
|
| 395 |
|
| 396 |
# if len(data) == 0:
|
|
@@ -1993,6 +2441,7 @@ def process_tasks(state, progress=gr.Progress()):
|
|
| 1993 |
yield status
|
| 1994 |
|
| 1995 |
queue[:] = [item for item in queue if item['id'] != task['id']]
|
|
|
|
| 1996 |
|
| 1997 |
gen["prompts_max"] = 0
|
| 1998 |
gen["prompt"] = ""
|
|
@@ -2716,7 +3165,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
| 2716 |
wizard_variables = "\n".join(variables)
|
| 2717 |
for _ in range( PROMPT_VARS_MAX - len(prompt_vars)):
|
| 2718 |
prompt_vars.append(gr.Textbox(visible= False, min_width=80, show_label= False))
|
| 2719 |
-
|
| 2720 |
with gr.Column(not advanced_prompt) as prompt_column_wizard:
|
| 2721 |
wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3)
|
| 2722 |
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
|
|
@@ -2902,7 +3351,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
| 2902 |
queue_df = gr.DataFrame(
|
| 2903 |
headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""],
|
| 2904 |
datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
|
| 2905 |
-
column_widths= ["
|
| 2906 |
interactive=False,
|
| 2907 |
col_count=(9, "fixed"),
|
| 2908 |
wrap=True,
|
|
@@ -2911,6 +3360,72 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
| 2911 |
visible= False,
|
| 2912 |
elem_id="queue_df"
|
| 2913 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2914 |
|
| 2915 |
extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
|
| 2916 |
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row] # show_advanced presets_column,
|
|
@@ -3014,7 +3529,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
| 3014 |
outputs=[modal_container]
|
| 3015 |
)
|
| 3016 |
|
| 3017 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3018 |
|
| 3019 |
def generate_download_tab(lset_name,loras_choices, state):
|
| 3020 |
with gr.Row():
|
|
@@ -3479,8 +4003,15 @@ def create_demo():
|
|
| 3479 |
with gr.Row():
|
| 3480 |
header = gr.Markdown(generate_header(transformer_filename, compile, attention_mode), visible= True)
|
| 3481 |
with gr.Row():
|
| 3482 |
-
|
| 3483 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3484 |
with gr.Tab("Informations"):
|
| 3485 |
generate_info_tab()
|
| 3486 |
if not args.lock_config:
|
|
@@ -3491,9 +4022,47 @@ def create_demo():
|
|
| 3491 |
with gr.Tab("About"):
|
| 3492 |
generate_about_tab()
|
| 3493 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3494 |
return demo
|
| 3495 |
|
| 3496 |
if __name__ == "__main__":
|
|
|
|
| 3497 |
# threading.Thread(target=runner, daemon=True).start()
|
| 3498 |
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
| 3499 |
server_port = int(args.server_port)
|
|
|
|
| 28 |
import base64
|
| 29 |
import io
|
| 30 |
from PIL import Image
|
| 31 |
+
import zipfile
|
| 32 |
+
import tempfile
|
| 33 |
+
import shutil
|
| 34 |
+
import atexit
|
| 35 |
+
global_queue_ref = []
|
| 36 |
+
AUTOSAVE_FILENAME = "queue.zip"
|
| 37 |
PROMPT_VARS_MAX = 10
|
| 38 |
|
| 39 |
target_mmgp_version = "3.3.4"
|
|
|
|
| 104 |
inputs["state"] = state
|
| 105 |
inputs.pop("lset_name")
|
| 106 |
if inputs == None:
|
| 107 |
+
gr.Warning("Internal state error: Could not retrieve inputs for the model.")
|
| 108 |
+
return update_queue_data(queue)
|
| 109 |
prompt = inputs["prompt"]
|
| 110 |
if len(prompt) ==0:
|
| 111 |
+
gr.Info("Prompt cannot be empty.")
|
| 112 |
+
gen = get_gen_info(state)
|
| 113 |
+
queue = gen.get("queue", [])
|
| 114 |
+
return get_queue_table(queue)
|
| 115 |
prompt, errors = prompt_parser.process_template(prompt)
|
| 116 |
if len(errors) > 0:
|
| 117 |
gr.Info("Error processing prompt template: " + errors)
|
|
|
|
| 121 |
prompts = prompt.replace("\r", "").split("\n")
|
| 122 |
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
|
| 123 |
if len(prompts) ==0:
|
| 124 |
+
gr.Info("Prompt cannot be empty.")
|
| 125 |
+
gen = get_gen_info(state)
|
| 126 |
+
queue = gen.get("queue", [])
|
| 127 |
+
return get_queue_table(queue)
|
| 128 |
|
| 129 |
resolution = inputs["resolution"]
|
| 130 |
width, height = resolution.split("x")
|
|
|
|
| 263 |
queue= gen.get("queue", [])
|
| 264 |
return update_queue_data(queue)
|
| 265 |
|
|
|
|
|
|
|
|
|
|
| 266 |
def add_video_task(**inputs):
|
| 267 |
global task_id
|
| 268 |
state = inputs["state"]
|
|
|
|
| 337 |
del queue[idx]
|
| 338 |
return update_queue_data(queue)
|
| 339 |
|
| 340 |
+
def update_global_queue_ref(queue):
|
| 341 |
+
global global_queue_ref
|
| 342 |
+
with lock:
|
| 343 |
+
global_queue_ref = queue[:]
|
| 344 |
+
|
| 345 |
+
def save_queue_action(state):
|
| 346 |
+
gen = get_gen_info(state)
|
| 347 |
+
queue = gen.get("queue", [])
|
| 348 |
+
|
| 349 |
+
if not queue or len(queue) <=1 : # Check if queue is empty or only has the placeholder
|
| 350 |
+
gr.Info("Queue is empty. Nothing to save.")
|
| 351 |
+
return None # Return None if nothing to save
|
| 352 |
+
|
| 353 |
+
# Use an in-memory buffer for the zip file
|
| 354 |
+
zip_buffer = io.BytesIO()
|
| 355 |
+
|
| 356 |
+
# Still use a temporary directory *only* for storing images before zipping
|
| 357 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 358 |
+
queue_manifest = []
|
| 359 |
+
image_paths_in_zip = {} # Tracks image PIL object ID -> filename in zip
|
| 360 |
+
|
| 361 |
+
for task_index, task in enumerate(queue):
|
| 362 |
+
# Skip the placeholder item if it exists
|
| 363 |
+
if task is None or not isinstance(task, dict) or task_index == 0: continue
|
| 364 |
+
|
| 365 |
+
params_copy = task.get('params', {}).copy()
|
| 366 |
+
task_id_s = task.get('id', f"task_{task_index}") # Use a different var name
|
| 367 |
+
|
| 368 |
+
image_keys = ["image_start", "image_end", "image_refs"]
|
| 369 |
+
for key in image_keys:
|
| 370 |
+
images_pil = params_copy.get(key)
|
| 371 |
+
if images_pil is None:
|
| 372 |
+
continue
|
| 373 |
+
|
| 374 |
+
# Ensure images_pil is always a list for processing
|
| 375 |
+
is_originally_list = isinstance(images_pil, list)
|
| 376 |
+
if not is_originally_list:
|
| 377 |
+
images_pil = [images_pil]
|
| 378 |
+
|
| 379 |
+
image_filenames_for_json = []
|
| 380 |
+
for img_index, pil_image in enumerate(images_pil):
|
| 381 |
+
# Ensure it's actually a PIL Image object before proceeding
|
| 382 |
+
if not isinstance(pil_image, Image.Image):
|
| 383 |
+
print(f"Warning: Expected PIL Image for key '{key}' in task {task_id_s}, got {type(pil_image)}. Skipping image.")
|
| 384 |
+
continue
|
| 385 |
+
|
| 386 |
+
# Use object ID to check if this specific image instance is already saved
|
| 387 |
+
img_id = id(pil_image)
|
| 388 |
+
if img_id in image_paths_in_zip:
|
| 389 |
+
# If already saved, just add its filename to the list
|
| 390 |
+
image_filenames_for_json.append(image_paths_in_zip[img_id])
|
| 391 |
+
continue # Move to the next image in the list
|
| 392 |
+
|
| 393 |
+
# Image not saved yet, create filename and save path
|
| 394 |
+
img_filename_in_zip = f"task{task_id_s}_{key}_{img_index}.png"
|
| 395 |
+
img_save_path = os.path.join(tmpdir, img_filename_in_zip)
|
| 396 |
+
|
| 397 |
+
try:
|
| 398 |
+
# Save the image to the temporary directory
|
| 399 |
+
pil_image.save(img_save_path, "PNG")
|
| 400 |
+
image_filenames_for_json.append(img_filename_in_zip)
|
| 401 |
+
# Store the mapping from image ID to its filename in the zip
|
| 402 |
+
image_paths_in_zip[img_id] = img_filename_in_zip
|
| 403 |
+
except Exception as e:
|
| 404 |
+
print(f"Error saving image {img_filename_in_zip} for task {task_id_s}: {e}")
|
| 405 |
+
# Optionally decide if you want to continue or fail here
|
| 406 |
+
|
| 407 |
+
# Update the params_copy with the list of filenames (or single filename)
|
| 408 |
+
if image_filenames_for_json:
|
| 409 |
+
params_copy[key] = image_filenames_for_json if is_originally_list else image_filenames_for_json[0]
|
| 410 |
+
else:
|
| 411 |
+
# If no images were successfully processed for this key, remove it
|
| 412 |
+
params_copy.pop(key, None)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
# Clean up parameters before adding to manifest
|
| 416 |
+
params_copy.pop('state', None)
|
| 417 |
+
params_copy.pop('start_image_data_base64', None) # Don't need base64 in saved queue
|
| 418 |
+
params_copy.pop('end_image_data_base64', None)
|
| 419 |
+
# Also remove the actual PIL data if it somehow remained
|
| 420 |
+
params_copy.pop('start_image_data', None)
|
| 421 |
+
params_copy.pop('end_image_data', None)
|
| 422 |
+
|
| 423 |
+
manifest_entry = {
|
| 424 |
+
"id": task.get('id'),
|
| 425 |
+
"params": params_copy,
|
| 426 |
+
# Keep other necessary top-level task info if needed, like repeats etc.
|
| 427 |
+
# Example: "repeats": task.get('repeats', 1)
|
| 428 |
+
}
|
| 429 |
+
queue_manifest.append(manifest_entry)
|
| 430 |
+
|
| 431 |
+
# --- Create queue.json content ---
|
| 432 |
+
manifest_path = os.path.join(tmpdir, "queue.json")
|
| 433 |
+
try:
|
| 434 |
+
with open(manifest_path, 'w', encoding='utf-8') as f:
|
| 435 |
+
# Dump only the relevant manifest data
|
| 436 |
+
json.dump(queue_manifest, f, indent=4)
|
| 437 |
+
except Exception as e:
|
| 438 |
+
print(f"Error writing queue.json: {e}")
|
| 439 |
+
gr.Warning("Failed to create queue manifest.")
|
| 440 |
+
return None # Return None on failure
|
| 441 |
+
|
| 442 |
+
# --- Create the zip file in memory ---
|
| 443 |
+
try:
|
| 444 |
+
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf:
|
| 445 |
+
# Add queue.json
|
| 446 |
+
zf.write(manifest_path, arcname="queue.json")
|
| 447 |
+
|
| 448 |
+
# Add all unique images that were saved to the temp dir
|
| 449 |
+
for saved_img_rel_path in image_paths_in_zip.values():
|
| 450 |
+
saved_img_abs_path = os.path.join(tmpdir, saved_img_rel_path)
|
| 451 |
+
if os.path.exists(saved_img_abs_path):
|
| 452 |
+
zf.write(saved_img_abs_path, arcname=saved_img_rel_path)
|
| 453 |
+
else:
|
| 454 |
+
# This shouldn't happen if saving was successful, but good to check
|
| 455 |
+
print(f"Warning: Image file {saved_img_rel_path} not found during zipping.")
|
| 456 |
+
|
| 457 |
+
# --- Prepare for return ---
|
| 458 |
+
# Move buffer position to the beginning
|
| 459 |
+
zip_buffer.seek(0)
|
| 460 |
+
# Read the binary content
|
| 461 |
+
zip_binary_content = zip_buffer.getvalue()
|
| 462 |
+
# Encode as base64 string
|
| 463 |
+
zip_base64 = base64.b64encode(zip_binary_content).decode('utf-8')
|
| 464 |
+
print(f"Queue successfully prepared as base64 string ({len(zip_base64)} chars).")
|
| 465 |
+
return zip_base64
|
| 466 |
+
|
| 467 |
+
except Exception as e:
|
| 468 |
+
print(f"Error creating zip file in memory: {e}")
|
| 469 |
+
gr.Warning("Failed to create zip data for download.")
|
| 470 |
+
return None # Return None on failure
|
| 471 |
+
finally:
|
| 472 |
+
zip_buffer.close()
|
| 473 |
+
|
| 474 |
+
def load_queue_action(filepath, state):
|
| 475 |
+
global task_id
|
| 476 |
+
gen = get_gen_info(state)
|
| 477 |
+
original_queue = gen.get("queue", []) # Store original queue for error case
|
| 478 |
+
|
| 479 |
+
if not filepath or not hasattr(filepath, 'name') or not Path(filepath.name).is_file():
|
| 480 |
+
print("[load_queue_action] Warning: No valid file selected or file not found.")
|
| 481 |
+
# Return the current state of the DataFrame
|
| 482 |
+
return update_queue_data(original_queue)
|
| 483 |
+
|
| 484 |
+
newly_loaded_queue = []
|
| 485 |
+
max_id_in_file = 0
|
| 486 |
+
error_message = ""
|
| 487 |
+
local_queue_copy_for_global_ref = None
|
| 488 |
+
|
| 489 |
+
try:
|
| 490 |
+
print(f"[load_queue_action] Attempting to load queue from: {filepath.name}")
|
| 491 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 492 |
+
with zipfile.ZipFile(filepath.name, 'r') as zf:
|
| 493 |
+
if "queue.json" not in zf.namelist(): raise ValueError("queue.json not found in zip file")
|
| 494 |
+
print(f"[load_queue_action] Extracting {filepath.name} to {tmpdir}")
|
| 495 |
+
zf.extractall(tmpdir)
|
| 496 |
+
print(f"[load_queue_action] Extraction complete.")
|
| 497 |
+
|
| 498 |
+
manifest_path = os.path.join(tmpdir, "queue.json")
|
| 499 |
+
print(f"[load_queue_action] Reading manifest: {manifest_path}")
|
| 500 |
+
with open(manifest_path, 'r', encoding='utf-8') as f:
|
| 501 |
+
loaded_manifest = json.load(f)
|
| 502 |
+
print(f"[load_queue_action] Manifest loaded. Processing {len(loaded_manifest)} tasks.")
|
| 503 |
+
|
| 504 |
+
for task_index, task_data in enumerate(loaded_manifest):
|
| 505 |
+
# (Keep the existing task processing logic here...)
|
| 506 |
+
if task_data is None or not isinstance(task_data, dict):
|
| 507 |
+
print(f"[load_queue_action] Skipping invalid task data at index {task_index}")
|
| 508 |
+
continue
|
| 509 |
+
|
| 510 |
+
params = task_data.get('params', {})
|
| 511 |
+
task_id_loaded = task_data.get('id', 0)
|
| 512 |
+
max_id_in_file = max(max_id_in_file, task_id_loaded)
|
| 513 |
+
loaded_pil_images = {}
|
| 514 |
+
image_keys = ["image_start", "image_end", "image_refs"]
|
| 515 |
+
params['state'] = state # Add state back temporarily for consistency if needed by internal logic, but it's removed before saving
|
| 516 |
+
|
| 517 |
+
for key in image_keys:
|
| 518 |
+
image_filenames = params.get(key)
|
| 519 |
+
if image_filenames is None: continue
|
| 520 |
+
is_list = isinstance(image_filenames, list)
|
| 521 |
+
if not is_list: image_filenames = [image_filenames]
|
| 522 |
+
loaded_pils = []
|
| 523 |
+
for img_filename_in_zip in image_filenames:
|
| 524 |
+
if not isinstance(img_filename_in_zip, str): continue
|
| 525 |
+
img_load_path = os.path.join(tmpdir, img_filename_in_zip)
|
| 526 |
+
if not os.path.exists(img_load_path):
|
| 527 |
+
print(f"[load_queue_action] Image file not found during load: {img_load_path}")
|
| 528 |
+
continue
|
| 529 |
+
try:
|
| 530 |
+
pil_image = Image.open(img_load_path)
|
| 531 |
+
# Ensure the image data is loaded into memory before the temp dir is cleaned up
|
| 532 |
+
pil_image.load()
|
| 533 |
+
# Convert image right after loading
|
| 534 |
+
converted_image = convert_image(pil_image)
|
| 535 |
+
loaded_pils.append(converted_image)
|
| 536 |
+
pil_image.close() # Close the file handle
|
| 537 |
+
except Exception as img_e:
|
| 538 |
+
print(f"[load_queue_action] Error loading image {img_filename_in_zip}: {img_e}")
|
| 539 |
+
if loaded_pils:
|
| 540 |
+
params[key] = loaded_pils if is_list else loaded_pils[0]
|
| 541 |
+
loaded_pil_images[key] = params[key] # Store loaded PILs for preview generation
|
| 542 |
+
else: params.pop(key, None)
|
| 543 |
+
|
| 544 |
+
# Generate preview base64 strings
|
| 545 |
+
primary_preview_pil, secondary_preview_pil = None, None
|
| 546 |
+
start_prev_pil_list = loaded_pil_images.get("image_start")
|
| 547 |
+
end_prev_pil_list = loaded_pil_images.get("image_end")
|
| 548 |
+
ref_prev_pil_list = loaded_pil_images.get("image_refs")
|
| 549 |
+
|
| 550 |
+
# Extract first image for preview if available
|
| 551 |
+
if start_prev_pil_list:
|
| 552 |
+
primary_preview_pil = start_prev_pil_list[0] if isinstance(start_prev_pil_list, list) and start_prev_pil_list else start_prev_pil_list if not isinstance(start_prev_pil_list, list) else None
|
| 553 |
+
if end_prev_pil_list:
|
| 554 |
+
secondary_preview_pil = end_prev_pil_list[0] if isinstance(end_prev_pil_list, list) and end_prev_pil_list else end_prev_pil_list if not isinstance(end_prev_pil_list, list) else None
|
| 555 |
+
elif ref_prev_pil_list and isinstance(ref_prev_pil_list, list) and ref_prev_pil_list:
|
| 556 |
+
primary_preview_pil = ref_prev_pil_list[0]
|
| 557 |
+
|
| 558 |
+
# Generate base64 only if PIL image exists
|
| 559 |
+
start_b64 = [pil_to_base64_uri(primary_preview_pil, format="jpeg", quality=70)] if primary_preview_pil else None
|
| 560 |
+
end_b64 = [pil_to_base64_uri(secondary_preview_pil, format="jpeg", quality=70)] if secondary_preview_pil else None
|
| 561 |
+
|
| 562 |
+
# Get top-level image data (PIL objects) for runtime task
|
| 563 |
+
top_level_start_image = loaded_pil_images.get("image_start")
|
| 564 |
+
top_level_end_image = loaded_pil_images.get("image_end")
|
| 565 |
+
|
| 566 |
+
# Construct the runtime task dictionary
|
| 567 |
+
runtime_task = {
|
| 568 |
+
"id": task_id_loaded,
|
| 569 |
+
"params": params.copy(), # Use a copy of params
|
| 570 |
+
# Extract necessary params for top level if they exist
|
| 571 |
+
"repeats": params.get('repeat_generation', 1),
|
| 572 |
+
"length": params.get('video_length'),
|
| 573 |
+
"steps": params.get('num_inference_steps'),
|
| 574 |
+
"prompt": params.get('prompt'),
|
| 575 |
+
# Store the actual loaded PIL image data here
|
| 576 |
+
"start_image_data": top_level_start_image,
|
| 577 |
+
"end_image_data": top_level_end_image,
|
| 578 |
+
# Store base64 previews generated above
|
| 579 |
+
"start_image_data_base64": start_b64,
|
| 580 |
+
"end_image_data_base64": end_b64,
|
| 581 |
+
}
|
| 582 |
+
newly_loaded_queue.append(runtime_task)
|
| 583 |
+
print(f"[load_queue_action] Processed task {task_index+1}/{len(loaded_manifest)}, ID: {task_id_loaded}")
|
| 584 |
+
|
| 585 |
+
# --- State Update ---
|
| 586 |
+
with lock:
|
| 587 |
+
print("[load_queue_action] Acquiring lock to update state...")
|
| 588 |
+
gen["queue"] = newly_loaded_queue[:] # Replace the queue in the state
|
| 589 |
+
local_queue_copy_for_global_ref = gen["queue"][:] # Copy for global ref update
|
| 590 |
+
current_max_id_in_new_queue = max([t['id'] for t in newly_loaded_queue if 'id' in t] + [0]) # Safer max ID calculation
|
| 591 |
+
|
| 592 |
+
# Update global task ID only if the loaded max ID is higher
|
| 593 |
+
if current_max_id_in_new_queue > task_id:
|
| 594 |
+
print(f"[load_queue_action] Updating global task_id from {task_id} to {current_max_id_in_new_queue + 1}")
|
| 595 |
+
task_id = current_max_id_in_new_queue + 1 # Ensure next ID is unique
|
| 596 |
+
else:
|
| 597 |
+
print(f"[load_queue_action] Global task_id ({task_id}) is >= max in file ({current_max_id_in_new_queue}). Not changing task_id.")
|
| 598 |
+
|
| 599 |
+
gen["prompts_max"] = len(newly_loaded_queue)
|
| 600 |
+
print("[load_queue_action] State update complete. Releasing lock.")
|
| 601 |
+
|
| 602 |
+
# --- Global Reference Update ---
|
| 603 |
+
if local_queue_copy_for_global_ref is not None:
|
| 604 |
+
print("[load_queue_action] Updating global queue reference...")
|
| 605 |
+
update_global_queue_ref(local_queue_copy_for_global_ref)
|
| 606 |
+
else:
|
| 607 |
+
# This case should ideally not be reached if state update happens
|
| 608 |
+
print("[load_queue_action] Warning: Skipping global ref update as local copy is None.")
|
| 609 |
+
|
| 610 |
+
print(f"[load_queue_action] Queue load successful. Returning DataFrame update for {len(newly_loaded_queue)} tasks.")
|
| 611 |
+
# *** Return the DataFrame update object ***
|
| 612 |
+
return update_queue_data(newly_loaded_queue)
|
| 613 |
+
|
| 614 |
+
except (ValueError, zipfile.BadZipFile, FileNotFoundError, Exception) as e:
|
| 615 |
+
error_message = f"Error during queue load: {e}"
|
| 616 |
+
print(f"[load_queue_action] Caught error: {error_message}")
|
| 617 |
+
traceback.print_exc()
|
| 618 |
+
# Optionally show a Gradio warning/error to the user
|
| 619 |
+
gr.Warning(f"Failed to load queue: {error_message[:200]}") # Show truncated error
|
| 620 |
+
|
| 621 |
+
# *** Return the DataFrame update for the original queue ***
|
| 622 |
+
print("[load_queue_action] Load failed. Returning DataFrame update for original queue.")
|
| 623 |
+
return update_queue_data(original_queue)
|
| 624 |
+
finally:
|
| 625 |
+
# Clean up the uploaded file object if it exists and has a path
|
| 626 |
+
if filepath and hasattr(filepath, 'name') and filepath.name and os.path.exists(filepath.name):
|
| 627 |
+
try:
|
| 628 |
+
# Gradio often uses temp files, attempting removal is good practice
|
| 629 |
+
# os.remove(filepath.name)
|
| 630 |
+
# print(f"[load_queue_action] Cleaned up temporary upload file: {filepath.name}")
|
| 631 |
+
pass # Let Gradio manage its temp files unless specifically needed
|
| 632 |
+
except OSError as e:
|
| 633 |
+
# Ignore errors like "file not found" if already cleaned up
|
| 634 |
+
print(f"[load_queue_action] Info: Could not remove temp file {filepath.name}: {e}")
|
| 635 |
+
pass
|
| 636 |
+
|
| 637 |
+
def clear_queue_action(state):
|
| 638 |
+
gen = get_gen_info(state)
|
| 639 |
+
queue = gen.get("queue", [])
|
| 640 |
+
if not queue:
|
| 641 |
+
gr.Info("Queue is already empty.")
|
| 642 |
+
return update_queue_data([])
|
| 643 |
+
|
| 644 |
+
with lock:
|
| 645 |
+
queue.clear()
|
| 646 |
+
gen["prompts_max"] = 0
|
| 647 |
+
|
| 648 |
+
gr.Info("Queue cleared.")
|
| 649 |
+
return update_queue_data([])
|
| 650 |
+
|
| 651 |
+
def autosave_queue():
|
| 652 |
+
global global_queue_ref
|
| 653 |
+
if not global_queue_ref:
|
| 654 |
+
print("Autosave: Queue is empty, nothing to save.")
|
| 655 |
+
return
|
| 656 |
+
|
| 657 |
+
print(f"Autosaving queue ({len(global_queue_ref)} items) to {AUTOSAVE_FILENAME}...")
|
| 658 |
+
temp_state_for_save = {"gen": {"queue": global_queue_ref}}
|
| 659 |
+
zip_file_path = None
|
| 660 |
+
try:
|
| 661 |
+
|
| 662 |
+
def _save_queue_to_file(queue_to_save, output_filename):
|
| 663 |
+
if not queue_to_save: return None
|
| 664 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 665 |
+
queue_manifest = []
|
| 666 |
+
image_paths_in_zip = {}
|
| 667 |
+
for task_index, task in enumerate(queue_to_save):
|
| 668 |
+
if task is None or not isinstance(task, dict): continue
|
| 669 |
+
params_copy = task.get('params', {}).copy()
|
| 670 |
+
task_id_s = task.get('id', f"task_{task_index}")
|
| 671 |
+
image_keys = ["image_start", "image_end", "image_refs"]
|
| 672 |
+
for key in image_keys:
|
| 673 |
+
images_pil = params_copy.get(key)
|
| 674 |
+
if images_pil is None: continue
|
| 675 |
+
is_list = isinstance(images_pil, list)
|
| 676 |
+
if not is_list: images_pil = [images_pil]
|
| 677 |
+
image_filenames_for_json = []
|
| 678 |
+
for img_index, pil_image in enumerate(images_pil):
|
| 679 |
+
if not isinstance(pil_image, Image.Image): continue
|
| 680 |
+
img_id = id(pil_image)
|
| 681 |
+
if img_id in image_paths_in_zip:
|
| 682 |
+
image_filenames_for_json.append(image_paths_in_zip[img_id])
|
| 683 |
+
continue
|
| 684 |
+
img_filename_in_zip = f"task{task_id_s}_{key}_{img_index}.png"
|
| 685 |
+
img_save_path = os.path.join(tmpdir, img_filename_in_zip)
|
| 686 |
+
try:
|
| 687 |
+
pil_image.save(img_save_path, "PNG")
|
| 688 |
+
image_filenames_for_json.append(img_filename_in_zip)
|
| 689 |
+
image_paths_in_zip[img_id] = img_filename_in_zip
|
| 690 |
+
except Exception as e:
|
| 691 |
+
print(f"Autosave error saving image {img_filename_in_zip}: {e}")
|
| 692 |
+
if image_filenames_for_json:
|
| 693 |
+
params_copy[key] = image_filenames_for_json if is_list else image_filenames_for_json[0]
|
| 694 |
+
else:
|
| 695 |
+
params_copy.pop(key, None)
|
| 696 |
+
params_copy.pop('state', None)
|
| 697 |
+
params_copy.pop('start_image_data_base64', None)
|
| 698 |
+
params_copy.pop('end_image_data_base64', None)
|
| 699 |
+
manifest_entry = {
|
| 700 |
+
"id": task.get('id'), "params": params_copy,
|
| 701 |
+
}
|
| 702 |
+
queue_manifest.append(manifest_entry)
|
| 703 |
+
manifest_path = os.path.join(tmpdir, "queue.json")
|
| 704 |
+
with open(manifest_path, 'w', encoding='utf-8') as f: json.dump(queue_manifest, f, indent=4)
|
| 705 |
+
with zipfile.ZipFile(output_filename, 'w', zipfile.ZIP_DEFLATED) as zf:
|
| 706 |
+
zf.write(manifest_path, arcname="queue.json")
|
| 707 |
+
for saved_img_rel_path in image_paths_in_zip.values():
|
| 708 |
+
saved_img_abs_path = os.path.join(tmpdir, saved_img_rel_path)
|
| 709 |
+
if os.path.exists(saved_img_abs_path):
|
| 710 |
+
zf.write(saved_img_abs_path, arcname=saved_img_rel_path)
|
| 711 |
+
return output_filename
|
| 712 |
+
return None # Should not happen if queue has items
|
| 713 |
+
|
| 714 |
+
saved_path = _save_queue_to_file(global_queue_ref, AUTOSAVE_FILENAME)
|
| 715 |
+
|
| 716 |
+
if saved_path:
|
| 717 |
+
print(f"Queue autosaved successfully to {saved_path}")
|
| 718 |
+
else:
|
| 719 |
+
print("Autosave failed.")
|
| 720 |
+
except Exception as e:
|
| 721 |
+
print(f"Error during autosave: {e}")
|
| 722 |
+
traceback.print_exc()
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
def autoload_queue(state):
|
| 726 |
+
global task_id
|
| 727 |
+
# Initial check using the original state
|
| 728 |
+
try:
|
| 729 |
+
gen = get_gen_info(state) # Make sure initial state is a dict
|
| 730 |
+
original_queue = gen.get("queue", [])
|
| 731 |
+
except AttributeError:
|
| 732 |
+
print("[autoload_queue] Error: Initial state is not a dictionary. Cannot autoload.")
|
| 733 |
+
# Return default values indicating no load occurred and the state is unchanged
|
| 734 |
+
return gr.update(visible=False), False, state # Return an empty DF update
|
| 735 |
+
|
| 736 |
+
loaded_flag = False
|
| 737 |
+
dataframe_update = update_queue_data(original_queue) # Default update is the original queue
|
| 738 |
+
|
| 739 |
+
if not original_queue and Path(AUTOSAVE_FILENAME).is_file():
|
| 740 |
+
print(f"Autoloading queue from {AUTOSAVE_FILENAME}...")
|
| 741 |
+
class MockFile:
|
| 742 |
+
def __init__(self, name):
|
| 743 |
+
self.name = name
|
| 744 |
+
mock_filepath = MockFile(AUTOSAVE_FILENAME)
|
| 745 |
+
|
| 746 |
+
# Call load_queue_action, it modifies 'state' internally and returns a DataFrame update
|
| 747 |
+
dataframe_update = load_queue_action(mock_filepath, state)
|
| 748 |
+
|
| 749 |
+
# Now check the 'state' dictionary which should have been modified by load_queue_action
|
| 750 |
+
gen = get_gen_info(state) # Use the (potentially) modified state dictionary
|
| 751 |
+
loaded_queue_after_action = gen.get("queue", [])
|
| 752 |
+
|
| 753 |
+
if loaded_queue_after_action: # Check if the queue in the state is now populated
|
| 754 |
+
print(f"Autoload successful. Loaded {len(loaded_queue_after_action)} tasks into state.")
|
| 755 |
+
loaded_flag = True
|
| 756 |
+
# Global ref update was already done inside load_queue_action if successful
|
| 757 |
+
else:
|
| 758 |
+
print("Autoload attempted but queue in state remains empty (file might be empty or invalid).")
|
| 759 |
+
# Ensure state reflects empty queue if load failed but file existed
|
| 760 |
+
with lock:
|
| 761 |
+
gen["queue"] = []
|
| 762 |
+
gen["prompts_max"] = 0
|
| 763 |
+
update_global_queue_ref([])
|
| 764 |
+
dataframe_update = update_queue_data([]) # Ensure UI shows empty queue
|
| 765 |
+
|
| 766 |
+
else: # Handle cases where autoload shouldn't happen
|
| 767 |
+
if original_queue:
|
| 768 |
+
print("Autoload skipped: Queue is not empty.")
|
| 769 |
+
update_global_queue_ref(original_queue) # Ensure global ref matches current state
|
| 770 |
+
dataframe_update = update_queue_data(original_queue) # UI should show current queue
|
| 771 |
+
else:
|
| 772 |
+
print(f"Autoload skipped: {AUTOSAVE_FILENAME} not found.")
|
| 773 |
+
update_global_queue_ref([]) # Ensure global ref is empty
|
| 774 |
+
dataframe_update = update_queue_data([]) # UI should show empty queue
|
| 775 |
+
|
| 776 |
+
# Return the DataFrame update needed for the UI, the flag, and the final state dictionary
|
| 777 |
+
return dataframe_update, loaded_flag, state
|
| 778 |
|
| 779 |
|
| 780 |
def get_queue_table(queue):
|
|
|
|
| 838 |
])
|
| 839 |
return data
|
| 840 |
def update_queue_data(queue):
|
| 841 |
+
update_global_queue_ref(queue)
|
| 842 |
data = get_queue_table(queue)
|
| 843 |
|
| 844 |
# if len(data) == 0:
|
|
|
|
| 2441 |
yield status
|
| 2442 |
|
| 2443 |
queue[:] = [item for item in queue if item['id'] != task['id']]
|
| 2444 |
+
update_global_queue_ref(queue)
|
| 2445 |
|
| 2446 |
gen["prompts_max"] = 0
|
| 2447 |
gen["prompt"] = ""
|
|
|
|
| 3165 |
wizard_variables = "\n".join(variables)
|
| 3166 |
for _ in range( PROMPT_VARS_MAX - len(prompt_vars)):
|
| 3167 |
prompt_vars.append(gr.Textbox(visible= False, min_width=80, show_label= False))
|
| 3168 |
+
|
| 3169 |
with gr.Column(not advanced_prompt) as prompt_column_wizard:
|
| 3170 |
wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3)
|
| 3171 |
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
|
|
|
|
| 3351 |
queue_df = gr.DataFrame(
|
| 3352 |
headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""],
|
| 3353 |
datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
|
| 3354 |
+
column_widths= ["5%", None, "7%", "7%", "10%", "10%", "3%", "3%", "3%"],
|
| 3355 |
interactive=False,
|
| 3356 |
col_count=(9, "fixed"),
|
| 3357 |
wrap=True,
|
|
|
|
| 3360 |
visible= False,
|
| 3361 |
elem_id="queue_df"
|
| 3362 |
)
|
| 3363 |
+
with gr.Row():
|
| 3364 |
+
queue_zip_base64_output = gr.Text(visible=False)
|
| 3365 |
+
save_queue_btn = gr.DownloadButton("Save Queue", size="sm")
|
| 3366 |
+
load_queue_btn = gr.UploadButton("Load Queue", file_types=[".zip"], size="sm")
|
| 3367 |
+
clear_queue_btn = gr.Button("Clear Queue", size="sm", variant="stop")
|
| 3368 |
+
trigger_zip_download_js = """
|
| 3369 |
+
(base64String) => {
|
| 3370 |
+
if (!base64String) {
|
| 3371 |
+
console.log("No base64 zip data received, skipping download.");
|
| 3372 |
+
return;
|
| 3373 |
+
}
|
| 3374 |
+
try {
|
| 3375 |
+
const byteCharacters = atob(base64String);
|
| 3376 |
+
const byteNumbers = new Array(byteCharacters.length);
|
| 3377 |
+
for (let i = 0; i < byteCharacters.length; i++) {
|
| 3378 |
+
byteNumbers[i] = byteCharacters.charCodeAt(i);
|
| 3379 |
+
}
|
| 3380 |
+
const byteArray = new Uint8Array(byteNumbers);
|
| 3381 |
+
const blob = new Blob([byteArray], { type: 'application/zip' });
|
| 3382 |
+
|
| 3383 |
+
const url = URL.createObjectURL(blob);
|
| 3384 |
+
const a = document.createElement('a');
|
| 3385 |
+
a.style.display = 'none';
|
| 3386 |
+
a.href = url;
|
| 3387 |
+
a.download = 'queue.zip';
|
| 3388 |
+
document.body.appendChild(a);
|
| 3389 |
+
a.click();
|
| 3390 |
+
|
| 3391 |
+
window.URL.revokeObjectURL(url);
|
| 3392 |
+
document.body.removeChild(a);
|
| 3393 |
+
console.log("Zip download triggered.");
|
| 3394 |
+
} catch (e) {
|
| 3395 |
+
console.error("Error processing base64 data or triggering download:", e);
|
| 3396 |
+
}
|
| 3397 |
+
}
|
| 3398 |
+
"""
|
| 3399 |
+
save_queue_btn.click(
|
| 3400 |
+
fn=save_queue_action,
|
| 3401 |
+
inputs=[state],
|
| 3402 |
+
outputs=[queue_zip_base64_output]
|
| 3403 |
+
).then(
|
| 3404 |
+
fn=None,
|
| 3405 |
+
inputs=[queue_zip_base64_output],
|
| 3406 |
+
outputs=None,
|
| 3407 |
+
js=trigger_zip_download_js
|
| 3408 |
+
)
|
| 3409 |
+
|
| 3410 |
+
load_queue_btn.upload(
|
| 3411 |
+
fn=load_queue_action,
|
| 3412 |
+
inputs=[load_queue_btn, state],
|
| 3413 |
+
outputs=[queue_df]
|
| 3414 |
+
).then(
|
| 3415 |
+
fn=lambda s: gr.update(visible=bool(get_gen_info(s).get("queue",[]))),
|
| 3416 |
+
inputs=[state],
|
| 3417 |
+
outputs=[current_gen_column]
|
| 3418 |
+
)
|
| 3419 |
+
|
| 3420 |
+
clear_queue_btn.click(
|
| 3421 |
+
fn=clear_queue_action,
|
| 3422 |
+
inputs=[state],
|
| 3423 |
+
outputs=[queue_df]
|
| 3424 |
+
).then(
|
| 3425 |
+
fn=lambda: gr.update(visible=False),
|
| 3426 |
+
inputs=None,
|
| 3427 |
+
outputs=[current_gen_column]
|
| 3428 |
+
)
|
| 3429 |
|
| 3430 |
extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
|
| 3431 |
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, advanced_row] # show_advanced presets_column,
|
|
|
|
| 3529 |
outputs=[modal_container]
|
| 3530 |
)
|
| 3531 |
|
| 3532 |
+
return (
|
| 3533 |
+
loras_choices, lset_name, state, queue_df, current_gen_column,
|
| 3534 |
+
gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
|
| 3535 |
+
gen_info,
|
| 3536 |
+
prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var,
|
| 3537 |
+
prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars,
|
| 3538 |
+
advanced_row, image_prompt_column, video_prompt_column,
|
| 3539 |
+
*prompt_vars
|
| 3540 |
+
)
|
| 3541 |
+
|
| 3542 |
|
| 3543 |
def generate_download_tab(lset_name,loras_choices, state):
|
| 3544 |
with gr.Row():
|
|
|
|
| 4003 |
with gr.Row():
|
| 4004 |
header = gr.Markdown(generate_header(transformer_filename, compile, attention_mode), visible= True)
|
| 4005 |
with gr.Row():
|
| 4006 |
+
(
|
| 4007 |
+
loras_choices, lset_name, state, queue_df, current_gen_column,
|
| 4008 |
+
gen_status, output, abort_btn, generate_btn, add_to_queue_btn,
|
| 4009 |
+
gen_info,
|
| 4010 |
+
prompt, wizard_prompt, wizard_prompt_activated_var, wizard_variables_var,
|
| 4011 |
+
prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars,
|
| 4012 |
+
advanced_row, image_prompt_column, video_prompt_column,
|
| 4013 |
+
*prompt_vars_outputs
|
| 4014 |
+
) = generate_video_tab(model_choice=model_choice, header=header)
|
| 4015 |
with gr.Tab("Informations"):
|
| 4016 |
generate_info_tab()
|
| 4017 |
if not args.lock_config:
|
|
|
|
| 4022 |
with gr.Tab("About"):
|
| 4023 |
generate_about_tab()
|
| 4024 |
|
| 4025 |
+
should_start_flag = gr.State(False)
|
| 4026 |
+
def run_autoload_and_prepare_ui(current_state):
|
| 4027 |
+
df_update, loaded_flag, modified_state = autoload_queue(current_state)
|
| 4028 |
+
should_start_processing = loaded_flag
|
| 4029 |
+
return df_update, gr.update(visible=loaded_flag), should_start_processing, modified_state
|
| 4030 |
+
|
| 4031 |
+
def start_processing_if_needed(should_start, current_state):
|
| 4032 |
+
if not isinstance(current_state, dict) or 'gen' not in current_state:
|
| 4033 |
+
yield "Error: Invalid state received before processing."
|
| 4034 |
+
return
|
| 4035 |
+
if should_start:
|
| 4036 |
+
yield from process_tasks(current_state)
|
| 4037 |
+
else:
|
| 4038 |
+
yield "Autoload complete. Processing not started."
|
| 4039 |
+
|
| 4040 |
+
def finalize_generation_with_state(current_state):
|
| 4041 |
+
if not isinstance(current_state, dict) or 'gen' not in current_state:
|
| 4042 |
+
return gr.update(), gr.update(interactive=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=""), current_state
|
| 4043 |
+
gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update = finalize_generation(current_state)
|
| 4044 |
+
return gallery_update, abort_btn_update, gen_btn_update, add_queue_btn_update, current_gen_col_update, gen_info_update, current_state
|
| 4045 |
+
|
| 4046 |
+
demo.load(
|
| 4047 |
+
fn=run_autoload_and_prepare_ui,
|
| 4048 |
+
inputs=[state],
|
| 4049 |
+
outputs=[queue_df, current_gen_column, should_start_flag, state]
|
| 4050 |
+
).then(
|
| 4051 |
+
fn=start_processing_if_needed,
|
| 4052 |
+
inputs=[should_start_flag, state],
|
| 4053 |
+
outputs=[gen_status],
|
| 4054 |
+
trigger_mode="once"
|
| 4055 |
+
).then(
|
| 4056 |
+
fn=finalize_generation_with_state,
|
| 4057 |
+
inputs=[state],
|
| 4058 |
+
outputs=[output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info, state],
|
| 4059 |
+
trigger_mode="always_last"
|
| 4060 |
+
)
|
| 4061 |
+
|
| 4062 |
return demo
|
| 4063 |
|
| 4064 |
if __name__ == "__main__":
|
| 4065 |
+
atexit.register(autosave_queue)
|
| 4066 |
# threading.Thread(target=runner, daemon=True).start()
|
| 4067 |
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
| 4068 |
server_port = int(args.server_port)
|