Commit ·
d7b86b3
1
Parent(s): 09b1e3f
fix broken queue states, avoid unnecessary reloading
Browse files- gradio_server.py +46 -40
gradio_server.py
CHANGED
|
@@ -37,6 +37,7 @@ task_id = 0
|
|
| 37 |
progress_tracker = {}
|
| 38 |
tracker_lock = threading.Lock()
|
| 39 |
file_list = []
|
|
|
|
| 40 |
|
| 41 |
def runner():
|
| 42 |
global current_task_id
|
|
@@ -47,22 +48,23 @@ def runner():
|
|
| 47 |
with tracker_lock:
|
| 48 |
progress = progress_tracker.get(task_id, {})
|
| 49 |
|
| 50 |
-
if item['state']
|
| 51 |
current_step = progress.get('current_step', 0)
|
| 52 |
total_steps = progress.get('total_steps', 0)
|
| 53 |
elapsed = time.time() - progress.get('start_time', time.time())
|
| 54 |
status = progress.get('status', "")
|
| 55 |
-
|
| 56 |
item.update({
|
| 57 |
'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
|
| 58 |
'steps': f"{current_step}/{total_steps}",
|
| 59 |
'time': f"{elapsed:.1f}s",
|
| 60 |
-
'
|
| 61 |
'status': f"{status}"
|
| 62 |
})
|
| 63 |
if not any(item['state'] == "Processing" for item in queue):
|
| 64 |
for item in queue:
|
| 65 |
if item['state'] == "Queued":
|
|
|
|
| 66 |
item['state'] = "Processing"
|
| 67 |
current_task_id = item['id']
|
| 68 |
threading.Thread(target=process_task, args=(item,)).start()
|
|
@@ -160,7 +162,8 @@ def add_video_task(*params):
|
|
| 160 |
"id": current_task_id,
|
| 161 |
"params": (current_task_id,) + params,
|
| 162 |
"state": "Queued",
|
| 163 |
-
"status": "
|
|
|
|
| 164 |
"progress": "0.0%",
|
| 165 |
"steps": f"0/{params[5]}",
|
| 166 |
"time": "--",
|
|
@@ -212,8 +215,8 @@ def update_queue_data():
|
|
| 212 |
for item in queue:
|
| 213 |
data.append([
|
| 214 |
str(item['id']),
|
| 215 |
-
item['state'],
|
| 216 |
item['status'],
|
|
|
|
| 217 |
item.get('progress', "0.0%"),
|
| 218 |
item.get('steps', ''),
|
| 219 |
item.get('time', '--'),
|
|
@@ -1013,8 +1016,8 @@ def build_callback(state, pipe, num_inference_steps, status):
|
|
| 1013 |
'total_steps': num_inference_steps,
|
| 1014 |
'start_time': start_time,
|
| 1015 |
'last_update': time.time(),
|
| 1016 |
-
'
|
| 1017 |
-
'
|
| 1018 |
}
|
| 1019 |
return update_progress
|
| 1020 |
|
|
@@ -1078,20 +1081,21 @@ def generate_video(
|
|
| 1078 |
progress=gr.Progress() #track_tqdm= True
|
| 1079 |
|
| 1080 |
):
|
| 1081 |
-
global wan_model, offloadobj
|
| 1082 |
reload_needed = state.get("_reload_needed", False)
|
| 1083 |
file_model_needed = model_needed(image2video)
|
| 1084 |
-
|
| 1085 |
-
|
| 1086 |
-
|
| 1087 |
-
|
| 1088 |
-
|
| 1089 |
-
|
| 1090 |
-
|
| 1091 |
-
|
| 1092 |
-
|
| 1093 |
-
|
| 1094 |
-
|
|
|
|
| 1095 |
|
| 1096 |
from PIL import Image
|
| 1097 |
import numpy as np
|
|
@@ -1251,13 +1255,6 @@ def generate_video(
|
|
| 1251 |
global save_path
|
| 1252 |
os.makedirs(save_path, exist_ok=True)
|
| 1253 |
abort = False
|
| 1254 |
-
with tracker_lock:
|
| 1255 |
-
progress_tracker[task_id] = {
|
| 1256 |
-
'current_step': 0,
|
| 1257 |
-
'total_steps': num_inference_steps,
|
| 1258 |
-
'start_time': time.time(),
|
| 1259 |
-
'last_update': time.time()
|
| 1260 |
-
}
|
| 1261 |
if trans.enable_teacache:
|
| 1262 |
trans.teacache_counter = 0
|
| 1263 |
trans.num_steps = num_inference_steps
|
|
@@ -1268,8 +1265,12 @@ def generate_video(
|
|
| 1268 |
status = f"{video_no}/{repeat_generation}"
|
| 1269 |
with tracker_lock:
|
| 1270 |
if task_id in progress_tracker:
|
| 1271 |
-
progress_tracker[task_id]['
|
| 1272 |
-
progress_tracker[task_id]['
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1273 |
callback = build_callback(state, trans, num_inference_steps, status)
|
| 1274 |
offload.shared_state["callback"] = callback
|
| 1275 |
gc.collect()
|
|
@@ -1279,7 +1280,7 @@ def generate_video(
|
|
| 1279 |
try:
|
| 1280 |
with tracker_lock:
|
| 1281 |
if task_id in progress_tracker:
|
| 1282 |
-
progress_tracker[task_id]['
|
| 1283 |
video_no += 1
|
| 1284 |
if image2video:
|
| 1285 |
samples = wan_model.generate(
|
|
@@ -1326,8 +1327,8 @@ def generate_video(
|
|
| 1326 |
gen_in_progress = False
|
| 1327 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
| 1328 |
os.remove(temp_filename)
|
| 1329 |
-
offload.last_offload_obj.unload_all()
|
| 1330 |
-
offload.unload_loras_from_model(trans)
|
| 1331 |
# if compile:
|
| 1332 |
# cache_size = torch._dynamo.config.cache_size_limit
|
| 1333 |
# torch.compiler.reset()
|
|
@@ -1411,6 +1412,7 @@ def generate_video(
|
|
| 1411 |
print(f"New video saved to Path: "+video_path)
|
| 1412 |
file_list.append(video_path)
|
| 1413 |
seed += 1
|
|
|
|
| 1414 |
|
| 1415 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
| 1416 |
os.remove(temp_filename)
|
|
@@ -2291,15 +2293,19 @@ def on_tab_select(t2v_state, i2v_state, evt: gr.SelectData):
|
|
| 2291 |
use_image2video = new_i2v
|
| 2292 |
|
| 2293 |
if(server_config.get("reload_model",2) == 1):
|
| 2294 |
-
|
| 2295 |
-
|
| 2296 |
-
|
| 2297 |
-
|
| 2298 |
-
|
| 2299 |
-
|
| 2300 |
-
|
| 2301 |
-
|
| 2302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2303 |
|
| 2304 |
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
|
| 2305 |
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
|
|
|
|
| 37 |
progress_tracker = {}
|
| 38 |
tracker_lock = threading.Lock()
|
| 39 |
file_list = []
|
| 40 |
+
last_model_type = None
|
| 41 |
|
| 42 |
def runner():
|
| 43 |
global current_task_id
|
|
|
|
| 48 |
with tracker_lock:
|
| 49 |
progress = progress_tracker.get(task_id, {})
|
| 50 |
|
| 51 |
+
if item['state'] == "Processing":
|
| 52 |
current_step = progress.get('current_step', 0)
|
| 53 |
total_steps = progress.get('total_steps', 0)
|
| 54 |
elapsed = time.time() - progress.get('start_time', time.time())
|
| 55 |
status = progress.get('status', "")
|
| 56 |
+
repeats = progress.get("repeats")
|
| 57 |
item.update({
|
| 58 |
'progress': f"{((current_step/total_steps)*100 if total_steps > 0 else 0):.1f}%",
|
| 59 |
'steps': f"{current_step}/{total_steps}",
|
| 60 |
'time': f"{elapsed:.1f}s",
|
| 61 |
+
'repeats': f"{repeats}",
|
| 62 |
'status': f"{status}"
|
| 63 |
})
|
| 64 |
if not any(item['state'] == "Processing" for item in queue):
|
| 65 |
for item in queue:
|
| 66 |
if item['state'] == "Queued":
|
| 67 |
+
item['status'] = "Processing"
|
| 68 |
item['state'] = "Processing"
|
| 69 |
current_task_id = item['id']
|
| 70 |
threading.Thread(target=process_task, args=(item,)).start()
|
|
|
|
| 162 |
"id": current_task_id,
|
| 163 |
"params": (current_task_id,) + params,
|
| 164 |
"state": "Queued",
|
| 165 |
+
"status": "Queued",
|
| 166 |
+
"repeats": "0/0",
|
| 167 |
"progress": "0.0%",
|
| 168 |
"steps": f"0/{params[5]}",
|
| 169 |
"time": "--",
|
|
|
|
| 215 |
for item in queue:
|
| 216 |
data.append([
|
| 217 |
str(item['id']),
|
|
|
|
| 218 |
item['status'],
|
| 219 |
+
item['repeats'],
|
| 220 |
item.get('progress', "0.0%"),
|
| 221 |
item.get('steps', ''),
|
| 222 |
item.get('time', '--'),
|
|
|
|
| 1016 |
'total_steps': num_inference_steps,
|
| 1017 |
'start_time': start_time,
|
| 1018 |
'last_update': time.time(),
|
| 1019 |
+
'repeats': status,
|
| 1020 |
+
'status': phase
|
| 1021 |
}
|
| 1022 |
return update_progress
|
| 1023 |
|
|
|
|
| 1081 |
progress=gr.Progress() #track_tqdm= True
|
| 1082 |
|
| 1083 |
):
|
| 1084 |
+
global wan_model, offloadobj, last_model_type
|
| 1085 |
reload_needed = state.get("_reload_needed", False)
|
| 1086 |
file_model_needed = model_needed(image2video)
|
| 1087 |
+
with lock:
|
| 1088 |
+
queue_not_empty = len(queue) > 0
|
| 1089 |
+
if(last_model_type != image2video and (queue_not_empty or server_config.get("reload_model",1) == 2) and (file_model_needed != model_filename or reload_needed)):
|
| 1090 |
+
del wan_model
|
| 1091 |
+
if offloadobj is not None:
|
| 1092 |
+
offloadobj.release()
|
| 1093 |
+
del offloadobj
|
| 1094 |
+
gc.collect()
|
| 1095 |
+
print(f"Loading model {get_model_name(file_model_needed)}...")
|
| 1096 |
+
wan_model, offloadobj, trans = load_models(image2video)
|
| 1097 |
+
print(f"Model loaded")
|
| 1098 |
+
state["_reload_needed"] = False
|
| 1099 |
|
| 1100 |
from PIL import Image
|
| 1101 |
import numpy as np
|
|
|
|
| 1255 |
global save_path
|
| 1256 |
os.makedirs(save_path, exist_ok=True)
|
| 1257 |
abort = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1258 |
if trans.enable_teacache:
|
| 1259 |
trans.teacache_counter = 0
|
| 1260 |
trans.num_steps = num_inference_steps
|
|
|
|
| 1265 |
status = f"{video_no}/{repeat_generation}"
|
| 1266 |
with tracker_lock:
|
| 1267 |
if task_id in progress_tracker:
|
| 1268 |
+
progress_tracker[task_id]['status'] = "Encoding Prompt"
|
| 1269 |
+
progress_tracker[task_id]['repeats'] = status
|
| 1270 |
+
progress_tracker[task_id]['current_step'] = 0
|
| 1271 |
+
progress_tracker[task_id]['total_steps'] = num_inference_steps
|
| 1272 |
+
progress_tracker[task_id]['start_time'] = time.time()
|
| 1273 |
+
progress_tracker[task_id]['last_update'] = time.time()
|
| 1274 |
callback = build_callback(state, trans, num_inference_steps, status)
|
| 1275 |
offload.shared_state["callback"] = callback
|
| 1276 |
gc.collect()
|
|
|
|
| 1280 |
try:
|
| 1281 |
with tracker_lock:
|
| 1282 |
if task_id in progress_tracker:
|
| 1283 |
+
progress_tracker[task_id]['repeats'] = video_no
|
| 1284 |
video_no += 1
|
| 1285 |
if image2video:
|
| 1286 |
samples = wan_model.generate(
|
|
|
|
| 1327 |
gen_in_progress = False
|
| 1328 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
| 1329 |
os.remove(temp_filename)
|
| 1330 |
+
if(offload.last_offload_obj): offload.last_offload_obj.unload_all()
|
| 1331 |
+
if(trans): offload.unload_loras_from_model(trans)
|
| 1332 |
# if compile:
|
| 1333 |
# cache_size = torch._dynamo.config.cache_size_limit
|
| 1334 |
# torch.compiler.reset()
|
|
|
|
| 1412 |
print(f"New video saved to Path: "+video_path)
|
| 1413 |
file_list.append(video_path)
|
| 1414 |
seed += 1
|
| 1415 |
+
last_model_type = image2video
|
| 1416 |
|
| 1417 |
if temp_filename!= None and os.path.isfile(temp_filename):
|
| 1418 |
os.remove(temp_filename)
|
|
|
|
| 2293 |
use_image2video = new_i2v
|
| 2294 |
|
| 2295 |
if(server_config.get("reload_model",2) == 1):
|
| 2296 |
+
with lock:
|
| 2297 |
+
queue_empty = len(queue) == 0
|
| 2298 |
+
if queue_empty:
|
| 2299 |
+
global wan_model, offloadobj
|
| 2300 |
+
if wan_model is not None:
|
| 2301 |
+
if offloadobj is not None:
|
| 2302 |
+
offloadobj.release()
|
| 2303 |
+
offloadobj = None
|
| 2304 |
+
wan_model = None
|
| 2305 |
+
gc.collect()
|
| 2306 |
+
torch.cuda.empty_cache()
|
| 2307 |
+
wan_model, offloadobj, trans = load_models(use_image2video)
|
| 2308 |
+
del trans
|
| 2309 |
|
| 2310 |
t2v_header = generate_header(transformer_filename_t2v, compile, attention_mode)
|
| 2311 |
i2v_header = generate_header(transformer_filename_i2v, compile, attention_mode)
|