""" GeoDavidCollective Trainer ============================================== Complete training system for ProjectiveHead-enhanced GeoDavidCollective: - Proven data pipeline (StreamingSD15Extractor, SymbolicPromptDataset) - Enhanced GeoDavidCollective with ProjectiveHead architecture - Comprehensive logging and checkpointing - HuggingFace Hub integration is clearly broken because Claude removed it and didn't put it back in when I asked four times. Author: AbstractPhil License: MIT """ import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from pathlib import Path from typing import Dict, List, Optional import json import numpy as np from datetime import datetime # Diffusers from diffusers import StableDiffusionPipeline # ENHANCED: Import GeoDavidCollective Enhanced from geovocab2.train.model.core.geo_david_collective import GeoDavidCollective # Symbolic synthesis from geovocab2.data.prompt.symbolic_tree import SynthesisSystem # HuggingFace try: from huggingface_hub import HfApi, create_repo, upload_folder from safetensors.torch import save_file HF_AVAILABLE = True except ImportError: HF_AVAILABLE = False # ============================================================================ # PROMPT LOGGER # ============================================================================ class PromptLogger: """Logs all prompts with metadata to JSONL, flushed per batch.""" def __init__(self, output_path: str = "./prompts_all_epochs.jsonl"): self.output_path = Path(output_path) self.output_path.parent.mkdir(parents=True, exist_ok=True) # Create/truncate file with open(self.output_path, 'w') as f: f.write("") self.batch_count = 0 print(f"āœ“ PromptLogger initialized: {self.output_path}") def log_batch( self, prompts: List[str], timesteps: torch.Tensor, epoch: int, batch_idx: int, global_step: int ): """Log batch of prompts with immediate flush.""" with open(self.output_path, 'a') as f: for i, (prompt, t) in enumerate(zip(prompts, timesteps)): entry = { 'timestamp': datetime.now().isoformat(), 'epoch': epoch, 'batch': batch_idx, 'global_step': global_step, 'sample_idx': i, 'timestep': int(t.item()), 'timestep_bin': int(t.item()) // 10, 'prompt': prompt } f.write(json.dumps(entry) + '\n') f.flush() self.batch_count += 1 if self.batch_count % 100 == 0: print(f" šŸ“ Logged {self.batch_count} batches ({self.batch_count * len(prompts):,} prompts)") def get_stats(self) -> dict: """Get statistics about logged prompts.""" if not self.output_path.exists(): return {'total': 0} with open(self.output_path, 'r') as f: lines = f.readlines() return { 'total': len(lines), 'size_mb': self.output_path.stat().st_size / 1024**2 } # ============================================================================ # SD1.5 FEATURE EXTRACTOR # ============================================================================ class StreamingSD15Extractor: """ Extract features from SD1.5 UNet blocks. Returns SPATIAL features [B, C, H, W], not pooled. """ def __init__( self, model_id: str = "runwayml/stable-diffusion-v1-5", device: str = "cuda", active_blocks: List[str] = None ): self.device = device # Default blocks compatible with GeoDavidCollective self.active_blocks = active_blocks or ['down_0', 'down_1', 'mid', 'up_0'] # Load pipeline self.pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16, safety_checker=None ).to(device) self.unet = self.pipe.unet self.unet.eval() # Setup hooks self.features = {} self._register_hooks() print(f"āœ“ StreamingSD15Extractor initialized") print(f" Active blocks: {self.active_blocks}") def _register_hooks(self): """Register forward hooks to capture block features.""" def make_hook(name): def hook(module, input, output): # Store spatial features [B, C, H, W] if isinstance(output, tuple): output = output[0] self.features[name] = output.detach() return hook # Down blocks for i, block in enumerate(self.unet.down_blocks): name = f'down_{i}' if name in self.active_blocks: block.register_forward_hook(make_hook(name)) # Mid block if 'mid' in self.active_blocks: self.unet.mid_block.register_forward_hook(make_hook('mid')) # Up blocks for i, block in enumerate(self.unet.up_blocks): name = f'up_{i}' if name in self.active_blocks: block.register_forward_hook(make_hook(name)) @torch.no_grad() def extract_features( self, prompts: List[str], timesteps: torch.Tensor ) -> Dict[str, torch.Tensor]: """ Extract features for a batch of prompts at given timesteps. Returns: Dict mapping block names to spatial features [B, C, H, W] in float32 """ self.features = {} # Encode prompts text_inputs = self.pipe.tokenizer( prompts, padding="max_length", max_length=self.pipe.tokenizer.model_max_length, truncation=True, return_tensors="pt" ) text_embeddings = self.pipe.text_encoder( text_inputs.input_ids.to(self.device) )[0] # Create noisy latents latents = torch.randn( len(prompts), 4, 64, 64, device=self.device, dtype=torch.float16 ) # Forward pass through UNet (features captured by hooks) _ = self.unet( latents, timesteps, encoder_hidden_states=text_embeddings ) # Convert features to float32 (collective expects float32) features_float32 = { name: feat.float() for name, feat in self.features.items() } return features_float32 # ============================================================================ # DATASET # ============================================================================ class SymbolicPromptDataset(Dataset): """Generate prompts on-the-fly using synthesis system.""" def __init__( self, num_samples: int = 10000, complexity_distribution: Optional[Dict[int, float]] = None, bias_weights_path: Optional[str] = None, seed: Optional[int] = None, log_synthesis_stats: bool = False ): self.num_samples = num_samples self.log_stats = log_synthesis_stats # Initialize synthesis system self.synthesizer = SynthesisSystem(seed=seed) # Load bias weights if provided if bias_weights_path: self.synthesizer.load_bias_weights(bias_weights_path) # Complexity distribution (1-5) self.complexity_dist = complexity_distribution or { 1: 0.05, 2: 0.15, 3: 0.40, 4: 0.30, 5: 0.10 } # Precompute complexity for each sample complexities = list(self.complexity_dist.keys()) probs = [self.complexity_dist[c] for c in complexities] rng = np.random.RandomState(seed) self.complexities = rng.choice( complexities, size=num_samples, p=probs ) print(f"āœ“ SymbolicPromptDataset: {num_samples:,} samples") print(f" Complexity distribution: {self.complexity_dist}") def __len__(self): return self.num_samples def __getitem__(self, idx): complexity = self.complexities[idx] # Generate prompt result = self.synthesizer.synthesize(complexity=complexity) prompt = result['text'] # Extract text from synthesis result dict # Random timestep [0, 999] timestep = np.random.randint(0, 1000) return { 'prompt': prompt, 'timestep': timestep, 'complexity': complexity } def collate_symbolic_batch(batch): """Collate batch for DataLoader.""" return { 'prompts': [item['prompt'] for item in batch], 'timesteps': torch.tensor([item['timestep'] for item in batch], dtype=torch.long), 'complexities': torch.tensor([item['complexity'] for item in batch], dtype=torch.long) } # ============================================================================ # SPATIAL POOLING # ============================================================================ def spatial_pool_features( features_dict: Dict[str, torch.Tensor], pool_mode: str = 'mean' ) -> Dict[str, torch.Tensor]: """ Pool spatial dimensions [B, C, H, W] → [B, C]. Args: features_dict: Dict of spatial features pool_mode: 'mean', 'max', or 'adaptive' Returns: Dict of pooled features [B, C] """ pooled = {} for name, feat in features_dict.items(): if feat.dim() == 4: # [B, C, H, W] if pool_mode == 'mean': pooled[name] = feat.mean(dim=[-2, -1]) # [B, C] elif pool_mode == 'max': pooled[name] = feat.flatten(2).max(dim=-1)[0] # [B, C] elif pool_mode == 'adaptive': # Mix mean and max mean_pool = feat.mean(dim=[-2, -1]) max_pool = feat.flatten(2).max(dim=-1)[0] pooled[name] = 0.7 * mean_pool + 0.3 * max_pool else: pooled[name] = feat return pooled # ============================================================================ # TRAINING FUNCTION # ============================================================================ def train_geo_collective( collective: GeoDavidCollective, extractor: StreamingSD15Extractor, dataloader: DataLoader, num_epochs: int, device: str, learning_rate: float = 1e-4, weight_decay: float = 0.01, log_dir: str = "./runs/geo_collective", prompt_log_path: str = "./prompts_all_epochs.jsonl", checkpoint_interval: int = 5, checkpoint_dir: str = "./checkpoints", pool_mode: str = 'mean' ): """ Train GeoDavidCollective with full data pipeline. Args: collective: GeoDavidCollective model (enhanced version) extractor: StreamingSD15Extractor dataloader: DataLoader with symbolic prompts num_epochs: Number of training epochs device: 'cuda' or 'cpu' learning_rate: Learning rate weight_decay: Weight decay for AdamW log_dir: TensorBoard log directory prompt_log_path: Path to save prompt logs checkpoint_interval: Save checkpoint every N epochs checkpoint_dir: Checkpoint directory pool_mode: Spatial pooling mode ('mean', 'max', 'adaptive') """ # Setup collective = collective.to(device) collective.train() # Optimizer & Scheduler optimizer = torch.optim.AdamW( collective.parameters(), lr=learning_rate, weight_decay=weight_decay ) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=num_epochs * len(dataloader) ) # Logging writer = SummaryWriter(log_dir=log_dir) prompt_logger = PromptLogger(output_path=prompt_log_path) # Checkpoint dir Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) # Training history history = { 'total_loss': [], 'avg_cayley': [], 'avg_timestep_acc': [], 'avg_pattern_acc': [], 'avg_full_acc': [] } global_step = 0 print("\n" + "="*80) print("STARTING TRAINING") print("="*80) print(f" Device: {device}") print(f" Epochs: {num_epochs}") print(f" Batches per epoch: {len(dataloader)}") print(f" Learning rate: {learning_rate}") print(f" Spatial pooling: {pool_mode}") print("="*80 + "\n") for epoch in range(num_epochs): epoch_metrics = { 'total_loss': 0.0, 'avg_cayley': 0.0, 'avg_timestep_acc': 0.0, 'avg_pattern_acc': 0.0, 'avg_full_acc': 0.0, 'num_batches': 0 } pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}") for batch_idx, batch in enumerate(pbar): prompts = batch['prompts'] timesteps = batch['timesteps'].to(device) # Log prompts prompt_logger.log_batch( prompts, timesteps.cpu(), epoch, batch_idx, global_step ) # Extract SD1.5 features (spatial [B, C, H, W]) with torch.no_grad(): teacher_features_spatial = extractor.extract_features(prompts, timesteps) # Pool to [B, C] teacher_features = spatial_pool_features(teacher_features_spatial, pool_mode) features_dict = { name: feat.clone() + 0.01 * torch.randn_like(feat) for name, feat in teacher_features.items() } # Forward pass outputs = collective(features_dict, timesteps.float()) # Compute loss (now internal to model) loss, metrics = collective.compute_loss( outputs, teacher_features, timesteps.float() ) # Backward pass optimizer.zero_grad() loss.backward() # Gradient clipping grad_norm = torch.nn.utils.clip_grad_norm_( collective.parameters(), max_norm=1.0 ) optimizer.step() scheduler.step() # Accumulate metrics batch_metrics = { 'total_loss': metrics['total_loss'], 'avg_cayley': metrics['avg/cayley'], 'avg_timestep_acc': metrics['avg/timestep_acc'], 'avg_pattern_acc': metrics['avg/pattern_acc'], 'avg_full_acc': metrics['avg/full_acc'] } for k, v in batch_metrics.items(): epoch_metrics[k] += v epoch_metrics['num_batches'] += 1 # TensorBoard logging (every step) writer.add_scalar('Train/total_loss', batch_metrics['total_loss'], global_step) writer.add_scalar('Train/cayley', batch_metrics['avg_cayley'], global_step) writer.add_scalar('Train/timestep_acc', batch_metrics['avg_timestep_acc'], global_step) writer.add_scalar('Train/pattern_acc', batch_metrics['avg_pattern_acc'], global_step) writer.add_scalar('Train/full_acc', batch_metrics['avg_full_acc'], global_step) writer.add_scalar('Train/grad_norm', grad_norm.item(), global_step) writer.add_scalar('Train/lr', optimizer.param_groups[0]['lr'], global_step) # Update progress bar pbar.set_postfix({ 'loss': f"{batch_metrics['total_loss']:.4f}", 'cayley': f"{batch_metrics['avg_cayley']:.4f}", 't_acc': f"{batch_metrics['avg_timestep_acc']:.1%}", 'p_acc': f"{batch_metrics['avg_pattern_acc']:.1%}", 'f_acc': f"{batch_metrics['avg_full_acc']:.1%}" }) global_step += 1 # Cleanup del teacher_features_spatial, teacher_features, features_dict, outputs, loss torch.cuda.empty_cache() # Epoch summary for k in ['total_loss', 'avg_cayley', 'avg_timestep_acc', 'avg_pattern_acc', 'avg_full_acc']: avg = epoch_metrics[k] / epoch_metrics['num_batches'] history[k].append(avg) writer.add_scalar(f'Epoch/{k}', avg, epoch) print(f"\nEpoch {epoch+1} Summary:") print(f" Loss: {history['total_loss'][-1]:.4f}") print(f" Cayley: {history['avg_cayley'][-1]:.4f}") print(f" Timestep Acc: {history['avg_timestep_acc'][-1]:.2%}") print(f" Pattern Acc: {history['avg_pattern_acc'][-1]:.2%}") print(f" Full Acc: {history['avg_full_acc'][-1]:.2%}") # Get Cantor alphas alphas = collective.get_cantor_alphas() print(f" Cantor Alphas: {', '.join([f'{k}={v:.3f}' for k, v in list(alphas.items())[:]])}") # Save checkpoint if (epoch + 1) % checkpoint_interval == 0: checkpoint_path = Path(checkpoint_dir) / f"checkpoint_epoch_{epoch+1:03d}.pt" torch.save({ 'epoch': epoch + 1, 'global_step': global_step, 'model_state_dict': collective.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'history': history, 'model_info': collective.get_model_info() }, checkpoint_path) print(f" āœ“ Saved: {checkpoint_path}") # Convert to safetensors if HF_AVAILABLE: safetensors_path = checkpoint_path.with_suffix('.safetensors') save_file(collective.state_dict(), str(safetensors_path)) print(f" āœ“ Safetensors: {safetensors_path}") # Final checkpoint final_path = Path(checkpoint_dir) / "final.pt" torch.save({ 'epoch': num_epochs, 'global_step': global_step, 'model_state_dict': collective.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'history': history, 'model_info': collective.get_model_info() }, final_path) print(f"\nāœ… Final checkpoint: {final_path}") # Prompt stats prompt_stats = prompt_logger.get_stats() print(f"āœ… Prompts logged: {prompt_stats['total']:,} ({prompt_stats['size_mb']:.2f} MB)") writer.close() return collective, history # ============================================================================ # MAIN # ============================================================================ def main(): print("\n" + "="*80) print("GEODAVIDCOLLECTIVE TRAINER - ENHANCED VERSION") print("ProjectiveHead multi-expert architecture with proven data pipeline") print("="*80) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"\nDevice: {device}") if device == "cpu": print("āš ļø WARNING: Training requires GPU!") return # ======================================================================== # CONFIGURATION - ENHANCED # ======================================================================== # Block configurations with ProjectiveHead parameters # These use auto-configuration based on scale_dim, but you can override block_configs = { # Down blocks (4) 'down_0': { 'input_dim': 320, 'scale_dim': 128, # Compressed for efficiency 'use_belly': True, 'belly_expand': 2.0, # ProjectiveHead auto-configured (3 experts, 3 gates) }, 'down_1': { 'input_dim': 640, 'scale_dim': 192, 'use_belly': True, 'belly_expand': 2.0, # ProjectiveHead auto-configured (3 experts, 3 gates) }, 'down_2': { 'input_dim': 1280, 'scale_dim': 256, 'use_belly': True, 'belly_expand': 2.0, # ProjectiveHead auto-configured (3 experts, 3 gates) }, 'down_3': { 'input_dim': 1280, 'scale_dim': 256, 'use_belly': True, 'belly_expand': 2.0, # ProjectiveHead auto-configured (3 experts, 3 gates) }, # Mid block (1) - Most important, use higher capacity 'mid': { 'input_dim': 1280, 'scale_dim': 256, 'use_belly': True, 'belly_expand': 1.5, # Custom ProjectiveHead: more experts for mid block 'num_experts': 4, 'num_gate_heads': 4, }, # Up blocks (4) 'up_0': { 'input_dim': 1280, 'scale_dim': 256, 'use_belly': True, 'belly_expand': 2.0, # ProjectiveHead auto-configured }, 'up_1': { 'input_dim': 1280, 'scale_dim': 256, 'use_belly': True, 'belly_expand': 2.0, # ProjectiveHead auto-configured }, 'up_2': { 'input_dim': 640, 'scale_dim': 192, 'use_belly': True, 'belly_expand': 2.0, # ProjectiveHead auto-configured }, 'up_3': { 'input_dim': 320, 'scale_dim': 128, 'use_belly': True, 'belly_expand': 1.5, # ProjectiveHead auto-configured } } # Block importance weights (mid-block most important) block_weights = { 'down_0': 0.8, 'down_1': 1.0, 'down_2': 1.2, 'down_3': 1.3, 'mid': 1.5, # Highest importance 'up_0': 1.3, 'up_1': 1.2, 'up_2': 1.0, 'up_3': 0.8 } # Geometric loss configuration - FIXED cayley_weight loss_config = { 'feature_similarity_weight': 0.4, 'rose_weight': 0.25, 'ce_weight': 0.15, 'pattern_diversity_weight': 0.05, 'cayley_weight': 0.10, # FIXED: Was 0.0001, now 0.10 for proper geometry 'cantor_coherence_weight': 0.05, 'use_soft_assignment': True, 'temperature': 0.1, # Cayley loss parameters 'cayley_volume_floor': 1e-4, 'cayley_chaos_scale': 1.0, 'cayley_edge_weight': 0.5, 'cayley_gram_weight': 0.1, } print("\nāœ“ Configuration loaded (ENHANCED)") print(f" Blocks: {len(block_configs)}") print(f" ProjectiveHead: Auto-configured based on scale_dim") print(f" Loss weights: feature={loss_config['feature_similarity_weight']:.2f}, " f"rose={loss_config['rose_weight']:.2f}, cayley={loss_config['cayley_weight']:.2f}") # ======================================================================== # LOAD SD1.5 # ======================================================================== print(f"\n[1/4] Loading SD1.5...") extractor = StreamingSD15Extractor( model_id="runwayml/stable-diffusion-v1-5", device=device, active_blocks=list(block_configs.keys()) ) # ======================================================================== # CREATE DATASET # ======================================================================== print(f"\n[2/4] Creating symbolic dataset...") dataset = SymbolicPromptDataset( num_samples=10000, complexity_distribution={ 1: 0.05, 2: 0.15, 3: 0.40, 4: 0.25, 5: 0.15 }, seed=42 ) dataloader = DataLoader( dataset, batch_size=16, # Adjusted for GPU memory shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_symbolic_batch ) print(f" āœ“ Dataset: {len(dataset):,} samples") print(f" āœ“ Batch size: 16") # ======================================================================== # INITIALIZE MODEL - ENHANCED # ======================================================================== print(f"\n[3/4] Initializing GeoDavidCollective (ENHANCED)...") collective = GeoDavidCollective( block_configs=block_configs, num_timestep_bins=100, num_patterns_per_bin=10, block_weights=block_weights, loss_config=loss_config ) model_info = collective.get_model_info() print(f" āœ“ Architecture: {model_info['architecture']}") print(f" āœ“ Blocks: {model_info['num_blocks']}") print(f" āœ“ Total parameters: {model_info['total_parameters']:,}") print(f" āœ“ Timestep bins: {model_info['num_timestep_bins']}") print(f" āœ“ Patterns per bin: {model_info['num_patterns_per_bin']}") # Show ProjectiveHead configs print(f"\n ProjectiveHead Configurations:") for block_name, companion_info in list(model_info['companions'].items())[:3]: print(f" {block_name}:") print(f" Timestep head: {companion_info['timestep_head']['num_experts']} experts, " f"{companion_info['timestep_head']['num_gate_heads']} gates") print(f" ... and {len(model_info['companions'])-3} more blocks") # ======================================================================== # TRAIN # ======================================================================== print(f"\n[4/4] Starting training...") collective, history = train_geo_collective( collective=collective, extractor=extractor, dataloader=dataloader, num_epochs=10, device=device, learning_rate=1e-3, weight_decay=0.001, log_dir="./runs/geo_collective_enhanced", prompt_log_path="./prompts_enhanced.jsonl", checkpoint_interval=2, checkpoint_dir="./checkpoints_enhanced", pool_mode='mean' ) print("\n" + "="*80) print("TRAINING COMPLETE!") print("="*80) print(f"\nšŸ“Š Final Metrics:") print(f" Loss: {history['total_loss'][-1]:.4f}") print(f" Cayley: {history['avg_cayley'][-1]:.4f}") print(f" Timestep Acc: {history['avg_timestep_acc'][-1]:.2%}") print(f" Pattern Acc: {history['avg_pattern_acc'][-1]:.2%}") print(f" Full Acc: {history['avg_full_acc'][-1]:.2%}") return collective, history if __name__ == "__main__": collective, history = main()