Tophness2022 commited on
Commit
d7b86b3
·
1 Parent(s): 09b1e3f

fix broken queue states, avoid unnecessary reloading

Browse files
Files changed (1) hide show
  1. gradio_server.py +46 -40
gradio_server.py CHANGED
@@ -37,6 +37,7 @@ task_id = 0
37
  progress_tracker = {}
38
  tracker_lock = threading.Lock()
39
  file_list = []
 
40
 
41
  def runner():
42
  global current_task_id
@@ -47,22 +48,23 @@ def runner():
47
  with tracker_lock:
48
  progress = progress_tracker.get(task_id, {})
49
 
50
- if item['state'] != "Queued" and item['state'] != "Finished":
51
  current_step = progress.get('current_step', 0)
52
  total_steps = progress.get('total_steps', 0)
53
  elapsed = time.time() - progress.get('start_time', time.time())
54
  status = progress.get('status', "")
55
- state = progress.get("state")
56
  item.update({
57
  'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
58
  'steps': f"{current_step}/{total_steps}",
59
  'time': f"{elapsed:.1f}s",
60
- 'state': f"{state}",
61
  'status': f"{status}"
62
  })
63
  if not any(item['state'] == "Processing" for item in queue):
64
  for item in queue:
65
  if item['state'] == "Queued":
 
66
  item['state'] = "Processing"
67
  current_task_id = item['id']
68
  threading.Thread(target=process_task, args=(item,)).start()
@@ -160,7 +162,8 @@ def add_video_task(*params):
160
  "id": current_task_id,
161
  "params": (current_task_id,) + params,
162
  "state": "Queued",
163
- "status": "0/0",
 
164
  "progress": "0.0%",
165
  "steps": f"0/{params[5]}",
166
  "time": "--",
@@ -212,8 +215,8 @@ def update_queue_data():
212
  for item in queue:
213
  data.append([
214
  str(item['id']),
215
- item['state'],
216
  item['status'],
 
217
  item.get('progress', "0.0%"),
218
  item.get('steps', ''),
219
  item.get('time', '--'),
@@ -1013,8 +1016,8 @@ def build_callback(state, pipe, num_inference_steps, status):
1013
  'total_steps': num_inference_steps,
1014
  'start_time': start_time,
1015
  'last_update': time.time(),
1016
- 'status': status,
1017
- 'state': phase
1018
  }
1019
  return update_progress
1020
 
@@ -1078,20 +1081,21 @@ def generate_video(
1078
  progress=gr.Progress() #track_tqdm= True
1079
 
1080
  ):
1081
- global wan_model, offloadobj
1082
  reload_needed = state.get("_reload_needed", False)
1083
  file_model_needed = model_needed(image2video)
1084
- if(server_config.get("reload_model",1) == 2):
1085
- if file_model_needed != model_filename or reload_needed:
1086
- del wan_model
1087
- if offloadobj is not None:
1088
- offloadobj.release()
1089
- del offloadobj
1090
- gc.collect()
1091
- print(f"Loading model {get_model_name(file_model_needed)}...")
1092
- wan_model, offloadobj, trans = load_models(image2video)
1093
- print(f"Model loaded")
1094
- state["_reload_needed"] = False
 
1095
 
1096
  from PIL import Image
1097
  import numpy as np
@@ -1251,13 +1255,6 @@ def generate_video(
1251
  global save_path
1252
  os.makedirs(save_path, exist_ok=True)
1253
  abort = False
1254
- with tracker_lock:
1255
- progress_tracker[task_id] = {
1256
- 'current_step': 0,
1257
- 'total_steps': num_inference_steps,
1258
- 'start_time': time.time(),
1259
- 'last_update': time.time()
1260
- }
1261
  if trans.enable_teacache:
1262
  trans.teacache_counter = 0
1263
  trans.num_steps = num_inference_steps
@@ -1268,8 +1265,12 @@ def generate_video(
1268
  status = f"{video_no}/{repeat_generation}"
1269
  with tracker_lock:
1270
  if task_id in progress_tracker:
1271
- progress_tracker[task_id]['state'] = "Encoding Prompt"
1272
- progress_tracker[task_id]['status'] = status
 
 
 
 
1273
  callback = build_callback(state, trans, num_inference_steps, status)
1274
  offload.shared_state["callback"] = callback
1275
  gc.collect()
@@ -1279,7 +1280,7 @@ def generate_video(
1279
  try:
1280
  with tracker_lock:
1281
  if task_id in progress_tracker:
1282
- progress_tracker[task_id]['status'] = video_no
1283
  video_no += 1
1284
  if image2video:
1285
  samples = wan_model.generate(
@@ -1326,8 +1327,8 @@ def generate_video(
1326
  gen_in_progress = False
1327
  if temp_filename!= None and os.path.isfile(temp_filename):
1328
  os.remove(temp_filename)
1329
- offload.last_offload_obj.unload_all()
1330
- offload.unload_loras_from_model(trans)
1331
  # if compile:
1332
  # cache_size = torch._dynamo.config.cache_size_limit
1333
  # torch.compiler.reset()
@@ -1411,6 +1412,7 @@ def generate_video(
1411
  print(f"New video saved to Path: "+video_path)
1412
  file_list.append(video_path)
1413
  seed += 1
 
1414
 
1415
  if temp_filename!= None and os.path.isfile(temp_filename):
1416
  os.remove(temp_filename)
@@ -2291,15 +2293,19 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
2291
  use_image2video = new_i2v
2292
 
2293
  if(server_config.get("reload_model",2) == 1):
2294
- global wan_model, offloadobj
2295
- if wan_model is not None:
2296
- if offloadobj is not None:
2297
- offloadobj.release()
2298
- offloadobj = None
2299
- wan_model = None
2300
- gc.collect()
2301
- torch.cuda.empty_cache()
2302
- wan_model, offloadobj, trans = load_models(use_image2video)
 
 
 
 
2303
 
2304
  t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
2305
  i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
 
37
  progress_tracker = {}
38
  tracker_lock = threading.Lock()
39
  file_list = []
40
+ last_model_type = None
41
 
42
  def runner():
43
  global current_task_id
 
48
  with tracker_lock:
49
  progress = progress_tracker.get(task_id, {})
50
 
51
+ if item['state'] == "Processing":
52
  current_step = progress.get('current_step', 0)
53
  total_steps = progress.get('total_steps', 0)
54
  elapsed = time.time() - progress.get('start_time', time.time())
55
  status = progress.get('status', "")
56
+ repeats = progress.get("repeats")
57
  item.update({
58
  'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
59
  'steps': f"{current_step}/{total_steps}",
60
  'time': f"{elapsed:.1f}s",
61
+ 'repeats': f"{repeats}",
62
  'status': f"{status}"
63
  })
64
  if not any(item['state'] == "Processing" for item in queue):
65
  for item in queue:
66
  if item['state'] == "Queued":
67
+ item['status'] = "Processing"
68
  item['state'] = "Processing"
69
  current_task_id = item['id']
70
  threading.Thread(target=process_task, args=(item,)).start()
 
162
  "id": current_task_id,
163
  "params": (current_task_id,) + params,
164
  "state": "Queued",
165
+ "status": "Queued",
166
+ "repeats": "0/0",
167
  "progress": "0.0%",
168
  "steps": f"0/{params[5]}",
169
  "time": "--",
 
215
  for item in queue:
216
  data.append([
217
  str(item['id']),
 
218
  item['status'],
219
+ item['repeats'],
220
  item.get('progress', "0.0%"),
221
  item.get('steps', ''),
222
  item.get('time', '--'),
 
1016
  'total_steps': num_inference_steps,
1017
  'start_time': start_time,
1018
  'last_update': time.time(),
1019
+ 'repeats': status,
1020
+ 'status': phase
1021
  }
1022
  return update_progress
1023
 
 
1081
  progress=gr.Progress() #track_tqdm= True
1082
 
1083
  ):
1084
+ global wan_model, offloadobj, last_model_type
1085
  reload_needed = state.get("_reload_needed", False)
1086
  file_model_needed = model_needed(image2video)
1087
+ with lock:
1088
+ queue_not_empty = len(queue) > 0
1089
+ if(last_model_type != image2video and (queue_not_empty or server_config.get("reload_model",1) == 2) and (file_model_needed != model_filename or reload_needed)):
1090
+ del wan_model
1091
+ if offloadobj is not None:
1092
+ offloadobj.release()
1093
+ del offloadobj
1094
+ gc.collect()
1095
+ print(f"Loading model {get_model_name(file_model_needed)}...")
1096
+ wan_model, offloadobj, trans = load_models(image2video)
1097
+ print(f"Model loaded")
1098
+ state["_reload_needed"] = False
1099
 
1100
  from PIL import Image
1101
  import numpy as np
 
1255
  global save_path
1256
  os.makedirs(save_path, exist_ok=True)
1257
  abort = False
 
 
 
 
 
 
 
1258
  if trans.enable_teacache:
1259
  trans.teacache_counter = 0
1260
  trans.num_steps = num_inference_steps
 
1265
  status = f"{video_no}/{repeat_generation}"
1266
  with tracker_lock:
1267
  if task_id in progress_tracker:
1268
+ progress_tracker[task_id]['status'] = "Encoding Prompt"
1269
+ progress_tracker[task_id]['repeats'] = status
1270
+ progress_tracker[task_id]['current_step'] = 0
1271
+ progress_tracker[task_id]['total_steps'] = num_inference_steps
1272
+ progress_tracker[task_id]['start_time'] = time.time()
1273
+ progress_tracker[task_id]['last_update'] = time.time()
1274
  callback = build_callback(state, trans, num_inference_steps, status)
1275
  offload.shared_state["callback"] = callback
1276
  gc.collect()
 
1280
  try:
1281
  with tracker_lock:
1282
  if task_id in progress_tracker:
1283
+ progress_tracker[task_id]['repeats'] = video_no
1284
  video_no += 1
1285
  if image2video:
1286
  samples = wan_model.generate(
 
1327
  gen_in_progress = False
1328
  if temp_filename!= None and os.path.isfile(temp_filename):
1329
  os.remove(temp_filename)
1330
+ if(offload.last_offload_obj): offload.last_offload_obj.unload_all()
1331
+ if(trans): offload.unload_loras_from_model(trans)
1332
  # if compile:
1333
  # cache_size = torch._dynamo.config.cache_size_limit
1334
  # torch.compiler.reset()
 
1412
  print(f"New video saved to Path: "+video_path)
1413
  file_list.append(video_path)
1414
  seed += 1
1415
+ last_model_type = image2video
1416
 
1417
  if temp_filename!= None and os.path.isfile(temp_filename):
1418
  os.remove(temp_filename)
 
2293
  use_image2video = new_i2v
2294
 
2295
  if(server_config.get("reload_model",2) == 1):
2296
+ with lock:
2297
+ queue_empty = len(queue) == 0
2298
+ if queue_empty:
2299
+ global wan_model, offloadobj
2300
+ if wan_model is not None:
2301
+ if offloadobj is not None:
2302
+ offloadobj.release()
2303
+ offloadobj = None
2304
+ wan_model = None
2305
+ gc.collect()
2306
+ torch.cuda.empty_cache()
2307
+ wan_model, offloadobj, trans = load_models(use_image2video)
2308
+ del trans
2309
 
2310
  t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
2311
  i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)