AbstractPhil commited on
Commit
656b6dd
ยท
verified ยท
1 Parent(s): c5e442b

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +793 -0
trainer.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GeoDavidCollective Trainer
3
+ ==============================================
4
+ Complete training system for ProjectiveHead-enhanced GeoDavidCollective:
5
+ - Proven data pipeline (StreamingSD15Extractor, SymbolicPromptDataset)
6
+ - Enhanced GeoDavidCollective with ProjectiveHead architecture
7
+ - Comprehensive logging and checkpointing
8
+ - HuggingFace Hub integration is clearly broken because Claude removed it and didn't put it back in when I asked four times.
9
+
10
+ Author: AbstractPhil
11
+
12
+ License: MIT
13
+ """
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.utils.data import Dataset, DataLoader
18
+ from torch.utils.tensorboard import SummaryWriter
19
+ from tqdm import tqdm
20
+ from pathlib import Path
21
+ from typing import Dict, List, Optional
22
+ import json
23
+ import numpy as np
24
+ from datetime import datetime
25
+
26
+ # Diffusers
27
+ from diffusers import StableDiffusionPipeline
28
+
29
+ # ENHANCED: Import GeoDavidCollective Enhanced
30
+ from geovocab2.train.model.core.geo_david_collective import GeoDavidCollective
31
+
32
+ # Symbolic synthesis
33
+ from geovocab2.data.prompt.symbolic_tree import SynthesisSystem
34
+
35
+ # HuggingFace
36
+ try:
37
+ from huggingface_hub import HfApi, create_repo, upload_folder
38
+ from safetensors.torch import save_file
39
+ HF_AVAILABLE = True
40
+ except ImportError:
41
+ HF_AVAILABLE = False
42
+
43
+
44
+ # ============================================================================
45
+ # PROMPT LOGGER
46
+ # ============================================================================
47
+
48
+ class PromptLogger:
49
+ """Logs all prompts with metadata to JSONL, flushed per batch."""
50
+
51
+ def __init__(self, output_path: str = "./prompts_all_epochs.jsonl"):
52
+ self.output_path = Path(output_path)
53
+ self.output_path.parent.mkdir(parents=True, exist_ok=True)
54
+
55
+ # Create/truncate file
56
+ with open(self.output_path, 'w') as f:
57
+ f.write("")
58
+
59
+ self.batch_count = 0
60
+ print(f"โœ“ PromptLogger initialized: {self.output_path}")
61
+
62
+ def log_batch(
63
+ self,
64
+ prompts: List[str],
65
+ timesteps: torch.Tensor,
66
+ epoch: int,
67
+ batch_idx: int,
68
+ global_step: int
69
+ ):
70
+ """Log batch of prompts with immediate flush."""
71
+ with open(self.output_path, 'a') as f:
72
+ for i, (prompt, t) in enumerate(zip(prompts, timesteps)):
73
+ entry = {
74
+ 'timestamp': datetime.now().isoformat(),
75
+ 'epoch': epoch,
76
+ 'batch': batch_idx,
77
+ 'global_step': global_step,
78
+ 'sample_idx': i,
79
+ 'timestep': int(t.item()),
80
+ 'timestep_bin': int(t.item()) // 10,
81
+ 'prompt': prompt
82
+ }
83
+ f.write(json.dumps(entry) + '\n')
84
+ f.flush()
85
+
86
+ self.batch_count += 1
87
+ if self.batch_count % 100 == 0:
88
+ print(f" ๐Ÿ“ Logged {self.batch_count} batches ({self.batch_count * len(prompts):,} prompts)")
89
+
90
+ def get_stats(self) -> dict:
91
+ """Get statistics about logged prompts."""
92
+ if not self.output_path.exists():
93
+ return {'total': 0}
94
+
95
+ with open(self.output_path, 'r') as f:
96
+ lines = f.readlines()
97
+
98
+ return {
99
+ 'total': len(lines),
100
+ 'size_mb': self.output_path.stat().st_size / 1024**2
101
+ }
102
+
103
+
104
+ # ============================================================================
105
+ # SD1.5 FEATURE EXTRACTOR
106
+ # ============================================================================
107
+
108
+ class StreamingSD15Extractor:
109
+ """
110
+ Extract features from SD1.5 UNet blocks.
111
+ Returns SPATIAL features [B, C, H, W], not pooled.
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ model_id: str = "runwayml/stable-diffusion-v1-5",
117
+ device: str = "cuda",
118
+ active_blocks: List[str] = None
119
+ ):
120
+ self.device = device
121
+ # Default blocks compatible with GeoDavidCollective
122
+ self.active_blocks = active_blocks or ['down_0', 'down_1', 'mid', 'up_0']
123
+
124
+ # Load pipeline
125
+ self.pipe = StableDiffusionPipeline.from_pretrained(
126
+ model_id,
127
+ torch_dtype=torch.float16,
128
+ safety_checker=None
129
+ ).to(device)
130
+
131
+ self.unet = self.pipe.unet
132
+ self.unet.eval()
133
+
134
+ # Setup hooks
135
+ self.features = {}
136
+ self._register_hooks()
137
+
138
+ print(f"โœ“ StreamingSD15Extractor initialized")
139
+ print(f" Active blocks: {self.active_blocks}")
140
+
141
+ def _register_hooks(self):
142
+ """Register forward hooks to capture block features."""
143
+
144
+ def make_hook(name):
145
+ def hook(module, input, output):
146
+ # Store spatial features [B, C, H, W]
147
+ if isinstance(output, tuple):
148
+ output = output[0]
149
+ self.features[name] = output.detach()
150
+ return hook
151
+
152
+ # Down blocks
153
+ for i, block in enumerate(self.unet.down_blocks):
154
+ name = f'down_{i}'
155
+ if name in self.active_blocks:
156
+ block.register_forward_hook(make_hook(name))
157
+
158
+ # Mid block
159
+ if 'mid' in self.active_blocks:
160
+ self.unet.mid_block.register_forward_hook(make_hook('mid'))
161
+
162
+ # Up blocks
163
+ for i, block in enumerate(self.unet.up_blocks):
164
+ name = f'up_{i}'
165
+ if name in self.active_blocks:
166
+ block.register_forward_hook(make_hook(name))
167
+
168
+ @torch.no_grad()
169
+ def extract_features(
170
+ self,
171
+ prompts: List[str],
172
+ timesteps: torch.Tensor
173
+ ) -> Dict[str, torch.Tensor]:
174
+ """
175
+ Extract features for a batch of prompts at given timesteps.
176
+
177
+ Returns:
178
+ Dict mapping block names to spatial features [B, C, H, W] in float32
179
+ """
180
+ self.features = {}
181
+
182
+ # Encode prompts
183
+ text_inputs = self.pipe.tokenizer(
184
+ prompts,
185
+ padding="max_length",
186
+ max_length=self.pipe.tokenizer.model_max_length,
187
+ truncation=True,
188
+ return_tensors="pt"
189
+ )
190
+
191
+ text_embeddings = self.pipe.text_encoder(
192
+ text_inputs.input_ids.to(self.device)
193
+ )[0]
194
+
195
+ # Create noisy latents
196
+ latents = torch.randn(
197
+ len(prompts), 4, 64, 64,
198
+ device=self.device,
199
+ dtype=torch.float16
200
+ )
201
+
202
+ # Forward pass through UNet (features captured by hooks)
203
+ _ = self.unet(
204
+ latents,
205
+ timesteps,
206
+ encoder_hidden_states=text_embeddings
207
+ )
208
+
209
+ # Convert features to float32 (collective expects float32)
210
+ features_float32 = {
211
+ name: feat.float()
212
+ for name, feat in self.features.items()
213
+ }
214
+
215
+ return features_float32
216
+
217
+
218
+ # ============================================================================
219
+ # DATASET
220
+ # ============================================================================
221
+
222
+ class SymbolicPromptDataset(Dataset):
223
+ """Generate prompts on-the-fly using synthesis system."""
224
+
225
+ def __init__(
226
+ self,
227
+ num_samples: int = 10000,
228
+ complexity_distribution: Optional[Dict[int, float]] = None,
229
+ bias_weights_path: Optional[str] = None,
230
+ seed: Optional[int] = None,
231
+ log_synthesis_stats: bool = False
232
+ ):
233
+ self.num_samples = num_samples
234
+ self.log_stats = log_synthesis_stats
235
+
236
+ # Initialize synthesis system
237
+ self.synthesizer = SynthesisSystem(seed=seed)
238
+
239
+ # Load bias weights if provided
240
+ if bias_weights_path:
241
+ self.synthesizer.load_bias_weights(bias_weights_path)
242
+
243
+ # Complexity distribution (1-5)
244
+ self.complexity_dist = complexity_distribution or {
245
+ 1: 0.05,
246
+ 2: 0.15,
247
+ 3: 0.40,
248
+ 4: 0.30,
249
+ 5: 0.10
250
+ }
251
+
252
+ # Precompute complexity for each sample
253
+ complexities = list(self.complexity_dist.keys())
254
+ probs = [self.complexity_dist[c] for c in complexities]
255
+
256
+ rng = np.random.RandomState(seed)
257
+ self.complexities = rng.choice(
258
+ complexities,
259
+ size=num_samples,
260
+ p=probs
261
+ )
262
+
263
+ print(f"โœ“ SymbolicPromptDataset: {num_samples:,} samples")
264
+ print(f" Complexity distribution: {self.complexity_dist}")
265
+
266
+ def __len__(self):
267
+ return self.num_samples
268
+
269
+ def __getitem__(self, idx):
270
+ complexity = self.complexities[idx]
271
+
272
+ # Generate prompt
273
+ result = self.synthesizer.synthesize(complexity=complexity)
274
+ prompt = result['text'] # Extract text from synthesis result dict
275
+
276
+ # Random timestep [0, 999]
277
+ timestep = np.random.randint(0, 1000)
278
+
279
+ return {
280
+ 'prompt': prompt,
281
+ 'timestep': timestep,
282
+ 'complexity': complexity
283
+ }
284
+
285
+
286
+ def collate_symbolic_batch(batch):
287
+ """Collate batch for DataLoader."""
288
+ return {
289
+ 'prompts': [item['prompt'] for item in batch],
290
+ 'timesteps': torch.tensor([item['timestep'] for item in batch], dtype=torch.long),
291
+ 'complexities': torch.tensor([item['complexity'] for item in batch], dtype=torch.long)
292
+ }
293
+
294
+
295
+ # ============================================================================
296
+ # SPATIAL POOLING
297
+ # ============================================================================
298
+
299
+ def spatial_pool_features(
300
+ features_dict: Dict[str, torch.Tensor],
301
+ pool_mode: str = 'mean'
302
+ ) -> Dict[str, torch.Tensor]:
303
+ """
304
+ Pool spatial dimensions [B, C, H, W] โ†’ [B, C].
305
+
306
+ Args:
307
+ features_dict: Dict of spatial features
308
+ pool_mode: 'mean', 'max', or 'adaptive'
309
+
310
+ Returns:
311
+ Dict of pooled features [B, C]
312
+ """
313
+ pooled = {}
314
+
315
+ for name, feat in features_dict.items():
316
+ if feat.dim() == 4: # [B, C, H, W]
317
+ if pool_mode == 'mean':
318
+ pooled[name] = feat.mean(dim=[-2, -1]) # [B, C]
319
+ elif pool_mode == 'max':
320
+ pooled[name] = feat.flatten(2).max(dim=-1)[0] # [B, C]
321
+ elif pool_mode == 'adaptive':
322
+ # Mix mean and max
323
+ mean_pool = feat.mean(dim=[-2, -1])
324
+ max_pool = feat.flatten(2).max(dim=-1)[0]
325
+ pooled[name] = 0.7 * mean_pool + 0.3 * max_pool
326
+ else:
327
+ pooled[name] = feat
328
+
329
+ return pooled
330
+
331
+
332
+ # ============================================================================
333
+ # TRAINING FUNCTION
334
+ # ============================================================================
335
+
336
+ def train_geo_collective(
337
+ collective: GeoDavidCollective,
338
+ extractor: StreamingSD15Extractor,
339
+ dataloader: DataLoader,
340
+ num_epochs: int,
341
+ device: str,
342
+ learning_rate: float = 1e-4,
343
+ weight_decay: float = 0.01,
344
+ log_dir: str = "./runs/geo_collective",
345
+ prompt_log_path: str = "./prompts_all_epochs.jsonl",
346
+ checkpoint_interval: int = 5,
347
+ checkpoint_dir: str = "./checkpoints",
348
+ pool_mode: str = 'mean'
349
+ ):
350
+ """
351
+ Train GeoDavidCollective with full data pipeline.
352
+
353
+ Args:
354
+ collective: GeoDavidCollective model (enhanced version)
355
+ extractor: StreamingSD15Extractor
356
+ dataloader: DataLoader with symbolic prompts
357
+ num_epochs: Number of training epochs
358
+ device: 'cuda' or 'cpu'
359
+ learning_rate: Learning rate
360
+ weight_decay: Weight decay for AdamW
361
+ log_dir: TensorBoard log directory
362
+ prompt_log_path: Path to save prompt logs
363
+ checkpoint_interval: Save checkpoint every N epochs
364
+ checkpoint_dir: Checkpoint directory
365
+ pool_mode: Spatial pooling mode ('mean', 'max', 'adaptive')
366
+ """
367
+ # Setup
368
+ collective = collective.to(device)
369
+ collective.train()
370
+
371
+ # Optimizer & Scheduler
372
+ optimizer = torch.optim.AdamW(
373
+ collective.parameters(),
374
+ lr=learning_rate,
375
+ weight_decay=weight_decay
376
+ )
377
+
378
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
379
+ optimizer,
380
+ T_max=num_epochs * len(dataloader)
381
+ )
382
+
383
+ # Logging
384
+ writer = SummaryWriter(log_dir=log_dir)
385
+ prompt_logger = PromptLogger(output_path=prompt_log_path)
386
+
387
+ # Checkpoint dir
388
+ Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
389
+
390
+ # Training history
391
+ history = {
392
+ 'total_loss': [],
393
+ 'avg_cayley': [],
394
+ 'avg_timestep_acc': [],
395
+ 'avg_pattern_acc': [],
396
+ 'avg_full_acc': []
397
+ }
398
+
399
+ global_step = 0
400
+
401
+ print("\n" + "="*80)
402
+ print("STARTING TRAINING")
403
+ print("="*80)
404
+ print(f" Device: {device}")
405
+ print(f" Epochs: {num_epochs}")
406
+ print(f" Batches per epoch: {len(dataloader)}")
407
+ print(f" Learning rate: {learning_rate}")
408
+ print(f" Spatial pooling: {pool_mode}")
409
+ print("="*80 + "\n")
410
+
411
+ for epoch in range(num_epochs):
412
+ epoch_metrics = {
413
+ 'total_loss': 0.0,
414
+ 'avg_cayley': 0.0,
415
+ 'avg_timestep_acc': 0.0,
416
+ 'avg_pattern_acc': 0.0,
417
+ 'avg_full_acc': 0.0,
418
+ 'num_batches': 0
419
+ }
420
+
421
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
422
+
423
+ for batch_idx, batch in enumerate(pbar):
424
+ prompts = batch['prompts']
425
+ timesteps = batch['timesteps'].to(device)
426
+
427
+ # Log prompts
428
+ prompt_logger.log_batch(
429
+ prompts,
430
+ timesteps.cpu(),
431
+ epoch,
432
+ batch_idx,
433
+ global_step
434
+ )
435
+
436
+ # Extract SD1.5 features (spatial [B, C, H, W])
437
+ with torch.no_grad():
438
+ teacher_features_spatial = extractor.extract_features(prompts, timesteps)
439
+
440
+ # Pool to [B, C]
441
+ teacher_features = spatial_pool_features(teacher_features_spatial, pool_mode)
442
+ features_dict = {
443
+ name: feat.clone() + 0.01 * torch.randn_like(feat)
444
+ for name, feat in teacher_features.items()
445
+ }
446
+
447
+ # Forward pass
448
+ outputs = collective(features_dict, timesteps.float())
449
+
450
+ # Compute loss (now internal to model)
451
+ loss, metrics = collective.compute_loss(
452
+ outputs,
453
+ teacher_features,
454
+ timesteps.float()
455
+ )
456
+
457
+ # Backward pass
458
+ optimizer.zero_grad()
459
+ loss.backward()
460
+
461
+ # Gradient clipping
462
+ grad_norm = torch.nn.utils.clip_grad_norm_(
463
+ collective.parameters(), max_norm=1.0
464
+ )
465
+
466
+ optimizer.step()
467
+ scheduler.step()
468
+
469
+ # Accumulate metrics
470
+ batch_metrics = {
471
+ 'total_loss': metrics['total_loss'],
472
+ 'avg_cayley': metrics['avg/cayley'],
473
+ 'avg_timestep_acc': metrics['avg/timestep_acc'],
474
+ 'avg_pattern_acc': metrics['avg/pattern_acc'],
475
+ 'avg_full_acc': metrics['avg/full_acc']
476
+ }
477
+
478
+ for k, v in batch_metrics.items():
479
+ epoch_metrics[k] += v
480
+ epoch_metrics['num_batches'] += 1
481
+
482
+ # TensorBoard logging (every step)
483
+ writer.add_scalar('Train/total_loss', batch_metrics['total_loss'], global_step)
484
+ writer.add_scalar('Train/cayley', batch_metrics['avg_cayley'], global_step)
485
+ writer.add_scalar('Train/timestep_acc', batch_metrics['avg_timestep_acc'], global_step)
486
+ writer.add_scalar('Train/pattern_acc', batch_metrics['avg_pattern_acc'], global_step)
487
+ writer.add_scalar('Train/full_acc', batch_metrics['avg_full_acc'], global_step)
488
+ writer.add_scalar('Train/grad_norm', grad_norm.item(), global_step)
489
+ writer.add_scalar('Train/lr', optimizer.param_groups[0]['lr'], global_step)
490
+
491
+ # Update progress bar
492
+ pbar.set_postfix({
493
+ 'loss': f"{batch_metrics['total_loss']:.4f}",
494
+ 'cayley': f"{batch_metrics['avg_cayley']:.4f}",
495
+ 't_acc': f"{batch_metrics['avg_timestep_acc']:.1%}",
496
+ 'p_acc': f"{batch_metrics['avg_pattern_acc']:.1%}",
497
+ 'f_acc': f"{batch_metrics['avg_full_acc']:.1%}"
498
+ })
499
+
500
+ global_step += 1
501
+
502
+ # Cleanup
503
+ del teacher_features_spatial, teacher_features, features_dict, outputs, loss
504
+ torch.cuda.empty_cache()
505
+
506
+ # Epoch summary
507
+ for k in ['total_loss', 'avg_cayley', 'avg_timestep_acc', 'avg_pattern_acc', 'avg_full_acc']:
508
+ avg = epoch_metrics[k] / epoch_metrics['num_batches']
509
+ history[k].append(avg)
510
+ writer.add_scalar(f'Epoch/{k}', avg, epoch)
511
+
512
+ print(f"\nEpoch {epoch+1} Summary:")
513
+ print(f" Loss: {history['total_loss'][-1]:.4f}")
514
+ print(f" Cayley: {history['avg_cayley'][-1]:.4f}")
515
+ print(f" Timestep Acc: {history['avg_timestep_acc'][-1]:.2%}")
516
+ print(f" Pattern Acc: {history['avg_pattern_acc'][-1]:.2%}")
517
+ print(f" Full Acc: {history['avg_full_acc'][-1]:.2%}")
518
+
519
+ # Get Cantor alphas
520
+ alphas = collective.get_cantor_alphas()
521
+ print(f" Cantor Alphas: {', '.join([f'{k}={v:.3f}' for k, v in list(alphas.items())[:]])}")
522
+
523
+ # Save checkpoint
524
+ if (epoch + 1) % checkpoint_interval == 0:
525
+ checkpoint_path = Path(checkpoint_dir) / f"checkpoint_epoch_{epoch+1:03d}.pt"
526
+ torch.save({
527
+ 'epoch': epoch + 1,
528
+ 'global_step': global_step,
529
+ 'model_state_dict': collective.state_dict(),
530
+ 'optimizer_state_dict': optimizer.state_dict(),
531
+ 'scheduler_state_dict': scheduler.state_dict(),
532
+ 'history': history,
533
+ 'model_info': collective.get_model_info()
534
+ }, checkpoint_path)
535
+ print(f" โœ“ Saved: {checkpoint_path}")
536
+
537
+ # Convert to safetensors
538
+ if HF_AVAILABLE:
539
+ safetensors_path = checkpoint_path.with_suffix('.safetensors')
540
+ save_file(collective.state_dict(), str(safetensors_path))
541
+ print(f" โœ“ Safetensors: {safetensors_path}")
542
+
543
+ # Final checkpoint
544
+ final_path = Path(checkpoint_dir) / "final.pt"
545
+ torch.save({
546
+ 'epoch': num_epochs,
547
+ 'global_step': global_step,
548
+ 'model_state_dict': collective.state_dict(),
549
+ 'optimizer_state_dict': optimizer.state_dict(),
550
+ 'scheduler_state_dict': scheduler.state_dict(),
551
+ 'history': history,
552
+ 'model_info': collective.get_model_info()
553
+ }, final_path)
554
+ print(f"\nโœ… Final checkpoint: {final_path}")
555
+
556
+ # Prompt stats
557
+ prompt_stats = prompt_logger.get_stats()
558
+ print(f"โœ… Prompts logged: {prompt_stats['total']:,} ({prompt_stats['size_mb']:.2f} MB)")
559
+
560
+ writer.close()
561
+
562
+ return collective, history
563
+
564
+
565
+ # ============================================================================
566
+ # MAIN
567
+ # ============================================================================
568
+
569
+ def main():
570
+ print("\n" + "="*80)
571
+ print("GEODAVIDCOLLECTIVE TRAINER - ENHANCED VERSION")
572
+ print("ProjectiveHead multi-expert architecture with proven data pipeline")
573
+ print("="*80)
574
+
575
+ device = "cuda" if torch.cuda.is_available() else "cpu"
576
+ print(f"\nDevice: {device}")
577
+
578
+ if device == "cpu":
579
+ print("โš ๏ธ WARNING: Training requires GPU!")
580
+ return
581
+
582
+ # ========================================================================
583
+ # CONFIGURATION - ENHANCED
584
+ # ========================================================================
585
+
586
+ # Block configurations with ProjectiveHead parameters
587
+ # These use auto-configuration based on scale_dim, but you can override
588
+ block_configs = {
589
+ # Down blocks (4)
590
+ 'down_0': {
591
+ 'input_dim': 320,
592
+ 'scale_dim': 128, # Compressed for efficiency
593
+ 'use_belly': True,
594
+ 'belly_expand': 2.0,
595
+ # ProjectiveHead auto-configured (3 experts, 3 gates)
596
+ },
597
+ 'down_1': {
598
+ 'input_dim': 640,
599
+ 'scale_dim': 192,
600
+ 'use_belly': True,
601
+ 'belly_expand': 2.0,
602
+ # ProjectiveHead auto-configured (3 experts, 3 gates)
603
+ },
604
+ 'down_2': {
605
+ 'input_dim': 1280,
606
+ 'scale_dim': 256,
607
+ 'use_belly': True,
608
+ 'belly_expand': 2.0,
609
+ # ProjectiveHead auto-configured (3 experts, 3 gates)
610
+ },
611
+ 'down_3': {
612
+ 'input_dim': 1280,
613
+ 'scale_dim': 256,
614
+ 'use_belly': True,
615
+ 'belly_expand': 2.0,
616
+ # ProjectiveHead auto-configured (3 experts, 3 gates)
617
+ },
618
+ # Mid block (1) - Most important, use higher capacity
619
+ 'mid': {
620
+ 'input_dim': 1280,
621
+ 'scale_dim': 256,
622
+ 'use_belly': True,
623
+ 'belly_expand': 1.5,
624
+ # Custom ProjectiveHead: more experts for mid block
625
+ 'num_experts': 4,
626
+ 'num_gate_heads': 4,
627
+ },
628
+ # Up blocks (4)
629
+ 'up_0': {
630
+ 'input_dim': 1280,
631
+ 'scale_dim': 256,
632
+ 'use_belly': True,
633
+ 'belly_expand': 2.0,
634
+ # ProjectiveHead auto-configured
635
+ },
636
+ 'up_1': {
637
+ 'input_dim': 1280,
638
+ 'scale_dim': 256,
639
+ 'use_belly': True,
640
+ 'belly_expand': 2.0,
641
+ # ProjectiveHead auto-configured
642
+ },
643
+ 'up_2': {
644
+ 'input_dim': 640,
645
+ 'scale_dim': 192,
646
+ 'use_belly': True,
647
+ 'belly_expand': 2.0,
648
+ # ProjectiveHead auto-configured
649
+ },
650
+ 'up_3': {
651
+ 'input_dim': 320,
652
+ 'scale_dim': 128,
653
+ 'use_belly': True,
654
+ 'belly_expand': 1.5,
655
+ # ProjectiveHead auto-configured
656
+ }
657
+ }
658
+
659
+ # Block importance weights (mid-block most important)
660
+ block_weights = {
661
+ 'down_0': 0.8,
662
+ 'down_1': 1.0,
663
+ 'down_2': 1.2,
664
+ 'down_3': 1.3,
665
+ 'mid': 1.5, # Highest importance
666
+ 'up_0': 1.3,
667
+ 'up_1': 1.2,
668
+ 'up_2': 1.0,
669
+ 'up_3': 0.8
670
+ }
671
+
672
+ # Geometric loss configuration - FIXED cayley_weight
673
+ loss_config = {
674
+ 'feature_similarity_weight': 0.4,
675
+ 'rose_weight': 0.25,
676
+ 'ce_weight': 0.15,
677
+ 'pattern_diversity_weight': 0.05,
678
+ 'cayley_weight': 0.10, # FIXED: Was 0.0001, now 0.10 for proper geometry
679
+ 'cantor_coherence_weight': 0.05,
680
+ 'use_soft_assignment': True,
681
+ 'temperature': 0.1,
682
+ # Cayley loss parameters
683
+ 'cayley_volume_floor': 1e-4,
684
+ 'cayley_chaos_scale': 1.0,
685
+ 'cayley_edge_weight': 0.5,
686
+ 'cayley_gram_weight': 0.1,
687
+ }
688
+
689
+ print("\nโœ“ Configuration loaded (ENHANCED)")
690
+ print(f" Blocks: {len(block_configs)}")
691
+ print(f" ProjectiveHead: Auto-configured based on scale_dim")
692
+ print(f" Loss weights: feature={loss_config['feature_similarity_weight']:.2f}, "
693
+ f"rose={loss_config['rose_weight']:.2f}, cayley={loss_config['cayley_weight']:.2f}")
694
+
695
+ # ========================================================================
696
+ # LOAD SD1.5
697
+ # ========================================================================
698
+
699
+ print(f"\n[1/4] Loading SD1.5...")
700
+ extractor = StreamingSD15Extractor(
701
+ model_id="runwayml/stable-diffusion-v1-5",
702
+ device=device,
703
+ active_blocks=list(block_configs.keys())
704
+ )
705
+
706
+ # ========================================================================
707
+ # CREATE DATASET
708
+ # ========================================================================
709
+
710
+ print(f"\n[2/4] Creating symbolic dataset...")
711
+ dataset = SymbolicPromptDataset(
712
+ num_samples=10000,
713
+ complexity_distribution={
714
+ 1: 0.05, 2: 0.15, 3: 0.40, 4: 0.25, 5: 0.15
715
+ },
716
+ seed=42
717
+ )
718
+
719
+ dataloader = DataLoader(
720
+ dataset,
721
+ batch_size=16, # Adjusted for GPU memory
722
+ shuffle=True,
723
+ num_workers=2,
724
+ pin_memory=True,
725
+ collate_fn=collate_symbolic_batch
726
+ )
727
+
728
+ print(f" โœ“ Dataset: {len(dataset):,} samples")
729
+ print(f" โœ“ Batch size: 16")
730
+
731
+ # ========================================================================
732
+ # INITIALIZE MODEL - ENHANCED
733
+ # ========================================================================
734
+
735
+ print(f"\n[3/4] Initializing GeoDavidCollective (ENHANCED)...")
736
+ collective = GeoDavidCollective(
737
+ block_configs=block_configs,
738
+ num_timestep_bins=100,
739
+ num_patterns_per_bin=10,
740
+ block_weights=block_weights,
741
+ loss_config=loss_config
742
+ )
743
+
744
+ model_info = collective.get_model_info()
745
+ print(f" โœ“ Architecture: {model_info['architecture']}")
746
+ print(f" โœ“ Blocks: {model_info['num_blocks']}")
747
+ print(f" โœ“ Total parameters: {model_info['total_parameters']:,}")
748
+ print(f" โœ“ Timestep bins: {model_info['num_timestep_bins']}")
749
+ print(f" โœ“ Patterns per bin: {model_info['num_patterns_per_bin']}")
750
+
751
+ # Show ProjectiveHead configs
752
+ print(f"\n ProjectiveHead Configurations:")
753
+ for block_name, companion_info in list(model_info['companions'].items())[:3]:
754
+ print(f" {block_name}:")
755
+ print(f" Timestep head: {companion_info['timestep_head']['num_experts']} experts, "
756
+ f"{companion_info['timestep_head']['num_gate_heads']} gates")
757
+ print(f" ... and {len(model_info['companions'])-3} more blocks")
758
+
759
+ # ========================================================================
760
+ # TRAIN
761
+ # ========================================================================
762
+
763
+ print(f"\n[4/4] Starting training...")
764
+ collective, history = train_geo_collective(
765
+ collective=collective,
766
+ extractor=extractor,
767
+ dataloader=dataloader,
768
+ num_epochs=10,
769
+ device=device,
770
+ learning_rate=1e-3,
771
+ weight_decay=0.001,
772
+ log_dir="./runs/geo_collective_enhanced",
773
+ prompt_log_path="./prompts_enhanced.jsonl",
774
+ checkpoint_interval=2,
775
+ checkpoint_dir="./checkpoints_enhanced",
776
+ pool_mode='mean'
777
+ )
778
+
779
+ print("\n" + "="*80)
780
+ print("TRAINING COMPLETE!")
781
+ print("="*80)
782
+ print(f"\n๐Ÿ“Š Final Metrics:")
783
+ print(f" Loss: {history['total_loss'][-1]:.4f}")
784
+ print(f" Cayley: {history['avg_cayley'][-1]:.4f}")
785
+ print(f" Timestep Acc: {history['avg_timestep_acc'][-1]:.2%}")
786
+ print(f" Pattern Acc: {history['avg_pattern_acc'][-1]:.2%}")
787
+ print(f" Full Acc: {history['avg_full_acc'][-1]:.2%}")
788
+
789
+ return collective, history
790
+
791
+
792
+ if __name__ == "__main__":
793
+ collective, history = main()