DeepBeepMeep commited on
Commit ·
cb1518b
1
Parent(s): 90fc871
Queue adaptations
Browse files- gradio_server.py +1015 -550
- wan/image2video.py +45 -18
- wan/modules/model.py +3 -4
- wan/text2video.py +2 -2
gradio_server.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import time
|
|
|
|
| 3 |
import threading
|
| 4 |
import argparse
|
| 5 |
from mmgp import offload, safetensors2, profile_type
|
|
@@ -33,15 +34,12 @@ mmgp_version = version("mmgp")
|
|
| 33 |
if mmgp_version != target_mmgp_version:
|
| 34 |
print(f"Incorrect version of mmgp ({mmgp_version}), version {target_mmgp_version} is needed. Please upgrade with the command 'pip install -r requirements.txt'")
|
| 35 |
exit()
|
| 36 |
-
queue = []
|
| 37 |
lock = threading.Lock()
|
| 38 |
current_task_id = None
|
| 39 |
task_id = 0
|
| 40 |
-
progress_tracker = {}
|
| 41 |
-
tracker_lock = threading.Lock()
|
| 42 |
-
file_list = []
|
| 43 |
last_model_type = None
|
| 44 |
-
last_status_string = ""
|
| 45 |
|
| 46 |
def format_time(seconds):
|
| 47 |
if seconds < 60:
|
|
@@ -77,37 +75,6 @@ def pil_to_base64_uri(pil_image, format="png", quality=75):
|
|
| 77 |
print(f"Error converting PIL to base64: {e}")
|
| 78 |
return None
|
| 79 |
|
| 80 |
-
def runner():
|
| 81 |
-
global current_task_id
|
| 82 |
-
while True:
|
| 83 |
-
with lock:
|
| 84 |
-
for item in queue:
|
| 85 |
-
task_id_runner = item['id']
|
| 86 |
-
with tracker_lock:
|
| 87 |
-
progress = progress_tracker.get(task_id_runner, {})
|
| 88 |
-
|
| 89 |
-
if item['state'] == "Processing":
|
| 90 |
-
current_step = progress.get('current_step', 0)
|
| 91 |
-
total_steps = progress.get('total_steps', 0)
|
| 92 |
-
elapsed = time.time() - progress.get('start_time', time.time())
|
| 93 |
-
status = progress.get('status', "")
|
| 94 |
-
repeats = progress.get("repeats", "0/0")
|
| 95 |
-
item.update({
|
| 96 |
-
'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
|
| 97 |
-
'steps': f"{current_step}/{total_steps}",
|
| 98 |
-
'time': format_time(elapsed),
|
| 99 |
-
'repeats': f"{repeats}",
|
| 100 |
-
'status': f"{status}"
|
| 101 |
-
})
|
| 102 |
-
if not any(item['state'] == "Processing" for item in queue):
|
| 103 |
-
for item in queue:
|
| 104 |
-
if item['state'] == "Queued":
|
| 105 |
-
item['status'] = "Processing"
|
| 106 |
-
item['state'] = "Processing"
|
| 107 |
-
current_task_id = item['id']
|
| 108 |
-
threading.Thread(target=process_task, args=(item,)).start()
|
| 109 |
-
break
|
| 110 |
-
time.sleep(1)
|
| 111 |
|
| 112 |
def process_prompt_and_add_tasks(
|
| 113 |
prompt,
|
|
@@ -137,161 +104,290 @@ def process_prompt_and_add_tasks(
|
|
| 137 |
slg_end,
|
| 138 |
cfg_star_switch,
|
| 139 |
cfg_zero_step,
|
| 140 |
-
|
| 141 |
image2video
|
| 142 |
):
|
| 143 |
-
|
| 144 |
-
if
|
| 145 |
-
|
| 146 |
return
|
|
|
|
|
|
|
| 147 |
if len(prompt) ==0:
|
| 148 |
return
|
| 149 |
prompt, errors = prompt_parser.process_template(prompt)
|
| 150 |
if len(errors) > 0:
|
| 151 |
-
|
| 152 |
return
|
| 153 |
prompts = prompt.replace("\r", "").split("\n")
|
| 154 |
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
|
| 155 |
if len(prompts) ==0:
|
| 156 |
return
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
resolution,
|
| 163 |
-
video_length,
|
| 164 |
-
seed,
|
| 165 |
-
num_inference_steps,
|
| 166 |
-
guidance_scale,
|
| 167 |
-
flow_shift,
|
| 168 |
-
embedded_guidance_scale,
|
| 169 |
-
repeat_generation,
|
| 170 |
-
multi_images_gen_type,
|
| 171 |
-
tea_cache,
|
| 172 |
-
tea_cache_start_step_perc,
|
| 173 |
-
loras_choices,
|
| 174 |
-
loras_mult_choices,
|
| 175 |
-
image_prompt_type,
|
| 176 |
-
image_to_continue,
|
| 177 |
-
image_to_end,
|
| 178 |
-
video_to_continue,
|
| 179 |
-
max_frames,
|
| 180 |
-
RIFLEx_setting,
|
| 181 |
-
slg_switch,
|
| 182 |
-
slg_layers,
|
| 183 |
-
slg_start,
|
| 184 |
-
slg_end,
|
| 185 |
-
cfg_star_switch,
|
| 186 |
-
cfg_zero_step,
|
| 187 |
-
state_arg,
|
| 188 |
-
image2video
|
| 189 |
-
)
|
| 190 |
-
add_video_task(*task_params)
|
| 191 |
-
return update_queue_data()
|
| 192 |
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
global task_id
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
| 228 |
if not selected_indices or len(selected_indices) == 0:
|
| 229 |
-
return update_queue_data()
|
| 230 |
idx = selected_indices[0]
|
| 231 |
if isinstance(idx, list):
|
| 232 |
idx = idx[0]
|
| 233 |
idx = int(idx)
|
| 234 |
with lock:
|
| 235 |
if idx > 0:
|
|
|
|
| 236 |
queue[idx], queue[idx-1] = queue[idx-1], queue[idx]
|
| 237 |
-
return update_queue_data()
|
| 238 |
|
| 239 |
-
def move_down(selected_indices):
|
| 240 |
if not selected_indices or len(selected_indices) == 0:
|
| 241 |
-
return update_queue_data()
|
| 242 |
idx = selected_indices[0]
|
| 243 |
if isinstance(idx, list):
|
| 244 |
idx = idx[0]
|
| 245 |
idx = int(idx)
|
| 246 |
with lock:
|
|
|
|
| 247 |
if idx < len(queue)-1:
|
| 248 |
queue[idx], queue[idx+1] = queue[idx+1], queue[idx]
|
| 249 |
-
return update_queue_data()
|
| 250 |
|
| 251 |
-
def remove_task(selected_indices):
|
| 252 |
if not selected_indices or len(selected_indices) == 0:
|
| 253 |
-
return update_queue_data()
|
| 254 |
idx = selected_indices[0]
|
| 255 |
if isinstance(idx, list):
|
| 256 |
idx = idx[0]
|
| 257 |
-
idx = int(idx)
|
| 258 |
with lock:
|
| 259 |
if idx < len(queue):
|
| 260 |
if idx == 0:
|
| 261 |
wan_model._interrupt = True
|
| 262 |
del queue[idx]
|
| 263 |
-
return update_queue_data()
|
| 264 |
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True):
|
| 297 |
bar_class = "progress-bar-custom idle" if is_idle else "progress-bar-custom"
|
|
@@ -306,35 +402,35 @@ def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True):
|
|
| 306 |
"""
|
| 307 |
return html
|
| 308 |
|
| 309 |
-
def refresh_progress():
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
|
| 339 |
def update_generation_status(html_content):
|
| 340 |
if(html_content):
|
|
@@ -736,7 +832,6 @@ if args.i2v_1_3B:
|
|
| 736 |
|
| 737 |
only_allow_edit_in_advanced = False
|
| 738 |
lora_preselected_preset = args.lora_preset
|
| 739 |
-
lora_preselected_preset_for_i2v = use_image2video
|
| 740 |
# if args.fast : #or args.fastest
|
| 741 |
# transformer_filename_t2v = transformer_choices_t2v[2]
|
| 742 |
# attention_mode="sage2" if "sage2" in attention_modes_supported else "sage"
|
|
@@ -749,7 +844,6 @@ if args.compile: #args.fastest or
|
|
| 749 |
lock_ui_compile = True
|
| 750 |
|
| 751 |
model_filename = ""
|
| 752 |
-
lora_model_filename = ""
|
| 753 |
#attention_mode="sage"
|
| 754 |
#attention_mode="sage2"
|
| 755 |
#attention_mode="flash"
|
|
@@ -758,15 +852,12 @@ lora_model_filename = ""
|
|
| 758 |
# compile = "transformer"
|
| 759 |
|
| 760 |
def preprocess_loras(sd):
|
| 761 |
-
if not use_image2video:
|
| 762 |
-
return sd
|
| 763 |
-
|
| 764 |
-
new_sd = {}
|
| 765 |
first = next(iter(sd), None)
|
| 766 |
if first == None:
|
| 767 |
return sd
|
| 768 |
-
if
|
| 769 |
return sd
|
|
|
|
| 770 |
print("Converting Lora Safetensors format to Lora Diffusers format")
|
| 771 |
alphas = {}
|
| 772 |
repl_list = ["cross_attn", "self_attn", "ffn"]
|
|
@@ -845,14 +936,14 @@ download_models(transformer_filename_i2v if use_image2video else transformer_fil
|
|
| 845 |
def sanitize_file_name(file_name, rep =""):
|
| 846 |
return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep)
|
| 847 |
|
| 848 |
-
def extract_preset(lset_name, loras):
|
| 849 |
loras_choices = []
|
| 850 |
loras_choices_files = []
|
| 851 |
loras_mult_choices = ""
|
| 852 |
prompt =""
|
| 853 |
full_prompt =""
|
| 854 |
lset_name = sanitize_file_name(lset_name)
|
| 855 |
-
lora_dir = get_lora_dir(
|
| 856 |
if not lset_name.endswith(".lset"):
|
| 857 |
lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" )
|
| 858 |
else:
|
|
@@ -923,7 +1014,7 @@ def setup_loras(i2v, transformer, lora_dir, lora_preselected_preset, split_line
|
|
| 923 |
if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
|
| 924 |
raise Exception(f"Unknown preset '{lora_preselected_preset}'")
|
| 925 |
default_lora_preset = lora_preselected_preset
|
| 926 |
-
default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, _ , error = extract_preset(default_lora_preset, loras)
|
| 927 |
if len(error) > 0:
|
| 928 |
print(error[:200])
|
| 929 |
return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
|
|
@@ -1010,7 +1101,7 @@ def load_models(i2v):
|
|
| 1010 |
# kwargs["partialPinning"] = True
|
| 1011 |
elif profile == 3:
|
| 1012 |
kwargs["budgets"] = { "*" : "70%" }
|
| 1013 |
-
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", **kwargs)
|
| 1014 |
if len(args.gpu) > 0:
|
| 1015 |
torch.set_default_device(args.gpu)
|
| 1016 |
|
|
@@ -1087,7 +1178,7 @@ def apply_changes( state,
|
|
| 1087 |
if gen_in_progress:
|
| 1088 |
yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
|
| 1089 |
return
|
| 1090 |
-
global offloadobj, wan_model, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
|
| 1091 |
server_config = {"attention_mode" : attention_choice,
|
| 1092 |
"transformer_filename": transformer_choices_t2v[transformer_t2v_choice],
|
| 1093 |
"transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice],
|
|
@@ -1152,44 +1243,152 @@ def save_video(final_frames, output_path, fps=24):
|
|
| 1152 |
final_frames = (final_frames * 255).astype(np.uint8)
|
| 1153 |
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
|
| 1154 |
|
| 1155 |
-
|
| 1156 |
-
|
| 1157 |
-
|
| 1158 |
-
|
| 1159 |
-
|
| 1160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1161 |
# pipe._interrupt = True
|
| 1162 |
-
phase = "Aborting"
|
| 1163 |
elif step_idx == num_inference_steps:
|
| 1164 |
-
phase = "VAE Decoding"
|
| 1165 |
else:
|
| 1166 |
-
phase = "Denoising"
|
| 1167 |
-
|
| 1168 |
-
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
-
|
| 1172 |
-
|
| 1173 |
-
|
| 1174 |
-
|
| 1175 |
-
|
| 1176 |
-
|
| 1177 |
-
|
| 1178 |
-
def
|
| 1179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1180 |
|
| 1181 |
def refresh_gallery_on_trigger(state):
|
| 1182 |
-
|
| 1183 |
-
|
| 1184 |
-
|
|
|
|
|
|
|
| 1185 |
|
| 1186 |
def select_video(state , event_data: gr.EventData):
|
| 1187 |
data= event_data._data
|
|
|
|
|
|
|
| 1188 |
if data!=None:
|
| 1189 |
choice = data.get("index",0)
|
| 1190 |
-
file_list =
|
| 1191 |
-
|
| 1192 |
-
|
| 1193 |
return
|
| 1194 |
|
| 1195 |
def expand_slist(slist, num_inference_steps ):
|
|
@@ -1221,6 +1420,7 @@ def convert_image(image):
|
|
| 1221 |
|
| 1222 |
def generate_video(
|
| 1223 |
task_id,
|
|
|
|
| 1224 |
prompt,
|
| 1225 |
negative_prompt,
|
| 1226 |
resolution,
|
|
@@ -1254,22 +1454,29 @@ def generate_video(
|
|
| 1254 |
):
|
| 1255 |
|
| 1256 |
global wan_model, offloadobj, reload_needed, last_model_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1257 |
file_model_needed = model_needed(image2video)
|
| 1258 |
-
|
| 1259 |
-
|
| 1260 |
-
|
|
|
|
|
|
|
| 1261 |
del wan_model
|
| 1262 |
if offloadobj is not None:
|
| 1263 |
offloadobj.release()
|
| 1264 |
del offloadobj
|
| 1265 |
gc.collect()
|
| 1266 |
-
|
| 1267 |
wan_model, offloadobj, trans = load_models(image2video)
|
| 1268 |
-
|
| 1269 |
reload_needed= False
|
| 1270 |
|
| 1271 |
if wan_model == None:
|
| 1272 |
-
|
| 1273 |
if attention_mode == "auto":
|
| 1274 |
attn = get_auto_attention()
|
| 1275 |
elif attention_mode in attention_modes_supported:
|
|
@@ -1278,26 +1485,15 @@ def generate_video(
|
|
| 1278 |
gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.")
|
| 1279 |
return
|
| 1280 |
|
| 1281 |
-
|
| 1282 |
-
|
| 1283 |
-
|
|
|
|
|
|
|
| 1284 |
|
| 1285 |
if slg_switch == 0:
|
| 1286 |
slg_layers = None
|
| 1287 |
-
if image2video:
|
| 1288 |
-
if "480p" in model_filename and not "Fun" in model_filename and width * height > 848*480:
|
| 1289 |
-
gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
|
| 1290 |
-
return
|
| 1291 |
|
| 1292 |
-
resolution = str(width) + "*" + str(height)
|
| 1293 |
-
if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
|
| 1294 |
-
gr.Info(f"Resolution {resolution} not supported by image 2 video")
|
| 1295 |
-
return
|
| 1296 |
-
|
| 1297 |
-
if "1.3B" in model_filename and width * height > 848*480:
|
| 1298 |
-
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
|
| 1299 |
-
return
|
| 1300 |
-
|
| 1301 |
offload.shared_state["_attention"] = attn
|
| 1302 |
|
| 1303 |
# VAE Tiling
|
|
@@ -1321,16 +1517,7 @@ def generate_video(
|
|
| 1321 |
|
| 1322 |
trans = wan_model.model
|
| 1323 |
|
| 1324 |
-
global gen_in_progress
|
| 1325 |
-
gen_in_progress = True
|
| 1326 |
temp_filename = None
|
| 1327 |
-
if image2video:
|
| 1328 |
-
if video_to_continue != None and len(video_to_continue) >0 :
|
| 1329 |
-
input_image_or_video_path = video_to_continue
|
| 1330 |
-
# pipeline.num_input_frames = max_frames
|
| 1331 |
-
# pipeline.max_frames = max_frames
|
| 1332 |
-
else:
|
| 1333 |
-
input_image_or_video_path = None
|
| 1334 |
|
| 1335 |
loras = state["loras"]
|
| 1336 |
if len(loras) > 0:
|
|
@@ -1374,10 +1561,6 @@ def generate_video(
|
|
| 1374 |
raise gr.Error("Error while loading Loras: " + ", ".join(error_files))
|
| 1375 |
seed = None if seed == -1 else seed
|
| 1376 |
# negative_prompt = "" # not applicable in the inference
|
| 1377 |
-
|
| 1378 |
-
if "abort" in state:
|
| 1379 |
-
del state["abort"]
|
| 1380 |
-
state["in_progress"] = True
|
| 1381 |
|
| 1382 |
enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
|
| 1383 |
# VAE Tiling
|
|
@@ -1414,49 +1597,53 @@ def generate_video(
|
|
| 1414 |
if seed == None or seed <0:
|
| 1415 |
seed = random.randint(0, 999999999)
|
| 1416 |
|
| 1417 |
-
global file_list
|
| 1418 |
-
clear_file_list = server_config.get("clear_file_list", 0)
|
| 1419 |
-
file_list = state.get("file_list", [])
|
| 1420 |
-
if clear_file_list > 0:
|
| 1421 |
-
file_list_current_size = len(file_list)
|
| 1422 |
-
keep_file_from = max(file_list_current_size - clear_file_list, 0)
|
| 1423 |
-
files_removed = keep_file_from
|
| 1424 |
-
choice = state.get("selected",0)
|
| 1425 |
-
choice = max(choice- files_removed, 0)
|
| 1426 |
-
file_list = file_list[ keep_file_from: ]
|
| 1427 |
-
else:
|
| 1428 |
-
file_list = []
|
| 1429 |
-
choice = 0
|
| 1430 |
-
state["selected"] = choice
|
| 1431 |
-
state["file_list"] = file_list
|
| 1432 |
-
|
| 1433 |
-
|
| 1434 |
global save_path
|
| 1435 |
os.makedirs(save_path, exist_ok=True)
|
| 1436 |
video_no = 0
|
| 1437 |
abort = False
|
| 1438 |
-
repeats = f"{video_no}/{repeat_generation}"
|
| 1439 |
-
callback = build_callback(task_id, state, trans, num_inference_steps, repeats)
|
| 1440 |
-
offload.shared_state["callback"] = callback
|
| 1441 |
gc.collect()
|
| 1442 |
torch.cuda.empty_cache()
|
| 1443 |
wan_model._interrupt = False
|
| 1444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1445 |
try:
|
| 1446 |
-
|
| 1447 |
-
|
| 1448 |
-
|
| 1449 |
-
|
| 1450 |
-
|
| 1451 |
-
|
| 1452 |
-
|
| 1453 |
-
|
| 1454 |
-
|
| 1455 |
-
|
| 1456 |
if trans.enable_teacache:
|
| 1457 |
trans.teacache_counter = 0
|
| 1458 |
trans.num_steps = num_inference_steps
|
| 1459 |
-
trans.teacache_skipped_steps = 0
|
| 1460 |
trans.previous_residual_uncond = None
|
| 1461 |
trans.previous_residual_cond = None
|
| 1462 |
|
|
@@ -1464,8 +1651,8 @@ def generate_video(
|
|
| 1464 |
if image2video:
|
| 1465 |
samples = wan_model.generate(
|
| 1466 |
prompt,
|
| 1467 |
-
|
| 1468 |
-
|
| 1469 |
frame_num=(video_length // 4)* 4 + 1,
|
| 1470 |
max_area=MAX_AREA_CONFIGS[resolution],
|
| 1471 |
shift=flow_shift,
|
|
@@ -1483,7 +1670,7 @@ def generate_video(
|
|
| 1483 |
slg_end = slg_end/100,
|
| 1484 |
cfg_star_switch = cfg_star_switch,
|
| 1485 |
cfg_zero_step = cfg_zero_step,
|
| 1486 |
-
add_frames_for_end_image = not "Fun" in transformer_filename_i2v
|
| 1487 |
)
|
| 1488 |
else:
|
| 1489 |
samples = wan_model.generate(
|
|
@@ -1507,7 +1694,6 @@ def generate_video(
|
|
| 1507 |
cfg_zero_step = cfg_zero_step,
|
| 1508 |
)
|
| 1509 |
except Exception as e:
|
| 1510 |
-
gen_in_progress = False
|
| 1511 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
| 1512 |
os.remove(temp_filename)
|
| 1513 |
offload.last_offload_obj.unload_all()
|
|
@@ -1530,15 +1716,23 @@ def generate_video(
|
|
| 1530 |
if any( keyword in frame.name for keyword in keyword_list):
|
| 1531 |
VRAM_crash = True
|
| 1532 |
break
|
|
|
|
|
|
|
|
|
|
| 1533 |
state["prompt"] = ""
|
| 1534 |
if VRAM_crash:
|
| 1535 |
-
|
| 1536 |
else:
|
| 1537 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1538 |
finally:
|
| 1539 |
-
|
| 1540 |
-
|
| 1541 |
-
|
|
|
|
| 1542 |
|
| 1543 |
if trans.enable_teacache:
|
| 1544 |
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
|
|
@@ -1555,7 +1749,7 @@ def generate_video(
|
|
| 1555 |
end_time = time.time()
|
| 1556 |
abort = True
|
| 1557 |
state["prompt"] = ""
|
| 1558 |
-
|
| 1559 |
else:
|
| 1560 |
sample = samples.cpu()
|
| 1561 |
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
|
|
@@ -1574,7 +1768,7 @@ def generate_video(
|
|
| 1574 |
normalize=True,
|
| 1575 |
value_range=(-1, 1))
|
| 1576 |
|
| 1577 |
-
configs = get_settings_dict(state,
|
| 1578 |
loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
|
| 1579 |
|
| 1580 |
metadata_choice = server_config.get("metadata_choice","metadata")
|
|
@@ -1596,9 +1790,159 @@ def generate_video(
|
|
| 1596 |
|
| 1597 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
| 1598 |
os.remove(temp_filename)
|
| 1599 |
-
gen_in_progress = False
|
| 1600 |
offload.unload_loras_from_model(trans)
|
| 1601 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1602 |
|
| 1603 |
def get_new_preset_msg(advanced = True):
|
| 1604 |
if advanced:
|
|
@@ -1650,7 +1994,7 @@ def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_
|
|
| 1650 |
|
| 1651 |
|
| 1652 |
lset_name_filename = lset_name + ".lset"
|
| 1653 |
-
full_lset_name_filename = os.path.join(get_lora_dir(
|
| 1654 |
|
| 1655 |
with open(full_lset_name_filename, "w", encoding="utf-8") as writer:
|
| 1656 |
writer.write(json.dumps(lset, indent=4))
|
|
@@ -1667,7 +2011,7 @@ def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_
|
|
| 1667 |
|
| 1668 |
def delete_lset(state, lset_name):
|
| 1669 |
loras_presets = state["loras_presets"]
|
| 1670 |
-
lset_name_filename = os.path.join( get_lora_dir(
|
| 1671 |
if len(lset_name) > 0 and lset_name != get_new_preset_msg(True) and lset_name != get_new_preset_msg(False):
|
| 1672 |
if not os.path.isfile(lset_name_filename):
|
| 1673 |
raise gr.Error(f"Preset '{lset_name}' not found ")
|
|
@@ -1688,8 +2032,8 @@ def delete_lset(state, lset_name):
|
|
| 1688 |
def refresh_lora_list(state, lset_name, loras_choices):
|
| 1689 |
loras_names = state["loras_names"]
|
| 1690 |
prev_lora_names_selected = [ loras_names[int(i)] for i in loras_choices]
|
| 1691 |
-
|
| 1692 |
-
loras, loras_names, loras_presets, _, _, _, _ = setup_loras(
|
| 1693 |
state["loras"] = loras
|
| 1694 |
state["loras_names"] = loras_names
|
| 1695 |
state["loras_presets"] = loras_presets
|
|
@@ -1729,7 +2073,7 @@ def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_m
|
|
| 1729 |
gr.Info("Please choose a preset in the list or create one")
|
| 1730 |
else:
|
| 1731 |
loras = state["loras"]
|
| 1732 |
-
loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(lset_name, loras)
|
| 1733 |
if len(error) > 0:
|
| 1734 |
gr.Info(error)
|
| 1735 |
else:
|
|
@@ -1930,10 +2274,11 @@ def save_settings(state, prompt, image_prompt_type, video_length, resolution, nu
|
|
| 1930 |
if state.get("validate_success",0) != 1:
|
| 1931 |
return
|
| 1932 |
|
| 1933 |
-
|
|
|
|
| 1934 |
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
|
| 1935 |
|
| 1936 |
-
defaults_filename = get_settings_file_name(
|
| 1937 |
|
| 1938 |
with open(defaults_filename, "w", encoding="utf-8") as f:
|
| 1939 |
json.dump(ui_defaults, f, indent=4)
|
|
@@ -1976,7 +2321,12 @@ def generate_video_tab(image2video=False):
|
|
| 1976 |
|
| 1977 |
state_dict["advanced"] = advanced
|
| 1978 |
state_dict["loras_model"] = filename
|
| 1979 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1980 |
|
| 1981 |
loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset = setup_loras(image2video, None, get_lora_dir(image2video), preset_to_load, None)
|
| 1982 |
|
|
@@ -1989,7 +2339,7 @@ def generate_video_tab(image2video=False):
|
|
| 1989 |
launch_loras = []
|
| 1990 |
launch_multis_str = ""
|
| 1991 |
|
| 1992 |
-
if len(default_lora_preset) > 0 and image2video ==
|
| 1993 |
launch_preset = default_lora_preset
|
| 1994 |
launch_prompt = default_lora_preset_prompt
|
| 1995 |
launch_loras = default_loras_choices
|
|
@@ -2014,15 +2364,6 @@ def generate_video_tab(image2video=False):
|
|
| 2014 |
|
| 2015 |
|
| 2016 |
header = gr.Markdown(generate_header(model_filename, compile, attention_mode))
|
| 2017 |
-
with gr.Row(visible= image2video):
|
| 2018 |
-
with gr.Row(scale =2):
|
| 2019 |
-
gr.Markdown("<I>Wan2GP's Lora Festival ! Press the following button to download i2v <B>Remade</B> Loras collection (and bonuses Loras).")
|
| 2020 |
-
with gr.Row(scale =1):
|
| 2021 |
-
download_loras_btn = gr.Button("---> Let the Lora's Festival Start !", scale =1)
|
| 2022 |
-
with gr.Row(scale =1):
|
| 2023 |
-
gr.Markdown("")
|
| 2024 |
-
with gr.Row(visible= image2video) as download_status_row:
|
| 2025 |
-
download_status = gr.Markdown()
|
| 2026 |
with gr.Row():
|
| 2027 |
with gr.Column():
|
| 2028 |
with gr.Column(visible=False, elem_id="image-modal-container") as modal_container:
|
|
@@ -2250,89 +2591,112 @@ def generate_video_tab(image2video=False):
|
|
| 2250 |
cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)")
|
| 2251 |
|
| 2252 |
with gr.Row():
|
| 2253 |
-
save_settings_btn = gr.Button("Set Settings as Default")
|
| 2254 |
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
|
| 2255 |
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
|
| 2256 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
| 2257 |
gen_progress_html = gr.HTML(
|
| 2258 |
label="Status",
|
| 2259 |
value="Idle",
|
| 2260 |
-
elem_id="generation_progress_bar_container"
|
| 2261 |
)
|
| 2262 |
output = gr.Gallery(
|
| 2263 |
label="Generated videos", show_label=False, elem_id="gallery"
|
| 2264 |
, columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
|
| 2265 |
generate_btn = gr.Button("Generate")
|
| 2266 |
-
|
| 2267 |
-
|
| 2268 |
-
|
| 2269 |
-
|
| 2270 |
-
|
| 2271 |
-
|
| 2272 |
-
|
| 2273 |
-
|
| 2274 |
-
|
| 2275 |
-
|
| 2276 |
-
|
| 2277 |
-
|
| 2278 |
-
|
| 2279 |
-
|
| 2280 |
-
|
| 2281 |
-
|
| 2282 |
-
|
| 2283 |
-
|
| 2284 |
-
|
| 2285 |
-
|
| 2286 |
-
|
| 2287 |
-
|
| 2288 |
-
|
| 2289 |
-
|
| 2290 |
-
|
| 2291 |
-
|
| 2292 |
-
|
| 2293 |
-
|
| 2294 |
-
|
| 2295 |
-
|
| 2296 |
-
|
| 2297 |
-
|
| 2298 |
-
|
| 2299 |
-
|
| 2300 |
-
|
| 2301 |
-
|
| 2302 |
-
|
| 2303 |
-
|
| 2304 |
-
|
| 2305 |
-
|
| 2306 |
-
|
| 2307 |
-
|
| 2308 |
-
|
| 2309 |
-
|
| 2310 |
-
|
| 2311 |
-
|
| 2312 |
-
|
| 2313 |
-
|
| 2314 |
-
|
| 2315 |
-
|
| 2316 |
-
|
| 2317 |
-
|
| 2318 |
-
|
| 2319 |
-
|
| 2320 |
-
|
| 2321 |
-
|
| 2322 |
-
|
| 2323 |
-
|
| 2324 |
-
|
| 2325 |
-
|
| 2326 |
-
|
| 2327 |
-
|
| 2328 |
-
|
| 2329 |
-
|
| 2330 |
-
|
| 2331 |
-
|
| 2332 |
-
|
| 2333 |
-
|
| 2334 |
-
|
| 2335 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2336 |
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
|
| 2337 |
save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
|
| 2338 |
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
|
|
@@ -2348,53 +2712,114 @@ def generate_video_tab(image2video=False):
|
|
| 2348 |
)
|
| 2349 |
refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
| 2350 |
refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
| 2351 |
-
download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
| 2352 |
output.select(select_video, state, None )
|
| 2353 |
|
| 2354 |
-
|
| 2355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2356 |
).then(
|
| 2357 |
fn=process_prompt_and_add_tasks,
|
| 2358 |
-
inputs=
|
| 2359 |
-
prompt,
|
| 2360 |
-
negative_prompt,
|
| 2361 |
-
resolution,
|
| 2362 |
-
video_length,
|
| 2363 |
-
seed,
|
| 2364 |
-
num_inference_steps,
|
| 2365 |
-
guidance_scale,
|
| 2366 |
-
flow_shift,
|
| 2367 |
-
embedded_guidance_scale,
|
| 2368 |
-
repeat_generation,
|
| 2369 |
-
multi_images_gen_type,
|
| 2370 |
-
tea_cache_setting,
|
| 2371 |
-
tea_cache_start_step_perc,
|
| 2372 |
-
loras_choices,
|
| 2373 |
-
loras_mult_choices,
|
| 2374 |
-
image_prompt_type_radio,
|
| 2375 |
-
image_to_continue,
|
| 2376 |
-
image_to_end,
|
| 2377 |
-
video_to_continue,
|
| 2378 |
-
max_frames,
|
| 2379 |
-
RIFLEx_setting,
|
| 2380 |
-
slg_switch,
|
| 2381 |
-
slg_layers,
|
| 2382 |
-
slg_start_perc,
|
| 2383 |
-
slg_end_perc,
|
| 2384 |
-
cfg_star_switch,
|
| 2385 |
-
cfg_zero_step,
|
| 2386 |
-
state,
|
| 2387 |
-
gr.State(image2video)
|
| 2388 |
-
],
|
| 2389 |
outputs=queue_df
|
|
|
|
|
|
|
|
|
|
| 2390 |
)
|
|
|
|
|
|
|
| 2391 |
close_modal_button.click(
|
| 2392 |
lambda: gr.update(visible=False),
|
| 2393 |
inputs=[],
|
| 2394 |
outputs=[modal_container]
|
| 2395 |
)
|
| 2396 |
-
return loras_column, loras_choices, presets_column, lset_name, header, state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2397 |
|
|
|
|
| 2398 |
def generate_configuration_tab():
|
| 2399 |
state_dict = {}
|
| 2400 |
state = gr.State(state_dict)
|
|
@@ -2411,7 +2836,7 @@ def generate_configuration_tab():
|
|
| 2411 |
value= index,
|
| 2412 |
label="Transformer model for Text to Video",
|
| 2413 |
interactive= not lock_ui_transformer,
|
| 2414 |
-
visible=True
|
| 2415 |
)
|
| 2416 |
index = transformer_choices_i2v.index(transformer_filename_i2v)
|
| 2417 |
index = 0 if index ==0 else index
|
|
@@ -2428,7 +2853,7 @@ def generate_configuration_tab():
|
|
| 2428 |
value= index,
|
| 2429 |
label="Transformer model for Image to Video",
|
| 2430 |
interactive= not lock_ui_transformer,
|
| 2431 |
-
visible = True
|
| 2432 |
)
|
| 2433 |
index = text_encoder_choices.index(text_encoder_filename)
|
| 2434 |
index = 0 if index ==0 else index
|
|
@@ -2524,7 +2949,7 @@ def generate_configuration_tab():
|
|
| 2524 |
reload_choice = gr.Dropdown(
|
| 2525 |
choices=[
|
| 2526 |
("When changing tabs", 1),
|
| 2527 |
-
("When pressing
|
| 2528 |
],
|
| 2529 |
value=server_config.get("reload_model",2),
|
| 2530 |
label="Reload model"
|
|
@@ -2577,19 +3002,46 @@ def generate_about_tab():
|
|
| 2577 |
gr.Markdown("- <B>Remade_AI</B> : for creating their awesome Loras collection")
|
| 2578 |
|
| 2579 |
|
| 2580 |
-
def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
|
| 2581 |
-
global lora_model_filename, use_image2video
|
| 2582 |
-
|
| 2583 |
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
|
| 2584 |
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
|
| 2585 |
|
| 2586 |
new_t2v = evt.index == 0
|
| 2587 |
new_i2v = evt.index == 1
|
| 2588 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2589 |
|
| 2590 |
if(server_config.get("reload_model",2) == 1):
|
| 2591 |
-
|
| 2592 |
-
|
|
|
|
| 2593 |
if queue_empty:
|
| 2594 |
global wan_model, offloadobj
|
| 2595 |
if wan_model is not None:
|
|
@@ -2599,7 +3051,7 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
|
|
| 2599 |
wan_model = None
|
| 2600 |
gc.collect()
|
| 2601 |
torch.cuda.empty_cache()
|
| 2602 |
-
wan_model, offloadobj, trans = load_models(
|
| 2603 |
del trans
|
| 2604 |
|
| 2605 |
if new_t2v or new_i2v:
|
|
@@ -2625,11 +3077,15 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
|
|
| 2625 |
gr.Column(visible= visible),
|
| 2626 |
gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible),
|
| 2627 |
t2v_header,
|
|
|
|
|
|
|
| 2628 |
gr.Column(),
|
| 2629 |
gr.Dropdown(),
|
| 2630 |
gr.Column(),
|
| 2631 |
gr.Dropdown(),
|
| 2632 |
-
|
|
|
|
|
|
|
| 2633 |
]
|
| 2634 |
else:
|
| 2635 |
return [
|
|
@@ -2637,16 +3093,21 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
|
|
| 2637 |
gr.Dropdown(),
|
| 2638 |
gr.Column(),
|
| 2639 |
gr.Dropdown(),
|
| 2640 |
-
|
|
|
|
|
|
|
|
|
|
| 2641 |
gr.Column(visible= visible),
|
| 2642 |
gr.Dropdown(choices=new_loras_choices, visible=visible, value=[]),
|
| 2643 |
gr.Column(visible= visible),
|
| 2644 |
gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible),
|
| 2645 |
i2v_header,
|
|
|
|
|
|
|
| 2646 |
]
|
| 2647 |
|
| 2648 |
-
return [gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), t2v_header,
|
| 2649 |
-
gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), i2v_header]
|
| 2650 |
|
| 2651 |
|
| 2652 |
def create_demo():
|
|
@@ -2706,112 +3167,112 @@ def create_demo():
|
|
| 2706 |
overflow: hidden;
|
| 2707 |
text-overflow: ellipsis;
|
| 2708 |
}
|
| 2709 |
-
#queue_df td:nth-child(-n+5) {
|
| 2710 |
-
|
| 2711 |
-
|
| 2712 |
-
}
|
| 2713 |
-
#queue_df td:nth-child(6) {
|
| 2714 |
-
|
| 2715 |
-
}
|
| 2716 |
-
#queue_df th {
|
| 2717 |
-
|
| 2718 |
-
|
| 2719 |
-
|
| 2720 |
-
}
|
| 2721 |
-
#queue_df table {
|
| 2722 |
-
|
| 2723 |
-
|
| 2724 |
-
}
|
| 2725 |
-
#queue_df::-webkit-scrollbar {
|
| 2726 |
-
|
| 2727 |
-
}
|
| 2728 |
-
#queue_df {
|
| 2729 |
-
|
| 2730 |
-
|
| 2731 |
-
}
|
| 2732 |
-
#queue_df th:nth-child(1),
|
| 2733 |
-
#queue_df td:nth-child(1) {
|
| 2734 |
-
|
| 2735 |
-
|
| 2736 |
-
|
| 2737 |
-
}
|
| 2738 |
-
#queue_df th:nth-child(1) {
|
| 2739 |
-
|
| 2740 |
-
}
|
| 2741 |
-
#queue_df th:nth-child(2),
|
| 2742 |
-
#queue_df td:nth-child(2) {
|
| 2743 |
-
|
| 2744 |
-
|
| 2745 |
-
|
| 2746 |
-
}
|
| 2747 |
-
#queue_df th:nth-child(2) {
|
| 2748 |
-
|
| 2749 |
-
}
|
| 2750 |
-
#queue_df th:nth-child(3),
|
| 2751 |
-
#queue_df td:nth-child(3) {
|
| 2752 |
-
|
| 2753 |
-
|
| 2754 |
-
|
| 2755 |
-
}
|
| 2756 |
-
|
| 2757 |
-
|
| 2758 |
-
}
|
| 2759 |
-
#queue_df th:nth-child(4),
|
| 2760 |
-
#queue_df td:nth-child(4) {
|
| 2761 |
-
|
| 2762 |
-
|
| 2763 |
-
|
| 2764 |
-
}
|
| 2765 |
-
|
| 2766 |
-
|
| 2767 |
-
}
|
| 2768 |
-
#queue_df th:nth-child(5),
|
| 2769 |
-
#queue_df td:nth-child(5) {
|
| 2770 |
-
|
| 2771 |
-
|
| 2772 |
-
|
| 2773 |
-
}
|
| 2774 |
-
#queue_df th:nth-child(6),
|
| 2775 |
-
#queue_df td:nth-child(6) {
|
| 2776 |
-
|
| 2777 |
-
|
| 2778 |
-
|
| 2779 |
-
}
|
| 2780 |
-
|
| 2781 |
-
|
| 2782 |
-
}
|
| 2783 |
-
#queue_df th:nth-child(7), #queue_df td:nth-child(7),
|
| 2784 |
-
#queue_df th:nth-child(8), #queue_df td:nth-child(8) {
|
| 2785 |
-
|
| 2786 |
-
|
| 2787 |
-
|
| 2788 |
-
}
|
| 2789 |
-
#queue_df td:nth-child(7) img,
|
| 2790 |
-
#queue_df td:nth-child(8) img {
|
| 2791 |
-
|
| 2792 |
-
|
| 2793 |
-
|
| 2794 |
-
|
| 2795 |
-
|
| 2796 |
-
|
| 2797 |
-
}
|
| 2798 |
-
#queue_df th:nth-child(9), #queue_df td:nth-child(9),
|
| 2799 |
-
#queue_df th:nth-child(10), #queue_df td:nth-child(10),
|
| 2800 |
-
#queue_df th:nth-child(11), #queue_df td:nth-child(11) {
|
| 2801 |
-
|
| 2802 |
-
|
| 2803 |
-
|
| 2804 |
-
|
| 2805 |
-
|
| 2806 |
-
|
| 2807 |
-
}
|
| 2808 |
-
#queue_df td:nth-child(7):hover,
|
| 2809 |
-
#queue_df td:nth-child(8):hover,
|
| 2810 |
-
#queue_df td:nth-child(9):hover,
|
| 2811 |
-
#queue_df td:nth-child(10):hover,
|
| 2812 |
-
#queue_df td:nth-child(11):hover {
|
| 2813 |
-
|
| 2814 |
-
}
|
| 2815 |
#image-modal-container {
|
| 2816 |
position: fixed;
|
| 2817 |
top: 0;
|
|
@@ -2893,8 +3354,8 @@ def create_demo():
|
|
| 2893 |
pointer-events: none;
|
| 2894 |
}
|
| 2895 |
"""
|
| 2896 |
-
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
|
| 2897 |
-
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.
|
| 2898 |
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
|
| 2899 |
|
| 2900 |
with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
|
|
@@ -2904,30 +3365,34 @@ def create_demo():
|
|
| 2904 |
gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
|
| 2905 |
gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
|
| 2906 |
gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
|
| 2907 |
-
|
|
|
|
|
|
|
| 2908 |
|
| 2909 |
with gr.Tabs(selected="i2v" if use_image2video else "t2v") as main_tabs:
|
| 2910 |
with gr.Tab("Text To Video", id="t2v") as t2v_tab:
|
| 2911 |
-
t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_state = generate_video_tab()
|
| 2912 |
with gr.Tab("Image To Video", id="i2v") as i2v_tab:
|
| 2913 |
-
i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_state = generate_video_tab(True)
|
| 2914 |
if not args.lock_config:
|
|
|
|
|
|
|
| 2915 |
with gr.Tab("Configuration"):
|
| 2916 |
generate_configuration_tab()
|
| 2917 |
with gr.Tab("About"):
|
| 2918 |
generate_about_tab()
|
| 2919 |
main_tabs.select(
|
| 2920 |
fn=on_tab_select,
|
| 2921 |
-
inputs=[t2v_state, i2v_state],
|
| 2922 |
outputs=[
|
| 2923 |
-
t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header,
|
| 2924 |
-
i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header
|
| 2925 |
]
|
| 2926 |
)
|
| 2927 |
return demo
|
| 2928 |
|
| 2929 |
if __name__ == "__main__":
|
| 2930 |
-
threading.Thread(target=runner, daemon=True).start()
|
| 2931 |
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
| 2932 |
server_port = int(args.server_port)
|
| 2933 |
if os.name == "nt":
|
|
|
|
| 1 |
import os
|
| 2 |
import time
|
| 3 |
+
import sys
|
| 4 |
import threading
|
| 5 |
import argparse
|
| 6 |
from mmgp import offload, safetensors2, profile_type
|
|
|
|
| 34 |
if mmgp_version != target_mmgp_version:
|
| 35 |
print(f"Incorrect version of mmgp ({mmgp_version}), version {target_mmgp_version} is needed. Please upgrade with the command 'pip install -r requirements.txt'")
|
| 36 |
exit()
|
|
|
|
| 37 |
lock = threading.Lock()
|
| 38 |
current_task_id = None
|
| 39 |
task_id = 0
|
| 40 |
+
# progress_tracker = {}
|
| 41 |
+
# tracker_lock = threading.Lock()
|
|
|
|
| 42 |
last_model_type = None
|
|
|
|
| 43 |
|
| 44 |
def format_time(seconds):
|
| 45 |
if seconds < 60:
|
|
|
|
| 75 |
print(f"Error converting PIL to base64: {e}")
|
| 76 |
return None
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
def process_prompt_and_add_tasks(
|
| 80 |
prompt,
|
|
|
|
| 104 |
slg_end,
|
| 105 |
cfg_star_switch,
|
| 106 |
cfg_zero_step,
|
| 107 |
+
state,
|
| 108 |
image2video
|
| 109 |
):
|
| 110 |
+
|
| 111 |
+
if state.get("validate_success",0) != 1:
|
| 112 |
+
gr.Info("Validation failed, not adding tasks.")
|
| 113 |
return
|
| 114 |
+
|
| 115 |
+
state["validate_success"] = 0
|
| 116 |
if len(prompt) ==0:
|
| 117 |
return
|
| 118 |
prompt, errors = prompt_parser.process_template(prompt)
|
| 119 |
if len(errors) > 0:
|
| 120 |
+
gr.Info("Error processing prompt template: " + errors)
|
| 121 |
return
|
| 122 |
prompts = prompt.replace("\r", "").split("\n")
|
| 123 |
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
|
| 124 |
if len(prompts) ==0:
|
| 125 |
return
|
| 126 |
|
| 127 |
+
file_model_needed = model_needed(image2video)
|
| 128 |
+
if image2video:
|
| 129 |
+
width, height = resolution.split("x")
|
| 130 |
+
width, height = int(width), int(height)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
+
if "480p" in file_model_needed and not "Fun" in file_model_needed and width * height > 848*480:
|
| 133 |
+
gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
|
| 134 |
+
return
|
| 135 |
+
resolution = str(width) + "*" + str(height)
|
| 136 |
+
if resolution not in ['720*1280', '1280*720', '480*832', '832*480']:
|
| 137 |
+
gr.Info(f"Resolution {resolution} not supported by image 2 video")
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
if "1.3B" in file_model_needed and width * height > 848*480:
|
| 141 |
+
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
|
| 142 |
+
return
|
| 143 |
+
|
| 144 |
+
if image2video:
|
| 145 |
+
if image_to_continue == None or isinstance(image_to_continue, list) and len(image_to_continue) == 0:
|
| 146 |
+
return
|
| 147 |
+
if image_prompt_type == 0:
|
| 148 |
+
image_to_end = None
|
| 149 |
+
if isinstance(image_to_continue, list):
|
| 150 |
+
image_to_continue = [ convert_image(tup[0]) for tup in image_to_continue ]
|
| 151 |
+
else:
|
| 152 |
+
image_to_continue = [convert_image(image_to_continue)]
|
| 153 |
+
if image_to_end != None:
|
| 154 |
+
if isinstance(image_to_end , list):
|
| 155 |
+
image_to_end = [ convert_image(tup[0]) for tup in image_to_end ]
|
| 156 |
+
else:
|
| 157 |
+
image_to_end = [convert_image(image_to_end) ]
|
| 158 |
+
if len(image_to_continue) != len(image_to_end):
|
| 159 |
+
gr.Info("The number of start and end images should be the same ")
|
| 160 |
+
return
|
| 161 |
+
|
| 162 |
+
if multi_images_gen_type == 0:
|
| 163 |
+
new_prompts = []
|
| 164 |
+
new_image_to_continue = []
|
| 165 |
+
new_image_to_end = []
|
| 166 |
+
for i in range(len(prompts) * len(image_to_continue) ):
|
| 167 |
+
new_prompts.append( prompts[ i % len(prompts)] )
|
| 168 |
+
new_image_to_continue.append(image_to_continue[i // len(prompts)] )
|
| 169 |
+
if image_to_end != None:
|
| 170 |
+
new_image_to_end.append(image_to_end[i // len(prompts)] )
|
| 171 |
+
prompts = new_prompts
|
| 172 |
+
image_to_continue = new_image_to_continue
|
| 173 |
+
if image_to_end != None:
|
| 174 |
+
image_to_end = new_image_to_end
|
| 175 |
+
else:
|
| 176 |
+
if len(prompts) >= len(image_to_continue):
|
| 177 |
+
if len(prompts) % len(image_to_continue) !=0:
|
| 178 |
+
raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
|
| 179 |
+
rep = len(prompts) // len(image_to_continue)
|
| 180 |
+
new_image_to_continue = []
|
| 181 |
+
new_image_to_end = []
|
| 182 |
+
for i, _ in enumerate(prompts):
|
| 183 |
+
new_image_to_continue.append(image_to_continue[i//rep] )
|
| 184 |
+
if image_to_end != None:
|
| 185 |
+
new_image_to_end.append(image_to_end[i//rep] )
|
| 186 |
+
image_to_continue = new_image_to_continue
|
| 187 |
+
if image_to_end != None:
|
| 188 |
+
image_to_end = new_image_to_end
|
| 189 |
+
else:
|
| 190 |
+
if len(image_to_continue) % len(prompts) !=0:
|
| 191 |
+
raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
|
| 192 |
+
rep = len(image_to_continue) // len(prompts)
|
| 193 |
+
new_prompts = []
|
| 194 |
+
for i, _ in enumerate(image_to_continue):
|
| 195 |
+
new_prompts.append( prompts[ i//rep] )
|
| 196 |
+
prompts = new_prompts
|
| 197 |
+
|
| 198 |
+
# elif video_to_continue != None and len(video_to_continue) >0 :
|
| 199 |
+
# input_image_or_video_path = video_to_continue
|
| 200 |
+
# # pipeline.num_input_frames = max_frames
|
| 201 |
+
# # pipeline.max_frames = max_frames
|
| 202 |
+
# else:
|
| 203 |
+
# return
|
| 204 |
+
# else:
|
| 205 |
+
# input_image_or_video_path = None
|
| 206 |
+
if image_to_continue == None:
|
| 207 |
+
image_to_continue = [None] * len(prompts)
|
| 208 |
+
if image_to_end == None:
|
| 209 |
+
image_to_end = [None] * len(prompts)
|
| 210 |
+
|
| 211 |
+
for single_prompt, image_start, image_end in zip(prompts, image_to_continue, image_to_end) :
|
| 212 |
+
kwargs = {
|
| 213 |
+
"prompt" : single_prompt,
|
| 214 |
+
"negative_prompt" : negative_prompt,
|
| 215 |
+
"resolution" : resolution,
|
| 216 |
+
"video_length" : video_length,
|
| 217 |
+
"seed" : seed,
|
| 218 |
+
"num_inference_steps" : num_inference_steps,
|
| 219 |
+
"guidance_scale" : guidance_scale,
|
| 220 |
+
"flow_shift" : flow_shift,
|
| 221 |
+
"embedded_guidance_scale" : embedded_guidance_scale,
|
| 222 |
+
"repeat_generation" : repeat_generation,
|
| 223 |
+
"multi_images_gen_type" : multi_images_gen_type,
|
| 224 |
+
"tea_cache" : tea_cache,
|
| 225 |
+
"tea_cache_start_step_perc" : tea_cache_start_step_perc,
|
| 226 |
+
"loras_choices" : loras_choices,
|
| 227 |
+
"loras_mult_choices" : loras_mult_choices,
|
| 228 |
+
"image_prompt_type" : image_prompt_type,
|
| 229 |
+
"image_to_continue": image_start,
|
| 230 |
+
"image_to_end" : image_end,
|
| 231 |
+
"video_to_continue" : video_to_continue ,
|
| 232 |
+
"max_frames" : max_frames,
|
| 233 |
+
"RIFLEx_setting" : RIFLEx_setting,
|
| 234 |
+
"slg_switch" : slg_switch,
|
| 235 |
+
"slg_layers" : slg_layers,
|
| 236 |
+
"slg_start" : slg_start,
|
| 237 |
+
"slg_end" : slg_end,
|
| 238 |
+
"cfg_star_switch" : cfg_star_switch,
|
| 239 |
+
"cfg_zero_step" : cfg_zero_step,
|
| 240 |
+
"state" : state,
|
| 241 |
+
"image2video" : image2video
|
| 242 |
+
}
|
| 243 |
+
add_video_task(**kwargs)
|
| 244 |
+
|
| 245 |
+
gen = get_gen_info(state)
|
| 246 |
+
gen["prompts_max"] = len(prompts) + gen.get("prompts_max",0)
|
| 247 |
+
state["validate_success"] = 1
|
| 248 |
+
queue= gen.get("queue", [])
|
| 249 |
+
return update_queue_data(queue)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def add_video_task(**kwargs):
|
| 255 |
global task_id
|
| 256 |
+
state = kwargs["state"]
|
| 257 |
+
gen = get_gen_info(state)
|
| 258 |
+
queue = gen["queue"]
|
| 259 |
+
task_id += 1
|
| 260 |
+
current_task_id = task_id
|
| 261 |
+
start_image_data = kwargs["image_to_continue"]
|
| 262 |
+
end_image_data = kwargs["image_to_end"]
|
| 263 |
+
|
| 264 |
+
queue.append({
|
| 265 |
+
"id": current_task_id,
|
| 266 |
+
"image2video": kwargs["image2video"],
|
| 267 |
+
"params": kwargs.copy(),
|
| 268 |
+
"repeats": kwargs["repeat_generation"],
|
| 269 |
+
"length": kwargs["video_length"],
|
| 270 |
+
"steps": kwargs["num_inference_steps"],
|
| 271 |
+
"prompt": kwargs["prompt"],
|
| 272 |
+
"start_image_data": start_image_data,
|
| 273 |
+
"end_image_data": end_image_data,
|
| 274 |
+
"start_image_data_base64": pil_to_base64_uri(start_image_data, format="jpeg", quality=70),
|
| 275 |
+
"end_image_data_base64": pil_to_base64_uri(end_image_data, format="jpeg", quality=70)
|
| 276 |
+
})
|
| 277 |
+
return update_queue_data(queue)
|
| 278 |
+
|
| 279 |
+
def move_up(queue, selected_indices):
|
| 280 |
if not selected_indices or len(selected_indices) == 0:
|
| 281 |
+
return update_queue_data(queue)
|
| 282 |
idx = selected_indices[0]
|
| 283 |
if isinstance(idx, list):
|
| 284 |
idx = idx[0]
|
| 285 |
idx = int(idx)
|
| 286 |
with lock:
|
| 287 |
if idx > 0:
|
| 288 |
+
idx += 1
|
| 289 |
queue[idx], queue[idx-1] = queue[idx-1], queue[idx]
|
| 290 |
+
return update_queue_data(queue)
|
| 291 |
|
| 292 |
+
def move_down(queue, selected_indices):
|
| 293 |
if not selected_indices or len(selected_indices) == 0:
|
| 294 |
+
return update_queue_data(queue)
|
| 295 |
idx = selected_indices[0]
|
| 296 |
if isinstance(idx, list):
|
| 297 |
idx = idx[0]
|
| 298 |
idx = int(idx)
|
| 299 |
with lock:
|
| 300 |
+
idx += 1
|
| 301 |
if idx < len(queue)-1:
|
| 302 |
queue[idx], queue[idx+1] = queue[idx+1], queue[idx]
|
| 303 |
+
return update_queue_data(queue)
|
| 304 |
|
| 305 |
+
def remove_task(queue, selected_indices):
|
| 306 |
if not selected_indices or len(selected_indices) == 0:
|
| 307 |
+
return update_queue_data(queue)
|
| 308 |
idx = selected_indices[0]
|
| 309 |
if isinstance(idx, list):
|
| 310 |
idx = idx[0]
|
| 311 |
+
idx = int(idx) + 1
|
| 312 |
with lock:
|
| 313 |
if idx < len(queue):
|
| 314 |
if idx == 0:
|
| 315 |
wan_model._interrupt = True
|
| 316 |
del queue[idx]
|
| 317 |
+
return update_queue_data(queue)
|
| 318 |
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def get_queue_table(queue):
|
| 322 |
+
data = []
|
| 323 |
+
if len(queue) == 1:
|
| 324 |
+
return data
|
| 325 |
+
|
| 326 |
+
# def td(l, content, width =None):
|
| 327 |
+
# if width !=None:
|
| 328 |
+
# l.append("<TD WIDTH="+ str(width) + "px>" + content + "</TD>")
|
| 329 |
+
# else:
|
| 330 |
+
# l.append("<TD>" + content + "</TD>")
|
| 331 |
+
|
| 332 |
+
# data.append("<STYLE> .TB, .TB th, .TB td {border: 1px solid #CCCCCC};></STYLE><TABLE CLASS=TB><TR BGCOLOR=#F2F2F2><TD Style='Bold'>Qty</TD><TD>Prompt</TD><TD>Steps</TD><TD></TD><TD><TD></TD><TD></TD><TD></TD></TR>")
|
| 333 |
+
|
| 334 |
+
for i, item in enumerate(queue):
|
| 335 |
+
if i==0:
|
| 336 |
+
continue
|
| 337 |
+
truncated_prompt = (item['prompt'][:97] + '...') if len(item['prompt']) > 100 else item['prompt']
|
| 338 |
+
full_prompt = item['prompt'].replace('"', '"')
|
| 339 |
+
prompt_cell = f'<span title="{full_prompt}">{truncated_prompt}</span>'
|
| 340 |
+
start_img_uri =item.get('start_image_data_base64')
|
| 341 |
+
end_img_uri = item.get('end_image_data_base64')
|
| 342 |
+
thumbnail_size = "50px"
|
| 343 |
+
num_steps = item.get('steps')
|
| 344 |
+
length = item.get('length')
|
| 345 |
+
start_img_md = ""
|
| 346 |
+
end_img_md = ""
|
| 347 |
+
if start_img_uri:
|
| 348 |
+
start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
|
| 349 |
+
if end_img_uri:
|
| 350 |
+
end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
|
| 351 |
+
# if i % 2 == 1:
|
| 352 |
+
# data.append("<TR>")
|
| 353 |
+
# else:
|
| 354 |
+
# data.append("<TR BGCOLOR=#F2F2F2>")
|
| 355 |
+
|
| 356 |
+
# td(data,str(item.get('repeats', "1")) )
|
| 357 |
+
# td(data, prompt_cell, "100%")
|
| 358 |
+
# td(data, num_steps, "100%")
|
| 359 |
+
# td(data, start_img_md)
|
| 360 |
+
# td(data, end_img_md)
|
| 361 |
+
# td(data, "↑")
|
| 362 |
+
# td(data, "↓")
|
| 363 |
+
# td(data, "✖")
|
| 364 |
+
# data.append("</TR>")
|
| 365 |
+
# data.append("</TABLE>")
|
| 366 |
+
# return ''.join(data)
|
| 367 |
+
|
| 368 |
+
data.append([item.get('repeats', "1"),
|
| 369 |
+
prompt_cell,
|
| 370 |
+
length,
|
| 371 |
+
num_steps,
|
| 372 |
+
start_img_md,
|
| 373 |
+
end_img_md,
|
| 374 |
+
"↑",
|
| 375 |
+
"↓",
|
| 376 |
+
"✖"
|
| 377 |
+
])
|
| 378 |
+
return data
|
| 379 |
+
def update_queue_data(queue):
|
| 380 |
+
|
| 381 |
+
data = get_queue_table(queue)
|
| 382 |
+
|
| 383 |
+
# if len(data) == 0:
|
| 384 |
+
# return gr.HTML(visible=False)
|
| 385 |
+
# else:
|
| 386 |
+
# return gr.HTML(value=data, visible= True)
|
| 387 |
+
if len(data) == 0:
|
| 388 |
+
return gr.DataFrame(visible=False)
|
| 389 |
+
else:
|
| 390 |
+
return gr.DataFrame(value=data, visible= True)
|
| 391 |
|
| 392 |
def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True):
|
| 393 |
bar_class = "progress-bar-custom idle" if is_idle else "progress-bar-custom"
|
|
|
|
| 402 |
"""
|
| 403 |
return html
|
| 404 |
|
| 405 |
+
# def refresh_progress():
|
| 406 |
+
# global current_task_id, progress_tracker, last_status_string
|
| 407 |
+
# task_id_to_check = current_task_id
|
| 408 |
+
# is_idle = True
|
| 409 |
+
# status_string = "Starting..."
|
| 410 |
+
# progress_percent = 0.0
|
| 411 |
+
# html_content = ""
|
| 412 |
+
|
| 413 |
+
# with tracker_lock:
|
| 414 |
+
# with lock:
|
| 415 |
+
# processing_or_queued = any(item['state'] in ["Processing", "Queued"] for item in queue)
|
| 416 |
+
# if task_id_to_check is not None:
|
| 417 |
+
# progress_data = progress_tracker.get(task_id_to_check)
|
| 418 |
+
# if progress_data:
|
| 419 |
+
# is_idle = False
|
| 420 |
+
# current_step = progress_data.get('current_step', 0)
|
| 421 |
+
# total_steps = progress_data.get('total_steps', 0)
|
| 422 |
+
# status = progress_data.get('status', "Starting...")
|
| 423 |
+
# repeats = progress_data.get("repeats", 1)
|
| 424 |
+
|
| 425 |
+
# if total_steps > 0:
|
| 426 |
+
# progress_float = min(1.0, max(0.0, float(current_step) / float(total_steps)))
|
| 427 |
+
# progress_percent = progress_float * 100
|
| 428 |
+
# status_string = f"{status} [{repeats}] - {progress_percent:.1f}% complete ({current_step}/{total_steps} steps)"
|
| 429 |
+
# else:
|
| 430 |
+
# progress_percent = 0.0
|
| 431 |
+
# status_string = f"{status} [{repeats}] - Initializing..."
|
| 432 |
+
# html_content = create_html_progress_bar(progress_percent, status_string, is_idle)
|
| 433 |
+
# return gr.update(value=html_content)
|
| 434 |
|
| 435 |
def update_generation_status(html_content):
|
| 436 |
if(html_content):
|
|
|
|
| 832 |
|
| 833 |
only_allow_edit_in_advanced = False
|
| 834 |
lora_preselected_preset = args.lora_preset
|
|
|
|
| 835 |
# if args.fast : #or args.fastest
|
| 836 |
# transformer_filename_t2v = transformer_choices_t2v[2]
|
| 837 |
# attention_mode="sage2" if "sage2" in attention_modes_supported else "sage"
|
|
|
|
| 844 |
lock_ui_compile = True
|
| 845 |
|
| 846 |
model_filename = ""
|
|
|
|
| 847 |
#attention_mode="sage"
|
| 848 |
#attention_mode="sage2"
|
| 849 |
#attention_mode="flash"
|
|
|
|
| 852 |
# compile = "transformer"
|
| 853 |
|
| 854 |
def preprocess_loras(sd):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 855 |
first = next(iter(sd), None)
|
| 856 |
if first == None:
|
| 857 |
return sd
|
| 858 |
+
if not first.startswith("lora_unet_"):
|
| 859 |
return sd
|
| 860 |
+
new_sd = {}
|
| 861 |
print("Converting Lora Safetensors format to Lora Diffusers format")
|
| 862 |
alphas = {}
|
| 863 |
repl_list = ["cross_attn", "self_attn", "ffn"]
|
|
|
|
| 936 |
def sanitize_file_name(file_name, rep =""):
|
| 937 |
return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep)
|
| 938 |
|
| 939 |
+
def extract_preset(image2video, lset_name, loras):
|
| 940 |
loras_choices = []
|
| 941 |
loras_choices_files = []
|
| 942 |
loras_mult_choices = ""
|
| 943 |
prompt =""
|
| 944 |
full_prompt =""
|
| 945 |
lset_name = sanitize_file_name(lset_name)
|
| 946 |
+
lora_dir = get_lora_dir(image2video)
|
| 947 |
if not lset_name.endswith(".lset"):
|
| 948 |
lset_name_filename = os.path.join(lora_dir, lset_name + ".lset" )
|
| 949 |
else:
|
|
|
|
| 1014 |
if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
|
| 1015 |
raise Exception(f"Unknown preset '{lora_preselected_preset}'")
|
| 1016 |
default_lora_preset = lora_preselected_preset
|
| 1017 |
+
default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, _ , error = extract_preset(i2v, default_lora_preset, loras)
|
| 1018 |
if len(error) > 0:
|
| 1019 |
print(error[:200])
|
| 1020 |
return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
|
|
|
|
| 1101 |
# kwargs["partialPinning"] = True
|
| 1102 |
elif profile == 3:
|
| 1103 |
kwargs["budgets"] = { "*" : "70%" }
|
| 1104 |
+
offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = quantizeTransformer, loras = "transformer", coTenantsMap= {}, **kwargs)
|
| 1105 |
if len(args.gpu) > 0:
|
| 1106 |
torch.set_default_device(args.gpu)
|
| 1107 |
|
|
|
|
| 1178 |
if gen_in_progress:
|
| 1179 |
yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
|
| 1180 |
return
|
| 1181 |
+
global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
|
| 1182 |
server_config = {"attention_mode" : attention_choice,
|
| 1183 |
"transformer_filename": transformer_choices_t2v[transformer_t2v_choice],
|
| 1184 |
"transformer_filename_i2v": transformer_choices_i2v[transformer_i2v_choice],
|
|
|
|
| 1243 |
final_frames = (final_frames * 255).astype(np.uint8)
|
| 1244 |
ImageSequenceClip(list(final_frames), fps=fps).write_videofile(output_path, verbose= False, logger = None)
|
| 1245 |
|
| 1246 |
+
|
| 1247 |
+
def get_gen_info(state):
|
| 1248 |
+
cache = state.get("gen", None)
|
| 1249 |
+
if cache == None:
|
| 1250 |
+
cache = dict()
|
| 1251 |
+
state["gen"] = cache
|
| 1252 |
+
return cache
|
| 1253 |
+
|
| 1254 |
+
def build_callback(state, pipe, progress, status, num_inference_steps):
|
| 1255 |
+
def callback(step_idx, force_refresh, read_state = False):
|
| 1256 |
+
gen = get_gen_info(state)
|
| 1257 |
+
refresh_id = gen.get("refresh", -1)
|
| 1258 |
+
if force_refresh or step_idx >= 0:
|
| 1259 |
+
pass
|
| 1260 |
+
else:
|
| 1261 |
+
refresh_id = gen.get("refresh", -1)
|
| 1262 |
+
if refresh_id < 0:
|
| 1263 |
+
return
|
| 1264 |
+
UI_refresh = state.get("refresh", 0)
|
| 1265 |
+
if UI_refresh >= refresh_id:
|
| 1266 |
+
return
|
| 1267 |
+
|
| 1268 |
+
status = gen["progress_status"]
|
| 1269 |
+
state["refresh"] = refresh_id
|
| 1270 |
+
if read_state:
|
| 1271 |
+
phase, step_idx = gen["progress_phase"]
|
| 1272 |
+
else:
|
| 1273 |
+
step_idx += 1
|
| 1274 |
+
if gen.get("abort", False):
|
| 1275 |
# pipe._interrupt = True
|
| 1276 |
+
phase = " - Aborting"
|
| 1277 |
elif step_idx == num_inference_steps:
|
| 1278 |
+
phase = " - VAE Decoding"
|
| 1279 |
else:
|
| 1280 |
+
phase = " - Denoising"
|
| 1281 |
+
gen["progress_phase"] = (phase, step_idx)
|
| 1282 |
+
status_msg = status + phase
|
| 1283 |
+
if step_idx >= 0:
|
| 1284 |
+
progress_args = [(step_idx , num_inference_steps) , status_msg , num_inference_steps]
|
| 1285 |
+
else:
|
| 1286 |
+
progress_args = [0, status_msg]
|
| 1287 |
+
|
| 1288 |
+
progress(*progress_args)
|
| 1289 |
+
gen["progress_args"] = progress_args
|
| 1290 |
+
|
| 1291 |
+
return callback
|
| 1292 |
+
def abort_generation(state):
|
| 1293 |
+
gen = get_gen_info(state)
|
| 1294 |
+
if "in_progress" in gen:
|
| 1295 |
+
|
| 1296 |
+
gen["abort"] = True
|
| 1297 |
+
gen["extra_orders"] = 0
|
| 1298 |
+
wan_model._interrupt= True
|
| 1299 |
+
msg = "Processing Request to abort Current Generation"
|
| 1300 |
+
gr.Info(msg)
|
| 1301 |
+
return msg, gr.Button(interactive= False)
|
| 1302 |
+
else:
|
| 1303 |
+
return "", gr.Button(interactive= True)
|
| 1304 |
+
|
| 1305 |
+
def is_gen_location(state):
|
| 1306 |
+
gen = get_gen_info(state)
|
| 1307 |
+
|
| 1308 |
+
gen_location = gen.get("location",None)
|
| 1309 |
+
if gen_location == None:
|
| 1310 |
+
return None
|
| 1311 |
+
return state["image2video"] == gen_location
|
| 1312 |
+
|
| 1313 |
+
|
| 1314 |
+
def refresh_gallery(state, msg):
|
| 1315 |
+
gen = get_gen_info(state)
|
| 1316 |
+
|
| 1317 |
+
if is_gen_location(state):
|
| 1318 |
+
gen["last_msg"] = msg
|
| 1319 |
+
file_list = gen.get("file_list", None)
|
| 1320 |
+
choice = gen.get("selected",0)
|
| 1321 |
+
in_progress = "in_progress" in gen
|
| 1322 |
+
if in_progress:
|
| 1323 |
+
if gen.get("last_selected", True):
|
| 1324 |
+
choice = max(len(file_list) - 1,0)
|
| 1325 |
+
|
| 1326 |
+
queue = gen.get("queue", [])
|
| 1327 |
+
abort_interactive = not gen.get("abort", False)
|
| 1328 |
+
if not in_progress or len(queue) == 0:
|
| 1329 |
+
return gr.Gallery(selected_index=choice, value = file_list), gr.HTML("", visible= False), gr.Button(visible=True), gr.Button(visible=False), gr.Row(visible=False), update_queue_data(queue), gr.Button(interactive= abort_interactive)
|
| 1330 |
+
else:
|
| 1331 |
+
task = queue[0]
|
| 1332 |
+
start_img_md = ""
|
| 1333 |
+
end_img_md = ""
|
| 1334 |
+
prompt = task["prompt"]
|
| 1335 |
+
|
| 1336 |
+
if task.get('image2video'):
|
| 1337 |
+
start_img_uri = task.get('start_image_data_base64')
|
| 1338 |
+
end_img_uri = task.get('end_image_data_base64')
|
| 1339 |
+
thumbnail_size = "100px"
|
| 1340 |
+
if start_img_uri:
|
| 1341 |
+
start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
|
| 1342 |
+
if end_img_uri:
|
| 1343 |
+
end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
|
| 1344 |
+
|
| 1345 |
+
label = f"Prompt of Video being Generated"
|
| 1346 |
+
|
| 1347 |
+
html = "<STYLE> #PINFO, #PINFO th, #PINFO td {border: 1px solid #CCCCCC;background-color:#FFFFFF;}</STYLE><TABLE WIDTH=100% ID=PINFO ><TR><TD width=100%>" + prompt + "</TD>"
|
| 1348 |
+
if start_img_md != "":
|
| 1349 |
+
html += "<TD>" + start_img_md + "</TD>"
|
| 1350 |
+
if end_img_md != "":
|
| 1351 |
+
html += "<TD>" + end_img_md + "</TD>"
|
| 1352 |
+
|
| 1353 |
+
html += "</TR></TABLE>"
|
| 1354 |
+
html_output = gr.HTML(html, visible= True)
|
| 1355 |
+
return gr.Gallery(selected_index=choice, value = file_list), html_output, gr.Button(visible=False), gr.Button(visible=True), gr.Row(visible=True), update_queue_data(queue), gr.Button(interactive= abort_interactive)
|
| 1356 |
+
|
| 1357 |
+
|
| 1358 |
+
|
| 1359 |
+
def finalize_generation(state):
|
| 1360 |
+
gen = get_gen_info(state)
|
| 1361 |
+
choice = gen.get("selected",0)
|
| 1362 |
+
if "in_progress" in gen:
|
| 1363 |
+
del gen["in_progress"]
|
| 1364 |
+
if gen.get("last_selected", True):
|
| 1365 |
+
file_list = gen.get("file_list", [])
|
| 1366 |
+
choice = len(file_list) - 1
|
| 1367 |
+
|
| 1368 |
+
|
| 1369 |
+
gen["extra_orders"] = 0
|
| 1370 |
+
time.sleep(0.2)
|
| 1371 |
+
global gen_in_progress
|
| 1372 |
+
gen_in_progress = False
|
| 1373 |
+
return gr.Gallery(selected_index=choice), gr.Button(interactive= True), gr.Button(visible= True), gr.Button(visible= False), gr.Column(visible= False), gr.HTML(visible= False, value="")
|
| 1374 |
+
|
| 1375 |
|
| 1376 |
def refresh_gallery_on_trigger(state):
|
| 1377 |
+
gen = get_gen_info(state)
|
| 1378 |
+
|
| 1379 |
+
if(gen.get("update_gallery", False)):
|
| 1380 |
+
gen['update_gallery'] = False
|
| 1381 |
+
return gr.update(value=gen.get("file_list", []))
|
| 1382 |
|
| 1383 |
def select_video(state , event_data: gr.EventData):
|
| 1384 |
data= event_data._data
|
| 1385 |
+
gen = get_gen_info(state)
|
| 1386 |
+
|
| 1387 |
if data!=None:
|
| 1388 |
choice = data.get("index",0)
|
| 1389 |
+
file_list = gen.get("file_list", [])
|
| 1390 |
+
gen["last_selected"] = (choice + 1) >= len(file_list)
|
| 1391 |
+
gen["selected"] = choice
|
| 1392 |
return
|
| 1393 |
|
| 1394 |
def expand_slist(slist, num_inference_steps ):
|
|
|
|
| 1420 |
|
| 1421 |
def generate_video(
|
| 1422 |
task_id,
|
| 1423 |
+
progress,
|
| 1424 |
prompt,
|
| 1425 |
negative_prompt,
|
| 1426 |
resolution,
|
|
|
|
| 1454 |
):
|
| 1455 |
|
| 1456 |
global wan_model, offloadobj, reload_needed, last_model_type
|
| 1457 |
+
gen = get_gen_info(state)
|
| 1458 |
+
|
| 1459 |
+
file_list = gen["file_list"]
|
| 1460 |
+
prompt_no = gen["prompt_no"]
|
| 1461 |
+
|
| 1462 |
file_model_needed = model_needed(image2video)
|
| 1463 |
+
# queue = gen.get("queue", [])
|
| 1464 |
+
# with lock:
|
| 1465 |
+
# queue_not_empty = len(queue) > 0
|
| 1466 |
+
# if(last_model_type != image2video and (queue_not_empty or server_config.get("reload_model",1) == 2) and (file_model_needed != model_filename or reload_needed)):
|
| 1467 |
+
if file_model_needed != model_filename or reload_needed:
|
| 1468 |
del wan_model
|
| 1469 |
if offloadobj is not None:
|
| 1470 |
offloadobj.release()
|
| 1471 |
del offloadobj
|
| 1472 |
gc.collect()
|
| 1473 |
+
yield f"Loading model {get_model_name(file_model_needed)}..."
|
| 1474 |
wan_model, offloadobj, trans = load_models(image2video)
|
| 1475 |
+
yield f"Model loaded"
|
| 1476 |
reload_needed= False
|
| 1477 |
|
| 1478 |
if wan_model == None:
|
| 1479 |
+
gr.Info("Unable to generate a Video while a new configuration is being applied.")
|
| 1480 |
if attention_mode == "auto":
|
| 1481 |
attn = get_auto_attention()
|
| 1482 |
elif attention_mode in attention_modes_supported:
|
|
|
|
| 1485 |
gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.")
|
| 1486 |
return
|
| 1487 |
|
| 1488 |
+
|
| 1489 |
+
|
| 1490 |
+
if not image2video:
|
| 1491 |
+
width, height = resolution.split("x")
|
| 1492 |
+
width, height = int(width), int(height)
|
| 1493 |
|
| 1494 |
if slg_switch == 0:
|
| 1495 |
slg_layers = None
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1496 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1497 |
offload.shared_state["_attention"] = attn
|
| 1498 |
|
| 1499 |
# VAE Tiling
|
|
|
|
| 1517 |
|
| 1518 |
trans = wan_model.model
|
| 1519 |
|
|
|
|
|
|
|
| 1520 |
temp_filename = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1521 |
|
| 1522 |
loras = state["loras"]
|
| 1523 |
if len(loras) > 0:
|
|
|
|
| 1561 |
raise gr.Error("Error while loading Loras: " + ", ".join(error_files))
|
| 1562 |
seed = None if seed == -1 else seed
|
| 1563 |
# negative_prompt = "" # not applicable in the inference
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1564 |
|
| 1565 |
enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
|
| 1566 |
# VAE Tiling
|
|
|
|
| 1597 |
if seed == None or seed <0:
|
| 1598 |
seed = random.randint(0, 999999999)
|
| 1599 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1600 |
global save_path
|
| 1601 |
os.makedirs(save_path, exist_ok=True)
|
| 1602 |
video_no = 0
|
| 1603 |
abort = False
|
|
|
|
|
|
|
|
|
|
| 1604 |
gc.collect()
|
| 1605 |
torch.cuda.empty_cache()
|
| 1606 |
wan_model._interrupt = False
|
| 1607 |
+
gen["abort"] = False
|
| 1608 |
+
gen["prompt"] = prompt
|
| 1609 |
+
repeat_no = 0
|
| 1610 |
+
extra_generation = 0
|
| 1611 |
+
while True:
|
| 1612 |
+
extra_generation += gen.get("extra_orders",0)
|
| 1613 |
+
gen["extra_orders"] = 0
|
| 1614 |
+
total_generation = repeat_generation + extra_generation
|
| 1615 |
+
gen["total_generation"] = total_generation
|
| 1616 |
+
if abort or repeat_no >= total_generation:
|
| 1617 |
+
break
|
| 1618 |
+
repeat_no +=1
|
| 1619 |
+
gen["repeat_no"] = repeat_no
|
| 1620 |
+
prompts_max = gen["prompts_max"]
|
| 1621 |
+
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation)
|
| 1622 |
+
|
| 1623 |
+
yield status
|
| 1624 |
+
|
| 1625 |
+
gen["progress_status"] = status
|
| 1626 |
+
gen["progress_phase"] = (" - Encoding Prompt", -1 )
|
| 1627 |
+
callback = build_callback(state, trans, progress, status, num_inference_steps)
|
| 1628 |
+
progress_args = [0, status + " - Encoding Prompt"]
|
| 1629 |
+
progress(*progress_args )
|
| 1630 |
+
gen["progress_args"] = progress_args
|
| 1631 |
+
|
| 1632 |
try:
|
| 1633 |
+
start_time = time.time()
|
| 1634 |
+
# with tracker_lock:
|
| 1635 |
+
# progress_tracker[task_id] = {
|
| 1636 |
+
# 'current_step': 0,
|
| 1637 |
+
# 'total_steps': num_inference_steps,
|
| 1638 |
+
# 'start_time': start_time,
|
| 1639 |
+
# 'last_update': start_time,
|
| 1640 |
+
# 'repeats': repeat_generation, # f"{video_no}/{repeat_generation}",
|
| 1641 |
+
# 'status': "Encoding Prompt"
|
| 1642 |
+
# }
|
| 1643 |
if trans.enable_teacache:
|
| 1644 |
trans.teacache_counter = 0
|
| 1645 |
trans.num_steps = num_inference_steps
|
| 1646 |
+
trans.teacache_skipped_steps = 0
|
| 1647 |
trans.previous_residual_uncond = None
|
| 1648 |
trans.previous_residual_cond = None
|
| 1649 |
|
|
|
|
| 1651 |
if image2video:
|
| 1652 |
samples = wan_model.generate(
|
| 1653 |
prompt,
|
| 1654 |
+
image_to_continue,
|
| 1655 |
+
image_to_end if image_to_end != None else None,
|
| 1656 |
frame_num=(video_length // 4)* 4 + 1,
|
| 1657 |
max_area=MAX_AREA_CONFIGS[resolution],
|
| 1658 |
shift=flow_shift,
|
|
|
|
| 1670 |
slg_end = slg_end/100,
|
| 1671 |
cfg_star_switch = cfg_star_switch,
|
| 1672 |
cfg_zero_step = cfg_zero_step,
|
| 1673 |
+
add_frames_for_end_image = not "Fun" in transformer_filename_i2v,
|
| 1674 |
)
|
| 1675 |
else:
|
| 1676 |
samples = wan_model.generate(
|
|
|
|
| 1694 |
cfg_zero_step = cfg_zero_step,
|
| 1695 |
)
|
| 1696 |
except Exception as e:
|
|
|
|
| 1697 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
| 1698 |
os.remove(temp_filename)
|
| 1699 |
offload.last_offload_obj.unload_all()
|
|
|
|
| 1716 |
if any( keyword in frame.name for keyword in keyword_list):
|
| 1717 |
VRAM_crash = True
|
| 1718 |
break
|
| 1719 |
+
|
| 1720 |
+
_ , exc_value, exc_traceback = sys.exc_info()
|
| 1721 |
+
|
| 1722 |
state["prompt"] = ""
|
| 1723 |
if VRAM_crash:
|
| 1724 |
+
new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames."
|
| 1725 |
else:
|
| 1726 |
+
new_error = gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
| 1727 |
+
tb = traceback.format_exc().split('\n')[:-2]
|
| 1728 |
+
print('\n'.join(tb))
|
| 1729 |
+
raise gr.Error(new_error, print_exception= False)
|
| 1730 |
+
|
| 1731 |
finally:
|
| 1732 |
+
pass
|
| 1733 |
+
# with tracker_lock:
|
| 1734 |
+
# if task_id in progress_tracker:
|
| 1735 |
+
# del progress_tracker[task_id]
|
| 1736 |
|
| 1737 |
if trans.enable_teacache:
|
| 1738 |
print(f"Teacache Skipped Steps:{trans.teacache_skipped_steps}/{num_inference_steps}" )
|
|
|
|
| 1749 |
end_time = time.time()
|
| 1750 |
abort = True
|
| 1751 |
state["prompt"] = ""
|
| 1752 |
+
# yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
|
| 1753 |
else:
|
| 1754 |
sample = samples.cpu()
|
| 1755 |
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
|
|
|
|
| 1768 |
normalize=True,
|
| 1769 |
value_range=(-1, 1))
|
| 1770 |
|
| 1771 |
+
configs = get_settings_dict(state, image2video, prompt, 0 if image_to_end == None else 1 , video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
| 1772 |
loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
|
| 1773 |
|
| 1774 |
metadata_choice = server_config.get("metadata_choice","metadata")
|
|
|
|
| 1790 |
|
| 1791 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
| 1792 |
os.remove(temp_filename)
|
|
|
|
| 1793 |
offload.unload_loras_from_model(trans)
|
| 1794 |
|
| 1795 |
+
def prepare_generate_video(state):
|
| 1796 |
+
if state.get("validate_success",0) != 1:
|
| 1797 |
+
return gr.Button(visible= True), gr.Button(visible= False), gr.Column(visible= False)
|
| 1798 |
+
else:
|
| 1799 |
+
return gr.Button(visible= False), gr.Button(visible= True), gr.Column(visible= True)
|
| 1800 |
+
|
| 1801 |
+
|
| 1802 |
+
def wait_tasks_done(state, progress=gr.Progress()):
|
| 1803 |
+
|
| 1804 |
+
gen = get_gen_info(state)
|
| 1805 |
+
gen_location = is_gen_location(state)
|
| 1806 |
+
|
| 1807 |
+
last_msg = gen.get("last_msg", "")
|
| 1808 |
+
if len(last_msg) > 0:
|
| 1809 |
+
yield last_msg
|
| 1810 |
+
|
| 1811 |
+
if gen_location == None or gen_location:
|
| 1812 |
+
return gr.Text()
|
| 1813 |
+
|
| 1814 |
+
|
| 1815 |
+
while True:
|
| 1816 |
+
|
| 1817 |
+
msg = gen.get("last_msg", "")
|
| 1818 |
+
if len(msg) > 0 and last_msg != msg:
|
| 1819 |
+
yield msg
|
| 1820 |
+
last_msg = msg
|
| 1821 |
+
progress_args = gen.get("progress_args", None)
|
| 1822 |
+
if progress_args != None:
|
| 1823 |
+
progress(*progress_args)
|
| 1824 |
+
|
| 1825 |
+
in_progress= gen.get("in_progress", False)
|
| 1826 |
+
if not in_progress:
|
| 1827 |
+
break
|
| 1828 |
+
time.sleep(0.5)
|
| 1829 |
+
|
| 1830 |
+
|
| 1831 |
+
|
| 1832 |
+
def process_tasks(state, progress=gr.Progress()):
|
| 1833 |
+
gen = get_gen_info(state)
|
| 1834 |
+
queue = gen.get("queue", [])
|
| 1835 |
+
|
| 1836 |
+
if len(queue) == 0:
|
| 1837 |
+
return
|
| 1838 |
+
gen = get_gen_info(state)
|
| 1839 |
+
gen["location"] = state["image2video"]
|
| 1840 |
+
clear_file_list = server_config.get("clear_file_list", 0)
|
| 1841 |
+
file_list = gen.get("file_list", [])
|
| 1842 |
+
if clear_file_list > 0:
|
| 1843 |
+
file_list_current_size = len(file_list)
|
| 1844 |
+
keep_file_from = max(file_list_current_size - clear_file_list, 0)
|
| 1845 |
+
files_removed = keep_file_from
|
| 1846 |
+
choice = gen.get("selected",0)
|
| 1847 |
+
choice = max(choice- files_removed, 0)
|
| 1848 |
+
file_list = file_list[ keep_file_from: ]
|
| 1849 |
+
else:
|
| 1850 |
+
file_list = []
|
| 1851 |
+
choice = 0
|
| 1852 |
+
gen["selected"] = choice
|
| 1853 |
+
gen["file_list"] = file_list
|
| 1854 |
+
|
| 1855 |
+
start_time = time.time()
|
| 1856 |
+
|
| 1857 |
+
global gen_in_progress
|
| 1858 |
+
gen_in_progress = True
|
| 1859 |
+
gen["in_progress"] = True
|
| 1860 |
+
|
| 1861 |
+
prompt_no = 0
|
| 1862 |
+
while len(queue) > 0:
|
| 1863 |
+
prompt_no += 1
|
| 1864 |
+
gen["prompt_no"] = prompt_no
|
| 1865 |
+
task = queue[0]
|
| 1866 |
+
task_id = task["id"]
|
| 1867 |
+
params = task['params']
|
| 1868 |
+
iterator = iter(generate_video(task_id, progress, **params))
|
| 1869 |
+
while True:
|
| 1870 |
+
try:
|
| 1871 |
+
ok = False
|
| 1872 |
+
status = next(iterator, "#")
|
| 1873 |
+
if status == "#":
|
| 1874 |
+
break
|
| 1875 |
+
ok = True
|
| 1876 |
+
except Exception as e:
|
| 1877 |
+
_ , exc_value, exc_traceback = sys.exc_info()
|
| 1878 |
+
raise exc_value.with_traceback(exc_traceback)
|
| 1879 |
+
finally:
|
| 1880 |
+
if not ok:
|
| 1881 |
+
queue.clear()
|
| 1882 |
+
yield status
|
| 1883 |
+
|
| 1884 |
+
queue[:] = [item for item in queue if item['id'] != task['id']]
|
| 1885 |
+
|
| 1886 |
+
gen["prompts_max"] = 0
|
| 1887 |
+
gen["prompt"] = ""
|
| 1888 |
+
end_time = time.time()
|
| 1889 |
+
if gen.get("abort"):
|
| 1890 |
+
yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
|
| 1891 |
+
else:
|
| 1892 |
+
yield f"Total Generation Time: {end_time-start_time:.1f}s"
|
| 1893 |
+
|
| 1894 |
+
|
| 1895 |
+
def get_generation_status(prompt_no, prompts_max, repeat_no, repeat_max):
|
| 1896 |
+
if prompts_max == 1:
|
| 1897 |
+
if repeat_max == 1:
|
| 1898 |
+
return "Video"
|
| 1899 |
+
else:
|
| 1900 |
+
return f"Sample {repeat_no}/{repeat_max}"
|
| 1901 |
+
else:
|
| 1902 |
+
if repeat_max == 1:
|
| 1903 |
+
return f"Prompt {prompt_no}/{prompts_max}"
|
| 1904 |
+
else:
|
| 1905 |
+
return f"Prompt {prompt_no}/{prompts_max}, Sample {repeat_no}/{repeat_max}"
|
| 1906 |
+
|
| 1907 |
+
|
| 1908 |
+
refresh_id = 0
|
| 1909 |
+
|
| 1910 |
+
def get_new_refresh_id():
|
| 1911 |
+
global refresh_id
|
| 1912 |
+
refresh_id += 1
|
| 1913 |
+
return refresh_id
|
| 1914 |
+
|
| 1915 |
+
def update_status(state):
|
| 1916 |
+
gen = get_gen_info(state)
|
| 1917 |
+
prompt_no = gen["prompt_no"]
|
| 1918 |
+
prompts_max = gen.get("prompts_max",0)
|
| 1919 |
+
total_generation = gen["total_generation"]
|
| 1920 |
+
repeat_no = gen["repeat_no"]
|
| 1921 |
+
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation)
|
| 1922 |
+
gen["progress_status"] = status
|
| 1923 |
+
gen["refresh"] = get_new_refresh_id()
|
| 1924 |
+
|
| 1925 |
+
|
| 1926 |
+
def one_more_sample(state):
|
| 1927 |
+
gen = get_gen_info(state)
|
| 1928 |
+
extra_orders = gen.get("extra_orders", 0)
|
| 1929 |
+
extra_orders += 1
|
| 1930 |
+
gen["extra_orders"] = extra_orders
|
| 1931 |
+
in_progress = gen.get("in_progress", False)
|
| 1932 |
+
if not in_progress :
|
| 1933 |
+
return state
|
| 1934 |
+
prompt_no = gen["prompt_no"]
|
| 1935 |
+
prompts_max = gen.get("prompts_max",0)
|
| 1936 |
+
total_generation = gen["total_generation"] + extra_orders
|
| 1937 |
+
repeat_no = gen["repeat_no"]
|
| 1938 |
+
status = get_generation_status(prompt_no, prompts_max, repeat_no, total_generation)
|
| 1939 |
+
|
| 1940 |
+
|
| 1941 |
+
gen["progress_status"] = status
|
| 1942 |
+
gen["refresh"] = get_new_refresh_id()
|
| 1943 |
+
gr.Info(f"An extra sample generation is planned for a total of {total_generation} videos for this prompt")
|
| 1944 |
+
|
| 1945 |
+
return state
|
| 1946 |
|
| 1947 |
def get_new_preset_msg(advanced = True):
|
| 1948 |
if advanced:
|
|
|
|
| 1994 |
|
| 1995 |
|
| 1996 |
lset_name_filename = lset_name + ".lset"
|
| 1997 |
+
full_lset_name_filename = os.path.join(get_lora_dir(state["image2video"]), lset_name_filename)
|
| 1998 |
|
| 1999 |
with open(full_lset_name_filename, "w", encoding="utf-8") as writer:
|
| 2000 |
writer.write(json.dumps(lset, indent=4))
|
|
|
|
| 2011 |
|
| 2012 |
def delete_lset(state, lset_name):
|
| 2013 |
loras_presets = state["loras_presets"]
|
| 2014 |
+
lset_name_filename = os.path.join( get_lora_dir(state["image2video"]), sanitize_file_name(lset_name) + ".lset" )
|
| 2015 |
if len(lset_name) > 0 and lset_name != get_new_preset_msg(True) and lset_name != get_new_preset_msg(False):
|
| 2016 |
if not os.path.isfile(lset_name_filename):
|
| 2017 |
raise gr.Error(f"Preset '{lset_name}' not found ")
|
|
|
|
| 2032 |
def refresh_lora_list(state, lset_name, loras_choices):
|
| 2033 |
loras_names = state["loras_names"]
|
| 2034 |
prev_lora_names_selected = [ loras_names[int(i)] for i in loras_choices]
|
| 2035 |
+
image2video= state["image2video"]
|
| 2036 |
+
loras, loras_names, loras_presets, _, _, _, _ = setup_loras(image2video, None, get_lora_dir(image2video), lora_preselected_preset, None)
|
| 2037 |
state["loras"] = loras
|
| 2038 |
state["loras_names"] = loras_names
|
| 2039 |
state["loras_presets"] = loras_presets
|
|
|
|
| 2073 |
gr.Info("Please choose a preset in the list or create one")
|
| 2074 |
else:
|
| 2075 |
loras = state["loras"]
|
| 2076 |
+
loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(state["image2video"], lset_name, loras)
|
| 2077 |
if len(error) > 0:
|
| 2078 |
gr.Info(error)
|
| 2079 |
else:
|
|
|
|
| 2274 |
if state.get("validate_success",0) != 1:
|
| 2275 |
return
|
| 2276 |
|
| 2277 |
+
image2video = state["image2video"]
|
| 2278 |
+
ui_defaults = get_settings_dict(state, image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
| 2279 |
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
|
| 2280 |
|
| 2281 |
+
defaults_filename = get_settings_file_name(image2video)
|
| 2282 |
|
| 2283 |
with open(defaults_filename, "w", encoding="utf-8") as f:
|
| 2284 |
json.dump(ui_defaults, f, indent=4)
|
|
|
|
| 2321 |
|
| 2322 |
state_dict["advanced"] = advanced
|
| 2323 |
state_dict["loras_model"] = filename
|
| 2324 |
+
state_dict["image2video"] = image2video
|
| 2325 |
+
gen = dict()
|
| 2326 |
+
gen["queue"] = []
|
| 2327 |
+
state_dict["gen"] = gen
|
| 2328 |
+
|
| 2329 |
+
preset_to_load = lora_preselected_preset if use_image2video == image2video else ""
|
| 2330 |
|
| 2331 |
loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset = setup_loras(image2video, None, get_lora_dir(image2video), preset_to_load, None)
|
| 2332 |
|
|
|
|
| 2339 |
launch_loras = []
|
| 2340 |
launch_multis_str = ""
|
| 2341 |
|
| 2342 |
+
if len(default_lora_preset) > 0 and image2video == use_image2video:
|
| 2343 |
launch_preset = default_lora_preset
|
| 2344 |
launch_prompt = default_lora_preset_prompt
|
| 2345 |
launch_loras = default_loras_choices
|
|
|
|
| 2364 |
|
| 2365 |
|
| 2366 |
header = gr.Markdown(generate_header(model_filename, compile, attention_mode))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2367 |
with gr.Row():
|
| 2368 |
with gr.Column():
|
| 2369 |
with gr.Column(visible=False, elem_id="image-modal-container") as modal_container:
|
|
|
|
| 2591 |
cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)")
|
| 2592 |
|
| 2593 |
with gr.Row():
|
| 2594 |
+
save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config)
|
| 2595 |
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
|
| 2596 |
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
|
| 2597 |
with gr.Column():
|
| 2598 |
+
gen_status = gr.Text(label="Status", interactive= False)
|
| 2599 |
+
full_sync = gr.Text(label="Status", interactive= False, visible= False)
|
| 2600 |
+
light_sync = gr.Text(label="Status", interactive= False, visible= False)
|
| 2601 |
gen_progress_html = gr.HTML(
|
| 2602 |
label="Status",
|
| 2603 |
value="Idle",
|
| 2604 |
+
elem_id="generation_progress_bar_container", visible= False
|
| 2605 |
)
|
| 2606 |
output = gr.Gallery(
|
| 2607 |
label="Generated videos", show_label=False, elem_id="gallery"
|
| 2608 |
, columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
|
| 2609 |
generate_btn = gr.Button("Generate")
|
| 2610 |
+
add_to_queue_btn = gr.Button("Add New Prompt To Queue", visible = False)
|
| 2611 |
+
|
| 2612 |
+
with gr.Column(visible= False) as current_gen_column:
|
| 2613 |
+
with gr.Row():
|
| 2614 |
+
gen_info = gr.HTML(visible=False, min_height=1)
|
| 2615 |
+
with gr.Row():
|
| 2616 |
+
onemore_btn = gr.Button("One More Sample Please !")
|
| 2617 |
+
abort_btn = gr.Button("Abort")
|
| 2618 |
+
|
| 2619 |
+
queue_df = gr.DataFrame(
|
| 2620 |
+
headers=["Qty","Prompt", "Length","Steps","Start", "End", "", "", ""],
|
| 2621 |
+
datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
|
| 2622 |
+
interactive=False,
|
| 2623 |
+
col_count=(9, "fixed"),
|
| 2624 |
+
wrap=True,
|
| 2625 |
+
value=[],
|
| 2626 |
+
visible= False,
|
| 2627 |
+
# every=1,
|
| 2628 |
+
elem_id="queue_df"
|
| 2629 |
+
)
|
| 2630 |
+
# queue_df = gr.HTML("",
|
| 2631 |
+
# visible= False,
|
| 2632 |
+
# elem_id="queue_df"
|
| 2633 |
+
# )
|
| 2634 |
+
|
| 2635 |
+
def handle_selection(state, evt: gr.SelectData):
|
| 2636 |
+
gen = get_gen_info(state)
|
| 2637 |
+
queue = gen.get("queue", [])
|
| 2638 |
+
|
| 2639 |
+
if evt.index is None:
|
| 2640 |
+
return gr.update(), gr.update(), gr.update(visible=False)
|
| 2641 |
+
row_index, col_index = evt.index
|
| 2642 |
+
cell_value = None
|
| 2643 |
+
if col_index in [6, 7, 8]:
|
| 2644 |
+
if col_index == 6: cell_value = "↑"
|
| 2645 |
+
elif col_index == 7: cell_value = "↓"
|
| 2646 |
+
elif col_index == 8: cell_value = "✖"
|
| 2647 |
+
if col_index == 6:
|
| 2648 |
+
new_df_data = move_up(queue, [row_index])
|
| 2649 |
+
return new_df_data, gr.update(), gr.update(visible=False)
|
| 2650 |
+
elif col_index == 7:
|
| 2651 |
+
new_df_data = move_down(queue, [row_index])
|
| 2652 |
+
return new_df_data, gr.update(), gr.update(visible=False)
|
| 2653 |
+
elif col_index == 8:
|
| 2654 |
+
new_df_data = remove_task(queue, [row_index])
|
| 2655 |
+
gen["prompts_max"] = gen.get("prompts_max",0) - 1
|
| 2656 |
+
update_status(state)
|
| 2657 |
+
return new_df_data, gr.update(), gr.update(visible=False)
|
| 2658 |
+
start_img_col_idx = 4
|
| 2659 |
+
end_img_col_idx = 5
|
| 2660 |
+
image_data_to_show = None
|
| 2661 |
+
if col_index == start_img_col_idx:
|
| 2662 |
+
with lock:
|
| 2663 |
+
if row_index < len(queue):
|
| 2664 |
+
image_data_to_show = queue[row_index].get('start_image_data')
|
| 2665 |
+
elif col_index == end_img_col_idx:
|
| 2666 |
+
with lock:
|
| 2667 |
+
if row_index < len(queue):
|
| 2668 |
+
image_data_to_show = queue[row_index].get('end_image_data')
|
| 2669 |
+
|
| 2670 |
+
if image_data_to_show:
|
| 2671 |
+
return gr.update(), gr.update(value=image_data_to_show), gr.update(visible=True)
|
| 2672 |
+
else:
|
| 2673 |
+
return gr.update(), gr.update(), gr.update(visible=False)
|
| 2674 |
+
selected_indices = gr.State([])
|
| 2675 |
+
queue_df.select(
|
| 2676 |
+
fn=handle_selection,
|
| 2677 |
+
inputs=state,
|
| 2678 |
+
outputs=[queue_df, modal_image_display, modal_container],
|
| 2679 |
+
)
|
| 2680 |
+
# gallery_update_trigger.change(
|
| 2681 |
+
# fn=refresh_gallery_on_trigger,
|
| 2682 |
+
# inputs=[state],
|
| 2683 |
+
# outputs=[output]
|
| 2684 |
+
# )
|
| 2685 |
+
# queue_df.change(
|
| 2686 |
+
# fn=refresh_gallery,
|
| 2687 |
+
# inputs=[state],
|
| 2688 |
+
# outputs=[gallery_update_trigger]
|
| 2689 |
+
# ).then(
|
| 2690 |
+
# fn=refresh_progress,
|
| 2691 |
+
# inputs=None,
|
| 2692 |
+
# outputs=[progress_update_trigger]
|
| 2693 |
+
# )
|
| 2694 |
+
progress_update_trigger.change(
|
| 2695 |
+
fn=update_generation_status,
|
| 2696 |
+
inputs=[progress_update_trigger],
|
| 2697 |
+
outputs=[gen_progress_html],
|
| 2698 |
+
show_progress="hidden"
|
| 2699 |
+
)
|
| 2700 |
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
|
| 2701 |
save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
|
| 2702 |
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
|
|
|
|
| 2712 |
)
|
| 2713 |
refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
| 2714 |
refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
|
|
|
| 2715 |
output.select(select_video, state, None )
|
| 2716 |
|
| 2717 |
+
gen_status.change(refresh_gallery,
|
| 2718 |
+
inputs = [state, gen_status],
|
| 2719 |
+
outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn])
|
| 2720 |
+
|
| 2721 |
+
full_sync.change(refresh_gallery,
|
| 2722 |
+
inputs = [state, gen_status],
|
| 2723 |
+
outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
|
| 2724 |
+
).then( fn=wait_tasks_done,
|
| 2725 |
+
inputs= [state],
|
| 2726 |
+
outputs =[gen_status],
|
| 2727 |
+
).then(finalize_generation,
|
| 2728 |
+
inputs= [state],
|
| 2729 |
+
outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
|
| 2730 |
+
)
|
| 2731 |
+
light_sync.change(refresh_gallery,
|
| 2732 |
+
inputs = [state, gen_status],
|
| 2733 |
+
outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
|
| 2734 |
+
)
|
| 2735 |
+
|
| 2736 |
+
abort_btn.click(abort_generation, [state], [gen_status, abort_btn] ) #.then(refresh_gallery, inputs = [state, gen_info], outputs = [output, gen_info, queue_df] )
|
| 2737 |
+
onemore_btn.click(fn=one_more_sample,inputs=[state], outputs= [state])
|
| 2738 |
+
|
| 2739 |
+
|
| 2740 |
+
gen_inputs=[
|
| 2741 |
+
prompt,
|
| 2742 |
+
negative_prompt,
|
| 2743 |
+
resolution,
|
| 2744 |
+
video_length,
|
| 2745 |
+
seed,
|
| 2746 |
+
num_inference_steps,
|
| 2747 |
+
guidance_scale,
|
| 2748 |
+
flow_shift,
|
| 2749 |
+
embedded_guidance_scale,
|
| 2750 |
+
repeat_generation,
|
| 2751 |
+
multi_images_gen_type,
|
| 2752 |
+
tea_cache_setting,
|
| 2753 |
+
tea_cache_start_step_perc,
|
| 2754 |
+
loras_choices,
|
| 2755 |
+
loras_mult_choices,
|
| 2756 |
+
image_prompt_type_radio,
|
| 2757 |
+
image_to_continue,
|
| 2758 |
+
image_to_end,
|
| 2759 |
+
video_to_continue,
|
| 2760 |
+
max_frames,
|
| 2761 |
+
RIFLEx_setting,
|
| 2762 |
+
slg_switch,
|
| 2763 |
+
slg_layers,
|
| 2764 |
+
slg_start_perc,
|
| 2765 |
+
slg_end_perc,
|
| 2766 |
+
cfg_star_switch,
|
| 2767 |
+
cfg_zero_step,
|
| 2768 |
+
state,
|
| 2769 |
+
gr.State(image2video)
|
| 2770 |
+
]
|
| 2771 |
+
|
| 2772 |
+
generate_btn.click(fn=validate_wizard_prompt,
|
| 2773 |
+
inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
|
| 2774 |
+
outputs= [prompt]
|
| 2775 |
+
).then(fn=process_prompt_and_add_tasks,
|
| 2776 |
+
inputs = gen_inputs,
|
| 2777 |
+
outputs= queue_df
|
| 2778 |
+
).then(fn=prepare_generate_video,
|
| 2779 |
+
inputs= [state],
|
| 2780 |
+
outputs= [generate_btn, add_to_queue_btn, current_gen_column],
|
| 2781 |
+
).then(fn=process_tasks,
|
| 2782 |
+
inputs= [state],
|
| 2783 |
+
outputs= [gen_status],
|
| 2784 |
+
).then(finalize_generation,
|
| 2785 |
+
inputs= [state],
|
| 2786 |
+
outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
|
| 2787 |
+
)
|
| 2788 |
+
|
| 2789 |
+
add_to_queue_btn.click(fn=validate_wizard_prompt,
|
| 2790 |
+
inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
|
| 2791 |
+
outputs= [prompt]
|
| 2792 |
).then(
|
| 2793 |
fn=process_prompt_and_add_tasks,
|
| 2794 |
+
inputs = gen_inputs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2795 |
outputs=queue_df
|
| 2796 |
+
).then(
|
| 2797 |
+
fn=update_status,
|
| 2798 |
+
inputs = [state],
|
| 2799 |
)
|
| 2800 |
+
|
| 2801 |
+
|
| 2802 |
close_modal_button.click(
|
| 2803 |
lambda: gr.update(visible=False),
|
| 2804 |
inputs=[],
|
| 2805 |
outputs=[modal_container]
|
| 2806 |
)
|
| 2807 |
+
return loras_column, loras_choices, presets_column, lset_name, header, light_sync, full_sync, state
|
| 2808 |
+
|
| 2809 |
+
def generate_doxnload_tab(presets_column, loras_column, lset_name,loras_choices, state):
|
| 2810 |
+
with gr.Row():
|
| 2811 |
+
with gr.Row(scale =2):
|
| 2812 |
+
gr.Markdown("<I>Wan2GP's Lora Festival ! Press the following button to download i2v <B>Remade</B> Loras collection (and bonuses Loras).")
|
| 2813 |
+
with gr.Row(scale =1):
|
| 2814 |
+
download_loras_btn = gr.Button("---> Let the Lora's Festival Start !", scale =1)
|
| 2815 |
+
with gr.Row(scale =1):
|
| 2816 |
+
gr.Markdown("")
|
| 2817 |
+
with gr.Row() as download_status_row:
|
| 2818 |
+
download_status = gr.Markdown()
|
| 2819 |
+
|
| 2820 |
+
download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status, presets_column, loras_column]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
| 2821 |
|
| 2822 |
+
|
| 2823 |
def generate_configuration_tab():
|
| 2824 |
state_dict = {}
|
| 2825 |
state = gr.State(state_dict)
|
|
|
|
| 2836 |
value= index,
|
| 2837 |
label="Transformer model for Text to Video",
|
| 2838 |
interactive= not lock_ui_transformer,
|
| 2839 |
+
visible=True
|
| 2840 |
)
|
| 2841 |
index = transformer_choices_i2v.index(transformer_filename_i2v)
|
| 2842 |
index = 0 if index ==0 else index
|
|
|
|
| 2853 |
value= index,
|
| 2854 |
label="Transformer model for Image to Video",
|
| 2855 |
interactive= not lock_ui_transformer,
|
| 2856 |
+
visible = True,
|
| 2857 |
)
|
| 2858 |
index = text_encoder_choices.index(text_encoder_filename)
|
| 2859 |
index = 0 if index ==0 else index
|
|
|
|
| 2949 |
reload_choice = gr.Dropdown(
|
| 2950 |
choices=[
|
| 2951 |
("When changing tabs", 1),
|
| 2952 |
+
("When pressing Generate", 2),
|
| 2953 |
],
|
| 2954 |
value=server_config.get("reload_model",2),
|
| 2955 |
label="Reload model"
|
|
|
|
| 3002 |
gr.Markdown("- <B>Remade_AI</B> : for creating their awesome Loras collection")
|
| 3003 |
|
| 3004 |
|
| 3005 |
+
def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData):
|
|
|
|
|
|
|
| 3006 |
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
|
| 3007 |
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
|
| 3008 |
|
| 3009 |
new_t2v = evt.index == 0
|
| 3010 |
new_i2v = evt.index == 1
|
| 3011 |
+
i2v_light_sync = gr.Text()
|
| 3012 |
+
t2v_light_sync = gr.Text()
|
| 3013 |
+
i2v_full_sync = gr.Text()
|
| 3014 |
+
t2v_full_sync = gr.Text()
|
| 3015 |
+
if new_t2v or new_i2v:
|
| 3016 |
+
last_tab_was_image2video =global_state.get("last_tab_was_image2video", None)
|
| 3017 |
+
if last_tab_was_image2video == None or last_tab_was_image2video:
|
| 3018 |
+
gen = i2v_state["gen"]
|
| 3019 |
+
t2v_state["gen"] = gen
|
| 3020 |
+
else:
|
| 3021 |
+
gen = t2v_state["gen"]
|
| 3022 |
+
i2v_state["gen"] = gen
|
| 3023 |
+
|
| 3024 |
+
|
| 3025 |
+
if last_tab_was_image2video != None and new_t2v != new_i2v:
|
| 3026 |
+
gen_location = gen.get("location", None)
|
| 3027 |
+
if "in_progress" in gen and gen_location !=None and not (gen_location and new_i2v or not gen_location and new_t2v) :
|
| 3028 |
+
if new_i2v:
|
| 3029 |
+
i2v_full_sync = gr.Text(str(time.time()))
|
| 3030 |
+
else:
|
| 3031 |
+
t2v_full_sync = gr.Text(str(time.time()))
|
| 3032 |
+
else:
|
| 3033 |
+
if new_i2v:
|
| 3034 |
+
i2v_light_sync = gr.Text(str(time.time()))
|
| 3035 |
+
else:
|
| 3036 |
+
t2v_light_sync = gr.Text(str(time.time()))
|
| 3037 |
+
|
| 3038 |
+
|
| 3039 |
+
global_state["last_tab_was_image2video"] = new_i2v
|
| 3040 |
|
| 3041 |
if(server_config.get("reload_model",2) == 1):
|
| 3042 |
+
queue = gen.get("queue", [])
|
| 3043 |
+
|
| 3044 |
+
queue_empty = len(queue) == 0
|
| 3045 |
if queue_empty:
|
| 3046 |
global wan_model, offloadobj
|
| 3047 |
if wan_model is not None:
|
|
|
|
| 3051 |
wan_model = None
|
| 3052 |
gc.collect()
|
| 3053 |
torch.cuda.empty_cache()
|
| 3054 |
+
wan_model, offloadobj, trans = load_models(new_i2v)
|
| 3055 |
del trans
|
| 3056 |
|
| 3057 |
if new_t2v or new_i2v:
|
|
|
|
| 3077 |
gr.Column(visible= visible),
|
| 3078 |
gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible),
|
| 3079 |
t2v_header,
|
| 3080 |
+
t2v_light_sync,
|
| 3081 |
+
t2v_full_sync,
|
| 3082 |
gr.Column(),
|
| 3083 |
gr.Dropdown(),
|
| 3084 |
gr.Column(),
|
| 3085 |
gr.Dropdown(),
|
| 3086 |
+
gr.Markdown(),
|
| 3087 |
+
gr.Text(),
|
| 3088 |
+
gr.Text(),
|
| 3089 |
]
|
| 3090 |
else:
|
| 3091 |
return [
|
|
|
|
| 3093 |
gr.Dropdown(),
|
| 3094 |
gr.Column(),
|
| 3095 |
gr.Dropdown(),
|
| 3096 |
+
gr.Markdown(),
|
| 3097 |
+
gr.Text(),
|
| 3098 |
+
gr.Text(),
|
| 3099 |
+
gr.Text(),
|
| 3100 |
gr.Column(visible= visible),
|
| 3101 |
gr.Dropdown(choices=new_loras_choices, visible=visible, value=[]),
|
| 3102 |
gr.Column(visible= visible),
|
| 3103 |
gr.Dropdown(choices=lset_choices, value=get_new_preset_msg(advanced), visible=visible),
|
| 3104 |
i2v_header,
|
| 3105 |
+
i2v_light_sync,
|
| 3106 |
+
i2v_full_sync,
|
| 3107 |
]
|
| 3108 |
|
| 3109 |
+
return [gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), t2v_header, t2v_light_sync, t2v_full_sync,
|
| 3110 |
+
gr.Column(), gr.Dropdown(), gr.Column(), gr.Dropdown(), i2v_header, i2v_light_sync, i2v_full_sync]
|
| 3111 |
|
| 3112 |
|
| 3113 |
def create_demo():
|
|
|
|
| 3167 |
overflow: hidden;
|
| 3168 |
text-overflow: ellipsis;
|
| 3169 |
}
|
| 3170 |
+
# #queue_df td:nth-child(-n+5) {
|
| 3171 |
+
# cursor: default !important;
|
| 3172 |
+
# pointer-events: none;
|
| 3173 |
+
# }
|
| 3174 |
+
# #queue_df td:nth-child(6) {
|
| 3175 |
+
# cursor: default !important;
|
| 3176 |
+
# }
|
| 3177 |
+
# #queue_df th {
|
| 3178 |
+
# pointer-events: none;
|
| 3179 |
+
# text-align: center;
|
| 3180 |
+
# vertical-align: middle;
|
| 3181 |
+
# }
|
| 3182 |
+
# #queue_df table {
|
| 3183 |
+
# width: 100%;
|
| 3184 |
+
# overflow: hidden !important;
|
| 3185 |
+
# }
|
| 3186 |
+
# #queue_df::-webkit-scrollbar {
|
| 3187 |
+
# display: none !important;
|
| 3188 |
+
# }
|
| 3189 |
+
# #queue_df {
|
| 3190 |
+
# scrollbar-width: none !important;
|
| 3191 |
+
# -ms-overflow-style: none !important;
|
| 3192 |
+
# }
|
| 3193 |
+
# #queue_df th:nth-child(1),
|
| 3194 |
+
# #queue_df td:nth-child(1) {
|
| 3195 |
+
# width: 90px;
|
| 3196 |
+
# text-align: center;
|
| 3197 |
+
# vertical-align: middle;
|
| 3198 |
+
# }
|
| 3199 |
+
# #queue_df th:nth-child(1) {
|
| 3200 |
+
# font-size: 0.8em;
|
| 3201 |
+
# }
|
| 3202 |
+
# #queue_df th:nth-child(2),
|
| 3203 |
+
# #queue_df td:nth-child(2) {
|
| 3204 |
+
# width: 85px;
|
| 3205 |
+
# text-align: center;
|
| 3206 |
+
# vertical-align: middle;
|
| 3207 |
+
# }
|
| 3208 |
+
# #queue_df th:nth-child(2) {
|
| 3209 |
+
# font-size: 0.5em;
|
| 3210 |
+
# }
|
| 3211 |
+
# #queue_df th:nth-child(3),
|
| 3212 |
+
# #queue_df td:nth-child(3) {
|
| 3213 |
+
# width: 75px;
|
| 3214 |
+
# text-align: center;
|
| 3215 |
+
# vertical-align: middle;
|
| 3216 |
+
# }
|
| 3217 |
+
# #queue_df th:nth-child(3) {
|
| 3218 |
+
# font-size: 0.6em;
|
| 3219 |
+
# }
|
| 3220 |
+
# #queue_df th:nth-child(4),
|
| 3221 |
+
# #queue_df td:nth-child(4) {
|
| 3222 |
+
# width: 65px;
|
| 3223 |
+
# text-align: center;
|
| 3224 |
+
# white-space: nowrap;
|
| 3225 |
+
# }
|
| 3226 |
+
# #queue_df th:nth-child(4) {
|
| 3227 |
+
# font-size: 0.9em;
|
| 3228 |
+
# }
|
| 3229 |
+
# #queue_df th:nth-child(5),
|
| 3230 |
+
# #queue_df td:nth-child(5) {
|
| 3231 |
+
# width: 60px;
|
| 3232 |
+
# text-align: center;
|
| 3233 |
+
# white-space: nowrap;
|
| 3234 |
+
# }
|
| 3235 |
+
# #queue_df th:nth-child(6),
|
| 3236 |
+
# #queue_df td:nth-child(6) {
|
| 3237 |
+
# width: auto;
|
| 3238 |
+
# text-align: center;
|
| 3239 |
+
# white-space: normal;
|
| 3240 |
+
# }
|
| 3241 |
+
# #queue_df th:nth-child(6) {
|
| 3242 |
+
# font-size: 0.8em;
|
| 3243 |
+
# }
|
| 3244 |
+
# #queue_df th:nth-child(7), #queue_df td:nth-child(7),
|
| 3245 |
+
# #queue_df th:nth-child(8), #queue_df td:nth-child(8) {
|
| 3246 |
+
# width: 60px;
|
| 3247 |
+
# text-align: center;
|
| 3248 |
+
# vertical-align: middle;
|
| 3249 |
+
# }
|
| 3250 |
+
# #queue_df td:nth-child(7) img,
|
| 3251 |
+
# #queue_df td:nth-child(8) img {
|
| 3252 |
+
# max-width: 50px;
|
| 3253 |
+
# max-height: 50px;
|
| 3254 |
+
# object-fit: contain;
|
| 3255 |
+
# display: block;
|
| 3256 |
+
# margin: auto;
|
| 3257 |
+
# cursor: pointer;
|
| 3258 |
+
# }
|
| 3259 |
+
# #queue_df th:nth-child(9), #queue_df td:nth-child(9),
|
| 3260 |
+
# #queue_df th:nth-child(10), #queue_df td:nth-child(10),
|
| 3261 |
+
# #queue_df th:nth-child(11), #queue_df td:nth-child(11) {
|
| 3262 |
+
# width: 20px;
|
| 3263 |
+
# padding: 2px !important;
|
| 3264 |
+
# cursor: pointer;
|
| 3265 |
+
# text-align: center;
|
| 3266 |
+
# font-weight: bold;
|
| 3267 |
+
# vertical-align: middle;
|
| 3268 |
+
# }
|
| 3269 |
+
# #queue_df td:nth-child(7):hover,
|
| 3270 |
+
# #queue_df td:nth-child(8):hover,
|
| 3271 |
+
# #queue_df td:nth-child(9):hover,
|
| 3272 |
+
# #queue_df td:nth-child(10):hover,
|
| 3273 |
+
# #queue_df td:nth-child(11):hover {
|
| 3274 |
+
# background-color: #e0e0e0;
|
| 3275 |
+
# }
|
| 3276 |
#image-modal-container {
|
| 3277 |
position: fixed;
|
| 3278 |
top: 0;
|
|
|
|
| 3354 |
pointer-events: none;
|
| 3355 |
}
|
| 3356 |
"""
|
| 3357 |
+
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
|
| 3358 |
+
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.4 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
| 3359 |
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
|
| 3360 |
|
| 3361 |
with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
|
|
|
|
| 3365 |
gr.Markdown("- 1280 x 720 with a 14B model: 80 frames (5s): 11 GB of VRAM")
|
| 3366 |
gr.Markdown("It is not recommmended to generate a video longer than 8s (128 frames) even if there is still some VRAM left as some artifacts may appear")
|
| 3367 |
gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
|
| 3368 |
+
global_dict = {}
|
| 3369 |
+
global_dict["last_tab_was_image2video"] = use_image2video
|
| 3370 |
+
global_state = gr.State(global_dict)
|
| 3371 |
|
| 3372 |
with gr.Tabs(selected="i2v" if use_image2video else "t2v") as main_tabs:
|
| 3373 |
with gr.Tab("Text To Video", id="t2v") as t2v_tab:
|
| 3374 |
+
t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_light_sync, t2v_full_sync, t2v_state = generate_video_tab(False)
|
| 3375 |
with gr.Tab("Image To Video", id="i2v") as i2v_tab:
|
| 3376 |
+
i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync, i2v_state = generate_video_tab(True)
|
| 3377 |
if not args.lock_config:
|
| 3378 |
+
with gr.Tab("Downloads", id="downloads") as downloads_tab:
|
| 3379 |
+
generate_doxnload_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state)
|
| 3380 |
with gr.Tab("Configuration"):
|
| 3381 |
generate_configuration_tab()
|
| 3382 |
with gr.Tab("About"):
|
| 3383 |
generate_about_tab()
|
| 3384 |
main_tabs.select(
|
| 3385 |
fn=on_tab_select,
|
| 3386 |
+
inputs=[global_state, t2v_state, i2v_state],
|
| 3387 |
outputs=[
|
| 3388 |
+
t2v_loras_column, t2v_loras_choices, t2v_presets_column, t2v_lset_name, t2v_header, t2v_light_sync, t2v_full_sync,
|
| 3389 |
+
i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync
|
| 3390 |
]
|
| 3391 |
)
|
| 3392 |
return demo
|
| 3393 |
|
| 3394 |
if __name__ == "__main__":
|
| 3395 |
+
# threading.Thread(target=runner, daemon=True).start()
|
| 3396 |
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
|
| 3397 |
server_port = int(args.server_port)
|
| 3398 |
if os.name == "nt":
|
wan/image2video.py
CHANGED
|
@@ -40,7 +40,12 @@ def optimized_scale(positive_flat, negative_flat):
|
|
| 40 |
st_star = dot_product / squared_norm
|
| 41 |
|
| 42 |
return st_star
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
class WanI2V:
|
| 46 |
|
|
@@ -90,7 +95,6 @@ class WanI2V:
|
|
| 90 |
|
| 91 |
self.num_train_timesteps = config.num_train_timesteps
|
| 92 |
self.param_dtype = config.param_dtype
|
| 93 |
-
|
| 94 |
shard_fn = partial(shard_model, device_id=device_id)
|
| 95 |
self.text_encoder = T5EncoderModel(
|
| 96 |
text_len=config.text_len,
|
|
@@ -208,16 +212,16 @@ class WanI2V:
|
|
| 208 |
- H: Frame height (from max_area)
|
| 209 |
- W: Frame width from max_area)
|
| 210 |
"""
|
| 211 |
-
img = TF.to_tensor(img)
|
| 212 |
lat_frames = int((frame_num - 1) // self.vae_stride[0] + 1)
|
| 213 |
any_end_frame = img2 !=None
|
| 214 |
if any_end_frame:
|
| 215 |
any_end_frame = True
|
| 216 |
-
img2 = TF.to_tensor(img2)
|
| 217 |
if add_frames_for_end_image:
|
| 218 |
frame_num +=1
|
| 219 |
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
|
| 220 |
-
|
| 221 |
h, w = img.shape[1:]
|
| 222 |
aspect_ratio = h / w
|
| 223 |
lat_h = round(
|
|
@@ -229,6 +233,15 @@ class WanI2V:
|
|
| 229 |
h = lat_h * self.vae_stride[1]
|
| 230 |
w = lat_w * self.vae_stride[2]
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
|
| 233 |
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
|
| 234 |
|
|
@@ -273,21 +286,32 @@ class WanI2V:
|
|
| 273 |
|
| 274 |
from mmgp import offload
|
| 275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
offload.last_offload_obj.unload_all()
|
| 277 |
if any_end_frame:
|
| 278 |
-
img_interpolated = torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16)
|
| 279 |
-
img2_interpolated = torch.nn.functional.interpolate(img2[None].cpu(), size=(h, w), mode='bicubic').transpose(0, 1).to(torch.bfloat16)
|
| 280 |
mean2 = 0
|
| 281 |
enc= torch.concat([
|
| 282 |
img_interpolated,
|
| 283 |
-
torch.full( (3, frame_num-2, h, w), mean2, device=
|
| 284 |
-
|
| 285 |
], dim=1).to(self.device)
|
| 286 |
else:
|
| 287 |
enc= torch.concat([
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
torch.zeros(3, frame_num-1, h, w, device="cpu", dtype= torch.bfloat16)
|
| 291 |
], dim=1).to(self.device)
|
| 292 |
|
| 293 |
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
|
@@ -333,7 +357,8 @@ class WanI2V:
|
|
| 333 |
'seq_len': max_seq_len,
|
| 334 |
'y': [y],
|
| 335 |
'freqs' : freqs,
|
| 336 |
-
'pipeline' : self
|
|
|
|
| 337 |
}
|
| 338 |
|
| 339 |
arg_null = {
|
|
@@ -342,7 +367,8 @@ class WanI2V:
|
|
| 342 |
'seq_len': max_seq_len,
|
| 343 |
'y': [y],
|
| 344 |
'freqs' : freqs,
|
| 345 |
-
'pipeline' : self
|
|
|
|
| 346 |
}
|
| 347 |
|
| 348 |
arg_both= {
|
|
@@ -352,7 +378,8 @@ class WanI2V:
|
|
| 352 |
'seq_len': max_seq_len,
|
| 353 |
'y': [y],
|
| 354 |
'freqs' : freqs,
|
| 355 |
-
'pipeline' : self
|
|
|
|
| 356 |
}
|
| 357 |
|
| 358 |
if offload_model:
|
|
@@ -363,7 +390,7 @@ class WanI2V:
|
|
| 363 |
|
| 364 |
# self.model.to(self.device)
|
| 365 |
if callback != None:
|
| 366 |
-
callback(-1,
|
| 367 |
|
| 368 |
for i, t in enumerate(tqdm(timesteps)):
|
| 369 |
offload.set_step_no_for_lora(self.model, i)
|
|
@@ -437,7 +464,7 @@ class WanI2V:
|
|
| 437 |
del timestep
|
| 438 |
|
| 439 |
if callback is not None:
|
| 440 |
-
callback(i,
|
| 441 |
|
| 442 |
|
| 443 |
x0 = [latent.to(self.device, dtype=torch.bfloat16)]
|
|
@@ -451,7 +478,7 @@ class WanI2V:
|
|
| 451 |
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
| 452 |
|
| 453 |
if any_end_frame and add_frames_for_end_image:
|
| 454 |
-
# video[:, -1:] =
|
| 455 |
video = video[:, :-1]
|
| 456 |
|
| 457 |
else:
|
|
|
|
| 40 |
st_star = dot_product / squared_norm
|
| 41 |
|
| 42 |
return st_star
|
| 43 |
+
|
| 44 |
+
def resize_lanczos(img, h, w):
|
| 45 |
+
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
|
| 46 |
+
img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
|
| 47 |
+
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
|
| 48 |
+
|
| 49 |
|
| 50 |
class WanI2V:
|
| 51 |
|
|
|
|
| 95 |
|
| 96 |
self.num_train_timesteps = config.num_train_timesteps
|
| 97 |
self.param_dtype = config.param_dtype
|
|
|
|
| 98 |
shard_fn = partial(shard_model, device_id=device_id)
|
| 99 |
self.text_encoder = T5EncoderModel(
|
| 100 |
text_len=config.text_len,
|
|
|
|
| 212 |
- H: Frame height (from max_area)
|
| 213 |
- W: Frame width from max_area)
|
| 214 |
"""
|
| 215 |
+
img = TF.to_tensor(img)
|
| 216 |
lat_frames = int((frame_num - 1) // self.vae_stride[0] + 1)
|
| 217 |
any_end_frame = img2 !=None
|
| 218 |
if any_end_frame:
|
| 219 |
any_end_frame = True
|
| 220 |
+
img2 = TF.to_tensor(img2)
|
| 221 |
if add_frames_for_end_image:
|
| 222 |
frame_num +=1
|
| 223 |
lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
|
| 224 |
+
|
| 225 |
h, w = img.shape[1:]
|
| 226 |
aspect_ratio = h / w
|
| 227 |
lat_h = round(
|
|
|
|
| 233 |
h = lat_h * self.vae_stride[1]
|
| 234 |
w = lat_w * self.vae_stride[2]
|
| 235 |
|
| 236 |
+
clip_image_size = self.clip.model.image_size
|
| 237 |
+
img_interpolated = resize_lanczos(img, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device)
|
| 238 |
+
img = resize_lanczos(img, clip_image_size, clip_image_size)
|
| 239 |
+
img = img.sub_(0.5).div_(0.5).to(self.device)
|
| 240 |
+
if img2!= None:
|
| 241 |
+
img_interpolated2 = resize_lanczos(img2, h, w).sub_(0.5).div_(0.5).unsqueeze(0).transpose(0,1).to(self.device)
|
| 242 |
+
img2 = resize_lanczos(img2, clip_image_size, clip_image_size)
|
| 243 |
+
img2 = img2.sub_(0.5).div_(0.5).to(self.device)
|
| 244 |
+
|
| 245 |
max_seq_len = lat_frames * lat_h * lat_w // ( self.patch_size[1] * self.patch_size[2])
|
| 246 |
max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
|
| 247 |
|
|
|
|
| 286 |
|
| 287 |
from mmgp import offload
|
| 288 |
|
| 289 |
+
|
| 290 |
+
# img_interpolated.save('aaa.png')
|
| 291 |
+
|
| 292 |
+
# img_interpolated = torch.from_numpy(np.array(img_interpolated).astype(np.float32) / 255.0).movedim(-1, 0)
|
| 293 |
+
|
| 294 |
+
# img_interpolated = torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode='lanczos')
|
| 295 |
+
# img_interpolated = img_interpolated.squeeze(0).transpose(0,2).transpose(1,0)
|
| 296 |
+
# img_interpolated = img_interpolated.clamp(-1, 1)
|
| 297 |
+
# img_interpolated = (img_interpolated + 1)/2
|
| 298 |
+
# img_interpolated = (img_interpolated*255).type(torch.uint8)
|
| 299 |
+
# img_interpolated = img_interpolated.cpu().numpy()
|
| 300 |
+
# xxx = Image.fromarray(img_interpolated, 'RGB')
|
| 301 |
+
# xxx.save('my.png')
|
| 302 |
+
|
| 303 |
offload.last_offload_obj.unload_all()
|
| 304 |
if any_end_frame:
|
|
|
|
|
|
|
| 305 |
mean2 = 0
|
| 306 |
enc= torch.concat([
|
| 307 |
img_interpolated,
|
| 308 |
+
torch.full( (3, frame_num-2, h, w), mean2, device=self.device, dtype= torch.bfloat16),
|
| 309 |
+
img_interpolated2,
|
| 310 |
], dim=1).to(self.device)
|
| 311 |
else:
|
| 312 |
enc= torch.concat([
|
| 313 |
+
img_interpolated,
|
| 314 |
+
torch.zeros(3, frame_num-1, h, w, device=self.device, dtype= torch.bfloat16)
|
|
|
|
| 315 |
], dim=1).to(self.device)
|
| 316 |
|
| 317 |
lat_y = self.vae.encode([enc], VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
|
|
|
| 357 |
'seq_len': max_seq_len,
|
| 358 |
'y': [y],
|
| 359 |
'freqs' : freqs,
|
| 360 |
+
'pipeline' : self,
|
| 361 |
+
'callback' : callback
|
| 362 |
}
|
| 363 |
|
| 364 |
arg_null = {
|
|
|
|
| 367 |
'seq_len': max_seq_len,
|
| 368 |
'y': [y],
|
| 369 |
'freqs' : freqs,
|
| 370 |
+
'pipeline' : self,
|
| 371 |
+
'callback' : callback
|
| 372 |
}
|
| 373 |
|
| 374 |
arg_both= {
|
|
|
|
| 378 |
'seq_len': max_seq_len,
|
| 379 |
'y': [y],
|
| 380 |
'freqs' : freqs,
|
| 381 |
+
'pipeline' : self,
|
| 382 |
+
'callback' : callback
|
| 383 |
}
|
| 384 |
|
| 385 |
if offload_model:
|
|
|
|
| 390 |
|
| 391 |
# self.model.to(self.device)
|
| 392 |
if callback != None:
|
| 393 |
+
callback(-1, True)
|
| 394 |
|
| 395 |
for i, t in enumerate(tqdm(timesteps)):
|
| 396 |
offload.set_step_no_for_lora(self.model, i)
|
|
|
|
| 464 |
del timestep
|
| 465 |
|
| 466 |
if callback is not None:
|
| 467 |
+
callback(i, False)
|
| 468 |
|
| 469 |
|
| 470 |
x0 = [latent.to(self.device, dtype=torch.bfloat16)]
|
|
|
|
| 478 |
video = self.vae.decode(x0, VAE_tile_size, any_end_frame= any_end_frame and add_frames_for_end_image)[0]
|
| 479 |
|
| 480 |
if any_end_frame and add_frames_for_end_image:
|
| 481 |
+
# video[:, -1:] = img_interpolated2
|
| 482 |
video = video[:, :-1]
|
| 483 |
|
| 484 |
else:
|
wan/modules/model.py
CHANGED
|
@@ -704,6 +704,7 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 704 |
is_uncond=False,
|
| 705 |
max_steps = 0,
|
| 706 |
slg_layers=None,
|
|
|
|
| 707 |
):
|
| 708 |
r"""
|
| 709 |
Forward pass through the diffusion model
|
|
@@ -835,12 +836,10 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 835 |
freqs=freqs,
|
| 836 |
# context=context,
|
| 837 |
context_lens=context_lens)
|
| 838 |
-
|
| 839 |
for block_idx, block in enumerate(self.blocks):
|
| 840 |
offload.shared_state["layer"] = block_idx
|
| 841 |
-
if
|
| 842 |
-
|
| 843 |
-
offload.shared_state["callback"](-1, -1, True)
|
| 844 |
if pipeline._interrupt:
|
| 845 |
if joint_pass:
|
| 846 |
return None, None
|
|
|
|
| 704 |
is_uncond=False,
|
| 705 |
max_steps = 0,
|
| 706 |
slg_layers=None,
|
| 707 |
+
callback = None,
|
| 708 |
):
|
| 709 |
r"""
|
| 710 |
Forward pass through the diffusion model
|
|
|
|
| 836 |
freqs=freqs,
|
| 837 |
# context=context,
|
| 838 |
context_lens=context_lens)
|
|
|
|
| 839 |
for block_idx, block in enumerate(self.blocks):
|
| 840 |
offload.shared_state["layer"] = block_idx
|
| 841 |
+
if callback != None:
|
| 842 |
+
callback(-1, False, True)
|
|
|
|
| 843 |
if pipeline._interrupt:
|
| 844 |
if joint_pass:
|
| 845 |
return None, None
|
wan/text2video.py
CHANGED
|
@@ -268,7 +268,7 @@ class WanT2V:
|
|
| 268 |
if self.model.enable_teacache:
|
| 269 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
| 270 |
if callback != None:
|
| 271 |
-
callback(-1,
|
| 272 |
for i, t in enumerate(tqdm(timesteps)):
|
| 273 |
latent_model_input = latents
|
| 274 |
slg_layers_local = None
|
|
@@ -322,7 +322,7 @@ class WanT2V:
|
|
| 322 |
del temp_x0
|
| 323 |
|
| 324 |
if callback is not None:
|
| 325 |
-
callback(i,
|
| 326 |
|
| 327 |
x0 = latents
|
| 328 |
if offload_model:
|
|
|
|
| 268 |
if self.model.enable_teacache:
|
| 269 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
| 270 |
if callback != None:
|
| 271 |
+
callback(-1, True)
|
| 272 |
for i, t in enumerate(tqdm(timesteps)):
|
| 273 |
latent_model_input = latents
|
| 274 |
slg_layers_local = None
|
|
|
|
| 322 |
del temp_x0
|
| 323 |
|
| 324 |
if callback is not None:
|
| 325 |
+
callback(i, False)
|
| 326 |
|
| 327 |
x0 = latents
|
| 328 |
if offload_model:
|