signsur4739379373 commited on
Commit
3bfb445
·
verified ·
1 Parent(s): 6c17d77

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -43
app.py CHANGED
@@ -266,67 +266,102 @@ import torch.nn.functional as F
266
 
267
 
268
 
269
- v21_path = hf_hub_download(
270
- repo_id="Phr00t/Qwen-Image-Edit-Rapid-AIO",
271
- filename="v21/Qwen-Rapid-AIO-NSFW-v21.safetensors",
272
- repo_type="model"
273
- )
274
-
275
- # 2. load the base architecture
276
- # we use the default flowmatch scheduler first to ensure the pipe inits correctly,
277
- # then we swap it to euler_a later
278
  print("loading base pipeline architecture...")
279
  pipe = QwenImageEditPlusPipeline.from_pretrained(
280
  "Qwen/Qwen-Image-Edit-2511",
281
  torch_dtype=torch.bfloat16
282
  ).to("cuda")
283
 
284
- # 3. switch scheduler to Euler Ancestral (Lightning requirement)
285
- # we configure it with the base config to keep timestep spacing correct
286
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
287
 
288
- # 4. load the massive 28GB v21 weights
289
- print(f"loading v21 weights from {v21_path}...")
 
 
 
 
 
 
 
 
290
  state_dict = load_file(v21_path)
291
 
292
- # 5. The "Brutal" Injection
293
- # Because this is an AIO file, keys might be prefixed with "model." or "transformer."
294
- # or they might match the pipeline exactly. We try the root load first.
295
- print("injecting AIO weights...")
296
 
297
- # clean up keys if necessary (common in comfyui > diffusers conversions)
298
- # this removes 'model.diffusion_model.' prefixes if they exist to match diffusers 'transformer.'
299
- new_state_dict = {}
 
 
 
 
 
 
 
300
  for k, v in state_dict.items():
 
 
301
  if k.startswith("model.diffusion_model."):
302
- new_key = k.replace("model.diffusion_model.", "transformer.")
303
- new_state_dict[new_key] = v
 
 
 
 
 
 
 
304
  elif k.startswith("first_stage_model."):
305
- new_key = k.replace("first_stage_model.", "vae.")
306
- new_state_dict[new_key] = v
307
- elif k.startswith("conditioner.embedders.0."):
308
- new_key = k.replace("conditioner.embedders.0.", "text_encoder.")
309
- new_state_dict[new_key] = v
310
- else:
311
- new_state_dict[k] = v
312
-
313
- # if no keys were renamed, just use the original
314
- if len(new_state_dict) == len(state_dict):
315
- final_dict = state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  else:
317
- print("detected comfyui keys, remapped for diffusers.")
318
- final_dict = new_state_dict
 
 
 
319
 
320
- # attempt load
321
- mismatched = pipe.load_state_dict(final_dict, strict=False)
322
- print("weights loaded.")
323
- print(f"missing keys (ignore if just config/aux): {len(mismatched.missing_keys)}")
324
- print(f"unexpected keys (ignore if comfy artifacts): {len(mismatched.unexpected_keys)}")
325
 
326
- # 6. cleanup
 
327
  del state_dict
328
- del new_state_dict
329
- del final_dict
 
330
  gc.collect()
331
  torch.cuda.empty_cache()
332
 
 
266
 
267
 
268
 
 
 
 
 
 
 
 
 
 
269
  print("loading base pipeline architecture...")
270
  pipe = QwenImageEditPlusPipeline.from_pretrained(
271
  "Qwen/Qwen-Image-Edit-2511",
272
  torch_dtype=torch.bfloat16
273
  ).to("cuda")
274
 
275
+ # force euler ancestral scheduler
 
276
  pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
277
 
278
+ # 2. DOWNLOAD & LOAD RAW WEIGHTS
279
+ # ------------------------------------------------------------------------------
280
+ print("accessing v21 checkpoint...")
281
+ v21_path = hf_hub_download(
282
+ repo_id="Phr00t/Qwen-Image-Edit-Rapid-AIO",
283
+ filename="v21/Qwen-Rapid-AIO-NSFW-v21.safetensors",
284
+ repo_type="model"
285
+ )
286
+
287
+ print(f"loading 28GB state dict into cpu memory...")
288
  state_dict = load_file(v21_path)
289
 
290
+ # 3. DYNAMIC COMPONENT MAPPING (NO ASSUMPTIONS)
291
+ # ------------------------------------------------------------------------------
292
+ print("sorting weights into components...")
 
293
 
294
+ # containers for the sorted weights
295
+ transformer_weights = {}
296
+ vae_weights = {}
297
+ text_encoder_weights = {}
298
+
299
+ # analyze the first key to determine the format
300
+ first_key = next(iter(state_dict.keys()))
301
+ print(f"format detection - first key detected: {first_key}")
302
+
303
+ # iterate and sort
304
  for k, v in state_dict.items():
305
+ # MAPPING: TRANSFORMER
306
+ # ComfyUI usually prefixes with 'model.diffusion_model.'
307
  if k.startswith("model.diffusion_model."):
308
+ new_key = k.replace("model.diffusion_model.", "")
309
+ transformer_weights[new_key] = v
310
+ # Or sometimes just 'transformer.' or 'model.'
311
+ elif k.startswith("transformer."):
312
+ new_key = k.replace("transformer.", "")
313
+ transformer_weights[new_key] = v
314
+
315
+ # MAPPING: VAE
316
+ # ComfyUI prefix: 'first_stage_model.'
317
  elif k.startswith("first_stage_model."):
318
+ new_key = k.replace("first_stage_model.", "")
319
+ vae_weights[new_key] = v
320
+ # Diffusers prefix: 'vae.'
321
+ elif k.startswith("vae."):
322
+ new_key = k.replace("vae.", "")
323
+ vae_weights[new_key] = v
324
+
325
+ # MAPPING: TEXT ENCODER
326
+ # ComfyUI prefix: 'conditioner.embedders.' or 'text_encoder.'
327
+ elif "text_encoder" in k or "conditioner" in k:
328
+ # this is tricky, we try to keep the suffix
329
+ if "conditioner.embedders.0." in k:
330
+ new_key = k.replace("conditioner.embedders.0.", "")
331
+ text_encoder_weights[new_key] = v
332
+ elif "text_encoder." in k:
333
+ new_key = k.replace("text_encoder.", "")
334
+ text_encoder_weights[new_key] = v
335
+
336
+ # 4. INJECT WEIGHTS (COMPONENT LEVEL)
337
+ # ------------------------------------------------------------------------------
338
+ print(f"injection statistics:")
339
+ print(f" - transformer keys found: {len(transformer_weights)}")
340
+ print(f" - vae keys found: {len(vae_weights)}")
341
+ print(f" - text encoder keys found: {len(text_encoder_weights)}")
342
+
343
+ if len(transformer_weights) > 0:
344
+ print("injecting transformer weights...")
345
+ msg = pipe.transformer.load_state_dict(transformer_weights, strict=False)
346
+ print(f"transformer missing keys: {len(msg.missing_keys)}")
347
  else:
348
+ print("CRITICAL WARNING: no transformer weights found in file. check mapping logic.")
349
+
350
+ if len(vae_weights) > 0:
351
+ print("injecting vae weights...")
352
+ pipe.vae.load_state_dict(vae_weights, strict=False)
353
 
354
+ if len(text_encoder_weights) > 0:
355
+ print("injecting text encoder weights...")
356
+ # text encoder structure can vary wildly, strict=False is mandatory here
357
+ pipe.text_encoder.load_state_dict(text_encoder_weights, strict=False)
 
358
 
359
+ # 5. CLEANUP & RUN
360
+ # ------------------------------------------------------------------------------
361
  del state_dict
362
+ del transformer_weights
363
+ del vae_weights
364
+ del text_encoder_weights
365
  gc.collect()
366
  torch.cuda.empty_cache()
367