""" Production-Ready Model Optimizations Stable, tested optimizations for ChatterboxTTS deployment """ import torch import logging import time import hashlib import gc from typing import Dict, Any, Optional, List from pathlib import Path class ProductionOptimizer: """Production-ready optimization wrapper for ChatterboxTTS""" def __init__(self, model, enable_caching=True, enable_memory_optimization=True): self.model = model self.enable_caching = enable_caching self.enable_memory_optimization = enable_memory_optimization self.logger = logging.getLogger(__name__) # Initialize components self.text_cache = TextCache() if enable_caching else None self.memory_manager = MemoryManager() if enable_memory_optimization else None # Performance tracking self.stats = { 'total_generations': 0, 'cache_hits': 0, 'memory_cleanups': 0, 'total_time_saved': 0.0, 'successful_generations': 0 } # Apply safe optimizations self._apply_safe_optimizations() self.logger.info("🚀 Production optimizer initialized") def _apply_safe_optimizations(self): """Apply only safe, tested optimizations""" optimizations = [] # 1. Disable gradient computation (safe for inference) torch.set_grad_enabled(False) optimizations.append("gradients_disabled") # 2. Enable cuDNN benchmark (safe for consistent input sizes) if torch.backends.cudnn.is_available(): torch.backends.cudnn.benchmark = True optimizations.append("cudnn_benchmark") # 3. Set models to eval mode if available if hasattr(self.model, 't3') and hasattr(self.model.t3, 'eval'): self.model.t3.eval() optimizations.append("t3_eval") if hasattr(self.model, 's3gen') and hasattr(self.model.s3gen, 'eval'): self.model.s3gen.eval() optimizations.append("s3gen_eval") self.logger.info(f"✅ Applied safe optimizations: {', '.join(optimizations)}") def generate(self, text: str, **kwargs) -> torch.Tensor: """Optimized generation with production-safe enhancements""" start_time = time.time() self.stats['total_generations'] += 1 # Check cache first if self.text_cache: cached_result = self.text_cache.get(text, kwargs) if cached_result is not None: self.stats['cache_hits'] += 1 self.logger.debug(f"📋 Cache hit for: '{text[:30]}...'") return cached_result # Memory management before generation if self.memory_manager and self.memory_manager.should_cleanup(): freed = self.memory_manager.cleanup() if freed > 0: self.stats['memory_cleanups'] += 1 try: # Generate with inference mode for better performance with torch.inference_mode(): audio = self.model.generate(text, **kwargs) # Cache the result if self.text_cache: self.text_cache.put(text, audio, kwargs) self.stats['successful_generations'] += 1 # Track performance generation_time = time.time() - start_time # Estimate baseline time (conservative 5% overhead estimate) estimated_baseline = generation_time * 1.05 time_saved = estimated_baseline - generation_time self.stats['total_time_saved'] += time_saved return audio except Exception as e: self.logger.error(f"❌ Generation failed for text '{text[:30]}...': {e}") raise def prepare_conditionals(self, *args, **kwargs): """Optimized voice conditioning""" with torch.inference_mode(): return self.model.prepare_conditionals(*args, **kwargs) def get_stats(self) -> Dict[str, Any]: """Get performance statistics""" stats = self.stats.copy() if self.text_cache: cache_stats = self.text_cache.get_stats() stats.update(cache_stats) if self.memory_manager: memory_stats = self.memory_manager.get_stats() stats.update(memory_stats) # Calculate derived metrics if stats['total_generations'] > 0: stats['cache_hit_rate'] = (stats['cache_hits'] / stats['total_generations']) * 100 stats['success_rate'] = (stats['successful_generations'] / stats['total_generations']) * 100 stats['avg_time_saved'] = stats['total_time_saved'] / stats['total_generations'] return stats def log_performance_summary(self): """Log comprehensive performance summary""" stats = self.get_stats() self.logger.info("📊 PRODUCTION OPTIMIZER SUMMARY") self.logger.info("=" * 50) self.logger.info(f"Total generations: {stats['total_generations']}") self.logger.info(f"Successful generations: {stats['successful_generations']}") self.logger.info(f"Success rate: {stats.get('success_rate', 0):.1f}%") self.logger.info(f"Cache hits: {stats['cache_hits']}") self.logger.info(f"Cache hit rate: {stats.get('cache_hit_rate', 0):.1f}%") self.logger.info(f"Memory cleanups: {stats['memory_cleanups']}") self.logger.info(f"Total time saved: {stats['total_time_saved']:.2f}s") self.logger.info(f"Average time saved per generation: {stats.get('avg_time_saved', 0):.3f}s") class TextCache: """Lightweight text caching for repeated generations""" def __init__(self, max_size=500): self.cache = {} self.access_order = [] # For LRU eviction self.max_size = max_size self.hits = 0 self.misses = 0 def _make_key(self, text: str, params: Dict = None) -> str: """Create cache key from text and parameters""" key_data = text if params: # Include parameters that affect audio generation relevant_params = { 'exaggeration': params.get('exaggeration'), 'cfg_weight': params.get('cfg_weight'), 'temperature': params.get('temperature'), 'min_p': params.get('min_p'), 'repetition_penalty': params.get('repetition_penalty') } key_data += "|" + str(sorted(relevant_params.items())) return hashlib.md5(key_data.encode()).hexdigest() def get(self, text: str, params: Dict = None) -> Optional[torch.Tensor]: """Get cached result""" key = self._make_key(text, params) if key in self.cache: self.hits += 1 # Update access order for LRU if key in self.access_order: self.access_order.remove(key) self.access_order.append(key) return self.cache[key] self.misses += 1 return None def put(self, text: str, result: torch.Tensor, params: Dict = None): """Store result in cache""" if len(text) > 500: # Don't cache very long texts return key = self._make_key(text, params) # Evict oldest if at capacity if len(self.cache) >= self.max_size and key not in self.cache: oldest_key = self.access_order.pop(0) del self.cache[oldest_key] self.cache[key] = result.clone() # Store a copy to avoid reference issues # Update access order if key in self.access_order: self.access_order.remove(key) self.access_order.append(key) def get_stats(self) -> Dict[str, Any]: """Get cache statistics""" total = self.hits + self.misses return { 'cache_hits': self.hits, 'cache_misses': self.misses, 'cache_size': len(self.cache), 'cache_hit_rate': (self.hits / total * 100) if total > 0 else 0 } class MemoryManager: """Smart memory management for stable performance""" def __init__(self, cleanup_threshold_gb=5.5): self.cleanup_threshold_gb = cleanup_threshold_gb self.cleanups_performed = 0 self.total_memory_freed = 0.0 def get_memory_usage(self) -> float: """Get current GPU memory usage in GB""" if torch.cuda.is_available(): return torch.cuda.memory_allocated() / (1024 ** 3) return 0.0 def should_cleanup(self) -> bool: """Check if memory cleanup is needed""" return self.get_memory_usage() > self.cleanup_threshold_gb def cleanup(self) -> float: """Perform memory cleanup and return amount freed""" initial_memory = self.get_memory_usage() # Standard cleanup gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() final_memory = self.get_memory_usage() freed = initial_memory - final_memory if freed > 0.1: # Significant cleanup self.cleanups_performed += 1 self.total_memory_freed += freed return freed def get_stats(self) -> Dict[str, Any]: """Get memory management statistics""" return { 'current_memory_gb': self.get_memory_usage(), 'memory_cleanups': self.cleanups_performed, 'total_memory_freed_gb': self.total_memory_freed, 'cleanup_threshold_gb': self.cleanup_threshold_gb } class BatchOptimizer: """Optimized batch processing for multiple texts""" def __init__(self, optimizer: ProductionOptimizer): self.optimizer = optimizer self.logger = logging.getLogger(__name__) def generate_batch(self, texts: List[str], **kwargs) -> List[torch.Tensor]: """Generate audio for multiple texts with optimizations""" if not texts: return [] self.logger.info(f"🔥 Processing batch of {len(texts)} texts") start_time = time.time() results = [] cache_hits = 0 for i, text in enumerate(texts): try: # Check cache first if self.optimizer.text_cache: cached = self.optimizer.text_cache.get(text, kwargs) if cached is not None: results.append(cached) cache_hits += 1 continue # Generate audio audio = self.optimizer.generate(text, **kwargs) results.append(audio) # Progress logging for large batches if len(texts) > 10 and (i + 1) % 5 == 0: self.logger.info(f" Processed {i + 1}/{len(texts)} texts") except Exception as e: self.logger.error(f"❌ Failed to process text {i + 1}: {e}") # Create silent audio as fallback results.append(torch.zeros(1, 24000)) total_time = time.time() - start_time cache_hit_rate = (cache_hits / len(texts)) * 100 self.logger.info(f"✅ Batch completed in {total_time:.2f}s") self.logger.info(f" Cache hit rate: {cache_hit_rate:.1f}%") self.logger.info(f" Average time per text: {total_time / len(texts):.2f}s") return results def create_production_optimizer(model, enable_caching=True, enable_memory_optimization=True): """Create a production-ready optimized wrapper for ChatterboxTTS""" return ProductionOptimizer( model, enable_caching=enable_caching, enable_memory_optimization=enable_memory_optimization ) def create_batch_optimizer(model): """Create an optimized batch processor""" prod_optimizer = create_production_optimizer(model) return BatchOptimizer(prod_optimizer)