""" Model loading for Pixagram Pixel Art Generator With Commercial FaceID Support This module loads: 1. SDXL Pipeline with ControlNet 2. Pixel Art LoRA 3. FaceID Module (commercial-friendly) """ import torch from PIL import Image from diffusers import ( StableDiffusionXLControlNetImg2ImgPipeline, ControlNetModel, LCMScheduler ) from controlnet_aux import ZoeDetector from huggingface_hub import hf_hub_download from config import Config from faceid import FaceIDModule, FaceIDAttnProcessor from utils import preload_captioner class ModelHandler: """ Handles loading and management of all models. Components: - SDXL Pipeline with ControlNet - Zoe Depth Detector - LoRA Weights - FaceID Module (Commercial) """ def __init__(self): self.pipeline = None self.zoe_detector = None self.faceid = None def load_models(self): print("=" * 60) print("Loading Pixagram Pixel Art Models (with FaceID)") print("=" * 60) # 1. Load Zoe Depth Detector print("\n[1/5] Loading Zoe Depth Detector...") try: self.zoe_detector = ZoeDetector.from_pretrained(Config.ANNOTATOR_REPO) self.zoe_detector.to(Config.DEVICE) print(" [OK] ZoeDetector loaded") except Exception as e: print(f" [ERROR] Failed to load ZoeDetector: {e}") raise # 2. Load ControlNet Depth print("\n[2/5] Loading ControlNet Zoe Depth...") controlnet = ControlNetModel.from_pretrained( Config.CN_ZOE_REPO, torch_dtype=Config.DTYPE ) print(" [OK] ControlNet Depth loaded") # 3. Load Base Pipeline with checkpoint print("\n[3/5] Loading SDXL Pipeline...") checkpoint_path = hf_hub_download( repo_id=Config.REPO_ID, filename=Config.CHECKPOINT_FILENAME ) print(f" [OK] Checkpoint downloaded: {Config.CHECKPOINT_FILENAME}") self.pipeline = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file( checkpoint_path, controlnet=controlnet, torch_dtype=Config.DTYPE, use_safetensors=True ) print(" [OK] Pipeline loaded") # 4. Load LoRA print("\n[4/5] Loading LoRA weights...") lora_path = hf_hub_download( repo_id=Config.REPO_ID, filename=Config.LORA_FILENAME ) self.pipeline.load_lora_weights(lora_path) self.pipeline.fuse_lora(lora_scale=Config.LORA_STRENGTH) print(f" [OK] LoRA loaded and fused: {Config.LORA_FILENAME}") # 5. Load FaceID Module (Commercial) # FaceID stays on CPU intentionally — AuraFace runs via ONNX on CPU, # detection uses OpenCV. Results are precomputed and moved to GPU. print("\n[5/5] Loading FaceID Module (AuraFace)...") self.faceid = FaceIDModule( auraface_repo=Config.AURAFACE_REPO, face_embed_dim=Config.FACEID_EMBED_DIM, cross_attention_dim=Config.FACEID_CROSS_ATTENTION_DIM, num_tokens=Config.FACEID_NUM_TOKENS, device="cpu", dtype=Config.DTYPE ) # Keep entirely on CPU — no .to(DEVICE) print(" [OK] FaceID Module loaded (CPU)") print(" - Face Detection: OpenCV/MediaPipe (BSD/Apache 2.0)") print(" - Face Encoding: AuraFace (Commercial OK)") print(" - Adapters: Proprietary") # Setup scheduler print("\nConfiguring LCM Scheduler...") self.pipeline.scheduler = LCMScheduler.from_config( self.pipeline.scheduler.config ) print(" [OK] LCM Scheduler configured") # Move pipeline to device and optimize print(f"\nMoving pipeline to {Config.DEVICE}...") self.pipeline.to(Config.DEVICE) if Config.DEVICE == "cuda": try: self.pipeline.enable_xformers_memory_efficient_attention() print(" [OK] xformers enabled") except Exception as e: print(f" [INFO] xformers not available: {e}") # Setup FaceID attention processors ONCE on cross-attention layers only. # Self-attention layers keep their native/xformers processors untouched. print("\nSetting up FaceID cross-attention processors...") self._setup_faceid_attention() print(" [OK] FaceID processors ready (cross-attention only)") # Preload BLIP captioner (avoids 990MB download on first generation) print("\nPreloading captioner...") preload_captioner() # Warm up AuraFace encoder with a dummy pass to initialize # onnxruntime sessions (detection + recognition) print("\nWarming up FaceID encoder...") dummy = Image.new("RGB", (640, 640), (128, 128, 128)) _ = self.faceid.has_face(dummy) # Warms up SCRFD detection print(" [OK] AuraFace encoder warmed up") print("\n" + "=" * 60) print("Model loading complete!") print(" - Pixel Art Style: ✓") print(" - Depth Control: ✓") print(" - FaceID (Commercial): ✓") print("=" * 60 + "\n") def _setup_faceid_attention(self): """ Install FaceID processors on cross-attention layers only. Called once during model loading. Self-attention layers keep their existing processors (AttnProcessor2_0 / xformers) for maximum speed. Processors hold no module references — only precomputed tensors at runtime. """ attn_procs = {} unet = self.pipeline.unet for name, proc in unet.attn_processors.items(): if name.endswith("attn1.processor"): # Self-attention — keep native/xformers processor as-is attn_procs[name] = proc else: # Cross-attention — install FaceID processor hidden_size = self._get_hidden_size(name) if hidden_size: attn_procs[name] = FaceIDAttnProcessor( hidden_size=hidden_size ) else: # Unknown hidden size — keep existing processor attn_procs[name] = proc unet.set_attn_processor(attn_procs) def set_face_data(self, face_tokens: torch.Tensor, face_scale: float): """ Precompute K/V projections for each FaceID processor and move to GPU. This runs once per generation on CPU (via self.faceid), producing small K/V tensors per unique hidden_size. These are moved to the pipeline device and stored on each processor. During diffusion, processors only do cheap tensor reshapes — no linear layers. """ # Bake in the learnable scale from the adapter effective_scale = face_scale * self.faceid.scale.item() device = self.pipeline.device # Cache K/V by hidden_size — only 4 unique sizes for SDXL # (320, 640, 1280, 2048) but dozens of cross-attention layers kv_cache = {} for proc in self.pipeline.unet.attn_processors.values(): if isinstance(proc, FaceIDAttnProcessor): hs = proc.hidden_size if hs not in kv_cache: face_k, face_v = self.faceid.get_kv_for_attention( face_tokens, hs ) kv_cache[hs] = (face_k.to(device), face_v.to(device)) proc.set_face_data( face_k=kv_cache[hs][0], face_v=kv_cache[hs][1], face_scale=effective_scale ) def clear_face_data(self): """Clear precomputed face data from all FaceID attention processors.""" for proc in self.pipeline.unet.attn_processors.values(): if isinstance(proc, FaceIDAttnProcessor): proc.clear_face_data() def _get_hidden_size(self, name: str) -> int: """Get hidden size for a given attention layer name.""" unet = self.pipeline.unet if name.startswith("mid_block"): return unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name.split(".")[1]) return list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name.split(".")[1]) return unet.config.block_out_channels[block_id] return None def check_face(self, image) -> bool: """Check if image contains a detectable face.""" return self.faceid.has_face(image) def get_face_tokens(self, image, padding: float = 0.3): """Get face tokens from image (if face detected).""" return self.faceid.get_face_tokens(image, padding=padding) print("[OK] Model handler ready (FaceID enabled)")