DeepBeepMeep commited on
Commit
238b44c
·
1 Parent(s): a8a3e31

added default models selection option

Browse files
Files changed (1) hide show
  1. wgp.py +81 -81
wgp.py CHANGED
@@ -441,6 +441,12 @@ def _parse_args():
441
  help="Prevent modifying the configuration from the web interface"
442
  )
443
 
 
 
 
 
 
 
444
  parser.add_argument(
445
  "--preload",
446
  type=str,
@@ -700,11 +706,15 @@ transformer_choices_t2v=["ckpts/wan2.1_text2video_1.3B_bf16.safetensors", "ckpts
700
  transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ]
701
  transformer_choices = transformer_choices_t2v + transformer_choices_i2v
702
  text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
703
- server_config_filename = "gradio_config.json"
 
 
 
 
704
 
705
  if not Path(server_config_filename).is_file():
706
  server_config = {"attention_mode" : "auto",
707
- "transformer_type": "t2v",
708
  "transformer_quantization": "int8",
709
  "text_encoder_filename" : text_encoder_choices[1],
710
  "save_path": os.path.join(os.getcwd(), "gradio_outputs"),
@@ -826,7 +836,8 @@ def get_default_settings(filename):
826
  ui_defaults["num_inference_steps"] = default_number_steps
827
  return ui_defaults
828
 
829
- transformer_type = server_config.get("transformer_type", "t2v")
 
830
  transformer_quantization =server_config.get("transformer_quantization", "int8")
831
  transformer_filename = get_model_filename(transformer_type, transformer_quantization)
832
  text_encoder_filename = server_config["text_encoder_filename"]
@@ -1213,22 +1224,25 @@ def get_model_name(model_filename):
1213
  # return header
1214
 
1215
 
1216
- def generate_header(compile, attention_mode):
1217
 
1218
- header = "<DIV style='align:right;width:100%'><FONT SIZE=2>Attention mode: " + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
1219
  if attention_mode not in attention_modes_installed:
1220
  header += " -NOT INSTALLED-"
1221
  elif attention_mode not in attention_modes_supported:
1222
  header += " -NOT SUPPORTED-"
 
1223
 
1224
  if compile:
1225
- header += ", pytorch compilation ON"
 
 
1226
  header += "<FONT></DIV>"
1227
 
1228
  return header
1229
 
1230
  def apply_changes( state,
1231
- transformer_type_choice,
1232
  text_encoder_choice,
1233
  save_path_choice,
1234
  attention_choice,
@@ -1239,16 +1253,15 @@ def apply_changes( state,
1239
  quantization_choice,
1240
  boost_choice = 1,
1241
  clear_file_list = 0,
1242
- reload_choice = 1
1243
  ):
1244
  if args.lock_config:
1245
  return
1246
  if gen_in_progress:
1247
- yield "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
1248
- return
1249
  global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
1250
  server_config = {"attention_mode" : attention_choice,
1251
- "transformer_type": transformer_type_choice,
1252
  "text_encoder_filename" : text_encoder_choices[text_encoder_choice],
1253
  "save_path" : save_path_choice,
1254
  "compile" : compile_choice,
@@ -1281,7 +1294,7 @@ def apply_changes( state,
1281
  if v != v_old:
1282
  changes.append(k)
1283
 
1284
- global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, reload_model, transformer_quantization, transformer_type
1285
  attention_mode = server_config["attention_mode"]
1286
  profile = server_config["profile"]
1287
  compile = server_config["compile"]
@@ -1290,15 +1303,19 @@ def apply_changes( state,
1290
  boost = server_config["boost"]
1291
  reload_model = server_config["reload_model"]
1292
  transformer_quantization = server_config["transformer_quantization"]
1293
- transformer_filename = get_model_filename(transformer_type, transformer_quantization)
1294
-
1295
- if all(change in ["attention_mode", "vae_config", "default_ui", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ):
1296
- pass
 
 
 
1297
  else:
1298
  reload_needed = True
 
1299
 
1300
-
1301
- yield "<DIV ALIGN=CENTER>The new configuration has been succesfully applied</DIV>"
1302
 
1303
 
1304
 
@@ -2505,16 +2522,11 @@ def handle_celll_selection(state, evt: gr.SelectData):
2505
 
2506
 
2507
  def change_model(state, model_choice):
2508
- model_filename = ""
2509
- for filename in model_list:
2510
- if get_model_type(filename) == model_choice:
2511
- model_filename = filename
2512
- break
2513
- if len(model_filename) == 0:
2514
  return
2515
-
2516
  state["model_filename"] = model_filename
2517
- header = generate_header(compile=compile, attention_mode=attention_mode)
2518
  return header
2519
 
2520
  def fill_inputs(state):
@@ -3014,53 +3026,32 @@ def generate_download_tab(lset_name,loras_choices, state):
3014
  download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
3015
 
3016
 
3017
- def generate_configuration_tab():
3018
  state_dict = {}
3019
  state = gr.State(state_dict)
3020
  gr.Markdown("Please click Apply Changes at the bottom so that the changes are effective. Some choices below may be locked if the app has been launched by specifying a config preset.")
3021
  with gr.Column():
3022
- index = transformer_choices.index(transformer_filename)
3023
- index = 0 if index ==0 else index
3024
-
3025
  model_list = []
 
3026
  for model_type in model_types:
3027
  choice = get_model_filename(model_type, transformer_quantization)
3028
  model_list.append(choice)
3029
  dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
3030
- transformer_type_choice = gr.Dropdown(
3031
  choices= dropdown_choices,
3032
- value= get_model_type(transformer_filename),
3033
- label= "Default Wan Transformer Model",
3034
- scale= 2
 
3035
  )
3036
 
3037
- # transformer_choice = gr.Dropdown(
3038
- # choices=[
3039
- # ("WAN 2.1 1.3B Text to Video 16 bits (recommended)- the small model for fast generations with low VRAM requirements", 0),
3040
- # ("WAN 2.1 14B Text to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 1),
3041
- # ("WAN 2.1 14B Text to Video quantized to 8 bits (recommended) - the default engine but quantized", 2),
3042
- # ("WAN 2.1 VACE 1.3B Text to Video / Control Net - text generation driven by reference images or videos", 3),
3043
- # ("WAN 2.1 - 480p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 4),
3044
- # ("WAN 2.1 - 480p 14B Image to Video quantized to 8 bits (recommended) - the default engine but quantized", 5),
3045
- # ("WAN 2.1 - 720p 14B Image to Video 16 bits - the default engine in its original glory, offers a slightly better image quality but slower and requires more RAM", 6),
3046
- # ("WAN 2.1 - 720p 14B Image to Video quantized to 8 bits - the default engine but quantized", 7),
3047
- # ("WAN 2.1 - Fun InP 1.3B 16 bits - the small model for fast generations with low VRAM requirements", 8),
3048
- # ("WAN 2.1 - Fun InP 14B 16 bits - Fun InP version in its original glory, offers a slightly better image quality but slower and requires more RAM", 9),
3049
- # ("WAN 2.1 - Fun InP 14B quantized to 8 bits - quantized Fun InP version", 10),
3050
- # ],
3051
- # value= index,
3052
- # label="Transformer model for Image to Video",
3053
- # interactive= not lock_ui_transformer,
3054
- # visible = True,
3055
- # )
3056
-
3057
  quantization_choice = gr.Dropdown(
3058
  choices=[
3059
  ("Int8 Quantization (recommended)", "int8"),
3060
  ("BF16 (no quantization)", "bf16"),
3061
  ],
3062
  value= transformer_quantization,
3063
- label="Wan Transformer Model Quantization (if available)",
3064
  )
3065
 
3066
  index = text_encoder_choices.index(text_encoder_filename)
@@ -3137,14 +3128,14 @@ def generate_configuration_tab():
3137
  value= profile,
3138
  label="Profile (for power users only, not needed to change it)"
3139
  )
3140
- default_ui_choice = gr.Dropdown(
3141
- choices=[
3142
- ("Text to Video", "t2v"),
3143
- ("Image to Video", "i2v"),
3144
- ],
3145
- value= default_ui,
3146
- label="Default mode when launching the App if not '--t2v' ot '--i2v' switch is specified when launching the server ",
3147
- )
3148
  metadata_choice = gr.Dropdown(
3149
  choices=[
3150
  ("Export JSON files", "json"),
@@ -3184,7 +3175,7 @@ def generate_configuration_tab():
3184
  fn=apply_changes,
3185
  inputs=[
3186
  state,
3187
- transformer_type_choice,
3188
  text_encoder_choice,
3189
  save_path_choice,
3190
  attention_choice,
@@ -3197,7 +3188,7 @@ def generate_configuration_tab():
3197
  clear_file_list_choice,
3198
  reload_choice,
3199
  ],
3200
- outputs= msg
3201
  )
3202
 
3203
  def generate_about_tab():
@@ -3221,6 +3212,22 @@ def generate_info_tab():
3221
  gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
3222
 
3223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3224
 
3225
 
3226
 
@@ -3457,22 +3464,15 @@ def create_demo():
3457
  with gr.Tabs(selected="video_gen", ) as main_tabs:
3458
  with gr.Tab("Video Generator", id="video_gen") as t2v_tab:
3459
  with gr.Row():
3460
- header = gr.Markdown(generate_header(compile, attention_mode), visible= True)
 
 
 
 
 
 
3461
  with gr.Row():
3462
- gr.Markdown("<div class='title-with-lines'><div class=line width=100%></div></div>")
3463
-
3464
- model_list = []
3465
- for model_type in model_types:
3466
- choice = get_model_filename(model_type, transformer_quantization)
3467
- model_list.append(choice)
3468
- dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
3469
- model_choice = gr.Dropdown(
3470
- choices= dropdown_choices,
3471
- value= get_model_type(transformer_filename),
3472
- show_label= False,
3473
- scale= 2
3474
- )
3475
- gr.Markdown("<div class='title-with-lines'><div class=line width=100%></div></div>")
3476
  with gr.Row():
3477
 
3478
  loras_choices, lset_name, state = generate_video_tab(model_choice = model_choice, header = header)
@@ -3482,7 +3482,7 @@ def create_demo():
3482
  with gr.Tab("Downloads", id="downloads") as downloads_tab:
3483
  generate_download_tab(lset_name, loras_choices, state)
3484
  with gr.Tab("Configuration"):
3485
- generate_configuration_tab()
3486
  with gr.Tab("About"):
3487
  generate_about_tab()
3488
 
 
441
  help="Prevent modifying the configuration from the web interface"
442
  )
443
 
444
+ parser.add_argument(
445
+ "--lock-model",
446
+ action="store_true",
447
+ help="Prevent switch models"
448
+ )
449
+
450
  parser.add_argument(
451
  "--preload",
452
  type=str,
 
706
  transformer_choices_i2v=["ckpts/wan2.1_image2video_480p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_480p_14B_quanto_int8.safetensors", "ckpts/wan2.1_image2video_720p_14B_bf16.safetensors", "ckpts/wan2.1_image2video_720p_14B_quanto_int8.safetensors", "ckpts/wan2.1_Fun_InP_1.3B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_bf16.safetensors", "ckpts/wan2.1_Fun_InP_14B_quanto_int8.safetensors", ]
707
  transformer_choices = transformer_choices_t2v + transformer_choices_i2v
708
  text_encoder_choices = ["ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors"]
709
+ server_config_filename = "wgp_config.json"
710
+
711
+ if not os.path.isfile(server_config_filename) and os.path.isfile("gradio_config.json"):
712
+ import shutil
713
+ shutil.move("gradio_config.json", server_config_filename)
714
 
715
  if not Path(server_config_filename).is_file():
716
  server_config = {"attention_mode" : "auto",
717
+ "transformer_types": [],
718
  "transformer_quantization": "int8",
719
  "text_encoder_filename" : text_encoder_choices[1],
720
  "save_path": os.path.join(os.getcwd(), "gradio_outputs"),
 
836
  ui_defaults["num_inference_steps"] = default_number_steps
837
  return ui_defaults
838
 
839
+ transformer_types = server_config.get("transformer_types", [])
840
+ transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
841
  transformer_quantization =server_config.get("transformer_quantization", "int8")
842
  transformer_filename = get_model_filename(transformer_type, transformer_quantization)
843
  text_encoder_filename = server_config["text_encoder_filename"]
 
1224
  # return header
1225
 
1226
 
1227
+ def generate_header(model_filename, compile, attention_mode):
1228
 
1229
+ header = "<DIV style='align:right;width:100%'><FONT SIZE=3>Attention mode <B>" + (attention_mode if attention_mode!="auto" else "auto/" + get_auto_attention() )
1230
  if attention_mode not in attention_modes_installed:
1231
  header += " -NOT INSTALLED-"
1232
  elif attention_mode not in attention_modes_supported:
1233
  header += " -NOT SUPPORTED-"
1234
+ header += "</B>"
1235
 
1236
  if compile:
1237
+ header += ", Pytorch compilation <B>ON</B>"
1238
+ if "int8" in model_filename:
1239
+ header += ", Quantization <B>Int8</B>"
1240
  header += "<FONT></DIV>"
1241
 
1242
  return header
1243
 
1244
  def apply_changes( state,
1245
+ transformer_types_choices,
1246
  text_encoder_choice,
1247
  save_path_choice,
1248
  attention_choice,
 
1253
  quantization_choice,
1254
  boost_choice = 1,
1255
  clear_file_list = 0,
1256
+ reload_choice = 1,
1257
  ):
1258
  if args.lock_config:
1259
  return
1260
  if gen_in_progress:
1261
+ return "<DIV ALIGN=CENTER>Unable to change config when a generation is in progress</DIV>"
 
1262
  global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets
1263
  server_config = {"attention_mode" : attention_choice,
1264
+ "transformer_types": transformer_types_choices,
1265
  "text_encoder_filename" : text_encoder_choices[text_encoder_choice],
1266
  "save_path" : save_path_choice,
1267
  "compile" : compile_choice,
 
1294
  if v != v_old:
1295
  changes.append(k)
1296
 
1297
+ global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, reload_model, transformer_quantization, transformer_types
1298
  attention_mode = server_config["attention_mode"]
1299
  profile = server_config["profile"]
1300
  compile = server_config["compile"]
 
1303
  boost = server_config["boost"]
1304
  reload_model = server_config["reload_model"]
1305
  transformer_quantization = server_config["transformer_quantization"]
1306
+ transformer_types = server_config["transformer_types"]
1307
+ transformer_type = get_model_type(transformer_filename)
1308
+ if not transformer_type in transformer_types:
1309
+ transformer_type = transformer_types[0] if len(transformer_types) > 0 else model_types[0]
1310
+ transformer_filename = get_model_filename(transformer_type, transformer_quantization)
1311
+ if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_choice", "clear_file_list"] for change in changes ):
1312
+ model_choice = gr.Dropdown()
1313
  else:
1314
  reload_needed = True
1315
+ model_choice = generate_dropdown_model_list()
1316
 
1317
+ header = generate_header(transformer_filename, compile=compile, attention_mode= attention_mode)
1318
+ return "<DIV ALIGN=CENTER>The new configuration has been succesfully applied</DIV>", header, model_choice
1319
 
1320
 
1321
 
 
2522
 
2523
 
2524
  def change_model(state, model_choice):
2525
+ if model_choice == None:
 
 
 
 
 
2526
  return
2527
+ model_filename = get_model_filename(model_choice, transformer_quantization)
2528
  state["model_filename"] = model_filename
2529
+ header = generate_header(model_filename, compile=compile, attention_mode=attention_mode)
2530
  return header
2531
 
2532
  def fill_inputs(state):
 
3026
  download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices])
3027
 
3028
 
3029
+ def generate_configuration_tab(header, model_choice):
3030
  state_dict = {}
3031
  state = gr.State(state_dict)
3032
  gr.Markdown("Please click Apply Changes at the bottom so that the changes are effective. Some choices below may be locked if the app has been launched by specifying a config preset.")
3033
  with gr.Column():
 
 
 
3034
  model_list = []
3035
+
3036
  for model_type in model_types:
3037
  choice = get_model_filename(model_type, transformer_quantization)
3038
  model_list.append(choice)
3039
  dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
3040
+ transformer_types_choices = gr.Dropdown(
3041
  choices= dropdown_choices,
3042
+ value= transformer_types,
3043
+ label= "Selectable Wan Transformer Models (keep empty to get All of them)",
3044
+ scale= 2,
3045
+ multiselect= True
3046
  )
3047
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3048
  quantization_choice = gr.Dropdown(
3049
  choices=[
3050
  ("Int8 Quantization (recommended)", "int8"),
3051
  ("BF16 (no quantization)", "bf16"),
3052
  ],
3053
  value= transformer_quantization,
3054
+ label="Wan Transformer Model Quantization Type (if available)",
3055
  )
3056
 
3057
  index = text_encoder_choices.index(text_encoder_filename)
 
3128
  value= profile,
3129
  label="Profile (for power users only, not needed to change it)"
3130
  )
3131
+ # default_ui_choice = gr.Dropdown(
3132
+ # choices=[
3133
+ # ("Text to Video", "t2v"),
3134
+ # ("Image to Video", "i2v"),
3135
+ # ],
3136
+ # value= default_ui,
3137
+ # label="Default mode when launching the App if not '--t2v' ot '--i2v' switch is specified when launching the server ",
3138
+ # )
3139
  metadata_choice = gr.Dropdown(
3140
  choices=[
3141
  ("Export JSON files", "json"),
 
3175
  fn=apply_changes,
3176
  inputs=[
3177
  state,
3178
+ transformer_types_choices,
3179
  text_encoder_choice,
3180
  save_path_choice,
3181
  attention_choice,
 
3188
  clear_file_list_choice,
3189
  reload_choice,
3190
  ],
3191
+ outputs= [msg , header, model_choice]
3192
  )
3193
 
3194
  def generate_about_tab():
 
3212
  gr.Markdown("Please note that if your turn on compilation, the first denoising step of the first video generation will be slow due to the compilation. Therefore all your tests should be done with compilation turned off.")
3213
 
3214
 
3215
+ def generate_dropdown_model_list():
3216
+ dropdown_types= transformer_types if len(transformer_types) > 0 else model_types
3217
+ current_model_type = get_model_type(transformer_filename)
3218
+ if current_model_type not in dropdown_types:
3219
+ dropdown_types.append(current_model_type)
3220
+ model_list = []
3221
+ for model_type in dropdown_types:
3222
+ choice = get_model_filename(model_type, transformer_quantization)
3223
+ model_list.append(choice)
3224
+ dropdown_choices = [ ( get_model_name(choice), get_model_type(choice) ) for choice in model_list]
3225
+ return gr.Dropdown(
3226
+ choices= dropdown_choices,
3227
+ value= current_model_type,
3228
+ show_label= False,
3229
+ scale= 2
3230
+ )
3231
 
3232
 
3233
 
 
3464
  with gr.Tabs(selected="video_gen", ) as main_tabs:
3465
  with gr.Tab("Video Generator", id="video_gen") as t2v_tab:
3466
  with gr.Row():
3467
+ if args.lock_model:
3468
+ gr.Markdown("<div class='title-with-lines'><div class=line></div><h2>" + get_model_name(transformer_filename) + "</h2><div class=line></div>")
3469
+ model_choice = gr.Dropdown(visible=False, value= get_model_type(transformer_filename))
3470
+ else:
3471
+ gr.Markdown("<div class='title-with-lines'><div class=line width=100%></div></div>")
3472
+ model_choice = generate_dropdown_model_list()
3473
+ gr.Markdown("<div class='title-with-lines'><div class=line width=100%></div></div>")
3474
  with gr.Row():
3475
+ header = gr.Markdown(generate_header(transformer_filename, compile, attention_mode), visible= True)
 
 
 
 
 
 
 
 
 
 
 
 
 
3476
  with gr.Row():
3477
 
3478
  loras_choices, lset_name, state = generate_video_tab(model_choice = model_choice, header = header)
 
3482
  with gr.Tab("Downloads", id="downloads") as downloads_tab:
3483
  generate_download_tab(lset_name, loras_choices, state)
3484
  with gr.Tab("Configuration"):
3485
+ generate_configuration_tab(header, model_choice)
3486
  with gr.Tab("About"):
3487
  generate_about_tab()
3488