Merge remote-tracking branch 'upstream/main' into queues
Browse files- README.md +6 -0
- gradio_server.py +115 -33
- requirements.txt +1 -1
- wan/image2video.py +30 -73
- wan/modules/attention.py +30 -27
- wan/modules/model.py +0 -20
- wan/modules/sage2_core.py +3 -3
- wan/text2video.py +34 -4
README.md
CHANGED
|
@@ -19,6 +19,12 @@ In this repository, we present **Wan2.1**, a comprehensive and open suite of vid
|
|
| 19 |
|
| 20 |
|
| 21 |
## 🔥 Latest News!!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
* Mar 19 2022: 👋 Wan2.1GP v3.1: Faster launch and RAM optimizations (should require less RAM to run)\
|
| 23 |
You will need one more *pip install -r requirements.txt*
|
| 24 |
* Mar 18 2022: 👋 Wan2.1GP v3.0:
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
## 🔥 Latest News!!
|
| 22 |
+
* Mar 19 2022: 👋 Wan2.1GP v3.2:
|
| 23 |
+
- Added Classifier-Free Guidance Zero Star. The video should match better the text prompt (especially with text2video) at no performance cost: many thanks to the **CFG Zero * Team:**\
|
| 24 |
+
Dont hesitate to give them a star if you appreciate the results: https://github.com/WeichenFan/CFG-Zero-star
|
| 25 |
+
- Added back support for Pytorch compilation with Loras. It seems it had been broken for some time
|
| 26 |
+
- Added possibility to keep a number of pregenerated videos in the Video Gallery (useful to compare outputs of different settings)
|
| 27 |
+
You will need one more *pip install -r requirements.txt*
|
| 28 |
* Mar 19 2022: 👋 Wan2.1GP v3.1: Faster launch and RAM optimizations (should require less RAM to run)\
|
| 29 |
You will need one more *pip install -r requirements.txt*
|
| 30 |
* Mar 18 2022: 👋 Wan2.1GP v3.0:
|
gradio_server.py
CHANGED
|
@@ -24,7 +24,7 @@ import asyncio
|
|
| 24 |
from wan.utils import prompt_parser
|
| 25 |
PROMPT_VARS_MAX = 10
|
| 26 |
|
| 27 |
-
target_mmgp_version = "3.3.
|
| 28 |
from importlib.metadata import version
|
| 29 |
mmgp_version = version("mmgp")
|
| 30 |
if mmgp_version != target_mmgp_version:
|
|
@@ -98,6 +98,7 @@ def process_prompt_and_add_tasks(
|
|
| 98 |
tea_cache_start_step_perc,
|
| 99 |
loras_choices,
|
| 100 |
loras_mult_choices,
|
|
|
|
| 101 |
image_to_continue,
|
| 102 |
image_to_end,
|
| 103 |
video_to_continue,
|
|
@@ -107,6 +108,8 @@ def process_prompt_and_add_tasks(
|
|
| 107 |
slg_layers,
|
| 108 |
slg_start,
|
| 109 |
slg_end,
|
|
|
|
|
|
|
| 110 |
state_arg,
|
| 111 |
image2video
|
| 112 |
):
|
|
@@ -138,6 +141,7 @@ def process_prompt_and_add_tasks(
|
|
| 138 |
tea_cache_start_step_perc,
|
| 139 |
loras_choices,
|
| 140 |
loras_mult_choices,
|
|
|
|
| 141 |
image_to_continue,
|
| 142 |
image_to_end,
|
| 143 |
video_to_continue,
|
|
@@ -147,6 +151,8 @@ def process_prompt_and_add_tasks(
|
|
| 147 |
slg_layers,
|
| 148 |
slg_start,
|
| 149 |
slg_end,
|
|
|
|
|
|
|
| 150 |
state_arg,
|
| 151 |
image2video
|
| 152 |
)
|
|
@@ -380,7 +386,6 @@ def _parse_args():
|
|
| 380 |
default="",
|
| 381 |
help="Server name"
|
| 382 |
)
|
| 383 |
-
|
| 384 |
parser.add_argument(
|
| 385 |
"--gpu",
|
| 386 |
type=str,
|
|
@@ -482,7 +487,6 @@ def get_lora_dir(i2v):
|
|
| 482 |
|
| 483 |
attention_modes_installed = get_attention_modes()
|
| 484 |
attention_modes_supported = get_supported_attention_modes()
|
| 485 |
-
|
| 486 |
args = _parse_args()
|
| 487 |
args.flow_reverse = True
|
| 488 |
|
|
@@ -513,6 +517,7 @@ if not Path(server_config_filename).is_file():
|
|
| 513 |
"metadata_type": "metadata",
|
| 514 |
"default_ui": "t2v",
|
| 515 |
"boost" : 1,
|
|
|
|
| 516 |
"vae_config": 0,
|
| 517 |
"profile" : profile_type.LowRAM_LowVRAM,
|
| 518 |
"reload_model": 2 }
|
|
@@ -596,7 +601,6 @@ if len(args.vae_config) > 0:
|
|
| 596 |
|
| 597 |
reload_needed = False
|
| 598 |
default_ui = server_config.get("default_ui", "t2v")
|
| 599 |
-
metadata = server_config.get("metadata_type", "metadata")
|
| 600 |
save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
|
| 601 |
use_image2video = default_ui != "t2v"
|
| 602 |
if args.t2v:
|
|
@@ -956,6 +960,7 @@ def apply_changes( state,
|
|
| 956 |
metadata_choice,
|
| 957 |
default_ui_choice ="t2v",
|
| 958 |
boost_choice = 1,
|
|
|
|
| 959 |
reload_choice = 1
|
| 960 |
):
|
| 961 |
if args.lock_config:
|
|
@@ -975,6 +980,7 @@ def apply_changes( state,
|
|
| 975 |
"metadata_choice": metadata_choice,
|
| 976 |
"default_ui" : default_ui_choice,
|
| 977 |
"boost" : boost_choice,
|
|
|
|
| 978 |
"reload_model" : reload_choice,
|
| 979 |
}
|
| 980 |
|
|
@@ -1008,7 +1014,7 @@ def apply_changes( state,
|
|
| 1008 |
text_encoder_filename = server_config["text_encoder_filename"]
|
| 1009 |
vae_config = server_config["vae_config"]
|
| 1010 |
boost = server_config["boost"]
|
| 1011 |
-
if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice"] for change in changes ):
|
| 1012 |
pass
|
| 1013 |
else:
|
| 1014 |
reload_needed = True
|
|
@@ -1059,6 +1065,10 @@ def finalize_gallery(state):
|
|
| 1059 |
if "in_progress" in state:
|
| 1060 |
del state["in_progress"]
|
| 1061 |
choice = state.get("selected",0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1062 |
time.sleep(0.2)
|
| 1063 |
global gen_in_progress
|
| 1064 |
gen_in_progress = False
|
|
@@ -1096,6 +1106,7 @@ def generate_video(
|
|
| 1096 |
tea_cache_start_step_perc,
|
| 1097 |
loras_choices,
|
| 1098 |
loras_mult_choices,
|
|
|
|
| 1099 |
image_to_continue,
|
| 1100 |
image_to_end,
|
| 1101 |
video_to_continue,
|
|
@@ -1104,7 +1115,9 @@ def generate_video(
|
|
| 1104 |
slg_switch,
|
| 1105 |
slg_layers,
|
| 1106 |
slg_start,
|
| 1107 |
-
slg_end,
|
|
|
|
|
|
|
| 1108 |
state,
|
| 1109 |
image2video,
|
| 1110 |
progress=gr.Progress() #track_tqdm= True
|
|
@@ -1243,7 +1256,6 @@ def generate_video(
|
|
| 1243 |
if "abort" in state:
|
| 1244 |
del state["abort"]
|
| 1245 |
state["in_progress"] = True
|
| 1246 |
-
state["selected"] = 0
|
| 1247 |
|
| 1248 |
enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
|
| 1249 |
# VAE Tiling
|
|
@@ -1281,17 +1293,25 @@ def generate_video(
|
|
| 1281 |
seed = random.randint(0, 999999999)
|
| 1282 |
|
| 1283 |
global file_list
|
| 1284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1285 |
global save_path
|
| 1286 |
os.makedirs(save_path, exist_ok=True)
|
| 1287 |
-
abort = False
|
| 1288 |
-
if trans.enable_teacache:
|
| 1289 |
-
trans.teacache_counter = 0
|
| 1290 |
-
trans.num_steps = num_inference_steps
|
| 1291 |
-
trans.teacache_skipped_steps = 0
|
| 1292 |
-
trans.previous_residual_uncond = None
|
| 1293 |
-
trans.previous_residual_cond = None
|
| 1294 |
video_no = 0
|
|
|
|
| 1295 |
repeats = f"{video_no}/{repeat_generation}"
|
| 1296 |
callback = build_callback(task_id, state, trans, num_inference_steps, repeats)
|
| 1297 |
offload.shared_state["callback"] = callback
|
|
@@ -1301,14 +1321,22 @@ def generate_video(
|
|
| 1301 |
for i in range(repeat_generation):
|
| 1302 |
try:
|
| 1303 |
with tracker_lock:
|
|
|
|
| 1304 |
progress_tracker[task_id] = {
|
| 1305 |
'current_step': 0,
|
| 1306 |
'total_steps': num_inference_steps,
|
| 1307 |
-
'start_time':
|
| 1308 |
-
'last_update':
|
| 1309 |
'repeats': f"{video_no}/{repeat_generation}",
|
| 1310 |
'status': "Encoding Prompt"
|
| 1311 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1312 |
video_no += 1
|
| 1313 |
if image2video:
|
| 1314 |
samples = wan_model.generate(
|
|
@@ -1330,6 +1358,8 @@ def generate_video(
|
|
| 1330 |
slg_layers = slg_layers,
|
| 1331 |
slg_start = slg_start/100,
|
| 1332 |
slg_end = slg_end/100,
|
|
|
|
|
|
|
| 1333 |
)
|
| 1334 |
else:
|
| 1335 |
samples = wan_model.generate(
|
|
@@ -1349,13 +1379,15 @@ def generate_video(
|
|
| 1349 |
slg_layers = slg_layers,
|
| 1350 |
slg_start = slg_start/100,
|
| 1351 |
slg_end = slg_end/100,
|
|
|
|
|
|
|
| 1352 |
)
|
| 1353 |
except Exception as e:
|
| 1354 |
gen_in_progress = False
|
| 1355 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
| 1356 |
os.remove(temp_filename)
|
| 1357 |
-
|
| 1358 |
-
|
| 1359 |
# if compile:
|
| 1360 |
# cache_size = torch._dynamo.config.cache_size_limit
|
| 1361 |
# torch.compiler.reset()
|
|
@@ -1399,6 +1431,7 @@ def generate_video(
|
|
| 1399 |
end_time = time.time()
|
| 1400 |
abort = True
|
| 1401 |
state["prompt"] = ""
|
|
|
|
| 1402 |
else:
|
| 1403 |
sample = samples.cpu()
|
| 1404 |
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
|
|
@@ -1416,9 +1449,9 @@ def generate_video(
|
|
| 1416 |
nrow=1,
|
| 1417 |
normalize=True,
|
| 1418 |
value_range=(-1, 1))
|
| 1419 |
-
|
| 1420 |
configs = get_settings_dict(state, use_image2video, prompt, 0 if image_to_end == None else 1 , video_length, raw_resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
| 1421 |
-
loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end)
|
| 1422 |
|
| 1423 |
metadata_choice = server_config.get("metadata_choice","metadata")
|
| 1424 |
if metadata_choice == "json":
|
|
@@ -1432,7 +1465,15 @@ def generate_video(
|
|
| 1432 |
|
| 1433 |
print(f"New video saved to Path: "+video_path)
|
| 1434 |
file_list.append(video_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1435 |
seed += 1
|
|
|
|
|
|
|
| 1436 |
last_model_type = image2video
|
| 1437 |
|
| 1438 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
|
@@ -1725,8 +1766,9 @@ def switch_advanced(state, new_advanced, lset_name):
|
|
| 1725 |
else:
|
| 1726 |
return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
|
| 1727 |
|
|
|
|
| 1728 |
def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
| 1729 |
-
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc):
|
| 1730 |
|
| 1731 |
loras = state["loras"]
|
| 1732 |
activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ]
|
|
@@ -1750,7 +1792,9 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
|
|
| 1750 |
"slg_switch": slg_switch,
|
| 1751 |
"slg_layers": slg_layers,
|
| 1752 |
"slg_start_perc": slg_start_perc,
|
| 1753 |
-
"slg_end_perc": slg_end_perc
|
|
|
|
|
|
|
| 1754 |
}
|
| 1755 |
|
| 1756 |
if i2v:
|
|
@@ -1758,21 +1802,22 @@ def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resol
|
|
| 1758 |
ui_settings["image_prompt_type"] = image_prompt_type
|
| 1759 |
else:
|
| 1760 |
ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video"
|
|
|
|
| 1761 |
return ui_settings
|
| 1762 |
|
| 1763 |
def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
| 1764 |
-
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc):
|
| 1765 |
|
| 1766 |
if state.get("validate_success",0) != 1:
|
| 1767 |
return
|
| 1768 |
|
| 1769 |
ui_defaults = get_settings_dict(state, use_image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
| 1770 |
-
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc)
|
| 1771 |
|
| 1772 |
defaults_filename = get_settings_file_name(use_image2video)
|
| 1773 |
|
| 1774 |
with open(defaults_filename, "w", encoding="utf-8") as f:
|
| 1775 |
-
json.dump(ui_defaults
|
| 1776 |
|
| 1777 |
gr.Info("New Default Settings saved")
|
| 1778 |
|
|
@@ -1907,6 +1952,7 @@ def generate_video_tab(image2video=False):
|
|
| 1907 |
return gr.Gallery(visible = (image_prompt_type_radio == 1) )
|
| 1908 |
else:
|
| 1909 |
return gr.Image(visible = (image_prompt_type_radio == 1) )
|
|
|
|
| 1910 |
image_prompt_type_radio.change(fn=switch_image_prompt_type_radio, inputs=[image_prompt_type_radio], outputs=[image_to_end])
|
| 1911 |
|
| 1912 |
|
|
@@ -2037,7 +2083,7 @@ def generate_video_tab(image2video=False):
|
|
| 2037 |
label="RIFLEx positional embedding to generate long video"
|
| 2038 |
)
|
| 2039 |
with gr.Row():
|
| 2040 |
-
gr.Markdown("<B>Experimental: Skip Layer
|
| 2041 |
with gr.Row():
|
| 2042 |
slg_switch = gr.Dropdown(
|
| 2043 |
choices=[
|
|
@@ -2061,6 +2107,23 @@ def generate_video_tab(image2video=False):
|
|
| 2061 |
with gr.Row():
|
| 2062 |
slg_start_perc = gr.Slider(0, 100, value=ui_defaults["slg_start_perc"], step=1, label="Denoising Steps % start")
|
| 2063 |
slg_end_perc = gr.Slider(0, 100, value=ui_defaults["slg_end_perc"], step=1, label="Denoising Steps % end")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2064 |
with gr.Row():
|
| 2065 |
save_settings_btn = gr.Button("Set Settings as Default")
|
| 2066 |
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
|
|
@@ -2103,7 +2166,7 @@ def generate_video_tab(image2video=False):
|
|
| 2103 |
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
|
| 2104 |
save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
|
| 2105 |
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
|
| 2106 |
-
slg_start_perc, slg_end_perc ], outputs = [])
|
| 2107 |
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
| 2108 |
confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
|
| 2109 |
save_lset, inputs=[state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
|
@@ -2133,6 +2196,7 @@ def generate_video_tab(image2video=False):
|
|
| 2133 |
tea_cache_start_step_perc,
|
| 2134 |
loras_choices,
|
| 2135 |
loras_mult_choices,
|
|
|
|
| 2136 |
image_to_continue,
|
| 2137 |
image_to_end,
|
| 2138 |
video_to_continue,
|
|
@@ -2142,6 +2206,8 @@ def generate_video_tab(image2video=False):
|
|
| 2142 |
slg_layers,
|
| 2143 |
slg_start_perc,
|
| 2144 |
slg_end_perc,
|
|
|
|
|
|
|
| 2145 |
state,
|
| 2146 |
gr.State(image2video)
|
| 2147 |
]
|
|
@@ -2276,7 +2342,7 @@ def generate_configuration_tab():
|
|
| 2276 |
("Add metadata to video", "metadata"),
|
| 2277 |
("Neither", "none")
|
| 2278 |
],
|
| 2279 |
-
value=metadata,
|
| 2280 |
label="Metadata Handling"
|
| 2281 |
)
|
| 2282 |
reload_choice = gr.Dropdown(
|
|
@@ -2287,6 +2353,21 @@ def generate_configuration_tab():
|
|
| 2287 |
value=server_config.get("reload_model",2),
|
| 2288 |
label="Reload model"
|
| 2289 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2290 |
msg = gr.Markdown()
|
| 2291 |
apply_btn = gr.Button("Apply Changes")
|
| 2292 |
apply_btn.click(
|
|
@@ -2304,6 +2385,7 @@ def generate_configuration_tab():
|
|
| 2304 |
metadata_choice,
|
| 2305 |
default_ui_choice,
|
| 2306 |
boost_choice,
|
|
|
|
| 2307 |
reload_choice,
|
| 2308 |
],
|
| 2309 |
outputs= msg
|
|
@@ -2322,6 +2404,9 @@ def generate_about_tab():
|
|
| 2322 |
def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
|
| 2323 |
global lora_model_filename, use_image2video
|
| 2324 |
|
|
|
|
|
|
|
|
|
|
| 2325 |
new_t2v = evt.index == 0
|
| 2326 |
new_i2v = evt.index == 1
|
| 2327 |
use_image2video = new_i2v
|
|
@@ -2341,9 +2426,6 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
|
|
| 2341 |
wan_model, offloadobj, trans = load_models(use_image2video)
|
| 2342 |
del trans
|
| 2343 |
|
| 2344 |
-
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
|
| 2345 |
-
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
|
| 2346 |
-
|
| 2347 |
if new_t2v:
|
| 2348 |
lora_model_filename = t2v_state["loras_model"]
|
| 2349 |
if ("1.3B" in transformer_filename_t2v and not "1.3B" in lora_model_filename or "14B" in transformer_filename_t2v and not "14B" in lora_model_filename):
|
|
@@ -2470,7 +2552,7 @@ def create_demo():
|
|
| 2470 |
}
|
| 2471 |
"""
|
| 2472 |
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
|
| 2473 |
-
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.
|
| 2474 |
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
|
| 2475 |
|
| 2476 |
with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
|
|
|
|
| 24 |
from wan.utils import prompt_parser
|
| 25 |
PROMPT_VARS_MAX = 10
|
| 26 |
|
| 27 |
+
target_mmgp_version = "3.3.4"
|
| 28 |
from importlib.metadata import version
|
| 29 |
mmgp_version = version("mmgp")
|
| 30 |
if mmgp_version != target_mmgp_version:
|
|
|
|
| 98 |
tea_cache_start_step_perc,
|
| 99 |
loras_choices,
|
| 100 |
loras_mult_choices,
|
| 101 |
+
image_prompt_type,
|
| 102 |
image_to_continue,
|
| 103 |
image_to_end,
|
| 104 |
video_to_continue,
|
|
|
|
| 108 |
slg_layers,
|
| 109 |
slg_start,
|
| 110 |
slg_end,
|
| 111 |
+
cfg_star_switch,
|
| 112 |
+
cfg_zero_step,
|
| 113 |
state_arg,
|
| 114 |
image2video
|
| 115 |
):
|
|
|
|
| 141 |
tea_cache_start_step_perc,
|
| 142 |
loras_choices,
|
| 143 |
loras_mult_choices,
|
| 144 |
+
image_prompt_type,
|
| 145 |
image_to_continue,
|
| 146 |
image_to_end,
|
| 147 |
video_to_continue,
|
|
|
|
| 151 |
slg_layers,
|
| 152 |
slg_start,
|
| 153 |
slg_end,
|
| 154 |
+
cfg_star_switch,
|
| 155 |
+
cfg_zero_step,
|
| 156 |
state_arg,
|
| 157 |
image2video
|
| 158 |
)
|
|
|
|
| 386 |
default="",
|
| 387 |
help="Server name"
|
| 388 |
)
|
|
|
|
| 389 |
parser.add_argument(
|
| 390 |
"--gpu",
|
| 391 |
type=str,
|
|
|
|
| 487 |
|
| 488 |
attention_modes_installed = get_attention_modes()
|
| 489 |
attention_modes_supported = get_supported_attention_modes()
|
|
|
|
| 490 |
args = _parse_args()
|
| 491 |
args.flow_reverse = True
|
| 492 |
|
|
|
|
| 517 |
"metadata_type": "metadata",
|
| 518 |
"default_ui": "t2v",
|
| 519 |
"boost" : 1,
|
| 520 |
+
"clear_file_list" : 0,
|
| 521 |
"vae_config": 0,
|
| 522 |
"profile" : profile_type.LowRAM_LowVRAM,
|
| 523 |
"reload_model": 2 }
|
|
|
|
| 601 |
|
| 602 |
reload_needed = False
|
| 603 |
default_ui = server_config.get("default_ui", "t2v")
|
|
|
|
| 604 |
save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
|
| 605 |
use_image2video = default_ui != "t2v"
|
| 606 |
if args.t2v:
|
|
|
|
| 960 |
metadata_choice,
|
| 961 |
default_ui_choice ="t2v",
|
| 962 |
boost_choice = 1,
|
| 963 |
+
clear_file_list = 0,
|
| 964 |
reload_choice = 1
|
| 965 |
):
|
| 966 |
if args.lock_config:
|
|
|
|
| 980 |
"metadata_choice": metadata_choice,
|
| 981 |
"default_ui" : default_ui_choice,
|
| 982 |
"boost" : boost_choice,
|
| 983 |
+
"clear_file_list" : clear_file_list
|
| 984 |
"reload_model" : reload_choice,
|
| 985 |
}
|
| 986 |
|
|
|
|
| 1014 |
text_encoder_filename = server_config["text_encoder_filename"]
|
| 1015 |
vae_config = server_config["vae_config"]
|
| 1016 |
boost = server_config["boost"]
|
| 1017 |
+
if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ):
|
| 1018 |
pass
|
| 1019 |
else:
|
| 1020 |
reload_needed = True
|
|
|
|
| 1065 |
if "in_progress" in state:
|
| 1066 |
del state["in_progress"]
|
| 1067 |
choice = state.get("selected",0)
|
| 1068 |
+
# file_list = state.get("file_list", [])
|
| 1069 |
+
|
| 1070 |
+
|
| 1071 |
+
state["extra_orders"] = 0
|
| 1072 |
time.sleep(0.2)
|
| 1073 |
global gen_in_progress
|
| 1074 |
gen_in_progress = False
|
|
|
|
| 1106 |
tea_cache_start_step_perc,
|
| 1107 |
loras_choices,
|
| 1108 |
loras_mult_choices,
|
| 1109 |
+
image_prompt_type,
|
| 1110 |
image_to_continue,
|
| 1111 |
image_to_end,
|
| 1112 |
video_to_continue,
|
|
|
|
| 1115 |
slg_switch,
|
| 1116 |
slg_layers,
|
| 1117 |
slg_start,
|
| 1118 |
+
slg_end,
|
| 1119 |
+
cfg_star_switch,
|
| 1120 |
+
cfg_zero_step,
|
| 1121 |
state,
|
| 1122 |
image2video,
|
| 1123 |
progress=gr.Progress() #track_tqdm= True
|
|
|
|
| 1256 |
if "abort" in state:
|
| 1257 |
del state["abort"]
|
| 1258 |
state["in_progress"] = True
|
|
|
|
| 1259 |
|
| 1260 |
enable_RIFLEx = RIFLEx_setting == 0 and video_length > (6* 16) or RIFLEx_setting == 1
|
| 1261 |
# VAE Tiling
|
|
|
|
| 1293 |
seed = random.randint(0, 999999999)
|
| 1294 |
|
| 1295 |
global file_list
|
| 1296 |
+
clear_file_list = server_config.get("clear_file_list", 0)
|
| 1297 |
+
file_list = state.get("file_list", [])
|
| 1298 |
+
if clear_file_list > 0:
|
| 1299 |
+
file_list_current_size = len(file_list)
|
| 1300 |
+
keep_file_from = max(file_list_current_size - clear_file_list, 0)
|
| 1301 |
+
files_removed = keep_file_from
|
| 1302 |
+
choice = state.get("selected",0)
|
| 1303 |
+
choice = max(choice- files_removed, 0)
|
| 1304 |
+
file_list = file_list[ keep_file_from: ]
|
| 1305 |
+
else:
|
| 1306 |
+
file_list = []
|
| 1307 |
+
choice = 0
|
| 1308 |
+
state["selected"] = choice
|
| 1309 |
+
state["file_list"] = file_list
|
| 1310 |
+
|
| 1311 |
global save_path
|
| 1312 |
os.makedirs(save_path, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1313 |
video_no = 0
|
| 1314 |
+
abort = False
|
| 1315 |
repeats = f"{video_no}/{repeat_generation}"
|
| 1316 |
callback = build_callback(task_id, state, trans, num_inference_steps, repeats)
|
| 1317 |
offload.shared_state["callback"] = callback
|
|
|
|
| 1321 |
for i in range(repeat_generation):
|
| 1322 |
try:
|
| 1323 |
with tracker_lock:
|
| 1324 |
+
start_time = time.time()
|
| 1325 |
progress_tracker[task_id] = {
|
| 1326 |
'current_step': 0,
|
| 1327 |
'total_steps': num_inference_steps,
|
| 1328 |
+
'start_time': start_time,
|
| 1329 |
+
'last_update': start_time,
|
| 1330 |
'repeats': f"{video_no}/{repeat_generation}",
|
| 1331 |
'status': "Encoding Prompt"
|
| 1332 |
}
|
| 1333 |
+
if trans.enable_teacache:
|
| 1334 |
+
trans.teacache_counter = 0
|
| 1335 |
+
trans.num_steps = num_inference_steps
|
| 1336 |
+
trans.teacache_skipped_steps = 0
|
| 1337 |
+
trans.previous_residual_uncond = None
|
| 1338 |
+
trans.previous_residual_cond = None
|
| 1339 |
+
|
| 1340 |
video_no += 1
|
| 1341 |
if image2video:
|
| 1342 |
samples = wan_model.generate(
|
|
|
|
| 1358 |
slg_layers = slg_layers,
|
| 1359 |
slg_start = slg_start/100,
|
| 1360 |
slg_end = slg_end/100,
|
| 1361 |
+
cfg_star_switch = cfg_star_switch,
|
| 1362 |
+
cfg_zero_step = cfg_zero_step,
|
| 1363 |
)
|
| 1364 |
else:
|
| 1365 |
samples = wan_model.generate(
|
|
|
|
| 1379 |
slg_layers = slg_layers,
|
| 1380 |
slg_start = slg_start/100,
|
| 1381 |
slg_end = slg_end/100,
|
| 1382 |
+
cfg_star_switch = cfg_star_switch,
|
| 1383 |
+
cfg_zero_step = cfg_zero_step,
|
| 1384 |
)
|
| 1385 |
except Exception as e:
|
| 1386 |
gen_in_progress = False
|
| 1387 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
| 1388 |
os.remove(temp_filename)
|
| 1389 |
+
offload.last_offload_obj.unload_all()
|
| 1390 |
+
offload.unload_loras_from_model(trans)
|
| 1391 |
# if compile:
|
| 1392 |
# cache_size = torch._dynamo.config.cache_size_limit
|
| 1393 |
# torch.compiler.reset()
|
|
|
|
| 1431 |
end_time = time.time()
|
| 1432 |
abort = True
|
| 1433 |
state["prompt"] = ""
|
| 1434 |
+
yield f"Video generation was aborted. Total Generation Time: {end_time-start_time:.1f}s"
|
| 1435 |
else:
|
| 1436 |
sample = samples.cpu()
|
| 1437 |
# video = rearrange(sample.cpu().numpy(), "c t h w -> t h w c")
|
|
|
|
| 1449 |
nrow=1,
|
| 1450 |
normalize=True,
|
| 1451 |
value_range=(-1, 1))
|
| 1452 |
+
|
| 1453 |
configs = get_settings_dict(state, use_image2video, prompt, 0 if image_to_end == None else 1 , video_length, raw_resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
| 1454 |
+
loras_mult_choices, tea_cache , tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start, slg_end, cfg_star_switch, cfg_zero_step)
|
| 1455 |
|
| 1456 |
metadata_choice = server_config.get("metadata_choice","metadata")
|
| 1457 |
if metadata_choice == "json":
|
|
|
|
| 1465 |
|
| 1466 |
print(f"New video saved to Path: "+video_path)
|
| 1467 |
file_list.append(video_path)
|
| 1468 |
+
if video_no < total_video:
|
| 1469 |
+
yield status
|
| 1470 |
+
else:
|
| 1471 |
+
end_time = time.time()
|
| 1472 |
+
state["prompt"] = ""
|
| 1473 |
+
yield f"Total Generation Time: {end_time-start_time:.1f}s"
|
| 1474 |
seed += 1
|
| 1475 |
+
repeat_no += 1
|
| 1476 |
+
|
| 1477 |
last_model_type = image2video
|
| 1478 |
|
| 1479 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
|
|
|
| 1766 |
else:
|
| 1767 |
return gr.Row(visible=new_advanced), gr.Row(visible=True), gr.Button(visible=True), gr.Row(visible= False), gr.Dropdown(choices=lset_choices, value= lset_name)
|
| 1768 |
|
| 1769 |
+
|
| 1770 |
def get_settings_dict(state, i2v, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
| 1771 |
+
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
|
| 1772 |
|
| 1773 |
loras = state["loras"]
|
| 1774 |
activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ]
|
|
|
|
| 1792 |
"slg_switch": slg_switch,
|
| 1793 |
"slg_layers": slg_layers,
|
| 1794 |
"slg_start_perc": slg_start_perc,
|
| 1795 |
+
"slg_end_perc": slg_end_perc,
|
| 1796 |
+
"cfg_star_switch": cfg_star_switch,
|
| 1797 |
+
"cfg_zero_step": cfg_zero_step
|
| 1798 |
}
|
| 1799 |
|
| 1800 |
if i2v:
|
|
|
|
| 1802 |
ui_settings["image_prompt_type"] = image_prompt_type
|
| 1803 |
else:
|
| 1804 |
ui_settings["type"] = "Wan2.1GP by DeepBeepMeep - text2video"
|
| 1805 |
+
|
| 1806 |
return ui_settings
|
| 1807 |
|
| 1808 |
def save_settings(state, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
| 1809 |
+
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step):
|
| 1810 |
|
| 1811 |
if state.get("validate_success",0) != 1:
|
| 1812 |
return
|
| 1813 |
|
| 1814 |
ui_defaults = get_settings_dict(state, use_image2video, prompt, image_prompt_type, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt, loras_choices,
|
| 1815 |
+
loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers, slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step)
|
| 1816 |
|
| 1817 |
defaults_filename = get_settings_file_name(use_image2video)
|
| 1818 |
|
| 1819 |
with open(defaults_filename, "w", encoding="utf-8") as f:
|
| 1820 |
+
json.dump(ui_defaults, f, indent=4)
|
| 1821 |
|
| 1822 |
gr.Info("New Default Settings saved")
|
| 1823 |
|
|
|
|
| 1952 |
return gr.Gallery(visible = (image_prompt_type_radio == 1) )
|
| 1953 |
else:
|
| 1954 |
return gr.Image(visible = (image_prompt_type_radio == 1) )
|
| 1955 |
+
|
| 1956 |
image_prompt_type_radio.change(fn=switch_image_prompt_type_radio, inputs=[image_prompt_type_radio], outputs=[image_to_end])
|
| 1957 |
|
| 1958 |
|
|
|
|
| 2083 |
label="RIFLEx positional embedding to generate long video"
|
| 2084 |
)
|
| 2085 |
with gr.Row():
|
| 2086 |
+
gr.Markdown("<B>Experimental: Skip Layer Guidance, should improve video quality</B>")
|
| 2087 |
with gr.Row():
|
| 2088 |
slg_switch = gr.Dropdown(
|
| 2089 |
choices=[
|
|
|
|
| 2107 |
with gr.Row():
|
| 2108 |
slg_start_perc = gr.Slider(0, 100, value=ui_defaults["slg_start_perc"], step=1, label="Denoising Steps % start")
|
| 2109 |
slg_end_perc = gr.Slider(0, 100, value=ui_defaults["slg_end_perc"], step=1, label="Denoising Steps % end")
|
| 2110 |
+
|
| 2111 |
+
with gr.Row():
|
| 2112 |
+
gr.Markdown("<B>Experimental: Classifier-Free Guidance Zero Star, better adherence to Text Prompt")
|
| 2113 |
+
with gr.Row():
|
| 2114 |
+
cfg_star_switch = gr.Dropdown(
|
| 2115 |
+
choices=[
|
| 2116 |
+
("OFF", 0),
|
| 2117 |
+
("ON", 1),
|
| 2118 |
+
],
|
| 2119 |
+
value=ui_defaults.get("cfg_star_switch",0),
|
| 2120 |
+
visible=True,
|
| 2121 |
+
scale = 1,
|
| 2122 |
+
label="CFG Star"
|
| 2123 |
+
)
|
| 2124 |
+
with gr.Row():
|
| 2125 |
+
cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)")
|
| 2126 |
+
|
| 2127 |
with gr.Row():
|
| 2128 |
save_settings_btn = gr.Button("Set Settings as Default")
|
| 2129 |
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
|
|
|
|
| 2166 |
save_settings_btn.click( fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
|
| 2167 |
save_settings, inputs = [state, prompt, image_prompt_type_radio, video_length, resolution, num_inference_steps, seed, repeat_generation, multi_images_gen_type, guidance_scale, flow_shift, negative_prompt,
|
| 2168 |
loras_choices, loras_mult_choices, tea_cache_setting, tea_cache_start_step_perc, RIFLEx_setting, slg_switch, slg_layers,
|
| 2169 |
+
slg_start_perc, slg_end_perc, cfg_star_switch, cfg_zero_step ], outputs = [])
|
| 2170 |
save_lset_btn.click(validate_save_lset, inputs=[lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
| 2171 |
confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
|
| 2172 |
save_lset, inputs=[state, lset_name, loras_choices, loras_mult_choices, prompt, save_lset_prompt_drop], outputs=[lset_name, apply_lset_btn,refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
|
|
|
|
| 2196 |
tea_cache_start_step_perc,
|
| 2197 |
loras_choices,
|
| 2198 |
loras_mult_choices,
|
| 2199 |
+
image_prompt_type_radio,
|
| 2200 |
image_to_continue,
|
| 2201 |
image_to_end,
|
| 2202 |
video_to_continue,
|
|
|
|
| 2206 |
slg_layers,
|
| 2207 |
slg_start_perc,
|
| 2208 |
slg_end_perc,
|
| 2209 |
+
cfg_star_switch,
|
| 2210 |
+
cfg_zero_step,
|
| 2211 |
state,
|
| 2212 |
gr.State(image2video)
|
| 2213 |
]
|
|
|
|
| 2342 |
("Add metadata to video", "metadata"),
|
| 2343 |
("Neither", "none")
|
| 2344 |
],
|
| 2345 |
+
value=server_config.get("metadata_type", "metadata"),
|
| 2346 |
label="Metadata Handling"
|
| 2347 |
)
|
| 2348 |
reload_choice = gr.Dropdown(
|
|
|
|
| 2353 |
value=server_config.get("reload_model",2),
|
| 2354 |
label="Reload model"
|
| 2355 |
)
|
| 2356 |
+
|
| 2357 |
+
clear_file_list_choice = gr.Dropdown(
|
| 2358 |
+
choices=[
|
| 2359 |
+
("None", 0),
|
| 2360 |
+
("Keep the last video", 1),
|
| 2361 |
+
("Keep the last 5 videos", 5),
|
| 2362 |
+
("Keep the last 10 videos", 10),
|
| 2363 |
+
("Keep the last 20 videos", 20),
|
| 2364 |
+
("Keep the last 30 videos", 30),
|
| 2365 |
+
],
|
| 2366 |
+
value=server_config.get("clear_file_list", 0),
|
| 2367 |
+
label="Keep Previously Generated Videos when starting a Generation Batch"
|
| 2368 |
+
)
|
| 2369 |
+
|
| 2370 |
+
|
| 2371 |
msg = gr.Markdown()
|
| 2372 |
apply_btn = gr.Button("Apply Changes")
|
| 2373 |
apply_btn.click(
|
|
|
|
| 2385 |
metadata_choice,
|
| 2386 |
default_ui_choice,
|
| 2387 |
boost_choice,
|
| 2388 |
+
clear_file_list_choice,
|
| 2389 |
reload_choice,
|
| 2390 |
],
|
| 2391 |
outputs= msg
|
|
|
|
| 2404 |
def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
|
| 2405 |
global lora_model_filename, use_image2video
|
| 2406 |
|
| 2407 |
+
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
|
| 2408 |
+
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
|
| 2409 |
+
|
| 2410 |
new_t2v = evt.index == 0
|
| 2411 |
new_i2v = evt.index == 1
|
| 2412 |
use_image2video = new_i2v
|
|
|
|
| 2426 |
wan_model, offloadobj, trans = load_models(use_image2video)
|
| 2427 |
del trans
|
| 2428 |
|
|
|
|
|
|
|
|
|
|
| 2429 |
if new_t2v:
|
| 2430 |
lora_model_filename = t2v_state["loras_model"]
|
| 2431 |
if ("1.3B" in transformer_filename_t2v and not "1.3B" in lora_model_filename or "14B" in transformer_filename_t2v and not "14B" in lora_model_filename):
|
|
|
|
| 2552 |
}
|
| 2553 |
"""
|
| 2554 |
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
|
| 2555 |
+
gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.2 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
| 2556 |
gr.Markdown("<FONT SIZE=3>Welcome to Wan 2.1GP a super fast and low VRAM AI Video Generator !</FONT>")
|
| 2557 |
|
| 2558 |
with gr.Accordion("Click here for some Info on how to use Wan2GP", open = False):
|
requirements.txt
CHANGED
|
@@ -16,6 +16,6 @@ gradio>=5.0.0
|
|
| 16 |
numpy>=1.23.5,<2
|
| 17 |
einops
|
| 18 |
moviepy==1.0.3
|
| 19 |
-
mmgp==3.3.
|
| 20 |
peft==0.14.0
|
| 21 |
mutagen
|
|
|
|
| 16 |
numpy>=1.23.5,<2
|
| 17 |
einops
|
| 18 |
moviepy==1.0.3
|
| 19 |
+
mmgp==3.3.4
|
| 20 |
peft==0.14.0
|
| 21 |
mutagen
|
wan/image2video.py
CHANGED
|
@@ -28,79 +28,19 @@ from wan.modules.posemb_layers import get_rotary_pos_embed
|
|
| 28 |
|
| 29 |
from PIL import Image
|
| 30 |
|
| 31 |
-
def
|
| 32 |
-
images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
|
| 33 |
-
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
|
| 34 |
-
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images]
|
| 35 |
-
result = torch.stack(images)
|
| 36 |
-
return result.to(samples.device, samples.dtype)
|
| 37 |
-
|
| 38 |
-
def bislerp(samples, width, height):
|
| 39 |
-
def slerp(b1, b2, r):
|
| 40 |
-
'''slerps batches b1, b2 according to ratio r, batches should be flat e.g. NxC'''
|
| 41 |
-
|
| 42 |
-
c = b1.shape[-1]
|
| 43 |
-
|
| 44 |
-
#norms
|
| 45 |
-
b1_norms = torch.norm(b1, dim=-1, keepdim=True)
|
| 46 |
-
b2_norms = torch.norm(b2, dim=-1, keepdim=True)
|
| 47 |
-
|
| 48 |
-
#normalize
|
| 49 |
-
b1_normalized = b1 / b1_norms
|
| 50 |
-
b2_normalized = b2 / b2_norms
|
| 51 |
-
|
| 52 |
-
#zero when norms are zero
|
| 53 |
-
b1_normalized[b1_norms.expand(-1,c) == 0.0] = 0.0
|
| 54 |
-
b2_normalized[b2_norms.expand(-1,c) == 0.0] = 0.0
|
| 55 |
-
|
| 56 |
-
#slerp
|
| 57 |
-
dot = (b1_normalized*b2_normalized).sum(1)
|
| 58 |
-
omega = torch.acos(dot)
|
| 59 |
-
so = torch.sin(omega)
|
| 60 |
-
|
| 61 |
-
#technically not mathematically correct, but more pleasing?
|
| 62 |
-
res = (torch.sin((1.0-r.squeeze(1))*omega)/so).unsqueeze(1)*b1_normalized + (torch.sin(r.squeeze(1)*omega)/so).unsqueeze(1) * b2_normalized
|
| 63 |
-
res *= (b1_norms * (1.0-r) + b2_norms * r).expand(-1,c)
|
| 64 |
-
|
| 65 |
-
#edge cases for same or polar opposites
|
| 66 |
-
res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
|
| 67 |
-
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
|
| 68 |
-
return res
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def common_upscale(samples, width, height, upscale_method, crop):
|
| 72 |
-
orig_shape = tuple(samples.shape)
|
| 73 |
-
if len(orig_shape) > 4:
|
| 74 |
-
samples = samples.reshape(samples.shape[0], samples.shape[1], -1, samples.shape[-2], samples.shape[-1])
|
| 75 |
-
samples = samples.movedim(2, 1)
|
| 76 |
-
samples = samples.reshape(-1, orig_shape[1], orig_shape[-2], orig_shape[-1])
|
| 77 |
-
if crop == "center":
|
| 78 |
-
old_width = samples.shape[-1]
|
| 79 |
-
old_height = samples.shape[-2]
|
| 80 |
-
old_aspect = old_width / old_height
|
| 81 |
-
new_aspect = width / height
|
| 82 |
-
x = 0
|
| 83 |
-
y = 0
|
| 84 |
-
if old_aspect > new_aspect:
|
| 85 |
-
x = round((old_width - old_width * (new_aspect / old_aspect)) / 2)
|
| 86 |
-
elif old_aspect < new_aspect:
|
| 87 |
-
y = round((old_height - old_height * (old_aspect / new_aspect)) / 2)
|
| 88 |
-
s = samples.narrow(-2, y, old_height - y * 2).narrow(-1, x, old_width - x * 2)
|
| 89 |
-
else:
|
| 90 |
-
s = samples
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
elif upscale_method == "lanczos":
|
| 95 |
-
out = lanczos(s, width, height)
|
| 96 |
-
else:
|
| 97 |
-
out = torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
class WanI2V:
|
| 106 |
|
|
@@ -227,6 +167,8 @@ class WanI2V:
|
|
| 227 |
slg_layers = None,
|
| 228 |
slg_start = 0.0,
|
| 229 |
slg_end = 1.0,
|
|
|
|
|
|
|
| 230 |
):
|
| 231 |
r"""
|
| 232 |
Generates video frames from input image and text prompt using diffusion process.
|
|
@@ -375,7 +317,7 @@ class WanI2V:
|
|
| 375 |
|
| 376 |
# sample videos
|
| 377 |
latent = noise
|
| 378 |
-
|
| 379 |
freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
|
| 380 |
|
| 381 |
arg_c = {
|
|
@@ -456,8 +398,23 @@ class WanI2V:
|
|
| 456 |
del latent_model_input
|
| 457 |
if offload_model:
|
| 458 |
torch.cuda.empty_cache()
|
| 459 |
-
|
| 460 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
del noise_pred_uncond
|
| 462 |
|
| 463 |
latent = latent.to(
|
|
|
|
| 28 |
|
| 29 |
from PIL import Image
|
| 30 |
|
| 31 |
+
def optimized_scale(positive_flat, negative_flat):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
# Calculate dot production
|
| 34 |
+
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
# Squared norm of uncondition
|
| 37 |
+
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
| 38 |
|
| 39 |
+
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
| 40 |
+
st_star = dot_product / squared_norm
|
| 41 |
+
|
| 42 |
+
return st_star
|
| 43 |
+
|
| 44 |
|
| 45 |
class WanI2V:
|
| 46 |
|
|
|
|
| 167 |
slg_layers = None,
|
| 168 |
slg_start = 0.0,
|
| 169 |
slg_end = 1.0,
|
| 170 |
+
cfg_star_switch = True,
|
| 171 |
+
cfg_zero_step = 5,
|
| 172 |
):
|
| 173 |
r"""
|
| 174 |
Generates video frames from input image and text prompt using diffusion process.
|
|
|
|
| 317 |
|
| 318 |
# sample videos
|
| 319 |
latent = noise
|
| 320 |
+
batch_size = latent.shape[0]
|
| 321 |
freqs = get_rotary_pos_embed(latent.shape[1:], enable_RIFLEx= enable_RIFLEx)
|
| 322 |
|
| 323 |
arg_c = {
|
|
|
|
| 398 |
del latent_model_input
|
| 399 |
if offload_model:
|
| 400 |
torch.cuda.empty_cache()
|
| 401 |
+
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
| 402 |
+
noise_pred_text = noise_pred_cond
|
| 403 |
+
if cfg_star_switch:
|
| 404 |
+
positive_flat = noise_pred_text.view(batch_size, -1)
|
| 405 |
+
negative_flat = noise_pred_uncond.view(batch_size, -1)
|
| 406 |
+
|
| 407 |
+
alpha = optimized_scale(positive_flat,negative_flat)
|
| 408 |
+
alpha = alpha.view(batch_size, 1, 1, 1)
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
if (i <= cfg_zero_step):
|
| 412 |
+
noise_pred = noise_pred_text*0.
|
| 413 |
+
else:
|
| 414 |
+
noise_pred = noise_pred_uncond * alpha + guide_scale * (noise_pred_text - noise_pred_uncond * alpha)
|
| 415 |
+
else:
|
| 416 |
+
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
|
| 417 |
+
|
| 418 |
del noise_pred_uncond
|
| 419 |
|
| 420 |
latent = latent.to(
|
wan/modules/attention.py
CHANGED
|
@@ -70,30 +70,31 @@ def sageattn_wrapper(
|
|
| 70 |
|
| 71 |
return o
|
| 72 |
|
| 73 |
-
#
|
| 74 |
# if True:
|
| 75 |
-
#
|
| 76 |
-
#
|
| 77 |
-
#
|
| 78 |
-
#
|
| 79 |
-
#
|
| 80 |
-
#
|
| 81 |
-
#
|
| 82 |
-
#
|
| 83 |
-
#
|
| 84 |
-
#
|
| 85 |
-
#
|
| 86 |
-
#
|
| 87 |
-
#
|
| 88 |
-
#
|
| 89 |
-
#
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
#
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
#
|
| 96 |
-
#
|
|
|
|
| 97 |
|
| 98 |
@torch.compiler.disable()
|
| 99 |
def sdpa_wrapper(
|
|
@@ -253,17 +254,19 @@ def pay_attention(
|
|
| 253 |
# nb_latents = embed_sizes[0] * embed_sizes[1]* embed_sizes[2]
|
| 254 |
|
| 255 |
# window = 0
|
| 256 |
-
# start_window_step = int(max_steps * 0.
|
| 257 |
# start_layer = 10
|
| 258 |
-
#
|
|
|
|
| 259 |
# window = 0
|
| 260 |
# else:
|
| 261 |
-
# coef = min((max_steps - current_step)/(max_steps-start_window_step),1)*max(min((25 - layer)/(25-start_layer),1),0) * 0.7 + 0.3
|
|
|
|
| 262 |
# print(f"step: {current_step}, layer: {layer}, coef:{coef:0.1f}]")
|
| 263 |
# window = math.ceil(coef* nb_latents)
|
| 264 |
|
| 265 |
# invert_spaces = (layer + current_step) % 2 == 0 and window > 0
|
| 266 |
-
|
| 267 |
# def flip(q):
|
| 268 |
# q = q.reshape(*embed_sizes, *q.shape[-2:])
|
| 269 |
# q = q.transpose(0,2)
|
|
|
|
| 70 |
|
| 71 |
return o
|
| 72 |
|
| 73 |
+
# try:
|
| 74 |
# if True:
|
| 75 |
+
# from .sage2_core import sageattn_qk_int8_pv_fp8_window_cuda
|
| 76 |
+
# @torch.compiler.disable()
|
| 77 |
+
# def sageattn_window_wrapper(
|
| 78 |
+
# qkv_list,
|
| 79 |
+
# attention_length,
|
| 80 |
+
# window
|
| 81 |
+
# ):
|
| 82 |
+
# q,k, v = qkv_list
|
| 83 |
+
# padding_length = q.shape[0] -attention_length
|
| 84 |
+
# q = q[:attention_length, :, : ].unsqueeze(0)
|
| 85 |
+
# k = k[:attention_length, :, : ].unsqueeze(0)
|
| 86 |
+
# v = v[:attention_length, :, : ].unsqueeze(0)
|
| 87 |
+
# qkvl_list = [q, k , v]
|
| 88 |
+
# del q, k ,v
|
| 89 |
+
# o = sageattn_qk_int8_pv_fp8_window_cuda(qkvl_list, tensor_layout="NHD", window = window).squeeze(0)
|
| 90 |
+
# qkv_list.clear()
|
| 91 |
+
|
| 92 |
+
# if padding_length > 0:
|
| 93 |
+
# o = torch.cat([o, torch.empty( (padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 0)
|
| 94 |
+
|
| 95 |
+
# return o
|
| 96 |
+
# except ImportError:
|
| 97 |
+
# sageattn = sageattn_qk_int8_pv_fp8_window_cuda
|
| 98 |
|
| 99 |
@torch.compiler.disable()
|
| 100 |
def sdpa_wrapper(
|
|
|
|
| 254 |
# nb_latents = embed_sizes[0] * embed_sizes[1]* embed_sizes[2]
|
| 255 |
|
| 256 |
# window = 0
|
| 257 |
+
# start_window_step = int(max_steps * 0.3)
|
| 258 |
# start_layer = 10
|
| 259 |
+
# end_layer = 30
|
| 260 |
+
# if (layer < start_layer or layer > end_layer ) or current_step <start_window_step:
|
| 261 |
# window = 0
|
| 262 |
# else:
|
| 263 |
+
# # coef = min((max_steps - current_step)/(max_steps-start_window_step),1)*max(min((25 - layer)/(25-start_layer),1),0) * 0.7 + 0.3
|
| 264 |
+
# coef = 0.3
|
| 265 |
# print(f"step: {current_step}, layer: {layer}, coef:{coef:0.1f}]")
|
| 266 |
# window = math.ceil(coef* nb_latents)
|
| 267 |
|
| 268 |
# invert_spaces = (layer + current_step) % 2 == 0 and window > 0
|
| 269 |
+
# invert_spaces = False
|
| 270 |
# def flip(q):
|
| 271 |
# q = q.reshape(*embed_sizes, *q.shape[-2:])
|
| 272 |
# q = q.transpose(0,2)
|
wan/modules/model.py
CHANGED
|
@@ -647,26 +647,6 @@ class WanModel(ModelMixin, ConfigMixin):
|
|
| 647 |
self.init_weights()
|
| 648 |
|
| 649 |
|
| 650 |
-
# self.freqs = torch.cat([
|
| 651 |
-
# rope_params(1024, d - 4 * (d // 6)), #44
|
| 652 |
-
# rope_params(1024, 2 * (d // 6)), #42
|
| 653 |
-
# rope_params(1024, 2 * (d // 6)) #42
|
| 654 |
-
# ],dim=1)
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
def get_rope_freqs(self, nb_latent_frames, RIFLEx_k = None, device = "cuda"):
|
| 658 |
-
dim = self.dim
|
| 659 |
-
num_heads = self.num_heads
|
| 660 |
-
d = dim // num_heads
|
| 661 |
-
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
c1, s1 = rope_params_riflex(1024, dim= d - 4 * (d // 6), L_test=nb_latent_frames, k = RIFLEx_k ) if RIFLEx_k != None else rope_params(1024, dim= d - 4 * (d // 6)) #44
|
| 665 |
-
c2, s2 = rope_params(1024, 2 * (d // 6)) #42
|
| 666 |
-
c3, s3 = rope_params(1024, 2 * (d // 6)) #42
|
| 667 |
-
|
| 668 |
-
return (torch.cat([c1,c2,c3],dim=1).to(device) , torch.cat([s1,s2,s3],dim=1).to(device))
|
| 669 |
-
|
| 670 |
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
| 671 |
rescale_func = np.poly1d(self.coefficients)
|
| 672 |
e_list = []
|
|
|
|
| 647 |
self.init_weights()
|
| 648 |
|
| 649 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 650 |
def compute_teacache_threshold(self, start_step, timesteps = None, speed_factor =0):
|
| 651 |
rescale_func = np.poly1d(self.coefficients)
|
| 652 |
e_list = []
|
wan/modules/sage2_core.py
CHANGED
|
@@ -925,11 +925,11 @@ def sageattn_qk_int8_pv_fp8_window_cuda(
|
|
| 925 |
|
| 926 |
if pv_accum_dtype == "fp32":
|
| 927 |
if smooth_v:
|
| 928 |
-
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window
|
| 929 |
else:
|
| 930 |
-
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse,
|
| 931 |
elif pv_accum_dtype == "fp32+fp32":
|
| 932 |
-
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse,
|
| 933 |
|
| 934 |
o = o[..., :head_dim_og]
|
| 935 |
|
|
|
|
| 925 |
|
| 926 |
if pv_accum_dtype == "fp32":
|
| 927 |
if smooth_v:
|
| 928 |
+
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
|
| 929 |
else:
|
| 930 |
+
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
|
| 931 |
elif pv_accum_dtype == "fp32+fp32":
|
| 932 |
+
lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse, window)
|
| 933 |
|
| 934 |
o = o[..., :head_dim_og]
|
| 935 |
|
wan/text2video.py
CHANGED
|
@@ -24,6 +24,20 @@ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
|
| 24 |
from wan.modules.posemb_layers import get_rotary_pos_embed
|
| 25 |
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
class WanT2V:
|
| 28 |
|
| 29 |
def __init__(
|
|
@@ -136,6 +150,8 @@ class WanT2V:
|
|
| 136 |
slg_layers = None,
|
| 137 |
slg_start = 0.0,
|
| 138 |
slg_end = 1.0,
|
|
|
|
|
|
|
| 139 |
):
|
| 140 |
r"""
|
| 141 |
Generates video frames from text prompt using diffusion process.
|
|
@@ -240,7 +256,7 @@ class WanT2V:
|
|
| 240 |
|
| 241 |
# sample videos
|
| 242 |
latents = noise
|
| 243 |
-
|
| 244 |
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
|
| 245 |
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
| 246 |
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
|
@@ -249,7 +265,6 @@ class WanT2V:
|
|
| 249 |
# arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
| 250 |
# arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
| 251 |
# arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
| 252 |
-
|
| 253 |
if self.model.enable_teacache:
|
| 254 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
| 255 |
if callback != None:
|
|
@@ -280,8 +295,23 @@ class WanT2V:
|
|
| 280 |
return None
|
| 281 |
|
| 282 |
del latent_model_input
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
del noise_pred_uncond
|
| 286 |
|
| 287 |
temp_x0 = sample_scheduler.step(
|
|
|
|
| 24 |
from wan.modules.posemb_layers import get_rotary_pos_embed
|
| 25 |
|
| 26 |
|
| 27 |
+
def optimized_scale(positive_flat, negative_flat):
|
| 28 |
+
|
| 29 |
+
# Calculate dot production
|
| 30 |
+
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
| 31 |
+
|
| 32 |
+
# Squared norm of uncondition
|
| 33 |
+
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
| 34 |
+
|
| 35 |
+
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
| 36 |
+
st_star = dot_product / squared_norm
|
| 37 |
+
|
| 38 |
+
return st_star
|
| 39 |
+
|
| 40 |
+
|
| 41 |
class WanT2V:
|
| 42 |
|
| 43 |
def __init__(
|
|
|
|
| 150 |
slg_layers = None,
|
| 151 |
slg_start = 0.0,
|
| 152 |
slg_end = 1.0,
|
| 153 |
+
cfg_star_switch = True,
|
| 154 |
+
cfg_zero_step = 5,
|
| 155 |
):
|
| 156 |
r"""
|
| 157 |
Generates video frames from text prompt using diffusion process.
|
|
|
|
| 256 |
|
| 257 |
# sample videos
|
| 258 |
latents = noise
|
| 259 |
+
batch_size =len(latents)
|
| 260 |
freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
|
| 261 |
arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
| 262 |
arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self}
|
|
|
|
| 265 |
# arg_c = {'context': context, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
| 266 |
# arg_null = {'context': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
| 267 |
# arg_both = {'context': context, 'context2': context_null, 'seq_len': seq_len, 'freqs': freqs, 'pipeline': self, "max_steps": sampling_steps}
|
|
|
|
| 268 |
if self.model.enable_teacache:
|
| 269 |
self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
|
| 270 |
if callback != None:
|
|
|
|
| 295 |
return None
|
| 296 |
|
| 297 |
del latent_model_input
|
| 298 |
+
|
| 299 |
+
# CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
|
| 300 |
+
noise_pred_text = noise_pred_cond
|
| 301 |
+
if cfg_star_switch:
|
| 302 |
+
positive_flat = noise_pred_text.view(batch_size, -1)
|
| 303 |
+
negative_flat = noise_pred_uncond.view(batch_size, -1)
|
| 304 |
+
|
| 305 |
+
alpha = optimized_scale(positive_flat,negative_flat)
|
| 306 |
+
alpha = alpha.view(batch_size, 1, 1, 1)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
if (i <= cfg_zero_step):
|
| 310 |
+
noise_pred = noise_pred_text*0.
|
| 311 |
+
else:
|
| 312 |
+
noise_pred = noise_pred_uncond * alpha + guide_scale * (noise_pred_text - noise_pred_uncond * alpha)
|
| 313 |
+
else:
|
| 314 |
+
noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
|
| 315 |
del noise_pred_uncond
|
| 316 |
|
| 317 |
temp_x0 = sample_scheduler.step(
|