Tophness2022 commited on
Commit
372ce2e
·
1 Parent(s): 5b0e9b1

add image thumbnails and previews to queue items

Browse files
Files changed (1) hide show
  1. gradio_server.py +141 -28
gradio_server.py CHANGED
@@ -22,6 +22,9 @@ import traceback
22
  import math
23
  import asyncio
24
  from wan.utils import prompt_parser
 
 
 
25
  PROMPT_VARS_MAX = 10
26
 
27
  target_mmgp_version = "3.3.4"
@@ -50,6 +53,29 @@ def format_time(seconds):
50
  minutes = int((seconds % 3600) // 60)
51
  return f"{hours}h {minutes}m"
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def runner():
54
  global current_task_id
55
  while True:
@@ -175,6 +201,9 @@ def add_video_task(*params):
175
  with lock:
176
  task_id += 1
177
  current_task_id = task_id
 
 
 
178
  queue.append({
179
  "id": current_task_id,
180
  "params": (current_task_id,) + params,
@@ -184,9 +213,11 @@ def add_video_task(*params):
184
  "progress": "0.0%",
185
  "steps": f"0/{params[5]}",
186
  "time": "--",
187
- "prompt": params[0]
 
 
188
  })
189
- return
190
 
191
  def move_up(selected_indices):
192
  if not selected_indices or len(selected_indices) == 0:
@@ -233,6 +264,15 @@ def update_queue_data():
233
  truncated_prompt = (item['prompt'][:97] + '...') if len(item['prompt']) > 100 else item['prompt']
234
  full_prompt = item['prompt'].replace('"', '"')
235
  prompt_cell = f'<span title="{full_prompt}">{truncated_prompt}</span>'
 
 
 
 
 
 
 
 
 
236
  data.append([
237
  item.get('status', "Starting"),
238
  item.get('repeats', "0/0"),
@@ -240,6 +280,8 @@ def update_queue_data():
240
  item.get('steps', ''),
241
  item.get('time', '--'),
242
  prompt_cell,
 
 
243
  "↑",
244
  "↓",
245
  "✖"
@@ -1143,7 +1185,6 @@ def generate_video(
1143
  print(f"Model loaded")
1144
  reload_needed= False
1145
 
1146
- from PIL import Image
1147
  import numpy as np
1148
  import tempfile
1149
 
@@ -1905,6 +1946,10 @@ def generate_video_tab(image2video=False):
1905
  download_status = gr.Markdown()
1906
  with gr.Row():
1907
  with gr.Column():
 
 
 
 
1908
  gallery_update_trigger = gr.Textbox(value="0", visible=False, label="_gallery_trigger")
1909
  with gr.Row(visible= len(loras)>0) as presets_column:
1910
  lset_choices = [ (preset, preset) for preset in loras_presets ] + [(get_new_preset_msg(advanced), "")]
@@ -2134,25 +2179,49 @@ def generate_video_tab(image2video=False):
2134
  , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
2135
  generate_btn = gr.Button("Generate")
2136
  queue_df = gr.DataFrame(
2137
- headers=["Status", "Completed", "Progress", "Steps", "Time", "Prompt", "", "", ""],
2138
- datatype=["str", "str", "str", "str", "str", "markdown", "str", "str", "str"],
2139
  interactive=False,
2140
- col_count=(9, "fixed"),
2141
  wrap=True,
2142
  value=update_queue_data,
2143
  every=1,
2144
  elem_id="queue_df"
2145
  )
2146
  def handle_selection(evt: gr.SelectData):
2147
- cell_value = evt.value
2148
- selected_index = evt.index
2149
- if cell_value == "↑":
2150
- return move_up([selected_index])
2151
- elif cell_value == "↓":
2152
- return move_down([selected_index])
2153
- elif cell_value == "":
2154
- return remove_task([selected_index])
2155
- return queue_df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2156
  def refresh_gallery_on_trigger(state):
2157
  if(state.get("update_gallery", False)):
2158
  state['update_gallery'] = False
@@ -2160,7 +2229,8 @@ def generate_video_tab(image2video=False):
2160
  selected_indices = gr.State([])
2161
  queue_df.select(
2162
  fn=handle_selection,
2163
- outputs=selected_indices
 
2164
  )
2165
  gallery_update_trigger.change(
2166
  fn=refresh_gallery_on_trigger,
@@ -2229,6 +2299,11 @@ def generate_video_tab(image2video=False):
2229
  inputs=original_inputs,
2230
  outputs=queue_df
2231
  )
 
 
 
 
 
2232
  return loras_choices, lset_name, header, state
2233
 
2234
  def generate_configuration_tab():
@@ -2524,16 +2599,9 @@ def create_demo():
2524
  #queue_df th {
2525
  pointer-events: none;
2526
  }
2527
- #queue_df .tabulator-col {
2528
- pointer-events: none;
2529
- }
2530
- #queue_df .tabulator-col .tabulator-arrow {
2531
- display: none;
2532
- }
2533
  #queue_df table {
2534
  overflow: hidden !important;
2535
  }
2536
-
2537
  #queue_df::-webkit-scrollbar {
2538
  display: none !important;
2539
  }
@@ -2545,7 +2613,8 @@ def create_demo():
2545
  width: 100px;
2546
  }
2547
  #queue_df td:nth-child(6) {
2548
- width: 300px;
 
2549
  }
2550
  #queue_df td:nth-child(7),
2551
  #queue_df td:nth-child(8),
@@ -2553,12 +2622,56 @@ def create_demo():
2553
  cursor: pointer;
2554
  text-align: center;
2555
  font-weight: bold;
 
 
 
 
2556
  }
2557
- #queue_df td:nth-child(7):hover,
2558
- #queue_df td:nth-child(8):hover,
2559
- #queue_df td:nth-child(9):hover {
2560
- background-color: #e0e0e0;
 
 
 
2561
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2562
  """
2563
  with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
2564
  gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.2 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
 
22
  import math
23
  import asyncio
24
  from wan.utils import prompt_parser
25
+ import base64
26
+ import io
27
+ from PIL import Image
28
  PROMPT_VARS_MAX = 10
29
 
30
  target_mmgp_version = "3.3.4"
 
53
  minutes = int((seconds % 3600) // 60)
54
  return f"{hours}h {minutes}m"
55
 
56
+ def pil_to_base64_uri(pil_image, format="png", quality=75):
57
+ if pil_image is None:
58
+ return None
59
+ buffer = io.BytesIO()
60
+ try:
61
+ img_to_save = pil_image
62
+ if format.lower() == 'jpeg' and pil_image.mode == 'RGBA':
63
+ img_to_save = pil_image.convert('RGB')
64
+ elif format.lower() == 'png' and pil_image.mode not in ['RGB', 'RGBA', 'L', 'P']:
65
+ img_to_save = pil_image.convert('RGBA')
66
+ elif pil_image.mode == 'P':
67
+ img_to_save = pil_image.convert('RGBA' if 'transparency' in pil_image.info else 'RGB')
68
+ if format.lower() == 'jpeg':
69
+ img_to_save.save(buffer, format=format, quality=quality)
70
+ else:
71
+ img_to_save.save(buffer, format=format)
72
+ img_bytes = buffer.getvalue()
73
+ encoded_string = base64.b64encode(img_bytes).decode("utf-8")
74
+ return f"data:image/{format.lower()};base64,{encoded_string}"
75
+ except Exception as e:
76
+ print(f"Error converting PIL to base64: {e}")
77
+ return None
78
+
79
  def runner():
80
  global current_task_id
81
  while True:
 
201
  with lock:
202
  task_id += 1
203
  current_task_id = task_id
204
+ start_image_data = params[16] if len(params) > 16 else None
205
+ end_image_data = params[17] if len(params) > 17 else None
206
+
207
  queue.append({
208
  "id": current_task_id,
209
  "params": (current_task_id,) + params,
 
213
  "progress": "0.0%",
214
  "steps": f"0/{params[5]}",
215
  "time": "--",
216
+ "prompt": params[0],
217
+ "start_image_data": start_image_data,
218
+ "end_image_data": end_image_data
219
  })
220
+ return update_queue_data()
221
 
222
  def move_up(selected_indices):
223
  if not selected_indices or len(selected_indices) == 0:
 
264
  truncated_prompt = (item['prompt'][:97] + '...') if len(item['prompt']) > 100 else item['prompt']
265
  full_prompt = item['prompt'].replace('"', '&quot;')
266
  prompt_cell = f'<span title="{full_prompt}">{truncated_prompt}</span>'
267
+ start_img_uri = pil_to_base64_uri(item.get('start_image_data'), format="jpeg", quality=70)
268
+ end_img_uri = pil_to_base64_uri(item.get('end_image_data'), format="jpeg", quality=70)
269
+ thumbnail_size = "50px"
270
+ start_img_md = ""
271
+ end_img_md = ""
272
+ if start_img_uri:
273
+ start_img_md = f'<img src="{start_img_uri}" alt="Start" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
274
+ if end_img_uri:
275
+ end_img_md = f'<img src="{end_img_uri}" alt="End" style="max-width:{thumbnail_size}; max-height:{thumbnail_size}; display: block; margin: auto; object-fit: contain;" />'
276
  data.append([
277
  item.get('status', "Starting"),
278
  item.get('repeats', "0/0"),
 
280
  item.get('steps', ''),
281
  item.get('time', '--'),
282
  prompt_cell,
283
+ start_img_md,
284
+ end_img_md,
285
  "↑",
286
  "↓",
287
  "✖"
 
1185
  print(f"Model loaded")
1186
  reload_needed= False
1187
 
 
1188
  import numpy as np
1189
  import tempfile
1190
 
 
1946
  download_status = gr.Markdown()
1947
  with gr.Row():
1948
  with gr.Column():
1949
+ with gr.Column(visible=False, elem_id="image-modal-container") as modal_container:
1950
+ with gr.Row(elem_id="image-modal-close-button-row"):
1951
+ close_modal_button = gr.Button("❌", size="sm")
1952
+ modal_image_display = gr.Image(label="Full Resolution Image", interactive=False, show_label=False)
1953
  gallery_update_trigger = gr.Textbox(value="0", visible=False, label="_gallery_trigger")
1954
  with gr.Row(visible= len(loras)>0) as presets_column:
1955
  lset_choices = [ (preset, preset) for preset in loras_presets ] + [(get_new_preset_msg(advanced), "")]
 
2179
  , columns=[3], rows=[1], object_fit="contain", height=450, selected_index=0, interactive= False)
2180
  generate_btn = gr.Button("Generate")
2181
  queue_df = gr.DataFrame(
2182
+ headers=["Status", "Completed", "Progress", "Steps", "Time", "Prompt", "Start", "End", "", "", ""],
2183
+ datatype=["str", "str", "str", "str", "str", "markdown", "markdown", "markdown", "str", "str", "str"],
2184
  interactive=False,
2185
+ col_count=(11, "fixed"),
2186
  wrap=True,
2187
  value=update_queue_data,
2188
  every=1,
2189
  elem_id="queue_df"
2190
  )
2191
  def handle_selection(evt: gr.SelectData):
2192
+ if evt.index is None:
2193
+ return gr.update(), gr.update(), gr.update(visible=False)
2194
+ row_index, col_index = evt.index
2195
+ cell_value = None
2196
+ if col_index in [8, 9, 10]:
2197
+ if col_index == 8: cell_value = "↑"
2198
+ elif col_index == 9: cell_value = ""
2199
+ elif col_index == 10: cell_value = "✖"
2200
+ if col_index == 8:
2201
+ new_df_data = move_up([row_index])
2202
+ return new_df_data, gr.update(), gr.update(visible=False)
2203
+ elif col_index == 9:
2204
+ new_df_data = move_down([row_index])
2205
+ return new_df_data, gr.update(), gr.update(visible=False)
2206
+ elif col_index == 10:
2207
+ new_df_data = remove_task([row_index])
2208
+ return new_df_data, gr.update(), gr.update(visible=False)
2209
+ start_img_col_idx = 6
2210
+ end_img_col_idx = 7
2211
+ image_data_to_show = None
2212
+ if col_index == start_img_col_idx:
2213
+ with lock:
2214
+ if row_index < len(queue):
2215
+ image_data_to_show = queue[row_index].get('start_image_data')
2216
+ elif col_index == end_img_col_idx:
2217
+ with lock:
2218
+ if row_index < len(queue):
2219
+ image_data_to_show = queue[row_index].get('end_image_data')
2220
+
2221
+ if image_data_to_show:
2222
+ return gr.update(), gr.update(value=image_data_to_show), gr.update(visible=True)
2223
+ else:
2224
+ return gr.update(), gr.update(), gr.update(visible=False)
2225
  def refresh_gallery_on_trigger(state):
2226
  if(state.get("update_gallery", False)):
2227
  state['update_gallery'] = False
 
2229
  selected_indices = gr.State([])
2230
  queue_df.select(
2231
  fn=handle_selection,
2232
+ inputs=None,
2233
+ outputs=[queue_df, modal_image_display, modal_container],
2234
  )
2235
  gallery_update_trigger.change(
2236
  fn=refresh_gallery_on_trigger,
 
2299
  inputs=original_inputs,
2300
  outputs=queue_df
2301
  )
2302
+ close_modal_button.click(
2303
+ lambda: gr.update(visible=False),
2304
+ inputs=[],
2305
+ outputs=[modal_container]
2306
+ )
2307
  return loras_choices, lset_name, header, state
2308
 
2309
  def generate_configuration_tab():
 
2599
  #queue_df th {
2600
  pointer-events: none;
2601
  }
 
 
 
 
 
 
2602
  #queue_df table {
2603
  overflow: hidden !important;
2604
  }
 
2605
  #queue_df::-webkit-scrollbar {
2606
  display: none !important;
2607
  }
 
2613
  width: 100px;
2614
  }
2615
  #queue_df td:nth-child(6) {
2616
+ width: auto;
2617
+ min-width: 200px;
2618
  }
2619
  #queue_df td:nth-child(7),
2620
  #queue_df td:nth-child(8),
 
2622
  cursor: pointer;
2623
  text-align: center;
2624
  font-weight: bold;
2625
+ width: 60px;
2626
+ text-align: center;
2627
+ padding: 2px !important;
2628
+ cursor: pointer;
2629
  }
2630
+ #queue_df td:nth-child(10) img,
2631
+ #queue_df td:nth-child(11) img {
2632
+ max-width: 50px;
2633
+ max-height: 50px;
2634
+ object-fit: contain;
2635
+ display: block;
2636
+ margin: auto;
2637
  }
2638
+ #image-modal-container {
2639
+ position: fixed;
2640
+ top: 0;
2641
+ left: 0;
2642
+ width: 100%;
2643
+ height: 100%;
2644
+ background-color: rgba(0, 0, 0, 0.7);
2645
+ justify-content: center;
2646
+ align-items: center;
2647
+ z-index: 1000;
2648
+ padding: 20px;
2649
+ box-sizing: border-box;
2650
+ }
2651
+ #image-modal-container > div {
2652
+ background-color: white;
2653
+ padding: 15px;
2654
+ border-radius: 8px;
2655
+ max-width: 90%;
2656
+ max-height: 90%;
2657
+ overflow: auto;
2658
+ position: relative;
2659
+ display: flex;
2660
+ flex-direction: column;
2661
+ }
2662
+ #image-modal-container img {
2663
+ max-width: 100%;
2664
+ max-height: 80vh;
2665
+ object-fit: contain;
2666
+ margin-top: 10px;
2667
+ }
2668
+ #image-modal-close-button-row {
2669
+ display: flex;
2670
+ justify-content: flex-end;
2671
+ }
2672
+ #image-modal-close-button-row button {
2673
+ cursor: pointer;
2674
+ }
2675
  """
2676
  with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md")) as demo:
2677
  gr.Markdown("<div align=center><H1>Wan 2.1<SUP>GP</SUP> v3.2 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3> (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")