DeepBeepMeep commited on
Commit ·
bd91f8e
1
Parent(s): d6835bd
Added Vac Contronet support
Browse files- gradio_server.py +255 -152
- requirements.txt +6 -2
- wan/configs/__init__.py +14 -0
- wan/modules/model.py +133 -15
- wan/text2video.py +178 -16
- wan/utils/utils.py +49 -0
- wan/utils/vace_preprocessor.py +298 -0
gradio_server.py
CHANGED
|
@@ -14,7 +14,7 @@ import gradio as gr
|
|
| 14 |
import random
|
| 15 |
import json
|
| 16 |
import wan
|
| 17 |
-
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES
|
| 18 |
from wan.utils.utils import cache_video
|
| 19 |
from wan.modules.attention import get_attention_modes, get_supported_attention_modes
|
| 20 |
import torch
|
|
@@ -55,6 +55,11 @@ def format_time(seconds):
|
|
| 55 |
def pil_to_base64_uri(pil_image, format="png", quality=75):
|
| 56 |
if pil_image is None:
|
| 57 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
buffer = io.BytesIO()
|
| 59 |
try:
|
| 60 |
img_to_save = pil_image
|
|
@@ -93,10 +98,11 @@ def process_prompt_and_add_tasks(
|
|
| 93 |
loras_choices,
|
| 94 |
loras_mult_choices,
|
| 95 |
image_prompt_type,
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
max_frames,
|
|
|
|
| 100 |
temporal_upsampling,
|
| 101 |
spatial_upsampling,
|
| 102 |
RIFLEx_setting,
|
|
@@ -127,9 +133,9 @@ def process_prompt_and_add_tasks(
|
|
| 127 |
return
|
| 128 |
|
| 129 |
file_model_needed = model_needed(image2video)
|
|
|
|
|
|
|
| 130 |
if image2video:
|
| 131 |
-
width, height = resolution.split("x")
|
| 132 |
-
width, height = int(width), int(height)
|
| 133 |
|
| 134 |
if "480p" in file_model_needed and not "Fun" in file_model_needed and width * height > 848*480:
|
| 135 |
gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
|
|
@@ -143,74 +149,94 @@ def process_prompt_and_add_tasks(
|
|
| 143 |
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
|
| 144 |
return
|
| 145 |
|
| 146 |
-
if image2video:
|
| 147 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
return
|
| 149 |
if image_prompt_type == 0:
|
| 150 |
-
|
| 151 |
-
if isinstance(
|
| 152 |
-
|
| 153 |
else:
|
| 154 |
-
|
| 155 |
-
if
|
| 156 |
-
if isinstance(
|
| 157 |
-
|
| 158 |
else:
|
| 159 |
-
|
| 160 |
-
if len(
|
| 161 |
gr.Info("The number of start and end images should be the same ")
|
| 162 |
return
|
| 163 |
|
| 164 |
if multi_images_gen_type == 0:
|
| 165 |
new_prompts = []
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
for i in range(len(prompts) * len(
|
| 169 |
new_prompts.append( prompts[ i % len(prompts)] )
|
| 170 |
-
|
| 171 |
-
if
|
| 172 |
-
|
| 173 |
prompts = new_prompts
|
| 174 |
-
|
| 175 |
-
if
|
| 176 |
-
|
| 177 |
else:
|
| 178 |
-
if len(prompts) >= len(
|
| 179 |
-
if len(prompts) % len(
|
| 180 |
raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
|
| 181 |
-
rep = len(prompts) // len(
|
| 182 |
-
|
| 183 |
-
|
| 184 |
for i, _ in enumerate(prompts):
|
| 185 |
-
|
| 186 |
-
if
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
if
|
| 190 |
-
|
| 191 |
else:
|
| 192 |
-
if len(
|
| 193 |
raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
|
| 194 |
-
rep = len(
|
| 195 |
new_prompts = []
|
| 196 |
-
for i, _ in enumerate(
|
| 197 |
new_prompts.append( prompts[ i//rep] )
|
| 198 |
prompts = new_prompts
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
image_to_continue = [None] * len(prompts)
|
| 210 |
-
if image_to_end == None:
|
| 211 |
-
image_to_end = [None] * len(prompts)
|
| 212 |
-
|
| 213 |
-
for single_prompt, image_start, image_end in zip(prompts, image_to_continue, image_to_end) :
|
| 214 |
kwargs = {
|
| 215 |
"prompt" : single_prompt,
|
| 216 |
"negative_prompt" : negative_prompt,
|
|
@@ -228,10 +254,11 @@ def process_prompt_and_add_tasks(
|
|
| 228 |
"loras_choices" : loras_choices,
|
| 229 |
"loras_mult_choices" : loras_mult_choices,
|
| 230 |
"image_prompt_type" : image_prompt_type,
|
| 231 |
-
"
|
| 232 |
-
"
|
| 233 |
-
"
|
| 234 |
"max_frames" : max_frames,
|
|
|
|
| 235 |
"temporal_upsampling" : temporal_upsampling,
|
| 236 |
"spatial_upsampling" : spatial_upsampling,
|
| 237 |
"RIFLEx_setting" : RIFLEx_setting,
|
|
@@ -262,8 +289,9 @@ def add_video_task(**kwargs):
|
|
| 262 |
queue = gen["queue"]
|
| 263 |
task_id += 1
|
| 264 |
current_task_id = task_id
|
| 265 |
-
start_image_data = kwargs["
|
| 266 |
-
|
|
|
|
| 267 |
|
| 268 |
queue.append({
|
| 269 |
"id": current_task_id,
|
|
@@ -275,7 +303,7 @@ def add_video_task(**kwargs):
|
|
| 275 |
"prompt": kwargs["prompt"],
|
| 276 |
"start_image_data": start_image_data,
|
| 277 |
"end_image_data": end_image_data,
|
| 278 |
-
"start_image_data_base64":
|
| 279 |
"end_image_data_base64": pil_to_base64_uri(end_image_data, format="jpeg", quality=70)
|
| 280 |
})
|
| 281 |
return update_queue_data(queue)
|
|
@@ -342,6 +370,7 @@ def get_queue_table(queue):
|
|
| 342 |
full_prompt = item['prompt'].replace('"', '"')
|
| 343 |
prompt_cell = f'<span title="{full_prompt}">{truncated_prompt}</span>'
|
| 344 |
start_img_uri =item.get('start_image_data_base64')
|
|
|
|
| 345 |
end_img_uri = item.get('end_image_data_base64')
|
| 346 |
thumbnail_size = "50px"
|
| 347 |
num_steps = item.get('steps')
|
|
@@ -694,6 +723,9 @@ attention_modes_installed = get_attention_modes()
|
|
| 694 |
attention_modes_supported = get_supported_attention_modes()
|
| 695 |
args = _parse_args()
|
| 696 |
args.flow_reverse = True
|
|
|
|
|
|
|
|
|
|
| 697 |
# torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
| 698 |
lock_ui_attention = False
|
| 699 |
lock_ui_transformer = False
|
|
@@ -706,7 +738,7 @@ quantizeTransformer = args.quantize_transformer
|
|
| 706 |
check_loras = args.check_loras ==1
|
| 707 |
advanced = args.advanced
|
| 708 |
|
| 709 |
-
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors"]
|
| 710 |
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ]
|
| 711 |
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
|
| 712 |
|
|
@@ -750,7 +782,7 @@ def get_default_settings(filename, i2v):
|
|
| 750 |
"prompts": get_default_prompt(i2v),
|
| 751 |
"resolution": "832x480",
|
| 752 |
"video_length": 81,
|
| 753 |
-
"image_prompt_type" : 0,
|
| 754 |
"num_inference_steps": 30,
|
| 755 |
"seed": -1,
|
| 756 |
"repeat_generation": 1,
|
|
@@ -1149,6 +1181,9 @@ def get_model_name(model_filename):
|
|
| 1149 |
if "Fun" in model_filename:
|
| 1150 |
model_name = "Fun InP image2video"
|
| 1151 |
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
|
|
|
|
|
|
|
|
|
| 1152 |
elif "image" in model_filename:
|
| 1153 |
model_name = "Wan2.1 image2video"
|
| 1154 |
model_name += " 720p" if "720p" in model_filename else " 480p"
|
|
@@ -1353,22 +1388,22 @@ def refresh_gallery(state, msg):
|
|
| 1353 |
end_img_md = ""
|
| 1354 |
prompt = task["prompt"]
|
| 1355 |
|
| 1356 |
-
|
| 1357 |
-
|
| 1358 |
-
|
| 1359 |
-
|
| 1360 |
-
|
| 1361 |
-
|
| 1362 |
-
|
| 1363 |
-
|
| 1364 |
|
| 1365 |
label = f"Prompt of Video being Generated"
|
| 1366 |
|
| 1367 |
html = "<STYLE> #PINFO, #PINFO th, #PINFO td {border: 1px solid #CCCCCC;background-color:#FFFFFF;}</STYLE><TABLE WIDTH=100% ID=PINFO ><TR><TD width=100%>" + prompt + "</TD>"
|
| 1368 |
if start_img_md != "":
|
| 1369 |
html += "<TD>" + start_img_md + "</TD>"
|
| 1370 |
-
|
| 1371 |
-
|
| 1372 |
|
| 1373 |
html += "</TR></TABLE>"
|
| 1374 |
html_output = gr.HTML(html, visible= True)
|
|
@@ -1419,24 +1454,26 @@ def expand_slist(slist, num_inference_steps ):
|
|
| 1419 |
new_slist.append(slist[ int(pos)])
|
| 1420 |
pos += inc
|
| 1421 |
return new_slist
|
| 1422 |
-
|
| 1423 |
def convert_image(image):
|
| 1424 |
-
|
| 1425 |
-
|
| 1426 |
-
|
| 1427 |
-
|
| 1428 |
-
|
| 1429 |
-
|
| 1430 |
-
|
| 1431 |
-
if
|
| 1432 |
-
|
| 1433 |
-
|
| 1434 |
-
|
| 1435 |
-
|
| 1436 |
-
|
| 1437 |
-
|
| 1438 |
-
|
| 1439 |
-
|
|
|
|
|
|
|
|
|
|
| 1440 |
|
| 1441 |
def generate_video(
|
| 1442 |
task_id,
|
|
@@ -1457,10 +1494,11 @@ def generate_video(
|
|
| 1457 |
loras_choices,
|
| 1458 |
loras_mult_choices,
|
| 1459 |
image_prompt_type,
|
| 1460 |
-
|
| 1461 |
-
|
| 1462 |
-
|
| 1463 |
max_frames,
|
|
|
|
| 1464 |
temporal_upsampling,
|
| 1465 |
spatial_upsampling,
|
| 1466 |
RIFLEx_setting,
|
|
@@ -1507,7 +1545,6 @@ def generate_video(
|
|
| 1507 |
gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.")
|
| 1508 |
return
|
| 1509 |
|
| 1510 |
-
|
| 1511 |
|
| 1512 |
if not image2video:
|
| 1513 |
width, height = resolution.split("x")
|
|
@@ -1586,7 +1623,7 @@ def generate_video(
|
|
| 1586 |
|
| 1587 |
enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
|
| 1588 |
# VAE Tiling
|
| 1589 |
-
device_mem_capacity = torch.cuda.get_device_properties(
|
| 1590 |
|
| 1591 |
joint_pass = boost ==1 #and profile != 1 and profile != 3
|
| 1592 |
# TeaCache
|
|
@@ -1615,6 +1652,17 @@ def generate_video(
|
|
| 1615 |
else:
|
| 1616 |
raise gr.Error("Teacache not supported for this model")
|
| 1617 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1618 |
import random
|
| 1619 |
if seed == None or seed <0:
|
| 1620 |
seed = random.randint(0, 999999999)
|
|
@@ -1673,8 +1721,8 @@ def generate_video(
|
|
| 1673 |
if image2video:
|
| 1674 |
samples = wan_model.generate(
|
| 1675 |
prompt,
|
| 1676 |
-
|
| 1677 |
-
|
| 1678 |
frame_num=(video_length // 4)* 4 + 1,
|
| 1679 |
max_area=MAX_AREA_CONFIGS[resolution],
|
| 1680 |
shift=flow_shift,
|
|
@@ -1697,6 +1745,9 @@ def generate_video(
|
|
| 1697 |
else:
|
| 1698 |
samples = wan_model.generate(
|
| 1699 |
prompt,
|
|
|
|
|
|
|
|
|
|
| 1700 |
frame_num=(video_length // 4)* 4 + 1,
|
| 1701 |
size=(width, height),
|
| 1702 |
shift=flow_shift,
|
|
@@ -1745,7 +1796,7 @@ def generate_video(
|
|
| 1745 |
new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames."
|
| 1746 |
else:
|
| 1747 |
new_error = gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
| 1748 |
-
tb = traceback.format_exc().split('\n')[:-
|
| 1749 |
print('\n'.join(tb))
|
| 1750 |
raise gr.Error(new_error, print_exception= False)
|
| 1751 |
|
|
@@ -1799,7 +1850,7 @@ def generate_video(
|
|
| 1799 |
|
| 1800 |
if exp > 0:
|
| 1801 |
from rife.inference import temporal_interpolation
|
| 1802 |
-
sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=
|
| 1803 |
fps = fps * 2**exp
|
| 1804 |
|
| 1805 |
if len(spatial_upsampling) > 0:
|
|
@@ -1831,8 +1882,7 @@ def generate_video(
|
|
| 1831 |
normalize=True,
|
| 1832 |
value_range=(-1, 1))
|
| 1833 |
|
| 1834 |
-
|
| 1835 |
-
configs = get_settings_dict(state, image2video, prompt, 0 if image_to_end == None else 1 , video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
| 1836 |
loras_mult_choices, tea_cache , tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
|
| 1837 |
|
| 1838 |
metadata_choice = server_config.get("metadata_choice","metadata")
|
|
@@ -2294,7 +2344,7 @@ def switch_advanced(state, new_advanced, lset_name):
|
|
| 2294 |
return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
|
| 2295 |
|
| 2296 |
|
| 2297 |
-
def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
| 2298 |
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
|
| 2299 |
|
| 2300 |
loras = state["loras"]
|
|
@@ -2330,18 +2380,22 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
|
|
| 2330 |
ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - image2video"
|
| 2331 |
ui_settings["image_prompt_type"] = image_prompt_type
|
| 2332 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2333 |
ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video"
|
| 2334 |
|
| 2335 |
return ui_settings
|
| 2336 |
|
| 2337 |
-
def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
| 2338 |
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
|
| 2339 |
|
| 2340 |
if state.get("validate_success",0) != 1:
|
| 2341 |
return
|
| 2342 |
|
| 2343 |
image2video = state["image2video"]
|
| 2344 |
-
ui_defaults = get_settings_dict(state, image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
| 2345 |
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
|
| 2346 |
|
| 2347 |
defaults_filename = get_settings_file_name(image2video)
|
|
@@ -2379,6 +2433,25 @@ def download_loras():
|
|
| 2379 |
writer.write(f"Loras downloaded on the {dt} at {time.time()} on the {time.time()}")
|
| 2380 |
return
|
| 2381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2382 |
def generate_video_tab(image2video=False):
|
| 2383 |
filename = transformer_filename_i2v if image2video else transformer_filename_t2v
|
| 2384 |
ui_defaults= get_default_settings(filename, image2video)
|
|
@@ -2387,6 +2460,7 @@ def generate_video_tab(image2video=False):
|
|
| 2387 |
|
| 2388 |
state_dict["advanced"] = advanced
|
| 2389 |
state_dict["loras_model"] = filename
|
|
|
|
| 2390 |
state_dict["image2video"] = image2video
|
| 2391 |
gen = dict()
|
| 2392 |
gen["queue"] = []
|
|
@@ -2461,31 +2535,51 @@ def generate_video_tab(image2video=False):
|
|
| 2461 |
save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
|
| 2462 |
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
|
| 2463 |
cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
|
| 2464 |
-
video_to_continue = gr.Video(label= "Video to continue", visible= image2video and False) #######
|
| 2465 |
-
image_prompt_type= ui_defaults.get("image_prompt_type",0)
|
| 2466 |
-
image_prompt_type_radio = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =image_prompt_type, label="Location", show_label= False, scale= 3, visible=image2video)
|
| 2467 |
-
|
| 2468 |
-
if args.multiple_images:
|
| 2469 |
-
image_to_continue = gr.Gallery(
|
| 2470 |
-
label="Images as starting points for new videos", type ="pil", #file_types= "image",
|
| 2471 |
-
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image2video)
|
| 2472 |
-
else:
|
| 2473 |
-
image_to_continue = gr.Image(label= "Image as a starting point for a new video", type ="pil", visible=image2video)
|
| 2474 |
|
| 2475 |
-
|
| 2476 |
-
|
| 2477 |
-
|
| 2478 |
-
|
| 2479 |
-
|
| 2480 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2481 |
|
| 2482 |
-
|
| 2483 |
-
|
| 2484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2485 |
else:
|
| 2486 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2487 |
|
| 2488 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2489 |
|
| 2490 |
|
| 2491 |
advanced_prompt = advanced
|
|
@@ -2518,7 +2612,6 @@ def generate_video_tab(image2video=False):
|
|
| 2518 |
wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3)
|
| 2519 |
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
|
| 2520 |
wizard_variables_var = gr.Text(wizard_variables, visible = False)
|
| 2521 |
-
state = gr.State(state_dict)
|
| 2522 |
with gr.Row():
|
| 2523 |
if image2video:
|
| 2524 |
resolution = gr.Dropdown(
|
|
@@ -2555,8 +2648,6 @@ def generate_video_tab(image2video=False):
|
|
| 2555 |
video_length = gr.Slider(5, 193, value=ui_defaults["video_length"], step=4, label="Number of frames (16 = 1s)")
|
| 2556 |
with gr.Column():
|
| 2557 |
num_inference_steps = gr.Slider(1, 100, value=ui_defaults["num_inference_steps"], step=1, label="Number of Inference Steps")
|
| 2558 |
-
with gr.Row():
|
| 2559 |
-
max_frames = gr.Slider(1, 100, value=9, step=1, label="Number of input frames to use for Video2World prediction", visible=image2video and False) #########
|
| 2560 |
show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced)
|
| 2561 |
with gr.Row(visible=advanced) as advanced_row:
|
| 2562 |
with gr.Column():
|
|
@@ -2605,7 +2696,7 @@ def generate_video_tab(image2video=False):
|
|
| 2605 |
tea_cache_start_step_perc = gr.Slider(0, 100, value=ui_defaults["tea_cache_start_step_perc"], step=1, label="Tea Cache starting moment in % of generation")
|
| 2606 |
|
| 2607 |
with gr.Row():
|
| 2608 |
-
gr.Markdown("<B>Upsampling</B>")
|
| 2609 |
with gr.Row():
|
| 2610 |
temporal_upsampling_choice = gr.Dropdown(
|
| 2611 |
choices=[
|
|
@@ -2687,9 +2778,10 @@ def generate_video_tab(image2video=False):
|
|
| 2687 |
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
|
| 2688 |
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
|
| 2689 |
with gr.Column():
|
| 2690 |
-
gen_status = gr.Text(
|
| 2691 |
-
full_sync = gr.Text(
|
| 2692 |
-
light_sync = gr.Text(
|
|
|
|
| 2693 |
gen_progress_html = gr.HTML(
|
| 2694 |
label="Status",
|
| 2695 |
value="Idle",
|
|
@@ -2709,8 +2801,8 @@ def generate_video_tab(image2video=False):
|
|
| 2709 |
abort_btn = gr.Button("Abort")
|
| 2710 |
|
| 2711 |
queue_df = gr.DataFrame(
|
| 2712 |
-
headers=["Qty","Prompt", "Length","Steps","
|
| 2713 |
-
datatype=[ "str","markdown","str", "markdown", "markdown", "markdown",
|
| 2714 |
column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"],
|
| 2715 |
interactive=False,
|
| 2716 |
col_count=(9, "fixed"),
|
|
@@ -2792,7 +2884,7 @@ def generate_video_tab(image2video=False):
|
|
| 2792 |
show_progress="hidden"
|
| 2793 |
)
|
| 2794 |
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
|
| 2795 |
-
save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
|
| 2796 |
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling_choice, spatial_upsampling_choice, RIFLEx_setting, slg_switch, slg_layers,
|
| 2797 |
slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = [])
|
| 2798 |
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
|
@@ -2808,21 +2900,30 @@ def generate_video_tab(image2video=False):
|
|
| 2808 |
refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
| 2809 |
output.select(select_video, state, None )
|
| 2810 |
|
|
|
|
|
|
|
| 2811 |
gen_status.change(refresh_gallery,
|
| 2812 |
inputs = [state, gen_status],
|
| 2813 |
outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn])
|
| 2814 |
|
| 2815 |
-
full_sync.change(
|
|
|
|
|
|
|
|
|
|
| 2816 |
inputs = [state, gen_status],
|
| 2817 |
outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
|
| 2818 |
-
).then(
|
| 2819 |
inputs= [state],
|
| 2820 |
outputs =[gen_status],
|
| 2821 |
).then(finalize_generation,
|
| 2822 |
inputs= [state],
|
| 2823 |
outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
|
| 2824 |
)
|
| 2825 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2826 |
inputs = [state, gen_status],
|
| 2827 |
outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
|
| 2828 |
)
|
|
@@ -2848,10 +2949,11 @@ def generate_video_tab(image2video=False):
|
|
| 2848 |
loras_choices,
|
| 2849 |
loras_mult_choices,
|
| 2850 |
image_prompt_type_radio,
|
| 2851 |
-
|
| 2852 |
-
|
| 2853 |
-
|
| 2854 |
max_frames,
|
|
|
|
| 2855 |
temporal_upsampling_choice,
|
| 2856 |
spatial_upsampling_choice,
|
| 2857 |
RIFLEx_setting,
|
|
@@ -2902,7 +3004,7 @@ def generate_video_tab(image2video=False):
|
|
| 2902 |
)
|
| 2903 |
return loras_column, loras_choices, presets_column, lset_name, header, light_sync, full_sync, state
|
| 2904 |
|
| 2905 |
-
def
|
| 2906 |
with gr.Row():
|
| 2907 |
with gr.Row(scale =2):
|
| 2908 |
gr.Markdown("<I>Wan2GP's Lora Festival ! Press the following button to download i2v <B>Remade</B> Loras collection (and bonuses Loras).")
|
|
@@ -2928,6 +3030,7 @@ def generate_configuration_tab():
|
|
| 2928 |
("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0),
|
| 2929 |
("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1),
|
| 2930 |
("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2),
|
|
|
|
| 2931 |
],
|
| 2932 |
value= index,
|
| 2933 |
label="Transformer model for Text to Video",
|
|
@@ -3108,16 +3211,17 @@ def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData):
|
|
| 3108 |
t2v_light_sync = gr.Text()
|
| 3109 |
i2v_full_sync = gr.Text()
|
| 3110 |
t2v_full_sync = gr.Text()
|
| 3111 |
-
if new_t2v or new_i2v:
|
| 3112 |
-
last_tab_was_image2video =global_state.get("last_tab_was_image2video", None)
|
| 3113 |
-
if last_tab_was_image2video == None or last_tab_was_image2video:
|
| 3114 |
-
gen = i2v_state["gen"]
|
| 3115 |
-
t2v_state["gen"] = gen
|
| 3116 |
-
else:
|
| 3117 |
-
gen = t2v_state["gen"]
|
| 3118 |
-
i2v_state["gen"] = gen
|
| 3119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3120 |
|
|
|
|
|
|
|
| 3121 |
if last_tab_was_image2video != None and new_t2v != new_i2v:
|
| 3122 |
gen_location = gen.get("location", None)
|
| 3123 |
if "in_progress" in gen and gen_location !=None and not (gen_location and new_i2v or not gen_location and new_t2v) :
|
|
@@ -3131,7 +3235,6 @@ def on_tab_select(global_state, t2v_state, i2v_state, evt: gr.SelectData):
|
|
| 3131 |
else:
|
| 3132 |
t2v_light_sync = gr.Text(str(time.time()))
|
| 3133 |
|
| 3134 |
-
|
| 3135 |
global_state["last_tab_was_image2video"] = new_i2v
|
| 3136 |
|
| 3137 |
if(server_config.get("reload_model",2) == 1):
|
|
@@ -3433,7 +3536,7 @@ def create_demo():
|
|
| 3433 |
}
|
| 3434 |
"""
|
| 3435 |
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
|
| 3436 |
-
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP>
|
| 3437 |
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
|
| 3438 |
|
| 3439 |
with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
|
|
@@ -3454,7 +3557,7 @@ def create_demo():
|
|
| 3454 |
i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync, i2v_state = generate_video_tab(True)
|
| 3455 |
if not args.lock_config:
|
| 3456 |
with gr.Tab("Downloads", id="downloads") as downloads_tab:
|
| 3457 |
-
|
| 3458 |
with gr.Tab("Configuration"):
|
| 3459 |
generate_configuration_tab()
|
| 3460 |
with gr.Tab("About"):
|
|
|
|
| 14 |
import random
|
| 15 |
import json
|
| 16 |
import wan
|
| 17 |
+
from wan.configs import MAX_AREA_CONFIGS, WAN_CONFIGS, SUPPORTED_SIZES, VACE_SIZE_CONFIGS
|
| 18 |
from wan.utils.utils import cache_video
|
| 19 |
from wan.modules.attention import get_attention_modes, get_supported_attention_modes
|
| 20 |
import torch
|
|
|
|
| 55 |
def pil_to_base64_uri(pil_image, format="png", quality=75):
|
| 56 |
if pil_image is None:
|
| 57 |
return None
|
| 58 |
+
|
| 59 |
+
if isinstance(pil_image, str):
|
| 60 |
+
from wan.utils.utils import get_video_frame
|
| 61 |
+
pil_image = get_video_frame(pil_image, 0)
|
| 62 |
+
|
| 63 |
buffer = io.BytesIO()
|
| 64 |
try:
|
| 65 |
img_to_save = pil_image
|
|
|
|
| 98 |
loras_choices,
|
| 99 |
loras_mult_choices,
|
| 100 |
image_prompt_type,
|
| 101 |
+
image_source1,
|
| 102 |
+
image_source2,
|
| 103 |
+
image_source3,
|
| 104 |
max_frames,
|
| 105 |
+
remove_background_image_ref,
|
| 106 |
temporal_upsampling,
|
| 107 |
spatial_upsampling,
|
| 108 |
RIFLEx_setting,
|
|
|
|
| 133 |
return
|
| 134 |
|
| 135 |
file_model_needed = model_needed(image2video)
|
| 136 |
+
width, height = resolution.split("x")
|
| 137 |
+
width, height = int(width), int(height)
|
| 138 |
if image2video:
|
|
|
|
|
|
|
| 139 |
|
| 140 |
if "480p" in file_model_needed and not "Fun" in file_model_needed and width * height > 848*480:
|
| 141 |
gr.Info("You must use the 720P image to video model to generate videos with a resolution equivalent to 720P")
|
|
|
|
| 149 |
gr.Info("You must use the 14B model to generate videos with a resolution equivalent to 720P")
|
| 150 |
return
|
| 151 |
|
| 152 |
+
if not image2video:
|
| 153 |
+
if "Vace" in file_model_needed and "1.3B" in file_model_needed :
|
| 154 |
+
resolution_reformated = str(height) + "*" + str(width)
|
| 155 |
+
if not resolution_reformated in VACE_SIZE_CONFIGS:
|
| 156 |
+
res = VACE_SIZE_CONFIGS.keys().join(" and ")
|
| 157 |
+
gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
if not "I" in image_prompt_type:
|
| 161 |
+
image_source1 = None
|
| 162 |
+
if not "V" in image_prompt_type:
|
| 163 |
+
image_source2 = None
|
| 164 |
+
if not "M" in image_prompt_type:
|
| 165 |
+
image_source3 = None
|
| 166 |
+
|
| 167 |
+
if isinstance(image_source1, list):
|
| 168 |
+
image_source1 = [ convert_image(tup[0]) for tup in image_source1 ]
|
| 169 |
+
|
| 170 |
+
from wan.utils.utils import resize_and_remove_background
|
| 171 |
+
image_source1 = resize_and_remove_background(image_source1, width, height, remove_background_image_ref ==1)
|
| 172 |
+
|
| 173 |
+
image_source1 = [ image_source1 ] * len(prompts)
|
| 174 |
+
image_source2 = [ image_source2 ] * len(prompts)
|
| 175 |
+
image_source3 = [ image_source3 ] * len(prompts)
|
| 176 |
+
|
| 177 |
+
else:
|
| 178 |
+
if image_source1 == None or isinstance(image_source1, list) and len(image_source1) == 0:
|
| 179 |
return
|
| 180 |
if image_prompt_type == 0:
|
| 181 |
+
image_source2 = None
|
| 182 |
+
if isinstance(image_source1, list):
|
| 183 |
+
image_source1 = [ convert_image(tup[0]) for tup in image_source1 ]
|
| 184 |
else:
|
| 185 |
+
image_source1 = [convert_image(image_source1)]
|
| 186 |
+
if image_source2 != None:
|
| 187 |
+
if isinstance(image_source2 , list):
|
| 188 |
+
image_source2 = [ convert_image(tup[0]) for tup in image_source2 ]
|
| 189 |
else:
|
| 190 |
+
image_source2 = [convert_image(image_source2) ]
|
| 191 |
+
if len(image_source1) != len(image_source2):
|
| 192 |
gr.Info("The number of start and end images should be the same ")
|
| 193 |
return
|
| 194 |
|
| 195 |
if multi_images_gen_type == 0:
|
| 196 |
new_prompts = []
|
| 197 |
+
new_image_source1 = []
|
| 198 |
+
new_image_source2 = []
|
| 199 |
+
for i in range(len(prompts) * len(image_source1) ):
|
| 200 |
new_prompts.append( prompts[ i % len(prompts)] )
|
| 201 |
+
new_image_source1.append(image_source1[i // len(prompts)] )
|
| 202 |
+
if image_source2 != None:
|
| 203 |
+
new_image_source2.append(image_source2[i // len(prompts)] )
|
| 204 |
prompts = new_prompts
|
| 205 |
+
image_source1 = new_image_source1
|
| 206 |
+
if image_source2 != None:
|
| 207 |
+
image_source2 = new_image_source2
|
| 208 |
else:
|
| 209 |
+
if len(prompts) >= len(image_source1):
|
| 210 |
+
if len(prompts) % len(image_source1) !=0:
|
| 211 |
raise gr.Error("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
|
| 212 |
+
rep = len(prompts) // len(image_source1)
|
| 213 |
+
new_image_source1 = []
|
| 214 |
+
new_image_source2 = []
|
| 215 |
for i, _ in enumerate(prompts):
|
| 216 |
+
new_image_source1.append(image_source1[i//rep] )
|
| 217 |
+
if image_source2 != None:
|
| 218 |
+
new_image_source2.append(image_source2[i//rep] )
|
| 219 |
+
image_source1 = new_image_source1
|
| 220 |
+
if image_source2 != None:
|
| 221 |
+
image_source2 = new_image_source2
|
| 222 |
else:
|
| 223 |
+
if len(image_source1) % len(prompts) !=0:
|
| 224 |
raise gr.Error("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
|
| 225 |
+
rep = len(image_source1) // len(prompts)
|
| 226 |
new_prompts = []
|
| 227 |
+
for i, _ in enumerate(image_source1):
|
| 228 |
new_prompts.append( prompts[ i//rep] )
|
| 229 |
prompts = new_prompts
|
| 230 |
|
| 231 |
+
|
| 232 |
+
if image_source1 == None:
|
| 233 |
+
image_source1 = [None] * len(prompts)
|
| 234 |
+
if image_source2 == None:
|
| 235 |
+
image_source2 = [None] * len(prompts)
|
| 236 |
+
if image_source3 == None:
|
| 237 |
+
image_source3 = [None] * len(prompts)
|
| 238 |
+
|
| 239 |
+
for single_prompt, image_source1, image_source2, image_source3 in zip(prompts, image_source1, image_source2, image_source3) :
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
kwargs = {
|
| 241 |
"prompt" : single_prompt,
|
| 242 |
"negative_prompt" : negative_prompt,
|
|
|
|
| 254 |
"loras_choices" : loras_choices,
|
| 255 |
"loras_mult_choices" : loras_mult_choices,
|
| 256 |
"image_prompt_type" : image_prompt_type,
|
| 257 |
+
"image_source1": image_source1,
|
| 258 |
+
"image_source2" : image_source2,
|
| 259 |
+
"image_source3" : image_source3 ,
|
| 260 |
"max_frames" : max_frames,
|
| 261 |
+
"remove_background_image_ref" : remove_background_image_ref,
|
| 262 |
"temporal_upsampling" : temporal_upsampling,
|
| 263 |
"spatial_upsampling" : spatial_upsampling,
|
| 264 |
"RIFLEx_setting" : RIFLEx_setting,
|
|
|
|
| 289 |
queue = gen["queue"]
|
| 290 |
task_id += 1
|
| 291 |
current_task_id = task_id
|
| 292 |
+
start_image_data = kwargs["image_source1"]
|
| 293 |
+
start_image_data = [start_image_data] if not isinstance(start_image_data, list) else start_image_data
|
| 294 |
+
end_image_data = kwargs["image_source2"]
|
| 295 |
|
| 296 |
queue.append({
|
| 297 |
"id": current_task_id,
|
|
|
|
| 303 |
"prompt": kwargs["prompt"],
|
| 304 |
"start_image_data": start_image_data,
|
| 305 |
"end_image_data": end_image_data,
|
| 306 |
+
"start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data],
|
| 307 |
"end_image_data_base64": pil_to_base64_uri(end_image_data, format="jpeg", quality=70)
|
| 308 |
})
|
| 309 |
return update_queue_data(queue)
|
|
|
|
| 370 |
full_prompt = item['prompt'].replace('"', '"')
|
| 371 |
prompt_cell = f'<span title="{full_prompt}">{truncated_prompt}</span>'
|
| 372 |
start_img_uri =item.get('start_image_data_base64')
|
| 373 |
+
start_img_uri = start_img_uri[0] if start_img_uri !=None else None
|
| 374 |
end_img_uri = item.get('end_image_data_base64')
|
| 375 |
thumbnail_size = "50px"
|
| 376 |
num_steps = item.get('steps')
|
|
|
|
| 723 |
attention_modes_supported = get_supported_attention_modes()
|
| 724 |
args = _parse_args()
|
| 725 |
args.flow_reverse = True
|
| 726 |
+
processing_device = args.gpu
|
| 727 |
+
if len(processing_device) == 0:
|
| 728 |
+
processing_device ="cuda"
|
| 729 |
# torch.backends.cuda.matmul.allow_fp16_accumulation = True
|
| 730 |
lock_ui_attention = False
|
| 731 |
lock_ui_transformer = False
|
|
|
|
| 738 |
check_loras = args.check_loras ==1
|
| 739 |
advanced = args.advanced
|
| 740 |
|
| 741 |
+
transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_bf16.safetensors", "ckpts/wan2.1_text2video_14B_quanto_int8.safetensors", "ckpts/wan2.1_Vace_1.3B_preview_bf16.safetensors"]
|
| 742 |
transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ]
|
| 743 |
text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
|
| 744 |
|
|
|
|
| 782 |
"prompts": get_default_prompt(i2v),
|
| 783 |
"resolution": "832x480",
|
| 784 |
"video_length": 81,
|
| 785 |
+
"image_prompt_type" : 0 if i2v else "",
|
| 786 |
"num_inference_steps": 30,
|
| 787 |
"seed": -1,
|
| 788 |
"repeat_generation": 1,
|
|
|
|
| 1181 |
if "Fun" in model_filename:
|
| 1182 |
model_name = "Fun InP image2video"
|
| 1183 |
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
| 1184 |
+
elif "Vace" in model_filename:
|
| 1185 |
+
model_name = "Vace ControlNet text2video"
|
| 1186 |
+
model_name += " 14B" if "14B" in model_filename else " 1.3B"
|
| 1187 |
elif "image" in model_filename:
|
| 1188 |
model_name = "Wan2.1 image2video"
|
| 1189 |
model_name += " 720p" if "720p" in model_filename else " 480p"
|
|
|
|
| 1388 |
end_img_md = ""
|
| 1389 |
prompt = task["prompt"]
|
| 1390 |
|
| 1391 |
+
start_img_uri = task.get('start_image_data_base64')
|
| 1392 |
+
start_img_uri = start_img_uri[0] if start_img_uri !=None else None
|
| 1393 |
+
end_img_uri = task.get('end_image_data_base64')
|
| 1394 |
+
thumbnail_size = "100px"
|
| 1395 |
+
if start_img_uri:
|
| 1396 |
+
start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
|
| 1397 |
+
if end_img_uri:
|
| 1398 |
+
end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
|
| 1399 |
|
| 1400 |
label = f"Prompt of Video being Generated"
|
| 1401 |
|
| 1402 |
html = "<STYLE> #PINFO, #PINFO th, #PINFO td {border: 1px solid #CCCCCC;background-color:#FFFFFF;}</STYLE><TABLE WIDTH=100% ID=PINFO ><TR><TD width=100%>" + prompt + "</TD>"
|
| 1403 |
if start_img_md != "":
|
| 1404 |
html += "<TD>" + start_img_md + "</TD>"
|
| 1405 |
+
if end_img_md != "":
|
| 1406 |
+
html += "<TD>" + end_img_md + "</TD>"
|
| 1407 |
|
| 1408 |
html += "</TR></TABLE>"
|
| 1409 |
html_output = gr.HTML(html, visible= True)
|
|
|
|
| 1454 |
new_slist.append(slist[ int(pos)])
|
| 1455 |
pos += inc
|
| 1456 |
return new_slist
|
|
|
|
| 1457 |
def convert_image(image):
|
| 1458 |
+
|
| 1459 |
+
from PIL import ExifTags, ImageOps
|
| 1460 |
+
from typing import cast
|
| 1461 |
+
|
| 1462 |
+
return cast(Image, ImageOps.exif_transpose(image))
|
| 1463 |
+
# image = image.convert('RGB')
|
| 1464 |
+
# for orientation in ExifTags.TAGS.keys():
|
| 1465 |
+
# if ExifTags.TAGS[orientation]=='Orientation':
|
| 1466 |
+
# break
|
| 1467 |
+
# exif = image.getexif()
|
| 1468 |
+
# return image
|
| 1469 |
+
# if not orientation in exif:
|
| 1470 |
+
# if exif[orientation] == 3:
|
| 1471 |
+
# image=image.rotate(180, expand=True)
|
| 1472 |
+
# elif exif[orientation] == 6:
|
| 1473 |
+
# image=image.rotate(270, expand=True)
|
| 1474 |
+
# elif exif[orientation] == 8:
|
| 1475 |
+
# image=image.rotate(90, expand=True)
|
| 1476 |
+
# return image
|
| 1477 |
|
| 1478 |
def generate_video(
|
| 1479 |
task_id,
|
|
|
|
| 1494 |
loras_choices,
|
| 1495 |
loras_mult_choices,
|
| 1496 |
image_prompt_type,
|
| 1497 |
+
image_source1,
|
| 1498 |
+
image_source2,
|
| 1499 |
+
image_source3,
|
| 1500 |
max_frames,
|
| 1501 |
+
remove_background_image_ref,
|
| 1502 |
temporal_upsampling,
|
| 1503 |
spatial_upsampling,
|
| 1504 |
RIFLEx_setting,
|
|
|
|
| 1545 |
gr.Info(f"You have selected attention mode '{attention_mode}'. However it is not installed or supported on your system. You should either install it or switch to the default 'sdpa' attention.")
|
| 1546 |
return
|
| 1547 |
|
|
|
|
| 1548 |
|
| 1549 |
if not image2video:
|
| 1550 |
width, height = resolution.split("x")
|
|
|
|
| 1623 |
|
| 1624 |
enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
|
| 1625 |
# VAE Tiling
|
| 1626 |
+
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
|
| 1627 |
|
| 1628 |
joint_pass = boost ==1 #and profile != 1 and profile != 3
|
| 1629 |
# TeaCache
|
|
|
|
| 1652 |
else:
|
| 1653 |
raise gr.Error("Teacache not supported for this model")
|
| 1654 |
|
| 1655 |
+
if "Vace" in model_filename:
|
| 1656 |
+
resolution_reformated = str(height) + "*" + str(width)
|
| 1657 |
+
src_video, src_mask, src_ref_images = wan_model.prepare_source([image_source2],
|
| 1658 |
+
[image_source3],
|
| 1659 |
+
[image_source1],
|
| 1660 |
+
video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu",
|
| 1661 |
+
trim_video=max_frames)
|
| 1662 |
+
else:
|
| 1663 |
+
src_video, src_mask, src_ref_images = None, None, None
|
| 1664 |
+
|
| 1665 |
+
|
| 1666 |
import random
|
| 1667 |
if seed == None or seed <0:
|
| 1668 |
seed = random.randint(0, 999999999)
|
|
|
|
| 1721 |
if image2video:
|
| 1722 |
samples = wan_model.generate(
|
| 1723 |
prompt,
|
| 1724 |
+
image_source1,
|
| 1725 |
+
image_source2 if image_source2 != None else None,
|
| 1726 |
frame_num=(video_length // 4)* 4 + 1,
|
| 1727 |
max_area=MAX_AREA_CONFIGS[resolution],
|
| 1728 |
shift=flow_shift,
|
|
|
|
| 1745 |
else:
|
| 1746 |
samples = wan_model.generate(
|
| 1747 |
prompt,
|
| 1748 |
+
input_frames = src_video,
|
| 1749 |
+
input_ref_images= src_ref_images,
|
| 1750 |
+
input_masks = src_mask,
|
| 1751 |
frame_num=(video_length // 4)* 4 + 1,
|
| 1752 |
size=(width, height),
|
| 1753 |
shift=flow_shift,
|
|
|
|
| 1796 |
new_error = "The generation of the video has encountered an error: it is likely that you have unsufficient VRAM and you should therefore reduce the video resolution or its number of frames."
|
| 1797 |
else:
|
| 1798 |
new_error = gr.Error(f"The generation of the video has encountered an error, please check your terminal for more information. '{s}'")
|
| 1799 |
+
tb = traceback.format_exc().split('\n')[:-1]
|
| 1800 |
print('\n'.join(tb))
|
| 1801 |
raise gr.Error(new_error, print_exception= False)
|
| 1802 |
|
|
|
|
| 1850 |
|
| 1851 |
if exp > 0:
|
| 1852 |
from rife.inference import temporal_interpolation
|
| 1853 |
+
sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=processing_device)
|
| 1854 |
fps = fps * 2**exp
|
| 1855 |
|
| 1856 |
if len(spatial_upsampling) > 0:
|
|
|
|
| 1882 |
normalize=True,
|
| 1883 |
value_range=(-1, 1))
|
| 1884 |
|
| 1885 |
+
configs = get_settings_dict(state, image2video, True, prompt, image_prompt_type, max_frames , remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
|
|
|
| 1886 |
loras_mult_choices, tea_cache , tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
|
| 1887 |
|
| 1888 |
metadata_choice = server_config.get("metadata_choice","metadata")
|
|
|
|
| 2344 |
return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
|
| 2345 |
|
| 2346 |
|
| 2347 |
+
def get_settings_dict(state, i2v, image_metadata, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
| 2348 |
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
|
| 2349 |
|
| 2350 |
loras = state["loras"]
|
|
|
|
| 2380 |
ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - image2video"
|
| 2381 |
ui_settings["image_prompt_type"] = image_prompt_type
|
| 2382 |
else:
|
| 2383 |
+
if "Vace" in transformer_filename_t2v or not image_metadata:
|
| 2384 |
+
ui_settings["image_prompt_type"] = image_prompt_type
|
| 2385 |
+
ui_settings["max_frames"] = max_frames
|
| 2386 |
+
ui_settings["remove_background_image_ref"] = remove_background_image_ref
|
| 2387 |
ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video"
|
| 2388 |
|
| 2389 |
return ui_settings
|
| 2390 |
|
| 2391 |
+
def save_settings(state, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
| 2392 |
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
|
| 2393 |
|
| 2394 |
if state.get("validate_success",0) != 1:
|
| 2395 |
return
|
| 2396 |
|
| 2397 |
image2video = state["image2video"]
|
| 2398 |
+
ui_defaults = get_settings_dict(state, image2video, False, prompt, image_prompt_type, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
| 2399 |
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling, spatial_upsampling, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
|
| 2400 |
|
| 2401 |
defaults_filename = get_settings_file_name(image2video)
|
|
|
|
| 2433 |
writer.write(f"Loras downloaded on the {dt} at {time.time()} on the {time.time()}")
|
| 2434 |
return
|
| 2435 |
|
| 2436 |
+
def refresh_i2v_image_prompt_type_radio(state, image_prompt_type_radio):
|
| 2437 |
+
if args.multiple_images:
|
| 2438 |
+
return gr.Gallery(visible = (image_prompt_type_radio == 1) )
|
| 2439 |
+
else:
|
| 2440 |
+
return gr.Image(visible = (image_prompt_type_radio == 1) )
|
| 2441 |
+
|
| 2442 |
+
def refresh_t2v_image_prompt_type_radio(state, image_prompt_type_radio):
|
| 2443 |
+
vace_model = "Vace" in state["image_input_type_model"] and not state["image2video"]
|
| 2444 |
+
return gr.Column(visible= vace_model), gr.Radio(value= image_prompt_type_radio), gr.Gallery(visible = "I" in image_prompt_type_radio), gr.Video(visible= "V" in image_prompt_type_radio),gr.Video(visible= "M" in image_prompt_type_radio ), gr.Text(visible= "V" in image_prompt_type_radio) , gr.Checkbox(visible= "I" in image_prompt_type_radio)
|
| 2445 |
+
|
| 2446 |
+
def check_refresh_input_type(state):
|
| 2447 |
+
if not state["image2video"]:
|
| 2448 |
+
model_file_name = state["image_input_type_model"]
|
| 2449 |
+
model_file_needed= model_needed(False)
|
| 2450 |
+
if model_file_name != model_file_needed:
|
| 2451 |
+
state["image_input_type_model"] = model_file_needed
|
| 2452 |
+
return gr.Text(value= str(time.time()))
|
| 2453 |
+
return gr.Text()
|
| 2454 |
+
|
| 2455 |
def generate_video_tab(image2video=False):
|
| 2456 |
filename = transformer_filename_i2v if image2video else transformer_filename_t2v
|
| 2457 |
ui_defaults= get_default_settings(filename, image2video)
|
|
|
|
| 2460 |
|
| 2461 |
state_dict["advanced"] = advanced
|
| 2462 |
state_dict["loras_model"] = filename
|
| 2463 |
+
state_dict["image_input_type_model"] = filename
|
| 2464 |
state_dict["image2video"] = image2video
|
| 2465 |
gen = dict()
|
| 2466 |
gen["queue"] = []
|
|
|
|
| 2535 |
save_lset_btn = gr.Button("Save", size="sm", min_width= 1)
|
| 2536 |
delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1)
|
| 2537 |
cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2538 |
|
| 2539 |
+
state = gr.State(state_dict)
|
| 2540 |
+
vace_model = "Vace" in filename and not image2video
|
| 2541 |
+
trigger_refresh_input_type = gr.Text(interactive= False, visible= False)
|
| 2542 |
+
with gr.Column(visible= image2video or vace_model) as image_prompt_column:
|
| 2543 |
+
if image2video:
|
| 2544 |
+
image_source3 = gr.Video(label= "Placeholder", visible= image2video and False)
|
| 2545 |
+
|
| 2546 |
+
image_prompt_type= ui_defaults.get("image_prompt_type",0)
|
| 2547 |
+
image_prompt_type_radio = gr.Radio( [("Use only a Start Image", 0),("Use both a Start and an End Image", 1)], value =image_prompt_type, label="Location", show_label= False, scale= 3)
|
| 2548 |
+
|
| 2549 |
+
if args.multiple_images:
|
| 2550 |
+
image_source1 = gr.Gallery(
|
| 2551 |
+
label="Images as starting points for new videos", type ="pil", #file_types= "image",
|
| 2552 |
+
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True)
|
| 2553 |
+
else:
|
| 2554 |
+
image_source1 = gr.Image(label= "Image as a starting point for a new video", type ="pil")
|
| 2555 |
|
| 2556 |
+
if args.multiple_images:
|
| 2557 |
+
image_source2 = gr.Gallery(
|
| 2558 |
+
label="Images as ending points for new videos", type ="pil", #file_types= "image",
|
| 2559 |
+
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible=image_prompt_type==1)
|
| 2560 |
+
else:
|
| 2561 |
+
image_source2 = gr.Image(label= "Last Image for a new video", type ="pil", visible=image_prompt_type==1)
|
| 2562 |
+
|
| 2563 |
+
|
| 2564 |
+
image_prompt_type_radio.change(fn=refresh_i2v_image_prompt_type_radio, inputs=[state, image_prompt_type_radio], outputs=[image_source2])
|
| 2565 |
+
max_frames = gr.Slider(1, 100,step=1, visible = False)
|
| 2566 |
+
remove_background_image_ref = gr.Text(visible = False)
|
| 2567 |
else:
|
| 2568 |
+
image_prompt_type= ui_defaults.get("image_prompt_type","I")
|
| 2569 |
+
image_prompt_type_radio = gr.Radio( [("Use Images Ref", "I"),("a Video", "V"), ("Images + a Video", "IV"), ("Video + Video Mask", "VM"), ("Images + Video + Mask", "IVM")], value =image_prompt_type, label="Location", show_label= False, scale= 3, visible = vace_model)
|
| 2570 |
+
image_source1 = gr.Gallery(
|
| 2571 |
+
label="Reference Images of Faces and / or Object to be found in the Video", type ="pil",
|
| 2572 |
+
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in image_prompt_type )
|
| 2573 |
|
| 2574 |
+
image_source2 = gr.Video(label= "Reference Video", visible= "V" in image_prompt_type )
|
| 2575 |
+
with gr.Row():
|
| 2576 |
+
max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Reference Video to use in Video (0 for as many as possible)", visible= "V" in image_prompt_type, scale = 2 )
|
| 2577 |
+
remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Images Ref. Background", visible= "I" in image_prompt_type, scale =1 )
|
| 2578 |
+
|
| 2579 |
+
image_source3 = gr.Video(label= "Video Mask (white pixels = Mask)", visible= "M" in image_prompt_type )
|
| 2580 |
+
|
| 2581 |
+
|
| 2582 |
+
gr.on(triggers=[image_prompt_type_radio.change, trigger_refresh_input_type.change], fn=refresh_t2v_image_prompt_type_radio, inputs=[state, image_prompt_type_radio], outputs=[image_prompt_column, image_prompt_type_radio, image_source1, image_source2, image_source3, max_frames, remove_background_image_ref])
|
| 2583 |
|
| 2584 |
|
| 2585 |
advanced_prompt = advanced
|
|
|
|
| 2612 |
wizard_prompt = gr.Textbox(visible = not advanced_prompt, label="Prompts (each new line of prompt will generate a new video, # lines = comments)", value=default_wizard_prompt, lines=3)
|
| 2613 |
wizard_prompt_activated_var = gr.Text(wizard_prompt_activated, visible= False)
|
| 2614 |
wizard_variables_var = gr.Text(wizard_variables, visible = False)
|
|
|
|
| 2615 |
with gr.Row():
|
| 2616 |
if image2video:
|
| 2617 |
resolution = gr.Dropdown(
|
|
|
|
| 2648 |
video_length = gr.Slider(5, 193, value=ui_defaults["video_length"], step=4, label="Number of frames (16 = 1s)")
|
| 2649 |
with gr.Column():
|
| 2650 |
num_inference_steps = gr.Slider(1, 100, value=ui_defaults["num_inference_steps"], step=1, label="Number of Inference Steps")
|
|
|
|
|
|
|
| 2651 |
show_advanced = gr.Checkbox(label="Advanced Mode", value=advanced)
|
| 2652 |
with gr.Row(visible=advanced) as advanced_row:
|
| 2653 |
with gr.Column():
|
|
|
|
| 2696 |
tea_cache_start_step_perc = gr.Slider(0, 100, value=ui_defaults["tea_cache_start_step_perc"], step=1, label="Tea Cache starting moment in % of generation")
|
| 2697 |
|
| 2698 |
with gr.Row():
|
| 2699 |
+
gr.Markdown("<B>Upsampling - postprocessing that may improve fluidity and the size of the video</B>")
|
| 2700 |
with gr.Row():
|
| 2701 |
temporal_upsampling_choice = gr.Dropdown(
|
| 2702 |
choices=[
|
|
|
|
| 2778 |
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
|
| 2779 |
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
|
| 2780 |
with gr.Column():
|
| 2781 |
+
gen_status = gr.Text(interactive= False)
|
| 2782 |
+
full_sync = gr.Text(interactive= False, visible= False)
|
| 2783 |
+
light_sync = gr.Text(interactive= False, visible= False)
|
| 2784 |
+
|
| 2785 |
gen_progress_html = gr.HTML(
|
| 2786 |
label="Status",
|
| 2787 |
value="Idle",
|
|
|
|
| 2801 |
abort_btn = gr.Button("Abort")
|
| 2802 |
|
| 2803 |
queue_df = gr.DataFrame(
|
| 2804 |
+
headers=["Qty","Prompt", "Length","Steps","", "", "", "", ""],
|
| 2805 |
+
datatype=[ "str","markdown","str", "markdown", "markdown", "markdown", "str", "str", "str"],
|
| 2806 |
column_widths= ["50","", "65","55", "60", "60", "30", "30", "35"],
|
| 2807 |
interactive=False,
|
| 2808 |
col_count=(9, "fixed"),
|
|
|
|
| 2884 |
show_progress="hidden"
|
| 2885 |
)
|
| 2886 |
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
|
| 2887 |
+
save_settings, inputs = [state, prompt, image_prompt_type_radio, max_frames, remove_background_image_ref, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
|
| 2888 |
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, temporal_upsampling_choice, spatial_upsampling_choice, RIFLEx_setting, slg_switch, slg_layers,
|
| 2889 |
slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = [])
|
| 2890 |
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
|
|
|
| 2900 |
refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
|
| 2901 |
output.select(select_video, state, None )
|
| 2902 |
|
| 2903 |
+
|
| 2904 |
+
|
| 2905 |
gen_status.change(refresh_gallery,
|
| 2906 |
inputs = [state, gen_status],
|
| 2907 |
outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn])
|
| 2908 |
|
| 2909 |
+
full_sync.change(fn= check_refresh_input_type,
|
| 2910 |
+
inputs= [state],
|
| 2911 |
+
outputs= [trigger_refresh_input_type]
|
| 2912 |
+
).then(fn=refresh_gallery,
|
| 2913 |
inputs = [state, gen_status],
|
| 2914 |
outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
|
| 2915 |
+
).then(fn=wait_tasks_done,
|
| 2916 |
inputs= [state],
|
| 2917 |
outputs =[gen_status],
|
| 2918 |
).then(finalize_generation,
|
| 2919 |
inputs= [state],
|
| 2920 |
outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
|
| 2921 |
)
|
| 2922 |
+
|
| 2923 |
+
light_sync.change(fn= check_refresh_input_type,
|
| 2924 |
+
inputs= [state],
|
| 2925 |
+
outputs= [trigger_refresh_input_type]
|
| 2926 |
+
).then(fn=refresh_gallery,
|
| 2927 |
inputs = [state, gen_status],
|
| 2928 |
outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, queue_df, abort_btn]
|
| 2929 |
)
|
|
|
|
| 2949 |
loras_choices,
|
| 2950 |
loras_mult_choices,
|
| 2951 |
image_prompt_type_radio,
|
| 2952 |
+
image_source1,
|
| 2953 |
+
image_source2,
|
| 2954 |
+
image_source3,
|
| 2955 |
max_frames,
|
| 2956 |
+
remove_background_image_ref,
|
| 2957 |
temporal_upsampling_choice,
|
| 2958 |
spatial_upsampling_choice,
|
| 2959 |
RIFLEx_setting,
|
|
|
|
| 3004 |
)
|
| 3005 |
return loras_column, loras_choices, presets_column, lset_name, header, light_sync, full_sync, state
|
| 3006 |
|
| 3007 |
+
def generate_download_tab(presets_column, loras_column, lset_name,loras_choices, state):
|
| 3008 |
with gr.Row():
|
| 3009 |
with gr.Row(scale =2):
|
| 3010 |
gr.Markdown("<I>Wan2GP's Lora Festival ! Press the following button to download i2v <B>Remade</B> Loras collection (and bonuses Loras).")
|
|
|
|
| 3030 |
("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0),
|
| 3031 |
("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1),
|
| 3032 |
("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2),
|
| 3033 |
+
("WAN 2.1 VACE 1.3B Text to Video / Control Net - text generation driven by reference images or videos", 3),
|
| 3034 |
],
|
| 3035 |
value= index,
|
| 3036 |
label="Transformer model for Text to Video",
|
|
|
|
| 3211 |
t2v_light_sync = gr.Text()
|
| 3212 |
i2v_full_sync = gr.Text()
|
| 3213 |
t2v_full_sync = gr.Text()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3214 |
|
| 3215 |
+
last_tab_was_image2video =global_state.get("last_tab_was_image2video", None)
|
| 3216 |
+
if last_tab_was_image2video == None or last_tab_was_image2video:
|
| 3217 |
+
gen = i2v_state["gen"]
|
| 3218 |
+
t2v_state["gen"] = gen
|
| 3219 |
+
else:
|
| 3220 |
+
gen = t2v_state["gen"]
|
| 3221 |
+
i2v_state["gen"] = gen
|
| 3222 |
|
| 3223 |
+
|
| 3224 |
+
if new_t2v or new_i2v:
|
| 3225 |
if last_tab_was_image2video != None and new_t2v != new_i2v:
|
| 3226 |
gen_location = gen.get("location", None)
|
| 3227 |
if "in_progress" in gen and gen_location !=None and not (gen_location and new_i2v or not gen_location and new_t2v) :
|
|
|
|
| 3235 |
else:
|
| 3236 |
t2v_light_sync = gr.Text(str(time.time()))
|
| 3237 |
|
|
|
|
| 3238 |
global_state["last_tab_was_image2video"] = new_i2v
|
| 3239 |
|
| 3240 |
if(server_config.get("reload_model",2) == 1):
|
|
|
|
| 3536 |
}
|
| 3537 |
"""
|
| 3538 |
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
|
| 3539 |
+
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
| 3540 |
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
|
| 3541 |
|
| 3542 |
with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
|
|
|
|
| 3557 |
i2v_loras_column, i2v_loras_choices, i2v_presets_column, i2v_lset_name, i2v_header, i2v_light_sync, i2v_full_sync, i2v_state = generate_video_tab(True)
|
| 3558 |
if not args.lock_config:
|
| 3559 |
with gr.Tab("Downloads", id="downloads") as downloads_tab:
|
| 3560 |
+
generate_download_tab(i2v_presets_column, i2v_loras_column, i2v_lset_name, i2v_loras_choices, i2v_state)
|
| 3561 |
with gr.Tab("Configuration"):
|
| 3562 |
generate_configuration_tab()
|
| 3563 |
with gr.Tab("About"):
|
requirements.txt
CHANGED
|
@@ -11,11 +11,15 @@ easydict
|
|
| 11 |
ftfy
|
| 12 |
dashscope
|
| 13 |
imageio-ffmpeg
|
| 14 |
-
# flash_attn
|
| 15 |
gradio>=5.0.0
|
| 16 |
numpy>=1.23.5,<2
|
| 17 |
einops
|
| 18 |
moviepy==1.0.3
|
| 19 |
mmgp==3.3.4
|
| 20 |
peft==0.14.0
|
| 21 |
-
mutagen
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
ftfy
|
| 12 |
dashscope
|
| 13 |
imageio-ffmpeg
|
| 14 |
+
# flash_attn
|
| 15 |
gradio>=5.0.0
|
| 16 |
numpy>=1.23.5,<2
|
| 17 |
einops
|
| 18 |
moviepy==1.0.3
|
| 19 |
mmgp==3.3.4
|
| 20 |
peft==0.14.0
|
| 21 |
+
mutagen
|
| 22 |
+
decord
|
| 23 |
+
onnxruntime-gpu
|
| 24 |
+
rembg[gpu]==2.0.65
|
| 25 |
+
# rembg==2.0.65
|
wan/configs/__init__.py
CHANGED
|
@@ -40,3 +40,17 @@ SUPPORTED_SIZES = {
|
|
| 40 |
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
| 41 |
't2i-14B': tuple(SIZE_CONFIGS.keys()),
|
| 42 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
|
| 41 |
't2i-14B': tuple(SIZE_CONFIGS.keys()),
|
| 42 |
}
|
| 43 |
+
|
| 44 |
+
VACE_SIZE_CONFIGS = {
|
| 45 |
+
'480*832': (480, 832),
|
| 46 |
+
'832*480': (832, 480),
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
VACE_MAX_AREA_CONFIGS = {
|
| 50 |
+
'480*832': 480 * 832,
|
| 51 |
+
'832*480': 832 * 480,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
VACE_SUPPORTED_SIZES = {
|
| 55 |
+
'vace-1.3B': ('480*832', '832*480'),
|
| 56 |
+
}
|
wan/modules/model.py
CHANGED
|
@@ -377,6 +377,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|
| 377 |
return x
|
| 378 |
|
| 379 |
|
|
|
|
| 380 |
WAN_CROSSATTENTION_CLASSES = {
|
| 381 |
't2v_cross_attn': WanT2VCrossAttention,
|
| 382 |
'i2v_cross_attn': WanI2VCrossAttention,
|
|
@@ -393,7 +394,9 @@ class WanAttentionBlock(nn.Module):
|
|
| 393 |
window_size=(-1, -1),
|
| 394 |
qk_norm=True,
|
| 395 |
cross_attn_norm=False,
|
| 396 |
-
eps=1e-6
|
|
|
|
|
|
|
| 397 |
super().__init__()
|
| 398 |
self.dim = dim
|
| 399 |
self.ffn_dim = ffn_dim
|
|
@@ -422,6 +425,7 @@ class WanAttentionBlock(nn.Module):
|
|
| 422 |
|
| 423 |
# modulation
|
| 424 |
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
|
|
|
| 425 |
|
| 426 |
def forward(
|
| 427 |
self,
|
|
@@ -432,6 +436,8 @@ class WanAttentionBlock(nn.Module):
|
|
| 432 |
freqs,
|
| 433 |
context,
|
| 434 |
context_lens,
|
|
|
|
|
|
|
| 435 |
):
|
| 436 |
r"""
|
| 437 |
Args:
|
|
@@ -480,10 +486,49 @@ class WanAttentionBlock(nn.Module):
|
|
| 480 |
x.addcmul_(y, e[5])
|
| 481 |
|
| 482 |
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
class Head(nn.Module):
|
| 488 |
|
| 489 |
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
|
@@ -544,6 +589,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 544 |
|
| 545 |
@register_to_config
|
| 546 |
def __init__(self,
|
|
|
|
|
|
|
| 547 |
model_type='t2v',
|
| 548 |
patch_size=(1, 2, 2),
|
| 549 |
text_len=512,
|
|
@@ -628,12 +675,13 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 628 |
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
| 629 |
|
| 630 |
# blocks
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
|
|
|
| 637 |
|
| 638 |
# head
|
| 639 |
self.head = Head(dim, out_dim, patch_size, eps)
|
|
@@ -646,6 +694,33 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 646 |
# initialize weights
|
| 647 |
self.init_weights()
|
| 648 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 649 |
|
| 650 |
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
| 651 |
rescale_func = np.poly1d(self.coefficients)
|
|
@@ -688,6 +763,36 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 688 |
self.rel_l1_thresh = best_threshold
|
| 689 |
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
|
| 690 |
return best_threshold
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 691 |
|
| 692 |
def forward(
|
| 693 |
self,
|
|
@@ -695,6 +800,8 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 695 |
t,
|
| 696 |
context,
|
| 697 |
seq_len,
|
|
|
|
|
|
|
| 698 |
clip_fea=None,
|
| 699 |
y=None,
|
| 700 |
freqs = None,
|
|
@@ -829,13 +936,23 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 829 |
self.previous_residual_cond = None
|
| 830 |
ori_hidden_states = x_list[0].clone()
|
| 831 |
# arguments
|
|
|
|
| 832 |
kwargs = dict(
|
| 833 |
-
# e=e0,
|
| 834 |
seq_lens=seq_lens,
|
| 835 |
grid_sizes=grid_sizes,
|
| 836 |
freqs=freqs,
|
| 837 |
-
# context=context,
|
| 838 |
context_lens=context_lens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 839 |
for block_idx, block in enumerate(self.blocks):
|
| 840 |
offload.shared_state["layer"] = block_idx
|
| 841 |
if callback != None:
|
|
@@ -852,9 +969,10 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 852 |
x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs)
|
| 853 |
|
| 854 |
else:
|
| 855 |
-
for i, (x, context) in enumerate(zip(x_list, context_list)):
|
| 856 |
-
x_list[i] = block(x, context = context, e= e0, **kwargs)
|
| 857 |
del x
|
|
|
|
| 858 |
|
| 859 |
if self.enable_teacache:
|
| 860 |
if joint_pass:
|
|
|
|
| 377 |
return x
|
| 378 |
|
| 379 |
|
| 380 |
+
|
| 381 |
WAN_CROSSATTENTION_CLASSES = {
|
| 382 |
't2v_cross_attn': WanT2VCrossAttention,
|
| 383 |
'i2v_cross_attn': WanI2VCrossAttention,
|
|
|
|
| 394 |
window_size=(-1, -1),
|
| 395 |
qk_norm=True,
|
| 396 |
cross_attn_norm=False,
|
| 397 |
+
eps=1e-6,
|
| 398 |
+
block_id=None
|
| 399 |
+
):
|
| 400 |
super().__init__()
|
| 401 |
self.dim = dim
|
| 402 |
self.ffn_dim = ffn_dim
|
|
|
|
| 425 |
|
| 426 |
# modulation
|
| 427 |
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 428 |
+
self.block_id = block_id
|
| 429 |
|
| 430 |
def forward(
|
| 431 |
self,
|
|
|
|
| 436 |
freqs,
|
| 437 |
context,
|
| 438 |
context_lens,
|
| 439 |
+
hints= None,
|
| 440 |
+
context_scale=1.0,
|
| 441 |
):
|
| 442 |
r"""
|
| 443 |
Args:
|
|
|
|
| 486 |
x.addcmul_(y, e[5])
|
| 487 |
|
| 488 |
|
| 489 |
+
if self.block_id is not None and hints != None:
|
| 490 |
+
if context_scale == 1:
|
| 491 |
+
x.add_(hints[self.block_id])
|
| 492 |
+
else:
|
| 493 |
+
x.add_(hints[self.block_id], alpha =context_scale)
|
| 494 |
+
return x
|
| 495 |
+
|
| 496 |
+
class VaceWanAttentionBlock(WanAttentionBlock):
|
| 497 |
+
def __init__(
|
| 498 |
+
self,
|
| 499 |
+
cross_attn_type,
|
| 500 |
+
dim,
|
| 501 |
+
ffn_dim,
|
| 502 |
+
num_heads,
|
| 503 |
+
window_size=(-1, -1),
|
| 504 |
+
qk_norm=True,
|
| 505 |
+
cross_attn_norm=False,
|
| 506 |
+
eps=1e-6,
|
| 507 |
+
block_id=0
|
| 508 |
+
):
|
| 509 |
+
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
|
| 510 |
+
self.block_id = block_id
|
| 511 |
+
if block_id == 0:
|
| 512 |
+
self.before_proj = nn.Linear(self.dim, self.dim)
|
| 513 |
+
nn.init.zeros_(self.before_proj.weight)
|
| 514 |
+
nn.init.zeros_(self.before_proj.bias)
|
| 515 |
+
self.after_proj = nn.Linear(self.dim, self.dim)
|
| 516 |
+
nn.init.zeros_(self.after_proj.weight)
|
| 517 |
+
nn.init.zeros_(self.after_proj.bias)
|
| 518 |
+
|
| 519 |
+
def forward(self, c, x, **kwargs):
|
| 520 |
+
# behold dbm magic !
|
| 521 |
+
if self.block_id == 0:
|
| 522 |
+
c = self.before_proj(c) + x
|
| 523 |
+
all_c = []
|
| 524 |
+
else:
|
| 525 |
+
all_c = c
|
| 526 |
+
c = all_c.pop(-1)
|
| 527 |
+
c = super().forward(c, **kwargs)
|
| 528 |
+
c_skip = self.after_proj(c)
|
| 529 |
+
all_c += [c_skip, c]
|
| 530 |
+
return all_c
|
| 531 |
+
|
| 532 |
class Head(nn.Module):
|
| 533 |
|
| 534 |
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
|
|
|
|
| 589 |
|
| 590 |
@register_to_config
|
| 591 |
def __init__(self,
|
| 592 |
+
vace_layers=None,
|
| 593 |
+
vace_in_dim=None,
|
| 594 |
model_type='t2v',
|
| 595 |
patch_size=(1, 2, 2),
|
| 596 |
text_len=512,
|
|
|
|
| 675 |
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
|
| 676 |
|
| 677 |
# blocks
|
| 678 |
+
if vace_layers == None:
|
| 679 |
+
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
| 680 |
+
self.blocks = nn.ModuleList([
|
| 681 |
+
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
|
| 682 |
+
window_size, qk_norm, cross_attn_norm, eps)
|
| 683 |
+
for _ in range(num_layers)
|
| 684 |
+
])
|
| 685 |
|
| 686 |
# head
|
| 687 |
self.head = Head(dim, out_dim, patch_size, eps)
|
|
|
|
| 694 |
# initialize weights
|
| 695 |
self.init_weights()
|
| 696 |
|
| 697 |
+
if vace_layers != None:
|
| 698 |
+
self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
|
| 699 |
+
self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
|
| 700 |
+
|
| 701 |
+
assert 0 in self.vace_layers
|
| 702 |
+
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
|
| 703 |
+
|
| 704 |
+
# blocks
|
| 705 |
+
self.blocks = nn.ModuleList([
|
| 706 |
+
WanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
| 707 |
+
self.cross_attn_norm, self.eps,
|
| 708 |
+
block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
|
| 709 |
+
for i in range(self.num_layers)
|
| 710 |
+
])
|
| 711 |
+
|
| 712 |
+
# vace blocks
|
| 713 |
+
self.vace_blocks = nn.ModuleList([
|
| 714 |
+
VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
|
| 715 |
+
self.cross_attn_norm, self.eps, block_id=i)
|
| 716 |
+
for i in self.vace_layers
|
| 717 |
+
])
|
| 718 |
+
|
| 719 |
+
# vace patch embeddings
|
| 720 |
+
self.vace_patch_embedding = nn.Conv3d(
|
| 721 |
+
self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
|
| 725 |
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
| 726 |
rescale_func = np.poly1d(self.coefficients)
|
|
|
|
| 763 |
self.rel_l1_thresh = best_threshold
|
| 764 |
print(f"Tea Cache, best threshold found:{best_threshold:0.2f} with gain x{len(timesteps)/(target_nb_steps - best_signed_diff):0.2f} for a target of x{speed_factor}")
|
| 765 |
return best_threshold
|
| 766 |
+
|
| 767 |
+
def forward_vace(
|
| 768 |
+
self,
|
| 769 |
+
x,
|
| 770 |
+
vace_context,
|
| 771 |
+
seq_len,
|
| 772 |
+
context,
|
| 773 |
+
e,
|
| 774 |
+
kwargs
|
| 775 |
+
):
|
| 776 |
+
# embeddings
|
| 777 |
+
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
| 778 |
+
c = [u.flatten(2).transpose(1, 2) for u in c]
|
| 779 |
+
if (len(c) == 1 and seq_len == c[0].size(1)):
|
| 780 |
+
c = c[0]
|
| 781 |
+
else:
|
| 782 |
+
c = torch.cat([
|
| 783 |
+
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
|
| 784 |
+
dim=1) for u in c
|
| 785 |
+
])
|
| 786 |
+
|
| 787 |
+
# arguments
|
| 788 |
+
new_kwargs = dict(x=x)
|
| 789 |
+
new_kwargs.update(kwargs)
|
| 790 |
+
|
| 791 |
+
for block in self.vace_blocks:
|
| 792 |
+
c = block(c, context= context, e= e, **new_kwargs)
|
| 793 |
+
hints = c[:-1]
|
| 794 |
+
|
| 795 |
+
return hints
|
| 796 |
|
| 797 |
def forward(
|
| 798 |
self,
|
|
|
|
| 800 |
t,
|
| 801 |
context,
|
| 802 |
seq_len,
|
| 803 |
+
vace_context = None,
|
| 804 |
+
vace_context_scale=1.0,
|
| 805 |
clip_fea=None,
|
| 806 |
y=None,
|
| 807 |
freqs = None,
|
|
|
|
| 936 |
self.previous_residual_cond = None
|
| 937 |
ori_hidden_states = x_list[0].clone()
|
| 938 |
# arguments
|
| 939 |
+
|
| 940 |
kwargs = dict(
|
|
|
|
| 941 |
seq_lens=seq_lens,
|
| 942 |
grid_sizes=grid_sizes,
|
| 943 |
freqs=freqs,
|
|
|
|
| 944 |
context_lens=context_lens)
|
| 945 |
+
|
| 946 |
+
if vace_context == None:
|
| 947 |
+
hints_list = [None ] *len(x_list)
|
| 948 |
+
else:
|
| 949 |
+
hints_list = []
|
| 950 |
+
for x, context in zip(x_list, context_list) :
|
| 951 |
+
hints_list.append( self.forward_vace(x, vace_context, seq_len, context= context, e= e0, kwargs= kwargs))
|
| 952 |
+
del x, context
|
| 953 |
+
kwargs['context_scale'] = vace_context_scale
|
| 954 |
+
|
| 955 |
+
|
| 956 |
for block_idx, block in enumerate(self.blocks):
|
| 957 |
offload.shared_state["layer"] = block_idx
|
| 958 |
if callback != None:
|
|
|
|
| 969 |
x_list[0] = block(x_list[0], context = context_list[0], e= e0, **kwargs)
|
| 970 |
|
| 971 |
else:
|
| 972 |
+
for i, (x, context, hints) in enumerate(zip(x_list, context_list, hints_list)):
|
| 973 |
+
x_list[i] = block(x, context = context, hints= hints, e= e0, **kwargs)
|
| 974 |
del x
|
| 975 |
+
del context, hints
|
| 976 |
|
| 977 |
if self.enable_teacache:
|
| 978 |
if joint_pass:
|
wan/text2video.py
CHANGED
|
@@ -13,7 +13,9 @@ import torch
|
|
| 13 |
import torch.cuda.amp as amp
|
| 14 |
import torch.distributed as dist
|
| 15 |
from tqdm import tqdm
|
| 16 |
-
|
|
|
|
|
|
|
| 17 |
from .distributed.fsdp import shard_model
|
| 18 |
from .modules.model import WanModel
|
| 19 |
from .modules.t5 import T5EncoderModel
|
|
@@ -22,6 +24,7 @@ from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
|
| 22 |
get_sampling_sigmas, retrieve_timesteps)
|
| 23 |
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 24 |
from wan.modules.posemb_layers import get_rotary_pos_embed
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
def optimized_scale(positive_flat, negative_flat):
|
|
@@ -105,8 +108,6 @@ class WanT2V:
|
|
| 105 |
|
| 106 |
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False)
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
self.model.eval().requires_grad_(False)
|
| 111 |
|
| 112 |
if use_usp:
|
|
@@ -132,8 +133,148 @@ class WanT2V:
|
|
| 132 |
|
| 133 |
self.sample_neg_prompt = config.sample_neg_prompt
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
def generate(self,
|
| 136 |
input_prompt,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
size=(1280, 720),
|
| 138 |
frame_num=81,
|
| 139 |
shift=5.0,
|
|
@@ -187,14 +328,6 @@ class WanT2V:
|
|
| 187 |
- W: Frame width from size)
|
| 188 |
"""
|
| 189 |
# preprocess
|
| 190 |
-
F = frame_num
|
| 191 |
-
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
| 192 |
-
size[1] // self.vae_stride[1],
|
| 193 |
-
size[0] // self.vae_stride[2])
|
| 194 |
-
|
| 195 |
-
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
| 196 |
-
(self.patch_size[1] * self.patch_size[2]) *
|
| 197 |
-
target_shape[1] / self.sp_size) * self.sp_size
|
| 198 |
|
| 199 |
if n_prompt == "":
|
| 200 |
n_prompt = self.sample_neg_prompt
|
|
@@ -213,6 +346,29 @@ class WanT2V:
|
|
| 213 |
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
| 214 |
context = [t.to(self.device) for t in context]
|
| 215 |
context_null = [t.to(self.device) for t in context_null]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
noise = [
|
| 218 |
torch.randn(
|
|
@@ -261,10 +417,12 @@ class WanT2V:
|
|
| 261 |
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
| 262 |
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
| 263 |
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
-
# arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
| 266 |
-
# arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
| 267 |
-
# arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
| 268 |
if self.model.enable_teacache:
|
| 269 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
| 270 |
if callback != None:
|
|
@@ -281,7 +439,7 @@ class WanT2V:
|
|
| 281 |
# self.model.to(self.device)
|
| 282 |
if joint_pass:
|
| 283 |
noise_pred_cond, noise_pred_uncond = self.model(
|
| 284 |
-
latent_model_input, t=timestep,current_step=i, slg_layers=slg_layers_local, **arg_both)
|
| 285 |
if self._interrupt:
|
| 286 |
return None
|
| 287 |
else:
|
|
@@ -329,7 +487,11 @@ class WanT2V:
|
|
| 329 |
self.model.cpu()
|
| 330 |
torch.cuda.empty_cache()
|
| 331 |
if self.rank == 0:
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
|
| 334 |
|
| 335 |
del noise, latents
|
|
|
|
| 13 |
import torch.cuda.amp as amp
|
| 14 |
import torch.distributed as dist
|
| 15 |
from tqdm import tqdm
|
| 16 |
+
from PIL import Image
|
| 17 |
+
import torchvision.transforms.functional as TF
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
from .distributed.fsdp import shard_model
|
| 20 |
from .modules.model import WanModel
|
| 21 |
from .modules.t5 import T5EncoderModel
|
|
|
|
| 24 |
get_sampling_sigmas, retrieve_timesteps)
|
| 25 |
from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
| 26 |
from wan.modules.posemb_layers import get_rotary_pos_embed
|
| 27 |
+
from .utils.vace_preprocessor import VaceVideoProcessor
|
| 28 |
|
| 29 |
|
| 30 |
def optimized_scale(positive_flat, negative_flat):
|
|
|
|
| 108 |
|
| 109 |
self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel, writable_tensors= False)
|
| 110 |
|
|
|
|
|
|
|
| 111 |
self.model.eval().requires_grad_(False)
|
| 112 |
|
| 113 |
if use_usp:
|
|
|
|
| 133 |
|
| 134 |
self.sample_neg_prompt = config.sample_neg_prompt
|
| 135 |
|
| 136 |
+
if "Vace" in model_filename:
|
| 137 |
+
self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
|
| 138 |
+
min_area=480*832,
|
| 139 |
+
max_area=480*832,
|
| 140 |
+
min_fps=config.sample_fps,
|
| 141 |
+
max_fps=config.sample_fps,
|
| 142 |
+
zero_start=True,
|
| 143 |
+
seq_len=32760,
|
| 144 |
+
keep_last=True)
|
| 145 |
+
|
| 146 |
+
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
|
| 147 |
+
if ref_images is None:
|
| 148 |
+
ref_images = [None] * len(frames)
|
| 149 |
+
else:
|
| 150 |
+
assert len(frames) == len(ref_images)
|
| 151 |
+
|
| 152 |
+
if masks is None:
|
| 153 |
+
latents = self.vae.encode(frames, tile_size = tile_size)
|
| 154 |
+
else:
|
| 155 |
+
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
|
| 156 |
+
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
|
| 157 |
+
inactive = self.vae.encode(inactive, tile_size = tile_size)
|
| 158 |
+
reactive = self.vae.encode(reactive, tile_size = tile_size)
|
| 159 |
+
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
|
| 160 |
+
|
| 161 |
+
cat_latents = []
|
| 162 |
+
for latent, refs in zip(latents, ref_images):
|
| 163 |
+
if refs is not None:
|
| 164 |
+
if masks is None:
|
| 165 |
+
ref_latent = self.vae.encode(refs, tile_size = tile_size)
|
| 166 |
+
else:
|
| 167 |
+
ref_latent = self.vae.encode(refs, tile_size = tile_size)
|
| 168 |
+
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
|
| 169 |
+
assert all([x.shape[1] == 1 for x in ref_latent])
|
| 170 |
+
latent = torch.cat([*ref_latent, latent], dim=1)
|
| 171 |
+
cat_latents.append(latent)
|
| 172 |
+
return cat_latents
|
| 173 |
+
|
| 174 |
+
def vace_encode_masks(self, masks, ref_images=None):
|
| 175 |
+
if ref_images is None:
|
| 176 |
+
ref_images = [None] * len(masks)
|
| 177 |
+
else:
|
| 178 |
+
assert len(masks) == len(ref_images)
|
| 179 |
+
|
| 180 |
+
result_masks = []
|
| 181 |
+
for mask, refs in zip(masks, ref_images):
|
| 182 |
+
c, depth, height, width = mask.shape
|
| 183 |
+
new_depth = int((depth + 3) // self.vae_stride[0])
|
| 184 |
+
height = 2 * (int(height) // (self.vae_stride[1] * 2))
|
| 185 |
+
width = 2 * (int(width) // (self.vae_stride[2] * 2))
|
| 186 |
+
|
| 187 |
+
# reshape
|
| 188 |
+
mask = mask[0, :, :, :]
|
| 189 |
+
mask = mask.view(
|
| 190 |
+
depth, height, self.vae_stride[1], width, self.vae_stride[1]
|
| 191 |
+
) # depth, height, 8, width, 8
|
| 192 |
+
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
|
| 193 |
+
mask = mask.reshape(
|
| 194 |
+
self.vae_stride[1] * self.vae_stride[2], depth, height, width
|
| 195 |
+
) # 8*8, depth, height, width
|
| 196 |
+
|
| 197 |
+
# interpolation
|
| 198 |
+
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
|
| 199 |
+
|
| 200 |
+
if refs is not None:
|
| 201 |
+
length = len(refs)
|
| 202 |
+
mask_pad = torch.zeros_like(mask[:, :length, :, :])
|
| 203 |
+
mask = torch.cat((mask_pad, mask), dim=1)
|
| 204 |
+
result_masks.append(mask)
|
| 205 |
+
return result_masks
|
| 206 |
+
|
| 207 |
+
def vace_latent(self, z, m):
|
| 208 |
+
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
| 209 |
+
|
| 210 |
+
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, trim_video= 0):
|
| 211 |
+
image_sizes = []
|
| 212 |
+
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
| 213 |
+
if sub_src_mask is not None and sub_src_video is not None:
|
| 214 |
+
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video)
|
| 215 |
+
src_video[i] = src_video[i].to(device)
|
| 216 |
+
src_mask[i] = src_mask[i].to(device)
|
| 217 |
+
src_video_shape = src_video[i].shape
|
| 218 |
+
if src_video_shape[1] != num_frames:
|
| 219 |
+
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
| 220 |
+
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
| 221 |
+
|
| 222 |
+
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
|
| 223 |
+
image_sizes.append(src_video[i].shape[2:])
|
| 224 |
+
elif sub_src_video is None:
|
| 225 |
+
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
|
| 226 |
+
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
| 227 |
+
image_sizes.append(image_size)
|
| 228 |
+
else:
|
| 229 |
+
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video)
|
| 230 |
+
src_video[i] = src_video[i].to(device)
|
| 231 |
+
src_video_shape = src_video[i].shape
|
| 232 |
+
if src_video_shape[1] != num_frames:
|
| 233 |
+
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
| 234 |
+
src_mask[i] = torch.ones_like(src_video[i], device=device)
|
| 235 |
+
image_sizes.append(src_video[i].shape[2:])
|
| 236 |
+
|
| 237 |
+
for i, ref_images in enumerate(src_ref_images):
|
| 238 |
+
if ref_images is not None:
|
| 239 |
+
image_size = image_sizes[i]
|
| 240 |
+
for j, ref_img in enumerate(ref_images):
|
| 241 |
+
if ref_img is not None:
|
| 242 |
+
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
|
| 243 |
+
if ref_img.shape[-2:] != image_size:
|
| 244 |
+
canvas_height, canvas_width = image_size
|
| 245 |
+
ref_height, ref_width = ref_img.shape[-2:]
|
| 246 |
+
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
|
| 247 |
+
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
| 248 |
+
new_height = int(ref_height * scale)
|
| 249 |
+
new_width = int(ref_width * scale)
|
| 250 |
+
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
|
| 251 |
+
top = (canvas_height - new_height) // 2
|
| 252 |
+
left = (canvas_width - new_width) // 2
|
| 253 |
+
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
|
| 254 |
+
ref_img = white_canvas
|
| 255 |
+
src_ref_images[i][j] = ref_img.to(device)
|
| 256 |
+
return src_video, src_mask, src_ref_images
|
| 257 |
+
|
| 258 |
+
def decode_latent(self, zs, ref_images=None, tile_size= 0 ):
|
| 259 |
+
if ref_images is None:
|
| 260 |
+
ref_images = [None] * len(zs)
|
| 261 |
+
else:
|
| 262 |
+
assert len(zs) == len(ref_images)
|
| 263 |
+
|
| 264 |
+
trimed_zs = []
|
| 265 |
+
for z, refs in zip(zs, ref_images):
|
| 266 |
+
if refs is not None:
|
| 267 |
+
z = z[:, len(refs):, :, :]
|
| 268 |
+
trimed_zs.append(z)
|
| 269 |
+
|
| 270 |
+
return self.vae.decode(trimed_zs, tile_size= tile_size)
|
| 271 |
+
|
| 272 |
def generate(self,
|
| 273 |
input_prompt,
|
| 274 |
+
input_frames= None,
|
| 275 |
+
input_masks = None,
|
| 276 |
+
input_ref_images = None,
|
| 277 |
+
context_scale=1.0,
|
| 278 |
size=(1280, 720),
|
| 279 |
frame_num=81,
|
| 280 |
shift=5.0,
|
|
|
|
| 328 |
- W: Frame width from size)
|
| 329 |
"""
|
| 330 |
# preprocess
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
|
| 332 |
if n_prompt == "":
|
| 333 |
n_prompt = self.sample_neg_prompt
|
|
|
|
| 346 |
context_null = self.text_encoder([n_prompt], torch.device('cpu'))
|
| 347 |
context = [t.to(self.device) for t in context]
|
| 348 |
context_null = [t.to(self.device) for t in context_null]
|
| 349 |
+
|
| 350 |
+
if input_frames != None:
|
| 351 |
+
# vace context encode
|
| 352 |
+
input_frames = [u.to(self.device) for u in input_frames]
|
| 353 |
+
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
|
| 354 |
+
input_masks = [u.to(self.device) for u in input_masks]
|
| 355 |
+
|
| 356 |
+
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size)
|
| 357 |
+
m0 = self.vace_encode_masks(input_masks, input_ref_images)
|
| 358 |
+
z = self.vace_latent(z0, m0)
|
| 359 |
+
|
| 360 |
+
target_shape = list(z0[0].shape)
|
| 361 |
+
target_shape[0] = int(target_shape[0] / 2)
|
| 362 |
+
else:
|
| 363 |
+
F = frame_num
|
| 364 |
+
target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
|
| 365 |
+
size[1] // self.vae_stride[1],
|
| 366 |
+
size[0] // self.vae_stride[2])
|
| 367 |
+
|
| 368 |
+
seq_len = math.ceil((target_shape[2] * target_shape[3]) /
|
| 369 |
+
(self.patch_size[1] * self.patch_size[2]) *
|
| 370 |
+
target_shape[1] / self.sp_size) * self.sp_size
|
| 371 |
+
|
| 372 |
|
| 373 |
noise = [
|
| 374 |
torch.randn(
|
|
|
|
| 417 |
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
| 418 |
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
| 419 |
arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
| 420 |
+
if input_frames != None:
|
| 421 |
+
vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale}
|
| 422 |
+
arg_c.update(vace_dict)
|
| 423 |
+
arg_null.update(vace_dict)
|
| 424 |
+
arg_both.update(vace_dict)
|
| 425 |
|
|
|
|
|
|
|
|
|
|
| 426 |
if self.model.enable_teacache:
|
| 427 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
| 428 |
if callback != None:
|
|
|
|
| 439 |
# self.model.to(self.device)
|
| 440 |
if joint_pass:
|
| 441 |
noise_pred_cond, noise_pred_uncond = self.model(
|
| 442 |
+
latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
|
| 443 |
if self._interrupt:
|
| 444 |
return None
|
| 445 |
else:
|
|
|
|
| 487 |
self.model.cpu()
|
| 488 |
torch.cuda.empty_cache()
|
| 489 |
if self.rank == 0:
|
| 490 |
+
|
| 491 |
+
if input_frames == None:
|
| 492 |
+
videos = self.vae.decode(x0, VAE_tile_size)
|
| 493 |
+
else:
|
| 494 |
+
videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
|
| 495 |
|
| 496 |
|
| 497 |
del noise, latents
|
wan/utils/utils.py
CHANGED
|
@@ -3,21 +3,70 @@ import argparse
|
|
| 3 |
import binascii
|
| 4 |
import os
|
| 5 |
import os.path as osp
|
|
|
|
|
|
|
| 6 |
|
| 7 |
import imageio
|
| 8 |
import torch
|
|
|
|
| 9 |
import torchvision
|
| 10 |
from PIL import Image
|
| 11 |
import numpy as np
|
|
|
|
|
|
|
| 12 |
|
| 13 |
__all__ = ['cache_video', 'cache_image', 'str2bool']
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def resize_lanczos(img, h, w):
|
| 16 |
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
|
| 17 |
img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
|
| 18 |
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def rand_name(length=8, suffix=''):
|
| 22 |
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
|
| 23 |
if suffix:
|
|
|
|
| 3 |
import binascii
|
| 4 |
import os
|
| 5 |
import os.path as osp
|
| 6 |
+
import torchvision.transforms.functional as TF
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
|
| 9 |
import imageio
|
| 10 |
import torch
|
| 11 |
+
import decord
|
| 12 |
import torchvision
|
| 13 |
from PIL import Image
|
| 14 |
import numpy as np
|
| 15 |
+
from rembg import remove, new_session
|
| 16 |
+
|
| 17 |
|
| 18 |
__all__ = ['cache_video', 'cache_image', 'str2bool']
|
| 19 |
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
from PIL import Image
|
| 23 |
+
|
| 24 |
+
def get_video_frame(file_name, frame_no):
|
| 25 |
+
decord.bridge.set_bridge('torch')
|
| 26 |
+
reader = decord.VideoReader(file_name)
|
| 27 |
+
|
| 28 |
+
frame = reader.get_batch([frame_no]).squeeze(0)
|
| 29 |
+
img = Image.fromarray(frame.numpy().astype(np.uint8))
|
| 30 |
+
return img
|
| 31 |
+
|
| 32 |
def resize_lanczos(img, h, w):
|
| 33 |
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
|
| 34 |
img = img.resize((w,h), resample=Image.Resampling.LANCZOS)
|
| 35 |
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
|
| 36 |
|
| 37 |
|
| 38 |
+
def remove_background(img, session=None):
|
| 39 |
+
if session ==None:
|
| 40 |
+
session = new_session()
|
| 41 |
+
img = Image.fromarray(np.clip(255. * img.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8))
|
| 42 |
+
img = remove(img, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
| 43 |
+
return torch.from_numpy(np.array(img).astype(np.float32) / 255.0).movedim(-1, 0)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def resize_and_remove_background(img_list, canvas_width, canvas_height, rm_background ):
|
| 49 |
+
if rm_background:
|
| 50 |
+
session = new_session()
|
| 51 |
+
|
| 52 |
+
output_list =[]
|
| 53 |
+
for img in img_list:
|
| 54 |
+
width, height = img.size
|
| 55 |
+
white_canvas = np.full( (canvas_height, canvas_width, 3), 255, dtype= np.uint8 )
|
| 56 |
+
scale = min(canvas_height / height, canvas_width / width)
|
| 57 |
+
new_height = int(height * scale)
|
| 58 |
+
new_width = int(width * scale)
|
| 59 |
+
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
|
| 60 |
+
if rm_background:
|
| 61 |
+
resized_image = remove(resized_image, session=session, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
|
| 62 |
+
top = (canvas_height - new_height) // 2
|
| 63 |
+
left = (canvas_width - new_width) // 2
|
| 64 |
+
white_canvas[top:top + new_height, left:left + new_width, :] = np.array(resized_image)
|
| 65 |
+
img = Image.fromarray(white_canvas)
|
| 66 |
+
output_list.append(img)
|
| 67 |
+
return output_list
|
| 68 |
+
|
| 69 |
+
|
| 70 |
def rand_name(length=8, suffix=''):
|
| 71 |
name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')
|
| 72 |
if suffix:
|
wan/utils/vace_preprocessor.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torchvision.transforms.functional as TF
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class VaceImageProcessor(object):
|
| 11 |
+
def __init__(self, downsample=None, seq_len=None):
|
| 12 |
+
self.downsample = downsample
|
| 13 |
+
self.seq_len = seq_len
|
| 14 |
+
|
| 15 |
+
def _pillow_convert(self, image, cvt_type='RGB'):
|
| 16 |
+
if image.mode != cvt_type:
|
| 17 |
+
if image.mode == 'P':
|
| 18 |
+
image = image.convert(f'{cvt_type}A')
|
| 19 |
+
if image.mode == f'{cvt_type}A':
|
| 20 |
+
bg = Image.new(cvt_type,
|
| 21 |
+
size=(image.width, image.height),
|
| 22 |
+
color=(255, 255, 255))
|
| 23 |
+
bg.paste(image, (0, 0), mask=image)
|
| 24 |
+
image = bg
|
| 25 |
+
else:
|
| 26 |
+
image = image.convert(cvt_type)
|
| 27 |
+
return image
|
| 28 |
+
|
| 29 |
+
def _load_image(self, img_path):
|
| 30 |
+
if img_path is None or img_path == '':
|
| 31 |
+
return None
|
| 32 |
+
img = Image.open(img_path)
|
| 33 |
+
img = self._pillow_convert(img)
|
| 34 |
+
return img
|
| 35 |
+
|
| 36 |
+
def _resize_crop(self, img, oh, ow, normalize=True):
|
| 37 |
+
"""
|
| 38 |
+
Resize, center crop, convert to tensor, and normalize.
|
| 39 |
+
"""
|
| 40 |
+
# resize and crop
|
| 41 |
+
iw, ih = img.size
|
| 42 |
+
if iw != ow or ih != oh:
|
| 43 |
+
# resize
|
| 44 |
+
scale = max(ow / iw, oh / ih)
|
| 45 |
+
img = img.resize(
|
| 46 |
+
(round(scale * iw), round(scale * ih)),
|
| 47 |
+
resample=Image.Resampling.LANCZOS
|
| 48 |
+
)
|
| 49 |
+
assert img.width >= ow and img.height >= oh
|
| 50 |
+
|
| 51 |
+
# center crop
|
| 52 |
+
x1 = (img.width - ow) // 2
|
| 53 |
+
y1 = (img.height - oh) // 2
|
| 54 |
+
img = img.crop((x1, y1, x1 + ow, y1 + oh))
|
| 55 |
+
|
| 56 |
+
# normalize
|
| 57 |
+
if normalize:
|
| 58 |
+
img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1)
|
| 59 |
+
return img
|
| 60 |
+
|
| 61 |
+
def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs):
|
| 62 |
+
return self._resize_crop(img, oh, ow, normalize)
|
| 63 |
+
|
| 64 |
+
def load_image(self, data_key, **kwargs):
|
| 65 |
+
return self.load_image_batch(data_key, **kwargs)
|
| 66 |
+
|
| 67 |
+
def load_image_pair(self, data_key, data_key2, **kwargs):
|
| 68 |
+
return self.load_image_batch(data_key, data_key2, **kwargs)
|
| 69 |
+
|
| 70 |
+
def load_image_batch(self, *data_key_batch, normalize=True, seq_len=None, **kwargs):
|
| 71 |
+
seq_len = self.seq_len if seq_len is None else seq_len
|
| 72 |
+
imgs = []
|
| 73 |
+
for data_key in data_key_batch:
|
| 74 |
+
img = self._load_image(data_key)
|
| 75 |
+
imgs.append(img)
|
| 76 |
+
w, h = imgs[0].size
|
| 77 |
+
dh, dw = self.downsample[1:]
|
| 78 |
+
|
| 79 |
+
# compute output size
|
| 80 |
+
scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw))))
|
| 81 |
+
oh = int(h * scale) // dh * dh
|
| 82 |
+
ow = int(w * scale) // dw * dw
|
| 83 |
+
assert (oh // dh) * (ow // dw) <= seq_len
|
| 84 |
+
imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs]
|
| 85 |
+
return *imgs, (oh, ow)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class VaceVideoProcessor(object):
|
| 89 |
+
def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs):
|
| 90 |
+
self.downsample = downsample
|
| 91 |
+
self.min_area = min_area
|
| 92 |
+
self.max_area = max_area
|
| 93 |
+
self.min_fps = min_fps
|
| 94 |
+
self.max_fps = max_fps
|
| 95 |
+
self.zero_start = zero_start
|
| 96 |
+
self.keep_last = keep_last
|
| 97 |
+
self.seq_len = seq_len
|
| 98 |
+
assert seq_len >= min_area / (self.downsample[1] * self.downsample[2])
|
| 99 |
+
|
| 100 |
+
@staticmethod
|
| 101 |
+
def resize_crop(video: torch.Tensor, oh: int, ow: int):
|
| 102 |
+
"""
|
| 103 |
+
Resize, center crop and normalize for decord loaded video (torch.Tensor type)
|
| 104 |
+
|
| 105 |
+
Parameters:
|
| 106 |
+
video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C)
|
| 107 |
+
oh - target height (int)
|
| 108 |
+
ow - target width (int)
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W)
|
| 112 |
+
|
| 113 |
+
Raises:
|
| 114 |
+
"""
|
| 115 |
+
# permute ([t, h, w, c] -> [t, c, h, w])
|
| 116 |
+
video = video.permute(0, 3, 1, 2)
|
| 117 |
+
|
| 118 |
+
# resize and crop
|
| 119 |
+
ih, iw = video.shape[2:]
|
| 120 |
+
if ih != oh or iw != ow:
|
| 121 |
+
# resize
|
| 122 |
+
scale = max(ow / iw, oh / ih)
|
| 123 |
+
video = F.interpolate(
|
| 124 |
+
video,
|
| 125 |
+
size=(round(scale * ih), round(scale * iw)),
|
| 126 |
+
mode='bicubic',
|
| 127 |
+
antialias=True
|
| 128 |
+
)
|
| 129 |
+
assert video.size(3) >= ow and video.size(2) >= oh
|
| 130 |
+
|
| 131 |
+
# center crop
|
| 132 |
+
x1 = (video.size(3) - ow) // 2
|
| 133 |
+
y1 = (video.size(2) - oh) // 2
|
| 134 |
+
video = video[:, :, y1:y1 + oh, x1:x1 + ow]
|
| 135 |
+
|
| 136 |
+
# permute ([t, c, h, w] -> [c, t, h, w]) and normalize
|
| 137 |
+
video = video.transpose(0, 1).float().div_(127.5).sub_(1.)
|
| 138 |
+
return video
|
| 139 |
+
|
| 140 |
+
def _video_preprocess(self, video, oh, ow):
|
| 141 |
+
return self.resize_crop(video, oh, ow)
|
| 142 |
+
|
| 143 |
+
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng):
|
| 144 |
+
target_fps = min(fps, self.max_fps)
|
| 145 |
+
duration = frame_timestamps[-1].mean()
|
| 146 |
+
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
| 147 |
+
h, w = y2 - y1, x2 - x1
|
| 148 |
+
ratio = h / w
|
| 149 |
+
df, dh, dw = self.downsample
|
| 150 |
+
|
| 151 |
+
# min/max area of the [latent video]
|
| 152 |
+
min_area_z = self.min_area / (dh * dw)
|
| 153 |
+
max_area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
|
| 154 |
+
|
| 155 |
+
# sample a frame number of the [latent video]
|
| 156 |
+
rand_area_z = np.square(np.power(2, rng.uniform(
|
| 157 |
+
np.log2(np.sqrt(min_area_z)),
|
| 158 |
+
np.log2(np.sqrt(max_area_z))
|
| 159 |
+
)))
|
| 160 |
+
of = min(
|
| 161 |
+
(int(duration * target_fps) - 1) // df + 1,
|
| 162 |
+
int(self.seq_len / rand_area_z)
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# deduce target shape of the [latent video]
|
| 166 |
+
target_area_z = min(max_area_z, int(self.seq_len / of))
|
| 167 |
+
oh = round(np.sqrt(target_area_z * ratio))
|
| 168 |
+
ow = int(target_area_z / oh)
|
| 169 |
+
of = (of - 1) * df + 1
|
| 170 |
+
oh *= dh
|
| 171 |
+
ow *= dw
|
| 172 |
+
|
| 173 |
+
# sample frame ids
|
| 174 |
+
target_duration = of / target_fps
|
| 175 |
+
begin = 0. if self.zero_start else rng.uniform(0, duration - target_duration)
|
| 176 |
+
timestamps = np.linspace(begin, begin + target_duration, of)
|
| 177 |
+
frame_ids = np.argmax(np.logical_and(
|
| 178 |
+
timestamps[:, None] >= frame_timestamps[None, :, 0],
|
| 179 |
+
timestamps[:, None] < frame_timestamps[None, :, 1]
|
| 180 |
+
), axis=1).tolist()
|
| 181 |
+
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
| 182 |
+
|
| 183 |
+
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng, max_frames= 0):
|
| 184 |
+
import math
|
| 185 |
+
target_fps = self.max_fps
|
| 186 |
+
video_duration = frame_timestamps[-1][1]
|
| 187 |
+
video_frame_duration = 1 /fps
|
| 188 |
+
target_frame_duration = 1 / target_fps
|
| 189 |
+
|
| 190 |
+
cur_time = 0
|
| 191 |
+
target_time = 0
|
| 192 |
+
frame_no = 0
|
| 193 |
+
frame_ids =[]
|
| 194 |
+
for i in range(max_frames):
|
| 195 |
+
add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration )
|
| 196 |
+
frame_no += add_frames_count
|
| 197 |
+
frame_ids.append(frame_no)
|
| 198 |
+
cur_time += add_frames_count * video_frame_duration
|
| 199 |
+
target_time += target_frame_duration
|
| 200 |
+
if cur_time > video_duration:
|
| 201 |
+
break
|
| 202 |
+
|
| 203 |
+
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
| 204 |
+
h, w = y2 - y1, x2 - x1
|
| 205 |
+
ratio = h / w
|
| 206 |
+
df, dh, dw = self.downsample
|
| 207 |
+
seq_len = self.seq_len
|
| 208 |
+
# min/max area of the [latent video]
|
| 209 |
+
min_area_z = self.min_area / (dh * dw)
|
| 210 |
+
# max_area_z = min(seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
|
| 211 |
+
max_area_z = min_area_z # workaround bug
|
| 212 |
+
# sample a frame number of the [latent video]
|
| 213 |
+
rand_area_z = np.square(np.power(2, rng.uniform(
|
| 214 |
+
np.log2(np.sqrt(min_area_z)),
|
| 215 |
+
np.log2(np.sqrt(max_area_z))
|
| 216 |
+
)))
|
| 217 |
+
|
| 218 |
+
seq_len = max_area_z * ((max_frames- 1) // df +1)
|
| 219 |
+
|
| 220 |
+
# of = min(
|
| 221 |
+
# (len(frame_ids) - 1) // df + 1,
|
| 222 |
+
# int(seq_len / rand_area_z)
|
| 223 |
+
# )
|
| 224 |
+
of = (len(frame_ids) - 1) // df + 1
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# deduce target shape of the [latent video]
|
| 228 |
+
# target_area_z = min(max_area_z, int(seq_len / of))
|
| 229 |
+
target_area_z = max_area_z
|
| 230 |
+
oh = round(np.sqrt(target_area_z * ratio))
|
| 231 |
+
ow = int(target_area_z / oh)
|
| 232 |
+
of = (of - 1) * df + 1
|
| 233 |
+
oh *= dh
|
| 234 |
+
ow *= dw
|
| 235 |
+
|
| 236 |
+
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
| 237 |
+
|
| 238 |
+
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng, max_frames= 0):
|
| 239 |
+
if self.keep_last:
|
| 240 |
+
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng, max_frames= max_frames)
|
| 241 |
+
else:
|
| 242 |
+
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng, max_frames= max_frames)
|
| 243 |
+
|
| 244 |
+
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
|
| 245 |
+
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
|
| 246 |
+
|
| 247 |
+
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
|
| 248 |
+
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
|
| 249 |
+
|
| 250 |
+
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, max_frames= 0, trim_video =0, **kwargs):
|
| 251 |
+
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
|
| 252 |
+
# read video
|
| 253 |
+
import decord
|
| 254 |
+
decord.bridge.set_bridge('torch')
|
| 255 |
+
readers = []
|
| 256 |
+
for data_k in data_key_batch:
|
| 257 |
+
reader = decord.VideoReader(data_k)
|
| 258 |
+
readers.append(reader)
|
| 259 |
+
|
| 260 |
+
fps = readers[0].get_avg_fps()
|
| 261 |
+
length = min([len(r) for r in readers])
|
| 262 |
+
frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
|
| 263 |
+
frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
|
| 264 |
+
# # frame_timestamps = frame_timestamps[ :max_frames]
|
| 265 |
+
# if trim_video > 0:
|
| 266 |
+
# frame_timestamps = frame_timestamps[ :trim_video]
|
| 267 |
+
max_frames = min(max_frames, trim_video) if trim_video > 0 else max_frames
|
| 268 |
+
h, w = readers[0].next().shape[:2]
|
| 269 |
+
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng, max_frames=max_frames)
|
| 270 |
+
|
| 271 |
+
# preprocess video
|
| 272 |
+
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
|
| 273 |
+
videos = [self._video_preprocess(video, oh, ow) for video in videos]
|
| 274 |
+
return *videos, frame_ids, (oh, ow), fps
|
| 275 |
+
# return videos if len(videos) > 1 else videos[0]
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device):
|
| 279 |
+
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
| 280 |
+
if sub_src_video is None and sub_src_mask is None:
|
| 281 |
+
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
|
| 282 |
+
src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device)
|
| 283 |
+
for i, ref_images in enumerate(src_ref_images):
|
| 284 |
+
if ref_images is not None:
|
| 285 |
+
for j, ref_img in enumerate(ref_images):
|
| 286 |
+
if ref_img is not None and ref_img.shape[-2:] != image_size:
|
| 287 |
+
canvas_height, canvas_width = image_size
|
| 288 |
+
ref_height, ref_width = ref_img.shape[-2:]
|
| 289 |
+
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
|
| 290 |
+
scale = min(canvas_height / ref_height, canvas_width / ref_width)
|
| 291 |
+
new_height = int(ref_height * scale)
|
| 292 |
+
new_width = int(ref_width * scale)
|
| 293 |
+
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
|
| 294 |
+
top = (canvas_height - new_height) // 2
|
| 295 |
+
left = (canvas_width - new_width) // 2
|
| 296 |
+
white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
|
| 297 |
+
src_ref_images[i][j] = white_canvas
|
| 298 |
+
return src_video, src_mask, src_ref_images
|