""" Commercial FaceID Module for Pixagram ===================================== This module provides face identity preservation during image-to-image generation. All components use commercially-permissive licenses: - Face Detection: OpenCV Haar Cascades (BSD License) or MediaPipe (Apache 2.0) - Face Encoding: AuraFace (Commercial OK, fal.ai) - Projection Layers: Custom (Your IP) License: Apache 2.0 / Your Proprietary License Copyright (c) 2024 Pixagram SA """ import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image import numpy as np from typing import Optional, Tuple, Union, List import cv2 # Optional: Use MediaPipe if available (Apache 2.0 license) try: import mediapipe as mp # Verify the solutions API is present (removed in mediapipe >= 0.10.18) _ = mp.solutions.face_detection MEDIAPIPE_AVAILABLE = True except (ImportError, AttributeError): MEDIAPIPE_AVAILABLE = False print("[INFO] MediaPipe not available or incompatible version, using OpenCV for face detection") class FaceDetector: """ Commercial-friendly face detector using OpenCV or MediaPipe. Licenses: - OpenCV: BSD License (Commercial OK) - MediaPipe: Apache 2.0 (Commercial OK) """ def __init__(self, use_mediapipe: bool = True, min_confidence: float = 0.5): self.min_confidence = min_confidence self.use_mediapipe = use_mediapipe and MEDIAPIPE_AVAILABLE if self.use_mediapipe: self.mp_face_detection = mp.solutions.face_detection self.detector = self.mp_face_detection.FaceDetection( model_selection=1, # Full range model min_detection_confidence=min_confidence ) print(" [OK] FaceDetector: Using MediaPipe (Apache 2.0)") else: # OpenCV Haar Cascade (BSD License) self.face_cascade = cv2.CascadeClassifier( cv2.data.haarcascades + 'haarcascade_frontalface_default.xml' ) # Also load profile face detector for better coverage self.profile_cascade = cv2.CascadeClassifier( cv2.data.haarcascades + 'haarcascade_profileface.xml' ) print(" [OK] FaceDetector: Using OpenCV Haar Cascades (BSD)") def detect(self, image: Union[Image.Image, np.ndarray]) -> List[Tuple[int, int, int, int]]: """ Detect faces in image. Returns: List of (x, y, width, height) bounding boxes """ if isinstance(image, Image.Image): image_np = np.array(image) else: image_np = image if len(image_np.shape) == 2: image_rgb = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB) elif image_np.shape[2] == 4: image_rgb = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB) else: image_rgb = image_np faces = [] if self.use_mediapipe: results = self.detector.process(image_rgb) if results.detections: h, w = image_rgb.shape[:2] for detection in results.detections: bbox = detection.location_data.relative_bounding_box x = int(bbox.xmin * w) y = int(bbox.ymin * h) width = int(bbox.width * w) height = int(bbox.height * h) faces.append((x, y, width, height)) else: gray = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2GRAY) # Frontal faces frontal = self.face_cascade.detectMultiScale( gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30) ) faces.extend([tuple(f) for f in frontal]) # Profile faces (if no frontal detected) if len(faces) == 0: profiles = self.profile_cascade.detectMultiScale( gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30) ) faces.extend([tuple(f) for f in profiles]) return faces def get_largest_face(self, image: Union[Image.Image, np.ndarray], padding: float = 0.3) -> Optional[Image.Image]: """ Get the largest face from an image with padding. Args: image: Input image padding: Fraction of face size to add as padding Returns: Cropped face image or None if no face detected """ if isinstance(image, Image.Image): image_np = np.array(image) pil_image = image else: image_np = image pil_image = Image.fromarray(image) faces = self.detect(image_np) if len(faces) == 0: return None # Get largest face by area largest = max(faces, key=lambda f: f[2] * f[3]) x, y, w, h = largest # Add padding pad_w = int(w * padding) pad_h = int(h * padding) img_h, img_w = image_np.shape[:2] x1 = max(0, x - pad_w) y1 = max(0, y - pad_h) x2 = min(img_w, x + w + pad_w) y2 = min(img_h, y + h + pad_h) # Crop and return face_crop = pil_image.crop((x1, y1, x2, y2)) return face_crop def has_face(self, image: Union[Image.Image, np.ndarray]) -> bool: """Check if image contains a face.""" faces = self.detect(image) return len(faces) > 0 class AuraFaceEncoder: """ Face encoder using AuraFace via InsightFace's FaceAnalysis API. License: Apache 2.0 (Commercial OK, fal.ai) The AuraFace repo bundles detection (SCRFD), alignment, and recognition (glintr100) models — all Apache 2.0 licensed. FaceAnalysis handles the full pipeline: detect → align → encode, producing 512-dim face embeddings. This is NOT an nn.Module. Runs entirely on CPU via onnxruntime. """ def __init__( self, repo_id: str = "fal/AuraFace-v1", ): self.embed_dim = 512 # AuraFace output dimension from huggingface_hub import snapshot_download from insightface.app import FaceAnalysis # Download the full AuraFace model package print(" Downloading AuraFace model package...") snapshot_download( repo_id, local_dir="models/auraface", ) # Initialize InsightFace with AuraFace models (CPU only) self.app = FaceAnalysis( name="auraface", root=".", providers=["CPUExecutionProvider"], ) self.app.prepare(ctx_id=-1, det_size=(640, 640)) print(" [OK] AuraFaceEncoder loaded (InsightFace + AuraFace)") def __call__( self, image: Image.Image, dtype: torch.dtype = None ) -> Optional[torch.Tensor]: """ Detect, align, and encode the largest face in the image. The full InsightFace pipeline runs here: SCRFD detection → landmark alignment → AuraFace encoding. This produces much better embeddings than detect-crop-encode because the face is properly aligned before recognition. Args: image: Full input image (PIL), not pre-cropped dtype: Target dtype for output tensor Returns: Normalized face embedding [1, 512] or None if no face """ # Convert PIL → BGR numpy (InsightFace convention) image_rgb = np.array(image.convert("RGB")) image_bgr = image_rgb[:, :, ::-1].copy() # Run full pipeline: detect + align + encode faces = self.app.get(image_bgr) if len(faces) == 0: return None # Get largest face by bounding box area largest = max( faces, key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1]) ) # normed_embedding is already L2-normalized, 512-dim embedding = torch.from_numpy( largest.normed_embedding.copy() ).unsqueeze(0).float() # [1, 512] if dtype is not None: embedding = embedding.to(dtype) return embedding def has_face(self, image: Image.Image) -> bool: """Quick check if image contains a detectable face.""" image_rgb = np.array(image.convert("RGB")) image_bgr = image_rgb[:, :, ::-1].copy() faces = self.app.get(image_bgr) return len(faces) > 0 class FaceIDAdapter(nn.Module): """ Adapter that projects face embeddings to cross-attention space. Uses per-token projections so each token can specialize in different facial features (e.g., some tokens for mouth shape, others for age lines, eyes, skin texture). This preserves much more detail than a single large projection. License: Your proprietary IP """ def __init__( self, face_embed_dim: int = 512, cross_attention_dim: int = 2048, num_tokens: int = 16 ): super().__init__() self.num_tokens = num_tokens self.cross_attention_dim = cross_attention_dim # Shared feature expansion: 512 → 2048 # This lifts the face embedding into the cross-attention space once, # then each token gets a specialized view of it. self.shared_proj = nn.Sequential( nn.Linear(face_embed_dim, cross_attention_dim), nn.LayerNorm(cross_attention_dim), nn.GELU(), ) # Per-token projections: each token gets its own linear layer # so it can specialize in different facial features self.token_projs = nn.ModuleList([ nn.Linear(cross_attention_dim, cross_attention_dim) for _ in range(num_tokens) ]) self.norm = nn.LayerNorm(cross_attention_dim) # Learnable scale parameter — start at 1.0 so signal passes through self.scale = nn.Parameter(torch.ones(1)) self._init_weights() def _init_weights(self): # Shared projection: standard init so signal passes through for module in self.shared_proj.modules(): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight, gain=1.0) if module.bias is not None: nn.init.zeros_(module.bias) # Per-token projections: initialize near-identity so each token # starts as a slightly different view of the same face features. # Small random perturbation encourages specialization during training. for i, proj in enumerate(self.token_projs): nn.init.eye_(proj.weight) # Add small random offset so tokens aren't identical proj.weight.data += torch.randn_like(proj.weight) * 0.02 if proj.bias is not None: nn.init.zeros_(proj.bias) def forward(self, face_embedding: torch.Tensor) -> torch.Tensor: """ Project face embedding to cross-attention tokens. Args: face_embedding: [B, face_embed_dim] Returns: Cross-attention tokens [B, num_tokens, cross_attention_dim] """ batch_size = face_embedding.shape[0] input_dtype = face_embedding.dtype # Shared expansion: [B, 512] → [B, 2048] shared = self.shared_proj(face_embedding.float()) # Per-token projections: each gets a specialized view tokens = torch.stack([ proj(shared) for proj in self.token_projs ], dim=1) # [B, num_tokens, 2048] tokens = self.norm(tokens) return tokens.to(input_dtype) class FaceIDModule(nn.Module): """ Complete Commercial FaceID Module for SDXL ========================================== Integrates face detection, encoding, and cross-attention injection for identity-preserving image generation. All components use commercially-permissive licenses: - Face Detection: OpenCV (BSD) or MediaPipe (Apache 2.0) - Face Encoding: AuraFace (Commercial OK, fal.ai) - Adapters: Your proprietary IP Usage: faceid = FaceIDModule() faceid.to("cuda") # Check for face and get embeddings if faceid.has_face(image): face_tokens = faceid.get_face_tokens(image) # Inject into cross-attention... License: Apache 2.0 / Your Proprietary Copyright (c) 2024 Pixagram SA """ def __init__( self, auraface_repo: str = "fal/AuraFace-v1", face_embed_dim: int = 512, cross_attention_dim: int = 2048, num_tokens: int = 4, device: str = "cuda", dtype: torch.dtype = torch.float16 ): super().__init__() self.device = device self.dtype = dtype self.num_tokens = num_tokens # Face Detector (OpenCV, for UI preview only — not used in generation) self.detector = FaceDetector(use_mediapipe=MEDIAPIPE_AVAILABLE) # Face Encoder: AuraFace via InsightFace FaceAnalysis # Handles detection + alignment + encoding in a single call. # NOT an nn.Module — runs entirely on CPU via onnxruntime. self.encoder = AuraFaceEncoder( repo_id=auraface_repo, ) # Cross-Attention Adapter - Keep in float32, convert output self.adapter = FaceIDAdapter( face_embed_dim=face_embed_dim, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens ) # K/V projection layers for different attention dimensions in SDXL # These project 2048-dim face tokens to each layer's hidden size. self.to_k_ip = nn.ModuleDict() self.to_v_ip = nn.ModuleDict() for dim in [320, 640, 1280, 2048]: k_proj = nn.Linear(cross_attention_dim, dim, bias=False) v_proj = nn.Linear(cross_attention_dim, dim, bias=False) # Truncated identity init: select the first `dim` dimensions from # the 2048-dim tokens. This ensures face signal passes through # cleanly with untrained weights (no random corruption). # Training will learn better projections. nn.init.zeros_(k_proj.weight) nn.init.zeros_(v_proj.weight) min_dim = min(dim, cross_attention_dim) k_proj.weight.data[:min_dim, :min_dim] = torch.eye(min_dim) * 0.1 v_proj.weight.data[:min_dim, :min_dim] = torch.eye(min_dim) * 0.1 self.to_k_ip[str(dim)] = k_proj self.to_v_ip[str(dim)] = v_proj def to(self, *args, **kwargs): """ Move module to device/dtype. Note: self.encoder (AuraFace) is NOT an nn.Module — it uses onnxruntime and always runs on CPU. Only the adapter and K/V projections are PyTorch modules that get moved. """ # Move adapter + K/V projections super().to(*args, **kwargs) # Force adapter back to float32 for numerical stability self.adapter = self.adapter.float() return self def has_face(self, image: Union[Image.Image, np.ndarray]) -> bool: """Check if image contains a detectable face (uses AuraFace SCRFD).""" if isinstance(image, np.ndarray): image = Image.fromarray(image) return self.encoder.has_face(image) def get_face_tokens( self, image: Union[Image.Image, np.ndarray], padding: float = 0.3 ) -> Optional[torch.Tensor]: """ Full pipeline: detect, align, encode, and project to tokens. AuraFace's InsightFace pipeline handles detection + alignment + encoding in a single call — much better quality than separate detect-crop-encode since the face is properly aligned using landmarks before recognition. Args: image: Input image (may or may not contain a face) padding: Unused (kept for API compatibility, alignment is automatic) Returns: Cross-attention tokens [1, num_tokens, cross_attention_dim] or None if no face detected """ if isinstance(image, np.ndarray): image = Image.fromarray(image) # AuraFace encoder does detect + align + encode in one pass face_embedding = self.encoder(image, dtype=self.dtype) if face_embedding is None: return None # Project to cross-attention tokens face_tokens = self.adapter(face_embedding.float()) return face_tokens.to(self.dtype) def get_kv_for_attention( self, face_tokens: torch.Tensor, hidden_size: int ) -> Tuple[torch.Tensor, torch.Tensor]: """ Get key-value projections for a specific attention dimension. Args: face_tokens: [B, num_tokens, cross_attention_dim] hidden_size: The hidden size of the attention layer Returns: (key, value) tensors for cross-attention injection """ dim_key = str(hidden_size) input_dtype = face_tokens.dtype # Process in float32 for numerical stability face_tokens_f32 = face_tokens.float() if dim_key in self.to_k_ip: k = self.to_k_ip[dim_key](face_tokens_f32) v = self.to_v_ip[dim_key](face_tokens_f32) else: # Fallback: use the closest available dimension available_dims = [int(d) for d in self.to_k_ip.keys()] closest_dim = min(available_dims, key=lambda x: abs(x - hidden_size)) k = self.to_k_ip[str(closest_dim)](face_tokens_f32) v = self.to_v_ip[str(closest_dim)](face_tokens_f32) # Adjust dimension if needed if k.shape[-1] != hidden_size: k = F.interpolate( k.unsqueeze(1), size=hidden_size, mode='linear' ).squeeze(1) v = F.interpolate( v.unsqueeze(1), size=hidden_size, mode='linear' ).squeeze(1) # Return in original dtype return k.to(input_dtype), v.to(input_dtype) @property def scale(self) -> torch.Tensor: """Get the learnable scale parameter.""" return self.adapter.scale def forward( self, image: Union[Image.Image, np.ndarray], padding: float = 0.3 ) -> dict: """ Process image and return all face-related outputs. Returns: Dictionary with: - 'has_face': bool - 'face_embedding': Tensor or None - 'face_tokens': Tensor or None - 'scale': Tensor """ if isinstance(image, np.ndarray): image = Image.fromarray(image) # AuraFace encoder does detect + align + encode face_embedding = self.encoder(image, dtype=self.dtype) if face_embedding is None: return { 'has_face': False, 'face_embedding': None, 'face_tokens': None, 'scale': self.scale } face_tokens = self.adapter(face_embedding.float()) return { 'has_face': True, 'face_embedding': face_embedding, 'face_tokens': face_tokens.to(self.dtype), 'scale': self.scale } class FaceIDAttnProcessor: """ Custom attention processor that injects face identity. Precomputed K/V tensors are stored directly on the processor before each generation. During diffusion steps, only cheap tensor ops run — no linear layers, no module references, no device crossings. """ def __init__(self, hidden_size: int): self.hidden_size = hidden_size # Precomputed K/V set before generation, cleared after self._face_k = None # [B, num_tokens, hidden_size] self._face_v = None # [B, num_tokens, hidden_size] self._face_scale = 0.0 def set_face_data( self, face_k: torch.Tensor, face_v: torch.Tensor, face_scale: float ): """Store precomputed K/V projections for this layer's hidden_size.""" self._face_k = face_k self._face_v = face_v self._face_scale = face_scale def clear_face_data(self): """Clear face data after generation.""" self._face_k = None self._face_v = None self._face_scale = 0.0 def __call__( self, attn, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: """ Process attention with optional face identity injection. Uses precomputed K/V — zero module calls during diffusion. """ residual = hidden_states input_ndim = hidden_states.ndim target_dtype = hidden_states.dtype if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) # Standard attention if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # Standard scaled dot-product attention hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) # ===== FACE IDENTITY INJECTION ===== if self._face_k is not None and self._face_scale > 0: face_k = self._face_k.to(dtype=target_dtype, device=hidden_states.device) face_v = self._face_v.to(dtype=target_dtype, device=hidden_states.device) face_k = face_k.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) face_v = face_v.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) face_attention = F.scaled_dot_product_attention( query, face_k, face_v, attn_mask=None, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states + self._face_scale * face_attention # =================================== hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, inner_dim) hidden_states = hidden_states.to(target_dtype) # Output projection hidden_states = attn.to_out[0](hidden_states) hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape( batch_size, channel, height, width ) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states print("[OK] FaceID module loaded (AuraFace - Commercial License)")