diff --git a/HF_Deploy/.gitattributes b/HF_Deploy/.gitattributes deleted file mode 100644 index 83cfd8dbb643612f79f25d84b65ac7e4b3c4fb7f..0000000000000000000000000000000000000000 --- a/HF_Deploy/.gitattributes +++ /dev/null @@ -1,36 +0,0 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text -*.wav filter=lfs diff=lfs merge=lfs -text diff --git a/HF_Deploy/.gitignore b/HF_Deploy/.gitignore deleted file mode 100644 index 7a60b85e148f80966a550e5ab6a762a907c69ca6..0000000000000000000000000000000000000000 --- a/HF_Deploy/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -__pycache__/ -*.pyc diff --git a/HF_Deploy/README.md b/HF_Deploy/README.md deleted file mode 100644 index 7e8bf3450430c0ab63a515d520664a456d29252b..0000000000000000000000000000000000000000 --- a/HF_Deploy/README.md +++ /dev/null @@ -1,14 +0,0 @@ ---- -title: ChatterboxTTS DNXS Spokenword -emoji: πŸŒ– -colorFrom: blue -colorTo: red -sdk: gradio -sdk_version: 5.39.0 -app_file: app.py -pinned: false -license: apache-2.0 -short_description: 'ChatterboxTTS Gradio interface for custom workflow. ' ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/HF_Deploy/Text_Input/Goliath/test1.txt b/HF_Deploy/Text_Input/Goliath/test1.txt deleted file mode 100644 index d1da22da6218a4a9a9a02cddc7e540fd792e84f3..0000000000000000000000000000000000000000 --- a/HF_Deploy/Text_Input/Goliath/test1.txt +++ /dev/null @@ -1,7 +0,0 @@ -My dear fellow, I spent a considerable length of time in the Peninsula. I was with a British rifle brigade when I met Sir Arthur Wellesley. And I was a prisoner of the French at Salamanca - 18 12 I think it was. I always find it’s best to see both sides of both sides, if you see what I mean. - -I didn’t really think you approved of war sir, said Benton sadly. - -The Doctor turned his attention back to the twisting country lane. He sighed as he changed gear for another sharp corner. Sometimes it’s inevitable, he noted with genuine sadness. I’m a man of peace, but I seem to spend much of my time caught up in conflict. The central paradox of my life, perhaps. - -Benton leant back in the seat. What’s the central paradox of mine? he asked, fascinated. diff --git a/HF_Deploy/Text_Input/README.md b/HF_Deploy/Text_Input/README.md deleted file mode 100644 index a65e96f281ed935443775375c3b60c96700d6559..0000000000000000000000000000000000000000 --- a/HF_Deploy/Text_Input/README.md +++ /dev/null @@ -1,40 +0,0 @@ -# Text Input Directory - -Place your book text files here for audiobook generation. - -## Directory Structure -Create a subdirectory for each book: -``` -Text_Input/ -β”œβ”€β”€ Book Name 1/ -β”‚ β”œβ”€β”€ book.txt # Main text file -β”‚ β”œβ”€β”€ cover.jpg # Book cover image (optional) -β”‚ └── book.nfo # Metadata file (optional) -β”œβ”€β”€ Book Name 2/ -β”‚ β”œβ”€β”€ another_book.txt -β”‚ └── cover.png -└── ... -``` - -## Text File Requirements -- **Format**: Plain text (.txt) files -- **Encoding**: UTF-8 -- **Content**: Clean text without excessive formatting -- **Structure**: Use paragraph breaks for natural speech flow - -## Optional Files -- **cover.jpg/png**: Book cover image for M4B metadata -- **book.nfo**: XML metadata file with book information (title, author, etc.) - -## Text Preparation Tips -- Remove table of contents, page numbers, headers/footers -- Keep chapter headings (e.g., "Chapter 1") -- Use proper punctuation for natural speech -- Remove excessive line breaks or formatting -- Ensure UTF-8 encoding for special characters - -## Processing -1. Add your book directory to Text_Input/ -2. Run the main program and select your book -3. The system will chunk the text and generate JSON metadata -4. Use the generated chunks for TTS audiobook creation \ No newline at end of file diff --git a/HF_Deploy/Text_Input/test b/HF_Deploy/Text_Input/test deleted file mode 100644 index 65db84fece1a20f286f99c082b17d14cd28f3a8f..0000000000000000000000000000000000000000 --- a/HF_Deploy/Text_Input/test +++ /dev/null @@ -1,20 +0,0 @@ -She stood alone in the hallway. The lights flickered overhead. "I don't like this," she whispered. "Too quiet. Too cold." - - - -*** - -Chapter 1 - -A crash echoed from somewhere far off. -He turned. "Was that you?" - -"No," she said. "It wasn't me." - ---- - -They moved cautiously down the corridor. Every step sounded like thunder. Each shadow seemed to breathe. - -Chapter 2 - -Something moved behind the curtain. diff --git a/HF_Deploy/app.py b/HF_Deploy/app.py deleted file mode 100644 index 78502357a91344f0180705065b816fe6ff96d4e0..0000000000000000000000000000000000000000 --- a/HF_Deploy/app.py +++ /dev/null @@ -1,523 +0,0 @@ -#!/usr/bin/env python3 -""" -Comprehensive Gradio Launcher for ChatterboxTTS -Automatically handles all requirements, installation, and setup -""" - -import sys -import os -import subprocess -import importlib -import pkg_resources -from pathlib import Path -import time - -class GradioLauncher: - def __init__(self): - self.required_packages = { - # Core packages with fallbacks - 'gradio': {'min_version': '4.0.0', 'install_name': 'gradio>=4.0.0'}, - 'torch': {'min_version': '2.0.0', 'install_name': 'torch>=2.0.0'}, - 'torchaudio': {'min_version': '2.0.0', 'install_name': 'torchaudio>=2.0.0'}, - 'transformers': {'min_version': '4.20.0', 'install_name': 'transformers>=4.20.0'}, - 'huggingface_hub': {'min_version': '0.15.0', 'install_name': 'huggingface_hub>=0.15.0'}, - 'safetensors': {'min_version': '0.3.0', 'install_name': 'safetensors>=0.3.0'}, - - # Audio processing - 'soundfile': {'min_version': '0.12.0', 'install_name': 'soundfile>=0.12.0'}, - 'librosa': {'min_version': '0.10.0', 'install_name': 'librosa>=0.10.0'}, - 'pydub': {'min_version': '0.25.0', 'install_name': 'pydub>=0.25.0'}, - - # Voice Analysis (optional but recommended) - 'parselmouth': {'min_version': '0.4.3', 'install_name': 'praat-parselmouth>=0.4.3', 'optional': True}, - 'matplotlib': {'min_version': '3.5.0', 'install_name': 'matplotlib>=3.5.0'}, - 'scipy': {'min_version': '1.8.0', 'install_name': 'scipy>=1.8.0'}, - 'numpy': {'min_version': '1.21.0', 'install_name': 'numpy>=1.21.0'}, - - # System utilities - 'psutil': {'min_version': '5.8.0', 'install_name': 'psutil>=5.8.0'}, - 'vaderSentiment': {'min_version': '3.3.0', 'install_name': 'vaderSentiment>=3.3.0'}, - } - - self.chatterbox_git_url = 'git+https://github.com/resemble-ai/chatterbox-tts.git' - self.optional_packages = ['parselmouth', 'pynvml'] - - def print_header(self): - """Print launcher header""" - print("=" * 70) - print("πŸš€ ChatterboxTTS Gradio Launcher") - print("=" * 70) - print("πŸ”§ Comprehensive setup and dependency manager") - print("πŸ“¦ Automatically installs missing requirements") - print("🌐 Launches web interface when ready") - print("-" * 70) - - def check_python_version(self): - """Check if Python version is compatible""" - print("🐍 Checking Python version...") - - version_info = sys.version_info - if version_info.major < 3 or (version_info.major == 3 and version_info.minor < 8): - print("❌ Error: Python 3.8+ required") - print(f" Current version: {version_info.major}.{version_info.minor}.{version_info.micro}") - print(" Please upgrade Python and try again") - sys.exit(1) - - print(f"βœ… Python {version_info.major}.{version_info.minor}.{version_info.micro} - Compatible") - - def check_working_directory(self): - """Verify we're in the correct directory""" - print("πŸ“ Checking working directory...") - - - if missing_files: - print(f"❌ Error: Missing required files/directories: {', '.join(missing_files)}") - print(" Please run this script from the ChatterboxTTS root directory") - print(" Expected structure:") - print(" β”œβ”€β”€ gradio_main_interface.py") - print(" β”œβ”€β”€ gradio_tabs/") - print(" β”œβ”€β”€ config/") - print(" β”œβ”€β”€ src/") - print(" └── ...") - return False - - print("βœ… Working directory structure verified") - return True - - def create_directories(self): - """Create required directories if they don't exist""" - print("πŸ“‚ Creating required directories...") - - directories = ['Voice_Samples', 'Text_Input', 'Audiobook', 'Output', 'voice_analyzer'] - created = [] - - for dir_name in directories: - dir_path = Path(dir_name) - if not dir_path.exists(): - dir_path.mkdir(parents=True, exist_ok=True) - created.append(dir_name) - - if created: - print(f"βœ… Created directories: {', '.join(created)}") - else: - print("βœ… All required directories exist") - - def check_package_installed(self, package_name): - """Check if a package is installed and get its version""" - # If we have a virtual environment, check there first - if hasattr(self, 'venv_python') and Path(self.venv_python).exists(): - try: - cmd = [self.venv_python, '-c', f''' -try: - import {package_name} - print("INSTALLED", getattr({package_name}, "__version__", "0.0.0")) -except ImportError: - print("NOT_INSTALLED") -'''] - result = subprocess.run(cmd, capture_output=True, text=True, timeout=10) - if result.returncode == 0: - output = result.stdout.strip() - if output.startswith("INSTALLED"): - version = output.split(" ", 1)[1] if " " in output else "0.0.0" - return True, version - else: - return False, None - except Exception: - pass # Fall back to local check - - # Fallback to local Python environment check - try: - if package_name == 'parselmouth': - # Special case for praat-parselmouth - import parselmouth - return True, getattr(parselmouth, '__version__', '0.0.0') - else: - module = importlib.import_module(package_name) - version = getattr(module, '__version__', '0.0.0') - return True, version - except ImportError: - try: - # Try with pkg_resources as fallback - pkg = pkg_resources.get_distribution(package_name) - return True, pkg.version - except (pkg_resources.DistributionNotFound, ImportError): - return False, None - - def compare_versions(self, current, required): - """Compare version strings""" - try: - current_parts = [int(x) for x in current.split('.')] - required_parts = [int(x) for x in required.split('.')] - - # Pad shorter version with zeros - max_len = max(len(current_parts), len(required_parts)) - current_parts.extend([0] * (max_len - len(current_parts))) - required_parts.extend([0] * (max_len - len(required_parts))) - - return current_parts >= required_parts - except (ValueError, AttributeError): - # If we can't parse versions, assume it's okay - return True - - def setup_virtual_environment(self): - """Set up virtual environment if in externally managed environment""" - venv_path = Path("venv") - - if not venv_path.exists(): - print("πŸ”§ Creating virtual environment (externally managed Python detected)...") - try: - result = subprocess.run( - [sys.executable, '-m', 'venv', 'venv'], - capture_output=True, - text=True, - timeout=60 - ) - if result.returncode != 0: - print(f" ❌ Failed to create virtual environment: {result.stderr}") - return False - print(" βœ… Virtual environment created") - except Exception as e: - print(f" ❌ Error creating virtual environment: {e}") - return False - else: - print("πŸ”§ Using existing virtual environment...") - - # Update sys.executable to use venv python - if os.name == 'nt': # Windows - self.venv_python = str(venv_path / "Scripts" / "python.exe") - self.venv_pip = str(venv_path / "Scripts" / "pip.exe") - else: # Unix/Linux/Mac - self.venv_python = str(venv_path / "bin" / "python") - self.venv_pip = str(venv_path / "bin" / "pip") - - # Verify venv python works - try: - result = subprocess.run([self.venv_python, '--version'], capture_output=True, text=True) - if result.returncode == 0: - print(f" βœ… Virtual environment Python: {result.stdout.strip()}") - return True - else: - print(" ❌ Virtual environment Python not working") - return False - except Exception as e: - print(f" ❌ Error testing virtual environment: {e}") - return False - - def install_package(self, package_spec): - """Install a package using pip (with virtual environment support)""" - try: - print(f" Installing {package_spec}...") - - # Use venv pip if available, otherwise system pip - pip_executable = getattr(self, 'venv_pip', None) - if pip_executable and Path(pip_executable).exists(): - cmd = [pip_executable, 'install', package_spec] - else: - cmd = [sys.executable, '-m', 'pip', 'install', package_spec] - - result = subprocess.run( - cmd, - capture_output=True, - text=True, - timeout=300 # 5 minute timeout - ) - - if result.returncode == 0: - print(f" βœ… Successfully installed {package_spec}") - return True - else: - print(f" ❌ Failed to install {package_spec}") - print(f" Error: {result.stderr}") - - # If we get externally-managed error, try setting up venv - if "externally-managed-environment" in result.stderr and not hasattr(self, 'venv_python'): - print(" πŸ”„ Detected externally managed environment, setting up virtual environment...") - if self.setup_virtual_environment(): - # Retry installation with venv - return self.install_package(package_spec) - - return False - - except subprocess.TimeoutExpired: - print(f" ⏰ Installation of {package_spec} timed out") - return False - except Exception as e: - print(f" ❌ Error installing {package_spec}: {str(e)}") - return False - - def check_and_install_requirements(self): - """Check and install all required packages""" - print("πŸ“¦ Checking package requirements...") - - missing_packages = [] - outdated_packages = [] - optional_missing = [] - - # Check each required package - for package_name, info in self.required_packages.items(): - is_installed, current_version = self.check_package_installed(package_name) - min_version = info['min_version'] - is_optional = info.get('optional', False) - - if not is_installed: - if is_optional: - optional_missing.append((package_name, info)) - print(f" ⚠️ Optional package missing: {package_name}") - else: - missing_packages.append((package_name, info)) - print(f" ❌ Missing required package: {package_name}") - elif current_version and not self.compare_versions(current_version, min_version): - if is_optional: - print(f" ⚠️ Optional package outdated: {package_name} {current_version} < {min_version}") - else: - outdated_packages.append((package_name, info)) - print(f" ❌ Outdated package: {package_name} {current_version} < {min_version}") - else: - status = "βœ…" if not is_optional else "πŸ”§" - print(f" {status} {package_name}: {current_version}") - - # Install missing/outdated packages - if missing_packages or outdated_packages: - print(f"\nπŸ”§ Installing {len(missing_packages + outdated_packages)} required packages...") - - for package_name, info in missing_packages + outdated_packages: - install_spec = info['install_name'] - if not self.install_package(install_spec): - print(f"❌ Critical error: Failed to install {package_name}") - return False - - # Install ChatterboxTTS if not available - print("🎀 Checking ChatterboxTTS installation...") - try: - import chatterbox - print(" βœ… ChatterboxTTS already installed") - except ImportError: - print(" πŸ“₯ Installing ChatterboxTTS from GitHub...") - if not self.install_package(self.chatterbox_git_url): - print(" ⚠️ ChatterboxTTS installation failed - some features may not work") - - # Try to install optional packages - if optional_missing: - print(f"\n🎯 Installing {len(optional_missing)} optional packages...") - for package_name, info in optional_missing: - install_spec = info['install_name'] - if self.install_package(install_spec): - print(f" βœ… Optional package {package_name} installed successfully") - else: - print(f" ⚠️ Optional package {package_name} failed - voice analysis may be limited") - - return True - - def check_gpu_availability(self): - """Check for GPU availability""" - print("πŸ–₯️ Checking GPU availability...") - - try: - import torch - if torch.cuda.is_available(): - gpu_count = torch.cuda.device_count() - gpu_name = torch.cuda.get_device_name(0) - print(f" βœ… CUDA GPU available: {gpu_name} ({gpu_count} device{'s' if gpu_count > 1 else ''})") - return True - elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): - print(" βœ… Apple Metal Performance Shaders (MPS) available") - return True - else: - print(" ⚠️ No GPU acceleration available - using CPU") - print(" πŸ’‘ For better performance, consider using a GPU-enabled environment") - return False - except Exception as e: - print(f" ❌ Error checking GPU: {str(e)}") - return False - - def verify_installation(self): - """Verify that all components can be imported""" - print("πŸ” Verifying installation...") - - critical_imports = [ - ('gradio', 'Gradio web interface'), - ('torch', 'PyTorch machine learning'), - ('transformers', 'Hugging Face transformers'), - ('librosa', 'Audio processing'), - ('soundfile', 'Audio file I/O'), - ('numpy', 'Numerical computing'), - ('matplotlib', 'Plotting and visualization') - ] - - optional_imports = [ - ('parselmouth', 'Praat voice analysis'), - ('scipy', 'Scientific computing'), - ('psutil', 'System monitoring') - ] - - failed_critical = [] - failed_optional = [] - - # Check critical imports - for module_name, description in critical_imports: - try: - importlib.import_module(module_name) - print(f" βœ… {description}") - except ImportError as e: - print(f" ❌ {description}: {str(e)}") - failed_critical.append(module_name) - - # Check optional imports - for module_name, description in optional_imports: - try: - importlib.import_module(module_name) - print(f" πŸ”§ {description}") - except ImportError: - print(f" ⚠️ {description}: Not available") - failed_optional.append(module_name) - - if failed_critical: - print(f"\n❌ Critical imports failed: {', '.join(failed_critical)}") - print(" The interface may not work properly") - return False - - if failed_optional: - print(f"\n⚠️ Optional features unavailable: {', '.join(failed_optional)}") - print(" Voice analysis features may be limited") - - print("βœ… Installation verification complete") - return True - - def launch_interface(self): - """Launch the Gradio interface""" - print("\nπŸš€ Launching ChatterboxTTS Gradio Interface...") - print("-" * 50) - - # If we're using a virtual environment, launch with venv python - if hasattr(self, 'venv_python') and Path(self.venv_python).exists(): - print("πŸ”§ Using virtual environment Python...") - try: - print("🌐 Starting web server...") - print("πŸ“± Interface will be available in your browser") - print("πŸ”— Default URL: http://localhost:7860") - - if os.getenv("RUNPOD_POD_ID"): - print("☁️ RunPod deployment detected") - elif os.getenv("COLAB_GPU"): - print("☁️ Google Colab detected - sharing link will be generated") - - print("\n" + "=" * 50) - print("πŸŽ‰ LAUNCHING CHATTERBOX TTS!") - print("=" * 50) - - # Launch using virtual environment python - subprocess.run([self.venv_python, "gradio_main_interface.py"]) - - except KeyboardInterrupt: - print("\n\nπŸ‘‹ Shutdown requested by user") - print(" Thanks for using ChatterboxTTS!") - sys.exit(0) - except Exception as e: - print(f"\n❌ Error launching with virtual environment: {str(e)}") - print(" Falling back to direct import...") - self._launch_direct() - else: - self._launch_direct() - - def _launch_direct(self): - """Launch interface by direct import""" - try: - # Import and launch - from gradio_main_interface import launch_interface - - print("🌐 Starting web server...") - print("πŸ“± Interface will be available in your browser") - print("πŸ”— Default URL: http://localhost:7860") - - if os.getenv("RUNPOD_POD_ID"): - print("☁️ RunPod deployment detected") - elif os.getenv("COLAB_GPU"): - print("☁️ Google Colab detected - sharing link will be generated") - - print("\n" + "=" * 50) - print("πŸŽ‰ LAUNCHING CHATTERBOX TTS!") - print("=" * 50) - - # Small delay for user to read messages - time.sleep(2) - - # Launch the interface - launch_interface() - - except KeyboardInterrupt: - print("\n\nπŸ‘‹ Shutdown requested by user") - print(" Thanks for using ChatterboxTTS!") - sys.exit(0) - except Exception as e: - print(f"\n❌ Error launching interface: {str(e)}") - print("\nTroubleshooting tips:") - print("1. Check that all dependencies are installed") - print("2. Verify you're in the correct directory") - if hasattr(self, 'venv_python'): - print(f"3. Try running: {self.venv_python} gradio_main_interface.py") - else: - print("3. Try running: python3 gradio_main_interface.py") - sys.exit(1) - - def run(self): - """Run the complete launcher process""" - self.print_header() - - # Step 1: Check Python version - self.check_python_version() - - # Step 2: Check working directory - if not self.check_working_directory(): - sys.exit(1) - - # Step 3: Create required directories - self.create_directories() - - # Step 4: Check and install requirements - if not self.check_and_install_requirements(): - print("\n❌ Failed to install required packages") - sys.exit(1) - - # Step 5: Check GPU availability - self.check_gpu_availability() - - # Step 6: Verify installation - if not self.verify_installation(): - print("\n⚠️ Installation verification failed") - print(" Proceeding anyway - some features may not work") - - # Step 7: Launch interface - self.launch_interface() - -def main(): - """Main entry point""" - launcher = GradioLauncher() - launcher.run() - -if __name__ == "__main__": - # Add current directory to Python path for HF Spaces - import sys - import os - sys.path.append(os.path.dirname(os.path.abspath(__file__))) - - # Fix OpenMP environment variable for HuggingFace Spaces - os.environ["OMP_NUM_THREADS"] = "1" - - # Skip launcher logic for HF Spaces, run interface directly - try: - # Import the actual Gradio interface - import gradio_main_interface - - # Create and launch the interface - demo = gradio_main_interface.create_main_interface() - demo.launch( - server_name="0.0.0.0", - server_port=7860, - share=False, - show_error=True - ) - except ImportError as e: - print(f"❌ Failed to import gradio_main_interface: {e}") - # Fallback to launcher if needed - launcher = GradioLauncher() - launcher.launch_interface() diff --git a/HF_Deploy/config/__init__.py b/HF_Deploy/config/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/HF_Deploy/config/config.py b/HF_Deploy/config/config.py deleted file mode 100644 index cf1d3a0ed877541d16e542a5571ea6dcbf174fe4..0000000000000000000000000000000000000000 --- a/HF_Deploy/config/config.py +++ /dev/null @@ -1,159 +0,0 @@ -""" -GenTTS Configuration Module -Central location for all settings, paths, and feature toggles -""" - -import os -from pathlib import Path - -# ============================================================================ -# CORE DIRECTORIES -# ============================================================================ -TEXT_INPUT_ROOT = Path("Text_Input") -AUDIOBOOK_ROOT = Path("Audiobook") -VOICE_SAMPLES_DIR = Path("Voice_Samples") - -# ============================================================================ -# TEXT PROCESSING SETTINGS -# ============================================================================ -MAX_CHUNK_WORDS = 28 -MIN_CHUNK_WORDS = 4 - -# ============================================================================ -# WORKER AND PERFORMANCE SETTINGS -# ============================================================================ -MAX_WORKERS = 2 # Keep at 2 - GPU utilization already high -TEST_MAX_WORKERS = 6 # For experimentation -USE_DYNAMIC_WORKERS = False # Toggle for testing -VRAM_SAFETY_THRESHOLD = 6.5 # GB - -# ============================================================================ -# AUDIO QUALITY SETTINGS -# ============================================================================ -ENABLE_MID_DROP_CHECK = False -ENABLE_ASR = False -ASR_WORKERS = 4 # Parallel ASR on CPU threads - -# ============================================================================ -# TTS HUM DETECTION SETTINGS -# ============================================================================ -ENABLE_HUM_DETECTION = False # Disabled for speed (re-enable if quality issues) -HUM_FREQ_MIN = 50 # Hz - Lower frequency bound for hum detection -HUM_FREQ_MAX = 200 # Hz - Upper frequency bound for hum detection -HUM_ENERGY_THRESHOLD = 0.3 # Ratio of hum energy to total energy (0.1-0.5 range) -HUM_STEADY_THRESHOLD = 0.6 # Ratio of segments with steady amplitude (0.5-0.8 range) -HUM_AMPLITUDE_MIN = 0.005 # Minimum RMS for steady hum detection -HUM_AMPLITUDE_MAX = 0.1 # Maximum RMS for steady hum detection - -# ============================================================================ -# AUDIO TRIMMING SETTINGS -# ============================================================================ -ENABLE_AUDIO_TRIMMING = True # Enable automatic audio trimming after TTS -SPEECH_ENDPOINT_THRESHOLD = 0.005 # RMS threshold to detect end of speech (more aggressive) -TRIMMING_BUFFER_MS = 50 # Small buffer after detected speech endpoint - -# ============================================================================ -# SILENCE DURATION SETTINGS (milliseconds) -# ============================================================================ -SILENCE_CHAPTER_START = 500 # Half second for chapter beginnings -SILENCE_CHAPTER_END = 800 # Longer pause before new chapter -SILENCE_SECTION_BREAK = 600 # Section transitions -SILENCE_PARAGRAPH_END = 300 # Standard paragraph breaks - -# Punctuation-specific silence settings (milliseconds) -SILENCE_COMMA = 150 # Brief pause after commas -SILENCE_SEMICOLON = 250 # Medium pause after semicolons -SILENCE_COLON = 300 # Pause after colons -SILENCE_PERIOD = 400 # Sentence end pause -SILENCE_QUESTION_MARK = 450 # Question pause (slightly longer) -SILENCE_EXCLAMATION = 400 # Exclamation pause -SILENCE_DASH = 200 # Em dash pause -SILENCE_ELLIPSIS = 350 # Ellipsis pause (suspense) -SILENCE_QUOTE_END = 250 # End of quoted speech - -# Chunk-level silence settings -ENABLE_CHUNK_END_SILENCE = True # Add silence to end of every chunk -CHUNK_END_SILENCE_MS = 200 # Default silence at end of each chunk - -# Content boundary silence settings (milliseconds) -SILENCE_PARAGRAPH_FALLBACK = 500 # Original paragraph logic fallback - -# ============================================================================ -# AUDIO NORMALIZATION SETTINGS -# ============================================================================ -ENABLE_NORMALIZATION = True # Global ON/OFF switch for normalization -NORMALIZATION_TYPE = "peak" # Options: "loudness", "peak", "simple", "none" -TARGET_LUFS = -16 # Target loudness (LUFS) for broadcast standard -TARGET_PEAK_DB = -1.5 # Target peak level (dB) to prevent clipping -TARGET_LRA = 11 # Target loudness range for consistency - -# ============================================================================ -# AUDIO PLAYBACK SPEED SETTINGS -# ============================================================================ -ATEMPO_SPEED = 0.95 # Playback speed multiplier (0.5-2.0 range, 1.0 = normal speed) - -# ============================================================================ -# ENVIRONMENT SETUP -# ============================================================================ -os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" -os.environ["TRANSFORMERS_NO_PROGRESS_BAR"] = "1" -os.environ["HF_TRANSFORMERS_NO_TQDM"] = "1" -os.environ["TORCH_HUB_DIR"] = "/tmp/torch_hub_silent" - -# ============================================================================ -# COLOR CODES FOR TERMINAL OUTPUT -# ============================================================================ -RESET = "\033[0m" -BOLD = "\033[1m" -RED = "\033[91m" -GREEN = "\033[92m" -YELLOW = "\033[93m" -CYAN = "\033[96m" - -# ============================================================================ -# TTS MODEL PARAMETERS (DEFAULTS) -# ============================================================================ -DEFAULT_EXAGGERATION = 0.4 # Emotion intensity (0.0-2.0 range) -DEFAULT_CFG_WEIGHT = 0.5 # Faithfulness to text (0.0-1.0 range) -DEFAULT_TEMPERATURE = 0.9 # Randomness/creativity (0.0-1.0 range) - -# ============================================================================ -# VADER SENTIMENT TO TTS PARAMETER MAPPING -# ============================================================================ -# These settings control how VADER sentiment analysis dynamically adjusts TTS parameters. -# The formula used is: new_param = base_param + (compound_score * sensitivity) -# The result is then clamped within the defined MIN/MAX range. - -# --- Base TTS Parameters (used as the starting point) --- -# These are the same as the main defaults, but listed here for clarity. -BASE_EXAGGERATION = DEFAULT_EXAGGERATION # Default: 1.0 -BASE_CFG_WEIGHT = DEFAULT_CFG_WEIGHT # Default: 0.7 -BASE_TEMPERATURE = DEFAULT_TEMPERATURE # Default: 0.7 - -# --- Sensitivity --- -# How much VADER's compound score affects each parameter. -# Higher values mean more dramatic changes based on sentiment. -VADER_EXAGGERATION_SENSITIVITY = 0.5 # e.g., compound of 0.8 -> 1.0 + (0.8 * 0.5) = 1.4 -VADER_CFG_WEIGHT_SENSITIVITY = -0.2 # Negative: more emotional text is less strict -VADER_TEMPERATURE_SENSITIVITY = 0.15 # More emotional text gets slightly more creative - -# --- Min/Max Clamps --- -# Hard limits to prevent extreme, undesirable audio artifacts. -TTS_PARAM_MIN_EXAGGERATION = 0.1 -TTS_PARAM_MAX_EXAGGERATION = 2.0 -TTS_PARAM_MIN_CFG_WEIGHT = 0.1 -TTS_PARAM_MAX_CFG_WEIGHT = 1.0 - -TTS_PARAM_MIN_TEMPERATURE = 0.1 -TTS_PARAM_MAX_TEMPERATURE = 5.0 - -# ============================================================================ -# BATCH PROCESSING SETTINGS -# ============================================================================ -BATCH_SIZE = 250 # Larger batches for better speed (monitor VRAM) -CLEANUP_INTERVAL = 500 # Deep cleanup every N chunks (reduced frequency for speed) - -# ============================================================================ -# FEATURE TOGGLES -# ============================================================================ -shutdown_requested = False # Global shutdown flag diff --git a/HF_Deploy/gradio_main_interface.py b/HF_Deploy/gradio_main_interface.py deleted file mode 100644 index 448e91c21b755a590bcea159c22b3a30e87e12b0..0000000000000000000000000000000000000000 --- a/HF_Deploy/gradio_main_interface.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 -""" -ChatterboxTTS DNXS-Spokneword Gradio Main Interface -Modular web interface with separate tab modules -""" - -import gradio as gr -import sys -import os -from pathlib import Path - -# Add the current directory to Python path for imports -sys.path.append(str(Path(__file__).parent)) - -# Import tab modules -try: - from gradio_tabs.tab1_convert_book import create_convert_book_tab - TAB1_AVAILABLE = True -except ImportError as e: - print(f"⚠️ Tab 1 not available: {e}") - TAB1_AVAILABLE = False - -try: - from gradio_tabs.tab6_settings import create_settings_tab_interface - TAB6_AVAILABLE = True -except ImportError as e: - print(f"⚠️ Tab 6 (Settings) not available: {e}") - TAB6_AVAILABLE = False - -def create_placeholder_tab(tab_name, tab_number): - """Create a placeholder tab for future implementation""" - with gr.Column(): - gr.Markdown(f"# 🚧 {tab_name}") - gr.Markdown(f"*Tab {tab_number} - Coming Soon*") - gr.Markdown("This tab will be implemented in a future update.") - - gr.Button("Placeholder Button", interactive=False) - -def create_main_interface(): - """Create the main ChatterboxTTS Gradio interface with all tabs""" - - with gr.Blocks( - title="ChatterboxTTS - Complete Interface", - theme=gr.themes.Soft(), - css=""" - .gradio-container { - max-width: 1200px !important; - } - """ - ) as demo: - - # Header - gr.Markdown(""" - # 🎀 ChatterboxTTS - Complete Web Interface - *Modular audiobook generation system with advanced TTS capabilities* - """) - - # Tab interface - with gr.Tabs(): - # Tab 1: Convert Book (Working) - if TAB1_AVAILABLE: - with gr.Tab("1. Convert Book"): - create_convert_book_tab() - else: - with gr.Tab("1. Convert Book"): - create_placeholder_tab("Convert Book", 1) - - # Tab 2-10: Placeholders for now - with gr.Tab("2. File Management"): - create_placeholder_tab("File Management", 2) - - with gr.Tab("3. Voice Analysis"): - create_placeholder_tab("Voice Analysis", 3) - - with gr.Tab("4. Batch Processing"): - create_placeholder_tab("Batch Processing", 4) - - with gr.Tab("5. Audio Tools"): - create_placeholder_tab("Audio Tools", 5) - - # Tab 6: Settings (Working) - if TAB6_AVAILABLE: - with gr.Tab("6. Settings"): - create_settings_tab_interface() - else: - with gr.Tab("6. Settings"): - create_placeholder_tab("Settings", 6) - - with gr.Tab("7. Chunk Tools"): - create_placeholder_tab("Chunk Tools", 7) - - with gr.Tab("8. Voice Training"): - create_placeholder_tab("Voice Training", 8) - - with gr.Tab("9. System Monitor"): - create_placeholder_tab("System Monitor", 9) - - with gr.Tab("10. About"): - create_placeholder_tab("About", 10) - - # Footer - gr.Markdown(""" - --- - *ChatterboxTTS Gradio Interface - Modular Design* - Each tab is a separate module for easy maintenance and development. - """) - - return demo - -def launch_interface(): - """Launch the main interface""" - print("πŸš€ ChatterboxTTS - Starting Main Interface") - print("πŸ“Š Tab Status:") - print(f" Tab 1 (Convert Book): {'βœ… Available' if TAB1_AVAILABLE else '❌ Not Available'}") - print(" Tabs 2-10: 🚧 Placeholder (Coming Soon)") - print("-" * 50) - - demo = create_main_interface() - - # Launch configuration - launch_kwargs = { - 'server_name': '0.0.0.0', - 'server_port': 7860, - 'show_error': True, - 'quiet': False - } - - # Detect cloud environments - if os.getenv("RUNPOD_POD_ID"): - print("☁️ RunPod deployment detected") - launch_kwargs['share'] = True - elif os.getenv("COLAB_GPU"): - print("☁️ Google Colab detected") - launch_kwargs['share'] = True - else: - print("πŸ’» Local deployment") - launch_kwargs['share'] = False - - print(f"🌐 Interface will be available at: http://localhost:{launch_kwargs['server_port']}") - - try: - demo.launch(**launch_kwargs) - except Exception as e: - print(f"❌ Error launching interface: {e}") - raise - -if __name__ == "__main__": - launch_interface() diff --git a/HF_Deploy/gradio_tabs/__init__.py b/HF_Deploy/gradio_tabs/__init__.py deleted file mode 100644 index 15d3c56e398fc8f70a495002744e47217d3ba566..0000000000000000000000000000000000000000 --- a/HF_Deploy/gradio_tabs/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -ChatterboxTTS Gradio Tabs Package -Modular tab system for the web interface -""" - -# Make this directory a Python package -__version__ = "1.0.0" \ No newline at end of file diff --git a/HF_Deploy/gradio_tabs/tab1_convert_book.py b/HF_Deploy/gradio_tabs/tab1_convert_book.py deleted file mode 100644 index 732da1b22bc9c70094be2fa887e7141bee83795e..0000000000000000000000000000000000000000 --- a/HF_Deploy/gradio_tabs/tab1_convert_book.py +++ /dev/null @@ -1,1173 +0,0 @@ -#!/usr/bin/env python3 -""" -Gradio Tab 1: Convert Book -Exact replica of PyQt5 GUI Tab 1 functionality -""" - -import gradio as gr -import os -import sys -import threading -import subprocess -import tempfile -import json -import warnings -import re -import time -from pathlib import Path -from typing import List, Dict, Any, Optional, Tuple - -# Suppress CUDA deprecation warnings -warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.backends.cuda.sdp_kernel.*") -warnings.filterwarnings("ignore", category=FutureWarning, message=".*sdp_kernel.*") - -# Import ChatterboxTTS modules and ensure all config variables are available -# First set defaults, then try to import from config -DEFAULT_EXAGGERATION = 0.4 -DEFAULT_CFG_WEIGHT = 0.5 -DEFAULT_TEMPERATURE = 0.9 -TTS_PARAM_MIN_EXAGGERATION = 0.0 -TTS_PARAM_MAX_EXAGGERATION = 2.0 -TTS_PARAM_MIN_CFG_WEIGHT = 0.0 -TTS_PARAM_MAX_CFG_WEIGHT = 1.0 -TTS_PARAM_MIN_TEMPERATURE = 0.0 -TTS_PARAM_MAX_TEMPERATURE = 5.0 -ENABLE_REGENERATION_LOOP = True -MAX_REGENERATION_ATTEMPTS = 3 -QUALITY_THRESHOLD = 0.7 -ENABLE_SENTIMENT_SMOOTHING = True -SENTIMENT_SMOOTHING_WINDOW = 3 -SENTIMENT_SMOOTHING_METHOD = "rolling" -ENABLE_MFCC_VALIDATION = False -ENABLE_OUTPUT_VALIDATION = False -SPECTRAL_ANOMALY_THRESHOLD = 0.8 -OUTPUT_VALIDATION_THRESHOLD = 0.85 - -# Try to import config and override defaults if available -try: - from config.config import * - CONFIG_AVAILABLE = True - print("βœ… Config loaded successfully") -except ImportError: - print("⚠️ Config not available - using defaults") - CONFIG_AVAILABLE = False - -# Import the actual conversion functions from GUI -try: - # We need to import the actual conversion logic - import importlib.util - gui_spec = importlib.util.spec_from_file_location("chatterbox_gui", "chatterbox_gui.py") - gui_module = importlib.util.module_from_spec(gui_spec) - # We'll access the GUI's conversion methods - GUI_AVAILABLE = True -except Exception as e: - print(f"⚠️ GUI module not available: {e}") - GUI_AVAILABLE = False - -# Global state for conversion with enhanced stats -conversion_state = { - 'running': False, - 'progress': 0, - 'status': 'Ready', - 'thread': None, - 'realtime_factor': '--', - 'vram_usage': '-- GB', - 'current_chunk': '--', - 'eta': '--', - 'elapsed': '--' -} - -def parse_progress_stats(output_line): - """Parse progress statistics from TTS engine output""" - # Look for progress pattern: "πŸŒ€ Chunk 2/13 | ⏱ Elapsed: 0:01:31 | ETA: 0:09:54 | Remaining: 0:08:23 | Realtime: 0.11x | VRAM: 3.3GB" - progress_pattern = r'πŸŒ€ Chunk (\d+)/(\d+).*?Realtime: ([\d.]+)x.*?VRAM: ([\d.]+)GB' - match = re.search(progress_pattern, output_line) - - if match: - current_chunk = int(match.group(1)) - total_chunks = int(match.group(2)) - realtime_factor = f"{match.group(3)}x" - vram_usage = f"{match.group(4)} GB" - - # Update global state - conversion_state['current_chunk'] = f"{current_chunk}/{total_chunks}" - conversion_state['realtime_factor'] = realtime_factor - conversion_state['vram_usage'] = vram_usage - conversion_state['progress'] = int((current_chunk / total_chunks) * 100) if total_chunks > 0 else 0 - - print(f"πŸ“Š Stats Updated: Chunk {current_chunk}/{total_chunks}, {realtime_factor}, {vram_usage}") - return True - else: - # Try alternative patterns in case the format is different - alt_pattern = r'Chunk (\d+)/(\d+).*?Realtime: ([\d.]+)x.*?VRAM: ([\d.]+)GB' - alt_match = re.search(alt_pattern, output_line) - if alt_match: - current_chunk = int(alt_match.group(1)) - total_chunks = int(alt_match.group(2)) - realtime_factor = f"{alt_match.group(3)}x" - vram_usage = f"{alt_match.group(4)} GB" - - conversion_state['current_chunk'] = f"{current_chunk}/{total_chunks}" - conversion_state['realtime_factor'] = realtime_factor - conversion_state['vram_usage'] = vram_usage - conversion_state['progress'] = int((current_chunk / total_chunks) * 100) if total_chunks > 0 else 0 - - print(f"πŸ“Š Stats Updated: Chunk {current_chunk}/{total_chunks}, {realtime_factor}, {vram_usage}") - return True - - return False - -def get_progress_stats(): - """Get current progress statistics for UI update""" - return ( - conversion_state['realtime_factor'], - conversion_state['vram_usage'], - conversion_state['current_chunk'], - conversion_state['progress'] - ) - -def get_book_folders(): - """Get available book folders from Text_Input directory""" - text_input_dir = Path("Text_Input") - if not text_input_dir.exists(): - return [] - - folders = [] - for item in text_input_dir.iterdir(): - if item.is_dir(): - folders.append(item.name) # Show only folder name, not full path - - return sorted(folders) - -def get_text_files_in_folder(folder_name): - """Get text files in selected book folder""" - if not folder_name: - return [] - - # Build full path from folder name - folder = Path("Text_Input") / folder_name - if not folder.exists(): - return [] - - text_files = [] - for file in folder.glob("*.txt"): - text_files.append(file.name) - - return sorted(text_files) - -def get_voice_samples(): - """Get available voice samples from Voice_Samples directory""" - voice_dir = Path("Voice_Samples") - if not voice_dir.exists(): - return [] - - voices = [] - for file in voice_dir.glob("*.wav"): - voices.append(file.name) # Show only filename, not full path - - return sorted(voices) - -def find_generated_audiobook(book_folder_path, voice_sample_path): - """Find the generated audiobook files""" - try: - book_folder = Path(book_folder_path) - voice_file = Path(voice_sample_path) - voice_name = voice_file.stem - - # Look in Output/ directory first (final audiobooks) - output_dir = Path("Output") - if output_dir.exists(): - # Look for M4B files with voice name - for m4b_file in output_dir.glob(f"*[{voice_name}]*.m4b"): - if m4b_file.exists(): - return str(m4b_file), "M4B audiobook" - - # Look for WAV files with voice name - for wav_file in output_dir.glob(f"*[{voice_name}]*.wav"): - if wav_file.exists(): - return str(wav_file), "WAV audiobook" - - # Look in Audiobook/ directory (processing output) - audiobook_dir = Path("Audiobook") / book_folder.name - if audiobook_dir.exists(): - # Look for M4B files - for m4b_file in audiobook_dir.glob(f"*[{voice_name}]*.m4b"): - if m4b_file.exists(): - return str(m4b_file), "M4B audiobook" - - # Look for WAV files - for wav_file in audiobook_dir.glob(f"*[{voice_name}]*.wav"): - if wav_file.exists(): - return str(wav_file), "WAV audiobook" - - # Look for combined files - for combined_file in audiobook_dir.glob("*_combined.*"): - if combined_file.suffix in ['.wav', '.m4b', '.mp3']: - return str(combined_file), f"{combined_file.suffix.upper()[1:]} combined audiobook" - - return None, "No audiobook found" - - except Exception as e: - print(f"Error finding audiobook: {e}") - return None, f"Error: {str(e)}" - -def run_book_conversion(book_path, text_file_path, voice_path, tts_params, quality_params, config_params): - """Run the actual book conversion - Direct call to TTS engine with progress monitoring""" - try: - # Import the real TTS engine function directly (avoid interface.py) - from modules.tts_engine import process_book_folder - - # Extract enable_asr from tts_params (matching GUI exactly) - enable_asr = tts_params.get('enable_asr', False) - - print(f"πŸš€ Starting book conversion with GUI parameters") - print(f"πŸ“– Book: {book_path}") - print(f"πŸ“„ Text file: {text_file_path}") - print(f"🎀 Voice: {voice_path}") - print(f"πŸŽ›οΈ TTS Params: {tts_params}") - print(f"πŸ”¬ Quality Params: {quality_params}") - print(f"βš™οΈ Config Params: {config_params}") - - # Set up progress callback function - def progress_callback(current_chunk, total_chunks, realtime_factor, vram_usage): - """Callback function to update progress from TTS engine""" - conversion_state['current_chunk'] = f"{current_chunk}/{total_chunks}" - conversion_state['realtime_factor'] = f"{realtime_factor}x" - conversion_state['vram_usage'] = f"{vram_usage} GB" - conversion_state['progress'] = int((current_chunk / total_chunks) * 100) if total_chunks > 0 else 0 - print(f"πŸ“Š Progress: {current_chunk}/{total_chunks} ({conversion_state['progress']}%) - {realtime_factor}x - {vram_usage}GB") - - # Add progress callback to config params - config_params['progress_callback'] = progress_callback - - # Convert string paths to Path objects (required by TTS engine) - book_dir_path = Path(book_path) - voice_path_obj = Path(voice_path) - - # Auto-detect device with fallback to CPU - import torch - if torch.cuda.is_available(): - device = "cuda" - print("βœ… Using CUDA GPU for processing") - else: - device = "cpu" - print("πŸ’» Using CPU for processing (no GPU available)") - - # Direct call to TTS engine (function only accepts: book_dir, voice_path, tts_params, device, skip_cleanup) - result = process_book_folder( - book_dir=book_dir_path, - voice_path=voice_path_obj, - tts_params=tts_params, - device=device, - skip_cleanup=False - ) - - print(f"βœ… Conversion completed successfully") - return {'success': True, 'result': result} - - except Exception as e: - print(f"❌ Conversion failed: {e}") - import traceback - traceback.print_exc() - return {'success': False, 'error': str(e)} - -def regenerate_m4b_file(selected_m4b, playback_speed): - """Regenerate M4B file with new playback speed""" - if not selected_m4b: - return "❌ Please select an M4B file first", None - - try: - print(f"πŸ”„ Regenerating M4B: {selected_m4b} at {playback_speed}x speed") - - # Import M4B regeneration tools - from tools.combine_only import apply_playback_speed_to_m4b - - # Find the M4B file path - audiobook_root = Path("Audiobook") - m4b_path = None - - for book_dir in audiobook_root.iterdir(): - if book_dir.is_dir(): - for m4b_file in book_dir.glob("*.m4b"): - if m4b_file.name == selected_m4b: - m4b_path = m4b_file - break - if m4b_path: - break - - if not m4b_path: - return "❌ M4B file not found", None - - # Create new filename with speed suffix - speed_suffix = f"_speed{playback_speed}x".replace(".", "p") - new_name = m4b_path.stem + speed_suffix + ".m4b" - output_path = m4b_path.parent / new_name - - # Apply speed change - success = apply_playback_speed_to_m4b(str(m4b_path), str(output_path), playback_speed) - - if success: - return f"βœ… Regenerated M4B at {playback_speed}x speed: {new_name}", str(output_path) - else: - return "❌ Failed to regenerate M4B", None - - except Exception as e: - print(f"❌ M4B regeneration failed: {e}") - return f"❌ Error: {str(e)}", None - -def create_convert_book_tab(): - """Create Tab 1: Convert Book with all GUI functionality""" - - with gr.Column(): - gr.Markdown("# πŸš€ Convert Book") - gr.Markdown("*Main TTS conversion functionality - matches GUI Tab 1*") - - # Main Content Layout - with gr.Row(): - # Left Column - File Uploads - with gr.Column(scale=2): - gr.Markdown("### πŸ“š Book Selection") - - # Book text file upload only - text_file_upload = gr.File( - label="πŸ“š Upload Book Text File", - file_types=[".txt"], - file_count="single", - interactive=True - ) - - gr.Markdown("### 🎀 Voice Selection") - - # Single voice upload with integrated playback - voice_file_upload = gr.File( - label="🎀 Upload Voice Sample", - file_types=[".wav", ".mp3", ".m4a"], - file_count="single", - interactive=True - ) - - # Voice sample player (becomes active after upload) - voice_audio = gr.Audio( - label="Voice Sample Preview", - interactive=False, - show_download_button=False, - visible=False - ) - - # Right Column - All Settings - with gr.Column(scale=1): - gr.Markdown("### βš™οΈ Quick Settings") - - # VADER and ASR - vader_enabled = gr.Checkbox( - label="Use VADER sentiment analysis", - value=True, - info="Adjust TTS params per chunk based on emotion" - ) - - # ASR System with intelligent model selection - with gr.Row(): - asr_enabled = gr.Checkbox( - label="🎀 Enable ASR validation", - value=False, - info="Smart quality control with automatic model selection" - ) - - # ASR Configuration (initially hidden) - with gr.Column(visible=False) as asr_config_group: - gr.Markdown("#### πŸ” ASR Configuration") - - # System analysis display - system_analysis = gr.Textbox( - label="System Analysis", - value="Click 'Analyze System' to detect capabilities", - lines=3, - interactive=False - ) - - analyze_system_btn = gr.Button( - "πŸ” Analyze System", - size="sm", - variant="secondary" - ) - - # ASR Level Selection - asr_level = gr.Radio( - label="ASR Quality Level", - choices=[ - ("🟒 SAFE - Fast processing, basic accuracy", "safe"), - ("🟑 MODERATE - Balanced speed/accuracy (recommended)", "moderate"), - ("πŸ”΄ INSANE - Best accuracy, may stress system", "insane") - ], - value="moderate", - info="Automatically selects best models for your system" - ) - - # Selected models display - selected_models = gr.Textbox( - label="Selected ASR Models", - value="Select level to see model configuration", - lines=2, - interactive=False - ) - - # Batch processing - add_to_batch = gr.Checkbox( - label="πŸ“¦ Add to batch queue", - value=False, - info="Queue for batch processing" - ) - - gr.Markdown("### πŸ”„ Regeneration Settings") - - regeneration_enabled = gr.Checkbox( - label="Enable automatic chunk regeneration", - value=ENABLE_REGENERATION_LOOP, - info="Retry failed chunks automatically" - ) - - max_attempts = gr.Slider( - label="Max Attempts", - minimum=1, maximum=10, step=1, - value=MAX_REGENERATION_ATTEMPTS - ) - - quality_threshold = gr.Slider( - label="Quality Threshold", - minimum=0.1, maximum=1.0, step=0.05, - value=QUALITY_THRESHOLD - ) - - gr.Markdown("### πŸ“Š Sentiment Smoothing") - - sentiment_smoothing = gr.Checkbox( - label="Enable sentiment smoothing", - value=ENABLE_SENTIMENT_SMOOTHING, - info="Smooth emotional transitions" - ) - - smoothing_window = gr.Slider( - label="Window Size", - minimum=1, maximum=10, step=1, - value=SENTIMENT_SMOOTHING_WINDOW - ) - - smoothing_method = gr.Dropdown( - label="Smoothing Method", - choices=["rolling", "exp_decay"], - value=SENTIMENT_SMOOTHING_METHOD - ) - - gr.Markdown("### πŸ” Advanced Detection") - - mfcc_validation = gr.Checkbox( - label="MFCC spectral analysis", - value=ENABLE_MFCC_VALIDATION, - info="Advanced audio quality detection" - ) - - output_validation = gr.Checkbox( - label="Output validation", - value=ENABLE_OUTPUT_VALIDATION, - info="Quality control clearinghouse for enabled checks" - ) - - spectral_threshold = gr.Slider( - label="Spectral Threshold", - minimum=0.1, maximum=1.0, step=0.05, - value=SPECTRAL_ANOMALY_THRESHOLD - ) - - output_threshold = gr.Slider( - label="Output Threshold", - minimum=0.1, maximum=1.0, step=0.05, - value=OUTPUT_VALIDATION_THRESHOLD - ) - - - # TTS Parameters - with gr.Row(): - with gr.Column(): - gr.Markdown("### πŸŽ›οΈ TTS Parameters") - - exaggeration = gr.Slider( - label="Exaggeration", - minimum=TTS_PARAM_MIN_EXAGGERATION, - maximum=TTS_PARAM_MAX_EXAGGERATION, - step=0.1, - value=DEFAULT_EXAGGERATION, - info="Emotional intensity" - ) - - cfg_weight = gr.Slider( - label="CFG Weight", - minimum=TTS_PARAM_MIN_CFG_WEIGHT, - maximum=TTS_PARAM_MAX_CFG_WEIGHT, - step=0.1, - value=DEFAULT_CFG_WEIGHT, - info="Text faithfulness" - ) - - temperature = gr.Slider( - label="Temperature", - minimum=TTS_PARAM_MIN_TEMPERATURE, - maximum=TTS_PARAM_MAX_TEMPERATURE, - step=0.1, - value=DEFAULT_TEMPERATURE, - info="Creativity/randomness" - ) - - with gr.Column(): - gr.Markdown("### ⚑ Advanced Sampling") - - min_p = gr.Slider( - label="Min-P", - minimum=0.0, maximum=0.5, step=0.01, - value=0.05, - info="Minimum probability threshold" - ) - - top_p = gr.Slider( - label="Top-P", - minimum=0.5, maximum=1.0, step=0.1, - value=1.0, - info="Nucleus sampling" - ) - - repetition_penalty = gr.Slider( - label="Repetition Penalty", - minimum=1.0, maximum=3.0, step=0.1, - value=2.0, - info="Reduce repetition" - ) - - gr.Markdown("### βš™οΈ Performance Settings") - - max_workers = gr.Number( - label="Max Workers", - minimum=1, maximum=8, step=1, - value=2, - info="⚠️ Only increase above 2 if CPU/GPU utilization < 70%" - ) - - # Action Buttons and Status - with gr.Row(): - with gr.Column(scale=2): - convert_btn = gr.Button( - "πŸš€ Start Conversion", - variant="primary", - size="lg", - interactive=True - ) - - # Status Display - status_display = gr.Textbox( - label="Status", - value="⏸ Ready", - interactive=False, - lines=1 - ) - - progress_display = gr.Number( - label="Progress %", - value=0, - interactive=False, - precision=0 - ) - - with gr.Column(scale=1): - gr.Markdown("### πŸ“Š Processing Stats") - - realtime_factor = gr.Textbox( - label="Realtime Factor", - value="--", - interactive=False - ) - - vram_usage = gr.Textbox( - label="VRAM Usage", - value="-- GB", - interactive=False - ) - - current_chunk = gr.Textbox( - label="Current Chunk", - value="--", - interactive=False - ) - - # Regenerate M4B Section (moved above audiobook player) - with gr.Row(): - with gr.Column(): - gr.Markdown("### πŸ”„ Regenerate M4B") - - with gr.Row(): - with gr.Column(scale=2): - m4b_file_selector = gr.Dropdown( - label="Select M4B File to Regenerate", - choices=[], - value=None, - interactive=True, - info="Choose from generated audiobook files" - ) - - with gr.Column(scale=1): - playback_speed = gr.Slider( - label="Playback Speed", - minimum=0.5, - maximum=2.0, - step=0.1, - value=1.0, - info="Speed adjustment for regeneration" - ) - - regenerate_m4b_btn = gr.Button( - "πŸ”„ Regenerate M4B", - variant="secondary", - size="lg" - ) - - # Generated Audiobook Player (simplified, play-only) - with gr.Row(): - with gr.Column(): - gr.Markdown("### 🎧 Generated Audiobook Player") - - # Audiobook file selector dropdown - audiobook_selector = gr.Dropdown( - label="Select Audiobook", - choices=[], - value=None, - interactive=True, - info="Choose from session audiobooks" - ) - - # Main audio player - play only, no upload - audio_player = gr.Audio( - label="Audiobook Player", - value=None, - interactive=False, - show_download_button=True, - show_share_button=False, - waveform_options=gr.WaveformOptions( - show_controls=True, - show_recording_waveform=False, - skip_length=10 - ) - ) - - # Event Handlers - def handle_voice_upload(voice_file): - """Handle voice file upload and show player""" - if voice_file is None: - return gr.update(value=None, visible=False) - - # Show the voice player with uploaded file - return gr.update(value=voice_file, visible=True) - - def get_session_audiobooks(): - """Get list of M4B files from current session, sorted by creation time (newest first)""" - audiobooks = [] - - # Look in Audiobook directory for M4B files - audiobook_root = Path("Audiobook") - if audiobook_root.exists(): - for book_dir in audiobook_root.iterdir(): - if book_dir.is_dir(): - # Look for M4B files in book directory - for m4b_file in book_dir.glob("*.m4b"): - # Get creation time for sorting - creation_time = m4b_file.stat().st_mtime - audiobooks.append((str(m4b_file), m4b_file.name, creation_time)) - - # Also check Output directory - output_root = Path("Output") - if output_root.exists(): - for m4b_file in output_root.glob("*.m4b"): - creation_time = m4b_file.stat().st_mtime - audiobooks.append((str(m4b_file), m4b_file.name, creation_time)) - - # Sort by creation time (newest first) - audiobooks.sort(key=lambda x: x[2], reverse=True) - - # Return just path and name (drop creation time) - return [(ab[0], ab[1]) for ab in audiobooks] - - def update_audiobook_dropdowns(latest_file=None): - """Update audiobook dropdowns - after conversion both show latest, after regeneration only playback updates""" - audiobooks = get_session_audiobooks() - choices = [ab[1] for ab in audiobooks] # Just filenames for display - - # Determine what to set as selected - if latest_file: - # Use specific file if provided - selected_file = latest_file - elif choices: - # Default to newest file (first in sorted list) - selected_file = choices[0] - else: - selected_file = None - - return ( - gr.update(choices=choices, value=selected_file), # audiobook_selector (playback) - gr.update(choices=choices, value=selected_file) # m4b_file_selector (regeneration source) - ) - - def update_audiobook_dropdowns_after_conversion(): - """Update both dropdowns to show the newest generated file after conversion""" - return update_audiobook_dropdowns() - - def update_playback_only(new_file_name): - """Update only the playback dropdown after regeneration""" - audiobooks = get_session_audiobooks() - choices = [ab[1] for ab in audiobooks] - - return ( - gr.update(choices=choices, value=new_file_name), # audiobook_selector (playback) - new file - gr.update() # m4b_file_selector (regeneration) - no change - ) - - def load_selected_audiobook(selected_audiobook): - """Load selected audiobook into player""" - if not selected_audiobook: - return None - - # Find the full path for the selected audiobook - audiobooks = get_session_audiobooks() - for full_path, filename in audiobooks: - if filename == selected_audiobook: - return full_path - - return None - - def handle_asr_toggle(asr_enabled_val): - """Show/hide ASR configuration when ASR is toggled""" - return gr.update(visible=asr_enabled_val) - - def analyze_system(): - """Analyze system capabilities and return summary""" - try: - from modules.system_detector import get_system_profile, print_system_summary, categorize_system - - profile = get_system_profile() - categories = categorize_system(profile) - - summary = f"πŸ–₯️ System Profile:\n" - summary += f"VRAM: {profile['gpu']['total_mb']:,}MB total, {profile['available_vram_after_tts']:,}MB available after TTS ({categories['vram']} class)\n" - summary += f"RAM: {profile['ram']['total_mb']:,}MB total, {profile['ram']['available_mb']:,}MB available ({categories['ram']} class)\n" - summary += f"CPU: {profile['cpu_cores']} cores ({categories['cpu']} class)" - - if not profile['has_gpu']: - summary += f"\n⚠️ No CUDA GPU detected - ASR will run on CPU only" - - return summary - - except Exception as e: - return f"❌ Error analyzing system: {str(e)}" - - def update_asr_models(asr_level_val): - """Update ASR model display based on selected level""" - try: - from modules.system_detector import get_system_profile, recommend_asr_models - - profile = get_system_profile() - recommendations = recommend_asr_models(profile) - - if asr_level_val not in recommendations: - return "❌ Invalid ASR level selected" - - config = recommendations[asr_level_val] - primary = config['primary'] - fallback = config['fallback'] - - result = f"Primary: {primary['model']} on {primary['device'].upper()}\n" - result += f"Fallback: {fallback['model']} on {fallback['device'].upper()}" - - if asr_level_val == 'insane': - result += f"\n⚠️ WARNING: INSANE mode may cause memory pressure" - - return result - - except Exception as e: - return f"❌ Error getting models: {str(e)}" - - def start_conversion(text_file_upload, voice_file_upload, - vader_val, asr_val, asr_level_val, add_to_batch_val, - regen_enabled_val, max_attempts_val, quality_thresh_val, - sentiment_smooth_val, smooth_window_val, smooth_method_val, - mfcc_val, output_val, spectral_thresh_val, output_thresh_val, - exag_val, cfg_val, temp_val, min_p_val, top_p_val, rep_penalty_val, - max_workers_val): - """Start the actual book conversion - file upload version""" - - # Validation - if not text_file_upload: - return "❌ Please upload a text file", 0, None, None - if not voice_file_upload: - return "❌ Please upload a voice sample", 0, None, None - - # Check if already running - if conversion_state['running']: - return "⚠️ Conversion already in progress", conversion_state['progress'], None, None - - try: - # Create temporary book structure from uploads - import tempfile - import shutil - from datetime import datetime - - # Generate unique book name from text file - text_filename = Path(text_file_upload).name - book_name = text_filename.replace('.txt', '').replace(' ', '_') - timestamp = datetime.now().strftime("%H%M%S") - unique_book_name = f"{book_name}_{timestamp}" - - # Create directory structure - text_input_dir = Path("Text_Input") - text_input_dir.mkdir(exist_ok=True) - - book_dir = text_input_dir / unique_book_name - book_dir.mkdir(exist_ok=True) - - # Copy uploaded files to expected locations - text_dest = book_dir / f"{unique_book_name}.txt" - shutil.copy2(text_file_upload, text_dest) - - voice_samples_dir = Path("Voice_Samples") - voice_samples_dir.mkdir(exist_ok=True) - - voice_filename = Path(voice_file_upload).name - voice_dest = voice_samples_dir / voice_filename - shutil.copy2(voice_file_upload, voice_dest) - - print(f"πŸ“ Created book structure: {book_dir}") - print(f"πŸ“„ Text file: {text_dest}") - print(f"🎀 Voice file: {voice_dest}") - - except Exception as e: - return f"❌ Error setting up files: {e}", 0, None, None - - # Build ASR configuration first - asr_config = {'enabled': False} - if asr_val: - try: - from modules.system_detector import get_system_profile, recommend_asr_models - profile = get_system_profile() - recommendations = recommend_asr_models(profile) - - if asr_level_val in recommendations: - selected_config = recommendations[asr_level_val] - primary = selected_config['primary'] - fallback = selected_config['fallback'] - - asr_config = { - 'enabled': True, - 'level': asr_level_val, - 'primary_model': primary['model'], - 'primary_device': primary['device'], - 'fallback_model': fallback['model'], - 'fallback_device': fallback['device'] - } - except Exception as e: - print(f"⚠️ Error configuring ASR: {e}") - asr_config = {'enabled': False} - - # Prepare parameters (matching GUI structure exactly) - tts_params = { - 'exaggeration': exag_val, - 'cfg_weight': cfg_val, - 'temperature': temp_val, - 'min_p': min_p_val, - 'top_p': top_p_val, - 'repetition_penalty': rep_penalty_val, - 'enable_asr': asr_config.get('enabled', False), # Match GUI pattern - 'max_workers': int(max_workers_val) # User-defined worker count - } - - quality_params = { - 'regeneration_enabled': regen_enabled_val, - 'max_attempts': max_attempts_val, - 'quality_threshold': quality_thresh_val, - 'sentiment_smoothing': sentiment_smooth_val, - 'smoothing_window': smooth_window_val, - 'smoothing_method': smooth_method_val, - 'mfcc_validation': mfcc_val, - 'output_validation': output_val, - 'spectral_threshold': spectral_thresh_val, - 'output_threshold': output_thresh_val - } - - config_params = { - 'vader_enabled': vader_val, - 'asr_enabled': asr_val, - 'asr_config': asr_config, - 'add_to_batch': add_to_batch_val - } - - # Set conversion state - conversion_state['running'] = True - conversion_state['progress'] = 0 - conversion_state['status'] = 'Starting conversion...' - conversion_state['current_book'] = book_dir.name # Track current book - - try: - # Run conversion using the modular backend in a separate thread - import threading - - def run_conversion_thread(): - try: - result = run_book_conversion( - str(book_dir), str(text_dest), str(voice_dest), - tts_params, quality_params, config_params - ) - - if result['success']: - conversion_state['status'] = 'πŸŽ‰ CONVERSION COMPLETE! M4B audiobook ready for playback.' - conversion_state['progress'] = 100 - conversion_state['auto_refresh_needed'] = True # Flag for auto-refresh - else: - conversion_state['status'] = f"❌ Conversion failed: {result.get('error', 'Unknown error')}" - conversion_state['progress'] = 0 - - except Exception as e: - conversion_state['status'] = f"❌ Error: {str(e)}" - conversion_state['progress'] = 0 - finally: - conversion_state['running'] = False - - # Start conversion thread - thread = threading.Thread(target=run_conversion_thread) - thread.start() - - # Return immediate response - user will need to refresh to see final results - return ( - "πŸš€ Conversion started in background...", - 5, # Initial progress - None, - gr.update(), - gr.update() - ) - - except Exception as e: - conversion_state['status'] = f"❌ Error: {str(e)}" - return conversion_state['status'], 0, None, gr.update(), gr.update() - finally: - conversion_state['running'] = False - - - # Connect event handlers - - # ASR event handlers - asr_enabled.change( - handle_asr_toggle, - inputs=[asr_enabled], - outputs=[asr_config_group] - ) - - analyze_system_btn.click( - analyze_system, - inputs=[], - outputs=[system_analysis] - ) - - asr_level.change( - update_asr_models, - inputs=[asr_level], - outputs=[selected_models] - ) - - # Voice upload handler - voice_file_upload.change( - handle_voice_upload, - inputs=[voice_file_upload], - outputs=[voice_audio] - ) - - # Main conversion handler - convert_btn.click( - start_conversion, - inputs=[ - text_file_upload, voice_file_upload, - vader_enabled, asr_enabled, asr_level, add_to_batch, - regeneration_enabled, max_attempts, quality_threshold, - sentiment_smoothing, smoothing_window, smoothing_method, - mfcc_validation, output_validation, spectral_threshold, output_threshold, - exaggeration, cfg_weight, temperature, min_p, top_p, repetition_penalty, - max_workers - ], - outputs=[status_display, progress_display, audio_player, audiobook_selector, m4b_file_selector] - ) - - # Audiobook selector handler - audiobook_selector.change( - load_selected_audiobook, - inputs=[audiobook_selector], - outputs=[audio_player] - ) - - # M4B regeneration handler - def handle_m4b_regeneration(selected_m4b, speed): - """Handle M4B regeneration and update player""" - status_msg, new_m4b_path = regenerate_m4b_file(selected_m4b, speed) - - if new_m4b_path: - # Load the new M4B in the player - new_file_name = Path(new_m4b_path).name - new_audio = load_selected_audiobook(new_file_name) - # Update only playback dropdown, keep regeneration dropdown on source file - audiobook_choices, m4b_choices = update_playback_only(new_file_name) - return status_msg, new_audio, audiobook_choices, m4b_choices - else: - return status_msg, None, gr.update(), gr.update() - - regenerate_m4b_btn.click( - handle_m4b_regeneration, - inputs=[m4b_file_selector, playback_speed], - outputs=[status_display, audio_player, audiobook_selector, m4b_file_selector] - ) - - # Progress monitoring with file-based approach - def get_current_stats(): - """Get current progress statistics by monitoring output files""" - try: - if conversion_state['running']: - # Look for generated audio chunks to estimate progress - book_name = conversion_state.get('current_book', 'unknown') - audiobook_root = Path("Audiobook") / book_name / "TTS" / "audio_chunks" - - if audiobook_root.exists(): - chunk_files = list(audiobook_root.glob("chunk_*.wav")) - current_chunks = len(chunk_files) - - # Try to estimate total from JSON if available - json_path = Path("Text_Input") / f"{book_name}_chunks.json" - total_chunks = 0 - if json_path.exists(): - import json - with open(json_path, 'r') as f: - data = json.load(f) - total_chunks = len(data) - - if total_chunks > 0: - progress = int((current_chunks / total_chunks) * 100) - conversion_state['progress'] = progress - conversion_state['current_chunk'] = f"{current_chunks}/{total_chunks}" - - return ( - conversion_state.get('realtime_factor', '--'), - conversion_state.get('vram_usage', '-- GB'), - f"{current_chunks}/{total_chunks}", - progress - ) - - return ( - conversion_state.get('realtime_factor', '--'), - conversion_state.get('vram_usage', '-- GB'), - conversion_state.get('current_chunk', '--'), - conversion_state.get('progress', 0) - ) - except Exception as e: - print(f"Error getting stats: {e}") - return "--", "-- GB", "--", conversion_state.get('progress', 0) - - def auto_check_completion(): - """Automatically check for completion and refresh interface""" - # First get current stats - stats = get_current_stats() - - # Check if conversion just completed and needs auto-refresh - if (not conversion_state['running'] and - conversion_state['progress'] == 100 and - conversion_state.get('auto_refresh_needed', False)): - - # Clear the auto-refresh flag - conversion_state['auto_refresh_needed'] = False - print("πŸŽ‰ Auto-detected completion! Refreshing interface...") - - # Get completion results - status, progress, audio, audiobook_choices, m4b_choices = get_status_and_results() - - # Return combined stats + completion results - return ( - stats[0], # realtime_factor - stats[1], # vram_usage - stats[2], # current_chunk - 100, # progress (completed) - status, # completion status - audio, # audio player - audiobook_choices, # audiobook dropdown - m4b_choices # m4b dropdown - ) - else: - # Return stats + current status (no completion) - return ( - stats[0], # realtime_factor - stats[1], # vram_usage - stats[2], # current_chunk - stats[3], # progress - conversion_state.get('status', '⏸ Ready'), # current status - gr.update(), # no audio update - gr.update(), # no audiobook update - gr.update() # no m4b update - ) - - def get_status_and_results(): - """Get conversion status and results after completion""" - if not conversion_state['running'] and conversion_state['progress'] == 100: - # Conversion completed, update dropdowns - audiobook_choices, m4b_choices = update_audiobook_dropdowns_after_conversion() - latest_audiobook = None - if audiobook_choices['choices']: - latest_audiobook = load_selected_audiobook(audiobook_choices['choices'][0]) - - return ( - conversion_state['status'], - conversion_state['progress'], - latest_audiobook, - audiobook_choices, - m4b_choices - ) - else: - return ( - conversion_state['status'], - conversion_state['progress'], - None, - gr.update(), - gr.update() - ) - - # Create refresh buttons - with gr.Row(): - refresh_stats_btn = gr.Button("πŸ”„ Refresh Stats", size="sm", variant="secondary") - check_completion_btn = gr.Button("πŸ“‹ Check Completion", size="sm", variant="secondary") - - # Auto-refresh timer (checks every 5 seconds during conversion) - auto_timer = gr.Timer(5.0) # 5 second interval - - refresh_stats_btn.click( - auto_check_completion, - outputs=[realtime_factor, vram_usage, current_chunk, progress_display, status_display, audio_player, audiobook_selector, m4b_file_selector] - ) - - check_completion_btn.click( - get_status_and_results, - outputs=[status_display, progress_display, audio_player, audiobook_selector, m4b_file_selector] - ) - - # Auto-timer for progress monitoring and completion detection - auto_timer.tick( - auto_check_completion, - outputs=[realtime_factor, vram_usage, current_chunk, progress_display, status_display, audio_player, audiobook_selector, m4b_file_selector] - ) - - return { - 'convert_button': convert_btn, - 'status_display': status_display, - 'progress': progress_display - } - -if __name__ == "__main__": - # Test the tab - with gr.Blocks() as demo: - create_convert_book_tab() - - demo.launch() diff --git a/HF_Deploy/modules/__init__.py b/HF_Deploy/modules/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/HF_Deploy/modules/asr_manager.py b/HF_Deploy/modules/asr_manager.py deleted file mode 100644 index 17a011f0104854ed2025eb310fcb5b55ded6c7e9..0000000000000000000000000000000000000000 --- a/HF_Deploy/modules/asr_manager.py +++ /dev/null @@ -1,233 +0,0 @@ -""" -ASR Manager Module -Centralized ASR model loading with adaptive GPU/CPU fallback and real-time VRAM monitoring -""" - -import torch -import logging -from pathlib import Path -from config.config import DEFAULT_ASR_MODEL, ASR_MODEL_VRAM_MB, ASR_MODEL_RAM_MB - -def get_real_time_vram_status(): - """Get current GPU memory usage in real-time""" - try: - if torch.cuda.is_available(): - gpu_count = torch.cuda.device_count() - if gpu_count > 0: - # Use first GPU - total_vram = torch.cuda.get_device_properties(0).total_memory - allocated_vram = torch.cuda.memory_allocated(0) - reserved_vram = torch.cuda.memory_reserved(0) - available_vram = total_vram - allocated_vram - - return { - 'total_mb': total_vram // 1024 // 1024, - 'allocated_mb': allocated_vram // 1024 // 1024, - 'reserved_mb': reserved_vram // 1024 // 1024, - 'available_mb': available_vram // 1024 // 1024, - 'has_gpu': True - } - except Exception as e: - logging.warning(f"Failed to get real-time VRAM status: {e}") - - return { - 'total_mb': 0, - 'allocated_mb': 0, - 'reserved_mb': 0, - 'available_mb': 0, - 'has_gpu': False - } - -def calculate_available_vram_for_asr(safety_buffer_mb=500): - """Calculate VRAM available for ASR with safety buffer""" - vram_status = get_real_time_vram_status() - - if not vram_status['has_gpu']: - return 0 - - # Available VRAM minus safety buffer for stability - available_with_buffer = max(0, vram_status['available_mb'] - safety_buffer_mb) - - return available_with_buffer - -def can_model_fit_gpu(model_name, available_vram_mb): - """Check if a specific ASR model can fit in available VRAM""" - required_vram = ASR_MODEL_VRAM_MB.get(model_name, 0) - return available_vram_mb >= required_vram - -def try_load_model_with_fallback(model_name, primary_device, fallback_device="cpu"): - """Try to load model on primary device, fallback to secondary if it fails""" - import whisper - - # Convert device names for whisper compatibility - def convert_device_name(device): - if device.lower() == "gpu": - return "cuda" - return device.lower() - - primary_device_whisper = convert_device_name(primary_device) - fallback_device_whisper = convert_device_name(fallback_device) - - try: - print(f"🎯 Attempting to load {model_name} on {primary_device.upper()}") - model = whisper.load_model(model_name, device=primary_device_whisper) - print(f"βœ… Successfully loaded {model_name} on {primary_device.upper()}") - return model, primary_device - - except Exception as e: - print(f"⚠️ {model_name} failed on {primary_device} ({str(e)[:50]}...)") - - if fallback_device_whisper != primary_device_whisper: - try: - print(f"πŸ”„ Trying {model_name} on {fallback_device.upper()}") - model = whisper.load_model(model_name, device=fallback_device_whisper) - print(f"βœ… Successfully loaded {model_name} on {fallback_device.upper()}") - return model, fallback_device - - except Exception as fallback_e: - print(f"❌ {model_name} also failed on {fallback_device} ({str(fallback_e)[:50]}...)") - - # Both failed - raise Exception(f"Model {model_name} failed on both {primary_device} and {fallback_device}") - -def load_asr_model_adaptive(asr_config=None): - """ - Adaptive ASR model loading with real-time VRAM checking and intelligent fallback - - Args: - asr_config: ASR configuration dict from interfaces (None for GUI fallback) - - Returns: - tuple: (asr_model, actual_device_used) or (None, None) if all loading fails - """ - print(f"πŸ” Starting adaptive ASR model loading...") - - # Get current VRAM status - vram_status = get_real_time_vram_status() - available_vram = calculate_available_vram_for_asr() - - print(f"πŸ–₯️ Real-time VRAM status:") - print(f" Total: {vram_status['total_mb']:,}MB") - print(f" Allocated: {vram_status['allocated_mb']:,}MB") - print(f" Available for ASR: {available_vram:,}MB (with 500MB safety buffer)") - - # Determine what models to try based on config - if asr_config and asr_config.get('enabled') and 'primary_model' in asr_config: - # Intelligent selection from CLI/Gradio - primary_model = asr_config['primary_model'] - primary_device = asr_config['primary_device'] - fallback_model = asr_config['fallback_model'] - fallback_device = asr_config['fallback_device'] - - print(f"🧠 Using intelligent ASR config:") - print(f" Primary: {primary_model} on {primary_device.upper()}") - print(f" Fallback: {fallback_model} on {fallback_device.upper()}") - - # Real-time VRAM check for primary model - if primary_device.lower() == 'gpu': - if not vram_status['has_gpu']: - print(f"⚠️ No GPU available, forcing CPU mode") - primary_device = 'cpu' - elif not can_model_fit_gpu(primary_model, available_vram): - required = ASR_MODEL_VRAM_MB.get(primary_model, 0) - print(f"⚠️ Insufficient VRAM for {primary_model} (need {required}MB, have {available_vram}MB)") - print(f"πŸ”„ Switching primary to CPU") - primary_device = 'cpu' - - # Try primary model - try: - return try_load_model_with_fallback(primary_model, primary_device, primary_device) - except: - # Primary failed, try fallback model - print(f"πŸ”„ Primary model failed, trying fallback configuration...") - - # Real-time VRAM check for fallback model - if fallback_device.lower() == 'gpu': - if not vram_status['has_gpu']: - print(f"⚠️ No GPU available for fallback, using CPU") - fallback_device = 'cpu' - elif not can_model_fit_gpu(fallback_model, available_vram): - required = ASR_MODEL_VRAM_MB.get(fallback_model, 0) - print(f"⚠️ Insufficient VRAM for fallback {fallback_model} (need {required}MB, have {available_vram}MB)") - fallback_device = 'cpu' - - try: - return try_load_model_with_fallback(fallback_model, fallback_device, 'cpu') - except: - print(f"❌ Both configured models failed!") - - else: - # Fallback mode for GUI or missing config - print(f"πŸ”§ Using fallback mode: {DEFAULT_ASR_MODEL}") - - # Last resort: try default model with adaptive device selection - print(f"πŸ†˜ Last resort: trying {DEFAULT_ASR_MODEL} with adaptive device selection") - - # Choose device based on real-time VRAM availability - if vram_status['has_gpu'] and can_model_fit_gpu(DEFAULT_ASR_MODEL, available_vram): - device = 'cuda' # Use cuda directly for whisper - device_display = 'GPU' - print(f"βœ… Using GPU for {DEFAULT_ASR_MODEL}") - else: - device = 'cpu' - device_display = 'CPU' - print(f"πŸ”„ Using CPU for {DEFAULT_ASR_MODEL}") - - try: - import whisper - model = whisper.load_model(DEFAULT_ASR_MODEL, device=device) - print(f"βœ… Successfully loaded {DEFAULT_ASR_MODEL} on {device_display}") - return model, device_display.lower() - except Exception as e: - print(f"❌ Critical failure: Could not load {DEFAULT_ASR_MODEL} on {device}: {e}") - - # Ultimate fallback to CPU if GPU failed - if device == 'cuda': - try: - print(f"πŸ†˜ Ultimate fallback: {DEFAULT_ASR_MODEL} on CPU") - model = whisper.load_model(DEFAULT_ASR_MODEL, device='cpu') - print(f"βœ… Successfully loaded {DEFAULT_ASR_MODEL} on CPU") - return model, 'cpu' - except Exception as cpu_e: - print(f"πŸ’€ Complete failure: {cpu_e}") - - return None, None - -def cleanup_asr_model(asr_model): - """Clean up ASR model to free memory""" - if asr_model is not None: - try: - del asr_model - if torch.cuda.is_available(): - torch.cuda.empty_cache() - print(f"🧹 ASR model cleaned up") - except Exception as e: - logging.warning(f"Failed to cleanup ASR model: {e}") - -def get_asr_memory_info(): - """Get memory information for ASR debugging""" - vram_status = get_real_time_vram_status() - available_vram = calculate_available_vram_for_asr() - - info = { - 'vram_total_mb': vram_status['total_mb'], - 'vram_allocated_mb': vram_status['allocated_mb'], - 'vram_available_for_asr_mb': available_vram, - 'has_gpu': vram_status['has_gpu'] - } - - return info - -if __name__ == "__main__": - # Test the adaptive loading - print("Testing ASR Manager...") - info = get_asr_memory_info() - print(f"Memory info: {info}") - - # Test adaptive loading - model, device = load_asr_model_adaptive() - if model: - print(f"Test successful: Model loaded on {device}") - cleanup_asr_model(model) - else: - print("Test failed: No model loaded") \ No newline at end of file diff --git a/HF_Deploy/modules/audio_processor.py b/HF_Deploy/modules/audio_processor.py deleted file mode 100644 index 6709d0f1e9aa45c418a96dd2b9e9fe6d8443150c..0000000000000000000000000000000000000000 --- a/HF_Deploy/modules/audio_processor.py +++ /dev/null @@ -1,569 +0,0 @@ -""" -Audio Processing Module -Handles audio validation, effects, cleanup, and quality control -""" - -import numpy as np -import soundfile as sf -import logging -import shutil -import re -import time -from pathlib import Path -from pydub import AudioSegment, silence -from config.config import * - -# ============================================================================ -# AUDIO QUALITY DETECTION -# ============================================================================ - -def check_audio_health(wav_path): - """Enhanced audio health checking""" - data, samplerate = sf.read(str(wav_path)) - if len(data.shape) > 1: - data = data[:, 0] # mono only - - clipping = np.mean(np.abs(data) > 0.98) - silence_ratio = np.mean(np.abs(data) < 1e-4) - rms = np.sqrt(np.mean(data**2)) - mean_abs = np.mean(np.abs(data)) - flatness = mean_abs / (rms + 1e-8) - - return { - "clipping_ratio": round(clipping, 4), - "silence_ratio": round(silence_ratio, 4), - "flatness": round(flatness, 4), - } - -def detect_tts_hum_artifact(wav_path): - """ - Detect low-frequency TTS confusion hum using configurable parameters - """ - if not ENABLE_HUM_DETECTION: - return False, {} - - data, sr = sf.read(str(wav_path)) - if data.ndim > 1: - data = data[:, 0] # Mono - - # FFT analysis for frequency content - fft = np.fft.rfft(data) - freqs = np.fft.rfftfreq(len(data), 1/sr) - - # Focus on hum frequency range (configurable at top of file) - hum_mask = (freqs >= HUM_FREQ_MIN) & (freqs <= HUM_FREQ_MAX) - hum_energy = np.sum(np.abs(fft[hum_mask])) - total_energy = np.sum(np.abs(fft)) - - # Check for sustained low-level amplitude (steady hum characteristic) - segment_size = sr // 4 # 250ms segments - segments = [data[i:i+segment_size] for i in range(0, len(data)-segment_size, segment_size)] - - steady_segments = 0 - for segment in segments: - rms = np.sqrt(np.mean(segment**2)) - if HUM_AMPLITUDE_MIN < rms < HUM_AMPLITUDE_MAX: - steady_segments += 1 - - # Calculate hum indicators using configurable thresholds - hum_ratio = hum_energy / (total_energy + 1e-10) - steady_ratio = steady_segments / len(segments) if segments else 0 - - # Detection logic using configurable thresholds - has_hum = (hum_ratio > HUM_ENERGY_THRESHOLD) and (steady_ratio > HUM_STEADY_THRESHOLD) - - if has_hum: - logging.info(f"πŸ” TTS hum detected: {wav_path.name}") - logging.info(f" Frequency range: {HUM_FREQ_MIN}-{HUM_FREQ_MAX}Hz") - logging.info(f" Hum energy ratio: {hum_ratio:.3f} (threshold: {HUM_ENERGY_THRESHOLD})") - logging.info(f" Steady segments: {steady_ratio:.3f} (threshold: {HUM_STEADY_THRESHOLD})") - - return has_hum, { - "hum_ratio": hum_ratio, - "steady_ratio": steady_ratio, - "freq_range": f"{HUM_FREQ_MIN}-{HUM_FREQ_MAX}Hz" - } - -def smart_audio_validation(wav_path): - """Comprehensive audio validation with intelligent responses""" - # Standard health check - health = check_audio_health(wav_path) - - # TTS hum detection (if enabled) - has_hum, hum_metrics = detect_tts_hum_artifact(wav_path) - - # Decision matrix - if health["clipping_ratio"] > 0.05: - return handle_problematic_chunks(wav_path, "clipping", health) - elif health["flatness"] > 0.9: - return handle_problematic_chunks(wav_path, "corrupted", health) - elif has_hum: - return handle_problematic_chunks(wav_path, "tts_hum", hum_metrics) - else: - return wav_path # Passed all checks - -def has_mid_energy_drop(wav_tensor, sr, window_ms=250, threshold_ratio=None): - """Detect mid-chunk energy drops""" - wav = wav_tensor.squeeze().numpy() - win_samples = int(sr * window_ms / 1000) - segments = [wav[i:i+win_samples] for i in range(0, len(wav) - win_samples, win_samples)] - - rms_vals = [np.sqrt(np.mean(seg**2)) for seg in segments] - rms_avg = np.mean(rms_vals) - dynamic_thresh = threshold_ratio or max(0.02, 0.1 if rms_avg < 0.01 else 0.2) - - drop_sequence = 0 - consecutive_required = 2 - - for i, rms in enumerate(rms_vals): - if i < 3: - continue - if rms < rms_avg * dynamic_thresh: - drop_sequence += 1 - if drop_sequence >= consecutive_required: - return True - else: - drop_sequence = 0 - - return False - -# ============================================================================ -# PROBLEMATIC CHUNK HANDLING -# ============================================================================ - -def handle_problematic_chunks(wav_path, issue_type, metrics): - """Handle chunks with audio issues - quarantine for review""" - quarantine_dir = wav_path.parent / "quarantine" - quarantine_dir.mkdir(exist_ok=True) - - # Move to quarantine with descriptive name - quarantine_path = quarantine_dir / f"{wav_path.stem}_{issue_type}.wav" - shutil.move(str(wav_path), str(quarantine_path)) - - # Log for user review - logging.warning(f"🚨 Quarantined {issue_type}: {wav_path.name} β†’ {quarantine_path.name}") - logging.warning(f" Metrics: {metrics}") - - return quarantine_path - -def pause_for_chunk_review(quarantine_dir): - """Pause processing to allow manual chunk review/editing with proper workflow""" - quarantined_files = list(quarantine_dir.glob("*.wav")) - - if not quarantined_files: - return # No quarantined files, continue normally - - print(f"\n⚠️ {len(quarantined_files)} chunks quarantined in: {quarantine_dir}") - print("\nQuarantined chunks:") - for qfile in quarantined_files: - print(f" πŸ“ {qfile.name}") - - print("\nπŸ”§ Options:") - print("1. Continue processing (use quarantined chunks as-is)") - print("2. Pause to manually review/edit chunks") - - while True: - choice = input("\nEnter choice [1/2]: ").strip() - if choice in ['1', '2']: - break - print("❌ Invalid choice. Please enter 1 or 2.") - - if choice == "2": - print(f"\nπŸ›‘ Processing paused for manual review.") - print(f"πŸ“‚ Quarantined chunks are in: {quarantine_dir}") - print("\nπŸ“ Instructions:") - print(" 1. Edit the audio files in the quarantine folder") - print(" 2. Keep the original filenames (chunk numbering intact)") - print(" 3. Leave edited files IN the quarantine folder") - print(" 4. Press Enter below to continue processing") - - input("\n⏸️ Press Enter when you've finished editing...") - - # Verify files still exist after user editing - edited_files = list(quarantine_dir.glob("*.wav")) - if not edited_files: - print("⚠️ No files found in quarantine folder after editing!") - return - - print(f"βœ… Found {len(edited_files)} edited files, continuing...") - - # Move all chunks back to main audio folder (whether edited or not) - moved_count = 0 - for qfile in quarantine_dir.glob("*.wav"): - # Extract original chunk name from quarantine filename - FIXED LINE: - original_name = re.sub(r'_(clipping|corrupted|tts_hum)$', '', qfile.stem) + ".wav" - main_path = qfile.parent.parent / original_name - - try: - shutil.move(str(qfile), str(main_path)) - moved_count += 1 - print(f"↩️ Restored: {original_name}") - except Exception as e: - logging.error(f"❌ Failed to restore {qfile.name}: {e}") - - print(f"\nβœ… Restored {moved_count} chunks to main audio folder") - - # Clean up empty quarantine directory - if not any(quarantine_dir.iterdir()): - quarantine_dir.rmdir() - - return moved_count - -# ============================================================================ -# AUDIO EFFECTS AND PROCESSING -# ============================================================================ - -def detect_end_artifact(wav_path, window_ms=100): - """Enhanced artifact detection""" - data, sr = sf.read(str(wav_path)) - if data.ndim > 1: - data = data[:, 0] - - win_samples = int(window_ms / 1000 * sr) - if len(data) < win_samples * 2: - return False - - end = data[-win_samples:] - middle = data[len(data)//2 : len(data)//2 + win_samples] - - rms_end = np.sqrt(np.mean(end**2)) - rms_mid = np.sqrt(np.mean(middle**2)) + 1e-10 - rms_ratio = rms_end / rms_mid - - zcr = np.mean(np.diff(np.sign(end)) != 0) - - fft = np.fft.rfft(end) - freqs = np.fft.rfftfreq(len(end), 1/sr) - low_band = fft[freqs < 150] - low_energy = np.sum(np.abs(low_band)) / (np.sum(np.abs(fft)) + 1e-10) - - logging.info(f"{GREEN}[DEBUG]{RESET} Artifact metrics - {YELLOW}RMS ratio: {rms_ratio:.3f}{RESET}, " - f"{GREEN}ZCR: {zcr:.3f}{RESET}, {CYAN}LowEnergy: {low_energy:.3f}{RESET}") - - return rms_ratio > 0.6 or zcr > 0.2 or low_energy > 0.4 - -def find_end_of_speech(wav_path, sr=16000): - """Find end of speech using Silero VAD""" - import torch - import os - - # Set environment variables to suppress PyTorch Hub verbosity - old_vars = {} - suppress_vars = { - 'TORCH_HUB_VERBOSE': '0', - 'PYTHONWARNINGS': 'ignore', - 'TF_CPP_MIN_LOG_LEVEL': '3' - } - - # Save old values and set new ones - for key, value in suppress_vars.items(): - old_vars[key] = os.environ.get(key) - os.environ[key] = value - - # Temporarily disable logging for this operation - old_level = logging.getLogger().level - logging.getLogger().setLevel(logging.ERROR) - - try: - model, utils = torch.hub.load( - repo_or_dir='snakers4/silero-vad', - model='silero_vad', - force_reload=False, - verbose=False - ) - (get_speech_timestamps, _, read_audio, _, _) = utils - - wav = read_audio(str(wav_path), sampling_rate=sr) - speech_segments = get_speech_timestamps(wav, model, sampling_rate=sr) - - if not speech_segments: - return None - - last_seg_end = speech_segments[-1]['end'] - return int(last_seg_end * 1000 / sr) - - finally: - # Restore everything - logging.getLogger().setLevel(old_level) - for key, old_value in old_vars.items(): - if old_value is None: - os.environ.pop(key, None) - else: - os.environ[key] = old_value - -def fade_out_wav(wav_path, output_path=None, fade_ms=20): - """Apply fade-out to audio""" - data, sr = sf.read(str(wav_path)) - if data.ndim > 1: - data = data[:, 0] - - fade_samples = int(sr * fade_ms / 1000) - if len(data) < fade_samples: - return - - debug_path = wav_path.parent / f"{wav_path.stem}_pre_fade.wav" - sf.write(str(debug_path), data, sr) - - fade_curve = np.linspace(1.0, 0.0, fade_samples) - data[-fade_samples:] *= fade_curve - - sf.write(str(output_path or wav_path), data, sr) - -def apply_smart_fade(wav_path): - """Apply smart fade with artifact detection""" - eos_ms = find_end_of_speech(wav_path) - - if detect_end_artifact(wav_path): - fade_out_wav(wav_path) - -def apply_smart_fade_memory(audio_segment): - """Apply smart fade with artifact detection - in memory version""" - # For now, apply a gentle fade to all audio to prevent clicks - # TODO: Add proper artifact detection for memory processing - return audio_segment.fade_out(50) # 50ms fade out - -def smart_audio_validation_memory(audio_segment, sample_rate): - """Enhanced audio validation in memory - returns (audio, is_quarantined)""" - # Basic validation - can be enhanced with hum detection later - # For now, just return the audio as-is - is_quarantined = False - - # Could add memory-based hum detection here - # is_quarantined = detect_hum_memory(audio_segment, sample_rate) - - return audio_segment, is_quarantined - -def add_contextual_silence_memory(audio_segment, boundary_type): - """Add appropriate silence based on content boundary type - in memory""" - from pydub import AudioSegment - from config.config import ( - SILENCE_CHAPTER_START, SILENCE_CHAPTER_END, SILENCE_SECTION_BREAK, SILENCE_PARAGRAPH_END, - SILENCE_COMMA, SILENCE_SEMICOLON, SILENCE_COLON, SILENCE_PERIOD, SILENCE_QUESTION_MARK, - SILENCE_EXCLAMATION, SILENCE_DASH, SILENCE_ELLIPSIS, SILENCE_QUOTE_END - ) - - silence_durations = { - # Structural boundaries - "chapter_start": SILENCE_CHAPTER_START, - "chapter_end": SILENCE_CHAPTER_END, - "section_break": SILENCE_SECTION_BREAK, - "paragraph_end": SILENCE_PARAGRAPH_END, - # Punctuation boundaries - "comma": SILENCE_COMMA, - "semicolon": SILENCE_SEMICOLON, - "colon": SILENCE_COLON, - "period": SILENCE_PERIOD, - "question_mark": SILENCE_QUESTION_MARK, - "exclamation": SILENCE_EXCLAMATION, - "dash": SILENCE_DASH, - "ellipsis": SILENCE_ELLIPSIS, - "quote_end": SILENCE_QUOTE_END, - } - - if boundary_type in silence_durations: - duration = silence_durations[boundary_type] - silence_segment = AudioSegment.silent(duration=duration) - return audio_segment + silence_segment - - return audio_segment - -def smart_fade_out(wav_path, silence_thresh_db=-40, min_silence_len=300): - """Smart fade-out for natural audio endings""" - audio = AudioSegment.from_wav(wav_path) - tail_window_ms = 2000 - - if len(audio) < tail_window_ms: - logging.info(f"⚠️ {YELLOW}Skipping fade: {wav_path.name} too short ({len(audio)}ms < {tail_window_ms}ms){RESET}") - return - - tail = audio[-tail_window_ms:] - silent_ranges = silence.detect_silence(tail, min_silence_len=min_silence_len, silence_thresh=silence_thresh_db) - - min_tail_energy = max(tail.get_array_of_samples()) - if not silent_ranges or min_tail_energy > audio.max_possible_amplitude * 0.1: - logging.info(f"βœ… {GREEN}No fade needed for {wav_path.name} (no valid trailing silence){RESET}") - return - - fade_start_ms = silent_ranges[0][0] - fade_length_ms = tail_window_ms - fade_start_ms - - if fade_length_ms < 100: - logging.info(f"βœ… {GREEN}No fade needed for {wav_path.name} (fade too short: {fade_length_ms}ms){RESET}") - return - - fade_start_point = silent_ranges[0][0] - logging.info(f"⚠️ {RED}Fading tail of {wav_path.name} from {fade_start_point}ms to end{RESET}") - faded = audio[:fade_start_point] + audio[fade_start_point:].fade_out(duration=fade_length_ms) - faded.export(wav_path, format="wav") - -# ============================================================================ -# AUDIO TRIMMING -# ============================================================================ - -def trim_audio_endpoint(audio_segment, threshold=None, buffer_ms=None): - """ - Trim audio to the detected end of speech using RMS energy analysis. - - Args: - audio_segment: pydub AudioSegment object - threshold: RMS threshold for speech detection (from config if None) - buffer_ms: Buffer to add after detected endpoint (from config if None) - - Returns: - Trimmed AudioSegment - """ - if threshold is None: - threshold = SPEECH_ENDPOINT_THRESHOLD - if buffer_ms is None: - buffer_ms = TRIMMING_BUFFER_MS - - # Convert to numpy array for analysis - samples = np.array(audio_segment.get_array_of_samples()) - if audio_segment.channels == 2: - samples = samples.reshape((-1, 2)).mean(axis=1) - - # Normalize samples - samples = samples.astype(np.float32) / audio_segment.max_possible_amplitude - - # Calculate RMS in sliding windows (50ms windows) - window_size = int(0.05 * audio_segment.frame_rate) # 50ms - rms_values = [] - - for i in range(0, len(samples) - window_size, window_size // 2): - window = samples[i:i + window_size] - rms = np.sqrt(np.mean(window ** 2)) - rms_values.append(rms) - - # Find actual end of speech using energy decay detection - speech_end_idx = 0 # Default to beginning if no speech found - - # Look for a significant and sustained drop in energy - # Scan backwards to find where energy consistently stays above a higher threshold - strong_speech_threshold = threshold * 3 # 3x threshold for "real" speech - - for i in range(len(rms_values) - 1, -1, -1): - if rms_values[i] > strong_speech_threshold: - # Found strong speech, check if it's sustained - # Look forward to see if energy drops and stays low - sustained_speech = True - windows_ahead = min(10, len(rms_values) - i) # Look ahead up to 10 windows (250ms) - - # Check if most of the next windows have reasonable speech levels - speech_count = 0 - for j in range(i, min(i + windows_ahead, len(rms_values))): - if rms_values[j] > threshold: - speech_count += 1 - - # If this looks like the end of sustained speech content - if speech_count >= max(1, windows_ahead * 0.3): # At least 30% speech in next windows - speech_end_idx = i - break - - # If no strong speech found, fall back to simple threshold method but be conservative - if speech_end_idx == 0: - for i in range(len(rms_values) - 1, -1, -1): - if rms_values[i] > threshold * 2: # Use 2x threshold for fallback - speech_end_idx = i - break - - # Convert back to milliseconds and add buffer - # Convert window index to sample position, then to milliseconds - sample_position = speech_end_idx * (window_size // 2) - speech_end_ms = int(sample_position * 1000 / audio_segment.frame_rate) - trim_point_ms = min(speech_end_ms + buffer_ms, len(audio_segment)) - - return audio_segment[:trim_point_ms] - -def process_audio_with_trimming_and_silence(audio_segment, boundary_type, enable_trimming=None): - """ - Complete audio processing: trim to speech endpoint + add punctuation-based silence. - - Args: - audio_segment: pydub AudioSegment object - boundary_type: Boundary type from text processing - enable_trimming: Whether to trim audio (from config if None) - - Returns: - Processed AudioSegment with trimming and appropriate silence - """ - if enable_trimming is None: - enable_trimming = ENABLE_AUDIO_TRIMMING - - processed_audio = audio_segment - - # Step 1: Trim to speech endpoint if enabled - if enable_trimming: - processed_audio = trim_audio_endpoint(processed_audio) - - # Step 2: Add punctuation-appropriate silence - processed_audio = add_contextual_silence_memory(processed_audio, boundary_type) - - return processed_audio - -# ============================================================================ -# SILENCE AND CONTEXTUAL AUDIO -# ============================================================================ - -def add_contextual_silence(wav_path, boundary_type): - """Add appropriate silence based on content boundary type""" - silence_durations = { - # Structural boundaries - "chapter_start": SILENCE_CHAPTER_START, - "chapter_end": SILENCE_CHAPTER_END, - "section_break": SILENCE_SECTION_BREAK, - "paragraph_end": SILENCE_PARAGRAPH_END, - # Punctuation boundaries - "comma": SILENCE_COMMA, - "semicolon": SILENCE_SEMICOLON, - "colon": SILENCE_COLON, - "period": SILENCE_PERIOD, - "question_mark": SILENCE_QUESTION_MARK, - "exclamation": SILENCE_EXCLAMATION, - "dash": SILENCE_DASH, - "ellipsis": SILENCE_ELLIPSIS, - "quote_end": SILENCE_QUOTE_END, - } - - if boundary_type in silence_durations: - duration = silence_durations[boundary_type] - audio = AudioSegment.from_wav(wav_path) - silence_segment = AudioSegment.silent(duration=duration) - extended_audio = audio + silence_segment - extended_audio.export(wav_path, format="wav") - - logging.info(f"πŸ”‡ Added {duration}ms silence for {boundary_type}: {wav_path.name}") - -def add_chunk_end_silence(wav_path): - """Add configurable silence to end of chunk if enabled""" - if not ENABLE_CHUNK_END_SILENCE or CHUNK_END_SILENCE_MS <= 0: - return - - try: - audio = AudioSegment.from_wav(wav_path) - silence_segment = AudioSegment.silent(duration=CHUNK_END_SILENCE_MS) - audio_with_silence = audio + silence_segment - audio_with_silence.export(wav_path, format="wav") - logging.info(f"βž• Added {CHUNK_END_SILENCE_MS}ms end silence to {wav_path.name}") - except Exception as e: - logging.warning(f"⚠️ Failed to add end silence to {wav_path.name}: {e}") - -# ============================================================================ -# AUDIO UTILITY FUNCTIONS -# ============================================================================ - -def get_wav_duration(wav_path): - """Get WAV file duration""" - import wave - with wave.open(str(wav_path), 'rb') as wf: - frames = wf.getnframes() - rate = wf.getframerate() - return frames / float(rate) - -def get_chunk_audio_duration(wav_path): - """Get actual audio duration from WAV file""" - try: - data, sr = sf.read(str(wav_path)) - return len(data) / sr - except: - # Fallback to wave module - return get_wav_duration(wav_path) diff --git a/HF_Deploy/modules/batch_processor.py b/HF_Deploy/modules/batch_processor.py deleted file mode 100644 index 86e74e9a8155135044ee542af75240e241343c91..0000000000000000000000000000000000000000 --- a/HF_Deploy/modules/batch_processor.py +++ /dev/null @@ -1,31 +0,0 @@ -""" -Batch Processing Module -Handles multi-book batch processing operations -""" - -import torch -from modules.tts_engine import process_book_folder - -def pipeline_book_processing(book_queue): - """Process multiple books in sequence""" - completed_books = [] - device = "cuda" if torch.cuda.is_available() else "cpu" - - for book_info in book_queue: - book_dir = book_info['book_dir'] - voice_path = book_info['voice_path'] - tts_params = book_info['tts_params'] - - print(f"\n🎯 Processing: {book_dir.name}") - - try: - result = process_book_folder(book_dir, voice_path, tts_params, device) - if result[0]: # Check if final_m4b_path exists - completed_books.append(book_info) - print(f"βœ… Completed: {book_dir.name}") - else: - print(f"❌ Failed: {book_dir.name}") - except Exception as e: - print(f"❌ Error processing {book_dir.name}: {e}") - - return completed_books \ No newline at end of file diff --git a/HF_Deploy/modules/file_manager.py b/HF_Deploy/modules/file_manager.py deleted file mode 100644 index d4c33fd4fed1a2ab769cce0e919e76423b5cbea1..0000000000000000000000000000000000000000 --- a/HF_Deploy/modules/file_manager.py +++ /dev/null @@ -1,431 +0,0 @@ -""" -File Manager Module -Handles I/O operations, M4B conversion, metadata, and FFmpeg operations -""" - -import subprocess -import soundfile as sf -import os -import re -import time -import logging -from pathlib import Path -from config.config import * - -# ============================================================================ -# VOICE SAMPLE MANAGEMENT -# ============================================================================ - -def list_voice_samples(): - """List available voice samples""" - return sorted(VOICE_SAMPLES_DIR.glob("*.wav")) - -def ensure_voice_sample_compatibility(input_path, output_dir=None): - """Ensure voice sample is compatible with TTS (24kHz mono)""" - input_path = str(input_path) - ext = os.path.splitext(input_path)[1].lower() - basename = os.path.splitext(os.path.basename(input_path))[0] - output_dir = output_dir or os.path.dirname(input_path) - output_path = os.path.join(output_dir, basename + "_ttsready.wav") - - try: - info = sf.info(input_path) - if (ext == '.wav' and info.samplerate == 24000 and info.channels == 1): - return input_path - except Exception: - pass - - cmd = [ - "ffmpeg", "-y", - "-i", input_path, - "-ar", "24000", - "-ac", "1", - output_path - ] - subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - return output_path - -# ============================================================================ -# FFMPEG OPERATIONS -# ============================================================================ - -def run_ffmpeg(cmd): - """Run FFmpeg command with error handling""" - try: - subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - except subprocess.CalledProcessError as e: - logging.info(f"FFmpeg command failed: {' '.join(cmd)}") - logging.info(f"Error: {e}") - subprocess.run(cmd) - raise - -# ============================================================================ -# M4B CONVERSION WITH NORMALIZATION -# ============================================================================ - -def convert_to_m4b_with_peak_normalization(wav_path, temp_m4b_path, target_db=-3.0): - """Convert WAV to M4B with peak normalization""" - print("πŸš€ Converting to m4b with peak normalization...") - - # Build audio filter chain - audio_filters = [f"loudnorm=I=-16:TP={target_db}:LRA=11"] - if ATEMPO_SPEED != 1.0: - audio_filters.append(f"atempo={ATEMPO_SPEED}") - - cmd = [ - "ffmpeg", "-y", - "-i", str(wav_path), - "-af", ",".join(audio_filters), - "-c:a", "aac", - str(temp_m4b_path) - ] - - start_time = time.time() - process = subprocess.Popen(cmd, stderr=subprocess.PIPE, text=True) - - audio_secs = 0.0 - for line in process.stderr: - match = re.search(r"time=(\d{2}):(\d{2}):(\d{2})\.(\d{2})", line) - if match: - h, m, s, ms = map(int, match.groups()) - audio_secs = h * 3600 + m * 60 + s + ms / 100 - elapsed = time.time() - start_time - factor = audio_secs / elapsed if elapsed > 0 else 0.0 - print(f"πŸ“Ό FFmpeg (normalizing): {match.group(0)} | {factor:.2f}x realtime", end='\r') - - process.wait() - print("\nβœ… Conversion with normalization complete.") - -def convert_to_m4b_with_loudness_normalization(wav_path, temp_m4b_path): - """Convert WAV to M4B with two-pass loudness normalization""" - import json - - print("πŸš€ Converting to m4b with loudness normalization...") - - # Step 1: Analyze audio loudness - print("πŸ“Š Analyzing audio loudness...") - analyze_cmd = [ - "ffmpeg", "-y", - "-i", str(wav_path), - "-af", "loudnorm=I=-16:TP=-1.5:LRA=11:print_format=json", - "-f", "null", "-" - ] - - result = subprocess.run(analyze_cmd, capture_output=True, text=True) - - # Extract loudness measurements from stderr - loudness_data = None - for line in result.stderr.split('\n'): - if line.strip().startswith('{'): - try: - loudness_data = json.loads(line.strip()) - break - except: - continue - - if not loudness_data: - print("⚠️ Could not analyze loudness, falling back to single-pass...") - return convert_to_m4b_with_peak_normalization(wav_path, temp_m4b_path) - - # Step 2: Apply normalization with measured values - print("πŸ”§ Applying normalization...") - - # Build audio filter chain - audio_filters = [f"loudnorm=I=-16:TP=-1.5:LRA=11:measured_I={loudness_data['input_i']}:measured_LRA={loudness_data['input_lra']}:measured_TP={loudness_data['input_tp']}:measured_thresh={loudness_data['input_thresh']}:offset={loudness_data['target_offset']}:linear=true:print_format=summary"] - if ATEMPO_SPEED != 1.0: - audio_filters.append(f"atempo={ATEMPO_SPEED}") - - cmd = [ - "ffmpeg", "-y", - "-i", str(wav_path), - "-af", ",".join(audio_filters), - "-c:a", "aac", - str(temp_m4b_path) - ] - - start_time = time.time() - process = subprocess.Popen(cmd, stderr=subprocess.PIPE, text=True) - - audio_secs = 0.0 - for line in process.stderr: - match = re.search(r"time=(\d{2}):(\d{2}):(\d{2})\.(\d{2})", line) - if match: - h, m, s, ms = map(int, match.groups()) - audio_secs = h * 3600 + m * 60 + s + ms / 100 - elapsed = time.time() - start_time - factor = audio_secs / elapsed if elapsed > 0 else 0.0 - print(f"πŸ“Ό FFmpeg (normalizing): {match.group(0)} | {factor:.2f}x realtime", end='\r') - - process.wait() - print("\nβœ… Two-pass normalization complete.") - -def convert_to_m4b_with_simple_normalization(wav_path, temp_m4b_path, target_db=-6.0): - """Convert WAV to M4B with simple peak normalization""" - print("πŸš€ Converting to m4b with simple normalization...") - - # Build audio filter chain - audio_filters = [f"volume={target_db}dB"] - if ATEMPO_SPEED != 1.0: - audio_filters.append(f"atempo={ATEMPO_SPEED}") - - cmd = [ - "ffmpeg", "-y", - "-i", str(wav_path), - "-af", ",".join(audio_filters), - "-c:a", "aac", - str(temp_m4b_path) - ] - - start_time = time.time() - process = subprocess.Popen(cmd, stderr=subprocess.PIPE, text=True) - - audio_secs = 0.0 - for line in process.stderr: - match = re.search(r"time=(\d{2}):(\d{2}):(\d{2})\.(\d{2})", line) - if match: - h, m, s, ms = map(int, match.groups()) - audio_secs = h * 3600 + m * 60 + s + ms / 100 - elapsed = time.time() - start_time - factor = audio_secs / elapsed if elapsed > 0 else 0.0 - print(f"πŸ“Ό FFmpeg (normalizing): {match.group(0)} | {factor:.2f}x realtime", end='\r') - - process.wait() - print("\nβœ… Simple normalization complete.") - -def convert_to_m4b(wav_path, temp_m4b_path): - """Convert WAV to M4B with configurable normalization""" - if not ENABLE_NORMALIZATION or NORMALIZATION_TYPE == "none": - # Original function without normalization - print("πŸš€ Converting to m4b...") - - # Build audio filter for atempo if needed - audio_filter = [] - if ATEMPO_SPEED != 1.0: - audio_filter = ["-filter:a", f"atempo={ATEMPO_SPEED}"] - - cmd = [ - "ffmpeg", "-y", - "-i", str(wav_path) - ] + audio_filter + [ - "-c:a", "aac", - str(temp_m4b_path) - ] - - elif NORMALIZATION_TYPE == "loudness": - # EBU R128 loudness normalization (recommended for audiobooks) - return convert_to_m4b_with_loudness_normalization(wav_path, temp_m4b_path) - - elif NORMALIZATION_TYPE == "peak": - # Peak normalization - return convert_to_m4b_with_peak_normalization(wav_path, temp_m4b_path, TARGET_PEAK_DB) - - elif NORMALIZATION_TYPE == "simple": - # Simple volume adjustment - return convert_to_m4b_with_simple_normalization(wav_path, temp_m4b_path, TARGET_PEAK_DB) - - else: - # Fallback to no normalization - # Build audio filter for atempo if needed - audio_filter = [] - if ATEMPO_SPEED != 1.0: - audio_filter = ["-filter:a", f"atempo={ATEMPO_SPEED}"] - - cmd = [ - "ffmpeg", "-y", - "-i", str(wav_path) - ] + audio_filter + [ - "-c:a", "aac", - str(temp_m4b_path) - ] - - # Run the conversion (if not handled by specialized functions above) - start_time = time.time() - process = subprocess.Popen(cmd, stderr=subprocess.PIPE, text=True) - - audio_secs = 0.0 - for line in process.stderr: - match = re.search(r"time=(\d{2}):(\d{2}):(\d{2})\.(\d{2})", line) - if match: - h, m, s, ms = map(int, match.groups()) - audio_secs = h * 3600 + m * 60 + s + ms / 100 - elapsed = time.time() - start_time - factor = audio_secs / elapsed if elapsed > 0 else 0.0 - print(f"πŸ“Ό FFmpeg: {match.group(0)} | {factor:.2f}x realtime", end='\r') - - process.wait() - print("\nβœ… Conversion complete.") - -def add_metadata_to_m4b(temp_m4b_path, final_m4b_path, cover_path=None, nfo_path=None): - """Add metadata and cover to M4B""" - cmd = ["ffmpeg", "-y", "-i", str(temp_m4b_path)] - - if cover_path and cover_path.exists(): - cmd.extend(["-i", str(cover_path), "-map", "0", "-map", "1", "-c", "copy", "-disposition:v:0", "attached_pic"]) - else: - cmd.extend(["-map", "0", "-c", "copy"]) - - if nfo_path and nfo_path.exists(): - with open(nfo_path, 'r', encoding='utf-8') as f: - for line in f: - if ':' in line: - key, val = line.strip().split(':', 1) - cmd.extend(["-metadata", f"{key.strip()}={val.strip()}"]) - - cmd.append(str(final_m4b_path)) - run_ffmpeg(cmd) - temp_m4b_path.unlink(missing_ok=True) - -# ============================================================================ -# FILE UTILITIES -# ============================================================================ - -def chunk_sort_key(f): - """Extracts the chunk number for natural sorting""" - m = re.match(r"chunk_(\d+)\.wav", f.name) - return int(m.group(1)) if m else 0 - -def create_concat_file(chunk_paths, output_path): - """Create FFmpeg concat file for audio chunks""" - with open(output_path, 'w') as f: - for p in chunk_paths: - # Use absolute path to ensure FFmpeg can find the files - f.write(f"file '{str(p.resolve())}'\n") - - logging.info(f"concat.txt written with {len(chunk_paths)} chunks.") - return output_path - -def cleanup_temp_files(directory, patterns): - """Clean up temporary files matching patterns""" - files_cleaned = 0 - for pattern in patterns: - for temp_file in directory.glob(pattern): - temp_file.unlink(missing_ok=True) - files_cleaned += 1 - - return files_cleaned - -# ============================================================================ -# DIRECTORY MANAGEMENT -# ============================================================================ - -def setup_book_directories(book_dir): - """Set up directory structure for book processing""" - basename = book_dir.name - output_root = AUDIOBOOK_ROOT / basename - tts_dir = output_root / "TTS" - text_chunks_dir = tts_dir / "text_chunks" - audio_chunks_dir = tts_dir / "audio_chunks" - - # Create directories - for d in [output_root, tts_dir, text_chunks_dir, audio_chunks_dir]: - d.mkdir(parents=True, exist_ok=True) - - return output_root, tts_dir, text_chunks_dir, audio_chunks_dir - -def find_book_files(book_dir): - """Find text files, cover, and metadata for a book""" - text_files = sorted(book_dir.glob("*.txt")) - nfo_file = book_dir / "book.nfo" - cover_jpg = book_dir / "cover.jpg" - cover_png = book_dir / "cover.png" - cover_file = cover_jpg if cover_jpg.exists() else cover_png if cover_png.exists() else None - - return { - 'text': text_files[0] if text_files else None, - 'cover': cover_file, - 'nfo': nfo_file if nfo_file.exists() else None - } - -# ============================================================================ -# AUDIO FILE OPERATIONS -# ============================================================================ - -def combine_audio_chunks(chunk_paths, output_path): - """Combine audio chunks into single file using FFmpeg""" - concat_list_path = output_path.parent / "concat.txt" - create_concat_file(chunk_paths, concat_list_path) - - run_ffmpeg([ - "ffmpeg", "-y", "-f", "concat", "-safe", "0", - "-i", str(concat_list_path.resolve()), - "-c", "copy", str(output_path.resolve()) - ]) - - return output_path - -def get_audio_files_in_directory(directory, pattern="chunk_*.wav"): - """Get sorted list of audio files matching pattern""" - chunk_paths = sorted([f for f in directory.glob(pattern) - if re.fullmatch(r'chunk_\d{3,}\.wav', f.name)], - key=chunk_sort_key) - return chunk_paths - -# ============================================================================ -# VALIDATION AND VERIFICATION -# ============================================================================ - -def verify_audio_file(wav_path): - """Verify audio file is valid and readable""" - try: - info = sf.info(str(wav_path)) - return info.frames > 0 and info.samplerate > 0 - except Exception as e: - logging.error(f"Invalid audio file {wav_path}: {e}") - return False - -def verify_chunk_completeness(audio_chunks_dir, expected_count): - """Verify all expected chunks exist and are valid""" - missing_chunks = [] - invalid_chunks = [] - - for i in range(1, expected_count + 1): - chunk_path = audio_chunks_dir / f"chunk_{i:05}.wav" - - if not chunk_path.exists(): - missing_chunks.append(i) - elif not verify_audio_file(chunk_path): - invalid_chunks.append(i) - - return missing_chunks, invalid_chunks - -# ============================================================================ -# EXPORT AND IMPORT FUNCTIONS -# ============================================================================ - -def export_processing_log(output_dir, processing_info): - """Export comprehensive processing log""" - log_path = output_dir / "processing_complete.log" - - with open(log_path, 'w', encoding='utf-8') as f: - f.write("GenTTS Processing Complete\n") - f.write("=" * 50 + "\n\n") - - for key, value in processing_info.items(): - f.write(f"{key}: {value}\n") - - return log_path - -def save_chunk_info(text_chunks_dir, chunks_info): - """Save chunk information for debugging/resume""" - info_path = text_chunks_dir / "chunks_info.json" - - import json - with open(info_path, 'w', encoding='utf-8') as f: - json.dump(chunks_info, f, indent=2, ensure_ascii=False) - - return info_path - -def load_chunk_info(text_chunks_dir): - """Load chunk information if available""" - info_path = text_chunks_dir / "chunks_info.json" - - if not info_path.exists(): - return None - - import json - try: - with open(info_path, 'r', encoding='utf-8') as f: - return json.load(f) - except Exception as e: - logging.warning(f"Could not load chunk info: {e}") - return None diff --git a/HF_Deploy/modules/gui_json_generator.py b/HF_Deploy/modules/gui_json_generator.py deleted file mode 100644 index da562b23a2041521e3053bf3472653ac1c64f3d5..0000000000000000000000000000000000000000 --- a/HF_Deploy/modules/gui_json_generator.py +++ /dev/null @@ -1,217 +0,0 @@ -#!/usr/bin/env python3 -""" -GUI JSON Audio Generation Module - -This module provides JSON-to-audiobook generation specifically for GUI use. -It's based on utils/generate_from_json.py but adapted for GUI integration. -""" - -import torch -from pathlib import Path -import sys -from concurrent.futures import ThreadPoolExecutor, as_completed -import time -from datetime import timedelta - -# Add project root to path to allow module imports -project_root = Path(__file__).parent.parent -sys.path.append(str(project_root)) - -from config.config import * -from modules.tts_engine import load_optimized_model, process_one_chunk -from modules.file_manager import setup_book_directories, list_voice_samples, ensure_voice_sample_compatibility -from wrapper.chunk_loader import load_chunks -from src.chatterbox.tts import punc_norm -from modules.progress_tracker import log_chunk_progress, log_run -from tools.combine_only import combine_audio_for_book - - -def generate_audiobook_from_json(json_path, voice_name, temp_setting=None): - """ - Generate complete audiobook from JSON chunks file. - - Args: - json_path (str): Path to the JSON chunks file - voice_name (str): Name of the voice to use (without .wav extension) - temp_setting (float, optional): Temperature override for TTS - - Returns: - tuple: (success: bool, message: str, audiobook_path: str or None) - """ - try: - print(f"🎡 GUI JSON Generator: Starting audiobook generation") - print(f"πŸ“„ JSON file: {json_path}") - print(f"🎀 Voice: {voice_name}") - if temp_setting: - print(f"🌑️ Temperature override: {temp_setting}") - - # Determine book name from JSON path - json_file = Path(json_path) - - # Try to extract book name from path structure - if 'Audiobook' in json_file.parts: - audiobook_index = json_file.parts.index('Audiobook') - if audiobook_index + 1 < len(json_file.parts): - book_name = json_file.parts[audiobook_index + 1] - print(f"πŸ“š Detected book name from path: {book_name}") - else: - raise Exception("Cannot determine book name from Audiobook path") - elif json_file.stem.endswith('_chunks'): - book_name = json_file.stem.replace('_chunks', '') - print(f"πŸ“š Detected book name from filename: {book_name}") - else: - book_name = json_file.stem - print(f"πŸ“š Using filename as book name: {book_name}") - - # Load JSON chunks (READ ONLY - never modify the original) - print(f"πŸ“– Loading chunks from: {json_path}") - all_chunks = load_chunks(str(json_path)) - print(f"βœ… Found {len(all_chunks)} chunks.") - - # Find voice file - voice_files = list_voice_samples() - voice_path = None - for voice_file in voice_files: - if voice_file.stem == voice_name: - voice_path = voice_file - break - - if not voice_path: - available_voices = [vf.stem for vf in voice_files] - return False, f"Voice '{voice_name}' not found. Available: {available_voices}", None - - # Ensure voice compatibility - voice_path = ensure_voice_sample_compatibility(voice_path) - if isinstance(voice_path, str): - voice_path = Path(voice_path) - - print(f"🎀 Using voice: {voice_path.name}") - - # Setup device - if torch.cuda.is_available(): - device = "cuda" - elif torch.backends.mps.is_available(): - device = "mps" - else: - device = "cpu" - - print(f"πŸš€ Using device: {device}") - - # Load TTS model - print(f"πŸ€– Loading TTS model...") - model = load_optimized_model(device) - - # Prepare voice conditionals - print(f"🎀 Preparing voice conditionals...") - model.prepare_conditionals(voice_path) - - # Setup output directories - output_root = AUDIOBOOK_ROOT / book_name - tts_dir = output_root / "TTS" - text_chunks_dir = tts_dir / "text_chunks" - audio_chunks_dir = tts_dir / "audio_chunks" - - # Create directories - for dir_path in [output_root, tts_dir, text_chunks_dir, audio_chunks_dir]: - dir_path.mkdir(parents=True, exist_ok=True) - - # Clean existing audio chunks - print("🧹 Clearing old audio chunks...") - for wav_file in audio_chunks_dir.glob("*.wav"): - wav_file.unlink() - - # Process chunks - start_time = time.time() - total_chunks = len(all_chunks) - log_path = output_root / "gui_json_generation.log" - - print(f"πŸ”„ Generating {total_chunks} audio chunks...") - - with ThreadPoolExecutor(max_workers=2) as executor: - futures = [] - for i, chunk_data in enumerate(all_chunks): - # Use chunk's TTS params, with temperature override if provided - chunk_tts_params = chunk_data.get("tts_params", {}).copy() - if temp_setting is not None: - chunk_tts_params["temperature"] = temp_setting - - # Ensure required TTS params exist - chunk_tts_params.setdefault("exaggeration", DEFAULT_EXAGGERATION) - chunk_tts_params.setdefault("cfg_weight", DEFAULT_CFG_WEIGHT) - chunk_tts_params.setdefault("temperature", DEFAULT_TEMPERATURE) - - future = executor.submit( - process_one_chunk, - i, chunk_data['text'], text_chunks_dir, audio_chunks_dir, - voice_path, chunk_tts_params, start_time, total_chunks, - punc_norm, book_name, log_run, log_path, device, - model, None, all_chunks, chunk_data.get('boundary_type', 'none') - ) - futures.append(future) - - # Wait for all chunks to complete - completed_chunks = 0 - for future in as_completed(futures): - try: - result = future.result() - if result: - idx, _ = result - completed_chunks += 1 - log_chunk_progress(idx, total_chunks, start_time, 0) - print(f"βœ… Completed chunk {completed_chunks}/{total_chunks}") - except Exception as e: - print(f"❌ Error processing chunk: {e}") - - elapsed_time = time.time() - start_time - print(f"βœ… Audio generation complete in {timedelta(seconds=int(elapsed_time))}") - print(f"πŸ”Š Audio chunks generated in: {audio_chunks_dir}") - - # Combine chunks into final audiobook - print("πŸ”— Combining audio chunks into final audiobook...") - try: - success = combine_audio_for_book(str(output_root), voice_name) - if success: - # Look for the created audiobook file with voice name - final_m4b = output_root / f"{book_name} [{voice_name}].m4b" - if final_m4b.exists(): - print(f"πŸŽ‰ Audiobook created: {final_m4b.name}") - return True, "Audiobook generation completed successfully", str(final_m4b) - else: - return False, "Combine succeeded but final audiobook file not found", None - else: - return False, "Failed to combine audio chunks", None - except Exception as e: - return False, f"Error combining audio chunks: {e}", None - - except Exception as e: - error_msg = f"JSON generation error: {e}" - print(f"❌ {error_msg}") - return False, error_msg, None - - -def get_book_name_from_json_path(json_path): - """ - Extract book name from JSON file path. - - Args: - json_path (str): Path to JSON file - - Returns: - str: Detected book name - """ - json_file = Path(json_path) - - if 'Audiobook' in json_file.parts: - audiobook_index = json_file.parts.index('Audiobook') - if audiobook_index + 1 < len(json_file.parts): - return json_file.parts[audiobook_index + 1] - - if json_file.stem.endswith('_chunks'): - return json_file.stem.replace('_chunks', '') - - return json_file.stem - - -if __name__ == "__main__": - # CLI compatibility for testing - print("GUI JSON Generator - use from GUI or import as module") \ No newline at end of file diff --git a/HF_Deploy/modules/path_manager.py b/HF_Deploy/modules/path_manager.py deleted file mode 100644 index 0ab4c5f426f0f99d155268028909d7b859e8ded0..0000000000000000000000000000000000000000 --- a/HF_Deploy/modules/path_manager.py +++ /dev/null @@ -1,19 +0,0 @@ -from pathlib import Path -from config.config import AUDIOBOOK_ROOT - -def get_book_paths(book_name): - """Return standardized paths for a given book name""" - base = AUDIOBOOK_ROOT / book_name - tts_dir = base / "TTS" - return { - "book_folder": base, - "tts_dir": tts_dir, - "text_chunks": tts_dir / "text_chunks", - "audio_chunks": tts_dir / "audio_chunks", - "combined_wav": base / f"{book_name}.wav", - "final_m4b": base / f"{book_name}.m4b", - "concat_list": tts_dir / "audio_chunks" / "concat.txt", - "quarantine": tts_dir / "audio_chunks" / "quarantine", - "run_log": base / "run.log", - "chunk_log": base / "chunk_validation.log" - } diff --git a/HF_Deploy/modules/progress_tracker.py b/HF_Deploy/modules/progress_tracker.py deleted file mode 100644 index 5f73b45dca773ef96f2cb519059db8629a6ed9fd..0000000000000000000000000000000000000000 --- a/HF_Deploy/modules/progress_tracker.py +++ /dev/null @@ -1,306 +0,0 @@ -""" -Progress Tracker Module -Handles progress display, VRAM monitoring, logging systems, and performance tracking -""" - -import time -import sys -import logging -from datetime import timedelta -from pathlib import Path -from config.config import * - -# ============================================================================ -# LOGGING SETUP -# ============================================================================ - -def setup_logging(log_dir): - """Setup logging configuration""" - log_file = log_dir / "chunk_validation.log" - - # Clear existing log - open(log_file, 'w').close() - - logging.basicConfig( - filename=str(log_file), - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", - filemode='w' # Overwrite existing log - ) - - # Also log to console for important messages - console_handler = logging.StreamHandler() - console_handler.setLevel(logging.WARNING) - formatter = logging.Formatter('%(levelname)s - %(message)s') - console_handler.setFormatter(formatter) - logging.getLogger().addHandler(console_handler) - -def log_console(message, color=None): - """Log to both console and file with optional color""" - color_codes = { - "RED": RED, "GREEN": GREEN, "YELLOW": YELLOW, - "CYAN": CYAN, "BOLD": BOLD, "RESET": RESET - } - - prefix = color_codes.get(color, "") - suffix = RESET if color else "" - - print(f"{prefix}{message}{suffix}") - logging.info(message) - -def log_run(message, log_path): - """Log to run file""" - with open(log_path, "a", encoding="utf-8") as logf: - logf.write(message + "\n") - -# ============================================================================ -# PROGRESS TRACKING -# ============================================================================ - -def log_chunk_progress(i, total_chunks, start_time, total_audio_duration=0.0): - """Enhanced progress logging with accurate realtime factor""" - elapsed = time.time() - start_time - avg_time = elapsed / (i + 1) - eta = avg_time * total_chunks - remaining = eta - elapsed - - def fmt(seconds): - return str(timedelta(seconds=int(seconds))) - - # Show VRAM usage in progress - allocated, _ = monitor_vram_usage("chunk_progress") - - # Calculate ACCURATE realtime factor using actual audio duration - if total_audio_duration > 0 and elapsed > 0: - actual_realtime = total_audio_duration / elapsed - realtime_str = f"{GREEN}{actual_realtime:.2f}x{RESET}" - audio_str = f" | Audio: {GREEN}{fmt(total_audio_duration)}{RESET}" - else: - actual_realtime = 0.0 # Default value when calculating - realtime_str = f"{YELLOW}Calculating...{RESET}" - audio_str = "" - - # Force immediate output with explicit flushing - progress_msg = (f"\nπŸŒ€ Chunk {i+1}/{total_chunks} | ⏱ Elapsed: {CYAN}{fmt(elapsed)}{RESET} | " - f"ETA: {CYAN}{fmt(eta)}{RESET} | Remaining: {YELLOW}{fmt(remaining)}{RESET} | " - f"Realtime: {realtime_str} | VRAM: {GREEN}{allocated:.1f}GB{RESET}{audio_str}") - - print(progress_msg) - sys.stdout.flush() # Force immediate output - - # Create clean status message for GUI (without ANSI color codes) - realtime_display = f"{actual_realtime:.2f}x" if actual_realtime > 0 else "Calculating..." - clean_status = (f"Elapsed: {fmt(elapsed)} | ETA: {fmt(eta)} | Remaining: {fmt(remaining)} | " - f"Realtime: {realtime_display} | VRAM: {allocated:.1f}GB" + - (f" | Audio: {fmt(total_audio_duration)}" if total_audio_duration > 0 else "")) - - # Emit status to GUI if callback is available - if hasattr(log_chunk_progress, '_status_callback') and log_chunk_progress._status_callback: - log_chunk_progress._status_callback(clean_status) - - # Also log to file for debugging - realtime_log = f"{actual_realtime:.2f}x" if actual_realtime > 0 else "N/A" - logging.info(f"Progress: Chunk {i+1}/{total_chunks}, Elapsed: {fmt(elapsed)}, " - f"ETA: {fmt(eta)}, Realtime: {realtime_log}, " - f"Audio Duration: {fmt(total_audio_duration)}, VRAM: {allocated:.1f}GB") - -def display_batch_progress(batch_start, batch_end, total_chunks): - """Display batch processing progress""" - batch_progress = (batch_end / total_chunks) * 100 - print(f"\nπŸ“Š Batch Progress: {batch_start+1}-{batch_end}/{total_chunks} ({batch_progress:.1f}%)") - -def display_final_summary(elapsed_time, audio_duration, chunk_count, realtime_factor): - """Display final processing summary""" - elapsed_td = timedelta(seconds=int(elapsed_time)) - audio_td = timedelta(seconds=int(audio_duration)) - - print(f"\nπŸŽ‰ {GREEN}Processing Complete!{RESET}") - print(f"πŸ“Š Final Statistics:") - print(f" ⏱️ Processing Time: {CYAN}{elapsed_td}{RESET}") - print(f" 🎡 Audio Duration: {GREEN}{audio_td}{RESET}") - print(f" πŸ“¦ Total Chunks: {YELLOW}{chunk_count}{RESET}") - print(f" πŸš€ Realtime Factor: {BOLD}{realtime_factor:.2f}x{RESET}") - print(f" πŸ’Ύ Memory Efficiency: {GREEN}Optimized{RESET}") - -# ============================================================================ -# VRAM AND PERFORMANCE MONITORING -# ============================================================================ - -def monitor_vram_usage(operation_name=""): - """Real-time VRAM monitoring with threshold warnings""" - import torch - - if not torch.cuda.is_available(): - return 0, 0 - - allocated = torch.cuda.memory_allocated() / 1024**3 - reserved = torch.cuda.memory_reserved() / 1024**3 - - if allocated > VRAM_SAFETY_THRESHOLD: - logging.warning(f"⚠️ High VRAM usage during {operation_name}: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved") - # Trigger memory optimization if available - optimize_memory_if_needed() - - return allocated, reserved - -def monitor_gpu_utilization(): - """Monitor GPU utilization if pynvml is available""" - try: - import pynvml - pynvml.nvmlInit() - handle = pynvml.nvmlDeviceGetHandleByIndex(0) - util = pynvml.nvmlDeviceGetUtilizationRates(handle) - temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU) - - return { - "gpu_util": util.gpu, - "memory_util": util.memory, - "temperature": temp - } - except: - return {"gpu_util": "N/A", "memory_util": "N/A", "temperature": "N/A"} - -def optimize_memory_if_needed(): - """Trigger memory optimization when thresholds are exceeded""" - try: - # Try to use the enhanced CUDA memory optimization if available - from modules.tts_engine import optimize_cuda_memory_usage - optimize_cuda_memory_usage() - except ImportError: - # Fallback to basic optimization - import torch - import gc - torch.cuda.empty_cache() - gc.collect() - if torch.cuda.is_available(): - torch.cuda.ipc_collect() - -def display_system_info(): - """Display system information at startup""" - import torch - - print(f"\nπŸ–₯️ {CYAN}System Information:{RESET}") - - # CUDA info - if torch.cuda.is_available(): - gpu_name = torch.cuda.get_device_name(0) - total_vram = torch.cuda.get_device_properties(0).total_memory / 1024**3 - print(f" GPU: {GREEN}{gpu_name}{RESET}") - print(f" VRAM: {GREEN}{total_vram:.1f}GB{RESET}") - print(f" CUDA Version: {GREEN}{torch.version.cuda}{RESET}") - else: - print(f" GPU: {RED}Not Available{RESET}") - - # Memory threshold - print(f" VRAM Safety Threshold: {YELLOW}{VRAM_SAFETY_THRESHOLD}GB{RESET}") - - # Worker configuration - print(f" Max Workers: {YELLOW}{MAX_WORKERS}{RESET}") - print(f" Dynamic Workers: {YELLOW}{USE_DYNAMIC_WORKERS}{RESET}") - -# ============================================================================ -# PERFORMANCE TRACKING -# ============================================================================ - -class PerformanceTracker: - """Track performance metrics throughout processing""" - - def __init__(self): - self.start_time = time.time() - self.chunk_times = [] - self.vram_usage = [] - self.batch_times = [] - - def log_chunk_completion(self, chunk_index, audio_duration): - """Log individual chunk completion""" - current_time = time.time() - chunk_time = current_time - (self.start_time + sum(self.chunk_times)) - - self.chunk_times.append(chunk_time) - - # Track VRAM - allocated, reserved = monitor_vram_usage() - self.vram_usage.append((chunk_index, allocated, reserved)) - - def log_batch_completion(self, batch_size): - """Log batch completion""" - if len(self.chunk_times) >= batch_size: - batch_time = sum(self.chunk_times[-batch_size:]) - self.batch_times.append(batch_time) - - def get_performance_summary(self): - """Get comprehensive performance summary""" - total_time = time.time() - self.start_time - avg_chunk_time = sum(self.chunk_times) / len(self.chunk_times) if self.chunk_times else 0 - - vram_peak = max([usage[1] for usage in self.vram_usage]) if self.vram_usage else 0 - vram_avg = sum([usage[1] for usage in self.vram_usage]) / len(self.vram_usage) if self.vram_usage else 0 - - return { - "total_time": total_time, - "avg_chunk_time": avg_chunk_time, - "total_chunks": len(self.chunk_times), - "vram_peak": vram_peak, - "vram_average": vram_avg, - "batch_count": len(self.batch_times) - } - -# ============================================================================ -# ERROR AND WARNING TRACKING -# ============================================================================ - -def log_processing_error(chunk_id, error_message, error_type="GENERAL"): - """Log processing errors with categorization""" - timestamp = time.strftime('%Y-%m-%d %H:%M:%S') - error_log = f"[{timestamp}] {error_type} ERROR - Chunk {chunk_id}: {error_message}" - - logging.error(error_log) - print(f"{RED}❌ Error in chunk {chunk_id}: {error_message}{RESET}") - -def log_processing_warning(chunk_id, warning_message, warning_type="GENERAL"): - """Log processing warnings with categorization""" - timestamp = time.strftime('%Y-%m-%d %H:%M:%S') - warning_log = f"[{timestamp}] {warning_type} WARNING - Chunk {chunk_id}: {warning_message}" - - logging.warning(warning_log) - print(f"{YELLOW}⚠️ Warning in chunk {chunk_id}: {warning_message}{RESET}") - -# ============================================================================ -# REAL-TIME STATUS DISPLAY -# ============================================================================ - -def create_status_line(current_chunk, total_chunks, elapsed_time, realtime_factor, vram_usage): - """Create a single-line status for real-time updates""" - progress_percent = (current_chunk / total_chunks) * 100 - elapsed_str = str(timedelta(seconds=int(elapsed_time))) - - status = (f"πŸ”„ {current_chunk}/{total_chunks} ({progress_percent:.1f}%) | " - f"⏱️ {elapsed_str} | πŸš€ {realtime_factor:.2f}x | πŸ’Ύ {vram_usage:.1f}GB") - - return status - -def update_status_line(status_message): - """Update status line in place""" - print(f"\r{status_message}", end='', flush=True) - -# ============================================================================ -# EXPORT FUNCTIONS -# ============================================================================ - -def export_performance_report(output_dir, performance_data): - """Export detailed performance report""" - report_path = output_dir / "performance_report.txt" - - with open(report_path, 'w', encoding='utf-8') as f: - f.write("GenTTS Performance Report\n") - f.write("=" * 50 + "\n\n") - - f.write(f"Processing Summary:\n") - f.write(f" Total Processing Time: {timedelta(seconds=int(performance_data['total_time']))}\n") - f.write(f" Average Chunk Time: {performance_data['avg_chunk_time']:.2f}s\n") - f.write(f" Total Chunks Processed: {performance_data['total_chunks']}\n") - f.write(f" Peak VRAM Usage: {performance_data['vram_peak']:.2f}GB\n") - f.write(f" Average VRAM Usage: {performance_data['vram_average']:.2f}GB\n") - f.write(f" Batch Count: {performance_data['batch_count']}\n") - - return report_path diff --git a/HF_Deploy/modules/resume_handler.py b/HF_Deploy/modules/resume_handler.py deleted file mode 100644 index 950f71f1ffb3f20d0c3be4b12af07d50a9b44b20..0000000000000000000000000000000000000000 --- a/HF_Deploy/modules/resume_handler.py +++ /dev/null @@ -1,596 +0,0 @@ -""" -Resume Handler Module -Handles resume functionality for interrupted processing -""" - -import torch -import time -import logging -from datetime import timedelta -from pathlib import Path - -from config.config import * -from modules.text_processor import smart_punctuate, sentence_chunk_text -from modules.file_manager import ( - setup_book_directories, find_book_files, list_voice_samples, - ensure_voice_sample_compatibility, get_audio_files_in_directory, - combine_audio_chunks, convert_to_m4b, add_metadata_to_m4b -) -from modules.audio_processor import get_chunk_audio_duration, pause_for_chunk_review -from modules.progress_tracker import setup_logging, log_chunk_progress, log_run - -def analyze_existing_chunks(audio_chunks_dir): - """Analyze existing chunks to determine resume point""" - if not audio_chunks_dir.exists(): - return 0, [] - - chunk_paths = get_audio_files_in_directory(audio_chunks_dir) - - if not chunk_paths: - return 0, [] - - # Find the highest chunk number - chunk_numbers = [] - for chunk_path in chunk_paths: - import re - match = re.match(r"chunk_(\d+)\.wav", chunk_path.name) - if match: - chunk_numbers.append(int(match.group(1))) - - if not chunk_numbers: - return 0, [] - - chunk_numbers.sort() - last_chunk_number = max(chunk_numbers) - - # Check for gaps in sequence - missing_chunks = [] - for i in range(1, last_chunk_number + 1): - if i not in chunk_numbers: - missing_chunks.append(i) - - print(f"πŸ“Š Existing chunks analysis:") - print(f" Total chunks found: {GREEN}{len(chunk_numbers)}{RESET}") - print(f" Highest chunk number: {GREEN}{last_chunk_number}{RESET}") - if missing_chunks: - print(f" Missing chunks: {YELLOW}{len(missing_chunks)}{RESET}") - if len(missing_chunks) <= 10: - print(f" Missing: {missing_chunks}") - else: - print(f" Missing: {missing_chunks[:10]}... (+{len(missing_chunks)-10} more)") - - return last_chunk_number, missing_chunks - -def suggest_resume_point(last_chunk, missing_chunks): - """Suggest optimal resume point based on existing chunks""" - if not missing_chunks: - # No gaps, can resume from next chunk - return last_chunk + 1 - - # If there are missing chunks, suggest resuming from first missing - first_missing = min(missing_chunks) - - print(f"\nπŸ’‘ Resume suggestions:") - print(f" Resume from chunk {GREEN}{last_chunk + 1}{RESET} (continue from last)") - print(f" Resume from chunk {YELLOW}{first_missing}{RESET} (fill gaps first)") - - return first_missing - -def validate_resume_point(start_chunk, total_expected_chunks): - """Validate that resume point makes sense""" - if start_chunk < 1: - print(f"{RED}❌ Invalid resume point: {start_chunk}. Must be >= 1{RESET}") - return False - - if start_chunk > total_expected_chunks: - print(f"{RED}❌ Resume point {start_chunk} exceeds expected total chunks {total_expected_chunks}{RESET}") - return False - - return True - -def process_book_folder_resume(book_dir, voice_path, tts_params, device, start_chunk=1): - """Enhanced book processing with resume capability""" - from modules.tts_engine import process_one_chunk, load_optimized_model, get_optimal_workers - from src.chatterbox.tts import punc_norm - from concurrent.futures import ThreadPoolExecutor, as_completed - - # Setup directories - output_root, tts_dir, text_chunks_dir, audio_chunks_dir = setup_book_directories(book_dir) - - # Find book files - book_files = find_book_files(book_dir) - text_file = book_files['text'] - cover_file = book_files['cover'] - nfo_file = book_files['nfo'] - - if not text_file: - logging.info(f"[{book_dir.name}] ERROR: No .txt files found in the book folder.") - return None, None, [] - - text_files = [text_file] # Convert to list for compatibility - - # IMPORTANT: Don't delete existing directories if resuming - print(f"πŸ” DEBUG: start_chunk = {start_chunk}") - if start_chunk == 1: - print(f"⚠️ WARNING: start_chunk is 1 - this will clear existing chunks!") - print(f"πŸ“ About to clear: {audio_chunks_dir}") - - # Only clear on fresh start - import shutil - for d in [text_chunks_dir, audio_chunks_dir]: - if d.exists() and d.is_dir(): - print(f"πŸ—‘οΈ CLEARING DIRECTORY: {d}") - shutil.rmtree(d) - - for d in [output_root, tts_dir, text_chunks_dir, audio_chunks_dir]: - d.mkdir(parents=True, exist_ok=True) - else: - print(f"βœ… RESUME MODE: Preserving existing chunks in {audio_chunks_dir}") - # Ensure directories exist for resume - for d in [output_root, tts_dir, text_chunks_dir, audio_chunks_dir]: - d.mkdir(parents=True, exist_ok=True) - - setup_logging(output_root) - - # Load existing chunks from JSON (resume should use preprocessed data) - from modules.tts_engine import find_chunks_json_file - - json_file = find_chunks_json_file(book_dir.name) - if json_file: - print(f"πŸ“– Loading preprocessed chunks from: {json_file.name}") - from wrapper.chunk_loader import load_chunks - all_chunks = load_chunks(str(json_file)) - print(f"βœ… Loaded {len(all_chunks)} chunks with metadata") - else: - print(f"❌ No preprocessed chunks found for {book_dir.name}") - print(f"πŸ’‘ Use Option 1 to process this book from the beginning first.") - return None, None, [] - - # Validate resume point - if not validate_resume_point(start_chunk, len(all_chunks)): - return None, None, [] - - # Filter chunks to process (resume logic) - if start_chunk > 1: - print(f"πŸ”„ Resuming from chunk {start_chunk}") - print(f"πŸ“Š Skipping chunks 1-{start_chunk-1} (already completed)") - - # Check which chunks already exist - existing_chunks = [] - for i in range(start_chunk-1): - chunk_path = audio_chunks_dir / f"chunk_{i+1:05}.wav" - if chunk_path.exists(): - existing_chunks.append(i+1) - - print(f"βœ… Found {len(existing_chunks)} existing chunks") - - # Only process remaining chunks - chunks_to_process = all_chunks[start_chunk-1:] - chunk_offset = start_chunk - 1 - else: - chunks_to_process = all_chunks - chunk_offset = 0 - - run_log_lines = [ - f"\n===== RESUME Processing: {book_dir.name} =====", - f"Voice: {voice_path.name}", - f"Started: {time.strftime('%Y-%m-%d %H:%M:%S')}", - f"Resume from chunk: {start_chunk}", - f"Text files processed: {len(text_files)}", - f"Total chunks generated: {len(all_chunks)}", - f"Chunks to process: {len(chunks_to_process)}" - ] - - # Write initial run info immediately - initial_log = run_log_lines + [ - f"--- Generation Settings ---", - f"Batch Processing: Enabled ({BATCH_SIZE} chunks per batch)", - f"ASR Enabled: {ENABLE_ASR}", - f"Hum Detection: {ENABLE_HUM_DETECTION}", - f"Dynamic Workers: {USE_DYNAMIC_WORKERS}", - f"Voice used: {voice_path.name}", - f"Exaggeration: {tts_params['exaggeration']}", - f"CFG weight: {tts_params['cfg_weight']}", - f"Temperature: {tts_params['temperature']}", - f"Processing Status: IN PROGRESS...", - f"="*50 - ] - - log_run("\n".join(initial_log), output_root / "run.log") - print(f"πŸ“ Initial run info written to: {output_root / 'run.log'}") - - start_time = time.time() - total_chunks = len(all_chunks) - remaining_chunks = len(chunks_to_process) - log_path = output_root / "chunk_validation.log" - - # Calculate existing audio duration for accurate progress - total_audio_duration = 0.0 - if start_chunk > 1: - print("πŸ“Š Calculating existing audio duration...") - for i in range(start_chunk-1): - chunk_path = audio_chunks_dir / f"chunk_{i+1:05}.wav" - if chunk_path.exists(): - total_audio_duration += get_chunk_audio_duration(chunk_path) - print(f"πŸ“Š Existing audio: {timedelta(seconds=int(total_audio_duration))}") - - # Initialize performance optimizations - from modules.tts_engine import detect_deployment_environment, enable_gpu_persistence_mode - deployment_env = detect_deployment_environment() - print(f"🌍 Deployment environment: {deployment_env}") - - # Enable GPU persistence mode for better performance - gpu_persistence_enabled = enable_gpu_persistence_mode() - - # Batch processing for remaining chunks - print(f"πŸ“Š Processing {remaining_chunks} remaining chunks in batches of {BATCH_SIZE}") - - all_results = [] - - for batch_start in range(0, remaining_chunks, BATCH_SIZE): - batch_end = min(batch_start + BATCH_SIZE, remaining_chunks) - batch_chunks = chunks_to_process[batch_start:batch_end] - - actual_start_chunk = chunk_offset + batch_start + 1 - actual_end_chunk = chunk_offset + batch_end - - print(f"\nπŸ”„ Processing batch: chunks {actual_start_chunk}-{actual_end_chunk}") - - # Fresh model for each batch - model = load_optimized_model(device) - compatible_voice = ensure_voice_sample_compatibility(voice_path, output_dir=tts_dir) - - # Pre-warm model to eliminate first chunk quality variations - from modules.tts_engine import prewarm_model_with_voice - model = prewarm_model_with_voice(model, compatible_voice, tts_params) - - # Load ASR model once per batch if needed using adaptive manager - asr_model = None - asr_device_used = None - if ENABLE_ASR: - from modules.asr_manager import load_asr_model_adaptive - print(f"🎀 Loading ASR model for resume mode...") - # Resume mode uses fallback config (no intelligent selection) - asr_model, asr_device_used = load_asr_model_adaptive() - - futures = [] - batch_results = [] - - # Dynamic worker allocation - optimal_workers = get_optimal_workers() - print(f"πŸ”§ Using {optimal_workers} workers for batch {actual_start_chunk}-{actual_end_chunk}") - - # Try producer-consumer pipeline first (Phase 4 optimization) - batch_results = [] - if ENABLE_PRODUCER_CONSUMER_PIPELINE: - try: - print(f"πŸš€ Attempting producer-consumer pipeline for resume batch {actual_start_chunk}-{actual_end_chunk}") - from modules.tts_engine import process_chunks_with_pipeline - pipeline_results = process_chunks_with_pipeline( - all_chunks, batch_chunks, chunk_offset, text_chunks_dir, audio_chunks_dir, - voice_path, tts_params, start_time, total_chunks, punc_norm, book_dir.name, - log_run, log_path, device, model, asr_model, True, optimal_workers, # asr_enabled=True for resume - total_audio_duration # Pass accumulated duration for proper ETA calculation - ) - - # Handle tuple return from pipeline - if isinstance(pipeline_results, tuple) and len(pipeline_results) == 2: - batch_results, batch_audio_duration = pipeline_results - total_audio_duration += batch_audio_duration - else: - # Fallback for old return format - batch_results = pipeline_results - - if batch_results: - print(f"βœ… Producer-consumer pipeline completed resume batch: {len(batch_results)} chunks") - # Pipeline already handled progress logging internally - - except Exception as e: - logging.error(f"❌ Producer-consumer pipeline failed in resume: {e}") - if not ENABLE_PIPELINE_FALLBACK: - raise - batch_results = [] # Clear failed results - - # Fallback to original sequential processing if pipeline disabled or failed - if not batch_results: - print(f"πŸ”„ Using sequential processing fallback for resume batch {actual_start_chunk}-{actual_end_chunk}") - futures = [] - - with ThreadPoolExecutor(max_workers=optimal_workers) as executor: - for i, chunk_data in enumerate(batch_chunks): - global_chunk_index = chunk_offset + i - - # Check for shutdown request - if shutdown_requested: - print(f"\n⏹️ {YELLOW}Stopping submission of new chunks...{RESET}") - break - - chunk = chunk_data["text"] - all_chunk_texts = [cd["text"] for cd in all_chunks] - boundary_type = chunk_data.get("boundary_type", "none") - - futures.append(executor.submit( - process_one_chunk, - global_chunk_index, chunk, text_chunks_dir, audio_chunks_dir, - voice_path, tts_params, start_time, total_chunks, - punc_norm, book_dir.name, log_run, log_path, device, - model, asr_model, all_chunk_texts, boundary_type - )) - - # Wait for batch to complete - print(f"πŸ”„ {CYAN}Waiting for batch {actual_start_chunk}-{actual_end_chunk} to complete...{RESET}") - completed_count = 0 - - for fut in as_completed(futures): - try: - idx, wav_path = fut.result() - if wav_path and wav_path.exists(): - # Measure actual audio duration for this chunk - chunk_duration = get_chunk_audio_duration(wav_path) - total_audio_duration += chunk_duration - batch_results.append((idx, wav_path)) - - # Update progress every 10 chunks within batch - completed_count += 1 - if completed_count % 10 == 0: - current_chunk = chunk_offset + completed_count - log_chunk_progress(current_chunk - 1, total_chunks, start_time, total_audio_duration) - - except Exception as e: - logging.error(f"Future failed in batch: {e}") - - # Clean up model after batch - print(f"🧹 Cleaning up after batch {actual_start_chunk}-{actual_end_chunk}") - del model - if asr_model: - from modules.asr_manager import cleanup_asr_model - cleanup_asr_model(asr_model) - torch.cuda.empty_cache() - import gc - gc.collect() - time.sleep(2) - - all_results.extend(batch_results) - print(f"βœ… Batch {actual_start_chunk}-{actual_end_chunk} completed ({len(batch_results)} chunks)") - - # Final processing - combine ALL chunks (existing + new) - quarantine_dir = audio_chunks_dir / "quarantine" - pause_for_chunk_review(quarantine_dir) - - # Collect ALL chunk paths (both existing and newly created) - chunk_paths = [] - for i in range(total_chunks): - chunk_path = audio_chunks_dir / f"chunk_{i+1:05}.wav" - if chunk_path.exists(): - chunk_paths.append(chunk_path) - else: - logging.warning(f"Missing chunk file: chunk_{i+1:05}.wav") - - if not chunk_paths: - logging.info(f"{RED}❌ No valid audio chunks found. Skipping concatenation and conversion.{RESET}") - return None, None, [] - - print(f"πŸ“Š Found {len(chunk_paths)} total chunks for final audiobook") - - # Calculate timing - elapsed_total = time.time() - start_time - elapsed_td = timedelta(seconds=int(elapsed_total)) - - # Get total audio duration from ALL chunks - total_audio_duration_final = sum(get_chunk_audio_duration(chunk_path) for chunk_path in chunk_paths) - audio_duration_td = timedelta(seconds=int(total_audio_duration_final)) - realtime_factor = total_audio_duration_final / elapsed_total if elapsed_total > 0 else 0.0 - - print(f"\n⏱️ Resume Processing Complete:") - print(f" Elapsed Time: {CYAN}{str(elapsed_td)}{RESET}") - print(f" Audio Duration: {GREEN}{str(audio_duration_td)}{RESET}") - print(f" Realtime Factor: {YELLOW}{realtime_factor:.2f}x{RESET}") - - # Combine audio - combined_wav_path = output_root / f"{book_dir.name} [{voice_path.stem}].wav" - print("\nπŸ’Ύ Saving WAV file...") - combine_audio_chunks(chunk_paths, combined_wav_path) - - # M4B conversion - temp_m4b_path = output_root / "output.m4b" - final_m4b_path = output_root / f"{book_dir.name}[{voice_path.stem}].m4b" - convert_to_m4b(combined_wav_path, temp_m4b_path) - add_metadata_to_m4b(temp_m4b_path, final_m4b_path, cover_file, nfo_file) - - logging.info(f"Audiobook created: {final_m4b_path}") - - # Append final completion info - completion_log = [ - f"\n--- Resume Processing Complete ---", - f"Completed: {time.strftime('%Y-%m-%d %H:%M:%S')}", - f"Processing Time: {str(elapsed_td)}", - f"Audio Duration: {str(audio_duration_td)}", - f"Realtime Factor: {realtime_factor:.2f}x", - f"Total Chunks: {len(chunk_paths)}", - f"Combined WAV: {combined_wav_path}", - f"Final M4B: {final_m4b_path}" - ] - - # Append to existing log - log_run("\n".join(completion_log), output_root / "run.log") - print(f"πŸ“ Final completion info appended to: {output_root / 'run.log'}") - - return final_m4b_path, combined_wav_path, run_log_lines - -def resume_book_from_chunk(start_chunk): - """Interactive resume function for stuck book""" - print(f"\nπŸ”„ Resume Book Processing from Chunk {start_chunk}") - print("=" * 50) - - # Show available books from Audiobook directory (books that have started processing) - audiobook_root = Path(AUDIOBOOK_ROOT) - if not audiobook_root.exists(): - print(f"{RED}No audiobook directory found at {AUDIOBOOK_ROOT}.{RESET}") - return None - - book_dirs = sorted([d for d in audiobook_root.iterdir() if d.is_dir() and d.name != "Audio_Revisions"]) - if not book_dirs: - print(f"{RED}No books found in {AUDIOBOOK_ROOT}/ - no books have started processing.{RESET}") - print(f"πŸ’‘ Use Option 1 to start processing a new book first.") - return None - - print("Available books (in progress or completed):") - for i, book_dir in enumerate(book_dirs): - # All books in Audiobook/ should have processing data - audio_chunks_dir = book_dir / "TTS" / "audio_chunks" - if audio_chunks_dir.exists(): - last_chunk, missing = analyze_existing_chunks(audio_chunks_dir) - if missing: - status = f"(last chunk: {last_chunk}, {len(missing)} missing)" - else: - status = f"(completed: {last_chunk} chunks)" - else: - status = "(processing started but no chunks yet)" - - print(f" [{i}] {book_dir.name} {status}") - - while True: - try: - book_idx = int(input("Select book index: ")) - if 0 <= book_idx < len(book_dirs): - audiobook_dir = book_dirs[book_idx] - # Find corresponding Text_Input directory - text_input_book_dir = TEXT_INPUT_ROOT / audiobook_dir.name - if text_input_book_dir.exists(): - book_dir = text_input_book_dir - else: - print(f"❌ Text_Input directory not found for {audiobook_dir.name}") - print(f"πŸ’‘ The original book files may have been moved or deleted.") - continue - break - except Exception: - pass - print("Invalid selection. Try again.") - - # Analyze existing chunks for selected book - audiobook_dir = AUDIOBOOK_ROOT / book_dir.name - if audiobook_dir.exists(): - audio_chunks_dir = audiobook_dir / "TTS" / "audio_chunks" - if audio_chunks_dir.exists(): - last_chunk, missing = analyze_existing_chunks(audio_chunks_dir) - suggested_resume = suggest_resume_point(last_chunk, missing) - - print(f"\nSuggested resume point: {GREEN}{suggested_resume}{RESET}") - - # Allow user to override - user_input = input(f"Resume from chunk [{suggested_resume}]: ").strip() - if user_input: - try: - start_chunk = int(user_input) - except ValueError: - print(f"Invalid input, using suggested: {suggested_resume}") - start_chunk = suggested_resume - else: - start_chunk = suggested_resume - - # Show available voices - voice_files = list_voice_samples() - if not voice_files: - print(f"{RED}No voice samples found.{RESET}") - return None - - print("\nAvailable voices:") - for i, voice in enumerate(voice_files): - print(f" [{i}] {voice.name}") - - while True: - try: - voice_idx = int(input("Select voice index: ")) - if 0 <= voice_idx < len(voice_files): - voice_path = voice_files[voice_idx] - break - except Exception: - pass - print("Invalid selection. Try again.") - - # Get TTS parameters - def prompt_float(prompt, default): - val = input(f"{prompt} [{default}]: ").strip() - return float(val) if val else default - - exaggeration = prompt_float("Enter exaggeration (emotion intensity)", DEFAULT_EXAGGERATION) - cfg_weight = prompt_float("Enter cfg_weight (faithfulness to text)", DEFAULT_CFG_WEIGHT) - temperature = prompt_float("Enter temperature (randomness)", DEFAULT_TEMPERATURE) - - tts_params = dict(exaggeration=exaggeration, cfg_weight=cfg_weight, temperature=temperature) - - # Determine device with proper validation - from modules.tts_engine import get_best_available_device - device = get_best_available_device() - - print(f"\nπŸš€ Resuming {book_dir.name} from chunk {start_chunk}") - print(f"🎀 Voice: {voice_path.name}") - print(f"βš™οΈ Parameters: {tts_params}") - - # Process with resume - return process_book_folder_resume(book_dir, voice_path, tts_params, device, start_chunk) - -def find_incomplete_books(): - """Find books that appear to be incomplete""" - incomplete_books = [] - - for book_dir in TEXT_INPUT_ROOT.iterdir(): - if not book_dir.is_dir(): - continue - - audiobook_dir = AUDIOBOOK_ROOT / book_dir.name - if not audiobook_dir.exists(): - continue - - audio_chunks_dir = audiobook_dir / "TTS" / "audio_chunks" - if not audio_chunks_dir.exists(): - continue - - # Check if there's a final M4B - m4b_files = list(audiobook_dir.glob("*.m4b")) - wav_files = list(audiobook_dir.glob("*.wav")) - - if not m4b_files and not wav_files: - # No final output, likely incomplete - last_chunk, missing = analyze_existing_chunks(audio_chunks_dir) - if last_chunk > 0: - incomplete_books.append({ - "name": book_dir.name, - "last_chunk": last_chunk, - "missing_chunks": len(missing), - "path": book_dir - }) - - return incomplete_books - -def auto_resume_incomplete(): - """Automatically suggest resume for incomplete books""" - incomplete = find_incomplete_books() - - if not incomplete: - print(f"{GREEN}βœ… No incomplete books found!{RESET}") - return - - print(f"{YELLOW}πŸ“‹ Found {len(incomplete)} incomplete books:{RESET}") - for i, book in enumerate(incomplete): - print(f" [{i}] {book['name']} (last chunk: {book['last_chunk']}, missing: {book['missing_chunks']})") - - choice = input(f"\nSelect book to resume [0-{len(incomplete)-1}] or 'q' to quit: ").strip() - - if choice.lower() == 'q': - return - - try: - idx = int(choice) - if 0 <= idx < len(incomplete): - selected_book = incomplete[idx] - suggested_resume = selected_book['last_chunk'] + 1 - - print(f"\n🎯 Selected: {selected_book['name']}") - print(f"πŸ’‘ Suggested resume point: chunk {suggested_resume}") - - return resume_book_from_chunk(suggested_resume) - except ValueError: - print("Invalid selection.") - - return None diff --git a/HF_Deploy/modules/system_detector.py b/HF_Deploy/modules/system_detector.py deleted file mode 100644 index 7f28f582138e3811c721882b2bdd7377d03f8954..0000000000000000000000000000000000000000 --- a/HF_Deploy/modules/system_detector.py +++ /dev/null @@ -1,231 +0,0 @@ -""" -System Resource Detection Module -Detects VRAM, RAM, CPU cores and recommends appropriate ASR models -""" - -import psutil -import torch -import os -import sys -from pathlib import Path - -# Add project root to path for imports -if __name__ == "__main__": - sys.path.insert(0, str(Path(__file__).parent.parent)) - -from config.config import ASR_MODEL_VRAM_MB, ASR_MODEL_RAM_MB - -def get_gpu_memory(): - """Get total and available GPU memory in MB""" - try: - if torch.cuda.is_available(): - gpu_count = torch.cuda.device_count() - if gpu_count > 0: - # Use first GPU - total_vram = torch.cuda.get_device_properties(0).total_memory - allocated_vram = torch.cuda.memory_allocated(0) - available_vram = total_vram - allocated_vram - - return { - 'total_mb': total_vram // 1024 // 1024, - 'available_mb': available_vram // 1024 // 1024, - 'allocated_mb': allocated_vram // 1024 // 1024 - } - except: - pass - - return {'total_mb': 0, 'available_mb': 0, 'allocated_mb': 0} - -def get_system_memory(): - """Get total and available system RAM in MB""" - try: - memory = psutil.virtual_memory() - return { - 'total_mb': memory.total // 1024 // 1024, - 'available_mb': memory.available // 1024 // 1024, - 'used_mb': memory.used // 1024 // 1024 - } - except: - return {'total_mb': 0, 'available_mb': 0, 'used_mb': 0} - -def get_cpu_cores(): - """Get number of CPU cores""" - try: - return psutil.cpu_count(logical=False) or psutil.cpu_count() - except: - return 1 - -def estimate_tts_vram_usage(): - """Estimate VRAM usage by ChatterboxTTS (updated based on real usage)""" - return 5500 # 5.5GB in MB (was 7GB, adjusted based on actual 3.5GB usage + buffer) - -def get_system_profile(): - """Get complete system resource profile""" - gpu_info = get_gpu_memory() - ram_info = get_system_memory() - cpu_cores = get_cpu_cores() - - # Estimate available resources after TTS loading - tts_vram_estimate = estimate_tts_vram_usage() - available_vram_after_tts = max(0, gpu_info['available_mb'] - tts_vram_estimate) - - return { - 'gpu': gpu_info, - 'ram': ram_info, - 'cpu_cores': cpu_cores, - 'available_vram_after_tts': available_vram_after_tts, - 'has_gpu': gpu_info['total_mb'] > 0 - } - -def categorize_system(profile): - """Categorize system capabilities""" - gpu_total = profile['gpu']['total_mb'] - ram_total = profile['ram']['total_mb'] - cpu_cores = profile['cpu_cores'] - - # VRAM categories - if gpu_total < 4000: - vram_category = "low" - elif gpu_total <= 12000: - vram_category = "medium" - else: - vram_category = "high" - - # RAM categories - if ram_total < 16000: - ram_category = "low" - elif ram_total <= 64000: - ram_category = "medium" - else: - ram_category = "high" - - # CPU categories - if cpu_cores < 6: - cpu_category = "low" - elif cpu_cores <= 16: - cpu_category = "medium" - else: - cpu_category = "high" - - return { - 'vram': vram_category, - 'ram': ram_category, - 'cpu': cpu_category - } - -def get_safe_asr_models(profile): - """Get ASR models that can safely run on GPU with available VRAM""" - available_vram = profile['available_vram_after_tts'] - safe_models = [] - - for model, vram_req in ASR_MODEL_VRAM_MB.items(): - if vram_req <= available_vram: - safe_models.append(model) - - return safe_models - -def get_safe_cpu_models(profile): - """Get ASR models that can safely run on CPU with available RAM""" - available_ram = profile['ram']['available_mb'] - safe_models = [] - - for model, ram_req in ASR_MODEL_RAM_MB.items(): - if ram_req <= available_ram: - safe_models.append(model) - - return safe_models - -def recommend_asr_models(profile): - """Recommend Safe/Moderate/Insane ASR model configurations""" - categories = categorize_system(profile) - safe_gpu_models = get_safe_asr_models(profile) - safe_cpu_models = get_safe_cpu_models(profile) - - recommendations = {} - - # Model priority order (best to worst) - model_priority = ["large-v3", "large", "large-v2", "medium", "small", "base", "tiny"] - - # Safe: Conservative choice - safe_gpu = None - safe_cpu = None - - for model in reversed(model_priority): # Start from smallest - if model in safe_gpu_models and not safe_gpu: - safe_gpu = model - if model in safe_cpu_models and not safe_cpu: - safe_cpu = model - if safe_gpu and safe_cpu: - break - - # Moderate: Balanced choice - moderate_gpu = None - moderate_cpu = None - - # Try to get a model 1-2 steps up from safe - safe_idx = model_priority.index(safe_gpu) if safe_gpu else len(model_priority) - moderate_idx = max(0, safe_idx - 2) - - for i in range(moderate_idx, len(model_priority)): - model = model_priority[i] - if model in safe_gpu_models and not moderate_gpu: - moderate_gpu = model - if model in safe_cpu_models and not moderate_cpu: - moderate_cpu = model - if moderate_gpu and moderate_cpu: - break - - # Insane: Push the limits (best available models) - insane_gpu = None - insane_cpu = None - - # Get the best (largest) models that are safe - for model in model_priority: # Start from best - if model in safe_gpu_models and not insane_gpu: - insane_gpu = model - if model in safe_cpu_models and not insane_cpu: - insane_cpu = model - if insane_gpu and insane_cpu: - break - - # Build recommendations - recommendations['safe'] = { - 'primary': {'model': safe_gpu or safe_cpu, 'device': 'gpu' if safe_gpu else 'cpu'}, - 'fallback': {'model': safe_cpu, 'device': 'cpu'} - } - - recommendations['moderate'] = { - 'primary': {'model': moderate_gpu or moderate_cpu, 'device': 'gpu' if moderate_gpu else 'cpu'}, - 'fallback': {'model': moderate_cpu, 'device': 'cpu'} - } - - recommendations['insane'] = { - 'primary': {'model': insane_gpu or insane_cpu, 'device': 'gpu' if insane_gpu else 'cpu'}, - 'fallback': {'model': insane_cpu, 'device': 'cpu'} - } - - return recommendations - -def print_system_summary(profile): - """Print a human-readable system summary""" - categories = categorize_system(profile) - - print(f"πŸ–₯️ System Profile:") - print(f" VRAM: {profile['gpu']['total_mb']:,}MB total, {profile['available_vram_after_tts']:,}MB available after TTS ({categories['vram']} class)") - print(f" RAM: {profile['ram']['total_mb']:,}MB total, {profile['ram']['available_mb']:,}MB available ({categories['ram']} class)") - print(f" CPU: {profile['cpu_cores']} cores ({categories['cpu']} class)") - - if not profile['has_gpu']: - print(f" ⚠️ No CUDA GPU detected - ASR will run on CPU only") - -if __name__ == "__main__": - # Test the detection - profile = get_system_profile() - print_system_summary(profile) - - recommendations = recommend_asr_models(profile) - print(f"\nASR Model Recommendations:") - for level, config in recommendations.items(): - primary = config['primary'] - fallback = config['fallback'] - print(f"🟒 {level.upper()}: {primary['model']} ({primary['device']}) + {fallback['model']} (cpu fallback)") \ No newline at end of file diff --git a/HF_Deploy/modules/text_processor.py b/HF_Deploy/modules/text_processor.py deleted file mode 100644 index 225c39667ff6930f2696f8b25eb93fb4ffcbad2c..0000000000000000000000000000000000000000 --- a/HF_Deploy/modules/text_processor.py +++ /dev/null @@ -1,745 +0,0 @@ -""" -Text Processing Module -Handles text chunking, abbreviations, and preprocessing for TTS -""" - -import re -import logging -from pathlib import Path -from config.config import MAX_CHUNK_WORDS, MIN_CHUNK_WORDS, YELLOW, RESET - - - -# ============================================================================ -# ABBREVIATION REPLACEMENT SYSTEM -# ============================================================================ - -def load_abbreviations(file_path="utils/abbreviations.txt"): - """Load abbreviation replacements from external file""" - replacements = {} - abbrev_file = Path(file_path) - - if not abbrev_file.exists(): - print(f"⚠️ {YELLOW}Abbreviations file not found: {file_path}{RESET}") - print(f"πŸ“ Creating sample file...") - create_sample_abbreviations_file(abbrev_file) - return replacements - - try: - with open(abbrev_file, 'r', encoding='utf-8') as f: - for line_num, line in enumerate(f, 1): - line = line.strip() - - # Skip empty lines and comments - if not line or line.startswith('#'): - continue - - # Parse "abbrev -> replacement" format - if ' -> ' in line: - abbrev, replacement = line.split(' -> ', 1) - replacements[abbrev.strip()] = replacement.strip() - else: - print(f"⚠️ Invalid format on line {line_num}: {line}") - - print(f"βœ… Loaded {len(replacements)} abbreviation replacements from {file_path}") - - except Exception as e: - print(f"❌ Error loading abbreviations: {e}") - - return replacements - -def create_sample_abbreviations_file(file_path): - """Create a sample abbreviations file with common replacements""" - sample_content = """# Abbreviation Replacements for TTS -# Format: abbreviation -> replacement -# Lines starting with # are comments - -# Common titles and abbreviations -Dr. -> Doctor -Mr. -> Mister -Mrs. -> Missus -Ms. -> Miss -Prof. -> Professor -Rev. -> Reverend -Lt. -> Lieutenant -Capt. -> Captain -Gen. -> General -Col. -> Colonel -Jr. -> Junior -Sr. -> Senior - -# Political and organizations -M.P. -> MP -U.S. -> US -U.K. -> UK -U.N. -> UN -F.B.I. -> FBI -C.I.A. -> CIA -N.A.S.A. -> NASA - -# Common abbreviations -etc. -> et cetera -vs. -> versus -e.g. -> for example -i.e. -> that is -Inc. -> Incorporated -Corp. -> Corporation -Ltd. -> Limited -Co. -> Company - -# Numbers and ordinals -1st -> first -2nd -> second -3rd -> third -4th -> fourth -5th -> fifth -10th -> tenth -20th -> twentieth -21st -> twenty-first -30th -> thirtieth -40th -> fortieth -50th -> fiftieth -60th -> sixtieth -70th -> seventieth -80th -> eightieth -90th -> ninetieth -100th -> one hundredth - -# Time abbreviations -a.m. -> AM -p.m. -> PM -A.M. -> AM -P.M. -> PM -""" - - try: - with open(file_path, 'w', encoding='utf-8') as f: - f.write(sample_content) - print(f"πŸ“ Created sample abbreviations file: {file_path}") - print(f"πŸ’‘ Edit this file to add your own replacements!") - except Exception as e: - print(f"❌ Error creating sample file: {e}") - -def preprocess_abbreviations(text, replacements): - """Replace abbreviations with TTS-friendly versions""" - if not replacements: - return text - - original_text = text - replacements_made = 0 - - # Apply replacements (order matters for overlapping patterns) - for abbrev, replacement in replacements.items(): - if abbrev in text: - text = text.replace(abbrev, replacement) - replacements_made += 1 - - if replacements_made > 0: - logging.info(f"πŸ“ Applied {replacements_made} abbreviation replacements") - - return text - -# ============================================================================ -# TEXT PREPROCESSING AND CHUNKING -# ============================================================================ - -def smart_punctuate(text): - """ - Enhanced punctuation normalization with abbreviation replacement. - - PROCESSING REQUIREMENTS: - - Load and apply abbreviation replacements (Dr. -> Doctor, etc.) - - Add periods to lines that don't end with punctuation - - Replace Unicode smart quotes with ASCII quotes (", ') - - Remove problematic formatting (bold markdown, underlines) - - Preserve paragraph breaks (empty lines) - - This prepares text for consistent TTS processing. - """ - - # Load abbreviations and apply them - abbreviation_replacements = load_abbreviations() - text = preprocess_abbreviations(text, abbreviation_replacements) - - # Then continue with existing punctuation logic - lines = text.splitlines() - out = [] - - for l in lines: - stripped = l.strip() - - # Preserve empty lines (paragraph breaks) - if not stripped: - out.append("") # Keep the blank line - # Process non-empty lines - elif not re.search(r'[.!?]$', stripped) and not re.search(r'[.!?]["\']$', stripped): - out.append(stripped + ".") - else: - out.append(stripped) - - result = "\n".join(out) - - # Enhanced text preprocessing - replace curly quotes with straight quotes - result = result.replace('\u201c', '"').replace('\u201d', '"') # Replace smart double quotes " " - result = result.replace('\u2018', "'").replace('\u2019', "'") # Replace smart single quotes ' ' - - # Remove problematic formatting - result = re.sub(r'\*\*([^*]+)\*\*', r'\1', result) # Remove bold markdown - result = re.sub(r'_{2,}', '', result) # Remove underlines - - # Fix any escaped quotes that might appear in the text - result = result.replace('\\"', '"').replace("\\'", "'") - - # Additional quote normalization to prevent recurring dialogue corruption - result = re.sub(r'(["\'])\s*,\s*(["\'])', r'\1, \2', result) # Fix quote spacing around commas - result = re.sub(r'(["\'])\s*\.\s*(["\'])', r'\1. \2', result) # Fix quote spacing around periods - result = re.sub(r'(["\'])\s*([,.])\s*(["\'])\s*([,.])', r'\1\2 \3', result) # Remove duplicate punctuation - - # Debug logging for dialogue patterns - if '"' in result and ('replied' in result or 'said' in result): - print(f"πŸ—£οΈ DEBUG: Dialogue detected in smart_punctuate: {result[:100]}...") - - return result - -def fix_short_sentence_artifacts(chunk_text): - """ - Fix multiple short sentences that cause TTS errors. - Example: "Yes. No. Maybe." β†’ "Yes, no, maybe." - "Right." β†’ "Right," (if it's a single-word chunk) - """ - # Handle full chunk that is just one short sentence - words = chunk_text.strip().split() - if len(words) == 1 and chunk_text.strip().endswith('.'): - return chunk_text.strip()[:-1] + ',' # Replace period with comma - - parts = re.split(r'([.!?])', chunk_text.strip()) - if len(parts) < 2: - return chunk_text # nothing to fix - - # Reconstruct sentence-punctuation pairs - sentences = [] - for i in range(0, len(parts)-1, 2): - sentence = parts[i].strip() - punct = parts[i+1] - if sentence: - word_count = len(sentence.split()) - sentences.append((sentence, punct, word_count)) - - # Handle multiple short sentences - short_count = sum(1 for _, _, wc in sentences if wc <= 3) - - if short_count >= 2 and len(sentences) >= 2: - merged = ", ".join(s for s, _, _ in sentences) + "." - return merged - - # Handle case where first sentence is a single word - if len(sentences) >= 2 and sentences[0][2] == 1 and sentences[0][1] == ".": - # Replace period with comma - first, second = sentences[0][0], sentences[1][0] - rest = " ".join(s for s, _, _ in sentences[2:]) - new_text = f"{first}, {second}" - if rest: - new_text += " " + rest - return new_text - - return chunk_text - -def _is_apostrophe(text, pos): - """Check if a single quote at position pos is likely an apostrophe (not speech quote)""" - if pos == 0 or pos >= len(text) - 1: - return False - - # Check characters before and after - before = text[pos - 1] if pos > 0 else ' ' - after = text[pos + 1] if pos < len(text) - 1 else ' ' - - # It's likely an apostrophe if: - # 1. Preceded and followed by letters (contractions like "don't", possessives like "John's") - # 2. Or preceded by letter and followed by 's' or 't' (common contractions) - if before.isalpha() and after.isalpha(): - return True - if before.isalpha() and after in 's': - return True - - return False - -def sentence_chunk_text(text, max_words=MAX_CHUNK_WORDS, min_words=MIN_CHUNK_WORDS): - """ - Simple and reliable text chunking that follows the exact rules: - - TEXT CHUNKING RULES: - 1. Break at sentence boundaries (. ! ?) first (highest priority) - 2. If sentence > max_words, break at punctuation working backwards (; β€” , in that order) - 3. If no punctuation available, preserve sentence intact to maintain coherence - 4. Ensure all chunks meet min_words requirement by combining small chunks - - PUNCTUATION HIERARCHY (for breaking long sentences): - 1. . ! ? (sentence boundaries) - handled at sentence level - 2. ; (semicolon) - major pause - 3. β€” – (dashes) - major pause - 4. , (comma) - minor pause - 5. Preserve overlong sentences if no punctuation available - """ - import re - - # Process text paragraph by paragraph to preserve structure - paragraphs = text.split('\n\n') - all_final_chunks = [] - - for paragraph in paragraphs: - paragraph = paragraph.strip() - if not paragraph: - continue - - # Check if this is a chapter/section header - para_lower = paragraph.lower().strip() - is_chapter_header = ( - any(word in para_lower for word in ['chapter', 'section', 'part', 'prologue', 'epilogue']) and - len(paragraph.split()) <= 10 - ) - - if is_chapter_header: - # Chapter headers are their own chunks and always paragraph ends - all_final_chunks.append((paragraph, True)) - continue - - # Split into sentences using periods, exclamation marks, question marks - # This avoids the complex quote detection that was causing problems - sentences = re.split(r'([.!?])\s+', paragraph.strip()) - - # Reconstruct sentences with their punctuation - reconstructed_sentences = [] - for i in range(0, len(sentences) - 1, 2): - sentence = sentences[i].strip() - if i + 1 < len(sentences): - punct = sentences[i + 1] - sentence += punct - if sentence: - reconstructed_sentences.append(sentence) - - # Handle any remaining text (no ending punctuation) - if sentences and sentences[-1].strip(): - last_part = sentences[-1].strip() - if last_part and not last_part in '.!?': - reconstructed_sentences.append(last_part) - - # Process each sentence - paragraph_chunks = [] - for sent_idx, sentence in enumerate(reconstructed_sentences): - is_last_sentence = (sent_idx == len(reconstructed_sentences) - 1) - words = sentence.split() - - if len(words) <= max_words: - # Sentence fits, use as-is - paragraph_chunks.append((sentence.strip(), is_last_sentence)) - else: - # Sentence too long, break it using punctuation - broken_chunks = _break_long_sentence_simple(sentence, max_words) - # Only mark the last broken chunk as sentence end - for i, chunk in enumerate(broken_chunks): - is_chunk_end = (is_last_sentence and i == len(broken_chunks) - 1) - paragraph_chunks.append((chunk.strip(), is_chunk_end)) - - all_final_chunks.extend(paragraph_chunks) - - # Combine small chunks that don't meet min_words requirement - combined_chunks = _combine_small_chunks(all_final_chunks, min_words, max_words) - - return combined_chunks - -def _break_long_sentence_simple(sentence, max_words): - """Break a long sentence at punctuation marks, working backwards""" - import re - - # Punctuation patterns in priority order - patterns = [ - r';\s*', # semicolon + optional space - r'β€”\s*', # em dash + optional space - r'–\s*', # en dash + optional space - r',\s*', # comma + optional space - ] - - chunks = [] - remaining = sentence.strip() - - while remaining: - words = remaining.split() - if len(words) <= max_words: - chunks.append(remaining) - break - - # Find best break point working backwards - best_break = -1 - - # Try each punctuation pattern - for pattern in patterns: - matches = list(re.finditer(pattern, remaining)) - if matches: - # Find rightmost match that results in chunk <= max_words - for match in reversed(matches): - test_chunk = remaining[:match.end()].strip() - if len(test_chunk.split()) <= max_words: - best_break = match.end() - break - if best_break != -1: - break - - if best_break != -1: - # Found good break point - chunk = remaining[:best_break].strip() - chunks.append(chunk) - remaining = remaining[best_break:].strip() - else: - # No punctuation found - preserve sentence coherence by keeping it intact - # This prevents splitting sentences with potentially different sentiment - chunks.append(remaining) - break - - return chunks - -def _combine_small_chunks(chunks, min_words, max_words): - """Combine chunks that are too small""" - combined = [] - current_chunk = "" - current_is_para_end = False - - for chunk_text, is_para_end in chunks: - chunk_words = len(chunk_text.split()) - current_words = len(current_chunk.split()) if current_chunk else 0 - - if not current_chunk: - # First chunk - current_chunk = chunk_text - current_is_para_end = is_para_end - elif current_words + chunk_words <= max_words: - # Can combine - current_chunk = current_chunk + " " + chunk_text - current_is_para_end = is_para_end # Use the latest para_end flag - else: - # Can't combine, flush current and start new - if current_words >= min_words: - combined.append((current_chunk, current_is_para_end)) - current_chunk = chunk_text - current_is_para_end = is_para_end - else: - # Current chunk too small, force combine anyway - current_chunk = current_chunk + " " + chunk_text - current_is_para_end = is_para_end - - # Handle remaining chunk - if current_chunk: - combined.append((current_chunk, current_is_para_end)) - - return combined - -def break_long_sentence_backwards(sentence, max_words, min_words): - """ - Break a long sentence working backwards from the end to find natural punctuation. - - ALGORITHM: - 1. Start from sentence end, work backwards to find punctuation within max_words - 2. Break at the latest (rightmost) punctuation that keeps chunk <= max_words - 3. This preserves natural pauses and speech rhythm - 4. Continue processing remaining text normally - - PUNCTUATION HIERARCHY (in order of preference): - 1. . ! ? (sentence boundaries) - highest priority - 2. ; (semicolon) - major pause - 3. β€” (em dash) - major pause - 4. , (comma) - minor pause - 5. Force break at word limit (last resort) - """ - - # Punctuation patterns to search for (in order of preference) - punctuation_patterns = [ - r'[.!?]\s+', # sentence boundaries + required space (highest priority) - r';\s*', # semicolon + optional space - r'β€”\s*', # em dash + optional space - r'–\s*', # en dash + optional space - r',\s*', # comma + optional space - ] - - chunks = [] - remaining_text = sentence.strip() - - while remaining_text: - words = remaining_text.split() - - if len(words) <= max_words: - # Remaining text fits within limit - chunks.append(remaining_text.strip()) - break - - # Text exceeds max_words - find backwards break point - # Search for punctuation within the current 'remaining_text' up to max_words - # We need to find the *last* punctuation mark that results in a chunk <= max_words - best_break_index = -1 # Index in 'words' list - best_break_pos_in_text = -1 # Character position in 'remaining_text' - - # Iterate backwards from max_words down to min_words (or 1 if min_words is very small) - # to find the latest punctuation that keeps the chunk within limits. - for i in range(min(max_words, len(words)) -1, 0, -1): - sub_text = " ".join(words[:i+1]) # Text up to current word - - found_punctuation = False - for pattern in punctuation_patterns: - matches = list(re.finditer(pattern, sub_text)) - if matches: - # Take the rightmost match in this sub_text - last_match = matches[-1] - # Ensure the break is within the max_words limit - if len(sub_text[:last_match.end()].split()) <= max_words: - best_break_index = i # Store word index - best_break_pos_in_text = last_match.end() # Store char position - found_punctuation = True - break # Found a good break for this sub_text, move to next i - if found_punctuation: - break # Found the best break for the overall chunk, exit outer loop - - if best_break_pos_in_text != -1: - # Found punctuation - break after it, keeping punctuation with preceding text - chunk_text = remaining_text[:best_break_pos_in_text].strip() - chunks.append(chunk_text) - remaining_text = remaining_text[best_break_pos_in_text:].strip() - else: - # No punctuation found within the desired range - keep sentence intact - # This preserves sentence coherence over word count limits - chunks.append(remaining_text.strip()) - break - - return chunks - -# ============================================================================ -# CONTENT BOUNDARY DETECTION -# ============================================================================ - -def detect_punctuation_boundary(chunk_text): - """ - Detect the ending punctuation of a text chunk for precise silence insertion. - - Returns specific punctuation boundary types: - - "comma" -> Brief pause after commas - - "semicolon" -> Medium pause after semicolons - - "colon" -> Pause after colons - - "period" -> Sentence end pause - - "question_mark" -> Question pause - - "exclamation" -> Exclamation pause - - "dash" -> Em dash pause - - "ellipsis" -> Ellipsis pause (suspense) - - "quote_end" -> End of quoted speech - - None -> No specific punctuation detected - """ - # Strip whitespace and newlines for accurate detection - text = chunk_text.strip() - - if not text: - return None - - # Check ending punctuation patterns (in order of specificity) - if text.endswith('...'): - return "ellipsis" - elif text.endswith('"') or text.endswith("'"): - return "quote_end" - elif text.endswith('!'): - return "exclamation" - elif text.endswith('?'): - return "question_mark" - elif text.endswith('.'): - return "period" - elif text.endswith(':'): - return "colon" - elif text.endswith(';'): - return "semicolon" - elif text.endswith(','): - return "comma" - elif text.endswith('β€”') or text.endswith('–'): - return "dash" - - return None - -def detect_content_boundaries(chunk_text, chunk_index, all_chunks, is_paragraph_end=False): - """ - Detect chapter breaks and paragraph endings for appropriate silence insertion. - Now enhanced with punctuation-specific boundary detection. - - BOUNDARY DETECTION REQUIREMENTS: - - Chapter start: "Chapter N", "Ch. N", "I.", "1." patterns - - Chapter end: Next chunk is a chapter start - - Section break: Multiple asterisks, hashes, or em-dashes - - Paragraph end: Detected via chunking process flag or content analysis - - Punctuation: Specific ending punctuation for precise silence timing - - Returns boundary_type for silence insertion: - - "chapter_start" -> Long pause before chapter - - "chapter_end" -> Long pause after chapter - - "section_break" -> Medium pause for section breaks - - "paragraph_end" -> Short pause for paragraph breaks - - Punctuation types: "comma", "period", "question_mark", etc. - - None -> No special boundary detected - """ - boundary_type = None - - # Chapter detection (flexible patterns) - chapter_patterns = [ - r'^(Chapter \d+|CHAPTER \d+)', - r'^(Ch\. \d+|CH\. \d+)', - r'^\d+\.', # Simple "1." numbering - r'^[IVX]+\.', # Roman numerals "I.", "II.", etc. - ] - - for pattern in chapter_patterns: - if re.search(pattern, chunk_text.strip(), re.MULTILINE): - boundary_type = "chapter_start" - break - - # Look ahead for chapter start (current chunk ends chapter) - if chunk_index + 1 < len(all_chunks): - next_chunk = all_chunks[chunk_index + 1] - for pattern in chapter_patterns: - if re.search(pattern, next_chunk.strip()): - boundary_type = "chapter_end" - break - - # Section breaks (asterisks, multiple line breaks) - if re.search(r'\*{3,}|\#{3,}|β€”{3,}', chunk_text): - boundary_type = "section_break" - - # Paragraph ending detection - # Use the is_paragraph_end flag from chunking process since newlines are stripped - if is_paragraph_end and boundary_type is None: - boundary_type = "paragraph_end" - - # If no major structural boundary found, check punctuation - if boundary_type is None: - boundary_type = detect_punctuation_boundary(chunk_text) - - return boundary_type - -def _split_long_dialogue(sentence, max_words, recursion_depth=0): - """ - Split long dialogue sections that exceed word limits. - Tries to break at natural points: attribution, internal punctuation, then word boundaries. - """ - # Prevent infinite recursion - if recursion_depth > 3: - # Force word boundary split if recursion gets too deep - words = sentence.split() - sentences = [] - start = 0 - while start < len(words): - end = min(start + max_words, len(words)) - chunk_words = words[start:end] - sentences.append(' '.join(chunk_words)) - start = end - return sentences - - words = sentence.split() - if len(words) <= max_words: - return [sentence] - - sentences = [] - - # Strategy 1: Break at dialogue attribution (he said, she replied, etc.) - attribution_pattern = r'(\s+(?:he|she|I|they|[A-Z][a-z]+)\s+(?:said|replied|asked|shouted|whispered|continued|added|interrupted)[^.!?]*?[.!?]?\s*)' - attribution_matches = list(re.finditer(attribution_pattern, sentence, re.IGNORECASE)) - - if attribution_matches: - start = 0 - for match in attribution_matches: - # Check if breaking here keeps chunks under limit - before_attr = sentence[start:match.end()].strip() - if before_attr and len(before_attr.split()) <= max_words: - sentences.append(before_attr) - start = match.end() - - # Add remaining text - if start < len(sentence): - remaining = sentence[start:].strip() - if remaining: - if len(remaining.split()) > max_words: - # Recursively split if still too long, but with depth tracking - sentences.extend(_split_long_dialogue(remaining, max_words, recursion_depth + 1)) - else: - sentences.append(remaining) - - if sentences: # If we successfully split, return result - return sentences - - # Strategy 2: Break at internal punctuation (commas, semicolons within quotes) - punct_pattern = r'([,;:]\s+)' - parts = re.split(punct_pattern, sentence) - - current_chunk = "" - sentences = [] - for i, part in enumerate(parts): - test_chunk = current_chunk + part - if len(test_chunk.split()) > max_words and current_chunk: - sentences.append(current_chunk.strip()) - current_chunk = part - else: - current_chunk = test_chunk - - if current_chunk.strip(): - sentences.append(current_chunk.strip()) - - # Check if any resulting chunk is still too long and needs further splitting - final_sentences = [] - for chunk in sentences: - if len(chunk.split()) > max_words: - # Split oversized chunks using word boundaries - chunk_words = chunk.split() - start = 0 - while start < len(chunk_words): - end = min(start + max_words, len(chunk_words)) - sub_chunk_words = chunk_words[start:end] - final_sentences.append(' '.join(sub_chunk_words)) - start = end - else: - final_sentences.append(chunk) - - if len(final_sentences) > 1: # If we successfully split, return result - return final_sentences - - # Strategy 3: Force break at word boundaries (guaranteed to work) - sentences = [] - start = 0 - while start < len(words): - end = min(start + max_words, len(words)) - chunk_words = words[start:end] - sentences.append(' '.join(chunk_words)) - start = end - - return sentences - -# ============================================================================ -# UTILITY FUNCTIONS -# ============================================================================ - -def reload_abbreviations(): - """Reload abbreviations from file (useful for testing changes)""" - return load_abbreviations() - -def test_abbreviations(test_text="Dr. Smith met with the M.P. at 3:30 p.m. on the 21st."): - """Test abbreviation replacements on sample text""" - abbreviation_replacements = load_abbreviations() - print(f"Original: {test_text}") - processed = preprocess_abbreviations(test_text, abbreviation_replacements) - print(f"Processed: {processed}") - return processed - -def test_chunking(test_text=None, max_words=20, min_words=4): - """Test the enhanced chunking with sample or custom text""" - if test_text is None: - test_text = '''Though perfectly worldly-wise, and able, as she expressed it, to take care of herself, there was yet something curiously ingenuous in her single-minded attitude towards life, and her whole-hearted determination to "make good." This glimpse of a world unknown to me was not without its charm, and I enjoyed seeing her vivid little face light up as she talked.''' - - chunks = sentence_chunk_text(test_text, max_words=max_words, min_words=min_words) - - print("Enhanced Chunking Results:") - for i, (chunk, is_para) in enumerate(chunks): - word_count = len(chunk.split()) - print(f"Chunk {i+1} ({word_count} words): {chunk}") - if word_count > max_words: - print(f" βœ… Over {max_words} words but complete sentence (follows punctuation rules)") - print() - - return chunks diff --git a/HF_Deploy/modules/tts_engine.py b/HF_Deploy/modules/tts_engine.py deleted file mode 100644 index 3339213c53f9c16b1f7ccf6f13a33247839b7e74..0000000000000000000000000000000000000000 --- a/HF_Deploy/modules/tts_engine.py +++ /dev/null @@ -1,710 +0,0 @@ -""" -TTS Engine Module -Handles ChatterboxTTS interface, model loading, and chunk processing coordination -""" - -import torch -import gc -import time -import logging -import shutil -import sys -from datetime import timedelta -from concurrent.futures import ThreadPoolExecutor, as_completed -from pathlib import Path -import torchaudio as ta - -from config.config import * -from modules.text_processor import smart_punctuate, sentence_chunk_text, detect_content_boundaries - -def find_chunks_json_file(book_name): - """Find the corresponding chunks JSON file for a book""" - from config.config import AUDIOBOOK_ROOT - - # Look in the TTS processing directory - tts_chunks_dir = AUDIOBOOK_ROOT / book_name / "TTS" / "text_chunks" - json_path = tts_chunks_dir / "chunks_info.json" - - if json_path.exists(): - return json_path - - # Also check old Text_Input location for backwards compatibility - text_input_dir = Path("Text_Input") - possible_names = [ - f"{book_name}_chunks.json", - f"{book_name.lower()}_chunks.json", - f"{book_name.replace(' ', '_')}_chunks.json" - ] - - for name in possible_names: - old_json_path = text_input_dir / name - if old_json_path.exists(): - return old_json_path - - return None -from modules.audio_processor import ( - smart_audio_validation, apply_smart_fade, add_chunk_end_silence, - add_contextual_silence, pause_for_chunk_review, get_chunk_audio_duration, - has_mid_energy_drop, apply_smart_fade_memory, smart_audio_validation_memory -) -from modules.file_manager import ( - setup_book_directories, find_book_files, ensure_voice_sample_compatibility, - combine_audio_chunks, get_audio_files_in_directory, convert_to_m4b, add_metadata_to_m4b -) -from modules.progress_tracker import setup_logging, log_chunk_progress, log_run - -# ============================================================================ -# MEMORY AND MODEL MANAGEMENT -# ============================================================================ - -def monitor_gpu_activity(operation_name): - """Lightweight GPU monitoring for high-speed processing""" - # Disabled expensive pynvml queries to free up GPU cycles - if torch.cuda.is_available(): - allocated = torch.cuda.memory_allocated() / 1024**3 - # Skip GPU utilization queries during production runs - return allocated, 0 - return 0, 0 - -def optimize_memory_usage(): - """Aggressive memory management for 8GB VRAM""" - torch.cuda.empty_cache() - gc.collect() - if torch.cuda.is_available(): - torch.cuda.ipc_collect() - -def monitor_vram_usage(operation_name=""): - """Real-time VRAM monitoring""" - if torch.cuda.is_available(): - allocated = torch.cuda.memory_allocated() / 1024**3 - reserved = torch.cuda.memory_reserved() / 1024**3 - - if allocated > VRAM_SAFETY_THRESHOLD: - logging.warning(f"⚠️ High VRAM usage during {operation_name}: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved") - optimize_memory_usage() - - return allocated, reserved - return 0, 0 - -def get_optimal_workers(user_max_workers=None): - """Dynamic worker allocation based on device type and resources""" - # Check for user override first - if user_max_workers is not None: - print(f"πŸ‘€ Using user-defined workers: {user_max_workers}") - return int(user_max_workers) - - if not USE_DYNAMIC_WORKERS: - return MAX_WORKERS - - # CPU-based worker calculation - if not torch.cuda.is_available(): - import psutil - cpu_cores = psutil.cpu_count(logical=False) # Physical cores - available_memory = psutil.virtual_memory().available / 1024**3 # GB - - # Each TTS model instance needs ~2-3GB RAM - # Conservative estimation: allow 1 worker per 4GB available RAM - memory_limited_workers = max(1, int(available_memory / 4)) - - # CPU-based calculation: use 50% of physical cores for intensive TTS work - cpu_limited_workers = max(1, int(cpu_cores * 0.5)) - - optimal_workers = min(memory_limited_workers, cpu_limited_workers, MAX_WORKERS) - print(f"πŸ’» CPU mode: {cpu_cores} cores, {available_memory:.1f}GB RAM β†’ {optimal_workers} workers") - return optimal_workers - - # GPU-based worker calculation (existing logic) - allocated_vram = torch.cuda.memory_allocated() / 1024**3 - - if allocated_vram < 5.0: - return min(TEST_MAX_WORKERS, MAX_WORKERS) - elif allocated_vram < VRAM_SAFETY_THRESHOLD: - return min(2, MAX_WORKERS) - else: - return 1 - -def load_optimized_model(device): - """Load TTS model with memory optimizations and device detection""" - from chatterbox.tts import ChatterboxTTS - - # Detect available device if not specified or if CUDA not available - if device == "cuda" and not torch.cuda.is_available(): - print("⚠️ CUDA not available, falling back to CPU") - device = "cpu" - elif device == "auto": - if torch.cuda.is_available(): - device = "cuda" - print("βœ… CUDA detected, using GPU") - else: - device = "cpu" - print("πŸ’» No GPU detected, using CPU") - - print(f"πŸ”§ Loading ChatterboxTTS model on device: {device}") - - try: - # Load model (ChatterboxTTS.from_pretrained doesn't support torch_dtype parameter) - model = ChatterboxTTS.from_pretrained(device=device) - logging.info(f"βœ… Loaded ChatterboxTTS model on {device}") - except Exception as e: - print(f"❌ Failed to load model on {device}: {e}") - if device == "cuda": - print("πŸ”„ Retrying with CPU...") - try: - model = ChatterboxTTS.from_pretrained(device="cpu") - logging.info("βœ… Loaded model on CPU (GPU failed)") - device = "cpu" - except Exception as e2: - print(f"❌ Failed to load model on CPU: {e2}") - raise e2 - else: - raise e - - # Only apply eval() and benchmark if the model has these attributes - if hasattr(model, 'eval'): - model.eval() - - # Set CUDNN benchmark for performance (if available) - if torch.backends.cudnn.is_available(): - torch.backends.cudnn.benchmark = True - - return model - -# ============================================================================ -# CHUNK PROCESSING -# ============================================================================ - -def patch_alignment_layer(tfmr, alignment_layer_idx=12): - """Patch alignment layer to avoid recursion""" - from types import MethodType - target_layer = tfmr.layers[alignment_layer_idx].self_attn - original_forward = target_layer.forward - - def patched_forward(self, *args, **kwargs): - kwargs['output_attentions'] = True - return original_forward(*args, **kwargs) - - target_layer.forward = MethodType(patched_forward, target_layer) - -def process_one_chunk( - i, chunk, text_chunks_dir, audio_chunks_dir, - voice_path, tts_params, start_time, total_chunks, - punc_norm, basename, log_run_func, log_path, device, - model, asr_model, all_chunks, boundary_type="none" -): - """Enhanced chunk processing with quality control, contextual silence, and deep cleanup""" - import difflib - from pydub import AudioSegment - - chunk_id_str = f"{i+1:05}" - chunk_path = text_chunks_dir / f"chunk_{chunk_id_str}.txt" - with open(chunk_path, 'w', encoding='utf-8') as cf: - cf.write(chunk) - - chunk_audio_path = audio_chunks_dir / f"chunk_{chunk_id_str}.wav" - - # ============================================================================ - # ENHANCED PERIODIC DEEP CLEANUP - # ============================================================================ - cleanup_interval = CLEANUP_INTERVAL - - # Skip cleanup on model reinitialization chunks to avoid conflicts - if (i + 1) % cleanup_interval == 0 and (i + 1) % BATCH_SIZE != 0: - print(f"\n🧹 {YELLOW}DEEP CLEANUP at chunk {i+1}/{total_chunks}...{RESET}") - - # Enhanced VRAM monitoring before cleanup - allocated_before = torch.cuda.memory_allocated() / 1024**3 if torch.cuda.is_available() else 0 - reserved_before = torch.cuda.memory_reserved() / 1024**3 if torch.cuda.is_available() else 0 - - print(f" Before: VRAM Allocated: {allocated_before:.1f}GB | Reserved: {reserved_before:.1f}GB") - - # Bulk temp file cleanup - print(" πŸ—‘οΈ Cleaning bulk temporary files...") - temp_patterns = ["*_try*.wav", "*_pre.wav", "*_fade*.wav", "*_debug*.wav", "*_temp*.wav", "*_backup*.wav"] - total_temp_files = 0 - for pattern in temp_patterns: - temp_files = list(audio_chunks_dir.glob(pattern)) - for temp_file in temp_files: - temp_file.unlink(missing_ok=True) - total_temp_files += len(temp_files) - - if total_temp_files > 0: - print(f" πŸ—‘οΈ Removed {total_temp_files} temporary audio files") - - # Aggressive CUDA context reset - print(" πŸ”„ Performing aggressive CUDA context reset...") - torch.cuda.synchronize() - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - - # Force CUDA context reset - if hasattr(torch.cuda, 'reset_peak_memory_stats'): - torch.cuda.reset_peak_memory_stats() - if hasattr(torch._C, '_cuda_clearCublasWorkspaces'): - torch._C._cuda_clearCublasWorkspaces() - - # Force garbage collection multiple times - for _ in range(3): - gc.collect() - - # Clear model cache if it has one - if hasattr(model, 'clear_cache'): - model.clear_cache() - elif hasattr(model, 'reset_states'): - model.reset_states() - - # Brief pause to let GPU settle - time.sleep(1.0) - - # Monitor after cleanup - allocated_after = torch.cuda.memory_allocated() / 1024**3 if torch.cuda.is_available() else 0 - reserved_after = torch.cuda.memory_reserved() / 1024**3 if torch.cuda.is_available() else 0 - - print(f" After: VRAM Allocated: {allocated_after:.1f}GB | Reserved: {reserved_after:.1f}GB") - print(f" Freed: {allocated_before - allocated_after:.1f}GB allocated, {reserved_before - reserved_after:.1f}GB reserved") - print(f"🧹 {GREEN}Deep cleanup complete!{RESET}\n") - - best_sim, best_asr_text = -1, "" - wav_path_active = None - attempt_paths = [] - mid_drop_retries = 0 - max_mid_drop_retries = 2 - - for attempt_num in range(1, 3): - logging.info(f"πŸ” Starting TTS for chunk {chunk_id_str}, attempt {attempt_num}") - try: - tts_args = {k: v for k, v in tts_params.items() if k != "max_workers"} - - # monitor_gpu_activity(f"Before TTS chunk_{chunk_id_str}") # Disabled for speed - with torch.no_grad(): - wav = model.generate(chunk, **tts_args).detach().cpu() - # monitor_gpu_activity(f"After TTS chunk_{chunk_id_str}") # Disabled for speed - - if wav.dim() == 1: - wav = wav.unsqueeze(0) - - # Retry if mid-energy drop is enabled and detected (check in memory) - if ENABLE_MID_DROP_CHECK and has_mid_energy_drop(wav, model.sr): - mid_drop_retries += 1 - if mid_drop_retries >= max_mid_drop_retries: - logging.info(f"⚠️ Mid-drop retry limit reached for {chunk_id_str}. Accepting audio.") - else: - logging.info(f"⚠️ Mid-chunk noise detected in {chunk_id_str}. Retrying...") - continue - - # Convert tensor to AudioSegment for in-memory processing - import io - import soundfile as sf - from pydub import AudioSegment - - # Convert wav tensor to AudioSegment (in memory) - wav_np = wav.squeeze().numpy() - with io.BytesIO() as wav_buffer: - sf.write(wav_buffer, wav_np, model.sr, format='wav') - wav_buffer.seek(0) - audio_segment = AudioSegment.from_wav(wav_buffer) - - # Smart fade removed - replaced by precise audio trimming - # Audio health validation disabled for speed - - # Note: Audio trimming will handle end-of-speech cleanup more precisely - - # ASR validation (memory-based processing) - check user setting first - enable_asr_user = tts_params.get('enable_asr', False) - if (enable_asr_user or ENABLE_ASR) and asr_model is not None: - from modules.audio_processor import asr_f1_score - import io - import soundfile as sf - # monitor_gpu_activity(f"Before ASR chunk_{chunk_id_str}") # Disabled for speed - try: - # Process ASR completely in memory - no disk writes - # Convert AudioSegment to numpy array for ASR - samples = np.array(audio_segment.get_array_of_samples()) - if audio_segment.channels == 2: - samples = samples.reshape((-1, 2)).mean(axis=1) - - # Normalize to float32 for ASR model - audio_np = samples.astype(np.float32) / audio_segment.max_possible_amplitude - - # Use ASR model directly on numpy array (if supported) - # Note: This depends on the ASR model's input capabilities - result = asr_model.transcribe(audio_np) - - if not isinstance(result, dict) or "text" not in result: - raise ValueError(f"Invalid ASR result type: {type(result)}") - - asr_text = result.get("text", "").strip() - sim_ratio = asr_f1_score(punc_norm(chunk), asr_text) - - except Exception as e: - print(f"❌ ASR failed for {chunk_id_str}: {e}") - log_run_func(f"ASR VALIDATION FAILED - Chunk {chunk_id_str}:\nExpected:\n{chunk}\nActual:\n\nSimilarity: -1.000\n" + "="*50, log_path) - sim_ratio = -1.0 - continue - - logging.info(f"ASR similarity for chunk {chunk_id_str}: {sim_ratio:.3f}") - if sim_ratio < 0.7: - continue - - # Track best valid match - best_sim = sim_ratio - best_asr_text = asr_text - # monitor_gpu_activity(f"After ASR chunk_{chunk_id_str}") # Disabled for speed - - # Success - we have processed audio in memory - final_audio = audio_segment - break - - except Exception as e: - import traceback - logging.error(f"Exception during TTS attempt {attempt_num} for chunk {chunk_id_str}: {e}") - traceback.print_exc() - continue - - if 'final_audio' not in locals(): - logging.info(f"❌ Chunk {chunk_id_str} failed all attempts.") - return None, None - - # Apply trimming and contextual silence in memory before final save - from modules.audio_processor import process_audio_with_trimming_and_silence - - if boundary_type and boundary_type != "none": - final_audio = process_audio_with_trimming_and_silence(final_audio, boundary_type) - print(f"πŸ”‡ Added {boundary_type} silence to chunk {i+1:05}") - else: - # Apply trimming even without boundary type if enabled - if ENABLE_AUDIO_TRIMMING: - from modules.audio_processor import trim_audio_endpoint - final_audio = trim_audio_endpoint(final_audio) - - # Note: ENABLE_CHUNK_END_SILENCE is now handled by punctuation-specific silence - # The new system provides more precise silence based on actual punctuation - - # Final save - only disk write in entire process - final_path = audio_chunks_dir / f"chunk_{chunk_id_str}.wav" - final_audio.export(final_path, format="wav") - logging.info(f"βœ… Saved final chunk: {final_path.name}") - - # No intermediate file cleanup needed - all processing done in memory - - # Log details - only log ASR failures - asr_active = enable_asr_user or ENABLE_ASR - if asr_active and best_sim < 0.8: - log_run_func(f"ASR VALIDATION FAILED - Chunk {chunk_id_str}:\nExpected:\n{chunk}\nActual:\n{best_asr_text}\nSimilarity: {best_sim:.3f}\n" + "="*50, log_path) - elif not asr_active: - log_run_func(f"Chunk {chunk_id_str}: Original text: {chunk}", log_path) - - # Silence already added in memory above - no disk processing needed - - # Enhanced regular cleanup (every chunk) - del wav - optimize_memory_usage() - - # Additional per-chunk cleanup for long runs - if (i + 1) % 50 == 0: - torch.cuda.empty_cache() - gc.collect() - - return i, final_path - -# ============================================================================ -# MAIN BOOK PROCESSING FUNCTION -# ============================================================================ - -from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer -from wrapper.chunk_loader import save_chunks - -def generate_enriched_chunks(text_file, output_dir, user_tts_params=None): - """Reads a text file, performs VADER sentiment analysis, and returns enriched chunks.""" - analyzer = SentimentIntensityAnalyzer() - - raw_text = text_file.read_text(encoding='utf-8') - cleaned = smart_punctuate(raw_text) - chunks = sentence_chunk_text(cleaned) - - # Use user-provided parameters as base, or fall back to config defaults - if user_tts_params: - base_exaggeration = user_tts_params.get('exaggeration', BASE_EXAGGERATION) - base_cfg_weight = user_tts_params.get('cfg_weight', BASE_CFG_WEIGHT) - base_temperature = user_tts_params.get('temperature', BASE_TEMPERATURE) - else: - base_exaggeration = BASE_EXAGGERATION - base_cfg_weight = BASE_CFG_WEIGHT - base_temperature = BASE_TEMPERATURE - - enriched = [] - chunk_texts = [chunk_text for chunk_text, _ in chunks] - - for i, (chunk_text, is_para_end) in enumerate(chunks): - sentiment_scores = analyzer.polarity_scores(chunk_text) - compound_score = sentiment_scores['compound'] - - exaggeration = base_exaggeration + (compound_score * VADER_EXAGGERATION_SENSITIVITY) - cfg_weight = base_cfg_weight + (compound_score * VADER_CFG_WEIGHT_SENSITIVITY) - temperature = base_temperature + (compound_score * VADER_TEMPERATURE_SENSITIVITY) - - # Clamp values to defined min/max - exaggeration = round(max(TTS_PARAM_MIN_EXAGGERATION, min(exaggeration, TTS_PARAM_MAX_EXAGGERATION)), 2) - cfg_weight = round(max(TTS_PARAM_MIN_CFG_WEIGHT, min(cfg_weight, TTS_PARAM_MAX_CFG_WEIGHT)), 2) - temperature = round(max(TTS_PARAM_MIN_TEMPERATURE, min(temperature, TTS_PARAM_MAX_TEMPERATURE)), 2) - - boundary_type = detect_content_boundaries(chunk_text, i, chunk_texts, is_para_end) - - enriched.append({ - "index": i, - "text": chunk_text, - "word_count": len(chunk_text.split()), - "boundary_type": boundary_type if boundary_type else "none", - "sentiment_compound": compound_score, - "tts_params": { - "exaggeration": exaggeration, - "cfg_weight": cfg_weight, - "temperature": temperature - } - }) - - output_json_path = output_dir / "chunks_info.json" - save_chunks(output_json_path, enriched) - return enriched - -def process_book_folder(book_dir, voice_path, tts_params, device, skip_cleanup=False): - """Enhanced book processing with batch processing to prevent hangs""" - print(f"πŸ” DEBUG: Entering process_book_folder with book_dir='{book_dir}', voice_path='{voice_path}'") - - from chatterbox.tts import punc_norm - print(f"πŸ” DEBUG: Successfully imported punc_norm") - - # Setup directories - print(f"πŸ” DEBUG: Calling setup_book_directories...") - output_root, tts_dir, text_chunks_dir, audio_chunks_dir = setup_book_directories(book_dir) - print(f"πŸ” DEBUG: Directory setup complete") - - # Clean previous processing files (but skip for resume operations) - if skip_cleanup: - print(f"πŸ”„ RESUME MODE: Skipping cleanup to preserve existing chunks") - print(f"πŸ“ Preserving: {text_chunks_dir}, {audio_chunks_dir}") - else: - print(f"🧹 FRESH PROCESSING: Cleaning previous processing files...") - import glob - - # Clear text chunks - for txt_file in text_chunks_dir.glob("*.txt"): - txt_file.unlink(missing_ok=True) - for json_file in text_chunks_dir.glob("*.json"): - json_file.unlink(missing_ok=True) - - # Clear audio chunks - for wav_file in audio_chunks_dir.glob("*.wav"): - wav_file.unlink(missing_ok=True) - - # Clear logs - for log_file in output_root.glob("*.log"): - log_file.unlink(missing_ok=True) - - print(f"βœ… Cleanup complete") - - # Find book files - print(f"πŸ” DEBUG: Calling find_book_files...") - book_files = find_book_files(book_dir) - text_files = [book_files['text']] if book_files['text'] else [] - cover_file = book_files['cover'] - nfo_file = book_files['nfo'] - print(f"πŸ” DEBUG: Found text files: {text_files}") - - if not text_files: - logging.info(f"[{book_dir.name}] ERROR: No .txt files found in the book folder.") - return None, None, [] - - setup_logging(output_root) - - # Generate enriched chunks with VADER analysis using user parameters - all_chunks = generate_enriched_chunks(text_files[0], text_chunks_dir, tts_params) - - # Create run_log_lines - print(f"πŸ” DEBUG: Creating run_log_lines...") - print(f"πŸ” DEBUG: voice_path type: {type(voice_path)}, value: {voice_path}") - - # Extract voice name for logging - voice_name_for_log = voice_path.stem if hasattr(voice_path, 'stem') else Path(voice_path).stem - - run_log_lines = [ - f"\n===== Processing: {book_dir.name} =====", - f"Voice: {voice_name_for_log}", - f"Started: {time.strftime('%Y-%m-%d %H:%M:%S')}", - f"Text files processed: {len(text_files)}", - f"Total chunks generated: {len(all_chunks)}" - ] - - start_time = time.time() - total_chunks = len(all_chunks) - log_path = output_root / "chunk_validation.log" - total_audio_duration = 0.0 - - # Batch processing - print(f"πŸ“Š Processing {total_chunks} chunks in batches of {BATCH_SIZE}") - - all_results = [] - - for batch_start in range(0, total_chunks, BATCH_SIZE): - batch_end = min(batch_start + BATCH_SIZE, total_chunks) - batch_chunks = all_chunks[batch_start:batch_end] - - print(f"\nπŸ”„ Processing batch: chunks {batch_start+1}-{batch_end}") - - # Fresh model for each batch - model = load_optimized_model(device) - compatible_voice = ensure_voice_sample_compatibility(voice_path, output_dir=tts_dir) - model.prepare_conditionals(compatible_voice) - - # Load ASR model once per batch if needed (check user settings first, then global config) - asr_model = None - enable_asr_user = tts_params.get('enable_asr', False) - if enable_asr_user or ENABLE_ASR: - import whisper - print(f"🎀 Loading Whisper ASR model for batch... (user setting: {enable_asr_user})") - # Use same device as TTS model, with fallback to CPU - asr_device = device if torch.cuda.is_available() and device == "cuda" else "cpu" - print(f"🎀 Loading ASR model on device: {asr_device}") - asr_model = whisper.load_model("base", device=asr_device) - - futures = [] - batch_results = [] - - # Dynamic worker allocation - user_max_workers = tts_params.get('max_workers', None) - optimal_workers = get_optimal_workers(user_max_workers) - print(f"πŸ”§ Using {optimal_workers} workers for batch {batch_start+1}-{batch_end}") - - with ThreadPoolExecutor(max_workers=optimal_workers) as executor: - for i, chunk_data in enumerate(batch_chunks): - global_chunk_index = batch_start + i - - # Check for shutdown request - if shutdown_requested: - print(f"\n⏹️ {YELLOW}Stopping submission of new chunks...{RESET}") - break - - # Handle both dictionary and tuple formats for chunk data - if isinstance(chunk_data, dict): - chunk = chunk_data["text"] - boundary_type = chunk_data.get("boundary_type", "none") - # Use chunk-specific TTS params if available, otherwise fall back to global - chunk_tts_params = chunk_data.get("tts_params", tts_params) - else: - # Handle old tuple format (text, is_para_end) - convert to boundary_type - chunk = chunk_data[0] if len(chunk_data) > 0 else str(chunk_data) - # Convert old is_paragraph_end to boundary_type - is_old_para_end = chunk_data[1] if len(chunk_data) > 1 else False - boundary_type = "paragraph_end" if is_old_para_end else "none" - chunk_tts_params = tts_params # Fallback for old format - - # Handle both dictionary and tuple formats for backward compatibility - all_chunk_texts = [] - for cd in all_chunks: - if isinstance(cd, dict): - all_chunk_texts.append(cd["text"]) - else: - # Handle old tuple format (text, is_para_end) - all_chunk_texts.append(cd[0] if len(cd) > 0 else str(cd)) - - futures.append(executor.submit( - process_one_chunk, - global_chunk_index, chunk, text_chunks_dir, audio_chunks_dir, - voice_path, chunk_tts_params, start_time, total_chunks, - punc_norm, book_dir.name, log_run, log_path, device, - model, asr_model, all_chunk_texts, boundary_type - )) - - # Wait for batch to complete - print(f"πŸ”„ {CYAN}Waiting for batch {batch_start+1}-{batch_end} to complete...{RESET}") - completed_count = 0 - - for fut in as_completed(futures): - try: - idx, wav_path = fut.result() - if wav_path and wav_path.exists(): - # Measure actual audio duration for this chunk - chunk_duration = get_chunk_audio_duration(wav_path) - total_audio_duration += chunk_duration - batch_results.append((idx, wav_path)) - - # Update progress every 10 chunks within batch - completed_count += 1 - if completed_count % 10 == 0: - log_chunk_progress(batch_start + completed_count - 1, total_chunks, start_time, total_audio_duration) - - except Exception as e: - logging.error(f"Future failed in batch: {e}") - - # Clean up model after batch - print(f"🧹 Cleaning up after batch {batch_start+1}-{batch_end}") - del model - if asr_model: - del asr_model - torch.cuda.empty_cache() - gc.collect() - time.sleep(2) - - all_results.extend(batch_results) - print(f"βœ… Batch {batch_start+1}-{batch_end} completed ({len(batch_results)} chunks)") - - # Final processing - quarantine_dir = audio_chunks_dir / "quarantine" - pause_for_chunk_review(quarantine_dir) - - # Collect final chunk paths - chunk_paths = get_audio_files_in_directory(audio_chunks_dir) - - if not chunk_paths: - logging.info(f"{RED}❌ No valid audio chunks found. Skipping concatenation and conversion.{RESET}") - return None, None, [] - - # Calculate timing - elapsed_total = time.time() - start_time - elapsed_td = timedelta(seconds=int(elapsed_total)) - - total_audio_duration_final = sum(get_chunk_audio_duration(chunk_path) for chunk_path in chunk_paths) - audio_duration_td = timedelta(seconds=int(total_audio_duration_final)) - realtime_factor = total_audio_duration_final / elapsed_total if elapsed_total > 0 else 0.0 - - print(f"\n⏱️ TTS Processing Complete:") - print(f" Elapsed Time: {CYAN}{str(elapsed_td)}{RESET}") - print(f" Audio Duration: {GREEN}{str(audio_duration_td)}{RESET}") - print(f" Realtime Factor: {YELLOW}{realtime_factor:.2f}x{RESET}") - - # Combine audio - voice_name = voice_path.stem if hasattr(voice_path, 'stem') else Path(voice_path).stem - combined_wav_path = output_root / f"{book_dir.name} [{voice_name}].wav" - print("\nπŸ’Ύ Saving WAV file...") - combine_audio_chunks(chunk_paths, combined_wav_path) - - # M4B conversion with normalization - temp_m4b_path = output_root / "output.m4b" - final_m4b_path = output_root / f"{book_dir.name}[{voice_name}].m4b" - convert_to_m4b(combined_wav_path, temp_m4b_path) - add_metadata_to_m4b(temp_m4b_path, final_m4b_path, cover_file, nfo_file) - - logging.info(f"Audiobook created: {final_m4b_path}") - - # Add final info to run log - run_log_lines.extend([ - f"Combined WAV: {combined_wav_path}", - "--- Generation Settings ---", - f"Batch Processing: Enabled ({BATCH_SIZE} chunks per batch)", - f"ASR Enabled: {enable_asr_user or ENABLE_ASR} (user: {enable_asr_user}, global: {ENABLE_ASR})", - f"Hum Detection: {ENABLE_HUM_DETECTION}", - f"Dynamic Workers: {USE_DYNAMIC_WORKERS}", - f"Voice used: {voice_name}", - f"Exaggeration: {tts_params['exaggeration']}", - f"CFG weight: {tts_params['cfg_weight']}", - f"Temperature: {tts_params['temperature']}", - f"Processing Time: {str(elapsed_td)}", - f"Audio Duration: {str(audio_duration_td)}", - f"Realtime Factor: {realtime_factor:.2f}x", - f"Total Chunks: {len(chunk_paths)}" - ]) - - # Write the run log - log_run("\n".join(run_log_lines), output_root / "run.log") - print(f"πŸ“ Run log written to: {output_root / 'run.log'}") - - return final_m4b_path, combined_wav_path, run_log_lines diff --git a/HF_Deploy/modules/voice_detector.py b/HF_Deploy/modules/voice_detector.py deleted file mode 100644 index 48dfad5bfd4573c97e51eb86e73d0195081894e6..0000000000000000000000000000000000000000 --- a/HF_Deploy/modules/voice_detector.py +++ /dev/null @@ -1,240 +0,0 @@ -""" -Voice Detection Module -Handles voice detection from multiple sources: JSON metadata, log files, filenames -""" - -import re -import json -from pathlib import Path -from config.config import AUDIOBOOK_ROOT -from modules.file_manager import list_voice_samples - - -def get_likely_voices_for_book(book_name, chunks_json_path=None): - """ - Get the most likely voice candidates for a book using the 3 detection methods: - 1. JSON metadata/comments (if available) - 2. run.log file - 3. Generated audiobook filenames (may return multiple) - - Returns: list of (voice_name, voice_path, detection_method) tuples - """ - print(f"πŸ” Finding likely voices for book: {book_name}") - likely_voices = [] - - # Method 1: Check JSON metadata and comments - if chunks_json_path: - voice_from_json = get_voice_from_json(chunks_json_path) - if voice_from_json: - voice_path = find_voice_file_by_name(voice_from_json) - if voice_path: - likely_voices.append((voice_from_json, voice_path, "json_metadata")) - print(f"βœ… Voice found in JSON: {voice_from_json}") - - # Method 2: Check run.log file - voice_from_log = get_voice_from_log(book_name) - if voice_from_log: - voice_path = find_voice_file_by_name(voice_from_log) - if voice_path: - # Avoid duplicates - if not any(v[0] == voice_from_log for v in likely_voices): - likely_voices.append((voice_from_log, voice_path, "run_log")) - print(f"βœ… Voice found in run.log: {voice_from_log}") - - # Method 3: Check generated filename patterns (may find multiple) - voices_from_files = get_voices_from_filenames(book_name) - for voice_name in voices_from_files: - voice_path = find_voice_file_by_name(voice_name) - if voice_path: - # Avoid duplicates - if not any(v[0] == voice_name for v in likely_voices): - likely_voices.append((voice_name, voice_path, "filename_pattern")) - print(f"βœ… Voice found in filename: {voice_name}") - - if not likely_voices: - print(f"⚠️ No likely voices detected for {book_name}") - else: - print(f"πŸ“‹ Found {len(likely_voices)} likely voice candidates") - - return likely_voices - -def detect_voice_for_book(book_name, chunks_json_path=None): - """ - Detect the most likely voice for a book (returns first candidate) - For backwards compatibility with existing code - """ - likely_voices = get_likely_voices_for_book(book_name, chunks_json_path) - if likely_voices: - return likely_voices[0] # Return the first (most likely) candidate - return None, None, "not_found" - - -def get_voice_from_json(json_path): - """Extract voice information from JSON metadata""" - try: - with open(json_path, 'r', encoding='utf-8') as f: - content = f.read() - - # Check for voice metadata in JSON - if '"voice_used":' in content: - data = json.loads(content) - if isinstance(data, dict) and 'voice_used' in data: - return data['voice_used'] - elif isinstance(data, list) and data and 'voice_used' in data[0]: - return data[0]['voice_used'] - - # Check for voice as comment in JSON (fallback option) - voice_comment_match = re.search(r'//\s*voice:\s*([^\n]+)', content, re.IGNORECASE) - if voice_comment_match: - return voice_comment_match.group(1).strip() - - except Exception as e: - print(f"⚠️ Error reading JSON for voice info: {e}") - - return None - - -def get_voice_from_log(book_name): - """Extract voice information from run.log file""" - audiobook_root = Path(AUDIOBOOK_ROOT) - log_file = audiobook_root / book_name / "run.log" - - if log_file.exists(): - try: - with open(log_file, 'r', encoding='utf-8') as f: - for line in f: - line = line.strip() - if line.startswith("Voice: ") or line.startswith("Voice used: "): - voice_name = line.split(": ", 1)[1].strip() - return voice_name - except Exception as e: - print(f"⚠️ Error reading run log: {e}") - - return None - - -def get_voices_from_filenames(book_name): - """Extract voice names from existing audiobook filename patterns (may return multiple)""" - audiobook_root = Path(AUDIOBOOK_ROOT) - book_dir = audiobook_root / book_name - - if not book_dir.exists(): - return [] - - found_voices = [] - - # Look for WAV files with voice pattern: BookName [VoiceName].wav - for wav_file in book_dir.glob("*.wav"): - match = re.search(r'\[([^\]]+)\]\.wav$', wav_file.name) - if match: - voice_name = match.group(1) - if voice_name not in found_voices: - found_voices.append(voice_name) - - # Look for M4B files with voice pattern: BookName[VoiceName].m4b - for m4b_file in book_dir.glob("*.m4b"): - match = re.search(r'\[([^\]]+)\]\.m4b$', m4b_file.name) - if match: - voice_name = match.group(1) - if voice_name not in found_voices: - found_voices.append(voice_name) - - return found_voices - -def get_voice_from_filename(book_name): - """Extract voice name from existing audiobook filename patterns (backwards compatibility)""" - voices = get_voices_from_filenames(book_name) - return voices[0] if voices else None - - -def find_voice_file_by_name(voice_name): - """Find voice file by name in Voice_Samples directory""" - voice_files = list_voice_samples() - - # Exact match first - for voice_file in voice_files: - if voice_file.stem == voice_name: - return voice_file - - # Partial match (case insensitive) - voice_name_lower = voice_name.lower() - for voice_file in voice_files: - if voice_name_lower in voice_file.stem.lower(): - return voice_file - - return None - - - - -def add_voice_to_json(json_path, voice_name, method="metadata"): - """ - Add voice information to JSON file - - method options: - - "metadata": Add as top-level metadata - - "comment": Add as comment that doesn't affect parsing - """ - try: - with open(json_path, 'r', encoding='utf-8') as f: - content = f.read() - - if method == "metadata": - # Add voice as metadata to JSON structure - data = json.loads(content) - - if isinstance(data, list): - # For list format, add metadata as first element or update existing - if data and isinstance(data[0], dict) and not any(key.startswith('text') for key in data[0].keys()): - # First element is already metadata - data[0]['voice_used'] = voice_name - else: - # Insert metadata as first element - metadata = {"voice_used": voice_name, "_metadata": True} - data.insert(0, metadata) - elif isinstance(data, dict): - # For dict format, add to top level - data['voice_used'] = voice_name - - # Save updated JSON - with open(json_path, 'w', encoding='utf-8') as f: - json.dump(data, f, indent=2, ensure_ascii=False) - - elif method == "comment": - # Add voice as comment at the top of file - voice_comment = f"// voice: {voice_name}\n" - - if not content.startswith("// voice:"): - content = voice_comment + content - with open(json_path, 'w', encoding='utf-8') as f: - f.write(content) - - print(f"βœ… Added voice '{voice_name}' to {json_path.name} using {method} method") - return True - - except Exception as e: - print(f"❌ Error adding voice to JSON: {e}") - return False - - -def remove_voice_comment_from_json(json_path): - """Remove voice comment from JSON file for clean processing""" - try: - with open(json_path, 'r', encoding='utf-8') as f: - content = f.read() - - # Remove voice comment lines - lines = content.split('\n') - filtered_lines = [line for line in lines if not line.strip().startswith('// voice:')] - - if len(filtered_lines) != len(lines): - # Comments were removed, save cleaned version - cleaned_content = '\n'.join(filtered_lines) - with open(json_path, 'w', encoding='utf-8') as f: - f.write(cleaned_content) - return True - - except Exception as e: - print(f"⚠️ Error cleaning JSON comments: {e}") - - return False \ No newline at end of file diff --git a/HF_Deploy/requirements.txt b/HF_Deploy/requirements.txt deleted file mode 100644 index 0fef73fab6ff0874881aa235ca25baa2ba0741a0..0000000000000000000000000000000000000000 --- a/HF_Deploy/requirements.txt +++ /dev/null @@ -1,56 +0,0 @@ -# ChatterboxTTS HuggingFace Spaces Requirements -# Optimized for HF Spaces environment with flexible versions - -# Core ML and TTS - Essential (pinned versions for fast builds) -torch==2.6.0 -torchaudio==2.6.0 -transformers==4.46.3 -huggingface_hub>=0.15.0 -safetensors>=0.3.0 - -# Audio processing - Required -soundfile>=0.12.0 -librosa>=0.9.0 -pydub>=0.25.0 -audioread>=3.0.0 - -# ASR System - Intelligent ASR with fallback -openai-whisper>=20231117 - -# System monitoring and resource detection -psutil>=5.8.0 -pynvml>=11.0.0 - -# Core scientific computing (pinned for fast builds) -numpy==2.2.0 -scipy>=1.7.0 - -# Text processing -regex>=2023.0.0 -vaderSentiment>=3.3.0 - -# Web interface - Gradio (let HF manage version) -gradio>=4.0.0 - -# Progress and logging -tqdm>=4.60.0 - -# File handling -pathlib2>=2.3.0 - -# Configuration and utilities -python-dotenv>=1.0.0 - -# Optional utilities -requests>=2.25.0 -packaging>=21.0 - -# Core ChatterboxTTS model dependencies -chatterbox-tts>=0.1.2 -resemble-perth>=1.0.1 -omegaconf>=2.3.0 -einops>=0.6.0 -diffusers>=0.21.0 -tokenizers>=0.13.0 -conformer>=0.3.0 -s3tokenizer==0.2.0 \ No newline at end of file diff --git a/HF_Deploy/src/chatterbox/__init__.py b/HF_Deploy/src/chatterbox/__init__.py deleted file mode 100644 index c8aa565d6cf00b8eaf2b7896ea751bb8091fc77a..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .tts import ChatterboxTTS -from .vc import ChatterboxVC diff --git a/HF_Deploy/src/chatterbox/models/s3gen/__init__.py b/HF_Deploy/src/chatterbox/models/s3gen/__init__.py deleted file mode 100644 index bef618df9cff52479712a67364c1922b2e27ebff..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .s3gen import S3Token2Wav as S3Gen -from .const import S3GEN_SR diff --git a/HF_Deploy/src/chatterbox/models/s3gen/const.py b/HF_Deploy/src/chatterbox/models/s3gen/const.py deleted file mode 100644 index 72de6a2355d1c30dc9ff3ad7ab83df64ea8a17df..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/const.py +++ /dev/null @@ -1 +0,0 @@ -S3GEN_SR = 24000 diff --git a/HF_Deploy/src/chatterbox/models/s3gen/decoder.py b/HF_Deploy/src/chatterbox/models/s3gen/decoder.py deleted file mode 100644 index c568c2dfabd760aa2ee7dcfc688a19a2b5bc6484..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/decoder.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import pack, rearrange, repeat - -from .utils.mask import add_optional_chunk_mask -from .matcha.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, \ - TimestepEmbedding, Upsample1D -from .matcha.transformer import BasicTransformerBlock - - -def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: - assert mask.dtype == torch.bool - assert dtype in [torch.float32, torch.bfloat16, torch.float16] - mask = mask.to(dtype) - # attention mask bias - # NOTE(Mddct): torch.finfo jit issues - # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min - mask = (1.0 - mask) * -1.0e+10 - return mask - - - -class Transpose(torch.nn.Module): - def __init__(self, dim0: int, dim1: int): - super().__init__() - self.dim0 = dim0 - self.dim1 = dim1 - - def forward(self, x: torch.Tensor): - x = torch.transpose(x, self.dim0, self.dim1) - return x - - -class CausalBlock1D(Block1D): - def __init__(self, dim: int, dim_out: int): - super(CausalBlock1D, self).__init__(dim, dim_out) - self.block = torch.nn.Sequential( - CausalConv1d(dim, dim_out, 3), - Transpose(1, 2), - nn.LayerNorm(dim_out), - Transpose(1, 2), - nn.Mish(), - ) - - def forward(self, x: torch.Tensor, mask: torch.Tensor): - output = self.block(x * mask) - return output * mask - - -class CausalResnetBlock1D(ResnetBlock1D): - def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8): - super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups) - self.block1 = CausalBlock1D(dim, dim_out) - self.block2 = CausalBlock1D(dim_out, dim_out) - - -class CausalConv1d(torch.nn.Conv1d): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: int = 1, - dilation: int = 1, - groups: int = 1, - bias: bool = True, - padding_mode: str = 'zeros', - device=None, - dtype=None - ) -> None: - super(CausalConv1d, self).__init__(in_channels, out_channels, - kernel_size, stride, - padding=0, dilation=dilation, - groups=groups, bias=bias, - padding_mode=padding_mode, - device=device, dtype=dtype) - assert stride == 1 - self.causal_padding = (kernel_size - 1, 0) - - def forward(self, x: torch.Tensor): - x = F.pad(x, self.causal_padding) - x = super(CausalConv1d, self).forward(x) - return x - - -class ConditionalDecoder(nn.Module): - def __init__( - self, - in_channels=320, - out_channels=80, - causal=True, - channels=[256], - dropout=0.0, - attention_head_dim=64, - n_blocks=4, - num_mid_blocks=12, - num_heads=8, - act_fn="gelu", - ): - """ - This decoder requires an input with the same shape of the target. So, if your text content - is shorter or longer than the outputs, please re-sampling it before feeding to the decoder. - """ - super().__init__() - channels = tuple(channels) - self.in_channels = in_channels - self.out_channels = out_channels - self.causal = causal - self.time_embeddings = SinusoidalPosEmb(in_channels) - time_embed_dim = channels[0] * 4 - self.time_mlp = TimestepEmbedding( - in_channels=in_channels, - time_embed_dim=time_embed_dim, - act_fn="silu", - ) - self.down_blocks = nn.ModuleList([]) - self.mid_blocks = nn.ModuleList([]) - self.up_blocks = nn.ModuleList([]) - - # NOTE jrm: `static_chunk_size` is missing? - self.static_chunk_size = 0 - - output_channel = in_channels - for i in range(len(channels)): # pylint: disable=consider-using-enumerate - input_channel = output_channel - output_channel = channels[i] - is_last = i == len(channels) - 1 - resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ - ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) - transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - dim=output_channel, - num_attention_heads=num_heads, - attention_head_dim=attention_head_dim, - dropout=dropout, - activation_fn=act_fn, - ) - for _ in range(n_blocks) - ] - ) - downsample = ( - Downsample1D(output_channel) if not is_last else - CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) - ) - self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) - - for _ in range(num_mid_blocks): - input_channel = channels[-1] - out_channels = channels[-1] - resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \ - ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) - - transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - dim=output_channel, - num_attention_heads=num_heads, - attention_head_dim=attention_head_dim, - dropout=dropout, - activation_fn=act_fn, - ) - for _ in range(n_blocks) - ] - ) - - self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) - - channels = channels[::-1] + (channels[0],) - for i in range(len(channels) - 1): - input_channel = channels[i] * 2 - output_channel = channels[i + 1] - is_last = i == len(channels) - 2 - resnet = CausalResnetBlock1D( - dim=input_channel, - dim_out=output_channel, - time_emb_dim=time_embed_dim, - ) if self.causal else ResnetBlock1D( - dim=input_channel, - dim_out=output_channel, - time_emb_dim=time_embed_dim, - ) - transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - dim=output_channel, - num_attention_heads=num_heads, - attention_head_dim=attention_head_dim, - dropout=dropout, - activation_fn=act_fn, - ) - for _ in range(n_blocks) - ] - ) - upsample = ( - Upsample1D(output_channel, use_conv_transpose=True) - if not is_last - else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1) - ) - self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) - self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1]) - self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) - self.initialize_weights() - - def initialize_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv1d): - nn.init.kaiming_normal_(m.weight, nonlinearity="relu") - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.GroupNorm): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.Linear): - nn.init.kaiming_normal_(m.weight, nonlinearity="relu") - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def forward(self, x, mask, mu, t, spks=None, cond=None): - """Forward pass of the UNet1DConditional model. - - Args: - x (torch.Tensor): shape (batch_size, in_channels, time) - mask (_type_): shape (batch_size, 1, time) - t (_type_): shape (batch_size) - spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. - cond (_type_, optional): placeholder for future use. Defaults to None. - - Raises: - ValueError: _description_ - ValueError: _description_ - - Returns: - _type_: _description_ - """ - - t = self.time_embeddings(t).to(t.dtype) - t = self.time_mlp(t) - - x = pack([x, mu], "b * t")[0] - - if spks is not None: - spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) - x = pack([x, spks], "b * t")[0] - if cond is not None: - x = pack([x, cond], "b * t")[0] - - hiddens = [] - masks = [mask] - for resnet, transformer_blocks, downsample in self.down_blocks: - mask_down = masks[-1] - x = resnet(x, mask_down, t) - x = rearrange(x, "b c t -> b t c").contiguous() - # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down) - attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1) - attn_mask = mask_to_bias(attn_mask == 1, x.dtype) - for transformer_block in transformer_blocks: - x = transformer_block( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - ) - x = rearrange(x, "b t c -> b c t").contiguous() - hiddens.append(x) # Save hidden states for skip connections - x = downsample(x * mask_down) - masks.append(mask_down[:, :, ::2]) - masks = masks[:-1] - mask_mid = masks[-1] - - for resnet, transformer_blocks in self.mid_blocks: - x = resnet(x, mask_mid, t) - x = rearrange(x, "b c t -> b t c").contiguous() - # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid) - attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1) - attn_mask = mask_to_bias(attn_mask == 1, x.dtype) - for transformer_block in transformer_blocks: - x = transformer_block( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - ) - x = rearrange(x, "b t c -> b c t").contiguous() - - for resnet, transformer_blocks, upsample in self.up_blocks: - mask_up = masks.pop() - skip = hiddens.pop() - x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0] - x = resnet(x, mask_up, t) - x = rearrange(x, "b c t -> b t c").contiguous() - # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up) - attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1) - attn_mask = mask_to_bias(attn_mask == 1, x.dtype) - for transformer_block in transformer_blocks: - x = transformer_block( - hidden_states=x, - attention_mask=attn_mask, - timestep=t, - ) - x = rearrange(x, "b t c -> b c t").contiguous() - x = upsample(x * mask_up) - x = self.final_block(x, mask_up) - output = self.final_proj(x * mask_up) - return output * mask diff --git a/HF_Deploy/src/chatterbox/models/s3gen/f0_predictor.py b/HF_Deploy/src/chatterbox/models/s3gen/f0_predictor.py deleted file mode 100644 index 172c5f50bdece3d4ac2b3874b0a32deb9f957b93..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/f0_predictor.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch -import torch.nn as nn -from torch.nn.utils.parametrizations import weight_norm - - -class ConvRNNF0Predictor(nn.Module): - def __init__(self, - num_class: int = 1, - in_channels: int = 80, - cond_channels: int = 512 - ): - super().__init__() - - self.num_class = num_class - self.condnet = nn.Sequential( - weight_norm( - nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) - ), - nn.ELU(), - weight_norm( - nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) - ), - nn.ELU(), - weight_norm( - nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) - ), - nn.ELU(), - weight_norm( - nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) - ), - nn.ELU(), - weight_norm( - nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) - ), - nn.ELU(), - ) - self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.condnet(x) - x = x.transpose(1, 2) - return torch.abs(self.classifier(x).squeeze(-1)) diff --git a/HF_Deploy/src/chatterbox/models/s3gen/flow.py b/HF_Deploy/src/chatterbox/models/s3gen/flow.py deleted file mode 100644 index a460ddef5db032967e849a2c4e134fcdf58d622d..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/flow.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import random -from typing import Dict, Optional -import torch -import torch.nn as nn -from torch.nn import functional as F -from omegaconf import DictConfig -from .utils.mask import make_pad_mask - - -class MaskedDiffWithXvec(torch.nn.Module): - def __init__(self, - input_size: int = 512, - output_size: int = 80, - spk_embed_dim: int = 192, - output_type: str = "mel", - vocab_size: int = 4096, - input_frame_rate: int = 50, - only_mask_loss: bool = True, - encoder: torch.nn.Module = None, - length_regulator: torch.nn.Module = None, - decoder: torch.nn.Module = None, - decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, - 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', - 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), - 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, - 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, - mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, - 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}): - super().__init__() - self.input_size = input_size - self.output_size = output_size - self.decoder_conf = decoder_conf - self.mel_feat_conf = mel_feat_conf - self.vocab_size = vocab_size - self.output_type = output_type - self.input_frame_rate = input_frame_rate - logging.info(f"input frame rate={self.input_frame_rate}") - self.input_embedding = nn.Embedding(vocab_size, input_size) - self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size) - self.encoder = encoder - self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size) - self.decoder = decoder - self.length_regulator = length_regulator - self.only_mask_loss = only_mask_loss - - def forward( - self, - batch: dict, - device: torch.device, - ) -> Dict[str, Optional[torch.Tensor]]: - token = batch['speech_token'].to(device) - token_len = batch['speech_token_len'].to(device) - feat = batch['speech_feat'].to(device) - feat_len = batch['speech_feat_len'].to(device) - embedding = batch['embedding'].to(device) - - # xvec projection - embedding = F.normalize(embedding, dim=1) - embedding = self.spk_embed_affine_layer(embedding) - - # concat text and prompt_text - mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) - token = self.input_embedding(torch.clamp(token, min=0)) * mask - - # text encode - h, h_lengths = self.encoder(token, token_len) - h = self.encoder_proj(h) - h, h_lengths = self.length_regulator(h, feat_len) - - # get conditions - conds = torch.zeros(feat.shape, device=token.device) - for i, j in enumerate(feat_len): - if random.random() < 0.5: - continue - index = random.randint(0, int(0.3 * j)) - conds[i, :index] = feat[i, :index] - conds = conds.transpose(1, 2) - - mask = (~make_pad_mask(feat_len)).to(h) - feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) - loss, _ = self.decoder.compute_loss( - feat.transpose(1, 2).contiguous(), - mask.unsqueeze(1), - h.transpose(1, 2).contiguous(), - embedding, - cond=conds - ) - return {'loss': loss} - - @torch.inference_mode() - def inference(self, - token, - token_len, - prompt_token, - prompt_token_len, - prompt_feat, - prompt_feat_len, - embedding, - flow_cache): - if self.fp16 is True: - prompt_feat = prompt_feat.half() - embedding = embedding.half() - - assert token.shape[0] == 1 - # xvec projection - embedding = F.normalize(embedding, dim=1) - embedding = self.spk_embed_affine_layer(embedding) - - # concat text and prompt_text - token_len1, token_len2 = prompt_token.shape[1], token.shape[1] - token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len - mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) - token = self.input_embedding(torch.clamp(token, min=0)) * mask - - # text encode - h, h_lengths = self.encoder(token, token_len) - h = self.encoder_proj(h) - mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256) - h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate) - - # get conditions - conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype) - conds[:, :mel_len1] = prompt_feat - conds = conds.transpose(1, 2) - - mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) - feat, flow_cache = self.decoder( - mu=h.transpose(1, 2).contiguous(), - mask=mask.unsqueeze(1), - spks=embedding, - cond=conds, - n_timesteps=10, - prompt_len=mel_len1, - flow_cache=flow_cache - ) - feat = feat[:, :, mel_len1:] - assert feat.shape[2] == mel_len2 - return feat.float(), flow_cache - - -class CausalMaskedDiffWithXvec(torch.nn.Module): - def __init__(self, - input_size: int = 512, - output_size: int = 80, - spk_embed_dim: int = 192, - output_type: str = "mel", - vocab_size: int = 6561, - input_frame_rate: int = 25, - only_mask_loss: bool = True, - token_mel_ratio: int = 2, - pre_lookahead_len: int = 3, - encoder: torch.nn.Module = None, - decoder: torch.nn.Module = None, - decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, - 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', - 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), - 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, - 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, - mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, - 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}): - super().__init__() - self.input_size = input_size - self.output_size = output_size - self.decoder_conf = decoder_conf - self.mel_feat_conf = mel_feat_conf - self.vocab_size = vocab_size - self.output_type = output_type - self.input_frame_rate = input_frame_rate - logging.info(f"input frame rate={self.input_frame_rate}") - self.input_embedding = nn.Embedding(vocab_size, input_size) - self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size) - self.encoder = encoder - self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size) - self.decoder = decoder - self.only_mask_loss = only_mask_loss - self.token_mel_ratio = token_mel_ratio - self.pre_lookahead_len = pre_lookahead_len - - # FIXME: this was missing - just putting it in as false - self.fp16 = False - - @torch.inference_mode() - def inference(self, - token, - token_len, - prompt_token, - prompt_token_len, - prompt_feat, - prompt_feat_len, - embedding, - finalize): - if self.fp16 is True: - prompt_feat = prompt_feat.half() - embedding = embedding.half() - - assert token.shape[0] == 1 - # xvec projection - embedding = F.normalize(embedding, dim=1) - embedding = self.spk_embed_affine_layer(embedding) - - # concat text and prompt_text - token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len - mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) - token = self.input_embedding(torch.clamp(token, min=0)) * mask - - # text encode - h, h_lengths = self.encoder(token, token_len) - if finalize is False: - h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio] - mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1] - h = self.encoder_proj(h) - - # get conditions - conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype) - conds[:, :mel_len1] = prompt_feat - conds = conds.transpose(1, 2) - - mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) - feat, _ = self.decoder( - mu=h.transpose(1, 2).contiguous(), - mask=mask.unsqueeze(1), - spks=embedding, - cond=conds, - n_timesteps=10 - ) - feat = feat[:, :, mel_len1:] - assert feat.shape[2] == mel_len2 - return feat.float(), None # NOTE jrm: why are they returning None here? diff --git a/HF_Deploy/src/chatterbox/models/s3gen/flow_matching.py b/HF_Deploy/src/chatterbox/models/s3gen/flow_matching.py deleted file mode 100644 index 8307e3c0d6120a81b6ff414fafa30e9fc63d015c..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/flow_matching.py +++ /dev/null @@ -1,228 +0,0 @@ -# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import threading -import torch -import torch.nn.functional as F -from .matcha.flow_matching import BASECFM -from omegaconf import OmegaConf - - -CFM_PARAMS = OmegaConf.create({ - "sigma_min": 1e-06, - "solver": "euler", - "t_scheduler": "cosine", - "training_cfg_rate": 0.2, - "inference_cfg_rate": 0.7, - "reg_loss_type": "l1" -}) - - -class ConditionalCFM(BASECFM): - def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None): - super().__init__( - n_feats=in_channels, - cfm_params=cfm_params, - n_spks=n_spks, - spk_emb_dim=spk_emb_dim, - ) - self.t_scheduler = cfm_params.t_scheduler - self.training_cfg_rate = cfm_params.training_cfg_rate - self.inference_cfg_rate = cfm_params.inference_cfg_rate - in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) - # Just change the architecture of the estimator here - self.estimator = estimator - self.lock = threading.Lock() - - @torch.inference_mode() - def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)): - """Forward diffusion - - Args: - mu (torch.Tensor): output of encoder - shape: (batch_size, n_feats, mel_timesteps) - mask (torch.Tensor): output_mask - shape: (batch_size, 1, mel_timesteps) - n_timesteps (int): number of diffusion steps - temperature (float, optional): temperature for scaling noise. Defaults to 1.0. - spks (torch.Tensor, optional): speaker ids. Defaults to None. - shape: (batch_size, spk_emb_dim) - cond: Not used but kept for future purposes - - Returns: - sample: generated mel-spectrogram - shape: (batch_size, n_feats, mel_timesteps) - """ - - z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature - cache_size = flow_cache.shape[2] - # fix prompt and overlap part mu and z - if cache_size != 0: - z[:, :, :cache_size] = flow_cache[:, :, :, 0] - mu[:, :, :cache_size] = flow_cache[:, :, :, 1] - z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2) - mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2) - flow_cache = torch.stack([z_cache, mu_cache], dim=-1) - - t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) - if self.t_scheduler == 'cosine': - t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) - return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache - - def solve_euler(self, x, t_span, mu, mask, spks, cond): - """ - Fixed euler solver for ODEs. - Args: - x (torch.Tensor): random noise - t_span (torch.Tensor): n_timesteps interpolated - shape: (n_timesteps + 1,) - mu (torch.Tensor): output of encoder - shape: (batch_size, n_feats, mel_timesteps) - mask (torch.Tensor): output_mask - shape: (batch_size, 1, mel_timesteps) - spks (torch.Tensor, optional): speaker ids. Defaults to None. - shape: (batch_size, spk_emb_dim) - cond: Not used but kept for future purposes - """ - t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] - t = t.unsqueeze(dim=0) - - # I am storing this because I can later plot it by putting a debugger here and saving it to a file - # Or in future might add like a return_all_steps flag - sol = [] - - # Do not use concat, it may cause memory format changed and trt infer with wrong results! - x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) - mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype) - mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) - t_in = torch.zeros([2], device=x.device, dtype=x.dtype) - spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype) - cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype) - for step in range(1, len(t_span)): - # Classifier-Free Guidance inference introduced in VoiceBox - x_in[:] = x - mask_in[:] = mask - mu_in[0] = mu - t_in[:] = t.unsqueeze(0) - spks_in[0] = spks - cond_in[0] = cond - dphi_dt = self.forward_estimator( - x_in, mask_in, - mu_in, t_in, - spks_in, - cond_in - ) - dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0) - dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt) - x = x + dt * dphi_dt - t = t + dt - sol.append(x) - if step < len(t_span) - 1: - dt = t_span[step + 1] - t - - return sol[-1].float() - - def forward_estimator(self, x, mask, mu, t, spks, cond): - if isinstance(self.estimator, torch.nn.Module): - return self.estimator.forward(x, mask, mu, t, spks, cond) - else: - with self.lock: - self.estimator.set_input_shape('x', (2, 80, x.size(2))) - self.estimator.set_input_shape('mask', (2, 1, x.size(2))) - self.estimator.set_input_shape('mu', (2, 80, x.size(2))) - self.estimator.set_input_shape('t', (2,)) - self.estimator.set_input_shape('spks', (2, 80)) - self.estimator.set_input_shape('cond', (2, 80, x.size(2))) - # run trt engine - self.estimator.execute_v2([x.contiguous().data_ptr(), - mask.contiguous().data_ptr(), - mu.contiguous().data_ptr(), - t.contiguous().data_ptr(), - spks.contiguous().data_ptr(), - cond.contiguous().data_ptr(), - x.data_ptr()]) - return x - - def compute_loss(self, x1, mask, mu, spks=None, cond=None): - """Computes diffusion loss - - Args: - x1 (torch.Tensor): Target - shape: (batch_size, n_feats, mel_timesteps) - mask (torch.Tensor): target mask - shape: (batch_size, 1, mel_timesteps) - mu (torch.Tensor): output of encoder - shape: (batch_size, n_feats, mel_timesteps) - spks (torch.Tensor, optional): speaker embedding. Defaults to None. - shape: (batch_size, spk_emb_dim) - - Returns: - loss: conditional flow matching loss - y: conditional flow - shape: (batch_size, n_feats, mel_timesteps) - """ - b, _, t = mu.shape - - # random timestep - t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) - if self.t_scheduler == 'cosine': - t = 1 - torch.cos(t * 0.5 * torch.pi) - # sample noise p(x_0) - z = torch.randn_like(x1) - - y = (1 - (1 - self.sigma_min) * t) * z + t * x1 - u = x1 - (1 - self.sigma_min) * z - - # during training, we randomly drop condition to trade off mode coverage and sample fidelity - if self.training_cfg_rate > 0: - cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate - mu = mu * cfg_mask.view(-1, 1, 1) - spks = spks * cfg_mask.view(-1, 1) - cond = cond * cfg_mask.view(-1, 1, 1) - - pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) - loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) - return loss, y - - -class CausalConditionalCFM(ConditionalCFM): - def __init__(self, in_channels=240, cfm_params=CFM_PARAMS, n_spks=1, spk_emb_dim=80, estimator=None): - super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator) - self.rand_noise = torch.randn([1, 80, 50 * 300]) - - @torch.inference_mode() - def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): - """Forward diffusion - - Args: - mu (torch.Tensor): output of encoder - shape: (batch_size, n_feats, mel_timesteps) - mask (torch.Tensor): output_mask - shape: (batch_size, 1, mel_timesteps) - n_timesteps (int): number of diffusion steps - temperature (float, optional): temperature for scaling noise. Defaults to 1.0. - spks (torch.Tensor, optional): speaker ids. Defaults to None. - shape: (batch_size, spk_emb_dim) - cond: Not used but kept for future purposes - - Returns: - sample: generated mel-spectrogram - shape: (batch_size, n_feats, mel_timesteps) - """ - - z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature - # fix prompt and overlap part mu and z - t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) - if self.t_scheduler == 'cosine': - t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) - return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None diff --git a/HF_Deploy/src/chatterbox/models/s3gen/hifigan.py b/HF_Deploy/src/chatterbox/models/s3gen/hifigan.py deleted file mode 100644 index 33f9387e8018169d175fba777a9d70d89035348a..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/hifigan.py +++ /dev/null @@ -1,474 +0,0 @@ -# jrm: adapted from CosyVoice/cosyvoice/hifigan/generator.py -# most modules should be reusable, but I found their SineGen changed a git. - -# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""HIFI-GAN""" - -from typing import Dict, Optional, List -import numpy as np -from scipy.signal import get_window -import torch -import torch.nn.functional as F -from torch.nn import Conv1d -from torch.nn import ConvTranspose1d -from torch.nn.utils import remove_weight_norm -from torch.nn.utils.parametrizations import weight_norm -from torch.distributions.uniform import Uniform -from torch import nn, sin, pow -from torch.nn import Parameter - - -class Snake(nn.Module): - ''' - Implementation of a sine-based periodic activation function - Shape: - - Input: (B, C, T) - - Output: (B, C, T), same shape as the input - Parameters: - - alpha - trainable parameter - References: - - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: - https://arxiv.org/abs/2006.08195 - Examples: - >>> a1 = snake(256) - >>> x = torch.randn(256) - >>> x = a1(x) - ''' - def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): - ''' - Initialization. - INPUT: - - in_features: shape of the input - - alpha: trainable parameter - alpha is initialized to 1 by default, higher values = higher-frequency. - alpha will be trained along with the rest of your model. - ''' - super(Snake, self).__init__() - self.in_features = in_features - - # initialize alpha - self.alpha_logscale = alpha_logscale - if self.alpha_logscale: # log scale alphas initialized to zeros - self.alpha = Parameter(torch.zeros(in_features) * alpha) - else: # linear scale alphas initialized to ones - self.alpha = Parameter(torch.ones(in_features) * alpha) - - self.alpha.requires_grad = alpha_trainable - - self.no_div_by_zero = 0.000000001 - - def forward(self, x): - ''' - Forward pass of the function. - Applies the function to the input elementwise. - Snake ∢= x + 1/a * sin^2 (xa) - ''' - alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] - if self.alpha_logscale: - alpha = torch.exp(alpha) - x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) - - return x - - - -def get_padding(kernel_size, dilation=1): - return int((kernel_size * dilation - dilation) / 2) - -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - -"""hifigan based generator implementation. - -This code is modified from https://github.com/jik876/hifi-gan - ,https://github.com/kan-bayashi/ParallelWaveGAN and - https://github.com/NVIDIA/BigVGAN - -""" - - -class ResBlock(torch.nn.Module): - """Residual block module in HiFiGAN/BigVGAN.""" - def __init__( - self, - channels: int = 512, - kernel_size: int = 3, - dilations: List[int] = [1, 3, 5], - ): - super(ResBlock, self).__init__() - self.convs1 = nn.ModuleList() - self.convs2 = nn.ModuleList() - - for dilation in dilations: - self.convs1.append( - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation, - padding=get_padding(kernel_size, dilation) - ) - ) - ) - self.convs2.append( - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1) - ) - ) - ) - self.convs1.apply(init_weights) - self.convs2.apply(init_weights) - self.activations1 = nn.ModuleList([ - Snake(channels, alpha_logscale=False) - for _ in range(len(self.convs1)) - ]) - self.activations2 = nn.ModuleList([ - Snake(channels, alpha_logscale=False) - for _ in range(len(self.convs2)) - ]) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - for idx in range(len(self.convs1)): - xt = self.activations1[idx](x) - xt = self.convs1[idx](xt) - xt = self.activations2[idx](xt) - xt = self.convs2[idx](xt) - x = xt + x - return x - - def remove_weight_norm(self): - for idx in range(len(self.convs1)): - remove_weight_norm(self.convs1[idx]) - remove_weight_norm(self.convs2[idx]) - - -class SineGen(torch.nn.Module): - """ Definition of sine generator - SineGen(samp_rate, harmonic_num = 0, - sine_amp = 0.1, noise_std = 0.003, - voiced_threshold = 0, - flag_for_pulse=False) - samp_rate: sampling rate in Hz - harmonic_num: number of harmonic overtones (default 0) - sine_amp: amplitude of sine-wavefrom (default 0.1) - noise_std: std of Gaussian noise (default 0.003) - voiced_thoreshold: F0 threshold for U/V classification (default 0) - flag_for_pulse: this SinGen is used inside PulseGen (default False) - Note: when flag_for_pulse is True, the first time step of a voiced - segment is always sin(np.pi) or cos(0) - """ - - def __init__(self, samp_rate, harmonic_num=0, - sine_amp=0.1, noise_std=0.003, - voiced_threshold=0): - super(SineGen, self).__init__() - self.sine_amp = sine_amp - self.noise_std = noise_std - self.harmonic_num = harmonic_num - self.sampling_rate = samp_rate - self.voiced_threshold = voiced_threshold - - def _f02uv(self, f0): - # generate uv signal - uv = (f0 > self.voiced_threshold).type(torch.float32) - return uv - - @torch.no_grad() - def forward(self, f0): - """ - :param f0: [B, 1, sample_len], Hz - :return: [B, 1, sample_len] - """ - - F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device) - for i in range(self.harmonic_num + 1): - F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate - - theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1) - u_dist = Uniform(low=-np.pi, high=np.pi) - phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device) - phase_vec[:, 0, :] = 0 - - # generate sine waveforms - sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec) - - # generate uv signal - uv = self._f02uv(f0) - - # noise: for unvoiced should be similar to sine_amp - # std = self.sine_amp/3 -> max value ~ self.sine_amp - # . for voiced regions is self.noise_std - noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 - noise = noise_amp * torch.randn_like(sine_waves) - - # first: set the unvoiced part to 0 by uv - # then: additive noise - sine_waves = sine_waves * uv + noise - return sine_waves, uv, noise - - -class SourceModuleHnNSF(torch.nn.Module): - """ SourceModule for hn-nsf - SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, - add_noise_std=0.003, voiced_threshod=0) - sampling_rate: sampling_rate in Hz - harmonic_num: number of harmonic above F0 (default: 0) - sine_amp: amplitude of sine source signal (default: 0.1) - add_noise_std: std of additive Gaussian noise (default: 0.003) - note that amplitude of noise in unvoiced is decided - by sine_amp - voiced_threshold: threhold to set U/V given F0 (default: 0) - Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) - F0_sampled (batchsize, length, 1) - Sine_source (batchsize, length, 1) - noise_source (batchsize, length 1) - uv (batchsize, length, 1) - """ - - def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, - add_noise_std=0.003, voiced_threshod=0): - super(SourceModuleHnNSF, self).__init__() - - self.sine_amp = sine_amp - self.noise_std = add_noise_std - - # to produce sine waveforms - self.l_sin_gen = SineGen(sampling_rate, harmonic_num, - sine_amp, add_noise_std, voiced_threshod) - - # to merge source harmonics into a single excitation - self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) - self.l_tanh = torch.nn.Tanh() - - def forward(self, x): - """ - Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) - F0_sampled (batchsize, length, 1) - Sine_source (batchsize, length, 1) - noise_source (batchsize, length 1) - """ - # source for harmonic branch - with torch.no_grad(): - sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2)) - sine_wavs = sine_wavs.transpose(1, 2) - uv = uv.transpose(1, 2) - sine_merge = self.l_tanh(self.l_linear(sine_wavs)) - - # source for noise branch, in the same shape as uv - noise = torch.randn_like(uv) * self.sine_amp / 3 - return sine_merge, noise, uv - - -class HiFTGenerator(nn.Module): - """ - HiFTNet Generator: Neural Source Filter + ISTFTNet - https://arxiv.org/abs/2309.09493 - """ - def __init__( - self, - in_channels: int = 80, - base_channels: int = 512, - nb_harmonics: int = 8, - sampling_rate: int = 22050, - nsf_alpha: float = 0.1, - nsf_sigma: float = 0.003, - nsf_voiced_threshold: float = 10, - upsample_rates: List[int] = [8, 8], - upsample_kernel_sizes: List[int] = [16, 16], - istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4}, - resblock_kernel_sizes: List[int] = [3, 7, 11], - resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], - source_resblock_kernel_sizes: List[int] = [7, 11], - source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]], - lrelu_slope: float = 0.1, - audio_limit: float = 0.99, - f0_predictor: torch.nn.Module = None, - ): - super(HiFTGenerator, self).__init__() - - self.out_channels = 1 - self.nb_harmonics = nb_harmonics - self.sampling_rate = sampling_rate - self.istft_params = istft_params - self.lrelu_slope = lrelu_slope - self.audio_limit = audio_limit - - self.num_kernels = len(resblock_kernel_sizes) - self.num_upsamples = len(upsample_rates) - self.m_source = SourceModuleHnNSF( - sampling_rate=sampling_rate, - upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"], - harmonic_num=nb_harmonics, - sine_amp=nsf_alpha, - add_noise_std=nsf_sigma, - voiced_threshod=nsf_voiced_threshold) - self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"]) - - self.conv_pre = weight_norm( - Conv1d(in_channels, base_channels, 7, 1, padding=3) - ) - - # Up - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): - self.ups.append( - weight_norm( - ConvTranspose1d( - base_channels // (2**i), - base_channels // (2**(i + 1)), - k, - u, - padding=(k - u) // 2, - ) - ) - ) - - # Down - self.source_downs = nn.ModuleList() - self.source_resblocks = nn.ModuleList() - downsample_rates = [1] + upsample_rates[::-1][:-1] - downsample_cum_rates = np.cumprod(downsample_rates) - for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)): - if u == 1: - self.source_downs.append( - Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1) - ) - else: - self.source_downs.append( - Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2)) - ) - - self.source_resblocks.append( - ResBlock(base_channels // (2 ** (i + 1)), k, d) - ) - - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = base_channels // (2**(i + 1)) - for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): - self.resblocks.append(ResBlock(ch, k, d)) - - self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)) - self.ups.apply(init_weights) - self.conv_post.apply(init_weights) - self.reflection_pad = nn.ReflectionPad1d((1, 0)) - self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32)) - self.f0_predictor = f0_predictor - - def remove_weight_norm(self): - print('Removing weight norm...') - for l in self.ups: - remove_weight_norm(l) - for l in self.resblocks: - l.remove_weight_norm() - remove_weight_norm(self.conv_pre) - remove_weight_norm(self.conv_post) - self.m_source.remove_weight_norm() - for l in self.source_downs: - remove_weight_norm(l) - for l in self.source_resblocks: - l.remove_weight_norm() - - def _stft(self, x): - spec = torch.stft( - x, - self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device), - return_complex=True) - spec = torch.view_as_real(spec) # [B, F, TT, 2] - return spec[..., 0], spec[..., 1] - - def _istft(self, magnitude, phase): - magnitude = torch.clip(magnitude, max=1e2) - real = magnitude * torch.cos(phase) - img = magnitude * torch.sin(phase) - inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], - self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device)) - return inverse_transform - - def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: - s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) - s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1) - - x = self.conv_pre(x) - for i in range(self.num_upsamples): - x = F.leaky_relu(x, self.lrelu_slope) - x = self.ups[i](x) - - if i == self.num_upsamples - 1: - x = self.reflection_pad(x) - - # fusion - si = self.source_downs[i](s_stft) - si = self.source_resblocks[i](si) - x = x + si - - xs = None - for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i * self.num_kernels + j](x) - else: - xs += self.resblocks[i * self.num_kernels + j](x) - x = xs / self.num_kernels - - x = F.leaky_relu(x) - x = self.conv_post(x) - magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :]) - phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy - - x = self._istft(magnitude, phase) - x = torch.clamp(x, -self.audio_limit, self.audio_limit) - return x - - def forward( - self, - batch: dict, - device: torch.device, - ) -> Dict[str, Optional[torch.Tensor]]: - speech_feat = batch['speech_feat'].transpose(1, 2).to(device) - # mel->f0 - f0 = self.f0_predictor(speech_feat) - # f0->source - s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t - s, _, _ = self.m_source(s) - s = s.transpose(1, 2) - # mel+source->speech - generated_speech = self.decode(x=speech_feat, s=s) - return generated_speech, f0 - - @torch.inference_mode() - def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor: - # mel->f0 - f0 = self.f0_predictor(speech_feat) - # f0->source - s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t - s, _, _ = self.m_source(s) - s = s.transpose(1, 2) - # use cache_source to avoid glitch - if cache_source.shape[2] != 0: - s[:, :, :cache_source.shape[2]] = cache_source - generated_speech = self.decode(x=speech_feat, s=s) - return generated_speech, s diff --git a/HF_Deploy/src/chatterbox/models/s3gen/matcha/decoder.py b/HF_Deploy/src/chatterbox/models/s3gen/matcha/decoder.py deleted file mode 100644 index 6919f32d9d0a04a0c734190d8b815abb40ad69db..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/matcha/decoder.py +++ /dev/null @@ -1,443 +0,0 @@ -import math -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -from conformer import ConformerBlock -from diffusers.models.activations import get_activation -from einops import pack, rearrange, repeat - -from .transformer import BasicTransformerBlock - - -class SinusoidalPosEmb(torch.nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" - - def forward(self, x, scale=1000): - if x.ndim < 1: - x = x.unsqueeze(0) - device = x.device - half_dim = self.dim // 2 - emb = math.log(10000) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb) - emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - - -class Block1D(torch.nn.Module): - def __init__(self, dim, dim_out, groups=8): - super().__init__() - self.block = torch.nn.Sequential( - torch.nn.Conv1d(dim, dim_out, 3, padding=1), - torch.nn.GroupNorm(groups, dim_out), - nn.Mish(), - ) - - def forward(self, x, mask): - output = self.block(x * mask) - return output * mask - - -class ResnetBlock1D(torch.nn.Module): - def __init__(self, dim, dim_out, time_emb_dim, groups=8): - super().__init__() - self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out)) - - self.block1 = Block1D(dim, dim_out, groups=groups) - self.block2 = Block1D(dim_out, dim_out, groups=groups) - - self.res_conv = torch.nn.Conv1d(dim, dim_out, 1) - - def forward(self, x, mask, time_emb): - h = self.block1(x, mask) - h += self.mlp(time_emb).unsqueeze(-1) - h = self.block2(h, mask) - output = h + self.res_conv(x * mask) - return output - - -class Downsample1D(nn.Module): - def __init__(self, dim): - super().__init__() - self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1) - - def forward(self, x): - return self.conv(x) - - -class TimestepEmbedding(nn.Module): - def __init__( - self, - in_channels: int, - time_embed_dim: int, - act_fn: str = "silu", - out_dim: int = None, - post_act_fn: Optional[str] = None, - cond_proj_dim=None, - ): - super().__init__() - - self.linear_1 = nn.Linear(in_channels, time_embed_dim) - - if cond_proj_dim is not None: - self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) - else: - self.cond_proj = None - - self.act = get_activation(act_fn) - - if out_dim is not None: - time_embed_dim_out = out_dim - else: - time_embed_dim_out = time_embed_dim - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) - - if post_act_fn is None: - self.post_act = None - else: - self.post_act = get_activation(post_act_fn) - - def forward(self, sample, condition=None): - if condition is not None: - sample = sample + self.cond_proj(condition) - sample = self.linear_1(sample) - - if self.act is not None: - sample = self.act(sample) - - sample = self.linear_2(sample) - - if self.post_act is not None: - sample = self.post_act(sample) - return sample - - -class Upsample1D(nn.Module): - """A 1D upsampling layer with an optional convolution. - - Parameters: - channels (`int`): - number of channels in the inputs and outputs. - use_conv (`bool`, default `False`): - option to use a convolution. - use_conv_transpose (`bool`, default `False`): - option to use a convolution transpose. - out_channels (`int`, optional): - number of output channels. Defaults to `channels`. - """ - - def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - - self.conv = None - if use_conv_transpose: - self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) - elif use_conv: - self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) - - def forward(self, inputs): - assert inputs.shape[1] == self.channels - if self.use_conv_transpose: - return self.conv(inputs) - - outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") - - if self.use_conv: - outputs = self.conv(outputs) - - return outputs - - -class ConformerWrapper(ConformerBlock): - def __init__( # pylint: disable=useless-super-delegation - self, - *, - dim, - dim_head=64, - heads=8, - ff_mult=4, - conv_expansion_factor=2, - conv_kernel_size=31, - attn_dropout=0, - ff_dropout=0, - conv_dropout=0, - conv_causal=False, - ): - super().__init__( - dim=dim, - dim_head=dim_head, - heads=heads, - ff_mult=ff_mult, - conv_expansion_factor=conv_expansion_factor, - conv_kernel_size=conv_kernel_size, - attn_dropout=attn_dropout, - ff_dropout=ff_dropout, - conv_dropout=conv_dropout, - conv_causal=conv_causal, - ) - - def forward( - self, - hidden_states, - attention_mask, - encoder_hidden_states=None, - encoder_attention_mask=None, - timestep=None, - ): - return super().forward(x=hidden_states, mask=attention_mask.bool()) - - -class Decoder(nn.Module): - def __init__( - self, - in_channels, - out_channels, - channels=(256, 256), - dropout=0.05, - attention_head_dim=64, - n_blocks=1, - num_mid_blocks=2, - num_heads=4, - act_fn="snake", - down_block_type="transformer", - mid_block_type="transformer", - up_block_type="transformer", - ): - super().__init__() - channels = tuple(channels) - self.in_channels = in_channels - self.out_channels = out_channels - - self.time_embeddings = SinusoidalPosEmb(in_channels) - time_embed_dim = channels[0] * 4 - self.time_mlp = TimestepEmbedding( - in_channels=in_channels, - time_embed_dim=time_embed_dim, - act_fn="silu", - ) - - self.down_blocks = nn.ModuleList([]) - self.mid_blocks = nn.ModuleList([]) - self.up_blocks = nn.ModuleList([]) - - output_channel = in_channels - for i in range(len(channels)): # pylint: disable=consider-using-enumerate - input_channel = output_channel - output_channel = channels[i] - is_last = i == len(channels) - 1 - resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) - transformer_blocks = nn.ModuleList( - [ - self.get_block( - down_block_type, - output_channel, - attention_head_dim, - num_heads, - dropout, - act_fn, - ) - for _ in range(n_blocks) - ] - ) - downsample = ( - Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1) - ) - - self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample])) - - for i in range(num_mid_blocks): - input_channel = channels[-1] - out_channels = channels[-1] - - resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) - - transformer_blocks = nn.ModuleList( - [ - self.get_block( - mid_block_type, - output_channel, - attention_head_dim, - num_heads, - dropout, - act_fn, - ) - for _ in range(n_blocks) - ] - ) - - self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks])) - - channels = channels[::-1] + (channels[0],) - for i in range(len(channels) - 1): - input_channel = channels[i] - output_channel = channels[i + 1] - is_last = i == len(channels) - 2 - - resnet = ResnetBlock1D( - dim=2 * input_channel, - dim_out=output_channel, - time_emb_dim=time_embed_dim, - ) - transformer_blocks = nn.ModuleList( - [ - self.get_block( - up_block_type, - output_channel, - attention_head_dim, - num_heads, - dropout, - act_fn, - ) - for _ in range(n_blocks) - ] - ) - upsample = ( - Upsample1D(output_channel, use_conv_transpose=True) - if not is_last - else nn.Conv1d(output_channel, output_channel, 3, padding=1) - ) - - self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample])) - - self.final_block = Block1D(channels[-1], channels[-1]) - self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1) - - self.initialize_weights() - # nn.init.normal_(self.final_proj.weight) - - @staticmethod - def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn): - if block_type == "conformer": - block = ConformerWrapper( - dim=dim, - dim_head=attention_head_dim, - heads=num_heads, - ff_mult=1, - conv_expansion_factor=2, - ff_dropout=dropout, - attn_dropout=dropout, - conv_dropout=dropout, - conv_kernel_size=31, - ) - elif block_type == "transformer": - block = BasicTransformerBlock( - dim=dim, - num_attention_heads=num_heads, - attention_head_dim=attention_head_dim, - dropout=dropout, - activation_fn=act_fn, - ) - else: - raise ValueError(f"Unknown block type {block_type}") - - return block - - def initialize_weights(self): - for m in self.modules(): - if isinstance(m, nn.Conv1d): - nn.init.kaiming_normal_(m.weight, nonlinearity="relu") - - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - elif isinstance(m, nn.GroupNorm): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - elif isinstance(m, nn.Linear): - nn.init.kaiming_normal_(m.weight, nonlinearity="relu") - - if m.bias is not None: - nn.init.constant_(m.bias, 0) - - def forward(self, x, mask, mu, t, spks=None, cond=None): - """Forward pass of the UNet1DConditional model. - - Args: - x (torch.Tensor): shape (batch_size, in_channels, time) - mask (_type_): shape (batch_size, 1, time) - t (_type_): shape (batch_size) - spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None. - cond (_type_, optional): placeholder for future use. Defaults to None. - - Raises: - ValueError: _description_ - ValueError: _description_ - - Returns: - _type_: _description_ - """ - - t = self.time_embeddings(t) - t = self.time_mlp(t) - - x = pack([x, mu], "b * t")[0] - - if spks is not None: - spks = repeat(spks, "b c -> b c t", t=x.shape[-1]) - x = pack([x, spks], "b * t")[0] - - hiddens = [] - masks = [mask] - for resnet, transformer_blocks, downsample in self.down_blocks: - mask_down = masks[-1] - x = resnet(x, mask_down, t) - x = rearrange(x, "b c t -> b t c") - mask_down = rearrange(mask_down, "b 1 t -> b t") - for transformer_block in transformer_blocks: - x = transformer_block( - hidden_states=x, - attention_mask=mask_down, - timestep=t, - ) - x = rearrange(x, "b t c -> b c t") - mask_down = rearrange(mask_down, "b t -> b 1 t") - hiddens.append(x) # Save hidden states for skip connections - x = downsample(x * mask_down) - masks.append(mask_down[:, :, ::2]) - - masks = masks[:-1] - mask_mid = masks[-1] - - for resnet, transformer_blocks in self.mid_blocks: - x = resnet(x, mask_mid, t) - x = rearrange(x, "b c t -> b t c") - mask_mid = rearrange(mask_mid, "b 1 t -> b t") - for transformer_block in transformer_blocks: - x = transformer_block( - hidden_states=x, - attention_mask=mask_mid, - timestep=t, - ) - x = rearrange(x, "b t c -> b c t") - mask_mid = rearrange(mask_mid, "b t -> b 1 t") - - for resnet, transformer_blocks, upsample in self.up_blocks: - mask_up = masks.pop() - x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t) - x = rearrange(x, "b c t -> b t c") - mask_up = rearrange(mask_up, "b 1 t -> b t") - for transformer_block in transformer_blocks: - x = transformer_block( - hidden_states=x, - attention_mask=mask_up, - timestep=t, - ) - x = rearrange(x, "b t c -> b c t") - mask_up = rearrange(mask_up, "b t -> b 1 t") - x = upsample(x * mask_up) - - x = self.final_block(x, mask_up) - output = self.final_proj(x * mask_up) - - return output * mask diff --git a/HF_Deploy/src/chatterbox/models/s3gen/matcha/flow_matching.py b/HF_Deploy/src/chatterbox/models/s3gen/matcha/flow_matching.py deleted file mode 100644 index add7b08c4661ae7b56a19898cab3e088414a1b40..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/matcha/flow_matching.py +++ /dev/null @@ -1,129 +0,0 @@ -from abc import ABC - -import torch -import torch.nn.functional as F - -from .decoder import Decoder - - -class BASECFM(torch.nn.Module, ABC): - def __init__( - self, - n_feats, - cfm_params, - n_spks=1, - spk_emb_dim=128, - ): - super().__init__() - self.n_feats = n_feats - self.n_spks = n_spks - self.spk_emb_dim = spk_emb_dim - self.solver = cfm_params.solver - if hasattr(cfm_params, "sigma_min"): - self.sigma_min = cfm_params.sigma_min - else: - self.sigma_min = 1e-4 - - self.estimator = None - - @torch.inference_mode() - def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): - """Forward diffusion - - Args: - mu (torch.Tensor): output of encoder - shape: (batch_size, n_feats, mel_timesteps) - mask (torch.Tensor): output_mask - shape: (batch_size, 1, mel_timesteps) - n_timesteps (int): number of diffusion steps - temperature (float, optional): temperature for scaling noise. Defaults to 1.0. - spks (torch.Tensor, optional): speaker ids. Defaults to None. - shape: (batch_size, spk_emb_dim) - cond: Not used but kept for future purposes - - Returns: - sample: generated mel-spectrogram - shape: (batch_size, n_feats, mel_timesteps) - """ - z = torch.randn_like(mu) * temperature - t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) - return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) - - def solve_euler(self, x, t_span, mu, mask, spks, cond): - """ - Fixed euler solver for ODEs. - Args: - x (torch.Tensor): random noise - t_span (torch.Tensor): n_timesteps interpolated - shape: (n_timesteps + 1,) - mu (torch.Tensor): output of encoder - shape: (batch_size, n_feats, mel_timesteps) - mask (torch.Tensor): output_mask - shape: (batch_size, 1, mel_timesteps) - spks (torch.Tensor, optional): speaker ids. Defaults to None. - shape: (batch_size, spk_emb_dim) - cond: Not used but kept for future purposes - """ - t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] - - # I am storing this because I can later plot it by putting a debugger here and saving it to a file - # Or in future might add like a return_all_steps flag - sol = [] - - for step in range(1, len(t_span)): - dphi_dt = self.estimator(x, mask, mu, t, spks, cond) - - x = x + dt * dphi_dt - t = t + dt - sol.append(x) - if step < len(t_span) - 1: - dt = t_span[step + 1] - t - - return sol[-1] - - def compute_loss(self, x1, mask, mu, spks=None, cond=None): - """Computes diffusion loss - - Args: - x1 (torch.Tensor): Target - shape: (batch_size, n_feats, mel_timesteps) - mask (torch.Tensor): target mask - shape: (batch_size, 1, mel_timesteps) - mu (torch.Tensor): output of encoder - shape: (batch_size, n_feats, mel_timesteps) - spks (torch.Tensor, optional): speaker embedding. Defaults to None. - shape: (batch_size, spk_emb_dim) - - Returns: - loss: conditional flow matching loss - y: conditional flow - shape: (batch_size, n_feats, mel_timesteps) - """ - b, _, t = mu.shape - - # random timestep - t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) - # sample noise p(x_0) - z = torch.randn_like(x1) - - y = (1 - (1 - self.sigma_min) * t) * z + t * x1 - u = x1 - (1 - self.sigma_min) * z - - loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( - torch.sum(mask) * u.shape[1] - ) - return loss, y - - -class CFM(BASECFM): - def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): - super().__init__( - n_feats=in_channels, - cfm_params=cfm_params, - n_spks=n_spks, - spk_emb_dim=spk_emb_dim, - ) - - in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) - # Just change the architecture of the estimator here - self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) diff --git a/HF_Deploy/src/chatterbox/models/s3gen/matcha/text_encoder.py b/HF_Deploy/src/chatterbox/models/s3gen/matcha/text_encoder.py deleted file mode 100644 index 276eee7350b884cd37fd313f1e44db487a77f577..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/matcha/text_encoder.py +++ /dev/null @@ -1,413 +0,0 @@ -""" from https://github.com/jaywalnut310/glow-tts """ - -import math - -import torch -import torch.nn as nn -from einops import rearrange - - -def sequence_mask(length, max_length=None): - if max_length is None: - max_length = length.max() - x = torch.arange(max_length, dtype=length.dtype, device=length.device) - return x.unsqueeze(0) < length.unsqueeze(1) - - - -class LayerNorm(nn.Module): - def __init__(self, channels, eps=1e-4): - super().__init__() - self.channels = channels - self.eps = eps - - self.gamma = torch.nn.Parameter(torch.ones(channels)) - self.beta = torch.nn.Parameter(torch.zeros(channels)) - - def forward(self, x): - n_dims = len(x.shape) - mean = torch.mean(x, 1, keepdim=True) - variance = torch.mean((x - mean) ** 2, 1, keepdim=True) - - x = (x - mean) * torch.rsqrt(variance + self.eps) - - shape = [1, -1] + [1] * (n_dims - 2) - x = x * self.gamma.view(*shape) + self.beta.view(*shape) - return x - - -class ConvReluNorm(nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): - super().__init__() - self.in_channels = in_channels - self.hidden_channels = hidden_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.n_layers = n_layers - self.p_dropout = p_dropout - - self.conv_layers = torch.nn.ModuleList() - self.norm_layers = torch.nn.ModuleList() - self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)) - self.norm_layers.append(LayerNorm(hidden_channels)) - self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout)) - for _ in range(n_layers - 1): - self.conv_layers.append( - torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) - ) - self.norm_layers.append(LayerNorm(hidden_channels)) - self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) - self.proj.weight.data.zero_() - self.proj.bias.data.zero_() - - def forward(self, x, x_mask): - x_org = x - for i in range(self.n_layers): - x = self.conv_layers[i](x * x_mask) - x = self.norm_layers[i](x) - x = self.relu_drop(x) - x = x_org + self.proj(x) - return x * x_mask - - -class DurationPredictor(nn.Module): - def __init__(self, in_channels, filter_channels, kernel_size, p_dropout): - super().__init__() - self.in_channels = in_channels - self.filter_channels = filter_channels - self.p_dropout = p_dropout - - self.drop = torch.nn.Dropout(p_dropout) - self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) - self.norm_1 = LayerNorm(filter_channels) - self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2) - self.norm_2 = LayerNorm(filter_channels) - self.proj = torch.nn.Conv1d(filter_channels, 1, 1) - - def forward(self, x, x_mask): - x = self.conv_1(x * x_mask) - x = torch.relu(x) - x = self.norm_1(x) - x = self.drop(x) - x = self.conv_2(x * x_mask) - x = torch.relu(x) - x = self.norm_2(x) - x = self.drop(x) - x = self.proj(x * x_mask) - return x * x_mask - - -class RotaryPositionalEmbeddings(nn.Module): - """ - ## RoPE module - - Rotary encoding transforms pairs of features by rotating in the 2D plane. - That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. - Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it - by an angle depending on the position of the token. - """ - - def __init__(self, d: int, base: int = 10_000): - r""" - * `d` is the number of features $d$ - * `base` is the constant used for calculating $\Theta$ - """ - super().__init__() - - self.base = base - self.d = int(d) - self.cos_cached = None - self.sin_cached = None - - def _build_cache(self, x: torch.Tensor): - r""" - Cache $\cos$ and $\sin$ values - """ - # Return if cache is already built - if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: - return - - # Get sequence length - seq_len = x.shape[0] - - # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ - theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) - - # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) - - # Calculate the product of position index and $\theta_i$ - idx_theta = torch.einsum("n,d->nd", seq_idx, theta) - - # Concatenate so that for row $m$ we have - # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ - idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) - - # Cache them - self.cos_cached = idx_theta2.cos()[:, None, None, :] - self.sin_cached = idx_theta2.sin()[:, None, None, :] - - def _neg_half(self, x: torch.Tensor): - # $\frac{d}{2}$ - d_2 = self.d // 2 - - # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ - return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) - - def forward(self, x: torch.Tensor): - """ - * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` - """ - # Cache $\cos$ and $\sin$ values - x = rearrange(x, "b h t d -> t b h d") - - self._build_cache(x) - - # Split the features, we can choose to apply rotary embeddings only to a partial set of features. - x_rope, x_pass = x[..., : self.d], x[..., self.d :] - - # Calculate - # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ - neg_half_x = self._neg_half(x_rope) - - x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) - - return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d") - - -class MultiHeadAttention(nn.Module): - def __init__( - self, - channels, - out_channels, - n_heads, - heads_share=True, - p_dropout=0.0, - proximal_bias=False, - proximal_init=False, - ): - super().__init__() - assert channels % n_heads == 0 - - self.channels = channels - self.out_channels = out_channels - self.n_heads = n_heads - self.heads_share = heads_share - self.proximal_bias = proximal_bias - self.p_dropout = p_dropout - self.attn = None - - self.k_channels = channels // n_heads - self.conv_q = torch.nn.Conv1d(channels, channels, 1) - self.conv_k = torch.nn.Conv1d(channels, channels, 1) - self.conv_v = torch.nn.Conv1d(channels, channels, 1) - - # from https://nn.labml.ai/transformers/rope/index.html - self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) - self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) - - self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) - self.drop = torch.nn.Dropout(p_dropout) - - torch.nn.init.xavier_uniform_(self.conv_q.weight) - torch.nn.init.xavier_uniform_(self.conv_k.weight) - if proximal_init: - self.conv_k.weight.data.copy_(self.conv_q.weight.data) - self.conv_k.bias.data.copy_(self.conv_q.bias.data) - torch.nn.init.xavier_uniform_(self.conv_v.weight) - - def forward(self, x, c, attn_mask=None): - q = self.conv_q(x) - k = self.conv_k(c) - v = self.conv_v(c) - - x, self.attn = self.attention(q, k, v, mask=attn_mask) - - x = self.conv_o(x) - return x - - def attention(self, query, key, value, mask=None): - b, d, t_s, t_t = (*key.size(), query.size(2)) - query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads) - key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads) - value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads) - - query = self.query_rotary_pe(query) - key = self.key_rotary_pe(key) - - scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) - - if self.proximal_bias: - assert t_s == t_t, "Proximal bias is only available for self-attention." - scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) - if mask is not None: - scores = scores.masked_fill(mask == 0, -1e4) - p_attn = torch.nn.functional.softmax(scores, dim=-1) - p_attn = self.drop(p_attn) - output = torch.matmul(p_attn, value) - output = output.transpose(2, 3).contiguous().view(b, d, t_t) - return output, p_attn - - @staticmethod - def _attention_bias_proximal(length): - r = torch.arange(length, dtype=torch.float32) - diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) - return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) - - -class FFN(nn.Module): - def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.filter_channels = filter_channels - self.kernel_size = kernel_size - self.p_dropout = p_dropout - - self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) - self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) - self.drop = torch.nn.Dropout(p_dropout) - - def forward(self, x, x_mask): - x = self.conv_1(x * x_mask) - x = torch.relu(x) - x = self.drop(x) - x = self.conv_2(x * x_mask) - return x * x_mask - - -class Encoder(nn.Module): - def __init__( - self, - hidden_channels, - filter_channels, - n_heads, - n_layers, - kernel_size=1, - p_dropout=0.0, - **kwargs, - ): - super().__init__() - self.hidden_channels = hidden_channels - self.filter_channels = filter_channels - self.n_heads = n_heads - self.n_layers = n_layers - self.kernel_size = kernel_size - self.p_dropout = p_dropout - - self.drop = torch.nn.Dropout(p_dropout) - self.attn_layers = torch.nn.ModuleList() - self.norm_layers_1 = torch.nn.ModuleList() - self.ffn_layers = torch.nn.ModuleList() - self.norm_layers_2 = torch.nn.ModuleList() - for _ in range(self.n_layers): - self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) - self.norm_layers_1.append(LayerNorm(hidden_channels)) - self.ffn_layers.append( - FFN( - hidden_channels, - hidden_channels, - filter_channels, - kernel_size, - p_dropout=p_dropout, - ) - ) - self.norm_layers_2.append(LayerNorm(hidden_channels)) - - def forward(self, x, x_mask): - attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) - for i in range(self.n_layers): - x = x * x_mask - y = self.attn_layers[i](x, x, attn_mask) - y = self.drop(y) - x = self.norm_layers_1[i](x + y) - y = self.ffn_layers[i](x, x_mask) - y = self.drop(y) - x = self.norm_layers_2[i](x + y) - x = x * x_mask - return x - - -class TextEncoder(nn.Module): - def __init__( - self, - encoder_type, - encoder_params, - duration_predictor_params, - n_vocab, - n_spks=1, - spk_emb_dim=128, - ): - super().__init__() - self.encoder_type = encoder_type - self.n_vocab = n_vocab - self.n_feats = encoder_params.n_feats - self.n_channels = encoder_params.n_channels - self.spk_emb_dim = spk_emb_dim - self.n_spks = n_spks - - self.emb = torch.nn.Embedding(n_vocab, self.n_channels) - torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5) - - if encoder_params.prenet: - self.prenet = ConvReluNorm( - self.n_channels, - self.n_channels, - self.n_channels, - kernel_size=5, - n_layers=3, - p_dropout=0.5, - ) - else: - self.prenet = lambda x, x_mask: x - - self.encoder = Encoder( - encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0), - encoder_params.filter_channels, - encoder_params.n_heads, - encoder_params.n_layers, - encoder_params.kernel_size, - encoder_params.p_dropout, - ) - - self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1) - self.proj_w = DurationPredictor( - self.n_channels + (spk_emb_dim if n_spks > 1 else 0), - duration_predictor_params.filter_channels_dp, - duration_predictor_params.kernel_size, - duration_predictor_params.p_dropout, - ) - - def forward(self, x, x_lengths, spks=None): - """Run forward pass to the transformer based encoder and duration predictor - - Args: - x (torch.Tensor): text input - shape: (batch_size, max_text_length) - x_lengths (torch.Tensor): text input lengths - shape: (batch_size,) - spks (torch.Tensor, optional): speaker ids. Defaults to None. - shape: (batch_size,) - - Returns: - mu (torch.Tensor): average output of the encoder - shape: (batch_size, n_feats, max_text_length) - logw (torch.Tensor): log duration predicted by the duration predictor - shape: (batch_size, 1, max_text_length) - x_mask (torch.Tensor): mask for the text input - shape: (batch_size, 1, max_text_length) - """ - x = self.emb(x) * math.sqrt(self.n_channels) - x = torch.transpose(x, 1, -1) - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) - - x = self.prenet(x, x_mask) - if self.n_spks > 1: - x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1) - x = self.encoder(x, x_mask) - mu = self.proj_m(x) * x_mask - - x_dp = torch.detach(x) - logw = self.proj_w(x_dp, x_mask) - - return mu, logw, x_mask diff --git a/HF_Deploy/src/chatterbox/models/s3gen/matcha/transformer.py b/HF_Deploy/src/chatterbox/models/s3gen/matcha/transformer.py deleted file mode 100644 index dd1afa3aff5383912209e508676c6885e13ef4ee..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/matcha/transformer.py +++ /dev/null @@ -1,316 +0,0 @@ -from typing import Any, Dict, Optional - -import torch -import torch.nn as nn -from diffusers.models.attention import ( - GEGLU, - GELU, - AdaLayerNorm, - AdaLayerNormZero, - ApproximateGELU, -) -from diffusers.models.attention_processor import Attention -from diffusers.models.lora import LoRACompatibleLinear -from diffusers.utils.torch_utils import maybe_allow_in_graph - - -class SnakeBeta(nn.Module): - """ - A modified Snake function which uses separate parameters for the magnitude of the periodic components - Shape: - - Input: (B, C, T) - - Output: (B, C, T), same shape as the input - Parameters: - - alpha - trainable parameter that controls frequency - - beta - trainable parameter that controls magnitude - References: - - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: - https://arxiv.org/abs/2006.08195 - Examples: - >>> a1 = snakebeta(256) - >>> x = torch.randn(256) - >>> x = a1(x) - """ - - def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): - """ - Initialization. - INPUT: - - in_features: shape of the input - - alpha - trainable parameter that controls frequency - - beta - trainable parameter that controls magnitude - alpha is initialized to 1 by default, higher values = higher-frequency. - beta is initialized to 1 by default, higher values = higher-magnitude. - alpha will be trained along with the rest of your model. - """ - super().__init__() - self.in_features = out_features if isinstance(out_features, list) else [out_features] - self.proj = LoRACompatibleLinear(in_features, out_features) - - # initialize alpha - self.alpha_logscale = alpha_logscale - if self.alpha_logscale: # log scale alphas initialized to zeros - self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha) - self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha) - else: # linear scale alphas initialized to ones - self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha) - self.beta = nn.Parameter(torch.ones(self.in_features) * alpha) - - self.alpha.requires_grad = alpha_trainable - self.beta.requires_grad = alpha_trainable - - self.no_div_by_zero = 0.000000001 - - def forward(self, x): - """ - Forward pass of the function. - Applies the function to the input elementwise. - SnakeBeta ∢= x + 1/b * sin^2 (xa) - """ - x = self.proj(x) - if self.alpha_logscale: - alpha = torch.exp(self.alpha) - beta = torch.exp(self.beta) - else: - alpha = self.alpha - beta = self.beta - - x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2) - - return x - - -class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (`int`): The number of channels in the input. - dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. - """ - - def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - mult: int = 4, - dropout: float = 0.0, - activation_fn: str = "geglu", - final_dropout: bool = False, - ): - super().__init__() - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - - if activation_fn == "gelu": - act_fn = GELU(dim, inner_dim) - if activation_fn == "gelu-approximate": - act_fn = GELU(dim, inner_dim, approximate="tanh") - elif activation_fn == "geglu": - act_fn = GEGLU(dim, inner_dim) - elif activation_fn == "geglu-approximate": - act_fn = ApproximateGELU(dim, inner_dim) - elif activation_fn == "snakebeta": - act_fn = SnakeBeta(dim, inner_dim) - - self.net = nn.ModuleList([]) - # project in - self.net.append(act_fn) - # project dropout - self.net.append(nn.Dropout(dropout)) - # project out - self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) - # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout - if final_dropout: - self.net.append(nn.Dropout(dropout)) - - def forward(self, hidden_states): - for module in self.net: - hidden_states = module(hidden_states) - return hidden_states - - -@maybe_allow_in_graph -class BasicTransformerBlock(nn.Module): - r""" - A basic Transformer block. - - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. - only_cross_attention (`bool`, *optional*): - Whether to use only cross-attention layers. In this case two cross attention layers are used. - double_self_attention (`bool`, *optional*): - Whether to use two self-attention layers. In this case no cross attention layers are used. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm (: - obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. - attention_bias (: - obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. - """ - - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - double_self_attention: bool = False, - upcast_attention: bool = False, - norm_elementwise_affine: bool = True, - norm_type: str = "layer_norm", - final_dropout: bool = False, - ): - super().__init__() - self.only_cross_attention = only_cross_attention - - self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" - self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" - - if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: - raise ValueError( - f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" - f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." - ) - - # Define 3 blocks. Each block has its own normalization layer. - # 1. Self-Attn - if self.use_ada_layer_norm: - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) - elif self.use_ada_layer_norm_zero: - self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) - else: - self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - self.attn1 = Attention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) - - # 2. Cross-Attn - if cross_attention_dim is not None or double_self_attention: - # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. - # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during - # the second cross attention block. - self.norm2 = ( - AdaLayerNorm(dim, num_embeds_ada_norm) - if self.use_ada_layer_norm - else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - ) - self.attn2 = Attention( - query_dim=dim, - cross_attention_dim=cross_attention_dim if not double_self_attention else None, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - # scale_qk=False, # uncomment this to not to use flash attention - ) # is self-attn if encoder_hidden_states is none - else: - self.norm2 = None - self.attn2 = None - - # 3. Feed-forward - self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) - - # let chunk size default to None - self._chunk_size = None - self._chunk_dim = 0 - - def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int): - # Sets chunk feed-forward - self._chunk_size = chunk_size - self._chunk_dim = dim - - def forward( - self, - hidden_states: torch.FloatTensor, - attention_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - timestep: Optional[torch.LongTensor] = None, - cross_attention_kwargs: Dict[str, Any] = None, - class_labels: Optional[torch.LongTensor] = None, - ): - # Notice that normalization is always applied before the real computation in the following blocks. - # 1. Self-Attention - if self.use_ada_layer_norm: - norm_hidden_states = self.norm1(hidden_states, timestep) - elif self.use_ada_layer_norm_zero: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( - hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype - ) - else: - norm_hidden_states = self.norm1(hidden_states) - - cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} - - attn_output = self.attn1( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, - attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask, - **cross_attention_kwargs, - ) - if self.use_ada_layer_norm_zero: - attn_output = gate_msa.unsqueeze(1) * attn_output - hidden_states = attn_output + hidden_states - - # 2. Cross-Attention - if self.attn2 is not None: - norm_hidden_states = ( - self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) - ) - - attn_output = self.attn2( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - **cross_attention_kwargs, - ) - hidden_states = attn_output + hidden_states - - # 3. Feed-forward - norm_hidden_states = self.norm3(hidden_states) - - if self.use_ada_layer_norm_zero: - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] - - if self._chunk_size is not None: - # "feed_forward_chunk_size" can be used to save memory - if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: - raise ValueError( - f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." - ) - - num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size - ff_output = torch.cat( - [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)], - dim=self._chunk_dim, - ) - else: - ff_output = self.ff(norm_hidden_states) - - if self.use_ada_layer_norm_zero: - ff_output = gate_mlp.unsqueeze(1) * ff_output - - hidden_states = ff_output + hidden_states - - return hidden_states diff --git a/HF_Deploy/src/chatterbox/models/s3gen/s3gen.py b/HF_Deploy/src/chatterbox/models/s3gen/s3gen.py deleted file mode 100644 index 97b7c0bd40ad6cd258ca3c4bd4ae752c78f28b19..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/s3gen.py +++ /dev/null @@ -1,305 +0,0 @@ -# Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -import numpy as np -import torch -import torchaudio as ta -from functools import lru_cache -from typing import Optional -from omegaconf import DictConfig - -from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer -from .const import S3GEN_SR -from .flow import CausalMaskedDiffWithXvec -from .xvector import CAMPPlus -from .utils.mel import mel_spectrogram -from .f0_predictor import ConvRNNF0Predictor -from .hifigan import HiFTGenerator -from .transformer.upsample_encoder import UpsampleConformerEncoder -from .flow_matching import CausalConditionalCFM -from .decoder import ConditionalDecoder - - -def drop_invalid_tokens(x): - assert len(x.shape) <= 2 and x.shape[0] == 1, "only batch size of one allowed for now" - return x[x < SPEECH_VOCAB_SIZE] - - -# TODO: global resampler cache -@lru_cache(100) -def get_resampler(src_sr, dst_sr, device): - return ta.transforms.Resample(src_sr, dst_sr).to(device) - - -class S3Token2Mel(torch.nn.Module): - """ - CosyVoice2's CFM decoder maps S3 speech tokens to mel-spectrograms. - - TODO: make these modules configurable? - """ - def __init__(self): - super().__init__() - self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz") - self.mel_extractor = mel_spectrogram # TODO: make it a torch module? - self.speaker_encoder = CAMPPlus() # use default args - - encoder = UpsampleConformerEncoder( - output_size=512, - attention_heads=8, - linear_units=2048, - num_blocks=6, - dropout_rate=0.1, - positional_dropout_rate=0.1, - attention_dropout_rate=0.1, - normalize_before=True, - input_layer='linear', - pos_enc_layer_type='rel_pos_espnet', - selfattention_layer_type='rel_selfattn', - input_size=512, - use_cnn_module=False, - macaron_style=False, - ) - - estimator = ConditionalDecoder( - in_channels=320, - out_channels=80, - causal=True, - channels=[256], - dropout=0.0, - attention_head_dim=64, - n_blocks=4, - num_mid_blocks=12, - num_heads=8, - act_fn='gelu', - ) - cfm_params = DictConfig({ - "sigma_min": 1e-06, - "solver": 'euler', - "t_scheduler": 'cosine', - "training_cfg_rate": 0.2, - "inference_cfg_rate": 0.7, - "reg_loss_type": 'l1', - }) - decoder = CausalConditionalCFM( - spk_emb_dim=80, - cfm_params=cfm_params, - estimator=estimator, - ) - - self.flow = CausalMaskedDiffWithXvec( - encoder=encoder, - decoder=decoder - ) - - self.resamplers = {} - - @property - def device(self): - params = self.tokenizer.parameters() - return next(params).device - - def embed_ref( - self, - ref_wav: torch.Tensor, - ref_sr: int, - device="auto", - ref_fade_out=True, - ): - device = self.device if device == "auto" else device - if isinstance(ref_wav, np.ndarray): - ref_wav = torch.from_numpy(ref_wav).float() - - if ref_wav.device != device: - ref_wav = ref_wav.to(device) - - if len(ref_wav.shape) == 1: - ref_wav = ref_wav.unsqueeze(0) # (B, L) - - if ref_wav.size(1) > 10 * ref_sr: - print("WARNING: cosydec received ref longer than 10s") - - ref_wav_24 = ref_wav - if ref_sr != S3GEN_SR: - ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav) - - ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(device) - ref_mels_24_len = None - - # Resample to 16kHz - ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav).to(device) - - # Speaker embedding - ref_x_vector = self.speaker_encoder.inference(ref_wav_16) - - # Tokenize 16khz reference - ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16) - - # Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms) - if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]: - logging.warning( - "Reference mel length is not equal to 2 * reference token length.\n" - ) - ref_speech_tokens = ref_speech_tokens[:, :ref_mels_24.shape[1] // 2] - ref_speech_token_lens[0] = ref_speech_tokens.shape[1] - - return dict( - prompt_token=ref_speech_tokens.to(device), - prompt_token_len=ref_speech_token_lens, - prompt_feat=ref_mels_24, - prompt_feat_len=ref_mels_24_len, - embedding=ref_x_vector, - ) - - def forward( - self, - speech_tokens: torch.LongTensor, - # locally-computed ref embedding (mutex with ref_dict) - ref_wav: Optional[torch.Tensor], - ref_sr: Optional[int], - # pre-computed ref embedding (prod API) - ref_dict: Optional[dict] = None, - finalize: bool = False, - ): - """ - Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from. - - NOTE: - - The speaker encoder accepts 16 kHz waveform. - - S3TokenizerV2 accepts 16 kHz waveform. - - The mel-spectrogram for the reference assumes 24 kHz input signal. - - This function is designed for batch_size=1 only. - - Args - ---- - - `speech_tokens`: S3 speech tokens [B=1, T] - - `ref_wav`: reference waveform (`torch.Tensor` with shape=[B=1, T]) - - `ref_sr`: reference sample rate - - `finalize`: whether streaming is finished or not. Note that if False, the last 3 tokens will be ignored. - """ - assert (ref_wav is None) ^ (ref_dict is None), f"Must provide exactly one of ref_wav or ref_dict (got {ref_wav} and {ref_dict})" - - if ref_dict is None: - ref_dict = self.embed_ref(ref_wav, ref_sr) - else: - # type/device casting (all values will be numpy if it's from a prod API call) - for rk in list(ref_dict): - if isinstance(ref_dict[rk], np.ndarray): - ref_dict[rk] = torch.from_numpy(ref_dict[rk]) - if torch.is_tensor(ref_dict[rk]): - ref_dict[rk] = ref_dict[rk].to(self.device) - - if len(speech_tokens.shape) == 1: - speech_tokens = speech_tokens.unsqueeze(0) - - # assert speech_tokens.shape[0] == 1, "only batch size of one allowed for now" - speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device) - - output_mels, _ = self.flow.inference( - token=speech_tokens, - token_len=speech_token_lens, - finalize=finalize, - **ref_dict, - ) - return output_mels - - -class S3Token2Wav(S3Token2Mel): - """ - The decoder of CosyVoice2 is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules. - - TODO: make these modules configurable? - """ - - def __init__(self): - super().__init__() - - f0_predictor = ConvRNNF0Predictor() - self.mel2wav = HiFTGenerator( - sampling_rate=S3GEN_SR, - upsample_rates=[8, 5, 3], - upsample_kernel_sizes=[16, 11, 7], - source_resblock_kernel_sizes=[7, 7, 11], - source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], - f0_predictor=f0_predictor, - ) - - # silence out a few ms and fade audio in to reduce artifacts - n_trim = S3GEN_SR // 50 # 20ms = half of a frame - trim_fade = torch.zeros(2 * n_trim) - trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2 - self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting) - - def forward( - self, - speech_tokens, - # locally-computed ref embedding (mutex with ref_dict) - ref_wav: Optional[torch.Tensor], - ref_sr: Optional[int], - # pre-computed ref embedding (prod API) - ref_dict: Optional[dict] = None, - finalize: bool = False - ): - output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize) - - # TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now. - hift_cache_source = torch.zeros(1, 1, 0).to(self.device) - - output_wavs, *_ = self.mel2wav.inference(speech_feat=output_mels, cache_source=hift_cache_source) - - if not self.training: - # NOTE: ad-hoc method to reduce "spillover" from the reference clip. - output_wavs[:, :len(self.trim_fade)] *= self.trim_fade - - return output_wavs - - @torch.inference_mode() - def flow_inference( - self, - speech_tokens, - # locally-computed ref embedding (mutex with ref_dict) - ref_wav: Optional[torch.Tensor] = None, - ref_sr: Optional[int] = None, - # pre-computed ref embedding (prod API) - ref_dict: Optional[dict] = None, - finalize: bool = False, - ): - return super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize) - - @torch.inference_mode() - def hift_inference(self, speech_feat, cache_source: torch.Tensor = None): - if cache_source is None: - cache_source = torch.zeros(1, 1, 0).to(self.device) - return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source) - - @torch.inference_mode() - def inference( - self, - speech_tokens, - # locally-computed ref embedding (mutex with ref_dict) - ref_wav: Optional[torch.Tensor] = None, - ref_sr: Optional[int] = None, - # pre-computed ref embedding (prod API) - ref_dict: Optional[dict] = None, - cache_source: torch.Tensor = None, # NOTE: this arg is for streaming, it can probably be removed here - finalize: bool = True, - ): - output_mels = self.flow_inference(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize) - output_wavs, output_sources = self.hift_inference(output_mels, cache_source) - - # NOTE: ad-hoc method to reduce "spillover" from the reference clip. - output_wavs[:, :len(self.trim_fade)] *= self.trim_fade - - return output_wavs, output_sources diff --git a/HF_Deploy/src/chatterbox/models/s3gen/transformer/__init__.py b/HF_Deploy/src/chatterbox/models/s3gen/transformer/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/HF_Deploy/src/chatterbox/models/s3gen/transformer/activation.py b/HF_Deploy/src/chatterbox/models/s3gen/transformer/activation.py deleted file mode 100644 index 8cea54816385d3b6585ccc2417bc71630d578177..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/transformer/activation.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) -# 2020 Northwestern Polytechnical University (Pengcheng Guo) -# 2020 Mobvoi Inc (Binbin Zhang) -# 2024 Alibaba Inc (Xiang Lyu) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Swish() activation function for Conformer.""" - -import torch -from torch import nn, sin, pow -from torch.nn import Parameter - - -class Swish(torch.nn.Module): - """Construct an Swish object.""" - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Return Swish activation function.""" - return x * torch.sigmoid(x) - - -# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. -# LICENSE is in incl_licenses directory. -class Snake(nn.Module): - ''' - Implementation of a sine-based periodic activation function - Shape: - - Input: (B, C, T) - - Output: (B, C, T), same shape as the input - Parameters: - - alpha - trainable parameter - References: - - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: - https://arxiv.org/abs/2006.08195 - Examples: - >>> a1 = snake(256) - >>> x = torch.randn(256) - >>> x = a1(x) - ''' - def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): - ''' - Initialization. - INPUT: - - in_features: shape of the input - - alpha: trainable parameter - alpha is initialized to 1 by default, higher values = higher-frequency. - alpha will be trained along with the rest of your model. - ''' - super(Snake, self).__init__() - self.in_features = in_features - - # initialize alpha - self.alpha_logscale = alpha_logscale - if self.alpha_logscale: # log scale alphas initialized to zeros - self.alpha = Parameter(torch.zeros(in_features) * alpha) - else: # linear scale alphas initialized to ones - self.alpha = Parameter(torch.ones(in_features) * alpha) - - self.alpha.requires_grad = alpha_trainable - - self.no_div_by_zero = 0.000000001 - - def forward(self, x): - ''' - Forward pass of the function. - Applies the function to the input elementwise. - Snake ∢= x + 1/a * sin^2 (xa) - ''' - alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] - if self.alpha_logscale: - alpha = torch.exp(alpha) - x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) - - return x diff --git a/HF_Deploy/src/chatterbox/models/s3gen/transformer/attention.py b/HF_Deploy/src/chatterbox/models/s3gen/transformer/attention.py deleted file mode 100644 index 95e1d84035e84b27cfa88680c3d42fc84c0b7aed..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/transformer/attention.py +++ /dev/null @@ -1,330 +0,0 @@ -# Copyright (c) 2019 Shigeki Karita -# 2020 Mobvoi Inc (Binbin Zhang) -# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) -# 2024 Alibaba Inc (Xiang Lyu) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Multi-Head Attention layer definition.""" - -import math -from typing import Tuple - -import torch -from torch import nn - - -class MultiHeadedAttention(nn.Module): - """Multi-Head Attention layer. - - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - - """ - - def __init__(self, - n_head: int, - n_feat: int, - dropout_rate: float, - key_bias: bool = True): - """Construct an MultiHeadedAttention object.""" - super().__init__() - assert n_feat % n_head == 0 - # We assume d_v always equals d_k - self.d_k = n_feat // n_head - self.h = n_head - self.linear_q = nn.Linear(n_feat, n_feat) - self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias) - self.linear_v = nn.Linear(n_feat, n_feat) - self.linear_out = nn.Linear(n_feat, n_feat) - self.dropout = nn.Dropout(p=dropout_rate) - - def forward_qkv( - self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Transform query, key and value. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - - Returns: - torch.Tensor: Transformed query tensor, size - (#batch, n_head, time1, d_k). - torch.Tensor: Transformed key tensor, size - (#batch, n_head, time2, d_k). - torch.Tensor: Transformed value tensor, size - (#batch, n_head, time2, d_k). - - """ - n_batch = query.size(0) - q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) - k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) - v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) - q = q.transpose(1, 2) # (batch, head, time1, d_k) - k = k.transpose(1, 2) # (batch, head, time2, d_k) - v = v.transpose(1, 2) # (batch, head, time2, d_k) - - return q, k, v - - def forward_attention( - self, - value: torch.Tensor, - scores: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool) - ) -> torch.Tensor: - """Compute attention context vector. - - Args: - value (torch.Tensor): Transformed value, size - (#batch, n_head, time2, d_k). - scores (torch.Tensor): Attention score, size - (#batch, n_head, time1, time2). - mask (torch.Tensor): Mask, size (#batch, 1, time2) or - (#batch, time1, time2), (0, 0, 0) means fake mask. - - Returns: - torch.Tensor: Transformed value (#batch, time1, d_model) - weighted by the attention score (#batch, time1, time2). - - """ - n_batch = value.size(0) - # NOTE(xcsong): When will `if mask.size(2) > 0` be True? - # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the - # 1st chunk to ease the onnx export.] - # 2. pytorch training - if mask.size(2) > 0: # time2 > 0 - mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) - # For last chunk, time2 might be larger than scores.size(-1) - mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2) - scores = scores.masked_fill(mask, -float('inf')) - attn = torch.softmax(scores, dim=-1).masked_fill( - mask, 0.0) # (batch, head, time1, time2) - # NOTE(xcsong): When will `if mask.size(2) > 0` be False? - # 1. onnx(16/-1, -1/-1, 16/0) - # 2. jit (16/-1, -1/-1, 16/0, 16/4) - else: - attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) - - p_attn = self.dropout(attn) - x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) - x = (x.transpose(1, 2).contiguous().view(n_batch, -1, - self.h * self.d_k) - ) # (batch, time1, d_model) - - return self.linear_out(x) # (batch, time1, d_model) - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - pos_emb: torch.Tensor = torch.empty(0), - cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute scaled dot product attention. - - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2). - 1.When applying cross attention between decoder and encoder, - the batch padding mask for input is in (#batch, 1, T) shape. - 2.When applying self attention of encoder, - the mask is in (#batch, T, T) shape. - 3.When applying self attention of decoder, - the mask is in (#batch, L, L) shape. - 4.If the different position in decoder see different block - of the encoder, such as Mocha, the passed in mask could be - in (#batch, L, T) shape. But there is no such case in current - CosyVoice. - cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), - where `cache_t == chunk_size * num_decoding_left_chunks` - and `head * d_k == size` - - - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) - where `cache_t == chunk_size * num_decoding_left_chunks` - and `head * d_k == size` - - """ - q, k, v = self.forward_qkv(query, key, value) - - # NOTE(xcsong): - # when export onnx model, for 1st chunk, we feed - # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) - # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). - # In all modes, `if cache.size(0) > 0` will alwayse be `True` - # and we will always do splitting and - # concatnation(this will simplify onnx export). Note that - # it's OK to concat & split zero-shaped tensors(see code below). - # when export jit model, for 1st chunk, we always feed - # cache(0, 0, 0, 0) since jit supports dynamic if-branch. - # >>> a = torch.ones((1, 2, 0, 4)) - # >>> b = torch.ones((1, 2, 3, 4)) - # >>> c = torch.cat((a, b), dim=2) - # >>> torch.equal(b, c) # True - # >>> d = torch.split(a, 2, dim=-1) - # >>> torch.equal(d[0], d[1]) # True - if cache.size(0) > 0: - key_cache, value_cache = torch.split(cache, - cache.size(-1) // 2, - dim=-1) - k = torch.cat([key_cache, k], dim=2) - v = torch.cat([value_cache, v], dim=2) - # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's - # non-trivial to calculate `next_cache_start` here. - new_cache = torch.cat((k, v), dim=-1) - - scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) - return self.forward_attention(v, scores, mask), new_cache - - -class RelPositionMultiHeadedAttention(MultiHeadedAttention): - """Multi-Head Attention layer with relative position encoding. - Paper: https://arxiv.org/abs/1901.02860 - Args: - n_head (int): The number of heads. - n_feat (int): The number of features. - dropout_rate (float): Dropout rate. - """ - - def __init__(self, - n_head: int, - n_feat: int, - dropout_rate: float, - key_bias: bool = True): - """Construct an RelPositionMultiHeadedAttention object.""" - super().__init__(n_head, n_feat, dropout_rate, key_bias) - # linear transformation for positional encoding - self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) - self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) - torch.nn.init.xavier_uniform_(self.pos_bias_u) - torch.nn.init.xavier_uniform_(self.pos_bias_v) - - def rel_shift(self, x: torch.Tensor) -> torch.Tensor: - """Compute relative positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1). - time1 means the length of query vector. - - Returns: - torch.Tensor: Output tensor. - - """ - zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1), - device=x.device, - dtype=x.dtype) - x_padded = torch.cat([zero_pad, x], dim=-1) - - x_padded = x_padded.view(x.size()[0], - x.size()[1], - x.size(3) + 1, x.size(2)) - x = x_padded[:, :, 1:].view_as(x)[ - :, :, :, : x.size(-1) // 2 + 1 - ] # only keep the positions from 0 to time2 - return x - - def forward( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - pos_emb: torch.Tensor = torch.empty(0), - cache: torch.Tensor = torch.zeros((0, 0, 0, 0)) - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute 'Scaled Dot Product Attention' with rel. positional encoding. - Args: - query (torch.Tensor): Query tensor (#batch, time1, size). - key (torch.Tensor): Key tensor (#batch, time2, size). - value (torch.Tensor): Value tensor (#batch, time2, size). - mask (torch.Tensor): Mask tensor (#batch, 1, time2) or - (#batch, time1, time2), (0, 0, 0) means fake mask. - pos_emb (torch.Tensor): Positional embedding tensor - (#batch, time2, size). - cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2), - where `cache_t == chunk_size * num_decoding_left_chunks` - and `head * d_k == size` - Returns: - torch.Tensor: Output tensor (#batch, time1, d_model). - torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2) - where `cache_t == chunk_size * num_decoding_left_chunks` - and `head * d_k == size` - """ - q, k, v = self.forward_qkv(query, key, value) - q = q.transpose(1, 2) # (batch, time1, head, d_k) - - # NOTE(xcsong): - # when export onnx model, for 1st chunk, we feed - # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode) - # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode). - # In all modes, `if cache.size(0) > 0` will alwayse be `True` - # and we will always do splitting and - # concatnation(this will simplify onnx export). Note that - # it's OK to concat & split zero-shaped tensors(see code below). - # when export jit model, for 1st chunk, we always feed - # cache(0, 0, 0, 0) since jit supports dynamic if-branch. - # >>> a = torch.ones((1, 2, 0, 4)) - # >>> b = torch.ones((1, 2, 3, 4)) - # >>> c = torch.cat((a, b), dim=2) - # >>> torch.equal(b, c) # True - # >>> d = torch.split(a, 2, dim=-1) - # >>> torch.equal(d[0], d[1]) # True - if cache.size(0) > 0: - key_cache, value_cache = torch.split(cache, - cache.size(-1) // 2, - dim=-1) - k = torch.cat([key_cache, k], dim=2) - v = torch.cat([value_cache, v], dim=2) - # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's - # non-trivial to calculate `next_cache_start` here. - new_cache = torch.cat((k, v), dim=-1) - - n_batch_pos = pos_emb.size(0) - p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) - p = p.transpose(1, 2) # (batch, head, time1, d_k) - - # (batch, head, time1, d_k) - q_with_bias_u = (q + self.pos_bias_u.to(q.device)).transpose(1, 2) - # (batch, head, time1, d_k) - q_with_bias_v = (q + self.pos_bias_v.to(q.device)).transpose(1, 2) - - # compute attention score - # first compute matrix a and matrix c - # as described in https://arxiv.org/abs/1901.02860 Section 3.3 - # (batch, head, time1, time2) - matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) - - # compute matrix b and matrix d - # (batch, head, time1, time2) - matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) - # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used - if matrix_ac.shape != matrix_bd.shape: - matrix_bd = self.rel_shift(matrix_bd) - - scores = (matrix_ac + matrix_bd) / math.sqrt( - self.d_k) # (batch, head, time1, time2) - - return self.forward_attention(v, scores, mask), new_cache diff --git a/HF_Deploy/src/chatterbox/models/s3gen/transformer/convolution.py b/HF_Deploy/src/chatterbox/models/s3gen/transformer/convolution.py deleted file mode 100644 index 4d5d96149154776000991a681a666fbe55e562fe..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/transformer/convolution.py +++ /dev/null @@ -1,145 +0,0 @@ -# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) -# 2024 Alibaba Inc (Xiang Lyu) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from ESPnet(https://github.com/espnet/espnet) -"""ConvolutionModule definition.""" - -from typing import Tuple - -import torch -from torch import nn - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Conformer model.""" - - def __init__(self, - channels: int, - kernel_size: int = 15, - activation: nn.Module = nn.ReLU(), - norm: str = "batch_norm", - causal: bool = False, - bias: bool = True): - """Construct an ConvolutionModule object. - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernel size of conv layers. - causal (int): Whether use causal convolution or not - """ - super().__init__() - - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - # self.lorder is used to distinguish if it's a causal convolution, - # if self.lorder > 0: it's a causal convolution, the input will be - # padded with self.lorder frames on the left in forward. - # else: it's a symmetrical convolution - if causal: - padding = 0 - self.lorder = kernel_size - 1 - else: - # kernel_size should be an odd number for none causal convolution - assert (kernel_size - 1) % 2 == 0 - padding = (kernel_size - 1) // 2 - self.lorder = 0 - self.depthwise_conv = nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=padding, - groups=channels, - bias=bias, - ) - - assert norm in ['batch_norm', 'layer_norm'] - if norm == "batch_norm": - self.use_layer_norm = False - self.norm = nn.BatchNorm1d(channels) - else: - self.use_layer_norm = True - self.norm = nn.LayerNorm(channels) - - self.pointwise_conv2 = nn.Conv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.activation = activation - - def forward( - self, - x: torch.Tensor, - mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - cache: torch.Tensor = torch.zeros((0, 0, 0)), - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Compute convolution module. - Args: - x (torch.Tensor): Input tensor (#batch, time, channels). - mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), - (0, 0, 0) means fake mask. - cache (torch.Tensor): left context cache, it is only - used in causal convolution (#batch, channels, cache_t), - (0, 0, 0) meas fake cache. - Returns: - torch.Tensor: Output tensor (#batch, time, channels). - """ - # exchange the temporal dimension and the feature dimension - x = x.transpose(1, 2) # (#batch, channels, time) - - # mask batch padding - if mask_pad.size(2) > 0: # time > 0 - x.masked_fill_(~mask_pad, 0.0) - - if self.lorder > 0: - if cache.size(2) == 0: # cache_t == 0 - x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) - else: - assert cache.size(0) == x.size(0) # equal batch - assert cache.size(1) == x.size(1) # equal channel - x = torch.cat((cache, x), dim=2) - assert (x.size(2) > self.lorder) - new_cache = x[:, :, -self.lorder:] - else: - # It's better we just return None if no cache is required, - # However, for JIT export, here we just fake one tensor instead of - # None. - new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channel, dim) - x = nn.functional.glu(x, dim=1) # (batch, channel, dim) - - # 1D Depthwise Conv - x = self.depthwise_conv(x) - if self.use_layer_norm: - x = x.transpose(1, 2) - x = self.activation(self.norm(x)) - if self.use_layer_norm: - x = x.transpose(1, 2) - x = self.pointwise_conv2(x) - # mask batch padding - if mask_pad.size(2) > 0: # time > 0 - x.masked_fill_(~mask_pad, 0.0) - - return x.transpose(1, 2), new_cache diff --git a/HF_Deploy/src/chatterbox/models/s3gen/transformer/embedding.py b/HF_Deploy/src/chatterbox/models/s3gen/transformer/embedding.py deleted file mode 100644 index eae8c8ecabb15b4174cc3aa73c070ae702bb5f82..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/transformer/embedding.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) -# 2024 Alibaba Inc (Xiang Lyu) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from ESPnet(https://github.com/espnet/espnet) -"""Positonal Encoding Module.""" - -import math -from typing import Tuple, Union - -import torch -import torch.nn.functional as F -import numpy as np - - -class PositionalEncoding(torch.nn.Module): - """Positional encoding. - - :param int d_model: embedding dim - :param float dropout_rate: dropout rate - :param int max_len: maximum input length - - PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) - PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) - """ - - def __init__(self, - d_model: int, - dropout_rate: float, - max_len: int = 5000, - reverse: bool = False): - """Construct an PositionalEncoding object.""" - super().__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.max_len = max_len - - self.pe = torch.zeros(self.max_len, self.d_model) - position = torch.arange(0, self.max_len, - dtype=torch.float32).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) * - -(math.log(10000.0) / self.d_model)) - self.pe[:, 0::2] = torch.sin(position * div_term) - self.pe[:, 1::2] = torch.cos(position * div_term) - self.pe = self.pe.unsqueeze(0) - - def forward(self, - x: torch.Tensor, - offset: Union[int, torch.Tensor] = 0) \ - -> Tuple[torch.Tensor, torch.Tensor]: - """Add positional encoding. - - Args: - x (torch.Tensor): Input. Its shape is (batch, time, ...) - offset (int, torch.tensor): position offset - - Returns: - torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) - torch.Tensor: for compatibility to RelPositionalEncoding - """ - - self.pe = self.pe.to(x.device) - pos_emb = self.position_encoding(offset, x.size(1), False) - x = x * self.xscale + pos_emb - return self.dropout(x), self.dropout(pos_emb) - - def position_encoding(self, - offset: Union[int, torch.Tensor], - size: int, - apply_dropout: bool = True) -> torch.Tensor: - """ For getting encoding in a streaming fashion - - Attention!!!!! - we apply dropout only once at the whole utterance level in a none - streaming way, but will call this function several times with - increasing input size in a streaming scenario, so the dropout will - be applied several times. - - Args: - offset (int or torch.tensor): start offset - size (int): required size of position encoding - - Returns: - torch.Tensor: Corresponding encoding - """ - # How to subscript a Union type: - # https://github.com/pytorch/pytorch/issues/69434 - if isinstance(offset, int): - assert offset + size <= self.max_len - pos_emb = self.pe[:, offset:offset + size] - elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar - assert offset + size <= self.max_len - pos_emb = self.pe[:, offset:offset + size] - else: # for batched streaming decoding on GPU - assert torch.max(offset) + size <= self.max_len - index = offset.unsqueeze(1) + \ - torch.arange(0, size).to(offset.device) # B X T - flag = index > 0 - # remove negative offset - index = index * flag - pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model - - if apply_dropout: - pos_emb = self.dropout(pos_emb) - return pos_emb - - -class RelPositionalEncoding(PositionalEncoding): - """Relative positional encoding module. - See : Appendix B in https://arxiv.org/abs/1901.02860 - Args: - d_model (int): Embedding dimension. - dropout_rate (float): Dropout rate. - max_len (int): Maximum input length. - """ - - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): - """Initialize class.""" - super().__init__(d_model, dropout_rate, max_len, reverse=True) - - def forward(self, - x: torch.Tensor, - offset: Union[int, torch.Tensor] = 0) \ - -> Tuple[torch.Tensor, torch.Tensor]: - """Compute positional encoding. - Args: - x (torch.Tensor): Input tensor (batch, time, `*`). - Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). - torch.Tensor: Positional embedding tensor (1, time, `*`). - """ - self.pe = self.pe.to(x.device) - x = x * self.xscale - pos_emb = self.position_encoding(offset, x.size(1), False) - return self.dropout(x), self.dropout(pos_emb) - - -class WhisperPositionalEncoding(PositionalEncoding): - """ Sinusoids position encoding used in openai-whisper.encoder - """ - - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500): - super().__init__(d_model, dropout_rate, max_len) - self.xscale = 1.0 - log_timescale_increment = np.log(10000) / (d_model // 2 - 1) - inv_timescales = torch.exp(-log_timescale_increment * - torch.arange(d_model // 2)) - scaled_time = torch.arange(max_len)[:, np.newaxis] * \ - inv_timescales[np.newaxis, :] - pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) - delattr(self, "pe") - self.register_buffer("pe", pe.unsqueeze(0)) - - -class LearnablePositionalEncoding(PositionalEncoding): - """ Learnable position encoding used in openai-whisper.decoder - """ - - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448): - super().__init__(d_model, dropout_rate, max_len) - # NOTE(xcsong): overwrite self.pe & self.xscale - self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model)) - self.xscale = 1.0 - - -class NoPositionalEncoding(torch.nn.Module): - """ No position encoding - """ - - def __init__(self, d_model: int, dropout_rate: float): - super().__init__() - self.d_model = d_model - self.dropout = torch.nn.Dropout(p=dropout_rate) - - def forward(self, - x: torch.Tensor, - offset: Union[int, torch.Tensor] = 0) \ - -> Tuple[torch.Tensor, torch.Tensor]: - """ Just return zero vector for interface compatibility - """ - pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device) - return self.dropout(x), pos_emb - - def position_encoding(self, offset: Union[int, torch.Tensor], - size: int) -> torch.Tensor: - return torch.zeros(1, size, self.d_model) - - -class EspnetRelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module (new implementation). - - Details can be found in https://github.com/espnet/espnet/pull/2816. - - See : Appendix B in https://arxiv.org/abs/1901.02860 - - Args: - d_model (int): Embedding dimension. - dropout_rate (float): Dropout rate. - max_len (int): Maximum input length. - - """ - - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000): - """Construct an PositionalEncoding object.""" - super(EspnetRelPositionalEncoding, self).__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - - def extend_pe(self, x: torch.Tensor): - """Reset the positional encodings.""" - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x.size(1) * 2 - 1: - if self.pe.dtype != x.dtype or self.pe.device != x.device: - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vecotr and `j` means the - # position of key vector. We use position relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]: - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, time, `*`). - - Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). - - """ - self.extend_pe(x) - x = x * self.xscale - pos_emb = self.position_encoding(size=x.size(1), offset=offset) - return self.dropout(x), self.dropout(pos_emb) - - def position_encoding(self, - offset: Union[int, torch.Tensor], - size: int) -> torch.Tensor: - """ For getting encoding in a streaming fashion - - Attention!!!!! - we apply dropout only once at the whole utterance level in a none - streaming way, but will call this function several times with - increasing input size in a streaming scenario, so the dropout will - be applied several times. - - Args: - offset (int or torch.tensor): start offset - size (int): required size of position encoding - - Returns: - torch.Tensor: Corresponding encoding - """ - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size, - ] - return pos_emb diff --git a/HF_Deploy/src/chatterbox/models/s3gen/transformer/encoder_layer.py b/HF_Deploy/src/chatterbox/models/s3gen/transformer/encoder_layer.py deleted file mode 100644 index efbb12dd365770bebe8bca75276fe63be260a08f..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/transformer/encoder_layer.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) -# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from ESPnet(https://github.com/espnet/espnet) -"""Encoder self-attention layer definition.""" - -from typing import Optional, Tuple - -import torch -from torch import nn - - -class TransformerEncoderLayer(nn.Module): - """Encoder layer module. - - Args: - size (int): Input dimension. - self_attn (torch.nn.Module): Self-attention module instance. - `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` - instance can be used as the argument. - feed_forward (torch.nn.Module): Feed-forward module instance. - `PositionwiseFeedForward`, instance can be used as the argument. - dropout_rate (float): Dropout rate. - normalize_before (bool): - True: use layer_norm before each sub-block. - False: to use layer_norm after each sub-block. - """ - - def __init__( - self, - size: int, - self_attn: torch.nn.Module, - feed_forward: torch.nn.Module, - dropout_rate: float, - normalize_before: bool = True, - ): - """Construct an EncoderLayer object.""" - super().__init__() - self.self_attn = self_attn - self.feed_forward = feed_forward - self.norm1 = nn.LayerNorm(size, eps=1e-12) - self.norm2 = nn.LayerNorm(size, eps=1e-12) - self.dropout = nn.Dropout(dropout_rate) - self.size = size - self.normalize_before = normalize_before - - def forward( - self, - x: torch.Tensor, - mask: torch.Tensor, - pos_emb: torch.Tensor, - mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), - cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute encoded features. - - Args: - x (torch.Tensor): (#batch, time, size) - mask (torch.Tensor): Mask tensor for the input (#batch, time,time), - (0, 0, 0) means fake mask. - pos_emb (torch.Tensor): just for interface compatibility - to ConformerEncoderLayer - mask_pad (torch.Tensor): does not used in transformer layer, - just for unified api with conformer. - att_cache (torch.Tensor): Cache tensor of the KEY & VALUE - (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. - cnn_cache (torch.Tensor): Convolution cache in conformer layer - (#batch=1, size, cache_t2), not used here, it's for interface - compatibility to ConformerEncoderLayer. - Returns: - torch.Tensor: Output tensor (#batch, time, size). - torch.Tensor: Mask tensor (#batch, time, time). - torch.Tensor: att_cache tensor, - (#batch=1, head, cache_t1 + time, d_k * 2). - torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2). - - """ - residual = x - if self.normalize_before: - x = self.norm1(x) - x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache) - x = residual + self.dropout(x_att) - if not self.normalize_before: - x = self.norm1(x) - - residual = x - if self.normalize_before: - x = self.norm2(x) - x = residual + self.dropout(self.feed_forward(x)) - if not self.normalize_before: - x = self.norm2(x) - - fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) - return x, mask, new_att_cache, fake_cnn_cache - - -class ConformerEncoderLayer(nn.Module): - """Encoder layer module. - Args: - size (int): Input dimension. - self_attn (torch.nn.Module): Self-attention module instance. - `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` - instance can be used as the argument. - feed_forward (torch.nn.Module): Feed-forward module instance. - `PositionwiseFeedForward` instance can be used as the argument. - feed_forward_macaron (torch.nn.Module): Additional feed-forward module - instance. - `PositionwiseFeedForward` instance can be used as the argument. - conv_module (torch.nn.Module): Convolution module instance. - `ConvlutionModule` instance can be used as the argument. - dropout_rate (float): Dropout rate. - normalize_before (bool): - True: use layer_norm before each sub-block. - False: use layer_norm after each sub-block. - """ - - def __init__( - self, - size: int, - self_attn: torch.nn.Module, - feed_forward: Optional[nn.Module] = None, - feed_forward_macaron: Optional[nn.Module] = None, - conv_module: Optional[nn.Module] = None, - dropout_rate: float = 0.1, - normalize_before: bool = True, - ): - """Construct an EncoderLayer object.""" - super().__init__() - self.self_attn = self_attn - self.feed_forward = feed_forward - self.feed_forward_macaron = feed_forward_macaron - self.conv_module = conv_module - self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module - self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module - if feed_forward_macaron is not None: - self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12) - self.ff_scale = 0.5 - else: - self.ff_scale = 1.0 - if self.conv_module is not None: - self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module - self.norm_final = nn.LayerNorm( - size, eps=1e-12) # for the final output of the block - self.dropout = nn.Dropout(dropout_rate) - self.size = size - self.normalize_before = normalize_before - - def forward( - self, - x: torch.Tensor, - mask: torch.Tensor, - pos_emb: torch.Tensor, - mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), - att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), - cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)), - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """Compute encoded features. - - Args: - x (torch.Tensor): (#batch, time, size) - mask (torch.Tensor): Mask tensor for the input (#batch, time,time), - (0, 0, 0) means fake mask. - pos_emb (torch.Tensor): positional encoding, must not be None - for ConformerEncoderLayer. - mask_pad (torch.Tensor): batch padding mask used for conv module. - (#batch, 1,time), (0, 0, 0) means fake mask. - att_cache (torch.Tensor): Cache tensor of the KEY & VALUE - (#batch=1, head, cache_t1, d_k * 2), head * d_k == size. - cnn_cache (torch.Tensor): Convolution cache in conformer layer - (#batch=1, size, cache_t2) - Returns: - torch.Tensor: Output tensor (#batch, time, size). - torch.Tensor: Mask tensor (#batch, time, time). - torch.Tensor: att_cache tensor, - (#batch=1, head, cache_t1 + time, d_k * 2). - torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2). - """ - - # whether to use macaron style - if self.feed_forward_macaron is not None: - residual = x - if self.normalize_before: - x = self.norm_ff_macaron(x) - x = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(x)) - if not self.normalize_before: - x = self.norm_ff_macaron(x) - - # multi-headed self-attention module - residual = x - if self.normalize_before: - x = self.norm_mha(x) - x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb, - att_cache) - x = residual + self.dropout(x_att) - if not self.normalize_before: - x = self.norm_mha(x) - - # convolution module - # Fake new cnn cache here, and then change it in conv_module - new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) - if self.conv_module is not None: - residual = x - if self.normalize_before: - x = self.norm_conv(x) - x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache) - x = residual + self.dropout(x) - - if not self.normalize_before: - x = self.norm_conv(x) - - # feed forward module - residual = x - if self.normalize_before: - x = self.norm_ff(x) - - x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) - if not self.normalize_before: - x = self.norm_ff(x) - - if self.conv_module is not None: - x = self.norm_final(x) - - return x, mask, new_att_cache, new_cnn_cache diff --git a/HF_Deploy/src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py b/HF_Deploy/src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py deleted file mode 100644 index b7a2cf6e7315e3a5ed2794423daff0a59cc5b208..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright (c) 2019 Shigeki Karita -# 2020 Mobvoi Inc (Binbin Zhang) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Positionwise feed forward layer definition.""" - -import torch - - -class PositionwiseFeedForward(torch.nn.Module): - """Positionwise feed forward layer. - - FeedForward are appied on each position of the sequence. - The output dim is same with the input dim. - - Args: - idim (int): Input dimenstion. - hidden_units (int): The number of hidden units. - dropout_rate (float): Dropout rate. - activation (torch.nn.Module): Activation function - """ - - def __init__( - self, - idim: int, - hidden_units: int, - dropout_rate: float, - activation: torch.nn.Module = torch.nn.ReLU(), - ): - """Construct a PositionwiseFeedForward object.""" - super(PositionwiseFeedForward, self).__init__() - self.w_1 = torch.nn.Linear(idim, hidden_units) - self.activation = activation - self.dropout = torch.nn.Dropout(dropout_rate) - self.w_2 = torch.nn.Linear(hidden_units, idim) - - def forward(self, xs: torch.Tensor) -> torch.Tensor: - """Forward function. - - Args: - xs: input tensor (B, L, D) - Returns: - output tensor, (B, L, D) - """ - return self.w_2(self.dropout(self.activation(self.w_1(xs)))) - - -class MoEFFNLayer(torch.nn.Module): - """ - Mixture of expert with Positionwise feed forward layer - See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf - The output dim is same with the input dim. - - Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 - https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 - Args: - n_expert: number of expert. - n_expert_per_token: The actual number of experts used for each frame - idim (int): Input dimenstion. - hidden_units (int): The number of hidden units. - dropout_rate (float): Dropout rate. - activation (torch.nn.Module): Activation function - """ - - def __init__( - self, - n_expert: int, - n_expert_per_token: int, - idim: int, - hidden_units: int, - dropout_rate: float, - activation: torch.nn.Module = torch.nn.ReLU(), - ): - super(MoEFFNLayer, self).__init__() - self.gate = torch.nn.Linear(idim, n_expert, bias=False) - self.experts = torch.nn.ModuleList( - PositionwiseFeedForward(idim, hidden_units, dropout_rate, - activation) for _ in range(n_expert)) - self.n_expert_per_token = n_expert_per_token - - def forward(self, xs: torch.Tensor) -> torch.Tensor: - """Foward function. - Args: - xs: input tensor (B, L, D) - Returns: - output tensor, (B, L, D) - - """ - B, L, D = xs.size( - ) # batch size, sequence length, embedding dimension (idim) - xs = xs.view(-1, D) # (B*L, D) - router = self.gate(xs) # (B*L, n_expert) - logits, indices = torch.topk( - router, self.n_expert_per_token - ) # probs:(B*L, n_expert), indices: (B*L, n_expert) - weights = torch.nn.functional.softmax( - logits, dim=1, - dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) - output = torch.zeros_like(xs) # (B*L, D) - for i, expert in enumerate(self.experts): - mask = indices == i - batch_idx, ith_expert = torch.where(mask) - output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( - xs[batch_idx]) - return output.view(B, L, D) diff --git a/HF_Deploy/src/chatterbox/models/s3gen/transformer/subsampling.py b/HF_Deploy/src/chatterbox/models/s3gen/transformer/subsampling.py deleted file mode 100644 index e17c2e324e3afb24e1b619effe29cef07c9c5b3a..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/transformer/subsampling.py +++ /dev/null @@ -1,383 +0,0 @@ -# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) -# 2024 Alibaba Inc (Xiang Lyu) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from ESPnet(https://github.com/espnet/espnet) -"""Subsampling layer definition.""" - -from typing import Tuple, Union - -import torch - - -class BaseSubsampling(torch.nn.Module): - - def __init__(self): - super().__init__() - self.right_context = 0 - self.subsampling_rate = 1 - - def position_encoding(self, offset: Union[int, torch.Tensor], - size: int) -> torch.Tensor: - return self.pos_enc.position_encoding(offset, size) - - -class EmbedinigNoSubsampling(BaseSubsampling): - """Embedding input without subsampling - """ - - def __init__(self, idim: int, odim: int, dropout_rate: float, - pos_enc_class: torch.nn.Module): - super().__init__() - self.embed = torch.nn.Embedding(idim, odim) - self.pos_enc = pos_enc_class - - def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - offset: Union[int, torch.Tensor] = 0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Input x. - - Args: - x (torch.Tensor): Input tensor (#batch, time, idim). - x_mask (torch.Tensor): Input mask (#batch, 1, time). - - Returns: - torch.Tensor: linear input tensor (#batch, time', odim), - where time' = time . - torch.Tensor: linear input mask (#batch, 1, time'), - where time' = time . - - """ - x = self.embed(x) - x, pos_emb = self.pos_enc(x, offset) - return x, pos_emb, x_mask - - -class LinearNoSubsampling(BaseSubsampling): - """Linear transform the input without subsampling - - Args: - idim (int): Input dimension. - odim (int): Output dimension. - dropout_rate (float): Dropout rate. - - """ - - def __init__(self, idim: int, odim: int, dropout_rate: float, - pos_enc_class: torch.nn.Module): - """Construct an linear object.""" - super().__init__() - self.out = torch.nn.Sequential( - torch.nn.Linear(idim, odim), - torch.nn.LayerNorm(odim, eps=1e-5), - torch.nn.Dropout(dropout_rate), - ) - self.pos_enc = pos_enc_class - self.right_context = 0 - self.subsampling_rate = 1 - - def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - offset: Union[int, torch.Tensor] = 0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Input x. - - Args: - x (torch.Tensor): Input tensor (#batch, time, idim). - x_mask (torch.Tensor): Input mask (#batch, 1, time). - - Returns: - torch.Tensor: linear input tensor (#batch, time', odim), - where time' = time . - torch.Tensor: linear input mask (#batch, 1, time'), - where time' = time . - - """ - x = self.out(x) - x, pos_emb = self.pos_enc(x, offset) - return x, pos_emb, x_mask - - -class Conv1dSubsampling2(BaseSubsampling): - """Convolutional 1D subsampling (to 1/2 length). - It is designed for Whisper, ref: - https://github.com/openai/whisper/blob/main/whisper/model.py - - Args: - idim (int): Input dimension. - odim (int): Output dimension. - dropout_rate (float): Dropout rate. - - """ - - def __init__(self, idim: int, odim: int, dropout_rate: float, - pos_enc_class: torch.nn.Module): - """Construct an Conv1dSubsampling2 object.""" - super().__init__() - self.conv = torch.nn.Sequential( - torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1), - torch.nn.GELU(), - torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1), - torch.nn.GELU(), - ) - self.pos_enc = pos_enc_class - # The right context for every conv layer is computed by: - # (kernel_size - 1) * frame_rate_of_this_layer - self.subsampling_rate = 2 - # 4 = (3 - 1) * 1 + (3 - 1) * 1 - self.right_context = 4 - - def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - offset: Union[int, torch.Tensor] = 0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Subsample x. - - Args: - x (torch.Tensor): Input tensor (#batch, time, idim). - x_mask (torch.Tensor): Input mask (#batch, 1, time). - - Returns: - torch.Tensor: Subsampled tensor (#batch, time', odim), - where time' = time // 2. - torch.Tensor: Subsampled mask (#batch, 1, time'), - where time' = time // 2. - torch.Tensor: positional encoding - - """ - time = x.size(1) - x = x.transpose(1, 2) # (b, f, t) - x = self.conv(x) - x = x.transpose(1, 2) # (b, t, f) - x, pos_emb = self.pos_enc(x, offset) - return x, pos_emb, x_mask[:, :, (time + 1) % 2::2] - - -class Conv2dSubsampling4(BaseSubsampling): - """Convolutional 2D subsampling (to 1/4 length). - - Args: - idim (int): Input dimension. - odim (int): Output dimension. - dropout_rate (float): Dropout rate. - - """ - - def __init__(self, idim: int, odim: int, dropout_rate: float, - pos_enc_class: torch.nn.Module): - """Construct an Conv2dSubsampling4 object.""" - super().__init__() - self.conv = torch.nn.Sequential( - torch.nn.Conv2d(1, odim, 3, 2), - torch.nn.ReLU(), - torch.nn.Conv2d(odim, odim, 3, 2), - torch.nn.ReLU(), - ) - self.out = torch.nn.Sequential( - torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)) - self.pos_enc = pos_enc_class - # The right context for every conv layer is computed by: - # (kernel_size - 1) * frame_rate_of_this_layer - self.subsampling_rate = 4 - # 6 = (3 - 1) * 1 + (3 - 1) * 2 - self.right_context = 6 - - def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - offset: Union[int, torch.Tensor] = 0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Subsample x. - - Args: - x (torch.Tensor): Input tensor (#batch, time, idim). - x_mask (torch.Tensor): Input mask (#batch, 1, time). - - Returns: - torch.Tensor: Subsampled tensor (#batch, time', odim), - where time' = time // 4. - torch.Tensor: Subsampled mask (#batch, 1, time'), - where time' = time // 4. - torch.Tensor: positional encoding - - """ - x = x.unsqueeze(1) # (b, c=1, t, f) - x = self.conv(x) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - x, pos_emb = self.pos_enc(x, offset) - return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2] - - -class Conv2dSubsampling6(BaseSubsampling): - """Convolutional 2D subsampling (to 1/6 length). - Args: - idim (int): Input dimension. - odim (int): Output dimension. - dropout_rate (float): Dropout rate. - pos_enc (torch.nn.Module): Custom position encoding layer. - """ - - def __init__(self, idim: int, odim: int, dropout_rate: float, - pos_enc_class: torch.nn.Module): - """Construct an Conv2dSubsampling6 object.""" - super().__init__() - self.conv = torch.nn.Sequential( - torch.nn.Conv2d(1, odim, 3, 2), - torch.nn.ReLU(), - torch.nn.Conv2d(odim, odim, 5, 3), - torch.nn.ReLU(), - ) - self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), - odim) - self.pos_enc = pos_enc_class - # 10 = (3 - 1) * 1 + (5 - 1) * 2 - self.subsampling_rate = 6 - self.right_context = 10 - - def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - offset: Union[int, torch.Tensor] = 0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Subsample x. - Args: - x (torch.Tensor): Input tensor (#batch, time, idim). - x_mask (torch.Tensor): Input mask (#batch, 1, time). - - Returns: - torch.Tensor: Subsampled tensor (#batch, time', odim), - where time' = time // 6. - torch.Tensor: Subsampled mask (#batch, 1, time'), - where time' = time // 6. - torch.Tensor: positional encoding - """ - x = x.unsqueeze(1) # (b, c, t, f) - x = self.conv(x) - b, c, t, f = x.size() - x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) - x, pos_emb = self.pos_enc(x, offset) - return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3] - - -class Conv2dSubsampling8(BaseSubsampling): - """Convolutional 2D subsampling (to 1/8 length). - - Args: - idim (int): Input dimension. - odim (int): Output dimension. - dropout_rate (float): Dropout rate. - - """ - - def __init__(self, idim: int, odim: int, dropout_rate: float, - pos_enc_class: torch.nn.Module): - """Construct an Conv2dSubsampling8 object.""" - super().__init__() - self.conv = torch.nn.Sequential( - torch.nn.Conv2d(1, odim, 3, 2), - torch.nn.ReLU(), - torch.nn.Conv2d(odim, odim, 3, 2), - torch.nn.ReLU(), - torch.nn.Conv2d(odim, odim, 3, 2), - torch.nn.ReLU(), - ) - self.linear = torch.nn.Linear( - odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim) - self.pos_enc = pos_enc_class - self.subsampling_rate = 8 - # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4 - self.right_context = 14 - - def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - offset: Union[int, torch.Tensor] = 0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Subsample x. - - Args: - x (torch.Tensor): Input tensor (#batch, time, idim). - x_mask (torch.Tensor): Input mask (#batch, 1, time). - - Returns: - torch.Tensor: Subsampled tensor (#batch, time', odim), - where time' = time // 8. - torch.Tensor: Subsampled mask (#batch, 1, time'), - where time' = time // 8. - torch.Tensor: positional encoding - """ - x = x.unsqueeze(1) # (b, c, t, f) - x = self.conv(x) - b, c, t, f = x.size() - x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f)) - x, pos_emb = self.pos_enc(x, offset) - return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2] - - -class LegacyLinearNoSubsampling(BaseSubsampling): - """Linear transform the input without subsampling - - Args: - idim (int): Input dimension. - odim (int): Output dimension. - dropout_rate (float): Dropout rate. - - """ - - def __init__(self, idim: int, odim: int, dropout_rate: float, - pos_enc_class: torch.nn.Module): - """Construct an linear object.""" - super().__init__() - self.out = torch.nn.Sequential( - torch.nn.Linear(idim, odim), - torch.nn.LayerNorm(odim, eps=1e-5), - torch.nn.Dropout(dropout_rate), - torch.nn.ReLU(), - ) - self.pos_enc = pos_enc_class - self.right_context = 0 - self.subsampling_rate = 1 - - def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - offset: Union[int, torch.Tensor] = 0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Input x. - - Args: - x (torch.Tensor): Input tensor (#batch, time, idim). - x_mask (torch.Tensor): Input mask (#batch, 1, time). - - Returns: - torch.Tensor: linear input tensor (#batch, time', odim), - where time' = time . - torch.Tensor: linear input mask (#batch, 1, time'), - where time' = time . - - """ - x = self.out(x) - x, pos_emb = self.pos_enc(x, offset) - return x, pos_emb, x_mask diff --git a/HF_Deploy/src/chatterbox/models/s3gen/transformer/upsample_encoder.py b/HF_Deploy/src/chatterbox/models/s3gen/transformer/upsample_encoder.py deleted file mode 100644 index 766a5e4e77070ff5579b1a567607c2879391bf8a..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/transformer/upsample_encoder.py +++ /dev/null @@ -1,318 +0,0 @@ -# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu) -# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn) -# 2024 Alibaba Inc (Xiang Lyu) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Modified from ESPnet(https://github.com/espnet/espnet) -"""Encoder definition.""" -from typing import Tuple - -import torch -from torch import nn -from torch.nn import functional as F - -from .convolution import ConvolutionModule -from .encoder_layer import ConformerEncoderLayer -from .positionwise_feed_forward import PositionwiseFeedForward -from ..utils.class_utils import ( - COSYVOICE_EMB_CLASSES, - COSYVOICE_SUBSAMPLE_CLASSES, - COSYVOICE_ATTENTION_CLASSES, - COSYVOICE_ACTIVATION_CLASSES, -) -from ..utils.mask import make_pad_mask -from ..utils.mask import add_optional_chunk_mask - - -class Upsample1D(nn.Module): - """A 1D upsampling layer with an optional convolution. - - Parameters: - channels (`int`): - number of channels in the inputs and outputs. - use_conv (`bool`, default `False`): - option to use a convolution. - use_conv_transpose (`bool`, default `False`): - option to use a convolution transpose. - out_channels (`int`, optional): - number of output channels. Defaults to `channels`. - """ - - def __init__(self, channels: int, out_channels: int, stride: int = 2): - super().__init__() - self.channels = channels - self.out_channels = out_channels - self.stride = stride - # In this mode, first repeat interpolate, than conv with stride=1 - self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0) - - def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor): - outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest") - outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0) - outputs = self.conv(outputs) - return outputs, input_lengths * self.stride - - -class PreLookaheadLayer(nn.Module): - def __init__(self, channels: int, pre_lookahead_len: int = 1): - super().__init__() - self.channels = channels - self.pre_lookahead_len = pre_lookahead_len - self.conv1 = nn.Conv1d( - channels, channels, - kernel_size=pre_lookahead_len + 1, - stride=1, padding=0, - ) - self.conv2 = nn.Conv1d( - channels, channels, - kernel_size=3, stride=1, padding=0, - ) - - def forward(self, inputs: torch.Tensor) -> torch.Tensor: - """ - inputs: (batch_size, seq_len, channels) - """ - outputs = inputs.transpose(1, 2).contiguous() - # look ahead - outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0) - outputs = F.leaky_relu(self.conv1(outputs)) - # outputs - outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0) - outputs = self.conv2(outputs) - outputs = outputs.transpose(1, 2).contiguous() - - # residual connection - outputs = outputs + inputs - return outputs - - -class UpsampleConformerEncoder(torch.nn.Module): - - def __init__( - self, - input_size: int = 512, - output_size: int = 512, - attention_heads: int = 8, - linear_units: int = 2048, - num_blocks: int = 6, - dropout_rate: float = 0.1, - positional_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0.1, - input_layer: str = "linear", - pos_enc_layer_type: str = "rel_pos_espnet", - normalize_before: bool = True, - static_chunk_size: int = 0, - use_dynamic_chunk: bool = False, - global_cmvn: torch.nn.Module = None, - use_dynamic_left_chunk: bool = False, - positionwise_conv_kernel_size: int = 1, - macaron_style: bool = False, - selfattention_layer_type: str = "rel_selfattn", - activation_type: str = "swish", - use_cnn_module: bool = False, - cnn_module_kernel: int = 15, - causal: bool = False, - cnn_module_norm: str = "batch_norm", - key_bias: bool = True, - gradient_checkpointing: bool = False, - ): - """ - Args: - input_size (int): input dim - output_size (int): dimension of attention - attention_heads (int): the number of heads of multi head attention - linear_units (int): the hidden units number of position-wise feed - forward - num_blocks (int): the number of decoder blocks - dropout_rate (float): dropout rate - attention_dropout_rate (float): dropout rate in attention - positional_dropout_rate (float): dropout rate after adding - positional encoding - input_layer (str): input layer type. - optional [linear, conv2d, conv2d6, conv2d8] - pos_enc_layer_type (str): Encoder positional encoding layer type. - opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos] - normalize_before (bool): - True: use layer_norm before each sub-block of a layer. - False: use layer_norm after each sub-block of a layer. - static_chunk_size (int): chunk size for static chunk training and - decoding - use_dynamic_chunk (bool): whether use dynamic chunk size for - training or not, You can only use fixed chunk(chunk_size > 0) - or dyanmic chunk size(use_dynamic_chunk = True) - global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module - use_dynamic_left_chunk (bool): whether use dynamic left chunk in - dynamic chunk training - key_bias: whether use bias in attention.linear_k, False for whisper models. - gradient_checkpointing: rerunning a forward-pass segment for each - checkpointed segment during backward. - """ - super().__init__() - self._output_size = output_size - - self.global_cmvn = global_cmvn - self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer]( - input_size, - output_size, - dropout_rate, - COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size, - positional_dropout_rate), - ) - - self.normalize_before = normalize_before - self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5) - self.static_chunk_size = static_chunk_size - self.use_dynamic_chunk = use_dynamic_chunk - self.use_dynamic_left_chunk = use_dynamic_left_chunk - self.gradient_checkpointing = gradient_checkpointing - activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]() - # self-attention module definition - encoder_selfattn_layer_args = ( - attention_heads, - output_size, - attention_dropout_rate, - key_bias, - ) - # feed-forward module definition - positionwise_layer_args = ( - output_size, - linear_units, - dropout_rate, - activation, - ) - # convolution module definition - convolution_layer_args = (output_size, cnn_module_kernel, activation, - cnn_module_norm, causal) - self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3) - self.encoders = torch.nn.ModuleList([ - ConformerEncoderLayer( - output_size, - COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type]( - *encoder_selfattn_layer_args), - PositionwiseFeedForward(*positionwise_layer_args), - PositionwiseFeedForward( - *positionwise_layer_args) if macaron_style else None, - ConvolutionModule( - *convolution_layer_args) if use_cnn_module else None, - dropout_rate, - normalize_before, - ) for _ in range(num_blocks) - ]) - self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2) - self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer]( - input_size, - output_size, - dropout_rate, - COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size, - positional_dropout_rate), - ) - self.up_encoders = torch.nn.ModuleList([ - ConformerEncoderLayer( - output_size, - COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type]( - *encoder_selfattn_layer_args), - PositionwiseFeedForward(*positionwise_layer_args), - PositionwiseFeedForward( - *positionwise_layer_args) if macaron_style else None, - ConvolutionModule( - *convolution_layer_args) if use_cnn_module else None, - dropout_rate, - normalize_before, - ) for _ in range(4) - ]) - - def output_size(self) -> int: - return self._output_size - - def forward( - self, - xs: torch.Tensor, - xs_lens: torch.Tensor, - decoding_chunk_size: int = 0, - num_decoding_left_chunks: int = -1, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Embed positions in tensor. - - Args: - xs: padded input tensor (B, T, D) - xs_lens: input length (B) - decoding_chunk_size: decoding chunk size for dynamic chunk - 0: default for training, use random dynamic chunk. - <0: for decoding, use full chunk. - >0: for decoding, use fixed chunk size as set. - num_decoding_left_chunks: number of left chunks, this is for decoding, - the chunk size is decoding_chunk_size. - >=0: use num_decoding_left_chunks - <0: use all left chunks - Returns: - encoder output tensor xs, and subsampled masks - xs: padded output tensor (B, T' ~= T/subsample_rate, D) - masks: torch.Tensor batch padding mask after subsample - (B, 1, T' ~= T/subsample_rate) - NOTE(xcsong): - We pass the `__call__` method of the modules instead of `forward` to the - checkpointing API because `__call__` attaches all the hooks of the module. - https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 - """ - T = xs.size(1) - masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) - if self.global_cmvn is not None: - xs = self.global_cmvn(xs) - xs, pos_emb, masks = self.embed(xs, masks) - mask_pad = masks # (B, 1, T/subsample_rate) - chunk_masks = add_optional_chunk_mask(xs, masks, - self.use_dynamic_chunk, - self.use_dynamic_left_chunk, - decoding_chunk_size, - self.static_chunk_size, - num_decoding_left_chunks) - # lookahead + conformer encoder - xs = self.pre_lookahead_layer(xs) - xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad) - - # upsample + conformer encoder - xs = xs.transpose(1, 2).contiguous() - xs, xs_lens = self.up_layer(xs, xs_lens) - xs = xs.transpose(1, 2).contiguous() - T = xs.size(1) - masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T) - xs, pos_emb, masks = self.up_embed(xs, masks) - mask_pad = masks # (B, 1, T/subsample_rate) - chunk_masks = add_optional_chunk_mask(xs, masks, - self.use_dynamic_chunk, - self.use_dynamic_left_chunk, - decoding_chunk_size, - self.static_chunk_size * self.up_layer.stride, - num_decoding_left_chunks) - xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad) - - if self.normalize_before: - xs = self.after_norm(xs) - # Here we assume the mask is not changed in encoder layers, so just - # return the masks before encoder layers, and the masks will be used - # for cross attention with decoder later - return xs, masks - - def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, - pos_emb: torch.Tensor, - mask_pad: torch.Tensor) -> torch.Tensor: - for layer in self.encoders: - xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) - return xs - - def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor, - pos_emb: torch.Tensor, - mask_pad: torch.Tensor) -> torch.Tensor: - for layer in self.up_encoders: - xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad) - return xs diff --git a/HF_Deploy/src/chatterbox/models/s3gen/utils/class_utils.py b/HF_Deploy/src/chatterbox/models/s3gen/utils/class_utils.py deleted file mode 100644 index cd31e48029ce1ee11728a2edbffec479cc0a1bd6..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/utils/class_utils.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright [2023-11-28] -# 2024 Alibaba Inc (authors: Xiang Lyu) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch - -from ..transformer.activation import Swish -from ..transformer.subsampling import ( - LinearNoSubsampling, - EmbedinigNoSubsampling, - Conv1dSubsampling2, - Conv2dSubsampling4, - Conv2dSubsampling6, - Conv2dSubsampling8, -) -from ..transformer.embedding import ( - PositionalEncoding, - RelPositionalEncoding, - WhisperPositionalEncoding, - LearnablePositionalEncoding, - NoPositionalEncoding) -from ..transformer.attention import (MultiHeadedAttention, - RelPositionMultiHeadedAttention) -from ..transformer.embedding import EspnetRelPositionalEncoding -from ..transformer.subsampling import LegacyLinearNoSubsampling - - -COSYVOICE_ACTIVATION_CLASSES = { - "hardtanh": torch.nn.Hardtanh, - "tanh": torch.nn.Tanh, - "relu": torch.nn.ReLU, - "selu": torch.nn.SELU, - "swish": getattr(torch.nn, "SiLU", Swish), - "gelu": torch.nn.GELU, -} - -COSYVOICE_SUBSAMPLE_CLASSES = { - "linear": LinearNoSubsampling, - "linear_legacy": LegacyLinearNoSubsampling, - "embed": EmbedinigNoSubsampling, - "conv1d2": Conv1dSubsampling2, - "conv2d": Conv2dSubsampling4, - "conv2d6": Conv2dSubsampling6, - "conv2d8": Conv2dSubsampling8, - 'paraformer_dummy': torch.nn.Identity -} - -COSYVOICE_EMB_CLASSES = { - "embed": PositionalEncoding, - "abs_pos": PositionalEncoding, - "rel_pos": RelPositionalEncoding, - "rel_pos_espnet": EspnetRelPositionalEncoding, - "no_pos": NoPositionalEncoding, - "abs_pos_whisper": WhisperPositionalEncoding, - "embed_learnable_pe": LearnablePositionalEncoding, -} - -COSYVOICE_ATTENTION_CLASSES = { - "selfattn": MultiHeadedAttention, - "rel_selfattn": RelPositionMultiHeadedAttention, -} diff --git a/HF_Deploy/src/chatterbox/models/s3gen/utils/mask.py b/HF_Deploy/src/chatterbox/models/s3gen/utils/mask.py deleted file mode 100644 index 08c97a3ed6f49d9e623b252273d2eee9d26c408b..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/utils/mask.py +++ /dev/null @@ -1,193 +0,0 @@ -# Copyright (c) 2019 Shigeki Karita -# 2020 Mobvoi Inc (Binbin Zhang) -# 2024 Alibaba Inc (authors: Xiang Lyu) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -''' -def subsequent_mask( - size: int, - device: torch.device = torch.device("cpu"), -) -> torch.Tensor: - """Create mask for subsequent steps (size, size). - - This mask is used only in decoder which works in an auto-regressive mode. - This means the current step could only do attention with its left steps. - - In encoder, fully attention is used when streaming is not necessary and - the sequence is not long. In this case, no attention mask is needed. - - When streaming is need, chunk-based attention is used in encoder. See - subsequent_chunk_mask for the chunk-based attention mask. - - Args: - size (int): size of mask - str device (str): "cpu" or "cuda" or torch.Tensor.device - dtype (torch.device): result dtype - - Returns: - torch.Tensor: mask - - Examples: - >>> subsequent_mask(3) - [[1, 0, 0], - [1, 1, 0], - [1, 1, 1]] - """ - ret = torch.ones(size, size, device=device, dtype=torch.bool) - return torch.tril(ret) -''' - - -def subsequent_chunk_mask( - size: int, - chunk_size: int, - num_left_chunks: int = -1, - device: torch.device = torch.device("cpu"), -) -> torch.Tensor: - """Create mask for subsequent steps (size, size) with chunk size, - this is for streaming encoder - - Args: - size (int): size of mask - chunk_size (int): size of chunk - num_left_chunks (int): number of left chunks - <0: use full chunk - >=0: use num_left_chunks - device (torch.device): "cpu" or "cuda" or torch.Tensor.device - - Returns: - torch.Tensor: mask - - Examples: - >>> subsequent_chunk_mask(4, 2) - [[1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 1, 1], - [1, 1, 1, 1]] - """ - # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks - # actually this is not needed after we have inference cache implemented, will remove it later - pos_idx = torch.arange(size, device=device) - block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size - ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1) - return ret - - -def add_optional_chunk_mask(xs: torch.Tensor, - masks: torch.Tensor, - use_dynamic_chunk: bool, - use_dynamic_left_chunk: bool, - decoding_chunk_size: int, - static_chunk_size: int, - num_decoding_left_chunks: int, - enable_full_context: bool = True): - """ Apply optional mask for encoder. - - Args: - xs (torch.Tensor): padded input, (B, L, D), L for max length - mask (torch.Tensor): mask for xs, (B, 1, L) - use_dynamic_chunk (bool): whether to use dynamic chunk or not - use_dynamic_left_chunk (bool): whether to use dynamic left chunk for - training. - decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's - 0: default for training, use random dynamic chunk. - <0: for decoding, use full chunk. - >0: for decoding, use fixed chunk size as set. - static_chunk_size (int): chunk size for static chunk training/decoding - if it's greater than 0, if use_dynamic_chunk is true, - this parameter will be ignored - num_decoding_left_chunks: number of left chunks, this is for decoding, - the chunk size is decoding_chunk_size. - >=0: use num_decoding_left_chunks - <0: use all left chunks - enable_full_context (bool): - True: chunk size is either [1, 25] or full context(max_len) - False: chunk size ~ U[1, 25] - - Returns: - torch.Tensor: chunk mask of the input xs. - """ - # Whether to use chunk mask or not - if use_dynamic_chunk: - max_len = xs.size(1) - if decoding_chunk_size < 0: - chunk_size = max_len - num_left_chunks = -1 - elif decoding_chunk_size > 0: - chunk_size = decoding_chunk_size - num_left_chunks = num_decoding_left_chunks - else: - # chunk size is either [1, 25] or full context(max_len). - # Since we use 4 times subsampling and allow up to 1s(100 frames) - # delay, the maximum frame is 100 / 4 = 25. - chunk_size = torch.randint(1, max_len, (1, )).item() - num_left_chunks = -1 - if chunk_size > max_len // 2 and enable_full_context: - chunk_size = max_len - else: - chunk_size = chunk_size % 25 + 1 - if use_dynamic_left_chunk: - max_left_chunks = (max_len - 1) // chunk_size - num_left_chunks = torch.randint(0, max_left_chunks, - (1, )).item() - chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, - num_left_chunks, - xs.device) # (L, L) - chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) - chunk_masks = masks & chunk_masks # (B, L, L) - elif static_chunk_size > 0: - num_left_chunks = num_decoding_left_chunks - chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, - num_left_chunks, - xs.device) # (L, L) - chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) - chunk_masks = masks & chunk_masks # (B, L, L) - else: - chunk_masks = masks - assert chunk_masks.dtype == torch.bool - if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0: - logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!') - chunk_masks[chunk_masks.sum(dim=-1)==0] = True - return chunk_masks - - -def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: - """Make mask tensor containing indices of padded part. - - See description of make_non_pad_mask. - - Args: - lengths (torch.Tensor): Batch of lengths (B,). - Returns: - torch.Tensor: Mask tensor containing indices of padded part. - - Examples: - >>> lengths = [5, 3, 2] - >>> make_pad_mask(lengths) - masks = [[0, 0, 0, 0 ,0], - [0, 0, 0, 1, 1], - [0, 0, 1, 1, 1]] - """ - batch_size = lengths.size(0) - max_len = max_len if max_len > 0 else lengths.max().item() - seq_range = torch.arange(0, - max_len, - dtype=torch.int64, - device=lengths.device) - seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) - seq_length_expand = lengths.unsqueeze(-1) - mask = seq_range_expand >= seq_length_expand - return mask diff --git a/HF_Deploy/src/chatterbox/models/s3gen/utils/mel.py b/HF_Deploy/src/chatterbox/models/s3gen/utils/mel.py deleted file mode 100644 index 5a9ff9d11d67e1d6a96dd97d45a02366a3bba300..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/utils/mel.py +++ /dev/null @@ -1,81 +0,0 @@ -"""mel-spectrogram extraction in Matcha-TTS""" -from librosa.filters import mel as librosa_mel_fn -import torch -import numpy as np - - -# NOTE: they decalred these global vars -mel_basis = {} -hann_window = {} - - -def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def spectral_normalize_torch(magnitudes): - output = dynamic_range_compression_torch(magnitudes) - return output - -""" -feat_extractor: !name:matcha.utils.audio.mel_spectrogram - n_fft: 1920 - num_mels: 80 - sampling_rate: 24000 - hop_size: 480 - win_size: 1920 - fmin: 0 - fmax: 8000 - center: False - -""" - -def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480, win_size=1920, - fmin=0, fmax=8000, center=False): - """Copied from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/audio.py - Set default values according to Cosyvoice's config. - """ - - if isinstance(y, np.ndarray): - y = torch.tensor(y).float() - - if len(y.shape) == 1: - y = y[None, ] - - if torch.min(y) < -1.0: - print("min value is ", torch.min(y)) - if torch.max(y) > 1.0: - print("max value is ", torch.max(y)) - - global mel_basis, hann_window # pylint: disable=global-statement,global-variable-not-assigned - if f"{str(fmax)}_{str(y.device)}" not in mel_basis: - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) - hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" - ) - y = y.squeeze(1) - - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[str(y.device)], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - ) - - spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) - - spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) - spec = spectral_normalize_torch(spec) - - return spec diff --git a/HF_Deploy/src/chatterbox/models/s3gen/xvector.py b/HF_Deploy/src/chatterbox/models/s3gen/xvector.py deleted file mode 100644 index 6eb99af4aad25b33698211aa033d182d2f753379..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3gen/xvector.py +++ /dev/null @@ -1,428 +0,0 @@ -#!/usr/bin/env python3 -# -*- encoding: utf-8 -*- -# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. -# MIT License (https://opensource.org/licenses/MIT) -# Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker) - - -from collections import OrderedDict -import torch -import torch.nn.functional as F -import torch.utils.checkpoint as cp -import torchaudio.compliance.kaldi as Kaldi - - -def pad_list(xs, pad_value): - """Perform padding for the list of tensors. - - Args: - xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. - pad_value (float): Value for padding. - - Returns: - Tensor: Padded tensor (B, Tmax, `*`). - - Examples: - >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] - >>> x - [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] - >>> pad_list(x, 0) - tensor([[1., 1., 1., 1.], - [1., 1., 0., 0.], - [1., 0., 0., 0.]]) - - """ - n_batch = len(xs) - max_len = max(x.size(0) for x in xs) - pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) - - for i in range(n_batch): - pad[i, : xs[i].size(0)] = xs[i] - - return pad - - -def extract_feature(audio): - features = [] - feature_times = [] - feature_lengths = [] - for au in audio: - feature = Kaldi.fbank(au.unsqueeze(0), num_mel_bins=80) - feature = feature - feature.mean(dim=0, keepdim=True) - features.append(feature) - feature_times.append(au.shape[0]) - feature_lengths.append(feature.shape[0]) - # padding for batch inference - features_padded = pad_list(features, pad_value=0) - # features = torch.cat(features) - return features_padded, feature_lengths, feature_times - - -class BasicResBlock(torch.nn.Module): - expansion = 1 - - def __init__(self, in_planes, planes, stride=1): - super(BasicResBlock, self).__init__() - self.conv1 = torch.nn.Conv2d( - in_planes, planes, kernel_size=3, stride=(stride, 1), padding=1, bias=False - ) - self.bn1 = torch.nn.BatchNorm2d(planes) - self.conv2 = torch.nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) - self.bn2 = torch.nn.BatchNorm2d(planes) - - self.shortcut = torch.nn.Sequential() - if stride != 1 or in_planes != self.expansion * planes: - self.shortcut = torch.nn.Sequential( - torch.nn.Conv2d( - in_planes, - self.expansion * planes, - kernel_size=1, - stride=(stride, 1), - bias=False, - ), - torch.nn.BatchNorm2d(self.expansion * planes), - ) - - def forward(self, x): - out = F.relu(self.bn1(self.conv1(x))) - out = self.bn2(self.conv2(out)) - out += self.shortcut(x) - out = F.relu(out) - return out - - -class FCM(torch.nn.Module): - def __init__(self, block=BasicResBlock, num_blocks=[2, 2], m_channels=32, feat_dim=80): - super(FCM, self).__init__() - self.in_planes = m_channels - self.conv1 = torch.nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) - self.bn1 = torch.nn.BatchNorm2d(m_channels) - - self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=2) - self.layer2 = self._make_layer(block, m_channels, num_blocks[0], stride=2) - - self.conv2 = torch.nn.Conv2d( - m_channels, m_channels, kernel_size=3, stride=(2, 1), padding=1, bias=False - ) - self.bn2 = torch.nn.BatchNorm2d(m_channels) - self.out_channels = m_channels * (feat_dim // 8) - - def _make_layer(self, block, planes, num_blocks, stride): - strides = [stride] + [1] * (num_blocks - 1) - layers = [] - for stride in strides: - layers.append(block(self.in_planes, planes, stride)) - self.in_planes = planes * block.expansion - return torch.nn.Sequential(*layers) - - def forward(self, x): - x = x.unsqueeze(1) - out = F.relu(self.bn1(self.conv1(x))) - out = self.layer1(out) - out = self.layer2(out) - out = F.relu(self.bn2(self.conv2(out))) - - shape = out.shape - out = out.reshape(shape[0], shape[1] * shape[2], shape[3]) - return out - - -def get_nonlinear(config_str, channels): - nonlinear = torch.nn.Sequential() - for name in config_str.split("-"): - if name == "relu": - nonlinear.add_module("relu", torch.nn.ReLU(inplace=True)) - elif name == "prelu": - nonlinear.add_module("prelu", torch.nn.PReLU(channels)) - elif name == "batchnorm": - nonlinear.add_module("batchnorm", torch.nn.BatchNorm1d(channels)) - elif name == "batchnorm_": - nonlinear.add_module("batchnorm", torch.nn.BatchNorm1d(channels, affine=False)) - else: - raise ValueError("Unexpected module ({}).".format(name)) - return nonlinear - - -def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2): - mean = x.mean(dim=dim) - std = x.std(dim=dim, unbiased=unbiased) - stats = torch.cat([mean, std], dim=-1) - if keepdim: - stats = stats.unsqueeze(dim=dim) - return stats - - -class StatsPool(torch.nn.Module): - def forward(self, x): - return statistics_pooling(x) - - -class TDNNLayer(torch.nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - bias=False, - config_str="batchnorm-relu", - ): - super(TDNNLayer, self).__init__() - if padding < 0: - assert ( - kernel_size % 2 == 1 - ), "Expect equal paddings, but got even kernel size ({})".format(kernel_size) - padding = (kernel_size - 1) // 2 * dilation - self.linear = torch.nn.Conv1d( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - bias=bias, - ) - self.nonlinear = get_nonlinear(config_str, out_channels) - - def forward(self, x): - x = self.linear(x) - x = self.nonlinear(x) - return x - - -class CAMLayer(torch.nn.Module): - def __init__( - self, bn_channels, out_channels, kernel_size, stride, padding, dilation, bias, reduction=2 - ): - super(CAMLayer, self).__init__() - self.linear_local = torch.nn.Conv1d( - bn_channels, - out_channels, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - bias=bias, - ) - self.linear1 = torch.nn.Conv1d(bn_channels, bn_channels // reduction, 1) - self.relu = torch.nn.ReLU(inplace=True) - self.linear2 = torch.nn.Conv1d(bn_channels // reduction, out_channels, 1) - self.sigmoid = torch.nn.Sigmoid() - - def forward(self, x): - y = self.linear_local(x) - context = x.mean(-1, keepdim=True) + self.seg_pooling(x) - context = self.relu(self.linear1(context)) - m = self.sigmoid(self.linear2(context)) - return y * m - - def seg_pooling(self, x, seg_len=100, stype="avg"): - if stype == "avg": - seg = F.avg_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) - elif stype == "max": - seg = F.max_pool1d(x, kernel_size=seg_len, stride=seg_len, ceil_mode=True) - else: - raise ValueError("Wrong segment pooling type.") - shape = seg.shape - seg = seg.unsqueeze(-1).expand(*shape, seg_len).reshape(*shape[:-1], -1) - seg = seg[..., : x.shape[-1]] - return seg - - -class CAMDenseTDNNLayer(torch.nn.Module): - def __init__( - self, - in_channels, - out_channels, - bn_channels, - kernel_size, - stride=1, - dilation=1, - bias=False, - config_str="batchnorm-relu", - memory_efficient=False, - ): - super(CAMDenseTDNNLayer, self).__init__() - assert kernel_size % 2 == 1, "Expect equal paddings, but got even kernel size ({})".format( - kernel_size - ) - padding = (kernel_size - 1) // 2 * dilation - self.memory_efficient = memory_efficient - self.nonlinear1 = get_nonlinear(config_str, in_channels) - self.linear1 = torch.nn.Conv1d(in_channels, bn_channels, 1, bias=False) - self.nonlinear2 = get_nonlinear(config_str, bn_channels) - self.cam_layer = CAMLayer( - bn_channels, - out_channels, - kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - bias=bias, - ) - - def bn_function(self, x): - return self.linear1(self.nonlinear1(x)) - - def forward(self, x): - if self.training and self.memory_efficient: - x = cp.checkpoint(self.bn_function, x) - else: - x = self.bn_function(x) - x = self.cam_layer(self.nonlinear2(x)) - return x - - -class CAMDenseTDNNBlock(torch.nn.ModuleList): - def __init__( - self, - num_layers, - in_channels, - out_channels, - bn_channels, - kernel_size, - stride=1, - dilation=1, - bias=False, - config_str="batchnorm-relu", - memory_efficient=False, - ): - super(CAMDenseTDNNBlock, self).__init__() - for i in range(num_layers): - layer = CAMDenseTDNNLayer( - in_channels=in_channels + i * out_channels, - out_channels=out_channels, - bn_channels=bn_channels, - kernel_size=kernel_size, - stride=stride, - dilation=dilation, - bias=bias, - config_str=config_str, - memory_efficient=memory_efficient, - ) - self.add_module("tdnnd%d" % (i + 1), layer) - - def forward(self, x): - for layer in self: - x = torch.cat([x, layer(x)], dim=1) - return x - - -class TransitLayer(torch.nn.Module): - def __init__(self, in_channels, out_channels, bias=True, config_str="batchnorm-relu"): - super(TransitLayer, self).__init__() - self.nonlinear = get_nonlinear(config_str, in_channels) - self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias) - - def forward(self, x): - x = self.nonlinear(x) - x = self.linear(x) - return x - - -class DenseLayer(torch.nn.Module): - def __init__(self, in_channels, out_channels, bias=False, config_str="batchnorm-relu"): - super(DenseLayer, self).__init__() - self.linear = torch.nn.Conv1d(in_channels, out_channels, 1, bias=bias) - self.nonlinear = get_nonlinear(config_str, out_channels) - - def forward(self, x): - if len(x.shape) == 2: - x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1) - else: - x = self.linear(x) - x = self.nonlinear(x) - return x - -# @tables.register("model_classes", "CAMPPlus") -class CAMPPlus(torch.nn.Module): - def __init__( - self, - feat_dim=80, - embedding_size=192, - growth_rate=32, - bn_size=4, - init_channels=128, - config_str="batchnorm-relu", - memory_efficient=True, - output_level="segment", - **kwargs, - ): - super().__init__() - - self.head = FCM(feat_dim=feat_dim) - channels = self.head.out_channels - self.output_level = output_level - - self.xvector = torch.nn.Sequential( - OrderedDict( - [ - ( - "tdnn", - TDNNLayer( - channels, - init_channels, - 5, - stride=2, - dilation=1, - padding=-1, - config_str=config_str, - ), - ), - ] - ) - ) - channels = init_channels - for i, (num_layers, kernel_size, dilation) in enumerate( - zip((12, 24, 16), (3, 3, 3), (1, 2, 2)) - ): - block = CAMDenseTDNNBlock( - num_layers=num_layers, - in_channels=channels, - out_channels=growth_rate, - bn_channels=bn_size * growth_rate, - kernel_size=kernel_size, - dilation=dilation, - config_str=config_str, - memory_efficient=memory_efficient, - ) - self.xvector.add_module("block%d" % (i + 1), block) - channels = channels + num_layers * growth_rate - self.xvector.add_module( - "transit%d" % (i + 1), - TransitLayer(channels, channels // 2, bias=False, config_str=config_str), - ) - channels //= 2 - - self.xvector.add_module("out_nonlinear", get_nonlinear(config_str, channels)) - - if self.output_level == "segment": - self.xvector.add_module("stats", StatsPool()) - self.xvector.add_module( - "dense", DenseLayer(channels * 2, embedding_size, config_str="batchnorm_") - ) - else: - assert ( - self.output_level == "frame" - ), "`output_level` should be set to 'segment' or 'frame'. " - - for m in self.modules(): - if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)): - torch.nn.init.kaiming_normal_(m.weight.data) - if m.bias is not None: - torch.nn.init.zeros_(m.bias) - - def forward(self, x): - x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) - x = self.head(x) - x = self.xvector(x) - if self.output_level == "frame": - x = x.transpose(1, 2) - return x - - def inference(self, audio_list): - speech, speech_lengths, speech_times = extract_feature(audio_list) - results = self.forward(speech.to(torch.float32)) - return results diff --git a/HF_Deploy/src/chatterbox/models/s3tokenizer/__init__.py b/HF_Deploy/src/chatterbox/models/s3tokenizer/__init__.py deleted file mode 100644 index cb2973ab128fde44060d3f2d37e3c1bdc7a25d96..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3tokenizer/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -from .s3tokenizer import ( - S3_SR, - S3_HOP, - S3_TOKEN_HOP, - S3_TOKEN_RATE, - SPEECH_VOCAB_SIZE, - S3Tokenizer, -) - - -SOS = SPEECH_VOCAB_SIZE -EOS = SPEECH_VOCAB_SIZE + 1 - - - -def drop_invalid_tokens(x): - """Drop SoS and EoS""" - assert len(x.shape) == 1 or (len(x.shape) == 2 and x.shape[0] == 1), "only batch size of one allowed for now" - if SOS in x: - s = (x == SOS).nonzero(as_tuple=True)[0].squeeze(0) + 1 - else: - s = 0 - - if EOS in x: - e = (x == EOS).nonzero(as_tuple=True)[0].squeeze(0) - else: - e = None - - x = x[s: e] - return x diff --git a/HF_Deploy/src/chatterbox/models/s3tokenizer/s3tokenizer.py b/HF_Deploy/src/chatterbox/models/s3tokenizer/s3tokenizer.py deleted file mode 100644 index 8648608ae4d8f28bfeec090b5fdb426b6b0ad336..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/s3tokenizer/s3tokenizer.py +++ /dev/null @@ -1,168 +0,0 @@ -from typing import List, Tuple - -import numpy as np -import librosa -import torch -import torch.nn.functional as F -from s3tokenizer.utils import padding -from s3tokenizer.model_v2 import ( - S3TokenizerV2, - ModelConfig, -) - - -# Sampling rate of the inputs to S3TokenizerV2 -S3_SR = 16_000 -S3_HOP = 160 # 100 frames/sec -S3_TOKEN_HOP = 640 # 25 tokens/sec -S3_TOKEN_RATE = 25 -SPEECH_VOCAB_SIZE = 6561 - - -class S3Tokenizer(S3TokenizerV2): - """ - s3tokenizer.S3TokenizerV2 with the following changes: - - a more integrated `forward` - - compute `log_mel_spectrogram` using `_mel_filters` and `window` in `register_buffers` - """ - - ignore_state_dict_missing = ("_mel_filters", "window") - - def __init__( - self, - name: str="speech_tokenizer_v2_25hz", - config: ModelConfig = ModelConfig() - ): - super().__init__(name) - - self.n_fft = 400 - _mel_filters = librosa.filters.mel( - sr=S3_SR, - n_fft=self.n_fft, - n_mels=config.n_mels - ) - self.register_buffer( - "_mel_filters", - torch.FloatTensor(_mel_filters), - ) - - self.register_buffer( - "window", - torch.hann_window(self.n_fft), - ) - - def pad(self, wavs, sr) -> List[torch.Tensor]: - """ - Given a list of wavs with the same `sample_rate`, pad them so that the length is multiple of 40ms (S3 runs at 25 token/sec). - """ - processed_wavs = [] - for wav in wavs: - if isinstance(wav, np.ndarray): - wav = torch.from_numpy(wav) - if wav.dim() == 1: - wav = wav.unsqueeze(0) - - n_tokens = (wav.shape[1] / sr) * S3_TOKEN_RATE - n_tokens = np.ceil(n_tokens) - intended_wav_len = n_tokens * (sr / S3_TOKEN_RATE) - intended_wav_len = int(intended_wav_len) - wav = torch.nn.functional.pad( - wav, - (0, intended_wav_len - wav.shape[-1]), - mode="constant", - value=0 - ) - processed_wavs.append(wav) - return processed_wavs - - def _prepare_audio(self, wavs): - """Prepare a list of audios for s3tokenizer processing.""" - processed_wavs = [] - for wav in wavs: - if isinstance(wav, np.ndarray): - wav = torch.from_numpy(wav) - if wav.dim() == 1: - wav = wav.unsqueeze(0) - - processed_wavs.append(wav) - return processed_wavs - - @torch.no_grad() - def forward( - self, - wavs: torch.Tensor, - accelerator: 'Accelerator'=None, - max_len: int=None, - ) -> Tuple[torch.Tensor, torch.LongTensor]: - """ - NOTE: mel-spec has a hop size of 160 points (100 frame/sec). - FIXME: this class inherits `nn.Module` but doesn't accept `torch.Tensor` and handles a list of wavs one by one, which is unexpected. - - Args - ---- - - `wavs`: 16 kHz speech audio - - `max_len` max length to truncate the output sequence to (25 token/sec). - NOTE: please pad the waveform if longer sequence is needed. - """ - processed_wavs = self._prepare_audio(wavs) - mels, mel_lens = [], [] - for wav in processed_wavs: - wav = wav.to(self.device) - mel = self.log_mel_spectrogram(wav) # [B=1, F, T] - if max_len is not None: - mel = mel[..., :max_len * 4] # num_mel_frames = 4 * num_tokens - mels.append(mel.squeeze(0)) - - mels, mel_lens = padding(mels) - if accelerator is None: - tokenizer = self - else: - tokenizer = accelerator.unwrap_model(self) - - speech_tokens, speech_token_lens = tokenizer.quantize(mels, mel_lens.to(self.device)) - return ( - speech_tokens.long().detach(), - speech_token_lens.long().detach(), - ) - - def log_mel_spectrogram( - self, - audio: torch.Tensor, - padding: int = 0, - ): - """ - Compute the log-Mel spectrogram of - - Parameters - ---------- - audio: torch.Tensor, shape = (*) - The path to audio or either a NumPy array or Tensor containing the - audio waveform in 16 kHz - - padding: int - Number of zero samples to pad to the right - - Returns - ------- - torch.Tensor, shape = (128, n_frames) - A Tensor that contains the Mel spectrogram - """ - if not torch.is_tensor(audio): - audio = torch.from_numpy(audio) - - audio = audio.to(self.device) - if padding > 0: - audio = F.pad(audio, (0, padding)) - stft = torch.stft( - audio, self.n_fft, S3_HOP, - window=self.window.to(self.device), - return_complex=True - ) - magnitudes = stft[..., :-1].abs()**2 - - mel_spec = self._mel_filters.to(self.device) @ magnitudes - - log_spec = torch.clamp(mel_spec, min=1e-10).log10() - log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) - log_spec = (log_spec + 4.0) / 4.0 - return log_spec diff --git a/HF_Deploy/src/chatterbox/models/t3/__init__.py b/HF_Deploy/src/chatterbox/models/t3/__init__.py deleted file mode 100644 index c15519f6107cd9f4d825e420d2ecbd85e92c8671..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/t3/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .t3 import T3 diff --git a/HF_Deploy/src/chatterbox/models/t3/inference/alignment_stream_analyzer (copy).py b/HF_Deploy/src/chatterbox/models/t3/inference/alignment_stream_analyzer (copy).py deleted file mode 100644 index d3a144f0f7f0cdef4a7a4c049db3b5433744296e..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/t3/inference/alignment_stream_analyzer (copy).py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright (c) 2025 Resemble AI -# Author: John Meade, Jeremy Hsu -# MIT License -import logging -import torch -from dataclasses import dataclass -from types import MethodType - - -logger = logging.getLogger(__name__) - - -@dataclass -class AlignmentAnalysisResult: - # was this frame detected as being part of a noisy beginning chunk with potential hallucinations? - false_start: bool - # was this frame detected as being part of a long tail with potential hallucinations? - long_tail: bool - # was this frame detected as repeating existing text content? - repetition: bool - # was the alignment position of this frame too far from the previous frame? - discontinuity: bool - # has inference reached the end of the text tokens? eg, this remains false if inference stops early - complete: bool - # approximate position in the text token sequence. Can be used for generating online timestamps. - position: int - - -class AlignmentStreamAnalyzer: - def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_idx=0): - """ - Some transformer TTS models implicitly solve text-speech alignment in one or more of their self-attention - activation maps. This module exploits this to perform online integrity checks which streaming. - A hook is injected into the specified attention layer, and heuristics are used to determine alignment - position, repetition, etc. - - NOTE: currently requires no queues. - """ - # self.queue = queue - self.text_tokens_slice = (i, j) = text_tokens_slice - self.eos_idx = eos_idx - self.alignment = torch.zeros(0, j-i) - # self.alignment_bin = torch.zeros(0, j-i) - self.curr_frame_pos = 0 - self.text_position = 0 - - self.started = False - self.started_at = None - - self.complete = False - self.completed_at = None - - # Using `output_attentions=True` is incompatible with optimized attention kernels, so - # using it for all layers slows things down too much. We can apply it to just one layer - # by intercepting the kwargs and adding a forward hook (credit: jrm) - self.last_aligned_attn = None - self._add_attention_spy(tfmr, alignment_layer_idx) - - def _add_attention_spy(self, tfmr, alignment_layer_idx): - """ - Adds a forward hook to a specific attention layer to collect outputs. - Using `output_attentions=True` is incompatible with optimized attention kernels, so - using it for all layers slows things down too much. - (credit: jrm) - """ - - def attention_forward_hook(module, input, output): - """ - See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`. - NOTE: - - When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`. - - `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th. - """ - step_attention = output[1].cpu() # (B, 16, N, N) - self.last_aligned_attn = step_attention[0].mean(0) # (N, N) - - target_layer = tfmr.layers[alignment_layer_idx].self_attn - hook_handle = target_layer.register_forward_hook(attention_forward_hook) - - # Backup original forward - original_forward = target_layer.forward - def patched_forward(self, *args, **kwargs): - kwargs['output_attentions'] = True - return original_forward(*args, **kwargs) - - # TODO: how to unpatch it? - target_layer.forward = MethodType(patched_forward, target_layer) - - def step(self, logits): - """ - Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS. - """ - # extract approximate alignment matrix chunk (1 frame at a time after the first chunk) - aligned_attn = self.last_aligned_attn # (N, N) - i, j = self.text_tokens_slice - if self.curr_frame_pos == 0: - # first chunk has conditioning info, text tokens, and BOS token - A_chunk = aligned_attn[j:, i:j].clone().cpu() # (T, S) - else: - # subsequent chunks have 1 frame due to KV-caching - A_chunk = aligned_attn[:, i:j].clone().cpu() # (1, S) - - # TODO: monotonic masking; could have issue b/c spaces are often skipped. - A_chunk[:, self.curr_frame_pos + 1:] = 0 - - - self.alignment = torch.cat((self.alignment, A_chunk), dim=0) - - A = self.alignment - T, S = A.shape - - # update position - cur_text_posn = A_chunk[-1].argmax() - discontinuity = not(-4 < cur_text_posn - self.text_position < 7) # NOTE: very lenient! - if not discontinuity: - self.text_position = cur_text_posn - - # Hallucinations at the start of speech show up as activations at the bottom of the attention maps! - # To mitigate this, we just wait until there are no activations far off-diagonal in the last 2 tokens, - # and there are some strong activations in the first few tokens. - false_start = (not self.started) and (A[-2:, -2:].max() > 0.1 or A[:, :4].max() < 0.5) - self.started = not false_start - if self.started and self.started_at is None: - self.started_at = T - - # Is generation likely complete? - self.complete = self.complete or self.text_position >= S - 3 - if self.complete and self.completed_at is None: - self.completed_at = T - - # NOTE: EOS rarely assigned activations, and second-last token is often punctuation, so use last 3 tokens. - # NOTE: due to the false-start behaviour, we need to make sure we skip activations for the first few tokens. - last_text_token_duration = A[15:, -3:].sum() - - # Activations for the final token that last too long are likely hallucinations. - long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 10) # 400ms - - # If there are activations in previous tokens after generation has completed, assume this is a repetition error. - repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5) - - # If a bad ending is detected, force emit EOS by modifying logits - # NOTE: this means logits may be inconsistent with latents! - if long_tail or repetition: - logger.warn(f"forcing EOS token, {long_tail=}, {repetition=}") - # (Β±2**15 is safe for all dtypes >= 16bit) - logits = -(2**15) * torch.ones_like(logits) - logits[..., self.eos_idx] = 2**15 - - # Suppress EoS to prevent early termination - if cur_text_posn < S - 3: # FIXME: arbitrary - logits[..., self.eos_idx] = -2**15 - - self.curr_frame_pos += 1 - return logits diff --git a/HF_Deploy/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py b/HF_Deploy/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py deleted file mode 100644 index 12efb95432f47c70122d655483d99724a65eeeae..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/t3/inference/alignment_stream_analyzer.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright (c) 2025 Resemble AI -# Author: John Meade, Jeremy Hsu -# MIT License -import logging -import torch -from dataclasses import dataclass -from types import MethodType - - -logger = logging.getLogger(__name__) - - -@dataclass -class AlignmentAnalysisResult: - # was this frame detected as being part of a noisy beginning chunk with potential hallucinations? - false_start: bool - # was this frame detected as being part of a long tail with potential hallucinations? - long_tail: bool - # was this frame detected as repeating existing text content? - repetition: bool - # was the alignment position of this frame too far from the previous frame? - discontinuity: bool - # has inference reached the end of the text tokens? eg, this remains false if inference stops early - complete: bool - # approximate position in the text token sequence. Can be used for generating online timestamps. - position: int - - -class AlignmentStreamAnalyzer: - def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_idx=0): - """ - Some transformer TTS models implicitly solve text-speech alignment in one or more of their self-attention - activation maps. This module exploits this to perform online integrity checks which streaming. - A hook is injected into the specified attention layer, and heuristics are used to determine alignment - position, repetition, etc. - - NOTE: currently requires no queues. - """ - # self.queue = queue - self.text_tokens_slice = (i, j) = text_tokens_slice - self.eos_idx = eos_idx - self.alignment = torch.zeros(0, j-i) - # self.alignment_bin = torch.zeros(0, j-i) - self.curr_frame_pos = 0 - self.text_position = 0 - - self.started = False - self.started_at = None - - self.complete = False - self.completed_at = None - - # Using `output_attentions=True` is incompatible with optimized attention kernels, so - # using it for all layers slows things down too much. We can apply it to just one layer - # by intercepting the kwargs and adding a forward hook (credit: jrm) - self.last_aligned_attn = None - self._add_attention_spy(tfmr, alignment_layer_idx) - - def _add_attention_spy(self, tfmr, alignment_layer_idx): - """ - Adds a forward hook to a specific attention layer to collect outputs. - Using `output_attentions=True` is incompatible with optimized attention kernels, so - using it for all layers slows things down too much. - (credit: jrm) - """ - - def attention_forward_hook(module, input, output): - """ - See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`. - NOTE: - - When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`. - - `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th. - """ - step_attention = output[1].cpu() # (B, 16, N, N) - self.last_aligned_attn = step_attention[0].mean(0) # (N, N) - - target_layer = tfmr.layers[alignment_layer_idx].self_attn - hook_handle = target_layer.register_forward_hook(attention_forward_hook) - - # Backup original forward - from types import MethodType - - def patch_alignment_layer(tfmr, alignment_layer_idx=12): - # β›” Old (broken) logic – causes recursion: - # target_layer = tfmr.layers[alignment_layer_idx].self_attn - # def patched_forward(self, *args, **kwargs): - # kwargs['output_attentions'] = True - # return original_forward(*args, **kwargs) - # target_layer.forward = MethodType(patched_forward, target_layer) - - # βœ… Corrected logic to avoid recursion: - target_layer = tfmr.layers[alignment_layer_idx].self_attn - original_forward = target_layer.forward # Save the real unpatched forward - - def patched_forward(self, *args, **kwargs): - kwargs['output_attentions'] = True - return original_forward(*args, **kwargs) - - target_layer.forward = MethodType(patched_forward, target_layer) - - - def step(self, logits): - """ - Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS. - """ - # extract approximate alignment matrix chunk (1 frame at a time after the first chunk) - aligned_attn = self.last_aligned_attn # (N, N) - i, j = self.text_tokens_slice - if self.curr_frame_pos == 0: - # first chunk has conditioning info, text tokens, and BOS token - A_chunk = aligned_attn[j:, i:j].clone().cpu() # (T, S) - else: - # subsequent chunks have 1 frame due to KV-caching - A_chunk = aligned_attn[:, i:j].clone().cpu() # (1, S) - - # TODO: monotonic masking; could have issue b/c spaces are often skipped. - A_chunk[:, self.curr_frame_pos + 1:] = 0 - - - self.alignment = torch.cat((self.alignment, A_chunk), dim=0) - - A = self.alignment - T, S = A.shape - - # update position - cur_text_posn = A_chunk[-1].argmax() - discontinuity = not(-4 < cur_text_posn - self.text_position < 7) # NOTE: very lenient! - if not discontinuity: - self.text_position = cur_text_posn - - # Hallucinations at the start of speech show up as activations at the bottom of the attention maps! - # To mitigate this, we just wait until there are no activations far off-diagonal in the last 2 tokens, - # and there are some strong activations in the first few tokens. - false_start = (not self.started) and (A[-2:, -2:].max() > 0.1 or A[:, :4].max() < 0.5) - self.started = not false_start - if self.started and self.started_at is None: - self.started_at = T - - # Is generation likely complete? - self.complete = self.complete or self.text_position >= S - 3 - if self.complete and self.completed_at is None: - self.completed_at = T - - # NOTE: EOS rarely assigned activations, and second-last token is often punctuation, so use last 3 tokens. - # NOTE: due to the false-start behaviour, we need to make sure we skip activations for the first few tokens. - last_text_token_duration = A[15:, -3:].sum() - - # Activations for the final token that last too long are likely hallucinations. - long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 10) # 400ms - - # If there are activations in previous tokens after generation has completed, assume this is a repetition error. - repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5) - - # If a bad ending is detected, force emit EOS by modifying logits - # NOTE: this means logits may be inconsistent with latents! - if long_tail or repetition: - logger.warn(f"forcing EOS token, {long_tail=}, {repetition=}") - # (Β±2**15 is safe for all dtypes >= 16bit) - logits = -(2**15) * torch.ones_like(logits) - logits[..., self.eos_idx] = 2**15 - - # Suppress EoS to prevent early termination - if cur_text_posn < S - 3: # FIXME: arbitrary - logits[..., self.eos_idx] = -2**15 - - self.curr_frame_pos += 1 - return logits diff --git a/HF_Deploy/src/chatterbox/models/t3/inference/t3_hf_backend.py b/HF_Deploy/src/chatterbox/models/t3/inference/t3_hf_backend.py deleted file mode 100644 index 69a6bf20ecafb87f3beb799838c3a79ec134784e..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/t3/inference/t3_hf_backend.py +++ /dev/null @@ -1,116 +0,0 @@ -from typing import Optional - -import torch -from torch import nn as nn -from transformers import LlamaConfig, LlamaModel, LlamaPreTrainedModel, GenerationMixin -from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions - - -class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin): - """ - Override some HuggingFace interface methods so we can use the standard `generate` method with our - custom embedding / logit layers. - - NOTE: need to extend "*PreTrainedModel" to avoid re-initializing weights! - """ - - def __init__( - self, - config: LlamaConfig, - llama: LlamaModel, - *, - speech_enc, - speech_head, - latents_queue=None, - logits_queue=None, - alignment_stream_analyzer: 'AlignmentStreamAnalyzer'=None, - ): - super().__init__(config) - self.model = llama - self.speech_enc = speech_enc - self.speech_head = speech_head - self._added_cond = False - self.alignment_stream_analyzer = alignment_stream_analyzer - - @torch.inference_mode() - def prepare_inputs_for_generation( - self, input_ids: torch.Tensor, decoder_cond: torch.Tensor, use_cache: bool, past_key_values=None, - # This argument was introduced in some recent version of transformers (>=4.29.1) - cache_position=None - ): - """ - This is a method used by huggingface's generate() method. - Overridden here to apply our custom speech token embedding layer. - - :param input_ids: (B, S) int64 tensors of input tokens. - :param decoder_cond: (B, T, C) float32 tensor of conditioning (prefixed to ) - """ - - # Make use of the kv cache: only the last input ID is new, we trim away all the ones before - if not use_cache: - past_key_values = None - if past_key_values is not None: - input_ids = input_ids[:, -1:] - - # custom speech token embedding layer - inputs_embeds = self.speech_enc(input_ids) - - # prefix decoder conditioning if applicable - if not self._added_cond: - assert past_key_values is not None # should be first step - if decoder_cond.size(0) != inputs_embeds.size(0): - decoder_cond = decoder_cond.expand(inputs_embeds.size(0), -1, -1) - inputs_embeds = torch.cat([decoder_cond, inputs_embeds], dim=1) - self._added_cond = True - - return { - "inputs_embeds": inputs_embeds, - "past_key_values": past_key_values, - "use_cache": use_cache, - } - - @torch.inference_mode() - def forward( - self, - inputs_embeds: torch.Tensor, - past_key_values: Optional[torch.Tensor]=None, - use_cache=True, - output_attentions=False, - output_hidden_states=True, - return_dict=True, - ): - """ - This is a method used by huggingface's generate() method. - Overridden here to apply our custom layer norm and speech logit projection layers. - - :param inputs_embeds: (B, S, C) float32 tensor of conditioning inputs. If past key values are given, - S should be 1. - """ - is_large_input = inputs_embeds.size(1) != 1 - has_cache = past_key_values is not None and len(past_key_values) > 0 - assert not (is_large_input and has_cache) - assert return_dict - assert output_hidden_states - - tfmr_out = self.model( - inputs_embeds=inputs_embeds, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=True, - ) - hidden_states = tfmr_out.hidden_states[-1] # (B, seq, dim) - - logits = self.speech_head(hidden_states) - # assert inputs_embeds.size(0) == 1 # (disabled for CFG) - - # NOTE: hallucination handler may modify logits to force emit an EOS token - # logits = self.alignment_stream_analyzer.step(logits) - - return CausalLMOutputWithCrossAttentions( - logits=logits, - past_key_values=tfmr_out.past_key_values, - hidden_states=tfmr_out.hidden_states, - attentions=tfmr_out.attentions, - ) diff --git a/HF_Deploy/src/chatterbox/models/t3/llama_configs.py b/HF_Deploy/src/chatterbox/models/t3/llama_configs.py deleted file mode 100644 index 14d068161ddb38de613c92d54e34d1ed72261d40..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/t3/llama_configs.py +++ /dev/null @@ -1,37 +0,0 @@ -LLAMA_520M_CONFIG_DICT = dict( - # Arbitrary small number that won't cause problems when loading. - # These param are unused due to custom input layers. - vocab_size=8, - # default params needed for loading most pretrained 1B weights - max_position_embeddings=131072, - hidden_size=1024, - intermediate_size=4096, - num_hidden_layers=30, - num_attention_heads=16, - attn_implementation="sdpa", - head_dim=64, - tie_word_embeddings=False, - hidden_act="silu", - attention_bias=False, - attention_dropout=0.0, - initializer_range=0.02, - mlp_bias=False, - model_type="llama", - num_key_value_heads=16, - pretraining_tp=1, - rms_norm_eps=1e-05, - rope_scaling=dict( - factor=8.0, - high_freq_factor=4.0, - low_freq_factor=1.0, - original_max_position_embeddings=8192, - rope_type="llama3" - ), - rope_theta=500000.0, - torch_dtype="bfloat16", - use_cache=True, -) - -LLAMA_CONFIGS = { - "Llama_520M": LLAMA_520M_CONFIG_DICT, -} diff --git a/HF_Deploy/src/chatterbox/models/t3/modules/cond_enc.py b/HF_Deploy/src/chatterbox/models/t3/modules/cond_enc.py deleted file mode 100644 index b5f15c685783fbb048f6c0e86fc2ea8fbf1ec3de..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/t3/modules/cond_enc.py +++ /dev/null @@ -1,97 +0,0 @@ -from dataclasses import dataclass -from typing import Optional - -import torch -from torch import nn, Tensor - -from .perceiver import Perceiver -from .t3_config import T3Config - - -@dataclass -class T3Cond: - """ - Dataclass container for most / all conditioning info. - TODO: serialization methods aren't used, keeping them around for convenience - """ - - speaker_emb: Tensor - clap_emb: Optional[Tensor] = None - cond_prompt_speech_tokens: Optional[Tensor] = None - cond_prompt_speech_emb: Optional[Tensor] = None - emotion_adv: Optional[Tensor] = 0.5 - - def to(self, *, device=None, dtype=None): - "Cast to a device and dtype. Dtype casting is ignored for long/int tensors." - for k, v in self.__dict__.items(): - if torch.is_tensor(v): - is_fp = type(v.view(-1)[0].item()) is not int - setattr(self, k, v.to(device=device, dtype=dtype if is_fp else None)) - return self - - def save(self, fpath): - torch.save(self.__dict__, fpath) - - @staticmethod - def load(fpath, map_location="cpu"): - kwargs = torch.load(fpath, map_location=map_location, weights_only=True) - return T3Cond(**kwargs) - - -class T3CondEnc(nn.Module): - """ - Handle all non-text conditioning, like speaker embeddings / prompts, CLAP, emotion, etc. - """ - - def __init__(self, hp: T3Config): - super().__init__() - self.hp = hp - if hp.encoder_type == "voice_encoder": - self.spkr_enc = nn.Linear(hp.speaker_embed_size, hp.n_channels) - else: - raise NotImplementedError(str(hp.encoder_type)) - - # emotion adv - self.emotion_adv_fc = None - if hp.emotion_adv: - self.emotion_adv_fc = nn.Linear(1, hp.n_channels, bias=False) - - # perceiver resampler - self.perceiver = None - if hp.use_perceiver_resampler: - self.perceiver = Perceiver() - - def forward(self, cond: T3Cond): - # Validate - assert (cond.cond_prompt_speech_tokens is None) == (cond.cond_prompt_speech_emb is None), \ - "no embeddings for cond_prompt_speech_tokens" - - # Speaker embedding projection - cond_spkr = self.spkr_enc(cond.speaker_emb.view(-1, self.hp.speaker_embed_size))[:, None] # (B, 1, dim) - empty = torch.zeros_like(cond_spkr[:, :0]) # (B, 0, dim) - - # TODO CLAP - assert cond.clap_emb is None, "clap_embed not implemented" - cond_clap = empty # (B, 0, dim) - - # Cond prompt - cond_prompt_speech_emb = cond.cond_prompt_speech_emb - if cond_prompt_speech_emb is None: - cond_prompt_speech_emb = empty # (B, 0, dim) - elif self.hp.use_perceiver_resampler: - cond_prompt_speech_emb = self.perceiver(cond_prompt_speech_emb) - - # Emotion Adv: must provide a value if this model uses emotion conditioning - cond_emotion_adv = empty # (B, 0, dim) - if self.hp.emotion_adv: - assert cond.emotion_adv is not None - cond_emotion_adv = self.emotion_adv_fc(cond.emotion_adv.view(-1, 1, 1)) - - # Concat and return - cond_embeds = torch.cat(( - cond_spkr, - cond_clap, - cond_prompt_speech_emb, - cond_emotion_adv, - ), dim=1) - return cond_embeds diff --git a/HF_Deploy/src/chatterbox/models/t3/modules/learned_pos_emb.py b/HF_Deploy/src/chatterbox/models/t3/modules/learned_pos_emb.py deleted file mode 100644 index 9b197f218192688f743a904676d66ff741eb33e3..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/t3/modules/learned_pos_emb.py +++ /dev/null @@ -1,32 +0,0 @@ -from typing import Union - -import torch -from torch import nn, Tensor - - -class LearnedPositionEmbeddings(nn.Module): - def __init__(self, seq_len, model_dim, init=.02): - super().__init__() - self.emb = nn.Embedding(seq_len, model_dim) - # Initializing this way is standard for GPT-2 - self.emb.weight.data.normal_(mean=0.0, std=init) - - def forward(self, x): - """ - Returns positional embeddings for index 0 up to the length of x - """ - sl = x.shape[1] - return self.emb(torch.arange(0, sl, device=x.device)) - - def get_fixed_embedding(self, idx: 'Union[int, Tensor]'): - """ - Args: - idx: scalar int or an integer tensor of shape (T,) or (B, T) - Returns: - positional embeddings for given indices, shape (B, T, dim), ie (1, 1, dim) for int input - """ - device = self.emb.weight.device - idx = idx.to(device) if torch.is_tensor(idx) else torch.tensor(idx, device=device) - idx = torch.atleast_2d(idx) - assert idx.ndim == 2 - return self.emb(idx) # (B, T, dim) diff --git a/HF_Deploy/src/chatterbox/models/t3/modules/perceiver.py b/HF_Deploy/src/chatterbox/models/t3/modules/perceiver.py deleted file mode 100644 index be9c5b863ce43ab43c0124a8ae0fa125b0da9673..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/t3/modules/perceiver.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright (c) 2025 Resemble AI -# Author: Manmay Nakhashi -# MIT License -import math - -import torch -from torch import nn -import torch.nn.functional as F -from einops import rearrange - - -class RelativePositionBias(nn.Module): - def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8): - super().__init__() - self.scale = scale - self.causal = causal - self.num_buckets = num_buckets - self.max_distance = max_distance - self.relative_attention_bias = nn.Embedding(num_buckets, heads) - - @staticmethod - def _relative_position_bucket(relative_position, causal=True, num_buckets=32, max_distance=128): - ret = 0 - n = -relative_position - if not causal: - num_buckets //= 2 - ret += (n < 0).long() * num_buckets - n = torch.abs(n) - else: - n = torch.max(n, torch.zeros_like(n)) - - max_exact = num_buckets // 2 - is_small = n < max_exact - - val_if_large = max_exact + ( - torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) - ).long() - val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) - - ret += torch.where(is_small, n, val_if_large) - return ret - - def forward(self, qk_dots): - i, j, device = *qk_dots.shape[-2:], qk_dots.device - q_pos = torch.arange(i, dtype=torch.long, device=device) - k_pos = torch.arange(j, dtype=torch.long, device=device) - rel_pos = k_pos[None, :] - q_pos[:, None] - rp_bucket = self._relative_position_bucket(rel_pos, causal=self.causal, num_buckets=self.num_buckets, - max_distance=self.max_distance) - values = self.relative_attention_bias(rp_bucket) - bias = rearrange(values, 'i j h -> () h i j') - return qk_dots + (bias * self.scale) - - -class AttentionQKV(nn.Module): - def __init__(self, n_heads, head_dim, dropout_rate=0.1, scale=None, flash=False): - super().__init__() - self.n_heads = n_heads - self.head_dim = head_dim - self.scale = scale if scale is not None else head_dim ** -0.5 - self.flash = flash - self.dropout_rate = dropout_rate - self.dropout = nn.Dropout(dropout_rate) - self.flash_config = self.setup_flash_config() if flash else None - - def setup_flash_config(self): - # Setup flash attention configuration - flash_config = { - 'enable_flash': True, - 'enable_math': True, - 'enable_mem_efficient': True - } - return flash_config - - def forward(self, q, k, v, mask=None): - q, k, v = [self.split_heads(tensor) for tensor in [q, k, v]] - if self.flash: - out = self.flash_attention(q, k, v, mask=mask) - else: - out = self.scaled_dot_product_attention(q, k, v, mask=mask) - - return self.combine_heads(out) - - def scaled_dot_product_attention(self, q, k, v, mask=None): - sim = torch.einsum("bhlt,bhls->bhts", q, k) * self.scale - if mask is not None: - sim = sim.masked_fill(mask == 0, float('-inf')) - attn = torch.softmax(sim, dim=-1) - attn = self.dropout(attn) - return torch.einsum("bhts,bhls->bhlt", attn, v) - - def flash_attention(self, q, k, v, mask=None): - config = self.flash_config if self.flash_config else {} - with torch.backends.cuda.sdp_kernel(**config): - out = F.scaled_dot_product_attention( - q, k, v, - attn_mask=mask, - dropout_p=self.dropout_rate if self.training else 0. - ) - return out - - def split_heads(self, x): - bs, length, _ = x.shape - x = x.view(bs, length, self.n_heads, self.head_dim) - return x.permute(0, 2, 1, 3) - - def combine_heads(self, x): - bs, _, length, _ = x.shape - x = x.permute(0, 2, 1, 3).contiguous() - return x.view(bs, length, -1) - - -class AttentionBlock2(nn.Module): - """ - An attention block that allows spatial positions to attend to each other, - using AttentionQKV and separate linear transformations for Q, K, and V. - """ - - def __init__( - self, - channels, - num_heads=1, - num_head_channels=-1, - relative_pos_embeddings=False, - flash_attention=True, - dropout_rate=0.2, - scale=None - ): - super().__init__() - self.channels = channels - - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - - self.norm = nn.LayerNorm(channels) - - # Separate linear layers for Q, K, and V - self.to_q = nn.Linear(channels, channels) - self.to_k = nn.Linear(channels, channels) - self.to_v = nn.Linear(channels, channels) - - self.attention = AttentionQKV(self.num_heads, channels // self.num_heads, dropout_rate=dropout_rate, flash=flash_attention, scale=scale) - - self.proj_out = nn.Linear(channels, channels) - - if relative_pos_embeddings: - self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) - else: - self.relative_pos_embeddings = None - - def forward(self, x1, x2, mask=None): - b1, c1, *spatial1 = x1.shape - b2, c2, *spatial2 = x2.shape - - x1_norm = self.norm(x1) - x2_norm = self.norm(x2) - - q = self.to_q(x1_norm) - k = self.to_k(x2_norm) - v = self.to_v(x2_norm) - - h = self.attention(q, k, v, mask=mask) - h = self.proj_out(h) - - return (x1 + h).reshape(b1, c1, *spatial1) - - -class Perceiver(nn.Module): - """Inspired by https://arxiv.org/abs/2103.03206""" - def __init__(self, pre_attention_query_token=32, pre_attention_query_size=1024, embedding_dim=1024, num_attn_heads=4): - """ - Initialize the perceiver module. - - :param pre_attention_query_token: Number of query tokens for pre-attention - :param pre_attention_query_size: Size of each query token - :param embedding_dim: Dimension of the embedding space - :param num_attn_heads: Number of attention heads - """ - super().__init__() - - # Initialize the pre-attention query parameter - self.pre_attention_query = torch.nn.Parameter( - torch.empty(1, pre_attention_query_token, pre_attention_query_size) - ) - - # Calculate the variance for uniform initialization - query_variance = math.sqrt(3.0) * math.sqrt(2.0 / (pre_attention_query_token + pre_attention_query_token)) - - # Initialize the pre-attention query with uniform distribution - self.pre_attention_query.data.uniform_(-query_variance, query_variance) - - # Initialize the attention block - self.attn = AttentionBlock2(embedding_dim, num_attn_heads) - - def forward(self, h): - """ - Forward pass of the perceiver module. - :param h: Input tensor - :return: Output after applying attention mechanisms - """ - # Expand the pre-attention query to match the batch size of the input - query_ = self.pre_attention_query.expand(h.shape[0], -1, -1) - # Apply the first attention mechanism (cross-attention) - pre_att = self.attn(query_, h) - # Apply the second attention mechanism (self-attention) - attn = self.attn(pre_att, pre_att) - return attn diff --git a/HF_Deploy/src/chatterbox/models/t3/modules/t3_config.py b/HF_Deploy/src/chatterbox/models/t3/modules/t3_config.py deleted file mode 100644 index 2769d835692578c7f8fb0f9bcf6b42daa4b0cd03..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/t3/modules/t3_config.py +++ /dev/null @@ -1,27 +0,0 @@ -from ..llama_configs import LLAMA_CONFIGS - - -class T3Config: - start_text_token = 255 - stop_text_token = 0 - text_tokens_dict_size = 704 - max_text_tokens = 2048 - - start_speech_token = 6561 - stop_speech_token = 6562 - speech_tokens_dict_size = 8194 - max_speech_tokens = 4096 - - llama_config_name = "Llama_520M" - input_pos_emb = "learned" - speech_cond_prompt_len = 150 - - # For T3CondEnc - encoder_type = "voice_encoder" - speaker_embed_size = 256 - use_perceiver_resampler = True - emotion_adv = True - - @property - def n_channels(self): - return LLAMA_CONFIGS[self.llama_config_name]["hidden_size"] diff --git a/HF_Deploy/src/chatterbox/models/t3/t3.py b/HF_Deploy/src/chatterbox/models/t3/t3.py deleted file mode 100644 index 733febfcf0901a676ce7790fd7600852a03ead15..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/t3/t3.py +++ /dev/null @@ -1,381 +0,0 @@ -# Copyright (c) 2025 Resemble AI -# MIT License -import logging -from typing import Union, Optional, List - -from tqdm import tqdm -import torch -import torch.nn.functional as F -from torch import nn, Tensor -from transformers import LlamaModel, LlamaConfig -from transformers.generation.logits_process import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor - -from .modules.learned_pos_emb import LearnedPositionEmbeddings - -from .modules.cond_enc import T3CondEnc, T3Cond -from .modules.t3_config import T3Config -from .llama_configs import LLAMA_CONFIGS -from .inference.t3_hf_backend import T3HuggingfaceBackend -from .inference.alignment_stream_analyzer import AlignmentStreamAnalyzer - - -logger = logging.getLogger(__name__) - - -class AttrDict(dict): - def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) - self.__dict__ = self - - -def _ensure_BOT_EOT(text_tokens: Tensor, hp): - B = text_tokens.size(0) - assert (text_tokens == hp.start_text_token).int().sum() >= B, "missing start_text_token" - assert (text_tokens == hp.stop_text_token).int().sum() >= B, "missing stop_text_token" - - -class T3(nn.Module): - """ - Token-To-Token (T3) TTS model using huggingface transformer models as backbones, - * tokenization, including start / stop tokens are always added externally to this class - * conditioning data like CLAP, emotion, etc are all in a separate file for more modularity - * careful! this class assumes relative positional encoding -- with absolute PE, we would at - least want to reset the position to 0 when speech tokens begin, and optionally use a - different PE embedding space for speech. - """ - - def __init__(self, hp=T3Config()): - super().__init__() - self.hp = hp - self.cfg = LlamaConfig(**LLAMA_CONFIGS[hp.llama_config_name]) - self.tfmr = LlamaModel(self.cfg) - self.dim = self.cfg.hidden_size - self.deepspeed_patch_applied = False - - # conditioning / embedding - self.cond_enc = T3CondEnc(hp) - self.text_emb = nn.Embedding(hp.text_tokens_dict_size, self.dim) - self.speech_emb = nn.Embedding(hp.speech_tokens_dict_size, self.dim) - - # custom position embedding - if hp.input_pos_emb == "learned": - max_text_seq_len = hp.max_text_tokens + 2 - self.text_pos_emb = LearnedPositionEmbeddings(max_text_seq_len, self.dim) - - max_mel_seq_len = hp.max_speech_tokens + 2 + 2 - self.speech_pos_emb = LearnedPositionEmbeddings(max_mel_seq_len, self.dim) - - # logit projection - self.text_head = nn.Linear(self.cfg.hidden_size, hp.text_tokens_dict_size, bias=False) - self.speech_head = nn.Linear(self.cfg.hidden_size, hp.speech_tokens_dict_size, bias=False) - self.compiled = False - - @property - def device(self): - return self.speech_head.weight.device - - def prepare_conditioning(self, t3_cond: T3Cond): - """ - Token cond data needs to be embedded, so that needs to be here instead of in `T3CondEnc`. - """ - if t3_cond.cond_prompt_speech_tokens is not None and t3_cond.cond_prompt_speech_emb is None: - t3_cond.cond_prompt_speech_emb = self.speech_emb(t3_cond.cond_prompt_speech_tokens) + \ - self.speech_pos_emb(t3_cond.cond_prompt_speech_tokens) - return self.cond_enc(t3_cond) # (B, len_cond, dim) - - def prepare_input_embeds( - self, - *, - t3_cond: T3Cond, - text_tokens: torch.LongTensor, - speech_tokens: torch.LongTensor, - cfg_weight: float = 0.0, - ): - # prepare input embeddings (skip backbone tranformer embeddings) - cond_emb = self.prepare_conditioning(t3_cond) # (B, len_cond, dim) - text_emb = self.text_emb(text_tokens) # (B, len_text, dim) - if cfg_weight > 0.0: - text_emb[1].zero_() # CFG uncond - - speech_emb = self.speech_emb(speech_tokens) # (B, len_speech, dim) - if self.hp.input_pos_emb == "learned": - text_emb = text_emb + self.text_pos_emb(text_tokens) - speech_emb = speech_emb + self.speech_pos_emb(speech_tokens) - len_cond = cond_emb.size(1) - - if cond_emb.size(0) != text_emb.size(0): - cond_emb = cond_emb.expand(text_emb.size(0), -1, -1) - - # concat - embeds = torch.stack([ - torch.cat((ce, te, se)) - for ce, te, se in zip(cond_emb, text_emb, speech_emb) - ]) # (B, length, dim) - return embeds, len_cond - - def forward( - self, - *, - t3_cond: T3Cond, - text_tokens: torch.LongTensor, - text_token_lens: torch.LongTensor, - speech_tokens: torch.LongTensor, - speech_token_lens: torch.LongTensor, - training=False, - ): - _ensure_BOT_EOT(text_tokens, self.hp) - - # prepare custom input embeds - embeds, len_cond = self.prepare_input_embeds( - t3_cond=t3_cond, - text_tokens=text_tokens, - speech_tokens=speech_tokens, - ) - - # backbone tranformer forward - tfmr_out = self.tfmr.forward( - input_ids=None, - # position_ids=position_ids, # TODO? ROPE should be fine? - inputs_embeds=embeds, - output_hidden_states=True, - return_dict=True, - use_cache=(not training), - ) - hidden_states = tfmr_out.hidden_states[-1] # final tfmr layer output, (B, seq, dim) - - # post-processing: splice out text and speech parts of hidden states - len_text = text_tokens.size(1) - len_speech = speech_tokens.size(1) - B, _, dim = hidden_states.shape - device, dtype = hidden_states.device, hidden_states.dtype - text_latents = torch.zeros(B, len_text, dim, dtype=dtype, device=device) - speech_latents = torch.zeros(B, len_speech, dim, dtype=dtype, device=device) - ttl, stl = text_token_lens, speech_token_lens - for i in range(B): - text_end = len_cond + ttl[i].item() - speech_start = len_cond + text_tokens.size(1) - speech_end = speech_start + stl[i].item() - text_latents[i, :ttl[i]] = hidden_states[i, len_cond:text_end] - speech_latents[i, :stl[i]] = hidden_states[i, speech_start:speech_end] - - # logit projection - text_logits = self.text_head(text_latents) - speech_logits = self.speech_head(speech_latents) - - return AttrDict( - text_logits=text_logits, - text_latents=text_latents, - speech_logits=speech_logits, - speech_latents=speech_latents, - hidden_states=hidden_states, - ) - - def loss( - self, - *, - t3_cond: T3Cond, - text_tokens: torch.LongTensor, - text_token_lens: torch.LongTensor, - speech_tokens: torch.LongTensor, - speech_token_lens: torch.LongTensor, - ): - "training method" - len_text = text_tokens.size(1) - len_speech = speech_tokens.size(1) - assert len_text == text_token_lens.max() - assert len_speech == speech_token_lens.max() - - out = self.forward( - t3_cond=t3_cond, - text_tokens=text_tokens, - text_token_lens=text_token_lens, - speech_tokens=speech_tokens, - speech_token_lens=speech_token_lens, - training=True, - ) # (B, seq, vocab_size) - - # Calc CCE losses - IGNORE_ID = -100 - device = out.text_logits.device - mask_text = torch.arange(len_text, device=device)[None] >= text_token_lens[:, None] # (B, len_text) - mask_speech = torch.arange(len_speech, device=device)[None] >= speech_token_lens[:, None] # (B, len_speech) - masked_text = text_tokens.masked_fill(mask_text, IGNORE_ID) - masked_speech = speech_tokens.masked_fill(mask_speech, IGNORE_ID) - loss_text = F.cross_entropy(out.text_logits, masked_text, ignore_index=IGNORE_ID) - loss_speech = F.cross_entropy(out.speech_logits, masked_speech, ignore_index=IGNORE_ID) - - return loss_text, loss_speech - - @torch.inference_mode() - def inference( - self, - *, - t3_cond: T3Cond, - text_tokens: Tensor, - initial_speech_tokens: Optional[Tensor]=None, - - # misc conditioning - prepend_prompt_speech_tokens: Optional[Tensor]=None, - - # HF generate args - num_return_sequences=1, - max_new_tokens=None, - stop_on_eos=True, - do_sample=True, - temperature=0.8, - top_p=0.8, - length_penalty=1.0, - repetition_penalty=2.0, - cfg_weight=0, - ): - """ - Args: - text_tokens: a 1D (unbatched) or 2D (batched) tensor. - """ - # Validate / sanitize inputs - assert prepend_prompt_speech_tokens is None, "not implemented" - _ensure_BOT_EOT(text_tokens, self.hp) - text_tokens = torch.atleast_2d(text_tokens).to(dtype=torch.long, device=self.device) - - # Default initial speech to a single start-of-speech token - if initial_speech_tokens is None: - initial_speech_tokens = self.hp.start_speech_token * torch.ones_like(text_tokens[:, :1]) - - # Prepare custom input embeds - embeds, len_cond = self.prepare_input_embeds( - t3_cond=t3_cond, - text_tokens=text_tokens, - speech_tokens=initial_speech_tokens, - cfg_weight=cfg_weight, - ) - - # In order to use the standard HF generate method, we need to extend some methods to inject our custom logic - # Note the llama-specific logic. Other tfmr types can be added later. - - self.compiled = False - - # TODO? synchronize the expensive compile function - # with self.compile_lock: - if not self.compiled: - alignment_stream_analyzer = AlignmentStreamAnalyzer( - self.tfmr, - None, - text_tokens_slice=(len_cond, len_cond + text_tokens.size(-1)), - alignment_layer_idx=9, # TODO: hparam or something? - eos_idx=self.hp.stop_speech_token, - ) - patched_model = T3HuggingfaceBackend( - config=self.cfg, - llama=self.tfmr, - speech_enc=self.speech_emb, - speech_head=self.speech_head, - alignment_stream_analyzer=alignment_stream_analyzer, - ) - self.patched_model = patched_model - self.compiled = True - - # # Run normal generate method, which calls our custom extended methods - # return self.patched_model.generate( - # inputs=initial_speech_tokens, - # decoder_cond=embeds, - # bos_token_id=self.hp.start_speech_token, - # eos_token_id=(self.hp.stop_speech_token if stop_on_eos else -1), - # pad_token_id=self.hp.stop_speech_token, - # max_new_tokens=max_new_tokens or self.hp.max_speech_tokens, - # num_return_sequences=num_return_sequences, - # temperature=temperature, - # top_p=top_p, - # length_penalty=length_penalty, - # repetition_penalty=repetition_penalty, - # do_sample=do_sample, - # # cache_implementation=None if not self.compiled else "static", - # ) - - device = embeds.device - - bos_token = torch.tensor([[self.hp.start_speech_token]], dtype=torch.long, device=device) - bos_embed = self.speech_emb(bos_token) # shape: (B, 1, embed_dim) - bos_embed = bos_embed + self.speech_pos_emb.get_fixed_embedding(0) - - # batch_size=2 for CFG - bos_embed = torch.cat([bos_embed, bos_embed]) - - # Combine condition and BOS token for the initial input if cfg_weight > 0 - if cfg_weight > 0: - inputs_embeds = torch.cat([embeds, bos_embed], dim=1) - else: - inputs_embeds = embeds - - # Track generated token ids; start with the BOS token. - generated_ids = bos_token.clone() - predicted = [] # To store the predicted tokens - - # Instantiate the logits processors. - top_p_warper = TopPLogitsWarper(top_p=top_p) - repetition_penalty_processor = RepetitionPenaltyLogitsProcessor(penalty=repetition_penalty) - - # ---- Initial Forward Pass (no kv_cache yet) ---- - output = self.patched_model( - inputs_embeds=inputs_embeds, - past_key_values=None, - use_cache=True, - output_attentions=True, - output_hidden_states=True, - return_dict=True, - ) - # Initialize kv_cache with the full context. - past = output.past_key_values - - # ---- Generation Loop using kv_cache ---- - for i in tqdm(range(max_new_tokens), desc="Sampling", dynamic_ncols=True): - logits = output.logits[:, -1, :] - - # CFG - if cfg_weight > 0.0: - logits_cond = logits[0:1] - logits_uncond = logits[1:2] - logits = logits_cond + cfg_weight * (logits_cond - logits_uncond) - - logits = logits.squeeze(1) - - # Apply temperature scaling. - if temperature != 1.0: - logits = logits / temperature - - # Apply repetition penalty and top‑p filtering. - logits = repetition_penalty_processor(generated_ids, logits) - logits = top_p_warper(None, logits) - - # Convert logits to probabilities and sample the next token. - probs = torch.softmax(logits, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) # shape: (B, 1) - - predicted.append(next_token) - generated_ids = torch.cat([generated_ids, next_token], dim=1) - - # Check for EOS token. - if next_token.view(-1) == self.hp.stop_speech_token: - break - - # Get embedding for the new token. - next_token_embed = self.speech_emb(next_token) - next_token_embed = next_token_embed + self.speech_pos_emb.get_fixed_embedding(i + 1) - - # For CFG - if cfg_weight > 0.0: - next_token_embed = torch.cat([next_token_embed, next_token_embed]) - - # Forward pass with only the new token and the cached past. - output = self.patched_model( - inputs_embeds=next_token_embed, - past_key_values=past, - output_attentions=True, - output_hidden_states=True, - return_dict=True, - ) - # Update the kv_cache. - past = output.past_key_values - - # Concatenate all predicted tokens along the sequence dimension. - predicted_tokens = torch.cat(predicted, dim=1) # shape: (B, num_tokens) - return predicted_tokens diff --git a/HF_Deploy/src/chatterbox/models/tokenizers/__init__.py b/HF_Deploy/src/chatterbox/models/tokenizers/__init__.py deleted file mode 100644 index 97457e6fd720a10b2c64d2cdbabce9ca5fbf9aad..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/tokenizers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .tokenizer import EnTokenizer diff --git a/HF_Deploy/src/chatterbox/models/tokenizers/tokenizer.py b/HF_Deploy/src/chatterbox/models/tokenizers/tokenizer.py deleted file mode 100644 index f3536bc24db7d37cca9faff11c064c2c5d7c1c64..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/tokenizers/tokenizer.py +++ /dev/null @@ -1,50 +0,0 @@ -import logging - -import torch -from tokenizers import Tokenizer - - -# Special tokens -SOT = "[START]" -EOT = "[STOP]" -UNK = "[UNK]" -SPACE = "[SPACE]" -SPECIAL_TOKENS = [SOT, EOT, UNK, SPACE, "[PAD]", "[SEP]", "[CLS]", "[MASK]"] - -logger = logging.getLogger(__name__) - -class EnTokenizer: - def __init__(self, vocab_file_path): - self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path) - self.check_vocabset_sot_eot() - - def check_vocabset_sot_eot(self): - voc = self.tokenizer.get_vocab() - assert SOT in voc - assert EOT in voc - - def text_to_tokens(self, text: str): - text_tokens = self.encode(text) - text_tokens = torch.IntTensor(text_tokens).unsqueeze(0) - return text_tokens - - def encode( self, txt: str, verbose=False): - """ - clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer - """ - txt = txt.replace(' ', SPACE) - code = self.tokenizer.encode(txt) - ids = code.ids - return ids - - def decode(self, seq): - if isinstance(seq, torch.Tensor): - seq = seq.cpu().numpy() - - txt: str = self.tokenizer.decode(seq, - skip_special_tokens=False) - txt = txt.replace(' ', '') - txt = txt.replace(SPACE, ' ') - txt = txt.replace(EOT, '') - txt = txt.replace(UNK, '') - return txt diff --git a/HF_Deploy/src/chatterbox/models/voice_encoder/__init__.py b/HF_Deploy/src/chatterbox/models/voice_encoder/__init__.py deleted file mode 100644 index 529e1e63e89f179ec06829bfc5f1afc80912433f..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/voice_encoder/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .voice_encoder import VoiceEncoder, VoiceEncConfig diff --git a/HF_Deploy/src/chatterbox/models/voice_encoder/config.py b/HF_Deploy/src/chatterbox/models/voice_encoder/config.py deleted file mode 100644 index 8e9782a20eac8bc41afaf38d80a8af862adac232..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/voice_encoder/config.py +++ /dev/null @@ -1,18 +0,0 @@ -class VoiceEncConfig: - num_mels = 40 - sample_rate = 16000 - speaker_embed_size = 256 - ve_hidden_size = 256 - flatten_lstm_params = False - n_fft = 400 - hop_size = 160 - win_size = 400 - fmax = 8000 - fmin = 0 - preemphasis = 0. - mel_power = 2.0 - mel_type = "amp" - normalized_mels = False - ve_partial_frames = 160 - ve_final_relu = True - stft_magnitude_min = 1e-4 diff --git a/HF_Deploy/src/chatterbox/models/voice_encoder/melspec.py b/HF_Deploy/src/chatterbox/models/voice_encoder/melspec.py deleted file mode 100644 index 69147fc8c591c9364ff829a157af0ea3fcbd5770..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/voice_encoder/melspec.py +++ /dev/null @@ -1,78 +0,0 @@ -from functools import lru_cache - -from scipy import signal -import numpy as np -import librosa - - -@lru_cache() -def mel_basis(hp): - assert hp.fmax <= hp.sample_rate // 2 - return librosa.filters.mel( - sr=hp.sample_rate, - n_fft=hp.n_fft, - n_mels=hp.num_mels, - fmin=hp.fmin, - fmax=hp.fmax) # -> (nmel, nfreq) - - -def preemphasis(wav, hp): - assert hp.preemphasis != 0 - wav = signal.lfilter([1, -hp.preemphasis], [1], wav) - wav = np.clip(wav, -1, 1) - return wav - - -def melspectrogram(wav, hp, pad=True): - # Run through pre-emphasis - if hp.preemphasis > 0: - wav = preemphasis(wav, hp) - assert np.abs(wav).max() - 1 < 1e-07 - - # Do the stft - spec_complex = _stft(wav, hp, pad=pad) - - # Get the magnitudes - spec_magnitudes = np.abs(spec_complex) - - if hp.mel_power != 1.0: - spec_magnitudes **= hp.mel_power - - # Get the mel and convert magnitudes->db - mel = np.dot(mel_basis(hp), spec_magnitudes) - if hp.mel_type == "db": - mel = _amp_to_db(mel, hp) - - # Normalise the mel from db to 0,1 - if hp.normalized_mels: - mel = _normalize(mel, hp).astype(np.float32) - - assert not pad or mel.shape[1] == 1 + len(wav) // hp.hop_size # Sanity check - return mel # (M, T) - - -def _stft(y, hp, pad=True): - # NOTE: after 0.8, pad mode defaults to constant, setting this to reflect for - # historical consistency and streaming-version consistency - return librosa.stft( - y, - n_fft=hp.n_fft, - hop_length=hp.hop_size, - win_length=hp.win_size, - center=pad, - pad_mode="reflect", - ) - - -def _amp_to_db(x, hp): - return 20 * np.log10(np.maximum(hp.stft_magnitude_min, x)) - - -def _db_to_amp(x): - return np.power(10.0, x * 0.05) - - -def _normalize(s, hp, headroom_db=15): - min_level_db = 20 * np.log10(hp.stft_magnitude_min) - s = (s - min_level_db) / (-min_level_db + headroom_db) - return s diff --git a/HF_Deploy/src/chatterbox/models/voice_encoder/voice_encoder.py b/HF_Deploy/src/chatterbox/models/voice_encoder/voice_encoder.py deleted file mode 100644 index d986f17fd6afab59364863b5e92fd56eec21236b..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/models/voice_encoder/voice_encoder.py +++ /dev/null @@ -1,274 +0,0 @@ -# Adapted from https://github.com/CorentinJ/Real-Time-Voice-Cloning -# MIT License -from typing import List, Union, Optional - -import numpy as np -from numpy.lib.stride_tricks import as_strided -import librosa -import torch -import torch.nn.functional as F -from torch import nn, Tensor - -from .config import VoiceEncConfig -from .melspec import melspectrogram - - -def pack(arrays, seq_len: int=None, pad_value=0): - """ - Given a list of length B of array-like objects of shapes (Ti, ...), packs them in a single tensor of - shape (B, T, ...) by padding each individual array on the right. - - :param arrays: a list of array-like objects of matching shapes except for the first axis. - :param seq_len: the value of T. It must be the maximum of the lengths Ti of the arrays at - minimum. Will default to that value if None. - :param pad_value: the value to pad the arrays with. - :return: a (B, T, ...) tensor - """ - if seq_len is None: - seq_len = max(len(array) for array in arrays) - else: - assert seq_len >= max(len(array) for array in arrays) - - # Convert lists to np.array - if isinstance(arrays[0], list): - arrays = [np.array(array) for array in arrays] - - # Convert to tensor and handle device - device = None - if isinstance(arrays[0], torch.Tensor): - tensors = arrays - device = tensors[0].device - else: - tensors = [torch.as_tensor(array) for array in arrays] - - # Fill the packed tensor with the array data - packed_shape = (len(tensors), seq_len, *tensors[0].shape[1:]) - packed_tensor = torch.full(packed_shape, pad_value, dtype=tensors[0].dtype, device=device) - - for i, tensor in enumerate(tensors): - packed_tensor[i, :tensor.size(0)] = tensor - - return packed_tensor - - -def get_num_wins( - n_frames: int, - step: int, - min_coverage: float, - hp: VoiceEncConfig, -): - assert n_frames > 0 - win_size = hp.ve_partial_frames - n_wins, remainder = divmod(max(n_frames - win_size + step, 0), step) - if n_wins == 0 or (remainder + (win_size - step)) / win_size >= min_coverage: - n_wins += 1 - target_n = win_size + step * (n_wins - 1) - return n_wins, target_n - - -def get_frame_step( - overlap: float, - rate: float, - hp: VoiceEncConfig, -): - # Compute how many frames separate two partial utterances - assert 0 <= overlap < 1 - if rate is None: - frame_step = int(np.round(hp.ve_partial_frames * (1 - overlap))) - else: - frame_step = int(np.round((hp.sample_rate / rate) / hp.ve_partial_frames)) - assert 0 < frame_step <= hp.ve_partial_frames - return frame_step - - -def stride_as_partials( - mel: np.ndarray, - hp: VoiceEncConfig, - overlap=0.5, - rate: float=None, - min_coverage=0.8, -): - """ - Takes unscaled mels in (T, M) format - TODO: doc - """ - assert 0 < min_coverage <= 1 - frame_step = get_frame_step(overlap, rate, hp) - - # Compute how many partials can fit in the mel - n_partials, target_len = get_num_wins(len(mel), frame_step, min_coverage, hp) - - # Trim or pad the mel spectrogram to match the number of partials - if target_len > len(mel): - mel = np.concatenate((mel, np.full((target_len - len(mel), hp.num_mels), 0))) - elif target_len < len(mel): - mel = mel[:target_len] - - # Ensure the numpy array data is float32 and contiguous in memory - mel = mel.astype(np.float32, order="C") - - # Re-arrange the array in memory to be of shape (N, P, M) with partials overlapping eachother, - # where N is the number of partials, P is the number of frames of each partial and M the - # number of channels of the mel spectrograms. - shape = (n_partials, hp.ve_partial_frames, hp.num_mels) - strides = (mel.strides[0] * frame_step, mel.strides[0], mel.strides[1]) - partials = as_strided(mel, shape, strides) - return partials - - -class VoiceEncoder(nn.Module): - def __init__(self, hp=VoiceEncConfig()): - super().__init__() - - self.hp = hp - - # Network definition - self.lstm = nn.LSTM(self.hp.num_mels, self.hp.ve_hidden_size, num_layers=3, batch_first=True) - if hp.flatten_lstm_params: - self.lstm.flatten_parameters() - self.proj = nn.Linear(self.hp.ve_hidden_size, self.hp.speaker_embed_size) - - # Cosine similarity scaling (fixed initial parameter values) - self.similarity_weight = nn.Parameter(torch.tensor([10.]), requires_grad=True) - self.similarity_bias = nn.Parameter(torch.tensor([-5.]), requires_grad=True) - - @property - def device(self): - return next(self.parameters()).device - - def forward(self, mels: torch.FloatTensor): - """ - Computes the embeddings of a batch of partial utterances. - - :param mels: a batch of unscaled mel spectrograms of same duration as a float32 tensor - of shape (B, T, M) where T is hp.ve_partial_frames - :return: the embeddings as a float32 tensor of shape (B, E) where E is - hp.speaker_embed_size. Embeddings are L2-normed and thus lay in the range [-1, 1]. - """ - if self.hp.normalized_mels and (mels.min() < 0 or mels.max() > 1): - raise Exception(f"Mels outside [0, 1]. Min={mels.min()}, Max={mels.max()}") - - # Pass the input through the LSTM layers - _, (hidden, _) = self.lstm(mels) - - # Project the final hidden state - raw_embeds = self.proj(hidden[-1]) - if self.hp.ve_final_relu: - raw_embeds = F.relu(raw_embeds) - - # L2 normalize the embeddings. - return raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True) - - def inference(self, mels: torch.Tensor, mel_lens, overlap=0.5, rate: float=None, min_coverage=0.8, batch_size=None): - """ - Computes the embeddings of a batch of full utterances with gradients. - - :param mels: (B, T, M) unscaled mels - :return: (B, E) embeddings on CPU - """ - mel_lens = mel_lens.tolist() if torch.is_tensor(mel_lens) else mel_lens - - # Compute where to split the utterances into partials - frame_step = get_frame_step(overlap, rate, self.hp) - n_partials, target_lens = zip(*(get_num_wins(l, frame_step, min_coverage, self.hp) for l in mel_lens)) - - # Possibly pad the mels to reach the target lengths - len_diff = max(target_lens) - mels.size(1) - if len_diff > 0: - pad = torch.full((mels.size(0), len_diff, self.hp.num_mels), 0, dtype=torch.float32) - mels = torch.cat((mels, pad.to(mels.device)), dim=1) - - # Group all partials together so that we can batch them easily - partials = [ - mel[i * frame_step: i * frame_step + self.hp.ve_partial_frames] - for mel, n_partial in zip(mels, n_partials) for i in range(n_partial) - ] - assert all(partials[0].shape == partial.shape for partial in partials) - partials = torch.stack(partials) - - # Forward the partials - n_chunks = int(np.ceil(len(partials) / (batch_size or len(partials)))) - partial_embeds = torch.cat([self(batch) for batch in partials.chunk(n_chunks)], dim=0).cpu() - - # Reduce the partial embeds into full embeds and L2-normalize them - slices = np.concatenate(([0], np.cumsum(n_partials))) - raw_embeds = [torch.mean(partial_embeds[start:end], dim=0) for start, end in zip(slices[:-1], slices[1:])] - raw_embeds = torch.stack(raw_embeds) - embeds = raw_embeds / torch.linalg.norm(raw_embeds, dim=1, keepdim=True) - - return embeds - - @staticmethod - def utt_to_spk_embed(utt_embeds: np.ndarray): - """ - Takes an array of L2-normalized utterance embeddings, computes the mean embedding and L2-normalize it to get a - speaker embedding. - """ - assert utt_embeds.ndim == 2 - utt_embeds = np.mean(utt_embeds, axis=0) - return utt_embeds / np.linalg.norm(utt_embeds, 2) - - @staticmethod - def voice_similarity(embeds_x: np.ndarray, embeds_y: np.ndarray): - """ - Cosine similarity for L2-normalized utterance embeddings or speaker embeddings - """ - embeds_x = embeds_x if embeds_x.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_x) - embeds_y = embeds_y if embeds_y.ndim == 1 else VoiceEncoder.utt_to_spk_embed(embeds_y) - return embeds_x @ embeds_y - - def embeds_from_mels( - self, mels: Union[Tensor, List[np.ndarray]], mel_lens=None, as_spk=False, batch_size=32, **kwargs - ): - """ - Convenience function for deriving utterance or speaker embeddings from mel spectrograms. - - :param mels: unscaled mels strictly within [0, 1] as either a (B, T, M) tensor or a list of (Ti, M) arrays. - :param mel_lens: if passing mels as a tensor, individual mel lengths - :param as_spk: whether to return utterance embeddings or a single speaker embedding - :param kwargs: args for inference() - - :returns: embeds as a (B, E) float32 numpy array if is False, else as a (E,) array - """ - # Load mels in memory and pack them - if isinstance(mels, List): - mels = [np.asarray(mel) for mel in mels] - assert all(m.shape[1] == mels[0].shape[1] for m in mels), "Mels aren't in (B, T, M) format" - mel_lens = [mel.shape[0] for mel in mels] - mels = pack(mels) - - # Embed them - with torch.inference_mode(): - utt_embeds = self.inference(mels.to(self.device), mel_lens, batch_size=batch_size, **kwargs).numpy() - - return self.utt_to_spk_embed(utt_embeds) if as_spk else utt_embeds - - def embeds_from_wavs( - self, - wavs: List[np.ndarray], - sample_rate, - as_spk=False, - batch_size=32, - trim_top_db: Optional[float]=20, - **kwargs - ): - """ - Wrapper around embeds_from_mels - - :param trim_top_db: this argument was only added for the sake of compatibility with metavoice's implementation - """ - if sample_rate != self.hp.sample_rate: - wavs = [ - librosa.resample(wav, orig_sr=sample_rate, target_sr=self.hp.sample_rate, res_type="kaiser_fast") - for wav in wavs - ] - - if trim_top_db: - wavs = [librosa.effects.trim(wav, top_db=trim_top_db)[0] for wav in wavs] - - if "rate" not in kwargs: - kwargs["rate"] = 1.3 # Resemble's default value. - - mels = [melspectrogram(w, self.hp).T for w in wavs] - - return self.embeds_from_mels(mels, as_spk=as_spk, batch_size=batch_size, **kwargs) diff --git a/HF_Deploy/src/chatterbox/tts.py b/HF_Deploy/src/chatterbox/tts.py deleted file mode 100644 index a7f223058d13ca08fd55f41fa823d98a74c0cc3d..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/tts.py +++ /dev/null @@ -1,281 +0,0 @@ -from dataclasses import dataclass -from pathlib import Path - -import librosa -import torch -import perth -import torch.nn.functional as F -from huggingface_hub import hf_hub_download -from safetensors.torch import load_file - -from .models.t3 import T3 -from .models.s3tokenizer import S3_SR, drop_invalid_tokens -from .models.s3gen import S3GEN_SR, S3Gen -from .models.tokenizers import EnTokenizer -from .models.voice_encoder import VoiceEncoder -from .models.t3.modules.cond_enc import T3Cond - - -REPO_ID = "ResembleAI/chatterbox" - - -def punc_norm(text: str) -> str: - """ - Quick cleanup func for punctuation from LLMs or - containing chars not seen often in the dataset - """ - if len(text) == 0: - return "You need to add some text for me to talk." - - # Capitalise first letter - if text[0].islower(): - text = text[0].upper() + text[1:] - - # Remove multiple space chars - text = " ".join(text.split()) - - # Replace uncommon/llm punc - punc_to_replace = [ - ("...", ", "), - ("…", ", "), - (":", ","), - (" - ", ", "), - (";", ", "), - ("β€”", "-"), - ("–", "-"), - (" ,", ","), - (""", '"'), - (""", '"'), - ("β€˜", "'"), - ("’", "'"), - ] - for old_char_sequence, new_char in punc_to_replace: - text = text.replace(old_char_sequence, new_char) - - # Add full stop if no ending punc - text = text.rstrip(" ") - sentence_enders = {".", "!", "?", "-", ","} - - # Check for punctuation at end, including inside quotes - has_ending_punct = False - if any(text.endswith(p) for p in sentence_enders): - has_ending_punct = True - elif len(text) >= 2 and text[-1] in ['"', "'"] and text[-2] in sentence_enders: - # Check for punctuation before closing quote: ?" or .' - has_ending_punct = True - - if not has_ending_punct: - text += "." - - return text - - -@dataclass -class Conditionals: - """ - Conditionals for T3 and S3Gen - - T3 conditionals: - - speaker_emb - - clap_emb - - cond_prompt_speech_tokens - - cond_prompt_speech_emb - - emotion_adv - - S3Gen conditionals: - - prompt_token - - prompt_token_len - - prompt_feat - - prompt_feat_len - - embedding - """ - t3: T3Cond - gen: dict - - def to(self, device): - self.t3 = self.t3.to(device=device) - for k, v in self.gen.items(): - if torch.is_tensor(v): - self.gen[k] = v.to(device=device) - return self - - def save(self, fpath: Path): - arg_dict = dict( - t3=self.t3.__dict__, - gen=self.gen - ) - torch.save(arg_dict, fpath) - - @classmethod - def load(cls, fpath, map_location="cpu"): - if isinstance(map_location, str): - map_location = torch.device(map_location) - kwargs = torch.load(fpath, map_location=map_location, weights_only=True) - return cls(T3Cond(**kwargs['t3']), kwargs['gen']) - - -class ChatterboxTTS: - ENC_COND_LEN = 6 * S3_SR - DEC_COND_LEN = 10 * S3GEN_SR - - def __init__( - self, - t3: T3, - s3gen: S3Gen, - ve: VoiceEncoder, - tokenizer: EnTokenizer, - device: str, - conds: Conditionals = None, - ): - self.sr = S3GEN_SR # sample rate of synthesized audio - self.t3 = t3 - self.s3gen = s3gen - self.ve = ve - self.tokenizer = tokenizer - self.device = device - self.conds = conds - self.watermarker = perth.PerthImplicitWatermarker() - - @classmethod - def from_local(cls, ckpt_dir, device) -> 'ChatterboxTTS': - ckpt_dir = Path(ckpt_dir) - - # Always load to CPU first for non-CUDA devices to handle CUDA-saved models - if device in ["cpu", "mps"]: - map_location = torch.device('cpu') - else: - map_location = None - - ve = VoiceEncoder() - ve.load_state_dict( - load_file(ckpt_dir / "ve.safetensors") - ) - ve.to(device).eval() - - t3 = T3() - t3_state = load_file(ckpt_dir / "t3_cfg.safetensors") - if "model" in t3_state.keys(): - t3_state = t3_state["model"][0] - t3.load_state_dict(t3_state) - t3.to(device).eval() - - s3gen = S3Gen() - s3gen.load_state_dict( - load_file(ckpt_dir / "s3gen.safetensors"), strict=False - ) - s3gen.to(device).eval() - - tokenizer = EnTokenizer( - str(ckpt_dir / "tokenizer.json") - ) - - conds = None - if (builtin_voice := ckpt_dir / "conds.pt").exists(): - conds = Conditionals.load(builtin_voice, map_location=map_location).to(device) - - return cls(t3, s3gen, ve, tokenizer, device, conds=conds) - - @classmethod - def from_pretrained(cls, device) -> 'ChatterboxTTS': - # Check if MPS is available on macOS - if device == "mps" and not torch.backends.mps.is_available(): - if not torch.backends.mps.is_built(): - print("MPS not available because the current PyTorch install was not built with MPS enabled.") - else: - print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.") - device = "cpu" - - for fpath in ["ve.safetensors", "t3_cfg.safetensors", "s3gen.safetensors", "tokenizer.json", "conds.pt"]: - local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath) - - return cls.from_local(Path(local_path).parent, device) - - def prepare_conditionals(self, wav_fpath, exaggeration=0.5): - ## Load reference wav - s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR) - - ref_16k_wav = librosa.resample(s3gen_ref_wav, orig_sr=S3GEN_SR, target_sr=S3_SR) - - s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN] - s3gen_ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device) - - # Speech cond prompt tokens - if plen := self.t3.hp.speech_cond_prompt_len: - s3_tokzr = self.s3gen.tokenizer - t3_cond_prompt_tokens, _ = s3_tokzr.forward([ref_16k_wav[:self.ENC_COND_LEN]], max_len=plen) - t3_cond_prompt_tokens = torch.atleast_2d(t3_cond_prompt_tokens).to(self.device) - - # Voice-encoder speaker embedding - ve_embed = torch.from_numpy(self.ve.embeds_from_wavs([ref_16k_wav], sample_rate=S3_SR)) - ve_embed = ve_embed.mean(axis=0, keepdim=True).to(self.device) - - t3_cond = T3Cond( - speaker_emb=ve_embed, - cond_prompt_speech_tokens=t3_cond_prompt_tokens, - emotion_adv=exaggeration * torch.ones(1, 1, 1), - ).to(device=self.device) - self.conds = Conditionals(t3_cond, s3gen_ref_dict) - - def generate( - self, - text, - audio_prompt_path=None, - exaggeration=0.5, - cfg_weight=0.5, - temperature=0.8, - min_p=0.05, - top_p=0.8, - repetition_penalty=2.0, - ): - if audio_prompt_path: - self.prepare_conditionals(audio_prompt_path, exaggeration=exaggeration) - else: - assert self.conds is not None, "Please `prepare_conditionals` first or specify `audio_prompt_path`" - - # Update exaggeration if needed - if exaggeration != self.conds.t3.emotion_adv[0, 0, 0]: - _cond: T3Cond = self.conds.t3 - self.conds.t3 = T3Cond( - speaker_emb=_cond.speaker_emb, - cond_prompt_speech_tokens=_cond.cond_prompt_speech_tokens, - emotion_adv=exaggeration * torch.ones(1, 1, 1), - ).to(device=self.device) - - # Norm and tokenize text - text = punc_norm(text) - text_tokens = self.tokenizer.text_to_tokens(text).to(self.device) - - if cfg_weight > 0.0: - text_tokens = torch.cat([text_tokens, text_tokens], dim=0) # Need two seqs for CFG - - sot = self.t3.hp.start_text_token - eot = self.t3.hp.stop_text_token - text_tokens = F.pad(text_tokens, (1, 0), value=sot) - text_tokens = F.pad(text_tokens, (0, 1), value=eot) - - with torch.inference_mode(): - speech_tokens = self.t3.inference( - t3_cond=self.conds.t3, - text_tokens=text_tokens, - max_new_tokens=1000, # TODO: use the value in config - temperature=temperature, - cfg_weight=cfg_weight, - min_p=min_p, - top_p=top_p, - repetition_penalty=repetition_penalty, - ) - # Extract only the conditional batch. - speech_tokens = speech_tokens[0] - - # TODO: output becomes 1D - speech_tokens = drop_invalid_tokens(speech_tokens) - - speech_tokens = speech_tokens[speech_tokens < 6561] - - speech_tokens = speech_tokens.to(self.device) - - wav, _ = self.s3gen.inference( - speech_tokens=speech_tokens, - ref_dict=self.conds.gen, - ) - wav = wav.squeeze(0).detach().cpu().numpy() - watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr) - return torch.from_numpy(watermarked_wav).unsqueeze(0) \ No newline at end of file diff --git a/HF_Deploy/src/chatterbox/vc.py b/HF_Deploy/src/chatterbox/vc.py deleted file mode 100644 index a9c32ed3567192f07eee68a78c0517c1324892fc..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox/vc.py +++ /dev/null @@ -1,104 +0,0 @@ -from pathlib import Path - -import librosa -import torch -import perth -from huggingface_hub import hf_hub_download -from safetensors.torch import load_file - -from .models.s3tokenizer import S3_SR -from .models.s3gen import S3GEN_SR, S3Gen - - -REPO_ID = "ResembleAI/chatterbox" - - -class ChatterboxVC: - ENC_COND_LEN = 6 * S3_SR - DEC_COND_LEN = 10 * S3GEN_SR - - def __init__( - self, - s3gen: S3Gen, - device: str, - ref_dict: dict=None, - ): - self.sr = S3GEN_SR - self.s3gen = s3gen - self.device = device - self.watermarker = perth.PerthImplicitWatermarker() - if ref_dict is None: - self.ref_dict = None - else: - self.ref_dict = { - k: v.to(device) if torch.is_tensor(v) else v - for k, v in ref_dict.items() - } - - @classmethod - def from_local(cls, ckpt_dir, device) -> 'ChatterboxVC': - ckpt_dir = Path(ckpt_dir) - - # Always load to CPU first for non-CUDA devices to handle CUDA-saved models - if device in ["cpu", "mps"]: - map_location = torch.device('cpu') - else: - map_location = None - - ref_dict = None - if (builtin_voice := ckpt_dir / "conds.pt").exists(): - states = torch.load(builtin_voice, map_location=map_location) - ref_dict = states['gen'] - - s3gen = S3Gen() - s3gen.load_state_dict( - load_file(ckpt_dir / "s3gen.safetensors"), strict=False - ) - s3gen.to(device).eval() - - return cls(s3gen, device, ref_dict=ref_dict) - - @classmethod - def from_pretrained(cls, device) -> 'ChatterboxVC': - # Check if MPS is available on macOS - if device == "mps" and not torch.backends.mps.is_available(): - if not torch.backends.mps.is_built(): - print("MPS not available because the current PyTorch install was not built with MPS enabled.") - else: - print("MPS not available because the current MacOS version is not 12.3+ and/or you do not have an MPS-enabled device on this machine.") - device = "cpu" - - for fpath in ["s3gen.safetensors", "conds.pt"]: - local_path = hf_hub_download(repo_id=REPO_ID, filename=fpath) - - return cls.from_local(Path(local_path).parent, device) - - def set_target_voice(self, wav_fpath): - ## Load reference wav - s3gen_ref_wav, _sr = librosa.load(wav_fpath, sr=S3GEN_SR) - - s3gen_ref_wav = s3gen_ref_wav[:self.DEC_COND_LEN] - self.ref_dict = self.s3gen.embed_ref(s3gen_ref_wav, S3GEN_SR, device=self.device) - - def generate( - self, - audio, - target_voice_path=None, - ): - if target_voice_path: - self.set_target_voice(target_voice_path) - else: - assert self.ref_dict is not None, "Please `prepare_conditionals` first or specify `target_voice_path`" - - with torch.inference_mode(): - audio_16, _ = librosa.load(audio, sr=S3_SR) - audio_16 = torch.from_numpy(audio_16).float().to(self.device)[None, ] - - s3_tokens, _ = self.s3gen.tokenizer(audio_16) - wav, _ = self.s3gen.inference( - speech_tokens=s3_tokens, - ref_dict=self.ref_dict, - ) - wav = wav.squeeze(0).detach().cpu().numpy() - watermarked_wav = self.watermarker.apply_watermark(wav, sample_rate=self.sr) - return torch.from_numpy(watermarked_wav).unsqueeze(0) \ No newline at end of file diff --git a/HF_Deploy/src/chatterbox_tts.egg-info/PKG-INFO b/HF_Deploy/src/chatterbox_tts.egg-info/PKG-INFO deleted file mode 100644 index a824b85d5e27fc6ebd05fe33445d7b0b6d55f49f..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox_tts.egg-info/PKG-INFO +++ /dev/null @@ -1,148 +0,0 @@ -Metadata-Version: 2.4 -Name: chatterbox-tts -Version: 0.1.2 -Summary: Chatterbox: Open Source TTS and Voice Conversion by Resemble AI -Author-email: resemble-ai -License: MIT License - - Copyright (c) 2025 Resemble AI - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE. -Project-URL: Homepage, https://github.com/resemble-ai/chatterbox -Project-URL: Repository, https://github.com/resemble-ai/chatterbox -Requires-Python: >=3.8 -Description-Content-Type: text/markdown -License-File: LICENSE -Requires-Dist: numpy~=1.26.0 -Requires-Dist: resampy==0.4.3 -Requires-Dist: librosa==0.11.0 -Requires-Dist: s3tokenizer -Requires-Dist: torch==2.6.0 -Requires-Dist: torchaudio==2.6.0 -Requires-Dist: transformers==4.46.3 -Requires-Dist: diffusers==0.29.0 -Requires-Dist: resemble-perth==1.0.1 -Requires-Dist: omegaconf==2.3.0 -Requires-Dist: conformer==0.3.2 -Requires-Dist: safetensors==0.5.3 -Dynamic: license-file - - -cb-big2 - -# Chatterbox TTS - -[![Alt Text](https://img.shields.io/badge/listen-demo_samples-blue)](https://resemble-ai.github.io/chatterbox_demopage/) -[![Alt Text](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/ResembleAI/Chatterbox) -[![Alt Text](https://static-public.podonos.com/badges/insight-on-pdns-sm-dark.svg)](https://podonos.com/resembleai/chatterbox) -[![Discord](https://img.shields.io/discord/1377773249798344776?label=join%20discord&logo=discord&style=flat)](https://discord.gg/XqS7RxUp) - -_Made with β™₯️ by resemble-logo-horizontal - -We're excited to introduce Chatterbox, [Resemble AI's](https://resemble.ai) first production-grade open source TTS model. Licensed under MIT, Chatterbox has been benchmarked against leading closed-source systems like ElevenLabs, and is consistently preferred in side-by-side evaluations. - -Whether you're working on memes, videos, games, or AI agents, Chatterbox brings your content to life. It's also the first open source TTS model to support **emotion exaggeration control**, a powerful feature that makes your voices stand out. Try it now on our [Hugging Face Gradio app.](https://huggingface.co/spaces/ResembleAI/Chatterbox) - -If you like the model but need to scale or tune it for higher accuracy, check out our competitively priced TTS service (link). It delivers reliable performance with ultra-low latency of sub 200msβ€”ideal for production use in agents, applications, or interactive media. - -# Key Details -- SoTA zeroshot TTS -- 0.5B Llama backbone -- Unique exaggeration/intensity control -- Ultra-stable with alignment-informed inference -- Trained on 0.5M hours of cleaned data -- Watermarked outputs -- Easy voice conversion script -- [Outperforms ElevenLabs](https://podonos.com/resembleai/chatterbox) - -# Tips -- **General Use (TTS and Voice Agents):** - - The default settings (`exaggeration=0.5`, `cfg_weight=0.5`) work well for most prompts. - - If the reference speaker has a fast speaking style, lowering `cfg_weight` to around `0.3` can improve pacing. - -- **Expressive or Dramatic Speech:** - - Try lower `cfg_weight` values (e.g. `~0.3`) and increase `exaggeration` to around `0.7` or higher. - - Higher `exaggeration` tends to speed up speech; reducing `cfg_weight` helps compensate with slower, more deliberate pacing. - - -# Installation -``` -pip install chatterbox-tts -``` - - -# Usage -```python -import torchaudio as ta -from chatterbox.tts import ChatterboxTTS - -model = ChatterboxTTS.from_pretrained(device="cuda") - -text = "Ezreal and Jinx teamed up with Ahri, Yasuo, and Teemo to take down the enemy's Nexus in an epic late-game pentakill." -wav = model.generate(text) -ta.save("test-1.wav", wav, model.sr) - -# If you want to synthesize with a different voice, specify the audio prompt -AUDIO_PROMPT_PATH = "YOUR_FILE.wav" -wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH) -ta.save("test-2.wav", wav, model.sr) -``` -See `example_tts.py` and `example_vc.py` for more examples. - -# Acknowledgements -- [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice) -- [Real-Time-Voice-Cloning](https://github.com/CorentinJ/Real-Time-Voice-Cloning) -- [HiFT-GAN](https://github.com/yl4579/HiFTNet) -- [Llama 3](https://github.com/meta-llama/llama3) -- [S3Tokenizer](https://github.com/xingchensong/S3Tokenizer) - -# Built-in PerTh Watermarking for Responsible AI - -Every audio file generated by Chatterbox includes [Resemble AI's Perth (Perceptual Threshold) Watermarker](https://github.com/resemble-ai/perth) - imperceptible neural watermarks that survive MP3 compression, audio editing, and common manipulations while maintaining nearly 100% detection accuracy. - - -## Watermark extraction - -You can look for the watermark using the following script. - -```python -import perth -import librosa - -AUDIO_PATH = "YOUR_FILE.wav" - -# Load the watermarked audio -watermarked_audio, sr = librosa.load(AUDIO_PATH, sr=None) - -# Initialize watermarker (same as used for embedding) -watermarker = perth.PerthImplicitWatermarker() - -# Extract watermark -watermark = watermarker.get_watermark(watermarked_audio, sample_rate=sr) -print(f"Extracted watermark: {watermark}") -# Output: 0.0 (no watermark) or 1.0 (watermarked) -``` - - -# Official Discord - -πŸ‘‹ Join us on [Discord](https://discord.gg/XqS7RxUp) and let's build something awesome together! - -# Disclaimer -Don't use this model to do bad things. Prompts are sourced from freely available data on the internet. diff --git a/HF_Deploy/src/chatterbox_tts.egg-info/SOURCES.txt b/HF_Deploy/src/chatterbox_tts.egg-info/SOURCES.txt deleted file mode 100644 index e1ef10f6375e8d0271b67dbd3f80817eb60a6514..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox_tts.egg-info/SOURCES.txt +++ /dev/null @@ -1,53 +0,0 @@ -LICENSE -README.md -pyproject.toml -src/chatterbox/__init__.py -src/chatterbox/tts.py -src/chatterbox/vc.py -src/chatterbox/models/s3gen/__init__.py -src/chatterbox/models/s3gen/const.py -src/chatterbox/models/s3gen/decoder.py -src/chatterbox/models/s3gen/f0_predictor.py -src/chatterbox/models/s3gen/flow.py -src/chatterbox/models/s3gen/flow_matching.py -src/chatterbox/models/s3gen/hifigan.py -src/chatterbox/models/s3gen/s3gen.py -src/chatterbox/models/s3gen/xvector.py -src/chatterbox/models/s3gen/matcha/decoder.py -src/chatterbox/models/s3gen/matcha/flow_matching.py -src/chatterbox/models/s3gen/matcha/text_encoder.py -src/chatterbox/models/s3gen/matcha/transformer.py -src/chatterbox/models/s3gen/transformer/__init__.py -src/chatterbox/models/s3gen/transformer/activation.py -src/chatterbox/models/s3gen/transformer/attention.py -src/chatterbox/models/s3gen/transformer/convolution.py -src/chatterbox/models/s3gen/transformer/embedding.py -src/chatterbox/models/s3gen/transformer/encoder_layer.py -src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py -src/chatterbox/models/s3gen/transformer/subsampling.py -src/chatterbox/models/s3gen/transformer/upsample_encoder.py -src/chatterbox/models/s3gen/utils/class_utils.py -src/chatterbox/models/s3gen/utils/mask.py -src/chatterbox/models/s3gen/utils/mel.py -src/chatterbox/models/s3tokenizer/__init__.py -src/chatterbox/models/s3tokenizer/s3tokenizer.py -src/chatterbox/models/t3/__init__.py -src/chatterbox/models/t3/llama_configs.py -src/chatterbox/models/t3/t3.py -src/chatterbox/models/t3/inference/alignment_stream_analyzer.py -src/chatterbox/models/t3/inference/t3_hf_backend.py -src/chatterbox/models/t3/modules/cond_enc.py -src/chatterbox/models/t3/modules/learned_pos_emb.py -src/chatterbox/models/t3/modules/perceiver.py -src/chatterbox/models/t3/modules/t3_config.py -src/chatterbox/models/tokenizers/__init__.py -src/chatterbox/models/tokenizers/tokenizer.py -src/chatterbox/models/voice_encoder/__init__.py -src/chatterbox/models/voice_encoder/config.py -src/chatterbox/models/voice_encoder/melspec.py -src/chatterbox/models/voice_encoder/voice_encoder.py -src/chatterbox_tts.egg-info/PKG-INFO -src/chatterbox_tts.egg-info/SOURCES.txt -src/chatterbox_tts.egg-info/dependency_links.txt -src/chatterbox_tts.egg-info/requires.txt -src/chatterbox_tts.egg-info/top_level.txt \ No newline at end of file diff --git a/HF_Deploy/src/chatterbox_tts.egg-info/dependency_links.txt b/HF_Deploy/src/chatterbox_tts.egg-info/dependency_links.txt deleted file mode 100644 index 8b137891791fe96927ad78e64b0aad7bded08bdc..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox_tts.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/HF_Deploy/src/chatterbox_tts.egg-info/requires.txt b/HF_Deploy/src/chatterbox_tts.egg-info/requires.txt deleted file mode 100644 index b47bb685e20bfdb4559edd4071b3b012d95e2832..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox_tts.egg-info/requires.txt +++ /dev/null @@ -1,12 +0,0 @@ -numpy~=1.26.0 -resampy==0.4.3 -librosa==0.11.0 -s3tokenizer -torch==2.6.0 -torchaudio==2.6.0 -transformers==4.46.3 -diffusers==0.29.0 -resemble-perth==1.0.1 -omegaconf==2.3.0 -conformer==0.3.2 -safetensors==0.5.3 diff --git a/HF_Deploy/src/chatterbox_tts.egg-info/top_level.txt b/HF_Deploy/src/chatterbox_tts.egg-info/top_level.txt deleted file mode 100644 index 62e76aa0de2aa3d9767119087811d40bb9150d65..0000000000000000000000000000000000000000 --- a/HF_Deploy/src/chatterbox_tts.egg-info/top_level.txt +++ /dev/null @@ -1 +0,0 @@ -chatterbox diff --git a/HF_Deploy/tools/combine_only.py b/HF_Deploy/tools/combine_only.py deleted file mode 100644 index 3ddbe262b38e57a665c0530628a3315a85c11fbd..0000000000000000000000000000000000000000 --- a/HF_Deploy/tools/combine_only.py +++ /dev/null @@ -1,396 +0,0 @@ -""" -Combine Only Tool -Standalone tool for combining existing audio chunks into final audiobook -""" - -import re -import time -import logging -from datetime import timedelta -from pathlib import Path - -from config.config import * -from modules.file_manager import ( - get_audio_files_in_directory, combine_audio_chunks, - convert_to_m4b, add_metadata_to_m4b, find_book_files -) -from modules.audio_processor import get_wav_duration -from modules.progress_tracker import log_console, log_run -import subprocess -import shutil - -def combine_audio_for_book(book_path_str, voice_name=None): - """Combine audio chunks for a specific book (GUI-friendly version)""" - from pathlib import Path - book_path = Path(book_path_str) - - print(f"\n{CYAN}πŸ”— Combining Audio Chunks for: {book_path.name}{RESET}") - print("=" * 60) - - # Setup paths - tts_dir = book_path / "TTS" - audio_chunks_dir = tts_dir / "audio_chunks" - - if not audio_chunks_dir.exists(): - print(f"{RED}❌ No audio_chunks folder found in {book_path}{RESET}") - print(f"πŸ’‘ Make sure this book has been processed with TTS generation first.") - return False - - # Find audio chunks - chunk_paths = get_audio_files_in_directory(audio_chunks_dir) - - if not chunk_paths: - print(f"{RED}❌ No chunk_*.wav files found in {audio_chunks_dir}{RESET}") - print(f"πŸ’‘ Expected files like: chunk_00001.wav, chunk_00002.wav, etc.") - return False - - print(f"\nπŸ“¦ Found {GREEN}{len(chunk_paths)}{RESET} audio chunks") - - # Verify chunk sequence - missing_chunks = verify_chunk_sequence(chunk_paths) - if missing_chunks: - print(f"\n⚠️ {YELLOW}Warning: Missing chunks detected:{RESET}") - for chunk_num in missing_chunks[:10]: # Show first 10 missing - print(f" Missing: chunk_{chunk_num:05}.wav") - if len(missing_chunks) > 10: - print(f" ... and {len(missing_chunks) - 10} more") - print(f"{YELLOW}πŸ”„ Continuing with available chunks for GUI operation...{RESET}") - - # Display chunk info - total_duration = sum(get_wav_duration(chunk_path) for chunk_path in chunk_paths) - duration_str = str(timedelta(seconds=int(total_duration))) - - print(f"\nπŸ“Š Chunk Analysis:") - print(f" Total Chunks: {GREEN}{len(chunk_paths)}{RESET}") - print(f" Total Duration: {GREEN}{duration_str}{RESET}") - print(f" Average Chunk: {GREEN}{total_duration/len(chunk_paths):.1f}s{RESET}") - - # Perform the actual combine operation - return _perform_combine_operation(book_path, chunk_paths, total_duration, voice_name) - -def _perform_combine_operation(book_path, chunk_paths, total_duration, voice_name=None): - """Perform the actual audio combining operation""" - import time - from datetime import timedelta - - basename = book_path.name - - # Determine file naming based on voice - if voice_name: - file_suffix = f" [{voice_name}]" - else: - file_suffix = "_combined" - - # Start timing - start_time = time.time() - - # Create concat file and combine - print(f"\nπŸ”— Combining audio chunks...") - combined_wav_path = book_path / f"{basename}{file_suffix}.wav" - - try: - combine_audio_chunks(chunk_paths, combined_wav_path) - print(f"βœ… Combined WAV created: {combined_wav_path.name}") - except Exception as e: - print(f"{RED}❌ Failed to combine chunks: {e}{RESET}") - return False - - # Find metadata files - text_book_dir = TEXT_INPUT_ROOT / basename - book_files = find_book_files(text_book_dir) - text_files, cover_file, nfo_file = book_files['text'], book_files['cover'], book_files['nfo'] - - if not cover_file: - print(f"⚠️ {YELLOW}No cover image found in {text_book_dir}{RESET}") - else: - print(f"πŸ“Έ Using cover: {cover_file.name}") - - if not nfo_file: - print(f"⚠️ {YELLOW}No book.nfo metadata found in {text_book_dir}{RESET}") - else: - print(f"πŸ“ Using metadata: {nfo_file.name}") - - # M4B conversion - print(f"\nπŸ“± Converting to M4B audiobook...") - temp_m4b_path = book_path / "temp_output.m4b" - final_m4b_path = book_path / f"{basename}{file_suffix}.m4b" - - try: - convert_to_m4b(combined_wav_path, temp_m4b_path) - add_metadata_to_m4b(temp_m4b_path, final_m4b_path, cover_file, nfo_file) - print(f"βœ… M4B audiobook created: {final_m4b_path.name}") - except Exception as e: - print(f"{RED}❌ Failed to create M4B: {e}{RESET}") - return False - - # Calculate final timing - elapsed_total = time.time() - start_time - elapsed_td = timedelta(seconds=int(elapsed_total)) - - # Verify final file - if final_m4b_path.exists(): - final_size = final_m4b_path.stat().st_size / (1024 * 1024) # MB - print(f"πŸ“¦ Final file size: {GREEN}{final_size:.1f} MB{RESET}") - - # Calculate efficiency - realtime_factor = total_duration / elapsed_total if elapsed_total > 0 else 0 - duration_str = str(timedelta(seconds=int(total_duration))) - - print(f"\nπŸŽ‰ {GREEN}Combine completed successfully!{RESET}") - print(f"πŸ“Š Final Statistics:") - print(f" Audio Duration: {GREEN}{duration_str}{RESET}") - print(f" Processing Time: {GREEN}{elapsed_td}{RESET}") - print(f" Realtime Factor: {GREEN}{realtime_factor:.2f}x{RESET}") - print(f" Output Location: {GREEN}{final_m4b_path}{RESET}") - - # Clean up temp files - try: - if temp_m4b_path.exists(): - temp_m4b_path.unlink() - print(f"🧹 Cleaned up temporary file: {temp_m4b_path.name}") - except Exception as e: - print(f"⚠️ Could not clean up temp file: {e}") - - return True - else: - print(f"{RED}❌ Final M4B file was not created successfully{RESET}") - return False - -def run_combine_only_mode(): - """Combine existing chunks into audiobook (CLI version)""" - print(f"\n{CYAN}πŸ”— Combine-Only Mode: Assembling Existing Audio Chunks{RESET}") - print("=" * 60) - - # Show available audiobooks - books = sorted([d for d in AUDIOBOOK_ROOT.iterdir() if d.is_dir()]) - if not books: - print(f"{RED}❌ No folders found in Audiobook/ directory.{RESET}") - print(f"πŸ’‘ Make sure you have processed books with audio chunks to combine.") - return None - - print(f"{CYAN}Available audiobooks to combine:{RESET}") - for i, book in enumerate(books): - # Check if it has audio chunks - audio_chunks_dir = book / "TTS" / "audio_chunks" - if audio_chunks_dir.exists(): - chunk_count = len(list(audio_chunks_dir.glob('chunk_*.wav'))) - status = f"({chunk_count} chunks)" if chunk_count > 0 else "(no chunks)" - print(f" [{i}] {book.name} {status}") - else: - print(f" [{i}] {book.name} (no TTS folder)") - - # Book selection - while True: - try: - idx = int(input(f"\n{YELLOW}Select audiobook index: {RESET}")) - if 0 <= idx < len(books): - break - else: - print(f"{RED}Invalid selection. Please enter a number between 0 and {len(books)-1}.{RESET}") - except (ValueError, KeyboardInterrupt): - print(f"{RED}Invalid selection. Please try again.{RESET}") - except EOFError: - print(f"\n{RED}❌ Input error - unable to read selection.{RESET}") - return None - except Exception as e: - print(f"{RED}❌ Unexpected error: {e}{RESET}") - return None - - selected_book = books[idx] - basename = selected_book.name - - print(f"\n🎯 Selected: {BOLD}{basename}{RESET}") - - # Setup paths - tts_dir = selected_book / "TTS" - audio_chunks_dir = tts_dir / "audio_chunks" - - if not audio_chunks_dir.exists(): - print(f"{RED}❌ No audio_chunks folder found in {selected_book}{RESET}") - print(f"πŸ’‘ Make sure this book has been processed with TTS generation first.") - return None - - # Find audio chunks - chunk_paths = get_audio_files_in_directory(audio_chunks_dir) - - if not chunk_paths: - print(f"{RED}❌ No chunk_*.wav files found in {audio_chunks_dir}{RESET}") - print(f"πŸ’‘ Expected files like: chunk_00001.wav, chunk_00002.wav, etc.") - return None - - print(f"\nπŸ“¦ Found {GREEN}{len(chunk_paths)}{RESET} audio chunks") - - # Verify chunk sequence - missing_chunks = verify_chunk_sequence(chunk_paths) - if missing_chunks: - print(f"\n⚠️ {YELLOW}Warning: Missing chunks detected:{RESET}") - for chunk_num in missing_chunks[:10]: # Show first 10 missing - print(f" Missing: chunk_{chunk_num:05}.wav") - if len(missing_chunks) > 10: - print(f" ... and {len(missing_chunks) - 10} more") - - try: - continue_anyway = input(f"\n{YELLOW}Continue with incomplete chunks? [y/N]: {RESET}").strip().lower() - if continue_anyway != 'y': - print("πŸ›‘ Combine operation cancelled.") - return None - except (EOFError, KeyboardInterrupt): - print(f"\n{RED}πŸ›‘ Combine operation cancelled.{RESET}") - return None - - # Display chunk info - total_duration = sum(get_wav_duration(chunk_path) for chunk_path in chunk_paths) - duration_str = str(timedelta(seconds=int(total_duration))) - - print(f"\nπŸ“Š Chunk Analysis:") - print(f" Total Chunks: {GREEN}{len(chunk_paths)}{RESET}") - print(f" Total Duration: {GREEN}{duration_str}{RESET}") - print(f" Average Chunk: {GREEN}{total_duration/len(chunk_paths):.1f}s{RESET}") - - # Use the shared combine operation (CLI doesn't pass voice name) - success = _perform_combine_operation(selected_book, chunk_paths, total_duration) - - if success: - return selected_book / f"{basename}_combined.m4b" - else: - return None - -def verify_chunk_sequence(chunk_paths): - """Verify chunk sequence and return missing chunk numbers""" - chunk_numbers = [] - - for chunk_path in chunk_paths: - match = re.match(r"chunk_(\d+)\.wav", chunk_path.name) - if match: - chunk_numbers.append(int(match.group(1))) - - if not chunk_numbers: - return [] - - chunk_numbers.sort() - expected_range = range(1, max(chunk_numbers) + 1) - missing = [num for num in expected_range if num not in chunk_numbers] - - return missing - -def list_available_books_for_combine(): - """List books available for combine operation""" - books_info = [] - - if not AUDIOBOOK_ROOT.exists(): - return books_info - - for book_dir in AUDIOBOOK_ROOT.iterdir(): - if not book_dir.is_dir(): - continue - - audio_chunks_dir = book_dir / "TTS" / "audio_chunks" - if not audio_chunks_dir.exists(): - continue - - chunk_paths = get_audio_files_in_directory(audio_chunks_dir) - if not chunk_paths: - continue - - # Calculate total duration - try: - total_duration = sum(get_wav_duration(chunk_path) for chunk_path in chunk_paths) - duration_str = str(timedelta(seconds=int(total_duration))) - except: - duration_str = "Unknown" - - books_info.append({ - "name": book_dir.name, - "path": book_dir, - "chunk_count": len(chunk_paths), - "duration": duration_str - }) - - return books_info - -def quick_combine(book_name): - """Quick combine operation for specific book (CLI usage)""" - book_path = AUDIOBOOK_ROOT / book_name - - if not book_path.exists(): - print(f"{RED}❌ Book '{book_name}' not found in Audiobook directory{RESET}") - return None - - audio_chunks_dir = book_path / "TTS" / "audio_chunks" - chunk_paths = get_audio_files_in_directory(audio_chunks_dir) - - if not chunk_paths: - print(f"{RED}❌ No audio chunks found for '{book_name}'{RESET}") - return None - - print(f"πŸ”— Quick combining {len(chunk_paths)} chunks for '{book_name}'...") - - # Use same logic as main function but without interactive prompts - combined_wav_path = book_path / f"{book_name}_quick_combined.wav" - final_m4b_path = book_path / f"{book_name}_quick_combined.m4b" - - combine_audio_chunks(chunk_paths, combined_wav_path) - - temp_m4b_path = book_path / "temp_quick.m4b" - convert_to_m4b(combined_wav_path, temp_m4b_path) - - # Simple M4B without metadata for quick operation - temp_m4b_path.rename(final_m4b_path) - - print(f"βœ… Quick combine complete: {final_m4b_path}") - return final_m4b_path - -def apply_playback_speed_to_m4b(input_m4b_path, output_m4b_path, speed_factor): - """Apply playback speed adjustment to M4B file using ffmpeg""" - try: - print(f"πŸ”„ Applying {speed_factor}x speed to {Path(input_m4b_path).name}") - - # Check if ffmpeg is available - if not shutil.which('ffmpeg'): - print("❌ ffmpeg not found - required for M4B speed adjustment") - return False - - # Build ffmpeg command for speed adjustment - cmd = [ - 'ffmpeg', '-y', # -y to overwrite output file - '-i', str(input_m4b_path), - '-filter:a', f'atempo={speed_factor}', # Audio speed adjustment - '-c:a', 'aac', # Re-encode to AAC for M4B compatibility - '-b:a', '64k', # Audio bitrate - str(output_m4b_path) - ] - - print(f"Running: {' '.join(cmd)}") - - # Execute ffmpeg command - result = subprocess.run( - cmd, - capture_output=True, - text=True, - timeout=300 # 5 minute timeout - ) - - if result.returncode == 0: - print(f"βœ… Successfully created speed-adjusted M4B: {Path(output_m4b_path).name}") - return True - else: - print(f"❌ ffmpeg failed: {result.stderr}") - return False - - except subprocess.TimeoutExpired: - print("❌ M4B speed adjustment timed out") - return False - except Exception as e: - print(f"❌ Error adjusting M4B speed: {e}") - return False - -if __name__ == "__main__": - import sys - - if len(sys.argv) > 1: - # CLI usage: python combine_only.py "Book Name" - book_name = sys.argv[1] - quick_combine(book_name) - else: - # Interactive mode - run_combine_only_mode() diff --git a/HF_Deploy/utils/abbreviations.txt b/HF_Deploy/utils/abbreviations.txt deleted file mode 100644 index 7d5a830cc503c876da29132e19aaa77f8083e993..0000000000000000000000000000000000000000 --- a/HF_Deploy/utils/abbreviations.txt +++ /dev/null @@ -1,11 +0,0 @@ -Dr. -> Doctor -Mr. -> Mister -Mrs. -> Missus -Ms. -> Miss -U.S. -> US -U.K. -> UK -etc. -> et cetera -vs. -> versus -1st -> first -2nd -> second -3rd -> third diff --git a/HF_Deploy/utils/generate_from_json.py b/HF_Deploy/utils/generate_from_json.py deleted file mode 100644 index 32a1c7545878c3dacea2e2aec805310fbfcb7f32..0000000000000000000000000000000000000000 --- a/HF_Deploy/utils/generate_from_json.py +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env python3 -""" -Direct Audio Generation from JSON Tool - -This script allows for generating audiobook chunks directly from a pre-existing -`chunks_info.json` file. It is intended for debugging and testing purposes, -allowing a user to manually edit the TTS parameters in the JSON file and -hear the results without the VADER analysis step. -""" - -import torch -from pathlib import Path -import sys -from concurrent.futures import ThreadPoolExecutor, as_completed -import time -from datetime import timedelta - -# Add project root to path to allow module imports -project_root = Path(__file__).parent -sys.path.append(str(project_root)) - -from config.config import * -from modules.tts_engine import load_optimized_model, process_one_chunk, prewarm_model_with_voice -from modules.file_manager import setup_book_directories, list_voice_samples, ensure_voice_sample_compatibility -from wrapper.chunk_loader import load_chunks -from chatterbox.tts import punc_norm -from modules.progress_tracker import log_chunk_progress, log_run - -def main(): - """Main function to drive the generation process.""" - print(f"{BOLD}{CYAN}--- Direct Audio Generation from JSON Tool ---\{RESET}") - - # 1. Get Book Name - book_name = input("Enter the book name (e.g., 'london'): ").strip() - if not book_name: - print("❌ Book name cannot be empty.") - return - - # 2. Locate and Load JSON - book_audio_dir = AUDIOBOOK_ROOT / book_name - json_path = book_audio_dir / "TTS" / "text_chunks" / "chunks_info.json" - - if not json_path.exists(): - print(f"❌ Error: JSON file not found at {json_path}") - print("Please ensure you have run the 'Prepare text file' option for this book first.") - return - - print(f"πŸ“– Loading chunks from: {json_path}") - all_chunks = load_chunks(str(json_path)) - print(f"βœ… Found {len(all_chunks)} chunks.") - - # 3. Select Voice - voice_files = list_voice_samples() - if not voice_files: - print(f"❌ No voice samples found in {VOICE_SAMPLES_DIR}") - return - - print("\nAvailable voices:") - for i, voice_file in enumerate(voice_files, 1): - print(f" [{i}] {voice_file.stem}") - - while True: - try: - choice = input("Select voice number: ").strip() - idx = int(choice) - 1 - if 0 <= idx < len(voice_files): - voice_path = voice_files[idx] - break - print("Invalid selection.") - except (ValueError, IndexError): - print("Invalid selection.") - - # Ensure voice compatibility - voice_path = ensure_voice_sample_compatibility(voice_path) - - # 4. Setup Environment - if torch.cuda.is_available(): - device = "cuda" - elif torch.backends.mps.is_available(): - device = "mps" - else: - device = "cpu" - - print(f"\nπŸš€ Using device: {device}") - print(f"🎀 Using voice: {Path(voice_path).name}") - - # 5. Load Model - model = load_optimized_model(device) - - # 6. Pre-warm model to eliminate first chunk quality variations - print(f"πŸ”₯ Pre-warming model with voice sample: {Path(voice_path).name}") - from modules.tts_engine import prewarm_model_with_voice - compatible_voice = ensure_voice_sample_compatibility(voice_path) - # Use default TTS params for pre-warming since we don't have user params here - model = prewarm_model_with_voice(model, compatible_voice, None) - - # 7. Process Chunks - output_root, tts_dir, text_chunks_dir, audio_chunks_dir = setup_book_directories(Path(TEXT_INPUT_ROOT) / book_name) - - # Clean existing audio chunks - print("🧹 Clearing old audio chunks...") - for wav_file in audio_chunks_dir.glob("*.wav"): - wav_file.unlink() - - start_time = time.time() - total_chunks = len(all_chunks) - log_path = output_root / "debug_generation.log" - - print(f"\nπŸ”„ Generating {total_chunks} chunks...") - - with ThreadPoolExecutor(max_workers=2) as executor: # Test parallel processing - futures = [] - for i, chunk_data in enumerate(all_chunks): - # Extract exaggeration from JSON, force others to default - chunk_tts_params = { - "exaggeration": chunk_data.get("tts_params", {}).get("exaggeration", DEFAULT_EXAGGERATION), - "cfg_weight": DEFAULT_CFG_WEIGHT, - "temperature": DEFAULT_TEMPERATURE - } - - future = executor.submit( - process_one_chunk, - i, chunk_data['text'], text_chunks_dir, audio_chunks_dir, - voice_path, chunk_tts_params, start_time, total_chunks, - punc_norm, book_name, log_run, log_path, device, - model, None, all_chunks, chunk_data['boundary_type'] - ) - futures.append(future) - - for future in as_completed(futures): - try: - result = future.result() - if result: - idx, _ = result - log_chunk_progress(idx, total_chunks, start_time, 0) - except Exception as e: - print(f"\n❌ An error occurred while processing a chunk: {e}") - - elapsed_time = time.time() - start_time - print(f"\n{GREEN}βœ… Generation Complete!{RESET}") - print(f"⏱️ Total time: {timedelta(seconds=int(elapsed_time))}") - print(f"πŸ”Š Audio chunks are in: {audio_chunks_dir}") - print("You can now use Option 3 from the main menu to combine them.") - -if __name__ == "__main__": - main() diff --git a/HF_Deploy/wrapper/chunk_editor.py b/HF_Deploy/wrapper/chunk_editor.py deleted file mode 100644 index a94d19b65efa4119ddcbbf59d06de1c7b5b21b76..0000000000000000000000000000000000000000 --- a/HF_Deploy/wrapper/chunk_editor.py +++ /dev/null @@ -1,8 +0,0 @@ -def update_chunk(chunk, boundary_type=None, pause_duration=None, sentiment_score=None): - if boundary_type is not None: - chunk['boundary_type'] = boundary_type - if pause_duration is not None: - chunk['pause_duration'] = pause_duration - if sentiment_score is not None: - chunk['sentiment_score'] = sentiment_score - return chunk diff --git a/HF_Deploy/wrapper/chunk_loader.py b/HF_Deploy/wrapper/chunk_loader.py deleted file mode 100644 index edc1294e72858ccefe3fabc41060d1e6a411285d..0000000000000000000000000000000000000000 --- a/HF_Deploy/wrapper/chunk_loader.py +++ /dev/null @@ -1,72 +0,0 @@ -import json - -def load_chunks(path): - with open(path, 'r', encoding='utf-8') as f: - data = json.load(f) - - # Filter out metadata entries (they start with _metadata: True) - if isinstance(data, list): - chunks = [item for item in data if not (isinstance(item, dict) and item.get('_metadata', False))] - return chunks - - return data - -def load_metadata(path): - """Extract metadata from JSON file""" - try: - with open(path, 'r', encoding='utf-8') as f: - data = json.load(f) - - if isinstance(data, list) and data: - # Look for metadata in first element - first_item = data[0] - if isinstance(first_item, dict) and first_item.get('_metadata', False): - return first_item - - except Exception as e: - print(f"⚠️ Error loading metadata from {path}: {e}") - - return None - -def save_chunks(path, chunks): - # Validate and clean chunks before saving - from collections import OrderedDict - import copy - - cleaned_chunks = [] - for chunk in chunks: - if isinstance(chunk, dict) and 'text' in chunk: - original_text = chunk['text'] - # Clean up any quote corruption - cleaned_text = original_text.replace('\\"', '"').replace("\\'", "'") - - # Check for dialogue corruption patterns - if ('replied' in cleaned_text or 'said' in cleaned_text) and '"' in cleaned_text: - # Additional cleanup for dialogue - import re - cleaned_text = re.sub(r'(["\'])\s*,\s*(["\'])\s*\.', r'\1.', cleaned_text) # Fix ", ". pattern - cleaned_text = re.sub(r'(["\'])\s*,\s*(["\'])\s*$', r'\1.', cleaned_text) # Fix trailing ", " - - if cleaned_text != original_text: - print(f"πŸ”§ FIXED dialogue corruption:") - print(f" Before: {original_text}") - print(f" After: {cleaned_text}") - - # Preserve structure (OrderedDict or regular dict) - if isinstance(chunk, OrderedDict): - chunk_copy = OrderedDict() - for key, value in chunk.items(): - if key == 'text': - chunk_copy[key] = cleaned_text - else: - chunk_copy[key] = copy.deepcopy(value) - else: - chunk_copy = chunk.copy() - chunk_copy['text'] = cleaned_text - - cleaned_chunks.append(chunk_copy) - else: - cleaned_chunks.append(chunk) - - with open(path, 'w', encoding='utf-8') as f: - json.dump(cleaned_chunks, f, indent=2, ensure_ascii=False) diff --git a/HF_Deploy/wrapper/chunk_player.py b/HF_Deploy/wrapper/chunk_player.py deleted file mode 100644 index 2f915e4a6a7cca5afba4534643d6989091a22a47..0000000000000000000000000000000000000000 --- a/HF_Deploy/wrapper/chunk_player.py +++ /dev/null @@ -1,12 +0,0 @@ -import subprocess -import os - -def play_chunk_audio(path): - if not os.path.exists(path): - print(f"❌ Audio file not found: {path}") - return - try: - subprocess.run(["ffplay", "-nodisp", "-autoexit", path], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - except Exception as e: - print(f"Error playing audio: {e}") - diff --git a/HF_Deploy/wrapper/chunk_revisions.py b/HF_Deploy/wrapper/chunk_revisions.py deleted file mode 100644 index 3432446c06ae8c38a191fe09ee37fd52f1787f37..0000000000000000000000000000000000000000 --- a/HF_Deploy/wrapper/chunk_revisions.py +++ /dev/null @@ -1,34 +0,0 @@ -import os -import shutil -from pathlib import Path -from config.config import AUDIOBOOK_ROOT -base = AUDIOBOOK_ROOT - - -def accept_revision(index, audio_dir): - """ - Archive original chunk and replace with revised version. - Assumes revised version is saved as: chunk_XXXXX_rev.wav - """ - base = Path(audio_dir) - # Use 1-based indexing and 5-digit format - original = base / f"chunk_{index+1:05d}.wav" - revised = base / f"chunk_{index+1:05d}_rev.wav" - archive_dir = base.parent.parent / "Audio_Revisions" - archive_dir.mkdir(exist_ok=True) - - if not revised.exists(): - print("❌ No revised file found. Cannot accept.") - return - - # Archive original if exists - if original.exists(): - archived = archive_dir / f"chunk_{index+1:05d}_orig.wav" - shutil.move(str(original), str(archived)) - print(f"πŸ“¦ Original chunk archived to {archived.name}") - else: - print(f"⚠️ Original chunk missing β€” no archive created.") - - # Move revised chunk to main filename - shutil.move(str(revised), str(original)) - print(f"βœ… Revised chunk accepted as {original.name}") diff --git a/HF_Deploy/wrapper/chunk_search.py b/HF_Deploy/wrapper/chunk_search.py deleted file mode 100644 index 81b4f5781dc355b168179364ef95709926dc98c2..0000000000000000000000000000000000000000 --- a/HF_Deploy/wrapper/chunk_search.py +++ /dev/null @@ -1,9 +0,0 @@ -def search_chunks(chunks, query): - results = [] - query_lower = query.lower() - - for chunk in chunks: - if query_lower in chunk['text'].lower(): - results.append(chunk) - - return results diff --git a/HF_Deploy/wrapper/chunk_synthesizer.py b/HF_Deploy/wrapper/chunk_synthesizer.py deleted file mode 100644 index 2567c55677a5d4a763f586921ee5190334378be9..0000000000000000000000000000000000000000 --- a/HF_Deploy/wrapper/chunk_synthesizer.py +++ /dev/null @@ -1,208 +0,0 @@ -from pathlib import Path -import torch -import time -import re -from pydub import AudioSegment - -from modules.tts_engine import load_optimized_model -from modules.file_manager import ensure_voice_sample_compatibility, list_voice_samples -from modules.audio_processor import apply_smart_fade_memory, smart_audio_validation_memory, process_audio_with_trimming_and_silence -from config.config import * - -def get_original_voice_from_log(book_name): - """Extract original voice name from run log""" - audiobook_root = Path(AUDIOBOOK_ROOT) - log_file = audiobook_root / book_name / "run.log" - - if log_file.exists(): - try: - with open(log_file, 'r', encoding='utf-8') as f: - for line in f: - line = line.strip() - if line.startswith("Voice: ") or line.startswith("Voice used: "): - voice_name = line.split(": ", 1)[1].strip() - print(f"πŸ“„ Found original voice in log: {voice_name}") - return voice_name - except Exception as e: - print(f"⚠️ Error reading run log: {e}") - - return None - -def get_original_voice_from_filename(book_name): - """Extract voice name from existing audiobook filename""" - audiobook_root = Path(AUDIOBOOK_ROOT) - book_dir = audiobook_root / book_name - - # Look for WAV files with voice pattern: BookName [VoiceName].wav - for wav_file in book_dir.glob("*.wav"): - match = re.search(r'\[([^\]]+)\]\.wav$', wav_file.name) - if match: - voice_name = match.group(1) - print(f"πŸ“ Found original voice in filename: {voice_name}") - return voice_name - - # Look for M4B files with voice pattern: BookName[VoiceName].m4b - for m4b_file in book_dir.glob("*.m4b"): - match = re.search(r'\[([^\]]+)\]\.m4b$', m4b_file.name) - if match: - voice_name = match.group(1) - print(f"πŸ“ Found original voice in M4B filename: {voice_name}") - return voice_name - - return None - -def find_voice_file_by_name(voice_name): - """Find voice file by name in Voice_Samples directory""" - voice_files = list_voice_samples() - - # Exact match first - for voice_file in voice_files: - if voice_file.stem == voice_name: - print(f"βœ… Found exact voice match: {voice_file.name}") - return voice_file - - # Partial match (case insensitive) - voice_name_lower = voice_name.lower() - for voice_file in voice_files: - if voice_name_lower in voice_file.stem.lower(): - print(f"βœ… Found partial voice match: {voice_file.name}") - return voice_file - - return None - -def get_tts_params_for_chunk(chunk): - """Extract TTS parameters from chunk data or prompt user""" - # Check if chunk has TTS params stored - if 'tts_params' in chunk: - tts_params = chunk['tts_params'] - print(f"πŸ“Š Using stored TTS params: exag={tts_params.get('exaggeration', 1.0)}, cfg={tts_params.get('cfg_weight', 0.7)}, temp={tts_params.get('temperature', 0.7)}") - return tts_params - - # Prompt user for TTS parameters - print(f"\nβš™οΈ TTS Parameters for chunk synthesis:") - - def get_float_input(prompt, default): - while True: - try: - value = input(f"{prompt} [{default}]: ").strip() - if not value: - return default - return float(value) - except ValueError: - print(f"❌ Invalid input. Please enter a valid number.") - - exaggeration = get_float_input("Exaggeration", DEFAULT_EXAGGERATION) - cfg_weight = get_float_input("CFG Weight", DEFAULT_CFG_WEIGHT) - temperature = get_float_input("Temperature", DEFAULT_TEMPERATURE) - - return { - 'exaggeration': exaggeration, - 'cfg_weight': cfg_weight, - 'temperature': temperature - } - -def synthesize_chunk(chunk, index, book_name, audio_dir, revision=False, chunks_json_path=None, override_voice_name=None): - """Generate audio for a single chunk using specified or detected voice and TTS parameters""" - filename = f"chunk_{index+1:05d}_rev.wav" if revision else f"chunk_{index+1:05d}.wav" - out_path = Path(audio_dir) / filename - - try: - # Get device - device = "cuda" if torch.cuda.is_available() else "cpu" - - # Load TTS model - print(f"πŸ€– Loading TTS model for chunk synthesis...") - model = load_optimized_model(device) - - # Determine voice to use - if override_voice_name: - # Use explicitly provided voice - print(f"🎀 Using explicitly selected voice: {override_voice_name}") - voice_path = find_voice_file_by_name(override_voice_name) - voice_name = override_voice_name - detection_method = "user_selected" - else: - # Use enhanced voice detection - print(f"πŸ” Detecting original voice for book: {book_name}") - from modules.voice_detector import detect_voice_for_book - - voice_name, voice_path, detection_method = detect_voice_for_book(book_name, chunks_json_path) - - # Fallback to first available voice if detection failed - if not voice_path: - print(f"⚠️ Voice not found, using first available voice") - voice_files = list_voice_samples() - if not voice_files: - print("❌ No voice samples found") - return None - voice_path = voice_files[0] - voice_name = voice_path.stem - detection_method = "fallback_first_available" - - print(f"🎀 Using voice: {voice_name} (method: {detection_method})") - compatible_voice = ensure_voice_sample_compatibility(voice_path) - - # Get TTS parameters for this chunk - tts_params = get_tts_params_for_chunk(chunk) - - # Prepare model with voice - model.prepare_conditionals(compatible_voice) - - # Get chunk text - chunk_text = chunk.get('text', '') - if not chunk_text: - print("❌ No text found in chunk") - return None - - print(f"🎀 Synthesizing: {chunk_text[:50]}...") - print(f"πŸ“Š TTS params: exag={tts_params['exaggeration']}, cfg={tts_params['cfg_weight']}, temp={tts_params['temperature']}") - - # Generate audio with specified parameters - with torch.no_grad(): - wav = model.generate(chunk_text, - exaggeration=tts_params['exaggeration'], - cfg_weight=tts_params['cfg_weight'], - temperature=tts_params['temperature']).detach().cpu() - - if wav.dim() == 1: - wav = wav.unsqueeze(0) - - # Convert tensor to AudioSegment for processing - import io - import soundfile as sf - - wav_np = wav.squeeze().numpy() - with io.BytesIO() as wav_buffer: - sf.write(wav_buffer, wav_np, model.sr, format='wav') - wav_buffer.seek(0) - audio_segment = AudioSegment.from_wav(wav_buffer) - - # Apply audio processing - audio_segment = apply_smart_fade_memory(audio_segment) - audio_segment, is_quarantined = smart_audio_validation_memory(audio_segment, model.sr) - - # Apply trimming and contextual silence based on boundary type - boundary_type = chunk.get('boundary_type', 'none') - if boundary_type and boundary_type != "none": - audio_segment = process_audio_with_trimming_and_silence(audio_segment, boundary_type) - else: - # Apply trimming even without boundary type if enabled - if ENABLE_AUDIO_TRIMMING: - from modules.audio_processor import trim_audio_endpoint - audio_segment = trim_audio_endpoint(audio_segment) - - # Save final audio - audio_segment.export(out_path, format="wav") - print(f"βœ… Saved synthesized chunk: {out_path.name}") - - # Clean up model - del model - torch.cuda.empty_cache() - - return str(out_path) - - except Exception as e: - print(f"❌ Failed to synthesize chunk: {e}") - import traceback - traceback.print_exc() - return None diff --git a/HF_Deploy/wrapper/chunk_tool.py b/HF_Deploy/wrapper/chunk_tool.py deleted file mode 100644 index c529bb13220f66e22013e8bbda393292c5a51a45..0000000000000000000000000000000000000000 --- a/HF_Deploy/wrapper/chunk_tool.py +++ /dev/null @@ -1,249 +0,0 @@ -from wrapper.chunk_loader import load_chunks, save_chunks -from wrapper.chunk_search import search_chunks -from wrapper.chunk_editor import update_chunk -from wrapper.chunk_player import play_chunk_audio -from wrapper.chunk_synthesizer import synthesize_chunk -from wrapper.chunk_revisions import accept_revision -import os -from config.config import AUDIOBOOK_ROOT -AUDIO_DIR = AUDIOBOOK_ROOT - -def select_book_for_repair(): - """Let user select which book to repair""" - from pathlib import Path - - # Look for books in both locations: TTS processing dirs and Text_Input - available_books = [] - - # First check TTS processing directories - audiobook_root = Path(AUDIOBOOK_ROOT) - if audiobook_root.exists(): - for book_dir in audiobook_root.iterdir(): - if book_dir.is_dir(): - tts_chunks_dir = book_dir / "TTS" / "text_chunks" - json_path = tts_chunks_dir / "chunks_info.json" - if json_path.exists(): - available_books.append((book_dir.name, json_path, "TTS")) - - # Then check Text_Input directory for fallback - text_input_dir = Path("Text_Input") - if text_input_dir.exists(): - for chunk_file in text_input_dir.glob("*_chunks.json"): - book_name = chunk_file.stem.replace("_chunks", "") - # Only add if not already found in TTS directories - if not any(book[0] == book_name for book in available_books): - available_books.append((book_name, chunk_file, "Text_Input")) - - if not available_books: - print("❌ No chunk files found in TTS processing directories or Text_Input/") - return None, None - - print("\nπŸ“š Available books for repair:") - for i, (book_name, json_path, source) in enumerate(available_books): - print(f" [{i}] {book_name} ({source}: {json_path.name})") - - while True: - try: - choice = input(f"\nSelect book index [0-{len(available_books)-1}]: ").strip() - idx = int(choice) - if 0 <= idx < len(available_books): - book_name, json_path, source = available_books[idx] - return book_name, json_path - else: - print(f"❌ Please enter a number between 0 and {len(available_books)-1}") - except (ValueError, EOFError, KeyboardInterrupt): - print("❌ Invalid selection or cancelled") - return None, None - -def run_chunk_repair_tool(): - print("\nπŸ› οΈ Chunk Repair & Revision Tool") - - # Ask user to select book - book_name, chunk_path = select_book_for_repair() - if not chunk_path: - return - - print(f"\nπŸ“– Loading chunks from: {chunk_path.name}") - chunks = load_chunks(str(chunk_path)) - - # Determine audio directory path based on book structure - from pathlib import Path - audiobook_root = Path(AUDIOBOOK_ROOT) - book_audio_dir = audiobook_root / book_name / "TTS" / "audio_chunks" - - if not book_audio_dir.exists(): - print(f"❌ Audio directory not found: {book_audio_dir}") - print(f"πŸ“ Looked for: {book_audio_dir}") - return - - print(f"πŸ“ Using audio directory: {book_audio_dir}") - - while True: - query = input("\nSearch for text fragment (or 'Q' to quit): ").strip() - if query.lower() == "q": - print("Exiting revision tool.") - break - - results = search_chunks(chunks, query) - if not results: - print("❌ No matching chunks found.") - continue - - print(f"\nπŸ” Found {len(results)} match(es):") - for i, chunk in enumerate(results): - print(f"[{i}] \"{chunk['text'][:60]}...\" | Index: {chunk['index']}") - - sel = input("Select chunk index to revise: ").strip() - if not sel.isdigit() or int(sel) >= len(results): - print("Invalid selection.") - continue - - chunk = results[int(sel)] - index = chunk['index'] - # Use 5-digit chunk numbering and correct directory path - chunk_audio_path = book_audio_dir / f"chunk_{index+1:05d}.wav" - chunk_audio_path_str = str(chunk_audio_path) - - while True: - print(f"\nπŸ“ Chunk: \"{chunk['text']}\"") - - # Display current chunk metadata - sentiment_compound = chunk.get('sentiment_compound', chunk.get('sentiment_score', 'N/A')) - tts_params = chunk.get('tts_params', {}) - - print(f" πŸ“ Index: {index}, Boundary: {chunk['boundary_type']}") - print(f" 😊 Sentiment: {sentiment_compound}") - print(f" πŸŽ›οΈ TTS Params: exag={tts_params.get('exaggeration', 'N/A')}, cfg={tts_params.get('cfg_weight', 'N/A')}, temp={tts_params.get('temperature', 'N/A')}") - print(f" πŸ“ Audio file: chunk_{index+1:05d}.wav") - print("\nOptions:") - print(" 1. Play original audio") - print(" 2. Edit text content") - print(" 3. Edit chunk metadata (boundary, sentiment)") - print(" 4. Edit TTS parameters (exaggeration, cfg_weight, temperature)") - print(" 5. Resynthesize audio with current settings") - print(" 6. Play revised audio") - print(" 7. Accept revision (replace original with revised)") - print(" 8. Back to search") - - try: - choice = input("\nπŸ’‘ Enter option number [1-8]: ").strip() - except (EOFError, KeyboardInterrupt): - print("\n❌ Input cancelled") - return - if choice == "1": - print(f"\nπŸ”Š Playing original audio: {chunk_audio_path.name}") - play_chunk_audio(chunk_audio_path_str) - elif choice == "2": - print("\n✏️ Edit Text Content:") - print(f"Current text: \"{chunk['text']}\"") - print("πŸ’‘ Enter new text (or Enter to cancel):") - new_text = input(">>> ").strip() - - if new_text: - chunk['text'] = new_text - chunk['word_count'] = len(new_text.split()) - save_chunks(str(chunk_path), chunks) - print("βœ… Text content updated successfully") - print(f"πŸ“Š New word count: {chunk['word_count']}") - else: - print("❌ No changes made") - elif choice == "3": - print("\n✏️ Edit Chunk Metadata:") - print(f"Current boundary type: {chunk['boundary_type']}") - boundary = input("New boundary type (none/paragraph_end/chapter_start/chapter_end/section_break) [Enter to skip]: ").strip() - - current_sentiment = chunk.get('sentiment_compound', chunk.get('sentiment_score', 'N/A')) - print(f"Current sentiment score: {current_sentiment}") - sentiment = input("New sentiment compound score (-1.0 to 1.0) [Enter to skip]: ").strip() - - try: - if boundary: - chunk['boundary_type'] = boundary - print(f"βœ… Updated boundary type to: {boundary}") - - if sentiment: - sentiment_val = float(sentiment) - if -1.0 <= sentiment_val <= 1.0: - chunk['sentiment_compound'] = sentiment_val - # Also update old key for compatibility - chunk['sentiment_score'] = sentiment_val - print(f"βœ… Updated sentiment score to: {sentiment_val}") - else: - print("❌ Sentiment score must be between -1.0 and 1.0") - - save_chunks(str(chunk_path), chunks) - print("βœ… Chunk metadata updated successfully") - except ValueError as e: - print(f"❌ Invalid input: {e}") - except Exception as e: - print(f"❌ Error updating chunk: {e}") - elif choice == "4": - print("\nπŸŽ›οΈ Edit TTS Parameters:") - current_tts_params = chunk.get('tts_params', {}) - - def get_float_input(param_name, current_val, min_val=None, max_val=None): - while True: - try: - prompt = f"New {param_name} [{current_val}]: " - value = input(prompt).strip() - if not value: - return current_val - new_val = float(value) - if min_val is not None and new_val < min_val: - print(f"❌ {param_name} must be >= {min_val}") - continue - if max_val is not None and new_val > max_val: - print(f"❌ {param_name} must be <= {max_val}") - continue - return new_val - except ValueError: - print(f"❌ Invalid input. Please enter a valid number.") - - # Edit TTS parameters - print(f"Current TTS parameters:") - current_exag = current_tts_params.get('exaggeration', 1.0) - current_cfg = current_tts_params.get('cfg_weight', 0.7) - current_temp = current_tts_params.get('temperature', 0.7) - - print(f" Exaggeration: {current_exag}") - print(f" CFG Weight: {current_cfg}") - print(f" Temperature: {current_temp}") - - new_exag = get_float_input("exaggeration", current_exag, 0.0, 3.0) - new_cfg = get_float_input("CFG weight", current_cfg, 0.0, 2.0) - new_temp = get_float_input("temperature", current_temp, 0.0, 2.0) - - # Update chunk TTS parameters - if 'tts_params' not in chunk: - chunk['tts_params'] = {} - - chunk['tts_params']['exaggeration'] = new_exag - chunk['tts_params']['cfg_weight'] = new_cfg - chunk['tts_params']['temperature'] = new_temp - - save_chunks(str(chunk_path), chunks) - print(f"βœ… TTS parameters updated: exag={new_exag}, cfg={new_cfg}, temp={new_temp}") - elif choice == "5": - print(f"\n🎀 Resynthesizing chunk {index+1:05d}...") - revised_path = synthesize_chunk(chunk, index, book_name, book_audio_dir, revision=True) - if revised_path: - print(f"βœ… Chunk resynthesized: {revised_path}") - else: - print("❌ Failed to resynthesize chunk") - elif choice == "6": - rev_path = book_audio_dir / f"chunk_{index+1:05d}_rev.wav" - print(f"\nπŸ”Š Playing revised audio: {rev_path.name}") - play_chunk_audio(str(rev_path)) - elif choice == "7": - print(f"\nπŸ“¦ Accepting revision for chunk {index+1:05d}...") - accept_revision(index, book_audio_dir) - print("βœ… Revision accepted successfully") - break - elif choice == "8": - print("πŸ”™ Returning to search...") - break - elif choice.lower() == 'q': - print("πŸšͺ Exiting chunk repair tool...") - return - else: - print(f"❌ Invalid option '{choice}'. Please enter a number 1-8 (or 'q' to quit).") diff --git a/HF_Deploy/wrapper/chunk_tool.py~ b/HF_Deploy/wrapper/chunk_tool.py~ deleted file mode 100644 index 91e8fda8e16ea37f99ff5111a4db7d4afb0129db..0000000000000000000000000000000000000000 --- a/HF_Deploy/wrapper/chunk_tool.py~ +++ /dev/null @@ -1,79 +0,0 @@ -from wrapper.chunk_loader import load_chunks, save_chunks -from wrapper.chunk_search import search_chunks -from wrapper.chunk_editor import update_chunk -from wrapper.chunk_player import play_chunk_audio -from wrapper.chunk_synthesizer import synthesize_chunk -from wrapper.chunk_revisions import accept_revision -import os -from config.config import AUDIOBOOK_ROOT -AUDIO_DIR = AUDIOBOOK_ROO - -CHUNK_PATH = "Text_Input/my_book_chunks.json" - - -def run_chunk_repair_tool(): - print("\nπŸ› οΈ Chunk Repair & Revision Tool") - chunks = load_chunks(CHUNK_PATH) - - while True: - query = input("\nSearch for text fragment (or 'Q' to quit): ").strip() - if query.lower() == "q": - print("Exiting revision tool.") - break - - results = search_chunks(chunks, query) - if not results: - print("❌ No matching chunks found.") - continue - - print(f"\nπŸ” Found {len(results)} match(es):") - for i, chunk in enumerate(results): - print(f"[{i}] \"{chunk['text'][:60]}...\" | Index: {chunk['index']}") - - sel = input("Select chunk index to revise: ").strip() - if not sel.isdigit() or int(sel) >= len(results): - print("Invalid selection.") - continue - - chunk = results[int(sel)] - index = chunk['index'] - chunk_path = os.path.join(AUDIO_DIR, f"chunk_{index:03}.wav") - - while True: - print(f"\nπŸ“ Chunk: \"{chunk['text']}\"") - print(f" Boundary: {chunk['boundary_type']}, Sentiment: {chunk.get('sentiment_score', 'N/A')}, Pause: {chunk.get('pause_duration', 'N/A')}") - print("\nOptions:") - print(" 1. Play original") - print(" 2. Edit values") - print(" 3. Resynthesize") - print(" 4. Play revised") - print(" 5. Accept revision") - print(" 6. Back to search") - - choice = input("Enter option number: ").strip() - if choice == "1": - play_chunk_audio(chunk_path) - elif choice == "2": - boundary = input("New boundary type (or Enter to skip): ").strip() - sentiment = input("New sentiment score (or Enter to skip): ").strip() - pause = input("New pause duration (or Enter to skip): ").strip() - - update_chunk( - chunk, - boundary_type=boundary if boundary else None, - sentiment_score=float(sentiment) if sentiment else None, - pause_duration=float(pause) if pause else None - ) - save_chunks(CHUNK_PATH, chunks) - elif choice == "3": - synthesize_chunk(chunk, index, revision=True) - elif choice == "4": - rev_path = os.path.join(AUDIO_DIR, f"chunk_{index:03}_rev.wav") - play_chunk_audio(rev_path) - elif choice == "5": - accept_revision(index) - break - elif choice == "6": - break - else: - print("Invalid input. Try again.") diff --git a/config/config.py b/config/config.py deleted file mode 100644 index cf1d3a0ed877541d16e542a5571ea6dcbf174fe4..0000000000000000000000000000000000000000 --- a/config/config.py +++ /dev/null @@ -1,159 +0,0 @@ -""" -GenTTS Configuration Module -Central location for all settings, paths, and feature toggles -""" - -import os -from pathlib import Path - -# ============================================================================ -# CORE DIRECTORIES -# ============================================================================ -TEXT_INPUT_ROOT = Path("Text_Input") -AUDIOBOOK_ROOT = Path("Audiobook") -VOICE_SAMPLES_DIR = Path("Voice_Samples") - -# ============================================================================ -# TEXT PROCESSING SETTINGS -# ============================================================================ -MAX_CHUNK_WORDS = 28 -MIN_CHUNK_WORDS = 4 - -# ============================================================================ -# WORKER AND PERFORMANCE SETTINGS -# ============================================================================ -MAX_WORKERS = 2 # Keep at 2 - GPU utilization already high -TEST_MAX_WORKERS = 6 # For experimentation -USE_DYNAMIC_WORKERS = False # Toggle for testing -VRAM_SAFETY_THRESHOLD = 6.5 # GB - -# ============================================================================ -# AUDIO QUALITY SETTINGS -# ============================================================================ -ENABLE_MID_DROP_CHECK = False -ENABLE_ASR = False -ASR_WORKERS = 4 # Parallel ASR on CPU threads - -# ============================================================================ -# TTS HUM DETECTION SETTINGS -# ============================================================================ -ENABLE_HUM_DETECTION = False # Disabled for speed (re-enable if quality issues) -HUM_FREQ_MIN = 50 # Hz - Lower frequency bound for hum detection -HUM_FREQ_MAX = 200 # Hz - Upper frequency bound for hum detection -HUM_ENERGY_THRESHOLD = 0.3 # Ratio of hum energy to total energy (0.1-0.5 range) -HUM_STEADY_THRESHOLD = 0.6 # Ratio of segments with steady amplitude (0.5-0.8 range) -HUM_AMPLITUDE_MIN = 0.005 # Minimum RMS for steady hum detection -HUM_AMPLITUDE_MAX = 0.1 # Maximum RMS for steady hum detection - -# ============================================================================ -# AUDIO TRIMMING SETTINGS -# ============================================================================ -ENABLE_AUDIO_TRIMMING = True # Enable automatic audio trimming after TTS -SPEECH_ENDPOINT_THRESHOLD = 0.005 # RMS threshold to detect end of speech (more aggressive) -TRIMMING_BUFFER_MS = 50 # Small buffer after detected speech endpoint - -# ============================================================================ -# SILENCE DURATION SETTINGS (milliseconds) -# ============================================================================ -SILENCE_CHAPTER_START = 500 # Half second for chapter beginnings -SILENCE_CHAPTER_END = 800 # Longer pause before new chapter -SILENCE_SECTION_BREAK = 600 # Section transitions -SILENCE_PARAGRAPH_END = 300 # Standard paragraph breaks - -# Punctuation-specific silence settings (milliseconds) -SILENCE_COMMA = 150 # Brief pause after commas -SILENCE_SEMICOLON = 250 # Medium pause after semicolons -SILENCE_COLON = 300 # Pause after colons -SILENCE_PERIOD = 400 # Sentence end pause -SILENCE_QUESTION_MARK = 450 # Question pause (slightly longer) -SILENCE_EXCLAMATION = 400 # Exclamation pause -SILENCE_DASH = 200 # Em dash pause -SILENCE_ELLIPSIS = 350 # Ellipsis pause (suspense) -SILENCE_QUOTE_END = 250 # End of quoted speech - -# Chunk-level silence settings -ENABLE_CHUNK_END_SILENCE = True # Add silence to end of every chunk -CHUNK_END_SILENCE_MS = 200 # Default silence at end of each chunk - -# Content boundary silence settings (milliseconds) -SILENCE_PARAGRAPH_FALLBACK = 500 # Original paragraph logic fallback - -# ============================================================================ -# AUDIO NORMALIZATION SETTINGS -# ============================================================================ -ENABLE_NORMALIZATION = True # Global ON/OFF switch for normalization -NORMALIZATION_TYPE = "peak" # Options: "loudness", "peak", "simple", "none" -TARGET_LUFS = -16 # Target loudness (LUFS) for broadcast standard -TARGET_PEAK_DB = -1.5 # Target peak level (dB) to prevent clipping -TARGET_LRA = 11 # Target loudness range for consistency - -# ============================================================================ -# AUDIO PLAYBACK SPEED SETTINGS -# ============================================================================ -ATEMPO_SPEED = 0.95 # Playback speed multiplier (0.5-2.0 range, 1.0 = normal speed) - -# ============================================================================ -# ENVIRONMENT SETUP -# ============================================================================ -os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" -os.environ["TRANSFORMERS_NO_PROGRESS_BAR"] = "1" -os.environ["HF_TRANSFORMERS_NO_TQDM"] = "1" -os.environ["TORCH_HUB_DIR"] = "/tmp/torch_hub_silent" - -# ============================================================================ -# COLOR CODES FOR TERMINAL OUTPUT -# ============================================================================ -RESET = "\033[0m" -BOLD = "\033[1m" -RED = "\033[91m" -GREEN = "\033[92m" -YELLOW = "\033[93m" -CYAN = "\033[96m" - -# ============================================================================ -# TTS MODEL PARAMETERS (DEFAULTS) -# ============================================================================ -DEFAULT_EXAGGERATION = 0.4 # Emotion intensity (0.0-2.0 range) -DEFAULT_CFG_WEIGHT = 0.5 # Faithfulness to text (0.0-1.0 range) -DEFAULT_TEMPERATURE = 0.9 # Randomness/creativity (0.0-1.0 range) - -# ============================================================================ -# VADER SENTIMENT TO TTS PARAMETER MAPPING -# ============================================================================ -# These settings control how VADER sentiment analysis dynamically adjusts TTS parameters. -# The formula used is: new_param = base_param + (compound_score * sensitivity) -# The result is then clamped within the defined MIN/MAX range. - -# --- Base TTS Parameters (used as the starting point) --- -# These are the same as the main defaults, but listed here for clarity. -BASE_EXAGGERATION = DEFAULT_EXAGGERATION # Default: 1.0 -BASE_CFG_WEIGHT = DEFAULT_CFG_WEIGHT # Default: 0.7 -BASE_TEMPERATURE = DEFAULT_TEMPERATURE # Default: 0.7 - -# --- Sensitivity --- -# How much VADER's compound score affects each parameter. -# Higher values mean more dramatic changes based on sentiment. -VADER_EXAGGERATION_SENSITIVITY = 0.5 # e.g., compound of 0.8 -> 1.0 + (0.8 * 0.5) = 1.4 -VADER_CFG_WEIGHT_SENSITIVITY = -0.2 # Negative: more emotional text is less strict -VADER_TEMPERATURE_SENSITIVITY = 0.15 # More emotional text gets slightly more creative - -# --- Min/Max Clamps --- -# Hard limits to prevent extreme, undesirable audio artifacts. -TTS_PARAM_MIN_EXAGGERATION = 0.1 -TTS_PARAM_MAX_EXAGGERATION = 2.0 -TTS_PARAM_MIN_CFG_WEIGHT = 0.1 -TTS_PARAM_MAX_CFG_WEIGHT = 1.0 - -TTS_PARAM_MIN_TEMPERATURE = 0.1 -TTS_PARAM_MAX_TEMPERATURE = 5.0 - -# ============================================================================ -# BATCH PROCESSING SETTINGS -# ============================================================================ -BATCH_SIZE = 250 # Larger batches for better speed (monitor VRAM) -CLEANUP_INTERVAL = 500 # Deep cleanup every N chunks (reduced frequency for speed) - -# ============================================================================ -# FEATURE TOGGLES -# ============================================================================ -shutdown_requested = False # Global shutdown flag diff --git a/gradio_main_interface.py b/gradio_main_interface.py index 448e91c21b755a590bcea159c22b3a30e87e12b0..1ec3ee90175a4c2157195bd917c593ace982793c 100644 --- a/gradio_main_interface.py +++ b/gradio_main_interface.py @@ -1,7 +1,40 @@ #!/usr/bin/env python3 """ -ChatterboxTTS DNXS-Spokneword Gradio Main Interface -Modular web interface with separate tab modules +ChatterboxTTS Gradio Web Interface - Main Entry Point +==================================================== + +OVERVIEW: +This is the main web interface for ChatterboxTTS, providing a user-friendly +Gradio-based GUI for audiobook generation. It serves as the primary entry point +for users who prefer web interfaces over command-line tools. + +ARCHITECTURE: +- MODULAR TAB SYSTEM: Each major function is a separate tab module +- GRACEFUL DEGRADATION: Missing tab modules show placeholder pages +- RESPONSIVE DESIGN: Works on desktop and mobile browsers +- IMPORT SAFETY: Handles missing dependencies gracefully + +AVAILABLE TABS: +1. Convert Book (Tab 1) - FUNCTIONAL: Main TTS conversion interface +2. Quick Convert (Tab 2) - PLACEHOLDER: Fast conversion for small texts +3. Voice Analysis (Tab 3) - PLACEHOLDER: Voice sample analysis tools +4. Batch Processing (Tab 4) - PLACEHOLDER: Multi-book processing +5. Audio Tools (Tab 5) - PLACEHOLDER: Audio editing and enhancement +6. Settings (Tab 6) - FUNCTIONAL: Configuration management +7. Chunk Tools (Tab 7) - PLACEHOLDER: Chunk editing and repair +8. Voice Training (Tab 8) - PLACEHOLDER: Voice cloning tools +9. System Monitor (Tab 9) - PLACEHOLDER: Performance monitoring + +DEPLOYMENT MODES: +- LOCAL: python3 gradio_main_interface.py (development) +- HUGGINGFACE SPACES: Called by app.py launcher (production) +- COLAB/RUNPOD: Automatic sharing and port configuration + +TECHNICAL FEATURES: +- Auto-detects HuggingFace Spaces environment +- Configurable sharing and port settings +- Error handling for missing tab modules +- Clean, professional interface design """ import gradio as gr @@ -10,6 +43,7 @@ import os from pathlib import Path # Add the current directory to Python path for imports +# This ensures tab modules can be imported regardless of working directory sys.path.append(str(Path(__file__).parent)) # Import tab modules @@ -20,6 +54,27 @@ except ImportError as e: print(f"⚠️ Tab 1 not available: {e}") TAB1_AVAILABLE = False +try: + from gradio_tabs.tab2_configuration import create_configuration_tab + TAB2_AVAILABLE = True +except ImportError as e: + print(f"⚠️ Tab 2 (Configuration) not available: {e}") + TAB2_AVAILABLE = False + +try: + from gradio_tabs.tab4_combine_audio import create_combine_audio_tab + TAB4_AVAILABLE = True +except ImportError as e: + print(f"⚠️ Tab 4 (Combine Audio) not available: {e}") + TAB4_AVAILABLE = False + +try: + from gradio_tabs.tab5_prepare_text import create_prepare_text_tab + TAB5_AVAILABLE = True +except ImportError as e: + print(f"⚠️ Tab 5 (Prepare Text) not available: {e}") + TAB5_AVAILABLE = False + try: from gradio_tabs.tab6_settings import create_settings_tab_interface TAB6_AVAILABLE = True @@ -27,6 +82,20 @@ except ImportError as e: print(f"⚠️ Tab 6 (Settings) not available: {e}") TAB6_AVAILABLE = False +try: + from gradio_tabs.tab7_chunk_tools import create_chunk_tools_tab + TAB7_AVAILABLE = True +except ImportError as e: + print(f"⚠️ Tab 7 (Chunk Tools) not available: {e}") + TAB7_AVAILABLE = False + +try: + from gradio_tabs.tab8_json_generate import create_json_generate_tab + TAB8_AVAILABLE = True +except ImportError as e: + print(f"⚠️ Tab 8 (JSON Generate) not available: {e}") + TAB8_AVAILABLE = False + def create_placeholder_tab(tab_name, tab_number): """Create a placeholder tab for future implementation""" with gr.Column(): @@ -65,18 +134,32 @@ def create_main_interface(): with gr.Tab("1. Convert Book"): create_placeholder_tab("Convert Book", 1) - # Tab 2-10: Placeholders for now - with gr.Tab("2. File Management"): - create_placeholder_tab("File Management", 2) + # Tab 2: Configuration Settings (Working) + if TAB2_AVAILABLE: + with gr.Tab("2. Configuration"): + create_configuration_tab() + else: + with gr.Tab("2. Configuration"): + create_placeholder_tab("Configuration Settings", 2) with gr.Tab("3. Voice Analysis"): create_placeholder_tab("Voice Analysis", 3) - with gr.Tab("4. Batch Processing"): - create_placeholder_tab("Batch Processing", 4) + # Tab 4: Combine Audio (Working) + if TAB4_AVAILABLE: + with gr.Tab("4. Combine Audio"): + create_combine_audio_tab() + else: + with gr.Tab("4. Combine Audio"): + create_placeholder_tab("Combine Audio", 4) - with gr.Tab("5. Audio Tools"): - create_placeholder_tab("Audio Tools", 5) + # Tab 5: Prepare Text (Working) + if TAB5_AVAILABLE: + with gr.Tab("5. Prepare Text"): + create_prepare_text_tab() + else: + with gr.Tab("5. Prepare Text"): + create_placeholder_tab("Prepare Text", 5) # Tab 6: Settings (Working) if TAB6_AVAILABLE: @@ -86,11 +169,21 @@ def create_main_interface(): with gr.Tab("6. Settings"): create_placeholder_tab("Settings", 6) - with gr.Tab("7. Chunk Tools"): - create_placeholder_tab("Chunk Tools", 7) + # Tab 7: Chunk Tools (Working) + if TAB7_AVAILABLE: + with gr.Tab("7. Chunk Tools"): + create_chunk_tools_tab() + else: + with gr.Tab("7. Chunk Tools"): + create_placeholder_tab("Chunk Tools", 7) - with gr.Tab("8. Voice Training"): - create_placeholder_tab("Voice Training", 8) + # Tab 8: JSON Generate (Working) + if TAB8_AVAILABLE: + with gr.Tab("8. JSON Generate"): + create_json_generate_tab() + else: + with gr.Tab("8. JSON Generate"): + create_placeholder_tab("JSON Generate", 8) with gr.Tab("9. System Monitor"): create_placeholder_tab("System Monitor", 9) @@ -112,7 +205,13 @@ def launch_interface(): print("πŸš€ ChatterboxTTS - Starting Main Interface") print("πŸ“Š Tab Status:") print(f" Tab 1 (Convert Book): {'βœ… Available' if TAB1_AVAILABLE else '❌ Not Available'}") - print(" Tabs 2-10: 🚧 Placeholder (Coming Soon)") + print(f" Tab 2 (Configuration): {'βœ… Available' if TAB2_AVAILABLE else '❌ Not Available'}") + print(f" Tab 4 (Combine Audio): {'βœ… Available' if TAB4_AVAILABLE else '❌ Not Available'}") + print(f" Tab 5 (Prepare Text): {'βœ… Available' if TAB5_AVAILABLE else '❌ Not Available'}") + print(f" Tab 6 (Settings): {'βœ… Available' if TAB6_AVAILABLE else '❌ Not Available'}") + print(f" Tab 7 (Chunk Tools): {'βœ… Available' if TAB7_AVAILABLE else '❌ Not Available'}") + print(f" Tab 8 (JSON Generate): {'βœ… Available' if TAB8_AVAILABLE else '❌ Not Available'}") + print(" Other Tabs: 🚧 Placeholder (Coming Soon)") print("-" * 50) demo = create_main_interface() diff --git a/gradio_tabs/tab1_convert_book.py b/gradio_tabs/tab1_convert_book.py index 732da1b22bc9c70094be2fa887e7141bee83795e..761af27a1fad5b9c97a291b4f64e595a249dfefa 100644 --- a/gradio_tabs/tab1_convert_book.py +++ b/gradio_tabs/tab1_convert_book.py @@ -21,36 +21,33 @@ from typing import List, Dict, Any, Optional, Tuple warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.backends.cuda.sdp_kernel.*") warnings.filterwarnings("ignore", category=FutureWarning, message=".*sdp_kernel.*") -# Import ChatterboxTTS modules and ensure all config variables are available -# First set defaults, then try to import from config -DEFAULT_EXAGGERATION = 0.4 -DEFAULT_CFG_WEIGHT = 0.5 -DEFAULT_TEMPERATURE = 0.9 -TTS_PARAM_MIN_EXAGGERATION = 0.0 -TTS_PARAM_MAX_EXAGGERATION = 2.0 -TTS_PARAM_MIN_CFG_WEIGHT = 0.0 -TTS_PARAM_MAX_CFG_WEIGHT = 1.0 -TTS_PARAM_MIN_TEMPERATURE = 0.0 -TTS_PARAM_MAX_TEMPERATURE = 5.0 -ENABLE_REGENERATION_LOOP = True -MAX_REGENERATION_ATTEMPTS = 3 -QUALITY_THRESHOLD = 0.7 -ENABLE_SENTIMENT_SMOOTHING = True -SENTIMENT_SMOOTHING_WINDOW = 3 -SENTIMENT_SMOOTHING_METHOD = "rolling" -ENABLE_MFCC_VALIDATION = False -ENABLE_OUTPUT_VALIDATION = False -SPECTRAL_ANOMALY_THRESHOLD = 0.8 -OUTPUT_VALIDATION_THRESHOLD = 0.85 - -# Try to import config and override defaults if available +# Import ChatterboxTTS modules try: from config.config import * CONFIG_AVAILABLE = True - print("βœ… Config loaded successfully") except ImportError: print("⚠️ Config not available - using defaults") CONFIG_AVAILABLE = False + # Default values from config + DEFAULT_EXAGGERATION = 0.4 + DEFAULT_CFG_WEIGHT = 0.5 + DEFAULT_TEMPERATURE = 0.9 + TTS_PARAM_MIN_EXAGGERATION = 0.0 + TTS_PARAM_MAX_EXAGGERATION = 2.0 + TTS_PARAM_MIN_CFG_WEIGHT = 0.0 + TTS_PARAM_MAX_CFG_WEIGHT = 1.0 + TTS_PARAM_MIN_TEMPERATURE = 0.0 + TTS_PARAM_MAX_TEMPERATURE = 5.0 + ENABLE_REGENERATION_LOOP = True + MAX_REGENERATION_ATTEMPTS = 3 + QUALITY_THRESHOLD = 0.7 + ENABLE_SENTIMENT_SMOOTHING = True + SENTIMENT_SMOOTHING_WINDOW = 3 + SENTIMENT_SMOOTHING_METHOD = "rolling" + ENABLE_MFCC_VALIDATION = False + ENABLE_OUTPUT_VALIDATION = False + SPECTRAL_ANOMALY_THRESHOLD = 0.8 + OUTPUT_VALIDATION_THRESHOLD = 0.85 # Import the actual conversion functions from GUI try: @@ -74,7 +71,8 @@ conversion_state = { 'vram_usage': '-- GB', 'current_chunk': '--', 'eta': '--', - 'elapsed': '--' + 'elapsed': '--', + 'needs_refresh': False } def parse_progress_stats(output_line): @@ -82,19 +80,19 @@ def parse_progress_stats(output_line): # Look for progress pattern: "πŸŒ€ Chunk 2/13 | ⏱ Elapsed: 0:01:31 | ETA: 0:09:54 | Remaining: 0:08:23 | Realtime: 0.11x | VRAM: 3.3GB" progress_pattern = r'πŸŒ€ Chunk (\d+)/(\d+).*?Realtime: ([\d.]+)x.*?VRAM: ([\d.]+)GB' match = re.search(progress_pattern, output_line) - + if match: current_chunk = int(match.group(1)) total_chunks = int(match.group(2)) realtime_factor = f"{match.group(3)}x" vram_usage = f"{match.group(4)} GB" - + # Update global state conversion_state['current_chunk'] = f"{current_chunk}/{total_chunks}" conversion_state['realtime_factor'] = realtime_factor conversion_state['vram_usage'] = vram_usage conversion_state['progress'] = int((current_chunk / total_chunks) * 100) if total_chunks > 0 else 0 - + print(f"πŸ“Š Stats Updated: Chunk {current_chunk}/{total_chunks}, {realtime_factor}, {vram_usage}") return True else: @@ -106,22 +104,22 @@ def parse_progress_stats(output_line): total_chunks = int(alt_match.group(2)) realtime_factor = f"{alt_match.group(3)}x" vram_usage = f"{alt_match.group(4)} GB" - + conversion_state['current_chunk'] = f"{current_chunk}/{total_chunks}" conversion_state['realtime_factor'] = realtime_factor conversion_state['vram_usage'] = vram_usage conversion_state['progress'] = int((current_chunk / total_chunks) * 100) if total_chunks > 0 else 0 - + print(f"πŸ“Š Stats Updated: Chunk {current_chunk}/{total_chunks}, {realtime_factor}, {vram_usage}") return True - + return False def get_progress_stats(): """Get current progress statistics for UI update""" return ( conversion_state['realtime_factor'], - conversion_state['vram_usage'], + conversion_state['vram_usage'], conversion_state['current_chunk'], conversion_state['progress'] ) @@ -131,28 +129,28 @@ def get_book_folders(): text_input_dir = Path("Text_Input") if not text_input_dir.exists(): return [] - + folders = [] for item in text_input_dir.iterdir(): if item.is_dir(): folders.append(item.name) # Show only folder name, not full path - + return sorted(folders) def get_text_files_in_folder(folder_name): """Get text files in selected book folder""" if not folder_name: return [] - + # Build full path from folder name folder = Path("Text_Input") / folder_name if not folder.exists(): return [] - + text_files = [] for file in folder.glob("*.txt"): text_files.append(file.name) - + return sorted(text_files) def get_voice_samples(): @@ -160,11 +158,11 @@ def get_voice_samples(): voice_dir = Path("Voice_Samples") if not voice_dir.exists(): return [] - + voices = [] for file in voice_dir.glob("*.wav"): voices.append(file.name) # Show only filename, not full path - + return sorted(voices) def find_generated_audiobook(book_folder_path, voice_sample_path): @@ -173,7 +171,7 @@ def find_generated_audiobook(book_folder_path, voice_sample_path): book_folder = Path(book_folder_path) voice_file = Path(voice_sample_path) voice_name = voice_file.stem - + # Look in Output/ directory first (final audiobooks) output_dir = Path("Output") if output_dir.exists(): @@ -181,12 +179,12 @@ def find_generated_audiobook(book_folder_path, voice_sample_path): for m4b_file in output_dir.glob(f"*[{voice_name}]*.m4b"): if m4b_file.exists(): return str(m4b_file), "M4B audiobook" - - # Look for WAV files with voice name + + # Look for WAV files with voice name for wav_file in output_dir.glob(f"*[{voice_name}]*.wav"): if wav_file.exists(): return str(wav_file), "WAV audiobook" - + # Look in Audiobook/ directory (processing output) audiobook_dir = Path("Audiobook") / book_folder.name if audiobook_dir.exists(): @@ -194,19 +192,19 @@ def find_generated_audiobook(book_folder_path, voice_sample_path): for m4b_file in audiobook_dir.glob(f"*[{voice_name}]*.m4b"): if m4b_file.exists(): return str(m4b_file), "M4B audiobook" - + # Look for WAV files for wav_file in audiobook_dir.glob(f"*[{voice_name}]*.wav"): if wav_file.exists(): return str(wav_file), "WAV audiobook" - + # Look for combined files for combined_file in audiobook_dir.glob("*_combined.*"): if combined_file.suffix in ['.wav', '.m4b', '.mp3']: return str(combined_file), f"{combined_file.suffix.upper()[1:]} combined audiobook" - + return None, "No audiobook found" - + except Exception as e: print(f"Error finding audiobook: {e}") return None, f"Error: {str(e)}" @@ -216,18 +214,18 @@ def run_book_conversion(book_path, text_file_path, voice_path, tts_params, quali try: # Import the real TTS engine function directly (avoid interface.py) from modules.tts_engine import process_book_folder - + # Extract enable_asr from tts_params (matching GUI exactly) enable_asr = tts_params.get('enable_asr', False) - + print(f"πŸš€ Starting book conversion with GUI parameters") print(f"πŸ“– Book: {book_path}") - print(f"πŸ“„ Text file: {text_file_path}") + print(f"πŸ“„ Text file: {text_file_path}") print(f"🎀 Voice: {voice_path}") print(f"πŸŽ›οΈ TTS Params: {tts_params}") print(f"πŸ”¬ Quality Params: {quality_params}") print(f"βš™οΈ Config Params: {config_params}") - + # Set up progress callback function def progress_callback(current_chunk, total_chunks, realtime_factor, vram_usage): """Callback function to update progress from TTS engine""" @@ -236,35 +234,30 @@ def run_book_conversion(book_path, text_file_path, voice_path, tts_params, quali conversion_state['vram_usage'] = f"{vram_usage} GB" conversion_state['progress'] = int((current_chunk / total_chunks) * 100) if total_chunks > 0 else 0 print(f"πŸ“Š Progress: {current_chunk}/{total_chunks} ({conversion_state['progress']}%) - {realtime_factor}x - {vram_usage}GB") - + # Add progress callback to config params config_params['progress_callback'] = progress_callback - + # Convert string paths to Path objects (required by TTS engine) book_dir_path = Path(book_path) voice_path_obj = Path(voice_path) - - # Auto-detect device with fallback to CPU - import torch - if torch.cuda.is_available(): - device = "cuda" - print("βœ… Using CUDA GPU for processing") - else: - device = "cpu" - print("πŸ’» Using CPU for processing (no GPU available)") - # Direct call to TTS engine (function only accepts: book_dir, voice_path, tts_params, device, skip_cleanup) + # Direct call to TTS engine (same as GUI) result = process_book_folder( book_dir=book_dir_path, voice_path=voice_path_obj, tts_params=tts_params, - device=device, - skip_cleanup=False + device="cuda", + skip_cleanup=False, + enable_asr=enable_asr, + quality_params=quality_params, + config_params=config_params, + specific_text_file=text_file_path ) - + print(f"βœ… Conversion completed successfully") return {'success': True, 'result': result} - + except Exception as e: print(f"❌ Conversion failed: {e}") import traceback @@ -275,17 +268,17 @@ def regenerate_m4b_file(selected_m4b, playback_speed): """Regenerate M4B file with new playback speed""" if not selected_m4b: return "❌ Please select an M4B file first", None - + try: print(f"πŸ”„ Regenerating M4B: {selected_m4b} at {playback_speed}x speed") - + # Import M4B regeneration tools from tools.combine_only import apply_playback_speed_to_m4b - + # Find the M4B file path audiobook_root = Path("Audiobook") m4b_path = None - + for book_dir in audiobook_root.iterdir(): if book_dir.is_dir(): for m4b_file in book_dir.glob("*.m4b"): @@ -294,40 +287,40 @@ def regenerate_m4b_file(selected_m4b, playback_speed): break if m4b_path: break - + if not m4b_path: return "❌ M4B file not found", None - + # Create new filename with speed suffix speed_suffix = f"_speed{playback_speed}x".replace(".", "p") new_name = m4b_path.stem + speed_suffix + ".m4b" output_path = m4b_path.parent / new_name - + # Apply speed change success = apply_playback_speed_to_m4b(str(m4b_path), str(output_path), playback_speed) - + if success: return f"βœ… Regenerated M4B at {playback_speed}x speed: {new_name}", str(output_path) else: return "❌ Failed to regenerate M4B", None - + except Exception as e: print(f"❌ M4B regeneration failed: {e}") return f"❌ Error: {str(e)}", None def create_convert_book_tab(): """Create Tab 1: Convert Book with all GUI functionality""" - + with gr.Column(): gr.Markdown("# πŸš€ Convert Book") gr.Markdown("*Main TTS conversion functionality - matches GUI Tab 1*") - + # Main Content Layout with gr.Row(): # Left Column - File Uploads with gr.Column(scale=2): gr.Markdown("### πŸ“š Book Selection") - + # Book text file upload only text_file_upload = gr.File( label="πŸ“š Upload Book Text File", @@ -335,9 +328,9 @@ def create_convert_book_tab(): file_count="single", interactive=True ) - + gr.Markdown("### 🎀 Voice Selection") - + # Single voice upload with integrated playback voice_file_upload = gr.File( label="🎀 Upload Voice Sample", @@ -345,7 +338,7 @@ def create_convert_book_tab(): file_count="single", interactive=True ) - + # Voice sample player (becomes active after upload) voice_audio = gr.Audio( label="Voice Sample Preview", @@ -353,30 +346,30 @@ def create_convert_book_tab(): show_download_button=False, visible=False ) - + # Right Column - All Settings with gr.Column(scale=1): gr.Markdown("### βš™οΈ Quick Settings") - + # VADER and ASR vader_enabled = gr.Checkbox( label="Use VADER sentiment analysis", value=True, info="Adjust TTS params per chunk based on emotion" ) - + # ASR System with intelligent model selection with gr.Row(): asr_enabled = gr.Checkbox( - label="🎀 Enable ASR validation", + label="🎀 Enable ASR validation", value=False, info="Smart quality control with automatic model selection" ) - + # ASR Configuration (initially hidden) with gr.Column(visible=False) as asr_config_group: gr.Markdown("#### πŸ” ASR Configuration") - + # System analysis display system_analysis = gr.Textbox( label="System Analysis", @@ -384,25 +377,25 @@ def create_convert_book_tab(): lines=3, interactive=False ) - + analyze_system_btn = gr.Button( "πŸ” Analyze System", size="sm", variant="secondary" ) - + # ASR Level Selection asr_level = gr.Radio( label="ASR Quality Level", choices=[ ("🟒 SAFE - Fast processing, basic accuracy", "safe"), - ("🟑 MODERATE - Balanced speed/accuracy (recommended)", "moderate"), + ("🟑 MODERATE - Balanced speed/accuracy (recommended)", "moderate"), ("πŸ”΄ INSANE - Best accuracy, may stress system", "insane") ], value="moderate", info="Automatically selects best models for your system" ) - + # Selected models display selected_models = gr.Textbox( label="Selected ASR Models", @@ -410,86 +403,86 @@ def create_convert_book_tab(): lines=2, interactive=False ) - + # Batch processing add_to_batch = gr.Checkbox( label="πŸ“¦ Add to batch queue", value=False, info="Queue for batch processing" ) - + gr.Markdown("### πŸ”„ Regeneration Settings") - + regeneration_enabled = gr.Checkbox( label="Enable automatic chunk regeneration", value=ENABLE_REGENERATION_LOOP, info="Retry failed chunks automatically" ) - + max_attempts = gr.Slider( label="Max Attempts", minimum=1, maximum=10, step=1, value=MAX_REGENERATION_ATTEMPTS ) - + quality_threshold = gr.Slider( - label="Quality Threshold", + label="Quality Threshold", minimum=0.1, maximum=1.0, step=0.05, value=QUALITY_THRESHOLD ) - + gr.Markdown("### πŸ“Š Sentiment Smoothing") - + sentiment_smoothing = gr.Checkbox( label="Enable sentiment smoothing", value=ENABLE_SENTIMENT_SMOOTHING, info="Smooth emotional transitions" ) - + smoothing_window = gr.Slider( label="Window Size", minimum=1, maximum=10, step=1, value=SENTIMENT_SMOOTHING_WINDOW ) - + smoothing_method = gr.Dropdown( label="Smoothing Method", choices=["rolling", "exp_decay"], value=SENTIMENT_SMOOTHING_METHOD ) - + gr.Markdown("### πŸ” Advanced Detection") - + mfcc_validation = gr.Checkbox( label="MFCC spectral analysis", value=ENABLE_MFCC_VALIDATION, info="Advanced audio quality detection" ) - + output_validation = gr.Checkbox( label="Output validation", value=ENABLE_OUTPUT_VALIDATION, info="Quality control clearinghouse for enabled checks" ) - + spectral_threshold = gr.Slider( label="Spectral Threshold", minimum=0.1, maximum=1.0, step=0.05, value=SPECTRAL_ANOMALY_THRESHOLD ) - + output_threshold = gr.Slider( - label="Output Threshold", + label="Output Threshold", minimum=0.1, maximum=1.0, step=0.05, value=OUTPUT_VALIDATION_THRESHOLD ) - - + + # TTS Parameters with gr.Row(): with gr.Column(): gr.Markdown("### πŸŽ›οΈ TTS Parameters") - + exaggeration = gr.Slider( label="Exaggeration", minimum=TTS_PARAM_MIN_EXAGGERATION, @@ -498,16 +491,16 @@ def create_convert_book_tab(): value=DEFAULT_EXAGGERATION, info="Emotional intensity" ) - + cfg_weight = gr.Slider( - label="CFG Weight", + label="CFG Weight", minimum=TTS_PARAM_MIN_CFG_WEIGHT, maximum=TTS_PARAM_MAX_CFG_WEIGHT, step=0.1, value=DEFAULT_CFG_WEIGHT, info="Text faithfulness" ) - + temperature = gr.Slider( label="Temperature", minimum=TTS_PARAM_MIN_TEMPERATURE, @@ -516,40 +509,31 @@ def create_convert_book_tab(): value=DEFAULT_TEMPERATURE, info="Creativity/randomness" ) - + with gr.Column(): gr.Markdown("### ⚑ Advanced Sampling") - + min_p = gr.Slider( label="Min-P", minimum=0.0, maximum=0.5, step=0.01, value=0.05, info="Minimum probability threshold" ) - + top_p = gr.Slider( label="Top-P", minimum=0.5, maximum=1.0, step=0.1, value=1.0, info="Nucleus sampling" ) - + repetition_penalty = gr.Slider( label="Repetition Penalty", minimum=1.0, maximum=3.0, step=0.1, value=2.0, info="Reduce repetition" ) - - gr.Markdown("### βš™οΈ Performance Settings") - - max_workers = gr.Number( - label="Max Workers", - minimum=1, maximum=8, step=1, - value=2, - info="⚠️ Only increase above 2 if CPU/GPU utilization < 70%" - ) - + # Action Buttons and Status with gr.Row(): with gr.Column(scale=2): @@ -559,7 +543,7 @@ def create_convert_book_tab(): size="lg", interactive=True ) - + # Status Display status_display = gr.Textbox( label="Status", @@ -567,40 +551,40 @@ def create_convert_book_tab(): interactive=False, lines=1 ) - + progress_display = gr.Number( label="Progress %", value=0, interactive=False, precision=0 ) - + with gr.Column(scale=1): gr.Markdown("### πŸ“Š Processing Stats") - + realtime_factor = gr.Textbox( label="Realtime Factor", value="--", interactive=False ) - + vram_usage = gr.Textbox( - label="VRAM Usage", + label="VRAM Usage", value="-- GB", interactive=False ) - + current_chunk = gr.Textbox( label="Current Chunk", value="--", interactive=False ) - + # Regenerate M4B Section (moved above audiobook player) with gr.Row(): with gr.Column(): gr.Markdown("### πŸ”„ Regenerate M4B") - + with gr.Row(): with gr.Column(scale=2): m4b_file_selector = gr.Dropdown( @@ -610,7 +594,7 @@ def create_convert_book_tab(): interactive=True, info="Choose from generated audiobook files" ) - + with gr.Column(scale=1): playback_speed = gr.Slider( label="Playback Speed", @@ -620,18 +604,18 @@ def create_convert_book_tab(): value=1.0, info="Speed adjustment for regeneration" ) - + regenerate_m4b_btn = gr.Button( - "πŸ”„ Regenerate M4B", + "πŸ”„ Regenerate M4B", variant="secondary", size="lg" ) - + # Generated Audiobook Player (simplified, play-only) with gr.Row(): with gr.Column(): gr.Markdown("### 🎧 Generated Audiobook Player") - + # Audiobook file selector dropdown audiobook_selector = gr.Dropdown( label="Select Audiobook", @@ -640,7 +624,7 @@ def create_convert_book_tab(): interactive=True, info="Choose from session audiobooks" ) - + # Main audio player - play only, no upload audio_player = gr.Audio( label="Audiobook Player", @@ -654,20 +638,20 @@ def create_convert_book_tab(): skip_length=10 ) ) - + # Event Handlers def handle_voice_upload(voice_file): """Handle voice file upload and show player""" if voice_file is None: return gr.update(value=None, visible=False) - + # Show the voice player with uploaded file return gr.update(value=voice_file, visible=True) - + def get_session_audiobooks(): """Get list of M4B files from current session, sorted by creation time (newest first)""" audiobooks = [] - + # Look in Audiobook directory for M4B files audiobook_root = Path("Audiobook") if audiobook_root.exists(): @@ -678,25 +662,25 @@ def create_convert_book_tab(): # Get creation time for sorting creation_time = m4b_file.stat().st_mtime audiobooks.append((str(m4b_file), m4b_file.name, creation_time)) - + # Also check Output directory output_root = Path("Output") if output_root.exists(): for m4b_file in output_root.glob("*.m4b"): creation_time = m4b_file.stat().st_mtime audiobooks.append((str(m4b_file), m4b_file.name, creation_time)) - + # Sort by creation time (newest first) audiobooks.sort(key=lambda x: x[2], reverse=True) - + # Return just path and name (drop creation time) return [(ab[0], ab[1]) for ab in audiobooks] - + def update_audiobook_dropdowns(latest_file=None): """Update audiobook dropdowns - after conversion both show latest, after regeneration only playback updates""" audiobooks = get_session_audiobooks() choices = [ab[1] for ab in audiobooks] # Just filenames for display - + # Determine what to set as selected if latest_file: # Use specific file if provided @@ -706,159 +690,158 @@ def create_convert_book_tab(): selected_file = choices[0] else: selected_file = None - + return ( gr.update(choices=choices, value=selected_file), # audiobook_selector (playback) gr.update(choices=choices, value=selected_file) # m4b_file_selector (regeneration source) ) - + def update_audiobook_dropdowns_after_conversion(): """Update both dropdowns to show the newest generated file after conversion""" return update_audiobook_dropdowns() - + def update_playback_only(new_file_name): """Update only the playback dropdown after regeneration""" audiobooks = get_session_audiobooks() choices = [ab[1] for ab in audiobooks] - + return ( gr.update(choices=choices, value=new_file_name), # audiobook_selector (playback) - new file gr.update() # m4b_file_selector (regeneration) - no change ) - + def load_selected_audiobook(selected_audiobook): """Load selected audiobook into player""" if not selected_audiobook: return None - + # Find the full path for the selected audiobook audiobooks = get_session_audiobooks() for full_path, filename in audiobooks: if filename == selected_audiobook: return full_path - + return None - + def handle_asr_toggle(asr_enabled_val): """Show/hide ASR configuration when ASR is toggled""" return gr.update(visible=asr_enabled_val) - + def analyze_system(): """Analyze system capabilities and return summary""" try: from modules.system_detector import get_system_profile, print_system_summary, categorize_system - + profile = get_system_profile() categories = categorize_system(profile) - + summary = f"πŸ–₯️ System Profile:\n" summary += f"VRAM: {profile['gpu']['total_mb']:,}MB total, {profile['available_vram_after_tts']:,}MB available after TTS ({categories['vram']} class)\n" summary += f"RAM: {profile['ram']['total_mb']:,}MB total, {profile['ram']['available_mb']:,}MB available ({categories['ram']} class)\n" summary += f"CPU: {profile['cpu_cores']} cores ({categories['cpu']} class)" - + if not profile['has_gpu']: summary += f"\n⚠️ No CUDA GPU detected - ASR will run on CPU only" - + return summary - + except Exception as e: return f"❌ Error analyzing system: {str(e)}" - + def update_asr_models(asr_level_val): """Update ASR model display based on selected level""" try: from modules.system_detector import get_system_profile, recommend_asr_models - + profile = get_system_profile() recommendations = recommend_asr_models(profile) - + if asr_level_val not in recommendations: return "❌ Invalid ASR level selected" - + config = recommendations[asr_level_val] primary = config['primary'] fallback = config['fallback'] - + result = f"Primary: {primary['model']} on {primary['device'].upper()}\n" result += f"Fallback: {fallback['model']} on {fallback['device'].upper()}" - + if asr_level_val == 'insane': result += f"\n⚠️ WARNING: INSANE mode may cause memory pressure" - + return result - + except Exception as e: return f"❌ Error getting models: {str(e)}" - - def start_conversion(text_file_upload, voice_file_upload, + + def start_conversion(text_file_upload, voice_file_upload, vader_val, asr_val, asr_level_val, add_to_batch_val, regen_enabled_val, max_attempts_val, quality_thresh_val, sentiment_smooth_val, smooth_window_val, smooth_method_val, mfcc_val, output_val, spectral_thresh_val, output_thresh_val, - exag_val, cfg_val, temp_val, min_p_val, top_p_val, rep_penalty_val, - max_workers_val): + exag_val, cfg_val, temp_val, min_p_val, top_p_val, rep_penalty_val): """Start the actual book conversion - file upload version""" - + # Validation if not text_file_upload: return "❌ Please upload a text file", 0, None, None if not voice_file_upload: return "❌ Please upload a voice sample", 0, None, None - + # Check if already running if conversion_state['running']: return "⚠️ Conversion already in progress", conversion_state['progress'], None, None - + try: # Create temporary book structure from uploads import tempfile import shutil from datetime import datetime - + # Generate unique book name from text file text_filename = Path(text_file_upload).name book_name = text_filename.replace('.txt', '').replace(' ', '_') timestamp = datetime.now().strftime("%H%M%S") unique_book_name = f"{book_name}_{timestamp}" - + # Create directory structure text_input_dir = Path("Text_Input") text_input_dir.mkdir(exist_ok=True) - + book_dir = text_input_dir / unique_book_name book_dir.mkdir(exist_ok=True) - + # Copy uploaded files to expected locations text_dest = book_dir / f"{unique_book_name}.txt" shutil.copy2(text_file_upload, text_dest) - + voice_samples_dir = Path("Voice_Samples") voice_samples_dir.mkdir(exist_ok=True) - + voice_filename = Path(voice_file_upload).name voice_dest = voice_samples_dir / voice_filename shutil.copy2(voice_file_upload, voice_dest) - + print(f"πŸ“ Created book structure: {book_dir}") print(f"πŸ“„ Text file: {text_dest}") print(f"🎀 Voice file: {voice_dest}") - + except Exception as e: return f"❌ Error setting up files: {e}", 0, None, None - + # Build ASR configuration first asr_config = {'enabled': False} if asr_val: try: from modules.system_detector import get_system_profile, recommend_asr_models - profile = get_system_profile() + profile = get_system_profile() recommendations = recommend_asr_models(profile) - + if asr_level_val in recommendations: selected_config = recommendations[asr_level_val] primary = selected_config['primary'] fallback = selected_config['fallback'] - + asr_config = { 'enabled': True, 'level': asr_level_val, @@ -870,7 +853,7 @@ def create_convert_book_tab(): except Exception as e: print(f"⚠️ Error configuring ASR: {e}") asr_config = {'enabled': False} - + # Prepare parameters (matching GUI structure exactly) tts_params = { 'exaggeration': exag_val, @@ -879,10 +862,9 @@ def create_convert_book_tab(): 'min_p': min_p_val, 'top_p': top_p_val, 'repetition_penalty': rep_penalty_val, - 'enable_asr': asr_config.get('enabled', False), # Match GUI pattern - 'max_workers': int(max_workers_val) # User-defined worker count + 'enable_asr': asr_config.get('enabled', False) # Match GUI pattern } - + quality_params = { 'regeneration_enabled': regen_enabled_val, 'max_attempts': max_attempts_val, @@ -895,49 +877,50 @@ def create_convert_book_tab(): 'spectral_threshold': spectral_thresh_val, 'output_threshold': output_thresh_val } - + config_params = { 'vader_enabled': vader_val, 'asr_enabled': asr_val, 'asr_config': asr_config, 'add_to_batch': add_to_batch_val } - + # Set conversion state conversion_state['running'] = True conversion_state['progress'] = 0 conversion_state['status'] = 'Starting conversion...' conversion_state['current_book'] = book_dir.name # Track current book - + try: # Run conversion using the modular backend in a separate thread import threading - + def run_conversion_thread(): try: result = run_book_conversion( str(book_dir), str(text_dest), str(voice_dest), tts_params, quality_params, config_params ) - + if result['success']: - conversion_state['status'] = 'πŸŽ‰ CONVERSION COMPLETE! M4B audiobook ready for playback.' + conversion_state['status'] = 'βœ… Conversion completed successfully!' conversion_state['progress'] = 100 - conversion_state['auto_refresh_needed'] = True # Flag for auto-refresh + # Trigger automatic refresh of audiobook dropdowns + conversion_state['needs_refresh'] = True else: conversion_state['status'] = f"❌ Conversion failed: {result.get('error', 'Unknown error')}" conversion_state['progress'] = 0 - + except Exception as e: conversion_state['status'] = f"❌ Error: {str(e)}" conversion_state['progress'] = 0 finally: conversion_state['running'] = False - + # Start conversion thread thread = threading.Thread(target=run_conversion_thread) thread.start() - + # Return immediate response - user will need to refresh to see final results return ( "πŸš€ Conversion started in background...", @@ -946,42 +929,42 @@ def create_convert_book_tab(): gr.update(), gr.update() ) - + except Exception as e: conversion_state['status'] = f"❌ Error: {str(e)}" return conversion_state['status'], 0, None, gr.update(), gr.update() finally: conversion_state['running'] = False - - + + # Connect event handlers - + # ASR event handlers asr_enabled.change( handle_asr_toggle, inputs=[asr_enabled], outputs=[asr_config_group] ) - + analyze_system_btn.click( analyze_system, inputs=[], outputs=[system_analysis] ) - + asr_level.change( update_asr_models, inputs=[asr_level], outputs=[selected_models] ) - + # Voice upload handler voice_file_upload.change( handle_voice_upload, inputs=[voice_file_upload], outputs=[voice_audio] ) - + # Main conversion handler convert_btn.click( start_conversion, @@ -991,24 +974,23 @@ def create_convert_book_tab(): regeneration_enabled, max_attempts, quality_threshold, sentiment_smoothing, smoothing_window, smoothing_method, mfcc_validation, output_validation, spectral_threshold, output_threshold, - exaggeration, cfg_weight, temperature, min_p, top_p, repetition_penalty, - max_workers + exaggeration, cfg_weight, temperature, min_p, top_p, repetition_penalty ], outputs=[status_display, progress_display, audio_player, audiobook_selector, m4b_file_selector] ) - + # Audiobook selector handler audiobook_selector.change( load_selected_audiobook, inputs=[audiobook_selector], outputs=[audio_player] ) - + # M4B regeneration handler def handle_m4b_regeneration(selected_m4b, speed): """Handle M4B regeneration and update player""" status_msg, new_m4b_path = regenerate_m4b_file(selected_m4b, speed) - + if new_m4b_path: # Load the new M4B in the player new_file_name = Path(new_m4b_path).name @@ -1018,13 +1000,13 @@ def create_convert_book_tab(): return status_msg, new_audio, audiobook_choices, m4b_choices else: return status_msg, None, gr.update(), gr.update() - + regenerate_m4b_btn.click( handle_m4b_regeneration, inputs=[m4b_file_selector, playback_speed], outputs=[status_display, audio_player, audiobook_selector, m4b_file_selector] ) - + # Progress monitoring with file-based approach def get_current_stats(): """Get current progress statistics by monitoring output files""" @@ -1033,11 +1015,11 @@ def create_convert_book_tab(): # Look for generated audio chunks to estimate progress book_name = conversion_state.get('current_book', 'unknown') audiobook_root = Path("Audiobook") / book_name / "TTS" / "audio_chunks" - + if audiobook_root.exists(): chunk_files = list(audiobook_root.glob("chunk_*.wav")) current_chunks = len(chunk_files) - + # Try to estimate total from JSON if available json_path = Path("Text_Input") / f"{book_name}_chunks.json" total_chunks = 0 @@ -1046,19 +1028,19 @@ def create_convert_book_tab(): with open(json_path, 'r') as f: data = json.load(f) total_chunks = len(data) - + if total_chunks > 0: progress = int((current_chunks / total_chunks) * 100) conversion_state['progress'] = progress conversion_state['current_chunk'] = f"{current_chunks}/{total_chunks}" - + return ( conversion_state.get('realtime_factor', '--'), conversion_state.get('vram_usage', '-- GB'), f"{current_chunks}/{total_chunks}", progress ) - + return ( conversion_state.get('realtime_factor', '--'), conversion_state.get('vram_usage', '-- GB'), @@ -1068,48 +1050,7 @@ def create_convert_book_tab(): except Exception as e: print(f"Error getting stats: {e}") return "--", "-- GB", "--", conversion_state.get('progress', 0) - - def auto_check_completion(): - """Automatically check for completion and refresh interface""" - # First get current stats - stats = get_current_stats() - - # Check if conversion just completed and needs auto-refresh - if (not conversion_state['running'] and - conversion_state['progress'] == 100 and - conversion_state.get('auto_refresh_needed', False)): - - # Clear the auto-refresh flag - conversion_state['auto_refresh_needed'] = False - print("πŸŽ‰ Auto-detected completion! Refreshing interface...") - - # Get completion results - status, progress, audio, audiobook_choices, m4b_choices = get_status_and_results() - - # Return combined stats + completion results - return ( - stats[0], # realtime_factor - stats[1], # vram_usage - stats[2], # current_chunk - 100, # progress (completed) - status, # completion status - audio, # audio player - audiobook_choices, # audiobook dropdown - m4b_choices # m4b dropdown - ) - else: - # Return stats + current status (no completion) - return ( - stats[0], # realtime_factor - stats[1], # vram_usage - stats[2], # current_chunk - stats[3], # progress - conversion_state.get('status', '⏸ Ready'), # current status - gr.update(), # no audio update - gr.update(), # no audiobook update - gr.update() # no m4b update - ) - + def get_status_and_results(): """Get conversion status and results after completion""" if not conversion_state['running'] and conversion_state['progress'] == 100: @@ -1118,7 +1059,22 @@ def create_convert_book_tab(): latest_audiobook = None if audiobook_choices['choices']: latest_audiobook = load_selected_audiobook(audiobook_choices['choices'][0]) - + + return ( + conversion_state['status'], + conversion_state['progress'], + latest_audiobook, + audiobook_choices, + m4b_choices + ) + elif conversion_state.get('needs_refresh', False): + # Auto-refresh requested + conversion_state['needs_refresh'] = False + audiobook_choices, m4b_choices = update_audiobook_dropdowns_after_conversion() + latest_audiobook = None + if audiobook_choices['choices']: + latest_audiobook = load_selected_audiobook(audiobook_choices['choices'][0]) + return ( conversion_state['status'], conversion_state['progress'], @@ -1134,31 +1090,22 @@ def create_convert_book_tab(): gr.update(), gr.update() ) - + # Create refresh buttons with gr.Row(): refresh_stats_btn = gr.Button("πŸ”„ Refresh Stats", size="sm", variant="secondary") check_completion_btn = gr.Button("πŸ“‹ Check Completion", size="sm", variant="secondary") - # Auto-refresh timer (checks every 5 seconds during conversion) - auto_timer = gr.Timer(5.0) # 5 second interval - refresh_stats_btn.click( - auto_check_completion, - outputs=[realtime_factor, vram_usage, current_chunk, progress_display, status_display, audio_player, audiobook_selector, m4b_file_selector] + get_current_stats, + outputs=[realtime_factor, vram_usage, current_chunk, progress_display] ) - + check_completion_btn.click( get_status_and_results, outputs=[status_display, progress_display, audio_player, audiobook_selector, m4b_file_selector] ) - # Auto-timer for progress monitoring and completion detection - auto_timer.tick( - auto_check_completion, - outputs=[realtime_factor, vram_usage, current_chunk, progress_display, status_display, audio_player, audiobook_selector, m4b_file_selector] - ) - return { 'convert_button': convert_btn, 'status_display': status_display, @@ -1169,5 +1116,5 @@ if __name__ == "__main__": # Test the tab with gr.Blocks() as demo: create_convert_book_tab() - - demo.launch() + + demo.launch() \ No newline at end of file diff --git a/gradio_tabs/tab2_configuration.py b/gradio_tabs/tab2_configuration.py new file mode 100644 index 0000000000000000000000000000000000000000..63992b34772d3c3ca72809480e33b7b22d7d4aac --- /dev/null +++ b/gradio_tabs/tab2_configuration.py @@ -0,0 +1,547 @@ +#!/usr/bin/env python3 +""" +Gradio Tab 2: Configuration Settings +Matches PyQt5 GUI Tab 2 functionality with all configuration options +""" + +import gradio as gr +import os +import sys +import json +from pathlib import Path +from typing import Dict, Any, Tuple, List + +# Import configuration +try: + from config.config import * + CONFIG_AVAILABLE = True + print("βœ… Config module loaded successfully") +except ImportError as e: + print(f"⚠️ Config not available: {e}") + CONFIG_AVAILABLE = False + # Default values if config not available + MAX_WORKERS = 2 + BATCH_SIZE = 100 + MIN_CHUNK_WORDS = 5 + MAX_CHUNK_WORDS = 25 + ENABLE_NORMALIZATION = True + TARGET_LUFS = -16 + ENABLE_AUDIO_TRIMMING = True + SPEECH_ENDPOINT_THRESHOLD = 0.005 + TRIMMING_BUFFER_MS = 100 + TTS_PARAM_MIN_EXAGGERATION = 0.0 + TTS_PARAM_MAX_EXAGGERATION = 2.0 + TTS_PARAM_MIN_CFG_WEIGHT = 0.0 + TTS_PARAM_MAX_CFG_WEIGHT = 1.0 + TTS_PARAM_MIN_TEMPERATURE = 0.0 + TTS_PARAM_MAX_TEMPERATURE = 5.0 + DEFAULT_EXAGGERATION = 0.5 + DEFAULT_CFG_WEIGHT = 0.5 + DEFAULT_TEMPERATURE = 0.8 + VADER_EXAGGERATION_SENSITIVITY = 0.3 + VADER_CFG_WEIGHT_SENSITIVITY = 0.3 + VADER_TEMPERATURE_SENSITIVITY = 0.3 + SILENCE_CHAPTER_START = 1000 + SILENCE_CHAPTER_END = 1500 + SILENCE_SECTION_BREAK = 800 + SILENCE_PARAGRAPH_END = 500 + SILENCE_COMMA = 200 + SILENCE_PERIOD = 400 + SILENCE_QUESTION_MARK = 500 + SILENCE_EXCLAMATION = 500 + CHUNK_END_SILENCE_MS = 0 + +def create_configuration_tab(): + """Create Tab 2: Configuration Settings with all GUI functionality""" + + with gr.Column(): + gr.Markdown("# βš™οΈ Configuration Settings") + gr.Markdown("*System configuration and parameter management - matches GUI Tab 2*") + + # Top Row - Core Settings with Group Boxes + with gr.Row(): + # Workers/Batch Settings Group + with gr.Column(): + gr.Markdown("### πŸ”§ Workers & Batch Settings") + gr.Markdown("*Set # workers for parallel processing. Too many workers will use up VRAM. Only increase if VRAM and GPU % are below 60% utilized. Batch size: set to determine when model is reloaded to flush VRAM and avoid recursive problems and slowdowns.*") + + workers_spin = gr.Slider( + label="Workers", + minimum=1, maximum=8, step=1, + value=MAX_WORKERS, + info="Number of parallel processing threads" + ) + + batch_size_spin = gr.Slider( + label="Batch Size", + minimum=50, maximum=500, step=50, + value=BATCH_SIZE, + info="Chunks processed before model reload" + ) + + # Min/Max Words Group + with gr.Column(): + gr.Markdown("### πŸ“ Chunk Word Limits") + gr.Markdown("*Set Min/Max for words in a text chunk. Too many words can lead to poor TTS.*") + + min_chunk_words_spin = gr.Slider( + label="Min Words", + minimum=1, maximum=50, step=1, + value=MIN_CHUNK_WORDS, + info="Minimum words per chunk" + ) + + max_chunk_words_spin = gr.Slider( + label="Max Words", + minimum=10, maximum=100, step=1, + value=MAX_CHUNK_WORDS, + info="Maximum words per chunk" + ) + + # Audio Detection Group + with gr.Column(): + gr.Markdown("### πŸ”Š Audio Processing") + gr.Markdown("*Detects when audio speech stops and just noise or silence follow at end of audio chunk. Use threshold to change detection range of voice. Buffer adds ms silence to end of audio chunk.*") + + normalization_check = gr.Checkbox( + label="Audio normalization", + value=ENABLE_NORMALIZATION, + info="Enable loudness normalization" + ) + + target_lufs_spin = gr.Slider( + label="Target LUFS (dB)", + minimum=-30, maximum=-6, step=1, + value=TARGET_LUFS, + info="Target loudness level" + ) + + audio_trimming_check = gr.Checkbox( + label="Automatic audio trimming", + value=ENABLE_AUDIO_TRIMMING, + info="Trim silence from audio chunks" + ) + + speech_threshold_spin = gr.Slider( + label="Speech Threshold", + minimum=0.001, maximum=0.1, step=0.001, + value=SPEECH_ENDPOINT_THRESHOLD, + info="Speech detection sensitivity" + ) + + trimming_buffer_spin = gr.Slider( + label="Buffer (ms)", + minimum=0, maximum=500, step=10, + value=TRIMMING_BUFFER_MS, + info="Silence buffer after speech" + ) + + # TTS Parameter Limits Group + with gr.Row(): + with gr.Column(): + gr.Markdown("### πŸ”’ TTS Parameter Limits") + gr.Markdown("*Set the upper and lower limits that TTS Params can be automatically adjusted to by VADER and other functions.*") + + with gr.Row(): + exag_min_spin = gr.Slider( + label="Exag Min", + minimum=0.0, maximum=2.0, step=0.05, + value=TTS_PARAM_MIN_EXAGGERATION, + info="Minimum exaggeration limit" + ) + + exag_max_spin = gr.Slider( + label="Exag Max", + minimum=0.0, maximum=2.0, step=0.05, + value=TTS_PARAM_MAX_EXAGGERATION, + info="Maximum exaggeration limit" + ) + + with gr.Row(): + cfg_min_spin = gr.Slider( + label="CFG Min", + minimum=0.0, maximum=1.0, step=0.05, + value=TTS_PARAM_MIN_CFG_WEIGHT, + info="Minimum CFG weight limit" + ) + + cfg_max_spin = gr.Slider( + label="CFG Max", + minimum=0.0, maximum=1.0, step=0.05, + value=TTS_PARAM_MAX_CFG_WEIGHT, + info="Maximum CFG weight limit" + ) + + with gr.Row(): + temp_min_spin = gr.Slider( + label="Temp Min", + minimum=0.0, maximum=5.0, step=0.05, + value=TTS_PARAM_MIN_TEMPERATURE, + info="Minimum temperature limit" + ) + + temp_max_spin = gr.Slider( + label="Temp Max", + minimum=0.0, maximum=5.0, step=0.05, + value=TTS_PARAM_MAX_TEMPERATURE, + info="Maximum temperature limit" + ) + + # TTS Defaults and VADER Sensitivity Section + with gr.Row(): + # TTS Defaults Group + with gr.Column(): + gr.Markdown("### 🎯 TTS Defaults") + gr.Markdown("*Default values: Exag: 0.50, CFG: 0.50, Temp: 0.80*") + + default_exag_spin = gr.Slider( + label="Default Exaggeration", + minimum=0.0, maximum=2.0, step=0.05, + value=DEFAULT_EXAGGERATION, + info="Base exaggeration value" + ) + + default_cfg_spin = gr.Slider( + label="Default CFG Weight", + minimum=0.0, maximum=1.0, step=0.05, + value=DEFAULT_CFG_WEIGHT, + info="Base CFG weight value" + ) + + default_temp_spin = gr.Slider( + label="Default Temperature", + minimum=0.0, maximum=5.0, step=0.05, + value=DEFAULT_TEMPERATURE, + info="Base temperature value" + ) + + # VADER Sensitivity Group + with gr.Column(): + gr.Markdown("### 🎭 VADER Sensitivity") + gr.Markdown("*Default values: Exag Sens: 0.30, CFG Sens: 0.30, Temp Sens: 0.30*") + gr.Markdown("*VADER Sensitivity sets how much VADER adjusts the TTS params based on emotional weight.*") + + vader_exag_sens_spin = gr.Slider( + label="Exag Sensitivity", + minimum=0.0, maximum=1.0, step=0.01, + value=VADER_EXAGGERATION_SENSITIVITY, + info="VADER exaggeration adjustment strength" + ) + + vader_cfg_sens_spin = gr.Slider( + label="CFG Sensitivity", + minimum=0.0, maximum=1.0, step=0.01, + value=VADER_CFG_WEIGHT_SENSITIVITY, + info="VADER CFG weight adjustment strength" + ) + + vader_temp_sens_spin = gr.Slider( + label="Temp Sensitivity", + minimum=0.0, maximum=1.0, step=0.01, + value=VADER_TEMPERATURE_SENSITIVITY, + info="VADER temperature adjustment strength" + ) + + # Silence Settings Group + with gr.Column(): + gr.Markdown("### πŸ”‡ Silence Settings") + gr.Markdown("*Set the silence added to audio chunks for each type of chunk. ie chapter start/end, period, paragraph. For each setting silence is added for pacing.*") + + # Chapter/Section silence + with gr.Row(): + silence_chapter_start_spin = gr.Slider( + label="Chapter Start (ms)", + minimum=0, maximum=9999, step=100, + value=SILENCE_CHAPTER_START, + info="Silence before chapter starts" + ) + + silence_chapter_end_spin = gr.Slider( + label="Chapter End (ms)", + minimum=0, maximum=9999, step=100, + value=SILENCE_CHAPTER_END, + info="Silence after chapter ends" + ) + + silence_section_spin = gr.Slider( + label="Section Break (ms)", + minimum=0, maximum=9999, step=100, + value=SILENCE_SECTION_BREAK, + info="Silence for section breaks" + ) + + silence_paragraph_spin = gr.Slider( + label="Paragraph End (ms)", + minimum=0, maximum=9999, step=50, + value=SILENCE_PARAGRAPH_END, + info="Silence after paragraphs" + ) + + # Punctuation silence + with gr.Row(): + silence_comma_spin = gr.Slider( + label="Comma (ms)", + minimum=0, maximum=9999, step=50, + value=SILENCE_COMMA, + info="Silence after commas" + ) + + silence_period_spin = gr.Slider( + label="Period (ms)", + minimum=0, maximum=9999, step=50, + value=SILENCE_PERIOD, + info="Silence after periods" + ) + + silence_question_spin = gr.Slider( + label="Question Mark (ms)", + minimum=0, maximum=9999, step=50, + value=SILENCE_QUESTION_MARK, + info="Silence after questions" + ) + + silence_exclamation_spin = gr.Slider( + label="Exclamation (ms)", + minimum=0, maximum=9999, step=50, + value=SILENCE_EXCLAMATION, + info="Silence after exclamations" + ) + + # Chunk silence settings + with gr.Row(): + chunk_end_silence_check = gr.Checkbox( + label="Enable Chunk End Silence", + value=CHUNK_END_SILENCE_MS > 0, + info="Add silence to end of every chunk" + ) + + chunk_end_silence_spin = gr.Slider( + label="Chunk End Silence (ms)", + minimum=0, maximum=9999, step=50, + value=CHUNK_END_SILENCE_MS, + info="Silence added to chunk ends" + ) + + # Config action buttons + with gr.Row(): + save_config_btn = gr.Button( + "πŸ’Ύ Save Configuration", + variant="primary", + size="lg" + ) + + reset_config_btn = gr.Button( + "πŸ”„ Reset to Defaults", + variant="secondary", + size="lg" + ) + + reload_config_btn = gr.Button( + "♻️ Reload Configuration", + variant="secondary", + size="lg" + ) + + # Status display + config_status = gr.Textbox( + label="Configuration Status", + value="Ready to save or load configuration", + interactive=False, + lines=2 + ) + + # Event Handlers + def save_configuration(*values): + """Save current configuration settings""" + try: + if not CONFIG_AVAILABLE: + return "❌ Configuration module not available" + + # Map values back to config variables + config_values = { + 'MAX_WORKERS': int(values[0]), + 'BATCH_SIZE': int(values[1]), + 'MIN_CHUNK_WORDS': int(values[2]), + 'MAX_CHUNK_WORDS': int(values[3]), + 'ENABLE_NORMALIZATION': values[4], + 'TARGET_LUFS': int(values[5]), + 'ENABLE_AUDIO_TRIMMING': values[6], + 'SPEECH_ENDPOINT_THRESHOLD': values[7], + 'TRIMMING_BUFFER_MS': int(values[8]), + 'TTS_PARAM_MIN_EXAGGERATION': values[9], + 'TTS_PARAM_MAX_EXAGGERATION': values[10], + 'TTS_PARAM_MIN_CFG_WEIGHT': values[11], + 'TTS_PARAM_MAX_CFG_WEIGHT': values[12], + 'TTS_PARAM_MIN_TEMPERATURE': values[13], + 'TTS_PARAM_MAX_TEMPERATURE': values[14], + 'DEFAULT_EXAGGERATION': values[15], + 'DEFAULT_CFG_WEIGHT': values[16], + 'DEFAULT_TEMPERATURE': values[17], + 'VADER_EXAGGERATION_SENSITIVITY': values[18], + 'VADER_CFG_WEIGHT_SENSITIVITY': values[19], + 'VADER_TEMPERATURE_SENSITIVITY': values[20], + 'SILENCE_CHAPTER_START': int(values[21]), + 'SILENCE_CHAPTER_END': int(values[22]), + 'SILENCE_SECTION_BREAK': int(values[23]), + 'SILENCE_PARAGRAPH_END': int(values[24]), + 'SILENCE_COMMA': int(values[25]), + 'SILENCE_PERIOD': int(values[26]), + 'SILENCE_QUESTION_MARK': int(values[27]), + 'SILENCE_EXCLAMATION': int(values[28]), + 'CHUNK_END_SILENCE_MS': int(values[30]) if values[29] else 0 + } + + # Import the config module and update values + from config import config + for key, value in config_values.items(): + if hasattr(config, key): + setattr(config, key, value) + + return "βœ… Configuration saved successfully!\nπŸ”„ Settings updated in memory. Restart application to persist changes." + + except Exception as e: + return f"❌ Error saving configuration: {str(e)}" + + def reset_configuration(): + """Reset all configuration values to defaults""" + try: + # Return default values for all controls + return ( + 2, # MAX_WORKERS + 100, # BATCH_SIZE + 5, # MIN_CHUNK_WORDS + 25, # MAX_CHUNK_WORDS + True, # ENABLE_NORMALIZATION + -16, # TARGET_LUFS + True, # ENABLE_AUDIO_TRIMMING + 0.005, # SPEECH_ENDPOINT_THRESHOLD + 100, # TRIMMING_BUFFER_MS + 0.0, # TTS_PARAM_MIN_EXAGGERATION + 2.0, # TTS_PARAM_MAX_EXAGGERATION + 0.0, # TTS_PARAM_MIN_CFG_WEIGHT + 1.0, # TTS_PARAM_MAX_CFG_WEIGHT + 0.0, # TTS_PARAM_MIN_TEMPERATURE + 5.0, # TTS_PARAM_MAX_TEMPERATURE + 0.5, # DEFAULT_EXAGGERATION + 0.5, # DEFAULT_CFG_WEIGHT + 0.8, # DEFAULT_TEMPERATURE + 0.3, # VADER_EXAGGERATION_SENSITIVITY + 0.3, # VADER_CFG_WEIGHT_SENSITIVITY + 0.3, # VADER_TEMPERATURE_SENSITIVITY + 1000, # SILENCE_CHAPTER_START + 1500, # SILENCE_CHAPTER_END + 800, # SILENCE_SECTION_BREAK + 500, # SILENCE_PARAGRAPH_END + 200, # SILENCE_COMMA + 400, # SILENCE_PERIOD + 500, # SILENCE_QUESTION_MARK + 500, # SILENCE_EXCLAMATION + False, # chunk_end_silence_check + 0, # CHUNK_END_SILENCE_MS + "πŸ”„ Configuration reset to default values" + ) + + except Exception as e: + return tuple([None] * 30 + [f"❌ Error resetting configuration: {str(e)}"]) + + def reload_configuration(): + """Reload configuration from file""" + try: + if not CONFIG_AVAILABLE: + return "❌ Configuration module not available" + + # Reload config module + import importlib + from config import config + importlib.reload(config) + + # Return reloaded values + return ( + config.MAX_WORKERS, + config.BATCH_SIZE, + config.MIN_CHUNK_WORDS, + config.MAX_CHUNK_WORDS, + config.ENABLE_NORMALIZATION, + config.TARGET_LUFS, + config.ENABLE_AUDIO_TRIMMING, + config.SPEECH_ENDPOINT_THRESHOLD, + config.TRIMMING_BUFFER_MS, + config.TTS_PARAM_MIN_EXAGGERATION, + config.TTS_PARAM_MAX_EXAGGERATION, + config.TTS_PARAM_MIN_CFG_WEIGHT, + config.TTS_PARAM_MAX_CFG_WEIGHT, + config.TTS_PARAM_MIN_TEMPERATURE, + config.TTS_PARAM_MAX_TEMPERATURE, + config.DEFAULT_EXAGGERATION, + config.DEFAULT_CFG_WEIGHT, + config.DEFAULT_TEMPERATURE, + config.VADER_EXAGGERATION_SENSITIVITY, + config.VADER_CFG_WEIGHT_SENSITIVITY, + config.VADER_TEMPERATURE_SENSITIVITY, + config.SILENCE_CHAPTER_START, + config.SILENCE_CHAPTER_END, + config.SILENCE_SECTION_BREAK, + config.SILENCE_PARAGRAPH_END, + config.SILENCE_COMMA, + config.SILENCE_PERIOD, + config.SILENCE_QUESTION_MARK, + config.SILENCE_EXCLAMATION, + config.CHUNK_END_SILENCE_MS > 0, + config.CHUNK_END_SILENCE_MS, + "βœ… Configuration reloaded from file" + ) + + except Exception as e: + return tuple([None] * 30 + [f"❌ Error reloading configuration: {str(e)}"]) + + # All input components for save operation + all_inputs = [ + workers_spin, batch_size_spin, min_chunk_words_spin, max_chunk_words_spin, + normalization_check, target_lufs_spin, audio_trimming_check, + speech_threshold_spin, trimming_buffer_spin, + exag_min_spin, exag_max_spin, cfg_min_spin, cfg_max_spin, + temp_min_spin, temp_max_spin, + default_exag_spin, default_cfg_spin, default_temp_spin, + vader_exag_sens_spin, vader_cfg_sens_spin, vader_temp_sens_spin, + silence_chapter_start_spin, silence_chapter_end_spin, + silence_section_spin, silence_paragraph_spin, + silence_comma_spin, silence_period_spin, + silence_question_spin, silence_exclamation_spin, + chunk_end_silence_check, chunk_end_silence_spin + ] + + # All output components for reset/reload operations + all_outputs = all_inputs + [config_status] + + # Connect event handlers + save_config_btn.click( + save_configuration, + inputs=all_inputs, + outputs=[config_status] + ) + + reset_config_btn.click( + reset_configuration, + inputs=[], + outputs=all_outputs + ) + + reload_config_btn.click( + reload_configuration, + inputs=[], + outputs=all_outputs + ) + + return { + 'save_button': save_config_btn, + 'reset_button': reset_config_btn, + 'reload_button': reload_config_btn, + 'status_display': config_status + } + +if __name__ == "__main__": + # Test the tab + with gr.Blocks() as demo: + create_configuration_tab() + + demo.launch() \ No newline at end of file diff --git a/gradio_tabs/tab4_combine_audio.py b/gradio_tabs/tab4_combine_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..3b090052637721d6875d5b4f7087155a96f95817 --- /dev/null +++ b/gradio_tabs/tab4_combine_audio.py @@ -0,0 +1,511 @@ +#!/usr/bin/env python3 +""" +Gradio Tab 4: Combine Audio +Combine processed audio chunks into final audiobook - matches PyQt5 GUI Tab 4 functionality +""" + +import gradio as gr +import os +import sys +import threading +import time +from pathlib import Path +from typing import List, Dict, Any, Optional, Tuple + +# Import backend functionality +try: + from tools.combine_only import combine_audio_for_book + from modules.file_manager import get_audio_files_in_directory + from modules.audio_processor import get_wav_duration + COMBINE_AVAILABLE = True + print("βœ… Audio combine functionality available") +except ImportError as e: + print(f"⚠️ Combine functionality not available: {e}") + COMBINE_AVAILABLE = False + +# Global state for combine operations +combine_state = { + 'running': False, + 'progress': 0, + 'status': 'Ready', + 'thread': None, + 'current_book': None +} + +def get_available_books(): + """Get list of books with audio chunks available for combining""" + books = [] + + # Look in Audiobook directory for processed books + audiobook_root = Path("Audiobook") + if audiobook_root.exists(): + for book_dir in audiobook_root.iterdir(): + if book_dir.is_dir(): + # Check for TTS/audio_chunks directory + audio_chunks_dir = book_dir / "TTS" / "audio_chunks" + if audio_chunks_dir.exists(): + chunk_files = list(audio_chunks_dir.glob("chunk_*.wav")) + if chunk_files: + # Get basic info about the chunks + chunk_count = len(chunk_files) + try: + total_duration = sum(get_wav_duration(chunk) for chunk in chunk_files) + duration_str = f"{int(total_duration//3600):02d}:{int((total_duration%3600)//60):02d}:{int(total_duration%60):02d}" + except: + duration_str = "Unknown" + + books.append({ + 'name': book_dir.name, + 'path': str(book_dir), + 'chunk_count': chunk_count, + 'duration': duration_str, + 'status': 'Ready to combine' + }) + + return sorted(books, key=lambda x: x['name']) + +def get_book_info(book_path_str): + """Get detailed information about a book's audio chunks""" + if not book_path_str: + return "No book selected" + + try: + book_path = Path(book_path_str) + audio_chunks_dir = book_path / "TTS" / "audio_chunks" + + if not audio_chunks_dir.exists(): + return f"❌ No audio chunks found in {book_path.name}" + + chunk_files = get_audio_files_in_directory(audio_chunks_dir) + if not chunk_files: + return f"❌ No valid chunk files found in {book_path.name}" + + # Calculate statistics + total_chunks = len(chunk_files) + try: + total_duration = sum(get_wav_duration(chunk) for chunk in chunk_files) + duration_str = f"{int(total_duration//3600):02d}:{int((total_duration%3600)//60):02d}:{int(total_duration%60):02d}" + avg_duration = total_duration / total_chunks if total_chunks > 0 else 0 + except Exception as e: + duration_str = "Error calculating" + avg_duration = 0 + + # Check for existing combined files + combined_wav = book_path / f"{book_path.name}_combined.wav" + combined_m4b = book_path / f"{book_path.name}_combined.m4b" + + existing_files = [] + if combined_wav.exists(): + size_mb = combined_wav.stat().st_size / (1024 * 1024) + existing_files.append(f"WAV: {size_mb:.1f}MB") + if combined_m4b.exists(): + size_mb = combined_m4b.stat().st_size / (1024 * 1024) + existing_files.append(f"M4B: {size_mb:.1f}MB") + + info = f"""πŸ“Š **Book Analysis: {book_path.name}** + +**Audio Chunks:** +β€’ Total Chunks: {total_chunks} +β€’ Total Duration: {duration_str} +β€’ Average Chunk: {avg_duration:.1f}s +β€’ Location: {audio_chunks_dir} + +**Existing Combined Files:** +{f"β€’ {', '.join(existing_files)}" if existing_files else "β€’ None found"} + +**Status:** Ready for combining""" + + return info + + except Exception as e: + return f"❌ Error analyzing book: {str(e)}" + +def run_combine_operation(book_path_str, voice_name=None): + """Run the audio combine operation""" + try: + if not COMBINE_AVAILABLE: + return {'success': False, 'error': 'Combine functionality not available'} + + print(f"πŸ”— Starting combine operation for: {book_path_str}") + + # Update combine state + combine_state['running'] = True + combine_state['progress'] = 10 + combine_state['status'] = 'Starting combine operation...' + combine_state['current_book'] = Path(book_path_str).name + + # Run the actual combine operation + result = combine_audio_for_book(book_path_str, voice_name) + + combine_state['progress'] = 100 + combine_state['status'] = 'βœ… Combine completed successfully!' if result else '❌ Combine failed' + + return {'success': result, 'message': 'Audio chunks combined successfully!' if result else 'Combine operation failed'} + + except Exception as e: + print(f"❌ Combine operation failed: {e}") + combine_state['status'] = f'❌ Error: {str(e)}' + combine_state['progress'] = 0 + return {'success': False, 'error': str(e)} + finally: + combine_state['running'] = False + +def create_combine_audio_tab(): + """Create Tab 4: Combine Audio with all GUI functionality""" + + with gr.Column(): + gr.Markdown("# πŸ”— Combine Audio Chunks") + gr.Markdown("*Combine processed audio chunks into final audiobook - matches GUI Tab 4*") + + # Important note + gr.Markdown(""" + ### πŸ“ **Important Note** + Select the main book folder (e.g., 'Audiobook/BookName'), NOT the TTS or audio_chunks subfolder. + This tool combines existing audio chunks that have already been generated by the TTS process. + """) + + # Book Selection Section + with gr.Row(): + with gr.Column(scale=2): + gr.Markdown("### πŸ“š Select Book to Combine") + + # Available books dropdown + available_books = get_available_books() + book_choices = [f"{book['name']} ({book['chunk_count']} chunks, {book['duration']})" + for book in available_books] + book_paths = {f"{book['name']} ({book['chunk_count']} chunks, {book['duration']})": book['path'] + for book in available_books} + + book_selector = gr.Dropdown( + label="Available Books with Audio Chunks", + choices=book_choices, + value=book_choices[0] if book_choices else None, + interactive=True, + info="Books with processed audio chunks ready for combining" + ) + + # Manual path input (for advanced users) + manual_path_input = gr.Textbox( + label="Or Enter Book Path Manually", + placeholder="e.g., /path/to/Audiobook/BookName", + interactive=True, + info="Full path to book folder containing TTS/audio_chunks" + ) + + # Refresh books button + refresh_books_btn = gr.Button( + "πŸ”„ Refresh Book List", + variant="secondary", + size="sm" + ) + + with gr.Column(scale=1): + # Book information display + book_info_display = gr.Markdown( + "Select a book to see detailed information", + label="Book Information" + ) + + # Optional Voice Name + with gr.Row(): + voice_name_input = gr.Textbox( + label="Voice Name (Optional)", + placeholder="e.g., NarratorName", + info="Used for output filename. If empty, uses '_combined' suffix", + interactive=True + ) + + # Action Buttons + with gr.Row(): + combine_btn = gr.Button( + "πŸ”— Combine Audio Chunks", + variant="primary", + size="lg", + interactive=True + ) + + stop_btn = gr.Button( + "⏹️ Stop Operation", + variant="secondary", + size="lg", + interactive=False + ) + + # Status and Progress + with gr.Row(): + with gr.Column(scale=2): + status_display = gr.Textbox( + label="Operation Status", + value="Ready to combine audio chunks", + interactive=False, + lines=2 + ) + + progress_display = gr.Number( + label="Progress %", + value=0, + interactive=False, + precision=0 + ) + + with gr.Column(scale=1): + # Operation details + current_book_display = gr.Textbox( + label="Current Book", + value="--", + interactive=False + ) + + operation_time_display = gr.Textbox( + label="Operation Time", + value="--:--:--", + interactive=False + ) + + # Output Files Section + with gr.Column(): + gr.Markdown("### πŸ“ Generated Files") + output_files_display = gr.Markdown( + "No files generated yet", + label="Output Files" + ) + + # Event Handlers + def update_book_info(selected_book): + """Update book information when selection changes""" + if selected_book and selected_book in book_paths: + book_path = book_paths[selected_book] + info = get_book_info(book_path) + return info + return "No book selected" + + def refresh_book_list(): + """Refresh the list of available books""" + books = get_available_books() + choices = [f"{book['name']} ({book['chunk_count']} chunks, {book['duration']})" + for book in books] + paths = {f"{book['name']} ({book['chunk_count']} chunks, {book['duration']})": book['path'] + for book in books} + + # Update global book_paths + nonlocal book_paths + book_paths = paths + + return gr.update(choices=choices, value=choices[0] if choices else None) + + def get_selected_book_path(selected_book, manual_path): + """Get the actual book path from selection or manual input""" + if manual_path.strip(): + return manual_path.strip() + elif selected_book and selected_book in book_paths: + return book_paths[selected_book] + return None + + def start_combine_operation(selected_book, manual_path, voice_name): + """Start the combine operation""" + # Validation + book_path = get_selected_book_path(selected_book, manual_path) + if not book_path: + return ( + "❌ Please select a book or enter a manual path", + 0, + "Error", + "--:--:--", + "No files generated", + gr.update(interactive=False), + gr.update(interactive=True) + ) + + # Check if already running + if combine_state['running']: + return ( + "⚠️ Combine operation already in progress", + combine_state['progress'], + combine_state.get('current_book', '--'), + "--:--:--", + "Operation in progress...", + gr.update(interactive=False), + gr.update(interactive=True) + ) + + try: + # Start combine operation in background thread + def run_combine_thread(): + start_time = time.time() + try: + result = run_combine_operation(book_path, voice_name.strip() or None) + elapsed = time.time() - start_time + combine_state['elapsed'] = elapsed + + if result['success']: + combine_state['status'] = 'βœ… Audio combining completed successfully!' + + # Find generated files + book_path_obj = Path(book_path) + suffix = f" [{voice_name.strip()}]" if voice_name.strip() else "_combined" + + generated_files = [] + wav_file = book_path_obj / f"{book_path_obj.name}{suffix}.wav" + m4b_file = book_path_obj / f"{book_path_obj.name}{suffix}.m4b" + + if wav_file.exists(): + size_mb = wav_file.stat().st_size / (1024 * 1024) + generated_files.append(f"**WAV**: {wav_file.name} ({size_mb:.1f}MB)") + + if m4b_file.exists(): + size_mb = m4b_file.stat().st_size / (1024 * 1024) + generated_files.append(f"**M4B**: {m4b_file.name} ({size_mb:.1f}MB)") + + combine_state['generated_files'] = "\n".join(generated_files) if generated_files else "No files found" + else: + combine_state['status'] = f"❌ Combine failed: {result.get('error', 'Unknown error')}" + combine_state['generated_files'] = "No files generated due to error" + + except Exception as e: + combine_state['status'] = f"❌ Error: {str(e)}" + combine_state['generated_files'] = "No files generated due to error" + finally: + combine_state['running'] = False + + # Start thread + thread = threading.Thread(target=run_combine_thread) + thread.start() + combine_state['thread'] = thread + + return ( + "πŸš€ Starting combine operation...", + 5, # Initial progress + Path(book_path).name, + "00:00:00", + "Starting operation...", + gr.update(interactive=False), # Disable combine button + gr.update(interactive=True) # Enable stop button + ) + + except Exception as e: + return ( + f"❌ Error starting combine: {str(e)}", + 0, + "Error", + "--:--:--", + "No files generated", + gr.update(interactive=True), + gr.update(interactive=False) + ) + + def stop_combine_operation(): + """Stop the current combine operation""" + if combine_state['running']: + combine_state['running'] = False + combine_state['status'] = '⏹️ Operation stopped by user' + combine_state['progress'] = 0 + + return ( + "⏹️ Operation stopped by user", + 0, + "--", + "--:--:--", + "Operation stopped", + gr.update(interactive=True), # Enable combine button + gr.update(interactive=False) # Disable stop button + ) + else: + return ( + "No operation to stop", + combine_state.get('progress', 0), + combine_state.get('current_book', '--'), + "--:--:--", + combine_state.get('generated_files', 'No files generated'), + gr.update(interactive=True), + gr.update(interactive=False) + ) + + def get_current_status(): + """Get current operation status for periodic updates""" + if combine_state['running']: + elapsed = time.time() - combine_state.get('start_time', time.time()) + elapsed_str = f"{int(elapsed//3600):02d}:{int((elapsed%3600)//60):02d}:{int(elapsed%60):02d}" + + return ( + combine_state.get('status', 'Processing...'), + combine_state.get('progress', 0), + combine_state.get('current_book', '--'), + elapsed_str, + combine_state.get('generated_files', 'Processing...'), + gr.update(interactive=False), + gr.update(interactive=True) + ) + else: + # Operation completed or not running + elapsed = combine_state.get('elapsed', 0) + elapsed_str = f"{int(elapsed//3600):02d}:{int((elapsed%3600)//60):02d}:{int(elapsed%60):02d}" if elapsed > 0 else "--:--:--" + + return ( + combine_state.get('status', 'Ready'), + combine_state.get('progress', 0), + combine_state.get('current_book', '--'), + elapsed_str, + combine_state.get('generated_files', 'No files generated'), + gr.update(interactive=True), + gr.update(interactive=False) + ) + + # Connect event handlers + book_selector.change( + update_book_info, + inputs=[book_selector], + outputs=[book_info_display] + ) + + refresh_books_btn.click( + refresh_book_list, + inputs=[], + outputs=[book_selector] + ) + + combine_btn.click( + start_combine_operation, + inputs=[book_selector, manual_path_input, voice_name_input], + outputs=[ + status_display, progress_display, current_book_display, + operation_time_display, output_files_display, + combine_btn, stop_btn + ] + ) + + stop_btn.click( + stop_combine_operation, + inputs=[], + outputs=[ + status_display, progress_display, current_book_display, + operation_time_display, output_files_display, + combine_btn, stop_btn + ] + ) + + # Status refresh button + with gr.Row(): + refresh_status_btn = gr.Button("πŸ”„ Refresh Status", size="sm", variant="secondary") + + refresh_status_btn.click( + get_current_status, + inputs=[], + outputs=[ + status_display, progress_display, current_book_display, + operation_time_display, output_files_display, + combine_btn, stop_btn + ] + ) + + return { + 'combine_button': combine_btn, + 'status_display': status_display, + 'progress': progress_display + } + +if __name__ == "__main__": + # Test the tab + with gr.Blocks() as demo: + create_combine_audio_tab() + + demo.launch() \ No newline at end of file diff --git a/gradio_tabs/tab5_prepare_text.py b/gradio_tabs/tab5_prepare_text.py new file mode 100644 index 0000000000000000000000000000000000000000..45d60cf120c2f16eaeb86583eacb50dbc1527eab --- /dev/null +++ b/gradio_tabs/tab5_prepare_text.py @@ -0,0 +1,658 @@ +#!/usr/bin/env python3 +""" +Gradio Tab 5: Prepare Text +Text file preparation and chunking with VADER analysis - matches PyQt5 GUI Tab 5 functionality +""" + +import gradio as gr +import os +import sys +import threading +import time +import json +from pathlib import Path +from typing import List, Dict, Any, Optional, Tuple + +# Import backend functionality +try: + from modules.tts_engine import generate_enriched_chunks + from config.config import ( + AUDIOBOOK_ROOT, TEXT_INPUT_ROOT, + BASE_EXAGGERATION, BASE_CFG_WEIGHT, BASE_TEMPERATURE, + DEFAULT_MIN_P, DEFAULT_TOP_P, DEFAULT_REPETITION_PENALTY, + ENABLE_SENTIMENT_SMOOTHING, SENTIMENT_SMOOTHING_WINDOW, SENTIMENT_SMOOTHING_METHOD, + VADER_EXAGGERATION_SENSITIVITY, VADER_CFG_WEIGHT_SENSITIVITY, VADER_TEMPERATURE_SENSITIVITY + ) + PREPARE_TEXT_AVAILABLE = True + print("βœ… Text preparation functionality available") +except ImportError as e: + print(f"⚠️ Text preparation functionality not available: {e}") + PREPARE_TEXT_AVAILABLE = False + # Default values if config not available + BASE_EXAGGERATION = 0.5 + BASE_CFG_WEIGHT = 0.5 + BASE_TEMPERATURE = 0.8 + DEFAULT_MIN_P = 0.1 + DEFAULT_TOP_P = 0.9 + DEFAULT_REPETITION_PENALTY = 1.0 + ENABLE_SENTIMENT_SMOOTHING = True + SENTIMENT_SMOOTHING_WINDOW = 3 + SENTIMENT_SMOOTHING_METHOD = "gaussian" + VADER_EXAGGERATION_SENSITIVITY = 0.3 + VADER_CFG_WEIGHT_SENSITIVITY = 0.3 + VADER_TEMPERATURE_SENSITIVITY = 0.3 + +# Global state for text preparation +prepare_state = { + 'preparation_running': False, + 'current_file': None, + 'progress': 0, + 'status': 'Ready', + 'generated_chunks': 0, + 'output_path': None +} + +def get_available_text_files(): + """Find available text files for preparation""" + text_files = [] + + if not PREPARE_TEXT_AVAILABLE: + return text_files + + # Look in Text_Input directory structure + text_input_root = Path(TEXT_INPUT_ROOT) if 'TEXT_INPUT_ROOT' in globals() else Path("Text_Input") + if text_input_root.exists(): + # Look for text files in subdirectories (book folders) + for book_dir in text_input_root.iterdir(): + if book_dir.is_dir(): + for text_file in book_dir.glob("*.txt"): + try: + # Check if file has content + if text_file.stat().st_size > 0: + text_files.append({ + 'name': f"{book_dir.name}/{text_file.name}", + 'path': str(text_file), + 'book_name': book_dir.name, + 'file_name': text_file.name, + 'size': text_file.stat().st_size, + 'display': f"{book_dir.name}/{text_file.name} ({text_file.stat().st_size // 1024}KB)" + }) + except: + pass + + # Also look for direct text files in Text_Input root + for text_file in text_input_root.glob("*.txt"): + try: + if text_file.stat().st_size > 0: + text_files.append({ + 'name': text_file.name, + 'path': str(text_file), + 'book_name': text_file.stem, + 'file_name': text_file.name, + 'size': text_file.stat().st_size, + 'display': f"{text_file.name} ({text_file.stat().st_size // 1024}KB)" + }) + except: + pass + + return sorted(text_files, key=lambda x: x['name']) + +def load_text_file_info(file_selection): + """Load information about selected text file""" + if not file_selection or file_selection == "-- Select Text File --": + return "No text file selected", "No file loaded" + + try: + # Find the selected file + text_files = get_available_text_files() + selected_file = None + for tf in text_files: + if tf['display'] == file_selection: + selected_file = tf + break + + if not selected_file: + return "❌ Selected file not found", "Error" + + prepare_state['current_file'] = selected_file + + # Analyze text file + text_path = Path(selected_file['path']) + with open(text_path, 'r', encoding='utf-8') as f: + content = f.read() + + # Basic statistics + char_count = len(content) + word_count = len(content.split()) + line_count = len(content.splitlines()) + paragraph_count = len([p for p in content.split('\n\n') if p.strip()]) + + # Estimate chunks (rough calculation) + estimated_chunks = max(1, word_count // 20) # Assuming ~20 words per chunk average + + # Check for existing processed version + book_name = selected_file['book_name'] + existing_json_path = Path(AUDIOBOOK_ROOT) / book_name / "TTS" / "text_chunks" / "chunks_info.json" + existing_status = "" + if existing_json_path.exists(): + try: + with open(existing_json_path, 'r') as f: + existing_data = json.load(f) + existing_chunks = len(existing_data) + existing_status = f"\n\n**⚠️ Already Processed:**\nβ€’ Existing JSON: {existing_chunks} chunks\nβ€’ Path: {existing_json_path}" + except: + existing_status = f"\n\n**⚠️ Partial Processing:**\nβ€’ JSON file exists but may be corrupted" + + info = f"""**πŸ“„ Text File Analysis:** +**File:** {selected_file['name']} +**Path:** {text_path} +**Book Name:** {book_name} + +**Content Statistics:** +β€’ File Size: {char_count:,} characters ({selected_file['size'] // 1024}KB) +β€’ Word Count: {word_count:,} words +β€’ Lines: {line_count:,} +β€’ Paragraphs: {paragraph_count:,} +β€’ Estimated Chunks: ~{estimated_chunks} + +**Processing Status:** +β€’ Ready for preparation: {'βœ… Yes' if word_count > 0 else '❌ Empty file'}{existing_status}""" + + current_file = f"πŸ“ Selected: {selected_file['name']} ({word_count:,} words)" + + return info, current_file + + except Exception as e: + return f"❌ Error loading text file: {str(e)}", "Error loading file" + +def start_text_preparation( + file_selection, + use_vader, exaggeration, cfg_weight, temperature, min_p, top_p, repetition_penalty, + sentiment_smoothing, smoothing_window, smoothing_method, + vader_exag_sens, vader_cfg_sens, vader_temp_sens +): + """Start text preparation with enriched chunking""" + if prepare_state['preparation_running']: + return "⚠️ Text preparation already in progress", 0, "Processing...", "Generation running..." + + if not file_selection or file_selection == "-- Select Text File --": + return "❌ Please select a text file", 0, "No file selected", "Ready" + + try: + # Find selected file + text_files = get_available_text_files() + selected_file = None + for tf in text_files: + if tf['display'] == file_selection: + selected_file = tf + break + + if not selected_file: + return "❌ Invalid file selection", 0, "Selection error", "Ready" + + prepare_state['current_file'] = selected_file + prepare_state['preparation_running'] = True + prepare_state['progress'] = 0 + prepare_state['status'] = 'Starting text preparation...' + + # Start preparation in background thread + def preparation_worker(): + try: + prepare_state['status'] = 'πŸ“ Analyzing text and generating chunks...' + prepare_state['progress'] = 10 + + text_path = Path(selected_file['path']) + book_name = selected_file['book_name'] + + # Create output directory + book_output_dir = Path(AUDIOBOOK_ROOT) / book_name / "TTS" / "text_chunks" + book_output_dir.mkdir(parents=True, exist_ok=True) + + prepare_state['progress'] = 20 + prepare_state['status'] = '🎭 Applying VADER sentiment analysis...' + + # Prepare TTS parameters + user_tts_params = { + 'exaggeration': exaggeration, + 'cfg_weight': cfg_weight, + 'temperature': temperature, + 'min_p': min_p, + 'top_p': top_p, + 'repetition_penalty': repetition_penalty, + 'use_vader': use_vader + } + + # Prepare quality parameters + quality_params = { + 'sentiment_smoothing': sentiment_smoothing, + 'smoothing_window': int(smoothing_window), + 'smoothing_method': smoothing_method + } + + # Prepare config parameters + config_params = { + 'vader_exag_sensitivity': vader_exag_sens, + 'vader_cfg_sensitivity': vader_cfg_sens, + 'vader_temp_sensitivity': vader_temp_sens + } + + prepare_state['progress'] = 50 + prepare_state['status'] = 'πŸ”¬ Generating enriched chunks with metadata...' + + # Generate enriched chunks + enriched_chunks = generate_enriched_chunks( + text_path, + book_output_dir, + user_tts_params, + quality_params, + config_params + ) + + prepare_state['progress'] = 90 + prepare_state['status'] = 'πŸ’Ύ Saving JSON metadata...' + + json_path = book_output_dir / "chunks_info.json" + prepare_state['generated_chunks'] = len(enriched_chunks) + prepare_state['output_path'] = str(json_path) + + prepare_state['progress'] = 100 + prepare_state['status'] = 'βœ… Text preparation completed successfully!' + + except Exception as e: + prepare_state['progress'] = 0 + prepare_state['status'] = f'❌ Preparation error: {str(e)}' + prepare_state['generated_chunks'] = 0 + prepare_state['output_path'] = None + finally: + prepare_state['preparation_running'] = False + + # Start worker thread + threading.Thread(target=preparation_worker, daemon=True).start() + + return ( + "πŸ“ Starting text preparation with VADER analysis...", + 10, + f"Processing: {selected_file['name']}", + "Preparation started" + ) + + except Exception as e: + prepare_state['preparation_running'] = False + return f"❌ Error starting preparation: {str(e)}", 0, "Preparation failed", "Error" + +def get_preparation_status(): + """Get current preparation status""" + if prepare_state['generated_chunks'] > 0: + chunk_info = f" ({prepare_state['generated_chunks']} chunks generated)" + output_info = f"JSON saved: {prepare_state['output_path']}" if prepare_state['output_path'] else "Processing..." + else: + chunk_info = "" + output_info = "No output yet" if prepare_state['preparation_running'] else "Ready for text preparation" + + return ( + prepare_state.get('status', 'Ready') + chunk_info, + prepare_state.get('progress', 0), + output_info, + "Processing..." if prepare_state['preparation_running'] else "Ready" + ) + +def stop_text_preparation(): + """Stop current preparation (if possible)""" + if prepare_state['preparation_running']: + prepare_state['preparation_running'] = False + prepare_state['status'] = '⏹️ Preparation stopped by user' + prepare_state['progress'] = 0 + return "⏹️ Preparation stopped", 0, "Preparation stopped", "Ready" + else: + return "No preparation to stop", prepare_state.get('progress', 0), prepare_state.get('status', 'Ready'), "Ready" + +def create_prepare_text_tab(): + """Create Tab 5: Prepare Text with all GUI functionality""" + + with gr.Column(): + gr.Markdown("# πŸ“ Prepare Text for Processing") + gr.Markdown("*Text file preparation with VADER sentiment analysis and chunking - matches GUI Tab 5*") + + if not PREPARE_TEXT_AVAILABLE: + gr.Markdown("### ❌ Text Preparation Not Available") + gr.Markdown("Missing required backend modules. Please ensure modules/tts_engine.py is available.") + return {} + + # Important guidance note + gr.Markdown(""" + ### πŸ’‘ **Important Usage Information** + + This tool prepares text files for TTS conversion by: + - **Chunking**: Breaking text into optimal segments for TTS processing + - **VADER Analysis**: Applying sentiment analysis to adjust TTS parameters per chunk + - **JSON Generation**: Creating metadata files ready for audiobook generation + + **Configure your parameters below, then select and process your text file.** + """) + + # Text File Selection Section + with gr.Row(): + with gr.Column(scale=2): + gr.Markdown("### πŸ“„ Text File Selection") + + text_files = get_available_text_files() + text_choices = ["-- Select Text File --"] + [tf['display'] for tf in text_files] + + text_file_selector = gr.Dropdown( + label="Text File to Prepare", + choices=text_choices, + value="-- Select Text File --", + interactive=True, + info="Select text file from Text_Input directory for preparation" + ) + + # Manual path input + text_manual_path = gr.Textbox( + label="Or Enter Text File Path Manually", + placeholder="e.g., /path/to/book.txt", + interactive=True, + info="Direct path to text file" + ) + + refresh_files_btn = gr.Button( + "πŸ”„ Refresh File List", + variant="secondary", + size="sm" + ) + + with gr.Column(scale=1): + text_file_info = gr.Markdown( + "No text file selected", + label="File Information" + ) + + # TTS Base Parameters Section + with gr.Column(): + gr.Markdown("### βš™οΈ Base TTS Parameters") + gr.Markdown("*These parameters will be used as the baseline, with VADER adjustments applied per chunk*") + + with gr.Row(): + use_vader_check = gr.Checkbox( + label="Enable VADER Sentiment Analysis", + value=True, + info="Apply sentiment-based TTS parameter adjustments" + ) + + with gr.Row(): + exaggeration_param = gr.Slider( + label="Base Exaggeration", + minimum=0.0, maximum=2.0, step=0.1, + value=BASE_EXAGGERATION, + interactive=True, + info="Base speech exaggeration level" + ) + + cfg_weight_param = gr.Slider( + label="Base CFG Weight", + minimum=0.0, maximum=1.0, step=0.1, + value=BASE_CFG_WEIGHT, + interactive=True, + info="Base CFG guidance strength" + ) + + temperature_param = gr.Slider( + label="Base Temperature", + minimum=0.0, maximum=2.0, step=0.1, + value=BASE_TEMPERATURE, + interactive=True, + info="Base TTS randomness/creativity" + ) + + with gr.Row(): + min_p_param = gr.Slider( + label="Min P", + minimum=0.0, maximum=1.0, step=0.01, + value=DEFAULT_MIN_P, + interactive=True, + info="Minimum probability threshold" + ) + + top_p_param = gr.Slider( + label="Top P", + minimum=0.0, maximum=1.0, step=0.01, + value=DEFAULT_TOP_P, + interactive=True, + info="Nucleus sampling parameter" + ) + + repetition_penalty_param = gr.Slider( + label="Repetition Penalty", + minimum=0.5, maximum=2.0, step=0.1, + value=DEFAULT_REPETITION_PENALTY, + interactive=True, + info="Penalty for repetitive speech" + ) + + # Sentiment Processing Section + with gr.Column(): + gr.Markdown("### 🎭 Sentiment Analysis Settings") + + with gr.Row(): + sentiment_smoothing_check = gr.Checkbox( + label="Enable Sentiment Smoothing", + value=ENABLE_SENTIMENT_SMOOTHING, + info="Smooth sentiment scores across adjacent chunks" + ) + + smoothing_window_param = gr.Slider( + label="Smoothing Window", + minimum=1, maximum=10, step=1, + value=SENTIMENT_SMOOTHING_WINDOW, + interactive=True, + info="Number of chunks to include in smoothing" + ) + + smoothing_method_param = gr.Dropdown( + label="Smoothing Method", + choices=["gaussian", "moving_average", "exponential"], + value=SENTIMENT_SMOOTHING_METHOD, + interactive=True, + info="Algorithm for sentiment smoothing" + ) + + # VADER Sensitivity Section + with gr.Column(): + gr.Markdown("### 🎚️ VADER Sensitivity Settings") + gr.Markdown("*Control how much sentiment analysis affects TTS parameters*") + + with gr.Row(): + vader_exag_sens_param = gr.Slider( + label="Exaggeration Sensitivity", + minimum=0.0, maximum=1.0, step=0.01, + value=VADER_EXAGGERATION_SENSITIVITY, + interactive=True, + info="How much sentiment affects exaggeration" + ) + + vader_cfg_sens_param = gr.Slider( + label="CFG Sensitivity", + minimum=0.0, maximum=1.0, step=0.01, + value=VADER_CFG_WEIGHT_SENSITIVITY, + interactive=True, + info="How much sentiment affects CFG weight" + ) + + vader_temp_sens_param = gr.Slider( + label="Temperature Sensitivity", + minimum=0.0, maximum=1.0, step=0.01, + value=VADER_TEMPERATURE_SENSITIVITY, + interactive=True, + info="How much sentiment affects temperature" + ) + + # Preparation Controls + with gr.Row(): + prepare_btn = gr.Button( + "πŸ“ Prepare Text for Chunking", + variant="primary", + size="lg", + interactive=True + ) + + stop_btn = gr.Button( + "⏹️ Stop Preparation", + variant="secondary", + size="lg", + interactive=True + ) + + # Progress and Status + with gr.Row(): + with gr.Column(scale=2): + preparation_status = gr.Textbox( + label="Preparation Status", + value="Ready for text preparation", + interactive=False, + lines=2 + ) + + progress_bar = gr.Slider( + label="Progress %", + minimum=0, maximum=100, step=1, + value=0, + interactive=False, + info="Preparation progress" + ) + + with gr.Column(scale=1): + current_file_display = gr.Textbox( + label="Current File", + value="No file selected", + interactive=False + ) + + operation_status = gr.Textbox( + label="Operation Status", + value="Ready", + interactive=False + ) + + # Output Information + with gr.Column(): + gr.Markdown("### πŸ“ Generated Output") + + output_info = gr.Textbox( + label="Generated JSON File", + value="No output generated yet", + interactive=False, + info="Location of generated chunks_info.json file" + ) + + with gr.Row(): + refresh_status_btn = gr.Button( + "πŸ”„ Refresh Status", + variant="secondary", + size="sm" + ) + + next_steps_btn = gr.Button( + "➑️ Next Steps Guide", + variant="secondary", + size="sm" + ) + + next_steps_info = gr.Markdown( + "*After preparation completes, use Tab 1 (Convert Book) or Tab 8 (JSON Generate) to create the audiobook.*", + visible=False + ) + + # Event Handlers + def refresh_file_list(): + """Refresh text files list""" + text_files = get_available_text_files() + choices = ["-- Select Text File --"] + [tf['display'] for tf in text_files] + return gr.update(choices=choices, value="-- Select Text File --") + + def show_next_steps(): + """Show next steps information""" + return """## πŸ“‹ Next Steps After Text Preparation + +**Your text has been prepared and is ready for audiobook generation!** + +### Option 1: Use Tab 1 (Convert Book) - **Recommended** +1. Go to **Tab 1: Convert Book** +2. Select your prepared book from the dropdown +3. Choose a voice sample +4. Click "Generate Audiobook" for full processing + +### Option 2: Use Tab 8 (JSON Generate) - **Advanced** +1. Go to **Tab 8: JSON Generate** +2. Select the generated JSON file +3. Choose a voice sample +4. Generate audiobook directly from JSON + +### Files Created: +- **JSON Chunks**: `Audiobook/[BookName]/TTS/text_chunks/chunks_info.json` +- **Metadata**: Includes sentiment analysis and TTS parameters per chunk +- **Ready**: For immediate audiobook generation + +### Benefits of Preparation: +- βœ… **VADER Analysis**: Sentiment-based TTS parameter adjustment +- βœ… **Optimized Chunks**: Smart text segmentation for better TTS +- βœ… **Metadata Rich**: Each chunk has custom TTS parameters +- βœ… **Faster Generation**: Skip text processing in future runs +""" + + # Connect event handlers + refresh_files_btn.click( + refresh_file_list, + inputs=[], + outputs=[text_file_selector] + ) + + text_file_selector.change( + load_text_file_info, + inputs=[text_file_selector], + outputs=[text_file_info, current_file_display] + ) + + prepare_btn.click( + start_text_preparation, + inputs=[ + text_file_selector, + use_vader_check, exaggeration_param, cfg_weight_param, temperature_param, + min_p_param, top_p_param, repetition_penalty_param, + sentiment_smoothing_check, smoothing_window_param, smoothing_method_param, + vader_exag_sens_param, vader_cfg_sens_param, vader_temp_sens_param + ], + outputs=[preparation_status, progress_bar, output_info, operation_status] + ) + + stop_btn.click( + stop_text_preparation, + inputs=[], + outputs=[preparation_status, progress_bar, output_info, operation_status] + ) + + refresh_status_btn.click( + get_preparation_status, + inputs=[], + outputs=[preparation_status, progress_bar, output_info, operation_status] + ) + + next_steps_btn.click( + show_next_steps, + inputs=[], + outputs=[next_steps_info] + ).then( + lambda: gr.update(visible=True), + outputs=[next_steps_info] + ) + + return { + 'file_selector': text_file_selector, + 'prepare_button': prepare_btn, + 'status_display': preparation_status + } + +if __name__ == "__main__": + # Test the tab + with gr.Blocks() as demo: + create_prepare_text_tab() + + demo.launch() \ No newline at end of file diff --git a/gradio_tabs/tab6_settings.py b/gradio_tabs/tab6_settings.py new file mode 100644 index 0000000000000000000000000000000000000000..74190e107859264bd6adfdf7cf86aeeb917b06d8 --- /dev/null +++ b/gradio_tabs/tab6_settings.py @@ -0,0 +1,384 @@ +#!/usr/bin/env python3 +""" +Gradio Tab 6: Settings +Configuration management with live reload functionality +""" + +import gradio as gr +import os +import sys +import json +import importlib +from pathlib import Path +from typing import Dict, Any, Tuple, List + +# Import configuration +try: + from config import config + CONFIG_MODULE = config + CONFIG_AVAILABLE = True + print("βœ… Config module loaded successfully") +except ImportError as e: + print(f"⚠️ Config not available: {e}") + CONFIG_AVAILABLE = False + CONFIG_MODULE = None + +class ConfigManager: + """Manages configuration reloading and validation.""" + + def __init__(self): + self.config_file_path = Path("config/config.py") + self.current_config = {} + self.load_current_config() + + def load_current_config(self): + """Load current configuration values.""" + if not CONFIG_AVAILABLE: + return + + # Extract current config values + config_attrs = [attr for attr in dir(CONFIG_MODULE) if not attr.startswith('_')] + + for attr in config_attrs: + value = getattr(CONFIG_MODULE, attr) + # Only include simple types that can be edited + if isinstance(value, (int, float, str, bool, Path)): + self.current_config[attr] = value + + def reload_config(self) -> Tuple[bool, str]: + """ + Reload the configuration module. + + Returns: + Tuple of (success, message) + """ + try: + # Reload the config module + if CONFIG_MODULE: + importlib.reload(CONFIG_MODULE) + + # Update current config + self.load_current_config() + + return True, "βœ… Configuration reloaded successfully!" + except Exception as e: + return False, f"❌ Error reloading config: {str(e)}" + + def save_config_value(self, key: str, value: Any) -> Tuple[bool, str]: + """ + Save a configuration value (in-memory for now). + + Args: + key: Configuration key + value: New value + + Returns: + Tuple of (success, message) + """ + try: + if CONFIG_MODULE and hasattr(CONFIG_MODULE, key): + # Set the value in the module + setattr(CONFIG_MODULE, key, value) + self.current_config[key] = value + return True, f"βœ… {key} updated to {value}" + else: + return False, f"❌ Configuration key '{key}' not found" + except Exception as e: + return False, f"❌ Error updating {key}: {str(e)}" + + def get_config_categories(self) -> Dict[str, List[str]]: + """Group configuration keys by category based on prefixes and naming.""" + categories = { + "Core Directories": [], + "Text Processing": [], + "Performance": [], + "TTS Parameters": [], + "Audio Quality": [], + "VADER Sentiment": [], + "File Paths": [] + } + + for key in self.current_config.keys(): + key_lower = key.lower() + + if any(x in key_lower for x in ['dir', 'root', 'path', 'folder']): + categories["Core Directories"].append(key) + elif any(x in key_lower for x in ['chunk', 'word', 'text']): + categories["Text Processing"].append(key) + elif any(x in key_lower for x in ['worker', 'thread', 'vram', 'memory', 'performance']): + categories["Performance"].append(key) + elif any(x in key_lower for x in ['tts_param', 'temperature', 'cfg', 'exaggeration']): + categories["TTS Parameters"].append(key) + elif any(x in key_lower for x in ['audio', 'quality', 'validation', 'threshold']): + categories["Audio Quality"].append(key) + elif any(x in key_lower for x in ['vader', 'sentiment']): + categories["VADER Sentiment"].append(key) + else: + categories["File Paths"].append(key) + + # Remove empty categories + return {cat: keys for cat, keys in categories.items() if keys} + +def create_config_editor(config_manager: ConfigManager): + """Create configuration editor interface.""" + + with gr.Column(): + gr.Markdown("## πŸ”§ Configuration Editor") + gr.Markdown("*Edit configuration values and reload the system*") + + # Reload button + reload_btn = gr.Button("πŸ”„ Reload Configuration", variant="primary") + reload_status = gr.Textbox( + label="Status", + value="Ready to reload configuration", + interactive=False + ) + + # Configuration categories + categories = config_manager.get_config_categories() + + config_inputs = {} + + for category, keys in categories.items(): + with gr.Accordion(f"πŸ“ {category}", open=False): + for key in keys: + current_value = config_manager.current_config.get(key, "") + + # Create appropriate input based on value type + if isinstance(current_value, bool): + config_inputs[key] = gr.Checkbox( + label=key, + value=current_value, + info=f"Current: {current_value}" + ) + elif isinstance(current_value, int): + config_inputs[key] = gr.Number( + label=key, + value=current_value, + precision=0, + info=f"Current: {current_value}" + ) + elif isinstance(current_value, float): + config_inputs[key] = gr.Number( + label=key, + value=current_value, + precision=3, + info=f"Current: {current_value}" + ) + else: + config_inputs[key] = gr.Textbox( + label=key, + value=str(current_value), + info=f"Current: {current_value}" + ) + + # Save all button + save_all_btn = gr.Button("πŸ’Ύ Save All Changes", variant="secondary") + save_status = gr.Textbox( + label="Save Status", + value="No changes to save", + interactive=False + ) + + # Reload functionality + def reload_config(): + success, message = config_manager.reload_config() + + if success: + # Update all input values with reloaded config + updates = {} + for key, input_component in config_inputs.items(): + new_value = config_manager.current_config.get(key, "") + updates[input_component] = gr.update(value=new_value, info=f"Current: {new_value}") + + return message, *updates.values() + else: + return message, *[gr.update() for _ in config_inputs] + + # Save all functionality + def save_all_changes(*values): + results = [] + all_success = True + + for i, (key, input_component) in enumerate(config_inputs.items()): + value = values[i] + success, message = config_manager.save_config_value(key, value) + results.append(message) + if not success: + all_success = False + + if all_success: + return "βœ… All configuration changes saved successfully!" + else: + return f"⚠️ Some changes failed:\n" + "\n".join(results) + + # Wire up the reload button + reload_outputs = [reload_status] + list(config_inputs.values()) + reload_btn.click( + fn=reload_config, + outputs=reload_outputs + ) + + # Wire up the save button + save_all_btn.click( + fn=save_all_changes, + inputs=list(config_inputs.values()), + outputs=[save_status] + ) + +def create_config_backup(): + """Create configuration backup interface.""" + + with gr.Column(): + gr.Markdown("## πŸ’Ύ Configuration Backup") + gr.Markdown("*Backup and restore configuration settings*") + + with gr.Row(): + backup_btn = gr.Button("πŸ“¦ Create Backup") + restore_btn = gr.Button("πŸ“‚ Restore from Backup") + + backup_status = gr.Textbox( + label="Backup Status", + value="No backup operations performed", + interactive=False + ) + + backup_file = gr.File( + label="Backup File", + file_types=[".json"], + type="filepath" + ) + + def create_backup(): + try: + if not CONFIG_AVAILABLE: + return "❌ Configuration not available for backup" + + config_data = {} + config_attrs = [attr for attr in dir(CONFIG_MODULE) if not attr.startswith('_')] + + for attr in config_attrs: + value = getattr(CONFIG_MODULE, attr) + if isinstance(value, (int, float, str, bool)): + config_data[attr] = value + elif isinstance(value, Path): + config_data[attr] = str(value) + + backup_path = Path("config_backup.json") + with open(backup_path, 'w') as f: + json.dump(config_data, f, indent=2) + + return f"βœ… Configuration backed up to {backup_path}" + + except Exception as e: + return f"❌ Backup failed: {str(e)}" + + def restore_backup(file_path): + try: + if not file_path: + return "❌ No backup file selected" + + with open(file_path, 'r') as f: + backup_data = json.load(f) + + restored_count = 0 + for key, value in backup_data.items(): + if CONFIG_MODULE and hasattr(CONFIG_MODULE, key): + setattr(CONFIG_MODULE, key, value) + restored_count += 1 + + return f"βœ… Restored {restored_count} configuration values from backup" + + except Exception as e: + return f"❌ Restore failed: {str(e)}" + + backup_btn.click( + fn=create_backup, + outputs=[backup_status] + ) + + restore_btn.click( + fn=restore_backup, + inputs=[backup_file], + outputs=[backup_status] + ) + +def create_system_info(): + """Create system information display.""" + + with gr.Column(): + gr.Markdown("## πŸ“Š System Information") + + # Get system info + def get_system_info(): + info = [] + + # Config availability + info.append(f"**Configuration Module**: {'βœ… Available' if CONFIG_AVAILABLE else '❌ Not Available'}") + + # Python version + info.append(f"**Python Version**: {sys.version.split()[0]}") + + # Working directory + info.append(f"**Working Directory**: {os.getcwd()}") + + # Config file path + config_path = Path("config/config.py") + info.append(f"**Config File**: {'βœ… Exists' if config_path.exists() else '❌ Missing'} ({config_path})") + + # Config count + if CONFIG_AVAILABLE: + config_count = len([attr for attr in dir(CONFIG_MODULE) if not attr.startswith('_')]) + info.append(f"**Configuration Items**: {config_count}") + + return "\n\n".join(info) + + system_info = gr.Markdown(get_system_info()) + + refresh_btn = gr.Button("πŸ”„ Refresh Info") + refresh_btn.click( + fn=get_system_info, + outputs=[system_info] + ) + +def create_settings_tab(): + """Create the main settings tab interface.""" + + if not CONFIG_AVAILABLE: + # Show error state + with gr.Column(): + gr.Markdown("# ⚠️ Configuration Not Available") + gr.Markdown(""" + The configuration module could not be loaded. This may be due to: + - Missing config/config.py file + - Import errors in the configuration + - Path issues + + Please check your installation and try again. + """) + return + + # Initialize config manager + config_manager = ConfigManager() + + with gr.Column(): + gr.Markdown("# βš™οΈ Settings & Configuration") + gr.Markdown("*Manage ChatterboxTTS configuration and system settings*") + + with gr.Tabs(): + # Configuration Editor + with gr.Tab("πŸ”§ Configuration"): + create_config_editor(config_manager) + + # Backup & Restore + with gr.Tab("πŸ’Ύ Backup"): + create_config_backup() + + # System Information + with gr.Tab("πŸ“Š System Info"): + create_system_info() + +# Export the main function +def create_settings_tab_interface(): + """Main entry point for the settings tab.""" + return create_settings_tab() \ No newline at end of file diff --git a/gradio_tabs/tab7_chunk_tools.py b/gradio_tabs/tab7_chunk_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..e82a6ce9f99d17075c24cef54a32bc1a16f4661a --- /dev/null +++ b/gradio_tabs/tab7_chunk_tools.py @@ -0,0 +1,760 @@ +#!/usr/bin/env python3 +""" +Gradio Tab 7: Chunk Tools +Interactive chunk editing, search, and audio regeneration - matches PyQt5 GUI Tab 7 functionality +""" + +import gradio as gr +import os +import sys +import threading +import time +import json +from pathlib import Path +from typing import List, Dict, Any, Optional, Tuple + +# Import backend functionality +try: + from wrapper.chunk_loader import load_chunks, save_chunks + from wrapper.chunk_search import search_chunks + from wrapper.chunk_editor import update_chunk + from wrapper.chunk_player import play_chunk_audio + from wrapper.chunk_synthesizer import synthesize_chunk + from wrapper.chunk_revisions import accept_revision + from config.config import AUDIOBOOK_ROOT + from modules.voice_detector import get_likely_voices_for_book + CHUNK_TOOLS_AVAILABLE = True + print("βœ… Chunk tools functionality available") +except ImportError as e: + print(f"⚠️ Chunk tools functionality not available: {e}") + CHUNK_TOOLS_AVAILABLE = False + +# Global state for chunk operations +chunk_state = { + 'loaded_chunks': None, + 'current_chunk': None, + 'book_path': None, + 'audio_dir': None, + 'search_results': [], + 'voice_candidates': [], + 'selected_voice': None, + 'operation_running': False +} + +def get_available_repair_books(): + """Get list of books available for chunk repair/editing""" + books = [] + + if not CHUNK_TOOLS_AVAILABLE: + return books + + # Check TTS processing directories first + audiobook_root = Path(AUDIOBOOK_ROOT) + if audiobook_root.exists(): + for book_dir in audiobook_root.iterdir(): + if book_dir.is_dir(): + tts_chunks_dir = book_dir / "TTS" / "text_chunks" + json_path = tts_chunks_dir / "chunks_info.json" + if json_path.exists(): + chunk_count = 0 + try: + with open(json_path, 'r') as f: + chunks_data = json.load(f) + chunk_count = len(chunks_data) + except: + pass + + books.append({ + 'name': book_dir.name, + 'path': str(json_path), + 'source': 'TTS', + 'chunk_count': chunk_count, + 'display': f"{book_dir.name} (TTS: {chunk_count} chunks)" + }) + + # Check Text_Input directory for fallback + text_input_dir = Path("Text_Input") + if text_input_dir.exists(): + for chunk_file in text_input_dir.glob("*_chunks.json"): + book_name = chunk_file.stem.replace("_chunks", "") + # Only add if not already found in TTS directories + if not any(book['name'] == book_name for book in books): + chunk_count = 0 + try: + with open(chunk_file, 'r') as f: + chunks_data = json.load(f) + chunk_count = len(chunks_data) + except: + pass + + books.append({ + 'name': book_name, + 'path': str(chunk_file), + 'source': 'Text_Input', + 'chunk_count': chunk_count, + 'display': f"{book_name} (Text_Input: {chunk_count} chunks)" + }) + + return sorted(books, key=lambda x: x['name']) + +def load_book_chunks(book_selection): + """Load chunks for selected book""" + if not book_selection or book_selection == "-- Select a Book --": + chunk_state['loaded_chunks'] = None + chunk_state['current_chunk'] = None + chunk_state['book_path'] = None + chunk_state['audio_dir'] = None + chunk_state['voice_candidates'] = [] + return ( + "No book selected", + "", # Clear search + "No chunks loaded", + "", # Clear chunk text + "none", # Reset boundary + 0.5, 0.5, 0.8, # Reset TTS params + "No chunk selected", + "No voice detected", + ["-- Please Select Voice --"] + ) + + try: + # Find the book data + books = get_available_repair_books() + selected_book = None + for book in books: + if book['display'] == book_selection: + selected_book = book + break + + if not selected_book: + return ( + "❌ Selected book not found", + "", "No chunks loaded", "", "none", 0.5, 0.5, 0.8, + "No chunk selected", "No voice detected", ["-- Please Select Voice --"] + ) + + # Load chunks + chunks = load_chunks(selected_book['path']) + + # Ensure chunks have index fields + for i, chunk in enumerate(chunks): + if 'index' not in chunk: + chunk['index'] = i + + chunk_state['loaded_chunks'] = chunks + chunk_state['book_path'] = Path(selected_book['path']) + chunk_state['current_chunk'] = None + chunk_state['search_results'] = [] + + # Determine audio directory + audiobook_root = Path(AUDIOBOOK_ROOT) + chunk_state['audio_dir'] = audiobook_root / selected_book['name'] / "TTS" / "audio_chunks" + + # Detect voice candidates + try: + likely_voices = get_likely_voices_for_book(selected_book['name'], chunk_state['book_path']) + chunk_state['voice_candidates'] = likely_voices + + voice_choices = ["-- Please Select Voice --"] + voice_info = "" + + if likely_voices: + for voice_name, voice_path, detection_method in likely_voices: + voice_choices.append(f"{voice_name} ({detection_method})") + voice_info = f"βœ… Found {len(likely_voices)} voice candidate(s). Please select voice before resynthesizing." + else: + voice_info = "❌ No voice candidates detected. Check JSON metadata or run.log." + + except Exception as e: + voice_choices = ["-- Please Select Voice --"] + voice_info = f"❌ Error detecting voices: {str(e)}" + chunk_state['voice_candidates'] = [] + + return ( + f"βœ… Loaded {len(chunks)} chunks from {selected_book['name']}", + "", # Clear search + "Chunks loaded successfully - use search to find specific chunks", + "", # Clear chunk text + "none", # Reset boundary + 0.5, 0.5, 0.8, # Reset TTS params + "No chunk selected - search and select a chunk to edit", + voice_info, + voice_choices + ) + + except Exception as e: + return ( + f"❌ Error loading chunks: {str(e)}", + "", "Failed to load chunks", "", "none", 0.5, 0.5, 0.8, + "Error loading chunks", "No voice detected", ["-- Please Select Voice --"] + ) + +def search_for_chunks(search_query): + """Search for chunks containing the query text""" + if not chunk_state['loaded_chunks']: + return "❌ No chunks loaded - please select a book first", "" + + if not search_query.strip(): + return "❌ Please enter text to search for", "" + + try: + results = search_chunks(chunk_state['loaded_chunks'], search_query.strip()) + chunk_state['search_results'] = results + + if results: + # Format results for display + results_text = f"**Found {len(results)} matching chunks:**\n\n" + for chunk in results: + text_preview = chunk['text'][:80] + "..." if len(chunk['text']) > 80 else chunk['text'] + results_text += f"**[{chunk['index']}]** {text_preview}\n\n" + + # Create dropdown choices + choices = [] + for chunk in results: + text_preview = chunk['text'][:60] + "..." if len(chunk['text']) > 60 else chunk['text'] + choices.append(f"[{chunk['index']}] {text_preview}") + + return results_text, gr.update(choices=choices, value=None, visible=True) + else: + return "No matching chunks found", gr.update(choices=[], visible=False) + + except Exception as e: + return f"❌ Error searching chunks: {str(e)}", gr.update(choices=[], visible=False) + +def select_chunk_for_editing(chunk_selection): + """Select a chunk for editing from search results""" + if not chunk_selection or not chunk_state['search_results']: + return ( + "", "none", 0.5, 0.5, 0.8, + "No chunk selected" + ) + + try: + # Find the selected chunk by parsing the selection string + chunk_index_str = chunk_selection.split(']')[0][1:] # Extract index from "[123] text..." + chunk_index = int(chunk_index_str) + + # Find the chunk with this index + selected_chunk = None + for chunk in chunk_state['search_results']: + if chunk['index'] == chunk_index: + selected_chunk = chunk + break + + if not selected_chunk: + return ( + "", "none", 0.5, 0.5, 0.8, + "❌ Selected chunk not found" + ) + + chunk_state['current_chunk'] = selected_chunk + + # Extract chunk data + text = selected_chunk.get('text', '') + boundary_type = selected_chunk.get('boundary_type', 'none') + + # Extract TTS parameters + tts_params = selected_chunk.get('tts_params', {}) + exaggeration = tts_params.get('exaggeration', 0.5) + cfg_weight = tts_params.get('cfg_weight', 0.5) + temperature = tts_params.get('temperature', 0.8) + + # Create info display + sentiment = selected_chunk.get('sentiment_compound', selected_chunk.get('sentiment_score', 'N/A')) + word_count = selected_chunk.get('word_count', 'N/A') + + info_text = f"""**Selected Chunk {chunk_index}** +**Boundary:** {boundary_type} | **Words:** {word_count} | **Sentiment:** {sentiment} +**TTS Params:** exag={exaggeration}, cfg={cfg_weight}, temp={temperature} +**Audio File:** chunk_{chunk_index+1:05d}.wav""" + + return ( + text, + boundary_type, + exaggeration, + cfg_weight, + temperature, + info_text + ) + + except Exception as e: + return ( + "", "none", 0.5, 0.5, 0.8, + f"❌ Error selecting chunk: {str(e)}" + ) + +def save_chunk_changes(chunk_text, boundary_type, exaggeration, cfg_weight, temperature): + """Save changes to the current chunk""" + if not chunk_state['current_chunk']: + return "❌ No chunk selected" + + try: + # Update chunk data + chunk_state['current_chunk']['text'] = chunk_text.strip() + chunk_state['current_chunk']['boundary_type'] = boundary_type + chunk_state['current_chunk']['tts_params'] = { + 'exaggeration': exaggeration, + 'cfg_weight': cfg_weight, + 'temperature': temperature + } + + # Update word count + word_count = len(chunk_text.strip().split()) + chunk_state['current_chunk']['word_count'] = word_count + + # Save to file + if chunk_state['book_path']: + save_chunks(chunk_state['loaded_chunks'], str(chunk_state['book_path'])) + return f"βœ… Chunk {chunk_state['current_chunk']['index']} saved successfully!" + else: + return "❌ No book path available for saving" + + except Exception as e: + return f"❌ Error saving chunk: {str(e)}" + +def play_original_audio(): + """Play the original audio for the current chunk""" + if not chunk_state['current_chunk'] or not chunk_state['audio_dir']: + return "❌ No chunk selected or audio directory not found" + + try: + chunk_index = chunk_state['current_chunk']['index'] + audio_file = chunk_state['audio_dir'] / f"chunk_{chunk_index+1:05d}.wav" + + if not audio_file.exists(): + return f"❌ Audio file not found: {audio_file.name}" + + # Play audio in background thread to avoid blocking UI + def play_audio(): + try: + play_chunk_audio(str(audio_file)) + except Exception as e: + print(f"Error playing audio: {e}") + + threading.Thread(target=play_audio, daemon=True).start() + return f"πŸ”Š Playing original audio: {audio_file.name}" + + except Exception as e: + return f"❌ Error playing audio: {str(e)}" + +def resynthesize_chunk_audio(voice_selection, chunk_text, boundary_type, exaggeration, cfg_weight, temperature): + """Regenerate audio for the current chunk with new parameters""" + if not chunk_state['current_chunk']: + return "❌ No chunk selected" + + if not voice_selection or voice_selection == "-- Please Select Voice --": + return "❌ Please select a voice before resynthesizing" + + if chunk_state['operation_running']: + return "⚠️ Another operation is already running" + + try: + # Find selected voice info + selected_voice_data = None + for voice_name, voice_path, detection_method in chunk_state['voice_candidates']: + if f"{voice_name} ({detection_method})" == voice_selection: + selected_voice_data = (voice_name, voice_path, detection_method) + break + + if not selected_voice_data: + return "❌ Selected voice not found in candidates" + + voice_name, voice_path, detection_method = selected_voice_data + + # Update chunk with current parameters first + chunk_state['current_chunk']['text'] = chunk_text.strip() + chunk_state['current_chunk']['boundary_type'] = boundary_type + chunk_state['current_chunk']['tts_params'] = { + 'exaggeration': exaggeration, + 'cfg_weight': cfg_weight, + 'temperature': temperature + } + + # Start resynthesis in background + def resynth_worker(): + chunk_state['operation_running'] = True + try: + chunk_index = chunk_state['current_chunk']['index'] + result = synthesize_chunk( + chunk_state['current_chunk'], + voice_path, + str(chunk_state['audio_dir']), + chunk_index + ) + chunk_state['operation_running'] = False + return result + except Exception as e: + chunk_state['operation_running'] = False + print(f"Error in resynthesis: {e}") + return False + + # Run in thread + threading.Thread(target=resynth_worker, daemon=True).start() + + return f"🎀 Starting resynthesis with voice '{voice_name}'...\n⏳ This may take a few moments." + + except Exception as e: + chunk_state['operation_running'] = False + return f"❌ Error resynthesizing chunk: {str(e)}" + +def play_revised_audio(): + """Play the revised audio for the current chunk""" + if not chunk_state['current_chunk'] or not chunk_state['audio_dir']: + return "❌ No chunk selected or audio directory not found" + + try: + chunk_index = chunk_state['current_chunk']['index'] + # Look for revised audio file (typically has _revised suffix or similar) + revised_file = chunk_state['audio_dir'] / f"chunk_{chunk_index+1:05d}_revised.wav" + if not revised_file.exists(): + # Fallback to regular file if revised doesn't exist + revised_file = chunk_state['audio_dir'] / f"chunk_{chunk_index+1:05d}.wav" + + if not revised_file.exists(): + return f"❌ Revised audio file not found: {revised_file.name}" + + def play_audio(): + try: + play_chunk_audio(str(revised_file)) + except Exception as e: + print(f"Error playing revised audio: {e}") + + threading.Thread(target=play_audio, daemon=True).start() + return f"πŸ”Š Playing revised audio: {revised_file.name}" + + except Exception as e: + return f"❌ Error playing revised audio: {str(e)}" + +def accept_chunk_revision(): + """Accept the current chunk revision""" + if not chunk_state['current_chunk']: + return "❌ No chunk selected" + + try: + chunk_index = chunk_state['current_chunk']['index'] + result = accept_revision(chunk_index, str(chunk_state['audio_dir'])) + + if result: + return f"βœ… Revision accepted for chunk {chunk_index}" + else: + return f"❌ Failed to accept revision for chunk {chunk_index}" + + except Exception as e: + return f"❌ Error accepting revision: {str(e)}" + +def create_chunk_tools_tab(): + """Create Tab 7: Chunk Tools with all GUI functionality""" + + with gr.Column(): + gr.Markdown("# πŸ”§ Chunk Repair and Editing Tool") + gr.Markdown("*Interactive chunk editing, search, and audio regeneration - matches GUI Tab 7*") + + if not CHUNK_TOOLS_AVAILABLE: + gr.Markdown("### ❌ Chunk Tools Not Available") + gr.Markdown("Missing required backend modules. Please ensure all wrapper modules are installed.") + return {} + + # Book Selection Section + with gr.Row(): + with gr.Column(scale=2): + gr.Markdown("### πŸ“š Book Selection") + + available_books = get_available_repair_books() + book_choices = ["-- Select a Book --"] + [book['display'] for book in available_books] + + book_selector = gr.Dropdown( + label="Select Book for Chunk Editing", + choices=book_choices, + value="-- Select a Book --", + interactive=True, + info="Books with processed chunks available for editing" + ) + + refresh_books_btn = gr.Button( + "πŸ”„ Refresh Book List", + variant="secondary", + size="sm" + ) + + load_status = gr.Textbox( + label="Load Status", + value="No book selected", + interactive=False, + lines=2 + ) + + with gr.Column(scale=1): + gr.Markdown("### 🎀 Voice Selection") + + voice_info_display = gr.Markdown( + "No voice detected", + label="Voice Detection Status" + ) + + voice_selector = gr.Dropdown( + label="Select Voice for Resynthesis", + choices=["-- Please Select Voice --"], + value="-- Please Select Voice --", + interactive=True, + info="Detected voice candidates for this book" + ) + + refresh_voices_btn = gr.Button( + "πŸ”„ Re-detect Voice Candidates", + variant="secondary", + size="sm" + ) + + # Search and Selection Section + with gr.Row(): + with gr.Column(): + gr.Markdown("### πŸ” Search and Select Chunks") + + search_input = gr.Textbox( + label="Search for Text Fragment", + placeholder="Enter text to search for in chunks...", + interactive=True, + info="Search through chunk text content" + ) + + search_btn = gr.Button( + "πŸ” Search Chunks", + variant="primary", + size="lg" + ) + + search_results_display = gr.Markdown( + "No search performed yet", + label="Search Results" + ) + + chunk_selector = gr.Dropdown( + label="Select Chunk to Edit", + choices=[], + value=None, + interactive=True, + visible=False, + info="Choose chunk from search results" + ) + + # Chunk Editor Section + with gr.Column(): + gr.Markdown("### ✏️ Edit Selected Chunk") + + chunk_info_display = gr.Markdown( + "No chunk selected", + label="Chunk Information" + ) + + # Text editing + with gr.Row(): + chunk_text_editor = gr.Textbox( + label="Chunk Text", + placeholder="Select a chunk to edit its text...", + interactive=True, + lines=4, + info="Edit the text content of the selected chunk" + ) + + # Metadata and TTS Parameters + with gr.Row(): + boundary_selector = gr.Dropdown( + label="Boundary Type", + choices=[ + "none", "paragraph_end", "chapter_start", "chapter_end", "section_break", + "period", "comma", "semicolon", "colon", "question_mark", "exclamation", + "dash", "ellipsis", "quote_end" + ], + value="none", + interactive=True, + info="Chunk boundary classification" + ) + + exag_param = gr.Slider( + label="TTS Exaggeration", + minimum=0.0, maximum=3.0, step=0.1, + value=0.5, + interactive=True, + info="Speech exaggeration level" + ) + + cfg_param = gr.Slider( + label="TTS CFG Weight", + minimum=0.0, maximum=2.0, step=0.1, + value=0.5, + interactive=True, + info="CFG guidance strength" + ) + + temp_param = gr.Slider( + label="TTS Temperature", + minimum=0.0, maximum=2.0, step=0.1, + value=0.8, + interactive=True, + info="TTS randomness/creativity" + ) + + # Action Buttons + with gr.Row(): + play_original_btn = gr.Button( + "πŸ”Š Play Original", + variant="secondary", + size="lg", + interactive=True + ) + + save_changes_btn = gr.Button( + "πŸ’Ύ Save Changes", + variant="primary", + size="lg", + interactive=True + ) + + resynthesize_btn = gr.Button( + "🎀 Resynthesize", + variant="primary", + size="lg", + interactive=True + ) + + play_revised_btn = gr.Button( + "πŸ”Š Play Revised", + variant="secondary", + size="lg", + interactive=True + ) + + accept_revision_btn = gr.Button( + "βœ… Accept Revision", + variant="primary", + size="lg", + interactive=True + ) + + # Operation Status + with gr.Row(): + operation_status = gr.Textbox( + label="Operation Status", + value="Ready for chunk editing operations", + interactive=False, + lines=2 + ) + + # Event Handlers + def refresh_book_list(): + """Refresh the available books list""" + books = get_available_repair_books() + choices = ["-- Select a Book --"] + [book['display'] for book in books] + return gr.update(choices=choices, value="-- Select a Book --") + + def refresh_voice_candidates(): + """Refresh voice candidates for current book""" + if chunk_state['book_path']: + # Re-run voice detection + try: + book_name = chunk_state['book_path'].parent.parent.name if chunk_state['book_path'] else "" + likely_voices = get_likely_voices_for_book(book_name, chunk_state['book_path']) + chunk_state['voice_candidates'] = likely_voices + + voice_choices = ["-- Please Select Voice --"] + voice_info = "" + + if likely_voices: + for voice_name, voice_path, detection_method in likely_voices: + voice_choices.append(f"{voice_name} ({detection_method})") + voice_info = f"βœ… Found {len(likely_voices)} voice candidate(s). Please select voice before resynthesizing." + else: + voice_info = "❌ No voice candidates detected. Check JSON metadata or run.log." + + return voice_info, gr.update(choices=voice_choices) + except Exception as e: + return f"❌ Error refreshing voices: {str(e)}", gr.update(choices=["-- Please Select Voice --"]) + else: + return "No book selected - cannot refresh voice candidates", gr.update(choices=["-- Please Select Voice --"]) + + # Connect event handlers + refresh_books_btn.click( + refresh_book_list, + inputs=[], + outputs=[book_selector] + ) + + book_selector.change( + load_book_chunks, + inputs=[book_selector], + outputs=[ + load_status, search_input, search_results_display, chunk_text_editor, + boundary_selector, exag_param, cfg_param, temp_param, + chunk_info_display, voice_info_display, voice_selector + ] + ) + + refresh_voices_btn.click( + refresh_voice_candidates, + inputs=[], + outputs=[voice_info_display, voice_selector] + ) + + search_btn.click( + search_for_chunks, + inputs=[search_input], + outputs=[search_results_display, chunk_selector] + ) + + search_input.submit( + search_for_chunks, + inputs=[search_input], + outputs=[search_results_display, chunk_selector] + ) + + chunk_selector.change( + select_chunk_for_editing, + inputs=[chunk_selector], + outputs=[ + chunk_text_editor, boundary_selector, exag_param, cfg_param, + temp_param, chunk_info_display + ] + ) + + save_changes_btn.click( + save_chunk_changes, + inputs=[chunk_text_editor, boundary_selector, exag_param, cfg_param, temp_param], + outputs=[operation_status] + ) + + play_original_btn.click( + play_original_audio, + inputs=[], + outputs=[operation_status] + ) + + resynthesize_btn.click( + resynthesize_chunk_audio, + inputs=[voice_selector, chunk_text_editor, boundary_selector, exag_param, cfg_param, temp_param], + outputs=[operation_status] + ) + + play_revised_btn.click( + play_revised_audio, + inputs=[], + outputs=[operation_status] + ) + + accept_revision_btn.click( + accept_chunk_revision, + inputs=[], + outputs=[operation_status] + ) + + return { + 'book_selector': book_selector, + 'search_button': search_btn, + 'operation_status': operation_status + } + +if __name__ == "__main__": + # Test the tab + with gr.Blocks() as demo: + create_chunk_tools_tab() + + demo.launch() \ No newline at end of file diff --git a/gradio_tabs/tab8_json_generate.py b/gradio_tabs/tab8_json_generate.py new file mode 100644 index 0000000000000000000000000000000000000000..b7e470a381e7f47baa8727f29a7dbe76d1dcbae0 --- /dev/null +++ b/gradio_tabs/tab8_json_generate.py @@ -0,0 +1,561 @@ +#!/usr/bin/env python3 +""" +Gradio Tab 8: JSON Generate +Generate audiobooks directly from JSON files with voice selection - matches PyQt5 GUI Tab 8 functionality +""" + +import gradio as gr +import os +import sys +import threading +import time +import json +import subprocess +from pathlib import Path +from typing import List, Dict, Any, Optional, Tuple + +# Import backend functionality +try: + from modules.gui_json_generator import generate_audiobook_from_json + from modules.file_manager import list_voice_samples + from config.config import AUDIOBOOK_ROOT + JSON_GENERATE_AVAILABLE = True + print("βœ… JSON generation functionality available") +except ImportError as e: + print(f"⚠️ JSON generation functionality not available: {e}") + JSON_GENERATE_AVAILABLE = False + +# Global state for JSON generation +json_state = { + 'generation_running': False, + 'current_json_file': None, + 'current_voice': None, + 'generated_audiobook': None, + 'audio_process': None, + 'audio_position': 0, + 'audio_duration': 0, + 'progress': 0, + 'status': 'Ready' +} + +def get_available_json_files(): + """Find available JSON chunk files for generation""" + json_files = [] + + if not JSON_GENERATE_AVAILABLE: + return json_files + + # Look in TTS processing directories + audiobook_root = Path(AUDIOBOOK_ROOT) + if audiobook_root.exists(): + for book_dir in audiobook_root.iterdir(): + if book_dir.is_dir(): + tts_chunks_dir = book_dir / "TTS" / "text_chunks" + json_path = tts_chunks_dir / "chunks_info.json" + if json_path.exists(): + try: + with open(json_path, 'r') as f: + chunks_data = json.load(f) + chunk_count = len(chunks_data) + + json_files.append({ + 'name': book_dir.name, + 'path': str(json_path), + 'chunk_count': chunk_count, + 'display': f"{book_dir.name} ({chunk_count} chunks)", + 'type': 'TTS' + }) + except: + pass + + # Look in Text_Input directory + text_input_dir = Path("Text_Input") + if text_input_dir.exists(): + for json_file in text_input_dir.glob("*_chunks.json"): + book_name = json_file.stem.replace("_chunks", "") + # Only add if not already found in TTS directories + if not any(jf['name'] == book_name for jf in json_files): + try: + with open(json_file, 'r') as f: + chunks_data = json.load(f) + chunk_count = len(chunks_data) + + json_files.append({ + 'name': book_name, + 'path': str(json_file), + 'chunk_count': chunk_count, + 'display': f"{book_name} ({chunk_count} chunks)", + 'type': 'Text_Input' + }) + except: + pass + + return sorted(json_files, key=lambda x: x['name']) + +def get_available_voices(): + """Get list of available voice samples""" + voices = [] + + if not JSON_GENERATE_AVAILABLE: + return voices + + try: + voice_files = list_voice_samples() + for voice_file in voice_files: + voices.append({ + 'name': voice_file.stem, + 'path': str(voice_file), + 'display': f"{voice_file.stem} ({voice_file.name})" + }) + except Exception as e: + print(f"Error getting voices: {e}") + + return sorted(voices, key=lambda x: x['name']) + +def load_json_file_info(file_selection): + """Load information about selected JSON file""" + if not file_selection or file_selection == "-- Select JSON File --": + return "No JSON file selected", "No file loaded" + + try: + # Find the selected file + json_files = get_available_json_files() + selected_file = None + for jf in json_files: + if jf['display'] == file_selection: + selected_file = jf + break + + if not selected_file: + return "❌ Selected file not found", "Error" + + json_state['current_json_file'] = selected_file + + # Load and analyze JSON + with open(selected_file['path'], 'r') as f: + chunks_data = json.load(f) + + chunk_count = len(chunks_data) + + # Calculate estimated metrics + total_words = sum(chunk.get('word_count', len(chunk.get('text', '').split())) for chunk in chunks_data) + estimated_duration_seconds = total_words * 0.4 # Rough estimate: 0.4 seconds per word + estimated_duration = f"{int(estimated_duration_seconds // 3600):02d}:{int((estimated_duration_seconds % 3600) // 60):02d}:{int(estimated_duration_seconds % 60):02d}" + + # Check for existing audio chunks + book_name = selected_file['name'] + audio_chunks_dir = Path(AUDIOBOOK_ROOT) / book_name / "TTS" / "audio_chunks" + existing_chunks = 0 + if audio_chunks_dir.exists(): + existing_chunks = len(list(audio_chunks_dir.glob("chunk_*.wav"))) + + info = f"""**πŸ“„ JSON File Analysis:** +**File:** {selected_file['path']} +**Book:** {book_name} +**Source:** {selected_file['type']} + +**Content:** +β€’ Total Chunks: {chunk_count} +β€’ Total Words: {total_words:,} +β€’ Estimated Duration: {estimated_duration} + +**Status:** +β€’ Existing Audio: {existing_chunks}/{chunk_count} chunks +β€’ Ready for Generation: {'βœ… Yes' if chunk_count > 0 else '❌ No chunks found'}""" + + current_file = f"πŸ“ Selected: {book_name} ({chunk_count} chunks)" + + return info, current_file + + except Exception as e: + return f"❌ Error loading JSON file: {str(e)}", "Error loading file" + +def start_json_generation(json_selection, voice_selection, temperature_override): + """Start JSON-to-audiobook generation""" + if json_state['generation_running']: + return "⚠️ Generation already in progress", 0, "Generation running...", "Processing..." + + if not json_selection or json_selection == "-- Select JSON File --": + return "❌ Please select a JSON file", 0, "No file selected", "Ready" + + if not voice_selection or voice_selection == "-- Select Voice --": + return "❌ Please select a voice", 0, "No voice selected", "Ready" + + try: + # Find selected files + json_files = get_available_json_files() + voices = get_available_voices() + + selected_json = None + for jf in json_files: + if jf['display'] == json_selection: + selected_json = jf + break + + selected_voice = None + for v in voices: + if v['display'] == voice_selection: + selected_voice = v + break + + if not selected_json or not selected_voice: + return "❌ Invalid selection", 0, "Selection error", "Ready" + + json_state['current_json_file'] = selected_json + json_state['current_voice'] = selected_voice + json_state['generation_running'] = True + json_state['progress'] = 0 + json_state['status'] = 'Starting generation...' + + # Start generation in background thread + def generation_worker(): + try: + json_state['status'] = '🎡 Generating audiobook from JSON...' + json_state['progress'] = 10 + + # Apply temperature override if specified + temp_setting = None + if temperature_override and temperature_override > 0: + temp_setting = temperature_override + + # Run the generation + success, message, audiobook_path = generate_audiobook_from_json( + selected_json['path'], + selected_voice['name'], + temp_setting + ) + + if success: + json_state['progress'] = 100 + json_state['status'] = 'βœ… Generation completed successfully!' + json_state['generated_audiobook'] = audiobook_path + else: + json_state['progress'] = 0 + json_state['status'] = f'❌ Generation failed: {message}' + json_state['generated_audiobook'] = None + + except Exception as e: + json_state['progress'] = 0 + json_state['status'] = f'❌ Generation error: {str(e)}' + json_state['generated_audiobook'] = None + finally: + json_state['generation_running'] = False + + # Start worker thread + threading.Thread(target=generation_worker, daemon=True).start() + + return ( + "🎡 Starting JSON audiobook generation...", + 10, + f"Generating: {selected_json['name']} with voice: {selected_voice['name']}", + "Generation started" + ) + + except Exception as e: + json_state['generation_running'] = False + return f"❌ Error starting generation: {str(e)}", 0, "Generation failed", "Error" + +def get_generation_status(): + """Get current generation status""" + return ( + json_state.get('status', 'Ready'), + json_state.get('progress', 0), + json_state.get('generated_audiobook', 'No audiobook generated') or 'No audiobook generated', + "Generation running..." if json_state['generation_running'] else "Ready" + ) + +def stop_json_generation(): + """Stop current generation (if possible)""" + if json_state['generation_running']: + json_state['generation_running'] = False + json_state['status'] = '⏹️ Generation stopped by user' + json_state['progress'] = 0 + return "⏹️ Generation stopped", 0, "Generation stopped", "Ready" + else: + return "No generation to stop", json_state.get('progress', 0), json_state.get('status', 'Ready'), "Ready" + +# Audio playback functions (simplified - web browsers handle audio playback) +def play_audio(): + """Play generated audiobook""" + if not json_state.get('generated_audiobook'): + return "❌ No audiobook generated to play" + + try: + audiobook_path = json_state['generated_audiobook'] + if isinstance(audiobook_path, str): + audiobook_path = Path(audiobook_path) + + if not audiobook_path.exists(): + return f"❌ Audio file not found: {audiobook_path}" + + # For web interface, we can't directly control audio playback + # User would need to download and play manually + return f"πŸ”Š Audio file ready for playback: {audiobook_path.name}" + + except Exception as e: + return f"❌ Error accessing audio: {str(e)}" + +def create_json_generate_tab(): + """Create Tab 8: JSON Generate with all GUI functionality""" + + with gr.Column(): + gr.Markdown("# πŸ“„ Generate Audiobook from JSON") + gr.Markdown("*Direct audiobook generation from preprocessed JSON files - matches GUI Tab 8*") + + if not JSON_GENERATE_AVAILABLE: + gr.Markdown("### ❌ JSON Generation Not Available") + gr.Markdown("Missing required backend modules. Please ensure modules/gui_json_generator.py is available.") + return {} + + # JSON File Selection Section + with gr.Row(): + with gr.Column(scale=2): + gr.Markdown("### πŸ“„ JSON File Selection") + + json_files = get_available_json_files() + json_choices = ["-- Select JSON File --"] + [jf['display'] for jf in json_files] + + json_file_selector = gr.Dropdown( + label="JSON Chunks File", + choices=json_choices, + value="-- Select JSON File --", + interactive=True, + info="Select preprocessed JSON file containing text chunks" + ) + + # Manual path input + json_manual_path = gr.Textbox( + label="Or Enter JSON Path Manually", + placeholder="e.g., /path/to/book_chunks.json", + interactive=True, + info="Direct path to JSON chunks file" + ) + + refresh_json_btn = gr.Button( + "πŸ”„ Refresh JSON Files", + variant="secondary", + size="sm" + ) + + with gr.Column(scale=1): + json_file_info = gr.Markdown( + "No JSON file selected", + label="File Information" + ) + + # Voice Selection Section + with gr.Row(): + with gr.Column(): + gr.Markdown("### 🎀 Voice Selection") + + voices = get_available_voices() + voice_choices = ["-- Select Voice --"] + [v['display'] for v in voices] + + voice_selector = gr.Dropdown( + label="Voice for Generation", + choices=voice_choices, + value="-- Select Voice --", + interactive=True, + info="Select voice sample for audiobook generation" + ) + + refresh_voices_btn = gr.Button( + "πŸ”„ Refresh Voice List", + variant="secondary", + size="sm" + ) + + # Generation Parameters + with gr.Row(): + with gr.Column(): + gr.Markdown("### βš™οΈ Generation Parameters") + + temperature_override = gr.Slider( + label="Temperature Override (Optional)", + minimum=0.0, maximum=2.0, step=0.1, + value=0.0, + interactive=True, + info="Override TTS temperature (0 = use JSON values)" + ) + + gr.Markdown("*Leave temperature at 0 to use individual chunk TTS parameters from JSON*") + + # Generation Controls + with gr.Row(): + generate_btn = gr.Button( + "🎡 Generate Audiobook from JSON", + variant="primary", + size="lg", + interactive=True + ) + + stop_btn = gr.Button( + "⏹️ Stop Generation", + variant="secondary", + size="lg", + interactive=True + ) + + # Progress and Status + with gr.Row(): + with gr.Column(scale=2): + generation_status = gr.Textbox( + label="Generation Status", + value="Ready for JSON audiobook generation", + interactive=False, + lines=2 + ) + + progress_bar = gr.Slider( + label="Progress %", + minimum=0, maximum=100, step=1, + value=0, + interactive=False, + info="Generation progress" + ) + + with gr.Column(scale=1): + current_file_display = gr.Textbox( + label="Current File", + value="No file selected", + interactive=False + ) + + operation_status = gr.Textbox( + label="Operation Status", + value="Ready", + interactive=False + ) + + # Audio Playback Section + with gr.Column(): + gr.Markdown("### πŸ”Š Generated Audiobook") + + generated_file_info = gr.Textbox( + label="Generated Audiobook", + value="No audiobook generated", + interactive=False, + info="Path to generated audiobook file" + ) + + # Simple playback controls (web-based) + with gr.Row(): + play_info_btn = gr.Button( + "πŸ“ Show File Info", + variant="secondary", + size="lg" + ) + + download_btn = gr.Button( + "⬇️ Download Instructions", + variant="secondary", + size="lg" + ) + + playback_info = gr.Markdown( + "*Generated audiobook files will be saved to the Output/ directory. Download and play using your preferred audio player.*", + visible=True + ) + + # Status refresh button + with gr.Row(): + refresh_status_btn = gr.Button( + "πŸ”„ Refresh Status", + variant="secondary", + size="sm" + ) + + # Event Handlers + def refresh_json_files(): + """Refresh JSON files list""" + json_files = get_available_json_files() + choices = ["-- Select JSON File --"] + [jf['display'] for jf in json_files] + return gr.update(choices=choices, value="-- Select JSON File --") + + def refresh_voice_list(): + """Refresh voice samples list""" + voices = get_available_voices() + choices = ["-- Select Voice --"] + [v['display'] for v in voices] + return gr.update(choices=choices, value="-- Select Voice --") + + def show_download_info(): + """Show download/playback instructions""" + return """## πŸ“ Audiobook File Access + +**Generated audiobooks are saved to:** +- `Output/[BookName]/` directory +- Files: `.wav` (uncompressed) and `.m4b` (audiobook format) + +**To play your audiobook:** +1. Navigate to the Output directory +2. Download the `.m4b` file for best audiobook experience +3. Use any audio player (VLC, iTunes, Audible app, etc.) +4. The `.m4b` format supports chapters and bookmarks + +**File locations:** +- Individual chunks: `Audiobook/[BookName]/TTS/audio_chunks/` +- Combined audiobook: `Output/[BookName]/` +""" + + # Connect event handlers + refresh_json_btn.click( + refresh_json_files, + inputs=[], + outputs=[json_file_selector] + ) + + refresh_voices_btn.click( + refresh_voice_list, + inputs=[], + outputs=[voice_selector] + ) + + json_file_selector.change( + load_json_file_info, + inputs=[json_file_selector], + outputs=[json_file_info, current_file_display] + ) + + generate_btn.click( + start_json_generation, + inputs=[json_file_selector, voice_selector, temperature_override], + outputs=[generation_status, progress_bar, generated_file_info, operation_status] + ) + + stop_btn.click( + stop_json_generation, + inputs=[], + outputs=[generation_status, progress_bar, generated_file_info, operation_status] + ) + + refresh_status_btn.click( + get_generation_status, + inputs=[], + outputs=[generation_status, progress_bar, generated_file_info, operation_status] + ) + + play_info_btn.click( + play_audio, + inputs=[], + outputs=[generated_file_info] + ) + + download_btn.click( + show_download_info, + inputs=[], + outputs=[playback_info] + ) + + return { + 'json_selector': json_file_selector, + 'voice_selector': voice_selector, + 'generate_button': generate_btn, + 'status_display': generation_status + } + +if __name__ == "__main__": + # Test the tab + with gr.Blocks() as demo: + create_json_generate_tab() + + demo.launch() \ No newline at end of file diff --git a/modules/audio_processor.py b/modules/audio_processor.py index 6709d0f1e9aa45c418a96dd2b9e9fe6d8443150c..8163f3c3fd0fe70c02fc384a75030bb64dfc7e4d 100644 --- a/modules/audio_processor.py +++ b/modules/audio_processor.py @@ -1,6 +1,42 @@ """ -Audio Processing Module -Handles audio validation, effects, cleanup, and quality control +ChatterboxTTS Audio Processing & Quality Control Module +====================================================== + +OVERVIEW: +This module provides comprehensive audio quality validation, enhancement, and +post-processing for TTS-generated audio. It ensures consistent quality across +audiobook chapters by detecting and handling common TTS artifacts. + +MAIN COMPONENTS: +1. QUALITY VALIDATION: Detects clipping, silence, flatness, and other artifacts +2. HUM DETECTION: Identifies and flags TTS-generated audio hum using frequency analysis +3. AUDIO ENHANCEMENT: Normalization, trimming, and quality improvements +4. ASR VALIDATION: Optional speech recognition for quality verification +5. SILENCE INSERTION: Adds appropriate pauses based on punctuation boundaries +6. AUDIO HEALTH CHECKS: Comprehensive audio file validation + +CRITICAL QUALITY FEATURES: +- TTS hum detection with configurable frequency thresholds +- Audio clipping detection and prevention +- Silence detection at beginning/end of chunks +- Flatness detection (monotone audio identification) +- ASR-based transcription accuracy validation +- Dynamic range and loudness assessment + +WORKFLOW: +Raw TTS Audio β†’ Quality Validation β†’ Artifact Detection β†’ +Enhancement Processing β†’ Silence Insertion β†’ Final Audio Output + +TECHNICAL DETAILS: +- Supports multiple audio formats (WAV, MP3, FLAC) +- Configurable quality thresholds for different validation types +- Integration with Whisper ASR for transcription validation +- Memory-efficient processing for large audio files +- Detailed logging for quality control debugging + +PERFORMANCE IMPACT: +Essential for maintaining consistent audiobook quality and preventing +distribution of low-quality audio with TTS artifacts or technical issues. """ import numpy as np @@ -13,6 +49,14 @@ from pathlib import Path from pydub import AudioSegment, silence from config.config import * +# Enhanced imports for spectral analysis +try: + import librosa + LIBROSA_AVAILABLE = True +except ImportError: + LIBROSA_AVAILABLE = False + logging.warning("librosa not available - enhanced spectral analysis disabled") + # ============================================================================ # AUDIO QUALITY DETECTION # ============================================================================ @@ -127,6 +171,259 @@ def has_mid_energy_drop(wav_tensor, sr, window_ms=250, threshold_ratio=None): return False +def detect_spectral_artifacts(audio_path_or_segment, use_mfcc=True): + """ + Enhanced spectral anomaly detection using MFCC analysis. + + Args: + audio_path_or_segment: Path to audio file or AudioSegment object + use_mfcc: Whether to use MFCC-based analysis (requires librosa) + + Returns: + float: Quality score (0.0-1.0, higher is better) + """ + try: + # Load audio data + if isinstance(audio_path_or_segment, (str, Path)): + y, sr = sf.read(str(audio_path_or_segment)) + elif isinstance(audio_path_or_segment, AudioSegment): + # Convert AudioSegment to numpy array + samples = np.array(audio_path_or_segment.get_array_of_samples()) + if audio_path_or_segment.channels == 2: + samples = samples.reshape((-1, 2)).mean(axis=1) + y = samples.astype(np.float32) / audio_path_or_segment.max_possible_amplitude + sr = audio_path_or_segment.frame_rate + else: + return 0.5 # Unknown format, neutral score + + # Ensure mono + if len(y.shape) > 1: + y = y[:, 0] + + # Basic energy-based anomaly detection (always available) + energy = np.abs(y) + energy_variance = np.var(energy) + + # Simple threshold-based scoring + basic_score = 1.0 - min(energy_variance / 0.1, 1.0) + + # Enhanced MFCC-based detection if librosa is available + if use_mfcc and LIBROSA_AVAILABLE and ENABLE_MFCC_VALIDATION: + try: + # Compute MFCC features + mfccs = librosa.feature.mfcc(y=y, sr=sr, n_mfcc=13) + + # Calculate spectral variance across time + mfcc_variance = np.var(mfccs, axis=1) + max_variance_jump = np.max(np.abs(np.diff(mfcc_variance))) + + # Normalize and score + mfcc_score = 1.0 - min(max_variance_jump / SPECTRAL_VARIANCE_LIMIT, 1.0) + + # Combine scores (weighted average) + final_score = 0.6 * mfcc_score + 0.4 * basic_score + + except Exception as e: + logging.debug(f"MFCC analysis failed: {e}") + final_score = basic_score + else: + final_score = basic_score + + return max(0.0, min(1.0, final_score)) + + except Exception as e: + logging.error(f"Spectral artifact detection failed: {e}") + return 0.5 # Neutral score on failure + +def evaluate_chunk_quality(audio_path_or_segment, reference_text=None, include_spectral=True, asr_model=None): + """ + Composite quality evaluation for a single audio chunk. + Acts as a clearinghouse - only runs individual checks when they are specifically enabled. + + Args: + audio_path_or_segment: Path to audio file or AudioSegment object + reference_text: Original text for comparison (optional) + include_spectral: Whether to include spectral analysis + asr_model: Pre-loaded ASR model to avoid duplicate loading + + Returns: + float: Composite quality score (0.0-1.0) + """ + # Skip all validation if output validation clearinghouse is disabled + if not ENABLE_OUTPUT_VALIDATION: + return 1.0 # Pass all chunks if validation is completely disabled + + scores = [] + + # Spectral anomaly detection (only if MFCC validation is enabled) + if include_spectral and ENABLE_MFCC_VALIDATION: + spectral_score = detect_spectral_artifacts(audio_path_or_segment) + scores.append(spectral_score) + + # ASR text validation (only if ASR is enabled AND reference text provided) + if reference_text and ENABLE_ASR: + text_validation_score = validate_output_matches_input(audio_path_or_segment, reference_text, asr_model) + scores.append(text_validation_score) + + # Basic audio health (if it's a file path) + if isinstance(audio_path_or_segment, (str, Path)): + try: + health_result = check_audio_health(audio_path_or_segment) + # Convert health result to score (assuming False = good, True = bad) + health_score = 0.2 if health_result else 0.8 + scores.append(health_score) + except Exception: + scores.append(0.5) # Neutral score on failure + + # Return average of all scores + return sum(scores) / len(scores) if scores else 0.5 + +def validate_output_matches_input(audio_path_or_segment, reference_text, asr_model=None): + """ + Validate that TTS audio output matches the input text using ASR transcription. + + Args: + audio_path_or_segment: Path to audio file or AudioSegment object + reference_text: Original input text that should have been synthesized + asr_model: Optional pre-loaded ASR model (will load whisper if None) + + Returns: + float: Validation score (0.0-1.0, higher means better match) + """ + try: + # Convert AudioSegment to temporary file if needed + temp_file = None + if isinstance(audio_path_or_segment, AudioSegment): + import tempfile + temp_file = tempfile.NamedTemporaryFile(suffix='.wav', delete=False) + audio_path_or_segment.export(temp_file.name, format='wav') + audio_path = temp_file.name + else: + audio_path = str(audio_path_or_segment) + + # Load ASR model if not provided + if asr_model is None: + try: + from modules.asr_manager import load_asr_model_adaptive + # Use adaptive manager for fallback ASR loading + asr_model, _ = load_asr_model_adaptive() + if asr_model is None: + logging.warning("ASR model loading failed in audio processor") + return 0.8 # Neutral score if ASR unavailable + except ImportError: + logging.warning("Whisper not available for output validation") + return 0.8 # Neutral score if ASR unavailable + + # Transcribe the audio + result = asr_model.transcribe(audio_path) + transcribed_text = result.get("text", "").strip() + + # Clean up temporary file + if temp_file: + import os + os.unlink(temp_file.name) + + # Calculate text similarity using F1 score + similarity_score = calculate_text_similarity(reference_text, transcribed_text) + + # Log significant mismatches for debugging + if similarity_score < OUTPUT_VALIDATION_THRESHOLD: + logging.warning(f"πŸ” Output validation failed (score: {similarity_score:.3f})") + logging.warning(f" Expected: {reference_text}") + logging.warning(f" Got: {transcribed_text}") + + return similarity_score + + except Exception as e: + logging.error(f"Output validation failed: {e}") + return 0.8 # Use neutral-good score to avoid regeneration on ASR errors + +def calculate_text_similarity(text1, text2): + """ + Calculate similarity between two texts using word-level F1 score. + + Args: + text1: Reference text + text2: Comparison text + + Returns: + float: F1 similarity score (0.0-1.0) + """ + # Normalize texts (lowercase, remove punctuation, split into words) + import re + + def normalize_text(text): + # Convert to lowercase and remove punctuation + text = re.sub(r'[^\w\s]', '', text.lower()) + # Split into words and filter empty strings + return [word for word in text.split() if word] + + words1 = set(normalize_text(text1)) + words2 = set(normalize_text(text2)) + + if not words1 and not words2: + return 1.0 # Both empty + + if not words1 or not words2: + return 0.0 # One empty, one not + + # Calculate precision, recall, and F1 + intersection = words1.intersection(words2) + precision = len(intersection) / len(words2) if words2 else 0 + recall = len(intersection) / len(words1) if words1 else 0 + + if precision + recall == 0: + return 0.0 + + f1_score = 2 * (precision * recall) / (precision + recall) + return f1_score + +def adjust_parameters_for_retry(params, quality_score, attempt_num): + """ + Adjust TTS parameters for regeneration attempts. + + Args: + params: Current TTS parameters dictionary + quality_score: Quality score from previous attempt (0.0-1.0) + attempt_num: Current attempt number (0-based) + + Returns: + dict: Adjusted parameters + """ + adjusted = params.copy() + + # Adjustment strategy based on quality score and attempt number + if quality_score < 0.3: + # Very poor quality - more aggressive adjustments + temp_adj = REGEN_TEMPERATURE_ADJUSTMENT * 2 + exag_adj = REGEN_EXAGGERATION_ADJUSTMENT * 2 + cfg_adj = REGEN_CFG_ADJUSTMENT * 2 + elif quality_score < 0.6: + # Moderate quality issues - standard adjustments + temp_adj = REGEN_TEMPERATURE_ADJUSTMENT + exag_adj = REGEN_EXAGGERATION_ADJUSTMENT + cfg_adj = REGEN_CFG_ADJUSTMENT + else: + # Minor quality issues - gentle adjustments + temp_adj = REGEN_TEMPERATURE_ADJUSTMENT * 0.5 + exag_adj = REGEN_EXAGGERATION_ADJUSTMENT * 0.5 + cfg_adj = REGEN_CFG_ADJUSTMENT * 0.5 + + # Apply adjustments based on attempt number + if attempt_num == 1: + # First retry: reduce temperature (less randomness) + adjusted['temperature'] = max(TTS_PARAM_MIN_TEMPERATURE, + adjusted['temperature'] - temp_adj) + elif attempt_num == 2: + # Second retry: adjust exaggeration (less emotion) + adjusted['exaggeration'] = max(TTS_PARAM_MIN_EXAGGERATION, + adjusted['exaggeration'] - exag_adj) + # Also increase cfg_weight (more faithful to text) + adjusted['cfg_weight'] = min(TTS_PARAM_MAX_CFG_WEIGHT, + adjusted['cfg_weight'] + cfg_adj) + + return adjusted + # ============================================================================ # PROBLEMATIC CHUNK HANDLING # ============================================================================ @@ -327,10 +624,10 @@ def smart_audio_validation_memory(audio_segment, sample_rate): # Basic validation - can be enhanced with hum detection later # For now, just return the audio as-is is_quarantined = False - + # Could add memory-based hum detection here # is_quarantined = detect_hum_memory(audio_segment, sample_rate) - + return audio_segment, is_quarantined def add_contextual_silence_memory(audio_segment, boundary_type): @@ -341,7 +638,7 @@ def add_contextual_silence_memory(audio_segment, boundary_type): SILENCE_COMMA, SILENCE_SEMICOLON, SILENCE_COLON, SILENCE_PERIOD, SILENCE_QUESTION_MARK, SILENCE_EXCLAMATION, SILENCE_DASH, SILENCE_ELLIPSIS, SILENCE_QUOTE_END ) - + silence_durations = { # Structural boundaries "chapter_start": SILENCE_CHAPTER_START, @@ -359,12 +656,12 @@ def add_contextual_silence_memory(audio_segment, boundary_type): "ellipsis": SILENCE_ELLIPSIS, "quote_end": SILENCE_QUOTE_END, } - + if boundary_type in silence_durations: duration = silence_durations[boundary_type] silence_segment = AudioSegment.silent(duration=duration) return audio_segment + silence_segment - + return audio_segment def smart_fade_out(wav_path, silence_thresh_db=-40, min_silence_len=300): @@ -403,12 +700,12 @@ def smart_fade_out(wav_path, silence_thresh_db=-40, min_silence_len=300): def trim_audio_endpoint(audio_segment, threshold=None, buffer_ms=None): """ Trim audio to the detected end of speech using RMS energy analysis. - + Args: audio_segment: pydub AudioSegment object threshold: RMS threshold for speech detection (from config if None) buffer_ms: Buffer to add after detected endpoint (from config if None) - + Returns: Trimmed AudioSegment """ @@ -416,88 +713,88 @@ def trim_audio_endpoint(audio_segment, threshold=None, buffer_ms=None): threshold = SPEECH_ENDPOINT_THRESHOLD if buffer_ms is None: buffer_ms = TRIMMING_BUFFER_MS - + # Convert to numpy array for analysis samples = np.array(audio_segment.get_array_of_samples()) if audio_segment.channels == 2: samples = samples.reshape((-1, 2)).mean(axis=1) - + # Normalize samples samples = samples.astype(np.float32) / audio_segment.max_possible_amplitude - + # Calculate RMS in sliding windows (50ms windows) window_size = int(0.05 * audio_segment.frame_rate) # 50ms rms_values = [] - + for i in range(0, len(samples) - window_size, window_size // 2): window = samples[i:i + window_size] rms = np.sqrt(np.mean(window ** 2)) rms_values.append(rms) - + # Find actual end of speech using energy decay detection speech_end_idx = 0 # Default to beginning if no speech found - + # Look for a significant and sustained drop in energy # Scan backwards to find where energy consistently stays above a higher threshold strong_speech_threshold = threshold * 3 # 3x threshold for "real" speech - + for i in range(len(rms_values) - 1, -1, -1): if rms_values[i] > strong_speech_threshold: # Found strong speech, check if it's sustained # Look forward to see if energy drops and stays low sustained_speech = True windows_ahead = min(10, len(rms_values) - i) # Look ahead up to 10 windows (250ms) - + # Check if most of the next windows have reasonable speech levels speech_count = 0 for j in range(i, min(i + windows_ahead, len(rms_values))): if rms_values[j] > threshold: speech_count += 1 - + # If this looks like the end of sustained speech content if speech_count >= max(1, windows_ahead * 0.3): # At least 30% speech in next windows speech_end_idx = i break - + # If no strong speech found, fall back to simple threshold method but be conservative if speech_end_idx == 0: for i in range(len(rms_values) - 1, -1, -1): if rms_values[i] > threshold * 2: # Use 2x threshold for fallback speech_end_idx = i break - + # Convert back to milliseconds and add buffer # Convert window index to sample position, then to milliseconds sample_position = speech_end_idx * (window_size // 2) speech_end_ms = int(sample_position * 1000 / audio_segment.frame_rate) trim_point_ms = min(speech_end_ms + buffer_ms, len(audio_segment)) - + return audio_segment[:trim_point_ms] def process_audio_with_trimming_and_silence(audio_segment, boundary_type, enable_trimming=None): """ Complete audio processing: trim to speech endpoint + add punctuation-based silence. - + Args: audio_segment: pydub AudioSegment object boundary_type: Boundary type from text processing enable_trimming: Whether to trim audio (from config if None) - + Returns: Processed AudioSegment with trimming and appropriate silence """ if enable_trimming is None: enable_trimming = ENABLE_AUDIO_TRIMMING - + processed_audio = audio_segment - + # Step 1: Trim to speech endpoint if enabled if enable_trimming: processed_audio = trim_audio_endpoint(processed_audio) - + # Step 2: Add punctuation-appropriate silence processed_audio = add_contextual_silence_memory(processed_audio, boundary_type) - + return processed_audio # ============================================================================ diff --git a/modules/file_manager.py b/modules/file_manager.py index d4c33fd4fed1a2ab769cce0e919e76423b5cbea1..45c2fbeb3f5450f7a6797c6ff6bff7f8e26b31f7 100644 --- a/modules/file_manager.py +++ b/modules/file_manager.py @@ -1,6 +1,50 @@ """ -File Manager Module -Handles I/O operations, M4B conversion, metadata, and FFmpeg operations +ChatterboxTTS File Management & Media Processing Module +====================================================== + +OVERVIEW: +This module handles all file system operations, media format conversions, and +metadata management for ChatterboxTTS. It manages the complex directory structure +for audiobook production and handles conversion to final distribution formats. + +MAIN COMPONENTS: +1. DIRECTORY MANAGEMENT: Creates and maintains audiobook processing directories +2. AUDIO DISCOVERY: Locates and validates audio files across directory structures +3. M4B CONVERSION: Converts WAV chunks to M4B audiobook format using FFmpeg +4. METADATA HANDLING: Adds cover art, chapters, and book information to audiobooks +5. FILE VALIDATION: Ensures audio file compatibility and format requirements +6. VOICE SAMPLE MANAGEMENT: Handles voice sample discovery and validation + +KEY OPERATIONS: +- Directory structure setup for new audiobooks +- Audio chunk discovery and organization +- WAV to M4B conversion with chapter markers +- Cover art integration and metadata embedding +- File compatibility checking (24kHz requirement for voice samples) +- Final audiobook packaging and organization + +DIRECTORY STRUCTURE MANAGED: +``` +Audiobook/[book_name]/ +β”œβ”€β”€ TTS/ +β”‚ β”œβ”€β”€ text_chunks/ # Individual text chunk files +β”‚ └── audio_chunks/ # Generated WAV audio chunks +β”œβ”€β”€ [book_name].m4b # Final audiobook file +β”œβ”€β”€ processing.log # Processing logs +└── metadata files # Cover art, chapter info +``` + +TECHNICAL FEATURES: +- FFmpeg integration for media processing +- Automatic cover art detection and integration +- Chapter marker generation from chunk structure +- Metadata preservation across format conversions +- File system safety with validation and error handling +- Cross-platform file operations (Windows/Linux/Mac) + +PERFORMANCE CONSIDERATIONS: +Handles large audio files efficiently with streaming processing +and manages disk space through temporary file cleanup. """ import subprocess @@ -18,7 +62,7 @@ from config.config import * def list_voice_samples(): """List available voice samples""" - return sorted(VOICE_SAMPLES_DIR.glob("*.wav")) + return sorted(VOICE_SAMPLES_DIR.glob("*.wav"), key=lambda x: x.stem.lower()) def ensure_voice_sample_compatibility(input_path, output_dir=None): """Ensure voice sample is compatible with TTS (24kHz mono)""" @@ -63,14 +107,17 @@ def run_ffmpeg(cmd): # M4B CONVERSION WITH NORMALIZATION # ============================================================================ -def convert_to_m4b_with_peak_normalization(wav_path, temp_m4b_path, target_db=-3.0): +def convert_to_m4b_with_peak_normalization(wav_path, temp_m4b_path, target_db=-3.0, custom_speed=None): """Convert WAV to M4B with peak normalization""" print("πŸš€ Converting to m4b with peak normalization...") # Build audio filter chain + speed_to_use = custom_speed if custom_speed is not None else ATEMPO_SPEED audio_filters = [f"loudnorm=I=-16:TP={target_db}:LRA=11"] - if ATEMPO_SPEED != 1.0: - audio_filters.append(f"atempo={ATEMPO_SPEED}") + if speed_to_use != 1.0: + audio_filters.append(f"atempo={speed_to_use}") + + print(f"πŸš€ Converting to m4b with peak normalization and speed {speed_to_use}x...") cmd = [ "ffmpeg", "-y", @@ -96,7 +143,7 @@ def convert_to_m4b_with_peak_normalization(wav_path, temp_m4b_path, target_db=-3 process.wait() print("\nβœ… Conversion with normalization complete.") -def convert_to_m4b_with_loudness_normalization(wav_path, temp_m4b_path): +def convert_to_m4b_with_loudness_normalization(wav_path, temp_m4b_path, custom_speed=None): """Convert WAV to M4B with two-pass loudness normalization""" import json @@ -130,10 +177,11 @@ def convert_to_m4b_with_loudness_normalization(wav_path, temp_m4b_path): # Step 2: Apply normalization with measured values print("πŸ”§ Applying normalization...") - # Build audio filter chain + # Build audio filter chain + speed_to_use = custom_speed if custom_speed is not None else ATEMPO_SPEED audio_filters = [f"loudnorm=I=-16:TP=-1.5:LRA=11:measured_I={loudness_data['input_i']}:measured_LRA={loudness_data['input_lra']}:measured_TP={loudness_data['input_tp']}:measured_thresh={loudness_data['input_thresh']}:offset={loudness_data['target_offset']}:linear=true:print_format=summary"] - if ATEMPO_SPEED != 1.0: - audio_filters.append(f"atempo={ATEMPO_SPEED}") + if speed_to_use != 1.0: + audio_filters.append(f"atempo={speed_to_use}") cmd = [ "ffmpeg", "-y", @@ -159,14 +207,15 @@ def convert_to_m4b_with_loudness_normalization(wav_path, temp_m4b_path): process.wait() print("\nβœ… Two-pass normalization complete.") -def convert_to_m4b_with_simple_normalization(wav_path, temp_m4b_path, target_db=-6.0): +def convert_to_m4b_with_simple_normalization(wav_path, temp_m4b_path, target_db=-6.0, custom_speed=None): """Convert WAV to M4B with simple peak normalization""" print("πŸš€ Converting to m4b with simple normalization...") # Build audio filter chain + speed_to_use = custom_speed if custom_speed is not None else ATEMPO_SPEED audio_filters = [f"volume={target_db}dB"] - if ATEMPO_SPEED != 1.0: - audio_filters.append(f"atempo={ATEMPO_SPEED}") + if speed_to_use != 1.0: + audio_filters.append(f"atempo={speed_to_use}") cmd = [ "ffmpeg", "-y", @@ -192,16 +241,19 @@ def convert_to_m4b_with_simple_normalization(wav_path, temp_m4b_path, target_db= process.wait() print("\nβœ… Simple normalization complete.") -def convert_to_m4b(wav_path, temp_m4b_path): - """Convert WAV to M4B with configurable normalization""" +def convert_to_m4b(wav_path, temp_m4b_path, custom_speed=None): + """Convert WAV to M4B with configurable normalization and optional custom speed""" + # Determine speed to use (custom speed overrides config) + speed_to_use = custom_speed if custom_speed is not None else ATEMPO_SPEED + if not ENABLE_NORMALIZATION or NORMALIZATION_TYPE == "none": # Original function without normalization - print("πŸš€ Converting to m4b...") + print(f"πŸš€ Converting to m4b with speed {speed_to_use}x...") # Build audio filter for atempo if needed audio_filter = [] - if ATEMPO_SPEED != 1.0: - audio_filter = ["-filter:a", f"atempo={ATEMPO_SPEED}"] + if speed_to_use != 1.0: + audio_filter = ["-filter:a", f"atempo={speed_to_use}"] cmd = [ "ffmpeg", "-y", @@ -213,22 +265,22 @@ def convert_to_m4b(wav_path, temp_m4b_path): elif NORMALIZATION_TYPE == "loudness": # EBU R128 loudness normalization (recommended for audiobooks) - return convert_to_m4b_with_loudness_normalization(wav_path, temp_m4b_path) + return convert_to_m4b_with_loudness_normalization(wav_path, temp_m4b_path, custom_speed) elif NORMALIZATION_TYPE == "peak": # Peak normalization - return convert_to_m4b_with_peak_normalization(wav_path, temp_m4b_path, TARGET_PEAK_DB) + return convert_to_m4b_with_peak_normalization(wav_path, temp_m4b_path, TARGET_PEAK_DB, custom_speed) elif NORMALIZATION_TYPE == "simple": # Simple volume adjustment - return convert_to_m4b_with_simple_normalization(wav_path, temp_m4b_path, TARGET_PEAK_DB) + return convert_to_m4b_with_simple_normalization(wav_path, temp_m4b_path, TARGET_PEAK_DB, custom_speed) else: # Fallback to no normalization # Build audio filter for atempo if needed audio_filter = [] - if ATEMPO_SPEED != 1.0: - audio_filter = ["-filter:a", f"atempo={ATEMPO_SPEED}"] + if speed_to_use != 1.0: + audio_filter = ["-filter:a", f"atempo={speed_to_use}"] cmd = [ "ffmpeg", "-y", diff --git a/modules/progress_tracker.py b/modules/progress_tracker.py index 5f73b45dca773ef96f2cb519059db8629a6ed9fd..0560c5e545ea36f76f4b860e69eac1a1016edc51 100644 --- a/modules/progress_tracker.py +++ b/modules/progress_tracker.py @@ -1,6 +1,37 @@ """ -Progress Tracker Module -Handles progress display, VRAM monitoring, logging systems, and performance tracking +ChatterboxTTS Progress Tracking & Performance Monitoring Module +============================================================== + +OVERVIEW: +This module provides comprehensive progress tracking, performance monitoring, and logging +for ChatterboxTTS audiobook generation. It handles real-time ETA calculations, VRAM usage +monitoring, and detailed logging for debugging and optimization. + +MAIN COMPONENTS: +1. LOGGING SYSTEM: File + console logging with color-coded output +2. PROGRESS TRACKING: Real-time progress display with ETA calculations +3. PERFORMANCE MONITORING: GPU memory usage, processing times, realtime factors +4. BATCH PROGRESS: Multi-chapter audiobook progress aggregation +5. SYSTEM MONITORING: VRAM safety thresholds and memory optimization + +KEY FEATURES: +- Real-time ETA updates during TTS processing +- VRAM usage monitoring with automatic cleanup +- Performance metrics (realtime factor, chunks/minute) +- Color-coded console output for different message types +- Detailed logging for troubleshooting and optimization +- Memory-safe processing with configurable thresholds + +PERFORMANCE ENHANCEMENTS: +- Added producer-consumer pipeline progress tracking +- Enhanced ETA calculations for more accurate estimates +- GPU memory monitoring prevents VRAM exhaustion +- Automatic memory cleanup and garbage collection +- Processing speed metrics for performance optimization + +USAGE: +Called by TTS engine during audiobook generation to provide user feedback +and monitor system resource usage for safe, efficient processing. """ import time @@ -203,13 +234,42 @@ def display_system_info(): # ============================================================================ class PerformanceTracker: - """Track performance metrics throughout processing""" + """ + PERFORMANCE TRACKING CLASS - Core metrics collection and analysis + =============================================================== + + PURPOSE: + This class provides comprehensive performance monitoring for TTS processing, + tracking timing, memory usage, and generating detailed performance reports + for optimization and debugging. + + TRACKED METRICS: + - Individual chunk processing times + - VRAM usage per chunk (allocated vs reserved memory) + - Batch processing times for multi-chapter books + - Overall processing statistics and trends + - Real-time factor calculations (audio time vs processing time) + + USAGE FLOW: + 1. Initialize at start of TTS session + 2. Log chunk completions during processing + 3. Track batch completions for multi-part books + 4. Generate final performance report + + BENEFITS: + - Identifies processing bottlenecks + - Monitors memory usage patterns + - Provides user feedback on progress + - Enables performance optimization + - Helps debug processing issues + """ def __init__(self): - self.start_time = time.time() - self.chunk_times = [] - self.vram_usage = [] - self.batch_times = [] + """Initialize performance tracking with baseline metrics""" + self.start_time = time.time() # Session start timestamp + self.chunk_times = [] # Individual chunk processing times + self.vram_usage = [] # VRAM usage snapshots (chunk_id, allocated, reserved) + self.batch_times = [] # Batch processing times for multi-chapter books def log_chunk_completion(self, chunk_index, audio_duration): """Log individual chunk completion""" diff --git a/modules/resume_handler.py b/modules/resume_handler.py index 950f71f1ffb3f20d0c3be4b12af07d50a9b44b20..5c8305f6f6b801afce1330bae118d30fe5b58357 100644 --- a/modules/resume_handler.py +++ b/modules/resume_handler.py @@ -1,6 +1,47 @@ """ -Resume Handler Module -Handles resume functionality for interrupted processing +ChatterboxTTS Resume Handler Module +=================================== + +OVERVIEW: +This module provides intelligent resume functionality for interrupted audiobook generation. +It analyzes existing progress, identifies missing chunks, and seamlessly continues processing +from where it left off, saving hours of work on long audiobooks. + +MAIN COMPONENTS: +1. CHUNK ANALYSIS: Scans existing audio chunks to determine completion status +2. RESUME POINT DETECTION: Identifies the exact point to continue processing +3. GAP DETECTION: Finds missing chunks in the sequence for targeted regeneration +4. STATE RECOVERY: Restores processing state and configuration +5. SEAMLESS CONTINUATION: Picks up processing without duplicating completed work + +KEY FEATURES: +- Intelligent chunk sequence analysis (handles missing/corrupted chunks) +- Resume from specific chunk numbers or percentage complete +- Gap detection and targeted regeneration +- Progress state preservation across sessions +- Memory-efficient resume without reprocessing completed chunks +- Compatible with both JSON and text-based chunk workflows + +CRITICAL USE CASES: +- Long audiobooks interrupted by system crashes or user stops +- Partial regeneration after chunk repair/editing +- Continuing processing on different systems/sessions +- Recovery from VRAM exhaustion or hardware issues +- Selective chunk regeneration for quality improvements + +PERFORMANCE BENEFITS: +- Saves hours on long audiobook regeneration +- Preserves completed high-quality chunks +- Reduces system resource usage +- Enables iterative quality improvements +- Supports distributed/interrupted processing workflows + +USAGE FLOW: +1. Analyze existing chunks directory +2. Determine resume point and missing chunks +3. Load original text/JSON configuration +4. Continue processing from resume point +5. Fill gaps and complete remaining chunks """ import torch @@ -20,7 +61,37 @@ from modules.audio_processor import get_chunk_audio_duration, pause_for_chunk_re from modules.progress_tracker import setup_logging, log_chunk_progress, log_run def analyze_existing_chunks(audio_chunks_dir): - """Analyze existing chunks to determine resume point""" + """ + CHUNK ANALYSIS FUNCTION - Core resume logic + ========================================== + + PURPOSE: + Analyzes the existing audio chunks directory to determine: + 1. How many chunks have been successfully generated + 2. Where to resume processing (next chunk number) + 3. Which chunks are missing from the sequence (gaps to fill) + 4. Overall completion status and progress percentage + + ANALYSIS PROCESS: + - Scans directory for chunk_XXXXX.wav files + - Extracts chunk numbers and sorts them + - Identifies highest completed chunk number + - Detects gaps in the sequence (missing chunk numbers) + - Calculates resume point and missing chunks list + + PARAMETERS: + - audio_chunks_dir: Path to directory containing generated audio chunks + + RETURNS: + - resume_chunk_number: Next chunk number to start processing from + - missing_chunks: List of chunk numbers that need regeneration + + EDGE CASES HANDLED: + - Empty directory (start from beginning) + - No valid chunks found (start from beginning) + - Gaps in sequence (targeted regeneration) + - Out-of-order chunk numbers (robust sorting) + """ if not audio_chunks_dir.exists(): return 0, [] diff --git a/modules/text_processor.py b/modules/text_processor.py index 225c39667ff6930f2696f8b25eb93fb4ffcbad2c..9197a38b3bb4c7767afe011a70efe4427a3bbe13 100644 --- a/modules/text_processor.py +++ b/modules/text_processor.py @@ -1,6 +1,31 @@ """ -Text Processing Module -Handles text chunking, abbreviations, and preprocessing for TTS +ChatterboxTTS Text Processing Module +==================================== + +OVERVIEW: +This module is the core text preprocessing system for ChatterboxTTS audiobook generation. +It handles intelligent text chunking, abbreviation replacement, and punctuation normalization +to prepare raw text for high-quality TTS synthesis. + +MAIN COMPONENTS: +1. ABBREVIATION SYSTEM: Converts TTS-unfriendly abbreviations (Dr. -> Doctor) +2. TEXT CHUNKING: Breaks text into optimal chunks respecting sentence boundaries +3. PUNCTUATION NORMALIZATION: Standardizes quotes, adds missing periods +4. BOUNDARY DETECTION: Identifies chapter/paragraph breaks for silence insertion + +CRITICAL ALGORITHM FIXES: +- Fixed sentence chunking to respect punctuation boundaries (not word counts) +- Enhanced dialogue handling to prevent quote corruption +- Improved abbreviation replacement with external file loading +- Added smart punctuation detection for precise silence timing + +USAGE FLOW: +Text Input β†’ Abbreviation Replacement β†’ Punctuation Normalization β†’ +Sentence Chunking β†’ Boundary Detection β†’ JSON Output for TTS + +PERFORMANCE IMPACT: +Proper chunking prevents TTS model confusion and maintains voice consistency +across long audiobooks by preserving natural speech boundaries. """ import re @@ -9,13 +34,40 @@ from pathlib import Path from config.config import MAX_CHUNK_WORDS, MIN_CHUNK_WORDS, YELLOW, RESET - # ============================================================================ # ABBREVIATION REPLACEMENT SYSTEM # ============================================================================ +# +# PURPOSE: Replace TTS-unfriendly abbreviations with pronounceable text +# EXAMPLES: "Dr. Smith" -> "Doctor Smith", "U.S.A." -> "USA" +# BENEFITS: Prevents awkward pronunciation and improves audio quality def load_abbreviations(file_path="utils/abbreviations.txt"): - """Load abbreviation replacements from external file""" + """ + Load abbreviation-to-replacement mappings from external text file. + + PURPOSE: + - Centralizes abbreviation management in an editable text file + - Allows users to customize TTS pronunciations without code changes + - Supports comment lines and flexible formatting + + FILE FORMAT: + # Comments start with # + Dr. -> Doctor + U.S. -> US + etc. -> et cetera + + PARAMETERS: + - file_path: Path to abbreviations file (default: utils/abbreviations.txt) + + RETURNS: + - dict: Mapping of abbreviation -> replacement text + + BEHAVIOR: + - Creates sample file if none exists + - Skips malformed lines with warnings + - Returns empty dict on file errors (graceful degradation) + """ replacements = {} abbrev_file = Path(file_path) @@ -265,20 +317,47 @@ def _is_apostrophe(text, pos): def sentence_chunk_text(text, max_words=MAX_CHUNK_WORDS, min_words=MIN_CHUNK_WORDS): """ - Simple and reliable text chunking that follows the exact rules: - - TEXT CHUNKING RULES: - 1. Break at sentence boundaries (. ! ?) first (highest priority) - 2. If sentence > max_words, break at punctuation working backwards (; β€” , in that order) - 3. If no punctuation available, preserve sentence intact to maintain coherence + CRITICAL CHUNKING ALGORITHM - Heart of the TTS preprocessing system + ================================================================ + + ALGORITHM OVERVIEW: + This function is the most important component for TTS quality. It breaks raw text + into optimal chunks that respect natural speech boundaries, preventing TTS model + confusion and maintaining consistent voice characteristics. + + CORE PRINCIPLE: SENTENCE BOUNDARIES FIRST, WORD COUNTS SECOND + - Always prioritize complete sentences over arbitrary word limits + - Break long sentences at natural pauses (punctuation hierarchy) + - Combine short chunks to meet minimum requirements + - Preserve semantic coherence and emotional consistency + + TEXT CHUNKING RULES (in priority order): + 1. Break at sentence boundaries (. ! ?) first (HIGHEST PRIORITY) + 2. If sentence > max_words, break at punctuation working backwards + 3. If no punctuation available, preserve sentence intact (coherence over limits) 4. Ensure all chunks meet min_words requirement by combining small chunks - PUNCTUATION HIERARCHY (for breaking long sentences): - 1. . ! ? (sentence boundaries) - handled at sentence level - 2. ; (semicolon) - major pause - 3. β€” – (dashes) - major pause - 4. , (comma) - minor pause - 5. Preserve overlong sentences if no punctuation available + PUNCTUATION HIERARCHY (for breaking overlong sentences): + 1. . ! ? (sentence boundaries) - handled at sentence level first + 2. ; (semicolon) - major pause, good break point + 3. β€” – (em/en dashes) - major pause, narrative breaks + 4. , (comma) - minor pause, last resort for breaks + 5. NO PUNCTUATION = preserve intact (maintains emotional/semantic unity) + + WHY THIS APPROACH: + - Prevents choppy, robotic speech from mid-sentence breaks + - Maintains narrative flow and character voice consistency + - Respects author's punctuation for natural pauses + - Reduces TTS model confusion from incomplete thoughts + - Essential for long-form audiobook quality + + PARAMETERS: + - text: Raw input text to be chunked + - max_words: Target maximum words per chunk (flexible for complete sentences) + - min_words: Minimum words per chunk (enforced by combining) + + RETURNS: + - List of (chunk_text, is_paragraph_end) tuples for TTS processing """ import re diff --git a/modules/tts_engine.py b/modules/tts_engine.py index 3339213c53f9c16b1f7ccf6f13a33247839b7e74..b59ff0065010d7def3f6be17e9b84ead197b8ea9 100644 --- a/modules/tts_engine.py +++ b/modules/tts_engine.py @@ -1,6 +1,45 @@ """ -TTS Engine Module -Handles ChatterboxTTS interface, model loading, and chunk processing coordination +ChatterboxTTS Engine Module - Core TTS Processing System +======================================================= + +OVERVIEW: +This is the heart of the ChatterboxTTS system, responsible for loading TTS models, +processing audio chunks, and managing the complete text-to-speech pipeline. +It handles voice embedding caching, memory optimization, and parallel processing +for efficient audiobook generation. + +MAIN COMPONENTS: +1. MODEL MANAGEMENT: Loading, caching, and optimizing ChatterboxTTS models +2. VOICE PROCESSING: Voice sample analysis and embedding caching +3. CHUNK PROCESSING: Individual text chunk β†’ audio conversion +4. MEMORY OPTIMIZATION: VRAM management and garbage collection +5. PARALLEL PROCESSING: Multi-threaded chunk processing with producer-consumer pattern +6. PERFORMANCE MONITORING: Real-time progress tracking and ETA calculations + +CRITICAL PERFORMANCE FEATURES: +- Voice embedding caching (5-10% speed improvement) +- GPU persistence mode for faster model loading +- In-memory processing pipeline (eliminates temp files) +- Producer-consumer threading for parallel processing +- Automatic memory management and VRAM monitoring +- Model reinitialization every 500 chunks for stability + +WORKFLOW: +Text Chunks β†’ Voice Embedding β†’ TTS Processing β†’ Audio Generation β†’ +Quality Validation β†’ Silence Insertion β†’ Final WAV Output + +TECHNICAL DETAILS: +- Supports ChatterboxTTS models with custom voice cloning +- Handles variable TTS parameters (temperature, CFG, exaggeration) +- Implements VADER sentiment-driven parameter adjustment +- Memory-safe processing with configurable VRAM thresholds +- Automatic fallback for CUDA memory issues + +USAGE CONTEXTS: +- Called by main processing scripts (GenTTS_Claude.py) +- Used by JSON generation utilities +- Integrated with chunk repair tools +- Supports both GUI and CLI interfaces """ import torch @@ -9,10 +48,16 @@ import time import logging import shutil import sys +import os +import subprocess +import psutil +import numpy as np from datetime import timedelta from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path import torchaudio as ta +import queue +import threading from config.config import * from modules.text_processor import smart_punctuate, sentence_chunk_text, detect_content_boundaries @@ -20,14 +65,14 @@ from modules.text_processor import smart_punctuate, sentence_chunk_text, detect_ def find_chunks_json_file(book_name): """Find the corresponding chunks JSON file for a book""" from config.config import AUDIOBOOK_ROOT - + # Look in the TTS processing directory tts_chunks_dir = AUDIOBOOK_ROOT / book_name / "TTS" / "text_chunks" json_path = tts_chunks_dir / "chunks_info.json" - + if json_path.exists(): return json_path - + # Also check old Text_Input location for backwards compatibility text_input_dir = Path("Text_Input") possible_names = [ @@ -35,12 +80,12 @@ def find_chunks_json_file(book_name): f"{book_name.lower()}_chunks.json", f"{book_name.replace(' ', '_')}_chunks.json" ] - + for name in possible_names: old_json_path = text_input_dir / name if old_json_path.exists(): return old_json_path - + return None from modules.audio_processor import ( smart_audio_validation, apply_smart_fade, add_chunk_end_silence, @@ -81,39 +126,191 @@ def monitor_vram_usage(operation_name=""): if allocated > VRAM_SAFETY_THRESHOLD: logging.warning(f"⚠️ High VRAM usage during {operation_name}: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved") - optimize_memory_usage() + optimize_cuda_memory_usage() return allocated, reserved return 0, 0 -def get_optimal_workers(user_max_workers=None): - """Dynamic worker allocation based on device type and resources""" - # Check for user override first - if user_max_workers is not None: - print(f"πŸ‘€ Using user-defined workers: {user_max_workers}") - return int(user_max_workers) +# ============================================================================ +# PERFORMANCE OPTIMIZATION UTILITIES +# ============================================================================ + +def detect_deployment_environment(): + """Detect deployment environment for optimization adaptation""" + if os.getenv("RUNPOD_POD_ID"): + return "runpod" + elif os.getenv("SPACE_ID"): # Hugging Face Spaces + return "huggingface" + elif os.path.exists("/.dockerenv"): + return "container" + elif torch.cuda.is_available(): + return "local_gpu" + else: + return "local_cpu" + +def get_available_memory(): + """Get available system memory in MB""" + try: + memory = psutil.virtual_memory() + return memory.available // (1024 * 1024) + except: + return 8192 # Safe default of 8GB + +def has_nvidia_smi(): + """Check if nvidia-smi is available""" + try: + subprocess.run(['nvidia-smi', '--version'], capture_output=True, check=True) + return True + except: + return False + +def enable_gpu_persistence_mode(): + """Enable GPU persistence mode with proper fallbacks""" + if not ENABLE_GPU_PERSISTENCE_MODE: + return False - if not USE_DYNAMIC_WORKERS: - return MAX_WORKERS + try: + if torch.cuda.is_available() and has_nvidia_smi(): + for attempt in range(GPU_PERSISTENCE_RETRY_COUNT): + result = subprocess.run(['nvidia-smi', '-pm', '1'], + capture_output=True, text=True) + if result.returncode == 0: + logging.info("βœ… GPU persistence mode enabled") + return True + elif "Insufficient permissions" in result.stderr: + logging.warning("⚠️ GPU persistence mode failed (insufficient privileges)") + break + time.sleep(0.5) # Brief delay between attempts + + logging.warning("πŸ“ Continuing with standard GPU power management") + else: + logging.info("ℹ️ GPU persistence mode not applicable (no NVIDIA GPU detected)") + except Exception as e: + logging.warning(f"⚠️ GPU persistence mode failed: {e}") + + return False + +def setup_cuda_memory_pool(): + """Configure CUDA memory pool for enhanced performance and reduced fragmentation""" + if not ENABLE_CUDA_MEMORY_POOL or not torch.cuda.is_available(): + return False + + try: + # Get current device and memory info + device = torch.cuda.current_device() + total_memory = torch.cuda.get_device_properties(device).total_memory + total_memory_gb = total_memory / (1024**3) + + deployment_env = detect_deployment_environment() + + # Adaptive pool sizing based on environment and available memory + if ENABLE_ADAPTIVE_MEMORY_POOL: + if deployment_env == "runpod": + pool_fraction = min(CUDA_MEMORY_POOL_FRACTION, 0.85) # More conservative on RunPod + elif deployment_env == "huggingface": + pool_fraction = min(CUDA_MEMORY_POOL_FRACTION, 0.75) # Very conservative on HF Spaces + elif total_memory_gb < 8: + pool_fraction = min(CUDA_MEMORY_POOL_FRACTION, 0.8) # Conservative for <8GB GPUs + else: + pool_fraction = CUDA_MEMORY_POOL_FRACTION # Use full config for high-memory GPUs + else: + pool_fraction = CUDA_MEMORY_POOL_FRACTION + + # Calculate pool size + pool_size = int(total_memory * pool_fraction) + pool_size_gb = pool_size / (1024**3) + + # Configure memory pool allocator settings + # Set memory pool to reduce fragmentation and improve allocation speed + if hasattr(torch.cuda, 'memory') and hasattr(torch.cuda.memory, 'set_per_process_memory_fraction'): + torch.cuda.memory.set_per_process_memory_fraction(pool_fraction, device) + logging.info(f"βœ… CUDA memory pool configured: {pool_size_gb:.1f}GB ({pool_fraction*100:.0f}% of {total_memory_gb:.1f}GB)") + + # Configure allocator settings for better memory management + if hasattr(torch.cuda, 'empty_cache'): + # Clear any existing allocations before setting up pool + torch.cuda.empty_cache() + + # Enable memory pool optimizations if available in PyTorch version + try: + # Try to enable expandable segments for better memory utilization + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' + logging.info("βœ… CUDA expandable segments enabled") + except: + pass # Not available in all PyTorch versions + + # Warm up the memory pool with a small allocation + try: + warmup_tensor = torch.zeros(1024, 1024, device=device) + del warmup_tensor + torch.cuda.empty_cache() + logging.info("βœ… CUDA memory pool warmed up") + except Exception as e: + logging.warning(f"⚠️ Memory pool warmup failed: {e}") + + logging.info(f"πŸš€ CUDA memory pool setup complete - environment: {deployment_env}") + return True + + except Exception as e: + logging.error(f"❌ CUDA memory pool setup failed: {e}") + return False - # CPU-based worker calculation +def optimize_cuda_memory_usage(): + """Advanced CUDA memory optimization for better performance""" if not torch.cuda.is_available(): - import psutil - cpu_cores = psutil.cpu_count(logical=False) # Physical cores - available_memory = psutil.virtual_memory().available / 1024**3 # GB + return - # Each TTS model instance needs ~2-3GB RAM - # Conservative estimation: allow 1 worker per 4GB available RAM - memory_limited_workers = max(1, int(available_memory / 4)) + try: + # More aggressive cleanup for memory pool systems + torch.cuda.empty_cache() - # CPU-based calculation: use 50% of physical cores for intensive TTS work - cpu_limited_workers = max(1, int(cpu_cores * 0.5)) + # Synchronize to ensure all operations complete before cleanup + torch.cuda.synchronize() - optimal_workers = min(memory_limited_workers, cpu_limited_workers, MAX_WORKERS) - print(f"πŸ’» CPU mode: {cpu_cores} cores, {available_memory:.1f}GB RAM β†’ {optimal_workers} workers") - return optimal_workers - - # GPU-based worker calculation (existing logic) + # Additional memory pool optimization if available + if hasattr(torch.cuda, 'reset_peak_memory_stats'): + torch.cuda.reset_peak_memory_stats() + + except Exception as e: + logging.warning(f"⚠️ CUDA memory optimization failed: {e}") + +# Global voice embedding cache +_voice_embedding_cache = {} +_cache_memory_usage = 0 + +def get_voice_cache_key(voice_path, exaggeration): + """Generate cache key for voice embeddings""" + try: + # Use file path and modification time for cache invalidation + stat = os.stat(voice_path) + return f"{voice_path}:{stat.st_mtime}:{exaggeration}" + except: + return f"{voice_path}:{exaggeration}" + +def clear_voice_embedding_cache(): + """Clear voice embedding cache to free memory""" + global _voice_embedding_cache, _cache_memory_usage + _voice_embedding_cache.clear() + _cache_memory_usage = 0 + if torch.cuda.is_available(): + torch.cuda.empty_cache() + logging.info("πŸ—‘οΈ Voice embedding cache cleared") + +def estimate_cache_memory_mb(conds_object): + """Estimate memory usage of cached voice embeddings in MB""" + try: + if hasattr(conds_object, 't3') and hasattr(conds_object.t3, 'voice_embed'): + # Rough estimate based on typical voice embedding sizes + return 50 # Typical voice embedding ~50MB + return 30 # Conservative estimate + except: + return 30 + +def get_optimal_workers(): + """Dynamic worker allocation based on VRAM usage""" + if not USE_DYNAMIC_WORKERS: + return MAX_WORKERS + allocated_vram = torch.cuda.memory_allocated() / 1024**3 if allocated_vram < 5.0: @@ -123,52 +320,380 @@ def get_optimal_workers(user_max_workers=None): else: return 1 -def load_optimized_model(device): - """Load TTS model with memory optimizations and device detection""" - from chatterbox.tts import ChatterboxTTS +def prewarm_model_with_voice(model, voice_path, tts_params=None): + """ + Pre-warm the TTS model with a voice sample to eliminate cold start quality issues. - # Detect available device if not specified or if CUDA not available - if device == "cuda" and not torch.cuda.is_available(): - print("⚠️ CUDA not available, falling back to CPU") - device = "cpu" - elif device == "auto": + Args: + model: Loaded TTS model + voice_path: Path to voice sample file + tts_params: Optional TTS parameters for pre-warming (uses defaults if None) + + Returns: + model: The pre-warmed model (same object, but with cached conditioning) + """ + import tempfile + import os + from modules.file_manager import ensure_voice_sample_compatibility + + try: + print("πŸ”₯ Pre-warming model with voice sample...") + + # Prepare voice for TTS + compatible_voice = ensure_voice_sample_compatibility(voice_path) + + # Set up default TTS parameters if none provided + if tts_params is None: + tts_params = { + 'exaggeration': 0.5, + 'cfg_weight': 0.5, + 'temperature': 0.9 + } + + # Prepare voice conditionals + model.prepare_conditionals(compatible_voice) + + # Generate a short dummy audio to fully warm up the model + dummy_text = "The quick brown fox jumps over the lazy dog." + print(f"🎀 Generating warm-up audio: '{dummy_text}'") + + # Generate dummy audio with the voice and parameters + wav_np = model.generate( + dummy_text, + exaggeration=tts_params['exaggeration'], + cfg_weight=tts_params['cfg_weight'], + temperature=tts_params['temperature'] + ) + + print("βœ… Model pre-warming completed - first chunk quality optimized") + + # Clean up any temporary audio data (don't save the dummy audio) + del wav_np + + return model + + except Exception as e: + print(f"⚠️ Pre-warming failed: {e}") + print("πŸ“ Model will still work but first chunk may have quality variations") + return model + +def get_best_available_device(): + """Detect and return the best available device with proper fallback""" + try: if torch.cuda.is_available(): - device = "cuda" - print("βœ… CUDA detected, using GPU") - else: - device = "cpu" - print("πŸ’» No GPU detected, using CPU") + # Test CUDA with a simple operation + test_tensor = torch.tensor([1.0]).to("cuda") + del test_tensor + torch.cuda.empty_cache() + return "cuda" + except Exception as e: + logging.warning(f"CUDA test failed: {e}") - print(f"πŸ”§ Loading ChatterboxTTS model on device: {device}") + try: + if torch.backends.mps.is_available(): + # Test MPS with a simple operation + test_tensor = torch.tensor([1.0]).to("mps") + del test_tensor + return "mps" + except Exception as e: + logging.warning(f"MPS test failed: {e}") + + return "cpu" + +def load_optimized_model(device): + """Load TTS model with memory optimizations and device fallback""" + from src.chatterbox.tts import ChatterboxTTS + + # Validate device availability + original_device = device + try: + if device == "cuda": + # Test CUDA availability with a small operation + test_tensor = torch.tensor([1.0]).to("cuda") + del test_tensor + torch.cuda.empty_cache() + logging.info(f"βœ… CUDA device validated successfully") + elif device == "mps" and torch.backends.mps.is_available(): + # Test MPS availability + test_tensor = torch.tensor([1.0]).to("mps") + del test_tensor + logging.info(f"βœ… MPS device validated successfully") + except Exception as e: + logging.warning(f"⚠️ Device {device} failed validation: {e}") + logging.info("πŸ”„ Falling back to CPU") + device = "cpu" try: - # Load model (ChatterboxTTS.from_pretrained doesn't support torch_dtype parameter) + # Load model with validated device (ChatterboxTTS doesn't support torch_dtype parameter) model = ChatterboxTTS.from_pretrained(device=device) - logging.info(f"βœ… Loaded ChatterboxTTS model on {device}") + logging.info(f"βœ… Model loaded successfully on {device.upper()}") + + if original_device != device: + logging.info(f"πŸ“ Note: Requested {original_device.upper()} but using {device.upper()} due to availability") + except Exception as e: - print(f"❌ Failed to load model on {device}: {e}") - if device == "cuda": - print("πŸ”„ Retrying with CPU...") - try: - model = ChatterboxTTS.from_pretrained(device="cpu") - logging.info("βœ… Loaded model on CPU (GPU failed)") - device = "cpu" - except Exception as e2: - print(f"❌ Failed to load model on CPU: {e2}") - raise e2 + logging.error(f"❌ Failed to load model on {device}: {e}") + if device != "cpu": + logging.info("πŸ”„ Final fallback to CPU...") + device = "cpu" + model = ChatterboxTTS.from_pretrained(device=device) + logging.info("βœ… Model loaded on CPU as final fallback") else: - raise e + raise RuntimeError(f"Failed to load model even on CPU: {e}") # Only apply eval() and benchmark if the model has these attributes if hasattr(model, 'eval'): model.eval() - # Set CUDNN benchmark for performance (if available) - if torch.backends.cudnn.is_available(): + # Set CUDNN benchmark for performance (if available and using CUDA) + if device == "cuda" and torch.backends.cudnn.is_available(): torch.backends.cudnn.benchmark = True + logging.info("βœ… CUDNN benchmark enabled for performance") + + # Initialize CUDA memory pool if enabled and using CUDA + if device == "cuda" and ENABLE_CUDA_MEMORY_POOL: + memory_pool_success = setup_cuda_memory_pool() + if memory_pool_success: + logging.info("πŸš€ CUDA memory pool optimization enabled") + else: + logging.warning("⚠️ CUDA memory pool setup failed, continuing without optimization") return model +# ============================================================================ +# PRODUCER-CONSUMER PIPELINE (PHASE 4) +# ============================================================================ + +def chunk_producer_thread(all_chunks, chunk_queue, start_index=0, max_queue_size=10): + """ + Producer thread that pre-loads chunks into a queue for worker threads to consume. + This eliminates chunk loading overhead during TTS processing. + + Args: + all_chunks: List of chunk data (dict format with text, boundary_type, etc) + chunk_queue: Queue to place prepared chunk data + start_index: Index to start producing from (for resume functionality) + max_queue_size: Maximum queue size to prevent memory overflow + """ + try: + logging.info(f"🏭 Producer thread started - pre-loading chunks from index {start_index}") + + for i, chunk_data in enumerate(all_chunks[start_index:], start=start_index): + # Check if we should stop (via sentinel or shutdown) + if shutdown_requested: + break + + # Handle both dictionary and tuple formats for backward compatibility + if isinstance(chunk_data, dict): + chunk_text = chunk_data["text"] + boundary_type = chunk_data.get("boundary_type", "none") + chunk_tts_params = chunk_data.get("tts_params", None) + else: + # Handle old tuple format (text, is_para_end) + chunk_text = chunk_data[0] if len(chunk_data) > 0 else str(chunk_data) + is_old_para_end = chunk_data[1] if len(chunk_data) > 1 else False + boundary_type = "paragraph_end" if is_old_para_end else "none" + chunk_tts_params = None + + # Create standardized chunk package for workers + chunk_package = { + 'index': i, + 'text': chunk_text, + 'boundary_type': boundary_type, + 'tts_params': chunk_tts_params + } + + # Put chunk in queue (blocks if queue is full) + chunk_queue.put(chunk_package, timeout=30) + + # Log progress every 50 chunks to avoid spam + if (i + 1) % 50 == 0: + logging.info(f"πŸ“¦ Producer queued {i + 1} chunks") + + logging.info(f"βœ… Producer thread completed - {len(all_chunks) - start_index} chunks queued") + + except Exception as e: + logging.error(f"❌ Producer thread failed: {e}") + finally: + # Signal completion by adding sentinel value + try: + chunk_queue.put(None, timeout=5) # None = end of chunks signal + except queue.Full: + logging.warning("⚠️ Could not add completion signal - queue full") + +def process_chunks_with_pipeline( + all_chunks, batch_chunks, chunk_offset, text_chunks_dir, audio_chunks_dir, + voice_path, tts_params, start_time, total_chunks, punc_norm, book_name, + log_run_func, log_path, device, model, asr_model, asr_enabled, optimal_workers, + accumulated_audio_duration=0.0 +): + """ + Enhanced chunk processing with producer-consumer pipeline for 5-10% performance improvement. + + Args: + all_chunks: Complete list of all chunks (for context) + batch_chunks: Current batch of chunks to process + chunk_offset: Offset for global chunk indexing + ... (other parameters same as original ThreadPoolExecutor pattern) + + Returns: + Tuple of (batch_results, total_audio_duration) where: + - batch_results: List of (index, wav_path) tuples for successful chunks + - total_audio_duration: Total audio duration for batch (for progress tracking) + """ + try: + # Create thread-safe queue with size limit to prevent memory overflow + max_queue_size = min(optimal_workers * 3, 20) # 3x workers or 20, whichever is smaller + chunk_queue = queue.Queue(maxsize=max_queue_size) + + # Start producer thread to pre-load chunks + producer_thread = threading.Thread( + target=chunk_producer_thread, + args=(batch_chunks, chunk_queue, 0, max_queue_size), + daemon=True + ) + producer_thread.start() + + logging.info(f"πŸš€ Producer-consumer pipeline started with queue size {max_queue_size}") + + # Consumer pattern: workers pull from queue instead of sequential loading + batch_results = [] + futures = [] + + with ThreadPoolExecutor(max_workers=optimal_workers) as executor: + # Process chunks as they become available and handle results in real-time + chunks_submitted = 0 + completed_count = 0 + total_audio_duration = accumulated_audio_duration + + # Import audio processing functions + from modules.audio_processor import get_chunk_audio_duration + from modules.progress_tracker import log_chunk_progress + + while True: + try: + # Get next chunk from producer (blocks until available) + chunk_package = chunk_queue.get(timeout=10) + + # Check for completion signal + if chunk_package is None: + break + + # Check for shutdown request + if shutdown_requested: + logging.info("πŸ›‘ Shutdown requested - stopping chunk submission") + break + + # Extract chunk data from package + global_chunk_index = chunk_offset + chunk_package['index'] + chunk_text = chunk_package['text'] + boundary_type = chunk_package['boundary_type'] + chunk_tts_params = chunk_package.get('tts_params') or tts_params + + # Build context for chunk (all chunk texts) + all_chunk_texts = [] + for cd in all_chunks: + if isinstance(cd, dict): + all_chunk_texts.append(cd["text"]) + else: + all_chunk_texts.append(cd[0] if len(cd) > 0 else str(cd)) + + # Submit chunk to worker thread + future = executor.submit( + process_one_chunk, + global_chunk_index, chunk_text, text_chunks_dir, audio_chunks_dir, + voice_path, chunk_tts_params, start_time, total_chunks, + punc_norm, book_name, log_run_func, log_path, device, + model, asr_model, all_chunk_texts, boundary_type, + asr_enabled + ) + futures.append(future) + + chunks_submitted += 1 + chunk_queue.task_done() + + # Check for completed futures while submitting new ones + completed_futures = [] + for fut in futures: + if fut.done(): + completed_futures.append(fut) + + # Process completed futures + for fut in completed_futures: + try: + idx, wav_path = fut.result() + if wav_path and wav_path.exists(): + batch_results.append((idx, wav_path)) + + # Update totals for final batch calculation + chunk_duration = get_chunk_audio_duration(wav_path) + total_audio_duration += chunk_duration + completed_count += 1 + + futures.remove(fut) # Remove completed future + + except Exception as e: + logging.error(f"❌ Future failed during real-time processing: {e}") + futures.remove(fut) + + except queue.Empty: + # Timeout waiting for chunks - check if producer is done + if not producer_thread.is_alive(): + break + else: + # Producer still working - check for completed futures while waiting + completed_futures = [fut for fut in futures if fut.done()] + for fut in completed_futures: + try: + idx, wav_path = fut.result() + if wav_path and wav_path.exists(): + batch_results.append((idx, wav_path)) + + chunk_duration = get_chunk_audio_duration(wav_path) + total_audio_duration += chunk_duration + completed_count += 1 + + futures.remove(fut) + + except Exception as e: + logging.error(f"❌ Future failed during timeout processing: {e}") + futures.remove(fut) + continue + + except Exception as e: + logging.error(f"❌ Error in consumer loop: {e}") + break + + # Process any remaining futures + if futures: + for fut in as_completed(futures): + try: + idx, wav_path = fut.result() + if wav_path and wav_path.exists(): + batch_results.append((idx, wav_path)) + + # Update batch totals + chunk_duration = get_chunk_audio_duration(wav_path) + total_audio_duration += chunk_duration + completed_count += 1 + + except Exception as e: + logging.error(f"❌ Final future failed: {e}") + + # Wait for producer thread to complete cleanly + if producer_thread.is_alive(): + producer_thread.join(timeout=5) + + # Calculate batch-specific audio duration for return + batch_audio_duration = total_audio_duration - accumulated_audio_duration + logging.info(f"πŸŽ‰ Producer-consumer pipeline completed: {len(batch_results)} chunks processed") + return batch_results, batch_audio_duration + + except Exception as e: + logging.error(f"❌ Producer-consumer pipeline failed: {e}") + logging.info("πŸ”„ Falling back to sequential processing...") + return [], 0.0 # Return empty results to trigger fallback + # ============================================================================ # CHUNK PROCESSING # ============================================================================ @@ -189,7 +714,8 @@ def process_one_chunk( i, chunk, text_chunks_dir, audio_chunks_dir, voice_path, tts_params, start_time, total_chunks, punc_norm, basename, log_run_func, log_path, device, - model, asr_model, all_chunks, boundary_type="none" + model, asr_model, all_chunks, boundary_type="none", + enable_asr=None ): """Enhanced chunk processing with quality control, contextual silence, and deep cleanup""" import difflib @@ -269,10 +795,21 @@ def process_one_chunk( mid_drop_retries = 0 max_mid_drop_retries = 2 - for attempt_num in range(1, 3): - logging.info(f"πŸ” Starting TTS for chunk {chunk_id_str}, attempt {attempt_num}") + # Enhanced regeneration loop with quality validation + max_attempts = MAX_REGENERATION_ATTEMPTS if ENABLE_REGENERATION_LOOP else 2 + current_tts_params = tts_params.copy() + + # Debug: Log the initial parameters for this chunk + logging.info(f"πŸŽ›οΈ Chunk {chunk_id_str} initial TTS params: exag={current_tts_params.get('exaggeration', 'N/A'):.3f}, cfg={current_tts_params.get('cfg_weight', 'N/A'):.3f}, temp={current_tts_params.get('temperature', 'N/A'):.3f}, min_p={current_tts_params.get('min_p', 'N/A'):.3f}") + + for attempt_num in range(max_attempts): + logging.info(f"πŸ” Starting TTS for chunk {chunk_id_str}, attempt {attempt_num + 1}/{max_attempts}") + if attempt_num > 0: + logging.info(f"πŸ”§ Adjusted params: exag={current_tts_params.get('exaggeration', 'N/A'):.3f}, cfg={current_tts_params.get('cfg_weight', 'N/A'):.3f}, temp={current_tts_params.get('temperature', 'N/A'):.3f}") try: - tts_args = {k: v for k, v in tts_params.items() if k != "max_workers"} + # Filter to only supported ChatterboxTTS parameters + supported_params = {"exaggeration", "cfg_weight", "temperature", "min_p", "top_p", "repetition_penalty"} + tts_args = {k: v for k, v in current_tts_params.items() if k in supported_params} # monitor_gpu_activity(f"Before TTS chunk_{chunk_id_str}") # Disabled for speed with torch.no_grad(): @@ -282,81 +819,86 @@ def process_one_chunk( if wav.dim() == 1: wav = wav.unsqueeze(0) - # Retry if mid-energy drop is enabled and detected (check in memory) - if ENABLE_MID_DROP_CHECK and has_mid_energy_drop(wav, model.sr): - mid_drop_retries += 1 - if mid_drop_retries >= max_mid_drop_retries: - logging.info(f"⚠️ Mid-drop retry limit reached for {chunk_id_str}. Accepting audio.") - else: - logging.info(f"⚠️ Mid-chunk noise detected in {chunk_id_str}. Retrying...") - continue - # Convert tensor to AudioSegment for in-memory processing import io import soundfile as sf from pydub import AudioSegment - + # Convert wav tensor to AudioSegment (in memory) wav_np = wav.squeeze().numpy() with io.BytesIO() as wav_buffer: sf.write(wav_buffer, wav_np, model.sr, format='wav') wav_buffer.seek(0) audio_segment = AudioSegment.from_wav(wav_buffer) - - # Smart fade removed - replaced by precise audio trimming - # Audio health validation disabled for speed - - # Note: Audio trimming will handle end-of-speech cleanup more precisely - - # ASR validation (memory-based processing) - check user setting first - enable_asr_user = tts_params.get('enable_asr', False) - if (enable_asr_user or ENABLE_ASR) and asr_model is not None: - from modules.audio_processor import asr_f1_score - import io - import soundfile as sf - # monitor_gpu_activity(f"Before ASR chunk_{chunk_id_str}") # Disabled for speed + + # Enhanced quality validation + quality_score = 1.0 # Start with perfect score + + # Legacy mid-energy drop check (converted to score) + if ENABLE_MID_DROP_CHECK and has_mid_energy_drop(wav, model.sr): + quality_score *= 0.3 # Significant penalty for mid-drop + logging.info(f"⚠️ Mid-chunk energy drop detected in {chunk_id_str}") + + # Enhanced quality validation (if enabled) + if ENABLE_REGENERATION_LOOP: + from modules.audio_processor import evaluate_chunk_quality + # Pass existing ASR model to avoid loading duplicate + composite_score = evaluate_chunk_quality(audio_segment, chunk, include_spectral=True, asr_model=asr_model) + quality_score *= composite_score + logging.info(f"πŸ“Š Quality score for {chunk_id_str}: {quality_score:.3f} (composite: {composite_score:.3f})") + + # ASR validation (memory-based processing) + asr_score = 1.0 # Default to passed if ASR disabled + # Use parameter if provided, otherwise fall back to config + asr_enabled = enable_asr if enable_asr is not None else ENABLE_ASR + if asr_enabled and asr_model is not None: + from modules.audio_processor import calculate_text_similarity try: # Process ASR completely in memory - no disk writes - # Convert AudioSegment to numpy array for ASR samples = np.array(audio_segment.get_array_of_samples()) if audio_segment.channels == 2: samples = samples.reshape((-1, 2)).mean(axis=1) - + # Normalize to float32 for ASR model audio_np = samples.astype(np.float32) / audio_segment.max_possible_amplitude - - # Use ASR model directly on numpy array (if supported) - # Note: This depends on the ASR model's input capabilities result = asr_model.transcribe(audio_np) - + if not isinstance(result, dict) or "text" not in result: raise ValueError(f"Invalid ASR result type: {type(result)}") asr_text = result.get("text", "").strip() - sim_ratio = asr_f1_score(punc_norm(chunk), asr_text) + asr_score = calculate_text_similarity(punc_norm(chunk), asr_text) + logging.info(f"🎀 ASR similarity for chunk {chunk_id_str}: {asr_score:.3f} - Expected: '{punc_norm(chunk)}' Got: '{asr_text}'") except Exception as e: - print(f"❌ ASR failed for {chunk_id_str}: {e}") - log_run_func(f"ASR VALIDATION FAILED - Chunk {chunk_id_str}:\nExpected:\n{chunk}\nActual:\n\nSimilarity: -1.000\n" + "="*50, log_path) - sim_ratio = -1.0 - continue + logging.error(f"❌ ASR failed for {chunk_id_str}: {e}") + asr_score = 0.8 # Use neutral score instead of 0 to avoid regeneration - logging.info(f"ASR similarity for chunk {chunk_id_str}: {sim_ratio:.3f}") - if sim_ratio < 0.7: - continue + # Include ASR score in overall quality + quality_score *= asr_score - # Track best valid match - best_sim = sim_ratio - best_asr_text = asr_text - # monitor_gpu_activity(f"After ASR chunk_{chunk_id_str}") # Disabled for speed - - # Success - we have processed audio in memory - final_audio = audio_segment - break + # Final quality check with all validations + if quality_score >= QUALITY_THRESHOLD or attempt_num == max_attempts - 1: + if quality_score >= QUALITY_THRESHOLD: + logging.info(f"βœ… Quality acceptable for {chunk_id_str} on attempt {attempt_num + 1} (final score: {quality_score:.3f})") + else: + logging.info(f"⚠️ Max attempts reached for {chunk_id_str}, accepting best effort (final score: {quality_score:.3f})") + + # Quality acceptable or max attempts reached, continue with processing + final_audio = audio_segment + best_sim = asr_score if asr_enabled else 1.0 + best_asr_text = asr_text if asr_enabled and 'asr_text' in locals() else "" + break + else: + # Quality too low, adjust parameters for retry + logging.info(f"πŸ”„ Quality below threshold ({quality_score:.3f} < {QUALITY_THRESHOLD}), adjusting parameters for retry {attempt_num + 2}") + from modules.audio_processor import adjust_parameters_for_retry + current_tts_params = adjust_parameters_for_retry(current_tts_params, quality_score, attempt_num) + continue except Exception as e: import traceback - logging.error(f"Exception during TTS attempt {attempt_num} for chunk {chunk_id_str}: {e}") + logging.error(f"Exception during TTS attempt {attempt_num + 1} for chunk {chunk_id_str}: {e}") traceback.print_exc() continue @@ -366,7 +908,7 @@ def process_one_chunk( # Apply trimming and contextual silence in memory before final save from modules.audio_processor import process_audio_with_trimming_and_silence - + if boundary_type and boundary_type != "none": final_audio = process_audio_with_trimming_and_silence(final_audio, boundary_type) print(f"πŸ”‡ Added {boundary_type} silence to chunk {i+1:05}") @@ -387,23 +929,41 @@ def process_one_chunk( # No intermediate file cleanup needed - all processing done in memory # Log details - only log ASR failures - asr_active = enable_asr_user or ENABLE_ASR - if asr_active and best_sim < 0.8: + if asr_enabled and best_sim < 0.8: log_run_func(f"ASR VALIDATION FAILED - Chunk {chunk_id_str}:\nExpected:\n{chunk}\nActual:\n{best_asr_text}\nSimilarity: {best_sim:.3f}\n" + "="*50, log_path) - elif not asr_active: + elif not asr_enabled: log_run_func(f"Chunk {chunk_id_str}: Original text: {chunk}", log_path) # Silence already added in memory above - no disk processing needed # Enhanced regular cleanup (every chunk) del wav - optimize_memory_usage() + optimize_cuda_memory_usage() # Additional per-chunk cleanup for long runs if (i + 1) % 50 == 0: torch.cuda.empty_cache() gc.collect() + # Show ETA progress updates during actual processing (every 2 chunks) + if i % 2 == 0: + try: + from modules.audio_processor import get_chunk_audio_duration + from modules.progress_tracker import log_chunk_progress + + # Calculate running total audio duration by checking existing chunks + total_audio_duration = 0.0 + for j in range(i + 1): # Include current chunk + check_path = audio_chunks_dir / f"chunk_{j+1:05}.wav" + if check_path.exists(): + total_audio_duration += get_chunk_audio_duration(check_path) + + # Show ETA update with accumulated audio + log_chunk_progress(i, total_chunks, start_time, total_audio_duration) + except Exception as e: + # Don't let ETA calculation failures break chunk processing + pass + return i, final_path # ============================================================================ @@ -413,10 +973,63 @@ def process_one_chunk( from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer from wrapper.chunk_loader import save_chunks -def generate_enriched_chunks(text_file, output_dir, user_tts_params=None): +def smooth_sentiment_scores(scores, index, method="rolling", window=3): + """ + Apply sentiment smoothing to prevent harsh emotional transitions. + + Args: + scores: List of compound sentiment scores + index: Current chunk index + method: "rolling" for moving average, "exp_decay" for exponential decay + window: Number of previous chunks to consider + + Returns: + float: Smoothed sentiment score + """ + if index == 0: + return scores[0] + + start_idx = max(0, index - window + 1) + window_scores = scores[start_idx:index + 1] + + if method == "rolling": + return sum(window_scores) / len(window_scores) + elif method == "exp_decay": + weights = SENTIMENT_EXP_DECAY_WEIGHTS[:len(window_scores)] + weighted_sum = sum(w * s for w, s in zip(weights, reversed(window_scores))) + weight_sum = sum(weights[:len(window_scores)]) + return weighted_sum / weight_sum if weight_sum > 0 else window_scores[-1] + else: + return scores[index] # No smoothing + +def generate_enriched_chunks(text_file, output_dir, user_tts_params=None, quality_params=None, config_params=None, voice_name=None): """Reads a text file, performs VADER sentiment analysis, and returns enriched chunks.""" analyzer = SentimentIntensityAnalyzer() - + + # Extract quality parameters for JSON generation (GUI overrides config) + if quality_params: + enable_smoothing = quality_params.get('sentiment_smoothing', ENABLE_SENTIMENT_SMOOTHING) + smoothing_window = quality_params.get('smoothing_window', SENTIMENT_SMOOTHING_WINDOW) + smoothing_method = quality_params.get('smoothing_method', SENTIMENT_SMOOTHING_METHOD) + print(f"πŸ”§ JSON Generation: Using GUI smoothing settings - Enabled: {enable_smoothing}, Window: {smoothing_window}, Method: {smoothing_method}") + else: + enable_smoothing = ENABLE_SENTIMENT_SMOOTHING + smoothing_window = SENTIMENT_SMOOTHING_WINDOW + smoothing_method = SENTIMENT_SMOOTHING_METHOD + print(f"πŸ”§ JSON Generation: Using config smoothing settings - Enabled: {enable_smoothing}") + + # Extract VADER sensitivity parameters (GUI overrides config) + if config_params: + vader_exag_sensitivity = config_params.get('vader_exag_sensitivity', VADER_EXAGGERATION_SENSITIVITY) + vader_cfg_sensitivity = config_params.get('vader_cfg_sensitivity', VADER_CFG_WEIGHT_SENSITIVITY) + vader_temp_sensitivity = config_params.get('vader_temp_sensitivity', VADER_TEMPERATURE_SENSITIVITY) + print(f"πŸ”§ JSON Generation: Using GUI VADER sensitivity - Exag: {vader_exag_sensitivity}, CFG: {vader_cfg_sensitivity}, Temp: {vader_temp_sensitivity}") + else: + vader_exag_sensitivity = VADER_EXAGGERATION_SENSITIVITY + vader_cfg_sensitivity = VADER_CFG_WEIGHT_SENSITIVITY + vader_temp_sensitivity = VADER_TEMPERATURE_SENSITIVITY + print(f"πŸ”§ JSON Generation: Using config VADER sensitivity - Exag: {vader_exag_sensitivity}, CFG: {vader_cfg_sensitivity}, Temp: {vader_temp_sensitivity}") + raw_text = text_file.read_text(encoding='utf-8') cleaned = smart_punctuate(raw_text) chunks = sentence_chunk_text(cleaned) @@ -426,58 +1039,155 @@ def generate_enriched_chunks(text_file, output_dir, user_tts_params=None): base_exaggeration = user_tts_params.get('exaggeration', BASE_EXAGGERATION) base_cfg_weight = user_tts_params.get('cfg_weight', BASE_CFG_WEIGHT) base_temperature = user_tts_params.get('temperature', BASE_TEMPERATURE) + base_min_p = user_tts_params.get('min_p', DEFAULT_MIN_P) + base_top_p = user_tts_params.get('top_p', DEFAULT_TOP_P) + base_repetition_penalty = user_tts_params.get('repetition_penalty', DEFAULT_REPETITION_PENALTY) + use_vader = user_tts_params.get('use_vader', True) # Default to True for backward compatibility + else: base_exaggeration = BASE_EXAGGERATION base_cfg_weight = BASE_CFG_WEIGHT base_temperature = BASE_TEMPERATURE + base_min_p = DEFAULT_MIN_P + base_top_p = DEFAULT_TOP_P + base_repetition_penalty = DEFAULT_REPETITION_PENALTY + use_vader = True # Default behavior enriched = [] chunk_texts = [chunk_text for chunk_text, _ in chunks] - - for i, (chunk_text, is_para_end) in enumerate(chunks): - sentiment_scores = analyzer.polarity_scores(chunk_text) - compound_score = sentiment_scores['compound'] - exaggeration = base_exaggeration + (compound_score * VADER_EXAGGERATION_SENSITIVITY) - cfg_weight = base_cfg_weight + (compound_score * VADER_CFG_WEIGHT_SENSITIVITY) - temperature = base_temperature + (compound_score * VADER_TEMPERATURE_SENSITIVITY) + # First pass: collect all sentiment scores + raw_sentiment_scores = [] + for chunk_text, _ in chunks: + sentiment_scores = analyzer.polarity_scores(chunk_text) + raw_sentiment_scores.append(sentiment_scores['compound']) - # Clamp values to defined min/max - exaggeration = round(max(TTS_PARAM_MIN_EXAGGERATION, min(exaggeration, TTS_PARAM_MAX_EXAGGERATION)), 2) - cfg_weight = round(max(TTS_PARAM_MIN_CFG_WEIGHT, min(cfg_weight, TTS_PARAM_MAX_CFG_WEIGHT)), 2) - temperature = round(max(TTS_PARAM_MIN_TEMPERATURE, min(temperature, TTS_PARAM_MAX_TEMPERATURE)), 2) + # Second pass: apply smoothing and generate parameters + for i, (chunk_text, is_para_end) in enumerate(chunks): + # Get original sentiment score + raw_compound_score = raw_sentiment_scores[i] + + # Apply sentiment smoothing if enabled (uses GUI settings, not config) + if use_vader and enable_smoothing: + compound_score = smooth_sentiment_scores( + raw_sentiment_scores, + i, + method=smoothing_method, + window=smoothing_window + ) + # Debug: Log sentiment changes + if abs(compound_score - raw_compound_score) > 0.1: + logging.info(f"πŸ“Š Chunk {i+1:05}: sentiment smoothed {raw_compound_score:.3f} β†’ {compound_score:.3f}") + else: + compound_score = raw_compound_score + + if use_vader: + # Apply VADER sentiment adjustments using smoothed score + exaggeration = base_exaggeration + (compound_score * vader_exag_sensitivity) + cfg_weight = base_cfg_weight + (compound_score * vader_cfg_sensitivity) + temperature = base_temperature + (compound_score * vader_temp_sensitivity) + min_p = base_min_p + (compound_score * VADER_MIN_P_SENSITIVITY) + repetition_penalty = base_repetition_penalty + (compound_score * VADER_REPETITION_PENALTY_SENSITIVITY) + + # Clamp values to defined min/max (ensure JSON values respect bounds) + exaggeration = round(max(TTS_PARAM_MIN_EXAGGERATION, min(exaggeration, TTS_PARAM_MAX_EXAGGERATION)), 2) + cfg_weight = round(max(TTS_PARAM_MIN_CFG_WEIGHT, min(cfg_weight, TTS_PARAM_MAX_CFG_WEIGHT)), 2) + temperature = round(max(TTS_PARAM_MIN_TEMPERATURE, min(temperature, TTS_PARAM_MAX_TEMPERATURE)), 2) + min_p = round(max(TTS_PARAM_MIN_MIN_P, min(min_p, TTS_PARAM_MAX_MIN_P)), 3) + repetition_penalty = round(max(TTS_PARAM_MIN_REPETITION_PENALTY, min(repetition_penalty, TTS_PARAM_MAX_REPETITION_PENALTY)), 1) + + # Debug: Log VADER-adjusted parameters for significant changes + if abs(exaggeration - base_exaggeration) > 0.05 or abs(cfg_weight - base_cfg_weight) > 0.05: + logging.info(f"🎭 Chunk {i+1:05}: VADER adjusted params - exag: {base_exaggeration:.2f}β†’{exaggeration:.2f}, cfg: {base_cfg_weight:.2f}β†’{cfg_weight:.2f}, sentiment: {compound_score:.3f}") + else: + # Use fixed base values (no VADER adjustment) + exaggeration = base_exaggeration + cfg_weight = base_cfg_weight + temperature = base_temperature + min_p = base_min_p + repetition_penalty = base_repetition_penalty boundary_type = detect_content_boundaries(chunk_text, i, chunk_texts, is_para_end) - + enriched.append({ "index": i, "text": chunk_text, "word_count": len(chunk_text.split()), "boundary_type": boundary_type if boundary_type else "none", - "sentiment_compound": compound_score, + "sentiment_compound": compound_score, # Store smoothed score + "sentiment_raw": raw_compound_score, # Store original score for reference "tts_params": { "exaggeration": exaggeration, "cfg_weight": cfg_weight, - "temperature": temperature + "temperature": temperature, + "min_p": min_p, + "top_p": base_top_p, # Top-P remains constant (not adjusted by VADER) + "repetition_penalty": repetition_penalty } }) output_json_path = output_dir / "chunks_info.json" - save_chunks(output_json_path, enriched) + + # Add voice metadata if provided + if voice_name: + # Try metadata method first + try: + # Create metadata entry as first element + metadata = { + "_metadata": True, + "voice_used": voice_name, + "generation_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "total_chunks": len(enriched) + } + enriched_with_metadata = [metadata] + enriched + save_chunks(output_json_path, enriched_with_metadata) + print(f"βœ… Saved voice metadata: {voice_name}") + except Exception as e: + # Fallback to comment method if metadata fails + print(f"⚠️ Metadata method failed, using comment fallback: {e}") + save_chunks(output_json_path, enriched) + + # Add voice as comment + from modules.voice_detector import add_voice_to_json + add_voice_to_json(output_json_path, voice_name, method="comment") + else: + save_chunks(output_json_path, enriched) + return enriched -def process_book_folder(book_dir, voice_path, tts_params, device, skip_cleanup=False): +def process_book_folder(book_dir, voice_path, tts_params, device, skip_cleanup=False, enable_asr=None, quality_params=None, config_params=None, specific_text_file=None): """Enhanced book processing with batch processing to prevent hangs""" print(f"πŸ” DEBUG: Entering process_book_folder with book_dir='{book_dir}', voice_path='{voice_path}'") - - from chatterbox.tts import punc_norm + + # Apply GUI quality parameters to override config defaults + if quality_params: + print(f"πŸ”§ Applying GUI quality parameters: {quality_params}") + + # Override config values with GUI settings + global ENABLE_REGENERATION_LOOP, ENABLE_SENTIMENT_SMOOTHING, ENABLE_MFCC_VALIDATION + global ENABLE_OUTPUT_VALIDATION, QUALITY_THRESHOLD, OUTPUT_VALIDATION_THRESHOLD + global SENTIMENT_SMOOTHING_WINDOW, SENTIMENT_SMOOTHING_METHOD, SPECTRAL_ANOMALY_THRESHOLD + + ENABLE_REGENERATION_LOOP = quality_params.get('regeneration_enabled', ENABLE_REGENERATION_LOOP) + ENABLE_SENTIMENT_SMOOTHING = quality_params.get('sentiment_smoothing', ENABLE_SENTIMENT_SMOOTHING) + ENABLE_MFCC_VALIDATION = quality_params.get('mfcc_validation', ENABLE_MFCC_VALIDATION) + ENABLE_OUTPUT_VALIDATION = quality_params.get('output_validation', ENABLE_OUTPUT_VALIDATION) + QUALITY_THRESHOLD = quality_params.get('quality_threshold', QUALITY_THRESHOLD) + OUTPUT_VALIDATION_THRESHOLD = quality_params.get('output_threshold', OUTPUT_VALIDATION_THRESHOLD) + SENTIMENT_SMOOTHING_WINDOW = quality_params.get('smoothing_window', SENTIMENT_SMOOTHING_WINDOW) + SENTIMENT_SMOOTHING_METHOD = quality_params.get('smoothing_method', SENTIMENT_SMOOTHING_METHOD) + SPECTRAL_ANOMALY_THRESHOLD = quality_params.get('spectral_threshold', SPECTRAL_ANOMALY_THRESHOLD) + + print(f"βœ… Quality settings applied - Regeneration: {ENABLE_REGENERATION_LOOP}, MFCC: {ENABLE_MFCC_VALIDATION}, Output Validation: {ENABLE_OUTPUT_VALIDATION}") + + from src.chatterbox.tts import punc_norm print(f"πŸ” DEBUG: Successfully imported punc_norm") # Setup directories print(f"πŸ” DEBUG: Calling setup_book_directories...") output_root, tts_dir, text_chunks_dir, audio_chunks_dir = setup_book_directories(book_dir) print(f"πŸ” DEBUG: Directory setup complete") - + # Clean previous processing files (but skip for resume operations) if skip_cleanup: print(f"πŸ”„ RESUME MODE: Skipping cleanup to preserve existing chunks") @@ -485,52 +1195,64 @@ def process_book_folder(book_dir, voice_path, tts_params, device, skip_cleanup=F else: print(f"🧹 FRESH PROCESSING: Cleaning previous processing files...") import glob - - # Clear text chunks + + # Clear text chunks for txt_file in text_chunks_dir.glob("*.txt"): txt_file.unlink(missing_ok=True) for json_file in text_chunks_dir.glob("*.json"): json_file.unlink(missing_ok=True) - + # Clear audio chunks for wav_file in audio_chunks_dir.glob("*.wav"): wav_file.unlink(missing_ok=True) - + # Clear logs for log_file in output_root.glob("*.log"): log_file.unlink(missing_ok=True) - + print(f"βœ… Cleanup complete") # Find book files print(f"πŸ” DEBUG: Calling find_book_files...") book_files = find_book_files(book_dir) - text_files = [book_files['text']] if book_files['text'] else [] + + # Use specific text file if provided (GUI selection), otherwise use auto-detected file + if specific_text_file: + text_file_to_use = Path(specific_text_file) + print(f"🎯 DEBUG: Using GUI-selected text file: {text_file_to_use}") + if not text_file_to_use.exists(): + logging.error(f"[{book_dir.name}] ERROR: Selected text file not found: {text_file_to_use}") + return None, None, [] + else: + text_file_to_use = book_files['text'] + print(f"πŸ” DEBUG: Using auto-detected text file: {text_file_to_use}") + if not text_file_to_use: + logging.info(f"[{book_dir.name}] ERROR: No .txt files found in the book folder.") + return None, None, [] + cover_file = book_files['cover'] nfo_file = book_files['nfo'] - print(f"πŸ” DEBUG: Found text files: {text_files}") - - if not text_files: - logging.info(f"[{book_dir.name}] ERROR: No .txt files found in the book folder.") - return None, None, [] setup_logging(output_root) - # Generate enriched chunks with VADER analysis using user parameters - all_chunks = generate_enriched_chunks(text_files[0], text_chunks_dir, tts_params) + # Extract voice name for logging and JSON metadata + voice_name_for_log = voice_path.stem if hasattr(voice_path, 'stem') else Path(voice_path).stem + + # Generate enriched chunks with VADER analysis using user parameters and GUI quality settings + print(f"πŸ” DEBUG: About to call generate_enriched_chunks with quality_params: {quality_params}") + print(f"πŸ” DEBUG: About to call generate_enriched_chunks with config_params: {config_params}") + print(f"πŸ” DEBUG: Using voice: {voice_name_for_log}") + all_chunks = generate_enriched_chunks(text_file_to_use, text_chunks_dir, tts_params, quality_params, config_params, voice_name_for_log) # Create run_log_lines print(f"πŸ” DEBUG: Creating run_log_lines...") print(f"πŸ” DEBUG: voice_path type: {type(voice_path)}, value: {voice_path}") - - # Extract voice name for logging - voice_name_for_log = voice_path.stem if hasattr(voice_path, 'stem') else Path(voice_path).stem - + run_log_lines = [ f"\n===== Processing: {book_dir.name} =====", f"Voice: {voice_name_for_log}", f"Started: {time.strftime('%Y-%m-%d %H:%M:%S')}", - f"Text files processed: {len(text_files)}", + f"Text file processed: {text_file_to_use.name}", f"Total chunks generated: {len(all_chunks)}" ] @@ -539,6 +1261,13 @@ def process_book_folder(book_dir, voice_path, tts_params, device, skip_cleanup=F log_path = output_root / "chunk_validation.log" total_audio_duration = 0.0 + # Initialize performance optimizations + deployment_env = detect_deployment_environment() + print(f"🌍 Deployment environment: {deployment_env}") + + # Enable GPU persistence mode for better performance + gpu_persistence_enabled = enable_gpu_persistence_mode() + # Batch processing print(f"πŸ“Š Processing {total_chunks} chunks in batches of {BATCH_SIZE}") @@ -553,93 +1282,135 @@ def process_book_folder(book_dir, voice_path, tts_params, device, skip_cleanup=F # Fresh model for each batch model = load_optimized_model(device) compatible_voice = ensure_voice_sample_compatibility(voice_path, output_dir=tts_dir) - model.prepare_conditionals(compatible_voice) + + # Pre-warm model to eliminate first chunk quality variations + model = prewarm_model_with_voice(model, compatible_voice, tts_params) - # Load ASR model once per batch if needed (check user settings first, then global config) + # Load ASR model once per batch if needed using adaptive manager asr_model = None - enable_asr_user = tts_params.get('enable_asr', False) - if enable_asr_user or ENABLE_ASR: - import whisper - print(f"🎀 Loading Whisper ASR model for batch... (user setting: {enable_asr_user})") - # Use same device as TTS model, with fallback to CPU - asr_device = device if torch.cuda.is_available() and device == "cuda" else "cpu" - print(f"🎀 Loading ASR model on device: {asr_device}") - asr_model = whisper.load_model("base", device=asr_device) - - futures = [] - batch_results = [] + asr_device_used = None + # Use parameter if provided, otherwise fall back to config + asr_enabled = enable_asr if enable_asr is not None else ENABLE_ASR + if asr_enabled: + from modules.asr_manager import load_asr_model_adaptive + + # Get ASR config from parameters + asr_config = config_params.get('asr_config', {}) if config_params else {} + + # Use adaptive ASR manager for intelligent loading + asr_model, asr_device_used = load_asr_model_adaptive(asr_config) + + if asr_model is None: + print(f"❌ ASR model loading failed completely - disabling ASR for this batch") + asr_enabled = False # Dynamic worker allocation - user_max_workers = tts_params.get('max_workers', None) - optimal_workers = get_optimal_workers(user_max_workers) + optimal_workers = get_optimal_workers() print(f"πŸ”§ Using {optimal_workers} workers for batch {batch_start+1}-{batch_end}") - with ThreadPoolExecutor(max_workers=optimal_workers) as executor: - for i, chunk_data in enumerate(batch_chunks): - global_chunk_index = batch_start + i - - # Check for shutdown request - if shutdown_requested: - print(f"\n⏹️ {YELLOW}Stopping submission of new chunks...{RESET}") - break - - # Handle both dictionary and tuple formats for chunk data - if isinstance(chunk_data, dict): - chunk = chunk_data["text"] - boundary_type = chunk_data.get("boundary_type", "none") - # Use chunk-specific TTS params if available, otherwise fall back to global - chunk_tts_params = chunk_data.get("tts_params", tts_params) + # Try producer-consumer pipeline first (Phase 4 optimization) + batch_results = [] + if ENABLE_PRODUCER_CONSUMER_PIPELINE: + try: + print(f"πŸš€ Producer-consumer pipeline for batch {batch_start+1}-{batch_end}") + pipeline_results = process_chunks_with_pipeline( + all_chunks, batch_chunks, batch_start, text_chunks_dir, audio_chunks_dir, + voice_path, tts_params, start_time, total_chunks, punc_norm, book_dir.name, + log_run, log_path, device, model, asr_model, asr_enabled, optimal_workers, + total_audio_duration # Pass accumulated duration for proper ETA calculation + ) + + # Handle tuple return from pipeline + if isinstance(pipeline_results, tuple) and len(pipeline_results) == 2: + batch_results, batch_audio_duration = pipeline_results + total_audio_duration += batch_audio_duration else: - # Handle old tuple format (text, is_para_end) - convert to boundary_type - chunk = chunk_data[0] if len(chunk_data) > 0 else str(chunk_data) - # Convert old is_paragraph_end to boundary_type - is_old_para_end = chunk_data[1] if len(chunk_data) > 1 else False - boundary_type = "paragraph_end" if is_old_para_end else "none" - chunk_tts_params = tts_params # Fallback for old format - - # Handle both dictionary and tuple formats for backward compatibility - all_chunk_texts = [] - for cd in all_chunks: - if isinstance(cd, dict): - all_chunk_texts.append(cd["text"]) + # Fallback for old return format + batch_results = pipeline_results + + if batch_results: + print(f"βœ… Producer-consumer pipeline completed: {len(batch_results)} chunks") + # Pipeline already handled progress logging internally + + except Exception as e: + logging.error(f"❌ Producer-consumer pipeline failed: {e}") + if not ENABLE_PIPELINE_FALLBACK: + raise + batch_results = [] # Clear failed results + + # Fallback to original sequential processing if pipeline disabled or failed + if not batch_results: + print(f"πŸ”„ Sequential processing fallback for batch {batch_start+1}-{batch_end}") + futures = [] + + with ThreadPoolExecutor(max_workers=optimal_workers) as executor: + for i, chunk_data in enumerate(batch_chunks): + global_chunk_index = batch_start + i + + # Check for shutdown request + if shutdown_requested: + print(f"\n⏹️ {YELLOW}Stopping submission of new chunks...{RESET}") + break + + # Handle both dictionary and tuple formats for chunk data + if isinstance(chunk_data, dict): + chunk = chunk_data["text"] + boundary_type = chunk_data.get("boundary_type", "none") + # Use chunk-specific TTS params if available, otherwise fall back to global + chunk_tts_params = chunk_data.get("tts_params", tts_params) else: - # Handle old tuple format (text, is_para_end) - all_chunk_texts.append(cd[0] if len(cd) > 0 else str(cd)) - - futures.append(executor.submit( - process_one_chunk, - global_chunk_index, chunk, text_chunks_dir, audio_chunks_dir, - voice_path, chunk_tts_params, start_time, total_chunks, - punc_norm, book_dir.name, log_run, log_path, device, - model, asr_model, all_chunk_texts, boundary_type - )) - - # Wait for batch to complete - print(f"πŸ”„ {CYAN}Waiting for batch {batch_start+1}-{batch_end} to complete...{RESET}") - completed_count = 0 - - for fut in as_completed(futures): - try: - idx, wav_path = fut.result() - if wav_path and wav_path.exists(): - # Measure actual audio duration for this chunk - chunk_duration = get_chunk_audio_duration(wav_path) - total_audio_duration += chunk_duration - batch_results.append((idx, wav_path)) - - # Update progress every 10 chunks within batch - completed_count += 1 - if completed_count % 10 == 0: - log_chunk_progress(batch_start + completed_count - 1, total_chunks, start_time, total_audio_duration) - - except Exception as e: - logging.error(f"Future failed in batch: {e}") + # Handle old tuple format (text, is_para_end) - convert to boundary_type + chunk = chunk_data[0] if len(chunk_data) > 0 else str(chunk_data) + # Convert old is_paragraph_end to boundary_type + is_old_para_end = chunk_data[1] if len(chunk_data) > 1 else False + boundary_type = "paragraph_end" if is_old_para_end else "none" + chunk_tts_params = tts_params # Fallback for old format + + # Handle both dictionary and tuple formats for backward compatibility + all_chunk_texts = [] + for cd in all_chunks: + if isinstance(cd, dict): + all_chunk_texts.append(cd["text"]) + else: + # Handle old tuple format (text, is_para_end) + all_chunk_texts.append(cd[0] if len(cd) > 0 else str(cd)) + + futures.append(executor.submit( + process_one_chunk, + global_chunk_index, chunk, text_chunks_dir, audio_chunks_dir, + voice_path, chunk_tts_params, start_time, total_chunks, + punc_norm, book_dir.name, log_run, log_path, device, + model, asr_model, all_chunk_texts, boundary_type, + asr_enabled + )) + + # Wait for batch to complete + print(f"πŸ”„ {CYAN}Waiting for batch {batch_start+1}-{batch_end} to complete...{RESET}") + completed_count = 0 + + for fut in as_completed(futures): + try: + idx, wav_path = fut.result() + if wav_path and wav_path.exists(): + # Measure actual audio duration for this chunk + chunk_duration = get_chunk_audio_duration(wav_path) + total_audio_duration += chunk_duration + batch_results.append((idx, wav_path)) + + # Update progress every 2 chunks within batch + completed_count += 1 + if completed_count % 2 == 0: + log_chunk_progress(batch_start + completed_count - 1, total_chunks, start_time, total_audio_duration) + + except Exception as e: + logging.error(f"Future failed in batch: {e}") # Clean up model after batch print(f"🧹 Cleaning up after batch {batch_start+1}-{batch_end}") del model if asr_model: - del asr_model + from modules.asr_manager import cleanup_asr_model + cleanup_asr_model(asr_model) torch.cuda.empty_cache() gc.collect() time.sleep(2) @@ -690,7 +1461,7 @@ def process_book_folder(book_dir, voice_path, tts_params, device, skip_cleanup=F f"Combined WAV: {combined_wav_path}", "--- Generation Settings ---", f"Batch Processing: Enabled ({BATCH_SIZE} chunks per batch)", - f"ASR Enabled: {enable_asr_user or ENABLE_ASR} (user: {enable_asr_user}, global: {ENABLE_ASR})", + f"ASR Enabled: {ENABLE_ASR}", f"Hum Detection: {ENABLE_HUM_DETECTION}", f"Dynamic Workers: {USE_DYNAMIC_WORKERS}", f"Voice used: {voice_name}", diff --git a/requirements.txt b/requirements.txt index 0fef73fab6ff0874881aa235ca25baa2ba0741a0..5d368d6fa90ffcfdac8f6a957c6a51f46afc8c1b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -# ChatterboxTTS HuggingFace Spaces Requirements -# Optimized for HF Spaces environment with flexible versions +# ChatterboxTTS Requirements - Main file (works everywhere) +# Complete dependency list for local development -# Core ML and TTS - Essential (pinned versions for fast builds) +# Core ML and TTS - Essential torch==2.6.0 torchaudio==2.6.0 transformers==4.46.3 @@ -21,7 +21,7 @@ openai-whisper>=20231117 psutil>=5.8.0 pynvml>=11.0.0 -# Core scientific computing (pinned for fast builds) +# Core scientific computing numpy==2.2.0 scipy>=1.7.0 @@ -29,9 +29,13 @@ scipy>=1.7.0 regex>=2023.0.0 vaderSentiment>=3.3.0 -# Web interface - Gradio (let HF manage version) +# Web interface - Gradio gradio>=4.0.0 +# GUI Support - Local development +PyQt5>=5.15.0 +PyQt5-sip>=12.0.0 + # Progress and logging tqdm>=4.60.0 @@ -45,8 +49,10 @@ python-dotenv>=1.0.0 requests>=2.25.0 packaging>=21.0 +# Development tools +pre-commit>=2.0.0 + # Core ChatterboxTTS model dependencies -chatterbox-tts>=0.1.2 resemble-perth>=1.0.1 omegaconf>=2.3.0 einops>=0.6.0 diff --git a/utils/generate_from_json.py b/utils/generate_from_json.py deleted file mode 100644 index 32a1c7545878c3dacea2e2aec805310fbfcb7f32..0000000000000000000000000000000000000000 --- a/utils/generate_from_json.py +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env python3 -""" -Direct Audio Generation from JSON Tool - -This script allows for generating audiobook chunks directly from a pre-existing -`chunks_info.json` file. It is intended for debugging and testing purposes, -allowing a user to manually edit the TTS parameters in the JSON file and -hear the results without the VADER analysis step. -""" - -import torch -from pathlib import Path -import sys -from concurrent.futures import ThreadPoolExecutor, as_completed -import time -from datetime import timedelta - -# Add project root to path to allow module imports -project_root = Path(__file__).parent -sys.path.append(str(project_root)) - -from config.config import * -from modules.tts_engine import load_optimized_model, process_one_chunk, prewarm_model_with_voice -from modules.file_manager import setup_book_directories, list_voice_samples, ensure_voice_sample_compatibility -from wrapper.chunk_loader import load_chunks -from chatterbox.tts import punc_norm -from modules.progress_tracker import log_chunk_progress, log_run - -def main(): - """Main function to drive the generation process.""" - print(f"{BOLD}{CYAN}--- Direct Audio Generation from JSON Tool ---\{RESET}") - - # 1. Get Book Name - book_name = input("Enter the book name (e.g., 'london'): ").strip() - if not book_name: - print("❌ Book name cannot be empty.") - return - - # 2. Locate and Load JSON - book_audio_dir = AUDIOBOOK_ROOT / book_name - json_path = book_audio_dir / "TTS" / "text_chunks" / "chunks_info.json" - - if not json_path.exists(): - print(f"❌ Error: JSON file not found at {json_path}") - print("Please ensure you have run the 'Prepare text file' option for this book first.") - return - - print(f"πŸ“– Loading chunks from: {json_path}") - all_chunks = load_chunks(str(json_path)) - print(f"βœ… Found {len(all_chunks)} chunks.") - - # 3. Select Voice - voice_files = list_voice_samples() - if not voice_files: - print(f"❌ No voice samples found in {VOICE_SAMPLES_DIR}") - return - - print("\nAvailable voices:") - for i, voice_file in enumerate(voice_files, 1): - print(f" [{i}] {voice_file.stem}") - - while True: - try: - choice = input("Select voice number: ").strip() - idx = int(choice) - 1 - if 0 <= idx < len(voice_files): - voice_path = voice_files[idx] - break - print("Invalid selection.") - except (ValueError, IndexError): - print("Invalid selection.") - - # Ensure voice compatibility - voice_path = ensure_voice_sample_compatibility(voice_path) - - # 4. Setup Environment - if torch.cuda.is_available(): - device = "cuda" - elif torch.backends.mps.is_available(): - device = "mps" - else: - device = "cpu" - - print(f"\nπŸš€ Using device: {device}") - print(f"🎀 Using voice: {Path(voice_path).name}") - - # 5. Load Model - model = load_optimized_model(device) - - # 6. Pre-warm model to eliminate first chunk quality variations - print(f"πŸ”₯ Pre-warming model with voice sample: {Path(voice_path).name}") - from modules.tts_engine import prewarm_model_with_voice - compatible_voice = ensure_voice_sample_compatibility(voice_path) - # Use default TTS params for pre-warming since we don't have user params here - model = prewarm_model_with_voice(model, compatible_voice, None) - - # 7. Process Chunks - output_root, tts_dir, text_chunks_dir, audio_chunks_dir = setup_book_directories(Path(TEXT_INPUT_ROOT) / book_name) - - # Clean existing audio chunks - print("🧹 Clearing old audio chunks...") - for wav_file in audio_chunks_dir.glob("*.wav"): - wav_file.unlink() - - start_time = time.time() - total_chunks = len(all_chunks) - log_path = output_root / "debug_generation.log" - - print(f"\nπŸ”„ Generating {total_chunks} chunks...") - - with ThreadPoolExecutor(max_workers=2) as executor: # Test parallel processing - futures = [] - for i, chunk_data in enumerate(all_chunks): - # Extract exaggeration from JSON, force others to default - chunk_tts_params = { - "exaggeration": chunk_data.get("tts_params", {}).get("exaggeration", DEFAULT_EXAGGERATION), - "cfg_weight": DEFAULT_CFG_WEIGHT, - "temperature": DEFAULT_TEMPERATURE - } - - future = executor.submit( - process_one_chunk, - i, chunk_data['text'], text_chunks_dir, audio_chunks_dir, - voice_path, chunk_tts_params, start_time, total_chunks, - punc_norm, book_name, log_run, log_path, device, - model, None, all_chunks, chunk_data['boundary_type'] - ) - futures.append(future) - - for future in as_completed(futures): - try: - result = future.result() - if result: - idx, _ = result - log_chunk_progress(idx, total_chunks, start_time, 0) - except Exception as e: - print(f"\n❌ An error occurred while processing a chunk: {e}") - - elapsed_time = time.time() - start_time - print(f"\n{GREEN}βœ… Generation Complete!{RESET}") - print(f"⏱️ Total time: {timedelta(seconds=int(elapsed_time))}") - print(f"πŸ”Š Audio chunks are in: {audio_chunks_dir}") - print("You can now use Option 3 from the main menu to combine them.") - -if __name__ == "__main__": - main() diff --git a/wrapper/chunk_tool.py b/wrapper/chunk_tool.py deleted file mode 100644 index c529bb13220f66e22013e8bbda393292c5a51a45..0000000000000000000000000000000000000000 --- a/wrapper/chunk_tool.py +++ /dev/null @@ -1,249 +0,0 @@ -from wrapper.chunk_loader import load_chunks, save_chunks -from wrapper.chunk_search import search_chunks -from wrapper.chunk_editor import update_chunk -from wrapper.chunk_player import play_chunk_audio -from wrapper.chunk_synthesizer import synthesize_chunk -from wrapper.chunk_revisions import accept_revision -import os -from config.config import AUDIOBOOK_ROOT -AUDIO_DIR = AUDIOBOOK_ROOT - -def select_book_for_repair(): - """Let user select which book to repair""" - from pathlib import Path - - # Look for books in both locations: TTS processing dirs and Text_Input - available_books = [] - - # First check TTS processing directories - audiobook_root = Path(AUDIOBOOK_ROOT) - if audiobook_root.exists(): - for book_dir in audiobook_root.iterdir(): - if book_dir.is_dir(): - tts_chunks_dir = book_dir / "TTS" / "text_chunks" - json_path = tts_chunks_dir / "chunks_info.json" - if json_path.exists(): - available_books.append((book_dir.name, json_path, "TTS")) - - # Then check Text_Input directory for fallback - text_input_dir = Path("Text_Input") - if text_input_dir.exists(): - for chunk_file in text_input_dir.glob("*_chunks.json"): - book_name = chunk_file.stem.replace("_chunks", "") - # Only add if not already found in TTS directories - if not any(book[0] == book_name for book in available_books): - available_books.append((book_name, chunk_file, "Text_Input")) - - if not available_books: - print("❌ No chunk files found in TTS processing directories or Text_Input/") - return None, None - - print("\nπŸ“š Available books for repair:") - for i, (book_name, json_path, source) in enumerate(available_books): - print(f" [{i}] {book_name} ({source}: {json_path.name})") - - while True: - try: - choice = input(f"\nSelect book index [0-{len(available_books)-1}]: ").strip() - idx = int(choice) - if 0 <= idx < len(available_books): - book_name, json_path, source = available_books[idx] - return book_name, json_path - else: - print(f"❌ Please enter a number between 0 and {len(available_books)-1}") - except (ValueError, EOFError, KeyboardInterrupt): - print("❌ Invalid selection or cancelled") - return None, None - -def run_chunk_repair_tool(): - print("\nπŸ› οΈ Chunk Repair & Revision Tool") - - # Ask user to select book - book_name, chunk_path = select_book_for_repair() - if not chunk_path: - return - - print(f"\nπŸ“– Loading chunks from: {chunk_path.name}") - chunks = load_chunks(str(chunk_path)) - - # Determine audio directory path based on book structure - from pathlib import Path - audiobook_root = Path(AUDIOBOOK_ROOT) - book_audio_dir = audiobook_root / book_name / "TTS" / "audio_chunks" - - if not book_audio_dir.exists(): - print(f"❌ Audio directory not found: {book_audio_dir}") - print(f"πŸ“ Looked for: {book_audio_dir}") - return - - print(f"πŸ“ Using audio directory: {book_audio_dir}") - - while True: - query = input("\nSearch for text fragment (or 'Q' to quit): ").strip() - if query.lower() == "q": - print("Exiting revision tool.") - break - - results = search_chunks(chunks, query) - if not results: - print("❌ No matching chunks found.") - continue - - print(f"\nπŸ” Found {len(results)} match(es):") - for i, chunk in enumerate(results): - print(f"[{i}] \"{chunk['text'][:60]}...\" | Index: {chunk['index']}") - - sel = input("Select chunk index to revise: ").strip() - if not sel.isdigit() or int(sel) >= len(results): - print("Invalid selection.") - continue - - chunk = results[int(sel)] - index = chunk['index'] - # Use 5-digit chunk numbering and correct directory path - chunk_audio_path = book_audio_dir / f"chunk_{index+1:05d}.wav" - chunk_audio_path_str = str(chunk_audio_path) - - while True: - print(f"\nπŸ“ Chunk: \"{chunk['text']}\"") - - # Display current chunk metadata - sentiment_compound = chunk.get('sentiment_compound', chunk.get('sentiment_score', 'N/A')) - tts_params = chunk.get('tts_params', {}) - - print(f" πŸ“ Index: {index}, Boundary: {chunk['boundary_type']}") - print(f" 😊 Sentiment: {sentiment_compound}") - print(f" πŸŽ›οΈ TTS Params: exag={tts_params.get('exaggeration', 'N/A')}, cfg={tts_params.get('cfg_weight', 'N/A')}, temp={tts_params.get('temperature', 'N/A')}") - print(f" πŸ“ Audio file: chunk_{index+1:05d}.wav") - print("\nOptions:") - print(" 1. Play original audio") - print(" 2. Edit text content") - print(" 3. Edit chunk metadata (boundary, sentiment)") - print(" 4. Edit TTS parameters (exaggeration, cfg_weight, temperature)") - print(" 5. Resynthesize audio with current settings") - print(" 6. Play revised audio") - print(" 7. Accept revision (replace original with revised)") - print(" 8. Back to search") - - try: - choice = input("\nπŸ’‘ Enter option number [1-8]: ").strip() - except (EOFError, KeyboardInterrupt): - print("\n❌ Input cancelled") - return - if choice == "1": - print(f"\nπŸ”Š Playing original audio: {chunk_audio_path.name}") - play_chunk_audio(chunk_audio_path_str) - elif choice == "2": - print("\n✏️ Edit Text Content:") - print(f"Current text: \"{chunk['text']}\"") - print("πŸ’‘ Enter new text (or Enter to cancel):") - new_text = input(">>> ").strip() - - if new_text: - chunk['text'] = new_text - chunk['word_count'] = len(new_text.split()) - save_chunks(str(chunk_path), chunks) - print("βœ… Text content updated successfully") - print(f"πŸ“Š New word count: {chunk['word_count']}") - else: - print("❌ No changes made") - elif choice == "3": - print("\n✏️ Edit Chunk Metadata:") - print(f"Current boundary type: {chunk['boundary_type']}") - boundary = input("New boundary type (none/paragraph_end/chapter_start/chapter_end/section_break) [Enter to skip]: ").strip() - - current_sentiment = chunk.get('sentiment_compound', chunk.get('sentiment_score', 'N/A')) - print(f"Current sentiment score: {current_sentiment}") - sentiment = input("New sentiment compound score (-1.0 to 1.0) [Enter to skip]: ").strip() - - try: - if boundary: - chunk['boundary_type'] = boundary - print(f"βœ… Updated boundary type to: {boundary}") - - if sentiment: - sentiment_val = float(sentiment) - if -1.0 <= sentiment_val <= 1.0: - chunk['sentiment_compound'] = sentiment_val - # Also update old key for compatibility - chunk['sentiment_score'] = sentiment_val - print(f"βœ… Updated sentiment score to: {sentiment_val}") - else: - print("❌ Sentiment score must be between -1.0 and 1.0") - - save_chunks(str(chunk_path), chunks) - print("βœ… Chunk metadata updated successfully") - except ValueError as e: - print(f"❌ Invalid input: {e}") - except Exception as e: - print(f"❌ Error updating chunk: {e}") - elif choice == "4": - print("\nπŸŽ›οΈ Edit TTS Parameters:") - current_tts_params = chunk.get('tts_params', {}) - - def get_float_input(param_name, current_val, min_val=None, max_val=None): - while True: - try: - prompt = f"New {param_name} [{current_val}]: " - value = input(prompt).strip() - if not value: - return current_val - new_val = float(value) - if min_val is not None and new_val < min_val: - print(f"❌ {param_name} must be >= {min_val}") - continue - if max_val is not None and new_val > max_val: - print(f"❌ {param_name} must be <= {max_val}") - continue - return new_val - except ValueError: - print(f"❌ Invalid input. Please enter a valid number.") - - # Edit TTS parameters - print(f"Current TTS parameters:") - current_exag = current_tts_params.get('exaggeration', 1.0) - current_cfg = current_tts_params.get('cfg_weight', 0.7) - current_temp = current_tts_params.get('temperature', 0.7) - - print(f" Exaggeration: {current_exag}") - print(f" CFG Weight: {current_cfg}") - print(f" Temperature: {current_temp}") - - new_exag = get_float_input("exaggeration", current_exag, 0.0, 3.0) - new_cfg = get_float_input("CFG weight", current_cfg, 0.0, 2.0) - new_temp = get_float_input("temperature", current_temp, 0.0, 2.0) - - # Update chunk TTS parameters - if 'tts_params' not in chunk: - chunk['tts_params'] = {} - - chunk['tts_params']['exaggeration'] = new_exag - chunk['tts_params']['cfg_weight'] = new_cfg - chunk['tts_params']['temperature'] = new_temp - - save_chunks(str(chunk_path), chunks) - print(f"βœ… TTS parameters updated: exag={new_exag}, cfg={new_cfg}, temp={new_temp}") - elif choice == "5": - print(f"\n🎀 Resynthesizing chunk {index+1:05d}...") - revised_path = synthesize_chunk(chunk, index, book_name, book_audio_dir, revision=True) - if revised_path: - print(f"βœ… Chunk resynthesized: {revised_path}") - else: - print("❌ Failed to resynthesize chunk") - elif choice == "6": - rev_path = book_audio_dir / f"chunk_{index+1:05d}_rev.wav" - print(f"\nπŸ”Š Playing revised audio: {rev_path.name}") - play_chunk_audio(str(rev_path)) - elif choice == "7": - print(f"\nπŸ“¦ Accepting revision for chunk {index+1:05d}...") - accept_revision(index, book_audio_dir) - print("βœ… Revision accepted successfully") - break - elif choice == "8": - print("πŸ”™ Returning to search...") - break - elif choice.lower() == 'q': - print("πŸšͺ Exiting chunk repair tool...") - return - else: - print(f"❌ Invalid option '{choice}'. Please enter a number 1-8 (or 'q' to quit).")