Update app.py
Browse files
app.py
CHANGED
|
@@ -266,67 +266,102 @@ import torch.nn.functional as F
|
|
| 266 |
|
| 267 |
|
| 268 |
|
| 269 |
-
v21_path = hf_hub_download(
|
| 270 |
-
repo_id="Phr00t/Qwen-Image-Edit-Rapid-AIO",
|
| 271 |
-
filename="v21/Qwen-Rapid-AIO-NSFW-v21.safetensors",
|
| 272 |
-
repo_type="model"
|
| 273 |
-
)
|
| 274 |
-
|
| 275 |
-
# 2. load the base architecture
|
| 276 |
-
# we use the default flowmatch scheduler first to ensure the pipe inits correctly,
|
| 277 |
-
# then we swap it to euler_a later
|
| 278 |
print("loading base pipeline architecture...")
|
| 279 |
pipe = QwenImageEditPlusPipeline.from_pretrained(
|
| 280 |
"Qwen/Qwen-Image-Edit-2511",
|
| 281 |
torch_dtype=torch.bfloat16
|
| 282 |
).to("cuda")
|
| 283 |
|
| 284 |
-
#
|
| 285 |
-
# we configure it with the base config to keep timestep spacing correct
|
| 286 |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
| 287 |
|
| 288 |
-
#
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
state_dict = load_file(v21_path)
|
| 291 |
|
| 292 |
-
#
|
| 293 |
-
#
|
| 294 |
-
|
| 295 |
-
print("injecting AIO weights...")
|
| 296 |
|
| 297 |
-
#
|
| 298 |
-
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
for k, v in state_dict.items():
|
|
|
|
|
|
|
| 301 |
if k.startswith("model.diffusion_model."):
|
| 302 |
-
new_key = k.replace("model.diffusion_model.", "
|
| 303 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
elif k.startswith("first_stage_model."):
|
| 305 |
-
new_key = k.replace("first_stage_model.", "
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
#
|
| 314 |
-
|
| 315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
else:
|
| 317 |
-
print("
|
| 318 |
-
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
print(f"unexpected keys (ignore if comfy artifacts): {len(mismatched.unexpected_keys)}")
|
| 325 |
|
| 326 |
-
#
|
|
|
|
| 327 |
del state_dict
|
| 328 |
-
del
|
| 329 |
-
del
|
|
|
|
| 330 |
gc.collect()
|
| 331 |
torch.cuda.empty_cache()
|
| 332 |
|
|
|
|
| 266 |
|
| 267 |
|
| 268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
print("loading base pipeline architecture...")
|
| 270 |
pipe = QwenImageEditPlusPipeline.from_pretrained(
|
| 271 |
"Qwen/Qwen-Image-Edit-2511",
|
| 272 |
torch_dtype=torch.bfloat16
|
| 273 |
).to("cuda")
|
| 274 |
|
| 275 |
+
# force euler ancestral scheduler
|
|
|
|
| 276 |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
|
| 277 |
|
| 278 |
+
# 2. DOWNLOAD & LOAD RAW WEIGHTS
|
| 279 |
+
# ------------------------------------------------------------------------------
|
| 280 |
+
print("accessing v21 checkpoint...")
|
| 281 |
+
v21_path = hf_hub_download(
|
| 282 |
+
repo_id="Phr00t/Qwen-Image-Edit-Rapid-AIO",
|
| 283 |
+
filename="v21/Qwen-Rapid-AIO-NSFW-v21.safetensors",
|
| 284 |
+
repo_type="model"
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
print(f"loading 28GB state dict into cpu memory...")
|
| 288 |
state_dict = load_file(v21_path)
|
| 289 |
|
| 290 |
+
# 3. DYNAMIC COMPONENT MAPPING (NO ASSUMPTIONS)
|
| 291 |
+
# ------------------------------------------------------------------------------
|
| 292 |
+
print("sorting weights into components...")
|
|
|
|
| 293 |
|
| 294 |
+
# containers for the sorted weights
|
| 295 |
+
transformer_weights = {}
|
| 296 |
+
vae_weights = {}
|
| 297 |
+
text_encoder_weights = {}
|
| 298 |
+
|
| 299 |
+
# analyze the first key to determine the format
|
| 300 |
+
first_key = next(iter(state_dict.keys()))
|
| 301 |
+
print(f"format detection - first key detected: {first_key}")
|
| 302 |
+
|
| 303 |
+
# iterate and sort
|
| 304 |
for k, v in state_dict.items():
|
| 305 |
+
# MAPPING: TRANSFORMER
|
| 306 |
+
# ComfyUI usually prefixes with 'model.diffusion_model.'
|
| 307 |
if k.startswith("model.diffusion_model."):
|
| 308 |
+
new_key = k.replace("model.diffusion_model.", "")
|
| 309 |
+
transformer_weights[new_key] = v
|
| 310 |
+
# Or sometimes just 'transformer.' or 'model.'
|
| 311 |
+
elif k.startswith("transformer."):
|
| 312 |
+
new_key = k.replace("transformer.", "")
|
| 313 |
+
transformer_weights[new_key] = v
|
| 314 |
+
|
| 315 |
+
# MAPPING: VAE
|
| 316 |
+
# ComfyUI prefix: 'first_stage_model.'
|
| 317 |
elif k.startswith("first_stage_model."):
|
| 318 |
+
new_key = k.replace("first_stage_model.", "")
|
| 319 |
+
vae_weights[new_key] = v
|
| 320 |
+
# Diffusers prefix: 'vae.'
|
| 321 |
+
elif k.startswith("vae."):
|
| 322 |
+
new_key = k.replace("vae.", "")
|
| 323 |
+
vae_weights[new_key] = v
|
| 324 |
+
|
| 325 |
+
# MAPPING: TEXT ENCODER
|
| 326 |
+
# ComfyUI prefix: 'conditioner.embedders.' or 'text_encoder.'
|
| 327 |
+
elif "text_encoder" in k or "conditioner" in k:
|
| 328 |
+
# this is tricky, we try to keep the suffix
|
| 329 |
+
if "conditioner.embedders.0." in k:
|
| 330 |
+
new_key = k.replace("conditioner.embedders.0.", "")
|
| 331 |
+
text_encoder_weights[new_key] = v
|
| 332 |
+
elif "text_encoder." in k:
|
| 333 |
+
new_key = k.replace("text_encoder.", "")
|
| 334 |
+
text_encoder_weights[new_key] = v
|
| 335 |
+
|
| 336 |
+
# 4. INJECT WEIGHTS (COMPONENT LEVEL)
|
| 337 |
+
# ------------------------------------------------------------------------------
|
| 338 |
+
print(f"injection statistics:")
|
| 339 |
+
print(f" - transformer keys found: {len(transformer_weights)}")
|
| 340 |
+
print(f" - vae keys found: {len(vae_weights)}")
|
| 341 |
+
print(f" - text encoder keys found: {len(text_encoder_weights)}")
|
| 342 |
+
|
| 343 |
+
if len(transformer_weights) > 0:
|
| 344 |
+
print("injecting transformer weights...")
|
| 345 |
+
msg = pipe.transformer.load_state_dict(transformer_weights, strict=False)
|
| 346 |
+
print(f"transformer missing keys: {len(msg.missing_keys)}")
|
| 347 |
else:
|
| 348 |
+
print("CRITICAL WARNING: no transformer weights found in file. check mapping logic.")
|
| 349 |
+
|
| 350 |
+
if len(vae_weights) > 0:
|
| 351 |
+
print("injecting vae weights...")
|
| 352 |
+
pipe.vae.load_state_dict(vae_weights, strict=False)
|
| 353 |
|
| 354 |
+
if len(text_encoder_weights) > 0:
|
| 355 |
+
print("injecting text encoder weights...")
|
| 356 |
+
# text encoder structure can vary wildly, strict=False is mandatory here
|
| 357 |
+
pipe.text_encoder.load_state_dict(text_encoder_weights, strict=False)
|
|
|
|
| 358 |
|
| 359 |
+
# 5. CLEANUP & RUN
|
| 360 |
+
# ------------------------------------------------------------------------------
|
| 361 |
del state_dict
|
| 362 |
+
del transformer_weights
|
| 363 |
+
del vae_weights
|
| 364 |
+
del text_encoder_weights
|
| 365 |
gc.collect()
|
| 366 |
torch.cuda.empty_cache()
|
| 367 |
|