danneauxs commited on
Commit
3cb0dc4
·
1 Parent(s): 3aa3268
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. BATCH_IMPLEMENTATION_PLAN.md +0 -58
  2. HF_Deploy/.gitattributes +36 -0
  3. HF_Deploy/.gitignore +2 -0
  4. HF_Deploy/README.md +14 -0
  5. HF_Deploy/Text_Input/Goliath/test1.txt +7 -0
  6. HF_Deploy/Text_Input/README.md +40 -0
  7. HF_Deploy/Text_Input/test +20 -0
  8. app.py.20250811-120000.bak → HF_Deploy/app.py +0 -0
  9. HF_Deploy/config/__init__.py +0 -0
  10. HF_Deploy/config/config.py +159 -0
  11. HF_Deploy/gradio_main_interface.py +148 -0
  12. HF_Deploy/gradio_tabs/__init__.py +7 -0
  13. HF_Deploy/gradio_tabs/tab1_convert_book.py +1173 -0
  14. HF_Deploy/modules/__init__.py +0 -0
  15. HF_Deploy/modules/asr_manager.py +233 -0
  16. HF_Deploy/modules/audio_processor.py +569 -0
  17. HF_Deploy/modules/batch_processor.py +31 -0
  18. HF_Deploy/modules/file_manager.py +431 -0
  19. HF_Deploy/modules/gui_json_generator.py +217 -0
  20. HF_Deploy/modules/path_manager.py +19 -0
  21. HF_Deploy/modules/progress_tracker.py +306 -0
  22. HF_Deploy/modules/resume_handler.py +596 -0
  23. HF_Deploy/modules/system_detector.py +231 -0
  24. HF_Deploy/modules/text_processor.py +745 -0
  25. HF_Deploy/modules/tts_engine.py +710 -0
  26. HF_Deploy/modules/voice_detector.py +240 -0
  27. HF_Deploy/requirements.txt +56 -0
  28. HF_Deploy/src/chatterbox/__init__.py +2 -0
  29. HF_Deploy/src/chatterbox/models/s3gen/__init__.py +2 -0
  30. HF_Deploy/src/chatterbox/models/s3gen/const.py +1 -0
  31. HF_Deploy/src/chatterbox/models/s3gen/decoder.py +317 -0
  32. HF_Deploy/src/chatterbox/models/s3gen/f0_predictor.py +55 -0
  33. HF_Deploy/src/chatterbox/models/s3gen/flow.py +242 -0
  34. HF_Deploy/src/chatterbox/models/s3gen/flow_matching.py +228 -0
  35. HF_Deploy/src/chatterbox/models/s3gen/hifigan.py +474 -0
  36. HF_Deploy/src/chatterbox/models/s3gen/matcha/decoder.py +443 -0
  37. HF_Deploy/src/chatterbox/models/s3gen/matcha/flow_matching.py +129 -0
  38. HF_Deploy/src/chatterbox/models/s3gen/matcha/text_encoder.py +413 -0
  39. HF_Deploy/src/chatterbox/models/s3gen/matcha/transformer.py +316 -0
  40. HF_Deploy/src/chatterbox/models/s3gen/s3gen.py +305 -0
  41. HF_Deploy/src/chatterbox/models/s3gen/transformer/__init__.py +0 -0
  42. HF_Deploy/src/chatterbox/models/s3gen/transformer/activation.py +84 -0
  43. HF_Deploy/src/chatterbox/models/s3gen/transformer/attention.py +330 -0
  44. HF_Deploy/src/chatterbox/models/s3gen/transformer/convolution.py +145 -0
  45. HF_Deploy/src/chatterbox/models/s3gen/transformer/embedding.py +294 -0
  46. HF_Deploy/src/chatterbox/models/s3gen/transformer/encoder_layer.py +236 -0
  47. HF_Deploy/src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py +115 -0
  48. HF_Deploy/src/chatterbox/models/s3gen/transformer/subsampling.py +383 -0
  49. HF_Deploy/src/chatterbox/models/s3gen/transformer/upsample_encoder.py +318 -0
  50. HF_Deploy/src/chatterbox/models/s3gen/utils/class_utils.py +71 -0
BATCH_IMPLEMENTATION_PLAN.md DELETED
@@ -1,58 +0,0 @@
1
- # Plan for Implementing High-Performance Batch Processing
2
-
3
- This document outlines the necessary code modifications to implement a high-performance batch processing mode that can be toggled by the "Use VADER" checkbox in the GUI.
4
-
5
- The goal is to create two distinct modes:
6
- - **VADER On (Nuanced Mode):** Slower, processes chunks one-by-one with unique TTS parameters for nuanced delivery.
7
- - **VADER Off (Batch Mode):** Significantly faster, processes chunks in batches with a single set of TTS parameters.
8
-
9
- ---
10
-
11
- ## 1. File to Modify: `src/chatterbox/tts.py`
12
-
13
- * **Purpose:** To enable the core TTS model to handle batches of text.
14
- * **Changes Needed:**
15
- * A new method, `generate_batch(self, texts: list, **tts_params)`, needs to be created within the `ChatterboxTTS` class.
16
- * This method must perform the following steps:
17
- 1. Accept a list of text strings (`texts`).
18
- 2. Tokenize each text string in the list.
19
- 3. Pad the tokenized sequences to ensure they all have the same length, creating a single batch tensor. `torch.nn.utils.rnn.pad_sequence` is suitable for this.
20
- 4. Feed the complete batch tensor to the underlying model (`self.t3.inference` and `self.s3gen.inference`).
21
- 5. Return a list of the resulting audio waveforms.
22
-
23
- ---
24
-
25
- ## 2. File to Modify: `modules/tts_engine.py`
26
-
27
- * **Purpose:** To orchestrate the new batching workflow and choose the processing mode.
28
- * **Changes Needed:**
29
-
30
- ### a. Create a New Worker Function
31
- * Add a new function: `process_batch(batch_of_chunks, model, ...)`
32
- * This function will:
33
- 1. Accept a list of chunk objects (e.g., a batch of 16).
34
- 2. Extract the text from each chunk into a simple list.
35
- 3. Call the new `model.generate_batch()` with the list of texts and the shared TTS parameters.
36
- 4. Receive a list of audio waveforms back.
37
- 5. Loop through the audio waves, apply the existing silence trimming and padding logic to each one, and save them to their respective `chunk_...wav` files.
38
-
39
- ### b. Modify the Main `process_book_folder` Function
40
- * Locate the `use_vader` flag which is determined from the GUI options.
41
- * Wrap the core processing loop in an `if/else` block based on this flag.
42
- * **`if use_vader:` (Nuanced Mode):**
43
- * Keep the existing code that iterates through chunks one-by-one and submits them to the `process_one_chunk` function.
44
- * **`else:` (Batch Mode):**
45
- * Add the new logic here.
46
- * Group the `all_chunks` list into fixed-size batches based on `TTS_BATCH_SIZE` from the config.
47
- * Use the existing `ThreadPoolExecutor` to submit these new **batches** to the new `process_batch` worker function.
48
-
49
- ---
50
-
51
- ## 3. Files to Modify: `config/config.py` and `chatterbox_gui.py`
52
-
53
- * **Purpose:** To provide user control over the batch size for performance tuning.
54
- * **Changes Needed:**
55
- * **In `config/config.py`:**
56
- * Add a new configuration variable: `TTS_BATCH_SIZE = 16` (or another sensible default).
57
- * **In `chatterbox_gui.py`:**
58
- * On the "Config" tab, add a new `QSpinBox` (numeric input field) that is linked to the `TTS_BATCH_SIZE` variable. This will allow the user to change the batch size without editing code.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
HF_Deploy/.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.wav filter=lfs diff=lfs merge=lfs -text
HF_Deploy/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ *.pyc
HF_Deploy/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ChatterboxTTS DNXS Spokenword
3
+ emoji: 🌖
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.39.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ short_description: 'ChatterboxTTS Gradio interface for custom workflow. '
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
HF_Deploy/Text_Input/Goliath/test1.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ My dear fellow, I spent a considerable length of time in the Peninsula. I was with a British rifle brigade when I met Sir Arthur Wellesley. And I was a prisoner of the French at Salamanca - 18 12 I think it was. I always find it’s best to see both sides of both sides, if you see what I mean.
2
+
3
+ I didn’t really think you approved of war sir, said Benton sadly.
4
+
5
+ The Doctor turned his attention back to the twisting country lane. He sighed as he changed gear for another sharp corner. Sometimes it’s inevitable, he noted with genuine sadness. I’m a man of peace, but I seem to spend much of my time caught up in conflict. The central paradox of my life, perhaps.
6
+
7
+ Benton leant back in the seat. What’s the central paradox of mine? he asked, fascinated.
HF_Deploy/Text_Input/README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Text Input Directory
2
+
3
+ Place your book text files here for audiobook generation.
4
+
5
+ ## Directory Structure
6
+ Create a subdirectory for each book:
7
+ ```
8
+ Text_Input/
9
+ ├── Book Name 1/
10
+ │ ├── book.txt # Main text file
11
+ │ ├── cover.jpg # Book cover image (optional)
12
+ │ └── book.nfo # Metadata file (optional)
13
+ ├── Book Name 2/
14
+ │ ├── another_book.txt
15
+ │ └── cover.png
16
+ └── ...
17
+ ```
18
+
19
+ ## Text File Requirements
20
+ - **Format**: Plain text (.txt) files
21
+ - **Encoding**: UTF-8
22
+ - **Content**: Clean text without excessive formatting
23
+ - **Structure**: Use paragraph breaks for natural speech flow
24
+
25
+ ## Optional Files
26
+ - **cover.jpg/png**: Book cover image for M4B metadata
27
+ - **book.nfo**: XML metadata file with book information (title, author, etc.)
28
+
29
+ ## Text Preparation Tips
30
+ - Remove table of contents, page numbers, headers/footers
31
+ - Keep chapter headings (e.g., "Chapter 1")
32
+ - Use proper punctuation for natural speech
33
+ - Remove excessive line breaks or formatting
34
+ - Ensure UTF-8 encoding for special characters
35
+
36
+ ## Processing
37
+ 1. Add your book directory to Text_Input/
38
+ 2. Run the main program and select your book
39
+ 3. The system will chunk the text and generate JSON metadata
40
+ 4. Use the generated chunks for TTS audiobook creation
HF_Deploy/Text_Input/test ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ She stood alone in the hallway. The lights flickered overhead. "I don't like this," she whispered. "Too quiet. Too cold."
2
+
3
+
4
+
5
+ ***
6
+
7
+ Chapter 1
8
+
9
+ A crash echoed from somewhere far off.
10
+ He turned. "Was that you?"
11
+
12
+ "No," she said. "It wasn't me."
13
+
14
+ ---
15
+
16
+ They moved cautiously down the corridor. Every step sounded like thunder. Each shadow seemed to breathe.
17
+
18
+ Chapter 2
19
+
20
+ Something moved behind the curtain.
app.py.20250811-120000.bak → HF_Deploy/app.py RENAMED
File without changes
HF_Deploy/config/__init__.py ADDED
File without changes
HF_Deploy/config/config.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GenTTS Configuration Module
3
+ Central location for all settings, paths, and feature toggles
4
+ """
5
+
6
+ import os
7
+ from pathlib import Path
8
+
9
+ # ============================================================================
10
+ # CORE DIRECTORIES
11
+ # ============================================================================
12
+ TEXT_INPUT_ROOT = Path("Text_Input")
13
+ AUDIOBOOK_ROOT = Path("Audiobook")
14
+ VOICE_SAMPLES_DIR = Path("Voice_Samples")
15
+
16
+ # ============================================================================
17
+ # TEXT PROCESSING SETTINGS
18
+ # ============================================================================
19
+ MAX_CHUNK_WORDS = 28
20
+ MIN_CHUNK_WORDS = 4
21
+
22
+ # ============================================================================
23
+ # WORKER AND PERFORMANCE SETTINGS
24
+ # ============================================================================
25
+ MAX_WORKERS = 2 # Keep at 2 - GPU utilization already high
26
+ TEST_MAX_WORKERS = 6 # For experimentation
27
+ USE_DYNAMIC_WORKERS = False # Toggle for testing
28
+ VRAM_SAFETY_THRESHOLD = 6.5 # GB
29
+
30
+ # ============================================================================
31
+ # AUDIO QUALITY SETTINGS
32
+ # ============================================================================
33
+ ENABLE_MID_DROP_CHECK = False
34
+ ENABLE_ASR = False
35
+ ASR_WORKERS = 4 # Parallel ASR on CPU threads
36
+
37
+ # ============================================================================
38
+ # TTS HUM DETECTION SETTINGS
39
+ # ============================================================================
40
+ ENABLE_HUM_DETECTION = False # Disabled for speed (re-enable if quality issues)
41
+ HUM_FREQ_MIN = 50 # Hz - Lower frequency bound for hum detection
42
+ HUM_FREQ_MAX = 200 # Hz - Upper frequency bound for hum detection
43
+ HUM_ENERGY_THRESHOLD = 0.3 # Ratio of hum energy to total energy (0.1-0.5 range)
44
+ HUM_STEADY_THRESHOLD = 0.6 # Ratio of segments with steady amplitude (0.5-0.8 range)
45
+ HUM_AMPLITUDE_MIN = 0.005 # Minimum RMS for steady hum detection
46
+ HUM_AMPLITUDE_MAX = 0.1 # Maximum RMS for steady hum detection
47
+
48
+ # ============================================================================
49
+ # AUDIO TRIMMING SETTINGS
50
+ # ============================================================================
51
+ ENABLE_AUDIO_TRIMMING = True # Enable automatic audio trimming after TTS
52
+ SPEECH_ENDPOINT_THRESHOLD = 0.005 # RMS threshold to detect end of speech (more aggressive)
53
+ TRIMMING_BUFFER_MS = 50 # Small buffer after detected speech endpoint
54
+
55
+ # ============================================================================
56
+ # SILENCE DURATION SETTINGS (milliseconds)
57
+ # ============================================================================
58
+ SILENCE_CHAPTER_START = 500 # Half second for chapter beginnings
59
+ SILENCE_CHAPTER_END = 800 # Longer pause before new chapter
60
+ SILENCE_SECTION_BREAK = 600 # Section transitions
61
+ SILENCE_PARAGRAPH_END = 300 # Standard paragraph breaks
62
+
63
+ # Punctuation-specific silence settings (milliseconds)
64
+ SILENCE_COMMA = 150 # Brief pause after commas
65
+ SILENCE_SEMICOLON = 250 # Medium pause after semicolons
66
+ SILENCE_COLON = 300 # Pause after colons
67
+ SILENCE_PERIOD = 400 # Sentence end pause
68
+ SILENCE_QUESTION_MARK = 450 # Question pause (slightly longer)
69
+ SILENCE_EXCLAMATION = 400 # Exclamation pause
70
+ SILENCE_DASH = 200 # Em dash pause
71
+ SILENCE_ELLIPSIS = 350 # Ellipsis pause (suspense)
72
+ SILENCE_QUOTE_END = 250 # End of quoted speech
73
+
74
+ # Chunk-level silence settings
75
+ ENABLE_CHUNK_END_SILENCE = True # Add silence to end of every chunk
76
+ CHUNK_END_SILENCE_MS = 200 # Default silence at end of each chunk
77
+
78
+ # Content boundary silence settings (milliseconds)
79
+ SILENCE_PARAGRAPH_FALLBACK = 500 # Original paragraph logic fallback
80
+
81
+ # ============================================================================
82
+ # AUDIO NORMALIZATION SETTINGS
83
+ # ============================================================================
84
+ ENABLE_NORMALIZATION = True # Global ON/OFF switch for normalization
85
+ NORMALIZATION_TYPE = "peak" # Options: "loudness", "peak", "simple", "none"
86
+ TARGET_LUFS = -16 # Target loudness (LUFS) for broadcast standard
87
+ TARGET_PEAK_DB = -1.5 # Target peak level (dB) to prevent clipping
88
+ TARGET_LRA = 11 # Target loudness range for consistency
89
+
90
+ # ============================================================================
91
+ # AUDIO PLAYBACK SPEED SETTINGS
92
+ # ============================================================================
93
+ ATEMPO_SPEED = 0.95 # Playback speed multiplier (0.5-2.0 range, 1.0 = normal speed)
94
+
95
+ # ============================================================================
96
+ # ENVIRONMENT SETUP
97
+ # ============================================================================
98
+ os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
99
+ os.environ["TRANSFORMERS_NO_PROGRESS_BAR"] = "1"
100
+ os.environ["HF_TRANSFORMERS_NO_TQDM"] = "1"
101
+ os.environ["TORCH_HUB_DIR"] = "/tmp/torch_hub_silent"
102
+
103
+ # ============================================================================
104
+ # COLOR CODES FOR TERMINAL OUTPUT
105
+ # ============================================================================
106
+ RESET = "\033[0m"
107
+ BOLD = "\033[1m"
108
+ RED = "\033[91m"
109
+ GREEN = "\033[92m"
110
+ YELLOW = "\033[93m"
111
+ CYAN = "\033[96m"
112
+
113
+ # ============================================================================
114
+ # TTS MODEL PARAMETERS (DEFAULTS)
115
+ # ============================================================================
116
+ DEFAULT_EXAGGERATION = 0.4 # Emotion intensity (0.0-2.0 range)
117
+ DEFAULT_CFG_WEIGHT = 0.5 # Faithfulness to text (0.0-1.0 range)
118
+ DEFAULT_TEMPERATURE = 0.9 # Randomness/creativity (0.0-1.0 range)
119
+
120
+ # ============================================================================
121
+ # VADER SENTIMENT TO TTS PARAMETER MAPPING
122
+ # ============================================================================
123
+ # These settings control how VADER sentiment analysis dynamically adjusts TTS parameters.
124
+ # The formula used is: new_param = base_param + (compound_score * sensitivity)
125
+ # The result is then clamped within the defined MIN/MAX range.
126
+
127
+ # --- Base TTS Parameters (used as the starting point) ---
128
+ # These are the same as the main defaults, but listed here for clarity.
129
+ BASE_EXAGGERATION = DEFAULT_EXAGGERATION # Default: 1.0
130
+ BASE_CFG_WEIGHT = DEFAULT_CFG_WEIGHT # Default: 0.7
131
+ BASE_TEMPERATURE = DEFAULT_TEMPERATURE # Default: 0.7
132
+
133
+ # --- Sensitivity ---
134
+ # How much VADER's compound score affects each parameter.
135
+ # Higher values mean more dramatic changes based on sentiment.
136
+ VADER_EXAGGERATION_SENSITIVITY = 0.5 # e.g., compound of 0.8 -> 1.0 + (0.8 * 0.5) = 1.4
137
+ VADER_CFG_WEIGHT_SENSITIVITY = -0.2 # Negative: more emotional text is less strict
138
+ VADER_TEMPERATURE_SENSITIVITY = 0.15 # More emotional text gets slightly more creative
139
+
140
+ # --- Min/Max Clamps ---
141
+ # Hard limits to prevent extreme, undesirable audio artifacts.
142
+ TTS_PARAM_MIN_EXAGGERATION = 0.1
143
+ TTS_PARAM_MAX_EXAGGERATION = 2.0
144
+ TTS_PARAM_MIN_CFG_WEIGHT = 0.1
145
+ TTS_PARAM_MAX_CFG_WEIGHT = 1.0
146
+
147
+ TTS_PARAM_MIN_TEMPERATURE = 0.1
148
+ TTS_PARAM_MAX_TEMPERATURE = 5.0
149
+
150
+ # ============================================================================
151
+ # BATCH PROCESSING SETTINGS
152
+ # ============================================================================
153
+ BATCH_SIZE = 250 # Larger batches for better speed (monitor VRAM)
154
+ CLEANUP_INTERVAL = 500 # Deep cleanup every N chunks (reduced frequency for speed)
155
+
156
+ # ============================================================================
157
+ # FEATURE TOGGLES
158
+ # ============================================================================
159
+ shutdown_requested = False # Global shutdown flag
HF_Deploy/gradio_main_interface.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ChatterboxTTS DNXS-Spokneword Gradio Main Interface
4
+ Modular web interface with separate tab modules
5
+ """
6
+
7
+ import gradio as gr
8
+ import sys
9
+ import os
10
+ from pathlib import Path
11
+
12
+ # Add the current directory to Python path for imports
13
+ sys.path.append(str(Path(__file__).parent))
14
+
15
+ # Import tab modules
16
+ try:
17
+ from gradio_tabs.tab1_convert_book import create_convert_book_tab
18
+ TAB1_AVAILABLE = True
19
+ except ImportError as e:
20
+ print(f"⚠️ Tab 1 not available: {e}")
21
+ TAB1_AVAILABLE = False
22
+
23
+ try:
24
+ from gradio_tabs.tab6_settings import create_settings_tab_interface
25
+ TAB6_AVAILABLE = True
26
+ except ImportError as e:
27
+ print(f"⚠️ Tab 6 (Settings) not available: {e}")
28
+ TAB6_AVAILABLE = False
29
+
30
+ def create_placeholder_tab(tab_name, tab_number):
31
+ """Create a placeholder tab for future implementation"""
32
+ with gr.Column():
33
+ gr.Markdown(f"# 🚧 {tab_name}")
34
+ gr.Markdown(f"*Tab {tab_number} - Coming Soon*")
35
+ gr.Markdown("This tab will be implemented in a future update.")
36
+
37
+ gr.Button("Placeholder Button", interactive=False)
38
+
39
+ def create_main_interface():
40
+ """Create the main ChatterboxTTS Gradio interface with all tabs"""
41
+
42
+ with gr.Blocks(
43
+ title="ChatterboxTTS - Complete Interface",
44
+ theme=gr.themes.Soft(),
45
+ css="""
46
+ .gradio-container {
47
+ max-width: 1200px !important;
48
+ }
49
+ """
50
+ ) as demo:
51
+
52
+ # Header
53
+ gr.Markdown("""
54
+ # 🎤 ChatterboxTTS - Complete Web Interface
55
+ *Modular audiobook generation system with advanced TTS capabilities*
56
+ """)
57
+
58
+ # Tab interface
59
+ with gr.Tabs():
60
+ # Tab 1: Convert Book (Working)
61
+ if TAB1_AVAILABLE:
62
+ with gr.Tab("1. Convert Book"):
63
+ create_convert_book_tab()
64
+ else:
65
+ with gr.Tab("1. Convert Book"):
66
+ create_placeholder_tab("Convert Book", 1)
67
+
68
+ # Tab 2-10: Placeholders for now
69
+ with gr.Tab("2. File Management"):
70
+ create_placeholder_tab("File Management", 2)
71
+
72
+ with gr.Tab("3. Voice Analysis"):
73
+ create_placeholder_tab("Voice Analysis", 3)
74
+
75
+ with gr.Tab("4. Batch Processing"):
76
+ create_placeholder_tab("Batch Processing", 4)
77
+
78
+ with gr.Tab("5. Audio Tools"):
79
+ create_placeholder_tab("Audio Tools", 5)
80
+
81
+ # Tab 6: Settings (Working)
82
+ if TAB6_AVAILABLE:
83
+ with gr.Tab("6. Settings"):
84
+ create_settings_tab_interface()
85
+ else:
86
+ with gr.Tab("6. Settings"):
87
+ create_placeholder_tab("Settings", 6)
88
+
89
+ with gr.Tab("7. Chunk Tools"):
90
+ create_placeholder_tab("Chunk Tools", 7)
91
+
92
+ with gr.Tab("8. Voice Training"):
93
+ create_placeholder_tab("Voice Training", 8)
94
+
95
+ with gr.Tab("9. System Monitor"):
96
+ create_placeholder_tab("System Monitor", 9)
97
+
98
+ with gr.Tab("10. About"):
99
+ create_placeholder_tab("About", 10)
100
+
101
+ # Footer
102
+ gr.Markdown("""
103
+ ---
104
+ *ChatterboxTTS Gradio Interface - Modular Design*
105
+ Each tab is a separate module for easy maintenance and development.
106
+ """)
107
+
108
+ return demo
109
+
110
+ def launch_interface():
111
+ """Launch the main interface"""
112
+ print("🚀 ChatterboxTTS - Starting Main Interface")
113
+ print("📊 Tab Status:")
114
+ print(f" Tab 1 (Convert Book): {'✅ Available' if TAB1_AVAILABLE else '❌ Not Available'}")
115
+ print(" Tabs 2-10: 🚧 Placeholder (Coming Soon)")
116
+ print("-" * 50)
117
+
118
+ demo = create_main_interface()
119
+
120
+ # Launch configuration
121
+ launch_kwargs = {
122
+ 'server_name': '0.0.0.0',
123
+ 'server_port': 7860,
124
+ 'show_error': True,
125
+ 'quiet': False
126
+ }
127
+
128
+ # Detect cloud environments
129
+ if os.getenv("RUNPOD_POD_ID"):
130
+ print("☁️ RunPod deployment detected")
131
+ launch_kwargs['share'] = True
132
+ elif os.getenv("COLAB_GPU"):
133
+ print("☁️ Google Colab detected")
134
+ launch_kwargs['share'] = True
135
+ else:
136
+ print("💻 Local deployment")
137
+ launch_kwargs['share'] = False
138
+
139
+ print(f"🌐 Interface will be available at: http://localhost:{launch_kwargs['server_port']}")
140
+
141
+ try:
142
+ demo.launch(**launch_kwargs)
143
+ except Exception as e:
144
+ print(f"❌ Error launching interface: {e}")
145
+ raise
146
+
147
+ if __name__ == "__main__":
148
+ launch_interface()
HF_Deploy/gradio_tabs/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ ChatterboxTTS Gradio Tabs Package
3
+ Modular tab system for the web interface
4
+ """
5
+
6
+ # Make this directory a Python package
7
+ __version__ = "1.0.0"
HF_Deploy/gradio_tabs/tab1_convert_book.py ADDED
@@ -0,0 +1,1173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio Tab 1: Convert Book
4
+ Exact replica of PyQt5 GUI Tab 1 functionality
5
+ """
6
+
7
+ import gradio as gr
8
+ import os
9
+ import sys
10
+ import threading
11
+ import subprocess
12
+ import tempfile
13
+ import json
14
+ import warnings
15
+ import re
16
+ import time
17
+ from pathlib import Path
18
+ from typing import List, Dict, Any, Optional, Tuple
19
+
20
+ # Suppress CUDA deprecation warnings
21
+ warnings.filterwarnings("ignore", category=FutureWarning, message=".*torch.backends.cuda.sdp_kernel.*")
22
+ warnings.filterwarnings("ignore", category=FutureWarning, message=".*sdp_kernel.*")
23
+
24
+ # Import ChatterboxTTS modules and ensure all config variables are available
25
+ # First set defaults, then try to import from config
26
+ DEFAULT_EXAGGERATION = 0.4
27
+ DEFAULT_CFG_WEIGHT = 0.5
28
+ DEFAULT_TEMPERATURE = 0.9
29
+ TTS_PARAM_MIN_EXAGGERATION = 0.0
30
+ TTS_PARAM_MAX_EXAGGERATION = 2.0
31
+ TTS_PARAM_MIN_CFG_WEIGHT = 0.0
32
+ TTS_PARAM_MAX_CFG_WEIGHT = 1.0
33
+ TTS_PARAM_MIN_TEMPERATURE = 0.0
34
+ TTS_PARAM_MAX_TEMPERATURE = 5.0
35
+ ENABLE_REGENERATION_LOOP = True
36
+ MAX_REGENERATION_ATTEMPTS = 3
37
+ QUALITY_THRESHOLD = 0.7
38
+ ENABLE_SENTIMENT_SMOOTHING = True
39
+ SENTIMENT_SMOOTHING_WINDOW = 3
40
+ SENTIMENT_SMOOTHING_METHOD = "rolling"
41
+ ENABLE_MFCC_VALIDATION = False
42
+ ENABLE_OUTPUT_VALIDATION = False
43
+ SPECTRAL_ANOMALY_THRESHOLD = 0.8
44
+ OUTPUT_VALIDATION_THRESHOLD = 0.85
45
+
46
+ # Try to import config and override defaults if available
47
+ try:
48
+ from config.config import *
49
+ CONFIG_AVAILABLE = True
50
+ print("✅ Config loaded successfully")
51
+ except ImportError:
52
+ print("⚠️ Config not available - using defaults")
53
+ CONFIG_AVAILABLE = False
54
+
55
+ # Import the actual conversion functions from GUI
56
+ try:
57
+ # We need to import the actual conversion logic
58
+ import importlib.util
59
+ gui_spec = importlib.util.spec_from_file_location("chatterbox_gui", "chatterbox_gui.py")
60
+ gui_module = importlib.util.module_from_spec(gui_spec)
61
+ # We'll access the GUI's conversion methods
62
+ GUI_AVAILABLE = True
63
+ except Exception as e:
64
+ print(f"⚠️ GUI module not available: {e}")
65
+ GUI_AVAILABLE = False
66
+
67
+ # Global state for conversion with enhanced stats
68
+ conversion_state = {
69
+ 'running': False,
70
+ 'progress': 0,
71
+ 'status': 'Ready',
72
+ 'thread': None,
73
+ 'realtime_factor': '--',
74
+ 'vram_usage': '-- GB',
75
+ 'current_chunk': '--',
76
+ 'eta': '--',
77
+ 'elapsed': '--'
78
+ }
79
+
80
+ def parse_progress_stats(output_line):
81
+ """Parse progress statistics from TTS engine output"""
82
+ # Look for progress pattern: "🌀 Chunk 2/13 | ⏱ Elapsed: 0:01:31 | ETA: 0:09:54 | Remaining: 0:08:23 | Realtime: 0.11x | VRAM: 3.3GB"
83
+ progress_pattern = r'🌀 Chunk (\d+)/(\d+).*?Realtime: ([\d.]+)x.*?VRAM: ([\d.]+)GB'
84
+ match = re.search(progress_pattern, output_line)
85
+
86
+ if match:
87
+ current_chunk = int(match.group(1))
88
+ total_chunks = int(match.group(2))
89
+ realtime_factor = f"{match.group(3)}x"
90
+ vram_usage = f"{match.group(4)} GB"
91
+
92
+ # Update global state
93
+ conversion_state['current_chunk'] = f"{current_chunk}/{total_chunks}"
94
+ conversion_state['realtime_factor'] = realtime_factor
95
+ conversion_state['vram_usage'] = vram_usage
96
+ conversion_state['progress'] = int((current_chunk / total_chunks) * 100) if total_chunks > 0 else 0
97
+
98
+ print(f"📊 Stats Updated: Chunk {current_chunk}/{total_chunks}, {realtime_factor}, {vram_usage}")
99
+ return True
100
+ else:
101
+ # Try alternative patterns in case the format is different
102
+ alt_pattern = r'Chunk (\d+)/(\d+).*?Realtime: ([\d.]+)x.*?VRAM: ([\d.]+)GB'
103
+ alt_match = re.search(alt_pattern, output_line)
104
+ if alt_match:
105
+ current_chunk = int(alt_match.group(1))
106
+ total_chunks = int(alt_match.group(2))
107
+ realtime_factor = f"{alt_match.group(3)}x"
108
+ vram_usage = f"{alt_match.group(4)} GB"
109
+
110
+ conversion_state['current_chunk'] = f"{current_chunk}/{total_chunks}"
111
+ conversion_state['realtime_factor'] = realtime_factor
112
+ conversion_state['vram_usage'] = vram_usage
113
+ conversion_state['progress'] = int((current_chunk / total_chunks) * 100) if total_chunks > 0 else 0
114
+
115
+ print(f"📊 Stats Updated: Chunk {current_chunk}/{total_chunks}, {realtime_factor}, {vram_usage}")
116
+ return True
117
+
118
+ return False
119
+
120
+ def get_progress_stats():
121
+ """Get current progress statistics for UI update"""
122
+ return (
123
+ conversion_state['realtime_factor'],
124
+ conversion_state['vram_usage'],
125
+ conversion_state['current_chunk'],
126
+ conversion_state['progress']
127
+ )
128
+
129
+ def get_book_folders():
130
+ """Get available book folders from Text_Input directory"""
131
+ text_input_dir = Path("Text_Input")
132
+ if not text_input_dir.exists():
133
+ return []
134
+
135
+ folders = []
136
+ for item in text_input_dir.iterdir():
137
+ if item.is_dir():
138
+ folders.append(item.name) # Show only folder name, not full path
139
+
140
+ return sorted(folders)
141
+
142
+ def get_text_files_in_folder(folder_name):
143
+ """Get text files in selected book folder"""
144
+ if not folder_name:
145
+ return []
146
+
147
+ # Build full path from folder name
148
+ folder = Path("Text_Input") / folder_name
149
+ if not folder.exists():
150
+ return []
151
+
152
+ text_files = []
153
+ for file in folder.glob("*.txt"):
154
+ text_files.append(file.name)
155
+
156
+ return sorted(text_files)
157
+
158
+ def get_voice_samples():
159
+ """Get available voice samples from Voice_Samples directory"""
160
+ voice_dir = Path("Voice_Samples")
161
+ if not voice_dir.exists():
162
+ return []
163
+
164
+ voices = []
165
+ for file in voice_dir.glob("*.wav"):
166
+ voices.append(file.name) # Show only filename, not full path
167
+
168
+ return sorted(voices)
169
+
170
+ def find_generated_audiobook(book_folder_path, voice_sample_path):
171
+ """Find the generated audiobook files"""
172
+ try:
173
+ book_folder = Path(book_folder_path)
174
+ voice_file = Path(voice_sample_path)
175
+ voice_name = voice_file.stem
176
+
177
+ # Look in Output/ directory first (final audiobooks)
178
+ output_dir = Path("Output")
179
+ if output_dir.exists():
180
+ # Look for M4B files with voice name
181
+ for m4b_file in output_dir.glob(f"*[{voice_name}]*.m4b"):
182
+ if m4b_file.exists():
183
+ return str(m4b_file), "M4B audiobook"
184
+
185
+ # Look for WAV files with voice name
186
+ for wav_file in output_dir.glob(f"*[{voice_name}]*.wav"):
187
+ if wav_file.exists():
188
+ return str(wav_file), "WAV audiobook"
189
+
190
+ # Look in Audiobook/ directory (processing output)
191
+ audiobook_dir = Path("Audiobook") / book_folder.name
192
+ if audiobook_dir.exists():
193
+ # Look for M4B files
194
+ for m4b_file in audiobook_dir.glob(f"*[{voice_name}]*.m4b"):
195
+ if m4b_file.exists():
196
+ return str(m4b_file), "M4B audiobook"
197
+
198
+ # Look for WAV files
199
+ for wav_file in audiobook_dir.glob(f"*[{voice_name}]*.wav"):
200
+ if wav_file.exists():
201
+ return str(wav_file), "WAV audiobook"
202
+
203
+ # Look for combined files
204
+ for combined_file in audiobook_dir.glob("*_combined.*"):
205
+ if combined_file.suffix in ['.wav', '.m4b', '.mp3']:
206
+ return str(combined_file), f"{combined_file.suffix.upper()[1:]} combined audiobook"
207
+
208
+ return None, "No audiobook found"
209
+
210
+ except Exception as e:
211
+ print(f"Error finding audiobook: {e}")
212
+ return None, f"Error: {str(e)}"
213
+
214
+ def run_book_conversion(book_path, text_file_path, voice_path, tts_params, quality_params, config_params):
215
+ """Run the actual book conversion - Direct call to TTS engine with progress monitoring"""
216
+ try:
217
+ # Import the real TTS engine function directly (avoid interface.py)
218
+ from modules.tts_engine import process_book_folder
219
+
220
+ # Extract enable_asr from tts_params (matching GUI exactly)
221
+ enable_asr = tts_params.get('enable_asr', False)
222
+
223
+ print(f"🚀 Starting book conversion with GUI parameters")
224
+ print(f"📖 Book: {book_path}")
225
+ print(f"📄 Text file: {text_file_path}")
226
+ print(f"🎤 Voice: {voice_path}")
227
+ print(f"🎛️ TTS Params: {tts_params}")
228
+ print(f"🔬 Quality Params: {quality_params}")
229
+ print(f"⚙️ Config Params: {config_params}")
230
+
231
+ # Set up progress callback function
232
+ def progress_callback(current_chunk, total_chunks, realtime_factor, vram_usage):
233
+ """Callback function to update progress from TTS engine"""
234
+ conversion_state['current_chunk'] = f"{current_chunk}/{total_chunks}"
235
+ conversion_state['realtime_factor'] = f"{realtime_factor}x"
236
+ conversion_state['vram_usage'] = f"{vram_usage} GB"
237
+ conversion_state['progress'] = int((current_chunk / total_chunks) * 100) if total_chunks > 0 else 0
238
+ print(f"📊 Progress: {current_chunk}/{total_chunks} ({conversion_state['progress']}%) - {realtime_factor}x - {vram_usage}GB")
239
+
240
+ # Add progress callback to config params
241
+ config_params['progress_callback'] = progress_callback
242
+
243
+ # Convert string paths to Path objects (required by TTS engine)
244
+ book_dir_path = Path(book_path)
245
+ voice_path_obj = Path(voice_path)
246
+
247
+ # Auto-detect device with fallback to CPU
248
+ import torch
249
+ if torch.cuda.is_available():
250
+ device = "cuda"
251
+ print("✅ Using CUDA GPU for processing")
252
+ else:
253
+ device = "cpu"
254
+ print("💻 Using CPU for processing (no GPU available)")
255
+
256
+ # Direct call to TTS engine (function only accepts: book_dir, voice_path, tts_params, device, skip_cleanup)
257
+ result = process_book_folder(
258
+ book_dir=book_dir_path,
259
+ voice_path=voice_path_obj,
260
+ tts_params=tts_params,
261
+ device=device,
262
+ skip_cleanup=False
263
+ )
264
+
265
+ print(f"✅ Conversion completed successfully")
266
+ return {'success': True, 'result': result}
267
+
268
+ except Exception as e:
269
+ print(f"❌ Conversion failed: {e}")
270
+ import traceback
271
+ traceback.print_exc()
272
+ return {'success': False, 'error': str(e)}
273
+
274
+ def regenerate_m4b_file(selected_m4b, playback_speed):
275
+ """Regenerate M4B file with new playback speed"""
276
+ if not selected_m4b:
277
+ return "❌ Please select an M4B file first", None
278
+
279
+ try:
280
+ print(f"🔄 Regenerating M4B: {selected_m4b} at {playback_speed}x speed")
281
+
282
+ # Import M4B regeneration tools
283
+ from tools.combine_only import apply_playback_speed_to_m4b
284
+
285
+ # Find the M4B file path
286
+ audiobook_root = Path("Audiobook")
287
+ m4b_path = None
288
+
289
+ for book_dir in audiobook_root.iterdir():
290
+ if book_dir.is_dir():
291
+ for m4b_file in book_dir.glob("*.m4b"):
292
+ if m4b_file.name == selected_m4b:
293
+ m4b_path = m4b_file
294
+ break
295
+ if m4b_path:
296
+ break
297
+
298
+ if not m4b_path:
299
+ return "❌ M4B file not found", None
300
+
301
+ # Create new filename with speed suffix
302
+ speed_suffix = f"_speed{playback_speed}x".replace(".", "p")
303
+ new_name = m4b_path.stem + speed_suffix + ".m4b"
304
+ output_path = m4b_path.parent / new_name
305
+
306
+ # Apply speed change
307
+ success = apply_playback_speed_to_m4b(str(m4b_path), str(output_path), playback_speed)
308
+
309
+ if success:
310
+ return f"✅ Regenerated M4B at {playback_speed}x speed: {new_name}", str(output_path)
311
+ else:
312
+ return "❌ Failed to regenerate M4B", None
313
+
314
+ except Exception as e:
315
+ print(f"❌ M4B regeneration failed: {e}")
316
+ return f"❌ Error: {str(e)}", None
317
+
318
+ def create_convert_book_tab():
319
+ """Create Tab 1: Convert Book with all GUI functionality"""
320
+
321
+ with gr.Column():
322
+ gr.Markdown("# 🚀 Convert Book")
323
+ gr.Markdown("*Main TTS conversion functionality - matches GUI Tab 1*")
324
+
325
+ # Main Content Layout
326
+ with gr.Row():
327
+ # Left Column - File Uploads
328
+ with gr.Column(scale=2):
329
+ gr.Markdown("### 📚 Book Selection")
330
+
331
+ # Book text file upload only
332
+ text_file_upload = gr.File(
333
+ label="📚 Upload Book Text File",
334
+ file_types=[".txt"],
335
+ file_count="single",
336
+ interactive=True
337
+ )
338
+
339
+ gr.Markdown("### 🎤 Voice Selection")
340
+
341
+ # Single voice upload with integrated playback
342
+ voice_file_upload = gr.File(
343
+ label="🎤 Upload Voice Sample",
344
+ file_types=[".wav", ".mp3", ".m4a"],
345
+ file_count="single",
346
+ interactive=True
347
+ )
348
+
349
+ # Voice sample player (becomes active after upload)
350
+ voice_audio = gr.Audio(
351
+ label="Voice Sample Preview",
352
+ interactive=False,
353
+ show_download_button=False,
354
+ visible=False
355
+ )
356
+
357
+ # Right Column - All Settings
358
+ with gr.Column(scale=1):
359
+ gr.Markdown("### ⚙️ Quick Settings")
360
+
361
+ # VADER and ASR
362
+ vader_enabled = gr.Checkbox(
363
+ label="Use VADER sentiment analysis",
364
+ value=True,
365
+ info="Adjust TTS params per chunk based on emotion"
366
+ )
367
+
368
+ # ASR System with intelligent model selection
369
+ with gr.Row():
370
+ asr_enabled = gr.Checkbox(
371
+ label="🎤 Enable ASR validation",
372
+ value=False,
373
+ info="Smart quality control with automatic model selection"
374
+ )
375
+
376
+ # ASR Configuration (initially hidden)
377
+ with gr.Column(visible=False) as asr_config_group:
378
+ gr.Markdown("#### 🔍 ASR Configuration")
379
+
380
+ # System analysis display
381
+ system_analysis = gr.Textbox(
382
+ label="System Analysis",
383
+ value="Click 'Analyze System' to detect capabilities",
384
+ lines=3,
385
+ interactive=False
386
+ )
387
+
388
+ analyze_system_btn = gr.Button(
389
+ "🔍 Analyze System",
390
+ size="sm",
391
+ variant="secondary"
392
+ )
393
+
394
+ # ASR Level Selection
395
+ asr_level = gr.Radio(
396
+ label="ASR Quality Level",
397
+ choices=[
398
+ ("🟢 SAFE - Fast processing, basic accuracy", "safe"),
399
+ ("🟡 MODERATE - Balanced speed/accuracy (recommended)", "moderate"),
400
+ ("🔴 INSANE - Best accuracy, may stress system", "insane")
401
+ ],
402
+ value="moderate",
403
+ info="Automatically selects best models for your system"
404
+ )
405
+
406
+ # Selected models display
407
+ selected_models = gr.Textbox(
408
+ label="Selected ASR Models",
409
+ value="Select level to see model configuration",
410
+ lines=2,
411
+ interactive=False
412
+ )
413
+
414
+ # Batch processing
415
+ add_to_batch = gr.Checkbox(
416
+ label="📦 Add to batch queue",
417
+ value=False,
418
+ info="Queue for batch processing"
419
+ )
420
+
421
+ gr.Markdown("### 🔄 Regeneration Settings")
422
+
423
+ regeneration_enabled = gr.Checkbox(
424
+ label="Enable automatic chunk regeneration",
425
+ value=ENABLE_REGENERATION_LOOP,
426
+ info="Retry failed chunks automatically"
427
+ )
428
+
429
+ max_attempts = gr.Slider(
430
+ label="Max Attempts",
431
+ minimum=1, maximum=10, step=1,
432
+ value=MAX_REGENERATION_ATTEMPTS
433
+ )
434
+
435
+ quality_threshold = gr.Slider(
436
+ label="Quality Threshold",
437
+ minimum=0.1, maximum=1.0, step=0.05,
438
+ value=QUALITY_THRESHOLD
439
+ )
440
+
441
+ gr.Markdown("### 📊 Sentiment Smoothing")
442
+
443
+ sentiment_smoothing = gr.Checkbox(
444
+ label="Enable sentiment smoothing",
445
+ value=ENABLE_SENTIMENT_SMOOTHING,
446
+ info="Smooth emotional transitions"
447
+ )
448
+
449
+ smoothing_window = gr.Slider(
450
+ label="Window Size",
451
+ minimum=1, maximum=10, step=1,
452
+ value=SENTIMENT_SMOOTHING_WINDOW
453
+ )
454
+
455
+ smoothing_method = gr.Dropdown(
456
+ label="Smoothing Method",
457
+ choices=["rolling", "exp_decay"],
458
+ value=SENTIMENT_SMOOTHING_METHOD
459
+ )
460
+
461
+ gr.Markdown("### 🔍 Advanced Detection")
462
+
463
+ mfcc_validation = gr.Checkbox(
464
+ label="MFCC spectral analysis",
465
+ value=ENABLE_MFCC_VALIDATION,
466
+ info="Advanced audio quality detection"
467
+ )
468
+
469
+ output_validation = gr.Checkbox(
470
+ label="Output validation",
471
+ value=ENABLE_OUTPUT_VALIDATION,
472
+ info="Quality control clearinghouse for enabled checks"
473
+ )
474
+
475
+ spectral_threshold = gr.Slider(
476
+ label="Spectral Threshold",
477
+ minimum=0.1, maximum=1.0, step=0.05,
478
+ value=SPECTRAL_ANOMALY_THRESHOLD
479
+ )
480
+
481
+ output_threshold = gr.Slider(
482
+ label="Output Threshold",
483
+ minimum=0.1, maximum=1.0, step=0.05,
484
+ value=OUTPUT_VALIDATION_THRESHOLD
485
+ )
486
+
487
+
488
+ # TTS Parameters
489
+ with gr.Row():
490
+ with gr.Column():
491
+ gr.Markdown("### 🎛️ TTS Parameters")
492
+
493
+ exaggeration = gr.Slider(
494
+ label="Exaggeration",
495
+ minimum=TTS_PARAM_MIN_EXAGGERATION,
496
+ maximum=TTS_PARAM_MAX_EXAGGERATION,
497
+ step=0.1,
498
+ value=DEFAULT_EXAGGERATION,
499
+ info="Emotional intensity"
500
+ )
501
+
502
+ cfg_weight = gr.Slider(
503
+ label="CFG Weight",
504
+ minimum=TTS_PARAM_MIN_CFG_WEIGHT,
505
+ maximum=TTS_PARAM_MAX_CFG_WEIGHT,
506
+ step=0.1,
507
+ value=DEFAULT_CFG_WEIGHT,
508
+ info="Text faithfulness"
509
+ )
510
+
511
+ temperature = gr.Slider(
512
+ label="Temperature",
513
+ minimum=TTS_PARAM_MIN_TEMPERATURE,
514
+ maximum=TTS_PARAM_MAX_TEMPERATURE,
515
+ step=0.1,
516
+ value=DEFAULT_TEMPERATURE,
517
+ info="Creativity/randomness"
518
+ )
519
+
520
+ with gr.Column():
521
+ gr.Markdown("### ⚡ Advanced Sampling")
522
+
523
+ min_p = gr.Slider(
524
+ label="Min-P",
525
+ minimum=0.0, maximum=0.5, step=0.01,
526
+ value=0.05,
527
+ info="Minimum probability threshold"
528
+ )
529
+
530
+ top_p = gr.Slider(
531
+ label="Top-P",
532
+ minimum=0.5, maximum=1.0, step=0.1,
533
+ value=1.0,
534
+ info="Nucleus sampling"
535
+ )
536
+
537
+ repetition_penalty = gr.Slider(
538
+ label="Repetition Penalty",
539
+ minimum=1.0, maximum=3.0, step=0.1,
540
+ value=2.0,
541
+ info="Reduce repetition"
542
+ )
543
+
544
+ gr.Markdown("### ⚙️ Performance Settings")
545
+
546
+ max_workers = gr.Number(
547
+ label="Max Workers",
548
+ minimum=1, maximum=8, step=1,
549
+ value=2,
550
+ info="⚠️ Only increase above 2 if CPU/GPU utilization < 70%"
551
+ )
552
+
553
+ # Action Buttons and Status
554
+ with gr.Row():
555
+ with gr.Column(scale=2):
556
+ convert_btn = gr.Button(
557
+ "🚀 Start Conversion",
558
+ variant="primary",
559
+ size="lg",
560
+ interactive=True
561
+ )
562
+
563
+ # Status Display
564
+ status_display = gr.Textbox(
565
+ label="Status",
566
+ value="⏸ Ready",
567
+ interactive=False,
568
+ lines=1
569
+ )
570
+
571
+ progress_display = gr.Number(
572
+ label="Progress %",
573
+ value=0,
574
+ interactive=False,
575
+ precision=0
576
+ )
577
+
578
+ with gr.Column(scale=1):
579
+ gr.Markdown("### 📊 Processing Stats")
580
+
581
+ realtime_factor = gr.Textbox(
582
+ label="Realtime Factor",
583
+ value="--",
584
+ interactive=False
585
+ )
586
+
587
+ vram_usage = gr.Textbox(
588
+ label="VRAM Usage",
589
+ value="-- GB",
590
+ interactive=False
591
+ )
592
+
593
+ current_chunk = gr.Textbox(
594
+ label="Current Chunk",
595
+ value="--",
596
+ interactive=False
597
+ )
598
+
599
+ # Regenerate M4B Section (moved above audiobook player)
600
+ with gr.Row():
601
+ with gr.Column():
602
+ gr.Markdown("### 🔄 Regenerate M4B")
603
+
604
+ with gr.Row():
605
+ with gr.Column(scale=2):
606
+ m4b_file_selector = gr.Dropdown(
607
+ label="Select M4B File to Regenerate",
608
+ choices=[],
609
+ value=None,
610
+ interactive=True,
611
+ info="Choose from generated audiobook files"
612
+ )
613
+
614
+ with gr.Column(scale=1):
615
+ playback_speed = gr.Slider(
616
+ label="Playback Speed",
617
+ minimum=0.5,
618
+ maximum=2.0,
619
+ step=0.1,
620
+ value=1.0,
621
+ info="Speed adjustment for regeneration"
622
+ )
623
+
624
+ regenerate_m4b_btn = gr.Button(
625
+ "🔄 Regenerate M4B",
626
+ variant="secondary",
627
+ size="lg"
628
+ )
629
+
630
+ # Generated Audiobook Player (simplified, play-only)
631
+ with gr.Row():
632
+ with gr.Column():
633
+ gr.Markdown("### 🎧 Generated Audiobook Player")
634
+
635
+ # Audiobook file selector dropdown
636
+ audiobook_selector = gr.Dropdown(
637
+ label="Select Audiobook",
638
+ choices=[],
639
+ value=None,
640
+ interactive=True,
641
+ info="Choose from session audiobooks"
642
+ )
643
+
644
+ # Main audio player - play only, no upload
645
+ audio_player = gr.Audio(
646
+ label="Audiobook Player",
647
+ value=None,
648
+ interactive=False,
649
+ show_download_button=True,
650
+ show_share_button=False,
651
+ waveform_options=gr.WaveformOptions(
652
+ show_controls=True,
653
+ show_recording_waveform=False,
654
+ skip_length=10
655
+ )
656
+ )
657
+
658
+ # Event Handlers
659
+ def handle_voice_upload(voice_file):
660
+ """Handle voice file upload and show player"""
661
+ if voice_file is None:
662
+ return gr.update(value=None, visible=False)
663
+
664
+ # Show the voice player with uploaded file
665
+ return gr.update(value=voice_file, visible=True)
666
+
667
+ def get_session_audiobooks():
668
+ """Get list of M4B files from current session, sorted by creation time (newest first)"""
669
+ audiobooks = []
670
+
671
+ # Look in Audiobook directory for M4B files
672
+ audiobook_root = Path("Audiobook")
673
+ if audiobook_root.exists():
674
+ for book_dir in audiobook_root.iterdir():
675
+ if book_dir.is_dir():
676
+ # Look for M4B files in book directory
677
+ for m4b_file in book_dir.glob("*.m4b"):
678
+ # Get creation time for sorting
679
+ creation_time = m4b_file.stat().st_mtime
680
+ audiobooks.append((str(m4b_file), m4b_file.name, creation_time))
681
+
682
+ # Also check Output directory
683
+ output_root = Path("Output")
684
+ if output_root.exists():
685
+ for m4b_file in output_root.glob("*.m4b"):
686
+ creation_time = m4b_file.stat().st_mtime
687
+ audiobooks.append((str(m4b_file), m4b_file.name, creation_time))
688
+
689
+ # Sort by creation time (newest first)
690
+ audiobooks.sort(key=lambda x: x[2], reverse=True)
691
+
692
+ # Return just path and name (drop creation time)
693
+ return [(ab[0], ab[1]) for ab in audiobooks]
694
+
695
+ def update_audiobook_dropdowns(latest_file=None):
696
+ """Update audiobook dropdowns - after conversion both show latest, after regeneration only playback updates"""
697
+ audiobooks = get_session_audiobooks()
698
+ choices = [ab[1] for ab in audiobooks] # Just filenames for display
699
+
700
+ # Determine what to set as selected
701
+ if latest_file:
702
+ # Use specific file if provided
703
+ selected_file = latest_file
704
+ elif choices:
705
+ # Default to newest file (first in sorted list)
706
+ selected_file = choices[0]
707
+ else:
708
+ selected_file = None
709
+
710
+ return (
711
+ gr.update(choices=choices, value=selected_file), # audiobook_selector (playback)
712
+ gr.update(choices=choices, value=selected_file) # m4b_file_selector (regeneration source)
713
+ )
714
+
715
+ def update_audiobook_dropdowns_after_conversion():
716
+ """Update both dropdowns to show the newest generated file after conversion"""
717
+ return update_audiobook_dropdowns()
718
+
719
+ def update_playback_only(new_file_name):
720
+ """Update only the playback dropdown after regeneration"""
721
+ audiobooks = get_session_audiobooks()
722
+ choices = [ab[1] for ab in audiobooks]
723
+
724
+ return (
725
+ gr.update(choices=choices, value=new_file_name), # audiobook_selector (playback) - new file
726
+ gr.update() # m4b_file_selector (regeneration) - no change
727
+ )
728
+
729
+ def load_selected_audiobook(selected_audiobook):
730
+ """Load selected audiobook into player"""
731
+ if not selected_audiobook:
732
+ return None
733
+
734
+ # Find the full path for the selected audiobook
735
+ audiobooks = get_session_audiobooks()
736
+ for full_path, filename in audiobooks:
737
+ if filename == selected_audiobook:
738
+ return full_path
739
+
740
+ return None
741
+
742
+ def handle_asr_toggle(asr_enabled_val):
743
+ """Show/hide ASR configuration when ASR is toggled"""
744
+ return gr.update(visible=asr_enabled_val)
745
+
746
+ def analyze_system():
747
+ """Analyze system capabilities and return summary"""
748
+ try:
749
+ from modules.system_detector import get_system_profile, print_system_summary, categorize_system
750
+
751
+ profile = get_system_profile()
752
+ categories = categorize_system(profile)
753
+
754
+ summary = f"🖥️ System Profile:\n"
755
+ summary += f"VRAM: {profile['gpu']['total_mb']:,}MB total, {profile['available_vram_after_tts']:,}MB available after TTS ({categories['vram']} class)\n"
756
+ summary += f"RAM: {profile['ram']['total_mb']:,}MB total, {profile['ram']['available_mb']:,}MB available ({categories['ram']} class)\n"
757
+ summary += f"CPU: {profile['cpu_cores']} cores ({categories['cpu']} class)"
758
+
759
+ if not profile['has_gpu']:
760
+ summary += f"\n⚠️ No CUDA GPU detected - ASR will run on CPU only"
761
+
762
+ return summary
763
+
764
+ except Exception as e:
765
+ return f"❌ Error analyzing system: {str(e)}"
766
+
767
+ def update_asr_models(asr_level_val):
768
+ """Update ASR model display based on selected level"""
769
+ try:
770
+ from modules.system_detector import get_system_profile, recommend_asr_models
771
+
772
+ profile = get_system_profile()
773
+ recommendations = recommend_asr_models(profile)
774
+
775
+ if asr_level_val not in recommendations:
776
+ return "❌ Invalid ASR level selected"
777
+
778
+ config = recommendations[asr_level_val]
779
+ primary = config['primary']
780
+ fallback = config['fallback']
781
+
782
+ result = f"Primary: {primary['model']} on {primary['device'].upper()}\n"
783
+ result += f"Fallback: {fallback['model']} on {fallback['device'].upper()}"
784
+
785
+ if asr_level_val == 'insane':
786
+ result += f"\n⚠️ WARNING: INSANE mode may cause memory pressure"
787
+
788
+ return result
789
+
790
+ except Exception as e:
791
+ return f"❌ Error getting models: {str(e)}"
792
+
793
+ def start_conversion(text_file_upload, voice_file_upload,
794
+ vader_val, asr_val, asr_level_val, add_to_batch_val,
795
+ regen_enabled_val, max_attempts_val, quality_thresh_val,
796
+ sentiment_smooth_val, smooth_window_val, smooth_method_val,
797
+ mfcc_val, output_val, spectral_thresh_val, output_thresh_val,
798
+ exag_val, cfg_val, temp_val, min_p_val, top_p_val, rep_penalty_val,
799
+ max_workers_val):
800
+ """Start the actual book conversion - file upload version"""
801
+
802
+ # Validation
803
+ if not text_file_upload:
804
+ return "❌ Please upload a text file", 0, None, None
805
+ if not voice_file_upload:
806
+ return "❌ Please upload a voice sample", 0, None, None
807
+
808
+ # Check if already running
809
+ if conversion_state['running']:
810
+ return "⚠️ Conversion already in progress", conversion_state['progress'], None, None
811
+
812
+ try:
813
+ # Create temporary book structure from uploads
814
+ import tempfile
815
+ import shutil
816
+ from datetime import datetime
817
+
818
+ # Generate unique book name from text file
819
+ text_filename = Path(text_file_upload).name
820
+ book_name = text_filename.replace('.txt', '').replace(' ', '_')
821
+ timestamp = datetime.now().strftime("%H%M%S")
822
+ unique_book_name = f"{book_name}_{timestamp}"
823
+
824
+ # Create directory structure
825
+ text_input_dir = Path("Text_Input")
826
+ text_input_dir.mkdir(exist_ok=True)
827
+
828
+ book_dir = text_input_dir / unique_book_name
829
+ book_dir.mkdir(exist_ok=True)
830
+
831
+ # Copy uploaded files to expected locations
832
+ text_dest = book_dir / f"{unique_book_name}.txt"
833
+ shutil.copy2(text_file_upload, text_dest)
834
+
835
+ voice_samples_dir = Path("Voice_Samples")
836
+ voice_samples_dir.mkdir(exist_ok=True)
837
+
838
+ voice_filename = Path(voice_file_upload).name
839
+ voice_dest = voice_samples_dir / voice_filename
840
+ shutil.copy2(voice_file_upload, voice_dest)
841
+
842
+ print(f"📁 Created book structure: {book_dir}")
843
+ print(f"📄 Text file: {text_dest}")
844
+ print(f"🎤 Voice file: {voice_dest}")
845
+
846
+ except Exception as e:
847
+ return f"❌ Error setting up files: {e}", 0, None, None
848
+
849
+ # Build ASR configuration first
850
+ asr_config = {'enabled': False}
851
+ if asr_val:
852
+ try:
853
+ from modules.system_detector import get_system_profile, recommend_asr_models
854
+ profile = get_system_profile()
855
+ recommendations = recommend_asr_models(profile)
856
+
857
+ if asr_level_val in recommendations:
858
+ selected_config = recommendations[asr_level_val]
859
+ primary = selected_config['primary']
860
+ fallback = selected_config['fallback']
861
+
862
+ asr_config = {
863
+ 'enabled': True,
864
+ 'level': asr_level_val,
865
+ 'primary_model': primary['model'],
866
+ 'primary_device': primary['device'],
867
+ 'fallback_model': fallback['model'],
868
+ 'fallback_device': fallback['device']
869
+ }
870
+ except Exception as e:
871
+ print(f"⚠️ Error configuring ASR: {e}")
872
+ asr_config = {'enabled': False}
873
+
874
+ # Prepare parameters (matching GUI structure exactly)
875
+ tts_params = {
876
+ 'exaggeration': exag_val,
877
+ 'cfg_weight': cfg_val,
878
+ 'temperature': temp_val,
879
+ 'min_p': min_p_val,
880
+ 'top_p': top_p_val,
881
+ 'repetition_penalty': rep_penalty_val,
882
+ 'enable_asr': asr_config.get('enabled', False), # Match GUI pattern
883
+ 'max_workers': int(max_workers_val) # User-defined worker count
884
+ }
885
+
886
+ quality_params = {
887
+ 'regeneration_enabled': regen_enabled_val,
888
+ 'max_attempts': max_attempts_val,
889
+ 'quality_threshold': quality_thresh_val,
890
+ 'sentiment_smoothing': sentiment_smooth_val,
891
+ 'smoothing_window': smooth_window_val,
892
+ 'smoothing_method': smooth_method_val,
893
+ 'mfcc_validation': mfcc_val,
894
+ 'output_validation': output_val,
895
+ 'spectral_threshold': spectral_thresh_val,
896
+ 'output_threshold': output_thresh_val
897
+ }
898
+
899
+ config_params = {
900
+ 'vader_enabled': vader_val,
901
+ 'asr_enabled': asr_val,
902
+ 'asr_config': asr_config,
903
+ 'add_to_batch': add_to_batch_val
904
+ }
905
+
906
+ # Set conversion state
907
+ conversion_state['running'] = True
908
+ conversion_state['progress'] = 0
909
+ conversion_state['status'] = 'Starting conversion...'
910
+ conversion_state['current_book'] = book_dir.name # Track current book
911
+
912
+ try:
913
+ # Run conversion using the modular backend in a separate thread
914
+ import threading
915
+
916
+ def run_conversion_thread():
917
+ try:
918
+ result = run_book_conversion(
919
+ str(book_dir), str(text_dest), str(voice_dest),
920
+ tts_params, quality_params, config_params
921
+ )
922
+
923
+ if result['success']:
924
+ conversion_state['status'] = '🎉 CONVERSION COMPLETE! M4B audiobook ready for playback.'
925
+ conversion_state['progress'] = 100
926
+ conversion_state['auto_refresh_needed'] = True # Flag for auto-refresh
927
+ else:
928
+ conversion_state['status'] = f"❌ Conversion failed: {result.get('error', 'Unknown error')}"
929
+ conversion_state['progress'] = 0
930
+
931
+ except Exception as e:
932
+ conversion_state['status'] = f"❌ Error: {str(e)}"
933
+ conversion_state['progress'] = 0
934
+ finally:
935
+ conversion_state['running'] = False
936
+
937
+ # Start conversion thread
938
+ thread = threading.Thread(target=run_conversion_thread)
939
+ thread.start()
940
+
941
+ # Return immediate response - user will need to refresh to see final results
942
+ return (
943
+ "🚀 Conversion started in background...",
944
+ 5, # Initial progress
945
+ None,
946
+ gr.update(),
947
+ gr.update()
948
+ )
949
+
950
+ except Exception as e:
951
+ conversion_state['status'] = f"❌ Error: {str(e)}"
952
+ return conversion_state['status'], 0, None, gr.update(), gr.update()
953
+ finally:
954
+ conversion_state['running'] = False
955
+
956
+
957
+ # Connect event handlers
958
+
959
+ # ASR event handlers
960
+ asr_enabled.change(
961
+ handle_asr_toggle,
962
+ inputs=[asr_enabled],
963
+ outputs=[asr_config_group]
964
+ )
965
+
966
+ analyze_system_btn.click(
967
+ analyze_system,
968
+ inputs=[],
969
+ outputs=[system_analysis]
970
+ )
971
+
972
+ asr_level.change(
973
+ update_asr_models,
974
+ inputs=[asr_level],
975
+ outputs=[selected_models]
976
+ )
977
+
978
+ # Voice upload handler
979
+ voice_file_upload.change(
980
+ handle_voice_upload,
981
+ inputs=[voice_file_upload],
982
+ outputs=[voice_audio]
983
+ )
984
+
985
+ # Main conversion handler
986
+ convert_btn.click(
987
+ start_conversion,
988
+ inputs=[
989
+ text_file_upload, voice_file_upload,
990
+ vader_enabled, asr_enabled, asr_level, add_to_batch,
991
+ regeneration_enabled, max_attempts, quality_threshold,
992
+ sentiment_smoothing, smoothing_window, smoothing_method,
993
+ mfcc_validation, output_validation, spectral_threshold, output_threshold,
994
+ exaggeration, cfg_weight, temperature, min_p, top_p, repetition_penalty,
995
+ max_workers
996
+ ],
997
+ outputs=[status_display, progress_display, audio_player, audiobook_selector, m4b_file_selector]
998
+ )
999
+
1000
+ # Audiobook selector handler
1001
+ audiobook_selector.change(
1002
+ load_selected_audiobook,
1003
+ inputs=[audiobook_selector],
1004
+ outputs=[audio_player]
1005
+ )
1006
+
1007
+ # M4B regeneration handler
1008
+ def handle_m4b_regeneration(selected_m4b, speed):
1009
+ """Handle M4B regeneration and update player"""
1010
+ status_msg, new_m4b_path = regenerate_m4b_file(selected_m4b, speed)
1011
+
1012
+ if new_m4b_path:
1013
+ # Load the new M4B in the player
1014
+ new_file_name = Path(new_m4b_path).name
1015
+ new_audio = load_selected_audiobook(new_file_name)
1016
+ # Update only playback dropdown, keep regeneration dropdown on source file
1017
+ audiobook_choices, m4b_choices = update_playback_only(new_file_name)
1018
+ return status_msg, new_audio, audiobook_choices, m4b_choices
1019
+ else:
1020
+ return status_msg, None, gr.update(), gr.update()
1021
+
1022
+ regenerate_m4b_btn.click(
1023
+ handle_m4b_regeneration,
1024
+ inputs=[m4b_file_selector, playback_speed],
1025
+ outputs=[status_display, audio_player, audiobook_selector, m4b_file_selector]
1026
+ )
1027
+
1028
+ # Progress monitoring with file-based approach
1029
+ def get_current_stats():
1030
+ """Get current progress statistics by monitoring output files"""
1031
+ try:
1032
+ if conversion_state['running']:
1033
+ # Look for generated audio chunks to estimate progress
1034
+ book_name = conversion_state.get('current_book', 'unknown')
1035
+ audiobook_root = Path("Audiobook") / book_name / "TTS" / "audio_chunks"
1036
+
1037
+ if audiobook_root.exists():
1038
+ chunk_files = list(audiobook_root.glob("chunk_*.wav"))
1039
+ current_chunks = len(chunk_files)
1040
+
1041
+ # Try to estimate total from JSON if available
1042
+ json_path = Path("Text_Input") / f"{book_name}_chunks.json"
1043
+ total_chunks = 0
1044
+ if json_path.exists():
1045
+ import json
1046
+ with open(json_path, 'r') as f:
1047
+ data = json.load(f)
1048
+ total_chunks = len(data)
1049
+
1050
+ if total_chunks > 0:
1051
+ progress = int((current_chunks / total_chunks) * 100)
1052
+ conversion_state['progress'] = progress
1053
+ conversion_state['current_chunk'] = f"{current_chunks}/{total_chunks}"
1054
+
1055
+ return (
1056
+ conversion_state.get('realtime_factor', '--'),
1057
+ conversion_state.get('vram_usage', '-- GB'),
1058
+ f"{current_chunks}/{total_chunks}",
1059
+ progress
1060
+ )
1061
+
1062
+ return (
1063
+ conversion_state.get('realtime_factor', '--'),
1064
+ conversion_state.get('vram_usage', '-- GB'),
1065
+ conversion_state.get('current_chunk', '--'),
1066
+ conversion_state.get('progress', 0)
1067
+ )
1068
+ except Exception as e:
1069
+ print(f"Error getting stats: {e}")
1070
+ return "--", "-- GB", "--", conversion_state.get('progress', 0)
1071
+
1072
+ def auto_check_completion():
1073
+ """Automatically check for completion and refresh interface"""
1074
+ # First get current stats
1075
+ stats = get_current_stats()
1076
+
1077
+ # Check if conversion just completed and needs auto-refresh
1078
+ if (not conversion_state['running'] and
1079
+ conversion_state['progress'] == 100 and
1080
+ conversion_state.get('auto_refresh_needed', False)):
1081
+
1082
+ # Clear the auto-refresh flag
1083
+ conversion_state['auto_refresh_needed'] = False
1084
+ print("🎉 Auto-detected completion! Refreshing interface...")
1085
+
1086
+ # Get completion results
1087
+ status, progress, audio, audiobook_choices, m4b_choices = get_status_and_results()
1088
+
1089
+ # Return combined stats + completion results
1090
+ return (
1091
+ stats[0], # realtime_factor
1092
+ stats[1], # vram_usage
1093
+ stats[2], # current_chunk
1094
+ 100, # progress (completed)
1095
+ status, # completion status
1096
+ audio, # audio player
1097
+ audiobook_choices, # audiobook dropdown
1098
+ m4b_choices # m4b dropdown
1099
+ )
1100
+ else:
1101
+ # Return stats + current status (no completion)
1102
+ return (
1103
+ stats[0], # realtime_factor
1104
+ stats[1], # vram_usage
1105
+ stats[2], # current_chunk
1106
+ stats[3], # progress
1107
+ conversion_state.get('status', '⏸ Ready'), # current status
1108
+ gr.update(), # no audio update
1109
+ gr.update(), # no audiobook update
1110
+ gr.update() # no m4b update
1111
+ )
1112
+
1113
+ def get_status_and_results():
1114
+ """Get conversion status and results after completion"""
1115
+ if not conversion_state['running'] and conversion_state['progress'] == 100:
1116
+ # Conversion completed, update dropdowns
1117
+ audiobook_choices, m4b_choices = update_audiobook_dropdowns_after_conversion()
1118
+ latest_audiobook = None
1119
+ if audiobook_choices['choices']:
1120
+ latest_audiobook = load_selected_audiobook(audiobook_choices['choices'][0])
1121
+
1122
+ return (
1123
+ conversion_state['status'],
1124
+ conversion_state['progress'],
1125
+ latest_audiobook,
1126
+ audiobook_choices,
1127
+ m4b_choices
1128
+ )
1129
+ else:
1130
+ return (
1131
+ conversion_state['status'],
1132
+ conversion_state['progress'],
1133
+ None,
1134
+ gr.update(),
1135
+ gr.update()
1136
+ )
1137
+
1138
+ # Create refresh buttons
1139
+ with gr.Row():
1140
+ refresh_stats_btn = gr.Button("🔄 Refresh Stats", size="sm", variant="secondary")
1141
+ check_completion_btn = gr.Button("📋 Check Completion", size="sm", variant="secondary")
1142
+
1143
+ # Auto-refresh timer (checks every 5 seconds during conversion)
1144
+ auto_timer = gr.Timer(5.0) # 5 second interval
1145
+
1146
+ refresh_stats_btn.click(
1147
+ auto_check_completion,
1148
+ outputs=[realtime_factor, vram_usage, current_chunk, progress_display, status_display, audio_player, audiobook_selector, m4b_file_selector]
1149
+ )
1150
+
1151
+ check_completion_btn.click(
1152
+ get_status_and_results,
1153
+ outputs=[status_display, progress_display, audio_player, audiobook_selector, m4b_file_selector]
1154
+ )
1155
+
1156
+ # Auto-timer for progress monitoring and completion detection
1157
+ auto_timer.tick(
1158
+ auto_check_completion,
1159
+ outputs=[realtime_factor, vram_usage, current_chunk, progress_display, status_display, audio_player, audiobook_selector, m4b_file_selector]
1160
+ )
1161
+
1162
+ return {
1163
+ 'convert_button': convert_btn,
1164
+ 'status_display': status_display,
1165
+ 'progress': progress_display
1166
+ }
1167
+
1168
+ if __name__ == "__main__":
1169
+ # Test the tab
1170
+ with gr.Blocks() as demo:
1171
+ create_convert_book_tab()
1172
+
1173
+ demo.launch()
HF_Deploy/modules/__init__.py ADDED
File without changes
HF_Deploy/modules/asr_manager.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ASR Manager Module
3
+ Centralized ASR model loading with adaptive GPU/CPU fallback and real-time VRAM monitoring
4
+ """
5
+
6
+ import torch
7
+ import logging
8
+ from pathlib import Path
9
+ from config.config import DEFAULT_ASR_MODEL, ASR_MODEL_VRAM_MB, ASR_MODEL_RAM_MB
10
+
11
+ def get_real_time_vram_status():
12
+ """Get current GPU memory usage in real-time"""
13
+ try:
14
+ if torch.cuda.is_available():
15
+ gpu_count = torch.cuda.device_count()
16
+ if gpu_count > 0:
17
+ # Use first GPU
18
+ total_vram = torch.cuda.get_device_properties(0).total_memory
19
+ allocated_vram = torch.cuda.memory_allocated(0)
20
+ reserved_vram = torch.cuda.memory_reserved(0)
21
+ available_vram = total_vram - allocated_vram
22
+
23
+ return {
24
+ 'total_mb': total_vram // 1024 // 1024,
25
+ 'allocated_mb': allocated_vram // 1024 // 1024,
26
+ 'reserved_mb': reserved_vram // 1024 // 1024,
27
+ 'available_mb': available_vram // 1024 // 1024,
28
+ 'has_gpu': True
29
+ }
30
+ except Exception as e:
31
+ logging.warning(f"Failed to get real-time VRAM status: {e}")
32
+
33
+ return {
34
+ 'total_mb': 0,
35
+ 'allocated_mb': 0,
36
+ 'reserved_mb': 0,
37
+ 'available_mb': 0,
38
+ 'has_gpu': False
39
+ }
40
+
41
+ def calculate_available_vram_for_asr(safety_buffer_mb=500):
42
+ """Calculate VRAM available for ASR with safety buffer"""
43
+ vram_status = get_real_time_vram_status()
44
+
45
+ if not vram_status['has_gpu']:
46
+ return 0
47
+
48
+ # Available VRAM minus safety buffer for stability
49
+ available_with_buffer = max(0, vram_status['available_mb'] - safety_buffer_mb)
50
+
51
+ return available_with_buffer
52
+
53
+ def can_model_fit_gpu(model_name, available_vram_mb):
54
+ """Check if a specific ASR model can fit in available VRAM"""
55
+ required_vram = ASR_MODEL_VRAM_MB.get(model_name, 0)
56
+ return available_vram_mb >= required_vram
57
+
58
+ def try_load_model_with_fallback(model_name, primary_device, fallback_device="cpu"):
59
+ """Try to load model on primary device, fallback to secondary if it fails"""
60
+ import whisper
61
+
62
+ # Convert device names for whisper compatibility
63
+ def convert_device_name(device):
64
+ if device.lower() == "gpu":
65
+ return "cuda"
66
+ return device.lower()
67
+
68
+ primary_device_whisper = convert_device_name(primary_device)
69
+ fallback_device_whisper = convert_device_name(fallback_device)
70
+
71
+ try:
72
+ print(f"🎯 Attempting to load {model_name} on {primary_device.upper()}")
73
+ model = whisper.load_model(model_name, device=primary_device_whisper)
74
+ print(f"✅ Successfully loaded {model_name} on {primary_device.upper()}")
75
+ return model, primary_device
76
+
77
+ except Exception as e:
78
+ print(f"⚠️ {model_name} failed on {primary_device} ({str(e)[:50]}...)")
79
+
80
+ if fallback_device_whisper != primary_device_whisper:
81
+ try:
82
+ print(f"🔄 Trying {model_name} on {fallback_device.upper()}")
83
+ model = whisper.load_model(model_name, device=fallback_device_whisper)
84
+ print(f"✅ Successfully loaded {model_name} on {fallback_device.upper()}")
85
+ return model, fallback_device
86
+
87
+ except Exception as fallback_e:
88
+ print(f"❌ {model_name} also failed on {fallback_device} ({str(fallback_e)[:50]}...)")
89
+
90
+ # Both failed
91
+ raise Exception(f"Model {model_name} failed on both {primary_device} and {fallback_device}")
92
+
93
+ def load_asr_model_adaptive(asr_config=None):
94
+ """
95
+ Adaptive ASR model loading with real-time VRAM checking and intelligent fallback
96
+
97
+ Args:
98
+ asr_config: ASR configuration dict from interfaces (None for GUI fallback)
99
+
100
+ Returns:
101
+ tuple: (asr_model, actual_device_used) or (None, None) if all loading fails
102
+ """
103
+ print(f"🔍 Starting adaptive ASR model loading...")
104
+
105
+ # Get current VRAM status
106
+ vram_status = get_real_time_vram_status()
107
+ available_vram = calculate_available_vram_for_asr()
108
+
109
+ print(f"🖥️ Real-time VRAM status:")
110
+ print(f" Total: {vram_status['total_mb']:,}MB")
111
+ print(f" Allocated: {vram_status['allocated_mb']:,}MB")
112
+ print(f" Available for ASR: {available_vram:,}MB (with 500MB safety buffer)")
113
+
114
+ # Determine what models to try based on config
115
+ if asr_config and asr_config.get('enabled') and 'primary_model' in asr_config:
116
+ # Intelligent selection from CLI/Gradio
117
+ primary_model = asr_config['primary_model']
118
+ primary_device = asr_config['primary_device']
119
+ fallback_model = asr_config['fallback_model']
120
+ fallback_device = asr_config['fallback_device']
121
+
122
+ print(f"🧠 Using intelligent ASR config:")
123
+ print(f" Primary: {primary_model} on {primary_device.upper()}")
124
+ print(f" Fallback: {fallback_model} on {fallback_device.upper()}")
125
+
126
+ # Real-time VRAM check for primary model
127
+ if primary_device.lower() == 'gpu':
128
+ if not vram_status['has_gpu']:
129
+ print(f"⚠️ No GPU available, forcing CPU mode")
130
+ primary_device = 'cpu'
131
+ elif not can_model_fit_gpu(primary_model, available_vram):
132
+ required = ASR_MODEL_VRAM_MB.get(primary_model, 0)
133
+ print(f"⚠️ Insufficient VRAM for {primary_model} (need {required}MB, have {available_vram}MB)")
134
+ print(f"🔄 Switching primary to CPU")
135
+ primary_device = 'cpu'
136
+
137
+ # Try primary model
138
+ try:
139
+ return try_load_model_with_fallback(primary_model, primary_device, primary_device)
140
+ except:
141
+ # Primary failed, try fallback model
142
+ print(f"🔄 Primary model failed, trying fallback configuration...")
143
+
144
+ # Real-time VRAM check for fallback model
145
+ if fallback_device.lower() == 'gpu':
146
+ if not vram_status['has_gpu']:
147
+ print(f"⚠️ No GPU available for fallback, using CPU")
148
+ fallback_device = 'cpu'
149
+ elif not can_model_fit_gpu(fallback_model, available_vram):
150
+ required = ASR_MODEL_VRAM_MB.get(fallback_model, 0)
151
+ print(f"⚠️ Insufficient VRAM for fallback {fallback_model} (need {required}MB, have {available_vram}MB)")
152
+ fallback_device = 'cpu'
153
+
154
+ try:
155
+ return try_load_model_with_fallback(fallback_model, fallback_device, 'cpu')
156
+ except:
157
+ print(f"❌ Both configured models failed!")
158
+
159
+ else:
160
+ # Fallback mode for GUI or missing config
161
+ print(f"🔧 Using fallback mode: {DEFAULT_ASR_MODEL}")
162
+
163
+ # Last resort: try default model with adaptive device selection
164
+ print(f"🆘 Last resort: trying {DEFAULT_ASR_MODEL} with adaptive device selection")
165
+
166
+ # Choose device based on real-time VRAM availability
167
+ if vram_status['has_gpu'] and can_model_fit_gpu(DEFAULT_ASR_MODEL, available_vram):
168
+ device = 'cuda' # Use cuda directly for whisper
169
+ device_display = 'GPU'
170
+ print(f"✅ Using GPU for {DEFAULT_ASR_MODEL}")
171
+ else:
172
+ device = 'cpu'
173
+ device_display = 'CPU'
174
+ print(f"🔄 Using CPU for {DEFAULT_ASR_MODEL}")
175
+
176
+ try:
177
+ import whisper
178
+ model = whisper.load_model(DEFAULT_ASR_MODEL, device=device)
179
+ print(f"✅ Successfully loaded {DEFAULT_ASR_MODEL} on {device_display}")
180
+ return model, device_display.lower()
181
+ except Exception as e:
182
+ print(f"❌ Critical failure: Could not load {DEFAULT_ASR_MODEL} on {device}: {e}")
183
+
184
+ # Ultimate fallback to CPU if GPU failed
185
+ if device == 'cuda':
186
+ try:
187
+ print(f"🆘 Ultimate fallback: {DEFAULT_ASR_MODEL} on CPU")
188
+ model = whisper.load_model(DEFAULT_ASR_MODEL, device='cpu')
189
+ print(f"✅ Successfully loaded {DEFAULT_ASR_MODEL} on CPU")
190
+ return model, 'cpu'
191
+ except Exception as cpu_e:
192
+ print(f"💀 Complete failure: {cpu_e}")
193
+
194
+ return None, None
195
+
196
+ def cleanup_asr_model(asr_model):
197
+ """Clean up ASR model to free memory"""
198
+ if asr_model is not None:
199
+ try:
200
+ del asr_model
201
+ if torch.cuda.is_available():
202
+ torch.cuda.empty_cache()
203
+ print(f"🧹 ASR model cleaned up")
204
+ except Exception as e:
205
+ logging.warning(f"Failed to cleanup ASR model: {e}")
206
+
207
+ def get_asr_memory_info():
208
+ """Get memory information for ASR debugging"""
209
+ vram_status = get_real_time_vram_status()
210
+ available_vram = calculate_available_vram_for_asr()
211
+
212
+ info = {
213
+ 'vram_total_mb': vram_status['total_mb'],
214
+ 'vram_allocated_mb': vram_status['allocated_mb'],
215
+ 'vram_available_for_asr_mb': available_vram,
216
+ 'has_gpu': vram_status['has_gpu']
217
+ }
218
+
219
+ return info
220
+
221
+ if __name__ == "__main__":
222
+ # Test the adaptive loading
223
+ print("Testing ASR Manager...")
224
+ info = get_asr_memory_info()
225
+ print(f"Memory info: {info}")
226
+
227
+ # Test adaptive loading
228
+ model, device = load_asr_model_adaptive()
229
+ if model:
230
+ print(f"Test successful: Model loaded on {device}")
231
+ cleanup_asr_model(model)
232
+ else:
233
+ print("Test failed: No model loaded")
HF_Deploy/modules/audio_processor.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio Processing Module
3
+ Handles audio validation, effects, cleanup, and quality control
4
+ """
5
+
6
+ import numpy as np
7
+ import soundfile as sf
8
+ import logging
9
+ import shutil
10
+ import re
11
+ import time
12
+ from pathlib import Path
13
+ from pydub import AudioSegment, silence
14
+ from config.config import *
15
+
16
+ # ============================================================================
17
+ # AUDIO QUALITY DETECTION
18
+ # ============================================================================
19
+
20
+ def check_audio_health(wav_path):
21
+ """Enhanced audio health checking"""
22
+ data, samplerate = sf.read(str(wav_path))
23
+ if len(data.shape) > 1:
24
+ data = data[:, 0] # mono only
25
+
26
+ clipping = np.mean(np.abs(data) > 0.98)
27
+ silence_ratio = np.mean(np.abs(data) < 1e-4)
28
+ rms = np.sqrt(np.mean(data**2))
29
+ mean_abs = np.mean(np.abs(data))
30
+ flatness = mean_abs / (rms + 1e-8)
31
+
32
+ return {
33
+ "clipping_ratio": round(clipping, 4),
34
+ "silence_ratio": round(silence_ratio, 4),
35
+ "flatness": round(flatness, 4),
36
+ }
37
+
38
+ def detect_tts_hum_artifact(wav_path):
39
+ """
40
+ Detect low-frequency TTS confusion hum using configurable parameters
41
+ """
42
+ if not ENABLE_HUM_DETECTION:
43
+ return False, {}
44
+
45
+ data, sr = sf.read(str(wav_path))
46
+ if data.ndim > 1:
47
+ data = data[:, 0] # Mono
48
+
49
+ # FFT analysis for frequency content
50
+ fft = np.fft.rfft(data)
51
+ freqs = np.fft.rfftfreq(len(data), 1/sr)
52
+
53
+ # Focus on hum frequency range (configurable at top of file)
54
+ hum_mask = (freqs >= HUM_FREQ_MIN) & (freqs <= HUM_FREQ_MAX)
55
+ hum_energy = np.sum(np.abs(fft[hum_mask]))
56
+ total_energy = np.sum(np.abs(fft))
57
+
58
+ # Check for sustained low-level amplitude (steady hum characteristic)
59
+ segment_size = sr // 4 # 250ms segments
60
+ segments = [data[i:i+segment_size] for i in range(0, len(data)-segment_size, segment_size)]
61
+
62
+ steady_segments = 0
63
+ for segment in segments:
64
+ rms = np.sqrt(np.mean(segment**2))
65
+ if HUM_AMPLITUDE_MIN < rms < HUM_AMPLITUDE_MAX:
66
+ steady_segments += 1
67
+
68
+ # Calculate hum indicators using configurable thresholds
69
+ hum_ratio = hum_energy / (total_energy + 1e-10)
70
+ steady_ratio = steady_segments / len(segments) if segments else 0
71
+
72
+ # Detection logic using configurable thresholds
73
+ has_hum = (hum_ratio > HUM_ENERGY_THRESHOLD) and (steady_ratio > HUM_STEADY_THRESHOLD)
74
+
75
+ if has_hum:
76
+ logging.info(f"🔍 TTS hum detected: {wav_path.name}")
77
+ logging.info(f" Frequency range: {HUM_FREQ_MIN}-{HUM_FREQ_MAX}Hz")
78
+ logging.info(f" Hum energy ratio: {hum_ratio:.3f} (threshold: {HUM_ENERGY_THRESHOLD})")
79
+ logging.info(f" Steady segments: {steady_ratio:.3f} (threshold: {HUM_STEADY_THRESHOLD})")
80
+
81
+ return has_hum, {
82
+ "hum_ratio": hum_ratio,
83
+ "steady_ratio": steady_ratio,
84
+ "freq_range": f"{HUM_FREQ_MIN}-{HUM_FREQ_MAX}Hz"
85
+ }
86
+
87
+ def smart_audio_validation(wav_path):
88
+ """Comprehensive audio validation with intelligent responses"""
89
+ # Standard health check
90
+ health = check_audio_health(wav_path)
91
+
92
+ # TTS hum detection (if enabled)
93
+ has_hum, hum_metrics = detect_tts_hum_artifact(wav_path)
94
+
95
+ # Decision matrix
96
+ if health["clipping_ratio"] > 0.05:
97
+ return handle_problematic_chunks(wav_path, "clipping", health)
98
+ elif health["flatness"] > 0.9:
99
+ return handle_problematic_chunks(wav_path, "corrupted", health)
100
+ elif has_hum:
101
+ return handle_problematic_chunks(wav_path, "tts_hum", hum_metrics)
102
+ else:
103
+ return wav_path # Passed all checks
104
+
105
+ def has_mid_energy_drop(wav_tensor, sr, window_ms=250, threshold_ratio=None):
106
+ """Detect mid-chunk energy drops"""
107
+ wav = wav_tensor.squeeze().numpy()
108
+ win_samples = int(sr * window_ms / 1000)
109
+ segments = [wav[i:i+win_samples] for i in range(0, len(wav) - win_samples, win_samples)]
110
+
111
+ rms_vals = [np.sqrt(np.mean(seg**2)) for seg in segments]
112
+ rms_avg = np.mean(rms_vals)
113
+ dynamic_thresh = threshold_ratio or max(0.02, 0.1 if rms_avg < 0.01 else 0.2)
114
+
115
+ drop_sequence = 0
116
+ consecutive_required = 2
117
+
118
+ for i, rms in enumerate(rms_vals):
119
+ if i < 3:
120
+ continue
121
+ if rms < rms_avg * dynamic_thresh:
122
+ drop_sequence += 1
123
+ if drop_sequence >= consecutive_required:
124
+ return True
125
+ else:
126
+ drop_sequence = 0
127
+
128
+ return False
129
+
130
+ # ============================================================================
131
+ # PROBLEMATIC CHUNK HANDLING
132
+ # ============================================================================
133
+
134
+ def handle_problematic_chunks(wav_path, issue_type, metrics):
135
+ """Handle chunks with audio issues - quarantine for review"""
136
+ quarantine_dir = wav_path.parent / "quarantine"
137
+ quarantine_dir.mkdir(exist_ok=True)
138
+
139
+ # Move to quarantine with descriptive name
140
+ quarantine_path = quarantine_dir / f"{wav_path.stem}_{issue_type}.wav"
141
+ shutil.move(str(wav_path), str(quarantine_path))
142
+
143
+ # Log for user review
144
+ logging.warning(f"🚨 Quarantined {issue_type}: {wav_path.name} → {quarantine_path.name}")
145
+ logging.warning(f" Metrics: {metrics}")
146
+
147
+ return quarantine_path
148
+
149
+ def pause_for_chunk_review(quarantine_dir):
150
+ """Pause processing to allow manual chunk review/editing with proper workflow"""
151
+ quarantined_files = list(quarantine_dir.glob("*.wav"))
152
+
153
+ if not quarantined_files:
154
+ return # No quarantined files, continue normally
155
+
156
+ print(f"\n⚠️ {len(quarantined_files)} chunks quarantined in: {quarantine_dir}")
157
+ print("\nQuarantined chunks:")
158
+ for qfile in quarantined_files:
159
+ print(f" 📁 {qfile.name}")
160
+
161
+ print("\n🔧 Options:")
162
+ print("1. Continue processing (use quarantined chunks as-is)")
163
+ print("2. Pause to manually review/edit chunks")
164
+
165
+ while True:
166
+ choice = input("\nEnter choice [1/2]: ").strip()
167
+ if choice in ['1', '2']:
168
+ break
169
+ print("❌ Invalid choice. Please enter 1 or 2.")
170
+
171
+ if choice == "2":
172
+ print(f"\n🛑 Processing paused for manual review.")
173
+ print(f"📂 Quarantined chunks are in: {quarantine_dir}")
174
+ print("\n📝 Instructions:")
175
+ print(" 1. Edit the audio files in the quarantine folder")
176
+ print(" 2. Keep the original filenames (chunk numbering intact)")
177
+ print(" 3. Leave edited files IN the quarantine folder")
178
+ print(" 4. Press Enter below to continue processing")
179
+
180
+ input("\n⏸️ Press Enter when you've finished editing...")
181
+
182
+ # Verify files still exist after user editing
183
+ edited_files = list(quarantine_dir.glob("*.wav"))
184
+ if not edited_files:
185
+ print("⚠️ No files found in quarantine folder after editing!")
186
+ return
187
+
188
+ print(f"✅ Found {len(edited_files)} edited files, continuing...")
189
+
190
+ # Move all chunks back to main audio folder (whether edited or not)
191
+ moved_count = 0
192
+ for qfile in quarantine_dir.glob("*.wav"):
193
+ # Extract original chunk name from quarantine filename - FIXED LINE:
194
+ original_name = re.sub(r'_(clipping|corrupted|tts_hum)$', '', qfile.stem) + ".wav"
195
+ main_path = qfile.parent.parent / original_name
196
+
197
+ try:
198
+ shutil.move(str(qfile), str(main_path))
199
+ moved_count += 1
200
+ print(f"↩️ Restored: {original_name}")
201
+ except Exception as e:
202
+ logging.error(f"❌ Failed to restore {qfile.name}: {e}")
203
+
204
+ print(f"\n✅ Restored {moved_count} chunks to main audio folder")
205
+
206
+ # Clean up empty quarantine directory
207
+ if not any(quarantine_dir.iterdir()):
208
+ quarantine_dir.rmdir()
209
+
210
+ return moved_count
211
+
212
+ # ============================================================================
213
+ # AUDIO EFFECTS AND PROCESSING
214
+ # ============================================================================
215
+
216
+ def detect_end_artifact(wav_path, window_ms=100):
217
+ """Enhanced artifact detection"""
218
+ data, sr = sf.read(str(wav_path))
219
+ if data.ndim > 1:
220
+ data = data[:, 0]
221
+
222
+ win_samples = int(window_ms / 1000 * sr)
223
+ if len(data) < win_samples * 2:
224
+ return False
225
+
226
+ end = data[-win_samples:]
227
+ middle = data[len(data)//2 : len(data)//2 + win_samples]
228
+
229
+ rms_end = np.sqrt(np.mean(end**2))
230
+ rms_mid = np.sqrt(np.mean(middle**2)) + 1e-10
231
+ rms_ratio = rms_end / rms_mid
232
+
233
+ zcr = np.mean(np.diff(np.sign(end)) != 0)
234
+
235
+ fft = np.fft.rfft(end)
236
+ freqs = np.fft.rfftfreq(len(end), 1/sr)
237
+ low_band = fft[freqs < 150]
238
+ low_energy = np.sum(np.abs(low_band)) / (np.sum(np.abs(fft)) + 1e-10)
239
+
240
+ logging.info(f"{GREEN}[DEBUG]{RESET} Artifact metrics - {YELLOW}RMS ratio: {rms_ratio:.3f}{RESET}, "
241
+ f"{GREEN}ZCR: {zcr:.3f}{RESET}, {CYAN}LowEnergy: {low_energy:.3f}{RESET}")
242
+
243
+ return rms_ratio > 0.6 or zcr > 0.2 or low_energy > 0.4
244
+
245
+ def find_end_of_speech(wav_path, sr=16000):
246
+ """Find end of speech using Silero VAD"""
247
+ import torch
248
+ import os
249
+
250
+ # Set environment variables to suppress PyTorch Hub verbosity
251
+ old_vars = {}
252
+ suppress_vars = {
253
+ 'TORCH_HUB_VERBOSE': '0',
254
+ 'PYTHONWARNINGS': 'ignore',
255
+ 'TF_CPP_MIN_LOG_LEVEL': '3'
256
+ }
257
+
258
+ # Save old values and set new ones
259
+ for key, value in suppress_vars.items():
260
+ old_vars[key] = os.environ.get(key)
261
+ os.environ[key] = value
262
+
263
+ # Temporarily disable logging for this operation
264
+ old_level = logging.getLogger().level
265
+ logging.getLogger().setLevel(logging.ERROR)
266
+
267
+ try:
268
+ model, utils = torch.hub.load(
269
+ repo_or_dir='snakers4/silero-vad',
270
+ model='silero_vad',
271
+ force_reload=False,
272
+ verbose=False
273
+ )
274
+ (get_speech_timestamps, _, read_audio, _, _) = utils
275
+
276
+ wav = read_audio(str(wav_path), sampling_rate=sr)
277
+ speech_segments = get_speech_timestamps(wav, model, sampling_rate=sr)
278
+
279
+ if not speech_segments:
280
+ return None
281
+
282
+ last_seg_end = speech_segments[-1]['end']
283
+ return int(last_seg_end * 1000 / sr)
284
+
285
+ finally:
286
+ # Restore everything
287
+ logging.getLogger().setLevel(old_level)
288
+ for key, old_value in old_vars.items():
289
+ if old_value is None:
290
+ os.environ.pop(key, None)
291
+ else:
292
+ os.environ[key] = old_value
293
+
294
+ def fade_out_wav(wav_path, output_path=None, fade_ms=20):
295
+ """Apply fade-out to audio"""
296
+ data, sr = sf.read(str(wav_path))
297
+ if data.ndim > 1:
298
+ data = data[:, 0]
299
+
300
+ fade_samples = int(sr * fade_ms / 1000)
301
+ if len(data) < fade_samples:
302
+ return
303
+
304
+ debug_path = wav_path.parent / f"{wav_path.stem}_pre_fade.wav"
305
+ sf.write(str(debug_path), data, sr)
306
+
307
+ fade_curve = np.linspace(1.0, 0.0, fade_samples)
308
+ data[-fade_samples:] *= fade_curve
309
+
310
+ sf.write(str(output_path or wav_path), data, sr)
311
+
312
+ def apply_smart_fade(wav_path):
313
+ """Apply smart fade with artifact detection"""
314
+ eos_ms = find_end_of_speech(wav_path)
315
+
316
+ if detect_end_artifact(wav_path):
317
+ fade_out_wav(wav_path)
318
+
319
+ def apply_smart_fade_memory(audio_segment):
320
+ """Apply smart fade with artifact detection - in memory version"""
321
+ # For now, apply a gentle fade to all audio to prevent clicks
322
+ # TODO: Add proper artifact detection for memory processing
323
+ return audio_segment.fade_out(50) # 50ms fade out
324
+
325
+ def smart_audio_validation_memory(audio_segment, sample_rate):
326
+ """Enhanced audio validation in memory - returns (audio, is_quarantined)"""
327
+ # Basic validation - can be enhanced with hum detection later
328
+ # For now, just return the audio as-is
329
+ is_quarantined = False
330
+
331
+ # Could add memory-based hum detection here
332
+ # is_quarantined = detect_hum_memory(audio_segment, sample_rate)
333
+
334
+ return audio_segment, is_quarantined
335
+
336
+ def add_contextual_silence_memory(audio_segment, boundary_type):
337
+ """Add appropriate silence based on content boundary type - in memory"""
338
+ from pydub import AudioSegment
339
+ from config.config import (
340
+ SILENCE_CHAPTER_START, SILENCE_CHAPTER_END, SILENCE_SECTION_BREAK, SILENCE_PARAGRAPH_END,
341
+ SILENCE_COMMA, SILENCE_SEMICOLON, SILENCE_COLON, SILENCE_PERIOD, SILENCE_QUESTION_MARK,
342
+ SILENCE_EXCLAMATION, SILENCE_DASH, SILENCE_ELLIPSIS, SILENCE_QUOTE_END
343
+ )
344
+
345
+ silence_durations = {
346
+ # Structural boundaries
347
+ "chapter_start": SILENCE_CHAPTER_START,
348
+ "chapter_end": SILENCE_CHAPTER_END,
349
+ "section_break": SILENCE_SECTION_BREAK,
350
+ "paragraph_end": SILENCE_PARAGRAPH_END,
351
+ # Punctuation boundaries
352
+ "comma": SILENCE_COMMA,
353
+ "semicolon": SILENCE_SEMICOLON,
354
+ "colon": SILENCE_COLON,
355
+ "period": SILENCE_PERIOD,
356
+ "question_mark": SILENCE_QUESTION_MARK,
357
+ "exclamation": SILENCE_EXCLAMATION,
358
+ "dash": SILENCE_DASH,
359
+ "ellipsis": SILENCE_ELLIPSIS,
360
+ "quote_end": SILENCE_QUOTE_END,
361
+ }
362
+
363
+ if boundary_type in silence_durations:
364
+ duration = silence_durations[boundary_type]
365
+ silence_segment = AudioSegment.silent(duration=duration)
366
+ return audio_segment + silence_segment
367
+
368
+ return audio_segment
369
+
370
+ def smart_fade_out(wav_path, silence_thresh_db=-40, min_silence_len=300):
371
+ """Smart fade-out for natural audio endings"""
372
+ audio = AudioSegment.from_wav(wav_path)
373
+ tail_window_ms = 2000
374
+
375
+ if len(audio) < tail_window_ms:
376
+ logging.info(f"⚠️ {YELLOW}Skipping fade: {wav_path.name} too short ({len(audio)}ms < {tail_window_ms}ms){RESET}")
377
+ return
378
+
379
+ tail = audio[-tail_window_ms:]
380
+ silent_ranges = silence.detect_silence(tail, min_silence_len=min_silence_len, silence_thresh=silence_thresh_db)
381
+
382
+ min_tail_energy = max(tail.get_array_of_samples())
383
+ if not silent_ranges or min_tail_energy > audio.max_possible_amplitude * 0.1:
384
+ logging.info(f"✅ {GREEN}No fade needed for {wav_path.name} (no valid trailing silence){RESET}")
385
+ return
386
+
387
+ fade_start_ms = silent_ranges[0][0]
388
+ fade_length_ms = tail_window_ms - fade_start_ms
389
+
390
+ if fade_length_ms < 100:
391
+ logging.info(f"✅ {GREEN}No fade needed for {wav_path.name} (fade too short: {fade_length_ms}ms){RESET}")
392
+ return
393
+
394
+ fade_start_point = silent_ranges[0][0]
395
+ logging.info(f"⚠️ {RED}Fading tail of {wav_path.name} from {fade_start_point}ms to end{RESET}")
396
+ faded = audio[:fade_start_point] + audio[fade_start_point:].fade_out(duration=fade_length_ms)
397
+ faded.export(wav_path, format="wav")
398
+
399
+ # ============================================================================
400
+ # AUDIO TRIMMING
401
+ # ============================================================================
402
+
403
+ def trim_audio_endpoint(audio_segment, threshold=None, buffer_ms=None):
404
+ """
405
+ Trim audio to the detected end of speech using RMS energy analysis.
406
+
407
+ Args:
408
+ audio_segment: pydub AudioSegment object
409
+ threshold: RMS threshold for speech detection (from config if None)
410
+ buffer_ms: Buffer to add after detected endpoint (from config if None)
411
+
412
+ Returns:
413
+ Trimmed AudioSegment
414
+ """
415
+ if threshold is None:
416
+ threshold = SPEECH_ENDPOINT_THRESHOLD
417
+ if buffer_ms is None:
418
+ buffer_ms = TRIMMING_BUFFER_MS
419
+
420
+ # Convert to numpy array for analysis
421
+ samples = np.array(audio_segment.get_array_of_samples())
422
+ if audio_segment.channels == 2:
423
+ samples = samples.reshape((-1, 2)).mean(axis=1)
424
+
425
+ # Normalize samples
426
+ samples = samples.astype(np.float32) / audio_segment.max_possible_amplitude
427
+
428
+ # Calculate RMS in sliding windows (50ms windows)
429
+ window_size = int(0.05 * audio_segment.frame_rate) # 50ms
430
+ rms_values = []
431
+
432
+ for i in range(0, len(samples) - window_size, window_size // 2):
433
+ window = samples[i:i + window_size]
434
+ rms = np.sqrt(np.mean(window ** 2))
435
+ rms_values.append(rms)
436
+
437
+ # Find actual end of speech using energy decay detection
438
+ speech_end_idx = 0 # Default to beginning if no speech found
439
+
440
+ # Look for a significant and sustained drop in energy
441
+ # Scan backwards to find where energy consistently stays above a higher threshold
442
+ strong_speech_threshold = threshold * 3 # 3x threshold for "real" speech
443
+
444
+ for i in range(len(rms_values) - 1, -1, -1):
445
+ if rms_values[i] > strong_speech_threshold:
446
+ # Found strong speech, check if it's sustained
447
+ # Look forward to see if energy drops and stays low
448
+ sustained_speech = True
449
+ windows_ahead = min(10, len(rms_values) - i) # Look ahead up to 10 windows (250ms)
450
+
451
+ # Check if most of the next windows have reasonable speech levels
452
+ speech_count = 0
453
+ for j in range(i, min(i + windows_ahead, len(rms_values))):
454
+ if rms_values[j] > threshold:
455
+ speech_count += 1
456
+
457
+ # If this looks like the end of sustained speech content
458
+ if speech_count >= max(1, windows_ahead * 0.3): # At least 30% speech in next windows
459
+ speech_end_idx = i
460
+ break
461
+
462
+ # If no strong speech found, fall back to simple threshold method but be conservative
463
+ if speech_end_idx == 0:
464
+ for i in range(len(rms_values) - 1, -1, -1):
465
+ if rms_values[i] > threshold * 2: # Use 2x threshold for fallback
466
+ speech_end_idx = i
467
+ break
468
+
469
+ # Convert back to milliseconds and add buffer
470
+ # Convert window index to sample position, then to milliseconds
471
+ sample_position = speech_end_idx * (window_size // 2)
472
+ speech_end_ms = int(sample_position * 1000 / audio_segment.frame_rate)
473
+ trim_point_ms = min(speech_end_ms + buffer_ms, len(audio_segment))
474
+
475
+ return audio_segment[:trim_point_ms]
476
+
477
+ def process_audio_with_trimming_and_silence(audio_segment, boundary_type, enable_trimming=None):
478
+ """
479
+ Complete audio processing: trim to speech endpoint + add punctuation-based silence.
480
+
481
+ Args:
482
+ audio_segment: pydub AudioSegment object
483
+ boundary_type: Boundary type from text processing
484
+ enable_trimming: Whether to trim audio (from config if None)
485
+
486
+ Returns:
487
+ Processed AudioSegment with trimming and appropriate silence
488
+ """
489
+ if enable_trimming is None:
490
+ enable_trimming = ENABLE_AUDIO_TRIMMING
491
+
492
+ processed_audio = audio_segment
493
+
494
+ # Step 1: Trim to speech endpoint if enabled
495
+ if enable_trimming:
496
+ processed_audio = trim_audio_endpoint(processed_audio)
497
+
498
+ # Step 2: Add punctuation-appropriate silence
499
+ processed_audio = add_contextual_silence_memory(processed_audio, boundary_type)
500
+
501
+ return processed_audio
502
+
503
+ # ============================================================================
504
+ # SILENCE AND CONTEXTUAL AUDIO
505
+ # ============================================================================
506
+
507
+ def add_contextual_silence(wav_path, boundary_type):
508
+ """Add appropriate silence based on content boundary type"""
509
+ silence_durations = {
510
+ # Structural boundaries
511
+ "chapter_start": SILENCE_CHAPTER_START,
512
+ "chapter_end": SILENCE_CHAPTER_END,
513
+ "section_break": SILENCE_SECTION_BREAK,
514
+ "paragraph_end": SILENCE_PARAGRAPH_END,
515
+ # Punctuation boundaries
516
+ "comma": SILENCE_COMMA,
517
+ "semicolon": SILENCE_SEMICOLON,
518
+ "colon": SILENCE_COLON,
519
+ "period": SILENCE_PERIOD,
520
+ "question_mark": SILENCE_QUESTION_MARK,
521
+ "exclamation": SILENCE_EXCLAMATION,
522
+ "dash": SILENCE_DASH,
523
+ "ellipsis": SILENCE_ELLIPSIS,
524
+ "quote_end": SILENCE_QUOTE_END,
525
+ }
526
+
527
+ if boundary_type in silence_durations:
528
+ duration = silence_durations[boundary_type]
529
+ audio = AudioSegment.from_wav(wav_path)
530
+ silence_segment = AudioSegment.silent(duration=duration)
531
+ extended_audio = audio + silence_segment
532
+ extended_audio.export(wav_path, format="wav")
533
+
534
+ logging.info(f"🔇 Added {duration}ms silence for {boundary_type}: {wav_path.name}")
535
+
536
+ def add_chunk_end_silence(wav_path):
537
+ """Add configurable silence to end of chunk if enabled"""
538
+ if not ENABLE_CHUNK_END_SILENCE or CHUNK_END_SILENCE_MS <= 0:
539
+ return
540
+
541
+ try:
542
+ audio = AudioSegment.from_wav(wav_path)
543
+ silence_segment = AudioSegment.silent(duration=CHUNK_END_SILENCE_MS)
544
+ audio_with_silence = audio + silence_segment
545
+ audio_with_silence.export(wav_path, format="wav")
546
+ logging.info(f"➕ Added {CHUNK_END_SILENCE_MS}ms end silence to {wav_path.name}")
547
+ except Exception as e:
548
+ logging.warning(f"⚠️ Failed to add end silence to {wav_path.name}: {e}")
549
+
550
+ # ============================================================================
551
+ # AUDIO UTILITY FUNCTIONS
552
+ # ============================================================================
553
+
554
+ def get_wav_duration(wav_path):
555
+ """Get WAV file duration"""
556
+ import wave
557
+ with wave.open(str(wav_path), 'rb') as wf:
558
+ frames = wf.getnframes()
559
+ rate = wf.getframerate()
560
+ return frames / float(rate)
561
+
562
+ def get_chunk_audio_duration(wav_path):
563
+ """Get actual audio duration from WAV file"""
564
+ try:
565
+ data, sr = sf.read(str(wav_path))
566
+ return len(data) / sr
567
+ except:
568
+ # Fallback to wave module
569
+ return get_wav_duration(wav_path)
HF_Deploy/modules/batch_processor.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Batch Processing Module
3
+ Handles multi-book batch processing operations
4
+ """
5
+
6
+ import torch
7
+ from modules.tts_engine import process_book_folder
8
+
9
+ def pipeline_book_processing(book_queue):
10
+ """Process multiple books in sequence"""
11
+ completed_books = []
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ for book_info in book_queue:
15
+ book_dir = book_info['book_dir']
16
+ voice_path = book_info['voice_path']
17
+ tts_params = book_info['tts_params']
18
+
19
+ print(f"\n🎯 Processing: {book_dir.name}")
20
+
21
+ try:
22
+ result = process_book_folder(book_dir, voice_path, tts_params, device)
23
+ if result[0]: # Check if final_m4b_path exists
24
+ completed_books.append(book_info)
25
+ print(f"✅ Completed: {book_dir.name}")
26
+ else:
27
+ print(f"❌ Failed: {book_dir.name}")
28
+ except Exception as e:
29
+ print(f"❌ Error processing {book_dir.name}: {e}")
30
+
31
+ return completed_books
HF_Deploy/modules/file_manager.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ File Manager Module
3
+ Handles I/O operations, M4B conversion, metadata, and FFmpeg operations
4
+ """
5
+
6
+ import subprocess
7
+ import soundfile as sf
8
+ import os
9
+ import re
10
+ import time
11
+ import logging
12
+ from pathlib import Path
13
+ from config.config import *
14
+
15
+ # ============================================================================
16
+ # VOICE SAMPLE MANAGEMENT
17
+ # ============================================================================
18
+
19
+ def list_voice_samples():
20
+ """List available voice samples"""
21
+ return sorted(VOICE_SAMPLES_DIR.glob("*.wav"))
22
+
23
+ def ensure_voice_sample_compatibility(input_path, output_dir=None):
24
+ """Ensure voice sample is compatible with TTS (24kHz mono)"""
25
+ input_path = str(input_path)
26
+ ext = os.path.splitext(input_path)[1].lower()
27
+ basename = os.path.splitext(os.path.basename(input_path))[0]
28
+ output_dir = output_dir or os.path.dirname(input_path)
29
+ output_path = os.path.join(output_dir, basename + "_ttsready.wav")
30
+
31
+ try:
32
+ info = sf.info(input_path)
33
+ if (ext == '.wav' and info.samplerate == 24000 and info.channels == 1):
34
+ return input_path
35
+ except Exception:
36
+ pass
37
+
38
+ cmd = [
39
+ "ffmpeg", "-y",
40
+ "-i", input_path,
41
+ "-ar", "24000",
42
+ "-ac", "1",
43
+ output_path
44
+ ]
45
+ subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
46
+ return output_path
47
+
48
+ # ============================================================================
49
+ # FFMPEG OPERATIONS
50
+ # ============================================================================
51
+
52
+ def run_ffmpeg(cmd):
53
+ """Run FFmpeg command with error handling"""
54
+ try:
55
+ subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
56
+ except subprocess.CalledProcessError as e:
57
+ logging.info(f"FFmpeg command failed: {' '.join(cmd)}")
58
+ logging.info(f"Error: {e}")
59
+ subprocess.run(cmd)
60
+ raise
61
+
62
+ # ============================================================================
63
+ # M4B CONVERSION WITH NORMALIZATION
64
+ # ============================================================================
65
+
66
+ def convert_to_m4b_with_peak_normalization(wav_path, temp_m4b_path, target_db=-3.0):
67
+ """Convert WAV to M4B with peak normalization"""
68
+ print("🚀 Converting to m4b with peak normalization...")
69
+
70
+ # Build audio filter chain
71
+ audio_filters = [f"loudnorm=I=-16:TP={target_db}:LRA=11"]
72
+ if ATEMPO_SPEED != 1.0:
73
+ audio_filters.append(f"atempo={ATEMPO_SPEED}")
74
+
75
+ cmd = [
76
+ "ffmpeg", "-y",
77
+ "-i", str(wav_path),
78
+ "-af", ",".join(audio_filters),
79
+ "-c:a", "aac",
80
+ str(temp_m4b_path)
81
+ ]
82
+
83
+ start_time = time.time()
84
+ process = subprocess.Popen(cmd, stderr=subprocess.PIPE, text=True)
85
+
86
+ audio_secs = 0.0
87
+ for line in process.stderr:
88
+ match = re.search(r"time=(\d{2}):(\d{2}):(\d{2})\.(\d{2})", line)
89
+ if match:
90
+ h, m, s, ms = map(int, match.groups())
91
+ audio_secs = h * 3600 + m * 60 + s + ms / 100
92
+ elapsed = time.time() - start_time
93
+ factor = audio_secs / elapsed if elapsed > 0 else 0.0
94
+ print(f"📼 FFmpeg (normalizing): {match.group(0)} | {factor:.2f}x realtime", end='\r')
95
+
96
+ process.wait()
97
+ print("\n✅ Conversion with normalization complete.")
98
+
99
+ def convert_to_m4b_with_loudness_normalization(wav_path, temp_m4b_path):
100
+ """Convert WAV to M4B with two-pass loudness normalization"""
101
+ import json
102
+
103
+ print("🚀 Converting to m4b with loudness normalization...")
104
+
105
+ # Step 1: Analyze audio loudness
106
+ print("📊 Analyzing audio loudness...")
107
+ analyze_cmd = [
108
+ "ffmpeg", "-y",
109
+ "-i", str(wav_path),
110
+ "-af", "loudnorm=I=-16:TP=-1.5:LRA=11:print_format=json",
111
+ "-f", "null", "-"
112
+ ]
113
+
114
+ result = subprocess.run(analyze_cmd, capture_output=True, text=True)
115
+
116
+ # Extract loudness measurements from stderr
117
+ loudness_data = None
118
+ for line in result.stderr.split('\n'):
119
+ if line.strip().startswith('{'):
120
+ try:
121
+ loudness_data = json.loads(line.strip())
122
+ break
123
+ except:
124
+ continue
125
+
126
+ if not loudness_data:
127
+ print("⚠️ Could not analyze loudness, falling back to single-pass...")
128
+ return convert_to_m4b_with_peak_normalization(wav_path, temp_m4b_path)
129
+
130
+ # Step 2: Apply normalization with measured values
131
+ print("🔧 Applying normalization...")
132
+
133
+ # Build audio filter chain
134
+ audio_filters = [f"loudnorm=I=-16:TP=-1.5:LRA=11:measured_I={loudness_data['input_i']}:measured_LRA={loudness_data['input_lra']}:measured_TP={loudness_data['input_tp']}:measured_thresh={loudness_data['input_thresh']}:offset={loudness_data['target_offset']}:linear=true:print_format=summary"]
135
+ if ATEMPO_SPEED != 1.0:
136
+ audio_filters.append(f"atempo={ATEMPO_SPEED}")
137
+
138
+ cmd = [
139
+ "ffmpeg", "-y",
140
+ "-i", str(wav_path),
141
+ "-af", ",".join(audio_filters),
142
+ "-c:a", "aac",
143
+ str(temp_m4b_path)
144
+ ]
145
+
146
+ start_time = time.time()
147
+ process = subprocess.Popen(cmd, stderr=subprocess.PIPE, text=True)
148
+
149
+ audio_secs = 0.0
150
+ for line in process.stderr:
151
+ match = re.search(r"time=(\d{2}):(\d{2}):(\d{2})\.(\d{2})", line)
152
+ if match:
153
+ h, m, s, ms = map(int, match.groups())
154
+ audio_secs = h * 3600 + m * 60 + s + ms / 100
155
+ elapsed = time.time() - start_time
156
+ factor = audio_secs / elapsed if elapsed > 0 else 0.0
157
+ print(f"📼 FFmpeg (normalizing): {match.group(0)} | {factor:.2f}x realtime", end='\r')
158
+
159
+ process.wait()
160
+ print("\n✅ Two-pass normalization complete.")
161
+
162
+ def convert_to_m4b_with_simple_normalization(wav_path, temp_m4b_path, target_db=-6.0):
163
+ """Convert WAV to M4B with simple peak normalization"""
164
+ print("🚀 Converting to m4b with simple normalization...")
165
+
166
+ # Build audio filter chain
167
+ audio_filters = [f"volume={target_db}dB"]
168
+ if ATEMPO_SPEED != 1.0:
169
+ audio_filters.append(f"atempo={ATEMPO_SPEED}")
170
+
171
+ cmd = [
172
+ "ffmpeg", "-y",
173
+ "-i", str(wav_path),
174
+ "-af", ",".join(audio_filters),
175
+ "-c:a", "aac",
176
+ str(temp_m4b_path)
177
+ ]
178
+
179
+ start_time = time.time()
180
+ process = subprocess.Popen(cmd, stderr=subprocess.PIPE, text=True)
181
+
182
+ audio_secs = 0.0
183
+ for line in process.stderr:
184
+ match = re.search(r"time=(\d{2}):(\d{2}):(\d{2})\.(\d{2})", line)
185
+ if match:
186
+ h, m, s, ms = map(int, match.groups())
187
+ audio_secs = h * 3600 + m * 60 + s + ms / 100
188
+ elapsed = time.time() - start_time
189
+ factor = audio_secs / elapsed if elapsed > 0 else 0.0
190
+ print(f"📼 FFmpeg (normalizing): {match.group(0)} | {factor:.2f}x realtime", end='\r')
191
+
192
+ process.wait()
193
+ print("\n✅ Simple normalization complete.")
194
+
195
+ def convert_to_m4b(wav_path, temp_m4b_path):
196
+ """Convert WAV to M4B with configurable normalization"""
197
+ if not ENABLE_NORMALIZATION or NORMALIZATION_TYPE == "none":
198
+ # Original function without normalization
199
+ print("🚀 Converting to m4b...")
200
+
201
+ # Build audio filter for atempo if needed
202
+ audio_filter = []
203
+ if ATEMPO_SPEED != 1.0:
204
+ audio_filter = ["-filter:a", f"atempo={ATEMPO_SPEED}"]
205
+
206
+ cmd = [
207
+ "ffmpeg", "-y",
208
+ "-i", str(wav_path)
209
+ ] + audio_filter + [
210
+ "-c:a", "aac",
211
+ str(temp_m4b_path)
212
+ ]
213
+
214
+ elif NORMALIZATION_TYPE == "loudness":
215
+ # EBU R128 loudness normalization (recommended for audiobooks)
216
+ return convert_to_m4b_with_loudness_normalization(wav_path, temp_m4b_path)
217
+
218
+ elif NORMALIZATION_TYPE == "peak":
219
+ # Peak normalization
220
+ return convert_to_m4b_with_peak_normalization(wav_path, temp_m4b_path, TARGET_PEAK_DB)
221
+
222
+ elif NORMALIZATION_TYPE == "simple":
223
+ # Simple volume adjustment
224
+ return convert_to_m4b_with_simple_normalization(wav_path, temp_m4b_path, TARGET_PEAK_DB)
225
+
226
+ else:
227
+ # Fallback to no normalization
228
+ # Build audio filter for atempo if needed
229
+ audio_filter = []
230
+ if ATEMPO_SPEED != 1.0:
231
+ audio_filter = ["-filter:a", f"atempo={ATEMPO_SPEED}"]
232
+
233
+ cmd = [
234
+ "ffmpeg", "-y",
235
+ "-i", str(wav_path)
236
+ ] + audio_filter + [
237
+ "-c:a", "aac",
238
+ str(temp_m4b_path)
239
+ ]
240
+
241
+ # Run the conversion (if not handled by specialized functions above)
242
+ start_time = time.time()
243
+ process = subprocess.Popen(cmd, stderr=subprocess.PIPE, text=True)
244
+
245
+ audio_secs = 0.0
246
+ for line in process.stderr:
247
+ match = re.search(r"time=(\d{2}):(\d{2}):(\d{2})\.(\d{2})", line)
248
+ if match:
249
+ h, m, s, ms = map(int, match.groups())
250
+ audio_secs = h * 3600 + m * 60 + s + ms / 100
251
+ elapsed = time.time() - start_time
252
+ factor = audio_secs / elapsed if elapsed > 0 else 0.0
253
+ print(f"📼 FFmpeg: {match.group(0)} | {factor:.2f}x realtime", end='\r')
254
+
255
+ process.wait()
256
+ print("\n✅ Conversion complete.")
257
+
258
+ def add_metadata_to_m4b(temp_m4b_path, final_m4b_path, cover_path=None, nfo_path=None):
259
+ """Add metadata and cover to M4B"""
260
+ cmd = ["ffmpeg", "-y", "-i", str(temp_m4b_path)]
261
+
262
+ if cover_path and cover_path.exists():
263
+ cmd.extend(["-i", str(cover_path), "-map", "0", "-map", "1", "-c", "copy", "-disposition:v:0", "attached_pic"])
264
+ else:
265
+ cmd.extend(["-map", "0", "-c", "copy"])
266
+
267
+ if nfo_path and nfo_path.exists():
268
+ with open(nfo_path, 'r', encoding='utf-8') as f:
269
+ for line in f:
270
+ if ':' in line:
271
+ key, val = line.strip().split(':', 1)
272
+ cmd.extend(["-metadata", f"{key.strip()}={val.strip()}"])
273
+
274
+ cmd.append(str(final_m4b_path))
275
+ run_ffmpeg(cmd)
276
+ temp_m4b_path.unlink(missing_ok=True)
277
+
278
+ # ============================================================================
279
+ # FILE UTILITIES
280
+ # ============================================================================
281
+
282
+ def chunk_sort_key(f):
283
+ """Extracts the chunk number for natural sorting"""
284
+ m = re.match(r"chunk_(\d+)\.wav", f.name)
285
+ return int(m.group(1)) if m else 0
286
+
287
+ def create_concat_file(chunk_paths, output_path):
288
+ """Create FFmpeg concat file for audio chunks"""
289
+ with open(output_path, 'w') as f:
290
+ for p in chunk_paths:
291
+ # Use absolute path to ensure FFmpeg can find the files
292
+ f.write(f"file '{str(p.resolve())}'\n")
293
+
294
+ logging.info(f"concat.txt written with {len(chunk_paths)} chunks.")
295
+ return output_path
296
+
297
+ def cleanup_temp_files(directory, patterns):
298
+ """Clean up temporary files matching patterns"""
299
+ files_cleaned = 0
300
+ for pattern in patterns:
301
+ for temp_file in directory.glob(pattern):
302
+ temp_file.unlink(missing_ok=True)
303
+ files_cleaned += 1
304
+
305
+ return files_cleaned
306
+
307
+ # ============================================================================
308
+ # DIRECTORY MANAGEMENT
309
+ # ============================================================================
310
+
311
+ def setup_book_directories(book_dir):
312
+ """Set up directory structure for book processing"""
313
+ basename = book_dir.name
314
+ output_root = AUDIOBOOK_ROOT / basename
315
+ tts_dir = output_root / "TTS"
316
+ text_chunks_dir = tts_dir / "text_chunks"
317
+ audio_chunks_dir = tts_dir / "audio_chunks"
318
+
319
+ # Create directories
320
+ for d in [output_root, tts_dir, text_chunks_dir, audio_chunks_dir]:
321
+ d.mkdir(parents=True, exist_ok=True)
322
+
323
+ return output_root, tts_dir, text_chunks_dir, audio_chunks_dir
324
+
325
+ def find_book_files(book_dir):
326
+ """Find text files, cover, and metadata for a book"""
327
+ text_files = sorted(book_dir.glob("*.txt"))
328
+ nfo_file = book_dir / "book.nfo"
329
+ cover_jpg = book_dir / "cover.jpg"
330
+ cover_png = book_dir / "cover.png"
331
+ cover_file = cover_jpg if cover_jpg.exists() else cover_png if cover_png.exists() else None
332
+
333
+ return {
334
+ 'text': text_files[0] if text_files else None,
335
+ 'cover': cover_file,
336
+ 'nfo': nfo_file if nfo_file.exists() else None
337
+ }
338
+
339
+ # ============================================================================
340
+ # AUDIO FILE OPERATIONS
341
+ # ============================================================================
342
+
343
+ def combine_audio_chunks(chunk_paths, output_path):
344
+ """Combine audio chunks into single file using FFmpeg"""
345
+ concat_list_path = output_path.parent / "concat.txt"
346
+ create_concat_file(chunk_paths, concat_list_path)
347
+
348
+ run_ffmpeg([
349
+ "ffmpeg", "-y", "-f", "concat", "-safe", "0",
350
+ "-i", str(concat_list_path.resolve()),
351
+ "-c", "copy", str(output_path.resolve())
352
+ ])
353
+
354
+ return output_path
355
+
356
+ def get_audio_files_in_directory(directory, pattern="chunk_*.wav"):
357
+ """Get sorted list of audio files matching pattern"""
358
+ chunk_paths = sorted([f for f in directory.glob(pattern)
359
+ if re.fullmatch(r'chunk_\d{3,}\.wav', f.name)],
360
+ key=chunk_sort_key)
361
+ return chunk_paths
362
+
363
+ # ============================================================================
364
+ # VALIDATION AND VERIFICATION
365
+ # ============================================================================
366
+
367
+ def verify_audio_file(wav_path):
368
+ """Verify audio file is valid and readable"""
369
+ try:
370
+ info = sf.info(str(wav_path))
371
+ return info.frames > 0 and info.samplerate > 0
372
+ except Exception as e:
373
+ logging.error(f"Invalid audio file {wav_path}: {e}")
374
+ return False
375
+
376
+ def verify_chunk_completeness(audio_chunks_dir, expected_count):
377
+ """Verify all expected chunks exist and are valid"""
378
+ missing_chunks = []
379
+ invalid_chunks = []
380
+
381
+ for i in range(1, expected_count + 1):
382
+ chunk_path = audio_chunks_dir / f"chunk_{i:05}.wav"
383
+
384
+ if not chunk_path.exists():
385
+ missing_chunks.append(i)
386
+ elif not verify_audio_file(chunk_path):
387
+ invalid_chunks.append(i)
388
+
389
+ return missing_chunks, invalid_chunks
390
+
391
+ # ============================================================================
392
+ # EXPORT AND IMPORT FUNCTIONS
393
+ # ============================================================================
394
+
395
+ def export_processing_log(output_dir, processing_info):
396
+ """Export comprehensive processing log"""
397
+ log_path = output_dir / "processing_complete.log"
398
+
399
+ with open(log_path, 'w', encoding='utf-8') as f:
400
+ f.write("GenTTS Processing Complete\n")
401
+ f.write("=" * 50 + "\n\n")
402
+
403
+ for key, value in processing_info.items():
404
+ f.write(f"{key}: {value}\n")
405
+
406
+ return log_path
407
+
408
+ def save_chunk_info(text_chunks_dir, chunks_info):
409
+ """Save chunk information for debugging/resume"""
410
+ info_path = text_chunks_dir / "chunks_info.json"
411
+
412
+ import json
413
+ with open(info_path, 'w', encoding='utf-8') as f:
414
+ json.dump(chunks_info, f, indent=2, ensure_ascii=False)
415
+
416
+ return info_path
417
+
418
+ def load_chunk_info(text_chunks_dir):
419
+ """Load chunk information if available"""
420
+ info_path = text_chunks_dir / "chunks_info.json"
421
+
422
+ if not info_path.exists():
423
+ return None
424
+
425
+ import json
426
+ try:
427
+ with open(info_path, 'r', encoding='utf-8') as f:
428
+ return json.load(f)
429
+ except Exception as e:
430
+ logging.warning(f"Could not load chunk info: {e}")
431
+ return None
HF_Deploy/modules/gui_json_generator.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GUI JSON Audio Generation Module
4
+
5
+ This module provides JSON-to-audiobook generation specifically for GUI use.
6
+ It's based on utils/generate_from_json.py but adapted for GUI integration.
7
+ """
8
+
9
+ import torch
10
+ from pathlib import Path
11
+ import sys
12
+ from concurrent.futures import ThreadPoolExecutor, as_completed
13
+ import time
14
+ from datetime import timedelta
15
+
16
+ # Add project root to path to allow module imports
17
+ project_root = Path(__file__).parent.parent
18
+ sys.path.append(str(project_root))
19
+
20
+ from config.config import *
21
+ from modules.tts_engine import load_optimized_model, process_one_chunk
22
+ from modules.file_manager import setup_book_directories, list_voice_samples, ensure_voice_sample_compatibility
23
+ from wrapper.chunk_loader import load_chunks
24
+ from src.chatterbox.tts import punc_norm
25
+ from modules.progress_tracker import log_chunk_progress, log_run
26
+ from tools.combine_only import combine_audio_for_book
27
+
28
+
29
+ def generate_audiobook_from_json(json_path, voice_name, temp_setting=None):
30
+ """
31
+ Generate complete audiobook from JSON chunks file.
32
+
33
+ Args:
34
+ json_path (str): Path to the JSON chunks file
35
+ voice_name (str): Name of the voice to use (without .wav extension)
36
+ temp_setting (float, optional): Temperature override for TTS
37
+
38
+ Returns:
39
+ tuple: (success: bool, message: str, audiobook_path: str or None)
40
+ """
41
+ try:
42
+ print(f"🎵 GUI JSON Generator: Starting audiobook generation")
43
+ print(f"📄 JSON file: {json_path}")
44
+ print(f"🎤 Voice: {voice_name}")
45
+ if temp_setting:
46
+ print(f"🌡️ Temperature override: {temp_setting}")
47
+
48
+ # Determine book name from JSON path
49
+ json_file = Path(json_path)
50
+
51
+ # Try to extract book name from path structure
52
+ if 'Audiobook' in json_file.parts:
53
+ audiobook_index = json_file.parts.index('Audiobook')
54
+ if audiobook_index + 1 < len(json_file.parts):
55
+ book_name = json_file.parts[audiobook_index + 1]
56
+ print(f"📚 Detected book name from path: {book_name}")
57
+ else:
58
+ raise Exception("Cannot determine book name from Audiobook path")
59
+ elif json_file.stem.endswith('_chunks'):
60
+ book_name = json_file.stem.replace('_chunks', '')
61
+ print(f"📚 Detected book name from filename: {book_name}")
62
+ else:
63
+ book_name = json_file.stem
64
+ print(f"📚 Using filename as book name: {book_name}")
65
+
66
+ # Load JSON chunks (READ ONLY - never modify the original)
67
+ print(f"📖 Loading chunks from: {json_path}")
68
+ all_chunks = load_chunks(str(json_path))
69
+ print(f"✅ Found {len(all_chunks)} chunks.")
70
+
71
+ # Find voice file
72
+ voice_files = list_voice_samples()
73
+ voice_path = None
74
+ for voice_file in voice_files:
75
+ if voice_file.stem == voice_name:
76
+ voice_path = voice_file
77
+ break
78
+
79
+ if not voice_path:
80
+ available_voices = [vf.stem for vf in voice_files]
81
+ return False, f"Voice '{voice_name}' not found. Available: {available_voices}", None
82
+
83
+ # Ensure voice compatibility
84
+ voice_path = ensure_voice_sample_compatibility(voice_path)
85
+ if isinstance(voice_path, str):
86
+ voice_path = Path(voice_path)
87
+
88
+ print(f"🎤 Using voice: {voice_path.name}")
89
+
90
+ # Setup device
91
+ if torch.cuda.is_available():
92
+ device = "cuda"
93
+ elif torch.backends.mps.is_available():
94
+ device = "mps"
95
+ else:
96
+ device = "cpu"
97
+
98
+ print(f"🚀 Using device: {device}")
99
+
100
+ # Load TTS model
101
+ print(f"🤖 Loading TTS model...")
102
+ model = load_optimized_model(device)
103
+
104
+ # Prepare voice conditionals
105
+ print(f"🎤 Preparing voice conditionals...")
106
+ model.prepare_conditionals(voice_path)
107
+
108
+ # Setup output directories
109
+ output_root = AUDIOBOOK_ROOT / book_name
110
+ tts_dir = output_root / "TTS"
111
+ text_chunks_dir = tts_dir / "text_chunks"
112
+ audio_chunks_dir = tts_dir / "audio_chunks"
113
+
114
+ # Create directories
115
+ for dir_path in [output_root, tts_dir, text_chunks_dir, audio_chunks_dir]:
116
+ dir_path.mkdir(parents=True, exist_ok=True)
117
+
118
+ # Clean existing audio chunks
119
+ print("🧹 Clearing old audio chunks...")
120
+ for wav_file in audio_chunks_dir.glob("*.wav"):
121
+ wav_file.unlink()
122
+
123
+ # Process chunks
124
+ start_time = time.time()
125
+ total_chunks = len(all_chunks)
126
+ log_path = output_root / "gui_json_generation.log"
127
+
128
+ print(f"🔄 Generating {total_chunks} audio chunks...")
129
+
130
+ with ThreadPoolExecutor(max_workers=2) as executor:
131
+ futures = []
132
+ for i, chunk_data in enumerate(all_chunks):
133
+ # Use chunk's TTS params, with temperature override if provided
134
+ chunk_tts_params = chunk_data.get("tts_params", {}).copy()
135
+ if temp_setting is not None:
136
+ chunk_tts_params["temperature"] = temp_setting
137
+
138
+ # Ensure required TTS params exist
139
+ chunk_tts_params.setdefault("exaggeration", DEFAULT_EXAGGERATION)
140
+ chunk_tts_params.setdefault("cfg_weight", DEFAULT_CFG_WEIGHT)
141
+ chunk_tts_params.setdefault("temperature", DEFAULT_TEMPERATURE)
142
+
143
+ future = executor.submit(
144
+ process_one_chunk,
145
+ i, chunk_data['text'], text_chunks_dir, audio_chunks_dir,
146
+ voice_path, chunk_tts_params, start_time, total_chunks,
147
+ punc_norm, book_name, log_run, log_path, device,
148
+ model, None, all_chunks, chunk_data.get('boundary_type', 'none')
149
+ )
150
+ futures.append(future)
151
+
152
+ # Wait for all chunks to complete
153
+ completed_chunks = 0
154
+ for future in as_completed(futures):
155
+ try:
156
+ result = future.result()
157
+ if result:
158
+ idx, _ = result
159
+ completed_chunks += 1
160
+ log_chunk_progress(idx, total_chunks, start_time, 0)
161
+ print(f"✅ Completed chunk {completed_chunks}/{total_chunks}")
162
+ except Exception as e:
163
+ print(f"❌ Error processing chunk: {e}")
164
+
165
+ elapsed_time = time.time() - start_time
166
+ print(f"✅ Audio generation complete in {timedelta(seconds=int(elapsed_time))}")
167
+ print(f"🔊 Audio chunks generated in: {audio_chunks_dir}")
168
+
169
+ # Combine chunks into final audiobook
170
+ print("🔗 Combining audio chunks into final audiobook...")
171
+ try:
172
+ success = combine_audio_for_book(str(output_root), voice_name)
173
+ if success:
174
+ # Look for the created audiobook file with voice name
175
+ final_m4b = output_root / f"{book_name} [{voice_name}].m4b"
176
+ if final_m4b.exists():
177
+ print(f"🎉 Audiobook created: {final_m4b.name}")
178
+ return True, "Audiobook generation completed successfully", str(final_m4b)
179
+ else:
180
+ return False, "Combine succeeded but final audiobook file not found", None
181
+ else:
182
+ return False, "Failed to combine audio chunks", None
183
+ except Exception as e:
184
+ return False, f"Error combining audio chunks: {e}", None
185
+
186
+ except Exception as e:
187
+ error_msg = f"JSON generation error: {e}"
188
+ print(f"❌ {error_msg}")
189
+ return False, error_msg, None
190
+
191
+
192
+ def get_book_name_from_json_path(json_path):
193
+ """
194
+ Extract book name from JSON file path.
195
+
196
+ Args:
197
+ json_path (str): Path to JSON file
198
+
199
+ Returns:
200
+ str: Detected book name
201
+ """
202
+ json_file = Path(json_path)
203
+
204
+ if 'Audiobook' in json_file.parts:
205
+ audiobook_index = json_file.parts.index('Audiobook')
206
+ if audiobook_index + 1 < len(json_file.parts):
207
+ return json_file.parts[audiobook_index + 1]
208
+
209
+ if json_file.stem.endswith('_chunks'):
210
+ return json_file.stem.replace('_chunks', '')
211
+
212
+ return json_file.stem
213
+
214
+
215
+ if __name__ == "__main__":
216
+ # CLI compatibility for testing
217
+ print("GUI JSON Generator - use from GUI or import as module")
HF_Deploy/modules/path_manager.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from config.config import AUDIOBOOK_ROOT
3
+
4
+ def get_book_paths(book_name):
5
+ """Return standardized paths for a given book name"""
6
+ base = AUDIOBOOK_ROOT / book_name
7
+ tts_dir = base / "TTS"
8
+ return {
9
+ "book_folder": base,
10
+ "tts_dir": tts_dir,
11
+ "text_chunks": tts_dir / "text_chunks",
12
+ "audio_chunks": tts_dir / "audio_chunks",
13
+ "combined_wav": base / f"{book_name}.wav",
14
+ "final_m4b": base / f"{book_name}.m4b",
15
+ "concat_list": tts_dir / "audio_chunks" / "concat.txt",
16
+ "quarantine": tts_dir / "audio_chunks" / "quarantine",
17
+ "run_log": base / "run.log",
18
+ "chunk_log": base / "chunk_validation.log"
19
+ }
HF_Deploy/modules/progress_tracker.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Progress Tracker Module
3
+ Handles progress display, VRAM monitoring, logging systems, and performance tracking
4
+ """
5
+
6
+ import time
7
+ import sys
8
+ import logging
9
+ from datetime import timedelta
10
+ from pathlib import Path
11
+ from config.config import *
12
+
13
+ # ============================================================================
14
+ # LOGGING SETUP
15
+ # ============================================================================
16
+
17
+ def setup_logging(log_dir):
18
+ """Setup logging configuration"""
19
+ log_file = log_dir / "chunk_validation.log"
20
+
21
+ # Clear existing log
22
+ open(log_file, 'w').close()
23
+
24
+ logging.basicConfig(
25
+ filename=str(log_file),
26
+ level=logging.INFO,
27
+ format="%(asctime)s - %(levelname)s - %(message)s",
28
+ filemode='w' # Overwrite existing log
29
+ )
30
+
31
+ # Also log to console for important messages
32
+ console_handler = logging.StreamHandler()
33
+ console_handler.setLevel(logging.WARNING)
34
+ formatter = logging.Formatter('%(levelname)s - %(message)s')
35
+ console_handler.setFormatter(formatter)
36
+ logging.getLogger().addHandler(console_handler)
37
+
38
+ def log_console(message, color=None):
39
+ """Log to both console and file with optional color"""
40
+ color_codes = {
41
+ "RED": RED, "GREEN": GREEN, "YELLOW": YELLOW,
42
+ "CYAN": CYAN, "BOLD": BOLD, "RESET": RESET
43
+ }
44
+
45
+ prefix = color_codes.get(color, "")
46
+ suffix = RESET if color else ""
47
+
48
+ print(f"{prefix}{message}{suffix}")
49
+ logging.info(message)
50
+
51
+ def log_run(message, log_path):
52
+ """Log to run file"""
53
+ with open(log_path, "a", encoding="utf-8") as logf:
54
+ logf.write(message + "\n")
55
+
56
+ # ============================================================================
57
+ # PROGRESS TRACKING
58
+ # ============================================================================
59
+
60
+ def log_chunk_progress(i, total_chunks, start_time, total_audio_duration=0.0):
61
+ """Enhanced progress logging with accurate realtime factor"""
62
+ elapsed = time.time() - start_time
63
+ avg_time = elapsed / (i + 1)
64
+ eta = avg_time * total_chunks
65
+ remaining = eta - elapsed
66
+
67
+ def fmt(seconds):
68
+ return str(timedelta(seconds=int(seconds)))
69
+
70
+ # Show VRAM usage in progress
71
+ allocated, _ = monitor_vram_usage("chunk_progress")
72
+
73
+ # Calculate ACCURATE realtime factor using actual audio duration
74
+ if total_audio_duration > 0 and elapsed > 0:
75
+ actual_realtime = total_audio_duration / elapsed
76
+ realtime_str = f"{GREEN}{actual_realtime:.2f}x{RESET}"
77
+ audio_str = f" | Audio: {GREEN}{fmt(total_audio_duration)}{RESET}"
78
+ else:
79
+ actual_realtime = 0.0 # Default value when calculating
80
+ realtime_str = f"{YELLOW}Calculating...{RESET}"
81
+ audio_str = ""
82
+
83
+ # Force immediate output with explicit flushing
84
+ progress_msg = (f"\n🌀 Chunk {i+1}/{total_chunks} | ⏱ Elapsed: {CYAN}{fmt(elapsed)}{RESET} | "
85
+ f"ETA: {CYAN}{fmt(eta)}{RESET} | Remaining: {YELLOW}{fmt(remaining)}{RESET} | "
86
+ f"Realtime: {realtime_str} | VRAM: {GREEN}{allocated:.1f}GB{RESET}{audio_str}")
87
+
88
+ print(progress_msg)
89
+ sys.stdout.flush() # Force immediate output
90
+
91
+ # Create clean status message for GUI (without ANSI color codes)
92
+ realtime_display = f"{actual_realtime:.2f}x" if actual_realtime > 0 else "Calculating..."
93
+ clean_status = (f"Elapsed: {fmt(elapsed)} | ETA: {fmt(eta)} | Remaining: {fmt(remaining)} | "
94
+ f"Realtime: {realtime_display} | VRAM: {allocated:.1f}GB" +
95
+ (f" | Audio: {fmt(total_audio_duration)}" if total_audio_duration > 0 else ""))
96
+
97
+ # Emit status to GUI if callback is available
98
+ if hasattr(log_chunk_progress, '_status_callback') and log_chunk_progress._status_callback:
99
+ log_chunk_progress._status_callback(clean_status)
100
+
101
+ # Also log to file for debugging
102
+ realtime_log = f"{actual_realtime:.2f}x" if actual_realtime > 0 else "N/A"
103
+ logging.info(f"Progress: Chunk {i+1}/{total_chunks}, Elapsed: {fmt(elapsed)}, "
104
+ f"ETA: {fmt(eta)}, Realtime: {realtime_log}, "
105
+ f"Audio Duration: {fmt(total_audio_duration)}, VRAM: {allocated:.1f}GB")
106
+
107
+ def display_batch_progress(batch_start, batch_end, total_chunks):
108
+ """Display batch processing progress"""
109
+ batch_progress = (batch_end / total_chunks) * 100
110
+ print(f"\n📊 Batch Progress: {batch_start+1}-{batch_end}/{total_chunks} ({batch_progress:.1f}%)")
111
+
112
+ def display_final_summary(elapsed_time, audio_duration, chunk_count, realtime_factor):
113
+ """Display final processing summary"""
114
+ elapsed_td = timedelta(seconds=int(elapsed_time))
115
+ audio_td = timedelta(seconds=int(audio_duration))
116
+
117
+ print(f"\n🎉 {GREEN}Processing Complete!{RESET}")
118
+ print(f"📊 Final Statistics:")
119
+ print(f" ⏱️ Processing Time: {CYAN}{elapsed_td}{RESET}")
120
+ print(f" 🎵 Audio Duration: {GREEN}{audio_td}{RESET}")
121
+ print(f" 📦 Total Chunks: {YELLOW}{chunk_count}{RESET}")
122
+ print(f" 🚀 Realtime Factor: {BOLD}{realtime_factor:.2f}x{RESET}")
123
+ print(f" 💾 Memory Efficiency: {GREEN}Optimized{RESET}")
124
+
125
+ # ============================================================================
126
+ # VRAM AND PERFORMANCE MONITORING
127
+ # ============================================================================
128
+
129
+ def monitor_vram_usage(operation_name=""):
130
+ """Real-time VRAM monitoring with threshold warnings"""
131
+ import torch
132
+
133
+ if not torch.cuda.is_available():
134
+ return 0, 0
135
+
136
+ allocated = torch.cuda.memory_allocated() / 1024**3
137
+ reserved = torch.cuda.memory_reserved() / 1024**3
138
+
139
+ if allocated > VRAM_SAFETY_THRESHOLD:
140
+ logging.warning(f"⚠️ High VRAM usage during {operation_name}: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved")
141
+ # Trigger memory optimization if available
142
+ optimize_memory_if_needed()
143
+
144
+ return allocated, reserved
145
+
146
+ def monitor_gpu_utilization():
147
+ """Monitor GPU utilization if pynvml is available"""
148
+ try:
149
+ import pynvml
150
+ pynvml.nvmlInit()
151
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0)
152
+ util = pynvml.nvmlDeviceGetUtilizationRates(handle)
153
+ temp = pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
154
+
155
+ return {
156
+ "gpu_util": util.gpu,
157
+ "memory_util": util.memory,
158
+ "temperature": temp
159
+ }
160
+ except:
161
+ return {"gpu_util": "N/A", "memory_util": "N/A", "temperature": "N/A"}
162
+
163
+ def optimize_memory_if_needed():
164
+ """Trigger memory optimization when thresholds are exceeded"""
165
+ try:
166
+ # Try to use the enhanced CUDA memory optimization if available
167
+ from modules.tts_engine import optimize_cuda_memory_usage
168
+ optimize_cuda_memory_usage()
169
+ except ImportError:
170
+ # Fallback to basic optimization
171
+ import torch
172
+ import gc
173
+ torch.cuda.empty_cache()
174
+ gc.collect()
175
+ if torch.cuda.is_available():
176
+ torch.cuda.ipc_collect()
177
+
178
+ def display_system_info():
179
+ """Display system information at startup"""
180
+ import torch
181
+
182
+ print(f"\n🖥️ {CYAN}System Information:{RESET}")
183
+
184
+ # CUDA info
185
+ if torch.cuda.is_available():
186
+ gpu_name = torch.cuda.get_device_name(0)
187
+ total_vram = torch.cuda.get_device_properties(0).total_memory / 1024**3
188
+ print(f" GPU: {GREEN}{gpu_name}{RESET}")
189
+ print(f" VRAM: {GREEN}{total_vram:.1f}GB{RESET}")
190
+ print(f" CUDA Version: {GREEN}{torch.version.cuda}{RESET}")
191
+ else:
192
+ print(f" GPU: {RED}Not Available{RESET}")
193
+
194
+ # Memory threshold
195
+ print(f" VRAM Safety Threshold: {YELLOW}{VRAM_SAFETY_THRESHOLD}GB{RESET}")
196
+
197
+ # Worker configuration
198
+ print(f" Max Workers: {YELLOW}{MAX_WORKERS}{RESET}")
199
+ print(f" Dynamic Workers: {YELLOW}{USE_DYNAMIC_WORKERS}{RESET}")
200
+
201
+ # ============================================================================
202
+ # PERFORMANCE TRACKING
203
+ # ============================================================================
204
+
205
+ class PerformanceTracker:
206
+ """Track performance metrics throughout processing"""
207
+
208
+ def __init__(self):
209
+ self.start_time = time.time()
210
+ self.chunk_times = []
211
+ self.vram_usage = []
212
+ self.batch_times = []
213
+
214
+ def log_chunk_completion(self, chunk_index, audio_duration):
215
+ """Log individual chunk completion"""
216
+ current_time = time.time()
217
+ chunk_time = current_time - (self.start_time + sum(self.chunk_times))
218
+
219
+ self.chunk_times.append(chunk_time)
220
+
221
+ # Track VRAM
222
+ allocated, reserved = monitor_vram_usage()
223
+ self.vram_usage.append((chunk_index, allocated, reserved))
224
+
225
+ def log_batch_completion(self, batch_size):
226
+ """Log batch completion"""
227
+ if len(self.chunk_times) >= batch_size:
228
+ batch_time = sum(self.chunk_times[-batch_size:])
229
+ self.batch_times.append(batch_time)
230
+
231
+ def get_performance_summary(self):
232
+ """Get comprehensive performance summary"""
233
+ total_time = time.time() - self.start_time
234
+ avg_chunk_time = sum(self.chunk_times) / len(self.chunk_times) if self.chunk_times else 0
235
+
236
+ vram_peak = max([usage[1] for usage in self.vram_usage]) if self.vram_usage else 0
237
+ vram_avg = sum([usage[1] for usage in self.vram_usage]) / len(self.vram_usage) if self.vram_usage else 0
238
+
239
+ return {
240
+ "total_time": total_time,
241
+ "avg_chunk_time": avg_chunk_time,
242
+ "total_chunks": len(self.chunk_times),
243
+ "vram_peak": vram_peak,
244
+ "vram_average": vram_avg,
245
+ "batch_count": len(self.batch_times)
246
+ }
247
+
248
+ # ============================================================================
249
+ # ERROR AND WARNING TRACKING
250
+ # ============================================================================
251
+
252
+ def log_processing_error(chunk_id, error_message, error_type="GENERAL"):
253
+ """Log processing errors with categorization"""
254
+ timestamp = time.strftime('%Y-%m-%d %H:%M:%S')
255
+ error_log = f"[{timestamp}] {error_type} ERROR - Chunk {chunk_id}: {error_message}"
256
+
257
+ logging.error(error_log)
258
+ print(f"{RED}❌ Error in chunk {chunk_id}: {error_message}{RESET}")
259
+
260
+ def log_processing_warning(chunk_id, warning_message, warning_type="GENERAL"):
261
+ """Log processing warnings with categorization"""
262
+ timestamp = time.strftime('%Y-%m-%d %H:%M:%S')
263
+ warning_log = f"[{timestamp}] {warning_type} WARNING - Chunk {chunk_id}: {warning_message}"
264
+
265
+ logging.warning(warning_log)
266
+ print(f"{YELLOW}⚠️ Warning in chunk {chunk_id}: {warning_message}{RESET}")
267
+
268
+ # ============================================================================
269
+ # REAL-TIME STATUS DISPLAY
270
+ # ============================================================================
271
+
272
+ def create_status_line(current_chunk, total_chunks, elapsed_time, realtime_factor, vram_usage):
273
+ """Create a single-line status for real-time updates"""
274
+ progress_percent = (current_chunk / total_chunks) * 100
275
+ elapsed_str = str(timedelta(seconds=int(elapsed_time)))
276
+
277
+ status = (f"🔄 {current_chunk}/{total_chunks} ({progress_percent:.1f}%) | "
278
+ f"⏱️ {elapsed_str} | 🚀 {realtime_factor:.2f}x | 💾 {vram_usage:.1f}GB")
279
+
280
+ return status
281
+
282
+ def update_status_line(status_message):
283
+ """Update status line in place"""
284
+ print(f"\r{status_message}", end='', flush=True)
285
+
286
+ # ============================================================================
287
+ # EXPORT FUNCTIONS
288
+ # ============================================================================
289
+
290
+ def export_performance_report(output_dir, performance_data):
291
+ """Export detailed performance report"""
292
+ report_path = output_dir / "performance_report.txt"
293
+
294
+ with open(report_path, 'w', encoding='utf-8') as f:
295
+ f.write("GenTTS Performance Report\n")
296
+ f.write("=" * 50 + "\n\n")
297
+
298
+ f.write(f"Processing Summary:\n")
299
+ f.write(f" Total Processing Time: {timedelta(seconds=int(performance_data['total_time']))}\n")
300
+ f.write(f" Average Chunk Time: {performance_data['avg_chunk_time']:.2f}s\n")
301
+ f.write(f" Total Chunks Processed: {performance_data['total_chunks']}\n")
302
+ f.write(f" Peak VRAM Usage: {performance_data['vram_peak']:.2f}GB\n")
303
+ f.write(f" Average VRAM Usage: {performance_data['vram_average']:.2f}GB\n")
304
+ f.write(f" Batch Count: {performance_data['batch_count']}\n")
305
+
306
+ return report_path
HF_Deploy/modules/resume_handler.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Resume Handler Module
3
+ Handles resume functionality for interrupted processing
4
+ """
5
+
6
+ import torch
7
+ import time
8
+ import logging
9
+ from datetime import timedelta
10
+ from pathlib import Path
11
+
12
+ from config.config import *
13
+ from modules.text_processor import smart_punctuate, sentence_chunk_text
14
+ from modules.file_manager import (
15
+ setup_book_directories, find_book_files, list_voice_samples,
16
+ ensure_voice_sample_compatibility, get_audio_files_in_directory,
17
+ combine_audio_chunks, convert_to_m4b, add_metadata_to_m4b
18
+ )
19
+ from modules.audio_processor import get_chunk_audio_duration, pause_for_chunk_review
20
+ from modules.progress_tracker import setup_logging, log_chunk_progress, log_run
21
+
22
+ def analyze_existing_chunks(audio_chunks_dir):
23
+ """Analyze existing chunks to determine resume point"""
24
+ if not audio_chunks_dir.exists():
25
+ return 0, []
26
+
27
+ chunk_paths = get_audio_files_in_directory(audio_chunks_dir)
28
+
29
+ if not chunk_paths:
30
+ return 0, []
31
+
32
+ # Find the highest chunk number
33
+ chunk_numbers = []
34
+ for chunk_path in chunk_paths:
35
+ import re
36
+ match = re.match(r"chunk_(\d+)\.wav", chunk_path.name)
37
+ if match:
38
+ chunk_numbers.append(int(match.group(1)))
39
+
40
+ if not chunk_numbers:
41
+ return 0, []
42
+
43
+ chunk_numbers.sort()
44
+ last_chunk_number = max(chunk_numbers)
45
+
46
+ # Check for gaps in sequence
47
+ missing_chunks = []
48
+ for i in range(1, last_chunk_number + 1):
49
+ if i not in chunk_numbers:
50
+ missing_chunks.append(i)
51
+
52
+ print(f"📊 Existing chunks analysis:")
53
+ print(f" Total chunks found: {GREEN}{len(chunk_numbers)}{RESET}")
54
+ print(f" Highest chunk number: {GREEN}{last_chunk_number}{RESET}")
55
+ if missing_chunks:
56
+ print(f" Missing chunks: {YELLOW}{len(missing_chunks)}{RESET}")
57
+ if len(missing_chunks) <= 10:
58
+ print(f" Missing: {missing_chunks}")
59
+ else:
60
+ print(f" Missing: {missing_chunks[:10]}... (+{len(missing_chunks)-10} more)")
61
+
62
+ return last_chunk_number, missing_chunks
63
+
64
+ def suggest_resume_point(last_chunk, missing_chunks):
65
+ """Suggest optimal resume point based on existing chunks"""
66
+ if not missing_chunks:
67
+ # No gaps, can resume from next chunk
68
+ return last_chunk + 1
69
+
70
+ # If there are missing chunks, suggest resuming from first missing
71
+ first_missing = min(missing_chunks)
72
+
73
+ print(f"\n💡 Resume suggestions:")
74
+ print(f" Resume from chunk {GREEN}{last_chunk + 1}{RESET} (continue from last)")
75
+ print(f" Resume from chunk {YELLOW}{first_missing}{RESET} (fill gaps first)")
76
+
77
+ return first_missing
78
+
79
+ def validate_resume_point(start_chunk, total_expected_chunks):
80
+ """Validate that resume point makes sense"""
81
+ if start_chunk < 1:
82
+ print(f"{RED}❌ Invalid resume point: {start_chunk}. Must be >= 1{RESET}")
83
+ return False
84
+
85
+ if start_chunk > total_expected_chunks:
86
+ print(f"{RED}❌ Resume point {start_chunk} exceeds expected total chunks {total_expected_chunks}{RESET}")
87
+ return False
88
+
89
+ return True
90
+
91
+ def process_book_folder_resume(book_dir, voice_path, tts_params, device, start_chunk=1):
92
+ """Enhanced book processing with resume capability"""
93
+ from modules.tts_engine import process_one_chunk, load_optimized_model, get_optimal_workers
94
+ from src.chatterbox.tts import punc_norm
95
+ from concurrent.futures import ThreadPoolExecutor, as_completed
96
+
97
+ # Setup directories
98
+ output_root, tts_dir, text_chunks_dir, audio_chunks_dir = setup_book_directories(book_dir)
99
+
100
+ # Find book files
101
+ book_files = find_book_files(book_dir)
102
+ text_file = book_files['text']
103
+ cover_file = book_files['cover']
104
+ nfo_file = book_files['nfo']
105
+
106
+ if not text_file:
107
+ logging.info(f"[{book_dir.name}] ERROR: No .txt files found in the book folder.")
108
+ return None, None, []
109
+
110
+ text_files = [text_file] # Convert to list for compatibility
111
+
112
+ # IMPORTANT: Don't delete existing directories if resuming
113
+ print(f"🔍 DEBUG: start_chunk = {start_chunk}")
114
+ if start_chunk == 1:
115
+ print(f"⚠️ WARNING: start_chunk is 1 - this will clear existing chunks!")
116
+ print(f"📁 About to clear: {audio_chunks_dir}")
117
+
118
+ # Only clear on fresh start
119
+ import shutil
120
+ for d in [text_chunks_dir, audio_chunks_dir]:
121
+ if d.exists() and d.is_dir():
122
+ print(f"🗑️ CLEARING DIRECTORY: {d}")
123
+ shutil.rmtree(d)
124
+
125
+ for d in [output_root, tts_dir, text_chunks_dir, audio_chunks_dir]:
126
+ d.mkdir(parents=True, exist_ok=True)
127
+ else:
128
+ print(f"✅ RESUME MODE: Preserving existing chunks in {audio_chunks_dir}")
129
+ # Ensure directories exist for resume
130
+ for d in [output_root, tts_dir, text_chunks_dir, audio_chunks_dir]:
131
+ d.mkdir(parents=True, exist_ok=True)
132
+
133
+ setup_logging(output_root)
134
+
135
+ # Load existing chunks from JSON (resume should use preprocessed data)
136
+ from modules.tts_engine import find_chunks_json_file
137
+
138
+ json_file = find_chunks_json_file(book_dir.name)
139
+ if json_file:
140
+ print(f"📖 Loading preprocessed chunks from: {json_file.name}")
141
+ from wrapper.chunk_loader import load_chunks
142
+ all_chunks = load_chunks(str(json_file))
143
+ print(f"✅ Loaded {len(all_chunks)} chunks with metadata")
144
+ else:
145
+ print(f"❌ No preprocessed chunks found for {book_dir.name}")
146
+ print(f"💡 Use Option 1 to process this book from the beginning first.")
147
+ return None, None, []
148
+
149
+ # Validate resume point
150
+ if not validate_resume_point(start_chunk, len(all_chunks)):
151
+ return None, None, []
152
+
153
+ # Filter chunks to process (resume logic)
154
+ if start_chunk > 1:
155
+ print(f"🔄 Resuming from chunk {start_chunk}")
156
+ print(f"📊 Skipping chunks 1-{start_chunk-1} (already completed)")
157
+
158
+ # Check which chunks already exist
159
+ existing_chunks = []
160
+ for i in range(start_chunk-1):
161
+ chunk_path = audio_chunks_dir / f"chunk_{i+1:05}.wav"
162
+ if chunk_path.exists():
163
+ existing_chunks.append(i+1)
164
+
165
+ print(f"✅ Found {len(existing_chunks)} existing chunks")
166
+
167
+ # Only process remaining chunks
168
+ chunks_to_process = all_chunks[start_chunk-1:]
169
+ chunk_offset = start_chunk - 1
170
+ else:
171
+ chunks_to_process = all_chunks
172
+ chunk_offset = 0
173
+
174
+ run_log_lines = [
175
+ f"\n===== RESUME Processing: {book_dir.name} =====",
176
+ f"Voice: {voice_path.name}",
177
+ f"Started: {time.strftime('%Y-%m-%d %H:%M:%S')}",
178
+ f"Resume from chunk: {start_chunk}",
179
+ f"Text files processed: {len(text_files)}",
180
+ f"Total chunks generated: {len(all_chunks)}",
181
+ f"Chunks to process: {len(chunks_to_process)}"
182
+ ]
183
+
184
+ # Write initial run info immediately
185
+ initial_log = run_log_lines + [
186
+ f"--- Generation Settings ---",
187
+ f"Batch Processing: Enabled ({BATCH_SIZE} chunks per batch)",
188
+ f"ASR Enabled: {ENABLE_ASR}",
189
+ f"Hum Detection: {ENABLE_HUM_DETECTION}",
190
+ f"Dynamic Workers: {USE_DYNAMIC_WORKERS}",
191
+ f"Voice used: {voice_path.name}",
192
+ f"Exaggeration: {tts_params['exaggeration']}",
193
+ f"CFG weight: {tts_params['cfg_weight']}",
194
+ f"Temperature: {tts_params['temperature']}",
195
+ f"Processing Status: IN PROGRESS...",
196
+ f"="*50
197
+ ]
198
+
199
+ log_run("\n".join(initial_log), output_root / "run.log")
200
+ print(f"📝 Initial run info written to: {output_root / 'run.log'}")
201
+
202
+ start_time = time.time()
203
+ total_chunks = len(all_chunks)
204
+ remaining_chunks = len(chunks_to_process)
205
+ log_path = output_root / "chunk_validation.log"
206
+
207
+ # Calculate existing audio duration for accurate progress
208
+ total_audio_duration = 0.0
209
+ if start_chunk > 1:
210
+ print("📊 Calculating existing audio duration...")
211
+ for i in range(start_chunk-1):
212
+ chunk_path = audio_chunks_dir / f"chunk_{i+1:05}.wav"
213
+ if chunk_path.exists():
214
+ total_audio_duration += get_chunk_audio_duration(chunk_path)
215
+ print(f"📊 Existing audio: {timedelta(seconds=int(total_audio_duration))}")
216
+
217
+ # Initialize performance optimizations
218
+ from modules.tts_engine import detect_deployment_environment, enable_gpu_persistence_mode
219
+ deployment_env = detect_deployment_environment()
220
+ print(f"🌍 Deployment environment: {deployment_env}")
221
+
222
+ # Enable GPU persistence mode for better performance
223
+ gpu_persistence_enabled = enable_gpu_persistence_mode()
224
+
225
+ # Batch processing for remaining chunks
226
+ print(f"📊 Processing {remaining_chunks} remaining chunks in batches of {BATCH_SIZE}")
227
+
228
+ all_results = []
229
+
230
+ for batch_start in range(0, remaining_chunks, BATCH_SIZE):
231
+ batch_end = min(batch_start + BATCH_SIZE, remaining_chunks)
232
+ batch_chunks = chunks_to_process[batch_start:batch_end]
233
+
234
+ actual_start_chunk = chunk_offset + batch_start + 1
235
+ actual_end_chunk = chunk_offset + batch_end
236
+
237
+ print(f"\n🔄 Processing batch: chunks {actual_start_chunk}-{actual_end_chunk}")
238
+
239
+ # Fresh model for each batch
240
+ model = load_optimized_model(device)
241
+ compatible_voice = ensure_voice_sample_compatibility(voice_path, output_dir=tts_dir)
242
+
243
+ # Pre-warm model to eliminate first chunk quality variations
244
+ from modules.tts_engine import prewarm_model_with_voice
245
+ model = prewarm_model_with_voice(model, compatible_voice, tts_params)
246
+
247
+ # Load ASR model once per batch if needed using adaptive manager
248
+ asr_model = None
249
+ asr_device_used = None
250
+ if ENABLE_ASR:
251
+ from modules.asr_manager import load_asr_model_adaptive
252
+ print(f"🎤 Loading ASR model for resume mode...")
253
+ # Resume mode uses fallback config (no intelligent selection)
254
+ asr_model, asr_device_used = load_asr_model_adaptive()
255
+
256
+ futures = []
257
+ batch_results = []
258
+
259
+ # Dynamic worker allocation
260
+ optimal_workers = get_optimal_workers()
261
+ print(f"🔧 Using {optimal_workers} workers for batch {actual_start_chunk}-{actual_end_chunk}")
262
+
263
+ # Try producer-consumer pipeline first (Phase 4 optimization)
264
+ batch_results = []
265
+ if ENABLE_PRODUCER_CONSUMER_PIPELINE:
266
+ try:
267
+ print(f"🚀 Attempting producer-consumer pipeline for resume batch {actual_start_chunk}-{actual_end_chunk}")
268
+ from modules.tts_engine import process_chunks_with_pipeline
269
+ pipeline_results = process_chunks_with_pipeline(
270
+ all_chunks, batch_chunks, chunk_offset, text_chunks_dir, audio_chunks_dir,
271
+ voice_path, tts_params, start_time, total_chunks, punc_norm, book_dir.name,
272
+ log_run, log_path, device, model, asr_model, True, optimal_workers, # asr_enabled=True for resume
273
+ total_audio_duration # Pass accumulated duration for proper ETA calculation
274
+ )
275
+
276
+ # Handle tuple return from pipeline
277
+ if isinstance(pipeline_results, tuple) and len(pipeline_results) == 2:
278
+ batch_results, batch_audio_duration = pipeline_results
279
+ total_audio_duration += batch_audio_duration
280
+ else:
281
+ # Fallback for old return format
282
+ batch_results = pipeline_results
283
+
284
+ if batch_results:
285
+ print(f"✅ Producer-consumer pipeline completed resume batch: {len(batch_results)} chunks")
286
+ # Pipeline already handled progress logging internally
287
+
288
+ except Exception as e:
289
+ logging.error(f"❌ Producer-consumer pipeline failed in resume: {e}")
290
+ if not ENABLE_PIPELINE_FALLBACK:
291
+ raise
292
+ batch_results = [] # Clear failed results
293
+
294
+ # Fallback to original sequential processing if pipeline disabled or failed
295
+ if not batch_results:
296
+ print(f"🔄 Using sequential processing fallback for resume batch {actual_start_chunk}-{actual_end_chunk}")
297
+ futures = []
298
+
299
+ with ThreadPoolExecutor(max_workers=optimal_workers) as executor:
300
+ for i, chunk_data in enumerate(batch_chunks):
301
+ global_chunk_index = chunk_offset + i
302
+
303
+ # Check for shutdown request
304
+ if shutdown_requested:
305
+ print(f"\n⏹️ {YELLOW}Stopping submission of new chunks...{RESET}")
306
+ break
307
+
308
+ chunk = chunk_data["text"]
309
+ all_chunk_texts = [cd["text"] for cd in all_chunks]
310
+ boundary_type = chunk_data.get("boundary_type", "none")
311
+
312
+ futures.append(executor.submit(
313
+ process_one_chunk,
314
+ global_chunk_index, chunk, text_chunks_dir, audio_chunks_dir,
315
+ voice_path, tts_params, start_time, total_chunks,
316
+ punc_norm, book_dir.name, log_run, log_path, device,
317
+ model, asr_model, all_chunk_texts, boundary_type
318
+ ))
319
+
320
+ # Wait for batch to complete
321
+ print(f"🔄 {CYAN}Waiting for batch {actual_start_chunk}-{actual_end_chunk} to complete...{RESET}")
322
+ completed_count = 0
323
+
324
+ for fut in as_completed(futures):
325
+ try:
326
+ idx, wav_path = fut.result()
327
+ if wav_path and wav_path.exists():
328
+ # Measure actual audio duration for this chunk
329
+ chunk_duration = get_chunk_audio_duration(wav_path)
330
+ total_audio_duration += chunk_duration
331
+ batch_results.append((idx, wav_path))
332
+
333
+ # Update progress every 10 chunks within batch
334
+ completed_count += 1
335
+ if completed_count % 10 == 0:
336
+ current_chunk = chunk_offset + completed_count
337
+ log_chunk_progress(current_chunk - 1, total_chunks, start_time, total_audio_duration)
338
+
339
+ except Exception as e:
340
+ logging.error(f"Future failed in batch: {e}")
341
+
342
+ # Clean up model after batch
343
+ print(f"🧹 Cleaning up after batch {actual_start_chunk}-{actual_end_chunk}")
344
+ del model
345
+ if asr_model:
346
+ from modules.asr_manager import cleanup_asr_model
347
+ cleanup_asr_model(asr_model)
348
+ torch.cuda.empty_cache()
349
+ import gc
350
+ gc.collect()
351
+ time.sleep(2)
352
+
353
+ all_results.extend(batch_results)
354
+ print(f"✅ Batch {actual_start_chunk}-{actual_end_chunk} completed ({len(batch_results)} chunks)")
355
+
356
+ # Final processing - combine ALL chunks (existing + new)
357
+ quarantine_dir = audio_chunks_dir / "quarantine"
358
+ pause_for_chunk_review(quarantine_dir)
359
+
360
+ # Collect ALL chunk paths (both existing and newly created)
361
+ chunk_paths = []
362
+ for i in range(total_chunks):
363
+ chunk_path = audio_chunks_dir / f"chunk_{i+1:05}.wav"
364
+ if chunk_path.exists():
365
+ chunk_paths.append(chunk_path)
366
+ else:
367
+ logging.warning(f"Missing chunk file: chunk_{i+1:05}.wav")
368
+
369
+ if not chunk_paths:
370
+ logging.info(f"{RED}❌ No valid audio chunks found. Skipping concatenation and conversion.{RESET}")
371
+ return None, None, []
372
+
373
+ print(f"📊 Found {len(chunk_paths)} total chunks for final audiobook")
374
+
375
+ # Calculate timing
376
+ elapsed_total = time.time() - start_time
377
+ elapsed_td = timedelta(seconds=int(elapsed_total))
378
+
379
+ # Get total audio duration from ALL chunks
380
+ total_audio_duration_final = sum(get_chunk_audio_duration(chunk_path) for chunk_path in chunk_paths)
381
+ audio_duration_td = timedelta(seconds=int(total_audio_duration_final))
382
+ realtime_factor = total_audio_duration_final / elapsed_total if elapsed_total > 0 else 0.0
383
+
384
+ print(f"\n⏱️ Resume Processing Complete:")
385
+ print(f" Elapsed Time: {CYAN}{str(elapsed_td)}{RESET}")
386
+ print(f" Audio Duration: {GREEN}{str(audio_duration_td)}{RESET}")
387
+ print(f" Realtime Factor: {YELLOW}{realtime_factor:.2f}x{RESET}")
388
+
389
+ # Combine audio
390
+ combined_wav_path = output_root / f"{book_dir.name} [{voice_path.stem}].wav"
391
+ print("\n💾 Saving WAV file...")
392
+ combine_audio_chunks(chunk_paths, combined_wav_path)
393
+
394
+ # M4B conversion
395
+ temp_m4b_path = output_root / "output.m4b"
396
+ final_m4b_path = output_root / f"{book_dir.name}[{voice_path.stem}].m4b"
397
+ convert_to_m4b(combined_wav_path, temp_m4b_path)
398
+ add_metadata_to_m4b(temp_m4b_path, final_m4b_path, cover_file, nfo_file)
399
+
400
+ logging.info(f"Audiobook created: {final_m4b_path}")
401
+
402
+ # Append final completion info
403
+ completion_log = [
404
+ f"\n--- Resume Processing Complete ---",
405
+ f"Completed: {time.strftime('%Y-%m-%d %H:%M:%S')}",
406
+ f"Processing Time: {str(elapsed_td)}",
407
+ f"Audio Duration: {str(audio_duration_td)}",
408
+ f"Realtime Factor: {realtime_factor:.2f}x",
409
+ f"Total Chunks: {len(chunk_paths)}",
410
+ f"Combined WAV: {combined_wav_path}",
411
+ f"Final M4B: {final_m4b_path}"
412
+ ]
413
+
414
+ # Append to existing log
415
+ log_run("\n".join(completion_log), output_root / "run.log")
416
+ print(f"📝 Final completion info appended to: {output_root / 'run.log'}")
417
+
418
+ return final_m4b_path, combined_wav_path, run_log_lines
419
+
420
+ def resume_book_from_chunk(start_chunk):
421
+ """Interactive resume function for stuck book"""
422
+ print(f"\n🔄 Resume Book Processing from Chunk {start_chunk}")
423
+ print("=" * 50)
424
+
425
+ # Show available books from Audiobook directory (books that have started processing)
426
+ audiobook_root = Path(AUDIOBOOK_ROOT)
427
+ if not audiobook_root.exists():
428
+ print(f"{RED}No audiobook directory found at {AUDIOBOOK_ROOT}.{RESET}")
429
+ return None
430
+
431
+ book_dirs = sorted([d for d in audiobook_root.iterdir() if d.is_dir() and d.name != "Audio_Revisions"])
432
+ if not book_dirs:
433
+ print(f"{RED}No books found in {AUDIOBOOK_ROOT}/ - no books have started processing.{RESET}")
434
+ print(f"💡 Use Option 1 to start processing a new book first.")
435
+ return None
436
+
437
+ print("Available books (in progress or completed):")
438
+ for i, book_dir in enumerate(book_dirs):
439
+ # All books in Audiobook/ should have processing data
440
+ audio_chunks_dir = book_dir / "TTS" / "audio_chunks"
441
+ if audio_chunks_dir.exists():
442
+ last_chunk, missing = analyze_existing_chunks(audio_chunks_dir)
443
+ if missing:
444
+ status = f"(last chunk: {last_chunk}, {len(missing)} missing)"
445
+ else:
446
+ status = f"(completed: {last_chunk} chunks)"
447
+ else:
448
+ status = "(processing started but no chunks yet)"
449
+
450
+ print(f" [{i}] {book_dir.name} {status}")
451
+
452
+ while True:
453
+ try:
454
+ book_idx = int(input("Select book index: "))
455
+ if 0 <= book_idx < len(book_dirs):
456
+ audiobook_dir = book_dirs[book_idx]
457
+ # Find corresponding Text_Input directory
458
+ text_input_book_dir = TEXT_INPUT_ROOT / audiobook_dir.name
459
+ if text_input_book_dir.exists():
460
+ book_dir = text_input_book_dir
461
+ else:
462
+ print(f"❌ Text_Input directory not found for {audiobook_dir.name}")
463
+ print(f"💡 The original book files may have been moved or deleted.")
464
+ continue
465
+ break
466
+ except Exception:
467
+ pass
468
+ print("Invalid selection. Try again.")
469
+
470
+ # Analyze existing chunks for selected book
471
+ audiobook_dir = AUDIOBOOK_ROOT / book_dir.name
472
+ if audiobook_dir.exists():
473
+ audio_chunks_dir = audiobook_dir / "TTS" / "audio_chunks"
474
+ if audio_chunks_dir.exists():
475
+ last_chunk, missing = analyze_existing_chunks(audio_chunks_dir)
476
+ suggested_resume = suggest_resume_point(last_chunk, missing)
477
+
478
+ print(f"\nSuggested resume point: {GREEN}{suggested_resume}{RESET}")
479
+
480
+ # Allow user to override
481
+ user_input = input(f"Resume from chunk [{suggested_resume}]: ").strip()
482
+ if user_input:
483
+ try:
484
+ start_chunk = int(user_input)
485
+ except ValueError:
486
+ print(f"Invalid input, using suggested: {suggested_resume}")
487
+ start_chunk = suggested_resume
488
+ else:
489
+ start_chunk = suggested_resume
490
+
491
+ # Show available voices
492
+ voice_files = list_voice_samples()
493
+ if not voice_files:
494
+ print(f"{RED}No voice samples found.{RESET}")
495
+ return None
496
+
497
+ print("\nAvailable voices:")
498
+ for i, voice in enumerate(voice_files):
499
+ print(f" [{i}] {voice.name}")
500
+
501
+ while True:
502
+ try:
503
+ voice_idx = int(input("Select voice index: "))
504
+ if 0 <= voice_idx < len(voice_files):
505
+ voice_path = voice_files[voice_idx]
506
+ break
507
+ except Exception:
508
+ pass
509
+ print("Invalid selection. Try again.")
510
+
511
+ # Get TTS parameters
512
+ def prompt_float(prompt, default):
513
+ val = input(f"{prompt} [{default}]: ").strip()
514
+ return float(val) if val else default
515
+
516
+ exaggeration = prompt_float("Enter exaggeration (emotion intensity)", DEFAULT_EXAGGERATION)
517
+ cfg_weight = prompt_float("Enter cfg_weight (faithfulness to text)", DEFAULT_CFG_WEIGHT)
518
+ temperature = prompt_float("Enter temperature (randomness)", DEFAULT_TEMPERATURE)
519
+
520
+ tts_params = dict(exaggeration=exaggeration, cfg_weight=cfg_weight, temperature=temperature)
521
+
522
+ # Determine device with proper validation
523
+ from modules.tts_engine import get_best_available_device
524
+ device = get_best_available_device()
525
+
526
+ print(f"\n🚀 Resuming {book_dir.name} from chunk {start_chunk}")
527
+ print(f"🎤 Voice: {voice_path.name}")
528
+ print(f"⚙️ Parameters: {tts_params}")
529
+
530
+ # Process with resume
531
+ return process_book_folder_resume(book_dir, voice_path, tts_params, device, start_chunk)
532
+
533
+ def find_incomplete_books():
534
+ """Find books that appear to be incomplete"""
535
+ incomplete_books = []
536
+
537
+ for book_dir in TEXT_INPUT_ROOT.iterdir():
538
+ if not book_dir.is_dir():
539
+ continue
540
+
541
+ audiobook_dir = AUDIOBOOK_ROOT / book_dir.name
542
+ if not audiobook_dir.exists():
543
+ continue
544
+
545
+ audio_chunks_dir = audiobook_dir / "TTS" / "audio_chunks"
546
+ if not audio_chunks_dir.exists():
547
+ continue
548
+
549
+ # Check if there's a final M4B
550
+ m4b_files = list(audiobook_dir.glob("*.m4b"))
551
+ wav_files = list(audiobook_dir.glob("*.wav"))
552
+
553
+ if not m4b_files and not wav_files:
554
+ # No final output, likely incomplete
555
+ last_chunk, missing = analyze_existing_chunks(audio_chunks_dir)
556
+ if last_chunk > 0:
557
+ incomplete_books.append({
558
+ "name": book_dir.name,
559
+ "last_chunk": last_chunk,
560
+ "missing_chunks": len(missing),
561
+ "path": book_dir
562
+ })
563
+
564
+ return incomplete_books
565
+
566
+ def auto_resume_incomplete():
567
+ """Automatically suggest resume for incomplete books"""
568
+ incomplete = find_incomplete_books()
569
+
570
+ if not incomplete:
571
+ print(f"{GREEN}✅ No incomplete books found!{RESET}")
572
+ return
573
+
574
+ print(f"{YELLOW}📋 Found {len(incomplete)} incomplete books:{RESET}")
575
+ for i, book in enumerate(incomplete):
576
+ print(f" [{i}] {book['name']} (last chunk: {book['last_chunk']}, missing: {book['missing_chunks']})")
577
+
578
+ choice = input(f"\nSelect book to resume [0-{len(incomplete)-1}] or 'q' to quit: ").strip()
579
+
580
+ if choice.lower() == 'q':
581
+ return
582
+
583
+ try:
584
+ idx = int(choice)
585
+ if 0 <= idx < len(incomplete):
586
+ selected_book = incomplete[idx]
587
+ suggested_resume = selected_book['last_chunk'] + 1
588
+
589
+ print(f"\n🎯 Selected: {selected_book['name']}")
590
+ print(f"💡 Suggested resume point: chunk {suggested_resume}")
591
+
592
+ return resume_book_from_chunk(suggested_resume)
593
+ except ValueError:
594
+ print("Invalid selection.")
595
+
596
+ return None
HF_Deploy/modules/system_detector.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ System Resource Detection Module
3
+ Detects VRAM, RAM, CPU cores and recommends appropriate ASR models
4
+ """
5
+
6
+ import psutil
7
+ import torch
8
+ import os
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ # Add project root to path for imports
13
+ if __name__ == "__main__":
14
+ sys.path.insert(0, str(Path(__file__).parent.parent))
15
+
16
+ from config.config import ASR_MODEL_VRAM_MB, ASR_MODEL_RAM_MB
17
+
18
+ def get_gpu_memory():
19
+ """Get total and available GPU memory in MB"""
20
+ try:
21
+ if torch.cuda.is_available():
22
+ gpu_count = torch.cuda.device_count()
23
+ if gpu_count > 0:
24
+ # Use first GPU
25
+ total_vram = torch.cuda.get_device_properties(0).total_memory
26
+ allocated_vram = torch.cuda.memory_allocated(0)
27
+ available_vram = total_vram - allocated_vram
28
+
29
+ return {
30
+ 'total_mb': total_vram // 1024 // 1024,
31
+ 'available_mb': available_vram // 1024 // 1024,
32
+ 'allocated_mb': allocated_vram // 1024 // 1024
33
+ }
34
+ except:
35
+ pass
36
+
37
+ return {'total_mb': 0, 'available_mb': 0, 'allocated_mb': 0}
38
+
39
+ def get_system_memory():
40
+ """Get total and available system RAM in MB"""
41
+ try:
42
+ memory = psutil.virtual_memory()
43
+ return {
44
+ 'total_mb': memory.total // 1024 // 1024,
45
+ 'available_mb': memory.available // 1024 // 1024,
46
+ 'used_mb': memory.used // 1024 // 1024
47
+ }
48
+ except:
49
+ return {'total_mb': 0, 'available_mb': 0, 'used_mb': 0}
50
+
51
+ def get_cpu_cores():
52
+ """Get number of CPU cores"""
53
+ try:
54
+ return psutil.cpu_count(logical=False) or psutil.cpu_count()
55
+ except:
56
+ return 1
57
+
58
+ def estimate_tts_vram_usage():
59
+ """Estimate VRAM usage by ChatterboxTTS (updated based on real usage)"""
60
+ return 5500 # 5.5GB in MB (was 7GB, adjusted based on actual 3.5GB usage + buffer)
61
+
62
+ def get_system_profile():
63
+ """Get complete system resource profile"""
64
+ gpu_info = get_gpu_memory()
65
+ ram_info = get_system_memory()
66
+ cpu_cores = get_cpu_cores()
67
+
68
+ # Estimate available resources after TTS loading
69
+ tts_vram_estimate = estimate_tts_vram_usage()
70
+ available_vram_after_tts = max(0, gpu_info['available_mb'] - tts_vram_estimate)
71
+
72
+ return {
73
+ 'gpu': gpu_info,
74
+ 'ram': ram_info,
75
+ 'cpu_cores': cpu_cores,
76
+ 'available_vram_after_tts': available_vram_after_tts,
77
+ 'has_gpu': gpu_info['total_mb'] > 0
78
+ }
79
+
80
+ def categorize_system(profile):
81
+ """Categorize system capabilities"""
82
+ gpu_total = profile['gpu']['total_mb']
83
+ ram_total = profile['ram']['total_mb']
84
+ cpu_cores = profile['cpu_cores']
85
+
86
+ # VRAM categories
87
+ if gpu_total < 4000:
88
+ vram_category = "low"
89
+ elif gpu_total <= 12000:
90
+ vram_category = "medium"
91
+ else:
92
+ vram_category = "high"
93
+
94
+ # RAM categories
95
+ if ram_total < 16000:
96
+ ram_category = "low"
97
+ elif ram_total <= 64000:
98
+ ram_category = "medium"
99
+ else:
100
+ ram_category = "high"
101
+
102
+ # CPU categories
103
+ if cpu_cores < 6:
104
+ cpu_category = "low"
105
+ elif cpu_cores <= 16:
106
+ cpu_category = "medium"
107
+ else:
108
+ cpu_category = "high"
109
+
110
+ return {
111
+ 'vram': vram_category,
112
+ 'ram': ram_category,
113
+ 'cpu': cpu_category
114
+ }
115
+
116
+ def get_safe_asr_models(profile):
117
+ """Get ASR models that can safely run on GPU with available VRAM"""
118
+ available_vram = profile['available_vram_after_tts']
119
+ safe_models = []
120
+
121
+ for model, vram_req in ASR_MODEL_VRAM_MB.items():
122
+ if vram_req <= available_vram:
123
+ safe_models.append(model)
124
+
125
+ return safe_models
126
+
127
+ def get_safe_cpu_models(profile):
128
+ """Get ASR models that can safely run on CPU with available RAM"""
129
+ available_ram = profile['ram']['available_mb']
130
+ safe_models = []
131
+
132
+ for model, ram_req in ASR_MODEL_RAM_MB.items():
133
+ if ram_req <= available_ram:
134
+ safe_models.append(model)
135
+
136
+ return safe_models
137
+
138
+ def recommend_asr_models(profile):
139
+ """Recommend Safe/Moderate/Insane ASR model configurations"""
140
+ categories = categorize_system(profile)
141
+ safe_gpu_models = get_safe_asr_models(profile)
142
+ safe_cpu_models = get_safe_cpu_models(profile)
143
+
144
+ recommendations = {}
145
+
146
+ # Model priority order (best to worst)
147
+ model_priority = ["large-v3", "large", "large-v2", "medium", "small", "base", "tiny"]
148
+
149
+ # Safe: Conservative choice
150
+ safe_gpu = None
151
+ safe_cpu = None
152
+
153
+ for model in reversed(model_priority): # Start from smallest
154
+ if model in safe_gpu_models and not safe_gpu:
155
+ safe_gpu = model
156
+ if model in safe_cpu_models and not safe_cpu:
157
+ safe_cpu = model
158
+ if safe_gpu and safe_cpu:
159
+ break
160
+
161
+ # Moderate: Balanced choice
162
+ moderate_gpu = None
163
+ moderate_cpu = None
164
+
165
+ # Try to get a model 1-2 steps up from safe
166
+ safe_idx = model_priority.index(safe_gpu) if safe_gpu else len(model_priority)
167
+ moderate_idx = max(0, safe_idx - 2)
168
+
169
+ for i in range(moderate_idx, len(model_priority)):
170
+ model = model_priority[i]
171
+ if model in safe_gpu_models and not moderate_gpu:
172
+ moderate_gpu = model
173
+ if model in safe_cpu_models and not moderate_cpu:
174
+ moderate_cpu = model
175
+ if moderate_gpu and moderate_cpu:
176
+ break
177
+
178
+ # Insane: Push the limits (best available models)
179
+ insane_gpu = None
180
+ insane_cpu = None
181
+
182
+ # Get the best (largest) models that are safe
183
+ for model in model_priority: # Start from best
184
+ if model in safe_gpu_models and not insane_gpu:
185
+ insane_gpu = model
186
+ if model in safe_cpu_models and not insane_cpu:
187
+ insane_cpu = model
188
+ if insane_gpu and insane_cpu:
189
+ break
190
+
191
+ # Build recommendations
192
+ recommendations['safe'] = {
193
+ 'primary': {'model': safe_gpu or safe_cpu, 'device': 'gpu' if safe_gpu else 'cpu'},
194
+ 'fallback': {'model': safe_cpu, 'device': 'cpu'}
195
+ }
196
+
197
+ recommendations['moderate'] = {
198
+ 'primary': {'model': moderate_gpu or moderate_cpu, 'device': 'gpu' if moderate_gpu else 'cpu'},
199
+ 'fallback': {'model': moderate_cpu, 'device': 'cpu'}
200
+ }
201
+
202
+ recommendations['insane'] = {
203
+ 'primary': {'model': insane_gpu or insane_cpu, 'device': 'gpu' if insane_gpu else 'cpu'},
204
+ 'fallback': {'model': insane_cpu, 'device': 'cpu'}
205
+ }
206
+
207
+ return recommendations
208
+
209
+ def print_system_summary(profile):
210
+ """Print a human-readable system summary"""
211
+ categories = categorize_system(profile)
212
+
213
+ print(f"🖥️ System Profile:")
214
+ print(f" VRAM: {profile['gpu']['total_mb']:,}MB total, {profile['available_vram_after_tts']:,}MB available after TTS ({categories['vram']} class)")
215
+ print(f" RAM: {profile['ram']['total_mb']:,}MB total, {profile['ram']['available_mb']:,}MB available ({categories['ram']} class)")
216
+ print(f" CPU: {profile['cpu_cores']} cores ({categories['cpu']} class)")
217
+
218
+ if not profile['has_gpu']:
219
+ print(f" ⚠️ No CUDA GPU detected - ASR will run on CPU only")
220
+
221
+ if __name__ == "__main__":
222
+ # Test the detection
223
+ profile = get_system_profile()
224
+ print_system_summary(profile)
225
+
226
+ recommendations = recommend_asr_models(profile)
227
+ print(f"\nASR Model Recommendations:")
228
+ for level, config in recommendations.items():
229
+ primary = config['primary']
230
+ fallback = config['fallback']
231
+ print(f"🟢 {level.upper()}: {primary['model']} ({primary['device']}) + {fallback['model']} (cpu fallback)")
HF_Deploy/modules/text_processor.py ADDED
@@ -0,0 +1,745 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text Processing Module
3
+ Handles text chunking, abbreviations, and preprocessing for TTS
4
+ """
5
+
6
+ import re
7
+ import logging
8
+ from pathlib import Path
9
+ from config.config import MAX_CHUNK_WORDS, MIN_CHUNK_WORDS, YELLOW, RESET
10
+
11
+
12
+
13
+ # ============================================================================
14
+ # ABBREVIATION REPLACEMENT SYSTEM
15
+ # ============================================================================
16
+
17
+ def load_abbreviations(file_path="utils/abbreviations.txt"):
18
+ """Load abbreviation replacements from external file"""
19
+ replacements = {}
20
+ abbrev_file = Path(file_path)
21
+
22
+ if not abbrev_file.exists():
23
+ print(f"⚠️ {YELLOW}Abbreviations file not found: {file_path}{RESET}")
24
+ print(f"📝 Creating sample file...")
25
+ create_sample_abbreviations_file(abbrev_file)
26
+ return replacements
27
+
28
+ try:
29
+ with open(abbrev_file, 'r', encoding='utf-8') as f:
30
+ for line_num, line in enumerate(f, 1):
31
+ line = line.strip()
32
+
33
+ # Skip empty lines and comments
34
+ if not line or line.startswith('#'):
35
+ continue
36
+
37
+ # Parse "abbrev -> replacement" format
38
+ if ' -> ' in line:
39
+ abbrev, replacement = line.split(' -> ', 1)
40
+ replacements[abbrev.strip()] = replacement.strip()
41
+ else:
42
+ print(f"⚠️ Invalid format on line {line_num}: {line}")
43
+
44
+ print(f"✅ Loaded {len(replacements)} abbreviation replacements from {file_path}")
45
+
46
+ except Exception as e:
47
+ print(f"❌ Error loading abbreviations: {e}")
48
+
49
+ return replacements
50
+
51
+ def create_sample_abbreviations_file(file_path):
52
+ """Create a sample abbreviations file with common replacements"""
53
+ sample_content = """# Abbreviation Replacements for TTS
54
+ # Format: abbreviation -> replacement
55
+ # Lines starting with # are comments
56
+
57
+ # Common titles and abbreviations
58
+ Dr. -> Doctor
59
+ Mr. -> Mister
60
+ Mrs. -> Missus
61
+ Ms. -> Miss
62
+ Prof. -> Professor
63
+ Rev. -> Reverend
64
+ Lt. -> Lieutenant
65
+ Capt. -> Captain
66
+ Gen. -> General
67
+ Col. -> Colonel
68
+ Jr. -> Junior
69
+ Sr. -> Senior
70
+
71
+ # Political and organizations
72
+ M.P. -> MP
73
+ U.S. -> US
74
+ U.K. -> UK
75
+ U.N. -> UN
76
+ F.B.I. -> FBI
77
+ C.I.A. -> CIA
78
+ N.A.S.A. -> NASA
79
+
80
+ # Common abbreviations
81
+ etc. -> et cetera
82
+ vs. -> versus
83
+ e.g. -> for example
84
+ i.e. -> that is
85
+ Inc. -> Incorporated
86
+ Corp. -> Corporation
87
+ Ltd. -> Limited
88
+ Co. -> Company
89
+
90
+ # Numbers and ordinals
91
+ 1st -> first
92
+ 2nd -> second
93
+ 3rd -> third
94
+ 4th -> fourth
95
+ 5th -> fifth
96
+ 10th -> tenth
97
+ 20th -> twentieth
98
+ 21st -> twenty-first
99
+ 30th -> thirtieth
100
+ 40th -> fortieth
101
+ 50th -> fiftieth
102
+ 60th -> sixtieth
103
+ 70th -> seventieth
104
+ 80th -> eightieth
105
+ 90th -> ninetieth
106
+ 100th -> one hundredth
107
+
108
+ # Time abbreviations
109
+ a.m. -> AM
110
+ p.m. -> PM
111
+ A.M. -> AM
112
+ P.M. -> PM
113
+ """
114
+
115
+ try:
116
+ with open(file_path, 'w', encoding='utf-8') as f:
117
+ f.write(sample_content)
118
+ print(f"📝 Created sample abbreviations file: {file_path}")
119
+ print(f"💡 Edit this file to add your own replacements!")
120
+ except Exception as e:
121
+ print(f"❌ Error creating sample file: {e}")
122
+
123
+ def preprocess_abbreviations(text, replacements):
124
+ """Replace abbreviations with TTS-friendly versions"""
125
+ if not replacements:
126
+ return text
127
+
128
+ original_text = text
129
+ replacements_made = 0
130
+
131
+ # Apply replacements (order matters for overlapping patterns)
132
+ for abbrev, replacement in replacements.items():
133
+ if abbrev in text:
134
+ text = text.replace(abbrev, replacement)
135
+ replacements_made += 1
136
+
137
+ if replacements_made > 0:
138
+ logging.info(f"📝 Applied {replacements_made} abbreviation replacements")
139
+
140
+ return text
141
+
142
+ # ============================================================================
143
+ # TEXT PREPROCESSING AND CHUNKING
144
+ # ============================================================================
145
+
146
+ def smart_punctuate(text):
147
+ """
148
+ Enhanced punctuation normalization with abbreviation replacement.
149
+
150
+ PROCESSING REQUIREMENTS:
151
+ - Load and apply abbreviation replacements (Dr. -> Doctor, etc.)
152
+ - Add periods to lines that don't end with punctuation
153
+ - Replace Unicode smart quotes with ASCII quotes (", ')
154
+ - Remove problematic formatting (bold markdown, underlines)
155
+ - Preserve paragraph breaks (empty lines)
156
+
157
+ This prepares text for consistent TTS processing.
158
+ """
159
+
160
+ # Load abbreviations and apply them
161
+ abbreviation_replacements = load_abbreviations()
162
+ text = preprocess_abbreviations(text, abbreviation_replacements)
163
+
164
+ # Then continue with existing punctuation logic
165
+ lines = text.splitlines()
166
+ out = []
167
+
168
+ for l in lines:
169
+ stripped = l.strip()
170
+
171
+ # Preserve empty lines (paragraph breaks)
172
+ if not stripped:
173
+ out.append("") # Keep the blank line
174
+ # Process non-empty lines
175
+ elif not re.search(r'[.!?]$', stripped) and not re.search(r'[.!?]["\']$', stripped):
176
+ out.append(stripped + ".")
177
+ else:
178
+ out.append(stripped)
179
+
180
+ result = "\n".join(out)
181
+
182
+ # Enhanced text preprocessing - replace curly quotes with straight quotes
183
+ result = result.replace('\u201c', '"').replace('\u201d', '"') # Replace smart double quotes " "
184
+ result = result.replace('\u2018', "'").replace('\u2019', "'") # Replace smart single quotes ' '
185
+
186
+ # Remove problematic formatting
187
+ result = re.sub(r'\*\*([^*]+)\*\*', r'\1', result) # Remove bold markdown
188
+ result = re.sub(r'_{2,}', '', result) # Remove underlines
189
+
190
+ # Fix any escaped quotes that might appear in the text
191
+ result = result.replace('\\"', '"').replace("\\'", "'")
192
+
193
+ # Additional quote normalization to prevent recurring dialogue corruption
194
+ result = re.sub(r'(["\'])\s*,\s*(["\'])', r'\1, \2', result) # Fix quote spacing around commas
195
+ result = re.sub(r'(["\'])\s*\.\s*(["\'])', r'\1. \2', result) # Fix quote spacing around periods
196
+ result = re.sub(r'(["\'])\s*([,.])\s*(["\'])\s*([,.])', r'\1\2 \3', result) # Remove duplicate punctuation
197
+
198
+ # Debug logging for dialogue patterns
199
+ if '"' in result and ('replied' in result or 'said' in result):
200
+ print(f"🗣️ DEBUG: Dialogue detected in smart_punctuate: {result[:100]}...")
201
+
202
+ return result
203
+
204
+ def fix_short_sentence_artifacts(chunk_text):
205
+ """
206
+ Fix multiple short sentences that cause TTS errors.
207
+ Example: "Yes. No. Maybe." → "Yes, no, maybe."
208
+ "Right." → "Right," (if it's a single-word chunk)
209
+ """
210
+ # Handle full chunk that is just one short sentence
211
+ words = chunk_text.strip().split()
212
+ if len(words) == 1 and chunk_text.strip().endswith('.'):
213
+ return chunk_text.strip()[:-1] + ',' # Replace period with comma
214
+
215
+ parts = re.split(r'([.!?])', chunk_text.strip())
216
+ if len(parts) < 2:
217
+ return chunk_text # nothing to fix
218
+
219
+ # Reconstruct sentence-punctuation pairs
220
+ sentences = []
221
+ for i in range(0, len(parts)-1, 2):
222
+ sentence = parts[i].strip()
223
+ punct = parts[i+1]
224
+ if sentence:
225
+ word_count = len(sentence.split())
226
+ sentences.append((sentence, punct, word_count))
227
+
228
+ # Handle multiple short sentences
229
+ short_count = sum(1 for _, _, wc in sentences if wc <= 3)
230
+
231
+ if short_count >= 2 and len(sentences) >= 2:
232
+ merged = ", ".join(s for s, _, _ in sentences) + "."
233
+ return merged
234
+
235
+ # Handle case where first sentence is a single word
236
+ if len(sentences) >= 2 and sentences[0][2] == 1 and sentences[0][1] == ".":
237
+ # Replace period with comma
238
+ first, second = sentences[0][0], sentences[1][0]
239
+ rest = " ".join(s for s, _, _ in sentences[2:])
240
+ new_text = f"{first}, {second}"
241
+ if rest:
242
+ new_text += " " + rest
243
+ return new_text
244
+
245
+ return chunk_text
246
+
247
+ def _is_apostrophe(text, pos):
248
+ """Check if a single quote at position pos is likely an apostrophe (not speech quote)"""
249
+ if pos == 0 or pos >= len(text) - 1:
250
+ return False
251
+
252
+ # Check characters before and after
253
+ before = text[pos - 1] if pos > 0 else ' '
254
+ after = text[pos + 1] if pos < len(text) - 1 else ' '
255
+
256
+ # It's likely an apostrophe if:
257
+ # 1. Preceded and followed by letters (contractions like "don't", possessives like "John's")
258
+ # 2. Or preceded by letter and followed by 's' or 't' (common contractions)
259
+ if before.isalpha() and after.isalpha():
260
+ return True
261
+ if before.isalpha() and after in 's':
262
+ return True
263
+
264
+ return False
265
+
266
+ def sentence_chunk_text(text, max_words=MAX_CHUNK_WORDS, min_words=MIN_CHUNK_WORDS):
267
+ """
268
+ Simple and reliable text chunking that follows the exact rules:
269
+
270
+ TEXT CHUNKING RULES:
271
+ 1. Break at sentence boundaries (. ! ?) first (highest priority)
272
+ 2. If sentence > max_words, break at punctuation working backwards (; — , in that order)
273
+ 3. If no punctuation available, preserve sentence intact to maintain coherence
274
+ 4. Ensure all chunks meet min_words requirement by combining small chunks
275
+
276
+ PUNCTUATION HIERARCHY (for breaking long sentences):
277
+ 1. . ! ? (sentence boundaries) - handled at sentence level
278
+ 2. ; (semicolon) - major pause
279
+ 3. — – (dashes) - major pause
280
+ 4. , (comma) - minor pause
281
+ 5. Preserve overlong sentences if no punctuation available
282
+ """
283
+ import re
284
+
285
+ # Process text paragraph by paragraph to preserve structure
286
+ paragraphs = text.split('\n\n')
287
+ all_final_chunks = []
288
+
289
+ for paragraph in paragraphs:
290
+ paragraph = paragraph.strip()
291
+ if not paragraph:
292
+ continue
293
+
294
+ # Check if this is a chapter/section header
295
+ para_lower = paragraph.lower().strip()
296
+ is_chapter_header = (
297
+ any(word in para_lower for word in ['chapter', 'section', 'part', 'prologue', 'epilogue']) and
298
+ len(paragraph.split()) <= 10
299
+ )
300
+
301
+ if is_chapter_header:
302
+ # Chapter headers are their own chunks and always paragraph ends
303
+ all_final_chunks.append((paragraph, True))
304
+ continue
305
+
306
+ # Split into sentences using periods, exclamation marks, question marks
307
+ # This avoids the complex quote detection that was causing problems
308
+ sentences = re.split(r'([.!?])\s+', paragraph.strip())
309
+
310
+ # Reconstruct sentences with their punctuation
311
+ reconstructed_sentences = []
312
+ for i in range(0, len(sentences) - 1, 2):
313
+ sentence = sentences[i].strip()
314
+ if i + 1 < len(sentences):
315
+ punct = sentences[i + 1]
316
+ sentence += punct
317
+ if sentence:
318
+ reconstructed_sentences.append(sentence)
319
+
320
+ # Handle any remaining text (no ending punctuation)
321
+ if sentences and sentences[-1].strip():
322
+ last_part = sentences[-1].strip()
323
+ if last_part and not last_part in '.!?':
324
+ reconstructed_sentences.append(last_part)
325
+
326
+ # Process each sentence
327
+ paragraph_chunks = []
328
+ for sent_idx, sentence in enumerate(reconstructed_sentences):
329
+ is_last_sentence = (sent_idx == len(reconstructed_sentences) - 1)
330
+ words = sentence.split()
331
+
332
+ if len(words) <= max_words:
333
+ # Sentence fits, use as-is
334
+ paragraph_chunks.append((sentence.strip(), is_last_sentence))
335
+ else:
336
+ # Sentence too long, break it using punctuation
337
+ broken_chunks = _break_long_sentence_simple(sentence, max_words)
338
+ # Only mark the last broken chunk as sentence end
339
+ for i, chunk in enumerate(broken_chunks):
340
+ is_chunk_end = (is_last_sentence and i == len(broken_chunks) - 1)
341
+ paragraph_chunks.append((chunk.strip(), is_chunk_end))
342
+
343
+ all_final_chunks.extend(paragraph_chunks)
344
+
345
+ # Combine small chunks that don't meet min_words requirement
346
+ combined_chunks = _combine_small_chunks(all_final_chunks, min_words, max_words)
347
+
348
+ return combined_chunks
349
+
350
+ def _break_long_sentence_simple(sentence, max_words):
351
+ """Break a long sentence at punctuation marks, working backwards"""
352
+ import re
353
+
354
+ # Punctuation patterns in priority order
355
+ patterns = [
356
+ r';\s*', # semicolon + optional space
357
+ r'—\s*', # em dash + optional space
358
+ r'–\s*', # en dash + optional space
359
+ r',\s*', # comma + optional space
360
+ ]
361
+
362
+ chunks = []
363
+ remaining = sentence.strip()
364
+
365
+ while remaining:
366
+ words = remaining.split()
367
+ if len(words) <= max_words:
368
+ chunks.append(remaining)
369
+ break
370
+
371
+ # Find best break point working backwards
372
+ best_break = -1
373
+
374
+ # Try each punctuation pattern
375
+ for pattern in patterns:
376
+ matches = list(re.finditer(pattern, remaining))
377
+ if matches:
378
+ # Find rightmost match that results in chunk <= max_words
379
+ for match in reversed(matches):
380
+ test_chunk = remaining[:match.end()].strip()
381
+ if len(test_chunk.split()) <= max_words:
382
+ best_break = match.end()
383
+ break
384
+ if best_break != -1:
385
+ break
386
+
387
+ if best_break != -1:
388
+ # Found good break point
389
+ chunk = remaining[:best_break].strip()
390
+ chunks.append(chunk)
391
+ remaining = remaining[best_break:].strip()
392
+ else:
393
+ # No punctuation found - preserve sentence coherence by keeping it intact
394
+ # This prevents splitting sentences with potentially different sentiment
395
+ chunks.append(remaining)
396
+ break
397
+
398
+ return chunks
399
+
400
+ def _combine_small_chunks(chunks, min_words, max_words):
401
+ """Combine chunks that are too small"""
402
+ combined = []
403
+ current_chunk = ""
404
+ current_is_para_end = False
405
+
406
+ for chunk_text, is_para_end in chunks:
407
+ chunk_words = len(chunk_text.split())
408
+ current_words = len(current_chunk.split()) if current_chunk else 0
409
+
410
+ if not current_chunk:
411
+ # First chunk
412
+ current_chunk = chunk_text
413
+ current_is_para_end = is_para_end
414
+ elif current_words + chunk_words <= max_words:
415
+ # Can combine
416
+ current_chunk = current_chunk + " " + chunk_text
417
+ current_is_para_end = is_para_end # Use the latest para_end flag
418
+ else:
419
+ # Can't combine, flush current and start new
420
+ if current_words >= min_words:
421
+ combined.append((current_chunk, current_is_para_end))
422
+ current_chunk = chunk_text
423
+ current_is_para_end = is_para_end
424
+ else:
425
+ # Current chunk too small, force combine anyway
426
+ current_chunk = current_chunk + " " + chunk_text
427
+ current_is_para_end = is_para_end
428
+
429
+ # Handle remaining chunk
430
+ if current_chunk:
431
+ combined.append((current_chunk, current_is_para_end))
432
+
433
+ return combined
434
+
435
+ def break_long_sentence_backwards(sentence, max_words, min_words):
436
+ """
437
+ Break a long sentence working backwards from the end to find natural punctuation.
438
+
439
+ ALGORITHM:
440
+ 1. Start from sentence end, work backwards to find punctuation within max_words
441
+ 2. Break at the latest (rightmost) punctuation that keeps chunk <= max_words
442
+ 3. This preserves natural pauses and speech rhythm
443
+ 4. Continue processing remaining text normally
444
+
445
+ PUNCTUATION HIERARCHY (in order of preference):
446
+ 1. . ! ? (sentence boundaries) - highest priority
447
+ 2. ; (semicolon) - major pause
448
+ 3. — (em dash) - major pause
449
+ 4. , (comma) - minor pause
450
+ 5. Force break at word limit (last resort)
451
+ """
452
+
453
+ # Punctuation patterns to search for (in order of preference)
454
+ punctuation_patterns = [
455
+ r'[.!?]\s+', # sentence boundaries + required space (highest priority)
456
+ r';\s*', # semicolon + optional space
457
+ r'—\s*', # em dash + optional space
458
+ r'–\s*', # en dash + optional space
459
+ r',\s*', # comma + optional space
460
+ ]
461
+
462
+ chunks = []
463
+ remaining_text = sentence.strip()
464
+
465
+ while remaining_text:
466
+ words = remaining_text.split()
467
+
468
+ if len(words) <= max_words:
469
+ # Remaining text fits within limit
470
+ chunks.append(remaining_text.strip())
471
+ break
472
+
473
+ # Text exceeds max_words - find backwards break point
474
+ # Search for punctuation within the current 'remaining_text' up to max_words
475
+ # We need to find the *last* punctuation mark that results in a chunk <= max_words
476
+ best_break_index = -1 # Index in 'words' list
477
+ best_break_pos_in_text = -1 # Character position in 'remaining_text'
478
+
479
+ # Iterate backwards from max_words down to min_words (or 1 if min_words is very small)
480
+ # to find the latest punctuation that keeps the chunk within limits.
481
+ for i in range(min(max_words, len(words)) -1, 0, -1):
482
+ sub_text = " ".join(words[:i+1]) # Text up to current word
483
+
484
+ found_punctuation = False
485
+ for pattern in punctuation_patterns:
486
+ matches = list(re.finditer(pattern, sub_text))
487
+ if matches:
488
+ # Take the rightmost match in this sub_text
489
+ last_match = matches[-1]
490
+ # Ensure the break is within the max_words limit
491
+ if len(sub_text[:last_match.end()].split()) <= max_words:
492
+ best_break_index = i # Store word index
493
+ best_break_pos_in_text = last_match.end() # Store char position
494
+ found_punctuation = True
495
+ break # Found a good break for this sub_text, move to next i
496
+ if found_punctuation:
497
+ break # Found the best break for the overall chunk, exit outer loop
498
+
499
+ if best_break_pos_in_text != -1:
500
+ # Found punctuation - break after it, keeping punctuation with preceding text
501
+ chunk_text = remaining_text[:best_break_pos_in_text].strip()
502
+ chunks.append(chunk_text)
503
+ remaining_text = remaining_text[best_break_pos_in_text:].strip()
504
+ else:
505
+ # No punctuation found within the desired range - keep sentence intact
506
+ # This preserves sentence coherence over word count limits
507
+ chunks.append(remaining_text.strip())
508
+ break
509
+
510
+ return chunks
511
+
512
+ # ============================================================================
513
+ # CONTENT BOUNDARY DETECTION
514
+ # ============================================================================
515
+
516
+ def detect_punctuation_boundary(chunk_text):
517
+ """
518
+ Detect the ending punctuation of a text chunk for precise silence insertion.
519
+
520
+ Returns specific punctuation boundary types:
521
+ - "comma" -> Brief pause after commas
522
+ - "semicolon" -> Medium pause after semicolons
523
+ - "colon" -> Pause after colons
524
+ - "period" -> Sentence end pause
525
+ - "question_mark" -> Question pause
526
+ - "exclamation" -> Exclamation pause
527
+ - "dash" -> Em dash pause
528
+ - "ellipsis" -> Ellipsis pause (suspense)
529
+ - "quote_end" -> End of quoted speech
530
+ - None -> No specific punctuation detected
531
+ """
532
+ # Strip whitespace and newlines for accurate detection
533
+ text = chunk_text.strip()
534
+
535
+ if not text:
536
+ return None
537
+
538
+ # Check ending punctuation patterns (in order of specificity)
539
+ if text.endswith('...'):
540
+ return "ellipsis"
541
+ elif text.endswith('"') or text.endswith("'"):
542
+ return "quote_end"
543
+ elif text.endswith('!'):
544
+ return "exclamation"
545
+ elif text.endswith('?'):
546
+ return "question_mark"
547
+ elif text.endswith('.'):
548
+ return "period"
549
+ elif text.endswith(':'):
550
+ return "colon"
551
+ elif text.endswith(';'):
552
+ return "semicolon"
553
+ elif text.endswith(','):
554
+ return "comma"
555
+ elif text.endswith('—') or text.endswith('–'):
556
+ return "dash"
557
+
558
+ return None
559
+
560
+ def detect_content_boundaries(chunk_text, chunk_index, all_chunks, is_paragraph_end=False):
561
+ """
562
+ Detect chapter breaks and paragraph endings for appropriate silence insertion.
563
+ Now enhanced with punctuation-specific boundary detection.
564
+
565
+ BOUNDARY DETECTION REQUIREMENTS:
566
+ - Chapter start: "Chapter N", "Ch. N", "I.", "1." patterns
567
+ - Chapter end: Next chunk is a chapter start
568
+ - Section break: Multiple asterisks, hashes, or em-dashes
569
+ - Paragraph end: Detected via chunking process flag or content analysis
570
+ - Punctuation: Specific ending punctuation for precise silence timing
571
+
572
+ Returns boundary_type for silence insertion:
573
+ - "chapter_start" -> Long pause before chapter
574
+ - "chapter_end" -> Long pause after chapter
575
+ - "section_break" -> Medium pause for section breaks
576
+ - "paragraph_end" -> Short pause for paragraph breaks
577
+ - Punctuation types: "comma", "period", "question_mark", etc.
578
+ - None -> No special boundary detected
579
+ """
580
+ boundary_type = None
581
+
582
+ # Chapter detection (flexible patterns)
583
+ chapter_patterns = [
584
+ r'^(Chapter \d+|CHAPTER \d+)',
585
+ r'^(Ch\. \d+|CH\. \d+)',
586
+ r'^\d+\.', # Simple "1." numbering
587
+ r'^[IVX]+\.', # Roman numerals "I.", "II.", etc.
588
+ ]
589
+
590
+ for pattern in chapter_patterns:
591
+ if re.search(pattern, chunk_text.strip(), re.MULTILINE):
592
+ boundary_type = "chapter_start"
593
+ break
594
+
595
+ # Look ahead for chapter start (current chunk ends chapter)
596
+ if chunk_index + 1 < len(all_chunks):
597
+ next_chunk = all_chunks[chunk_index + 1]
598
+ for pattern in chapter_patterns:
599
+ if re.search(pattern, next_chunk.strip()):
600
+ boundary_type = "chapter_end"
601
+ break
602
+
603
+ # Section breaks (asterisks, multiple line breaks)
604
+ if re.search(r'\*{3,}|\#{3,}|—{3,}', chunk_text):
605
+ boundary_type = "section_break"
606
+
607
+ # Paragraph ending detection
608
+ # Use the is_paragraph_end flag from chunking process since newlines are stripped
609
+ if is_paragraph_end and boundary_type is None:
610
+ boundary_type = "paragraph_end"
611
+
612
+ # If no major structural boundary found, check punctuation
613
+ if boundary_type is None:
614
+ boundary_type = detect_punctuation_boundary(chunk_text)
615
+
616
+ return boundary_type
617
+
618
+ def _split_long_dialogue(sentence, max_words, recursion_depth=0):
619
+ """
620
+ Split long dialogue sections that exceed word limits.
621
+ Tries to break at natural points: attribution, internal punctuation, then word boundaries.
622
+ """
623
+ # Prevent infinite recursion
624
+ if recursion_depth > 3:
625
+ # Force word boundary split if recursion gets too deep
626
+ words = sentence.split()
627
+ sentences = []
628
+ start = 0
629
+ while start < len(words):
630
+ end = min(start + max_words, len(words))
631
+ chunk_words = words[start:end]
632
+ sentences.append(' '.join(chunk_words))
633
+ start = end
634
+ return sentences
635
+
636
+ words = sentence.split()
637
+ if len(words) <= max_words:
638
+ return [sentence]
639
+
640
+ sentences = []
641
+
642
+ # Strategy 1: Break at dialogue attribution (he said, she replied, etc.)
643
+ attribution_pattern = r'(\s+(?:he|she|I|they|[A-Z][a-z]+)\s+(?:said|replied|asked|shouted|whispered|continued|added|interrupted)[^.!?]*?[.!?]?\s*)'
644
+ attribution_matches = list(re.finditer(attribution_pattern, sentence, re.IGNORECASE))
645
+
646
+ if attribution_matches:
647
+ start = 0
648
+ for match in attribution_matches:
649
+ # Check if breaking here keeps chunks under limit
650
+ before_attr = sentence[start:match.end()].strip()
651
+ if before_attr and len(before_attr.split()) <= max_words:
652
+ sentences.append(before_attr)
653
+ start = match.end()
654
+
655
+ # Add remaining text
656
+ if start < len(sentence):
657
+ remaining = sentence[start:].strip()
658
+ if remaining:
659
+ if len(remaining.split()) > max_words:
660
+ # Recursively split if still too long, but with depth tracking
661
+ sentences.extend(_split_long_dialogue(remaining, max_words, recursion_depth + 1))
662
+ else:
663
+ sentences.append(remaining)
664
+
665
+ if sentences: # If we successfully split, return result
666
+ return sentences
667
+
668
+ # Strategy 2: Break at internal punctuation (commas, semicolons within quotes)
669
+ punct_pattern = r'([,;:]\s+)'
670
+ parts = re.split(punct_pattern, sentence)
671
+
672
+ current_chunk = ""
673
+ sentences = []
674
+ for i, part in enumerate(parts):
675
+ test_chunk = current_chunk + part
676
+ if len(test_chunk.split()) > max_words and current_chunk:
677
+ sentences.append(current_chunk.strip())
678
+ current_chunk = part
679
+ else:
680
+ current_chunk = test_chunk
681
+
682
+ if current_chunk.strip():
683
+ sentences.append(current_chunk.strip())
684
+
685
+ # Check if any resulting chunk is still too long and needs further splitting
686
+ final_sentences = []
687
+ for chunk in sentences:
688
+ if len(chunk.split()) > max_words:
689
+ # Split oversized chunks using word boundaries
690
+ chunk_words = chunk.split()
691
+ start = 0
692
+ while start < len(chunk_words):
693
+ end = min(start + max_words, len(chunk_words))
694
+ sub_chunk_words = chunk_words[start:end]
695
+ final_sentences.append(' '.join(sub_chunk_words))
696
+ start = end
697
+ else:
698
+ final_sentences.append(chunk)
699
+
700
+ if len(final_sentences) > 1: # If we successfully split, return result
701
+ return final_sentences
702
+
703
+ # Strategy 3: Force break at word boundaries (guaranteed to work)
704
+ sentences = []
705
+ start = 0
706
+ while start < len(words):
707
+ end = min(start + max_words, len(words))
708
+ chunk_words = words[start:end]
709
+ sentences.append(' '.join(chunk_words))
710
+ start = end
711
+
712
+ return sentences
713
+
714
+ # ============================================================================
715
+ # UTILITY FUNCTIONS
716
+ # ============================================================================
717
+
718
+ def reload_abbreviations():
719
+ """Reload abbreviations from file (useful for testing changes)"""
720
+ return load_abbreviations()
721
+
722
+ def test_abbreviations(test_text="Dr. Smith met with the M.P. at 3:30 p.m. on the 21st."):
723
+ """Test abbreviation replacements on sample text"""
724
+ abbreviation_replacements = load_abbreviations()
725
+ print(f"Original: {test_text}")
726
+ processed = preprocess_abbreviations(test_text, abbreviation_replacements)
727
+ print(f"Processed: {processed}")
728
+ return processed
729
+
730
+ def test_chunking(test_text=None, max_words=20, min_words=4):
731
+ """Test the enhanced chunking with sample or custom text"""
732
+ if test_text is None:
733
+ test_text = '''Though perfectly worldly-wise, and able, as she expressed it, to take care of herself, there was yet something curiously ingenuous in her single-minded attitude towards life, and her whole-hearted determination to "make good." This glimpse of a world unknown to me was not without its charm, and I enjoyed seeing her vivid little face light up as she talked.'''
734
+
735
+ chunks = sentence_chunk_text(test_text, max_words=max_words, min_words=min_words)
736
+
737
+ print("Enhanced Chunking Results:")
738
+ for i, (chunk, is_para) in enumerate(chunks):
739
+ word_count = len(chunk.split())
740
+ print(f"Chunk {i+1} ({word_count} words): {chunk}")
741
+ if word_count > max_words:
742
+ print(f" ✅ Over {max_words} words but complete sentence (follows punctuation rules)")
743
+ print()
744
+
745
+ return chunks
HF_Deploy/modules/tts_engine.py ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TTS Engine Module
3
+ Handles ChatterboxTTS interface, model loading, and chunk processing coordination
4
+ """
5
+
6
+ import torch
7
+ import gc
8
+ import time
9
+ import logging
10
+ import shutil
11
+ import sys
12
+ from datetime import timedelta
13
+ from concurrent.futures import ThreadPoolExecutor, as_completed
14
+ from pathlib import Path
15
+ import torchaudio as ta
16
+
17
+ from config.config import *
18
+ from modules.text_processor import smart_punctuate, sentence_chunk_text, detect_content_boundaries
19
+
20
+ def find_chunks_json_file(book_name):
21
+ """Find the corresponding chunks JSON file for a book"""
22
+ from config.config import AUDIOBOOK_ROOT
23
+
24
+ # Look in the TTS processing directory
25
+ tts_chunks_dir = AUDIOBOOK_ROOT / book_name / "TTS" / "text_chunks"
26
+ json_path = tts_chunks_dir / "chunks_info.json"
27
+
28
+ if json_path.exists():
29
+ return json_path
30
+
31
+ # Also check old Text_Input location for backwards compatibility
32
+ text_input_dir = Path("Text_Input")
33
+ possible_names = [
34
+ f"{book_name}_chunks.json",
35
+ f"{book_name.lower()}_chunks.json",
36
+ f"{book_name.replace(' ', '_')}_chunks.json"
37
+ ]
38
+
39
+ for name in possible_names:
40
+ old_json_path = text_input_dir / name
41
+ if old_json_path.exists():
42
+ return old_json_path
43
+
44
+ return None
45
+ from modules.audio_processor import (
46
+ smart_audio_validation, apply_smart_fade, add_chunk_end_silence,
47
+ add_contextual_silence, pause_for_chunk_review, get_chunk_audio_duration,
48
+ has_mid_energy_drop, apply_smart_fade_memory, smart_audio_validation_memory
49
+ )
50
+ from modules.file_manager import (
51
+ setup_book_directories, find_book_files, ensure_voice_sample_compatibility,
52
+ combine_audio_chunks, get_audio_files_in_directory, convert_to_m4b, add_metadata_to_m4b
53
+ )
54
+ from modules.progress_tracker import setup_logging, log_chunk_progress, log_run
55
+
56
+ # ============================================================================
57
+ # MEMORY AND MODEL MANAGEMENT
58
+ # ============================================================================
59
+
60
+ def monitor_gpu_activity(operation_name):
61
+ """Lightweight GPU monitoring for high-speed processing"""
62
+ # Disabled expensive pynvml queries to free up GPU cycles
63
+ if torch.cuda.is_available():
64
+ allocated = torch.cuda.memory_allocated() / 1024**3
65
+ # Skip GPU utilization queries during production runs
66
+ return allocated, 0
67
+ return 0, 0
68
+
69
+ def optimize_memory_usage():
70
+ """Aggressive memory management for 8GB VRAM"""
71
+ torch.cuda.empty_cache()
72
+ gc.collect()
73
+ if torch.cuda.is_available():
74
+ torch.cuda.ipc_collect()
75
+
76
+ def monitor_vram_usage(operation_name=""):
77
+ """Real-time VRAM monitoring"""
78
+ if torch.cuda.is_available():
79
+ allocated = torch.cuda.memory_allocated() / 1024**3
80
+ reserved = torch.cuda.memory_reserved() / 1024**3
81
+
82
+ if allocated > VRAM_SAFETY_THRESHOLD:
83
+ logging.warning(f"⚠️ High VRAM usage during {operation_name}: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved")
84
+ optimize_memory_usage()
85
+
86
+ return allocated, reserved
87
+ return 0, 0
88
+
89
+ def get_optimal_workers(user_max_workers=None):
90
+ """Dynamic worker allocation based on device type and resources"""
91
+ # Check for user override first
92
+ if user_max_workers is not None:
93
+ print(f"👤 Using user-defined workers: {user_max_workers}")
94
+ return int(user_max_workers)
95
+
96
+ if not USE_DYNAMIC_WORKERS:
97
+ return MAX_WORKERS
98
+
99
+ # CPU-based worker calculation
100
+ if not torch.cuda.is_available():
101
+ import psutil
102
+ cpu_cores = psutil.cpu_count(logical=False) # Physical cores
103
+ available_memory = psutil.virtual_memory().available / 1024**3 # GB
104
+
105
+ # Each TTS model instance needs ~2-3GB RAM
106
+ # Conservative estimation: allow 1 worker per 4GB available RAM
107
+ memory_limited_workers = max(1, int(available_memory / 4))
108
+
109
+ # CPU-based calculation: use 50% of physical cores for intensive TTS work
110
+ cpu_limited_workers = max(1, int(cpu_cores * 0.5))
111
+
112
+ optimal_workers = min(memory_limited_workers, cpu_limited_workers, MAX_WORKERS)
113
+ print(f"💻 CPU mode: {cpu_cores} cores, {available_memory:.1f}GB RAM → {optimal_workers} workers")
114
+ return optimal_workers
115
+
116
+ # GPU-based worker calculation (existing logic)
117
+ allocated_vram = torch.cuda.memory_allocated() / 1024**3
118
+
119
+ if allocated_vram < 5.0:
120
+ return min(TEST_MAX_WORKERS, MAX_WORKERS)
121
+ elif allocated_vram < VRAM_SAFETY_THRESHOLD:
122
+ return min(2, MAX_WORKERS)
123
+ else:
124
+ return 1
125
+
126
+ def load_optimized_model(device):
127
+ """Load TTS model with memory optimizations and device detection"""
128
+ from chatterbox.tts import ChatterboxTTS
129
+
130
+ # Detect available device if not specified or if CUDA not available
131
+ if device == "cuda" and not torch.cuda.is_available():
132
+ print("⚠️ CUDA not available, falling back to CPU")
133
+ device = "cpu"
134
+ elif device == "auto":
135
+ if torch.cuda.is_available():
136
+ device = "cuda"
137
+ print("✅ CUDA detected, using GPU")
138
+ else:
139
+ device = "cpu"
140
+ print("💻 No GPU detected, using CPU")
141
+
142
+ print(f"🔧 Loading ChatterboxTTS model on device: {device}")
143
+
144
+ try:
145
+ # Load model (ChatterboxTTS.from_pretrained doesn't support torch_dtype parameter)
146
+ model = ChatterboxTTS.from_pretrained(device=device)
147
+ logging.info(f"✅ Loaded ChatterboxTTS model on {device}")
148
+ except Exception as e:
149
+ print(f"❌ Failed to load model on {device}: {e}")
150
+ if device == "cuda":
151
+ print("🔄 Retrying with CPU...")
152
+ try:
153
+ model = ChatterboxTTS.from_pretrained(device="cpu")
154
+ logging.info("✅ Loaded model on CPU (GPU failed)")
155
+ device = "cpu"
156
+ except Exception as e2:
157
+ print(f"❌ Failed to load model on CPU: {e2}")
158
+ raise e2
159
+ else:
160
+ raise e
161
+
162
+ # Only apply eval() and benchmark if the model has these attributes
163
+ if hasattr(model, 'eval'):
164
+ model.eval()
165
+
166
+ # Set CUDNN benchmark for performance (if available)
167
+ if torch.backends.cudnn.is_available():
168
+ torch.backends.cudnn.benchmark = True
169
+
170
+ return model
171
+
172
+ # ============================================================================
173
+ # CHUNK PROCESSING
174
+ # ============================================================================
175
+
176
+ def patch_alignment_layer(tfmr, alignment_layer_idx=12):
177
+ """Patch alignment layer to avoid recursion"""
178
+ from types import MethodType
179
+ target_layer = tfmr.layers[alignment_layer_idx].self_attn
180
+ original_forward = target_layer.forward
181
+
182
+ def patched_forward(self, *args, **kwargs):
183
+ kwargs['output_attentions'] = True
184
+ return original_forward(*args, **kwargs)
185
+
186
+ target_layer.forward = MethodType(patched_forward, target_layer)
187
+
188
+ def process_one_chunk(
189
+ i, chunk, text_chunks_dir, audio_chunks_dir,
190
+ voice_path, tts_params, start_time, total_chunks,
191
+ punc_norm, basename, log_run_func, log_path, device,
192
+ model, asr_model, all_chunks, boundary_type="none"
193
+ ):
194
+ """Enhanced chunk processing with quality control, contextual silence, and deep cleanup"""
195
+ import difflib
196
+ from pydub import AudioSegment
197
+
198
+ chunk_id_str = f"{i+1:05}"
199
+ chunk_path = text_chunks_dir / f"chunk_{chunk_id_str}.txt"
200
+ with open(chunk_path, 'w', encoding='utf-8') as cf:
201
+ cf.write(chunk)
202
+
203
+ chunk_audio_path = audio_chunks_dir / f"chunk_{chunk_id_str}.wav"
204
+
205
+ # ============================================================================
206
+ # ENHANCED PERIODIC DEEP CLEANUP
207
+ # ============================================================================
208
+ cleanup_interval = CLEANUP_INTERVAL
209
+
210
+ # Skip cleanup on model reinitialization chunks to avoid conflicts
211
+ if (i + 1) % cleanup_interval == 0 and (i + 1) % BATCH_SIZE != 0:
212
+ print(f"\n🧹 {YELLOW}DEEP CLEANUP at chunk {i+1}/{total_chunks}...{RESET}")
213
+
214
+ # Enhanced VRAM monitoring before cleanup
215
+ allocated_before = torch.cuda.memory_allocated() / 1024**3 if torch.cuda.is_available() else 0
216
+ reserved_before = torch.cuda.memory_reserved() / 1024**3 if torch.cuda.is_available() else 0
217
+
218
+ print(f" Before: VRAM Allocated: {allocated_before:.1f}GB | Reserved: {reserved_before:.1f}GB")
219
+
220
+ # Bulk temp file cleanup
221
+ print(" 🗑️ Cleaning bulk temporary files...")
222
+ temp_patterns = ["*_try*.wav", "*_pre.wav", "*_fade*.wav", "*_debug*.wav", "*_temp*.wav", "*_backup*.wav"]
223
+ total_temp_files = 0
224
+ for pattern in temp_patterns:
225
+ temp_files = list(audio_chunks_dir.glob(pattern))
226
+ for temp_file in temp_files:
227
+ temp_file.unlink(missing_ok=True)
228
+ total_temp_files += len(temp_files)
229
+
230
+ if total_temp_files > 0:
231
+ print(f" 🗑️ Removed {total_temp_files} temporary audio files")
232
+
233
+ # Aggressive CUDA context reset
234
+ print(" 🔄 Performing aggressive CUDA context reset...")
235
+ torch.cuda.synchronize()
236
+ torch.cuda.empty_cache()
237
+ torch.cuda.ipc_collect()
238
+
239
+ # Force CUDA context reset
240
+ if hasattr(torch.cuda, 'reset_peak_memory_stats'):
241
+ torch.cuda.reset_peak_memory_stats()
242
+ if hasattr(torch._C, '_cuda_clearCublasWorkspaces'):
243
+ torch._C._cuda_clearCublasWorkspaces()
244
+
245
+ # Force garbage collection multiple times
246
+ for _ in range(3):
247
+ gc.collect()
248
+
249
+ # Clear model cache if it has one
250
+ if hasattr(model, 'clear_cache'):
251
+ model.clear_cache()
252
+ elif hasattr(model, 'reset_states'):
253
+ model.reset_states()
254
+
255
+ # Brief pause to let GPU settle
256
+ time.sleep(1.0)
257
+
258
+ # Monitor after cleanup
259
+ allocated_after = torch.cuda.memory_allocated() / 1024**3 if torch.cuda.is_available() else 0
260
+ reserved_after = torch.cuda.memory_reserved() / 1024**3 if torch.cuda.is_available() else 0
261
+
262
+ print(f" After: VRAM Allocated: {allocated_after:.1f}GB | Reserved: {reserved_after:.1f}GB")
263
+ print(f" Freed: {allocated_before - allocated_after:.1f}GB allocated, {reserved_before - reserved_after:.1f}GB reserved")
264
+ print(f"🧹 {GREEN}Deep cleanup complete!{RESET}\n")
265
+
266
+ best_sim, best_asr_text = -1, ""
267
+ wav_path_active = None
268
+ attempt_paths = []
269
+ mid_drop_retries = 0
270
+ max_mid_drop_retries = 2
271
+
272
+ for attempt_num in range(1, 3):
273
+ logging.info(f"🔁 Starting TTS for chunk {chunk_id_str}, attempt {attempt_num}")
274
+ try:
275
+ tts_args = {k: v for k, v in tts_params.items() if k != "max_workers"}
276
+
277
+ # monitor_gpu_activity(f"Before TTS chunk_{chunk_id_str}") # Disabled for speed
278
+ with torch.no_grad():
279
+ wav = model.generate(chunk, **tts_args).detach().cpu()
280
+ # monitor_gpu_activity(f"After TTS chunk_{chunk_id_str}") # Disabled for speed
281
+
282
+ if wav.dim() == 1:
283
+ wav = wav.unsqueeze(0)
284
+
285
+ # Retry if mid-energy drop is enabled and detected (check in memory)
286
+ if ENABLE_MID_DROP_CHECK and has_mid_energy_drop(wav, model.sr):
287
+ mid_drop_retries += 1
288
+ if mid_drop_retries >= max_mid_drop_retries:
289
+ logging.info(f"⚠️ Mid-drop retry limit reached for {chunk_id_str}. Accepting audio.")
290
+ else:
291
+ logging.info(f"⚠️ Mid-chunk noise detected in {chunk_id_str}. Retrying...")
292
+ continue
293
+
294
+ # Convert tensor to AudioSegment for in-memory processing
295
+ import io
296
+ import soundfile as sf
297
+ from pydub import AudioSegment
298
+
299
+ # Convert wav tensor to AudioSegment (in memory)
300
+ wav_np = wav.squeeze().numpy()
301
+ with io.BytesIO() as wav_buffer:
302
+ sf.write(wav_buffer, wav_np, model.sr, format='wav')
303
+ wav_buffer.seek(0)
304
+ audio_segment = AudioSegment.from_wav(wav_buffer)
305
+
306
+ # Smart fade removed - replaced by precise audio trimming
307
+ # Audio health validation disabled for speed
308
+
309
+ # Note: Audio trimming will handle end-of-speech cleanup more precisely
310
+
311
+ # ASR validation (memory-based processing) - check user setting first
312
+ enable_asr_user = tts_params.get('enable_asr', False)
313
+ if (enable_asr_user or ENABLE_ASR) and asr_model is not None:
314
+ from modules.audio_processor import asr_f1_score
315
+ import io
316
+ import soundfile as sf
317
+ # monitor_gpu_activity(f"Before ASR chunk_{chunk_id_str}") # Disabled for speed
318
+ try:
319
+ # Process ASR completely in memory - no disk writes
320
+ # Convert AudioSegment to numpy array for ASR
321
+ samples = np.array(audio_segment.get_array_of_samples())
322
+ if audio_segment.channels == 2:
323
+ samples = samples.reshape((-1, 2)).mean(axis=1)
324
+
325
+ # Normalize to float32 for ASR model
326
+ audio_np = samples.astype(np.float32) / audio_segment.max_possible_amplitude
327
+
328
+ # Use ASR model directly on numpy array (if supported)
329
+ # Note: This depends on the ASR model's input capabilities
330
+ result = asr_model.transcribe(audio_np)
331
+
332
+ if not isinstance(result, dict) or "text" not in result:
333
+ raise ValueError(f"Invalid ASR result type: {type(result)}")
334
+
335
+ asr_text = result.get("text", "").strip()
336
+ sim_ratio = asr_f1_score(punc_norm(chunk), asr_text)
337
+
338
+ except Exception as e:
339
+ print(f"❌ ASR failed for {chunk_id_str}: {e}")
340
+ log_run_func(f"ASR VALIDATION FAILED - Chunk {chunk_id_str}:\nExpected:\n{chunk}\nActual:\n<ASR Failure: {e}>\nSimilarity: -1.000\n" + "="*50, log_path)
341
+ sim_ratio = -1.0
342
+ continue
343
+
344
+ logging.info(f"ASR similarity for chunk {chunk_id_str}: {sim_ratio:.3f}")
345
+ if sim_ratio < 0.7:
346
+ continue
347
+
348
+ # Track best valid match
349
+ best_sim = sim_ratio
350
+ best_asr_text = asr_text
351
+ # monitor_gpu_activity(f"After ASR chunk_{chunk_id_str}") # Disabled for speed
352
+
353
+ # Success - we have processed audio in memory
354
+ final_audio = audio_segment
355
+ break
356
+
357
+ except Exception as e:
358
+ import traceback
359
+ logging.error(f"Exception during TTS attempt {attempt_num} for chunk {chunk_id_str}: {e}")
360
+ traceback.print_exc()
361
+ continue
362
+
363
+ if 'final_audio' not in locals():
364
+ logging.info(f"❌ Chunk {chunk_id_str} failed all attempts.")
365
+ return None, None
366
+
367
+ # Apply trimming and contextual silence in memory before final save
368
+ from modules.audio_processor import process_audio_with_trimming_and_silence
369
+
370
+ if boundary_type and boundary_type != "none":
371
+ final_audio = process_audio_with_trimming_and_silence(final_audio, boundary_type)
372
+ print(f"🔇 Added {boundary_type} silence to chunk {i+1:05}")
373
+ else:
374
+ # Apply trimming even without boundary type if enabled
375
+ if ENABLE_AUDIO_TRIMMING:
376
+ from modules.audio_processor import trim_audio_endpoint
377
+ final_audio = trim_audio_endpoint(final_audio)
378
+
379
+ # Note: ENABLE_CHUNK_END_SILENCE is now handled by punctuation-specific silence
380
+ # The new system provides more precise silence based on actual punctuation
381
+
382
+ # Final save - only disk write in entire process
383
+ final_path = audio_chunks_dir / f"chunk_{chunk_id_str}.wav"
384
+ final_audio.export(final_path, format="wav")
385
+ logging.info(f"✅ Saved final chunk: {final_path.name}")
386
+
387
+ # No intermediate file cleanup needed - all processing done in memory
388
+
389
+ # Log details - only log ASR failures
390
+ asr_active = enable_asr_user or ENABLE_ASR
391
+ if asr_active and best_sim < 0.8:
392
+ log_run_func(f"ASR VALIDATION FAILED - Chunk {chunk_id_str}:\nExpected:\n{chunk}\nActual:\n{best_asr_text}\nSimilarity: {best_sim:.3f}\n" + "="*50, log_path)
393
+ elif not asr_active:
394
+ log_run_func(f"Chunk {chunk_id_str}: Original text: {chunk}", log_path)
395
+
396
+ # Silence already added in memory above - no disk processing needed
397
+
398
+ # Enhanced regular cleanup (every chunk)
399
+ del wav
400
+ optimize_memory_usage()
401
+
402
+ # Additional per-chunk cleanup for long runs
403
+ if (i + 1) % 50 == 0:
404
+ torch.cuda.empty_cache()
405
+ gc.collect()
406
+
407
+ return i, final_path
408
+
409
+ # ============================================================================
410
+ # MAIN BOOK PROCESSING FUNCTION
411
+ # ============================================================================
412
+
413
+ from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
414
+ from wrapper.chunk_loader import save_chunks
415
+
416
+ def generate_enriched_chunks(text_file, output_dir, user_tts_params=None):
417
+ """Reads a text file, performs VADER sentiment analysis, and returns enriched chunks."""
418
+ analyzer = SentimentIntensityAnalyzer()
419
+
420
+ raw_text = text_file.read_text(encoding='utf-8')
421
+ cleaned = smart_punctuate(raw_text)
422
+ chunks = sentence_chunk_text(cleaned)
423
+
424
+ # Use user-provided parameters as base, or fall back to config defaults
425
+ if user_tts_params:
426
+ base_exaggeration = user_tts_params.get('exaggeration', BASE_EXAGGERATION)
427
+ base_cfg_weight = user_tts_params.get('cfg_weight', BASE_CFG_WEIGHT)
428
+ base_temperature = user_tts_params.get('temperature', BASE_TEMPERATURE)
429
+ else:
430
+ base_exaggeration = BASE_EXAGGERATION
431
+ base_cfg_weight = BASE_CFG_WEIGHT
432
+ base_temperature = BASE_TEMPERATURE
433
+
434
+ enriched = []
435
+ chunk_texts = [chunk_text for chunk_text, _ in chunks]
436
+
437
+ for i, (chunk_text, is_para_end) in enumerate(chunks):
438
+ sentiment_scores = analyzer.polarity_scores(chunk_text)
439
+ compound_score = sentiment_scores['compound']
440
+
441
+ exaggeration = base_exaggeration + (compound_score * VADER_EXAGGERATION_SENSITIVITY)
442
+ cfg_weight = base_cfg_weight + (compound_score * VADER_CFG_WEIGHT_SENSITIVITY)
443
+ temperature = base_temperature + (compound_score * VADER_TEMPERATURE_SENSITIVITY)
444
+
445
+ # Clamp values to defined min/max
446
+ exaggeration = round(max(TTS_PARAM_MIN_EXAGGERATION, min(exaggeration, TTS_PARAM_MAX_EXAGGERATION)), 2)
447
+ cfg_weight = round(max(TTS_PARAM_MIN_CFG_WEIGHT, min(cfg_weight, TTS_PARAM_MAX_CFG_WEIGHT)), 2)
448
+ temperature = round(max(TTS_PARAM_MIN_TEMPERATURE, min(temperature, TTS_PARAM_MAX_TEMPERATURE)), 2)
449
+
450
+ boundary_type = detect_content_boundaries(chunk_text, i, chunk_texts, is_para_end)
451
+
452
+ enriched.append({
453
+ "index": i,
454
+ "text": chunk_text,
455
+ "word_count": len(chunk_text.split()),
456
+ "boundary_type": boundary_type if boundary_type else "none",
457
+ "sentiment_compound": compound_score,
458
+ "tts_params": {
459
+ "exaggeration": exaggeration,
460
+ "cfg_weight": cfg_weight,
461
+ "temperature": temperature
462
+ }
463
+ })
464
+
465
+ output_json_path = output_dir / "chunks_info.json"
466
+ save_chunks(output_json_path, enriched)
467
+ return enriched
468
+
469
+ def process_book_folder(book_dir, voice_path, tts_params, device, skip_cleanup=False):
470
+ """Enhanced book processing with batch processing to prevent hangs"""
471
+ print(f"🔍 DEBUG: Entering process_book_folder with book_dir='{book_dir}', voice_path='{voice_path}'")
472
+
473
+ from chatterbox.tts import punc_norm
474
+ print(f"🔍 DEBUG: Successfully imported punc_norm")
475
+
476
+ # Setup directories
477
+ print(f"🔍 DEBUG: Calling setup_book_directories...")
478
+ output_root, tts_dir, text_chunks_dir, audio_chunks_dir = setup_book_directories(book_dir)
479
+ print(f"🔍 DEBUG: Directory setup complete")
480
+
481
+ # Clean previous processing files (but skip for resume operations)
482
+ if skip_cleanup:
483
+ print(f"🔄 RESUME MODE: Skipping cleanup to preserve existing chunks")
484
+ print(f"📁 Preserving: {text_chunks_dir}, {audio_chunks_dir}")
485
+ else:
486
+ print(f"🧹 FRESH PROCESSING: Cleaning previous processing files...")
487
+ import glob
488
+
489
+ # Clear text chunks
490
+ for txt_file in text_chunks_dir.glob("*.txt"):
491
+ txt_file.unlink(missing_ok=True)
492
+ for json_file in text_chunks_dir.glob("*.json"):
493
+ json_file.unlink(missing_ok=True)
494
+
495
+ # Clear audio chunks
496
+ for wav_file in audio_chunks_dir.glob("*.wav"):
497
+ wav_file.unlink(missing_ok=True)
498
+
499
+ # Clear logs
500
+ for log_file in output_root.glob("*.log"):
501
+ log_file.unlink(missing_ok=True)
502
+
503
+ print(f"✅ Cleanup complete")
504
+
505
+ # Find book files
506
+ print(f"🔍 DEBUG: Calling find_book_files...")
507
+ book_files = find_book_files(book_dir)
508
+ text_files = [book_files['text']] if book_files['text'] else []
509
+ cover_file = book_files['cover']
510
+ nfo_file = book_files['nfo']
511
+ print(f"🔍 DEBUG: Found text files: {text_files}")
512
+
513
+ if not text_files:
514
+ logging.info(f"[{book_dir.name}] ERROR: No .txt files found in the book folder.")
515
+ return None, None, []
516
+
517
+ setup_logging(output_root)
518
+
519
+ # Generate enriched chunks with VADER analysis using user parameters
520
+ all_chunks = generate_enriched_chunks(text_files[0], text_chunks_dir, tts_params)
521
+
522
+ # Create run_log_lines
523
+ print(f"🔍 DEBUG: Creating run_log_lines...")
524
+ print(f"🔍 DEBUG: voice_path type: {type(voice_path)}, value: {voice_path}")
525
+
526
+ # Extract voice name for logging
527
+ voice_name_for_log = voice_path.stem if hasattr(voice_path, 'stem') else Path(voice_path).stem
528
+
529
+ run_log_lines = [
530
+ f"\n===== Processing: {book_dir.name} =====",
531
+ f"Voice: {voice_name_for_log}",
532
+ f"Started: {time.strftime('%Y-%m-%d %H:%M:%S')}",
533
+ f"Text files processed: {len(text_files)}",
534
+ f"Total chunks generated: {len(all_chunks)}"
535
+ ]
536
+
537
+ start_time = time.time()
538
+ total_chunks = len(all_chunks)
539
+ log_path = output_root / "chunk_validation.log"
540
+ total_audio_duration = 0.0
541
+
542
+ # Batch processing
543
+ print(f"📊 Processing {total_chunks} chunks in batches of {BATCH_SIZE}")
544
+
545
+ all_results = []
546
+
547
+ for batch_start in range(0, total_chunks, BATCH_SIZE):
548
+ batch_end = min(batch_start + BATCH_SIZE, total_chunks)
549
+ batch_chunks = all_chunks[batch_start:batch_end]
550
+
551
+ print(f"\n🔄 Processing batch: chunks {batch_start+1}-{batch_end}")
552
+
553
+ # Fresh model for each batch
554
+ model = load_optimized_model(device)
555
+ compatible_voice = ensure_voice_sample_compatibility(voice_path, output_dir=tts_dir)
556
+ model.prepare_conditionals(compatible_voice)
557
+
558
+ # Load ASR model once per batch if needed (check user settings first, then global config)
559
+ asr_model = None
560
+ enable_asr_user = tts_params.get('enable_asr', False)
561
+ if enable_asr_user or ENABLE_ASR:
562
+ import whisper
563
+ print(f"🎤 Loading Whisper ASR model for batch... (user setting: {enable_asr_user})")
564
+ # Use same device as TTS model, with fallback to CPU
565
+ asr_device = device if torch.cuda.is_available() and device == "cuda" else "cpu"
566
+ print(f"🎤 Loading ASR model on device: {asr_device}")
567
+ asr_model = whisper.load_model("base", device=asr_device)
568
+
569
+ futures = []
570
+ batch_results = []
571
+
572
+ # Dynamic worker allocation
573
+ user_max_workers = tts_params.get('max_workers', None)
574
+ optimal_workers = get_optimal_workers(user_max_workers)
575
+ print(f"🔧 Using {optimal_workers} workers for batch {batch_start+1}-{batch_end}")
576
+
577
+ with ThreadPoolExecutor(max_workers=optimal_workers) as executor:
578
+ for i, chunk_data in enumerate(batch_chunks):
579
+ global_chunk_index = batch_start + i
580
+
581
+ # Check for shutdown request
582
+ if shutdown_requested:
583
+ print(f"\n⏹️ {YELLOW}Stopping submission of new chunks...{RESET}")
584
+ break
585
+
586
+ # Handle both dictionary and tuple formats for chunk data
587
+ if isinstance(chunk_data, dict):
588
+ chunk = chunk_data["text"]
589
+ boundary_type = chunk_data.get("boundary_type", "none")
590
+ # Use chunk-specific TTS params if available, otherwise fall back to global
591
+ chunk_tts_params = chunk_data.get("tts_params", tts_params)
592
+ else:
593
+ # Handle old tuple format (text, is_para_end) - convert to boundary_type
594
+ chunk = chunk_data[0] if len(chunk_data) > 0 else str(chunk_data)
595
+ # Convert old is_paragraph_end to boundary_type
596
+ is_old_para_end = chunk_data[1] if len(chunk_data) > 1 else False
597
+ boundary_type = "paragraph_end" if is_old_para_end else "none"
598
+ chunk_tts_params = tts_params # Fallback for old format
599
+
600
+ # Handle both dictionary and tuple formats for backward compatibility
601
+ all_chunk_texts = []
602
+ for cd in all_chunks:
603
+ if isinstance(cd, dict):
604
+ all_chunk_texts.append(cd["text"])
605
+ else:
606
+ # Handle old tuple format (text, is_para_end)
607
+ all_chunk_texts.append(cd[0] if len(cd) > 0 else str(cd))
608
+
609
+ futures.append(executor.submit(
610
+ process_one_chunk,
611
+ global_chunk_index, chunk, text_chunks_dir, audio_chunks_dir,
612
+ voice_path, chunk_tts_params, start_time, total_chunks,
613
+ punc_norm, book_dir.name, log_run, log_path, device,
614
+ model, asr_model, all_chunk_texts, boundary_type
615
+ ))
616
+
617
+ # Wait for batch to complete
618
+ print(f"🔄 {CYAN}Waiting for batch {batch_start+1}-{batch_end} to complete...{RESET}")
619
+ completed_count = 0
620
+
621
+ for fut in as_completed(futures):
622
+ try:
623
+ idx, wav_path = fut.result()
624
+ if wav_path and wav_path.exists():
625
+ # Measure actual audio duration for this chunk
626
+ chunk_duration = get_chunk_audio_duration(wav_path)
627
+ total_audio_duration += chunk_duration
628
+ batch_results.append((idx, wav_path))
629
+
630
+ # Update progress every 10 chunks within batch
631
+ completed_count += 1
632
+ if completed_count % 10 == 0:
633
+ log_chunk_progress(batch_start + completed_count - 1, total_chunks, start_time, total_audio_duration)
634
+
635
+ except Exception as e:
636
+ logging.error(f"Future failed in batch: {e}")
637
+
638
+ # Clean up model after batch
639
+ print(f"🧹 Cleaning up after batch {batch_start+1}-{batch_end}")
640
+ del model
641
+ if asr_model:
642
+ del asr_model
643
+ torch.cuda.empty_cache()
644
+ gc.collect()
645
+ time.sleep(2)
646
+
647
+ all_results.extend(batch_results)
648
+ print(f"✅ Batch {batch_start+1}-{batch_end} completed ({len(batch_results)} chunks)")
649
+
650
+ # Final processing
651
+ quarantine_dir = audio_chunks_dir / "quarantine"
652
+ pause_for_chunk_review(quarantine_dir)
653
+
654
+ # Collect final chunk paths
655
+ chunk_paths = get_audio_files_in_directory(audio_chunks_dir)
656
+
657
+ if not chunk_paths:
658
+ logging.info(f"{RED}❌ No valid audio chunks found. Skipping concatenation and conversion.{RESET}")
659
+ return None, None, []
660
+
661
+ # Calculate timing
662
+ elapsed_total = time.time() - start_time
663
+ elapsed_td = timedelta(seconds=int(elapsed_total))
664
+
665
+ total_audio_duration_final = sum(get_chunk_audio_duration(chunk_path) for chunk_path in chunk_paths)
666
+ audio_duration_td = timedelta(seconds=int(total_audio_duration_final))
667
+ realtime_factor = total_audio_duration_final / elapsed_total if elapsed_total > 0 else 0.0
668
+
669
+ print(f"\n⏱️ TTS Processing Complete:")
670
+ print(f" Elapsed Time: {CYAN}{str(elapsed_td)}{RESET}")
671
+ print(f" Audio Duration: {GREEN}{str(audio_duration_td)}{RESET}")
672
+ print(f" Realtime Factor: {YELLOW}{realtime_factor:.2f}x{RESET}")
673
+
674
+ # Combine audio
675
+ voice_name = voice_path.stem if hasattr(voice_path, 'stem') else Path(voice_path).stem
676
+ combined_wav_path = output_root / f"{book_dir.name} [{voice_name}].wav"
677
+ print("\n💾 Saving WAV file...")
678
+ combine_audio_chunks(chunk_paths, combined_wav_path)
679
+
680
+ # M4B conversion with normalization
681
+ temp_m4b_path = output_root / "output.m4b"
682
+ final_m4b_path = output_root / f"{book_dir.name}[{voice_name}].m4b"
683
+ convert_to_m4b(combined_wav_path, temp_m4b_path)
684
+ add_metadata_to_m4b(temp_m4b_path, final_m4b_path, cover_file, nfo_file)
685
+
686
+ logging.info(f"Audiobook created: {final_m4b_path}")
687
+
688
+ # Add final info to run log
689
+ run_log_lines.extend([
690
+ f"Combined WAV: {combined_wav_path}",
691
+ "--- Generation Settings ---",
692
+ f"Batch Processing: Enabled ({BATCH_SIZE} chunks per batch)",
693
+ f"ASR Enabled: {enable_asr_user or ENABLE_ASR} (user: {enable_asr_user}, global: {ENABLE_ASR})",
694
+ f"Hum Detection: {ENABLE_HUM_DETECTION}",
695
+ f"Dynamic Workers: {USE_DYNAMIC_WORKERS}",
696
+ f"Voice used: {voice_name}",
697
+ f"Exaggeration: {tts_params['exaggeration']}",
698
+ f"CFG weight: {tts_params['cfg_weight']}",
699
+ f"Temperature: {tts_params['temperature']}",
700
+ f"Processing Time: {str(elapsed_td)}",
701
+ f"Audio Duration: {str(audio_duration_td)}",
702
+ f"Realtime Factor: {realtime_factor:.2f}x",
703
+ f"Total Chunks: {len(chunk_paths)}"
704
+ ])
705
+
706
+ # Write the run log
707
+ log_run("\n".join(run_log_lines), output_root / "run.log")
708
+ print(f"📝 Run log written to: {output_root / 'run.log'}")
709
+
710
+ return final_m4b_path, combined_wav_path, run_log_lines
HF_Deploy/modules/voice_detector.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Voice Detection Module
3
+ Handles voice detection from multiple sources: JSON metadata, log files, filenames
4
+ """
5
+
6
+ import re
7
+ import json
8
+ from pathlib import Path
9
+ from config.config import AUDIOBOOK_ROOT
10
+ from modules.file_manager import list_voice_samples
11
+
12
+
13
+ def get_likely_voices_for_book(book_name, chunks_json_path=None):
14
+ """
15
+ Get the most likely voice candidates for a book using the 3 detection methods:
16
+ 1. JSON metadata/comments (if available)
17
+ 2. run.log file
18
+ 3. Generated audiobook filenames (may return multiple)
19
+
20
+ Returns: list of (voice_name, voice_path, detection_method) tuples
21
+ """
22
+ print(f"🔍 Finding likely voices for book: {book_name}")
23
+ likely_voices = []
24
+
25
+ # Method 1: Check JSON metadata and comments
26
+ if chunks_json_path:
27
+ voice_from_json = get_voice_from_json(chunks_json_path)
28
+ if voice_from_json:
29
+ voice_path = find_voice_file_by_name(voice_from_json)
30
+ if voice_path:
31
+ likely_voices.append((voice_from_json, voice_path, "json_metadata"))
32
+ print(f"✅ Voice found in JSON: {voice_from_json}")
33
+
34
+ # Method 2: Check run.log file
35
+ voice_from_log = get_voice_from_log(book_name)
36
+ if voice_from_log:
37
+ voice_path = find_voice_file_by_name(voice_from_log)
38
+ if voice_path:
39
+ # Avoid duplicates
40
+ if not any(v[0] == voice_from_log for v in likely_voices):
41
+ likely_voices.append((voice_from_log, voice_path, "run_log"))
42
+ print(f"✅ Voice found in run.log: {voice_from_log}")
43
+
44
+ # Method 3: Check generated filename patterns (may find multiple)
45
+ voices_from_files = get_voices_from_filenames(book_name)
46
+ for voice_name in voices_from_files:
47
+ voice_path = find_voice_file_by_name(voice_name)
48
+ if voice_path:
49
+ # Avoid duplicates
50
+ if not any(v[0] == voice_name for v in likely_voices):
51
+ likely_voices.append((voice_name, voice_path, "filename_pattern"))
52
+ print(f"✅ Voice found in filename: {voice_name}")
53
+
54
+ if not likely_voices:
55
+ print(f"⚠️ No likely voices detected for {book_name}")
56
+ else:
57
+ print(f"📋 Found {len(likely_voices)} likely voice candidates")
58
+
59
+ return likely_voices
60
+
61
+ def detect_voice_for_book(book_name, chunks_json_path=None):
62
+ """
63
+ Detect the most likely voice for a book (returns first candidate)
64
+ For backwards compatibility with existing code
65
+ """
66
+ likely_voices = get_likely_voices_for_book(book_name, chunks_json_path)
67
+ if likely_voices:
68
+ return likely_voices[0] # Return the first (most likely) candidate
69
+ return None, None, "not_found"
70
+
71
+
72
+ def get_voice_from_json(json_path):
73
+ """Extract voice information from JSON metadata"""
74
+ try:
75
+ with open(json_path, 'r', encoding='utf-8') as f:
76
+ content = f.read()
77
+
78
+ # Check for voice metadata in JSON
79
+ if '"voice_used":' in content:
80
+ data = json.loads(content)
81
+ if isinstance(data, dict) and 'voice_used' in data:
82
+ return data['voice_used']
83
+ elif isinstance(data, list) and data and 'voice_used' in data[0]:
84
+ return data[0]['voice_used']
85
+
86
+ # Check for voice as comment in JSON (fallback option)
87
+ voice_comment_match = re.search(r'//\s*voice:\s*([^\n]+)', content, re.IGNORECASE)
88
+ if voice_comment_match:
89
+ return voice_comment_match.group(1).strip()
90
+
91
+ except Exception as e:
92
+ print(f"⚠️ Error reading JSON for voice info: {e}")
93
+
94
+ return None
95
+
96
+
97
+ def get_voice_from_log(book_name):
98
+ """Extract voice information from run.log file"""
99
+ audiobook_root = Path(AUDIOBOOK_ROOT)
100
+ log_file = audiobook_root / book_name / "run.log"
101
+
102
+ if log_file.exists():
103
+ try:
104
+ with open(log_file, 'r', encoding='utf-8') as f:
105
+ for line in f:
106
+ line = line.strip()
107
+ if line.startswith("Voice: ") or line.startswith("Voice used: "):
108
+ voice_name = line.split(": ", 1)[1].strip()
109
+ return voice_name
110
+ except Exception as e:
111
+ print(f"⚠️ Error reading run log: {e}")
112
+
113
+ return None
114
+
115
+
116
+ def get_voices_from_filenames(book_name):
117
+ """Extract voice names from existing audiobook filename patterns (may return multiple)"""
118
+ audiobook_root = Path(AUDIOBOOK_ROOT)
119
+ book_dir = audiobook_root / book_name
120
+
121
+ if not book_dir.exists():
122
+ return []
123
+
124
+ found_voices = []
125
+
126
+ # Look for WAV files with voice pattern: BookName [VoiceName].wav
127
+ for wav_file in book_dir.glob("*.wav"):
128
+ match = re.search(r'\[([^\]]+)\]\.wav$', wav_file.name)
129
+ if match:
130
+ voice_name = match.group(1)
131
+ if voice_name not in found_voices:
132
+ found_voices.append(voice_name)
133
+
134
+ # Look for M4B files with voice pattern: BookName[VoiceName].m4b
135
+ for m4b_file in book_dir.glob("*.m4b"):
136
+ match = re.search(r'\[([^\]]+)\]\.m4b$', m4b_file.name)
137
+ if match:
138
+ voice_name = match.group(1)
139
+ if voice_name not in found_voices:
140
+ found_voices.append(voice_name)
141
+
142
+ return found_voices
143
+
144
+ def get_voice_from_filename(book_name):
145
+ """Extract voice name from existing audiobook filename patterns (backwards compatibility)"""
146
+ voices = get_voices_from_filenames(book_name)
147
+ return voices[0] if voices else None
148
+
149
+
150
+ def find_voice_file_by_name(voice_name):
151
+ """Find voice file by name in Voice_Samples directory"""
152
+ voice_files = list_voice_samples()
153
+
154
+ # Exact match first
155
+ for voice_file in voice_files:
156
+ if voice_file.stem == voice_name:
157
+ return voice_file
158
+
159
+ # Partial match (case insensitive)
160
+ voice_name_lower = voice_name.lower()
161
+ for voice_file in voice_files:
162
+ if voice_name_lower in voice_file.stem.lower():
163
+ return voice_file
164
+
165
+ return None
166
+
167
+
168
+
169
+
170
+ def add_voice_to_json(json_path, voice_name, method="metadata"):
171
+ """
172
+ Add voice information to JSON file
173
+
174
+ method options:
175
+ - "metadata": Add as top-level metadata
176
+ - "comment": Add as comment that doesn't affect parsing
177
+ """
178
+ try:
179
+ with open(json_path, 'r', encoding='utf-8') as f:
180
+ content = f.read()
181
+
182
+ if method == "metadata":
183
+ # Add voice as metadata to JSON structure
184
+ data = json.loads(content)
185
+
186
+ if isinstance(data, list):
187
+ # For list format, add metadata as first element or update existing
188
+ if data and isinstance(data[0], dict) and not any(key.startswith('text') for key in data[0].keys()):
189
+ # First element is already metadata
190
+ data[0]['voice_used'] = voice_name
191
+ else:
192
+ # Insert metadata as first element
193
+ metadata = {"voice_used": voice_name, "_metadata": True}
194
+ data.insert(0, metadata)
195
+ elif isinstance(data, dict):
196
+ # For dict format, add to top level
197
+ data['voice_used'] = voice_name
198
+
199
+ # Save updated JSON
200
+ with open(json_path, 'w', encoding='utf-8') as f:
201
+ json.dump(data, f, indent=2, ensure_ascii=False)
202
+
203
+ elif method == "comment":
204
+ # Add voice as comment at the top of file
205
+ voice_comment = f"// voice: {voice_name}\n"
206
+
207
+ if not content.startswith("// voice:"):
208
+ content = voice_comment + content
209
+ with open(json_path, 'w', encoding='utf-8') as f:
210
+ f.write(content)
211
+
212
+ print(f"✅ Added voice '{voice_name}' to {json_path.name} using {method} method")
213
+ return True
214
+
215
+ except Exception as e:
216
+ print(f"❌ Error adding voice to JSON: {e}")
217
+ return False
218
+
219
+
220
+ def remove_voice_comment_from_json(json_path):
221
+ """Remove voice comment from JSON file for clean processing"""
222
+ try:
223
+ with open(json_path, 'r', encoding='utf-8') as f:
224
+ content = f.read()
225
+
226
+ # Remove voice comment lines
227
+ lines = content.split('\n')
228
+ filtered_lines = [line for line in lines if not line.strip().startswith('// voice:')]
229
+
230
+ if len(filtered_lines) != len(lines):
231
+ # Comments were removed, save cleaned version
232
+ cleaned_content = '\n'.join(filtered_lines)
233
+ with open(json_path, 'w', encoding='utf-8') as f:
234
+ f.write(cleaned_content)
235
+ return True
236
+
237
+ except Exception as e:
238
+ print(f"⚠️ Error cleaning JSON comments: {e}")
239
+
240
+ return False
HF_Deploy/requirements.txt ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ChatterboxTTS HuggingFace Spaces Requirements
2
+ # Optimized for HF Spaces environment with flexible versions
3
+
4
+ # Core ML and TTS - Essential (pinned versions for fast builds)
5
+ torch==2.6.0
6
+ torchaudio==2.6.0
7
+ transformers==4.46.3
8
+ huggingface_hub>=0.15.0
9
+ safetensors>=0.3.0
10
+
11
+ # Audio processing - Required
12
+ soundfile>=0.12.0
13
+ librosa>=0.9.0
14
+ pydub>=0.25.0
15
+ audioread>=3.0.0
16
+
17
+ # ASR System - Intelligent ASR with fallback
18
+ openai-whisper>=20231117
19
+
20
+ # System monitoring and resource detection
21
+ psutil>=5.8.0
22
+ pynvml>=11.0.0
23
+
24
+ # Core scientific computing (pinned for fast builds)
25
+ numpy==2.2.0
26
+ scipy>=1.7.0
27
+
28
+ # Text processing
29
+ regex>=2023.0.0
30
+ vaderSentiment>=3.3.0
31
+
32
+ # Web interface - Gradio (let HF manage version)
33
+ gradio>=4.0.0
34
+
35
+ # Progress and logging
36
+ tqdm>=4.60.0
37
+
38
+ # File handling
39
+ pathlib2>=2.3.0
40
+
41
+ # Configuration and utilities
42
+ python-dotenv>=1.0.0
43
+
44
+ # Optional utilities
45
+ requests>=2.25.0
46
+ packaging>=21.0
47
+
48
+ # Core ChatterboxTTS model dependencies
49
+ chatterbox-tts>=0.1.2
50
+ resemble-perth>=1.0.1
51
+ omegaconf>=2.3.0
52
+ einops>=0.6.0
53
+ diffusers>=0.21.0
54
+ tokenizers>=0.13.0
55
+ conformer>=0.3.0
56
+ s3tokenizer==0.2.0
HF_Deploy/src/chatterbox/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .tts import ChatterboxTTS
2
+ from .vc import ChatterboxVC
HF_Deploy/src/chatterbox/models/s3gen/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .s3gen import S3Token2Wav as S3Gen
2
+ from .const import S3GEN_SR
HF_Deploy/src/chatterbox/models/s3gen/const.py ADDED
@@ -0,0 +1 @@
 
 
1
+ S3GEN_SR = 24000
HF_Deploy/src/chatterbox/models/s3gen/decoder.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from einops import pack, rearrange, repeat
18
+
19
+ from .utils.mask import add_optional_chunk_mask
20
+ from .matcha.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, \
21
+ TimestepEmbedding, Upsample1D
22
+ from .matcha.transformer import BasicTransformerBlock
23
+
24
+
25
+ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
26
+ assert mask.dtype == torch.bool
27
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
28
+ mask = mask.to(dtype)
29
+ # attention mask bias
30
+ # NOTE(Mddct): torch.finfo jit issues
31
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
32
+ mask = (1.0 - mask) * -1.0e+10
33
+ return mask
34
+
35
+
36
+
37
+ class Transpose(torch.nn.Module):
38
+ def __init__(self, dim0: int, dim1: int):
39
+ super().__init__()
40
+ self.dim0 = dim0
41
+ self.dim1 = dim1
42
+
43
+ def forward(self, x: torch.Tensor):
44
+ x = torch.transpose(x, self.dim0, self.dim1)
45
+ return x
46
+
47
+
48
+ class CausalBlock1D(Block1D):
49
+ def __init__(self, dim: int, dim_out: int):
50
+ super(CausalBlock1D, self).__init__(dim, dim_out)
51
+ self.block = torch.nn.Sequential(
52
+ CausalConv1d(dim, dim_out, 3),
53
+ Transpose(1, 2),
54
+ nn.LayerNorm(dim_out),
55
+ Transpose(1, 2),
56
+ nn.Mish(),
57
+ )
58
+
59
+ def forward(self, x: torch.Tensor, mask: torch.Tensor):
60
+ output = self.block(x * mask)
61
+ return output * mask
62
+
63
+
64
+ class CausalResnetBlock1D(ResnetBlock1D):
65
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
66
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
67
+ self.block1 = CausalBlock1D(dim, dim_out)
68
+ self.block2 = CausalBlock1D(dim_out, dim_out)
69
+
70
+
71
+ class CausalConv1d(torch.nn.Conv1d):
72
+ def __init__(
73
+ self,
74
+ in_channels: int,
75
+ out_channels: int,
76
+ kernel_size: int,
77
+ stride: int = 1,
78
+ dilation: int = 1,
79
+ groups: int = 1,
80
+ bias: bool = True,
81
+ padding_mode: str = 'zeros',
82
+ device=None,
83
+ dtype=None
84
+ ) -> None:
85
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
86
+ kernel_size, stride,
87
+ padding=0, dilation=dilation,
88
+ groups=groups, bias=bias,
89
+ padding_mode=padding_mode,
90
+ device=device, dtype=dtype)
91
+ assert stride == 1
92
+ self.causal_padding = (kernel_size - 1, 0)
93
+
94
+ def forward(self, x: torch.Tensor):
95
+ x = F.pad(x, self.causal_padding)
96
+ x = super(CausalConv1d, self).forward(x)
97
+ return x
98
+
99
+
100
+ class ConditionalDecoder(nn.Module):
101
+ def __init__(
102
+ self,
103
+ in_channels=320,
104
+ out_channels=80,
105
+ causal=True,
106
+ channels=[256],
107
+ dropout=0.0,
108
+ attention_head_dim=64,
109
+ n_blocks=4,
110
+ num_mid_blocks=12,
111
+ num_heads=8,
112
+ act_fn="gelu",
113
+ ):
114
+ """
115
+ This decoder requires an input with the same shape of the target. So, if your text content
116
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
117
+ """
118
+ super().__init__()
119
+ channels = tuple(channels)
120
+ self.in_channels = in_channels
121
+ self.out_channels = out_channels
122
+ self.causal = causal
123
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
124
+ time_embed_dim = channels[0] * 4
125
+ self.time_mlp = TimestepEmbedding(
126
+ in_channels=in_channels,
127
+ time_embed_dim=time_embed_dim,
128
+ act_fn="silu",
129
+ )
130
+ self.down_blocks = nn.ModuleList([])
131
+ self.mid_blocks = nn.ModuleList([])
132
+ self.up_blocks = nn.ModuleList([])
133
+
134
+ # NOTE jrm: `static_chunk_size` is missing?
135
+ self.static_chunk_size = 0
136
+
137
+ output_channel = in_channels
138
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
139
+ input_channel = output_channel
140
+ output_channel = channels[i]
141
+ is_last = i == len(channels) - 1
142
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
143
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
144
+ transformer_blocks = nn.ModuleList(
145
+ [
146
+ BasicTransformerBlock(
147
+ dim=output_channel,
148
+ num_attention_heads=num_heads,
149
+ attention_head_dim=attention_head_dim,
150
+ dropout=dropout,
151
+ activation_fn=act_fn,
152
+ )
153
+ for _ in range(n_blocks)
154
+ ]
155
+ )
156
+ downsample = (
157
+ Downsample1D(output_channel) if not is_last else
158
+ CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
159
+ )
160
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
161
+
162
+ for _ in range(num_mid_blocks):
163
+ input_channel = channels[-1]
164
+ out_channels = channels[-1]
165
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim) if self.causal else \
166
+ ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
167
+
168
+ transformer_blocks = nn.ModuleList(
169
+ [
170
+ BasicTransformerBlock(
171
+ dim=output_channel,
172
+ num_attention_heads=num_heads,
173
+ attention_head_dim=attention_head_dim,
174
+ dropout=dropout,
175
+ activation_fn=act_fn,
176
+ )
177
+ for _ in range(n_blocks)
178
+ ]
179
+ )
180
+
181
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
182
+
183
+ channels = channels[::-1] + (channels[0],)
184
+ for i in range(len(channels) - 1):
185
+ input_channel = channels[i] * 2
186
+ output_channel = channels[i + 1]
187
+ is_last = i == len(channels) - 2
188
+ resnet = CausalResnetBlock1D(
189
+ dim=input_channel,
190
+ dim_out=output_channel,
191
+ time_emb_dim=time_embed_dim,
192
+ ) if self.causal else ResnetBlock1D(
193
+ dim=input_channel,
194
+ dim_out=output_channel,
195
+ time_emb_dim=time_embed_dim,
196
+ )
197
+ transformer_blocks = nn.ModuleList(
198
+ [
199
+ BasicTransformerBlock(
200
+ dim=output_channel,
201
+ num_attention_heads=num_heads,
202
+ attention_head_dim=attention_head_dim,
203
+ dropout=dropout,
204
+ activation_fn=act_fn,
205
+ )
206
+ for _ in range(n_blocks)
207
+ ]
208
+ )
209
+ upsample = (
210
+ Upsample1D(output_channel, use_conv_transpose=True)
211
+ if not is_last
212
+ else CausalConv1d(output_channel, output_channel, 3) if self.causal else nn.Conv1d(output_channel, output_channel, 3, padding=1)
213
+ )
214
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
215
+ self.final_block = CausalBlock1D(channels[-1], channels[-1]) if self.causal else Block1D(channels[-1], channels[-1])
216
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
217
+ self.initialize_weights()
218
+
219
+ def initialize_weights(self):
220
+ for m in self.modules():
221
+ if isinstance(m, nn.Conv1d):
222
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
223
+ if m.bias is not None:
224
+ nn.init.constant_(m.bias, 0)
225
+ elif isinstance(m, nn.GroupNorm):
226
+ nn.init.constant_(m.weight, 1)
227
+ nn.init.constant_(m.bias, 0)
228
+ elif isinstance(m, nn.Linear):
229
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
230
+ if m.bias is not None:
231
+ nn.init.constant_(m.bias, 0)
232
+
233
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
234
+ """Forward pass of the UNet1DConditional model.
235
+
236
+ Args:
237
+ x (torch.Tensor): shape (batch_size, in_channels, time)
238
+ mask (_type_): shape (batch_size, 1, time)
239
+ t (_type_): shape (batch_size)
240
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
241
+ cond (_type_, optional): placeholder for future use. Defaults to None.
242
+
243
+ Raises:
244
+ ValueError: _description_
245
+ ValueError: _description_
246
+
247
+ Returns:
248
+ _type_: _description_
249
+ """
250
+
251
+ t = self.time_embeddings(t).to(t.dtype)
252
+ t = self.time_mlp(t)
253
+
254
+ x = pack([x, mu], "b * t")[0]
255
+
256
+ if spks is not None:
257
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
258
+ x = pack([x, spks], "b * t")[0]
259
+ if cond is not None:
260
+ x = pack([x, cond], "b * t")[0]
261
+
262
+ hiddens = []
263
+ masks = [mask]
264
+ for resnet, transformer_blocks, downsample in self.down_blocks:
265
+ mask_down = masks[-1]
266
+ x = resnet(x, mask_down, t)
267
+ x = rearrange(x, "b c t -> b t c").contiguous()
268
+ # attn_mask = torch.matmul(mask_down.transpose(1, 2).contiguous(), mask_down)
269
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
270
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
271
+ for transformer_block in transformer_blocks:
272
+ x = transformer_block(
273
+ hidden_states=x,
274
+ attention_mask=attn_mask,
275
+ timestep=t,
276
+ )
277
+ x = rearrange(x, "b t c -> b c t").contiguous()
278
+ hiddens.append(x) # Save hidden states for skip connections
279
+ x = downsample(x * mask_down)
280
+ masks.append(mask_down[:, :, ::2])
281
+ masks = masks[:-1]
282
+ mask_mid = masks[-1]
283
+
284
+ for resnet, transformer_blocks in self.mid_blocks:
285
+ x = resnet(x, mask_mid, t)
286
+ x = rearrange(x, "b c t -> b t c").contiguous()
287
+ # attn_mask = torch.matmul(mask_mid.transpose(1, 2).contiguous(), mask_mid)
288
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
289
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
290
+ for transformer_block in transformer_blocks:
291
+ x = transformer_block(
292
+ hidden_states=x,
293
+ attention_mask=attn_mask,
294
+ timestep=t,
295
+ )
296
+ x = rearrange(x, "b t c -> b c t").contiguous()
297
+
298
+ for resnet, transformer_blocks, upsample in self.up_blocks:
299
+ mask_up = masks.pop()
300
+ skip = hiddens.pop()
301
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
302
+ x = resnet(x, mask_up, t)
303
+ x = rearrange(x, "b c t -> b t c").contiguous()
304
+ # attn_mask = torch.matmul(mask_up.transpose(1, 2).contiguous(), mask_up)
305
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
306
+ attn_mask = mask_to_bias(attn_mask == 1, x.dtype)
307
+ for transformer_block in transformer_blocks:
308
+ x = transformer_block(
309
+ hidden_states=x,
310
+ attention_mask=attn_mask,
311
+ timestep=t,
312
+ )
313
+ x = rearrange(x, "b t c -> b c t").contiguous()
314
+ x = upsample(x * mask_up)
315
+ x = self.final_block(x, mask_up)
316
+ output = self.final_proj(x * mask_up)
317
+ return output * mask
HF_Deploy/src/chatterbox/models/s3gen/f0_predictor.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn.utils.parametrizations import weight_norm
17
+
18
+
19
+ class ConvRNNF0Predictor(nn.Module):
20
+ def __init__(self,
21
+ num_class: int = 1,
22
+ in_channels: int = 80,
23
+ cond_channels: int = 512
24
+ ):
25
+ super().__init__()
26
+
27
+ self.num_class = num_class
28
+ self.condnet = nn.Sequential(
29
+ weight_norm(
30
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
31
+ ),
32
+ nn.ELU(),
33
+ weight_norm(
34
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
35
+ ),
36
+ nn.ELU(),
37
+ weight_norm(
38
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
39
+ ),
40
+ nn.ELU(),
41
+ weight_norm(
42
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
43
+ ),
44
+ nn.ELU(),
45
+ weight_norm(
46
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
47
+ ),
48
+ nn.ELU(),
49
+ )
50
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
51
+
52
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
53
+ x = self.condnet(x)
54
+ x = x.transpose(1, 2)
55
+ return torch.abs(self.classifier(x).squeeze(-1))
HF_Deploy/src/chatterbox/models/s3gen/flow.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import logging
15
+ import random
16
+ from typing import Dict, Optional
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+ from omegaconf import DictConfig
21
+ from .utils.mask import make_pad_mask
22
+
23
+
24
+ class MaskedDiffWithXvec(torch.nn.Module):
25
+ def __init__(self,
26
+ input_size: int = 512,
27
+ output_size: int = 80,
28
+ spk_embed_dim: int = 192,
29
+ output_type: str = "mel",
30
+ vocab_size: int = 4096,
31
+ input_frame_rate: int = 50,
32
+ only_mask_loss: bool = True,
33
+ encoder: torch.nn.Module = None,
34
+ length_regulator: torch.nn.Module = None,
35
+ decoder: torch.nn.Module = None,
36
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
37
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
38
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
39
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
40
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
41
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
42
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
43
+ super().__init__()
44
+ self.input_size = input_size
45
+ self.output_size = output_size
46
+ self.decoder_conf = decoder_conf
47
+ self.mel_feat_conf = mel_feat_conf
48
+ self.vocab_size = vocab_size
49
+ self.output_type = output_type
50
+ self.input_frame_rate = input_frame_rate
51
+ logging.info(f"input frame rate={self.input_frame_rate}")
52
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
53
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
54
+ self.encoder = encoder
55
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
56
+ self.decoder = decoder
57
+ self.length_regulator = length_regulator
58
+ self.only_mask_loss = only_mask_loss
59
+
60
+ def forward(
61
+ self,
62
+ batch: dict,
63
+ device: torch.device,
64
+ ) -> Dict[str, Optional[torch.Tensor]]:
65
+ token = batch['speech_token'].to(device)
66
+ token_len = batch['speech_token_len'].to(device)
67
+ feat = batch['speech_feat'].to(device)
68
+ feat_len = batch['speech_feat_len'].to(device)
69
+ embedding = batch['embedding'].to(device)
70
+
71
+ # xvec projection
72
+ embedding = F.normalize(embedding, dim=1)
73
+ embedding = self.spk_embed_affine_layer(embedding)
74
+
75
+ # concat text and prompt_text
76
+ mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
77
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
78
+
79
+ # text encode
80
+ h, h_lengths = self.encoder(token, token_len)
81
+ h = self.encoder_proj(h)
82
+ h, h_lengths = self.length_regulator(h, feat_len)
83
+
84
+ # get conditions
85
+ conds = torch.zeros(feat.shape, device=token.device)
86
+ for i, j in enumerate(feat_len):
87
+ if random.random() < 0.5:
88
+ continue
89
+ index = random.randint(0, int(0.3 * j))
90
+ conds[i, :index] = feat[i, :index]
91
+ conds = conds.transpose(1, 2)
92
+
93
+ mask = (~make_pad_mask(feat_len)).to(h)
94
+ feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
95
+ loss, _ = self.decoder.compute_loss(
96
+ feat.transpose(1, 2).contiguous(),
97
+ mask.unsqueeze(1),
98
+ h.transpose(1, 2).contiguous(),
99
+ embedding,
100
+ cond=conds
101
+ )
102
+ return {'loss': loss}
103
+
104
+ @torch.inference_mode()
105
+ def inference(self,
106
+ token,
107
+ token_len,
108
+ prompt_token,
109
+ prompt_token_len,
110
+ prompt_feat,
111
+ prompt_feat_len,
112
+ embedding,
113
+ flow_cache):
114
+ if self.fp16 is True:
115
+ prompt_feat = prompt_feat.half()
116
+ embedding = embedding.half()
117
+
118
+ assert token.shape[0] == 1
119
+ # xvec projection
120
+ embedding = F.normalize(embedding, dim=1)
121
+ embedding = self.spk_embed_affine_layer(embedding)
122
+
123
+ # concat text and prompt_text
124
+ token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
125
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
126
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
127
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
128
+
129
+ # text encode
130
+ h, h_lengths = self.encoder(token, token_len)
131
+ h = self.encoder_proj(h)
132
+ mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
133
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
134
+
135
+ # get conditions
136
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
137
+ conds[:, :mel_len1] = prompt_feat
138
+ conds = conds.transpose(1, 2)
139
+
140
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
141
+ feat, flow_cache = self.decoder(
142
+ mu=h.transpose(1, 2).contiguous(),
143
+ mask=mask.unsqueeze(1),
144
+ spks=embedding,
145
+ cond=conds,
146
+ n_timesteps=10,
147
+ prompt_len=mel_len1,
148
+ flow_cache=flow_cache
149
+ )
150
+ feat = feat[:, :, mel_len1:]
151
+ assert feat.shape[2] == mel_len2
152
+ return feat.float(), flow_cache
153
+
154
+
155
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
156
+ def __init__(self,
157
+ input_size: int = 512,
158
+ output_size: int = 80,
159
+ spk_embed_dim: int = 192,
160
+ output_type: str = "mel",
161
+ vocab_size: int = 6561,
162
+ input_frame_rate: int = 25,
163
+ only_mask_loss: bool = True,
164
+ token_mel_ratio: int = 2,
165
+ pre_lookahead_len: int = 3,
166
+ encoder: torch.nn.Module = None,
167
+ decoder: torch.nn.Module = None,
168
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
169
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
170
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
171
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
172
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
173
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
174
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
175
+ super().__init__()
176
+ self.input_size = input_size
177
+ self.output_size = output_size
178
+ self.decoder_conf = decoder_conf
179
+ self.mel_feat_conf = mel_feat_conf
180
+ self.vocab_size = vocab_size
181
+ self.output_type = output_type
182
+ self.input_frame_rate = input_frame_rate
183
+ logging.info(f"input frame rate={self.input_frame_rate}")
184
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
185
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
186
+ self.encoder = encoder
187
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
188
+ self.decoder = decoder
189
+ self.only_mask_loss = only_mask_loss
190
+ self.token_mel_ratio = token_mel_ratio
191
+ self.pre_lookahead_len = pre_lookahead_len
192
+
193
+ # FIXME: this was missing - just putting it in as false
194
+ self.fp16 = False
195
+
196
+ @torch.inference_mode()
197
+ def inference(self,
198
+ token,
199
+ token_len,
200
+ prompt_token,
201
+ prompt_token_len,
202
+ prompt_feat,
203
+ prompt_feat_len,
204
+ embedding,
205
+ finalize):
206
+ if self.fp16 is True:
207
+ prompt_feat = prompt_feat.half()
208
+ embedding = embedding.half()
209
+
210
+ assert token.shape[0] == 1
211
+ # xvec projection
212
+ embedding = F.normalize(embedding, dim=1)
213
+ embedding = self.spk_embed_affine_layer(embedding)
214
+
215
+ # concat text and prompt_text
216
+ token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
217
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
218
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
219
+
220
+ # text encode
221
+ h, h_lengths = self.encoder(token, token_len)
222
+ if finalize is False:
223
+ h = h[:, :-self.pre_lookahead_len * self.token_mel_ratio]
224
+ mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
225
+ h = self.encoder_proj(h)
226
+
227
+ # get conditions
228
+ conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
229
+ conds[:, :mel_len1] = prompt_feat
230
+ conds = conds.transpose(1, 2)
231
+
232
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
233
+ feat, _ = self.decoder(
234
+ mu=h.transpose(1, 2).contiguous(),
235
+ mask=mask.unsqueeze(1),
236
+ spks=embedding,
237
+ cond=conds,
238
+ n_timesteps=10
239
+ )
240
+ feat = feat[:, :, mel_len1:]
241
+ assert feat.shape[2] == mel_len2
242
+ return feat.float(), None # NOTE jrm: why are they returning None here?
HF_Deploy/src/chatterbox/models/s3gen/flow_matching.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import threading
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from .matcha.flow_matching import BASECFM
18
+ from omegaconf import OmegaConf
19
+
20
+
21
+ CFM_PARAMS = OmegaConf.create({
22
+ "sigma_min": 1e-06,
23
+ "solver": "euler",
24
+ "t_scheduler": "cosine",
25
+ "training_cfg_rate": 0.2,
26
+ "inference_cfg_rate": 0.7,
27
+ "reg_loss_type": "l1"
28
+ })
29
+
30
+
31
+ class ConditionalCFM(BASECFM):
32
+ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
33
+ super().__init__(
34
+ n_feats=in_channels,
35
+ cfm_params=cfm_params,
36
+ n_spks=n_spks,
37
+ spk_emb_dim=spk_emb_dim,
38
+ )
39
+ self.t_scheduler = cfm_params.t_scheduler
40
+ self.training_cfg_rate = cfm_params.training_cfg_rate
41
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
42
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
43
+ # Just change the architecture of the estimator here
44
+ self.estimator = estimator
45
+ self.lock = threading.Lock()
46
+
47
+ @torch.inference_mode()
48
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
49
+ """Forward diffusion
50
+
51
+ Args:
52
+ mu (torch.Tensor): output of encoder
53
+ shape: (batch_size, n_feats, mel_timesteps)
54
+ mask (torch.Tensor): output_mask
55
+ shape: (batch_size, 1, mel_timesteps)
56
+ n_timesteps (int): number of diffusion steps
57
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
58
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
59
+ shape: (batch_size, spk_emb_dim)
60
+ cond: Not used but kept for future purposes
61
+
62
+ Returns:
63
+ sample: generated mel-spectrogram
64
+ shape: (batch_size, n_feats, mel_timesteps)
65
+ """
66
+
67
+ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
68
+ cache_size = flow_cache.shape[2]
69
+ # fix prompt and overlap part mu and z
70
+ if cache_size != 0:
71
+ z[:, :, :cache_size] = flow_cache[:, :, :, 0]
72
+ mu[:, :, :cache_size] = flow_cache[:, :, :, 1]
73
+ z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
74
+ mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
75
+ flow_cache = torch.stack([z_cache, mu_cache], dim=-1)
76
+
77
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
78
+ if self.t_scheduler == 'cosine':
79
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
80
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache
81
+
82
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
83
+ """
84
+ Fixed euler solver for ODEs.
85
+ Args:
86
+ x (torch.Tensor): random noise
87
+ t_span (torch.Tensor): n_timesteps interpolated
88
+ shape: (n_timesteps + 1,)
89
+ mu (torch.Tensor): output of encoder
90
+ shape: (batch_size, n_feats, mel_timesteps)
91
+ mask (torch.Tensor): output_mask
92
+ shape: (batch_size, 1, mel_timesteps)
93
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
94
+ shape: (batch_size, spk_emb_dim)
95
+ cond: Not used but kept for future purposes
96
+ """
97
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
98
+ t = t.unsqueeze(dim=0)
99
+
100
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
101
+ # Or in future might add like a return_all_steps flag
102
+ sol = []
103
+
104
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
105
+ x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
106
+ mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=x.dtype)
107
+ mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
108
+ t_in = torch.zeros([2], device=x.device, dtype=x.dtype)
109
+ spks_in = torch.zeros([2, 80], device=x.device, dtype=x.dtype)
110
+ cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=x.dtype)
111
+ for step in range(1, len(t_span)):
112
+ # Classifier-Free Guidance inference introduced in VoiceBox
113
+ x_in[:] = x
114
+ mask_in[:] = mask
115
+ mu_in[0] = mu
116
+ t_in[:] = t.unsqueeze(0)
117
+ spks_in[0] = spks
118
+ cond_in[0] = cond
119
+ dphi_dt = self.forward_estimator(
120
+ x_in, mask_in,
121
+ mu_in, t_in,
122
+ spks_in,
123
+ cond_in
124
+ )
125
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
126
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
127
+ x = x + dt * dphi_dt
128
+ t = t + dt
129
+ sol.append(x)
130
+ if step < len(t_span) - 1:
131
+ dt = t_span[step + 1] - t
132
+
133
+ return sol[-1].float()
134
+
135
+ def forward_estimator(self, x, mask, mu, t, spks, cond):
136
+ if isinstance(self.estimator, torch.nn.Module):
137
+ return self.estimator.forward(x, mask, mu, t, spks, cond)
138
+ else:
139
+ with self.lock:
140
+ self.estimator.set_input_shape('x', (2, 80, x.size(2)))
141
+ self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
142
+ self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
143
+ self.estimator.set_input_shape('t', (2,))
144
+ self.estimator.set_input_shape('spks', (2, 80))
145
+ self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
146
+ # run trt engine
147
+ self.estimator.execute_v2([x.contiguous().data_ptr(),
148
+ mask.contiguous().data_ptr(),
149
+ mu.contiguous().data_ptr(),
150
+ t.contiguous().data_ptr(),
151
+ spks.contiguous().data_ptr(),
152
+ cond.contiguous().data_ptr(),
153
+ x.data_ptr()])
154
+ return x
155
+
156
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
157
+ """Computes diffusion loss
158
+
159
+ Args:
160
+ x1 (torch.Tensor): Target
161
+ shape: (batch_size, n_feats, mel_timesteps)
162
+ mask (torch.Tensor): target mask
163
+ shape: (batch_size, 1, mel_timesteps)
164
+ mu (torch.Tensor): output of encoder
165
+ shape: (batch_size, n_feats, mel_timesteps)
166
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
167
+ shape: (batch_size, spk_emb_dim)
168
+
169
+ Returns:
170
+ loss: conditional flow matching loss
171
+ y: conditional flow
172
+ shape: (batch_size, n_feats, mel_timesteps)
173
+ """
174
+ b, _, t = mu.shape
175
+
176
+ # random timestep
177
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
178
+ if self.t_scheduler == 'cosine':
179
+ t = 1 - torch.cos(t * 0.5 * torch.pi)
180
+ # sample noise p(x_0)
181
+ z = torch.randn_like(x1)
182
+
183
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
184
+ u = x1 - (1 - self.sigma_min) * z
185
+
186
+ # during training, we randomly drop condition to trade off mode coverage and sample fidelity
187
+ if self.training_cfg_rate > 0:
188
+ cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
189
+ mu = mu * cfg_mask.view(-1, 1, 1)
190
+ spks = spks * cfg_mask.view(-1, 1)
191
+ cond = cond * cfg_mask.view(-1, 1, 1)
192
+
193
+ pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
194
+ loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
195
+ return loss, y
196
+
197
+
198
+ class CausalConditionalCFM(ConditionalCFM):
199
+ def __init__(self, in_channels=240, cfm_params=CFM_PARAMS, n_spks=1, spk_emb_dim=80, estimator=None):
200
+ super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
201
+ self.rand_noise = torch.randn([1, 80, 50 * 300])
202
+
203
+ @torch.inference_mode()
204
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
205
+ """Forward diffusion
206
+
207
+ Args:
208
+ mu (torch.Tensor): output of encoder
209
+ shape: (batch_size, n_feats, mel_timesteps)
210
+ mask (torch.Tensor): output_mask
211
+ shape: (batch_size, 1, mel_timesteps)
212
+ n_timesteps (int): number of diffusion steps
213
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
214
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
215
+ shape: (batch_size, spk_emb_dim)
216
+ cond: Not used but kept for future purposes
217
+
218
+ Returns:
219
+ sample: generated mel-spectrogram
220
+ shape: (batch_size, n_feats, mel_timesteps)
221
+ """
222
+
223
+ z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
224
+ # fix prompt and overlap part mu and z
225
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
226
+ if self.t_scheduler == 'cosine':
227
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
228
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), None
HF_Deploy/src/chatterbox/models/s3gen/hifigan.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # jrm: adapted from CosyVoice/cosyvoice/hifigan/generator.py
2
+ # most modules should be reusable, but I found their SineGen changed a git.
3
+
4
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ """HIFI-GAN"""
19
+
20
+ from typing import Dict, Optional, List
21
+ import numpy as np
22
+ from scipy.signal import get_window
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from torch.nn import Conv1d
26
+ from torch.nn import ConvTranspose1d
27
+ from torch.nn.utils import remove_weight_norm
28
+ from torch.nn.utils.parametrizations import weight_norm
29
+ from torch.distributions.uniform import Uniform
30
+ from torch import nn, sin, pow
31
+ from torch.nn import Parameter
32
+
33
+
34
+ class Snake(nn.Module):
35
+ '''
36
+ Implementation of a sine-based periodic activation function
37
+ Shape:
38
+ - Input: (B, C, T)
39
+ - Output: (B, C, T), same shape as the input
40
+ Parameters:
41
+ - alpha - trainable parameter
42
+ References:
43
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44
+ https://arxiv.org/abs/2006.08195
45
+ Examples:
46
+ >>> a1 = snake(256)
47
+ >>> x = torch.randn(256)
48
+ >>> x = a1(x)
49
+ '''
50
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
51
+ '''
52
+ Initialization.
53
+ INPUT:
54
+ - in_features: shape of the input
55
+ - alpha: trainable parameter
56
+ alpha is initialized to 1 by default, higher values = higher-frequency.
57
+ alpha will be trained along with the rest of your model.
58
+ '''
59
+ super(Snake, self).__init__()
60
+ self.in_features = in_features
61
+
62
+ # initialize alpha
63
+ self.alpha_logscale = alpha_logscale
64
+ if self.alpha_logscale: # log scale alphas initialized to zeros
65
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
66
+ else: # linear scale alphas initialized to ones
67
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
68
+
69
+ self.alpha.requires_grad = alpha_trainable
70
+
71
+ self.no_div_by_zero = 0.000000001
72
+
73
+ def forward(self, x):
74
+ '''
75
+ Forward pass of the function.
76
+ Applies the function to the input elementwise.
77
+ Snake ∶= x + 1/a * sin^2 (xa)
78
+ '''
79
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
80
+ if self.alpha_logscale:
81
+ alpha = torch.exp(alpha)
82
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
83
+
84
+ return x
85
+
86
+
87
+
88
+ def get_padding(kernel_size, dilation=1):
89
+ return int((kernel_size * dilation - dilation) / 2)
90
+
91
+ def init_weights(m, mean=0.0, std=0.01):
92
+ classname = m.__class__.__name__
93
+ if classname.find("Conv") != -1:
94
+ m.weight.data.normal_(mean, std)
95
+
96
+
97
+ """hifigan based generator implementation.
98
+
99
+ This code is modified from https://github.com/jik876/hifi-gan
100
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
101
+ https://github.com/NVIDIA/BigVGAN
102
+
103
+ """
104
+
105
+
106
+ class ResBlock(torch.nn.Module):
107
+ """Residual block module in HiFiGAN/BigVGAN."""
108
+ def __init__(
109
+ self,
110
+ channels: int = 512,
111
+ kernel_size: int = 3,
112
+ dilations: List[int] = [1, 3, 5],
113
+ ):
114
+ super(ResBlock, self).__init__()
115
+ self.convs1 = nn.ModuleList()
116
+ self.convs2 = nn.ModuleList()
117
+
118
+ for dilation in dilations:
119
+ self.convs1.append(
120
+ weight_norm(
121
+ Conv1d(
122
+ channels,
123
+ channels,
124
+ kernel_size,
125
+ 1,
126
+ dilation=dilation,
127
+ padding=get_padding(kernel_size, dilation)
128
+ )
129
+ )
130
+ )
131
+ self.convs2.append(
132
+ weight_norm(
133
+ Conv1d(
134
+ channels,
135
+ channels,
136
+ kernel_size,
137
+ 1,
138
+ dilation=1,
139
+ padding=get_padding(kernel_size, 1)
140
+ )
141
+ )
142
+ )
143
+ self.convs1.apply(init_weights)
144
+ self.convs2.apply(init_weights)
145
+ self.activations1 = nn.ModuleList([
146
+ Snake(channels, alpha_logscale=False)
147
+ for _ in range(len(self.convs1))
148
+ ])
149
+ self.activations2 = nn.ModuleList([
150
+ Snake(channels, alpha_logscale=False)
151
+ for _ in range(len(self.convs2))
152
+ ])
153
+
154
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
155
+ for idx in range(len(self.convs1)):
156
+ xt = self.activations1[idx](x)
157
+ xt = self.convs1[idx](xt)
158
+ xt = self.activations2[idx](xt)
159
+ xt = self.convs2[idx](xt)
160
+ x = xt + x
161
+ return x
162
+
163
+ def remove_weight_norm(self):
164
+ for idx in range(len(self.convs1)):
165
+ remove_weight_norm(self.convs1[idx])
166
+ remove_weight_norm(self.convs2[idx])
167
+
168
+
169
+ class SineGen(torch.nn.Module):
170
+ """ Definition of sine generator
171
+ SineGen(samp_rate, harmonic_num = 0,
172
+ sine_amp = 0.1, noise_std = 0.003,
173
+ voiced_threshold = 0,
174
+ flag_for_pulse=False)
175
+ samp_rate: sampling rate in Hz
176
+ harmonic_num: number of harmonic overtones (default 0)
177
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
178
+ noise_std: std of Gaussian noise (default 0.003)
179
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
180
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
181
+ Note: when flag_for_pulse is True, the first time step of a voiced
182
+ segment is always sin(np.pi) or cos(0)
183
+ """
184
+
185
+ def __init__(self, samp_rate, harmonic_num=0,
186
+ sine_amp=0.1, noise_std=0.003,
187
+ voiced_threshold=0):
188
+ super(SineGen, self).__init__()
189
+ self.sine_amp = sine_amp
190
+ self.noise_std = noise_std
191
+ self.harmonic_num = harmonic_num
192
+ self.sampling_rate = samp_rate
193
+ self.voiced_threshold = voiced_threshold
194
+
195
+ def _f02uv(self, f0):
196
+ # generate uv signal
197
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
198
+ return uv
199
+
200
+ @torch.no_grad()
201
+ def forward(self, f0):
202
+ """
203
+ :param f0: [B, 1, sample_len], Hz
204
+ :return: [B, 1, sample_len]
205
+ """
206
+
207
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
208
+ for i in range(self.harmonic_num + 1):
209
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
210
+
211
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
212
+ u_dist = Uniform(low=-np.pi, high=np.pi)
213
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
214
+ phase_vec[:, 0, :] = 0
215
+
216
+ # generate sine waveforms
217
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
218
+
219
+ # generate uv signal
220
+ uv = self._f02uv(f0)
221
+
222
+ # noise: for unvoiced should be similar to sine_amp
223
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
224
+ # . for voiced regions is self.noise_std
225
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
226
+ noise = noise_amp * torch.randn_like(sine_waves)
227
+
228
+ # first: set the unvoiced part to 0 by uv
229
+ # then: additive noise
230
+ sine_waves = sine_waves * uv + noise
231
+ return sine_waves, uv, noise
232
+
233
+
234
+ class SourceModuleHnNSF(torch.nn.Module):
235
+ """ SourceModule for hn-nsf
236
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
237
+ add_noise_std=0.003, voiced_threshod=0)
238
+ sampling_rate: sampling_rate in Hz
239
+ harmonic_num: number of harmonic above F0 (default: 0)
240
+ sine_amp: amplitude of sine source signal (default: 0.1)
241
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
242
+ note that amplitude of noise in unvoiced is decided
243
+ by sine_amp
244
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
245
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
246
+ F0_sampled (batchsize, length, 1)
247
+ Sine_source (batchsize, length, 1)
248
+ noise_source (batchsize, length 1)
249
+ uv (batchsize, length, 1)
250
+ """
251
+
252
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
253
+ add_noise_std=0.003, voiced_threshod=0):
254
+ super(SourceModuleHnNSF, self).__init__()
255
+
256
+ self.sine_amp = sine_amp
257
+ self.noise_std = add_noise_std
258
+
259
+ # to produce sine waveforms
260
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
261
+ sine_amp, add_noise_std, voiced_threshod)
262
+
263
+ # to merge source harmonics into a single excitation
264
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
265
+ self.l_tanh = torch.nn.Tanh()
266
+
267
+ def forward(self, x):
268
+ """
269
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
270
+ F0_sampled (batchsize, length, 1)
271
+ Sine_source (batchsize, length, 1)
272
+ noise_source (batchsize, length 1)
273
+ """
274
+ # source for harmonic branch
275
+ with torch.no_grad():
276
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
277
+ sine_wavs = sine_wavs.transpose(1, 2)
278
+ uv = uv.transpose(1, 2)
279
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
280
+
281
+ # source for noise branch, in the same shape as uv
282
+ noise = torch.randn_like(uv) * self.sine_amp / 3
283
+ return sine_merge, noise, uv
284
+
285
+
286
+ class HiFTGenerator(nn.Module):
287
+ """
288
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
289
+ https://arxiv.org/abs/2309.09493
290
+ """
291
+ def __init__(
292
+ self,
293
+ in_channels: int = 80,
294
+ base_channels: int = 512,
295
+ nb_harmonics: int = 8,
296
+ sampling_rate: int = 22050,
297
+ nsf_alpha: float = 0.1,
298
+ nsf_sigma: float = 0.003,
299
+ nsf_voiced_threshold: float = 10,
300
+ upsample_rates: List[int] = [8, 8],
301
+ upsample_kernel_sizes: List[int] = [16, 16],
302
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
303
+ resblock_kernel_sizes: List[int] = [3, 7, 11],
304
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
305
+ source_resblock_kernel_sizes: List[int] = [7, 11],
306
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
307
+ lrelu_slope: float = 0.1,
308
+ audio_limit: float = 0.99,
309
+ f0_predictor: torch.nn.Module = None,
310
+ ):
311
+ super(HiFTGenerator, self).__init__()
312
+
313
+ self.out_channels = 1
314
+ self.nb_harmonics = nb_harmonics
315
+ self.sampling_rate = sampling_rate
316
+ self.istft_params = istft_params
317
+ self.lrelu_slope = lrelu_slope
318
+ self.audio_limit = audio_limit
319
+
320
+ self.num_kernels = len(resblock_kernel_sizes)
321
+ self.num_upsamples = len(upsample_rates)
322
+ self.m_source = SourceModuleHnNSF(
323
+ sampling_rate=sampling_rate,
324
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
325
+ harmonic_num=nb_harmonics,
326
+ sine_amp=nsf_alpha,
327
+ add_noise_std=nsf_sigma,
328
+ voiced_threshod=nsf_voiced_threshold)
329
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
330
+
331
+ self.conv_pre = weight_norm(
332
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
333
+ )
334
+
335
+ # Up
336
+ self.ups = nn.ModuleList()
337
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
338
+ self.ups.append(
339
+ weight_norm(
340
+ ConvTranspose1d(
341
+ base_channels // (2**i),
342
+ base_channels // (2**(i + 1)),
343
+ k,
344
+ u,
345
+ padding=(k - u) // 2,
346
+ )
347
+ )
348
+ )
349
+
350
+ # Down
351
+ self.source_downs = nn.ModuleList()
352
+ self.source_resblocks = nn.ModuleList()
353
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
354
+ downsample_cum_rates = np.cumprod(downsample_rates)
355
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
356
+ if u == 1:
357
+ self.source_downs.append(
358
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
359
+ )
360
+ else:
361
+ self.source_downs.append(
362
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
363
+ )
364
+
365
+ self.source_resblocks.append(
366
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
367
+ )
368
+
369
+ self.resblocks = nn.ModuleList()
370
+ for i in range(len(self.ups)):
371
+ ch = base_channels // (2**(i + 1))
372
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
373
+ self.resblocks.append(ResBlock(ch, k, d))
374
+
375
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
376
+ self.ups.apply(init_weights)
377
+ self.conv_post.apply(init_weights)
378
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
379
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
380
+ self.f0_predictor = f0_predictor
381
+
382
+ def remove_weight_norm(self):
383
+ print('Removing weight norm...')
384
+ for l in self.ups:
385
+ remove_weight_norm(l)
386
+ for l in self.resblocks:
387
+ l.remove_weight_norm()
388
+ remove_weight_norm(self.conv_pre)
389
+ remove_weight_norm(self.conv_post)
390
+ self.m_source.remove_weight_norm()
391
+ for l in self.source_downs:
392
+ remove_weight_norm(l)
393
+ for l in self.source_resblocks:
394
+ l.remove_weight_norm()
395
+
396
+ def _stft(self, x):
397
+ spec = torch.stft(
398
+ x,
399
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
400
+ return_complex=True)
401
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
402
+ return spec[..., 0], spec[..., 1]
403
+
404
+ def _istft(self, magnitude, phase):
405
+ magnitude = torch.clip(magnitude, max=1e2)
406
+ real = magnitude * torch.cos(phase)
407
+ img = magnitude * torch.sin(phase)
408
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
409
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
410
+ return inverse_transform
411
+
412
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
413
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
414
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
415
+
416
+ x = self.conv_pre(x)
417
+ for i in range(self.num_upsamples):
418
+ x = F.leaky_relu(x, self.lrelu_slope)
419
+ x = self.ups[i](x)
420
+
421
+ if i == self.num_upsamples - 1:
422
+ x = self.reflection_pad(x)
423
+
424
+ # fusion
425
+ si = self.source_downs[i](s_stft)
426
+ si = self.source_resblocks[i](si)
427
+ x = x + si
428
+
429
+ xs = None
430
+ for j in range(self.num_kernels):
431
+ if xs is None:
432
+ xs = self.resblocks[i * self.num_kernels + j](x)
433
+ else:
434
+ xs += self.resblocks[i * self.num_kernels + j](x)
435
+ x = xs / self.num_kernels
436
+
437
+ x = F.leaky_relu(x)
438
+ x = self.conv_post(x)
439
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
440
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
441
+
442
+ x = self._istft(magnitude, phase)
443
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
444
+ return x
445
+
446
+ def forward(
447
+ self,
448
+ batch: dict,
449
+ device: torch.device,
450
+ ) -> Dict[str, Optional[torch.Tensor]]:
451
+ speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
452
+ # mel->f0
453
+ f0 = self.f0_predictor(speech_feat)
454
+ # f0->source
455
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
456
+ s, _, _ = self.m_source(s)
457
+ s = s.transpose(1, 2)
458
+ # mel+source->speech
459
+ generated_speech = self.decode(x=speech_feat, s=s)
460
+ return generated_speech, f0
461
+
462
+ @torch.inference_mode()
463
+ def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
464
+ # mel->f0
465
+ f0 = self.f0_predictor(speech_feat)
466
+ # f0->source
467
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
468
+ s, _, _ = self.m_source(s)
469
+ s = s.transpose(1, 2)
470
+ # use cache_source to avoid glitch
471
+ if cache_source.shape[2] != 0:
472
+ s[:, :, :cache_source.shape[2]] = cache_source
473
+ generated_speech = self.decode(x=speech_feat, s=s)
474
+ return generated_speech, s
HF_Deploy/src/chatterbox/models/s3gen/matcha/decoder.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from conformer import ConformerBlock
8
+ from diffusers.models.activations import get_activation
9
+ from einops import pack, rearrange, repeat
10
+
11
+ from .transformer import BasicTransformerBlock
12
+
13
+
14
+ class SinusoidalPosEmb(torch.nn.Module):
15
+ def __init__(self, dim):
16
+ super().__init__()
17
+ self.dim = dim
18
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
19
+
20
+ def forward(self, x, scale=1000):
21
+ if x.ndim < 1:
22
+ x = x.unsqueeze(0)
23
+ device = x.device
24
+ half_dim = self.dim // 2
25
+ emb = math.log(10000) / (half_dim - 1)
26
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
27
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
28
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
29
+ return emb
30
+
31
+
32
+ class Block1D(torch.nn.Module):
33
+ def __init__(self, dim, dim_out, groups=8):
34
+ super().__init__()
35
+ self.block = torch.nn.Sequential(
36
+ torch.nn.Conv1d(dim, dim_out, 3, padding=1),
37
+ torch.nn.GroupNorm(groups, dim_out),
38
+ nn.Mish(),
39
+ )
40
+
41
+ def forward(self, x, mask):
42
+ output = self.block(x * mask)
43
+ return output * mask
44
+
45
+
46
+ class ResnetBlock1D(torch.nn.Module):
47
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
48
+ super().__init__()
49
+ self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
50
+
51
+ self.block1 = Block1D(dim, dim_out, groups=groups)
52
+ self.block2 = Block1D(dim_out, dim_out, groups=groups)
53
+
54
+ self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
55
+
56
+ def forward(self, x, mask, time_emb):
57
+ h = self.block1(x, mask)
58
+ h += self.mlp(time_emb).unsqueeze(-1)
59
+ h = self.block2(h, mask)
60
+ output = h + self.res_conv(x * mask)
61
+ return output
62
+
63
+
64
+ class Downsample1D(nn.Module):
65
+ def __init__(self, dim):
66
+ super().__init__()
67
+ self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
68
+
69
+ def forward(self, x):
70
+ return self.conv(x)
71
+
72
+
73
+ class TimestepEmbedding(nn.Module):
74
+ def __init__(
75
+ self,
76
+ in_channels: int,
77
+ time_embed_dim: int,
78
+ act_fn: str = "silu",
79
+ out_dim: int = None,
80
+ post_act_fn: Optional[str] = None,
81
+ cond_proj_dim=None,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
86
+
87
+ if cond_proj_dim is not None:
88
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
89
+ else:
90
+ self.cond_proj = None
91
+
92
+ self.act = get_activation(act_fn)
93
+
94
+ if out_dim is not None:
95
+ time_embed_dim_out = out_dim
96
+ else:
97
+ time_embed_dim_out = time_embed_dim
98
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
99
+
100
+ if post_act_fn is None:
101
+ self.post_act = None
102
+ else:
103
+ self.post_act = get_activation(post_act_fn)
104
+
105
+ def forward(self, sample, condition=None):
106
+ if condition is not None:
107
+ sample = sample + self.cond_proj(condition)
108
+ sample = self.linear_1(sample)
109
+
110
+ if self.act is not None:
111
+ sample = self.act(sample)
112
+
113
+ sample = self.linear_2(sample)
114
+
115
+ if self.post_act is not None:
116
+ sample = self.post_act(sample)
117
+ return sample
118
+
119
+
120
+ class Upsample1D(nn.Module):
121
+ """A 1D upsampling layer with an optional convolution.
122
+
123
+ Parameters:
124
+ channels (`int`):
125
+ number of channels in the inputs and outputs.
126
+ use_conv (`bool`, default `False`):
127
+ option to use a convolution.
128
+ use_conv_transpose (`bool`, default `False`):
129
+ option to use a convolution transpose.
130
+ out_channels (`int`, optional):
131
+ number of output channels. Defaults to `channels`.
132
+ """
133
+
134
+ def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
135
+ super().__init__()
136
+ self.channels = channels
137
+ self.out_channels = out_channels or channels
138
+ self.use_conv = use_conv
139
+ self.use_conv_transpose = use_conv_transpose
140
+ self.name = name
141
+
142
+ self.conv = None
143
+ if use_conv_transpose:
144
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
145
+ elif use_conv:
146
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
147
+
148
+ def forward(self, inputs):
149
+ assert inputs.shape[1] == self.channels
150
+ if self.use_conv_transpose:
151
+ return self.conv(inputs)
152
+
153
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
154
+
155
+ if self.use_conv:
156
+ outputs = self.conv(outputs)
157
+
158
+ return outputs
159
+
160
+
161
+ class ConformerWrapper(ConformerBlock):
162
+ def __init__( # pylint: disable=useless-super-delegation
163
+ self,
164
+ *,
165
+ dim,
166
+ dim_head=64,
167
+ heads=8,
168
+ ff_mult=4,
169
+ conv_expansion_factor=2,
170
+ conv_kernel_size=31,
171
+ attn_dropout=0,
172
+ ff_dropout=0,
173
+ conv_dropout=0,
174
+ conv_causal=False,
175
+ ):
176
+ super().__init__(
177
+ dim=dim,
178
+ dim_head=dim_head,
179
+ heads=heads,
180
+ ff_mult=ff_mult,
181
+ conv_expansion_factor=conv_expansion_factor,
182
+ conv_kernel_size=conv_kernel_size,
183
+ attn_dropout=attn_dropout,
184
+ ff_dropout=ff_dropout,
185
+ conv_dropout=conv_dropout,
186
+ conv_causal=conv_causal,
187
+ )
188
+
189
+ def forward(
190
+ self,
191
+ hidden_states,
192
+ attention_mask,
193
+ encoder_hidden_states=None,
194
+ encoder_attention_mask=None,
195
+ timestep=None,
196
+ ):
197
+ return super().forward(x=hidden_states, mask=attention_mask.bool())
198
+
199
+
200
+ class Decoder(nn.Module):
201
+ def __init__(
202
+ self,
203
+ in_channels,
204
+ out_channels,
205
+ channels=(256, 256),
206
+ dropout=0.05,
207
+ attention_head_dim=64,
208
+ n_blocks=1,
209
+ num_mid_blocks=2,
210
+ num_heads=4,
211
+ act_fn="snake",
212
+ down_block_type="transformer",
213
+ mid_block_type="transformer",
214
+ up_block_type="transformer",
215
+ ):
216
+ super().__init__()
217
+ channels = tuple(channels)
218
+ self.in_channels = in_channels
219
+ self.out_channels = out_channels
220
+
221
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
222
+ time_embed_dim = channels[0] * 4
223
+ self.time_mlp = TimestepEmbedding(
224
+ in_channels=in_channels,
225
+ time_embed_dim=time_embed_dim,
226
+ act_fn="silu",
227
+ )
228
+
229
+ self.down_blocks = nn.ModuleList([])
230
+ self.mid_blocks = nn.ModuleList([])
231
+ self.up_blocks = nn.ModuleList([])
232
+
233
+ output_channel = in_channels
234
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
235
+ input_channel = output_channel
236
+ output_channel = channels[i]
237
+ is_last = i == len(channels) - 1
238
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
239
+ transformer_blocks = nn.ModuleList(
240
+ [
241
+ self.get_block(
242
+ down_block_type,
243
+ output_channel,
244
+ attention_head_dim,
245
+ num_heads,
246
+ dropout,
247
+ act_fn,
248
+ )
249
+ for _ in range(n_blocks)
250
+ ]
251
+ )
252
+ downsample = (
253
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
254
+ )
255
+
256
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
257
+
258
+ for i in range(num_mid_blocks):
259
+ input_channel = channels[-1]
260
+ out_channels = channels[-1]
261
+
262
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
263
+
264
+ transformer_blocks = nn.ModuleList(
265
+ [
266
+ self.get_block(
267
+ mid_block_type,
268
+ output_channel,
269
+ attention_head_dim,
270
+ num_heads,
271
+ dropout,
272
+ act_fn,
273
+ )
274
+ for _ in range(n_blocks)
275
+ ]
276
+ )
277
+
278
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
279
+
280
+ channels = channels[::-1] + (channels[0],)
281
+ for i in range(len(channels) - 1):
282
+ input_channel = channels[i]
283
+ output_channel = channels[i + 1]
284
+ is_last = i == len(channels) - 2
285
+
286
+ resnet = ResnetBlock1D(
287
+ dim=2 * input_channel,
288
+ dim_out=output_channel,
289
+ time_emb_dim=time_embed_dim,
290
+ )
291
+ transformer_blocks = nn.ModuleList(
292
+ [
293
+ self.get_block(
294
+ up_block_type,
295
+ output_channel,
296
+ attention_head_dim,
297
+ num_heads,
298
+ dropout,
299
+ act_fn,
300
+ )
301
+ for _ in range(n_blocks)
302
+ ]
303
+ )
304
+ upsample = (
305
+ Upsample1D(output_channel, use_conv_transpose=True)
306
+ if not is_last
307
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
308
+ )
309
+
310
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
311
+
312
+ self.final_block = Block1D(channels[-1], channels[-1])
313
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
314
+
315
+ self.initialize_weights()
316
+ # nn.init.normal_(self.final_proj.weight)
317
+
318
+ @staticmethod
319
+ def get_block(block_type, dim, attention_head_dim, num_heads, dropout, act_fn):
320
+ if block_type == "conformer":
321
+ block = ConformerWrapper(
322
+ dim=dim,
323
+ dim_head=attention_head_dim,
324
+ heads=num_heads,
325
+ ff_mult=1,
326
+ conv_expansion_factor=2,
327
+ ff_dropout=dropout,
328
+ attn_dropout=dropout,
329
+ conv_dropout=dropout,
330
+ conv_kernel_size=31,
331
+ )
332
+ elif block_type == "transformer":
333
+ block = BasicTransformerBlock(
334
+ dim=dim,
335
+ num_attention_heads=num_heads,
336
+ attention_head_dim=attention_head_dim,
337
+ dropout=dropout,
338
+ activation_fn=act_fn,
339
+ )
340
+ else:
341
+ raise ValueError(f"Unknown block type {block_type}")
342
+
343
+ return block
344
+
345
+ def initialize_weights(self):
346
+ for m in self.modules():
347
+ if isinstance(m, nn.Conv1d):
348
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
349
+
350
+ if m.bias is not None:
351
+ nn.init.constant_(m.bias, 0)
352
+
353
+ elif isinstance(m, nn.GroupNorm):
354
+ nn.init.constant_(m.weight, 1)
355
+ nn.init.constant_(m.bias, 0)
356
+
357
+ elif isinstance(m, nn.Linear):
358
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
359
+
360
+ if m.bias is not None:
361
+ nn.init.constant_(m.bias, 0)
362
+
363
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
364
+ """Forward pass of the UNet1DConditional model.
365
+
366
+ Args:
367
+ x (torch.Tensor): shape (batch_size, in_channels, time)
368
+ mask (_type_): shape (batch_size, 1, time)
369
+ t (_type_): shape (batch_size)
370
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
371
+ cond (_type_, optional): placeholder for future use. Defaults to None.
372
+
373
+ Raises:
374
+ ValueError: _description_
375
+ ValueError: _description_
376
+
377
+ Returns:
378
+ _type_: _description_
379
+ """
380
+
381
+ t = self.time_embeddings(t)
382
+ t = self.time_mlp(t)
383
+
384
+ x = pack([x, mu], "b * t")[0]
385
+
386
+ if spks is not None:
387
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
388
+ x = pack([x, spks], "b * t")[0]
389
+
390
+ hiddens = []
391
+ masks = [mask]
392
+ for resnet, transformer_blocks, downsample in self.down_blocks:
393
+ mask_down = masks[-1]
394
+ x = resnet(x, mask_down, t)
395
+ x = rearrange(x, "b c t -> b t c")
396
+ mask_down = rearrange(mask_down, "b 1 t -> b t")
397
+ for transformer_block in transformer_blocks:
398
+ x = transformer_block(
399
+ hidden_states=x,
400
+ attention_mask=mask_down,
401
+ timestep=t,
402
+ )
403
+ x = rearrange(x, "b t c -> b c t")
404
+ mask_down = rearrange(mask_down, "b t -> b 1 t")
405
+ hiddens.append(x) # Save hidden states for skip connections
406
+ x = downsample(x * mask_down)
407
+ masks.append(mask_down[:, :, ::2])
408
+
409
+ masks = masks[:-1]
410
+ mask_mid = masks[-1]
411
+
412
+ for resnet, transformer_blocks in self.mid_blocks:
413
+ x = resnet(x, mask_mid, t)
414
+ x = rearrange(x, "b c t -> b t c")
415
+ mask_mid = rearrange(mask_mid, "b 1 t -> b t")
416
+ for transformer_block in transformer_blocks:
417
+ x = transformer_block(
418
+ hidden_states=x,
419
+ attention_mask=mask_mid,
420
+ timestep=t,
421
+ )
422
+ x = rearrange(x, "b t c -> b c t")
423
+ mask_mid = rearrange(mask_mid, "b t -> b 1 t")
424
+
425
+ for resnet, transformer_blocks, upsample in self.up_blocks:
426
+ mask_up = masks.pop()
427
+ x = resnet(pack([x, hiddens.pop()], "b * t")[0], mask_up, t)
428
+ x = rearrange(x, "b c t -> b t c")
429
+ mask_up = rearrange(mask_up, "b 1 t -> b t")
430
+ for transformer_block in transformer_blocks:
431
+ x = transformer_block(
432
+ hidden_states=x,
433
+ attention_mask=mask_up,
434
+ timestep=t,
435
+ )
436
+ x = rearrange(x, "b t c -> b c t")
437
+ mask_up = rearrange(mask_up, "b t -> b 1 t")
438
+ x = upsample(x * mask_up)
439
+
440
+ x = self.final_block(x, mask_up)
441
+ output = self.final_proj(x * mask_up)
442
+
443
+ return output * mask
HF_Deploy/src/chatterbox/models/s3gen/matcha/flow_matching.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from .decoder import Decoder
7
+
8
+
9
+ class BASECFM(torch.nn.Module, ABC):
10
+ def __init__(
11
+ self,
12
+ n_feats,
13
+ cfm_params,
14
+ n_spks=1,
15
+ spk_emb_dim=128,
16
+ ):
17
+ super().__init__()
18
+ self.n_feats = n_feats
19
+ self.n_spks = n_spks
20
+ self.spk_emb_dim = spk_emb_dim
21
+ self.solver = cfm_params.solver
22
+ if hasattr(cfm_params, "sigma_min"):
23
+ self.sigma_min = cfm_params.sigma_min
24
+ else:
25
+ self.sigma_min = 1e-4
26
+
27
+ self.estimator = None
28
+
29
+ @torch.inference_mode()
30
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
31
+ """Forward diffusion
32
+
33
+ Args:
34
+ mu (torch.Tensor): output of encoder
35
+ shape: (batch_size, n_feats, mel_timesteps)
36
+ mask (torch.Tensor): output_mask
37
+ shape: (batch_size, 1, mel_timesteps)
38
+ n_timesteps (int): number of diffusion steps
39
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
40
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
41
+ shape: (batch_size, spk_emb_dim)
42
+ cond: Not used but kept for future purposes
43
+
44
+ Returns:
45
+ sample: generated mel-spectrogram
46
+ shape: (batch_size, n_feats, mel_timesteps)
47
+ """
48
+ z = torch.randn_like(mu) * temperature
49
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
50
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
51
+
52
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
53
+ """
54
+ Fixed euler solver for ODEs.
55
+ Args:
56
+ x (torch.Tensor): random noise
57
+ t_span (torch.Tensor): n_timesteps interpolated
58
+ shape: (n_timesteps + 1,)
59
+ mu (torch.Tensor): output of encoder
60
+ shape: (batch_size, n_feats, mel_timesteps)
61
+ mask (torch.Tensor): output_mask
62
+ shape: (batch_size, 1, mel_timesteps)
63
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
64
+ shape: (batch_size, spk_emb_dim)
65
+ cond: Not used but kept for future purposes
66
+ """
67
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
68
+
69
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
70
+ # Or in future might add like a return_all_steps flag
71
+ sol = []
72
+
73
+ for step in range(1, len(t_span)):
74
+ dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
75
+
76
+ x = x + dt * dphi_dt
77
+ t = t + dt
78
+ sol.append(x)
79
+ if step < len(t_span) - 1:
80
+ dt = t_span[step + 1] - t
81
+
82
+ return sol[-1]
83
+
84
+ def compute_loss(self, x1, mask, mu, spks=None, cond=None):
85
+ """Computes diffusion loss
86
+
87
+ Args:
88
+ x1 (torch.Tensor): Target
89
+ shape: (batch_size, n_feats, mel_timesteps)
90
+ mask (torch.Tensor): target mask
91
+ shape: (batch_size, 1, mel_timesteps)
92
+ mu (torch.Tensor): output of encoder
93
+ shape: (batch_size, n_feats, mel_timesteps)
94
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
95
+ shape: (batch_size, spk_emb_dim)
96
+
97
+ Returns:
98
+ loss: conditional flow matching loss
99
+ y: conditional flow
100
+ shape: (batch_size, n_feats, mel_timesteps)
101
+ """
102
+ b, _, t = mu.shape
103
+
104
+ # random timestep
105
+ t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
106
+ # sample noise p(x_0)
107
+ z = torch.randn_like(x1)
108
+
109
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
110
+ u = x1 - (1 - self.sigma_min) * z
111
+
112
+ loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
113
+ torch.sum(mask) * u.shape[1]
114
+ )
115
+ return loss, y
116
+
117
+
118
+ class CFM(BASECFM):
119
+ def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
120
+ super().__init__(
121
+ n_feats=in_channels,
122
+ cfm_params=cfm_params,
123
+ n_spks=n_spks,
124
+ spk_emb_dim=spk_emb_dim,
125
+ )
126
+
127
+ in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
128
+ # Just change the architecture of the estimator here
129
+ self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
HF_Deploy/src/chatterbox/models/s3gen/matcha/text_encoder.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jaywalnut310/glow-tts """
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from einops import rearrange
8
+
9
+
10
+ def sequence_mask(length, max_length=None):
11
+ if max_length is None:
12
+ max_length = length.max()
13
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
14
+ return x.unsqueeze(0) < length.unsqueeze(1)
15
+
16
+
17
+
18
+ class LayerNorm(nn.Module):
19
+ def __init__(self, channels, eps=1e-4):
20
+ super().__init__()
21
+ self.channels = channels
22
+ self.eps = eps
23
+
24
+ self.gamma = torch.nn.Parameter(torch.ones(channels))
25
+ self.beta = torch.nn.Parameter(torch.zeros(channels))
26
+
27
+ def forward(self, x):
28
+ n_dims = len(x.shape)
29
+ mean = torch.mean(x, 1, keepdim=True)
30
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
31
+
32
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
33
+
34
+ shape = [1, -1] + [1] * (n_dims - 2)
35
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
36
+ return x
37
+
38
+
39
+ class ConvReluNorm(nn.Module):
40
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
41
+ super().__init__()
42
+ self.in_channels = in_channels
43
+ self.hidden_channels = hidden_channels
44
+ self.out_channels = out_channels
45
+ self.kernel_size = kernel_size
46
+ self.n_layers = n_layers
47
+ self.p_dropout = p_dropout
48
+
49
+ self.conv_layers = torch.nn.ModuleList()
50
+ self.norm_layers = torch.nn.ModuleList()
51
+ self.conv_layers.append(torch.nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
52
+ self.norm_layers.append(LayerNorm(hidden_channels))
53
+ self.relu_drop = torch.nn.Sequential(torch.nn.ReLU(), torch.nn.Dropout(p_dropout))
54
+ for _ in range(n_layers - 1):
55
+ self.conv_layers.append(
56
+ torch.nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
57
+ )
58
+ self.norm_layers.append(LayerNorm(hidden_channels))
59
+ self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
60
+ self.proj.weight.data.zero_()
61
+ self.proj.bias.data.zero_()
62
+
63
+ def forward(self, x, x_mask):
64
+ x_org = x
65
+ for i in range(self.n_layers):
66
+ x = self.conv_layers[i](x * x_mask)
67
+ x = self.norm_layers[i](x)
68
+ x = self.relu_drop(x)
69
+ x = x_org + self.proj(x)
70
+ return x * x_mask
71
+
72
+
73
+ class DurationPredictor(nn.Module):
74
+ def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
75
+ super().__init__()
76
+ self.in_channels = in_channels
77
+ self.filter_channels = filter_channels
78
+ self.p_dropout = p_dropout
79
+
80
+ self.drop = torch.nn.Dropout(p_dropout)
81
+ self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
82
+ self.norm_1 = LayerNorm(filter_channels)
83
+ self.conv_2 = torch.nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
84
+ self.norm_2 = LayerNorm(filter_channels)
85
+ self.proj = torch.nn.Conv1d(filter_channels, 1, 1)
86
+
87
+ def forward(self, x, x_mask):
88
+ x = self.conv_1(x * x_mask)
89
+ x = torch.relu(x)
90
+ x = self.norm_1(x)
91
+ x = self.drop(x)
92
+ x = self.conv_2(x * x_mask)
93
+ x = torch.relu(x)
94
+ x = self.norm_2(x)
95
+ x = self.drop(x)
96
+ x = self.proj(x * x_mask)
97
+ return x * x_mask
98
+
99
+
100
+ class RotaryPositionalEmbeddings(nn.Module):
101
+ """
102
+ ## RoPE module
103
+
104
+ Rotary encoding transforms pairs of features by rotating in the 2D plane.
105
+ That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
106
+ Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
107
+ by an angle depending on the position of the token.
108
+ """
109
+
110
+ def __init__(self, d: int, base: int = 10_000):
111
+ r"""
112
+ * `d` is the number of features $d$
113
+ * `base` is the constant used for calculating $\Theta$
114
+ """
115
+ super().__init__()
116
+
117
+ self.base = base
118
+ self.d = int(d)
119
+ self.cos_cached = None
120
+ self.sin_cached = None
121
+
122
+ def _build_cache(self, x: torch.Tensor):
123
+ r"""
124
+ Cache $\cos$ and $\sin$ values
125
+ """
126
+ # Return if cache is already built
127
+ if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
128
+ return
129
+
130
+ # Get sequence length
131
+ seq_len = x.shape[0]
132
+
133
+ # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
134
+ theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
135
+
136
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
137
+ seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
138
+
139
+ # Calculate the product of position index and $\theta_i$
140
+ idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
141
+
142
+ # Concatenate so that for row $m$ we have
143
+ # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
144
+ idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
145
+
146
+ # Cache them
147
+ self.cos_cached = idx_theta2.cos()[:, None, None, :]
148
+ self.sin_cached = idx_theta2.sin()[:, None, None, :]
149
+
150
+ def _neg_half(self, x: torch.Tensor):
151
+ # $\frac{d}{2}$
152
+ d_2 = self.d // 2
153
+
154
+ # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
155
+ return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
156
+
157
+ def forward(self, x: torch.Tensor):
158
+ """
159
+ * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
160
+ """
161
+ # Cache $\cos$ and $\sin$ values
162
+ x = rearrange(x, "b h t d -> t b h d")
163
+
164
+ self._build_cache(x)
165
+
166
+ # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
167
+ x_rope, x_pass = x[..., : self.d], x[..., self.d :]
168
+
169
+ # Calculate
170
+ # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
171
+ neg_half_x = self._neg_half(x_rope)
172
+
173
+ x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
174
+
175
+ return rearrange(torch.cat((x_rope, x_pass), dim=-1), "t b h d -> b h t d")
176
+
177
+
178
+ class MultiHeadAttention(nn.Module):
179
+ def __init__(
180
+ self,
181
+ channels,
182
+ out_channels,
183
+ n_heads,
184
+ heads_share=True,
185
+ p_dropout=0.0,
186
+ proximal_bias=False,
187
+ proximal_init=False,
188
+ ):
189
+ super().__init__()
190
+ assert channels % n_heads == 0
191
+
192
+ self.channels = channels
193
+ self.out_channels = out_channels
194
+ self.n_heads = n_heads
195
+ self.heads_share = heads_share
196
+ self.proximal_bias = proximal_bias
197
+ self.p_dropout = p_dropout
198
+ self.attn = None
199
+
200
+ self.k_channels = channels // n_heads
201
+ self.conv_q = torch.nn.Conv1d(channels, channels, 1)
202
+ self.conv_k = torch.nn.Conv1d(channels, channels, 1)
203
+ self.conv_v = torch.nn.Conv1d(channels, channels, 1)
204
+
205
+ # from https://nn.labml.ai/transformers/rope/index.html
206
+ self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
207
+ self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
208
+
209
+ self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
210
+ self.drop = torch.nn.Dropout(p_dropout)
211
+
212
+ torch.nn.init.xavier_uniform_(self.conv_q.weight)
213
+ torch.nn.init.xavier_uniform_(self.conv_k.weight)
214
+ if proximal_init:
215
+ self.conv_k.weight.data.copy_(self.conv_q.weight.data)
216
+ self.conv_k.bias.data.copy_(self.conv_q.bias.data)
217
+ torch.nn.init.xavier_uniform_(self.conv_v.weight)
218
+
219
+ def forward(self, x, c, attn_mask=None):
220
+ q = self.conv_q(x)
221
+ k = self.conv_k(c)
222
+ v = self.conv_v(c)
223
+
224
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
225
+
226
+ x = self.conv_o(x)
227
+ return x
228
+
229
+ def attention(self, query, key, value, mask=None):
230
+ b, d, t_s, t_t = (*key.size(), query.size(2))
231
+ query = rearrange(query, "b (h c) t-> b h t c", h=self.n_heads)
232
+ key = rearrange(key, "b (h c) t-> b h t c", h=self.n_heads)
233
+ value = rearrange(value, "b (h c) t-> b h t c", h=self.n_heads)
234
+
235
+ query = self.query_rotary_pe(query)
236
+ key = self.key_rotary_pe(key)
237
+
238
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
239
+
240
+ if self.proximal_bias:
241
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
242
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
243
+ if mask is not None:
244
+ scores = scores.masked_fill(mask == 0, -1e4)
245
+ p_attn = torch.nn.functional.softmax(scores, dim=-1)
246
+ p_attn = self.drop(p_attn)
247
+ output = torch.matmul(p_attn, value)
248
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t)
249
+ return output, p_attn
250
+
251
+ @staticmethod
252
+ def _attention_bias_proximal(length):
253
+ r = torch.arange(length, dtype=torch.float32)
254
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
255
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
256
+
257
+
258
+ class FFN(nn.Module):
259
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0):
260
+ super().__init__()
261
+ self.in_channels = in_channels
262
+ self.out_channels = out_channels
263
+ self.filter_channels = filter_channels
264
+ self.kernel_size = kernel_size
265
+ self.p_dropout = p_dropout
266
+
267
+ self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
268
+ self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
269
+ self.drop = torch.nn.Dropout(p_dropout)
270
+
271
+ def forward(self, x, x_mask):
272
+ x = self.conv_1(x * x_mask)
273
+ x = torch.relu(x)
274
+ x = self.drop(x)
275
+ x = self.conv_2(x * x_mask)
276
+ return x * x_mask
277
+
278
+
279
+ class Encoder(nn.Module):
280
+ def __init__(
281
+ self,
282
+ hidden_channels,
283
+ filter_channels,
284
+ n_heads,
285
+ n_layers,
286
+ kernel_size=1,
287
+ p_dropout=0.0,
288
+ **kwargs,
289
+ ):
290
+ super().__init__()
291
+ self.hidden_channels = hidden_channels
292
+ self.filter_channels = filter_channels
293
+ self.n_heads = n_heads
294
+ self.n_layers = n_layers
295
+ self.kernel_size = kernel_size
296
+ self.p_dropout = p_dropout
297
+
298
+ self.drop = torch.nn.Dropout(p_dropout)
299
+ self.attn_layers = torch.nn.ModuleList()
300
+ self.norm_layers_1 = torch.nn.ModuleList()
301
+ self.ffn_layers = torch.nn.ModuleList()
302
+ self.norm_layers_2 = torch.nn.ModuleList()
303
+ for _ in range(self.n_layers):
304
+ self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout))
305
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
306
+ self.ffn_layers.append(
307
+ FFN(
308
+ hidden_channels,
309
+ hidden_channels,
310
+ filter_channels,
311
+ kernel_size,
312
+ p_dropout=p_dropout,
313
+ )
314
+ )
315
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
316
+
317
+ def forward(self, x, x_mask):
318
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
319
+ for i in range(self.n_layers):
320
+ x = x * x_mask
321
+ y = self.attn_layers[i](x, x, attn_mask)
322
+ y = self.drop(y)
323
+ x = self.norm_layers_1[i](x + y)
324
+ y = self.ffn_layers[i](x, x_mask)
325
+ y = self.drop(y)
326
+ x = self.norm_layers_2[i](x + y)
327
+ x = x * x_mask
328
+ return x
329
+
330
+
331
+ class TextEncoder(nn.Module):
332
+ def __init__(
333
+ self,
334
+ encoder_type,
335
+ encoder_params,
336
+ duration_predictor_params,
337
+ n_vocab,
338
+ n_spks=1,
339
+ spk_emb_dim=128,
340
+ ):
341
+ super().__init__()
342
+ self.encoder_type = encoder_type
343
+ self.n_vocab = n_vocab
344
+ self.n_feats = encoder_params.n_feats
345
+ self.n_channels = encoder_params.n_channels
346
+ self.spk_emb_dim = spk_emb_dim
347
+ self.n_spks = n_spks
348
+
349
+ self.emb = torch.nn.Embedding(n_vocab, self.n_channels)
350
+ torch.nn.init.normal_(self.emb.weight, 0.0, self.n_channels**-0.5)
351
+
352
+ if encoder_params.prenet:
353
+ self.prenet = ConvReluNorm(
354
+ self.n_channels,
355
+ self.n_channels,
356
+ self.n_channels,
357
+ kernel_size=5,
358
+ n_layers=3,
359
+ p_dropout=0.5,
360
+ )
361
+ else:
362
+ self.prenet = lambda x, x_mask: x
363
+
364
+ self.encoder = Encoder(
365
+ encoder_params.n_channels + (spk_emb_dim if n_spks > 1 else 0),
366
+ encoder_params.filter_channels,
367
+ encoder_params.n_heads,
368
+ encoder_params.n_layers,
369
+ encoder_params.kernel_size,
370
+ encoder_params.p_dropout,
371
+ )
372
+
373
+ self.proj_m = torch.nn.Conv1d(self.n_channels + (spk_emb_dim if n_spks > 1 else 0), self.n_feats, 1)
374
+ self.proj_w = DurationPredictor(
375
+ self.n_channels + (spk_emb_dim if n_spks > 1 else 0),
376
+ duration_predictor_params.filter_channels_dp,
377
+ duration_predictor_params.kernel_size,
378
+ duration_predictor_params.p_dropout,
379
+ )
380
+
381
+ def forward(self, x, x_lengths, spks=None):
382
+ """Run forward pass to the transformer based encoder and duration predictor
383
+
384
+ Args:
385
+ x (torch.Tensor): text input
386
+ shape: (batch_size, max_text_length)
387
+ x_lengths (torch.Tensor): text input lengths
388
+ shape: (batch_size,)
389
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
390
+ shape: (batch_size,)
391
+
392
+ Returns:
393
+ mu (torch.Tensor): average output of the encoder
394
+ shape: (batch_size, n_feats, max_text_length)
395
+ logw (torch.Tensor): log duration predicted by the duration predictor
396
+ shape: (batch_size, 1, max_text_length)
397
+ x_mask (torch.Tensor): mask for the text input
398
+ shape: (batch_size, 1, max_text_length)
399
+ """
400
+ x = self.emb(x) * math.sqrt(self.n_channels)
401
+ x = torch.transpose(x, 1, -1)
402
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
403
+
404
+ x = self.prenet(x, x_mask)
405
+ if self.n_spks > 1:
406
+ x = torch.cat([x, spks.unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
407
+ x = self.encoder(x, x_mask)
408
+ mu = self.proj_m(x) * x_mask
409
+
410
+ x_dp = torch.detach(x)
411
+ logw = self.proj_w(x_dp, x_mask)
412
+
413
+ return mu, logw, x_mask
HF_Deploy/src/chatterbox/models/s3gen/matcha/transformer.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from diffusers.models.attention import (
6
+ GEGLU,
7
+ GELU,
8
+ AdaLayerNorm,
9
+ AdaLayerNormZero,
10
+ ApproximateGELU,
11
+ )
12
+ from diffusers.models.attention_processor import Attention
13
+ from diffusers.models.lora import LoRACompatibleLinear
14
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
15
+
16
+
17
+ class SnakeBeta(nn.Module):
18
+ """
19
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
20
+ Shape:
21
+ - Input: (B, C, T)
22
+ - Output: (B, C, T), same shape as the input
23
+ Parameters:
24
+ - alpha - trainable parameter that controls frequency
25
+ - beta - trainable parameter that controls magnitude
26
+ References:
27
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
28
+ https://arxiv.org/abs/2006.08195
29
+ Examples:
30
+ >>> a1 = snakebeta(256)
31
+ >>> x = torch.randn(256)
32
+ >>> x = a1(x)
33
+ """
34
+
35
+ def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
36
+ """
37
+ Initialization.
38
+ INPUT:
39
+ - in_features: shape of the input
40
+ - alpha - trainable parameter that controls frequency
41
+ - beta - trainable parameter that controls magnitude
42
+ alpha is initialized to 1 by default, higher values = higher-frequency.
43
+ beta is initialized to 1 by default, higher values = higher-magnitude.
44
+ alpha will be trained along with the rest of your model.
45
+ """
46
+ super().__init__()
47
+ self.in_features = out_features if isinstance(out_features, list) else [out_features]
48
+ self.proj = LoRACompatibleLinear(in_features, out_features)
49
+
50
+ # initialize alpha
51
+ self.alpha_logscale = alpha_logscale
52
+ if self.alpha_logscale: # log scale alphas initialized to zeros
53
+ self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
54
+ self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
55
+ else: # linear scale alphas initialized to ones
56
+ self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
57
+ self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
58
+
59
+ self.alpha.requires_grad = alpha_trainable
60
+ self.beta.requires_grad = alpha_trainable
61
+
62
+ self.no_div_by_zero = 0.000000001
63
+
64
+ def forward(self, x):
65
+ """
66
+ Forward pass of the function.
67
+ Applies the function to the input elementwise.
68
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
69
+ """
70
+ x = self.proj(x)
71
+ if self.alpha_logscale:
72
+ alpha = torch.exp(self.alpha)
73
+ beta = torch.exp(self.beta)
74
+ else:
75
+ alpha = self.alpha
76
+ beta = self.beta
77
+
78
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
79
+
80
+ return x
81
+
82
+
83
+ class FeedForward(nn.Module):
84
+ r"""
85
+ A feed-forward layer.
86
+
87
+ Parameters:
88
+ dim (`int`): The number of channels in the input.
89
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
90
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
91
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
92
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
93
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ dim: int,
99
+ dim_out: Optional[int] = None,
100
+ mult: int = 4,
101
+ dropout: float = 0.0,
102
+ activation_fn: str = "geglu",
103
+ final_dropout: bool = False,
104
+ ):
105
+ super().__init__()
106
+ inner_dim = int(dim * mult)
107
+ dim_out = dim_out if dim_out is not None else dim
108
+
109
+ if activation_fn == "gelu":
110
+ act_fn = GELU(dim, inner_dim)
111
+ if activation_fn == "gelu-approximate":
112
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
113
+ elif activation_fn == "geglu":
114
+ act_fn = GEGLU(dim, inner_dim)
115
+ elif activation_fn == "geglu-approximate":
116
+ act_fn = ApproximateGELU(dim, inner_dim)
117
+ elif activation_fn == "snakebeta":
118
+ act_fn = SnakeBeta(dim, inner_dim)
119
+
120
+ self.net = nn.ModuleList([])
121
+ # project in
122
+ self.net.append(act_fn)
123
+ # project dropout
124
+ self.net.append(nn.Dropout(dropout))
125
+ # project out
126
+ self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
127
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
128
+ if final_dropout:
129
+ self.net.append(nn.Dropout(dropout))
130
+
131
+ def forward(self, hidden_states):
132
+ for module in self.net:
133
+ hidden_states = module(hidden_states)
134
+ return hidden_states
135
+
136
+
137
+ @maybe_allow_in_graph
138
+ class BasicTransformerBlock(nn.Module):
139
+ r"""
140
+ A basic Transformer block.
141
+
142
+ Parameters:
143
+ dim (`int`): The number of channels in the input and output.
144
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
145
+ attention_head_dim (`int`): The number of channels in each head.
146
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
147
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
148
+ only_cross_attention (`bool`, *optional*):
149
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
150
+ double_self_attention (`bool`, *optional*):
151
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
152
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
153
+ num_embeds_ada_norm (:
154
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
155
+ attention_bias (:
156
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
157
+ """
158
+
159
+ def __init__(
160
+ self,
161
+ dim: int,
162
+ num_attention_heads: int,
163
+ attention_head_dim: int,
164
+ dropout=0.0,
165
+ cross_attention_dim: Optional[int] = None,
166
+ activation_fn: str = "geglu",
167
+ num_embeds_ada_norm: Optional[int] = None,
168
+ attention_bias: bool = False,
169
+ only_cross_attention: bool = False,
170
+ double_self_attention: bool = False,
171
+ upcast_attention: bool = False,
172
+ norm_elementwise_affine: bool = True,
173
+ norm_type: str = "layer_norm",
174
+ final_dropout: bool = False,
175
+ ):
176
+ super().__init__()
177
+ self.only_cross_attention = only_cross_attention
178
+
179
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
180
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
181
+
182
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
183
+ raise ValueError(
184
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
185
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
186
+ )
187
+
188
+ # Define 3 blocks. Each block has its own normalization layer.
189
+ # 1. Self-Attn
190
+ if self.use_ada_layer_norm:
191
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
192
+ elif self.use_ada_layer_norm_zero:
193
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
194
+ else:
195
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
196
+ self.attn1 = Attention(
197
+ query_dim=dim,
198
+ heads=num_attention_heads,
199
+ dim_head=attention_head_dim,
200
+ dropout=dropout,
201
+ bias=attention_bias,
202
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
203
+ upcast_attention=upcast_attention,
204
+ )
205
+
206
+ # 2. Cross-Attn
207
+ if cross_attention_dim is not None or double_self_attention:
208
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
209
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
210
+ # the second cross attention block.
211
+ self.norm2 = (
212
+ AdaLayerNorm(dim, num_embeds_ada_norm)
213
+ if self.use_ada_layer_norm
214
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
215
+ )
216
+ self.attn2 = Attention(
217
+ query_dim=dim,
218
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
219
+ heads=num_attention_heads,
220
+ dim_head=attention_head_dim,
221
+ dropout=dropout,
222
+ bias=attention_bias,
223
+ upcast_attention=upcast_attention,
224
+ # scale_qk=False, # uncomment this to not to use flash attention
225
+ ) # is self-attn if encoder_hidden_states is none
226
+ else:
227
+ self.norm2 = None
228
+ self.attn2 = None
229
+
230
+ # 3. Feed-forward
231
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
232
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
233
+
234
+ # let chunk size default to None
235
+ self._chunk_size = None
236
+ self._chunk_dim = 0
237
+
238
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
239
+ # Sets chunk feed-forward
240
+ self._chunk_size = chunk_size
241
+ self._chunk_dim = dim
242
+
243
+ def forward(
244
+ self,
245
+ hidden_states: torch.FloatTensor,
246
+ attention_mask: Optional[torch.FloatTensor] = None,
247
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
248
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
249
+ timestep: Optional[torch.LongTensor] = None,
250
+ cross_attention_kwargs: Dict[str, Any] = None,
251
+ class_labels: Optional[torch.LongTensor] = None,
252
+ ):
253
+ # Notice that normalization is always applied before the real computation in the following blocks.
254
+ # 1. Self-Attention
255
+ if self.use_ada_layer_norm:
256
+ norm_hidden_states = self.norm1(hidden_states, timestep)
257
+ elif self.use_ada_layer_norm_zero:
258
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
259
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
260
+ )
261
+ else:
262
+ norm_hidden_states = self.norm1(hidden_states)
263
+
264
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
265
+
266
+ attn_output = self.attn1(
267
+ norm_hidden_states,
268
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
269
+ attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
270
+ **cross_attention_kwargs,
271
+ )
272
+ if self.use_ada_layer_norm_zero:
273
+ attn_output = gate_msa.unsqueeze(1) * attn_output
274
+ hidden_states = attn_output + hidden_states
275
+
276
+ # 2. Cross-Attention
277
+ if self.attn2 is not None:
278
+ norm_hidden_states = (
279
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
280
+ )
281
+
282
+ attn_output = self.attn2(
283
+ norm_hidden_states,
284
+ encoder_hidden_states=encoder_hidden_states,
285
+ attention_mask=encoder_attention_mask,
286
+ **cross_attention_kwargs,
287
+ )
288
+ hidden_states = attn_output + hidden_states
289
+
290
+ # 3. Feed-forward
291
+ norm_hidden_states = self.norm3(hidden_states)
292
+
293
+ if self.use_ada_layer_norm_zero:
294
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
295
+
296
+ if self._chunk_size is not None:
297
+ # "feed_forward_chunk_size" can be used to save memory
298
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
299
+ raise ValueError(
300
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
301
+ )
302
+
303
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
304
+ ff_output = torch.cat(
305
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
306
+ dim=self._chunk_dim,
307
+ )
308
+ else:
309
+ ff_output = self.ff(norm_hidden_states)
310
+
311
+ if self.use_ada_layer_norm_zero:
312
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
313
+
314
+ hidden_states = ff_output + hidden_states
315
+
316
+ return hidden_states
HF_Deploy/src/chatterbox/models/s3gen/s3gen.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from CosyVoice https://github.com/FunAudioLLM/CosyVoice
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torchaudio as ta
20
+ from functools import lru_cache
21
+ from typing import Optional
22
+ from omegaconf import DictConfig
23
+
24
+ from ..s3tokenizer import S3_SR, SPEECH_VOCAB_SIZE, S3Tokenizer
25
+ from .const import S3GEN_SR
26
+ from .flow import CausalMaskedDiffWithXvec
27
+ from .xvector import CAMPPlus
28
+ from .utils.mel import mel_spectrogram
29
+ from .f0_predictor import ConvRNNF0Predictor
30
+ from .hifigan import HiFTGenerator
31
+ from .transformer.upsample_encoder import UpsampleConformerEncoder
32
+ from .flow_matching import CausalConditionalCFM
33
+ from .decoder import ConditionalDecoder
34
+
35
+
36
+ def drop_invalid_tokens(x):
37
+ assert len(x.shape) <= 2 and x.shape[0] == 1, "only batch size of one allowed for now"
38
+ return x[x < SPEECH_VOCAB_SIZE]
39
+
40
+
41
+ # TODO: global resampler cache
42
+ @lru_cache(100)
43
+ def get_resampler(src_sr, dst_sr, device):
44
+ return ta.transforms.Resample(src_sr, dst_sr).to(device)
45
+
46
+
47
+ class S3Token2Mel(torch.nn.Module):
48
+ """
49
+ CosyVoice2's CFM decoder maps S3 speech tokens to mel-spectrograms.
50
+
51
+ TODO: make these modules configurable?
52
+ """
53
+ def __init__(self):
54
+ super().__init__()
55
+ self.tokenizer = S3Tokenizer("speech_tokenizer_v2_25hz")
56
+ self.mel_extractor = mel_spectrogram # TODO: make it a torch module?
57
+ self.speaker_encoder = CAMPPlus() # use default args
58
+
59
+ encoder = UpsampleConformerEncoder(
60
+ output_size=512,
61
+ attention_heads=8,
62
+ linear_units=2048,
63
+ num_blocks=6,
64
+ dropout_rate=0.1,
65
+ positional_dropout_rate=0.1,
66
+ attention_dropout_rate=0.1,
67
+ normalize_before=True,
68
+ input_layer='linear',
69
+ pos_enc_layer_type='rel_pos_espnet',
70
+ selfattention_layer_type='rel_selfattn',
71
+ input_size=512,
72
+ use_cnn_module=False,
73
+ macaron_style=False,
74
+ )
75
+
76
+ estimator = ConditionalDecoder(
77
+ in_channels=320,
78
+ out_channels=80,
79
+ causal=True,
80
+ channels=[256],
81
+ dropout=0.0,
82
+ attention_head_dim=64,
83
+ n_blocks=4,
84
+ num_mid_blocks=12,
85
+ num_heads=8,
86
+ act_fn='gelu',
87
+ )
88
+ cfm_params = DictConfig({
89
+ "sigma_min": 1e-06,
90
+ "solver": 'euler',
91
+ "t_scheduler": 'cosine',
92
+ "training_cfg_rate": 0.2,
93
+ "inference_cfg_rate": 0.7,
94
+ "reg_loss_type": 'l1',
95
+ })
96
+ decoder = CausalConditionalCFM(
97
+ spk_emb_dim=80,
98
+ cfm_params=cfm_params,
99
+ estimator=estimator,
100
+ )
101
+
102
+ self.flow = CausalMaskedDiffWithXvec(
103
+ encoder=encoder,
104
+ decoder=decoder
105
+ )
106
+
107
+ self.resamplers = {}
108
+
109
+ @property
110
+ def device(self):
111
+ params = self.tokenizer.parameters()
112
+ return next(params).device
113
+
114
+ def embed_ref(
115
+ self,
116
+ ref_wav: torch.Tensor,
117
+ ref_sr: int,
118
+ device="auto",
119
+ ref_fade_out=True,
120
+ ):
121
+ device = self.device if device == "auto" else device
122
+ if isinstance(ref_wav, np.ndarray):
123
+ ref_wav = torch.from_numpy(ref_wav).float()
124
+
125
+ if ref_wav.device != device:
126
+ ref_wav = ref_wav.to(device)
127
+
128
+ if len(ref_wav.shape) == 1:
129
+ ref_wav = ref_wav.unsqueeze(0) # (B, L)
130
+
131
+ if ref_wav.size(1) > 10 * ref_sr:
132
+ print("WARNING: cosydec received ref longer than 10s")
133
+
134
+ ref_wav_24 = ref_wav
135
+ if ref_sr != S3GEN_SR:
136
+ ref_wav_24 = get_resampler(ref_sr, S3GEN_SR, device)(ref_wav)
137
+
138
+ ref_mels_24 = self.mel_extractor(ref_wav_24).transpose(1, 2).to(device)
139
+ ref_mels_24_len = None
140
+
141
+ # Resample to 16kHz
142
+ ref_wav_16 = get_resampler(ref_sr, S3_SR, device)(ref_wav).to(device)
143
+
144
+ # Speaker embedding
145
+ ref_x_vector = self.speaker_encoder.inference(ref_wav_16)
146
+
147
+ # Tokenize 16khz reference
148
+ ref_speech_tokens, ref_speech_token_lens = self.tokenizer(ref_wav_16)
149
+
150
+ # Make sure mel_len = 2 * stoken_len (happens when the input is not padded to multiple of 40ms)
151
+ if ref_mels_24.shape[1] != 2 * ref_speech_tokens.shape[1]:
152
+ logging.warning(
153
+ "Reference mel length is not equal to 2 * reference token length.\n"
154
+ )
155
+ ref_speech_tokens = ref_speech_tokens[:, :ref_mels_24.shape[1] // 2]
156
+ ref_speech_token_lens[0] = ref_speech_tokens.shape[1]
157
+
158
+ return dict(
159
+ prompt_token=ref_speech_tokens.to(device),
160
+ prompt_token_len=ref_speech_token_lens,
161
+ prompt_feat=ref_mels_24,
162
+ prompt_feat_len=ref_mels_24_len,
163
+ embedding=ref_x_vector,
164
+ )
165
+
166
+ def forward(
167
+ self,
168
+ speech_tokens: torch.LongTensor,
169
+ # locally-computed ref embedding (mutex with ref_dict)
170
+ ref_wav: Optional[torch.Tensor],
171
+ ref_sr: Optional[int],
172
+ # pre-computed ref embedding (prod API)
173
+ ref_dict: Optional[dict] = None,
174
+ finalize: bool = False,
175
+ ):
176
+ """
177
+ Generate waveforms from S3 speech tokens and a reference waveform, which the speaker timbre is inferred from.
178
+
179
+ NOTE:
180
+ - The speaker encoder accepts 16 kHz waveform.
181
+ - S3TokenizerV2 accepts 16 kHz waveform.
182
+ - The mel-spectrogram for the reference assumes 24 kHz input signal.
183
+ - This function is designed for batch_size=1 only.
184
+
185
+ Args
186
+ ----
187
+ - `speech_tokens`: S3 speech tokens [B=1, T]
188
+ - `ref_wav`: reference waveform (`torch.Tensor` with shape=[B=1, T])
189
+ - `ref_sr`: reference sample rate
190
+ - `finalize`: whether streaming is finished or not. Note that if False, the last 3 tokens will be ignored.
191
+ """
192
+ assert (ref_wav is None) ^ (ref_dict is None), f"Must provide exactly one of ref_wav or ref_dict (got {ref_wav} and {ref_dict})"
193
+
194
+ if ref_dict is None:
195
+ ref_dict = self.embed_ref(ref_wav, ref_sr)
196
+ else:
197
+ # type/device casting (all values will be numpy if it's from a prod API call)
198
+ for rk in list(ref_dict):
199
+ if isinstance(ref_dict[rk], np.ndarray):
200
+ ref_dict[rk] = torch.from_numpy(ref_dict[rk])
201
+ if torch.is_tensor(ref_dict[rk]):
202
+ ref_dict[rk] = ref_dict[rk].to(self.device)
203
+
204
+ if len(speech_tokens.shape) == 1:
205
+ speech_tokens = speech_tokens.unsqueeze(0)
206
+
207
+ # assert speech_tokens.shape[0] == 1, "only batch size of one allowed for now"
208
+ speech_token_lens = torch.LongTensor([speech_tokens.size(1)]).to(self.device)
209
+
210
+ output_mels, _ = self.flow.inference(
211
+ token=speech_tokens,
212
+ token_len=speech_token_lens,
213
+ finalize=finalize,
214
+ **ref_dict,
215
+ )
216
+ return output_mels
217
+
218
+
219
+ class S3Token2Wav(S3Token2Mel):
220
+ """
221
+ The decoder of CosyVoice2 is a concat of token-to-mel (CFM) and a mel-to-waveform (HiFiGAN) modules.
222
+
223
+ TODO: make these modules configurable?
224
+ """
225
+
226
+ def __init__(self):
227
+ super().__init__()
228
+
229
+ f0_predictor = ConvRNNF0Predictor()
230
+ self.mel2wav = HiFTGenerator(
231
+ sampling_rate=S3GEN_SR,
232
+ upsample_rates=[8, 5, 3],
233
+ upsample_kernel_sizes=[16, 11, 7],
234
+ source_resblock_kernel_sizes=[7, 7, 11],
235
+ source_resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
236
+ f0_predictor=f0_predictor,
237
+ )
238
+
239
+ # silence out a few ms and fade audio in to reduce artifacts
240
+ n_trim = S3GEN_SR // 50 # 20ms = half of a frame
241
+ trim_fade = torch.zeros(2 * n_trim)
242
+ trim_fade[n_trim:] = (torch.cos(torch.linspace(torch.pi, 0, n_trim)) + 1) / 2
243
+ self.register_buffer("trim_fade", trim_fade, persistent=False) # (buffers get automatic device casting)
244
+
245
+ def forward(
246
+ self,
247
+ speech_tokens,
248
+ # locally-computed ref embedding (mutex with ref_dict)
249
+ ref_wav: Optional[torch.Tensor],
250
+ ref_sr: Optional[int],
251
+ # pre-computed ref embedding (prod API)
252
+ ref_dict: Optional[dict] = None,
253
+ finalize: bool = False
254
+ ):
255
+ output_mels = super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
256
+
257
+ # TODO jrm: ignoring the speed control (mel interpolation) and the HiFTGAN caching mechanisms for now.
258
+ hift_cache_source = torch.zeros(1, 1, 0).to(self.device)
259
+
260
+ output_wavs, *_ = self.mel2wav.inference(speech_feat=output_mels, cache_source=hift_cache_source)
261
+
262
+ if not self.training:
263
+ # NOTE: ad-hoc method to reduce "spillover" from the reference clip.
264
+ output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
265
+
266
+ return output_wavs
267
+
268
+ @torch.inference_mode()
269
+ def flow_inference(
270
+ self,
271
+ speech_tokens,
272
+ # locally-computed ref embedding (mutex with ref_dict)
273
+ ref_wav: Optional[torch.Tensor] = None,
274
+ ref_sr: Optional[int] = None,
275
+ # pre-computed ref embedding (prod API)
276
+ ref_dict: Optional[dict] = None,
277
+ finalize: bool = False,
278
+ ):
279
+ return super().forward(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
280
+
281
+ @torch.inference_mode()
282
+ def hift_inference(self, speech_feat, cache_source: torch.Tensor = None):
283
+ if cache_source is None:
284
+ cache_source = torch.zeros(1, 1, 0).to(self.device)
285
+ return self.mel2wav.inference(speech_feat=speech_feat, cache_source=cache_source)
286
+
287
+ @torch.inference_mode()
288
+ def inference(
289
+ self,
290
+ speech_tokens,
291
+ # locally-computed ref embedding (mutex with ref_dict)
292
+ ref_wav: Optional[torch.Tensor] = None,
293
+ ref_sr: Optional[int] = None,
294
+ # pre-computed ref embedding (prod API)
295
+ ref_dict: Optional[dict] = None,
296
+ cache_source: torch.Tensor = None, # NOTE: this arg is for streaming, it can probably be removed here
297
+ finalize: bool = True,
298
+ ):
299
+ output_mels = self.flow_inference(speech_tokens, ref_wav=ref_wav, ref_sr=ref_sr, ref_dict=ref_dict, finalize=finalize)
300
+ output_wavs, output_sources = self.hift_inference(output_mels, cache_source)
301
+
302
+ # NOTE: ad-hoc method to reduce "spillover" from the reference clip.
303
+ output_wavs[:, :len(self.trim_fade)] *= self.trim_fade
304
+
305
+ return output_wavs, output_sources
HF_Deploy/src/chatterbox/models/s3gen/transformer/__init__.py ADDED
File without changes
HF_Deploy/src/chatterbox/models/s3gen/transformer/activation.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
2
+ # 2020 Northwestern Polytechnical University (Pengcheng Guo)
3
+ # 2020 Mobvoi Inc (Binbin Zhang)
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Swish() activation function for Conformer."""
18
+
19
+ import torch
20
+ from torch import nn, sin, pow
21
+ from torch.nn import Parameter
22
+
23
+
24
+ class Swish(torch.nn.Module):
25
+ """Construct an Swish object."""
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ """Return Swish activation function."""
29
+ return x * torch.sigmoid(x)
30
+
31
+
32
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
33
+ # LICENSE is in incl_licenses directory.
34
+ class Snake(nn.Module):
35
+ '''
36
+ Implementation of a sine-based periodic activation function
37
+ Shape:
38
+ - Input: (B, C, T)
39
+ - Output: (B, C, T), same shape as the input
40
+ Parameters:
41
+ - alpha - trainable parameter
42
+ References:
43
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44
+ https://arxiv.org/abs/2006.08195
45
+ Examples:
46
+ >>> a1 = snake(256)
47
+ >>> x = torch.randn(256)
48
+ >>> x = a1(x)
49
+ '''
50
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
51
+ '''
52
+ Initialization.
53
+ INPUT:
54
+ - in_features: shape of the input
55
+ - alpha: trainable parameter
56
+ alpha is initialized to 1 by default, higher values = higher-frequency.
57
+ alpha will be trained along with the rest of your model.
58
+ '''
59
+ super(Snake, self).__init__()
60
+ self.in_features = in_features
61
+
62
+ # initialize alpha
63
+ self.alpha_logscale = alpha_logscale
64
+ if self.alpha_logscale: # log scale alphas initialized to zeros
65
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
66
+ else: # linear scale alphas initialized to ones
67
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
68
+
69
+ self.alpha.requires_grad = alpha_trainable
70
+
71
+ self.no_div_by_zero = 0.000000001
72
+
73
+ def forward(self, x):
74
+ '''
75
+ Forward pass of the function.
76
+ Applies the function to the input elementwise.
77
+ Snake ∶= x + 1/a * sin^2 (xa)
78
+ '''
79
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
80
+ if self.alpha_logscale:
81
+ alpha = torch.exp(alpha)
82
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
83
+
84
+ return x
HF_Deploy/src/chatterbox/models/s3gen/transformer/attention.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Multi-Head Attention layer definition."""
18
+
19
+ import math
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+
26
+ class MultiHeadedAttention(nn.Module):
27
+ """Multi-Head Attention layer.
28
+
29
+ Args:
30
+ n_head (int): The number of heads.
31
+ n_feat (int): The number of features.
32
+ dropout_rate (float): Dropout rate.
33
+
34
+ """
35
+
36
+ def __init__(self,
37
+ n_head: int,
38
+ n_feat: int,
39
+ dropout_rate: float,
40
+ key_bias: bool = True):
41
+ """Construct an MultiHeadedAttention object."""
42
+ super().__init__()
43
+ assert n_feat % n_head == 0
44
+ # We assume d_v always equals d_k
45
+ self.d_k = n_feat // n_head
46
+ self.h = n_head
47
+ self.linear_q = nn.Linear(n_feat, n_feat)
48
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
49
+ self.linear_v = nn.Linear(n_feat, n_feat)
50
+ self.linear_out = nn.Linear(n_feat, n_feat)
51
+ self.dropout = nn.Dropout(p=dropout_rate)
52
+
53
+ def forward_qkv(
54
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
55
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
56
+ """Transform query, key and value.
57
+
58
+ Args:
59
+ query (torch.Tensor): Query tensor (#batch, time1, size).
60
+ key (torch.Tensor): Key tensor (#batch, time2, size).
61
+ value (torch.Tensor): Value tensor (#batch, time2, size).
62
+
63
+ Returns:
64
+ torch.Tensor: Transformed query tensor, size
65
+ (#batch, n_head, time1, d_k).
66
+ torch.Tensor: Transformed key tensor, size
67
+ (#batch, n_head, time2, d_k).
68
+ torch.Tensor: Transformed value tensor, size
69
+ (#batch, n_head, time2, d_k).
70
+
71
+ """
72
+ n_batch = query.size(0)
73
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
74
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
75
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
76
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
77
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
78
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
79
+
80
+ return q, k, v
81
+
82
+ def forward_attention(
83
+ self,
84
+ value: torch.Tensor,
85
+ scores: torch.Tensor,
86
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
87
+ ) -> torch.Tensor:
88
+ """Compute attention context vector.
89
+
90
+ Args:
91
+ value (torch.Tensor): Transformed value, size
92
+ (#batch, n_head, time2, d_k).
93
+ scores (torch.Tensor): Attention score, size
94
+ (#batch, n_head, time1, time2).
95
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
96
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
97
+
98
+ Returns:
99
+ torch.Tensor: Transformed value (#batch, time1, d_model)
100
+ weighted by the attention score (#batch, time1, time2).
101
+
102
+ """
103
+ n_batch = value.size(0)
104
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
105
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
106
+ # 1st chunk to ease the onnx export.]
107
+ # 2. pytorch training
108
+ if mask.size(2) > 0: # time2 > 0
109
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
110
+ # For last chunk, time2 might be larger than scores.size(-1)
111
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
112
+ scores = scores.masked_fill(mask, -float('inf'))
113
+ attn = torch.softmax(scores, dim=-1).masked_fill(
114
+ mask, 0.0) # (batch, head, time1, time2)
115
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
116
+ # 1. onnx(16/-1, -1/-1, 16/0)
117
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
118
+ else:
119
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
120
+
121
+ p_attn = self.dropout(attn)
122
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
123
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
124
+ self.h * self.d_k)
125
+ ) # (batch, time1, d_model)
126
+
127
+ return self.linear_out(x) # (batch, time1, d_model)
128
+
129
+ def forward(
130
+ self,
131
+ query: torch.Tensor,
132
+ key: torch.Tensor,
133
+ value: torch.Tensor,
134
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
135
+ pos_emb: torch.Tensor = torch.empty(0),
136
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """Compute scaled dot product attention.
139
+
140
+ Args:
141
+ query (torch.Tensor): Query tensor (#batch, time1, size).
142
+ key (torch.Tensor): Key tensor (#batch, time2, size).
143
+ value (torch.Tensor): Value tensor (#batch, time2, size).
144
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
145
+ (#batch, time1, time2).
146
+ 1.When applying cross attention between decoder and encoder,
147
+ the batch padding mask for input is in (#batch, 1, T) shape.
148
+ 2.When applying self attention of encoder,
149
+ the mask is in (#batch, T, T) shape.
150
+ 3.When applying self attention of decoder,
151
+ the mask is in (#batch, L, L) shape.
152
+ 4.If the different position in decoder see different block
153
+ of the encoder, such as Mocha, the passed in mask could be
154
+ in (#batch, L, T) shape. But there is no such case in current
155
+ CosyVoice.
156
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
157
+ where `cache_t == chunk_size * num_decoding_left_chunks`
158
+ and `head * d_k == size`
159
+
160
+
161
+ Returns:
162
+ torch.Tensor: Output tensor (#batch, time1, d_model).
163
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
164
+ where `cache_t == chunk_size * num_decoding_left_chunks`
165
+ and `head * d_k == size`
166
+
167
+ """
168
+ q, k, v = self.forward_qkv(query, key, value)
169
+
170
+ # NOTE(xcsong):
171
+ # when export onnx model, for 1st chunk, we feed
172
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
173
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
174
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
175
+ # and we will always do splitting and
176
+ # concatnation(this will simplify onnx export). Note that
177
+ # it's OK to concat & split zero-shaped tensors(see code below).
178
+ # when export jit model, for 1st chunk, we always feed
179
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
180
+ # >>> a = torch.ones((1, 2, 0, 4))
181
+ # >>> b = torch.ones((1, 2, 3, 4))
182
+ # >>> c = torch.cat((a, b), dim=2)
183
+ # >>> torch.equal(b, c) # True
184
+ # >>> d = torch.split(a, 2, dim=-1)
185
+ # >>> torch.equal(d[0], d[1]) # True
186
+ if cache.size(0) > 0:
187
+ key_cache, value_cache = torch.split(cache,
188
+ cache.size(-1) // 2,
189
+ dim=-1)
190
+ k = torch.cat([key_cache, k], dim=2)
191
+ v = torch.cat([value_cache, v], dim=2)
192
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
193
+ # non-trivial to calculate `next_cache_start` here.
194
+ new_cache = torch.cat((k, v), dim=-1)
195
+
196
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
197
+ return self.forward_attention(v, scores, mask), new_cache
198
+
199
+
200
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
201
+ """Multi-Head Attention layer with relative position encoding.
202
+ Paper: https://arxiv.org/abs/1901.02860
203
+ Args:
204
+ n_head (int): The number of heads.
205
+ n_feat (int): The number of features.
206
+ dropout_rate (float): Dropout rate.
207
+ """
208
+
209
+ def __init__(self,
210
+ n_head: int,
211
+ n_feat: int,
212
+ dropout_rate: float,
213
+ key_bias: bool = True):
214
+ """Construct an RelPositionMultiHeadedAttention object."""
215
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
216
+ # linear transformation for positional encoding
217
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
218
+ # these two learnable bias are used in matrix c and matrix d
219
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
220
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
221
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
222
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
223
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
224
+
225
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
226
+ """Compute relative positional encoding.
227
+
228
+ Args:
229
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
230
+ time1 means the length of query vector.
231
+
232
+ Returns:
233
+ torch.Tensor: Output tensor.
234
+
235
+ """
236
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
237
+ device=x.device,
238
+ dtype=x.dtype)
239
+ x_padded = torch.cat([zero_pad, x], dim=-1)
240
+
241
+ x_padded = x_padded.view(x.size()[0],
242
+ x.size()[1],
243
+ x.size(3) + 1, x.size(2))
244
+ x = x_padded[:, :, 1:].view_as(x)[
245
+ :, :, :, : x.size(-1) // 2 + 1
246
+ ] # only keep the positions from 0 to time2
247
+ return x
248
+
249
+ def forward(
250
+ self,
251
+ query: torch.Tensor,
252
+ key: torch.Tensor,
253
+ value: torch.Tensor,
254
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
255
+ pos_emb: torch.Tensor = torch.empty(0),
256
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
257
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
259
+ Args:
260
+ query (torch.Tensor): Query tensor (#batch, time1, size).
261
+ key (torch.Tensor): Key tensor (#batch, time2, size).
262
+ value (torch.Tensor): Value tensor (#batch, time2, size).
263
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
264
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
265
+ pos_emb (torch.Tensor): Positional embedding tensor
266
+ (#batch, time2, size).
267
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
268
+ where `cache_t == chunk_size * num_decoding_left_chunks`
269
+ and `head * d_k == size`
270
+ Returns:
271
+ torch.Tensor: Output tensor (#batch, time1, d_model).
272
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
273
+ where `cache_t == chunk_size * num_decoding_left_chunks`
274
+ and `head * d_k == size`
275
+ """
276
+ q, k, v = self.forward_qkv(query, key, value)
277
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
278
+
279
+ # NOTE(xcsong):
280
+ # when export onnx model, for 1st chunk, we feed
281
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
282
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
283
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
284
+ # and we will always do splitting and
285
+ # concatnation(this will simplify onnx export). Note that
286
+ # it's OK to concat & split zero-shaped tensors(see code below).
287
+ # when export jit model, for 1st chunk, we always feed
288
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
289
+ # >>> a = torch.ones((1, 2, 0, 4))
290
+ # >>> b = torch.ones((1, 2, 3, 4))
291
+ # >>> c = torch.cat((a, b), dim=2)
292
+ # >>> torch.equal(b, c) # True
293
+ # >>> d = torch.split(a, 2, dim=-1)
294
+ # >>> torch.equal(d[0], d[1]) # True
295
+ if cache.size(0) > 0:
296
+ key_cache, value_cache = torch.split(cache,
297
+ cache.size(-1) // 2,
298
+ dim=-1)
299
+ k = torch.cat([key_cache, k], dim=2)
300
+ v = torch.cat([value_cache, v], dim=2)
301
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
302
+ # non-trivial to calculate `next_cache_start` here.
303
+ new_cache = torch.cat((k, v), dim=-1)
304
+
305
+ n_batch_pos = pos_emb.size(0)
306
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
307
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
308
+
309
+ # (batch, head, time1, d_k)
310
+ q_with_bias_u = (q + self.pos_bias_u.to(q.device)).transpose(1, 2)
311
+ # (batch, head, time1, d_k)
312
+ q_with_bias_v = (q + self.pos_bias_v.to(q.device)).transpose(1, 2)
313
+
314
+ # compute attention score
315
+ # first compute matrix a and matrix c
316
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
317
+ # (batch, head, time1, time2)
318
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
319
+
320
+ # compute matrix b and matrix d
321
+ # (batch, head, time1, time2)
322
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
323
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
324
+ if matrix_ac.shape != matrix_bd.shape:
325
+ matrix_bd = self.rel_shift(matrix_bd)
326
+
327
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
328
+ self.d_k) # (batch, head, time1, time2)
329
+
330
+ return self.forward_attention(v, scores, mask), new_cache
HF_Deploy/src/chatterbox/models/s3gen/transformer/convolution.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """ConvolutionModule definition."""
17
+
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class ConvolutionModule(nn.Module):
25
+ """ConvolutionModule in Conformer model."""
26
+
27
+ def __init__(self,
28
+ channels: int,
29
+ kernel_size: int = 15,
30
+ activation: nn.Module = nn.ReLU(),
31
+ norm: str = "batch_norm",
32
+ causal: bool = False,
33
+ bias: bool = True):
34
+ """Construct an ConvolutionModule object.
35
+ Args:
36
+ channels (int): The number of channels of conv layers.
37
+ kernel_size (int): Kernel size of conv layers.
38
+ causal (int): Whether use causal convolution or not
39
+ """
40
+ super().__init__()
41
+
42
+ self.pointwise_conv1 = nn.Conv1d(
43
+ channels,
44
+ 2 * channels,
45
+ kernel_size=1,
46
+ stride=1,
47
+ padding=0,
48
+ bias=bias,
49
+ )
50
+ # self.lorder is used to distinguish if it's a causal convolution,
51
+ # if self.lorder > 0: it's a causal convolution, the input will be
52
+ # padded with self.lorder frames on the left in forward.
53
+ # else: it's a symmetrical convolution
54
+ if causal:
55
+ padding = 0
56
+ self.lorder = kernel_size - 1
57
+ else:
58
+ # kernel_size should be an odd number for none causal convolution
59
+ assert (kernel_size - 1) % 2 == 0
60
+ padding = (kernel_size - 1) // 2
61
+ self.lorder = 0
62
+ self.depthwise_conv = nn.Conv1d(
63
+ channels,
64
+ channels,
65
+ kernel_size,
66
+ stride=1,
67
+ padding=padding,
68
+ groups=channels,
69
+ bias=bias,
70
+ )
71
+
72
+ assert norm in ['batch_norm', 'layer_norm']
73
+ if norm == "batch_norm":
74
+ self.use_layer_norm = False
75
+ self.norm = nn.BatchNorm1d(channels)
76
+ else:
77
+ self.use_layer_norm = True
78
+ self.norm = nn.LayerNorm(channels)
79
+
80
+ self.pointwise_conv2 = nn.Conv1d(
81
+ channels,
82
+ channels,
83
+ kernel_size=1,
84
+ stride=1,
85
+ padding=0,
86
+ bias=bias,
87
+ )
88
+ self.activation = activation
89
+
90
+ def forward(
91
+ self,
92
+ x: torch.Tensor,
93
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
94
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
95
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
96
+ """Compute convolution module.
97
+ Args:
98
+ x (torch.Tensor): Input tensor (#batch, time, channels).
99
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
100
+ (0, 0, 0) means fake mask.
101
+ cache (torch.Tensor): left context cache, it is only
102
+ used in causal convolution (#batch, channels, cache_t),
103
+ (0, 0, 0) meas fake cache.
104
+ Returns:
105
+ torch.Tensor: Output tensor (#batch, time, channels).
106
+ """
107
+ # exchange the temporal dimension and the feature dimension
108
+ x = x.transpose(1, 2) # (#batch, channels, time)
109
+
110
+ # mask batch padding
111
+ if mask_pad.size(2) > 0: # time > 0
112
+ x.masked_fill_(~mask_pad, 0.0)
113
+
114
+ if self.lorder > 0:
115
+ if cache.size(2) == 0: # cache_t == 0
116
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
117
+ else:
118
+ assert cache.size(0) == x.size(0) # equal batch
119
+ assert cache.size(1) == x.size(1) # equal channel
120
+ x = torch.cat((cache, x), dim=2)
121
+ assert (x.size(2) > self.lorder)
122
+ new_cache = x[:, :, -self.lorder:]
123
+ else:
124
+ # It's better we just return None if no cache is required,
125
+ # However, for JIT export, here we just fake one tensor instead of
126
+ # None.
127
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
128
+
129
+ # GLU mechanism
130
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
131
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
132
+
133
+ # 1D Depthwise Conv
134
+ x = self.depthwise_conv(x)
135
+ if self.use_layer_norm:
136
+ x = x.transpose(1, 2)
137
+ x = self.activation(self.norm(x))
138
+ if self.use_layer_norm:
139
+ x = x.transpose(1, 2)
140
+ x = self.pointwise_conv2(x)
141
+ # mask batch padding
142
+ if mask_pad.size(2) > 0: # time > 0
143
+ x.masked_fill_(~mask_pad, 0.0)
144
+
145
+ return x.transpose(1, 2), new_cache
HF_Deploy/src/chatterbox/models/s3gen/transformer/embedding.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Positonal Encoding Module."""
17
+
18
+ import math
19
+ from typing import Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import numpy as np
24
+
25
+
26
+ class PositionalEncoding(torch.nn.Module):
27
+ """Positional encoding.
28
+
29
+ :param int d_model: embedding dim
30
+ :param float dropout_rate: dropout rate
31
+ :param int max_len: maximum input length
32
+
33
+ PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
34
+ PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
35
+ """
36
+
37
+ def __init__(self,
38
+ d_model: int,
39
+ dropout_rate: float,
40
+ max_len: int = 5000,
41
+ reverse: bool = False):
42
+ """Construct an PositionalEncoding object."""
43
+ super().__init__()
44
+ self.d_model = d_model
45
+ self.xscale = math.sqrt(self.d_model)
46
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
47
+ self.max_len = max_len
48
+
49
+ self.pe = torch.zeros(self.max_len, self.d_model)
50
+ position = torch.arange(0, self.max_len,
51
+ dtype=torch.float32).unsqueeze(1)
52
+ div_term = torch.exp(
53
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
54
+ -(math.log(10000.0) / self.d_model))
55
+ self.pe[:, 0::2] = torch.sin(position * div_term)
56
+ self.pe[:, 1::2] = torch.cos(position * div_term)
57
+ self.pe = self.pe.unsqueeze(0)
58
+
59
+ def forward(self,
60
+ x: torch.Tensor,
61
+ offset: Union[int, torch.Tensor] = 0) \
62
+ -> Tuple[torch.Tensor, torch.Tensor]:
63
+ """Add positional encoding.
64
+
65
+ Args:
66
+ x (torch.Tensor): Input. Its shape is (batch, time, ...)
67
+ offset (int, torch.tensor): position offset
68
+
69
+ Returns:
70
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
71
+ torch.Tensor: for compatibility to RelPositionalEncoding
72
+ """
73
+
74
+ self.pe = self.pe.to(x.device)
75
+ pos_emb = self.position_encoding(offset, x.size(1), False)
76
+ x = x * self.xscale + pos_emb
77
+ return self.dropout(x), self.dropout(pos_emb)
78
+
79
+ def position_encoding(self,
80
+ offset: Union[int, torch.Tensor],
81
+ size: int,
82
+ apply_dropout: bool = True) -> torch.Tensor:
83
+ """ For getting encoding in a streaming fashion
84
+
85
+ Attention!!!!!
86
+ we apply dropout only once at the whole utterance level in a none
87
+ streaming way, but will call this function several times with
88
+ increasing input size in a streaming scenario, so the dropout will
89
+ be applied several times.
90
+
91
+ Args:
92
+ offset (int or torch.tensor): start offset
93
+ size (int): required size of position encoding
94
+
95
+ Returns:
96
+ torch.Tensor: Corresponding encoding
97
+ """
98
+ # How to subscript a Union type:
99
+ # https://github.com/pytorch/pytorch/issues/69434
100
+ if isinstance(offset, int):
101
+ assert offset + size <= self.max_len
102
+ pos_emb = self.pe[:, offset:offset + size]
103
+ elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
104
+ assert offset + size <= self.max_len
105
+ pos_emb = self.pe[:, offset:offset + size]
106
+ else: # for batched streaming decoding on GPU
107
+ assert torch.max(offset) + size <= self.max_len
108
+ index = offset.unsqueeze(1) + \
109
+ torch.arange(0, size).to(offset.device) # B X T
110
+ flag = index > 0
111
+ # remove negative offset
112
+ index = index * flag
113
+ pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
114
+
115
+ if apply_dropout:
116
+ pos_emb = self.dropout(pos_emb)
117
+ return pos_emb
118
+
119
+
120
+ class RelPositionalEncoding(PositionalEncoding):
121
+ """Relative positional encoding module.
122
+ See : Appendix B in https://arxiv.org/abs/1901.02860
123
+ Args:
124
+ d_model (int): Embedding dimension.
125
+ dropout_rate (float): Dropout rate.
126
+ max_len (int): Maximum input length.
127
+ """
128
+
129
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
130
+ """Initialize class."""
131
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
132
+
133
+ def forward(self,
134
+ x: torch.Tensor,
135
+ offset: Union[int, torch.Tensor] = 0) \
136
+ -> Tuple[torch.Tensor, torch.Tensor]:
137
+ """Compute positional encoding.
138
+ Args:
139
+ x (torch.Tensor): Input tensor (batch, time, `*`).
140
+ Returns:
141
+ torch.Tensor: Encoded tensor (batch, time, `*`).
142
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
143
+ """
144
+ self.pe = self.pe.to(x.device)
145
+ x = x * self.xscale
146
+ pos_emb = self.position_encoding(offset, x.size(1), False)
147
+ return self.dropout(x), self.dropout(pos_emb)
148
+
149
+
150
+ class WhisperPositionalEncoding(PositionalEncoding):
151
+ """ Sinusoids position encoding used in openai-whisper.encoder
152
+ """
153
+
154
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
155
+ super().__init__(d_model, dropout_rate, max_len)
156
+ self.xscale = 1.0
157
+ log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
158
+ inv_timescales = torch.exp(-log_timescale_increment *
159
+ torch.arange(d_model // 2))
160
+ scaled_time = torch.arange(max_len)[:, np.newaxis] * \
161
+ inv_timescales[np.newaxis, :]
162
+ pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
163
+ delattr(self, "pe")
164
+ self.register_buffer("pe", pe.unsqueeze(0))
165
+
166
+
167
+ class LearnablePositionalEncoding(PositionalEncoding):
168
+ """ Learnable position encoding used in openai-whisper.decoder
169
+ """
170
+
171
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
172
+ super().__init__(d_model, dropout_rate, max_len)
173
+ # NOTE(xcsong): overwrite self.pe & self.xscale
174
+ self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
175
+ self.xscale = 1.0
176
+
177
+
178
+ class NoPositionalEncoding(torch.nn.Module):
179
+ """ No position encoding
180
+ """
181
+
182
+ def __init__(self, d_model: int, dropout_rate: float):
183
+ super().__init__()
184
+ self.d_model = d_model
185
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
186
+
187
+ def forward(self,
188
+ x: torch.Tensor,
189
+ offset: Union[int, torch.Tensor] = 0) \
190
+ -> Tuple[torch.Tensor, torch.Tensor]:
191
+ """ Just return zero vector for interface compatibility
192
+ """
193
+ pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
194
+ return self.dropout(x), pos_emb
195
+
196
+ def position_encoding(self, offset: Union[int, torch.Tensor],
197
+ size: int) -> torch.Tensor:
198
+ return torch.zeros(1, size, self.d_model)
199
+
200
+
201
+ class EspnetRelPositionalEncoding(torch.nn.Module):
202
+ """Relative positional encoding module (new implementation).
203
+
204
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
205
+
206
+ See : Appendix B in https://arxiv.org/abs/1901.02860
207
+
208
+ Args:
209
+ d_model (int): Embedding dimension.
210
+ dropout_rate (float): Dropout rate.
211
+ max_len (int): Maximum input length.
212
+
213
+ """
214
+
215
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
216
+ """Construct an PositionalEncoding object."""
217
+ super(EspnetRelPositionalEncoding, self).__init__()
218
+ self.d_model = d_model
219
+ self.xscale = math.sqrt(self.d_model)
220
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
221
+ self.pe = None
222
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
223
+
224
+ def extend_pe(self, x: torch.Tensor):
225
+ """Reset the positional encodings."""
226
+ if self.pe is not None:
227
+ # self.pe contains both positive and negative parts
228
+ # the length of self.pe is 2 * input_len - 1
229
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
230
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
231
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
232
+ return
233
+ # Suppose `i` means to the position of query vecotr and `j` means the
234
+ # position of key vector. We use position relative positions when keys
235
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
236
+ pe_positive = torch.zeros(x.size(1), self.d_model)
237
+ pe_negative = torch.zeros(x.size(1), self.d_model)
238
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
239
+ div_term = torch.exp(
240
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
241
+ * -(math.log(10000.0) / self.d_model)
242
+ )
243
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
244
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
245
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
246
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
247
+
248
+ # Reserve the order of positive indices and concat both positive and
249
+ # negative indices. This is used to support the shifting trick
250
+ # as in https://arxiv.org/abs/1901.02860
251
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
252
+ pe_negative = pe_negative[1:].unsqueeze(0)
253
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
254
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
255
+
256
+ def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
257
+ -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """Add positional encoding.
259
+
260
+ Args:
261
+ x (torch.Tensor): Input tensor (batch, time, `*`).
262
+
263
+ Returns:
264
+ torch.Tensor: Encoded tensor (batch, time, `*`).
265
+
266
+ """
267
+ self.extend_pe(x)
268
+ x = x * self.xscale
269
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
270
+ return self.dropout(x), self.dropout(pos_emb)
271
+
272
+ def position_encoding(self,
273
+ offset: Union[int, torch.Tensor],
274
+ size: int) -> torch.Tensor:
275
+ """ For getting encoding in a streaming fashion
276
+
277
+ Attention!!!!!
278
+ we apply dropout only once at the whole utterance level in a none
279
+ streaming way, but will call this function several times with
280
+ increasing input size in a streaming scenario, so the dropout will
281
+ be applied several times.
282
+
283
+ Args:
284
+ offset (int or torch.tensor): start offset
285
+ size (int): required size of position encoding
286
+
287
+ Returns:
288
+ torch.Tensor: Corresponding encoding
289
+ """
290
+ pos_emb = self.pe[
291
+ :,
292
+ self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
293
+ ]
294
+ return pos_emb
HF_Deploy/src/chatterbox/models/s3gen/transformer/encoder_layer.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Encoder self-attention layer definition."""
17
+
18
+ from typing import Optional, Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class TransformerEncoderLayer(nn.Module):
25
+ """Encoder layer module.
26
+
27
+ Args:
28
+ size (int): Input dimension.
29
+ self_attn (torch.nn.Module): Self-attention module instance.
30
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
31
+ instance can be used as the argument.
32
+ feed_forward (torch.nn.Module): Feed-forward module instance.
33
+ `PositionwiseFeedForward`, instance can be used as the argument.
34
+ dropout_rate (float): Dropout rate.
35
+ normalize_before (bool):
36
+ True: use layer_norm before each sub-block.
37
+ False: to use layer_norm after each sub-block.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ size: int,
43
+ self_attn: torch.nn.Module,
44
+ feed_forward: torch.nn.Module,
45
+ dropout_rate: float,
46
+ normalize_before: bool = True,
47
+ ):
48
+ """Construct an EncoderLayer object."""
49
+ super().__init__()
50
+ self.self_attn = self_attn
51
+ self.feed_forward = feed_forward
52
+ self.norm1 = nn.LayerNorm(size, eps=1e-12)
53
+ self.norm2 = nn.LayerNorm(size, eps=1e-12)
54
+ self.dropout = nn.Dropout(dropout_rate)
55
+ self.size = size
56
+ self.normalize_before = normalize_before
57
+
58
+ def forward(
59
+ self,
60
+ x: torch.Tensor,
61
+ mask: torch.Tensor,
62
+ pos_emb: torch.Tensor,
63
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
64
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
65
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
66
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
67
+ """Compute encoded features.
68
+
69
+ Args:
70
+ x (torch.Tensor): (#batch, time, size)
71
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
72
+ (0, 0, 0) means fake mask.
73
+ pos_emb (torch.Tensor): just for interface compatibility
74
+ to ConformerEncoderLayer
75
+ mask_pad (torch.Tensor): does not used in transformer layer,
76
+ just for unified api with conformer.
77
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
78
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
79
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
80
+ (#batch=1, size, cache_t2), not used here, it's for interface
81
+ compatibility to ConformerEncoderLayer.
82
+ Returns:
83
+ torch.Tensor: Output tensor (#batch, time, size).
84
+ torch.Tensor: Mask tensor (#batch, time, time).
85
+ torch.Tensor: att_cache tensor,
86
+ (#batch=1, head, cache_t1 + time, d_k * 2).
87
+ torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
88
+
89
+ """
90
+ residual = x
91
+ if self.normalize_before:
92
+ x = self.norm1(x)
93
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
94
+ x = residual + self.dropout(x_att)
95
+ if not self.normalize_before:
96
+ x = self.norm1(x)
97
+
98
+ residual = x
99
+ if self.normalize_before:
100
+ x = self.norm2(x)
101
+ x = residual + self.dropout(self.feed_forward(x))
102
+ if not self.normalize_before:
103
+ x = self.norm2(x)
104
+
105
+ fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
106
+ return x, mask, new_att_cache, fake_cnn_cache
107
+
108
+
109
+ class ConformerEncoderLayer(nn.Module):
110
+ """Encoder layer module.
111
+ Args:
112
+ size (int): Input dimension.
113
+ self_attn (torch.nn.Module): Self-attention module instance.
114
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
115
+ instance can be used as the argument.
116
+ feed_forward (torch.nn.Module): Feed-forward module instance.
117
+ `PositionwiseFeedForward` instance can be used as the argument.
118
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
119
+ instance.
120
+ `PositionwiseFeedForward` instance can be used as the argument.
121
+ conv_module (torch.nn.Module): Convolution module instance.
122
+ `ConvlutionModule` instance can be used as the argument.
123
+ dropout_rate (float): Dropout rate.
124
+ normalize_before (bool):
125
+ True: use layer_norm before each sub-block.
126
+ False: use layer_norm after each sub-block.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ size: int,
132
+ self_attn: torch.nn.Module,
133
+ feed_forward: Optional[nn.Module] = None,
134
+ feed_forward_macaron: Optional[nn.Module] = None,
135
+ conv_module: Optional[nn.Module] = None,
136
+ dropout_rate: float = 0.1,
137
+ normalize_before: bool = True,
138
+ ):
139
+ """Construct an EncoderLayer object."""
140
+ super().__init__()
141
+ self.self_attn = self_attn
142
+ self.feed_forward = feed_forward
143
+ self.feed_forward_macaron = feed_forward_macaron
144
+ self.conv_module = conv_module
145
+ self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
146
+ self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
147
+ if feed_forward_macaron is not None:
148
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
149
+ self.ff_scale = 0.5
150
+ else:
151
+ self.ff_scale = 1.0
152
+ if self.conv_module is not None:
153
+ self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
154
+ self.norm_final = nn.LayerNorm(
155
+ size, eps=1e-12) # for the final output of the block
156
+ self.dropout = nn.Dropout(dropout_rate)
157
+ self.size = size
158
+ self.normalize_before = normalize_before
159
+
160
+ def forward(
161
+ self,
162
+ x: torch.Tensor,
163
+ mask: torch.Tensor,
164
+ pos_emb: torch.Tensor,
165
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
166
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
167
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
168
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
169
+ """Compute encoded features.
170
+
171
+ Args:
172
+ x (torch.Tensor): (#batch, time, size)
173
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
174
+ (0, 0, 0) means fake mask.
175
+ pos_emb (torch.Tensor): positional encoding, must not be None
176
+ for ConformerEncoderLayer.
177
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
178
+ (#batch, 1,time), (0, 0, 0) means fake mask.
179
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
180
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
181
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
182
+ (#batch=1, size, cache_t2)
183
+ Returns:
184
+ torch.Tensor: Output tensor (#batch, time, size).
185
+ torch.Tensor: Mask tensor (#batch, time, time).
186
+ torch.Tensor: att_cache tensor,
187
+ (#batch=1, head, cache_t1 + time, d_k * 2).
188
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
189
+ """
190
+
191
+ # whether to use macaron style
192
+ if self.feed_forward_macaron is not None:
193
+ residual = x
194
+ if self.normalize_before:
195
+ x = self.norm_ff_macaron(x)
196
+ x = residual + self.ff_scale * self.dropout(
197
+ self.feed_forward_macaron(x))
198
+ if not self.normalize_before:
199
+ x = self.norm_ff_macaron(x)
200
+
201
+ # multi-headed self-attention module
202
+ residual = x
203
+ if self.normalize_before:
204
+ x = self.norm_mha(x)
205
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
206
+ att_cache)
207
+ x = residual + self.dropout(x_att)
208
+ if not self.normalize_before:
209
+ x = self.norm_mha(x)
210
+
211
+ # convolution module
212
+ # Fake new cnn cache here, and then change it in conv_module
213
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
214
+ if self.conv_module is not None:
215
+ residual = x
216
+ if self.normalize_before:
217
+ x = self.norm_conv(x)
218
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
219
+ x = residual + self.dropout(x)
220
+
221
+ if not self.normalize_before:
222
+ x = self.norm_conv(x)
223
+
224
+ # feed forward module
225
+ residual = x
226
+ if self.normalize_before:
227
+ x = self.norm_ff(x)
228
+
229
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
230
+ if not self.normalize_before:
231
+ x = self.norm_ff(x)
232
+
233
+ if self.conv_module is not None:
234
+ x = self.norm_final(x)
235
+
236
+ return x, mask, new_att_cache, new_cnn_cache
HF_Deploy/src/chatterbox/models/s3gen/transformer/positionwise_feed_forward.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Positionwise feed forward layer definition."""
16
+
17
+ import torch
18
+
19
+
20
+ class PositionwiseFeedForward(torch.nn.Module):
21
+ """Positionwise feed forward layer.
22
+
23
+ FeedForward are appied on each position of the sequence.
24
+ The output dim is same with the input dim.
25
+
26
+ Args:
27
+ idim (int): Input dimenstion.
28
+ hidden_units (int): The number of hidden units.
29
+ dropout_rate (float): Dropout rate.
30
+ activation (torch.nn.Module): Activation function
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ idim: int,
36
+ hidden_units: int,
37
+ dropout_rate: float,
38
+ activation: torch.nn.Module = torch.nn.ReLU(),
39
+ ):
40
+ """Construct a PositionwiseFeedForward object."""
41
+ super(PositionwiseFeedForward, self).__init__()
42
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
43
+ self.activation = activation
44
+ self.dropout = torch.nn.Dropout(dropout_rate)
45
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
46
+
47
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
48
+ """Forward function.
49
+
50
+ Args:
51
+ xs: input tensor (B, L, D)
52
+ Returns:
53
+ output tensor, (B, L, D)
54
+ """
55
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
56
+
57
+
58
+ class MoEFFNLayer(torch.nn.Module):
59
+ """
60
+ Mixture of expert with Positionwise feed forward layer
61
+ See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
62
+ The output dim is same with the input dim.
63
+
64
+ Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
65
+ https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
66
+ Args:
67
+ n_expert: number of expert.
68
+ n_expert_per_token: The actual number of experts used for each frame
69
+ idim (int): Input dimenstion.
70
+ hidden_units (int): The number of hidden units.
71
+ dropout_rate (float): Dropout rate.
72
+ activation (torch.nn.Module): Activation function
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ n_expert: int,
78
+ n_expert_per_token: int,
79
+ idim: int,
80
+ hidden_units: int,
81
+ dropout_rate: float,
82
+ activation: torch.nn.Module = torch.nn.ReLU(),
83
+ ):
84
+ super(MoEFFNLayer, self).__init__()
85
+ self.gate = torch.nn.Linear(idim, n_expert, bias=False)
86
+ self.experts = torch.nn.ModuleList(
87
+ PositionwiseFeedForward(idim, hidden_units, dropout_rate,
88
+ activation) for _ in range(n_expert))
89
+ self.n_expert_per_token = n_expert_per_token
90
+
91
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
92
+ """Foward function.
93
+ Args:
94
+ xs: input tensor (B, L, D)
95
+ Returns:
96
+ output tensor, (B, L, D)
97
+
98
+ """
99
+ B, L, D = xs.size(
100
+ ) # batch size, sequence length, embedding dimension (idim)
101
+ xs = xs.view(-1, D) # (B*L, D)
102
+ router = self.gate(xs) # (B*L, n_expert)
103
+ logits, indices = torch.topk(
104
+ router, self.n_expert_per_token
105
+ ) # probs:(B*L, n_expert), indices: (B*L, n_expert)
106
+ weights = torch.nn.functional.softmax(
107
+ logits, dim=1,
108
+ dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
109
+ output = torch.zeros_like(xs) # (B*L, D)
110
+ for i, expert in enumerate(self.experts):
111
+ mask = indices == i
112
+ batch_idx, ith_expert = torch.where(mask)
113
+ output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
114
+ xs[batch_idx])
115
+ return output.view(B, L, D)
HF_Deploy/src/chatterbox/models/s3gen/transformer/subsampling.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Subsampling layer definition."""
17
+
18
+ from typing import Tuple, Union
19
+
20
+ import torch
21
+
22
+
23
+ class BaseSubsampling(torch.nn.Module):
24
+
25
+ def __init__(self):
26
+ super().__init__()
27
+ self.right_context = 0
28
+ self.subsampling_rate = 1
29
+
30
+ def position_encoding(self, offset: Union[int, torch.Tensor],
31
+ size: int) -> torch.Tensor:
32
+ return self.pos_enc.position_encoding(offset, size)
33
+
34
+
35
+ class EmbedinigNoSubsampling(BaseSubsampling):
36
+ """Embedding input without subsampling
37
+ """
38
+
39
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
40
+ pos_enc_class: torch.nn.Module):
41
+ super().__init__()
42
+ self.embed = torch.nn.Embedding(idim, odim)
43
+ self.pos_enc = pos_enc_class
44
+
45
+ def forward(
46
+ self,
47
+ x: torch.Tensor,
48
+ x_mask: torch.Tensor,
49
+ offset: Union[int, torch.Tensor] = 0
50
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51
+ """Input x.
52
+
53
+ Args:
54
+ x (torch.Tensor): Input tensor (#batch, time, idim).
55
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
56
+
57
+ Returns:
58
+ torch.Tensor: linear input tensor (#batch, time', odim),
59
+ where time' = time .
60
+ torch.Tensor: linear input mask (#batch, 1, time'),
61
+ where time' = time .
62
+
63
+ """
64
+ x = self.embed(x)
65
+ x, pos_emb = self.pos_enc(x, offset)
66
+ return x, pos_emb, x_mask
67
+
68
+
69
+ class LinearNoSubsampling(BaseSubsampling):
70
+ """Linear transform the input without subsampling
71
+
72
+ Args:
73
+ idim (int): Input dimension.
74
+ odim (int): Output dimension.
75
+ dropout_rate (float): Dropout rate.
76
+
77
+ """
78
+
79
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
80
+ pos_enc_class: torch.nn.Module):
81
+ """Construct an linear object."""
82
+ super().__init__()
83
+ self.out = torch.nn.Sequential(
84
+ torch.nn.Linear(idim, odim),
85
+ torch.nn.LayerNorm(odim, eps=1e-5),
86
+ torch.nn.Dropout(dropout_rate),
87
+ )
88
+ self.pos_enc = pos_enc_class
89
+ self.right_context = 0
90
+ self.subsampling_rate = 1
91
+
92
+ def forward(
93
+ self,
94
+ x: torch.Tensor,
95
+ x_mask: torch.Tensor,
96
+ offset: Union[int, torch.Tensor] = 0
97
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
98
+ """Input x.
99
+
100
+ Args:
101
+ x (torch.Tensor): Input tensor (#batch, time, idim).
102
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
103
+
104
+ Returns:
105
+ torch.Tensor: linear input tensor (#batch, time', odim),
106
+ where time' = time .
107
+ torch.Tensor: linear input mask (#batch, 1, time'),
108
+ where time' = time .
109
+
110
+ """
111
+ x = self.out(x)
112
+ x, pos_emb = self.pos_enc(x, offset)
113
+ return x, pos_emb, x_mask
114
+
115
+
116
+ class Conv1dSubsampling2(BaseSubsampling):
117
+ """Convolutional 1D subsampling (to 1/2 length).
118
+ It is designed for Whisper, ref:
119
+ https://github.com/openai/whisper/blob/main/whisper/model.py
120
+
121
+ Args:
122
+ idim (int): Input dimension.
123
+ odim (int): Output dimension.
124
+ dropout_rate (float): Dropout rate.
125
+
126
+ """
127
+
128
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
129
+ pos_enc_class: torch.nn.Module):
130
+ """Construct an Conv1dSubsampling2 object."""
131
+ super().__init__()
132
+ self.conv = torch.nn.Sequential(
133
+ torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
134
+ torch.nn.GELU(),
135
+ torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
136
+ torch.nn.GELU(),
137
+ )
138
+ self.pos_enc = pos_enc_class
139
+ # The right context for every conv layer is computed by:
140
+ # (kernel_size - 1) * frame_rate_of_this_layer
141
+ self.subsampling_rate = 2
142
+ # 4 = (3 - 1) * 1 + (3 - 1) * 1
143
+ self.right_context = 4
144
+
145
+ def forward(
146
+ self,
147
+ x: torch.Tensor,
148
+ x_mask: torch.Tensor,
149
+ offset: Union[int, torch.Tensor] = 0
150
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
151
+ """Subsample x.
152
+
153
+ Args:
154
+ x (torch.Tensor): Input tensor (#batch, time, idim).
155
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
156
+
157
+ Returns:
158
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
159
+ where time' = time // 2.
160
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
161
+ where time' = time // 2.
162
+ torch.Tensor: positional encoding
163
+
164
+ """
165
+ time = x.size(1)
166
+ x = x.transpose(1, 2) # (b, f, t)
167
+ x = self.conv(x)
168
+ x = x.transpose(1, 2) # (b, t, f)
169
+ x, pos_emb = self.pos_enc(x, offset)
170
+ return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
171
+
172
+
173
+ class Conv2dSubsampling4(BaseSubsampling):
174
+ """Convolutional 2D subsampling (to 1/4 length).
175
+
176
+ Args:
177
+ idim (int): Input dimension.
178
+ odim (int): Output dimension.
179
+ dropout_rate (float): Dropout rate.
180
+
181
+ """
182
+
183
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
184
+ pos_enc_class: torch.nn.Module):
185
+ """Construct an Conv2dSubsampling4 object."""
186
+ super().__init__()
187
+ self.conv = torch.nn.Sequential(
188
+ torch.nn.Conv2d(1, odim, 3, 2),
189
+ torch.nn.ReLU(),
190
+ torch.nn.Conv2d(odim, odim, 3, 2),
191
+ torch.nn.ReLU(),
192
+ )
193
+ self.out = torch.nn.Sequential(
194
+ torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
195
+ self.pos_enc = pos_enc_class
196
+ # The right context for every conv layer is computed by:
197
+ # (kernel_size - 1) * frame_rate_of_this_layer
198
+ self.subsampling_rate = 4
199
+ # 6 = (3 - 1) * 1 + (3 - 1) * 2
200
+ self.right_context = 6
201
+
202
+ def forward(
203
+ self,
204
+ x: torch.Tensor,
205
+ x_mask: torch.Tensor,
206
+ offset: Union[int, torch.Tensor] = 0
207
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
208
+ """Subsample x.
209
+
210
+ Args:
211
+ x (torch.Tensor): Input tensor (#batch, time, idim).
212
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
213
+
214
+ Returns:
215
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
216
+ where time' = time // 4.
217
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
218
+ where time' = time // 4.
219
+ torch.Tensor: positional encoding
220
+
221
+ """
222
+ x = x.unsqueeze(1) # (b, c=1, t, f)
223
+ x = self.conv(x)
224
+ b, c, t, f = x.size()
225
+ x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
226
+ x, pos_emb = self.pos_enc(x, offset)
227
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
228
+
229
+
230
+ class Conv2dSubsampling6(BaseSubsampling):
231
+ """Convolutional 2D subsampling (to 1/6 length).
232
+ Args:
233
+ idim (int): Input dimension.
234
+ odim (int): Output dimension.
235
+ dropout_rate (float): Dropout rate.
236
+ pos_enc (torch.nn.Module): Custom position encoding layer.
237
+ """
238
+
239
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
240
+ pos_enc_class: torch.nn.Module):
241
+ """Construct an Conv2dSubsampling6 object."""
242
+ super().__init__()
243
+ self.conv = torch.nn.Sequential(
244
+ torch.nn.Conv2d(1, odim, 3, 2),
245
+ torch.nn.ReLU(),
246
+ torch.nn.Conv2d(odim, odim, 5, 3),
247
+ torch.nn.ReLU(),
248
+ )
249
+ self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
250
+ odim)
251
+ self.pos_enc = pos_enc_class
252
+ # 10 = (3 - 1) * 1 + (5 - 1) * 2
253
+ self.subsampling_rate = 6
254
+ self.right_context = 10
255
+
256
+ def forward(
257
+ self,
258
+ x: torch.Tensor,
259
+ x_mask: torch.Tensor,
260
+ offset: Union[int, torch.Tensor] = 0
261
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
262
+ """Subsample x.
263
+ Args:
264
+ x (torch.Tensor): Input tensor (#batch, time, idim).
265
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
266
+
267
+ Returns:
268
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
269
+ where time' = time // 6.
270
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
271
+ where time' = time // 6.
272
+ torch.Tensor: positional encoding
273
+ """
274
+ x = x.unsqueeze(1) # (b, c, t, f)
275
+ x = self.conv(x)
276
+ b, c, t, f = x.size()
277
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
278
+ x, pos_emb = self.pos_enc(x, offset)
279
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
280
+
281
+
282
+ class Conv2dSubsampling8(BaseSubsampling):
283
+ """Convolutional 2D subsampling (to 1/8 length).
284
+
285
+ Args:
286
+ idim (int): Input dimension.
287
+ odim (int): Output dimension.
288
+ dropout_rate (float): Dropout rate.
289
+
290
+ """
291
+
292
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
293
+ pos_enc_class: torch.nn.Module):
294
+ """Construct an Conv2dSubsampling8 object."""
295
+ super().__init__()
296
+ self.conv = torch.nn.Sequential(
297
+ torch.nn.Conv2d(1, odim, 3, 2),
298
+ torch.nn.ReLU(),
299
+ torch.nn.Conv2d(odim, odim, 3, 2),
300
+ torch.nn.ReLU(),
301
+ torch.nn.Conv2d(odim, odim, 3, 2),
302
+ torch.nn.ReLU(),
303
+ )
304
+ self.linear = torch.nn.Linear(
305
+ odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
306
+ self.pos_enc = pos_enc_class
307
+ self.subsampling_rate = 8
308
+ # 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
309
+ self.right_context = 14
310
+
311
+ def forward(
312
+ self,
313
+ x: torch.Tensor,
314
+ x_mask: torch.Tensor,
315
+ offset: Union[int, torch.Tensor] = 0
316
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
317
+ """Subsample x.
318
+
319
+ Args:
320
+ x (torch.Tensor): Input tensor (#batch, time, idim).
321
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
322
+
323
+ Returns:
324
+ torch.Tensor: Subsampled tensor (#batch, time', odim),
325
+ where time' = time // 8.
326
+ torch.Tensor: Subsampled mask (#batch, 1, time'),
327
+ where time' = time // 8.
328
+ torch.Tensor: positional encoding
329
+ """
330
+ x = x.unsqueeze(1) # (b, c, t, f)
331
+ x = self.conv(x)
332
+ b, c, t, f = x.size()
333
+ x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
334
+ x, pos_emb = self.pos_enc(x, offset)
335
+ return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
336
+
337
+
338
+ class LegacyLinearNoSubsampling(BaseSubsampling):
339
+ """Linear transform the input without subsampling
340
+
341
+ Args:
342
+ idim (int): Input dimension.
343
+ odim (int): Output dimension.
344
+ dropout_rate (float): Dropout rate.
345
+
346
+ """
347
+
348
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
349
+ pos_enc_class: torch.nn.Module):
350
+ """Construct an linear object."""
351
+ super().__init__()
352
+ self.out = torch.nn.Sequential(
353
+ torch.nn.Linear(idim, odim),
354
+ torch.nn.LayerNorm(odim, eps=1e-5),
355
+ torch.nn.Dropout(dropout_rate),
356
+ torch.nn.ReLU(),
357
+ )
358
+ self.pos_enc = pos_enc_class
359
+ self.right_context = 0
360
+ self.subsampling_rate = 1
361
+
362
+ def forward(
363
+ self,
364
+ x: torch.Tensor,
365
+ x_mask: torch.Tensor,
366
+ offset: Union[int, torch.Tensor] = 0
367
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
368
+ """Input x.
369
+
370
+ Args:
371
+ x (torch.Tensor): Input tensor (#batch, time, idim).
372
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
373
+
374
+ Returns:
375
+ torch.Tensor: linear input tensor (#batch, time', odim),
376
+ where time' = time .
377
+ torch.Tensor: linear input mask (#batch, 1, time'),
378
+ where time' = time .
379
+
380
+ """
381
+ x = self.out(x)
382
+ x, pos_emb = self.pos_enc(x, offset)
383
+ return x, pos_emb, x_mask
HF_Deploy/src/chatterbox/models/s3gen/transformer/upsample_encoder.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
3
+ # 2024 Alibaba Inc (Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """Encoder definition."""
18
+ from typing import Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+ from torch.nn import functional as F
23
+
24
+ from .convolution import ConvolutionModule
25
+ from .encoder_layer import ConformerEncoderLayer
26
+ from .positionwise_feed_forward import PositionwiseFeedForward
27
+ from ..utils.class_utils import (
28
+ COSYVOICE_EMB_CLASSES,
29
+ COSYVOICE_SUBSAMPLE_CLASSES,
30
+ COSYVOICE_ATTENTION_CLASSES,
31
+ COSYVOICE_ACTIVATION_CLASSES,
32
+ )
33
+ from ..utils.mask import make_pad_mask
34
+ from ..utils.mask import add_optional_chunk_mask
35
+
36
+
37
+ class Upsample1D(nn.Module):
38
+ """A 1D upsampling layer with an optional convolution.
39
+
40
+ Parameters:
41
+ channels (`int`):
42
+ number of channels in the inputs and outputs.
43
+ use_conv (`bool`, default `False`):
44
+ option to use a convolution.
45
+ use_conv_transpose (`bool`, default `False`):
46
+ option to use a convolution transpose.
47
+ out_channels (`int`, optional):
48
+ number of output channels. Defaults to `channels`.
49
+ """
50
+
51
+ def __init__(self, channels: int, out_channels: int, stride: int = 2):
52
+ super().__init__()
53
+ self.channels = channels
54
+ self.out_channels = out_channels
55
+ self.stride = stride
56
+ # In this mode, first repeat interpolate, than conv with stride=1
57
+ self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
58
+
59
+ def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
60
+ outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
61
+ outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
62
+ outputs = self.conv(outputs)
63
+ return outputs, input_lengths * self.stride
64
+
65
+
66
+ class PreLookaheadLayer(nn.Module):
67
+ def __init__(self, channels: int, pre_lookahead_len: int = 1):
68
+ super().__init__()
69
+ self.channels = channels
70
+ self.pre_lookahead_len = pre_lookahead_len
71
+ self.conv1 = nn.Conv1d(
72
+ channels, channels,
73
+ kernel_size=pre_lookahead_len + 1,
74
+ stride=1, padding=0,
75
+ )
76
+ self.conv2 = nn.Conv1d(
77
+ channels, channels,
78
+ kernel_size=3, stride=1, padding=0,
79
+ )
80
+
81
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
82
+ """
83
+ inputs: (batch_size, seq_len, channels)
84
+ """
85
+ outputs = inputs.transpose(1, 2).contiguous()
86
+ # look ahead
87
+ outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
88
+ outputs = F.leaky_relu(self.conv1(outputs))
89
+ # outputs
90
+ outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
91
+ outputs = self.conv2(outputs)
92
+ outputs = outputs.transpose(1, 2).contiguous()
93
+
94
+ # residual connection
95
+ outputs = outputs + inputs
96
+ return outputs
97
+
98
+
99
+ class UpsampleConformerEncoder(torch.nn.Module):
100
+
101
+ def __init__(
102
+ self,
103
+ input_size: int = 512,
104
+ output_size: int = 512,
105
+ attention_heads: int = 8,
106
+ linear_units: int = 2048,
107
+ num_blocks: int = 6,
108
+ dropout_rate: float = 0.1,
109
+ positional_dropout_rate: float = 0.1,
110
+ attention_dropout_rate: float = 0.1,
111
+ input_layer: str = "linear",
112
+ pos_enc_layer_type: str = "rel_pos_espnet",
113
+ normalize_before: bool = True,
114
+ static_chunk_size: int = 0,
115
+ use_dynamic_chunk: bool = False,
116
+ global_cmvn: torch.nn.Module = None,
117
+ use_dynamic_left_chunk: bool = False,
118
+ positionwise_conv_kernel_size: int = 1,
119
+ macaron_style: bool = False,
120
+ selfattention_layer_type: str = "rel_selfattn",
121
+ activation_type: str = "swish",
122
+ use_cnn_module: bool = False,
123
+ cnn_module_kernel: int = 15,
124
+ causal: bool = False,
125
+ cnn_module_norm: str = "batch_norm",
126
+ key_bias: bool = True,
127
+ gradient_checkpointing: bool = False,
128
+ ):
129
+ """
130
+ Args:
131
+ input_size (int): input dim
132
+ output_size (int): dimension of attention
133
+ attention_heads (int): the number of heads of multi head attention
134
+ linear_units (int): the hidden units number of position-wise feed
135
+ forward
136
+ num_blocks (int): the number of decoder blocks
137
+ dropout_rate (float): dropout rate
138
+ attention_dropout_rate (float): dropout rate in attention
139
+ positional_dropout_rate (float): dropout rate after adding
140
+ positional encoding
141
+ input_layer (str): input layer type.
142
+ optional [linear, conv2d, conv2d6, conv2d8]
143
+ pos_enc_layer_type (str): Encoder positional encoding layer type.
144
+ opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
145
+ normalize_before (bool):
146
+ True: use layer_norm before each sub-block of a layer.
147
+ False: use layer_norm after each sub-block of a layer.
148
+ static_chunk_size (int): chunk size for static chunk training and
149
+ decoding
150
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
151
+ training or not, You can only use fixed chunk(chunk_size > 0)
152
+ or dyanmic chunk size(use_dynamic_chunk = True)
153
+ global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
154
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
155
+ dynamic chunk training
156
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
157
+ gradient_checkpointing: rerunning a forward-pass segment for each
158
+ checkpointed segment during backward.
159
+ """
160
+ super().__init__()
161
+ self._output_size = output_size
162
+
163
+ self.global_cmvn = global_cmvn
164
+ self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
165
+ input_size,
166
+ output_size,
167
+ dropout_rate,
168
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
169
+ positional_dropout_rate),
170
+ )
171
+
172
+ self.normalize_before = normalize_before
173
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
174
+ self.static_chunk_size = static_chunk_size
175
+ self.use_dynamic_chunk = use_dynamic_chunk
176
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
177
+ self.gradient_checkpointing = gradient_checkpointing
178
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
179
+ # self-attention module definition
180
+ encoder_selfattn_layer_args = (
181
+ attention_heads,
182
+ output_size,
183
+ attention_dropout_rate,
184
+ key_bias,
185
+ )
186
+ # feed-forward module definition
187
+ positionwise_layer_args = (
188
+ output_size,
189
+ linear_units,
190
+ dropout_rate,
191
+ activation,
192
+ )
193
+ # convolution module definition
194
+ convolution_layer_args = (output_size, cnn_module_kernel, activation,
195
+ cnn_module_norm, causal)
196
+ self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
197
+ self.encoders = torch.nn.ModuleList([
198
+ ConformerEncoderLayer(
199
+ output_size,
200
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
201
+ *encoder_selfattn_layer_args),
202
+ PositionwiseFeedForward(*positionwise_layer_args),
203
+ PositionwiseFeedForward(
204
+ *positionwise_layer_args) if macaron_style else None,
205
+ ConvolutionModule(
206
+ *convolution_layer_args) if use_cnn_module else None,
207
+ dropout_rate,
208
+ normalize_before,
209
+ ) for _ in range(num_blocks)
210
+ ])
211
+ self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
212
+ self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
213
+ input_size,
214
+ output_size,
215
+ dropout_rate,
216
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
217
+ positional_dropout_rate),
218
+ )
219
+ self.up_encoders = torch.nn.ModuleList([
220
+ ConformerEncoderLayer(
221
+ output_size,
222
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
223
+ *encoder_selfattn_layer_args),
224
+ PositionwiseFeedForward(*positionwise_layer_args),
225
+ PositionwiseFeedForward(
226
+ *positionwise_layer_args) if macaron_style else None,
227
+ ConvolutionModule(
228
+ *convolution_layer_args) if use_cnn_module else None,
229
+ dropout_rate,
230
+ normalize_before,
231
+ ) for _ in range(4)
232
+ ])
233
+
234
+ def output_size(self) -> int:
235
+ return self._output_size
236
+
237
+ def forward(
238
+ self,
239
+ xs: torch.Tensor,
240
+ xs_lens: torch.Tensor,
241
+ decoding_chunk_size: int = 0,
242
+ num_decoding_left_chunks: int = -1,
243
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
244
+ """Embed positions in tensor.
245
+
246
+ Args:
247
+ xs: padded input tensor (B, T, D)
248
+ xs_lens: input length (B)
249
+ decoding_chunk_size: decoding chunk size for dynamic chunk
250
+ 0: default for training, use random dynamic chunk.
251
+ <0: for decoding, use full chunk.
252
+ >0: for decoding, use fixed chunk size as set.
253
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
254
+ the chunk size is decoding_chunk_size.
255
+ >=0: use num_decoding_left_chunks
256
+ <0: use all left chunks
257
+ Returns:
258
+ encoder output tensor xs, and subsampled masks
259
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
260
+ masks: torch.Tensor batch padding mask after subsample
261
+ (B, 1, T' ~= T/subsample_rate)
262
+ NOTE(xcsong):
263
+ We pass the `__call__` method of the modules instead of `forward` to the
264
+ checkpointing API because `__call__` attaches all the hooks of the module.
265
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
266
+ """
267
+ T = xs.size(1)
268
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
269
+ if self.global_cmvn is not None:
270
+ xs = self.global_cmvn(xs)
271
+ xs, pos_emb, masks = self.embed(xs, masks)
272
+ mask_pad = masks # (B, 1, T/subsample_rate)
273
+ chunk_masks = add_optional_chunk_mask(xs, masks,
274
+ self.use_dynamic_chunk,
275
+ self.use_dynamic_left_chunk,
276
+ decoding_chunk_size,
277
+ self.static_chunk_size,
278
+ num_decoding_left_chunks)
279
+ # lookahead + conformer encoder
280
+ xs = self.pre_lookahead_layer(xs)
281
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
282
+
283
+ # upsample + conformer encoder
284
+ xs = xs.transpose(1, 2).contiguous()
285
+ xs, xs_lens = self.up_layer(xs, xs_lens)
286
+ xs = xs.transpose(1, 2).contiguous()
287
+ T = xs.size(1)
288
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
289
+ xs, pos_emb, masks = self.up_embed(xs, masks)
290
+ mask_pad = masks # (B, 1, T/subsample_rate)
291
+ chunk_masks = add_optional_chunk_mask(xs, masks,
292
+ self.use_dynamic_chunk,
293
+ self.use_dynamic_left_chunk,
294
+ decoding_chunk_size,
295
+ self.static_chunk_size * self.up_layer.stride,
296
+ num_decoding_left_chunks)
297
+ xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
298
+
299
+ if self.normalize_before:
300
+ xs = self.after_norm(xs)
301
+ # Here we assume the mask is not changed in encoder layers, so just
302
+ # return the masks before encoder layers, and the masks will be used
303
+ # for cross attention with decoder later
304
+ return xs, masks
305
+
306
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
307
+ pos_emb: torch.Tensor,
308
+ mask_pad: torch.Tensor) -> torch.Tensor:
309
+ for layer in self.encoders:
310
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
311
+ return xs
312
+
313
+ def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
314
+ pos_emb: torch.Tensor,
315
+ mask_pad: torch.Tensor) -> torch.Tensor:
316
+ for layer in self.up_encoders:
317
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
318
+ return xs
HF_Deploy/src/chatterbox/models/s3gen/utils/class_utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright [2023-11-28] <sxc19@mails.tsinghua.edu.cn, Xingchen Song>
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import torch
16
+
17
+ from ..transformer.activation import Swish
18
+ from ..transformer.subsampling import (
19
+ LinearNoSubsampling,
20
+ EmbedinigNoSubsampling,
21
+ Conv1dSubsampling2,
22
+ Conv2dSubsampling4,
23
+ Conv2dSubsampling6,
24
+ Conv2dSubsampling8,
25
+ )
26
+ from ..transformer.embedding import (
27
+ PositionalEncoding,
28
+ RelPositionalEncoding,
29
+ WhisperPositionalEncoding,
30
+ LearnablePositionalEncoding,
31
+ NoPositionalEncoding)
32
+ from ..transformer.attention import (MultiHeadedAttention,
33
+ RelPositionMultiHeadedAttention)
34
+ from ..transformer.embedding import EspnetRelPositionalEncoding
35
+ from ..transformer.subsampling import LegacyLinearNoSubsampling
36
+
37
+
38
+ COSYVOICE_ACTIVATION_CLASSES = {
39
+ "hardtanh": torch.nn.Hardtanh,
40
+ "tanh": torch.nn.Tanh,
41
+ "relu": torch.nn.ReLU,
42
+ "selu": torch.nn.SELU,
43
+ "swish": getattr(torch.nn, "SiLU", Swish),
44
+ "gelu": torch.nn.GELU,
45
+ }
46
+
47
+ COSYVOICE_SUBSAMPLE_CLASSES = {
48
+ "linear": LinearNoSubsampling,
49
+ "linear_legacy": LegacyLinearNoSubsampling,
50
+ "embed": EmbedinigNoSubsampling,
51
+ "conv1d2": Conv1dSubsampling2,
52
+ "conv2d": Conv2dSubsampling4,
53
+ "conv2d6": Conv2dSubsampling6,
54
+ "conv2d8": Conv2dSubsampling8,
55
+ 'paraformer_dummy': torch.nn.Identity
56
+ }
57
+
58
+ COSYVOICE_EMB_CLASSES = {
59
+ "embed": PositionalEncoding,
60
+ "abs_pos": PositionalEncoding,
61
+ "rel_pos": RelPositionalEncoding,
62
+ "rel_pos_espnet": EspnetRelPositionalEncoding,
63
+ "no_pos": NoPositionalEncoding,
64
+ "abs_pos_whisper": WhisperPositionalEncoding,
65
+ "embed_learnable_pe": LearnablePositionalEncoding,
66
+ }
67
+
68
+ COSYVOICE_ATTENTION_CLASSES = {
69
+ "selfattn": MultiHeadedAttention,
70
+ "rel_selfattn": RelPositionMultiHeadedAttention,
71
+ }