primerz commited on
Commit
a27ebe7
·
verified ·
1 Parent(s): 53b19bc

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +386 -659
models.py CHANGED
@@ -1,698 +1,425 @@
1
  """
2
- Generation logic for Pixagram AI Pixel Art Generator
 
3
  """
4
  import torch
5
- import numpy as np
6
- import cv2
7
- from PIL import Image
8
- import torch.nn.functional as F
9
- from torchvision import transforms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  from config import (
12
- device, dtype, TRIGGER_WORD, MULTI_SCALE_FACTORS,
13
- ADAPTIVE_THRESHOLDS, ADAPTIVE_PARAMS, CAPTION_CONFIG, IDENTITY_BOOST_MULTIPLIER
14
- )
15
- from utils import (
16
- sanitize_text, enhanced_color_match, color_match, create_face_mask,
17
- draw_kps, get_demographic_description, calculate_optimal_size, enhance_face_crop
18
- )
19
- from models import (
20
- load_face_analysis, load_controlnets, load_image_encoder,
21
- load_sdxl_pipeline, load_lora, setup_ip_adapter, setup_compel,
22
- setup_scheduler, optimize_pipeline, load_caption_model, set_clip_skip,
23
- load_openpose_detector, load_depth_models
24
  )
25
 
26
 
27
- class RetroArtConverter:
28
- """Main class for retro art generation"""
 
 
29
 
30
- def __init__(self):
31
- self.device = device
32
- self.dtype = dtype
33
- self.models_loaded = {
34
- 'custom_checkpoint': False,
35
- 'lora': False,
36
- 'instantid': False,
37
- 'depth': False,
38
- 'ip_adapter': False,
39
- 'openpose': False
40
- }
41
-
42
- # Initialize face analysis
43
- self.face_app, self.face_detection_enabled = load_face_analysis()
44
-
45
- # Load Depth Detector Chain (Zoe -> MiDaS)
46
- self.depth_detector, self.depth_detector_name, depth_success = load_depth_models()
47
- self.models_loaded['depth'] = depth_success
48
-
49
- # Load OpenPose detector
50
- self.openpose_detector, openpose_success = load_openpose_detector()
51
- self.models_loaded['openpose'] = openpose_success
52
-
53
- # Load ControlNets
54
- controlnet_depth, self.controlnet_instantid, self.controlnet_openpose, instantid_success = load_controlnets(self.depth_detector_name)
55
- self.controlnet_depth = controlnet_depth
56
- self.instantid_enabled = instantid_success
57
- self.models_loaded['instantid'] = instantid_success
58
-
59
- # Load image encoder
60
- if self.instantid_enabled:
61
- self.image_encoder = load_image_encoder()
62
- else:
63
- self.image_encoder = None
64
-
65
- # Robust ControlNet Loading
66
- self.instantid_active = self.instantid_enabled and self.controlnet_instantid is not None
67
- self.depth_active = self.depth_detector is not None and self.controlnet_depth is not None
68
- self.openpose_active = self.openpose_detector is not None and self.controlnet_openpose is not None
69
-
70
- controlnets = []
71
- if self.instantid_active:
72
- controlnets.append(self.controlnet_instantid)
73
- print(" [CN] InstantID (Identity) active")
74
- else:
75
- print(" [CN] InstantID (Identity) DISABLED")
76
 
77
- if self.depth_active:
78
- controlnets.append(self.controlnet_depth)
79
- print(f" [CN] Depth ({self.depth_detector_name}) active")
80
- else:
81
- print(f" [CN] Depth ({self.depth_detector_name}) DISABLED (Detector or ControlNet missing)")
82
-
83
- if self.openpose_active:
84
- controlnets.append(self.controlnet_openpose)
85
- print(" [CN] OpenPose (Expression) active")
86
- else:
87
- print(" [CN] OpenPose (Expression) DISABLED (Detector or ControlNet missing)")
88
 
89
- if not controlnets:
90
- print("[WARNING] No ControlNets loaded!")
91
-
92
- print(f"Initializing with {len(controlnets)} active ControlNet(s)")
93
-
94
- # Load SDXL pipeline
95
- self.pipe, checkpoint_success = load_sdxl_pipeline(controlnets if controlnets else None)
96
-
97
- self.models_loaded['custom_checkpoint'] = checkpoint_success
98
-
99
- # Load LORA
100
- lora_success = load_lora(self.pipe)
101
- self.models_loaded['lora'] = lora_success
102
-
103
- # Setup IP-Adapter
104
- if self.instantid_active and self.image_encoder is not None:
105
- self.image_proj_model, ip_adapter_success = setup_ip_adapter(self.pipe, self.image_encoder)
106
- self.models_loaded['ip_adapter'] = ip_adapter_success
107
- else:
108
- print("[INFO] Face preservation: IP-Adapter disabled (InstantID model failed or encoder failed)")
109
- self.models_loaded['ip_adapter'] = False
110
- self.image_proj_model = None
111
-
112
- # Setup Compel
113
- self.compel, self.use_compel = setup_compel(self.pipe)
114
-
115
- # Setup LCM scheduler
116
- setup_scheduler(self.pipe)
117
-
118
- # Optimize pipeline
119
- optimize_pipeline(self.pipe)
120
-
121
- # Load caption model
122
- self.caption_processor, self.caption_model, self.caption_enabled, self.caption_model_type = load_caption_model()
123
-
124
- # Report caption model status
125
- if self.caption_enabled and self.caption_model is not None:
126
- if self.caption_model_type == "git":
127
- print(" [OK] Using GIT for detailed captions")
128
- elif self.caption_model_type == "blip":
129
- print(" [OK] Using BLIP for standard captions")
130
- else:
131
- print(" [OK] Caption model loaded")
132
-
133
-
134
- # Set CLIP skip
135
- set_clip_skip(self.pipe)
136
-
137
- # Track controlnet configuration
138
- self.using_multiple_controlnets = isinstance(controlnets, list)
139
- print(f"Pipeline initialized with {'multiple' if self.using_multiple_controlnets else 'single'} ControlNet(s)")
140
-
141
- # Print model status
142
- self._print_status()
143
-
144
- print(" [OK] Model initialization complete!")
145
-
146
- def _print_status(self):
147
- """Print model loading status"""
148
- print("\n=== MODEL STATUS ===")
149
- for model, loaded in self.models_loaded.items():
150
- status = "[OK] LOADED" if loaded else "[FALLBACK/DISABLED]"
151
- print(f"{model}: {status}")
152
- print("===================\n")
153
-
154
- print("=== UPGRADE VERIFICATION ===")
155
- try:
156
- # Check for enhanced classes if they exist
157
- pass
158
  except Exception as e:
159
- print(f"[INFO] Verification skipped: {e}")
160
- print("============================\n")
161
-
162
- def get_depth_map(self, image):
163
- """Generate depth map using the loaded detector (Zoe/MiDaS)"""
164
- if self.depth_detector is not None:
165
- try:
166
- if image.mode != 'RGB':
167
- image = image.convert('RGB')
168
-
169
- orig_width, orig_height = image.size
170
- orig_width = int(orig_width)
171
- orig_height = int(orig_height)
172
-
173
- target_width = int((orig_width // 64) * 64)
174
- target_height = int((orig_height // 64) * 64)
175
-
176
- target_width = int(max(64, target_width))
177
- target_height = int(max(64, target_height))
178
-
179
- size_for_depth = (int(target_width), int(target_height))
180
-
181
- image_resized = image.resize(size_for_depth, Image.LANCZOS)
182
-
183
- # --- FIX for numpy.int64 error ---
184
- # .copy() forces PIL to create a new image, stripping numpy-typed metadata
185
- image_for_depth = image_resized.copy()
186
- # --- END FIX ---
187
-
188
- if target_width != orig_width or target_height != orig_height:
189
- print(f"[DEPTH] Resized for {self.depth_detector_name}Detector: {orig_width}x{orig_height} -> {target_width}x{target_height}")
190
-
191
- with torch.no_grad():
192
- depth_image = self.depth_detector(image_for_depth)
193
-
194
- depth_width, depth_height = depth_image.size
195
- if depth_width != orig_width or depth_height != orig_height:
196
- depth_image = depth_image.resize((int(orig_width), int(orig_height)), Image.LANCZOS)
197
-
198
- print(f"[DEPTH] {self.depth_detector_name} depth map generated: {orig_width}x{orig_height}")
199
- return depth_image
200
-
201
- except Exception as e:
202
- print(f"[DEPTH] {self.depth_detector_name}Detector failed ({e}), falling back to grayscale depth")
203
- gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
204
- depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
205
- return Image.fromarray(depth_colored)
206
  else:
207
- print("[DEPTH] No depth detector active, falling back to grayscale depth")
208
- gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
209
- depth_colored = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
210
- return Image.fromarray(depth_colored)
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- def add_trigger_word(self, prompt):
214
- """Add trigger word to prompt if not present"""
215
- if TRIGGER_WORD.lower() not in prompt.lower():
216
- if not prompt or not prompt.strip():
217
- return TRIGGER_WORD
218
- return f"{TRIGGER_WORD}, {prompt}"
219
- return prompt
220
 
221
- def extract_multi_scale_face(self, face_crop, face):
222
- """
223
- Extract face features at multiple scales for better detail.
224
- +1-2% improvement in face preservation.
225
- """
226
- try:
227
- multi_scale_embeds = []
228
-
229
- for scale in MULTI_SCALE_FACTORS:
230
- # Resize
231
- w, h = face_crop.size
232
- scaled_size = (int(w * scale), int(h * scale))
233
- scaled_crop = face_crop.resize(scaled_size, Image.LANCZOS)
234
-
235
- # Pad/crop back to original
236
- scaled_crop = scaled_crop.resize((w, h), Image.LANCZOS)
237
-
238
- # Extract features
239
- scaled_array = cv2.cvtColor(np.array(scaled_crop), cv2.COLOR_RGB2BGR)
240
- scaled_faces = self.face_app.get(scaled_array)
241
-
242
- if len(scaled_faces) > 0:
243
- multi_scale_embeds.append(scaled_faces[0].normed_embedding)
244
-
245
- # Average embeddings
246
- if len(multi_scale_embeds) > 0:
247
- averaged = np.mean(multi_scale_embeds, axis=0)
248
- # Renormalize
249
- averaged = averaged / np.linalg.norm(averaged)
250
- print(f"[MULTI-SCALE] Combined {len(multi_scale_embeds)} scales")
251
- return averaged
252
-
253
- return face.normed_embedding
254
-
255
- except Exception as e:
256
- print(f"[MULTI-SCALE] Failed: {e}, using single scale")
257
- return face.normed_embedding
258
 
259
- def detect_face_quality(self, face):
260
- """
261
- Detect face quality and adaptively adjust parameters.
262
- +2-3% consistency improvement.
263
- """
264
  try:
265
- bbox = face.bbox
266
- face_size = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
267
- det_score = float(face.det_score) if hasattr(face, 'det_score') else 1.0
268
-
269
- # Small face -> boost identity preservation
270
- if face_size < ADAPTIVE_THRESHOLDS['small_face_size']:
271
- return ADAPTIVE_PARAMS['small_face'].copy()
272
-
273
- # Low confidence -> boost preservation
274
- elif det_score < ADAPTIVE_THRESHOLDS['low_confidence']:
275
- return ADAPTIVE_PARAMS['low_confidence'].copy()
276
-
277
- # Check for profile/side view (if pose available)
278
- elif hasattr(face, 'pose') and len(face.pose) > 1:
279
- try:
280
- yaw = float(face.pose[1])
281
- if abs(yaw) > ADAPTIVE_THRESHOLDS['profile_angle']:
282
- return ADAPTIVE_PARAMS['profile_view'].copy()
283
- except (ValueError, TypeError, IndexError):
284
- pass
285
-
286
- # Good quality face - use provided parameters
287
- return None
288
-
289
  except Exception as e:
290
- print(f"[ADAPTIVE] Quality detection failed: {e}")
291
- return None
 
292
 
293
- def validate_and_adjust_parameters(self, strength, guidance_scale, lora_scale,
294
- identity_preservation, identity_control_scale,
295
- depth_control_scale, consistency_mode=True,
296
- expression_control_scale=0.6):
297
- """
298
- Enhanced parameter validation with stricter rules for consistency.
299
- """
300
- if consistency_mode:
301
- print("[CONSISTENCY] Applying strict parameter validation...")
302
- adjustments = []
303
-
304
- # Rule 1: Strong inverse relationship between identity and LORA
305
- if identity_preservation > 1.2:
306
- original_lora = lora_scale
307
- lora_scale = min(lora_scale, 1.0)
308
- if abs(lora_scale - original_lora) > 0.01:
309
- adjustments.append(f"LORA: {original_lora:.2f}->{lora_scale:.2f} (high identity)")
310
-
311
- # Rule 2: Strength-based profile activation
312
- if strength < 0.5:
313
- # Maximum preservation mode
314
- if identity_preservation < 1.3:
315
- original_identity = identity_preservation
316
- identity_preservation = 1.3
317
- adjustments.append(f"Identity: {original_identity:.2f}->{identity_preservation:.2f} (max preservation)")
318
- if lora_scale > 0.9:
319
- original_lora = lora_scale
320
- lora_scale = 0.9
321
- adjustments.append(f"LORA: {original_lora:.2f}->{lora_scale:.2f} (max preservation)")
322
- if guidance_scale > 1.3:
323
- original_cfg = guidance_scale
324
- guidance_scale = 1.3
325
- adjustments.append(f"CFG: {original_cfg:.2f}->{guidance_scale:.2f} (max preservation)")
326
-
327
- elif strength > 0.7:
328
- # Artistic transformation mode
329
- if identity_preservation > 1.0:
330
- original_identity = identity_preservation
331
- identity_preservation = 1.0
332
- adjustments.append(f"Identity: {original_identity:.2f}->{identity_preservation:.2f} (artistic mode)")
333
- if lora_scale < 1.2:
334
- original_lora = lora_scale
335
- lora_scale = 1.2
336
- adjustments.append(f"LORA: {original_lora:.2f}->{lora_scale:.2f} (artistic mode)")
337
-
338
- # Rule 3: CFG-LORA relationship
339
- if guidance_scale > 1.4 and lora_scale > 1.2:
340
- original_lora = lora_scale
341
- lora_scale = 1.1
342
- adjustments.append(f"LORA: {original_lora:.2f}->{lora_scale:.2f} (high CFG detected)")
343
-
344
- # Rule 4: LCM sweet spot enforcement
345
- original_cfg = guidance_scale
346
- guidance_scale = max(1.0, min(guidance_scale, 1.5))
347
- if abs(guidance_scale - original_cfg) > 0.01:
348
- adjustments.append(f"CFG: {original_cfg:.2f}->{guidance_scale:.2f} (LCM optimal)")
349
-
350
- # Rule 5: ControlNet balance
351
- total_control = 0
352
- if self.instantid_active:
353
- total_control += identity_control_scale
354
- if self.depth_active:
355
- total_control += depth_control_scale
356
- if self.openpose_active:
357
- total_control += expression_control_scale
358
-
359
- if total_control > 2.0:
360
- scale_factor = 2.0 / total_control
361
- original_id_ctrl = identity_control_scale
362
- original_depth_ctrl = depth_control_scale
363
- original_expr_ctrl = expression_control_scale
364
-
365
- if self.instantid_active:
366
- identity_control_scale *= scale_factor
367
- if self.depth_active:
368
- depth_control_scale *= scale_factor
369
- if self.openpose_active:
370
- expression_control_scale *= scale_factor
371
-
372
- adjustments.append(f"ControlNets balanced: ID {original_id_ctrl:.2f}->{identity_control_scale:.2f}, Depth {original_depth_ctrl:.2f}->{depth_control_scale:.2f}, Expr {original_expr_ctrl:.2f}->{expression_control_scale:.2f}")
373
-
374
- if adjustments:
375
- print(" [OK] Applied adjustments:")
376
- for adj in adjustments:
377
- print(f" - {adj}")
378
- else:
379
- print(" [OK] Parameters already optimal")
380
-
381
- return strength, guidance_scale, lora_scale, identity_preservation, identity_control_scale, depth_control_scale, expression_control_scale
382
 
383
- def generate_caption(self, image, max_length=None, num_beams=None):
384
- """Generate a descriptive caption for the image (supports BLIP-2, GIT, BLIP)."""
385
- if not self.caption_enabled or self.caption_model is None:
386
- return None
387
-
388
- if max_length is None:
389
- if self.caption_model_type == "blip2":
390
- max_length = 50
391
- elif self.caption_model_type == "git":
392
- max_length = 40
393
- else:
394
- max_length = CAPTION_CONFIG['max_length']
395
-
396
- if num_beams is None:
397
- num_beams = CAPTION_CONFIG['num_beams']
398
-
399
- try:
400
- if self.caption_model_type == "blip2":
401
- inputs = self.caption_processor(image, return_tensors="pt").to(self.device, self.dtype)
402
- with torch.no_grad():
403
- output = self.caption_model.generate(
404
- **inputs, max_length=max_length, num_beams=num_beams, min_length=10,
405
- length_penalty=1.0, repetition_penalty=1.5, early_stopping=True
406
- )
407
- caption = self.caption_processor.decode(output[0], skip_special_tokens=True)
408
-
409
- elif self.caption_model_type == "git":
410
- inputs = self.caption_processor(images=image, return_tensors="pt").to(self.device, self.dtype)
411
- with torch.no_grad():
412
- output = self.caption_model.generate(
413
- pixel_values=inputs.pixel_values, max_length=max_length, num_beams=num_beams, min_length=10,
414
- length_penalty=1.0, repetition_penalty=1.5, early_stopping=True
415
- )
416
- caption = self.caption_processor.batch_decode(output, skip_special_tokens=True)[0]
417
-
418
- else:
419
- inputs = self.caption_processor(image, return_tensors="pt").to(self.device, self.dtype)
420
- with torch.no_grad():
421
- output = self.caption_model.generate(
422
- **inputs, max_length=max_length, num_beams=num_beams, early_stopping=True
423
- )
424
- caption = self.caption_processor.decode(output[0], skip_special_tokens=True)
425
-
426
- return caption.strip()
427
-
428
- except Exception as e:
429
- print(f"Caption generation failed: {e}")
430
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
- def generate_retro_art(
433
- self,
434
- input_image,
435
- prompt="retro game character, vibrant colors, detailed",
436
- negative_prompt="blurry, low quality, ugly, distorted",
437
- num_inference_steps=12,
438
- guidance_scale=1.0,
439
- depth_control_scale=0.8,
440
- identity_control_scale=0.85,
441
- expression_control_scale=0.6,
442
- lora_scale=1.0,
443
- identity_preservation=0.8,
444
- strength=0.75,
445
- enable_color_matching=False,
446
- consistency_mode=True,
447
- seed=-1
448
- ):
449
- """Generate retro art with img2img pipeline and enhanced InstantID"""
450
-
451
- # --- FIX for Compel tensor mismatch error ---
452
- prompt = sanitize_text(prompt)
453
- if not prompt or not prompt.strip():
454
- prompt = "" # Ensure prompt is not None or just whitespace
455
-
456
- negative_prompt = sanitize_text(negative_prompt)
457
- if not negative_prompt or not negative_prompt.strip():
458
- negative_prompt = "" # Ensure negative_prompt is "" if blank
459
- # --- END FIX ---
460
-
461
- if consistency_mode:
462
- print("\n[CONSISTENCY] Validating and adjusting parameters...")
463
- strength, guidance_scale, lora_scale, identity_preservation, identity_control_scale, depth_control_scale, expression_control_scale = \
464
- self.validate_and_adjust_parameters(
465
- strength, guidance_scale, lora_scale, identity_preservation,
466
- identity_control_scale, depth_control_scale, consistency_mode,
467
- expression_control_scale
468
- )
469
-
470
- prompt = self.add_trigger_word(prompt)
471
-
472
- original_width, original_height = input_image.size
473
- target_width, target_height = calculate_optimal_size(original_width, original_height)
474
-
475
- print(f"Resizing from {original_width}x{original_height} to {target_width}x{target_height}")
476
- print(f"Prompt: {prompt}")
477
- print(f"Img2Img Strength: {strength}")
478
-
479
- resized_image = input_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
480
-
481
- depth_image = None
482
- if self.depth_active:
483
- depth_image = self.get_depth_map(resized_image)
484
- if depth_image.size != (target_width, target_height):
485
- depth_image = depth_image.resize((int(target_width), int(target_height)), Image.LANCZOS)
486
-
487
- openpose_image = None
488
- if self.openpose_active:
489
- print("Generating OpenPose map...")
490
  try:
491
- openpose_image = self.openpose_detector(resized_image, face_only=True)
 
492
  except Exception as e:
493
- print(f"OpenPose failed, using blank map: {e}")
494
- openpose_image = Image.new("RGB", (target_width, target_height), (0,0,0))
495
-
 
496
 
497
- face_kps_image = None
498
- face_embeddings = None
499
- face_crop_enhanced = None
500
- has_detected_faces = False
501
- face_bbox_original = None
502
 
503
- if self.instantid_active and self.face_app is not None:
504
- print("Detecting faces and extracting keypoints...")
505
- img_array = cv2.cvtColor(np.array(resized_image), cv2.COLOR_RGB2BGR)
506
- faces = self.face_app.get(img_array)
507
 
508
- if len(faces) > 0:
509
- has_detected_faces = True
510
- print(f"Detected {len(faces)} face(s)")
511
-
512
- face = sorted(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))[-1]
513
-
514
- adaptive_params = self.detect_face_quality(face)
515
- if adaptive_params is not None:
516
- print(f"[ADAPTIVE] {adaptive_params['reason']}")
517
- identity_preservation = adaptive_params['identity_preservation']
518
- identity_control_scale = adaptive_params['identity_control_scale']
519
- guidance_scale = adaptive_params['guidance_scale']
520
- lora_scale = adaptive_params['lora_scale']
521
-
522
- face_embeddings_base = face.normed_embedding
523
-
524
- bbox = face.bbox.astype(int)
525
- x1, y1, x2, y2 = bbox[0], bbox[1], bbox[2], bbox[3]
526
- face_bbox_original = [x1, y1, x2, y2]
527
-
528
- face_width = x2 - x1
529
- face_height = y2 - y1
530
- padding_x = int(face_width * 0.3)
531
- padding_y = int(face_height * 0.3)
532
- x1 = max(0, x1 - padding_x)
533
- y1 = max(0, y1 - padding_y)
534
- x2 = min(resized_image.width, x2 + padding_x)
535
- y2 = min(resized_image.height, y2 + padding_y)
536
-
537
- face_crop = resized_image.crop((x1, y1, x2, y2))
538
-
539
- face_embeddings = self.extract_multi_scale_face(face_crop, face)
540
- face_crop_enhanced = enhance_face_crop(face_crop)
541
- face_kps = face.kps
542
- face_kps_image = draw_kps(resized_image, face_kps)
543
-
544
- from utils import get_facial_attributes, build_enhanced_prompt
545
- facial_attrs = get_facial_attributes(face)
546
- prompt = build_enhanced_prompt(prompt, facial_attrs, TRIGGER_WORD)
547
-
548
- age = facial_attrs['age']
549
- gender_code = facial_attrs['gender']
550
- det_score = facial_attrs['quality']
551
-
552
- gender_str = 'M' if gender_code == 1 else ('F' if gender_code == 0 else 'N/A')
553
- print(f"Face info: bbox={face.bbox}, age={age if age else 'N/A'}, gender={gender_str}")
554
- print(f"Face crop size: {face_crop.size}, enhanced: {face_crop_enhanced.size if face_crop_enhanced else 'N/A'}")
555
-
556
- if hasattr(self.pipe, 'set_adapters') and self.models_loaded['lora']:
557
- try:
558
- self.pipe.set_adapters(["retroart"], adapter_weights=[lora_scale])
559
- print(f"LORA scale: {lora_scale}")
560
- except Exception as e:
561
- print(f"Could not set LORA scale: {e}")
562
-
563
- pipe_kwargs = {
564
- "image": resized_image,
565
- "strength": strength,
566
- "num_inference_steps": num_inference_steps,
567
- "guidance_scale": guidance_scale,
568
- }
569
-
570
- if seed == -1:
571
- generator = torch.Generator(device=self.device)
572
- actual_seed = generator.seed()
573
- print(f"[SEED] Using random seed: {actual_seed}")
574
- else:
575
- generator = torch.Generator(device=self.device).manual_seed(seed)
576
- actual_seed = seed
577
- print(f"[SEED] Using fixed seed: {actual_seed}")
578
 
579
- pipe_kwargs["generator"] = generator
580
 
581
- if self.use_compel and self.compel is not None:
582
  try:
583
- print("Encoding prompts with Compel...")
584
- conditioning = self.compel(prompt)
585
- negative_conditioning = self.compel(negative_prompt)
586
-
587
- pipe_kwargs["prompt_embeds"] = conditioning[0]
588
- pipe_kwargs["pooled_prompt_embeds"] = conditioning[1]
589
- pipe_kwargs["negative_prompt_embeds"] = negative_conditioning[0]
590
- pipe_kwargs["negative_pooled_prompt_embeds"] = negative_conditioning[1]
591
-
592
- print("[OK] Using Compel-encoded prompts")
593
  except Exception as e:
594
- print(f"Compel encoding failed, using standard prompts: {e}")
595
- pipe_kwargs["prompt"] = prompt
596
- pipe_kwargs["negative_prompt"] = negative_prompt
597
- else:
598
- pipe_kwargs["prompt"] = prompt
599
- pipe_kwargs["negative_prompt"] = negative_prompt
600
-
601
- if hasattr(self.pipe, 'text_encoder'):
602
- pipe_kwargs["clip_skip"] = 2
603
-
604
- control_images = []
605
- conditioning_scales = []
606
- scale_debug_str = []
607
-
608
- if self.instantid_active:
609
- if has_detected_faces and face_kps_image is not None:
610
- control_images.append(face_kps_image)
611
- conditioning_scales.append(identity_control_scale)
612
- scale_debug_str.append(f"Identity: {identity_control_scale:.2f}")
613
-
614
- if face_embeddings is not None and self.models_loaded.get('ip_adapter', False) and face_crop_enhanced is not None:
615
- print(f"Processing InstantID face embeddings with Resampler...")
616
-
617
- with torch.no_grad():
618
- face_emb_tensor = torch.from_numpy(face_embeddings).to(device=self.device, dtype=self.dtype)
619
- face_emb_tensor = face_emb_tensor.reshape(1, -1, 512)
620
- face_proj_embeds = self.image_proj_model(face_emb_tensor)
621
-
622
- boosted_scale = identity_preservation * IDENTITY_BOOST_MULTIPLIER
623
- face_proj_embeds = face_proj_embeds * boosted_scale
624
-
625
- print(f" - Face embedding: {face_emb_tensor.shape} -> {face_proj_embeds.shape}, Scale: {boosted_scale:.2f}")
626
-
627
- if 'prompt_embeds' in pipe_kwargs:
628
- original_embeds = pipe_kwargs['prompt_embeds']
629
-
630
- if original_embeds.shape[0] > 1: # Handle CFG
631
- face_proj_embeds = torch.cat([torch.zeros_like(face_proj_embeds), face_proj_embeds], dim=0)
632
-
633
- combined_embeds = torch.cat([original_embeds, face_proj_embeds], dim=1)
634
- pipe_kwargs['prompt_embeds'] = combined_embeds
635
- print(f" [OK] Face embeddings concatenated successfully! New shape: {combined_embeds.shape}")
636
- else:
637
- print(f" [WARNING] Can't concatenate - no prompt_embeds (use Compel)")
638
-
639
- elif has_detected_faces:
640
- print(" Face detected but IP-Adapter/embeddings unavailable, using keypoints only")
641
-
642
- else:
643
- print("Using blank map for InstantID (no face/disabled)")
644
- control_images.append(Image.new("RGB", (target_width, target_height), (0,0,0)))
645
- conditioning_scales.append(0.0)
646
- scale_debug_str.append("Identity: 0.00")
647
-
648
- if self.depth_active:
649
- control_images.append(depth_image)
650
- conditioning_scales.append(depth_control_scale)
651
- scale_debug_str.append(f"Depth ({self.depth_detector_name}): {depth_control_scale:.2f}")
652
-
653
- if self.openpose_active:
654
- control_images.append(openpose_image)
655
- conditioning_scales.append(expression_control_scale)
656
- scale_debug_str.append(f"Expression: {expression_control_scale:.2f}")
657
-
658
- if control_images:
659
- pipe_kwargs["control_image"] = control_images
660
- pipe_kwargs["controlnet_conditioning_scale"] = conditioning_scales
661
- print(f"Active ControlNets: {len(control_images)}")
662
  else:
663
- print("No active ControlNets, running standard Img2Img")
664
 
 
665
 
666
- print(f"Generating with LCM: Steps={num_inference_steps}, CFG={guidance_scale}, Strength={strength}")
667
- print(f"Controlnet scales - {' | '.join(scale_debug_str)}")
668
- result = self.pipe(**pipe_kwargs)
669
 
670
- generated_image = result.images[0]
671
 
672
- if enable_color_matching and has_detected_faces:
673
- print("Applying enhanced face-aware color matching...")
674
- try:
675
- if face_bbox_original is not None:
676
- generated_image = enhanced_color_match(
677
- generated_image,
678
- resized_image,
679
- face_bbox=face_bbox_original
680
- )
681
- print("[OK] Enhanced color matching applied (face-aware)")
682
- else:
683
- generated_image = color_match(generated_image, resized_image, mode='mkl')
684
- print("[OK] Standard color matching applied")
685
- except Exception as e:
686
- print(f"Color matching failed: {e}")
687
- elif enable_color_matching:
688
- print("Applying standard color matching...")
689
- try:
690
- generated_image = color_match(generated_image, resized_image, mode='mkl')
691
- print("[OK] Standard color matching applied")
692
- except Exception as e:
693
- print(f"Color matching failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
694
 
695
- return generated_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696
 
697
 
698
- print("[OK] Generator class ready")
 
1
  """
2
+ Model loading and initialization for Pixagram AI Pixel Art Generator
3
+ FIXED VERSION with proper IP-Adapter and BLIP-2 support
4
  """
5
  import torch
6
+ import time
7
+ from diffusers import (
8
+ StableDiffusionXLControlNetImg2ImgPipeline,
9
+ ControlNetModel,
10
+ AutoencoderKL,
11
+ LCMScheduler
12
+ )
13
+ from diffusers.models.attention_processor import AttnProcessor2_0
14
+ from transformers import CLIPVisionModelWithProjection
15
+ from insightface.app import FaceAnalysis
16
+ # --- MODIFIED: Import detectors (LeReSDetector removed) ---
17
+ from controlnet_aux import ZoeDetector, OpenposeDetector, MidasDetector
18
+ from huggingface_hub import hf_hub_download
19
+ from compel import Compel, ReturnedEmbeddingsType
20
+
21
+ # Use reference implementation's attention processor
22
+ from attention_processor import IPAttnProcessor2_0, AttnProcessor
23
+ from resampler import Resampler
24
+
25
+ # --- ERROR WAS HERE ---
26
+ # The "from models import (...)" block that was here has been removed
27
+ # as it was causing a circular import.
28
+ # --- END FIX ---
29
 
30
  from config import (
31
+ device, dtype, MODEL_REPO, MODEL_FILES, HUGGINGFACE_TOKEN,
32
+ FACE_DETECTION_CONFIG, CLIP_SKIP, DOWNLOAD_CONFIG
 
 
 
 
 
 
 
 
 
 
33
  )
34
 
35
 
36
+ def download_model_with_retry(repo_id, filename, max_retries=None):
37
+ """Download model with retry logic and proper token handling."""
38
+ if max_retries is None:
39
+ max_retries = DOWNLOAD_CONFIG['max_retries']
40
 
41
+ for attempt in range(max_retries):
42
+ try:
43
+ print(f" Attempting to download {filename} (attempt {attempt + 1}/{max_retries})...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ kwargs = {"repo_type": "model"}
46
+ if HUGGINGFACE_TOKEN:
47
+ kwargs["token"] = HUGGINGFACE_TOKEN
48
+
49
+ path = hf_hub_download(
50
+ repo_id=repo_id,
51
+ filename=filename,
52
+ **kwargs
53
+ )
54
+ print(f" [OK] Downloaded: {filename}")
55
+ return path
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  except Exception as e:
58
+ print(f" [WARNING] Download attempt {attempt + 1} failed: {e}")
59
+
60
+ if attempt < max_retries - 1:
61
+ print(f" Retrying in {DOWNLOAD_CONFIG['retry_delay']} seconds...")
62
+ time.sleep(DOWNLOAD_CONFIG['retry_delay'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  else:
64
+ print(f" [ERROR] Failed to download {filename} after {max_retries} attempts")
65
+ raise
66
+
67
+ return None
68
 
69
+
70
+ def load_face_analysis():
71
+ """Load face analysis model with proper error handling."""
72
+ print("Loading face analysis model...")
73
+ try:
74
+ face_app = FaceAnalysis(
75
+ name=FACE_DETECTION_CONFIG['model_name'],
76
+ root='./models/insightface',
77
+ providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
78
+ )
79
+ face_app.prepare(
80
+ ctx_id=FACE_DETECTION_CONFIG['ctx_id'],
81
+ det_size=FACE_DETECTION_CONFIG['det_size']
82
+ )
83
+ print(" [OK] Face analysis model loaded successfully")
84
+ return face_app, True
85
+ except Exception as e:
86
+ print(f" [WARNING] Face detection not available: {e}")
87
+ return None, False
88
+
89
+ # --- MODIFIED FUNCTION: Depth Detector Fallback Chain (Zoe -> MiDaS) ---
90
+ def load_depth_models():
91
+ """
92
+ Load depth detector with fallback: Zoe -> MiDaS.
93
+ """
94
+ print("Loading depth detector...")
95
+
96
+ # 1. Try Zoe
97
+ try:
98
+ detector = ZoeDetector.from_pretrained("lllyasviel/Annotators")
99
+ detector.to(device)
100
+ print(" [OK] Using Zoe Depth detector")
101
+ return detector, "zoe", True
102
+ except Exception as e_zoe:
103
+ print(f" [INFO] Zoe failed ({e_zoe}), falling back to MiDaS...")
104
+
105
+ # 2. Try MiDaS
106
+ try:
107
+ detector = MidasDetector.from_pretrained("lllyasviel/Annotators")
108
+ detector.to(device)
109
+ print(" [OK] Using MiDaS Depth detector")
110
+ return detector, "midas", True
111
+ except Exception as e_midas:
112
+ print(f" [WARNING] All depth detectors failed ({e_midas})")
113
+ return None, "none", False
114
+ # --- END MODIFICATION ---
115
+
116
+
117
+ def load_openpose_detector():
118
+ """Load OpenPose detector."""
119
+ print("Loading OpenPose detector...")
120
+ try:
121
+ openpose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
122
+ openpose.to(device)
123
+ print(" [OK] OpenPose loaded successfully")
124
+ return openpose, True
125
+ except Exception as e:
126
+ print(f" [WARNING] OpenPose not available: {e}")
127
+ return None, False
128
+
129
+
130
+ def load_controlnets(depth_detector_name="zoe"):
131
+ """Load ControlNet models."""
132
 
133
+ # This logic correctly handles the "zoe" or "midas" name
134
+ depth_model_repo = {
135
+ "zoe": "diffusers/controlnet-zoe-depth-sdxl-1.0",
136
+ "midas": "diffusers/controlnet-midas-sdxl-1.0",
137
+ "none": None
138
+ }
 
139
 
140
+ repo_id = depth_model_repo.get(depth_detector_name)
141
+ controlnet_depth = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ if repo_id:
144
+ print(f"Loading ControlNet Depth model for {depth_detector_name} ({repo_id})...")
 
 
 
145
  try:
146
+ controlnet_depth = ControlNetModel.from_pretrained(
147
+ repo_id,
148
+ torch_dtype=dtype
149
+ ).to(device)
150
+ print(f" [OK] ControlNet {depth_detector_name} Depth loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  except Exception as e:
152
+ print(f" [WARNING] Could not load {depth_detector_name} ControlNet: {e}")
153
+ else:
154
+ print(" [INFO] No depth detector loaded, skipping depth ControlNet.")
155
 
156
+ print("Loading ControlNet OpenPose model...")
157
+ try:
158
+ controlnet_openpose = ControlNetModel.from_pretrained(
159
+ "diffusers/controlnet-openpose-sdxl-1.0",
160
+ torch_dtype=dtype
161
+ ).to(device)
162
+ print(" [OK] ControlNet OpenPose loaded")
163
+ except Exception as e:
164
+ print(f" [WARNING] ControlNet OpenPose not available: {e}")
165
+ controlnet_openpose = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ print("Loading InstantID ControlNet...")
168
+ try:
169
+ controlnet_instantid = ControlNetModel.from_pretrained(
170
+ "InstantX/InstantID",
171
+ subfolder="ControlNetModel",
172
+ torch_dtype=dtype
173
+ ).to(device)
174
+ print(" [OK] InstantID ControlNet loaded successfully")
175
+ return controlnet_depth, controlnet_instantid, controlnet_openpose, True
176
+ except Exception as e:
177
+ print(f" [WARNING] InstantID ControlNet not available: {e}")
178
+ return controlnet_depth, None, controlnet_openpose, False
179
+
180
+
181
+ def load_image_encoder():
182
+ """Load CLIP Image Encoder for IP-Adapter."""
183
+ print("Loading CLIP Image Encoder for IP-Adapter...")
184
+ try:
185
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
186
+ "h94/IP-Adapter",
187
+ subfolder="models/image_encoder",
188
+ torch_dtype=dtype
189
+ ).to(device)
190
+ print(" [OK] CLIP Image Encoder loaded successfully")
191
+ return image_encoder
192
+ except Exception as e:
193
+ print(f" [ERROR] Could not load image encoder: {e}")
194
+ return None
195
+
196
+
197
+ def load_sdxl_pipeline(controlnets):
198
+ """Load SDXL checkpoint from HuggingFace Hub."""
199
+ print("Loading SDXL checkpoint (horizon) with bundled VAE from HuggingFace Hub...")
200
+ try:
201
+ model_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['checkpoint'])
202
+
203
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_single_file(
204
+ model_path,
205
+ controlnet=controlnets,
206
+ torch_dtype=dtype,
207
+ use_safetensors=True
208
+ ).to(device)
209
+ print(" [OK] Custom checkpoint loaded successfully (VAE bundled)")
210
+ return pipe, True
211
+ except Exception as e:
212
+ print(f" [WARNING] Could not load custom checkpoint: {e}")
213
+ print(" Using default SDXL base model")
214
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
215
+ "stabilityai/stable-diffusion-xl-base-1.0",
216
+ controlnet=controlnets,
217
+ torch_dtype=dtype,
218
+ use_safetensors=True
219
+ ).to(device)
220
+ return pipe, False
221
+
222
+
223
+ def load_lora(pipe):
224
+ """Load LORA from HuggingFace Hub."""
225
+ print("Loading LORA (retroart) from HuggingFace Hub...")
226
+ try:
227
+ lora_path = download_model_with_retry(MODEL_REPO, MODEL_FILES['lora'])
228
+ pipe.load_lora_weights(lora_path, adapter_name="retroart")
229
+ print(f" [OK] LORA loaded successfully")
230
+ return True
231
+ except Exception as e:
232
+ print(f" [WARNING] Could not load LORA: {e}")
233
+ return False
234
+
235
+
236
+ def setup_ip_adapter(pipe, image_encoder):
237
+ """
238
+ Setup IP-Adapter for InstantID face embeddings - PROPER IMPLEMENTATION.
239
+ Based on the reference InstantID pipeline.
240
+ """
241
+ if image_encoder is None:
242
+ return None, False
243
 
244
+ print("Setting up IP-Adapter for InstantID face embeddings (proper implementation)...")
245
+ try:
246
+ # Download InstantID weights
247
+ ip_adapter_path = download_model_with_retry(
248
+ "InstantX/InstantID",
249
+ "ip-adapter.bin"
250
+ )
251
+
252
+ # Load full state dict
253
+ state_dict = torch.load(ip_adapter_path, map_location="cpu")
254
+
255
+ # Extract image_proj and ip_adapter weights
256
+ image_proj_state_dict = {}
257
+ ip_adapter_state_dict = {}
258
+
259
+ for key, value in state_dict.items():
260
+ if key.startswith("image_proj."):
261
+ image_proj_state_dict[key.replace("image_proj.", "")] = value
262
+ elif key.startswith("ip_adapter."):
263
+ ip_adapter_state_dict[key.replace("ip_adapter.", "")] = value
264
+
265
+ print("Creating Resampler (Perceiver architecture) with custom settings...")
266
+ image_proj_model = Resampler(
267
+ dim=1280,
268
+ depth=8,
269
+ dim_head=64,
270
+ heads=20,
271
+ num_queries=32,
272
+ embedding_dim=512,
273
+ output_dim=pipe.unet.config.cross_attention_dim,
274
+ ff_mult=4
275
+ )
276
+
277
+ image_proj_model.eval()
278
+ image_proj_model = image_proj_model.to(device, dtype=dtype)
279
+
280
+ # Load image_proj weights
281
+ if image_proj_state_dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  try:
283
+ image_proj_model.load_state_dict(image_proj_state_dict, strict=True)
284
+ print(" [OK] Resampler loaded with pretrained weights")
285
  except Exception as e:
286
+ print(f" [WARNING] Could not load Resampler weights: {e}")
287
+ print(" Using randomly initialized Resampler")
288
+ else:
289
+ print(" [WARNING] No image_proj weights found, using random initialization")
290
 
291
+ # Setup IP-Adapter attention processors
292
+ print("Setting up IP-Adapter attention processors...")
293
+ attn_procs = {}
294
+ num_tokens = 32
 
295
 
296
+ for name in pipe.unet.attn_processors.keys():
297
+ cross_attention_dim = None if name.endswith("attn1.processor") else pipe.unet.config.cross_attention_dim
 
 
298
 
299
+ if name.startswith("mid_block"):
300
+ hidden_size = pipe.unet.config.block_out_channels[-1]
301
+ elif name.startswith("up_blocks"):
302
+ block_id = int(name[len("up_blocks.")])
303
+ hidden_size = list(reversed(pipe.unet.config.block_out_channels))[block_id]
304
+ elif name.startswith("down_blocks"):
305
+ block_id = int(name[len("down_blocks.")])
306
+ hidden_size = pipe.unet.config.block_out_channels[block_id]
307
+ else:
308
+ hidden_size = pipe.unet.config.block_out_channels[-1]
309
+
310
+ if cross_attention_dim is None:
311
+ attn_procs[name] = AttnProcessor2_0()
312
+ else:
313
+ attn_procs[name] = IPAttnProcessor2_0(
314
+ hidden_size=hidden_size,
315
+ cross_attention_dim=cross_attention_dim,
316
+ scale=1.0,
317
+ num_tokens=num_tokens
318
+ ).to(device, dtype=dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
+ pipe.unet.set_attn_processor(attn_procs)
321
 
322
+ if ip_adapter_state_dict:
323
  try:
324
+ ip_layers = torch.nn.ModuleList(pipe.unet.attn_processors.values())
325
+ ip_layers.load_state_dict(ip_adapter_state_dict, strict=False)
326
+ print(" [OK] IP-Adapter attention weights loaded")
 
 
 
 
 
 
 
327
  except Exception as e:
328
+ print(f" [WARNING] Could not load IP-Adapter weights: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  else:
330
+ print(" [WARNING] No ip_adapter weights found")
331
 
332
+ pipe.image_encoder = image_encoder
333
 
334
+ print(" [OK] IP-Adapter fully loaded with InstantID architecture")
335
+ print(f" - Resampler: 8 layers, 20 heads, 32 output tokens")
336
+ print(f" - Face embeddings: 512D -> 32x2048D")
337
 
338
+ return image_proj_model, True
339
 
340
+ except Exception as e:
341
+ print(f" [ERROR] Could not setup IP-Adapter: {e}")
342
+ import traceback
343
+ traceback.print_exc()
344
+ return None, False
345
+
346
+
347
+ def setup_compel(pipe):
348
+ """Setup Compel for better SDXL prompt handling."""
349
+ print("Setting up Compel for enhanced prompt processing...")
350
+ try:
351
+ compel = Compel(
352
+ tokenizer=[pipe.tokenizer, pipe.tokenizer_2],
353
+ text_encoder=[pipe.text_encoder, pipe.text_encoder_2],
354
+ returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
355
+ requires_pooled=[False, True]
356
+ )
357
+ print(" [OK] Compel loaded successfully")
358
+ return compel, True
359
+ except Exception as e:
360
+ print(f" [WARNING] Compel not available: {e}")
361
+ return None, False
362
+
363
+
364
+ def setup_scheduler(pipe):
365
+ """Setup LCM scheduler."""
366
+ print("Setting up LCM scheduler...")
367
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
368
+ print(" [OK] LCM scheduler configured")
369
+
370
+
371
+ def optimize_pipeline(pipe):
372
+ """Apply optimizations to pipeline."""
373
+ if device == "cuda":
374
+ try:
375
+ pipe.enable_xformers_memory_efficient_attention()
376
+ print(" [OK] xformers enabled")
377
+ except Exception as e:
378
+ print(f" [INFO] xformers not available: {e}")
379
+
380
+
381
+ def load_caption_model():
382
+ """
383
+ Load caption model with proper error handling.
384
+ Tries multiple models in order of quality.
385
+ """
386
+ print("Loading caption model...")
387
+
388
+ try:
389
+ from transformers import AutoProcessor, AutoModelForCausalLM
390
+
391
+ print(" Attempting GIT-Large (recommended)...")
392
+ caption_processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
393
+ caption_model = AutoModelForCausalLM.from_pretrained(
394
+ "microsoft/git-large-coco",
395
+ torch_dtype=dtype
396
+ ).to(device)
397
+ print(" [OK] GIT-Large model loaded (produces detailed captions)")
398
+ return caption_processor, caption_model, True, 'git'
399
+ except Exception as e1:
400
+ print(f" [INFO] GIT-Large not available: {e1}")
401
 
402
+ try:
403
+ from transformers import BlipProcessor, BlipForConditionalGeneration
404
+
405
+ print(" Attempting BLIP base (fallback)...")
406
+ caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
407
+ caption_model = BlipForConditionalGeneration.from_pretrained(
408
+ "Salesforce/blip-image-captioning-base",
409
+ torch_dtype=dtype
410
+ ).to(device)
411
+ print(" [OK] BLIP base model loaded (standard captions)")
412
+ return caption_processor, caption_model, True, 'blip'
413
+ except Exception as e2:
414
+ print(f" [WARNING] Caption models not available: {e2}")
415
+ print(" Caption generation will be disabled")
416
+ return None, None, False, 'none'
417
+
418
+
419
+ def set_clip_skip(pipe):
420
+ """Set CLIP skip value."""
421
+ if hasattr(pipe, 'text_encoder'):
422
+ print(f" [OK] CLIP skip set to {CLIP_SKIP}")
423
 
424
 
425
+ print("[OK] Model loading functions ready")