Masaaki Kawata commited on
Commit
e07602f
·
1 Parent(s): 44e5c9c

initial commit

Browse files
.dockerignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ .env
3
+ .git
4
+ .gitignore
5
+ __pycache__/
6
+ *.py[cod]
7
+ .cache/
8
+ .venv/
9
+ venv/
10
+ logs/
11
+ data/
12
+ models/
.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+
164
+ # Generated by MacOS
165
+ .DS_Store
166
+
167
+ #GPT_SoVITS/text/ja_userdic/
Dockerfile ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive \
4
+ PYTHONUNBUFFERED=1 \
5
+ PIP_NO_CACHE_DIR=1 \
6
+ HF_HOME=/data/huggingface \
7
+ WHISPER_CACHE_DIR=/data/whisper \
8
+ GRADIO_SERVER_NAME=0.0.0.0 \
9
+ GRADIO_SERVER_PORT=7860 \
10
+ NVIDIA_VISIBLE_DEVICES=all \
11
+ NVIDIA_DRIVER_CAPABILITIES=compute,utility
12
+
13
+ RUN apt-get update \
14
+ && apt-get install -y --no-install-recommends \
15
+ build-essential \
16
+ curl \
17
+ ffmpeg \
18
+ git \
19
+ libsndfile1 \
20
+ sox \
21
+ && apt-get clean \
22
+ && rm -rf /var/lib/apt/lists/*
23
+
24
+ WORKDIR /app
25
+
26
+ COPY requirements.txt .
27
+ RUN python -m pip install --upgrade pip setuptools wheel \
28
+ && python -m pip install \
29
+ --index-url https://download.pytorch.org/whl/cu128 \
30
+ torch==2.10.0+cu128 \
31
+ torchaudio==2.10.0+cu128 \
32
+ && sed '/^torch==/d; /^torchaudio==/d' requirements.txt > /tmp/requirements-no-torch.txt \
33
+ && python -m pip install -r /tmp/requirements-no-torch.txt
34
+
35
+ COPY app.py .
36
+ COPY faster_qwen3_tts ./faster_qwen3_tts
37
+ COPY qwen_tts ./qwen_tts
38
+
39
+ RUN useradd --create-home --uid 1000 appuser \
40
+ && mkdir -p /data/huggingface /data/whisper \
41
+ && chown -R appuser:appuser /app /data
42
+
43
+ USER appuser
44
+
45
+ EXPOSE 7860
46
+
47
+ CMD ["python", "app.py"]
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
  title: Merkurius
3
- emoji: 🦀
4
- colorFrom: yellow
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 6.14.0
8
  python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Merkurius
3
+ emoji: 🌟
4
+ colorFrom: pink
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 6.14.0
8
  python_version: '3.12'
9
  app_file: app.py
10
  pinned: false
11
+ short_description: milchchan.com
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import subprocess
2
+ #subprocess.run('pip install flash-attn==2.7.4.post1', shell=True)
3
+ import io
4
+ import re
5
+ import os
6
+ import json
7
+ import hashlib
8
+ import threading
9
+ import time
10
+ import numpy as np
11
+ import torch
12
+ import spaces
13
+ import whisper
14
+ import gradio as gr
15
+ from gradio.themes.base import Base
16
+ from gradio.themes.utils import colors, fonts, sizes
17
+ from typing import Iterable
18
+ from dotenv import load_dotenv
19
+ from urllib.request import urlopen, Request
20
+ from scipy.signal import resample_poly
21
+ #from huggingface_hub import snapshot_download
22
+ #from qwen_tts import Qwen3TTSModel
23
+ from faster_qwen3_tts import FasterQwen3TTS
24
+
25
+
26
+ load_dotenv(verbose=False)
27
+
28
+ #TTS_MODEL = Qwen3TTSModel.from_pretrained(snapshot_download('Qwen/Qwen3-TTS-12Hz-1.7B-Base', token=os.environ['HF_TOKEN']), device_map=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), dtype=torch.bfloat16, token=os.environ['HF_TOKEN'], attn_implementation='kernels-community/flash-attn3')
29
+ TTS_MODEL = FasterQwen3TTS.from_pretrained('Qwen/Qwen3-TTS-12Hz-1.7B-Base')
30
+ WHISPER_MODEL = whisper.load_model('turbo', device='cpu', download_root=os.environ.get('WHISPER_CACHE_DIR'))
31
+ REFERENCE_AUDIO_TRANSCRIPTION_CACHE: dict[str, tuple[float, str, str]] = {}
32
+ REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LOCK = threading.Lock()
33
+ REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LIMIT = max(1, int(os.environ.get('REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LIMIT', 100)))
34
+
35
+
36
+ class Theme(Base):
37
+ def __init__(
38
+ self,
39
+ *,
40
+ primary_hue: colors.Color | str = colors.neutral,
41
+ secondary_hue: colors.Color | str = colors.neutral,
42
+ neutral_hue: colors.Color | str = colors.neutral,
43
+ spacing_size: sizes.Size | str = sizes.spacing_md,
44
+ radius_size: sizes.Size | str = sizes.radius_md,
45
+ text_size: sizes.Size | str = sizes.text_md,
46
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (fonts.GoogleFont('Barlow'), 'ui-sans-serif', 'sans-serif'),
47
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (fonts.GoogleFont('IBM Plex Mono'), 'ui-monospace', 'monospace'),
48
+ ):
49
+ super().__init__(
50
+ primary_hue=primary_hue,
51
+ secondary_hue=secondary_hue,
52
+ neutral_hue=neutral_hue,
53
+ spacing_size=spacing_size,
54
+ radius_size=radius_size,
55
+ text_size=text_size,
56
+ font=font,
57
+ font_mono=font_mono,
58
+ )
59
+ super().set(
60
+ color_accent='rgb(0 231 255 / 1)',
61
+ slider_color='rgb(0 231 255 / 1)',
62
+ slider_color_dark='rgb(0 231 255 / 1)',
63
+ button_primary_background_fill='rgb(0 231 255 / 1)',
64
+ button_primary_background_fill_hover='rgb(0 231 255 / .75)',
65
+ button_primary_text_color='#ffffff',
66
+ button_primary_background_fill_dark='rgb(0 231 255 / 1)',
67
+ button_primary_background_fill_hover_dark='rgb(0 231 255 / .75)',
68
+ button_primary_text_color_dark='#ffffff',
69
+ loader_color='rgb(255 199 229 / 1)',
70
+ loader_color_dark='rgb(255 199 229 / 1)'
71
+ )
72
+
73
+
74
+ def _normalize_audio(wav, eps=1e-12, clip=True):
75
+ """Normalize audio to float32 in [-1, 1] range."""
76
+ x = np.asarray(wav)
77
+
78
+ if np.issubdtype(x.dtype, np.integer):
79
+ info = np.iinfo(x.dtype)
80
+
81
+ if info.min < 0:
82
+ y = x.astype(np.float32) / max(abs(info.min), info.max)
83
+ else:
84
+ mid = (info.max + 1) / 2.0
85
+ y = (x.astype(np.float32) - mid) / mid
86
+
87
+ elif np.issubdtype(x.dtype, np.floating):
88
+ y = x.astype(np.float32)
89
+ m = np.max(np.abs(y)) if y.size else 0.0
90
+
91
+ if m > 1.0 + 1e-6:
92
+ y = y / (m + eps)
93
+ else:
94
+ return None
95
+
96
+ if clip:
97
+ y = np.clip(y, -1.0, 1.0)
98
+
99
+ if y.ndim > 1:
100
+ y = np.mean(y, axis=-1).astype(np.float32)
101
+
102
+ return y
103
+
104
+
105
+ def _resample(x: np.ndarray, original_sample_rate: int, target_sample_rate: int, axis: int = 0) -> np.ndarray:
106
+ g = np.gcd(original_sample_rate, target_sample_rate)
107
+
108
+ return resample_poly(x, up=target_sample_rate // g, down=original_sample_rate // g, axis=axis)
109
+
110
+
111
+ def _reference_audio_hash(reference_audio: tuple[np.ndarray, int]) -> str:
112
+ audio = reference_audio[0]
113
+ audio = np.ascontiguousarray(np.asarray(audio))
114
+ digest = hashlib.sha256()
115
+ digest.update(audio.tobytes())
116
+
117
+ return digest.hexdigest()
118
+
119
+
120
+ def _detect_reference_text_and_language(reference_audio: tuple[np.ndarray, int], sample_rate: int) -> tuple[str, str]:
121
+ audio = np.asarray(reference_audio[0])
122
+
123
+ if audio.ndim == 2:
124
+ audio = audio.mean(axis=1)
125
+
126
+ if sample_rate != 16000:
127
+ audio = _resample(audio, sample_rate, 16000).astype(np.float32)
128
+
129
+ model = WHISPER_MODEL.to(device='cuda' if torch.cuda.is_available() else 'cpu')
130
+ audio = np.clip(audio, -1.0, 1.0)
131
+ audio = whisper.pad_or_trim(audio)
132
+ mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(model.device)
133
+ _, probs = model.detect_language(mel)
134
+ detected_language = max(probs, key=probs.get)
135
+ result = whisper.decode(model, mel, whisper.DecodingOptions())
136
+ reference_text = re.sub(r'\s*\n\s*', '', result.text)
137
+
138
+ if detected_language == 'ja':
139
+ converted_reference_text = generate_text(reference_text)
140
+
141
+ if converted_reference_text is not None:
142
+ reference_text = converted_reference_text
143
+
144
+ return reference_text, detected_language
145
+
146
+
147
+ def _get_reference_text_and_language(reference_audio: tuple[np.ndarray, int], sample_rate: int) -> tuple[str, str]:
148
+ cache_key = _reference_audio_hash(reference_audio)
149
+
150
+ with REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LOCK:
151
+ cached_result = REFERENCE_AUDIO_TRANSCRIPTION_CACHE.get(cache_key)
152
+
153
+ if cached_result is not None:
154
+ _, reference_text, detected_language = cached_result
155
+ REFERENCE_AUDIO_TRANSCRIPTION_CACHE[cache_key] = (time.time(), reference_text, detected_language)
156
+
157
+ if cached_result is not None:
158
+ return reference_text, detected_language
159
+
160
+ reference_text, detected_language = _detect_reference_text_and_language(reference_audio, sample_rate)
161
+
162
+ with REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LOCK:
163
+ REFERENCE_AUDIO_TRANSCRIPTION_CACHE[cache_key] = (time.time(), reference_text, detected_language)
164
+
165
+ if len(REFERENCE_AUDIO_TRANSCRIPTION_CACHE) > REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LIMIT:
166
+ expired_cache_keys = sorted(
167
+ REFERENCE_AUDIO_TRANSCRIPTION_CACHE,
168
+ key=lambda key: REFERENCE_AUDIO_TRANSCRIPTION_CACHE[key][0]
169
+ )[:-REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LIMIT]
170
+
171
+ for expired_cache_key in expired_cache_keys:
172
+ del REFERENCE_AUDIO_TRANSCRIPTION_CACHE[expired_cache_key]
173
+
174
+ return reference_text, detected_language
175
+
176
+
177
+ def generate_text(prompt: str) -> str | None:
178
+ system_prompt = '''あなたは日本語テキストを「読み(かな)」だけに変換する変換器です。
179
+
180
+ 出力に含めてよい文字は ひらがな・カタカナ・長音記号ー・空白 のみです。改行も禁止(1行で出力)。
181
+ 入力に含まれる 漢字は必ずかなにする。
182
+ 英数字・記号は、可能な範囲で日本語のカナ読みにする(例:AI→えーあい、LLM→えるえるえむ、2026→にせんにじゅうろく)。
183
+ 出力は 変換後の本文のみ。説明、注釈、引用符、箇条書き、コードブロックは一切禁止。
184
+ 最後に必ず自己検査を行う:出力が ^[ぁ-ゟ゠-ヿー ]+$ に一致しない場合、条件を満たすまで修正してから出力する。
185
+ それでも読めない文字がある場合は、意味を落としてよいので「最も近いかな」に置き換える(記号は省略よりも読みを優先。ただし許可文字以外は絶対に出さない)。'''
186
+ request = Request('https://api.openai.com/v1/responses', data=json.dumps({
187
+ 'model': os.environ.get('OPENAI_MODEL', 'gpt-5.4-mini'),
188
+ 'input': [{
189
+ 'role': 'developer',
190
+ 'content': system_prompt
191
+ },
192
+ {
193
+ 'role': 'user',
194
+ 'content': [
195
+ {
196
+ 'type': 'input_text',
197
+ 'text': prompt
198
+ }
199
+ ]
200
+ }],
201
+ 'temperature': 1,
202
+ 'reasoning': {'effort': 'none'},
203
+ }).encode('utf-8'), method='POST', headers={'Content-Type': 'application/json', 'Authorization': f'Bearer {os.environ["OPENAI_API_KEY"]}'})
204
+
205
+ with urlopen(request) as response:
206
+ result = json.loads(response.read().decode('utf-8'))
207
+
208
+ for output in result['output']:
209
+ if 'type' in output and output['type'] == 'message':
210
+ for content in output['content']:
211
+ if 'type' in content and content['type'] == 'output_text':
212
+ return content['text']
213
+
214
+ return None
215
+
216
+
217
+ @spaces.GPU(duration=30)
218
+ def generate_voice_clone(input_text: str, language: str | None, reference_audio: np.ndarray, reference_text: str | None, temperature: float, progress: gr.Progress=gr.Progress(track_tqdm=True)) -> (np.ndarray, str | None, str | None):
219
+ language_codes = {'en': 'English', 'ja': 'Japanese'}
220
+ transcribed_text = None
221
+ detected_language = None
222
+
223
+ if isinstance(reference_audio, tuple) and len(reference_audio) == 2 and isinstance(reference_audio[0], int):
224
+ sample_rate, wav = reference_audio
225
+ sample_rate = int(sample_rate)
226
+ reference_audio = (_normalize_audio(wav), sample_rate)
227
+
228
+ if isinstance(reference_audio, dict) and 'sampling_rate' in reference_audio and 'data' in reference_audio:
229
+ sample_rate = int(reference_audio['sampling_rate'])
230
+ reference_audio = (_normalize_audio(reference_audio['data']), sample_rate)
231
+
232
+ if reference_text is None or len(reference_text) == 0:
233
+ reference_text, detected_language = _get_reference_text_and_language(reference_audio, sample_rate)
234
+ transcribed_text = reference_text
235
+
236
+ if language is None:
237
+ if detected_language in language_codes:
238
+ language = language_codes[detected_language]
239
+ else:
240
+ language = 'Auto'
241
+ elif language == 'Auto':
242
+ if detected_language in language_codes:
243
+ language = language_codes[detected_language]
244
+ elif language in language_codes:
245
+ language = language_codes[language]
246
+
247
+ elif language is None:
248
+ language = 'Auto'
249
+
250
+ elif language in language_codes:
251
+ language = language_codes[language]
252
+
253
+ if sample_rate != 48000:
254
+ reference_audio = (_resample(reference_audio[0], sample_rate, 48000), 48000)
255
+
256
+ wavs, sample_rate = TTS_MODEL.generate_voice_clone(text=input_text.strip(), language=language, ref_audio=reference_audio, ref_text=reference_text.strip(), temperature=temperature, append_silence=False)
257
+ #wavs, sample_rate = TTS_MODEL.generate_voice_clone(text=input_text.strip(), language=language, ref_audio=reference_audio, ref_text=reference_text, max_new_tokens=2048, temperature=temperature)
258
+
259
+ return (sample_rate, (np.clip(wavs[0], -1.0, 1.0) * 32768.0).round().astype(np.int16)), transcribed_text, detected_language
260
+
261
+
262
+ with gr.Blocks() as demo:
263
+ with gr.Row():
264
+ with gr.Column(scale=2):
265
+ with gr.Group():
266
+ tts_reference_audio = gr.Audio(label='Reference Audio', type='numpy', buttons=['download'], waveform_options={'waveform_color': 'rgb(0 231 255 / 1)', 'waveform_progress_color': 'rgb(255 199 229 / 1)'})
267
+ tts_reference_text = gr.Textbox(label='Reference Text', value='', lines=1)
268
+
269
+ tts_input_text = gr.Textbox(label='Input', lines=4)
270
+ tts_language = gr.Dropdown(label='Language', choices=[('Automatic', 'Auto'), ('English', 'en'), ('Japanese', 'ja')], value='Auto', interactive=True)
271
+ tts_temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.9, step=0.1, label='Temperature')
272
+ tts_generate_button = gr.Button('Generate', variant='primary')
273
+
274
+ with gr.Column(scale=2):
275
+ tts_audio_output = gr.Audio(label='Output', type='numpy', buttons=['download'], waveform_options={'waveform_color': 'rgb(0 231 255 / 1)', 'waveform_progress_color': 'rgb(255 199 229 / 1)'})
276
+ tts_transcribed_text = gr.Label(label='Transcript', value='')
277
+ tts_detected_language = gr.Label(label='Language', value='')
278
+
279
+ tts_generate_button.click(fn=generate_voice_clone, inputs=[tts_input_text, tts_language, tts_reference_audio, tts_reference_text, tts_temperature_slider], outputs=[tts_audio_output, tts_transcribed_text, tts_detected_language], api_name='synthesize')
280
+
281
+
282
+ if __name__ == '__main__':
283
+ demo.launch(
284
+ server_name=os.environ.get('GRADIO_SERVER_NAME', '0.0.0.0'),
285
+ server_port=int(os.environ.get('GRADIO_SERVER_PORT', os.environ.get('PORT', 7860))),
286
+ theme=Theme(),
287
+ css='.column>.row>.column:first-of-type .block { border-width: 0px !important; }'
288
+ )
docker-compose.yml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ ai:
3
+ container_name: "milchchanai"
4
+ build:
5
+ context: .
6
+ restart: unless-stopped
7
+ tty: true
8
+ env_file:
9
+ - .env
10
+ environment:
11
+ GRADIO_SERVER_NAME: 0.0.0.0
12
+ GRADIO_SERVER_PORT: 7860
13
+ HF_HOME: /data/huggingface
14
+ WHISPER_CACHE_DIR: /data/whisper
15
+ volumes:
16
+ - hf-cache:/data/huggingface
17
+ - whisper-cache:/data/whisper
18
+ ports:
19
+ - "7860:7860"
20
+ deploy:
21
+ resources:
22
+ reservations:
23
+ devices:
24
+ - capabilities: [gpu]
25
+ volumes:
26
+ hf-cache:
27
+ whisper-cache:
faster_qwen3_tts/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ faster-qwen3-tts: Real-time Qwen3-TTS inference using CUDA graphs
3
+ """
4
+ from .model import FasterQwen3TTS
5
+
6
+ __version__ = "0.2.5"
7
+ __all__ = ["FasterQwen3TTS"]
faster_qwen3_tts/cli.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """CLI for FasterQwen3TTS."""
3
+ import argparse
4
+ import os
5
+ import sys
6
+ import time
7
+ import numpy as np
8
+ import soundfile as sf
9
+ import torch
10
+
11
+ from faster_qwen3_tts import FasterQwen3TTS
12
+
13
+
14
+ def _load_model(model_id: str, device: str, dtype: str):
15
+ if dtype == "bf16":
16
+ torch_dtype = torch.bfloat16
17
+ elif dtype == "fp16":
18
+ torch_dtype = torch.float16
19
+ else:
20
+ torch_dtype = torch.float32
21
+
22
+ return FasterQwen3TTS.from_pretrained(
23
+ model_id,
24
+ device=device,
25
+ dtype=torch_dtype,
26
+ attn_implementation="sdpa",
27
+ max_seq_len=2048,
28
+ )
29
+
30
+
31
+ def _write_audio(out_path: str, audio: np.ndarray, sr: int):
32
+ os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
33
+ sf.write(out_path, audio, sr)
34
+
35
+
36
+ def _stream_to_audio(gen):
37
+ chunks = []
38
+ sr = None
39
+ for audio_chunk, sr, _ in gen:
40
+ chunks.append(audio_chunk)
41
+ if not chunks:
42
+ return np.zeros(1, dtype=np.float32), 24000
43
+ return np.concatenate(chunks), sr
44
+
45
+
46
+ def cmd_clone(args):
47
+ model = _load_model(args.model, args.device, args.dtype)
48
+
49
+ if args.streaming:
50
+ start = time.perf_counter()
51
+ gen = model.generate_voice_clone_streaming(
52
+ text=args.text,
53
+ language=args.language,
54
+ ref_audio=args.ref_audio,
55
+ ref_text=args.ref_text,
56
+ chunk_size=args.chunk_size,
57
+ max_new_tokens=args.max_new_tokens,
58
+ temperature=args.temperature,
59
+ top_k=args.top_k,
60
+ do_sample=not args.greedy,
61
+ repetition_penalty=args.repetition_penalty,
62
+ xvec_only=args.xvec_only,
63
+ non_streaming_mode=args.non_streaming_mode,
64
+ )
65
+ audio, sr = _stream_to_audio(gen)
66
+ total_time = time.perf_counter() - start
67
+ audio_dur = len(audio) / sr if sr else 0.0
68
+ rtf = audio_dur / total_time if total_time > 0 else 0.0
69
+ else:
70
+ start = time.perf_counter()
71
+ audio_list, sr = model.generate_voice_clone(
72
+ text=args.text,
73
+ language=args.language,
74
+ ref_audio=args.ref_audio,
75
+ ref_text=args.ref_text,
76
+ max_new_tokens=args.max_new_tokens,
77
+ temperature=args.temperature,
78
+ top_k=args.top_k,
79
+ do_sample=not args.greedy,
80
+ repetition_penalty=args.repetition_penalty,
81
+ xvec_only=args.xvec_only,
82
+ non_streaming_mode=args.non_streaming_mode,
83
+ )
84
+ audio = audio_list[0]
85
+ total_time = time.perf_counter() - start
86
+ audio_dur = len(audio) / sr if sr else 0.0
87
+ rtf = audio_dur / total_time if total_time > 0 else 0.0
88
+
89
+ _write_audio(args.output, audio, sr)
90
+ print(f"Wrote {args.output} (dur {audio_dur:.2f}s, RTF {rtf:.2f})")
91
+
92
+
93
+ def cmd_custom(args):
94
+ model = _load_model(args.model, args.device, args.dtype)
95
+
96
+ if args.list_speakers:
97
+ speakers = model.model.get_supported_speakers() or []
98
+ print("\n".join(speakers))
99
+ return
100
+
101
+ if not args.speaker:
102
+ print("ERROR: --speaker is required (use --list-speakers)")
103
+ sys.exit(2)
104
+
105
+ if args.streaming:
106
+ start = time.perf_counter()
107
+ gen = model.generate_custom_voice_streaming(
108
+ text=args.text,
109
+ speaker=args.speaker,
110
+ language=args.language,
111
+ instruct=args.instruct,
112
+ chunk_size=args.chunk_size,
113
+ max_new_tokens=args.max_new_tokens,
114
+ temperature=args.temperature,
115
+ top_k=args.top_k,
116
+ do_sample=not args.greedy,
117
+ repetition_penalty=args.repetition_penalty,
118
+ )
119
+ audio, sr = _stream_to_audio(gen)
120
+ total_time = time.perf_counter() - start
121
+ audio_dur = len(audio) / sr if sr else 0.0
122
+ rtf = audio_dur / total_time if total_time > 0 else 0.0
123
+ else:
124
+ start = time.perf_counter()
125
+ audio_list, sr = model.generate_custom_voice(
126
+ text=args.text,
127
+ speaker=args.speaker,
128
+ language=args.language,
129
+ instruct=args.instruct,
130
+ max_new_tokens=args.max_new_tokens,
131
+ temperature=args.temperature,
132
+ top_k=args.top_k,
133
+ do_sample=not args.greedy,
134
+ repetition_penalty=args.repetition_penalty,
135
+ )
136
+ audio = audio_list[0]
137
+ total_time = time.perf_counter() - start
138
+ audio_dur = len(audio) / sr if sr else 0.0
139
+ rtf = audio_dur / total_time if total_time > 0 else 0.0
140
+
141
+ _write_audio(args.output, audio, sr)
142
+ print(f"Wrote {args.output} (dur {audio_dur:.2f}s, RTF {rtf:.2f})")
143
+
144
+
145
+ def cmd_design(args):
146
+ model = _load_model(args.model, args.device, args.dtype)
147
+
148
+ if args.streaming:
149
+ start = time.perf_counter()
150
+ gen = model.generate_voice_design_streaming(
151
+ text=args.text,
152
+ instruct=args.instruct,
153
+ language=args.language,
154
+ chunk_size=args.chunk_size,
155
+ max_new_tokens=args.max_new_tokens,
156
+ temperature=args.temperature,
157
+ top_k=args.top_k,
158
+ do_sample=not args.greedy,
159
+ repetition_penalty=args.repetition_penalty,
160
+ )
161
+ audio, sr = _stream_to_audio(gen)
162
+ total_time = time.perf_counter() - start
163
+ audio_dur = len(audio) / sr if sr else 0.0
164
+ rtf = audio_dur / total_time if total_time > 0 else 0.0
165
+ else:
166
+ start = time.perf_counter()
167
+ audio_list, sr = model.generate_voice_design(
168
+ text=args.text,
169
+ instruct=args.instruct,
170
+ language=args.language,
171
+ max_new_tokens=args.max_new_tokens,
172
+ temperature=args.temperature,
173
+ top_k=args.top_k,
174
+ do_sample=not args.greedy,
175
+ repetition_penalty=args.repetition_penalty,
176
+ )
177
+ audio = audio_list[0]
178
+ total_time = time.perf_counter() - start
179
+ audio_dur = len(audio) / sr if sr else 0.0
180
+ rtf = audio_dur / total_time if total_time > 0 else 0.0
181
+
182
+ _write_audio(args.output, audio, sr)
183
+ print(f"Wrote {args.output} (dur {audio_dur:.2f}s, RTF {rtf:.2f})")
184
+
185
+
186
+ def cmd_serve(args):
187
+ model = _load_model(args.model, args.device, args.dtype)
188
+
189
+ if args.mode == "clone":
190
+ if not args.ref_audio or not args.ref_text:
191
+ print("ERROR: --ref-audio and --ref-text are required for clone mode")
192
+ sys.exit(2)
193
+ if args.mode == "custom" and not args.speaker:
194
+ print("ERROR: --speaker is required for custom mode")
195
+ sys.exit(2)
196
+ if args.mode == "design" and not args.instruct:
197
+ print("ERROR: --instruct is required for design mode")
198
+ sys.exit(2)
199
+
200
+ print("Server started. Enter text per line. Type 'exit' or 'quit' to stop.")
201
+ idx = 1
202
+ for line in sys.stdin:
203
+ text = line.strip()
204
+ if not text:
205
+ continue
206
+ if text.lower() in ("exit", "quit", "stop"):
207
+ break
208
+
209
+ out_path = os.path.join(args.output_dir, f"out_{idx:04d}.wav")
210
+ idx += 1
211
+
212
+ start = time.perf_counter()
213
+
214
+ if args.mode == "clone":
215
+ if args.streaming:
216
+ gen = model.generate_voice_clone_streaming(
217
+ text=text,
218
+ language=args.language,
219
+ ref_audio=args.ref_audio,
220
+ ref_text=args.ref_text,
221
+ chunk_size=args.chunk_size,
222
+ max_new_tokens=args.max_new_tokens,
223
+ temperature=args.temperature,
224
+ top_k=args.top_k,
225
+ do_sample=not args.greedy,
226
+ repetition_penalty=args.repetition_penalty,
227
+ xvec_only=False,
228
+ non_streaming_mode=args.non_streaming_mode,
229
+ )
230
+ audio, sr = _stream_to_audio(gen)
231
+ else:
232
+ audio_list, sr = model.generate_voice_clone(
233
+ text=text,
234
+ language=args.language,
235
+ ref_audio=args.ref_audio,
236
+ ref_text=args.ref_text,
237
+ max_new_tokens=args.max_new_tokens,
238
+ temperature=args.temperature,
239
+ top_k=args.top_k,
240
+ do_sample=not args.greedy,
241
+ repetition_penalty=args.repetition_penalty,
242
+ xvec_only=False,
243
+ non_streaming_mode=args.non_streaming_mode,
244
+ )
245
+ audio = audio_list[0]
246
+ elif args.mode == "custom":
247
+ if args.streaming:
248
+ gen = model.generate_custom_voice_streaming(
249
+ text=text,
250
+ speaker=args.speaker,
251
+ language=args.language,
252
+ instruct=args.instruct,
253
+ chunk_size=args.chunk_size,
254
+ max_new_tokens=args.max_new_tokens,
255
+ temperature=args.temperature,
256
+ top_k=args.top_k,
257
+ do_sample=not args.greedy,
258
+ repetition_penalty=args.repetition_penalty,
259
+ )
260
+ audio, sr = _stream_to_audio(gen)
261
+ else:
262
+ audio_list, sr = model.generate_custom_voice(
263
+ text=text,
264
+ speaker=args.speaker,
265
+ language=args.language,
266
+ instruct=args.instruct,
267
+ max_new_tokens=args.max_new_tokens,
268
+ temperature=args.temperature,
269
+ top_k=args.top_k,
270
+ do_sample=not args.greedy,
271
+ repetition_penalty=args.repetition_penalty,
272
+ )
273
+ audio = audio_list[0]
274
+ else:
275
+ if args.streaming:
276
+ gen = model.generate_voice_design_streaming(
277
+ text=text,
278
+ instruct=args.instruct,
279
+ language=args.language,
280
+ chunk_size=args.chunk_size,
281
+ max_new_tokens=args.max_new_tokens,
282
+ temperature=args.temperature,
283
+ top_k=args.top_k,
284
+ do_sample=not args.greedy,
285
+ repetition_penalty=args.repetition_penalty,
286
+ )
287
+ audio, sr = _stream_to_audio(gen)
288
+ else:
289
+ audio_list, sr = model.generate_voice_design(
290
+ text=text,
291
+ instruct=args.instruct,
292
+ language=args.language,
293
+ max_new_tokens=args.max_new_tokens,
294
+ temperature=args.temperature,
295
+ top_k=args.top_k,
296
+ do_sample=not args.greedy,
297
+ repetition_penalty=args.repetition_penalty,
298
+ )
299
+ audio = audio_list[0]
300
+
301
+ _write_audio(out_path, audio, sr)
302
+ total_time = time.perf_counter() - start
303
+ audio_dur = len(audio) / sr if sr else 0.0
304
+ rtf = audio_dur / total_time if total_time > 0 else 0.0
305
+ print(f"Wrote {out_path} (dur {audio_dur:.2f}s, RTF {rtf:.2f})")
306
+
307
+
308
+ def build_parser():
309
+ p = argparse.ArgumentParser(prog="faster-qwen3-tts", description="FasterQwen3TTS CLI")
310
+ p.add_argument("--device", default="cuda", help="Device (cuda or cpu)")
311
+ p.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"], help="Model dtype")
312
+ sub = p.add_subparsers(dest="command", required=True)
313
+
314
+ def add_common(sp):
315
+ sp.add_argument("--text", required=True, help="Text to synthesize")
316
+ sp.add_argument("--language", default="Auto", help="Language (Auto, English, French, ...)" )
317
+ sp.add_argument("--output", required=True, help="Output wav path")
318
+ sp.add_argument("--model", required=True, help="Model id or local path")
319
+ sp.add_argument("--max-new-tokens", type=int, default=2048)
320
+ sp.add_argument("--temperature", type=float, default=0.9)
321
+ sp.add_argument("--top-k", type=int, default=50)
322
+ sp.add_argument("--repetition-penalty", type=float, default=1.05)
323
+ sp.add_argument("--greedy", action="store_true", help="Disable sampling")
324
+ sp.add_argument("--streaming", action="store_true", help="Use streaming generation")
325
+ nsm_group = sp.add_mutually_exclusive_group()
326
+ nsm_group.add_argument(
327
+ "--non-streaming-mode",
328
+ dest="non_streaming_mode",
329
+ action="store_true",
330
+ help="Prefill full text before decode",
331
+ )
332
+ nsm_group.add_argument(
333
+ "--no-non-streaming-mode",
334
+ dest="non_streaming_mode",
335
+ action="store_false",
336
+ help="Use upstream step-by-step text feeding during decode",
337
+ )
338
+ sp.set_defaults(non_streaming_mode=True)
339
+ sp.add_argument("--chunk-size", type=int, default=8, help="Streaming chunk size")
340
+
341
+ sp = sub.add_parser("clone", help="Voice cloning (reference audio)")
342
+ add_common(sp)
343
+ sp.add_argument("--ref-audio", required=True, help="Reference audio path")
344
+ sp.add_argument("--ref-text", required=True, help="Reference transcript")
345
+ sp.add_argument(
346
+ "--xvec-only",
347
+ action="store_true",
348
+ help="Use speaker embedding only instead of upstream-default ICL mode",
349
+ )
350
+ sp.set_defaults(non_streaming_mode=False)
351
+ sp.set_defaults(fn=cmd_clone)
352
+
353
+ sp = sub.add_parser("custom", help="CustomVoice model (speaker IDs)")
354
+ add_common(sp)
355
+ sp.add_argument("--speaker", help="Speaker ID")
356
+ sp.add_argument("--instruct", default="", help="Optional instruction")
357
+ sp.add_argument("--list-speakers", action="store_true", help="List available speaker IDs")
358
+ sp.set_defaults(fn=cmd_custom)
359
+
360
+ sp = sub.add_parser("design", help="VoiceDesign model (instruction-based)")
361
+ add_common(sp)
362
+ sp.add_argument("--instruct", required=True, help="Voice/style instruction")
363
+ sp.set_defaults(fn=cmd_design)
364
+
365
+ sp = sub.add_parser("serve", help="Keep model hot and generate multiple requests from stdin")
366
+ sp.add_argument("--mode", required=True, choices=["clone", "custom", "design"])
367
+ sp.add_argument("--model", required=True, help="Model id or local path")
368
+ sp.add_argument("--language", default="Auto", help="Language (Auto, English, French, ...)")
369
+ sp.add_argument("--ref-audio", help="Reference audio path (clone)")
370
+ sp.add_argument("--ref-text", help="Reference transcript (clone)")
371
+ sp.add_argument("--speaker", help="Speaker ID (custom)")
372
+ sp.add_argument("--instruct", default="", help="Instruction (custom/design)")
373
+ sp.add_argument("--streaming", action="store_true", help="Use streaming generation")
374
+ nsm_group = sp.add_mutually_exclusive_group()
375
+ nsm_group.add_argument(
376
+ "--non-streaming-mode",
377
+ dest="non_streaming_mode",
378
+ action="store_true",
379
+ help="Prefill full text before decode",
380
+ )
381
+ nsm_group.add_argument(
382
+ "--no-non-streaming-mode",
383
+ dest="non_streaming_mode",
384
+ action="store_false",
385
+ help="Use upstream step-by-step text feeding during decode",
386
+ )
387
+ sp.set_defaults(non_streaming_mode=False)
388
+ sp.add_argument("--chunk-size", type=int, default=8, help="Streaming chunk size")
389
+ sp.add_argument("--max-new-tokens", type=int, default=2048)
390
+ sp.add_argument("--temperature", type=float, default=0.9)
391
+ sp.add_argument("--top-k", type=int, default=50)
392
+ sp.add_argument("--repetition-penalty", type=float, default=1.05)
393
+ sp.add_argument("--greedy", action="store_true", help="Disable sampling")
394
+ sp.add_argument("--output-dir", default="outputs", help="Directory for output wavs")
395
+ sp.set_defaults(fn=cmd_serve)
396
+
397
+ return p
398
+
399
+
400
+ def main():
401
+ parser = build_parser()
402
+ args = parser.parse_args()
403
+ args.fn(args)
404
+
405
+
406
+ if __name__ == "__main__":
407
+ main()
faster_qwen3_tts/generate.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Non-streaming generation loop using CUDA graphs for both predictor and talker.
4
+ """
5
+ import time
6
+ from typing import Optional, Tuple
7
+
8
+ import torch
9
+
10
+ from .predictor_graph import PredictorGraph
11
+ from .sampling import apply_repetition_penalty, sample_logits
12
+ from .talker_graph import TalkerGraph
13
+
14
+
15
+ @torch.inference_mode()
16
+ def fast_generate(
17
+ talker,
18
+ talker_input_embeds: torch.Tensor,
19
+ attention_mask: torch.Tensor,
20
+ trailing_text_hiddens: torch.Tensor,
21
+ tts_pad_embed: torch.Tensor,
22
+ config,
23
+ predictor_graph: PredictorGraph,
24
+ talker_graph: TalkerGraph,
25
+ max_new_tokens: int = 2048,
26
+ min_new_tokens: int = 2,
27
+ temperature: float = 0.9,
28
+ top_k: int = 50,
29
+ top_p: float = 1.0,
30
+ do_sample: bool = True,
31
+ repetition_penalty: float = 1.05,
32
+ subtalker_dosample: Optional[bool] = None,
33
+ subtalker_top_k: Optional[int] = None,
34
+ subtalker_top_p: Optional[float] = None,
35
+ subtalker_temperature: Optional[float] = None,
36
+ parity_mode: bool = False,
37
+ ) -> Tuple[Optional[torch.Tensor], dict]:
38
+ """
39
+ Fast autoregressive generation with CUDA-graphed predictor and talker.
40
+ """
41
+ eos_id = config.codec_eos_token_id
42
+ num_code_groups = config.num_code_groups
43
+ vocab_size = config.vocab_size
44
+ device = talker_input_embeds.device
45
+
46
+ suppress_mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
47
+ suppress_start = max(0, vocab_size - 1024)
48
+ for i in range(suppress_start, vocab_size):
49
+ if i != eos_id:
50
+ suppress_mask[i] = True
51
+
52
+ if parity_mode:
53
+ suppress_tokens = [i for i in range(suppress_start, vocab_size) if i != eos_id]
54
+ t_start = time.time()
55
+ talker_result = talker.generate(
56
+ inputs_embeds=talker_input_embeds,
57
+ attention_mask=attention_mask,
58
+ trailing_text_hidden=trailing_text_hiddens,
59
+ tts_pad_embed=tts_pad_embed,
60
+ max_new_tokens=max_new_tokens,
61
+ min_new_tokens=min_new_tokens,
62
+ do_sample=do_sample,
63
+ top_k=top_k,
64
+ top_p=top_p,
65
+ temperature=temperature,
66
+ repetition_penalty=repetition_penalty,
67
+ eos_token_id=eos_id,
68
+ suppress_tokens=suppress_tokens,
69
+ subtalker_dosample=subtalker_dosample if subtalker_dosample is not None else do_sample,
70
+ subtalker_top_k=subtalker_top_k if subtalker_top_k is not None else top_k,
71
+ subtalker_top_p=subtalker_top_p if subtalker_top_p is not None else top_p,
72
+ subtalker_temperature=subtalker_temperature if subtalker_temperature is not None else temperature,
73
+ output_hidden_states=True,
74
+ return_dict_in_generate=True,
75
+ )
76
+ talker_codes = torch.stack(
77
+ [hid[-1] for hid in talker_result.hidden_states if hid[-1] is not None],
78
+ dim=1,
79
+ )
80
+ first_codebook = talker_codes[:, :, 0]
81
+ is_stop_token = first_codebook == eos_id
82
+ stop_indices = torch.argmax(is_stop_token.int(), dim=1)
83
+ has_stop_token = is_stop_token.any(dim=1)
84
+ effective_lengths = torch.where(has_stop_token, stop_indices, talker_codes.shape[1])
85
+ talker_codes_list = [talker_codes[i, :length, :] for i, length in enumerate(effective_lengths)]
86
+
87
+ torch.cuda.synchronize()
88
+ total_time = time.time() - t_start
89
+ steps = int(talker_codes_list[0].shape[0]) if talker_codes_list else 0
90
+ timing = {
91
+ 'prefill_ms': 0.0,
92
+ 'decode_s': total_time,
93
+ 'steps': steps,
94
+ 'ms_per_step': (total_time / steps * 1000) if steps > 0 else 0.0,
95
+ 'steps_per_s': (steps / total_time) if total_time > 0 else 0.0,
96
+ }
97
+ return talker_codes_list[0] if talker_codes_list else None, timing
98
+
99
+ predictor = talker.code_predictor
100
+ talker_codec_embed = talker.get_input_embeddings()
101
+ talker_codec_head = talker.codec_head
102
+ predictor_codec_embeds = predictor.get_input_embeddings()
103
+
104
+ # === PREFILL (still uses HF forward for variable-length prefill) ===
105
+ t_start = time.time()
106
+
107
+ out = talker.forward(
108
+ inputs_embeds=talker_input_embeds,
109
+ attention_mask=attention_mask,
110
+ use_cache=True,
111
+ output_hidden_states=True,
112
+ return_dict=True,
113
+ trailing_text_hidden=trailing_text_hiddens,
114
+ tts_pad_embed=tts_pad_embed,
115
+ generation_step=None,
116
+ past_hidden=None,
117
+ past_key_values=None,
118
+ )
119
+
120
+ talker_past_kv = out.past_key_values
121
+ past_hidden = out.past_hidden
122
+ gen_step = out.generation_step
123
+
124
+ logits = out.logits[:, -1, :]
125
+ suppress_eos = min_new_tokens > 0
126
+ token = sample_logits(
127
+ logits,
128
+ temperature=temperature,
129
+ top_k=top_k,
130
+ top_p=top_p,
131
+ do_sample=do_sample,
132
+ suppress_mask=suppress_mask,
133
+ suppress_tokens=[eos_id] if suppress_eos else None,
134
+ )
135
+
136
+ # Copy prefill KV cache into talker graph's static cache
137
+ prefill_len = talker_graph.prefill_kv(talker_past_kv)
138
+ # Sync padding mask + rope deltas for decode parity
139
+ rope_deltas = getattr(talker, "rope_deltas", None)
140
+ talker_graph.set_generation_state(attention_mask, rope_deltas)
141
+
142
+ torch.cuda.synchronize()
143
+ t_prefill = time.time() - t_start
144
+
145
+ # === DECODE LOOP ===
146
+ t_decode_start = time.time()
147
+ all_codec_ids = []
148
+
149
+ for step_idx in range(max_new_tokens):
150
+ if token.item() == eos_id:
151
+ break
152
+
153
+ # --- CUDA-Graphed Code Predictor ---
154
+ last_id_hidden = talker_codec_embed(token.unsqueeze(1)) # [1, 1, H]
155
+ pred_input = torch.cat((past_hidden, last_id_hidden), dim=1) # [1, 2, H]
156
+ codebook_token_ids = predictor_graph.run(pred_input) # [15] long tensor
157
+
158
+ # Build full codec: [first_cb, cb1, ..., cb15]
159
+ all_cb = torch.cat([token.view(1), codebook_token_ids]) # [16]
160
+ all_codec_ids.append(all_cb.detach())
161
+
162
+ # --- Build input embedding for talker ---
163
+ codec_hiddens = [last_id_hidden]
164
+ for i in range(num_code_groups - 1):
165
+ codec_hiddens.append(predictor_codec_embeds[i](codebook_token_ids[i].unsqueeze(0).unsqueeze(0)))
166
+ inputs_embeds = torch.cat(codec_hiddens, dim=1).sum(1, keepdim=True)
167
+
168
+ if gen_step < trailing_text_hiddens.shape[1]:
169
+ inputs_embeds = inputs_embeds + trailing_text_hiddens[:, gen_step].unsqueeze(1)
170
+ else:
171
+ inputs_embeds = inputs_embeds + tts_pad_embed
172
+
173
+ # --- CUDA-Graphed Talker decode step ---
174
+ current_pos = prefill_len + step_idx
175
+ if current_pos >= talker_graph.max_seq_len - 1:
176
+ # Stop if we exceed max_seq_len
177
+ break
178
+
179
+ hidden_states = talker_graph.run(inputs_embeds, position=current_pos)
180
+ # hidden_states is the static output buffer - use it immediately
181
+
182
+ logits = talker_codec_head(hidden_states[:, -1, :]).unsqueeze(0)
183
+
184
+ if repetition_penalty != 1.0 and len(all_codec_ids) > 0:
185
+ history = torch.stack([c[0] for c in all_codec_ids])
186
+ logits = apply_repetition_penalty(logits, history, repetition_penalty)
187
+
188
+ suppress_eos = len(all_codec_ids) < min_new_tokens
189
+ token = sample_logits(
190
+ logits.squeeze(0),
191
+ temperature=temperature,
192
+ top_k=top_k,
193
+ top_p=top_p,
194
+ do_sample=do_sample,
195
+ suppress_mask=suppress_mask,
196
+ suppress_tokens=[eos_id] if suppress_eos else None,
197
+ )
198
+ past_hidden = hidden_states[:, -1:, :].clone() # clone since it's the static buffer
199
+ gen_step += 1
200
+
201
+ torch.cuda.synchronize()
202
+ t_decode = time.time() - t_decode_start
203
+
204
+ n_steps = len(all_codec_ids)
205
+ timing = {
206
+ 'prefill_ms': t_prefill * 1000,
207
+ 'decode_s': t_decode,
208
+ 'steps': n_steps,
209
+ 'ms_per_step': (t_decode / n_steps * 1000) if n_steps > 0 else 0,
210
+ 'steps_per_s': (n_steps / t_decode) if t_decode > 0 else 0,
211
+ }
212
+
213
+ if all_codec_ids:
214
+ return torch.stack(all_codec_ids), timing
215
+ return None, timing
faster_qwen3_tts/model.py ADDED
@@ -0,0 +1,1370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FasterQwen3TTS: Real-time TTS using CUDA graph capture.
3
+
4
+ Wrapper class that provides a Qwen3-TTS API while using
5
+ CUDA graphs for 6-10x speedup.
6
+ """
7
+ import logging
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Generator, List, Optional, Tuple, Union
10
+
11
+ import numpy as np
12
+ import soundfile as sf
13
+ import torch
14
+
15
+ from .utils import suppress_flash_attn_warning
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+
21
+
22
+ class FasterQwen3TTS:
23
+ """
24
+ Qwen3-TTS model with CUDA graphs for real-time inference.
25
+
26
+ Compatible API with Qwen3TTSModel, but uses CUDA graph
27
+ capture for 6-10x speedup on NVIDIA GPUs.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ base_model,
33
+ predictor_graph,
34
+ talker_graph,
35
+ device: str = "cuda",
36
+ dtype: torch.dtype = torch.bfloat16,
37
+ max_seq_len: int = 2048,
38
+ ):
39
+ self.model = base_model # The qwen-tts Qwen3TTSModel instance
40
+ self.predictor_graph = predictor_graph
41
+ self.talker_graph = talker_graph
42
+ self.device = device
43
+ self.dtype = dtype
44
+ self.max_seq_len = max_seq_len
45
+ self.sample_rate = self._infer_sample_rate(base_model)
46
+ self._warmed_up = False
47
+ self._voice_prompt_cache = {} # Cache (ref_audio, ref_text) -> (vcp, ref_ids)
48
+
49
+ @staticmethod
50
+ def _get_speech_tokenizer(base_model):
51
+ """Return the nested qwen-tts speech tokenizer when available."""
52
+ return getattr(getattr(base_model, "model", None), "speech_tokenizer", None)
53
+
54
+ @property
55
+ def speech_tokenizer(self):
56
+ """Expose the codec decoder on the wrapper's public surface."""
57
+ speech_tokenizer = self._get_speech_tokenizer(self.model)
58
+ if speech_tokenizer is None:
59
+ raise AttributeError("Underlying model does not expose a speech_tokenizer")
60
+ return speech_tokenizer
61
+
62
+ @staticmethod
63
+ def _infer_sample_rate(base_model) -> int:
64
+ """Infer output audio sample rate from qwen-tts internals."""
65
+ # Qwen3-TTS model IDs include "12Hz", but that is codec frame-rate (tokens/s),
66
+ # not waveform sampling rate. Generated audio is 24kHz.
67
+ sample_rate = None
68
+
69
+ speech_tokenizer = FasterQwen3TTS._get_speech_tokenizer(base_model)
70
+ if speech_tokenizer is not None:
71
+ sample_rate = getattr(speech_tokenizer, "sample_rate", None)
72
+
73
+ if sample_rate is None:
74
+ sample_rate = getattr(base_model, "sample_rate", None)
75
+
76
+ if sample_rate is None:
77
+ logger.warning(
78
+ "Could not infer sample rate from base model; defaulting to 24000 Hz."
79
+ )
80
+ return 24000
81
+
82
+ return int(sample_rate)
83
+
84
+ @classmethod
85
+ def from_pretrained(
86
+ cls,
87
+ model_name: str,
88
+ device: str = "cuda",
89
+ dtype: Union[str, torch.dtype] = torch.bfloat16,
90
+ attn_implementation: str = "sdpa",
91
+ max_seq_len: int = 2048,
92
+ ):
93
+ """
94
+ Load Qwen3-TTS model and prepare CUDA graphs.
95
+
96
+ Args:
97
+ model_name: Model path or HuggingFace Hub ID
98
+ device: Device to use ("cuda" or "cpu")
99
+ dtype: Data type for inference
100
+ attn_implementation: Attention implementation ("sdpa" or "flash_attention_2")
101
+ max_seq_len: Maximum sequence length for static cache
102
+
103
+ Returns:
104
+ FasterQwen3TTS instance
105
+ """
106
+ if isinstance(dtype, str):
107
+ dtype = getattr(torch, dtype)
108
+
109
+ if not device.startswith("cuda") or not torch.cuda.is_available():
110
+ raise ValueError("CUDA graphs require CUDA device")
111
+
112
+ logger.info(f"Loading Qwen3-TTS model: {model_name}")
113
+
114
+ # Import here to avoid dependency issues (and suppress flash-attn warning)
115
+ with suppress_flash_attn_warning():
116
+ from qwen_tts import Qwen3TTSModel
117
+ from .predictor_graph import PredictorGraph
118
+ from .talker_graph import TalkerGraph
119
+ # Load base model using qwen-tts library
120
+ base_model = Qwen3TTSModel.from_pretrained(
121
+ model_name,
122
+ device_map=device,
123
+ torch_dtype=dtype,
124
+ attn_implementation=attn_implementation,
125
+ )
126
+
127
+ talker = base_model.model.talker
128
+ talker_config = base_model.model.config.talker_config
129
+
130
+ # Extract predictor config from loaded model
131
+ predictor = talker.code_predictor
132
+ pred_config = predictor.model.config
133
+ talker_hidden = talker_config.hidden_size
134
+
135
+ # Build CUDA graphs
136
+ logger.info("Building CUDA graphs...")
137
+ predictor_graph = PredictorGraph(
138
+ predictor,
139
+ pred_config,
140
+ talker_hidden,
141
+ device=device,
142
+ dtype=dtype,
143
+ do_sample=True, # subtalker_dosample (Default: True)
144
+ top_k=50, # subtalker_top_k (Default: 50)
145
+ top_p=1.0, # subtalker_top_p (Default: 1.0)
146
+ temperature=0.2, # subtalker_temperature (Default: 0.9)
147
+ )
148
+
149
+ talker_graph = TalkerGraph(
150
+ talker.model,
151
+ talker_config,
152
+ device=device,
153
+ dtype=dtype,
154
+ max_seq_len=max_seq_len,
155
+ )
156
+
157
+ logger.info("CUDA graphs initialized (will capture on first run)")
158
+
159
+ return cls(
160
+ base_model=base_model,
161
+ predictor_graph=predictor_graph,
162
+ talker_graph=talker_graph,
163
+ device=device,
164
+ dtype=dtype,
165
+ max_seq_len=max_seq_len,
166
+ )
167
+
168
+ def _warmup(self, prefill_len: int):
169
+ """Warm up and capture CUDA graphs with given prefill length."""
170
+ if self._warmed_up:
171
+ return
172
+
173
+ logger.info("Warming up CUDA graphs...")
174
+ self.predictor_graph.capture(num_warmup=3)
175
+ self.talker_graph.capture(prefill_len=prefill_len, num_warmup=3)
176
+ self._warmed_up = True
177
+ logger.info("CUDA graphs captured and ready")
178
+
179
+ def generate(
180
+ self,
181
+ text: str,
182
+ language: str = "English",
183
+ max_new_tokens: int = 2048,
184
+ temperature: float = 0.9,
185
+ top_k: int = 50,
186
+ do_sample: bool = True,
187
+ repetition_penalty: float = 1.05,
188
+ ) -> Tuple[list, int]:
189
+ """
190
+ Generate speech from text using default voice.
191
+
192
+ Not yet implemented - use generate_voice_clone() instead.
193
+ """
194
+ raise NotImplementedError(
195
+ "Default voice generation not yet implemented. "
196
+ "Use generate_voice_clone() with reference audio."
197
+ )
198
+
199
+ def _load_ref_audio_with_silence(self, ref_audio: Union[str, Path, tuple], silence_secs: float = 0.5) -> Tuple[np.ndarray, int]:
200
+ """Load reference audio and optionally append trailing silence.
201
+
202
+ The ICL voice-cloning prompt ends with the last codec token of the reference
203
+ audio, so the model's first generated token is conditioned on whatever phoneme
204
+ the reference ends with. Appending a short silence makes the last tokens
205
+ encode silence instead, preventing that phoneme from bleeding into the start
206
+ of the generated speech. Set silence_secs=0 to disable this behavior.
207
+ """
208
+ if isinstance(ref_audio, tuple):
209
+ audio, sr = ref_audio
210
+ else:
211
+ audio, sr = sf.read(str(ref_audio), dtype="float32", always_2d=False)
212
+
213
+ if audio.ndim > 1:
214
+ audio = audio.mean(axis=1) # convert to mono
215
+ if silence_secs > 0:
216
+ silence = np.zeros(int(silence_secs * sr), dtype=np.float32)
217
+ audio = np.concatenate([audio, silence])
218
+ return audio, sr
219
+
220
+ def _resolve_voice_clone_prompt(
221
+ self,
222
+ input_ids,
223
+ ref_audio: Optional[Union[str, Path, tuple]],
224
+ ref_text: str,
225
+ xvec_only: bool,
226
+ append_silence: bool,
227
+ voice_clone_prompt: Optional[Union[Dict[str, Any], List[Any]]],
228
+ ) -> Tuple[Dict[str, Any], list, bool]:
229
+ """Resolve voice clone prompt data and return (prompt, ref_ids, using_icl_mode)."""
230
+ if voice_clone_prompt is not None:
231
+ return self._resolve_precomputed_voice_clone_prompt(
232
+ input_ids=input_ids,
233
+ ref_text=ref_text,
234
+ voice_clone_prompt=voice_clone_prompt,
235
+ )
236
+ if ref_audio is None:
237
+ raise ValueError("ref_audio is required when voice_clone_prompt is not provided")
238
+
239
+ return self._resolve_voice_clone_prompt_from_reference(
240
+ input_ids=input_ids,
241
+ ref_audio=ref_audio,
242
+ ref_text=ref_text,
243
+ xvec_only=xvec_only,
244
+ append_silence=append_silence,
245
+ )
246
+
247
+ def _resolve_precomputed_voice_clone_prompt(
248
+ self,
249
+ input_ids,
250
+ ref_text: str,
251
+ voice_clone_prompt: Union[Dict[str, Any], List[Any]],
252
+ ) -> Tuple[Dict[str, Any], list, bool]:
253
+ if isinstance(voice_clone_prompt, list):
254
+ if len(voice_clone_prompt) != len(input_ids):
255
+ raise ValueError(
256
+ f"voice_clone_prompt must have length {len(input_ids)}, got {len(voice_clone_prompt)}"
257
+ )
258
+
259
+ vcp = self.model._prompt_items_to_voice_clone_prompt(voice_clone_prompt)
260
+ ref_ids = []
261
+ for item in voice_clone_prompt:
262
+ if bool(item.icl_mode):
263
+ item_ref_text = item.ref_text if item.ref_text else ref_text
264
+ if not item_ref_text:
265
+ raise ValueError(
266
+ "ref_text is required when voice_clone_prompt uses ICL mode."
267
+ )
268
+ ref_id = self.model._tokenize_texts(
269
+ [self.model._build_ref_text(item_ref_text)]
270
+ )[0]
271
+ ref_ids.append(ref_id)
272
+ else:
273
+ ref_ids.append(None)
274
+
275
+ return vcp, ref_ids, any(vcp["icl_mode"])
276
+
277
+ required_keys = ("ref_spk_embedding",)
278
+ missing = [k for k in required_keys if k not in voice_clone_prompt]
279
+ if missing:
280
+ raise ValueError(
281
+ f"voice_clone_prompt missing required keys: {missing}. "
282
+ f"Expected keys: {list(required_keys)}"
283
+ )
284
+
285
+ list_keys = ("ref_spk_embedding", "x_vector_only_mode", "icl_mode", "ref_code")
286
+ for key in list_keys:
287
+ if key not in voice_clone_prompt:
288
+ continue
289
+ value = voice_clone_prompt[key]
290
+ if not isinstance(value, list) or len(value) != len(input_ids):
291
+ raise ValueError(
292
+ f"voice_clone_prompt[{key!r}] must be a list with length {len(input_ids)}"
293
+ )
294
+
295
+ xvec_modes = voice_clone_prompt.get("x_vector_only_mode", [True] * len(input_ids))
296
+ if "icl_mode" in voice_clone_prompt:
297
+ icl_modes = [bool(v) for v in voice_clone_prompt["icl_mode"]]
298
+ for i, (xvec_mode, icl_mode) in enumerate(zip(xvec_modes, icl_modes)):
299
+ if bool(xvec_mode) == bool(icl_mode):
300
+ raise ValueError(
301
+ f"voice_clone_prompt has inconsistent mode flags at index {i}: "
302
+ "x_vector_only_mode and icl_mode must be opposites"
303
+ )
304
+ else:
305
+ icl_modes = [not bool(v) for v in xvec_modes]
306
+
307
+ ref_codes = voice_clone_prompt.get("ref_code", [None] * len(input_ids))
308
+ for i, (xvec_mode, icl_mode, ref_code) in enumerate(zip(xvec_modes, icl_modes, ref_codes)):
309
+ if bool(xvec_mode) and ref_code is not None:
310
+ raise ValueError(
311
+ f"voice_clone_prompt index {i}: ref_code must be None in x_vector_only mode"
312
+ )
313
+ if bool(icl_mode) and ref_code is None:
314
+ raise ValueError(
315
+ f"voice_clone_prompt index {i}: ref_code is required in ICL mode"
316
+ )
317
+
318
+ vcp = dict(
319
+ ref_code=ref_codes,
320
+ ref_spk_embedding=voice_clone_prompt["ref_spk_embedding"],
321
+ x_vector_only_mode=[bool(v) for v in xvec_modes],
322
+ icl_mode=[bool(v) for v in icl_modes],
323
+ )
324
+ using_icl_mode = any(vcp["icl_mode"])
325
+
326
+ if using_icl_mode:
327
+ if not ref_text:
328
+ raise ValueError(
329
+ "ref_text is required when voice_clone_prompt uses ICL mode."
330
+ )
331
+ ref_texts = [self.model._build_ref_text(ref_text)]
332
+ # NOTE: single ref_text is shared across all ICL items in the batch.
333
+ ref_id = self.model._tokenize_texts(ref_texts)[0]
334
+ ref_ids = [ref_id if is_icl else None for is_icl in vcp["icl_mode"]]
335
+ else:
336
+ ref_ids = [None] * len(input_ids)
337
+
338
+ return vcp, ref_ids, using_icl_mode
339
+
340
+ def _resolve_voice_clone_prompt_from_reference(
341
+ self,
342
+ input_ids,
343
+ ref_audio: Union[str, Path, tuple],
344
+ ref_text: str,
345
+ xvec_only: bool,
346
+ append_silence: bool,
347
+ ) -> Tuple[Dict[str, Any], list, bool]:
348
+ using_icl_mode = not xvec_only
349
+ cache_key = (str(ref_audio), ref_text, xvec_only, append_silence)
350
+ if cache_key in self._voice_prompt_cache:
351
+ vcp, ref_ids = self._voice_prompt_cache[cache_key]
352
+ return vcp, ref_ids, using_icl_mode
353
+
354
+ if xvec_only:
355
+ prompt_items = self.model.create_voice_clone_prompt(
356
+ ref_audio=str(ref_audio),
357
+ ref_text="",
358
+ x_vector_only_mode=True,
359
+ )
360
+ spk_emb = prompt_items[0].ref_spk_embedding
361
+ vcp = dict(
362
+ ref_code=[None],
363
+ ref_spk_embedding=[spk_emb],
364
+ x_vector_only_mode=[True],
365
+ icl_mode=[False],
366
+ )
367
+ ref_ids = [None] * len(input_ids)
368
+ self._voice_prompt_cache[cache_key] = (vcp, ref_ids)
369
+ return vcp, ref_ids, using_icl_mode
370
+
371
+ silence_secs = 0.5 if append_silence else 0.0
372
+ ref_audio_input = self._load_ref_audio_with_silence(ref_audio, silence_secs=silence_secs)
373
+ prompt_items = self.model.create_voice_clone_prompt(
374
+ ref_audio=ref_audio_input,
375
+ ref_text=ref_text
376
+ )
377
+ vcp = self.model._prompt_items_to_voice_clone_prompt(prompt_items)
378
+
379
+ ref_ids = []
380
+ rt = prompt_items[0].ref_text
381
+ if rt:
382
+ ref_texts = [self.model._build_ref_text(rt)]
383
+ ref_ids.append(self.model._tokenize_texts(ref_texts)[0])
384
+ else:
385
+ ref_ids.append(None)
386
+
387
+ self._voice_prompt_cache[cache_key] = (vcp, ref_ids)
388
+ return vcp, ref_ids, using_icl_mode
389
+
390
+ def _prepare_generation(
391
+ self,
392
+ text: str,
393
+ ref_audio: Optional[Union[str, Path, tuple]] = None,
394
+ ref_text: str = "",
395
+ language: str = "English",
396
+ xvec_only: bool = False,
397
+ non_streaming_mode: bool = False,
398
+ append_silence: bool = True,
399
+ voice_clone_prompt: Optional[Union[Dict[str, Any], List[Any]]] = None,
400
+ instruct: Optional[str] = None,
401
+ ):
402
+ """Prepare inputs for generation (shared by streaming and non-streaming).
403
+
404
+ Args:
405
+ xvec_only: When True, use only the speaker embedding (x-vector) for voice
406
+ cloning instead of the full ICL acoustic prompt. This prevents the model from
407
+ continuing the reference audio's last phoneme and allows natural language switching.
408
+ Default False to match upstream ICL behavior, where the full reference
409
+ audio codec tokens are included in context.
410
+ voice_clone_prompt: Optional precomputed prompt dict from
411
+ `create_voice_clone_prompt`/`_prompt_items_to_voice_clone_prompt`.
412
+ When provided, `xvec_only` is ignored. This path supports both:
413
+ x-vector-only prompts (`ref_spk_embedding` only) and ICL prompts
414
+ (`ref_spk_embedding` + `ref_code` + mode flags). `ref_text` is ignored
415
+ for x-vector-only and required for ICL.
416
+ instruct: Optional instruction string to guide generation style/language (e.g.
417
+ "请用纯正广东话朗读"). Prepended as a user turn before the assistant TTS turn.
418
+ """
419
+ input_texts = [self.model._build_assistant_text(text)]
420
+ input_ids = self.model._tokenize_texts(input_texts)
421
+
422
+ instruct_ids = [None]
423
+ if instruct:
424
+ instruct_ids = [self.model._tokenize_texts([self.model._build_instruct_text(instruct)])[0]]
425
+
426
+ vcp, ref_ids, using_icl_mode = self._resolve_voice_clone_prompt(
427
+ input_ids=input_ids,
428
+ ref_audio=ref_audio,
429
+ ref_text=ref_text,
430
+ xvec_only=xvec_only,
431
+ append_silence=append_silence,
432
+ voice_clone_prompt=voice_clone_prompt,
433
+ )
434
+
435
+ if instruct and not using_icl_mode:
436
+ logger.warning(
437
+ "Base-model instruct with x-vector-only voice cloning is experimental. "
438
+ "Upstream Qwen3-TTS itself does not follow instructions reliably in this "
439
+ "mode. Prefer xvec_only=False (ICL mode) when using instruct for voice "
440
+ "cloning."
441
+ )
442
+
443
+ m = self.model.model
444
+
445
+ tie, tam, tth, tpe = self._build_talker_inputs_local(
446
+ m=m,
447
+ input_ids=input_ids,
448
+ ref_ids=ref_ids,
449
+ voice_clone_prompt=vcp,
450
+ languages=[language] if language is not None else ["Auto"],
451
+ speakers=None,
452
+ non_streaming_mode=non_streaming_mode,
453
+ instruct_ids=instruct_ids,
454
+ )
455
+
456
+ if not self._warmed_up:
457
+ self._warmup(tie.shape[1])
458
+
459
+ talker = m.talker
460
+ config = m.config.talker_config
461
+ talker.rope_deltas = None
462
+
463
+ # For ICL mode: return ref_codes so the decoder can use them as acoustic context
464
+ ref_codes = None
465
+ if using_icl_mode and vcp.get("ref_code") and vcp["ref_code"][0] is not None:
466
+ ref_codes = vcp["ref_code"][0]
467
+
468
+ return m, talker, config, tie, tam, tth, tpe, ref_codes
469
+
470
+ def _prepare_generation_custom(
471
+ self,
472
+ text: str,
473
+ language: str,
474
+ speaker: Optional[str],
475
+ instruct: Optional[str] = None,
476
+ non_streaming_mode: bool = True,
477
+ ):
478
+ input_texts = [self.model._build_assistant_text(text)]
479
+ input_ids = self.model._tokenize_texts(input_texts)
480
+
481
+ instruct_ids = []
482
+ if instruct is None or instruct == "":
483
+ instruct_ids.append(None)
484
+ else:
485
+ instruct_ids.append(self.model._tokenize_texts([self.model._build_instruct_text(instruct)])[0])
486
+
487
+ m = self.model.model
488
+ tie, tam, tth, tpe = self._build_talker_inputs_local(
489
+ m=m,
490
+ input_ids=input_ids,
491
+ ref_ids=[None],
492
+ voice_clone_prompt=None,
493
+ languages=[language] if language is not None else ["Auto"],
494
+ speakers=[speaker],
495
+ non_streaming_mode=non_streaming_mode,
496
+ instruct_ids=instruct_ids,
497
+ )
498
+
499
+ if not self._warmed_up:
500
+ self._warmup(tie.shape[1])
501
+
502
+ talker = m.talker
503
+ config = m.config.talker_config
504
+ talker.rope_deltas = None
505
+
506
+ return m, talker, config, tie, tam, tth, tpe
507
+
508
+ def _build_talker_inputs_local(
509
+ self,
510
+ m,
511
+ input_ids,
512
+ ref_ids,
513
+ voice_clone_prompt,
514
+ languages,
515
+ speakers,
516
+ non_streaming_mode: bool,
517
+ instruct_ids=None,
518
+ ):
519
+ """Local copy of upstream talker input building for qwen-tts main repo."""
520
+ talker_input_embeds = [[] for _ in range(len(input_ids))]
521
+
522
+ voice_clone_spk_embeds = None
523
+ if voice_clone_prompt is not None:
524
+ voice_clone_spk_embeds = m.generate_speaker_prompt(voice_clone_prompt)
525
+
526
+ if instruct_ids is not None:
527
+ for index, instruct_id in enumerate(instruct_ids):
528
+ if instruct_id is not None:
529
+ talker_input_embeds[index].append(
530
+ m.talker.text_projection(m.talker.get_text_embeddings()(instruct_id))
531
+ )
532
+
533
+ if speakers is None:
534
+ speakers = [None] * len(input_ids)
535
+
536
+ trailing_text_hiddens = []
537
+ tts_pad_embed = None
538
+
539
+ for index, (input_id, language, speaker) in enumerate(zip(input_ids, languages, speakers)):
540
+ if voice_clone_spk_embeds is None:
541
+ if speaker == "" or speaker is None:
542
+ speaker_embed = None
543
+ else:
544
+ if speaker.lower() not in m.config.talker_config.spk_id:
545
+ raise NotImplementedError(f"Speaker {speaker} not implemented")
546
+ spk_id = m.config.talker_config.spk_id[speaker.lower()]
547
+ speaker_embed = m.talker.get_input_embeddings()(
548
+ torch.tensor(spk_id, device=m.talker.device, dtype=input_id.dtype)
549
+ )
550
+ else:
551
+ if voice_clone_prompt["x_vector_only_mode"][index] or voice_clone_prompt["icl_mode"][index]:
552
+ speaker_embed = voice_clone_spk_embeds[index]
553
+ else:
554
+ speaker_embed = None
555
+
556
+ assert language is not None
557
+ if language.lower() == "auto":
558
+ language_id = None
559
+ else:
560
+ if language.lower() not in m.config.talker_config.codec_language_id:
561
+ raise NotImplementedError(f"Language {language} not implemented")
562
+ language_id = m.config.talker_config.codec_language_id[language.lower()]
563
+
564
+ if (
565
+ language.lower() in ["chinese", "auto"]
566
+ and speaker not in ("", None)
567
+ and m.config.talker_config.spk_is_dialect[speaker.lower()]
568
+ ):
569
+ dialect = m.config.talker_config.spk_is_dialect[speaker.lower()]
570
+ language_id = m.config.talker_config.codec_language_id[dialect]
571
+
572
+ tts_bos_embed, tts_eos_embed, tts_pad_embed = m.talker.text_projection(
573
+ m.talker.get_text_embeddings()(
574
+ torch.tensor(
575
+ [[m.config.tts_bos_token_id, m.config.tts_eos_token_id, m.config.tts_pad_token_id]],
576
+ device=m.talker.device,
577
+ dtype=input_id.dtype,
578
+ )
579
+ )
580
+ ).chunk(3, dim=1)
581
+
582
+ if language_id is None:
583
+ codec_prefill_list = [[
584
+ m.config.talker_config.codec_nothink_id,
585
+ m.config.talker_config.codec_think_bos_id,
586
+ m.config.talker_config.codec_think_eos_id,
587
+ ]]
588
+ else:
589
+ codec_prefill_list = [[
590
+ m.config.talker_config.codec_think_id,
591
+ m.config.talker_config.codec_think_bos_id,
592
+ language_id,
593
+ m.config.talker_config.codec_think_eos_id,
594
+ ]]
595
+
596
+ codec_input_emebdding_0 = m.talker.get_input_embeddings()(
597
+ torch.tensor(codec_prefill_list, device=m.talker.device, dtype=input_id.dtype)
598
+ )
599
+ codec_input_emebdding_1 = m.talker.get_input_embeddings()(
600
+ torch.tensor(
601
+ [[m.config.talker_config.codec_pad_id, m.config.talker_config.codec_bos_id]],
602
+ device=m.talker.device,
603
+ dtype=input_id.dtype,
604
+ )
605
+ )
606
+ if speaker_embed is None:
607
+ codec_input_emebdding = torch.cat([codec_input_emebdding_0, codec_input_emebdding_1], dim=1)
608
+ else:
609
+ codec_input_emebdding = torch.cat([codec_input_emebdding_0, speaker_embed.view(1, 1, -1), codec_input_emebdding_1], dim=1)
610
+
611
+ _talker_input_embed_role = m.talker.text_projection(
612
+ m.talker.get_text_embeddings()(input_id[:, :3])
613
+ )
614
+ _talker_input_embed = torch.cat(
615
+ (
616
+ tts_pad_embed.expand(-1, codec_input_emebdding.shape[1] - 2, -1),
617
+ tts_bos_embed,
618
+ ),
619
+ dim=1,
620
+ ) + codec_input_emebdding[:, :-1]
621
+
622
+ talker_input_embed = torch.cat((_talker_input_embed_role, _talker_input_embed), dim=1)
623
+
624
+ if (
625
+ voice_clone_prompt is not None
626
+ and voice_clone_prompt.get("ref_code", None) is not None
627
+ and voice_clone_prompt["icl_mode"][index]
628
+ ):
629
+ icl_input_embed, trailing_text_hidden = m.generate_icl_prompt(
630
+ text_id=input_id[:, 3:-5],
631
+ ref_id=ref_ids[index][:, 3:-2],
632
+ ref_code=voice_clone_prompt["ref_code"][index].to(m.talker.device).clone(), # escape inference_mode context
633
+ tts_pad_embed=tts_pad_embed,
634
+ tts_eos_embed=tts_eos_embed,
635
+ non_streaming_mode=non_streaming_mode,
636
+ )
637
+ talker_input_embed = torch.cat([talker_input_embed, icl_input_embed], dim=1)
638
+ else:
639
+ talker_input_embed = torch.cat(
640
+ [
641
+ talker_input_embed,
642
+ m.talker.text_projection(
643
+ m.talker.get_text_embeddings()(input_id[:, 3:4])
644
+ )
645
+ + codec_input_emebdding[:, -1:],
646
+ ],
647
+ dim=1,
648
+ )
649
+ if non_streaming_mode:
650
+ talker_input_embed = talker_input_embed[:, :-1]
651
+ talker_input_embed = torch.cat(
652
+ [
653
+ talker_input_embed,
654
+ torch.cat(
655
+ (
656
+ m.talker.text_projection(
657
+ m.talker.get_text_embeddings()(input_id[:, 3:-5])
658
+ ),
659
+ tts_eos_embed,
660
+ ),
661
+ dim=1,
662
+ )
663
+ + m.talker.get_input_embeddings()(
664
+ torch.tensor(
665
+ [[m.config.talker_config.codec_pad_id] * (input_id[:, 3:-5].shape[1] + 1)],
666
+ device=m.talker.device,
667
+ dtype=input_id.dtype,
668
+ )
669
+ ),
670
+ tts_pad_embed
671
+ + m.talker.get_input_embeddings()(
672
+ torch.tensor(
673
+ [[m.config.talker_config.codec_bos_id]],
674
+ device=m.talker.device,
675
+ dtype=input_id.dtype,
676
+ )
677
+ ),
678
+ ],
679
+ dim=1,
680
+ )
681
+ trailing_text_hidden = tts_pad_embed
682
+ else:
683
+ trailing_text_hidden = torch.cat(
684
+ (
685
+ m.talker.text_projection(
686
+ m.talker.get_text_embeddings()(input_id[:, 4:-5])
687
+ ),
688
+ tts_eos_embed,
689
+ ),
690
+ dim=1,
691
+ )
692
+
693
+ talker_input_embeds[index].append(talker_input_embed)
694
+ trailing_text_hiddens.append(trailing_text_hidden)
695
+
696
+ for index, talker_input_embed in enumerate(talker_input_embeds):
697
+ talker_input_embeds[index] = torch.cat([item for item in talker_input_embed if item is not None], dim=1)
698
+
699
+ original_lengths = torch.tensor([t.shape[1] for t in talker_input_embeds])
700
+ sequences = [t.squeeze(0) for t in talker_input_embeds]
701
+ sequences_reversed = [t.flip(dims=[0]) for t in sequences]
702
+ padded_reversed = torch.nn.utils.rnn.pad_sequence(
703
+ sequences_reversed,
704
+ batch_first=True,
705
+ padding_value=0.0,
706
+ )
707
+ talker_input_embeds = padded_reversed.flip(dims=[1])
708
+
709
+ batch_size, max_len = talker_input_embeds.shape[0], talker_input_embeds.shape[1]
710
+ indices = torch.arange(max_len).expand(batch_size, -1)
711
+ num_pads = max_len - original_lengths
712
+ talker_attention_mask = (indices >= num_pads.unsqueeze(1)).long().to(talker_input_embeds.device)
713
+
714
+ pad_embedding_vector = tts_pad_embed.squeeze()
715
+ sequences_to_pad = [t.squeeze(0) for t in trailing_text_hiddens]
716
+ trailing_text_original_lengths = [s.shape[0] for s in sequences_to_pad]
717
+ padded_hiddens = torch.nn.utils.rnn.pad_sequence(
718
+ sequences_to_pad,
719
+ batch_first=True,
720
+ padding_value=0.0,
721
+ )
722
+ arange_tensor = torch.arange(max(trailing_text_original_lengths), device=padded_hiddens.device).expand(
723
+ len(trailing_text_original_lengths), -1
724
+ )
725
+ lengths_tensor = torch.tensor(trailing_text_original_lengths, device=padded_hiddens.device).unsqueeze(1)
726
+ padding_mask = arange_tensor >= lengths_tensor
727
+ padded_hiddens[padding_mask] = pad_embedding_vector
728
+ trailing_text_hiddens = padded_hiddens
729
+
730
+ return talker_input_embeds, talker_attention_mask, trailing_text_hiddens, tts_pad_embed
731
+
732
+ @torch.inference_mode()
733
+ def generate_voice_clone(
734
+ self,
735
+ text: str,
736
+ language: str,
737
+ ref_audio: Optional[Union[str, Path, tuple]] = None,
738
+ ref_text: str = "",
739
+ max_new_tokens: int = 2048,
740
+ min_new_tokens: int = 2,
741
+ temperature: float = 0.9,
742
+ top_k: int = 50,
743
+ top_p: float = 1.0,
744
+ do_sample: bool = True,
745
+ repetition_penalty: float = 1.05,
746
+ xvec_only: bool = False,
747
+ non_streaming_mode: bool = False,
748
+ append_silence: bool = True,
749
+ instruct: Optional[str] = None,
750
+ voice_clone_prompt: Optional[Union[Dict[str, Any], List[Any]]] = None,
751
+ ) -> Tuple[list, int]:
752
+ """
753
+ Generate speech with voice cloning using reference audio.
754
+
755
+ Args:
756
+ text: Text to synthesize
757
+ language: Target language
758
+ ref_audio: Path to reference audio file. Required when `voice_clone_prompt` is not provided.
759
+ ref_text: Transcription of reference audio.
760
+ max_new_tokens: Maximum tokens to generate
761
+ min_new_tokens: Minimum tokens before EOS is allowed
762
+ temperature: Sampling temperature
763
+ top_k: Top-k sampling
764
+ top_p: Top-p (nucleus) sampling
765
+ do_sample: Whether to sample
766
+ repetition_penalty: Repetition penalty
767
+ xvec_only: When True, use only the speaker embedding for voice cloning.
768
+ This prevents phoneme bleed-through from the reference and allows clean
769
+ language switching. Default False to match upstream ICL behavior
770
+ (reference audio in context).
771
+ non_streaming_mode: Match upstream text-feeding layout. Default False to match
772
+ upstream step-by-step text feeding during decode.
773
+ voice_clone_prompt: Optional precomputed voice clone prompt dict. When provided,
774
+ `xvec_only` is ignored and prompt extraction from `ref_audio` is skipped.
775
+ This path supports x-vector-only prompts (`ref_spk_embedding` only)
776
+ and ICL prompts (`ref_spk_embedding` + `ref_code` + mode flags).
777
+ `ref_text` is ignored for x-vector-only and required for ICL.
778
+ instruct: Optional instruction to guide generation style/dialect (e.g.
779
+ "请用纯正广东话朗读"). Prepended as a user turn before the TTS assistant turn.
780
+ Experimental for x-vector-only voice cloning; prefer `xvec_only=False`.
781
+
782
+ Returns:
783
+ Tuple of ([audio_waveform], sample_rate)
784
+ """
785
+ from .generate import fast_generate
786
+
787
+ m, talker, config, tie, tam, tth, tpe, ref_codes = self._prepare_generation(
788
+ text=text,
789
+ language=language,
790
+ ref_audio=ref_audio,
791
+ ref_text=ref_text,
792
+ xvec_only=xvec_only,
793
+ non_streaming_mode=non_streaming_mode,
794
+ append_silence=append_silence,
795
+ voice_clone_prompt=voice_clone_prompt,
796
+ instruct=instruct,
797
+ )
798
+
799
+ codec_ids, timing = fast_generate(
800
+ talker=talker,
801
+ talker_input_embeds=tie,
802
+ attention_mask=tam,
803
+ trailing_text_hiddens=tth,
804
+ tts_pad_embed=tpe,
805
+ config=config,
806
+ predictor_graph=self.predictor_graph,
807
+ talker_graph=self.talker_graph,
808
+ max_new_tokens=max_new_tokens,
809
+ min_new_tokens=min_new_tokens,
810
+ temperature=temperature,
811
+ top_k=top_k,
812
+ top_p=top_p,
813
+ do_sample=do_sample,
814
+ repetition_penalty=repetition_penalty,
815
+ )
816
+
817
+ if codec_ids is None:
818
+ logger.warning("Generation returned no tokens")
819
+ return [np.zeros(1, dtype=np.float32)], self.sample_rate
820
+
821
+ # In ICL mode: prepend reference codes before decoding so the codec decoder
822
+ # has acoustic context from the reference audio (matches official implementation).
823
+ speech_tokenizer = m.speech_tokenizer
824
+ if ref_codes is not None:
825
+ ref_codes_dev = ref_codes.to(codec_ids.device)
826
+ codes_for_decode = torch.cat([ref_codes_dev, codec_ids], dim=0)
827
+ else:
828
+ codes_for_decode = codec_ids
829
+ audio_list, sr = speech_tokenizer.decode({"audio_codes": codes_for_decode.unsqueeze(0)})
830
+
831
+ # Convert to numpy and trim off the reference audio portion
832
+ ref_len = ref_codes.shape[0] if ref_codes is not None else 0
833
+ total_len = codes_for_decode.shape[0]
834
+ audio_arrays = []
835
+ for a in audio_list:
836
+ if hasattr(a, 'cpu'): # torch tensor
837
+ a = a.flatten().cpu().numpy()
838
+ else: # already numpy
839
+ a = a.flatten() if hasattr(a, 'flatten') else a
840
+ if ref_len > 0:
841
+ cut = int(ref_len / max(total_len, 1) * len(a))
842
+ a = a[cut:]
843
+ audio_arrays.append(a)
844
+
845
+ n_steps = timing['steps']
846
+ audio_duration = n_steps / 12.0 # 12 Hz codec
847
+ total_time = timing['prefill_ms']/1000 + timing['decode_s']
848
+ rtf = audio_duration / total_time if total_time > 0 else 0
849
+
850
+ logger.info(
851
+ f"Generated {audio_duration:.2f}s audio in {total_time:.2f}s "
852
+ f"({timing['ms_per_step']:.1f}ms/step, RTF: {rtf:.2f})"
853
+ )
854
+
855
+ return audio_arrays, sr
856
+
857
+ @torch.inference_mode()
858
+ def generate_voice_clone_streaming(
859
+ self,
860
+ text: str,
861
+ language: str,
862
+ ref_audio: Optional[Union[str, Path]] = None,
863
+ ref_text: str = "",
864
+ max_new_tokens: int = 2048,
865
+ min_new_tokens: int = 2,
866
+ temperature: float = 0.9,
867
+ top_k: int = 50,
868
+ top_p: float = 1.0,
869
+ do_sample: bool = True,
870
+ repetition_penalty: float = 1.05,
871
+ chunk_size: int = 12,
872
+ xvec_only: bool = False,
873
+ non_streaming_mode: bool = False,
874
+ append_silence: bool = True,
875
+ parity_mode: bool = False,
876
+ instruct: Optional[str] = None,
877
+ voice_clone_prompt: Optional[Union[Dict[str, Any], List[Any]]] = None,
878
+ ) -> Generator[Tuple[np.ndarray, int, dict], None, None]:
879
+ """
880
+ Stream voice-cloned speech generation, yielding audio chunks.
881
+
882
+ Same as generate_voice_clone() but yields (audio_chunk, sample_rate, timing)
883
+ tuples every chunk_size codec steps (~chunk_size/12 seconds of audio).
884
+
885
+ Args:
886
+ text: Text to synthesize
887
+ language: Target language
888
+ ref_audio: Path to reference audio file. Required when `voice_clone_prompt` is not provided.
889
+ ref_text: Transcription of reference audio.
890
+ max_new_tokens: Maximum tokens to generate
891
+ min_new_tokens: Minimum tokens before EOS is allowed
892
+ temperature: Sampling temperature
893
+ top_k: Top-k sampling
894
+ top_p: Top-p (nucleus) sampling
895
+ do_sample: Whether to sample
896
+ repetition_penalty: Repetition penalty
897
+ chunk_size: Codec steps per chunk (12 = ~1 second)
898
+ xvec_only: When True, use only the speaker embedding for voice cloning.
899
+ This prevents phoneme bleed-through from the reference and allows clean
900
+ language switching. Default False to match upstream ICL behavior
901
+ (reference audio in context).
902
+ non_streaming_mode: Default False to match upstream text feeding during decode.
903
+ Set to True to prefill the full target text before streaming decode.
904
+ parity_mode: When True, disables CUDA graphs and uses dynamic cache streaming.
905
+ voice_clone_prompt: Optional precomputed voice clone prompt dict. When provided,
906
+ `xvec_only` is ignored and prompt extraction from `ref_audio` is skipped.
907
+ This path supports x-vector-only prompts (`ref_spk_embedding` only)
908
+ and ICL prompts (`ref_spk_embedding` + `ref_code` + mode flags).
909
+ `ref_text` is ignored for x-vector-only and required for ICL.
910
+ instruct: Optional instruction to guide generation style/dialect (e.g.
911
+ "请用纯正广东话朗读"). Prepended as a user turn before the TTS assistant turn.
912
+ Experimental for x-vector-only voice cloning; prefer `xvec_only=False`.
913
+
914
+ Yields:
915
+ Tuple of (audio_chunk_numpy, sample_rate, timing_dict)
916
+ """
917
+ from .streaming import fast_generate_streaming, parity_generate_streaming
918
+
919
+ m, talker, config, tie, tam, tth, tpe, ref_codes = self._prepare_generation(
920
+ text=text,
921
+ language=language,
922
+ ref_audio=ref_audio,
923
+ ref_text=ref_text,
924
+ xvec_only=xvec_only,
925
+ non_streaming_mode=non_streaming_mode,
926
+ append_silence=append_silence,
927
+ voice_clone_prompt=voice_clone_prompt,
928
+ instruct=instruct,
929
+ )
930
+
931
+ speech_tokenizer = m.speech_tokenizer
932
+
933
+ # Hybrid decode strategy:
934
+ # 1. Accumulated decode for early chunks (correct, calibrates samples_per_frame)
935
+ # 2. Sliding window with 25-frame left context once calibrated (constant cost)
936
+ # This avoids boundary artifacts (pops) while keeping decode cost bounded.
937
+ context_frames = 25
938
+ min_calibration_frames = max(context_frames, chunk_size)
939
+ all_codes = []
940
+ prev_gen_audio_len = 0 # tracks position within the generated (non-ref) audio
941
+ samples_per_frame = None
942
+
943
+ stream_fn = parity_generate_streaming if parity_mode else fast_generate_streaming
944
+ stream_kwargs = dict(
945
+ talker=talker,
946
+ talker_input_embeds=tie,
947
+ attention_mask=tam,
948
+ trailing_text_hiddens=tth,
949
+ tts_pad_embed=tpe,
950
+ config=config,
951
+ max_new_tokens=max_new_tokens,
952
+ min_new_tokens=min_new_tokens,
953
+ temperature=temperature,
954
+ top_k=top_k,
955
+ top_p=top_p,
956
+ do_sample=do_sample,
957
+ repetition_penalty=repetition_penalty,
958
+ chunk_size=chunk_size,
959
+ )
960
+ if not parity_mode:
961
+ stream_kwargs["predictor_graph"] = self.predictor_graph
962
+ stream_kwargs["talker_graph"] = self.talker_graph
963
+
964
+ for codec_chunk, timing in stream_fn(**stream_kwargs):
965
+ all_codes.append(codec_chunk)
966
+ n_new = codec_chunk.shape[0]
967
+ all_flat = torch.cat(all_codes, dim=0)
968
+ n_total = all_flat.shape[0]
969
+
970
+ if samples_per_frame is None:
971
+ # Phase 1: accumulated decode until we can calibrate.
972
+ # In ICL mode prepend reference codes so the codec decoder has acoustic
973
+ # context from the reference audio (matches official implementation).
974
+ if ref_codes is not None:
975
+ codes_input = torch.cat([ref_codes.to(all_flat.device), all_flat], dim=0)
976
+ else:
977
+ codes_input = all_flat
978
+ audio_list, sr = speech_tokenizer.decode(
979
+ {"audio_codes": codes_input.unsqueeze(0)}
980
+ )
981
+ audio = audio_list[0]
982
+ if hasattr(audio, 'cpu'):
983
+ audio = audio.flatten().cpu().numpy()
984
+ else:
985
+ audio = audio.flatten() if hasattr(audio, 'flatten') else audio
986
+
987
+ # Separate out reference audio portion; track position in generated audio only
988
+ if ref_codes is not None:
989
+ ref_len = ref_codes.shape[0]
990
+ total_len = codes_input.shape[0]
991
+ ref_audio_cut = int(ref_len / max(total_len, 1) * len(audio))
992
+ gen_audio = audio[ref_audio_cut:]
993
+ else:
994
+ gen_audio = audio
995
+
996
+ new_audio = gen_audio[prev_gen_audio_len:]
997
+ prev_gen_audio_len = len(gen_audio)
998
+
999
+ if n_total >= min_calibration_frames:
1000
+ samples_per_frame = len(gen_audio) / n_total
1001
+ else:
1002
+ # Phase 2: sliding window with left context
1003
+ ctx_start = max(0, n_total - n_new - context_frames)
1004
+ window = all_flat[ctx_start:]
1005
+ n_ctx = window.shape[0] - n_new
1006
+
1007
+ audio_list, sr = speech_tokenizer.decode(
1008
+ {"audio_codes": window.unsqueeze(0)}
1009
+ )
1010
+ audio = audio_list[0]
1011
+ if hasattr(audio, 'cpu'):
1012
+ audio = audio.flatten().cpu().numpy()
1013
+ else:
1014
+ audio = audio.flatten() if hasattr(audio, 'flatten') else audio
1015
+
1016
+ if n_ctx > 0:
1017
+ ctx_samples = int(round(n_ctx * samples_per_frame))
1018
+ new_audio = audio[ctx_samples:]
1019
+ else:
1020
+ new_audio = audio
1021
+
1022
+ yield new_audio, sr, timing
1023
+
1024
+ @torch.inference_mode()
1025
+ def generate_custom_voice(
1026
+ self,
1027
+ text: str,
1028
+ speaker: str,
1029
+ language: str,
1030
+ instruct: Optional[str] = None,
1031
+ non_streaming_mode: bool = True,
1032
+ max_new_tokens: int = 2048,
1033
+ min_new_tokens: int = 2,
1034
+ temperature: float = 0.9,
1035
+ top_k: int = 50,
1036
+ top_p: float = 1.0,
1037
+ do_sample: bool = True,
1038
+ repetition_penalty: float = 1.05,
1039
+ ) -> Tuple[list, int]:
1040
+ if self.model.model.tts_model_type != "custom_voice":
1041
+ raise ValueError("Loaded model does not support custom voice generation")
1042
+
1043
+ self.model._validate_languages([language])
1044
+ self.model._validate_speakers([speaker])
1045
+
1046
+ if self.model.model.tts_model_size in "0b6":
1047
+ instruct = None
1048
+
1049
+ from .generate import fast_generate
1050
+
1051
+ m, talker, config, tie, tam, tth, tpe = self._prepare_generation_custom(
1052
+ text=text,
1053
+ language=language,
1054
+ speaker=speaker,
1055
+ instruct=instruct,
1056
+ non_streaming_mode=non_streaming_mode,
1057
+ )
1058
+
1059
+ codec_ids, timing = fast_generate(
1060
+ talker=talker,
1061
+ talker_input_embeds=tie,
1062
+ attention_mask=tam,
1063
+ trailing_text_hiddens=tth,
1064
+ tts_pad_embed=tpe,
1065
+ config=config,
1066
+ predictor_graph=self.predictor_graph,
1067
+ talker_graph=self.talker_graph,
1068
+ max_new_tokens=max_new_tokens,
1069
+ min_new_tokens=min_new_tokens,
1070
+ temperature=temperature,
1071
+ top_k=top_k,
1072
+ top_p=top_p,
1073
+ do_sample=do_sample,
1074
+ repetition_penalty=repetition_penalty,
1075
+ )
1076
+
1077
+ if codec_ids is None:
1078
+ logger.warning("Generation returned no tokens")
1079
+ return [np.zeros(1, dtype=np.float32)], self.sample_rate
1080
+
1081
+ speech_tokenizer = m.speech_tokenizer
1082
+ audio_list, sr = speech_tokenizer.decode({"audio_codes": codec_ids.unsqueeze(0)})
1083
+
1084
+ audio_arrays = []
1085
+ for a in audio_list:
1086
+ if hasattr(a, "cpu"):
1087
+ audio_arrays.append(a.flatten().cpu().numpy())
1088
+ else:
1089
+ audio_arrays.append(a.flatten() if hasattr(a, "flatten") else a)
1090
+
1091
+ n_steps = timing["steps"]
1092
+ audio_duration = n_steps / 12.0
1093
+ total_time = timing["prefill_ms"] / 1000 + timing["decode_s"]
1094
+ rtf = audio_duration / total_time if total_time > 0 else 0
1095
+
1096
+ logger.info(
1097
+ f"Generated {audio_duration:.2f}s audio in {total_time:.2f}s "
1098
+ f"({timing['ms_per_step']:.1f}ms/step, RTF: {rtf:.2f})"
1099
+ )
1100
+
1101
+ return audio_arrays, sr
1102
+
1103
+ @torch.inference_mode()
1104
+ def generate_custom_voice_streaming(
1105
+ self,
1106
+ text: str,
1107
+ speaker: str,
1108
+ language: str,
1109
+ instruct: Optional[str] = None,
1110
+ non_streaming_mode: bool = True,
1111
+ max_new_tokens: int = 2048,
1112
+ min_new_tokens: int = 2,
1113
+ temperature: float = 0.9,
1114
+ top_k: int = 50,
1115
+ top_p: float = 1.0,
1116
+ do_sample: bool = True,
1117
+ repetition_penalty: float = 1.05,
1118
+ chunk_size: int = 12,
1119
+ ) -> Generator[Tuple[np.ndarray, int, dict], None, None]:
1120
+ if self.model.model.tts_model_type != "custom_voice":
1121
+ raise ValueError("Loaded model does not support custom voice generation")
1122
+
1123
+ self.model._validate_languages([language])
1124
+ self.model._validate_speakers([speaker])
1125
+
1126
+ if self.model.model.tts_model_size in "0b6":
1127
+ instruct = None
1128
+
1129
+ from .streaming import fast_generate_streaming
1130
+
1131
+ m, talker, config, tie, tam, tth, tpe = self._prepare_generation_custom(
1132
+ text=text,
1133
+ language=language,
1134
+ speaker=speaker,
1135
+ instruct=instruct,
1136
+ non_streaming_mode=non_streaming_mode,
1137
+ )
1138
+
1139
+ speech_tokenizer = m.speech_tokenizer
1140
+
1141
+ context_frames = 25
1142
+ min_calibration_frames = max(context_frames, chunk_size)
1143
+ all_codes = []
1144
+ prev_audio_len = 0
1145
+ samples_per_frame = None
1146
+
1147
+ for codec_chunk, timing in fast_generate_streaming(
1148
+ talker=talker,
1149
+ talker_input_embeds=tie,
1150
+ attention_mask=tam,
1151
+ trailing_text_hiddens=tth,
1152
+ tts_pad_embed=tpe,
1153
+ config=config,
1154
+ predictor_graph=self.predictor_graph,
1155
+ talker_graph=self.talker_graph,
1156
+ max_new_tokens=max_new_tokens,
1157
+ min_new_tokens=min_new_tokens,
1158
+ temperature=temperature,
1159
+ top_k=top_k,
1160
+ top_p=top_p,
1161
+ do_sample=do_sample,
1162
+ repetition_penalty=repetition_penalty,
1163
+ chunk_size=chunk_size,
1164
+ ):
1165
+ all_codes.append(codec_chunk)
1166
+ n_new = codec_chunk.shape[0]
1167
+ all_flat = torch.cat(all_codes, dim=0)
1168
+ n_total = all_flat.shape[0]
1169
+
1170
+ if samples_per_frame is None:
1171
+ audio_list, sr = speech_tokenizer.decode({"audio_codes": all_flat.unsqueeze(0)})
1172
+ audio = audio_list[0]
1173
+ if hasattr(audio, "cpu"):
1174
+ audio = audio.flatten().cpu().numpy()
1175
+ else:
1176
+ audio = audio.flatten() if hasattr(audio, "flatten") else audio
1177
+
1178
+ new_audio = audio[prev_audio_len:]
1179
+ prev_audio_len = len(audio)
1180
+
1181
+ if n_total >= min_calibration_frames:
1182
+ samples_per_frame = len(audio) / n_total
1183
+ else:
1184
+ ctx_start = max(0, n_total - n_new - context_frames)
1185
+ window = all_flat[ctx_start:]
1186
+ n_ctx = window.shape[0] - n_new
1187
+
1188
+ audio_list, sr = speech_tokenizer.decode({"audio_codes": window.unsqueeze(0)})
1189
+ audio = audio_list[0]
1190
+ if hasattr(audio, "cpu"):
1191
+ audio = audio.flatten().cpu().numpy()
1192
+ else:
1193
+ audio = audio.flatten() if hasattr(audio, "flatten") else audio
1194
+
1195
+ if n_ctx > 0:
1196
+ ctx_samples = int(round(n_ctx * samples_per_frame))
1197
+ new_audio = audio[ctx_samples:]
1198
+ else:
1199
+ new_audio = audio
1200
+
1201
+ yield new_audio, sr, timing
1202
+
1203
+ @torch.inference_mode()
1204
+ def generate_voice_design(
1205
+ self,
1206
+ text: str,
1207
+ instruct: str,
1208
+ language: str,
1209
+ non_streaming_mode: bool = True,
1210
+ max_new_tokens: int = 2048,
1211
+ min_new_tokens: int = 2,
1212
+ temperature: float = 0.9,
1213
+ top_k: int = 50,
1214
+ top_p: float = 1.0,
1215
+ do_sample: bool = True,
1216
+ repetition_penalty: float = 1.05,
1217
+ ) -> Tuple[list, int]:
1218
+ if self.model.model.tts_model_type != "voice_design":
1219
+ raise ValueError("Loaded model does not support voice design generation")
1220
+
1221
+ self.model._validate_languages([language])
1222
+
1223
+ from .generate import fast_generate
1224
+
1225
+ m, talker, config, tie, tam, tth, tpe = self._prepare_generation_custom(
1226
+ text=text,
1227
+ language=language,
1228
+ speaker=None,
1229
+ instruct=instruct,
1230
+ non_streaming_mode=non_streaming_mode,
1231
+ )
1232
+
1233
+ codec_ids, timing = fast_generate(
1234
+ talker=talker,
1235
+ talker_input_embeds=tie,
1236
+ attention_mask=tam,
1237
+ trailing_text_hiddens=tth,
1238
+ tts_pad_embed=tpe,
1239
+ config=config,
1240
+ predictor_graph=self.predictor_graph,
1241
+ talker_graph=self.talker_graph,
1242
+ max_new_tokens=max_new_tokens,
1243
+ min_new_tokens=min_new_tokens,
1244
+ temperature=temperature,
1245
+ top_k=top_k,
1246
+ top_p=top_p,
1247
+ do_sample=do_sample,
1248
+ repetition_penalty=repetition_penalty,
1249
+ )
1250
+
1251
+ if codec_ids is None:
1252
+ logger.warning("Generation returned no tokens")
1253
+ return [np.zeros(1, dtype=np.float32)], self.sample_rate
1254
+
1255
+ speech_tokenizer = m.speech_tokenizer
1256
+ audio_list, sr = speech_tokenizer.decode({"audio_codes": codec_ids.unsqueeze(0)})
1257
+
1258
+ audio_arrays = []
1259
+ for a in audio_list:
1260
+ if hasattr(a, "cpu"):
1261
+ audio_arrays.append(a.flatten().cpu().numpy())
1262
+ else:
1263
+ audio_arrays.append(a.flatten() if hasattr(a, "flatten") else a)
1264
+
1265
+ n_steps = timing["steps"]
1266
+ audio_duration = n_steps / 12.0
1267
+ total_time = timing["prefill_ms"] / 1000 + timing["decode_s"]
1268
+ rtf = audio_duration / total_time if total_time > 0 else 0
1269
+
1270
+ logger.info(
1271
+ f"Generated {audio_duration:.2f}s audio in {total_time:.2f}s "
1272
+ f"({timing['ms_per_step']:.1f}ms/step, RTF: {rtf:.2f})"
1273
+ )
1274
+
1275
+ return audio_arrays, sr
1276
+
1277
+ @torch.inference_mode()
1278
+ def generate_voice_design_streaming(
1279
+ self,
1280
+ text: str,
1281
+ instruct: str,
1282
+ language: str,
1283
+ non_streaming_mode: bool = True,
1284
+ max_new_tokens: int = 2048,
1285
+ min_new_tokens: int = 2,
1286
+ temperature: float = 0.9,
1287
+ top_k: int = 50,
1288
+ top_p: float = 1.0,
1289
+ do_sample: bool = True,
1290
+ repetition_penalty: float = 1.05,
1291
+ chunk_size: int = 12,
1292
+ ) -> Generator[Tuple[np.ndarray, int, dict], None, None]:
1293
+ if self.model.model.tts_model_type != "voice_design":
1294
+ raise ValueError("Loaded model does not support voice design generation")
1295
+
1296
+ self.model._validate_languages([language])
1297
+
1298
+ from .streaming import fast_generate_streaming
1299
+
1300
+ m, talker, config, tie, tam, tth, tpe = self._prepare_generation_custom(
1301
+ text=text,
1302
+ language=language,
1303
+ speaker=None,
1304
+ instruct=instruct,
1305
+ non_streaming_mode=non_streaming_mode,
1306
+ )
1307
+
1308
+ speech_tokenizer = m.speech_tokenizer
1309
+
1310
+ context_frames = 25
1311
+ min_calibration_frames = max(context_frames, chunk_size)
1312
+ all_codes = []
1313
+ prev_audio_len = 0
1314
+ samples_per_frame = None
1315
+
1316
+ for codec_chunk, timing in fast_generate_streaming(
1317
+ talker=talker,
1318
+ talker_input_embeds=tie,
1319
+ attention_mask=tam,
1320
+ trailing_text_hiddens=tth,
1321
+ tts_pad_embed=tpe,
1322
+ config=config,
1323
+ predictor_graph=self.predictor_graph,
1324
+ talker_graph=self.talker_graph,
1325
+ max_new_tokens=max_new_tokens,
1326
+ min_new_tokens=min_new_tokens,
1327
+ temperature=temperature,
1328
+ top_k=top_k,
1329
+ top_p=top_p,
1330
+ do_sample=do_sample,
1331
+ repetition_penalty=repetition_penalty,
1332
+ chunk_size=chunk_size,
1333
+ ):
1334
+ all_codes.append(codec_chunk)
1335
+ n_new = codec_chunk.shape[0]
1336
+ all_flat = torch.cat(all_codes, dim=0)
1337
+ n_total = all_flat.shape[0]
1338
+
1339
+ if samples_per_frame is None:
1340
+ audio_list, sr = speech_tokenizer.decode({"audio_codes": all_flat.unsqueeze(0)})
1341
+ audio = audio_list[0]
1342
+ if hasattr(audio, "cpu"):
1343
+ audio = audio.flatten().cpu().numpy()
1344
+ else:
1345
+ audio = audio.flatten() if hasattr(audio, "flatten") else audio
1346
+
1347
+ new_audio = audio[prev_audio_len:]
1348
+ prev_audio_len = len(audio)
1349
+
1350
+ if n_total >= min_calibration_frames:
1351
+ samples_per_frame = len(audio) / n_total
1352
+ else:
1353
+ ctx_start = max(0, n_total - n_new - context_frames)
1354
+ window = all_flat[ctx_start:]
1355
+ n_ctx = window.shape[0] - n_new
1356
+
1357
+ audio_list, sr = speech_tokenizer.decode({"audio_codes": window.unsqueeze(0)})
1358
+ audio = audio_list[0]
1359
+ if hasattr(audio, "cpu"):
1360
+ audio = audio.flatten().cpu().numpy()
1361
+ else:
1362
+ audio = audio.flatten() if hasattr(audio, "flatten") else audio
1363
+
1364
+ if n_ctx > 0:
1365
+ ctx_samples = int(round(n_ctx * samples_per_frame))
1366
+ new_audio = audio[ctx_samples:]
1367
+ else:
1368
+ new_audio = audio
1369
+
1370
+ yield new_audio, sr, timing
faster_qwen3_tts/predictor_graph.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CUDA graph capture for the code predictor's 15-step decode loop,
4
+ using transformers StaticCache.
5
+
6
+ The predictor generates 15 codebooks autoregressively:
7
+ - Step 0: prefill with 2 tokens (past_hidden + first_codebook_embed), get logits[0]
8
+ - Steps 1-14: decode 1 token at a time using previous codebook token's embedding
9
+
10
+ Strategy:
11
+ - Use transformers StaticCache for KV cache management
12
+ - Use the predictor's inner model forward (handles mask, RoPE, attention internally)
13
+ - Unroll the full 15-step loop for deterministic shapes
14
+ - Capture the entire loop as a single CUDA graph
15
+ """
16
+ import torch
17
+ from transformers import StaticCache
18
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
19
+
20
+ from .sampling import sample_logits
21
+
22
+
23
+ class PredictorGraph:
24
+ """
25
+ Captures the full predictor 15-step loop as a CUDA graph,
26
+ using the model's forward with transformers StaticCache.
27
+
28
+ Usage:
29
+ mpg = PredictorGraph(code_predictor, pred_config, talker_hidden_size)
30
+ mpg.capture()
31
+ codebook_tokens = mpg.run(pred_input) # pred_input: [1, 2, H]
32
+ """
33
+
34
+ def __init__(self, code_predictor, pred_config, talker_hidden_size, device='cuda', dtype=torch.bfloat16,
35
+ do_sample=True, top_k=50, top_p=1.0, temperature=0.9):
36
+ self.device = device
37
+ device_index = torch.device(device).index
38
+ device_index = device_index if device_index is not None else torch.cuda.current_device()
39
+ self.device_index = device_index
40
+
41
+ self.dtype = dtype
42
+ self.num_layers = pred_config.num_hidden_layers
43
+ self.hidden_size = pred_config.hidden_size
44
+ self.num_code_groups = pred_config.num_code_groups
45
+ self.num_codebooks = self.num_code_groups - 1 # 15
46
+ self.max_seq = 2 + self.num_codebooks # 17
47
+ self.do_sample = do_sample
48
+ self.top_k = top_k
49
+ self.top_p = top_p
50
+ self.temperature = temperature
51
+
52
+ # Extract model components (references, not copies)
53
+ cp = code_predictor
54
+ self.small_to_mtp = cp.small_to_mtp_projection
55
+ self.pred_model = cp.model # Inner transformer model (5 layers)
56
+ self.lm_heads = cp.lm_head # ModuleList[15]
57
+ self.codec_embeds = cp.model.codec_embedding # ModuleList[15]
58
+ self.has_sliding_layers = "sliding_attention" in getattr(self.pred_model.config, "layer_types", [])
59
+
60
+ # Transformers StaticCache for the predictor
61
+ self.static_cache = StaticCache(config=pred_config, max_cache_len=self.max_seq)
62
+
63
+ # Pre-allocate cache_position tensors for each step (avoids CPU→GPU in graph)
64
+ self.prefill_cache_pos = torch.arange(2, device=device)
65
+ self.decode_cache_positions = [
66
+ torch.tensor([2 + i], device=device) for i in range(self.num_codebooks - 1)
67
+ ]
68
+
69
+ # I/O buffers
70
+ self.input_buf = torch.zeros(1, 2, talker_hidden_size, dtype=dtype, device=device)
71
+ self.output_tokens = torch.zeros(self.num_codebooks, dtype=torch.long, device=device)
72
+
73
+ self.graph = None
74
+ self.captured = False
75
+ self.prefill_attn = None
76
+ self.decode_attn = None
77
+
78
+ def _init_cache_layers(self):
79
+ """Force lazy initialization of StaticCache layers before graph capture."""
80
+ config = self.pred_model.config
81
+ num_kv_heads = getattr(config, 'num_key_value_heads', config.num_attention_heads)
82
+ head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
83
+ dummy_k = torch.zeros(1, num_kv_heads, 1, head_dim, dtype=self.dtype, device=self.device)
84
+ for layer in self.static_cache.layers:
85
+ if not layer.is_initialized:
86
+ layer.lazy_initialization(dummy_k)
87
+
88
+ def _make_attn_mask(self, input_embeds: torch.Tensor, cache_position: torch.Tensor):
89
+ mask = create_causal_mask(
90
+ config=self.pred_model.config,
91
+ input_embeds=input_embeds,
92
+ attention_mask=None,
93
+ cache_position=cache_position,
94
+ past_key_values=self.static_cache,
95
+ )
96
+ if self.has_sliding_layers:
97
+ sliding = create_sliding_window_causal_mask(
98
+ config=self.pred_model.config,
99
+ input_embeds=input_embeds,
100
+ attention_mask=None,
101
+ cache_position=cache_position,
102
+ past_key_values=self.static_cache,
103
+ )
104
+ return {"full_attention": mask, "sliding_attention": sliding}
105
+ return {"full_attention": mask}
106
+
107
+ def _build_attention_masks(self):
108
+ dummy_prefill = torch.zeros(1, 2, self.hidden_size, dtype=self.dtype, device=self.device)
109
+ dummy_decode = torch.zeros(1, 1, self.hidden_size, dtype=self.dtype, device=self.device)
110
+ self.prefill_attn = self._make_attn_mask(dummy_prefill, self.prefill_cache_pos)
111
+ self.decode_attn = []
112
+ for pos in self.decode_cache_positions:
113
+ self.decode_attn.append(self._make_attn_mask(dummy_decode, pos))
114
+
115
+ def _full_loop(self):
116
+ """The full 15-step predictor loop on static buffers."""
117
+ # Project input from talker hidden size to predictor hidden size
118
+ h = self.small_to_mtp(self.input_buf) # [1, 2, hidden]
119
+
120
+ # Prefill: 2 tokens through all layers
121
+ out = self.pred_model(
122
+ inputs_embeds=h,
123
+ attention_mask=self.prefill_attn,
124
+ past_key_values=self.static_cache,
125
+ cache_position=self.prefill_cache_pos,
126
+ use_cache=True,
127
+ )
128
+ h = out.last_hidden_state # [1, 2, hidden] — already normalized
129
+
130
+ # First codebook: logits from last position
131
+ logits = self.lm_heads[0](h[:, -1:, :]) # [1, 1, vocab]
132
+ tok = sample_logits(
133
+ logits[:, 0, :],
134
+ temperature=self.temperature,
135
+ top_k=self.top_k,
136
+ top_p=self.top_p,
137
+ do_sample=self.do_sample,
138
+ )
139
+ self.output_tokens[0] = tok[0]
140
+
141
+ # Remaining 14 codebooks
142
+ for cb_idx in range(1, self.num_codebooks):
143
+ # Embed previous token using codebook-specific embedding
144
+ emb = self.codec_embeds[cb_idx - 1](tok.unsqueeze(0)) # [1, 1, codec_hidden]
145
+ emb = self.small_to_mtp(emb) # [1, 1, hidden]
146
+
147
+ # Single-token decode through all layers
148
+ out = self.pred_model(
149
+ inputs_embeds=emb,
150
+ attention_mask=self.decode_attn[cb_idx - 1],
151
+ past_key_values=self.static_cache,
152
+ cache_position=self.decode_cache_positions[cb_idx - 1],
153
+ use_cache=True,
154
+ )
155
+ h = out.last_hidden_state
156
+
157
+ logits = self.lm_heads[cb_idx](h[:, -1:, :])
158
+ tok = sample_logits(
159
+ logits[:, 0, :],
160
+ temperature=self.temperature,
161
+ top_k=self.top_k,
162
+ top_p=self.top_p,
163
+ do_sample=self.do_sample,
164
+ )
165
+ self.output_tokens[cb_idx] = tok[0]
166
+
167
+ return self.output_tokens
168
+
169
+ @torch.inference_mode()
170
+ def capture(self, num_warmup=3):
171
+ """Warmup and capture the CUDA graph."""
172
+ print(f"Warming up predictor ({num_warmup} runs)...")
173
+
174
+ # Force cache initialization before graph capture
175
+ self._init_cache_layers()
176
+ self._build_attention_masks()
177
+
178
+ for _ in range(num_warmup):
179
+ self.static_cache.reset()
180
+ self._full_loop()
181
+ torch.cuda.synchronize()
182
+
183
+ print("Capturing CUDA graph for predictor...")
184
+
185
+ with torch.cuda.device(self.device_index):
186
+ s = torch.cuda.Stream()
187
+ s.wait_stream(torch.cuda.current_stream())
188
+ with torch.cuda.stream(s):
189
+ self.graph = torch.cuda.CUDAGraph()
190
+ # Warmup in capture stream
191
+ self.static_cache.reset()
192
+ self._full_loop()
193
+ torch.cuda.synchronize()
194
+
195
+ self.static_cache.reset()
196
+ with torch.cuda.graph(self.graph):
197
+ self._full_loop()
198
+
199
+ torch.cuda.current_stream().wait_stream(s)
200
+ torch.cuda.synchronize()
201
+ self.captured = True
202
+ print("CUDA graph captured!")
203
+
204
+ @torch.inference_mode()
205
+ def run(self, pred_input: torch.Tensor) -> torch.Tensor:
206
+ """
207
+ Run the captured graph.
208
+ pred_input: [1, 2, talker_hidden_size] (past_hidden cat first_codebook_embed)
209
+ Returns: [15] long tensor of codebook tokens
210
+ """
211
+ self.input_buf.copy_(pred_input)
212
+ self.static_cache.reset()
213
+ self.graph.replay()
214
+ return self.output_tokens.clone()
faster_qwen3_tts/sampling.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared sampling helpers for talker and predictor generation."""
2
+ from __future__ import annotations
3
+
4
+ from typing import Iterable, Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+
10
+ def apply_repetition_penalty(
11
+ logits: torch.Tensor,
12
+ token_history: torch.Tensor,
13
+ repetition_penalty: float,
14
+ ) -> torch.Tensor:
15
+ """Apply repetition penalty to logits in-place and return them.
16
+
17
+ Args:
18
+ logits: Tensor shaped [1, 1, vocab] or [1, vocab].
19
+ token_history: 1-D tensor of previously generated token ids.
20
+ repetition_penalty: HF-style repetition penalty (>1.0).
21
+ """
22
+ if repetition_penalty == 1.0 or token_history.numel() == 0:
23
+ return logits
24
+ unique_toks = token_history.unique()
25
+ tok_logits = logits[..., unique_toks]
26
+ logits[..., unique_toks] = torch.where(
27
+ tok_logits > 0, tok_logits / repetition_penalty, tok_logits * repetition_penalty
28
+ )
29
+ return logits
30
+
31
+
32
+ def sample_logits(
33
+ logits: torch.Tensor,
34
+ *,
35
+ temperature: float,
36
+ top_k: int,
37
+ top_p: float,
38
+ do_sample: bool,
39
+ suppress_mask: Optional[torch.Tensor] = None,
40
+ suppress_tokens: Optional[Iterable[int]] = None,
41
+ ) -> torch.Tensor:
42
+ """Sample a token from logits.
43
+
44
+ Mirrors HF order: suppress -> temperature -> top-k -> top-p -> sample.
45
+ """
46
+ logits = logits.clone()
47
+ if suppress_mask is not None:
48
+ logits[..., suppress_mask] = float("-inf")
49
+ if suppress_tokens:
50
+ logits[..., list(suppress_tokens)] = float("-inf")
51
+ if not do_sample:
52
+ return torch.argmax(logits, dim=-1)
53
+ logits = logits / temperature
54
+ if top_k > 0:
55
+ topk_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
56
+ logits = torch.where(logits < topk_vals[..., -1:], torch.full_like(logits, float("-inf")), logits)
57
+ if top_p < 1.0:
58
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
59
+ probs = F.softmax(sorted_logits, dim=-1)
60
+ cumulative_probs = torch.cumsum(probs, dim=-1)
61
+ sorted_indices_to_remove = cumulative_probs > top_p
62
+ sorted_indices_to_remove[..., 0] = False
63
+ sorted_logits[sorted_indices_to_remove] = float("-inf")
64
+ logits = torch.full_like(logits, float("-inf"))
65
+ logits.scatter_(-1, sorted_indices, sorted_logits)
66
+ return torch.multinomial(F.softmax(logits, dim=-1), 1).squeeze(-1)
faster_qwen3_tts/streaming.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Streaming generation with CUDA graphs for both predictor and talker.
4
+
5
+ Yields codec ID chunks during generation instead of collecting all at once.
6
+ CUDA graph usage is identical to non-streaming — same per-step performance.
7
+ """
8
+ import time
9
+ from typing import Generator, Tuple
10
+
11
+ import torch
12
+
13
+ from .predictor_graph import PredictorGraph
14
+ from .sampling import apply_repetition_penalty, sample_logits
15
+ from .talker_graph import TalkerGraph
16
+
17
+
18
+ @torch.inference_mode()
19
+ def fast_generate_streaming(
20
+ talker,
21
+ talker_input_embeds: torch.Tensor,
22
+ attention_mask: torch.Tensor,
23
+ trailing_text_hiddens: torch.Tensor,
24
+ tts_pad_embed: torch.Tensor,
25
+ config,
26
+ predictor_graph: PredictorGraph,
27
+ talker_graph: TalkerGraph,
28
+ max_new_tokens: int = 2048,
29
+ min_new_tokens: int = 2,
30
+ temperature: float = 0.9,
31
+ top_k: int = 50,
32
+ top_p: float = 1.0,
33
+ do_sample: bool = True,
34
+ repetition_penalty: float = 1.05,
35
+ chunk_size: int = 12,
36
+ ) -> Generator[Tuple[torch.Tensor, dict], None, None]:
37
+ """
38
+ Streaming autoregressive generation with CUDA-graphed predictor and talker.
39
+
40
+ Yields (codec_chunk, timing_info) tuples every chunk_size steps.
41
+ codec_chunk: [chunk_steps, 16] tensor of codec IDs.
42
+ The final chunk may be shorter than chunk_size.
43
+ """
44
+ eos_id = config.codec_eos_token_id
45
+ vocab_size = config.vocab_size
46
+ device = talker_input_embeds.device
47
+
48
+ suppress_mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
49
+ suppress_start = max(0, vocab_size - 1024)
50
+ for i in range(suppress_start, vocab_size):
51
+ if i != eos_id:
52
+ suppress_mask[i] = True
53
+
54
+ predictor = talker.code_predictor
55
+ talker_codec_embed = talker.get_input_embeddings()
56
+ talker_codec_head = talker.codec_head
57
+ predictor_codec_embeds = predictor.get_input_embeddings()
58
+ num_code_groups = config.num_code_groups
59
+
60
+ # === PREFILL (still uses HF forward for variable-length prefill) ===
61
+ t_start = time.time()
62
+
63
+ out = talker.forward(
64
+ inputs_embeds=talker_input_embeds,
65
+ attention_mask=attention_mask,
66
+ use_cache=True,
67
+ output_hidden_states=True,
68
+ return_dict=True,
69
+ trailing_text_hidden=trailing_text_hiddens,
70
+ tts_pad_embed=tts_pad_embed,
71
+ generation_step=None,
72
+ past_hidden=None,
73
+ past_key_values=None,
74
+ )
75
+
76
+ talker_past_kv = out.past_key_values
77
+ past_hidden = out.past_hidden
78
+ gen_step = out.generation_step
79
+
80
+ logits = out.logits[:, -1, :]
81
+ suppress_eos = min_new_tokens > 0
82
+ token = sample_logits(
83
+ logits,
84
+ temperature=temperature,
85
+ top_k=top_k,
86
+ top_p=top_p,
87
+ do_sample=do_sample,
88
+ suppress_mask=suppress_mask,
89
+ suppress_tokens=[eos_id] if suppress_eos else None,
90
+ )
91
+
92
+ prefill_len = talker_graph.prefill_kv(talker_past_kv)
93
+ rope_deltas = getattr(talker, "rope_deltas", None)
94
+ talker_graph.set_generation_state(attention_mask, rope_deltas)
95
+
96
+ torch.cuda.synchronize()
97
+ t_prefill = time.time() - t_start
98
+
99
+ # === DECODE LOOP — yield chunks ===
100
+ chunk_buffer = []
101
+ all_first_tokens = [] # for repetition penalty across chunks
102
+ total_steps = 0
103
+ chunk_count = 0
104
+ chunk_start = time.time()
105
+
106
+ for step_idx in range(max_new_tokens):
107
+ if token.item() == eos_id:
108
+ break
109
+
110
+ # --- CUDA-Graphed Code Predictor ---
111
+ last_id_hidden = talker_codec_embed(token.unsqueeze(1))
112
+ pred_input = torch.cat((past_hidden, last_id_hidden), dim=1)
113
+ codebook_token_ids = predictor_graph.run(pred_input)
114
+
115
+ all_cb = torch.cat([token.view(1), codebook_token_ids])
116
+ chunk_buffer.append(all_cb.detach())
117
+ all_first_tokens.append(token.detach())
118
+
119
+ # --- Build input embedding for talker ---
120
+ codec_hiddens = [last_id_hidden]
121
+ for i in range(num_code_groups - 1):
122
+ codec_hiddens.append(predictor_codec_embeds[i](codebook_token_ids[i].unsqueeze(0).unsqueeze(0)))
123
+ inputs_embeds = torch.cat(codec_hiddens, dim=1).sum(1, keepdim=True)
124
+
125
+ if gen_step < trailing_text_hiddens.shape[1]:
126
+ inputs_embeds = inputs_embeds + trailing_text_hiddens[:, gen_step].unsqueeze(1)
127
+ else:
128
+ inputs_embeds = inputs_embeds + tts_pad_embed
129
+
130
+ # --- CUDA-Graphed Talker decode step ---
131
+ current_pos = prefill_len + step_idx
132
+ if current_pos >= talker_graph.max_seq_len - 1:
133
+ break
134
+
135
+ hidden_states = talker_graph.run(inputs_embeds, position=current_pos)
136
+
137
+ logits = talker_codec_head(hidden_states[:, -1, :]).unsqueeze(0)
138
+
139
+ if repetition_penalty != 1.0 and all_first_tokens:
140
+ history = torch.stack(all_first_tokens)
141
+ logits = apply_repetition_penalty(logits, history, repetition_penalty)
142
+
143
+ suppress_eos = len(all_first_tokens) < min_new_tokens
144
+ token = sample_logits(
145
+ logits.squeeze(0),
146
+ temperature=temperature,
147
+ top_k=top_k,
148
+ top_p=top_p,
149
+ do_sample=do_sample,
150
+ suppress_mask=suppress_mask,
151
+ suppress_tokens=[eos_id] if suppress_eos else None,
152
+ )
153
+ past_hidden = hidden_states[:, -1:, :].clone()
154
+ gen_step += 1
155
+
156
+ # --- Yield chunk when buffer is full ---
157
+ if len(chunk_buffer) >= chunk_size:
158
+ torch.cuda.synchronize()
159
+ chunk_decode_time = time.time() - chunk_start
160
+ total_steps += len(chunk_buffer)
161
+
162
+ yield torch.stack(chunk_buffer), {
163
+ 'chunk_index': chunk_count,
164
+ 'chunk_steps': len(chunk_buffer),
165
+ 'prefill_ms': t_prefill * 1000 if chunk_count == 0 else 0,
166
+ 'decode_ms': chunk_decode_time * 1000,
167
+ 'total_steps_so_far': total_steps,
168
+ 'is_final': False,
169
+ }
170
+
171
+ chunk_buffer = []
172
+ chunk_count += 1
173
+ chunk_start = time.time()
174
+
175
+ # --- Yield final partial chunk ---
176
+ if chunk_buffer:
177
+ torch.cuda.synchronize()
178
+ chunk_decode_time = time.time() - chunk_start
179
+ total_steps += len(chunk_buffer)
180
+
181
+ yield torch.stack(chunk_buffer), {
182
+ 'chunk_index': chunk_count,
183
+ 'chunk_steps': len(chunk_buffer),
184
+ 'prefill_ms': t_prefill * 1000 if chunk_count == 0 else 0,
185
+ 'decode_ms': chunk_decode_time * 1000,
186
+ 'total_steps_so_far': total_steps,
187
+ 'is_final': True,
188
+ }
189
+
190
+
191
+ @torch.inference_mode()
192
+ def parity_generate_streaming(
193
+ talker,
194
+ talker_input_embeds: torch.Tensor,
195
+ attention_mask: torch.Tensor,
196
+ trailing_text_hiddens: torch.Tensor,
197
+ tts_pad_embed: torch.Tensor,
198
+ config,
199
+ max_new_tokens: int = 2048,
200
+ min_new_tokens: int = 2,
201
+ temperature: float = 0.9,
202
+ top_k: int = 50,
203
+ top_p: float = 1.0,
204
+ do_sample: bool = True,
205
+ repetition_penalty: float = 1.05,
206
+ chunk_size: int = 12,
207
+ ) -> Generator[Tuple[torch.Tensor, dict], None, None]:
208
+ """
209
+ Streaming generation without CUDA graphs (dynamic cache).
210
+
211
+ Yields (codec_chunk, timing_info) tuples every chunk_size steps.
212
+ """
213
+ # NOTE: This function intentionally mirrors fast_generate_streaming. The core
214
+ # decode loop is duplicated so we can swap CUDA graphs/static cache for the
215
+ # dynamic-cache path while keeping sampling/chunking identical. If you edit
216
+ # the fast path, check parity_generate_streaming for matching changes.
217
+ eos_id = config.codec_eos_token_id
218
+ vocab_size = config.vocab_size
219
+ device = talker_input_embeds.device
220
+
221
+ suppress_mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
222
+ suppress_start = max(0, vocab_size - 1024)
223
+ for i in range(suppress_start, vocab_size):
224
+ if i != eos_id:
225
+ suppress_mask[i] = True
226
+
227
+ # === PREFILL ===
228
+ t_start = time.time()
229
+
230
+ out = talker.forward(
231
+ inputs_embeds=talker_input_embeds,
232
+ attention_mask=attention_mask,
233
+ use_cache=True,
234
+ output_hidden_states=True,
235
+ return_dict=True,
236
+ trailing_text_hidden=trailing_text_hiddens,
237
+ tts_pad_embed=tts_pad_embed,
238
+ generation_step=None,
239
+ past_hidden=None,
240
+ past_key_values=None,
241
+ )
242
+
243
+ talker_past_kv = out.past_key_values
244
+ past_hidden = out.past_hidden
245
+ gen_step = out.generation_step
246
+
247
+ logits = out.logits[:, -1, :]
248
+ suppress_eos = min_new_tokens > 0
249
+ token = sample_logits(
250
+ logits,
251
+ temperature=temperature,
252
+ top_k=top_k,
253
+ top_p=top_p,
254
+ do_sample=do_sample,
255
+ suppress_mask=suppress_mask,
256
+ suppress_tokens=[eos_id] if suppress_eos else None,
257
+ )
258
+
259
+ if attention_mask is not None:
260
+ attention_mask = attention_mask.clone()
261
+
262
+ torch.cuda.synchronize()
263
+ t_prefill = time.time() - t_start
264
+
265
+ # === DECODE LOOP — yield chunks ===
266
+ chunk_buffer = []
267
+ all_first_tokens = []
268
+ total_steps = 0
269
+ chunk_count = 0
270
+ chunk_start = time.time()
271
+
272
+ for _ in range(max_new_tokens):
273
+ if token.item() == eos_id:
274
+ break
275
+
276
+ cache_position = None
277
+ if attention_mask is not None:
278
+ attention_mask = torch.cat(
279
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
280
+ dim=1,
281
+ )
282
+ cache_position = torch.tensor([attention_mask.shape[1] - 1], device=attention_mask.device)
283
+
284
+ out = talker.forward(
285
+ input_ids=token.view(1, 1),
286
+ attention_mask=attention_mask,
287
+ use_cache=True,
288
+ output_hidden_states=True,
289
+ return_dict=True,
290
+ trailing_text_hidden=trailing_text_hiddens,
291
+ tts_pad_embed=tts_pad_embed,
292
+ generation_step=gen_step,
293
+ past_hidden=past_hidden,
294
+ past_key_values=talker_past_kv,
295
+ subtalker_dosample=do_sample,
296
+ subtalker_top_k=top_k,
297
+ subtalker_top_p=top_p,
298
+ subtalker_temperature=temperature,
299
+ cache_position=cache_position,
300
+ )
301
+
302
+ codec_ids = out.hidden_states[1]
303
+ if codec_ids is None:
304
+ break
305
+
306
+ chunk_buffer.append(codec_ids.squeeze(0).detach())
307
+ all_first_tokens.append(token.detach())
308
+
309
+ logits = out.logits[:, -1, :]
310
+ if repetition_penalty != 1.0 and all_first_tokens:
311
+ history = torch.stack(all_first_tokens)
312
+ logits = apply_repetition_penalty(logits, history, repetition_penalty)
313
+
314
+ suppress_eos = len(all_first_tokens) < min_new_tokens
315
+ token = sample_logits(
316
+ logits,
317
+ temperature=temperature,
318
+ top_k=top_k,
319
+ top_p=top_p,
320
+ do_sample=do_sample,
321
+ suppress_mask=suppress_mask,
322
+ suppress_tokens=[eos_id] if suppress_eos else None,
323
+ )
324
+
325
+ talker_past_kv = out.past_key_values
326
+ past_hidden = out.past_hidden
327
+ gen_step = out.generation_step
328
+
329
+ if len(chunk_buffer) >= chunk_size:
330
+ torch.cuda.synchronize()
331
+ chunk_decode_time = time.time() - chunk_start
332
+ total_steps += len(chunk_buffer)
333
+
334
+ yield torch.stack(chunk_buffer), {
335
+ 'chunk_index': chunk_count,
336
+ 'chunk_steps': len(chunk_buffer),
337
+ 'prefill_ms': t_prefill * 1000 if chunk_count == 0 else 0,
338
+ 'decode_ms': chunk_decode_time * 1000,
339
+ 'total_steps_so_far': total_steps,
340
+ 'is_final': False,
341
+ }
342
+
343
+ chunk_buffer = []
344
+ chunk_count += 1
345
+ chunk_start = time.time()
346
+
347
+ if chunk_buffer:
348
+ torch.cuda.synchronize()
349
+ chunk_decode_time = time.time() - chunk_start
350
+ total_steps += len(chunk_buffer)
351
+
352
+ yield torch.stack(chunk_buffer), {
353
+ 'chunk_index': chunk_count,
354
+ 'chunk_steps': len(chunk_buffer),
355
+ 'prefill_ms': t_prefill * 1000 if chunk_count == 0 else 0,
356
+ 'decode_ms': chunk_decode_time * 1000,
357
+ 'total_steps_so_far': total_steps,
358
+ 'is_final': True,
359
+ }
faster_qwen3_tts/talker_graph.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CUDA graph capture for the talker's single-token decode step,
4
+ using transformers StaticCache.
5
+
6
+ The talker has 28 transformer layers. Instead of reimplementing the
7
+ forward pass manually, we use the model's own forward with StaticCache.
8
+ The StaticCache provides fixed-size KV tensors compatible with CUDA graphs.
9
+
10
+ Strategy:
11
+ - Use transformers StaticCache for KV cache management
12
+ - Use the model's forward method (handles mask, RoPE, attention internally)
13
+ - Capture the single-token decode as a CUDA graph
14
+ - Update cache_position buffer between replays
15
+ """
16
+ import torch
17
+ from transformers import StaticCache
18
+ from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
19
+
20
+
21
+ class TalkerGraph:
22
+ """
23
+ Captures the talker's single-token decode step as a CUDA graph,
24
+ using the model's own forward with transformers StaticCache.
25
+ """
26
+
27
+ def __init__(self, talker_model, talker_config, device='cuda', dtype=torch.bfloat16,
28
+ max_seq_len=512):
29
+ self.device = device
30
+ device_index = torch.device(device).index
31
+ device_index = device_index if device_index is not None else torch.cuda.current_device()
32
+ self.device_index = device_index
33
+
34
+ self.dtype = dtype
35
+ self.max_seq_len = max_seq_len
36
+ self.hidden_size = talker_config.hidden_size
37
+ self.num_layers = talker_config.num_hidden_layers
38
+
39
+ # Keep reference to the inner model (transformer backbone)
40
+ self.model = talker_model
41
+
42
+ # Transformers StaticCache — handles index_copy_ and fixed-size KV internally
43
+ self.static_cache = StaticCache(config=talker_config, max_cache_len=max_seq_len)
44
+
45
+ # Static I/O buffers for CUDA graph
46
+ self.input_buf = torch.zeros(1, 1, self.hidden_size, dtype=dtype, device=device)
47
+ self.output_buf = torch.zeros(1, 1, self.hidden_size, dtype=dtype, device=device)
48
+
49
+ # Cache position buffer — updated before each graph replay
50
+ self.cache_position = torch.zeros(1, dtype=torch.long, device=device)
51
+ # Rope deltas from prefill (shape [batch, 1]) and position ids buffer.
52
+ self.rope_deltas = torch.zeros(1, 1, dtype=torch.float32, device=device)
53
+ self.position_ids = torch.zeros(3, 1, 1, dtype=torch.float32, device=device)
54
+
55
+ self.graph = None
56
+ self.captured = False
57
+ self.attn_mask = None
58
+ self.attn_mask_table = None
59
+ self._mask_key = None
60
+
61
+ def _init_cache_layers(self):
62
+ """Force lazy initialization of StaticCache layers before graph capture."""
63
+ config = self.model.config
64
+ num_kv_heads = getattr(config, 'num_key_value_heads', config.num_attention_heads)
65
+ head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
66
+ dummy_k = torch.zeros(1, num_kv_heads, 1, head_dim, dtype=self.dtype, device=self.device)
67
+ for layer in self.static_cache.layers:
68
+ if not layer.is_initialized:
69
+ layer.lazy_initialization(dummy_k)
70
+
71
+ def _build_attention_masks(self, attention_mask: torch.Tensor | None = None):
72
+ dummy = torch.zeros(1, 1, self.hidden_size, dtype=self.dtype, device=self.device)
73
+ max_len = self.max_seq_len
74
+ self.attn_mask_table = [None] * max_len
75
+
76
+ mask_fn = create_causal_mask if self.model.config.sliding_window is None else create_sliding_window_causal_mask
77
+
78
+ for i in range(max_len):
79
+ pos = torch.tensor([i], device=self.device)
80
+ full = mask_fn(
81
+ config=self.model.config,
82
+ input_embeds=dummy,
83
+ attention_mask=attention_mask,
84
+ cache_position=pos,
85
+ past_key_values=self.static_cache,
86
+ )
87
+ self.attn_mask_table[i] = full
88
+
89
+ if self.attn_mask is None:
90
+ self.attn_mask = self.attn_mask_table[0].clone()
91
+ else:
92
+ self.attn_mask.copy_(self.attn_mask_table[0])
93
+
94
+ def _set_attention_mask(self, position: int):
95
+ self.attn_mask.copy_(self.attn_mask_table[position])
96
+
97
+ def _decode_step(self):
98
+ """Single-token decode through the model's forward."""
99
+ out = self.model(
100
+ inputs_embeds=self.input_buf,
101
+ attention_mask=self.attn_mask,
102
+ past_key_values=self.static_cache,
103
+ cache_position=self.cache_position,
104
+ position_ids=self.position_ids,
105
+ use_cache=True,
106
+ )
107
+ self.output_buf.copy_(out.last_hidden_state)
108
+
109
+ @torch.inference_mode()
110
+ def capture(self, prefill_len=100, num_warmup=3):
111
+ """
112
+ Capture CUDA graph for single-token decode.
113
+ prefill_len: simulated prefill length for warmup (graph is position-independent).
114
+ """
115
+ print(f"Warming up talker graph ({num_warmup} runs)...")
116
+
117
+ # Force cache initialization before graph capture
118
+ self._init_cache_layers()
119
+ self._build_attention_masks()
120
+
121
+ # Set cache_position for warmup
122
+ self.cache_position[0] = prefill_len
123
+ self._set_attention_mask(prefill_len)
124
+
125
+ for _ in range(num_warmup):
126
+ self._decode_step()
127
+ torch.cuda.synchronize()
128
+
129
+ print("Capturing CUDA graph for talker decode...")
130
+
131
+ with torch.cuda.device(self.device_index):
132
+ self.graph = torch.cuda.CUDAGraph()
133
+
134
+ s = torch.cuda.Stream()
135
+ s.wait_stream(torch.cuda.current_stream())
136
+ with torch.cuda.stream(s):
137
+ # Warmup in capture stream
138
+ self._decode_step()
139
+ torch.cuda.synchronize()
140
+
141
+ with torch.cuda.graph(self.graph):
142
+ self._decode_step()
143
+
144
+ torch.cuda.current_stream().wait_stream(s)
145
+ torch.cuda.synchronize()
146
+ self.captured = True
147
+ print("Talker CUDA graph captured!")
148
+
149
+ def reset(self, prefill_len: int):
150
+ """Reset cache for new sequence."""
151
+ self.static_cache.reset()
152
+
153
+ def prefill_kv(self, past_key_values):
154
+ """
155
+ Copy HF DynamicCache from prefill into our StaticCache.
156
+ past_key_values: DynamicCache with num_layers layers of [1, kv_heads, seq_len, head_dim]
157
+ """
158
+ self.static_cache.reset()
159
+ seq_len = 0
160
+ for li in range(self.num_layers):
161
+ k, v = past_key_values[li] # each [1, kv_heads, seq_len, head_dim]
162
+ seq_len = k.shape[2]
163
+ if seq_len > self.max_seq_len:
164
+ raise RuntimeError(
165
+ f"Input is too long: prefill has {seq_len} tokens but max_seq_len={self.max_seq_len}. "
166
+ "Use shorter text or shorter reference audio."
167
+ )
168
+ cache_pos = torch.arange(seq_len, device=self.device)
169
+ self.static_cache.update(k, v, li, {"cache_position": cache_pos})
170
+ return seq_len
171
+
172
+ def set_generation_state(self, attention_mask: torch.Tensor, rope_deltas: torch.Tensor | None):
173
+ """Set padding-aware attention mask and rope deltas for decode parity."""
174
+ mask_key = None
175
+ full_attention_mask = None
176
+ if attention_mask is not None:
177
+ pad_counts = (attention_mask == 0).sum(dim=-1)
178
+ mask_key = tuple(pad_counts.tolist())
179
+ full_attention_mask = torch.ones(
180
+ attention_mask.shape[0],
181
+ self.max_seq_len,
182
+ dtype=attention_mask.dtype,
183
+ device=attention_mask.device,
184
+ )
185
+ for b, pads in enumerate(pad_counts.tolist()):
186
+ if pads > 0:
187
+ full_attention_mask[b, :pads] = 0
188
+ if self.attn_mask_table is None or mask_key != self._mask_key:
189
+ self._build_attention_masks(full_attention_mask)
190
+ self._mask_key = mask_key
191
+ if rope_deltas is None:
192
+ self.rope_deltas.zero_()
193
+ else:
194
+ if rope_deltas.dim() == 1:
195
+ rope_deltas = rope_deltas.unsqueeze(1)
196
+ self.rope_deltas.copy_(rope_deltas.to(self.rope_deltas.device, dtype=self.rope_deltas.dtype))
197
+
198
+ @torch.inference_mode()
199
+ def run(self, input_embeds: torch.Tensor, position: int) -> torch.Tensor:
200
+ """
201
+ Run one decode step.
202
+ input_embeds: [1, 1, hidden_size]
203
+ position: current sequence position
204
+ Returns: [1, 1, hidden_size] hidden states
205
+ """
206
+ self.input_buf.copy_(input_embeds)
207
+ self.cache_position[0] = position
208
+ self._set_attention_mask(position)
209
+ # position_ids = arange(seq_len=1) + cache_position + rope_deltas
210
+ delta = self.rope_deltas + self.cache_position[0].to(self.rope_deltas.dtype)
211
+ self.position_ids.copy_(delta.unsqueeze(0).expand(3, -1, -1))
212
+ self.graph.replay()
213
+
214
+ return self.output_buf # static buffer — caller should use immediately or clone
faster_qwen3_tts/utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import sys
3
+
4
+
5
+ class _FilteredStdout:
6
+ def __init__(self, stream, suppress_substrings):
7
+ self._stream = stream
8
+ self._suppress = suppress_substrings
9
+
10
+ def write(self, data):
11
+ if any(s in data for s in self._suppress):
12
+ return len(data)
13
+ return self._stream.write(data)
14
+
15
+ def flush(self):
16
+ return self._stream.flush()
17
+
18
+
19
+ @contextlib.contextmanager
20
+ def suppress_flash_attn_warning():
21
+ filtered = _FilteredStdout(
22
+ sys.stdout,
23
+ suppress_substrings=(
24
+ "flash-attn is not installed",
25
+ "manual PyTorch version",
26
+ "Please install flash-attn",
27
+ ),
28
+ )
29
+ with contextlib.redirect_stdout(filtered):
30
+ yield
main.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PYTORCH_ENABLE_MPS_FALLBACK=1 uvicorn main:app --host 0.0.0.0 --port 8888 --reload
2
+ # PYTORCH_ENABLE_MPS_FALLBACK=1 gunicorn main:app -b 0.0.0.0:8000 -w 4 -k uvicorn.workers.UvicornWorker
3
+ import io
4
+ import re
5
+ import os
6
+ import logging
7
+ import json
8
+ from time import gmtime
9
+ from datetime import datetime, timezone
10
+ from scipy.io import wavfile
11
+ from dotenv import load_dotenv
12
+ from contextlib import asynccontextmanager
13
+ from tts import synthesize, device
14
+ from huggingface_hub import hf_hub_download
15
+ from llama_cpp import Llama
16
+
17
+ from fastapi import FastAPI, Response, Body, UploadFile, HTTPException
18
+ from starlette.middleware.cors import CORSMiddleware
19
+
20
+
21
+ load_dotenv(verbose=False)
22
+
23
+ LOGGING_DIRECTORY = os.getenv('LOGGING_DIRECTORY', 'logs')
24
+
25
+ if not os.path.isdir(LOGGING_DIRECTORY):
26
+ os.makedirs(LOGGING_DIRECTORY)
27
+
28
+ file_handler = logging.FileHandler(os.path.join(LOGGING_DIRECTORY, 'api.log'), mode='a', encoding='utf-8')
29
+ formatter = logging.Formatter(fmt='%(asctime)s.%(msecs)03dZ - %(levelname)s - %(message)s', datefmt='%Y-%m-%dT%H:%M:%S')
30
+ formatter.converter = gmtime
31
+ file_handler.setFormatter(formatter)
32
+ #logger = logging.getLogger('uvicorn')
33
+ logger = logging.getLogger('gunicorn.error')
34
+ logger.addHandler(file_handler)
35
+
36
+ llm_prompt_format = os.getenv('LLM_PROMPT_FORMAT', None)
37
+ model_path = os.environ.get('LLAMACPP_PATH', None)
38
+
39
+
40
+ @asynccontextmanager
41
+ async def lifespan(app: FastAPI):
42
+ global model_path
43
+
44
+ base_directory = 'data'
45
+
46
+ for language in os.listdir(base_directory):
47
+ path = os.path.join(base_directory, language)
48
+
49
+ if os.path.isdir(path):
50
+ for filename in os.listdir(path):
51
+ _, extension = os.path.splitext(filename)
52
+
53
+ if extension.lower() == '.wav':
54
+ with open(os.path.join(path, filename), mode='rb') as f, io.BytesIO() as wave_bytes, open(os.path.join(path, 'prompt.txt'), 'r', encoding='utf-8') as prompt_file, open(os.path.join(path, 'input.txt'), 'r', encoding='utf-8') as input_file:
55
+ wave_bytes.write(f.read())
56
+ wave_bytes.seek(0)
57
+
58
+ synthesize(prompt_wave=wave_bytes, prompt_text=prompt_file.read(), prompt_language=language, input_text=input_file.read(), input_language=language, top_p=1, temperature=1)
59
+
60
+ if model_path is None:
61
+ model_path = hf_hub_download(repo_id=os.environ['LLAMACPP_REPO_ID'], filename=os.environ['LLAMACPP_FILENAME'], local_dir='./models')
62
+
63
+ yield
64
+
65
+
66
+ app = FastAPI(lifespan=lifespan)
67
+ app.add_middleware(CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=['*'], allow_headers=['*'])
68
+
69
+
70
+ @app.get("/device")
71
+ async def read_device():
72
+ return {'device': str(device), 'timestamp': int(datetime.now(timezone.utc).replace(tzinfo=timezone.utc).timestamp())}
73
+
74
+
75
+ @app.post("/generate", status_code=201)
76
+ def create_generated_text(messages: list[dict[str, str]] = Body(...), temperature: float = Body(default=1.0)):
77
+ input_text = ''
78
+
79
+ if llm_prompt_format == 'Llama':
80
+ for message in messages:
81
+ if message['role'] == 'system':
82
+ input_text += f"<|start_header_id|>system<|end_header_id|>\n\n{message['content']}<|eot_id|>"
83
+ elif message['role'] == 'user':
84
+ input_text += f"<|start_header_id|>user<|end_header_id|>\n\n{message['content']}<|eot_id|>"
85
+ elif message['role'] == 'assistant':
86
+ input_text += f"<|start_header_id|>assistant<|end_header_id|>\n\n{message['content']}<|eot_id|>"
87
+
88
+ input_text += '<|start_header_id|>assistant<|end_header_id|>\n\n'
89
+ pattern = r'<|start_header_id|>assistant<|end_header_id|>\n\n(.+?)(?:(?:<|eot_id|>)|$)'
90
+
91
+ else:
92
+ for message in messages:
93
+ if message['role'] == 'system' or message['role'] == 'user':
94
+ input_text += f"<start_of_turn>user\n{message['content']}<end_of_turn>\n"
95
+ elif message['role'] == 'assistant':
96
+ input_text += f"<start_of_turn>model\n{message['content']}<end_of_turn>\n"
97
+
98
+ input_text += '<start_of_turn>model\n'
99
+ pattern = r'<start_of_turn>model\n(.+?)(?:(?:<end_of_turn>)|$)'
100
+
101
+ if len(input_text) > 0:
102
+ llm = Llama(model_path=model_path, n_ctx=8192, n_gpu_layers=-1, n_batch=32, verbose=False)
103
+ choices = []
104
+
105
+ try:
106
+ for choice in llm(input_text, max_tokens=2048, temperature=temperature, top_p=0.95, echo=True)['choices']:
107
+ matches = re.findall(pattern, choice['text'], re.DOTALL)
108
+
109
+ if len(matches) > 0:
110
+ choices.append({'role': 'assistant', 'content': matches[len(matches) - 1]})
111
+
112
+ finally:
113
+ llm.close()
114
+
115
+ return {'choices': choices, 'timestamp': int(datetime.now(timezone.utc).replace(tzinfo=timezone.utc).timestamp())}
116
+
117
+ else:
118
+ raise HTTPException(status_code=400)
119
+
120
+
121
+ @app.post("/synthesize", status_code=201)
122
+ def create_uploaded_file(file: UploadFile, data = Body(...)):
123
+ if file.content_type == 'audio/wav':
124
+ try:
125
+ data = json.loads(data)
126
+
127
+ with io.BytesIO() as prompt_wave_bytes, io.BytesIO() as output_wave_bytes:
128
+ prompt_wave_bytes.write(file.file.read())
129
+ prompt_wave_bytes.seek(0)
130
+
131
+ output, sample_rate = synthesize(prompt_wave=prompt_wave_bytes, prompt_text=data['prompt'] if 'prompt' in data else None, prompt_language=data['language'], input_text=data['input'], input_language=data['language'], top_p=data['top_p'] if 'top_p' in data else 1.0, temperature=data['temperature'] if 'temperature' in data else 1.0)
132
+
133
+ wavfile.write(output_wave_bytes, sample_rate, output)
134
+ output_wave_bytes.seek(0)
135
+
136
+ return Response(content=output_wave_bytes.read(), media_type="audio/wav")
137
+
138
+ except Exception as e:
139
+ logging.error(f'{e}')
140
+
141
+ raise HTTPException(status_code=400, detail=str(e))
142
+
143
+ else:
144
+ raise HTTPException(status_code=400)
qwen_tts/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ qwen_tts: Qwen-TTS package.
19
+ """
20
+
21
+ from .inference.qwen3_tts_model import Qwen3TTSModel, VoiceClonePromptItem
22
+ from .inference.qwen3_tts_tokenizer import Qwen3TTSTokenizer
23
+
24
+ __all__ = ["__version__"]
qwen_tts/__main__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ def main():
17
+ print(
18
+ "qwen_tts package.\n"
19
+ "Use CLI entrypoints:\n"
20
+ " - qwen-tts-demo\n"
21
+ )
22
+
23
+ if __name__ == "__main__":
24
+ main()
qwen_tts/core/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from .tokenizer_25hz.configuration_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Config
17
+ from .tokenizer_25hz.modeling_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Model
18
+ from .tokenizer_12hz.configuration_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Config
19
+ from .tokenizer_12hz.modeling_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Model
qwen_tts/core/models/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from .configuration_qwen3_tts import Qwen3TTSConfig
17
+ from .modeling_qwen3_tts import Qwen3TTSForConditionalGeneration
18
+ from .processing_qwen3_tts import Qwen3TTSProcessor
qwen_tts/core/models/configuration_qwen3_tts.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from transformers.configuration_utils import PretrainedConfig, layer_type_validation
16
+ from transformers.modeling_rope_utils import rope_config_validation
17
+ from transformers.utils import logging
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class Qwen3TTSSpeakerEncoderConfig(PretrainedConfig):
23
+ r"""
24
+ This is the configuration class to store the configuration of a [`Qwen3TTSSpeakerEncoder`].
25
+ It is used to instantiate a Qwen3TTS speaker encoder model according to the specified arguments, defining the model
26
+ architecture. The architecture is based on the ECAPA-TDNN model.
27
+
28
+ Args:
29
+ mel_dim (`int`, *optional*, defaults to 128):
30
+ The dimension of the input mel-spectrogram.
31
+ enc_dim (`int`, *optional*, defaults to 192):
32
+ The dimension of the final speaker embedding.
33
+ enc_channels (`list[int]`, *optional*, defaults to `[512, 512, 512, 512, 1536]`):
34
+ A list of output channels for each TDNN/SERes2Net layer in the encoder. The first channel size is for the initial TDNN layer,
35
+ the intermediate ones for the `SqueezeExcitationRes2NetBlock` layers, and the last one for the multi-layer feature aggregation.
36
+ enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`):
37
+ A list of kernel sizes for each layer in the encoder, corresponding to `enc_channels`.
38
+ enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`):
39
+ A list of dilations for each layer in the encoder, corresponding to `enc_channels`.
40
+ enc_attention_channels (`int`, *optional*, defaults to 128):
41
+ The number of attention channels in the `AttentiveStatisticsPooling` layer.
42
+ enc_res2net_scale (`int`, *optional*,defaults to 8):
43
+ The scale of the `Res2NetBlock` in the encoder.
44
+ enc_se_channels (`int`, *optional*, defaults to 128):
45
+ The number of channels in the squeeze part of the `SqueezeExcitationBlock`.
46
+ """
47
+ def __init__(
48
+ self,
49
+ mel_dim=128,
50
+ enc_dim=1024,
51
+ enc_channels=[512, 512, 512, 512, 1536],
52
+ enc_kernel_sizes=[5, 3, 3, 3, 1],
53
+ enc_dilations=[1, 2, 3, 4, 1],
54
+ enc_attention_channels=128,
55
+ enc_res2net_scale=8,
56
+ enc_se_channels=128,
57
+ sample_rate=24000,
58
+ ):
59
+ self.mel_dim = mel_dim
60
+ self.enc_dim = enc_dim
61
+ self.enc_channels = enc_channels
62
+ self.enc_kernel_sizes = enc_kernel_sizes
63
+ self.enc_dilations = enc_dilations
64
+ self.enc_attention_channels = enc_attention_channels
65
+ self.enc_res2net_scale = enc_res2net_scale
66
+ self.enc_se_channels = enc_se_channels
67
+ self.sample_rate = sample_rate
68
+
69
+
70
+ class Qwen3TTSTalkerCodePredictorConfig(PretrainedConfig):
71
+ r"""
72
+ This is the configuration class to store the configuration of a [`Qwen3TTSTalkerCodePredictorModel`]. It is used to instantiate a
73
+ Qwen3TTSTalkerCodePredictor model according to the specified arguments, defining the model architecture.
74
+
75
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
76
+ documentation from [`PretrainedConfig`] for more information.
77
+
78
+
79
+ Args:
80
+ vocab_size (`int`, *optional*, defaults to 151936):
81
+ Vocabulary size of the Qwen3TTSTalkerCodePredictor model. Defines the number of different tokens that can be represented by the
82
+ `inputs_ids` passed when calling [`Qwen3TTSTalkerCodePredictorModel`]
83
+ hidden_size (`int`, *optional*, defaults to 4096):
84
+ Dimension of the hidden representations.
85
+ intermediate_size (`int`, *optional*, defaults to 22016):
86
+ Dimension of the MLP representations.
87
+ num_hidden_layers (`int`, *optional*, defaults to 32):
88
+ Number of hidden layers in the Transformer encoder.
89
+ num_attention_heads (`int`, *optional*, defaults to 32):
90
+ Number of attention heads for each attention layer in the Transformer encoder.
91
+ num_key_value_heads (`int`, *optional*, defaults to 32):
92
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
93
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
94
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
95
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
96
+ by meanpooling all the original heads within that group. For more details, check out [this
97
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
98
+ head_dim (`int`, *optional*, defaults to 128):
99
+ The attention head dimension.
100
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
101
+ The non-linear activation function (function or string) in the decoder.
102
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
103
+ The maximum sequence length that this model might ever be used with.
104
+ initializer_range (`float`, *optional*, defaults to 0.02):
105
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
106
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
107
+ The epsilon used by the rms normalization layers.
108
+ use_cache (`bool`, *optional*, defaults to `True`):
109
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
110
+ relevant if `config.is_decoder=True`.
111
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
112
+ Whether the model's input and output word embeddings should be tied.
113
+ rope_theta (`float`, *optional*, defaults to 10000.0):
114
+ The base period of the RoPE embeddings.
115
+ rope_scaling (`Dict`, *optional*):
116
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
117
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
118
+ accordingly.
119
+ Expected contents:
120
+ `rope_type` (`str`):
121
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
122
+ 'llama3'], with 'default' being the original RoPE implementation.
123
+ `factor` (`float`, *optional*):
124
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
125
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
126
+ original maximum pre-trained length.
127
+ `original_max_position_embeddings` (`int`, *optional*):
128
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
129
+ pretraining.
130
+ `attention_factor` (`float`, *optional*):
131
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
132
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
133
+ `factor` field to infer the suggested value.
134
+ `beta_fast` (`float`, *optional*):
135
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
136
+ ramp function. If unspecified, it defaults to 32.
137
+ `beta_slow` (`float`, *optional*):
138
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
139
+ ramp function. If unspecified, it defaults to 1.
140
+ `short_factor` (`list[float]`, *optional*):
141
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
142
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
143
+ size divided by the number of attention heads divided by 2
144
+ `long_factor` (`list[float]`, *optional*):
145
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
146
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
147
+ size divided by the number of attention heads divided by 2
148
+ `low_freq_factor` (`float`, *optional*):
149
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
150
+ `high_freq_factor` (`float`, *optional*):
151
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
152
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
153
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
154
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
155
+ Whether to use sliding window attention.
156
+ sliding_window (`int`, *optional*, defaults to 4096):
157
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
158
+ max_window_layers (`int`, *optional*, defaults to 28):
159
+ The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any
160
+ additional layer afterwards will use SWA (Sliding Window Attention).
161
+ layer_types (`list`, *optional*):
162
+ Attention pattern for each layer.
163
+ attention_dropout (`float`, *optional*, defaults to 0.0):
164
+ The dropout ratio for the attention probabilities.
165
+
166
+ """
167
+
168
+ model_type = "qwen3_tts_talker_code_predictor"
169
+ keys_to_ignore_at_inference = ["past_key_values"]
170
+
171
+ # Default tensor parallel plan for base model `Qwen3TTSTalkerCodePredictor`
172
+ base_model_tp_plan = {
173
+ "layers.*.self_attn.q_proj": "colwise",
174
+ "layers.*.self_attn.k_proj": "colwise",
175
+ "layers.*.self_attn.v_proj": "colwise",
176
+ "layers.*.self_attn.o_proj": "rowwise",
177
+ "layers.*.mlp.gate_proj": "colwise",
178
+ "layers.*.mlp.up_proj": "colwise",
179
+ "layers.*.mlp.down_proj": "rowwise",
180
+ }
181
+ base_model_pp_plan = {
182
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
183
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
184
+ "norm": (["hidden_states"], ["hidden_states"]),
185
+ }
186
+
187
+ def __init__(
188
+ self,
189
+ vocab_size=2048,
190
+ hidden_size=1024,
191
+ intermediate_size=3072,
192
+ num_hidden_layers=5,
193
+ num_attention_heads=16,
194
+ num_key_value_heads=8,
195
+ head_dim=128,
196
+ hidden_act="silu",
197
+ max_position_embeddings=32768,
198
+ initializer_range=0.02,
199
+ rms_norm_eps=0.000001,
200
+ use_cache=True,
201
+ tie_word_embeddings=False,
202
+ rope_theta=10000,
203
+ rope_scaling=None,
204
+ attention_bias=False,
205
+ use_sliding_window=False,
206
+ sliding_window=4096,
207
+ max_window_layers=28,
208
+ layer_types=None,
209
+ attention_dropout=0,
210
+ num_code_groups=32,
211
+ **kwargs,
212
+ ):
213
+ super().__init__(
214
+ tie_word_embeddings=tie_word_embeddings,
215
+ **kwargs,
216
+ )
217
+ self.vocab_size = vocab_size
218
+ self.max_position_embeddings = max_position_embeddings
219
+ self.hidden_size = hidden_size
220
+ self.intermediate_size = intermediate_size
221
+ self.num_hidden_layers = num_hidden_layers
222
+ self.num_attention_heads = num_attention_heads
223
+ self.use_sliding_window = use_sliding_window
224
+ self.sliding_window = sliding_window if self.use_sliding_window else None
225
+ self.max_window_layers = max_window_layers
226
+
227
+ # for backward compatibility
228
+ if num_key_value_heads is None:
229
+ num_key_value_heads = num_attention_heads
230
+
231
+ self.num_key_value_heads = num_key_value_heads
232
+ self.head_dim = head_dim
233
+ self.hidden_act = hidden_act
234
+ self.initializer_range = initializer_range
235
+ self.rms_norm_eps = rms_norm_eps
236
+ self.use_cache = use_cache
237
+ self.rope_theta = rope_theta
238
+ self.rope_scaling = rope_scaling
239
+ self.attention_bias = attention_bias
240
+ self.attention_dropout = attention_dropout
241
+ # Validate the correctness of rotary position embeddings parameters
242
+ # BC: if there is a 'type' field, move it to 'rope_type'.
243
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
244
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
245
+ rope_config_validation(self)
246
+
247
+ self.layer_types = layer_types
248
+ if self.layer_types is None:
249
+ self.layer_types = [
250
+ "sliding_attention"
251
+ if self.sliding_window is not None and i >= self.max_window_layers
252
+ else "full_attention"
253
+ for i in range(self.num_hidden_layers)
254
+ ]
255
+ layer_type_validation(self.layer_types)
256
+ self.num_code_groups = num_code_groups
257
+
258
+
259
+ class Qwen3TTSTalkerConfig(PretrainedConfig):
260
+ r"""
261
+ This is the configuration class to store the configuration of a [`Qwen3TTSTalkerModel`]. It is used to instantiate a
262
+ Qwen3TTSTalker model according to the specified arguments, defining the model architecture.
263
+
264
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
265
+ documentation from [`PretrainedConfig`] for more information.
266
+
267
+
268
+ Args:
269
+ vocab_size (`int`, *optional*, defaults to 151936):
270
+ Vocabulary size of the Qwen3TTSTalker model. Defines the number of different tokens that can be represented by the
271
+ `inputs_ids` passed when calling [`Qwen3TTSTalkerModel`]
272
+ hidden_size (`int`, *optional*, defaults to 2048):
273
+ Dimension of the hidden representations.
274
+ intermediate_size (`int`, *optional*, defaults to 6144):
275
+ Dimension of the MLP representations.
276
+ num_hidden_layers (`int`, *optional*, defaults to 24):
277
+ Number of hidden layers in the Transformer encoder.
278
+ num_attention_heads (`int`, *optional*, defaults to 32):
279
+ Number of attention heads for each attention layer in the Transformer encoder.
280
+ num_key_value_heads (`int`, *optional*, defaults to 4):
281
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
282
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
283
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
284
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
285
+ by meanpooling all the original heads within that group. For more details, check out [this
286
+ paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
287
+
288
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
289
+ The non-linear activation function (function or string) in the decoder.
290
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
291
+ The maximum sequence length that this model might ever be used with.
292
+ initializer_range (`float`, *optional*, defaults to 0.02):
293
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
294
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
295
+ The epsilon used by the rms normalization layers.
296
+ use_cache (`bool`, *optional*, defaults to `True`):
297
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
298
+ relevant if `config.is_decoder=True`.
299
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
300
+ Whether the model's input and output word embeddings should be tied.
301
+ rope_theta (`float`, *optional*, defaults to 10000.0):
302
+ The base period of the RoPE embeddings.
303
+ rope_scaling (`Dict`, *optional*):
304
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
305
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
306
+ accordingly.
307
+ Expected contents:
308
+ `rope_type` (`str`):
309
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
310
+ 'llama3'], with 'default' being the original RoPE implementation.
311
+ `factor` (`float`, *optional*):
312
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
313
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
314
+ original maximum pre-trained length.
315
+ `original_max_position_embeddings` (`int`, *optional*):
316
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
317
+ pretraining.
318
+ `attention_factor` (`float`, *optional*):
319
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
320
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
321
+ `factor` field to infer the suggested value.
322
+ `beta_fast` (`float`, *optional*):
323
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
324
+ ramp function. If unspecified, it defaults to 32.
325
+ `beta_slow` (`float`, *optional*):
326
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
327
+ ramp function. If unspecified, it defaults to 1.
328
+ `short_factor` (`list[float]`, *optional*):
329
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
330
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
331
+ size divided by the number of attention heads divided by 2
332
+ `long_factor` (`list[float]`, *optional*):
333
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
334
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
335
+ size divided by the number of attention heads divided by 2
336
+ `low_freq_factor` (`float`, *optional*):
337
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
338
+ `high_freq_factor` (`float`, *optional*):
339
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
340
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
341
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
342
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
343
+ Whether to use sliding window attention.
344
+ sliding_window (`int`, *optional*, defaults to 4096):
345
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
346
+ attention_dropout (`float`, *optional*, defaults to 0.0):
347
+ The dropout ratio for the attention probabilities.
348
+ """
349
+
350
+ model_type = "qwen3_tts_talker"
351
+ keys_to_ignore_at_inference = ["past_key_values"]
352
+
353
+ # Default tensor parallel plan for base model `Qwen3TTSTalker`
354
+ base_model_tp_plan = {
355
+ "layers.*.self_attn.q_proj": "colwise",
356
+ "layers.*.self_attn.k_proj": "colwise",
357
+ "layers.*.self_attn.v_proj": "colwise",
358
+ "layers.*.self_attn.o_proj": "rowwise",
359
+ "layers.*.mlp.gate_proj": "colwise",
360
+ "layers.*.mlp.up_proj": "colwise",
361
+ "layers.*.mlp.down_proj": "rowwise",
362
+ }
363
+ base_model_pp_plan = {
364
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
365
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
366
+ "norm": (["hidden_states"], ["hidden_states"]),
367
+ }
368
+ sub_configs = {"code_predictor_config": Qwen3TTSTalkerCodePredictorConfig}
369
+
370
+ def __init__(
371
+ self,
372
+ code_predictor_config=None,
373
+ vocab_size=3072,
374
+ hidden_size=1024,
375
+ intermediate_size=2048,
376
+ num_hidden_layers=20,
377
+ num_attention_heads=16,
378
+ num_key_value_heads=2,
379
+ hidden_act="silu",
380
+ max_position_embeddings=32768,
381
+ initializer_range=0.02,
382
+ rms_norm_eps=0.000001,
383
+ use_cache=True,
384
+ tie_word_embeddings=False,
385
+ rope_theta=10000,
386
+ rope_scaling=None,
387
+ attention_bias=False,
388
+ use_sliding_window=False,
389
+ sliding_window=4096,
390
+ attention_dropout=0,
391
+ num_code_groups=32,
392
+ text_hidden_size=2048,
393
+ codec_eos_token_id=4198,
394
+ codec_think_id=4202,
395
+ codec_nothink_id=4203,
396
+ codec_think_bos_id=4204,
397
+ codec_think_eos_id=4205,
398
+ codec_pad_id=4196,
399
+ codec_bos_id=4197,
400
+ spk_id=None,
401
+ spk_is_dialect=None,
402
+ codec_language_id=None,
403
+ **kwargs,
404
+ ):
405
+ super().__init__(
406
+ tie_word_embeddings=tie_word_embeddings,
407
+ **kwargs,
408
+ )
409
+ self.vocab_size = vocab_size
410
+ self.max_position_embeddings = max_position_embeddings
411
+ self.hidden_size = hidden_size
412
+ self.intermediate_size = intermediate_size
413
+ self.num_hidden_layers = num_hidden_layers
414
+ self.num_attention_heads = num_attention_heads
415
+ self.use_sliding_window = use_sliding_window
416
+ self.sliding_window = sliding_window if use_sliding_window else None
417
+
418
+ self.num_key_value_heads = num_key_value_heads
419
+ self.hidden_act = hidden_act
420
+ self.initializer_range = initializer_range
421
+ self.rms_norm_eps = rms_norm_eps
422
+ self.use_cache = use_cache
423
+ self.rope_theta = rope_theta
424
+ self.rope_scaling = rope_scaling
425
+ self.attention_bias = attention_bias
426
+ self.attention_dropout = attention_dropout
427
+ # Validate the correctness of rotary position embeddings parameters
428
+ # BC: if there is a 'type' field, move it to 'rope_type'.
429
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
430
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
431
+
432
+ if code_predictor_config is None:
433
+ code_predictor_config = {}
434
+ self.code_predictor_config = Qwen3TTSTalkerCodePredictorConfig()
435
+ logger.info("code_predictor_config is None. Initializing code_predictor model with default values")
436
+ elif isinstance(code_predictor_config, Qwen3TTSTalkerCodePredictorConfig):
437
+ self.code_predictor_config = code_predictor_config
438
+ else:
439
+ self.code_predictor_config = Qwen3TTSTalkerCodePredictorConfig(**code_predictor_config)
440
+ self.num_code_groups = num_code_groups
441
+ self.text_hidden_size = text_hidden_size
442
+ self.codec_eos_token_id = codec_eos_token_id
443
+ self.codec_think_id = codec_think_id
444
+ self.codec_language_id = codec_language_id
445
+ self.codec_nothink_id = codec_nothink_id
446
+ self.codec_think_bos_id = codec_think_bos_id
447
+ self.codec_think_eos_id = codec_think_eos_id
448
+ self.codec_pad_id = codec_pad_id
449
+ self.codec_bos_id = codec_bos_id
450
+ self.spk_id = spk_id
451
+ self.spk_is_dialect = spk_is_dialect
452
+
453
+
454
+ class Qwen3TTSConfig(PretrainedConfig):
455
+ """
456
+ This is the configuration class to store the configuration of a [`Qwen3TTSForConditionalGeneration`].
457
+ """
458
+
459
+ model_type = "qwen3_tts"
460
+ sub_configs = {
461
+ "talker_config": Qwen3TTSTalkerConfig,
462
+ "speaker_encoder_config": Qwen3TTSSpeakerEncoderConfig,
463
+ }
464
+
465
+ def __init__(
466
+ self,
467
+ talker_config=None,
468
+ speaker_encoder_config=None,
469
+ tokenizer_type=None,
470
+ tts_model_size=None,
471
+ tts_model_type=None,
472
+ im_start_token_id=151644,
473
+ im_end_token_id=151645,
474
+ tts_pad_token_id=151671,
475
+ tts_bos_token_id=151672,
476
+ tts_eos_token_id=151673,
477
+ **kwargs,
478
+ ):
479
+ super().__init__(**kwargs)
480
+
481
+ if talker_config is None:
482
+ talker_config = {}
483
+ logger.info("talker_config is None. Initializing talker model with default values")
484
+ if speaker_encoder_config is None:
485
+ speaker_encoder_config = {}
486
+ logger.info("speaker_encoder_config is None. Initializing talker model with default values")
487
+
488
+ self.talker_config = Qwen3TTSTalkerConfig(**talker_config)
489
+ self.speaker_encoder_config = Qwen3TTSSpeakerEncoderConfig(**speaker_encoder_config)
490
+
491
+ self.tokenizer_type = tokenizer_type
492
+ self.tts_model_size = tts_model_size
493
+ self.tts_model_type = tts_model_type
494
+
495
+ self.im_start_token_id = im_start_token_id
496
+ self.im_end_token_id = im_end_token_id
497
+ self.tts_pad_token_id = tts_pad_token_id
498
+ self.tts_bos_token_id = tts_bos_token_id
499
+ self.tts_eos_token_id = tts_eos_token_id
500
+
501
+
502
+ __all__ = ["Qwen3TTSConfig", "Qwen3TTSTalkerConfig", "Qwen3TTSSpeakerEncoderConfig"]
qwen_tts/core/models/modeling_qwen3_tts.py ADDED
The diff for this file is too large to render. See raw diff
 
qwen_tts/core/models/processing_qwen3_tts.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from transformers.feature_extraction_utils import BatchFeature
16
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
17
+
18
+
19
+ class Qwen3TTSProcessorKwargs(ProcessingKwargs, total=False):
20
+ _defaults = {
21
+ "text_kwargs": {
22
+ "padding": False,
23
+ "padding_side": "left",
24
+ }
25
+ }
26
+
27
+ class Qwen3TTSProcessor(ProcessorMixin):
28
+ r"""
29
+ Constructs a Qwen3TTS processor.
30
+
31
+ Args:
32
+ tokenizer ([`Qwen2TokenizerFast`], *optional*):
33
+ The text tokenizer.
34
+ chat_template (`Optional[str]`, *optional*):
35
+ The Jinja template to use for formatting the conversation. If not provided, the default chat template is used.
36
+ """
37
+
38
+ attributes = ["tokenizer"]
39
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
40
+
41
+ def __init__(
42
+ self, tokenizer=None, chat_template=None
43
+ ):
44
+ super().__init__(tokenizer, chat_template=chat_template)
45
+
46
+ def __call__(self, text=None, **kwargs) -> BatchFeature:
47
+ """
48
+ Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
49
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
50
+ the text.
51
+
52
+ Args:
53
+ text (`str`, `List[str]`, `List[List[str]]`):
54
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
55
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
56
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
57
+ """
58
+
59
+ if text is None:
60
+ raise ValueError("You need to specify either a `text` input to process.")
61
+
62
+ output_kwargs = self._merge_kwargs(
63
+ Qwen3TTSProcessorKwargs,
64
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
65
+ **kwargs,
66
+ )
67
+ if not isinstance(text, list):
68
+ text = [text]
69
+
70
+ texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
71
+
72
+ return BatchFeature(
73
+ data={**texts_inputs},
74
+ tensor_type=kwargs.get("return_tensors"),
75
+ )
76
+
77
+ def batch_decode(self, *args, **kwargs):
78
+ """
79
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
80
+ refer to the docstring of this method for more information.
81
+ """
82
+ return self.tokenizer.batch_decode(*args, **kwargs)
83
+
84
+ def decode(self, *args, **kwargs):
85
+ """
86
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
87
+ the docstring of this method for more information.
88
+ """
89
+ return self.tokenizer.decode(*args, **kwargs)
90
+
91
+ def apply_chat_template(self, conversations, chat_template=None, **kwargs):
92
+ if isinstance(conversations[0], dict):
93
+ conversations = [conversations]
94
+ return super().apply_chat_template(conversations, chat_template, **kwargs)
95
+
96
+ @property
97
+ def model_input_names(self):
98
+ tokenizer_input_names = self.tokenizer.model_input_names
99
+ return list(
100
+ dict.fromkeys(
101
+ tokenizer_input_names
102
+ )
103
+ )
104
+
105
+
106
+ __all__ = ["Qwen3TTSProcessor"]
qwen_tts/core/tokenizer_12hz/configuration_qwen3_tts_tokenizer_v2.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Qwen3TTSTokenizerV2 model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+ from transformers import MimiConfig
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class Qwen3TTSTokenizerV2DecoderConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV2DecoderConfig`].
29
+
30
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
31
+ documentation from [`PretrainedConfig`] for more information.
32
+
33
+ Args:
34
+ codebook_size (`int`, *optional*, defaults to 2048):
35
+ Number of entries in each residual codebook used for acoustic token quantization.
36
+ hidden_size (`int`, *optional*, defaults to 1024):
37
+ Dimensionality of the hidden states and embeddings in the autoregressive transformer decoder.
38
+ max_position_embeddings (`int`, *optional*, defaults to 8000):
39
+ Maximum sequence length that the autoregressive decoder can handle. Determines positional embedding size.
40
+ rope_theta (`float`, *optional*, defaults to 10000.0):
41
+ The base period for rotary position embeddings (RoPE) applied to attention layers.
42
+ num_attention_heads (`int`, *optional*, defaults to 16):
43
+ Number of attention heads for each attention layer in the decoder.
44
+ num_key_value_heads (`int`, *optional*, defaults to 16):
45
+ Number of key and value attention heads used in grouped-query attention (if applicable).
46
+ attention_bias (`bool`, *optional*, defaults to `False`):
47
+ Whether to use bias in the attention projection layers.
48
+ sliding_window (`int`, *optional*, defaults to 72):
49
+ Window size for local attention mechanism, limiting attention context to improve efficiency.
50
+ intermediate_size (`int`, *optional*, defaults to 3072):
51
+ Dimensionality of the feed-forward (intermediate) layer in each transformer block.
52
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
53
+ The non-linear activation function used in the feed-forward layers. Supports `"silu"`, `"relu"`, `"gelu"`, etc.
54
+ layer_scale_initial_scale (`float`, *optional*, defaults to 0.01):
55
+ Initial value for LayerScale applied in transformer blocks, helping stabilize training.
56
+ rms_norm_eps (`float`, *optional*, defaults to 1e-5):
57
+ Epsilon value for RMS normalization layers to prevent division by zero.
58
+ num_hidden_layers (`int`, *optional*, defaults to 8):
59
+ Number of transformer blocks in the autoregressive decoder.
60
+ num_quantizers (`int`, *optional*, defaults to 16):
61
+ Number of residual vector quantizers used in the vocoder for fine-grained audio reconstruction.
62
+ upsample_rates (`Tuple[int]`, *optional*, defaults to `(8, 5, 4, 3)`):
63
+ Rate at which features are upsampled in the final waveform synthesis stage.
64
+ upsampling_ratios (`Tuple[int]`, *optional*, defaults to `(2, 2)`):
65
+ Ratios used in transposed convolutional layers to progressively upsample feature maps to waveform.
66
+ decoder_dim (`int`, *optional*, defaults to 1536):
67
+ Final dimensionality of the decoder's output before waveform generation.
68
+ attention_dropout (`float`, *optional*, defaults to 0.0):
69
+ Dropout probability applied to attention weights in the decoder.
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ codebook_size=2048,
75
+ hidden_size=1024,
76
+ latent_dim=1024,
77
+ max_position_embeddings=8000,
78
+ rope_theta=10000,
79
+ num_attention_heads=16,
80
+ num_key_value_heads=16,
81
+ attention_bias=False,
82
+ sliding_window=72,
83
+ intermediate_size=3072,
84
+ hidden_act="silu",
85
+ layer_scale_initial_scale=0.01,
86
+ rms_norm_eps=1e-5,
87
+ num_hidden_layers=8,
88
+ num_quantizers=16,
89
+ upsample_rates=(8, 5, 4, 3),
90
+ upsampling_ratios=(2, 2),
91
+ decoder_dim=1536,
92
+ attention_dropout=0.0,
93
+ **kwargs,
94
+ ):
95
+ super().__init__(**kwargs)
96
+ self.codebook_size = codebook_size
97
+ self.hidden_size = hidden_size
98
+ self.latent_dim = latent_dim
99
+ self.max_position_embeddings = max_position_embeddings
100
+ self.rope_theta = rope_theta
101
+ self.num_attention_heads = num_attention_heads
102
+ self.num_key_value_heads = num_key_value_heads
103
+ self.attention_bias = attention_bias
104
+ self.sliding_window = sliding_window
105
+ self.intermediate_size = intermediate_size
106
+ self.hidden_act = hidden_act
107
+ self.layer_scale_initial_scale = layer_scale_initial_scale
108
+ self.rms_norm_eps = rms_norm_eps
109
+ self.num_hidden_layers = num_hidden_layers
110
+ self.num_quantizers = num_quantizers
111
+ self.upsample_rates = upsample_rates
112
+ self.upsampling_ratios = upsampling_ratios
113
+ self.decoder_dim = decoder_dim
114
+ self.attention_dropout = attention_dropout
115
+
116
+ @property
117
+ def layer_types(self):
118
+ """
119
+ All layer in code2wav should be sliding attention
120
+ """
121
+ return ["sliding_attention"] * self.num_hidden_layers
122
+
123
+
124
+ class Qwen3TTSTokenizerV2Config(PretrainedConfig):
125
+ """
126
+ This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV2Config`]. It is used to instantiate a Qwen3TTSTokenizerV2Model
127
+ model according to the specified sub-models configurations, defining the model architecture.
128
+
129
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
130
+ documentation from [`PretrainedConfig`] for more information.
131
+
132
+ Args:
133
+ encoder_config (`dict`, *optional*): Configuration of the underlying encoder sub-model.
134
+ decoder_config (`dict`, *optional*): Configuration of the underlying decoder sub-model.
135
+ """
136
+
137
+ model_type = "qwen3_tts_tokenizer_12hz"
138
+ sub_configs = {
139
+ "encoder_config": MimiConfig,
140
+ "decoder_config": Qwen3TTSTokenizerV2DecoderConfig,
141
+ }
142
+
143
+ def __init__(
144
+ self,
145
+ encoder_config=None,
146
+ decoder_config=None,
147
+ encoder_valid_num_quantizers=16,
148
+ input_sample_rate=24000,
149
+ output_sample_rate=24000,
150
+ decode_upsample_rate=1920,
151
+ encode_downsample_rate=1920,
152
+ **kwargs,
153
+ ):
154
+ super().__init__(**kwargs)
155
+ if encoder_config is None:
156
+ encoder_config = {}
157
+ logger.info("encoder_config is None. Initializing encoder with default values")
158
+ if decoder_config is None:
159
+ decoder_config = {}
160
+ logger.info("decoder_config is None. Initializing decoder with default values")
161
+
162
+ self.encoder_config = MimiConfig(**encoder_config)
163
+ self.decoder_config = Qwen3TTSTokenizerV2DecoderConfig(**decoder_config)
164
+
165
+ self.encoder_valid_num_quantizers = encoder_valid_num_quantizers
166
+ self.input_sample_rate = input_sample_rate
167
+ self.output_sample_rate = output_sample_rate
168
+ self.decode_upsample_rate = decode_upsample_rate
169
+ self.encode_downsample_rate = encode_downsample_rate
170
+
171
+
172
+ __all__ = ["Qwen3TTSTokenizerV2Config", "Qwen3TTSTokenizerV2DecoderConfig"]
qwen_tts/core/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py ADDED
@@ -0,0 +1,1025 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Qwen3TTSTokenizerV2 model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Callable, Optional, Union, List
20
+
21
+ import numpy as np
22
+ import torch
23
+ from torch import nn
24
+ from torch.nn import Parameter
25
+ from torch.nn import functional as F
26
+ from transformers import MimiConfig, MimiModel
27
+ from transformers.activations import ACT2FN
28
+ from transformers.cache_utils import Cache, DynamicCache
29
+ from transformers.integrations import use_kernel_forward_from_hub
30
+ from transformers.masking_utils import (
31
+ create_causal_mask,
32
+ create_sliding_window_causal_mask,
33
+ )
34
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
35
+ from transformers.modeling_layers import GradientCheckpointingLayer
36
+ from transformers.modeling_outputs import BaseModelOutputWithPast
37
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
+ from transformers.processing_utils import Unpack
40
+ from transformers.utils import ModelOutput, auto_docstring, logging
41
+ from transformers.utils.deprecation import deprecate_kwarg
42
+ from transformers.utils.generic import check_model_inputs
43
+
44
+ from .configuration_qwen3_tts_tokenizer_v2 import (
45
+ Qwen3TTSTokenizerV2Config,
46
+ Qwen3TTSTokenizerV2DecoderConfig,
47
+ )
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ @dataclass
53
+ @auto_docstring
54
+ class Qwen3TTSTokenizerV2EncoderOutput(ModelOutput):
55
+ r"""
56
+ audio_codes (`List[torch.LongTensor]`):
57
+ Discret code embeddings computed using `model.encode`, each tensor has shape (codes_length_i, num_quantizers).
58
+ """
59
+
60
+ audio_codes: List[torch.LongTensor] = None
61
+
62
+
63
+ @dataclass
64
+ @auto_docstring
65
+ class Qwen3TTSTokenizerV2DecoderOutput(ModelOutput):
66
+ r"""
67
+ audio_values (`List[torch.FloatTensor]`):
68
+ Decoded audio values, obtained using the decoder part of Qwen3TTSTokenizerV1.
69
+ Each tensor has shape (segment_length_i).
70
+ """
71
+
72
+ audio_values: List[torch.FloatTensor] = None
73
+
74
+
75
+ def rotate_half(x):
76
+ """Rotates half the hidden dims of the input."""
77
+ x1 = x[..., : x.shape[-1] // 2]
78
+ x2 = x[..., x.shape[-1] // 2 :]
79
+ return torch.cat((-x2, x1), dim=-1)
80
+
81
+
82
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
83
+ """Applies Rotary Position Embedding to the query and key tensors.
84
+
85
+ Args:
86
+ q (`torch.Tensor`): The query tensor.
87
+ k (`torch.Tensor`): The key tensor.
88
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
89
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
90
+ position_ids (`torch.Tensor`, *optional*):
91
+ Deprecated and unused.
92
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
93
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
94
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
95
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
96
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
97
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
98
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
99
+ Returns:
100
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
101
+ """
102
+ cos = cos.unsqueeze(unsqueeze_dim)
103
+ sin = sin.unsqueeze(unsqueeze_dim)
104
+ q_embed = (q * cos) + (rotate_half(q) * sin)
105
+ k_embed = (k * cos) + (rotate_half(k) * sin)
106
+ return q_embed, k_embed
107
+
108
+
109
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
110
+ """
111
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
112
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
113
+ """
114
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
115
+ if n_rep == 1:
116
+ return hidden_states
117
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
118
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
119
+
120
+
121
+ def eager_attention_forward(
122
+ module: nn.Module,
123
+ query: torch.Tensor,
124
+ key: torch.Tensor,
125
+ value: torch.Tensor,
126
+ attention_mask: Optional[torch.Tensor],
127
+ scaling: float,
128
+ dropout: float = 0.0,
129
+ **kwargs,
130
+ ):
131
+ key_states = repeat_kv(key, module.num_key_value_groups)
132
+ value_states = repeat_kv(value, module.num_key_value_groups)
133
+
134
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
135
+ if attention_mask is not None:
136
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
137
+ attn_weights = attn_weights + causal_mask
138
+
139
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
140
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
141
+ attn_output = torch.matmul(attn_weights, value_states)
142
+ attn_output = attn_output.transpose(1, 2).contiguous()
143
+
144
+ return attn_output, attn_weights
145
+
146
+
147
+ @auto_docstring
148
+ class Qwen3TTSTokenizerV2DecoderPreTrainedModel(PreTrainedModel):
149
+ config: Qwen3TTSTokenizerV2DecoderConfig
150
+ base_model_prefix = "model"
151
+ supports_gradient_checkpointing = True
152
+ _skip_keys_device_placement = "past_key_values"
153
+ _supports_flash_attn = True
154
+ _supports_sdpa = True
155
+ _can_compile_fullgraph = False
156
+ _supports_attention_backend = True
157
+
158
+
159
+ class Qwen3TTSTokenizerV2CausalConvNet(nn.Module):
160
+ def __init__(
161
+ self,
162
+ in_channels,
163
+ out_channels,
164
+ kernel_size,
165
+ dilation=1,
166
+ stride=1,
167
+ groups=1,
168
+ ):
169
+ super().__init__()
170
+ self.conv = nn.Conv1d(
171
+ in_channels,
172
+ out_channels,
173
+ kernel_size,
174
+ stride=stride,
175
+ dilation=dilation,
176
+ groups=groups,
177
+ )
178
+ self.stride = stride
179
+ self.kernel_size = (kernel_size - 1) * dilation + 1
180
+ self.dilation = dilation
181
+ self.padding = self.kernel_size - self.stride
182
+
183
+ def _get_extra_padding_for_conv1d(self, hidden_state: torch.Tensor) -> int:
184
+ length = hidden_state.shape[-1]
185
+ n_frames = (length - self.kernel_size + self.padding) / self.stride + 1
186
+ ideal_length = (math.ceil(n_frames) - 1) * self.stride + (self.kernel_size - self.padding)
187
+ return ideal_length - length
188
+
189
+ def forward(self, hidden_state):
190
+ extra_padding = self._get_extra_padding_for_conv1d(hidden_state)
191
+ hidden_state = F.pad(hidden_state, (self.padding, extra_padding), mode="constant", value=0)
192
+ return self.conv(hidden_state).contiguous()
193
+
194
+
195
+ class Qwen3TTSTokenizerV2CausalTransConvNet(nn.Module):
196
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1):
197
+ super().__init__()
198
+ self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=stride)
199
+
200
+ pad = kernel_size - stride
201
+ self.left_pad = math.ceil(pad)
202
+ self.right_pad = pad = self.left_pad
203
+
204
+ def forward(self, hidden_state):
205
+ hidden_state = self.conv(hidden_state)
206
+ hidden_state = hidden_state[..., self.left_pad : hidden_state.shape[-1] - self.right_pad]
207
+ return hidden_state.contiguous()
208
+
209
+
210
+ class Qwen3TTSTokenizerV2ConvNeXtBlock(nn.Module):
211
+ def __init__(self, dim: int):
212
+ super().__init__()
213
+ self.dwconv = Qwen3TTSTokenizerV2CausalConvNet(
214
+ dim,
215
+ dim,
216
+ kernel_size=7,
217
+ groups=dim,
218
+ dilation=1,
219
+ )
220
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
221
+ self.pwconv1 = nn.Linear(dim, 4 * dim)
222
+ self.act = nn.GELU()
223
+ self.pwconv2 = nn.Linear(4 * dim, dim)
224
+ self.gamma = nn.Parameter(1e-6 * torch.ones(dim))
225
+
226
+ def forward(self, hidden_states):
227
+ input = hidden_states
228
+
229
+ hidden_states = self.dwconv(hidden_states)
230
+ hidden_states = hidden_states.permute(0, 2, 1)
231
+ hidden_states = self.norm(hidden_states)
232
+ hidden_states = self.pwconv1(hidden_states)
233
+ hidden_states = self.act(hidden_states)
234
+ hidden_states = self.pwconv2(hidden_states)
235
+
236
+ hidden_states = self.gamma * hidden_states
237
+
238
+ hidden_states = hidden_states.permute(0, 2, 1)
239
+
240
+ hidden_states = input + hidden_states
241
+
242
+ return hidden_states
243
+
244
+
245
+ class Qwen3TTSTokenizerV2DecoderRotatoryEmbedding(nn.Module):
246
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
247
+
248
+ def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, device=None):
249
+ super().__init__()
250
+ # BC: "rope_type" was originally "type"
251
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
252
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
253
+ else:
254
+ self.rope_type = "default"
255
+ self.max_seq_len_cached = config.max_position_embeddings
256
+ self.original_max_seq_len = config.max_position_embeddings
257
+
258
+ self.config = config
259
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
260
+
261
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
262
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
263
+ self.original_inv_freq = self.inv_freq
264
+
265
+ @torch.no_grad()
266
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
267
+ def forward(self, x, position_ids):
268
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
269
+ position_ids_expanded = position_ids[:, None, :].float()
270
+
271
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
272
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
273
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
274
+ emb = torch.cat((freqs, freqs), dim=-1)
275
+ cos = emb.cos() * self.attention_scaling
276
+ sin = emb.sin() * self.attention_scaling
277
+
278
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
279
+
280
+
281
+ class Qwen3TTSTokenizerV2DecoderAttention(nn.Module):
282
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
283
+
284
+ def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, layer_idx):
285
+ super().__init__()
286
+ self.config = config
287
+ self.layer_idx = layer_idx
288
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
289
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
290
+ self.scaling = self.head_dim**-0.5
291
+ self.attention_dropout = config.attention_dropout
292
+ self.is_causal = True
293
+
294
+ self.q_proj = nn.Linear(
295
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
296
+ )
297
+ self.k_proj = nn.Linear(
298
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
299
+ )
300
+ self.v_proj = nn.Linear(
301
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
302
+ )
303
+ self.o_proj = nn.Linear(
304
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
305
+ )
306
+ self.q_norm = nn.Identity()
307
+ self.k_norm = nn.Identity()
308
+ self.sliding_window = config.sliding_window
309
+
310
+ @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
311
+ def forward(
312
+ self,
313
+ hidden_states: torch.Tensor,
314
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
315
+ attention_mask: Optional[torch.Tensor],
316
+ past_key_values: Optional[Cache] = None,
317
+ cache_position: Optional[torch.LongTensor] = None,
318
+ **kwargs: Unpack[FlashAttentionKwargs],
319
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
320
+ input_shape = hidden_states.shape[:-1]
321
+ hidden_shape = (*input_shape, -1, self.head_dim)
322
+
323
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
324
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
325
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
326
+
327
+ cos, sin = position_embeddings
328
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
329
+
330
+ if past_key_values is not None:
331
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
332
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
333
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
334
+
335
+ attention_interface: Callable = eager_attention_forward
336
+ if self.config._attn_implementation != "eager":
337
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
338
+
339
+ attn_output, attn_weights = attention_interface(
340
+ self,
341
+ query_states,
342
+ key_states,
343
+ value_states,
344
+ attention_mask,
345
+ dropout=0.0 if not self.training else self.attention_dropout,
346
+ scaling=self.scaling,
347
+ sliding_window=self.sliding_window, # diff with Llama
348
+ **kwargs,
349
+ )
350
+
351
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
352
+ attn_output = self.o_proj(attn_output)
353
+ return attn_output, attn_weights
354
+
355
+
356
+ class Qwen3TTSTokenizerV2DecoderMlp(nn.Module):
357
+ def __init__(self, config):
358
+ super().__init__()
359
+ self.config = config
360
+ self.hidden_size = config.hidden_size
361
+ self.intermediate_size = config.intermediate_size
362
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
363
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
364
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
365
+ self.act_fn = ACT2FN[config.hidden_act]
366
+
367
+ def forward(self, x):
368
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
369
+ return down_proj
370
+
371
+
372
+ @use_kernel_forward_from_hub("RMSNorm")
373
+ class Qwen3TTSTokenizerV2DecoderRMSNorm(nn.Module):
374
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
375
+ """
376
+ Qwen3TTSTokenizerV2DecoderRMSNorm is equivalent to T5LayerNorm
377
+ """
378
+ super().__init__()
379
+ self.weight = nn.Parameter(torch.ones(hidden_size))
380
+ self.variance_epsilon = eps
381
+
382
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
383
+ input_dtype = hidden_states.dtype
384
+ hidden_states = hidden_states.to(torch.float32)
385
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
386
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
387
+ return self.weight * hidden_states.to(input_dtype)
388
+
389
+ def extra_repr(self):
390
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
391
+
392
+
393
+ class Qwen3TTSTokenizerV2DecoderLayerScale(nn.Module):
394
+ """Layer scale from [Touvron et al 2021] (https://huggingface.co/papers/2103.17239).
395
+ This rescales diagonally the residual outputs close to 0, with a learnt scale.
396
+ """
397
+
398
+ def __init__(self, config):
399
+ super().__init__()
400
+ channels = config.hidden_size
401
+ initial_scale = config.layer_scale_initial_scale
402
+ self.scale = nn.Parameter(torch.full((channels,), initial_scale, requires_grad=True))
403
+
404
+ def forward(self, x: torch.Tensor):
405
+ return self.scale * x
406
+
407
+
408
+ class Qwen3TTSTokenizerV2DecoderTransformerLayer(GradientCheckpointingLayer):
409
+ def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, layer_idx):
410
+ super().__init__()
411
+ self.hidden_size = config.hidden_size
412
+ self.self_attn = Qwen3TTSTokenizerV2DecoderAttention(config, layer_idx)
413
+ self.mlp = Qwen3TTSTokenizerV2DecoderMlp(config)
414
+ self.input_layernorm = Qwen3TTSTokenizerV2DecoderRMSNorm(config.hidden_size, config.rms_norm_eps)
415
+ self.post_attention_layernorm = Qwen3TTSTokenizerV2DecoderRMSNorm(config.hidden_size, config.rms_norm_eps)
416
+ self.self_attn_layer_scale = Qwen3TTSTokenizerV2DecoderLayerScale(config)
417
+ self.mlp_layer_scale = Qwen3TTSTokenizerV2DecoderLayerScale(config)
418
+ self.attention_type = "sliding_attention"
419
+
420
+ def forward(
421
+ self,
422
+ hidden_states: torch.Tensor,
423
+ attention_mask: Optional[torch.Tensor] = None,
424
+ position_ids: Optional[torch.LongTensor] = None,
425
+ past_key_values: Optional[Cache] = None,
426
+ use_cache: Optional[bool] = False,
427
+ cache_position: Optional[torch.LongTensor] = None,
428
+ **kwargs,
429
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
430
+ """
431
+ Args:
432
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
433
+ attention_mask (`torch.FloatTensor`, *optional*):
434
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
435
+ query_sequence_length, key_sequence_length)` if default attention is used.
436
+ output_attentions (`bool`, *optional*):
437
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
438
+ returned tensors for more detail.
439
+ use_cache (`bool`, *optional*):
440
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
441
+ (see `past_key_values`).
442
+ past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
443
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
444
+ Indices depicting the position of the input sequence tokens in the sequence
445
+ kwargs (`dict`, *optional*):
446
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
447
+ into the model
448
+ """
449
+ residual = hidden_states
450
+
451
+ hidden_states = self.input_layernorm(hidden_states)
452
+
453
+ # Self Attention
454
+ hidden_states, _ = self.self_attn(
455
+ hidden_states=hidden_states,
456
+ attention_mask=attention_mask,
457
+ position_ids=position_ids,
458
+ past_key_values=past_key_values,
459
+ use_cache=use_cache,
460
+ cache_position=cache_position,
461
+ **kwargs,
462
+ )
463
+ hidden_states = residual + self.self_attn_layer_scale(hidden_states)
464
+
465
+ # Fully Connected
466
+ residual = hidden_states
467
+ hidden_states = self.post_attention_layernorm(hidden_states)
468
+ hidden_states = self.mlp(hidden_states)
469
+ hidden_states = residual + self.mlp_layer_scale(hidden_states)
470
+
471
+ return hidden_states
472
+
473
+
474
+ @auto_docstring
475
+ class Qwen3TTSTokenizerV2DecoderTransformerModel(Qwen3TTSTokenizerV2DecoderPreTrainedModel):
476
+ _can_record_outputs = {
477
+ "hidden_states": Qwen3TTSTokenizerV2DecoderTransformerLayer,
478
+ "attentions": Qwen3TTSTokenizerV2DecoderAttention,
479
+ }
480
+
481
+ def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig):
482
+ super().__init__(config)
483
+ self.layers = nn.ModuleList(
484
+ [Qwen3TTSTokenizerV2DecoderTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
485
+ )
486
+ self.norm = Qwen3TTSTokenizerV2DecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
487
+ self.rotary_emb = Qwen3TTSTokenizerV2DecoderRotatoryEmbedding(config=config)
488
+ self.gradient_checkpointing = False
489
+ self.has_sliding_layers = "sliding_attention" in self.config.layer_types
490
+ self.window_size = config.sliding_window
491
+
492
+ self.input_proj = nn.Linear(config.latent_dim, config.hidden_size)
493
+ self.output_proj = nn.Linear(config.hidden_size, config.latent_dim)
494
+
495
+ # Initialize weights and apply final processing
496
+ self.post_init()
497
+
498
+ @check_model_inputs()
499
+ @auto_docstring
500
+ def forward(
501
+ self,
502
+ input_ids=None,
503
+ attention_mask=None,
504
+ position_ids=None,
505
+ past_key_values=None,
506
+ inputs_embeds=None,
507
+ use_cache=None,
508
+ cache_position=None,
509
+ **kwargs,
510
+ ) -> BaseModelOutputWithPast:
511
+ if input_ids is not None:
512
+ raise ValueError("input_ids is not expected")
513
+ if (input_ids is None) ^ (inputs_embeds is not None):
514
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
515
+
516
+ if inputs_embeds is None:
517
+ inputs_embeds = self.embed_tokens(input_ids)
518
+
519
+ inputs_embeds = self.input_proj(inputs_embeds)
520
+
521
+ if use_cache and past_key_values is None:
522
+ past_key_values = DynamicCache(config=self.config)
523
+
524
+ if cache_position is None:
525
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
526
+ cache_position = torch.arange(
527
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
528
+ )
529
+
530
+ if position_ids is None:
531
+ position_ids = cache_position.unsqueeze(0)
532
+
533
+ # It may already have been prepared by e.g. `generate`
534
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
535
+ # Prepare mask arguments
536
+ mask_kwargs = {
537
+ "config": self.config,
538
+ "input_embeds": inputs_embeds,
539
+ "attention_mask": attention_mask,
540
+ "cache_position": cache_position,
541
+ "past_key_values": past_key_values,
542
+ "position_ids": position_ids,
543
+ }
544
+ # Create the masks
545
+ causal_mask_mapping = {
546
+ "full_attention": create_causal_mask(**mask_kwargs),
547
+ }
548
+ # The sliding window alternating layers are not always activated depending on the config
549
+ if self.has_sliding_layers:
550
+ causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
551
+
552
+ hidden_states = inputs_embeds
553
+
554
+ # create position embeddings to be shared across the decoder layers
555
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
556
+
557
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
558
+ hidden_states = decoder_layer(
559
+ hidden_states,
560
+ attention_mask=causal_mask_mapping[decoder_layer.attention_type],
561
+ position_ids=position_ids,
562
+ past_key_values=past_key_values,
563
+ use_cache=use_cache,
564
+ cache_position=cache_position,
565
+ position_embeddings=position_embeddings,
566
+ **kwargs,
567
+ )
568
+
569
+ hidden_states = self.norm(hidden_states)
570
+ hidden_states = self.output_proj(hidden_states)
571
+ return BaseModelOutputWithPast(
572
+ last_hidden_state=hidden_states,
573
+ past_key_values=past_key_values if use_cache else None,
574
+ )
575
+
576
+
577
+ class SnakeBeta(nn.Module):
578
+ """
579
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
580
+ Shape:
581
+ - Input: (B, C, T)
582
+ - Output: (B, C, T), same shape as the input
583
+ Parameters:
584
+ - alpha - trainable parameter that controls frequency
585
+ - beta - trainable parameter that controls magnitude
586
+ References:
587
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
588
+ https://huggingface.co/papers/2006.08195
589
+ """
590
+
591
+ def __init__(self, in_features, alpha=1.0):
592
+ super().__init__()
593
+ self.in_features = in_features
594
+
595
+ # initialize alpha
596
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
597
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
598
+
599
+ self.no_div_by_zero = 0.000000001
600
+
601
+ def forward(self, hidden_states):
602
+ """
603
+ Forward pass of the function.
604
+ Applies the function to the input elementwise.
605
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
606
+ """
607
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
608
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
609
+ alpha = torch.exp(alpha)
610
+ beta = torch.exp(beta)
611
+ hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
612
+ torch.sin(hidden_states * alpha), 2
613
+ )
614
+
615
+ return hidden_states
616
+
617
+
618
+ class Qwen3TTSTokenizerV2DecoderDecoderResidualUnit(nn.Module):
619
+ def __init__(self, dim: int = 16, dilation: int = 1):
620
+ super().__init__()
621
+
622
+ self.act1 = SnakeBeta(dim)
623
+ self.conv1 = Qwen3TTSTokenizerV2CausalConvNet(dim, dim, kernel_size=7, dilation=dilation)
624
+ self.act2 = SnakeBeta(dim)
625
+ self.conv2 = Qwen3TTSTokenizerV2CausalConvNet(dim, dim, kernel_size=1)
626
+
627
+ def forward(self, hidden_state):
628
+ residual = hidden_state
629
+
630
+ hidden_state = self.act1(hidden_state)
631
+ hidden_state = self.conv1(hidden_state)
632
+ hidden_state = self.act2(hidden_state)
633
+ hidden_state = self.conv2(hidden_state)
634
+ return hidden_state + residual
635
+
636
+
637
+ class Qwen3TTSTokenizerV2DecoderDecoderBlock(Qwen3TTSTokenizerV2DecoderPreTrainedModel):
638
+ def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, layer_idx):
639
+ super().__init__(config)
640
+ in_dim = config.decoder_dim // 2**layer_idx
641
+ out_dim = config.decoder_dim // 2 ** (layer_idx + 1)
642
+ upsample_rate = config.upsample_rates[layer_idx]
643
+
644
+ block = [
645
+ SnakeBeta(in_dim),
646
+ Qwen3TTSTokenizerV2CausalTransConvNet(in_dim, out_dim, 2 * upsample_rate, upsample_rate),
647
+ ]
648
+
649
+ for dilation in (1, 3, 9):
650
+ block.append(Qwen3TTSTokenizerV2DecoderDecoderResidualUnit(out_dim, dilation))
651
+
652
+ self.block = nn.ModuleList(block)
653
+
654
+ def forward(self, hidden):
655
+ for block in self.block:
656
+ hidden = block(hidden)
657
+ return hidden
658
+
659
+
660
+ class EuclideanCodebook(nn.Module):
661
+ def __init__(
662
+ self,
663
+ dim: int,
664
+ codebook_size: int,
665
+ epsilon: float = 1e-5,
666
+ ):
667
+ super().__init__()
668
+ self.dim = dim
669
+ self.codebook_size = codebook_size
670
+ self.epsilon = epsilon
671
+
672
+ self.cluster_usage = nn.Parameter(torch.ones(codebook_size))
673
+ self.embedding_sum = nn.Parameter(torch.zeros(codebook_size, dim))
674
+
675
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
676
+ embedding = self.embedding_sum / self.cluster_usage.clamp(min=self.epsilon)[:, None]
677
+ quantized = F.embedding(codes, embedding)
678
+ return quantized
679
+
680
+
681
+ class VectorQuantization(nn.Module):
682
+ def __init__(
683
+ self,
684
+ dim: int,
685
+ codebook_size: int,
686
+ codebook_dim: Optional[int] = None,
687
+ epsilon: float = 1e-5,
688
+ ):
689
+ super().__init__()
690
+ if codebook_dim is None:
691
+ codebook_dim = dim
692
+
693
+ requires_projection = codebook_dim != dim
694
+
695
+ self.project_out = (
696
+ nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
697
+ )
698
+ self.epsilon = epsilon
699
+ self._codebook = EuclideanCodebook(
700
+ dim=codebook_dim,
701
+ codebook_size=codebook_size,
702
+ epsilon=epsilon
703
+ )
704
+ self.codebook_size = codebook_size
705
+
706
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
707
+ quantized = self._codebook.decode(codes)
708
+ quantized = self.project_out(quantized)
709
+ quantized = quantized.transpose(1, 2)
710
+ return quantized
711
+
712
+
713
+ class ResidualVectorQuantization(nn.Module):
714
+ def __init__(self, *, num_quantizers: int, **kwargs):
715
+ super().__init__()
716
+ self.layers = nn.ModuleList(
717
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
718
+ )
719
+
720
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
721
+ quantized = torch.zeros([1], device=codes.device)[0]
722
+ for idx, layer_codes in enumerate(codes):
723
+ layer = self.layers[idx]
724
+ assert isinstance(layer, VectorQuantization)
725
+ quantized = quantized + layer.decode(layer_codes)
726
+ return quantized
727
+
728
+
729
+ class ResidualVectorQuantizer(nn.Module):
730
+ def __init__(
731
+ self,
732
+ dimension: int = 128,
733
+ input_dimension: Optional[int] = None,
734
+ output_dimension: Optional[int] = None,
735
+ n_q: int = 8,
736
+ q_dropout: bool = False,
737
+ no_quantization_rate: float = 0.0,
738
+ bins: int = 1024,
739
+ decay: float = 0.99,
740
+ force_projection: bool = False,
741
+ ):
742
+ super().__init__()
743
+ self.max_n_q = n_q
744
+ self.n_q = n_q
745
+ self.q_dropout = q_dropout
746
+ self.no_quantization_rate = no_quantization_rate
747
+ self.dimension = dimension
748
+ self.input_dimension = input_dimension or dimension
749
+ self.output_dimension = output_dimension or dimension
750
+ self.bins = bins
751
+ self.decay = decay
752
+ self.input_proj: torch.nn.Module
753
+ self.output_proj: torch.nn.Module
754
+ if self.input_dimension == self.dimension and not force_projection:
755
+ self.input_proj = torch.nn.Identity()
756
+ else:
757
+ self.input_proj = torch.nn.Conv1d(
758
+ self.input_dimension, self.dimension, 1, bias=False
759
+ )
760
+ if self.output_dimension == self.dimension and not force_projection:
761
+ self.output_proj = torch.nn.Identity()
762
+ else:
763
+ self.output_proj = torch.nn.Conv1d(
764
+ self.dimension, self.output_dimension, 1, bias=False
765
+ )
766
+ self.vq = ResidualVectorQuantization(
767
+ dim=self.dimension,
768
+ codebook_size=self.bins,
769
+ num_quantizers=self.n_q
770
+ )
771
+
772
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
773
+ codes = codes.transpose(0, 1)
774
+ quantized = self.vq.decode(codes)
775
+ quantized = self.output_proj(quantized)
776
+ return quantized
777
+
778
+
779
+ class SplitResidualVectorQuantizer(nn.Module):
780
+ """Residual Vector Quantizer with separate projections for the first quantizer and the rest.
781
+
782
+ Args:
783
+ n_q (int): Number of residual vector quantizers used.
784
+ n_semantic_q (int): Number of residual vector quantizers used for the semantic quantizer.
785
+ **kwargs: Arguments to the constructor of `ResidualVectorQuantizer` that are shared between both.
786
+ """
787
+
788
+ def __init__(
789
+ self,
790
+ *,
791
+ n_q: int = 8,
792
+ n_q_semantic: int = 1,
793
+ **kwargs,
794
+ ):
795
+ super().__init__()
796
+ assert n_q > n_q_semantic, (
797
+ f"Number of quantizers {n_q} must be larger "
798
+ f"than the number of semantic quantizers {n_q_semantic}."
799
+ )
800
+ self.max_n_q = n_q
801
+ self.n_q_semantic = n_q_semantic
802
+ self.n_q_acoustic = n_q - n_q_semantic
803
+ q_dropout = kwargs.pop("q_dropout", False)
804
+ self.rvq_first = ResidualVectorQuantizer(
805
+ n_q=n_q_semantic, force_projection=True, q_dropout=False, **kwargs
806
+ )
807
+ self.rvq_rest = ResidualVectorQuantizer(
808
+ n_q=n_q - n_q_semantic,
809
+ force_projection=True,
810
+ q_dropout=q_dropout,
811
+ **kwargs,
812
+ )
813
+
814
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
815
+ """Decode the given codes to the quantized representation."""
816
+ # codes is [B, K, T], with T frames, K nb of codebooks.
817
+ quantized = self.rvq_first.decode(codes[:, : self.n_q_semantic])
818
+ if codes.shape[1] > self.n_q_semantic:
819
+ quantized += self.rvq_rest.decode(codes[:, self.n_q_semantic :])
820
+ return quantized
821
+
822
+
823
+ class Qwen3TTSTokenizerV2Decoder(Qwen3TTSTokenizerV2DecoderPreTrainedModel):
824
+ def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig):
825
+ super().__init__(config)
826
+ self.total_upsample = np.prod(config.upsample_rates + config.upsampling_ratios)
827
+ self.pre_transformer = Qwen3TTSTokenizerV2DecoderTransformerModel._from_config(config)
828
+
829
+ self.quantizer = SplitResidualVectorQuantizer(
830
+ dimension=config.codebook_dim // 2,
831
+ n_q=config.num_quantizers,
832
+ n_q_semantic=1,
833
+ bins=config.codebook_size,
834
+ input_dimension=config.codebook_dim,
835
+ output_dimension=config.codebook_dim,
836
+ )
837
+
838
+ self.pre_conv = Qwen3TTSTokenizerV2CausalConvNet(
839
+ config.codebook_dim,
840
+ config.latent_dim,
841
+ kernel_size=3,
842
+ )
843
+
844
+ upsample = []
845
+ for factor in config.upsampling_ratios:
846
+ upsample.append(
847
+ nn.ModuleList(
848
+ [
849
+ Qwen3TTSTokenizerV2CausalTransConvNet(config.latent_dim, config.latent_dim, factor, factor),
850
+ Qwen3TTSTokenizerV2ConvNeXtBlock(config.latent_dim),
851
+ ]
852
+ )
853
+ )
854
+ self.upsample = nn.ModuleList(upsample)
855
+
856
+ decoder = [Qwen3TTSTokenizerV2CausalConvNet(config.latent_dim, config.decoder_dim, 7)]
857
+ for i in range(len(config.upsample_rates)):
858
+ decoder.append(Qwen3TTSTokenizerV2DecoderDecoderBlock(config, i))
859
+ output_dim = config.decoder_dim // 2 ** len(config.upsample_rates)
860
+ decoder += [
861
+ SnakeBeta(output_dim),
862
+ Qwen3TTSTokenizerV2CausalConvNet(output_dim, 1, 7),
863
+ ]
864
+ self.decoder = nn.ModuleList(decoder)
865
+
866
+ self.post_init()
867
+
868
+ def forward(self, codes):
869
+ if codes.shape[1] != self.config.num_quantizers:
870
+ raise ValueError(f"Expected {self.config.num_quantizers} layer of codes, got {codes.shape[1]}")
871
+
872
+ hidden = self.quantizer.decode(codes)
873
+ hidden = self.pre_conv(hidden).transpose(1, 2)
874
+
875
+ hidden = self.pre_transformer(inputs_embeds=hidden).last_hidden_state
876
+ hidden = hidden.permute(0, 2, 1)
877
+ for blocks in self.upsample:
878
+ for block in blocks:
879
+ hidden = block(hidden)
880
+ wav = hidden
881
+ for block in self.decoder:
882
+ wav = block(wav)
883
+ return wav.clamp(min=-1, max=1)
884
+
885
+ def chunked_decode(self, codes, chunk_size=300, left_context_size=25):
886
+ wavs = []
887
+ start_index = 0
888
+ while start_index < codes.shape[-1]:
889
+ end_index = min(start_index + chunk_size, codes.shape[-1])
890
+ context_size = left_context_size if start_index - left_context_size > 0 else start_index
891
+ codes_chunk = codes[..., start_index - context_size : end_index]
892
+ wav_chunk = self(codes_chunk)
893
+ wavs.append(wav_chunk[..., context_size * self.total_upsample :])
894
+ start_index = end_index
895
+ return torch.cat(wavs, dim=-1)
896
+
897
+
898
+ class Qwen3TTSTokenizerV2Encoder(MimiModel):
899
+ def __init__(self, config: MimiConfig):
900
+ super().__init__(config)
901
+ self.config = config
902
+
903
+ self.upsample = None
904
+ self.decoder_transformer = None
905
+ self.decoder = None
906
+
907
+ self.post_init()
908
+
909
+
910
+ @auto_docstring
911
+ class Qwen3TTSTokenizerV2PreTrainedModel(PreTrainedModel):
912
+ config: Qwen3TTSTokenizerV2Config
913
+ base_model_prefix = "model"
914
+ supports_gradient_checkpointing = True
915
+ _skip_keys_device_placement = "past_key_values"
916
+ _supports_flash_attn = True
917
+ _supports_sdpa = True
918
+ _can_compile_fullgraph = False
919
+ _supports_attention_backend = True
920
+
921
+
922
+ @auto_docstring(
923
+ custom_intro="""
924
+ The Qwen3TTSTokenizerV2 model.
925
+ """
926
+ )
927
+ class Qwen3TTSTokenizerV2Model(Qwen3TTSTokenizerV2PreTrainedModel):
928
+ def __init__(self, config: Qwen3TTSTokenizerV2Config):
929
+ super().__init__(config)
930
+ self.config = config
931
+
932
+ self.encoder_valid_num_quantizers = config.encoder_valid_num_quantizers
933
+
934
+ self.input_sample_rate = config.input_sample_rate
935
+ self.output_sample_rate = config.output_sample_rate
936
+
937
+ self.decode_upsample_rate = config.decode_upsample_rate
938
+ self.encode_downsample_rate = config.encode_downsample_rate
939
+
940
+ self.encoder = Qwen3TTSTokenizerV2Encoder._from_config(self.config.encoder_config)
941
+ self.decoder = Qwen3TTSTokenizerV2Decoder._from_config(self.config.decoder_config)
942
+
943
+ self.post_init()
944
+
945
+ def get_model_type(self):
946
+ return self.config.model_type
947
+
948
+ def get_input_sample_rate(self):
949
+ return self.input_sample_rate
950
+
951
+ def get_output_sample_rate(self):
952
+ return self.output_sample_rate
953
+
954
+ def get_encode_downsample_rate(self):
955
+ return self.encode_downsample_rate
956
+
957
+ def get_decode_upsample_rate(self):
958
+ return self.decode_upsample_rate
959
+
960
+ def encode(
961
+ self,
962
+ input_values: torch.Tensor,
963
+ padding_mask: Optional[torch.Tensor] = None,
964
+ return_dict: Optional[bool] = None,
965
+ ) -> Union[tuple[torch.Tensor, Optional[torch.Tensor]], Qwen3TTSTokenizerV2EncoderOutput]:
966
+ """
967
+ Encodes the input audio waveform into discrete codes.
968
+
969
+ Args:
970
+ input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`):
971
+ Float values of the input audio waveform.
972
+ padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`):
973
+ Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0
974
+ for *masked*.
975
+ return_dict (`bool`, *optional*):
976
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
977
+ """
978
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
979
+
980
+ encoded_frames = self.encoder.encode(input_values=input_values.unsqueeze(1),
981
+ return_dict=True)
982
+ audio_codes = encoded_frames.audio_codes[:, :self.encoder_valid_num_quantizers]
983
+ audio_codes = [code[..., :-(-mask.sum() // self.encode_downsample_rate)].transpose(0, 1) for code, mask in zip(audio_codes, padding_mask)]
984
+
985
+ if not return_dict:
986
+ return (
987
+ audio_codes,
988
+ )
989
+
990
+ return Qwen3TTSTokenizerV2EncoderOutput(audio_codes)
991
+
992
+ def decode(
993
+ self,
994
+ audio_codes: torch.Tensor,
995
+ return_dict: Optional[bool] = None,
996
+ ) -> Union[tuple[torch.Tensor, torch.Tensor], Qwen3TTSTokenizerV2DecoderOutput]:
997
+ """
998
+ Decodes the given frames into an output audio waveform.
999
+
1000
+ Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
1001
+ trimmed.
1002
+
1003
+ Args:
1004
+ audio_codes (`torch.LongTensor` of shape `(batch_size, codes_length, num_quantizers)`, *optional*):
1005
+ Discret code embeddings computed using `model.encode`.
1006
+ return_dict (`bool`, *optional*):
1007
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1008
+
1009
+ """
1010
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1011
+
1012
+ audio_values = self.decoder.chunked_decode(audio_codes.transpose(1, 2)).squeeze(1)
1013
+
1014
+ audio_lengths = (audio_codes[..., 0] > 0).sum(1) * self.decode_upsample_rate
1015
+ audio_values = [a[:l] for a, l in zip(audio_values, audio_lengths)]
1016
+
1017
+ if not return_dict:
1018
+ return (
1019
+ audio_values,
1020
+ )
1021
+
1022
+ return Qwen3TTSTokenizerV2DecoderOutput(audio_values)
1023
+
1024
+
1025
+ __all__ = ["Qwen3TTSTokenizerV2Model", "Qwen3TTSTokenizerV2PreTrainedModel"]
qwen_tts/core/tokenizer_25hz/configuration_qwen3_tts_tokenizer_v1.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Qwen3TTSTokenizerV1 model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class Qwen3TTSTokenizerV1DecoderDiTConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1DecoderToken2WavDiT.
27
+ It defines the architecture of the DiT model, which is used for generating mel-spectrograms from tokens.
28
+
29
+ Args:
30
+ hidden_size (`int`, *optional*, defaults to 1024):
31
+ The dimension of the model.
32
+ num_hidden_layers (`int`, *optional*, defaults to 22):
33
+ The number of transformer blocks in the DiT model.
34
+ num_attention_heads (`int`, *optional*, defaults to 16):
35
+ The number of attention heads in each transformer block.
36
+ ff_mult (`int`, *optional*, defaults to 2):
37
+ The multiplier for the feedforward layer in each transformer block.
38
+ emb_dim (`int`, *optional*, defaults to 512):
39
+ The dimension of the embedding layer.
40
+ head_dim (`int`, *optional*, defaults to 64):
41
+ The dimension of each attention head.
42
+ repeats (`int`, *optional*, defaults to 2):
43
+ The number of times the codec embeddings are repeated.
44
+ num_embeds (`int`, *optional*, defaults to 8193):
45
+ The number of unique embeddings in the codec.
46
+ mel_dim (`int`, *optional*, defaults to 80):
47
+ The dimension of the mel-spectrogram.
48
+ dropout (`float`, *optional*, defaults to 0.1):
49
+ The dropout rate for the transformer blocks.
50
+
51
+ enc_emb_dim (`int`, *optional*, defaults to 192):
52
+ The dimension of the pre-trained speaker embedding.
53
+ enc_dim (`int`, *optional*, defaults to 128):
54
+ The dimension of the encoder output.
55
+ enc_channels (`list[int]`, *optional*, defaults to `[256, 256, 256, 256, 768]`):
56
+ A list of output channels for each TDNN/SERes2Net layer in the encoder.
57
+ enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`):
58
+ A list of kernel sizes for each layer in the encoder.
59
+ enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`):
60
+ A list of dilations for each layer in the encoder.
61
+ enc_attention_channels (`int`, *optional*, defaults to 64):
62
+ The number of attention channels in the SqueezeExcitationBlock.
63
+ enc_res2net_scale (`int`, *optional*, defaults to 2):
64
+ The scale of the Res2Net block in the encoder.
65
+ enc_se_channels (`int`, *optional*, defaults to 64):
66
+ The number of output channels after squeeze in the SqueezeExcitationBlock.
67
+ """
68
+
69
+ model_type = "qwen3_tts_tokenizer_v1_decoder_dit"
70
+
71
+ def __init__(
72
+ self,
73
+ hidden_size=1024,
74
+ num_hidden_layers=22,
75
+ num_attention_heads=16,
76
+ ff_mult=2,
77
+ emb_dim=512,
78
+ head_dim=64,
79
+ rope_theta=10000.0,
80
+ max_position_embeddings=32768,
81
+ block_size=24,
82
+ look_ahead_layers=[10],
83
+ look_backward_layers=[0, 20],
84
+ repeats=2,
85
+ num_embeds=8193,
86
+ mel_dim=80,
87
+ dropout=0.1,
88
+ enc_emb_dim=192,
89
+ enc_dim=128,
90
+ enc_channels=[256, 256, 256, 256, 768],
91
+ enc_kernel_sizes=[5, 3, 3, 3, 1],
92
+ enc_dilations=[1, 2, 3, 4, 1],
93
+ enc_attention_channels=64,
94
+ enc_res2net_scale=2,
95
+ enc_se_channels=64,
96
+ **kwargs,
97
+ ):
98
+ self.hidden_size = hidden_size
99
+ self.num_hidden_layers = num_hidden_layers
100
+ self.num_attention_heads = num_attention_heads
101
+ self.ff_mult = ff_mult
102
+ self.emb_dim = emb_dim
103
+ self.head_dim = head_dim
104
+ self.rope_theta = rope_theta
105
+ self.max_position_embeddings = max_position_embeddings
106
+ self.block_size = block_size
107
+ self.look_ahead_layers = look_ahead_layers
108
+ self.look_backward_layers = look_backward_layers
109
+ self.repeats = repeats
110
+ self.num_embeds = num_embeds
111
+ self.mel_dim = mel_dim
112
+ self.dropout = dropout
113
+ self.enc_emb_dim = enc_emb_dim
114
+ self.enc_dim = enc_dim
115
+ self.enc_channels = enc_channels
116
+ self.enc_kernel_sizes = enc_kernel_sizes
117
+ self.enc_dilations = enc_dilations
118
+ self.enc_attention_channels = enc_attention_channels
119
+ self.enc_res2net_scale = enc_res2net_scale
120
+ self.enc_se_channels = enc_se_channels
121
+ super().__init__(**kwargs)
122
+
123
+
124
+ class Qwen3TTSTokenizerV1DecoderBigVGANConfig(PretrainedConfig):
125
+ r"""
126
+ This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1DecoderToken2WavBigVGAN module.
127
+ It defines the architecture of the BigVGAN model, which is used for converting mel-spectrograms to waveforms.
128
+
129
+ Args:
130
+ mel_dim (`int`, *optional*, defaults to 80):
131
+ The dimension of the mel-spectrogram.
132
+ upsample_initial_channel (`int`, *optional*, defaults to 1536):
133
+ The number of channels in the initial upsampling layer.
134
+ resblock_kernel_sizes (`list[int]`, *optional*, defaults to `[3, 7, 11]`):
135
+ A list of kernel sizes for each residual block.
136
+ resblock_dilation_sizes (`list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
137
+ A list of dilation sizes for each residual block.
138
+ upsample_rates (`list[int]`, *optional*, defaults to `[5, 3, 2, 2, 2, 2]`):
139
+ A list of upsampling rates for each upsampling layer.
140
+ upsample_kernel_sizes (`list[int]`, *optional*, defaults to `[11, 7, 4, 4, 4, 4]`):
141
+ A list of kernel sizes for each upsampling layer.
142
+ """
143
+
144
+ model_type = "qwen3_tts_tokenizer_v1_decoder_bigvgan"
145
+
146
+ def __init__(
147
+ self,
148
+ mel_dim=80,
149
+ upsample_initial_channel=1536,
150
+ resblock_kernel_sizes=[3, 7, 11],
151
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
152
+ upsample_rates=[5, 3, 2, 2, 2, 2],
153
+ upsample_kernel_sizes=[11, 7, 4, 4, 4, 4],
154
+ **kwargs,
155
+ ):
156
+ self.mel_dim = mel_dim
157
+ self.upsample_initial_channel = upsample_initial_channel
158
+ self.resblock_kernel_sizes = resblock_kernel_sizes
159
+ self.resblock_dilation_sizes = resblock_dilation_sizes
160
+ self.upsample_rates = upsample_rates
161
+ self.upsample_kernel_sizes = upsample_kernel_sizes
162
+ super().__init__(**kwargs)
163
+
164
+
165
+ class Qwen3TTSTokenizerV1DecoderConfig(PretrainedConfig):
166
+ r"""
167
+ This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV1DecoderConfig`].
168
+
169
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
170
+ documentation from [`PretrainedConfig`] for more information.
171
+
172
+ Args:
173
+ dit_config ([`DiT_Args`], *optional*):
174
+ Configuration class for the Diffusion Transformer (DiT) module responsible for generating mel-spectrograms.
175
+ bigvgan_config ([`BigVGAN_Args`], *optional*):
176
+ Configuration class for the BigVGAN module responsible for converting mel-spectrograms to waveforms.
177
+ """
178
+
179
+ model_type = "qwen3_tts_tokenizer_v1_decoder"
180
+ sub_configs = {
181
+ "dit_config": Qwen3TTSTokenizerV1DecoderDiTConfig,
182
+ "bigvgan_config": Qwen3TTSTokenizerV1DecoderBigVGANConfig,
183
+ }
184
+
185
+ def __init__(self, dit_config=None, bigvgan_config=None, **kwargs):
186
+ if dit_config is None:
187
+ dit_config = {}
188
+ if bigvgan_config is None:
189
+ bigvgan_config = {}
190
+ self.dit_config = Qwen3TTSTokenizerV1DecoderDiTConfig(**dit_config)
191
+ self.bigvgan_config = Qwen3TTSTokenizerV1DecoderBigVGANConfig(**bigvgan_config)
192
+ super().__init__(**kwargs)
193
+
194
+
195
+ class Qwen3TTSTokenizerV1EncoderConfig(PretrainedConfig):
196
+ r"""
197
+ This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1 Encoder.
198
+
199
+ The encoder typically takes mel-spectrogram features and produces high-level audio representations, then (optionally)
200
+ applies an Audio-VQ module (e.g., GRVQ) to discretize continuous representations into codes.
201
+
202
+ Args:
203
+ n_mels (`int`, *optional*, defaults to 128):
204
+ Number of mel bins in the input mel-spectrogram.
205
+ n_ctx (`int`, *optional*, defaults to 1500):
206
+ Maximum input sequence length (in frames/tokens) for the encoder.
207
+ n_state (`int`, *optional*, defaults to 1280):
208
+ Hidden size (model dimension) of the encoder transformer.
209
+ n_head (`int`, *optional*, defaults to 20):
210
+ Number of attention heads in each transformer layer.
211
+ n_layer (`int`, *optional*, defaults to 32):
212
+ Number of transformer layers.
213
+ n_window (`int`, *optional*, defaults to 100):
214
+ Window size used by the model for local attention / chunking (implementation-dependent).
215
+ output_dim (`int`, *optional*, defaults to 3584):
216
+ Output feature dimension produced by the encoder head (before/after projection, implementation-dependent).
217
+
218
+ grad_checkpointing (`bool`, *optional*, defaults to `False`):
219
+ Whether to enable gradient checkpointing to reduce memory usage during training.
220
+ enable_mp (`bool`, *optional*, defaults to `False`):
221
+ Whether to enable model parallel features (implementation-dependent).
222
+ audio_sequence_parallel (`bool`, *optional*, defaults to `False`):
223
+ Whether to enable sequence parallelism for audio branch (implementation-dependent).
224
+
225
+ audio_vq_type (`str`, *optional*, defaults to `"GRVQ"`):
226
+ Type of audio vector-quantization module. Common choices: `"GRVQ"`, `"RVQ"`, etc.
227
+ audio_vq_layers (`int`, *optional*, defaults to 6):
228
+ Number of VQ layers / quantizers (e.g., number of residual quantizers for RVQ/GRVQ-like designs).
229
+ audio_vq_codebook_size (`int`, *optional*, defaults to 32768):
230
+ Size of each codebook (number of entries).
231
+ audio_vq_codebook_dim (`int`, *optional*, defaults to 1280):
232
+ Dimension of codebook vectors (often equals encoder hidden size).
233
+ audio_vq_pe (`bool`, *optional*, defaults to `True`):
234
+ Whether to use positional encoding (or position embeddings) inside the VQ module.
235
+ audio_vq_ds_rate (`int`, *optional*, defaults to 2):
236
+ Downsampling rate applied before VQ (e.g., temporal downsample factor).
237
+ """
238
+
239
+ model_type = "qwen3_tts_tokenizer_v1_encoder"
240
+
241
+ def __init__(
242
+ self,
243
+ n_mels=128,
244
+ n_ctx=1500,
245
+ n_state=1280,
246
+ n_head=20,
247
+ n_layer=32,
248
+ n_window=100,
249
+ output_dim=3584,
250
+ grad_checkpointing=False,
251
+ enable_mp=False,
252
+ audio_sequence_parallel=False,
253
+ audio_vq_type="GRVQ",
254
+ audio_vq_layers=6,
255
+ audio_vq_codebook_size=32768,
256
+ audio_vq_codebook_dim=1280,
257
+ audio_vq_pe=True,
258
+ audio_vq_ds_rate=2,
259
+ **kwargs,
260
+ ):
261
+ super().__init__(**kwargs)
262
+ self.n_mels = n_mels
263
+ self.n_ctx = n_ctx
264
+ self.n_state = n_state
265
+ self.n_head = n_head
266
+ self.n_layer = n_layer
267
+ self.n_window = n_window
268
+ self.output_dim = output_dim
269
+ self.grad_checkpointing = grad_checkpointing
270
+ self.enable_mp = enable_mp
271
+ self.audio_sequence_parallel = audio_sequence_parallel
272
+ self.audio_vq_type = audio_vq_type
273
+ self.audio_vq_layers = audio_vq_layers
274
+ self.audio_vq_codebook_size = audio_vq_codebook_size
275
+ self.audio_vq_codebook_dim = audio_vq_codebook_dim
276
+ self.audio_vq_pe = audio_vq_pe
277
+ self.audio_vq_ds_rate = audio_vq_ds_rate
278
+
279
+
280
+ class Qwen3TTSTokenizerV1Config(PretrainedConfig):
281
+ """
282
+ This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV1Config`]. It is used to instantiate a Qwen3TTSTokenizerV1Model
283
+ model according to the specified sub-models configurations, defining the model architecture.
284
+
285
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
286
+ documentation from [`PretrainedConfig`] for more information.
287
+
288
+ Args:
289
+ encoder_config (`dict`, *optional*): Configuration of the underlying encoder sub-model.
290
+ decoder_config (`dict`, *optional*): Configuration of the underlying decoder sub-model.
291
+ """
292
+
293
+ model_type = "qwen3_tts_tokenizer_25hz"
294
+ sub_configs = {
295
+ "encoder_config": Qwen3TTSTokenizerV1EncoderConfig,
296
+ "decoder_config": Qwen3TTSTokenizerV1DecoderConfig,
297
+ }
298
+
299
+ def __init__(
300
+ self,
301
+ encoder_config=None,
302
+ decoder_config=None,
303
+ input_sample_rate=24000,
304
+ output_sample_rate=24000,
305
+ decode_upsample_rate=1920,
306
+ encode_downsample_rate=1920,
307
+ **kwargs,
308
+ ):
309
+ super().__init__(**kwargs)
310
+ if encoder_config is None:
311
+ encoder_config = {}
312
+ logger.info("encoder_config is None. Initializing encoder with default values")
313
+ if decoder_config is None:
314
+ decoder_config = {}
315
+ logger.info("decoder_config is None. Initializing decoder with default values")
316
+
317
+ self.encoder_config = Qwen3TTSTokenizerV1EncoderConfig(**encoder_config)
318
+ self.decoder_config = Qwen3TTSTokenizerV1DecoderConfig(**decoder_config)
319
+
320
+ self.input_sample_rate = input_sample_rate
321
+ self.output_sample_rate = output_sample_rate
322
+ self.decode_upsample_rate = decode_upsample_rate
323
+ self.encode_downsample_rate = encode_downsample_rate
324
+
325
+
326
+ __all__ = [
327
+ "Qwen3TTSTokenizerV1Config",
328
+ "Qwen3TTSTokenizerV1EncoderConfig",
329
+ "Qwen3TTSTokenizerV1DecoderConfig",
330
+ "Qwen3TTSTokenizerV1DecoderBigVGANConfig",
331
+ "Qwen3TTSTokenizerV1DecoderDiTConfig"
332
+ ]
qwen_tts/core/tokenizer_25hz/modeling_qwen3_tts_tokenizer_v1.py ADDED
@@ -0,0 +1,1528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Qwen3TTSTokenizerV1 model."""
16
+
17
+ import math
18
+ from dataclasses import dataclass
19
+ from typing import Optional, Union, List
20
+
21
+ import numpy as np
22
+ import torch
23
+ from torch import nn
24
+ from torch.nn import Parameter
25
+ from torch.nn import functional as F
26
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
27
+ from transformers.utils import ModelOutput, auto_docstring, logging
28
+ from transformers.utils.hub import cached_file
29
+
30
+ from torch.nn.utils.rnn import pad_sequence
31
+
32
+ from .vq.whisper_encoder import get_mel_audio, get_T_after_cnn
33
+ from .vq.speech_vq import WhisperEncoderVQ, XVectorExtractor
34
+
35
+ from .configuration_qwen3_tts_tokenizer_v1 import (
36
+ Qwen3TTSTokenizerV1Config,
37
+ Qwen3TTSTokenizerV1EncoderConfig,
38
+ Qwen3TTSTokenizerV1DecoderConfig,
39
+ Qwen3TTSTokenizerV1DecoderBigVGANConfig,
40
+ Qwen3TTSTokenizerV1DecoderDiTConfig
41
+ )
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ @dataclass
47
+ @auto_docstring
48
+ class Qwen3TTSTokenizerV1EncoderOutput(ModelOutput):
49
+ r"""
50
+ audio_codes (`List[torch.LongTensor]`):
51
+ Discret code embeddings computed using `model.encode`, each tensor has shape (codes_length_i,).
52
+ xvectors (`List[torch.FloatTensor]`):
53
+ X-vector embeddings computed using `model.encode`, each tensor has shape (xvector_dim,).
54
+ ref_mels (`List[torch.FloatTensor]`):
55
+ Reference mel spectrogram computed using `model.encode`, each tensor has shape (mel_length_i, mel_dim,).
56
+ """
57
+
58
+ audio_codes: List[torch.LongTensor] = None
59
+ xvectors: List[torch.FloatTensor] = None
60
+ ref_mels: List[torch.FloatTensor] = None
61
+
62
+
63
+ @dataclass
64
+ @auto_docstring
65
+ class Qwen3TTSTokenizerV1DecoderOutput(ModelOutput):
66
+ r"""
67
+ audio_values (`List[torch.FloatTensor]`):
68
+ Decoded audio values, obtained using the decoder part of Qwen3TTSTokenizerV1.
69
+ Each tensor has shape (segment_length_i).
70
+ """
71
+
72
+ audio_values: List[torch.FloatTensor] = None
73
+
74
+
75
+ @auto_docstring
76
+ class Qwen3TTSTokenizerV1DecoderPreTrainedModel(PreTrainedModel):
77
+ config: Qwen3TTSTokenizerV1DecoderConfig
78
+ base_model_prefix = "model"
79
+ supports_gradient_checkpointing = True
80
+ _skip_keys_device_placement = "past_key_values"
81
+ _supports_flash_attn = True
82
+ _supports_sdpa = True
83
+ _can_compile_fullgraph = False
84
+ _supports_attention_backend = True
85
+
86
+
87
+ @auto_docstring
88
+ class Qwen3TTSTokenizerV1EncoderPreTrainedModel(PreTrainedModel):
89
+ config: Qwen3TTSTokenizerV1EncoderConfig
90
+ base_model_prefix = "model"
91
+ supports_gradient_checkpointing = True
92
+ _skip_keys_device_placement = "past_key_values"
93
+ _supports_flash_attn = True
94
+ _supports_sdpa = True
95
+ _can_compile_fullgraph = False
96
+ _supports_attention_backend = True
97
+
98
+
99
+ class Qwen3TTSTokenizerV1DecoderDiTRotaryEmbedding(nn.Module):
100
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
101
+
102
+ def __init__(self, dim, base=10000):
103
+ super().__init__()
104
+
105
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
106
+ self.register_buffer("inv_freq", inv_freq)
107
+
108
+ def forward(self, x):
109
+ batch_size, seq_len = x.shape[0], x.shape[1]
110
+ t = torch.arange(seq_len, device=x.device)
111
+ device_type = x.device.type
112
+ device_type = device_type if device_type != "mps" else "cpu"
113
+ with torch.autocast(device_type=device_type, enabled=False):
114
+ freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float()
115
+ freqs = torch.stack((freqs, freqs), dim=-1)
116
+ freqs = freqs.reshape(*freqs.shape[:-2], -1)
117
+ freqs = freqs.repeat(batch_size, *([1] * freqs.dim()))
118
+ cos = freqs.cos()
119
+ sin = freqs.sin()
120
+
121
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
122
+
123
+
124
+ class TimeDelayNetBlock(nn.Module):
125
+ def __init__(
126
+ self,
127
+ in_channels,
128
+ out_channels,
129
+ kernel_size,
130
+ dilation,
131
+ ):
132
+ super().__init__()
133
+ self.conv = nn.Conv1d(
134
+ in_channels=in_channels,
135
+ out_channels=out_channels,
136
+ kernel_size=kernel_size,
137
+ dilation=dilation,
138
+ padding="same",
139
+ padding_mode="reflect",
140
+ )
141
+ self.activation = nn.ReLU()
142
+
143
+ def forward(self, hidden_states: torch.Tensor):
144
+ return self.activation(self.conv(hidden_states))
145
+
146
+
147
+ class Res2NetBlock(torch.nn.Module):
148
+ def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
149
+ super().__init__()
150
+
151
+ in_channel = in_channels // scale
152
+ hidden_channel = out_channels // scale
153
+
154
+ self.blocks = nn.ModuleList(
155
+ [
156
+ TimeDelayNetBlock(
157
+ in_channel,
158
+ hidden_channel,
159
+ kernel_size=kernel_size,
160
+ dilation=dilation,
161
+ )
162
+ for i in range(scale - 1)
163
+ ]
164
+ )
165
+ self.scale = scale
166
+
167
+ def forward(self, hidden_states):
168
+ outputs = []
169
+ for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)):
170
+ if i == 0:
171
+ output_part = hidden_part
172
+ elif i == 1:
173
+ output_part = self.blocks[i - 1](hidden_part)
174
+ else:
175
+ output_part = self.blocks[i - 1](hidden_part + output_part)
176
+ outputs.append(output_part)
177
+ output = torch.cat(outputs, dim=1)
178
+ return output
179
+
180
+
181
+ class SqueezeExcitationBlock(nn.Module):
182
+ def __init__(self, in_channels, se_channels, out_channels):
183
+ super().__init__()
184
+
185
+ self.conv1 = nn.Conv1d(
186
+ in_channels=in_channels,
187
+ out_channels=se_channels,
188
+ kernel_size=1,
189
+ padding="same",
190
+ padding_mode="reflect",
191
+ )
192
+ self.relu = nn.ReLU(inplace=True)
193
+ self.conv2 = nn.Conv1d(
194
+ in_channels=se_channels,
195
+ out_channels=out_channels,
196
+ kernel_size=1,
197
+ padding="same",
198
+ padding_mode="reflect",
199
+ )
200
+ self.sigmoid = nn.Sigmoid()
201
+
202
+ def forward(self, hidden_states):
203
+ hidden_states_mean = hidden_states.mean(dim=2, keepdim=True)
204
+
205
+ hidden_states_mean = self.relu(self.conv1(hidden_states_mean))
206
+ hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean))
207
+
208
+ return hidden_states * hidden_states_mean
209
+
210
+
211
+ class AttentiveStatisticsPooling(nn.Module):
212
+ """This class implements an attentive statistic pooling layer for each channel.
213
+ It returns the concatenated mean and std of the input tensor.
214
+ """
215
+
216
+ def __init__(self, channels, attention_channels=128):
217
+ super().__init__()
218
+
219
+ self.eps = 1e-12
220
+ self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1)
221
+ self.tanh = nn.Tanh()
222
+ self.conv = nn.Conv1d(
223
+ in_channels=attention_channels,
224
+ out_channels=channels,
225
+ kernel_size=1,
226
+ padding="same",
227
+ padding_mode="reflect",
228
+ )
229
+
230
+ def _length_to_mask(self, length, max_len=None, dtype=None, device=None):
231
+ """Creates a binary mask for each sequence.
232
+
233
+ Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
234
+
235
+ Arguments
236
+ ---------
237
+ length : torch.LongTensor
238
+ Containing the length of each sequence in the batch. Must be 1D.
239
+ max_len : int
240
+ Max length for the mask, also the size of the second dimension.
241
+ dtype : torch.dtype, default: None
242
+ The dtype of the generated mask.
243
+ device: torch.device, default: None
244
+ The device to put the mask variable.
245
+
246
+ Returns
247
+ -------
248
+ mask : tensor
249
+ The binary mask.
250
+ """
251
+
252
+ if max_len is None:
253
+ max_len = length.max().long().item() # using arange to generate mask
254
+ mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
255
+ len(length), max_len
256
+ ) < length.unsqueeze(1)
257
+
258
+ mask = torch.as_tensor(mask, dtype=dtype, device=device)
259
+ return mask
260
+
261
+ def _compute_statistics(self, x, m, dim=2):
262
+ mean = (m * x).sum(dim)
263
+ std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps))
264
+ return mean, std
265
+
266
+ def forward(self, hidden_states):
267
+ seq_length = hidden_states.shape[-1]
268
+ lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device)
269
+
270
+ # Make binary mask of shape [N, 1, L]
271
+ mask = self._length_to_mask(
272
+ lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device
273
+ )
274
+ mask = mask.unsqueeze(1)
275
+
276
+ # Expand the temporal context of the pooling layer by allowing the
277
+ # self-attention to look at global properties of the utterance.
278
+ total = mask.sum(dim=2, keepdim=True)
279
+
280
+ mean, std = self._compute_statistics(hidden_states, mask / total)
281
+ mean = mean.unsqueeze(2).repeat(1, 1, seq_length)
282
+ std = std.unsqueeze(2).repeat(1, 1, seq_length)
283
+ attention = torch.cat([hidden_states, mean, std], dim=1)
284
+
285
+ # Apply layers
286
+ attention = self.conv(self.tanh(self.tdnn(attention)))
287
+
288
+ # Filter out zero-paddings
289
+ attention = attention.masked_fill(mask == 0, float("-inf"))
290
+
291
+ attention = F.softmax(attention, dim=2)
292
+ mean, std = self._compute_statistics(hidden_states, attention)
293
+ # Append mean and std of the batch
294
+ pooled_stats = torch.cat((mean, std), dim=1)
295
+ pooled_stats = pooled_stats.unsqueeze(2)
296
+
297
+ return pooled_stats
298
+
299
+
300
+ class SqueezeExcitationRes2NetBlock(nn.Module):
301
+ """An implementation of building block in ECAPA-TDNN, i.e.,
302
+ TDNN-Res2Net-TDNN-SqueezeExcitationBlock.
303
+ """
304
+
305
+ def __init__(
306
+ self,
307
+ in_channels,
308
+ out_channels,
309
+ res2net_scale=8,
310
+ se_channels=128,
311
+ kernel_size=1,
312
+ dilation=1,
313
+ ):
314
+ super().__init__()
315
+ self.out_channels = out_channels
316
+ self.tdnn1 = TimeDelayNetBlock(
317
+ in_channels,
318
+ out_channels,
319
+ kernel_size=1,
320
+ dilation=1,
321
+ )
322
+ self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation)
323
+ self.tdnn2 = TimeDelayNetBlock(
324
+ out_channels,
325
+ out_channels,
326
+ kernel_size=1,
327
+ dilation=1,
328
+ )
329
+ self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels)
330
+
331
+ def forward(self, hidden_state):
332
+ residual = hidden_state
333
+
334
+ hidden_state = self.tdnn1(hidden_state)
335
+ hidden_state = self.res2net_block(hidden_state)
336
+ hidden_state = self.tdnn2(hidden_state)
337
+ hidden_state = self.se_block(hidden_state)
338
+
339
+ return hidden_state + residual
340
+
341
+
342
+ class ECAPA_TimeDelayNet(torch.nn.Module):
343
+ """An implementation of the speaker embedding model in a paper.
344
+ "ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
345
+ TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143).
346
+ """
347
+
348
+ def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig):
349
+ super().__init__()
350
+ if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len(
351
+ config.enc_dilations
352
+ ):
353
+ raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length")
354
+ self.channels = config.enc_channels
355
+ self.blocks = nn.ModuleList()
356
+
357
+ # The initial TDNN layer
358
+ self.blocks.append(
359
+ TimeDelayNetBlock(
360
+ config.mel_dim,
361
+ config.enc_channels[0],
362
+ config.enc_kernel_sizes[0],
363
+ config.enc_dilations[0],
364
+ )
365
+ )
366
+
367
+ # SE-Res2Net layers
368
+ for i in range(1, len(config.enc_channels) - 1):
369
+ self.blocks.append(
370
+ SqueezeExcitationRes2NetBlock(
371
+ config.enc_channels[i - 1],
372
+ config.enc_channels[i],
373
+ res2net_scale=config.enc_res2net_scale,
374
+ se_channels=config.enc_se_channels,
375
+ kernel_size=config.enc_kernel_sizes[i],
376
+ dilation=config.enc_dilations[i],
377
+ )
378
+ )
379
+
380
+ # Multi-layer feature aggregation
381
+ self.mfa = TimeDelayNetBlock(
382
+ config.enc_channels[-1],
383
+ config.enc_channels[-1],
384
+ config.enc_kernel_sizes[-1],
385
+ config.enc_dilations[-1],
386
+ )
387
+
388
+ # Attentive Statistical Pooling
389
+ self.asp = AttentiveStatisticsPooling(
390
+ config.enc_channels[-1],
391
+ attention_channels=config.enc_attention_channels,
392
+ )
393
+
394
+ # Final linear transformation
395
+ self.fc = nn.Conv1d(
396
+ in_channels=config.enc_channels[-1] * 2,
397
+ out_channels=config.enc_dim,
398
+ kernel_size=1,
399
+ padding="same",
400
+ padding_mode="reflect",
401
+ )
402
+
403
+ def forward(self, hidden_states):
404
+ # Minimize transpose for efficiency
405
+ hidden_states = hidden_states.transpose(1, 2)
406
+
407
+ hidden_states_list = []
408
+ for layer in self.blocks:
409
+ hidden_states = layer(hidden_states)
410
+ hidden_states_list.append(hidden_states)
411
+
412
+ # Multi-layer feature aggregation
413
+ hidden_states = torch.cat(hidden_states_list[1:], dim=1)
414
+ hidden_states = self.mfa(hidden_states)
415
+
416
+ # Attentive Statistical Pooling
417
+ hidden_states = self.asp(hidden_states)
418
+
419
+ # Final linear transformation
420
+ hidden_states = self.fc(hidden_states)
421
+
422
+ hidden_states = hidden_states.squeeze(-1)
423
+ return hidden_states
424
+
425
+
426
+ class DiTInputEmbedding(nn.Module):
427
+ def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig):
428
+ super().__init__()
429
+ self.proj = nn.Linear(
430
+ config.mel_dim + config.enc_dim + config.enc_emb_dim + config.emb_dim,
431
+ config.hidden_size,
432
+ )
433
+ self.spk_encoder = ECAPA_TimeDelayNet(config)
434
+
435
+ def forward(
436
+ self,
437
+ hidden_states: torch.Tensor,
438
+ speaker_embedding: torch.Tensor,
439
+ condition_vector: torch.Tensor,
440
+ code_embed: torch.Tensor,
441
+ drop_audio_cond: Optional[bool] = False,
442
+ code_embed_uncond: Optional[bool] = None,
443
+ apply_cfg: Optional[bool] = True,
444
+ ):
445
+ if apply_cfg:
446
+ hidden_states = torch.cat([hidden_states, hidden_states], dim=0)
447
+ speaker_embedding = torch.cat([speaker_embedding, torch.zeros_like(speaker_embedding)], dim=0)
448
+ condition_vector = torch.cat([condition_vector, torch.zeros_like(condition_vector)], dim=0)
449
+ code_embed = torch.cat([code_embed, code_embed_uncond], dim=0)
450
+ elif drop_audio_cond: # cfg for cond audio
451
+ condition_vector = torch.zeros_like(condition_vector)
452
+ speaker_embedding = torch.zeros_like(speaker_embedding)
453
+ condition_vector = self.spk_encoder(condition_vector).unsqueeze(1).repeat(1, hidden_states.size(1), 1)
454
+ hidden_states = self.proj(torch.cat((hidden_states, condition_vector, code_embed, speaker_embedding), dim=-1))
455
+
456
+ return hidden_states
457
+
458
+
459
+ # Transformer backbone using DiT blocks
460
+ class DiTCodecEmbedding(nn.Module):
461
+ def __init__(self, codec_num_embeds, codec_dim, repeats):
462
+ super().__init__()
463
+ self.repeats = repeats
464
+ self.codec_embed = nn.Embedding(codec_num_embeds + 1, codec_dim)
465
+
466
+ def forward(self, code, drop_code=False):
467
+ if drop_code:
468
+ code = torch.zeros_like(code)
469
+ code_embed = self.codec_embed(code)
470
+
471
+ code_embed = torch.repeat_interleave(code_embed, repeats=self.repeats, dim=1)
472
+ return code_embed
473
+
474
+
475
+ # AdaLayerNormZero
476
+ # return with modulated x for attn input, and params for later mlp modulation
477
+ class AdaLayerNormZero(nn.Module):
478
+ def __init__(self, dim):
479
+ super().__init__()
480
+
481
+ self.silu = nn.SiLU()
482
+ self.linear = nn.Linear(dim, dim * 6)
483
+
484
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
485
+
486
+ def forward(self, hidden_states, emb=None):
487
+ emb = self.linear(self.silu(emb))
488
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
489
+
490
+ hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None]
491
+ return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp
492
+
493
+
494
+ # AdaLayerNormZero for final layer
495
+ # return only with modulated x for attn input, cuz no more mlp modulation
496
+ class AdaLayerNormZero_Final(nn.Module):
497
+ def __init__(self, dim):
498
+ super().__init__()
499
+
500
+ self.silu = nn.SiLU()
501
+ self.linear = nn.Linear(dim, dim * 2)
502
+
503
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
504
+
505
+ def forward(self, hidden_states, emb):
506
+ emb = self.linear(self.silu(emb))
507
+ scale, shift = torch.chunk(emb, 2, dim=1)
508
+
509
+ hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
510
+ return hidden_states
511
+
512
+
513
+ # FeedForward
514
+ class DiTMLP(nn.Module):
515
+ def __init__(self, dim, mult=4, dropout=0.0):
516
+ super().__init__()
517
+ inner_dim = int(dim * mult)
518
+
519
+ self.ff = nn.ModuleList(
520
+ [
521
+ nn.Linear(dim, inner_dim),
522
+ nn.GELU(approximate="tanh"),
523
+ nn.Dropout(dropout),
524
+ nn.Linear(inner_dim, dim),
525
+ ]
526
+ )
527
+
528
+ def forward(self, hidden_states):
529
+ for layer in self.ff:
530
+ hidden_states = layer(hidden_states)
531
+ return hidden_states
532
+
533
+
534
+ # Modified from Llama with a different rotate function, will fixed in next release
535
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
536
+ """Applies Rotary Position Embedding to the query and key tensors.
537
+
538
+ Args:
539
+ q (`torch.Tensor`): The query tensor.
540
+ k (`torch.Tensor`): The key tensor.
541
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
542
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
543
+ position_ids (`torch.Tensor`, *optional*):
544
+ Deprecated and unused.
545
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
546
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
547
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
548
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
549
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
550
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
551
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
552
+ Returns:
553
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
554
+ """
555
+
556
+ def rotate_half_codec(x):
557
+ # x = rearrange(x, "... (d r) -> ... d r", r=2)
558
+ x = x.reshape(*x.shape[:-1], -1, 2)
559
+ x1, x2 = x.unbind(dim=-1)
560
+ x = torch.stack((-x2, x1), dim=-1)
561
+ return x.reshape(*x.shape[:-2], -1)
562
+
563
+ cos = cos.unsqueeze(unsqueeze_dim)
564
+ sin = sin.unsqueeze(unsqueeze_dim)
565
+ q_embed = (q * cos) + (rotate_half_codec(q) * sin)
566
+ k_embed = (k * cos) + (rotate_half_codec(k) * sin)
567
+ return q_embed, k_embed
568
+
569
+
570
+ class DiTAttention(nn.Module):
571
+ def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig):
572
+ super().__init__()
573
+
574
+ self.config = config
575
+ self.dim = config.hidden_size
576
+ self.heads = config.num_attention_heads
577
+ self.inner_dim = config.head_dim * config.num_attention_heads
578
+ self.dropout = config.dropout
579
+ self.is_causal = False
580
+
581
+ self.to_q = nn.Linear(config.hidden_size, self.inner_dim)
582
+ self.to_k = nn.Linear(config.hidden_size, self.inner_dim)
583
+ self.to_v = nn.Linear(config.hidden_size, self.inner_dim)
584
+
585
+ self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)])
586
+
587
+ def forward(
588
+ self,
589
+ hidden_states, # noised input x
590
+ position_embeddings=None, # rotary position embedding for x
591
+ attention_mask=None,
592
+ ) -> torch.Tensor:
593
+ batch_size = hidden_states.shape[0]
594
+
595
+ # `sample` projections.
596
+ query = self.to_q(hidden_states)
597
+ key = self.to_k(hidden_states)
598
+ value = self.to_v(hidden_states)
599
+
600
+ # attention
601
+ inner_dim = key.shape[-1]
602
+ head_dim = inner_dim // self.heads
603
+ query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
604
+ key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
605
+ value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
606
+
607
+ # apply rotary position embedding
608
+ # Due to training process, only first head is applied with RoPE, will be fixed at next release
609
+ cos, sin = position_embeddings
610
+ query, key = apply_rotary_pos_emb(query, key, cos, sin)
611
+
612
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
613
+ attention_weights, _ = attention_interface(
614
+ self,
615
+ query,
616
+ key,
617
+ value,
618
+ attention_mask=attention_mask,
619
+ is_causal=False,
620
+ )
621
+
622
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
623
+ attention_weights = attention_weights.reshape(batch_size, -1, self.heads * head_dim)
624
+ attention_weights = attention_weights.to(query.dtype)
625
+
626
+ # linear proj
627
+ attention_output = self.to_out[0](attention_weights)
628
+ attention_output = self.to_out[1](attention_output)
629
+
630
+ return attention_output
631
+
632
+
633
+ # time step conditioning embedding
634
+ class SinusPositionEmbedding(nn.Module):
635
+ def __init__(self, dim):
636
+ super().__init__()
637
+ self.dim = dim
638
+
639
+ def forward(self, hidden_states, scale=1000):
640
+ device = hidden_states.device
641
+ half_dim = self.dim // 2
642
+ emb = math.log(10000) / (half_dim - 1)
643
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
644
+ emb = scale * hidden_states.unsqueeze(1) * emb.unsqueeze(0)
645
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
646
+ return emb.type_as(hidden_states)
647
+
648
+
649
+ class DiTTimestepEmbedding(nn.Module):
650
+ def __init__(self, dim, freq_embed_dim=256):
651
+ super().__init__()
652
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
653
+ self.time_mlp = nn.ModuleList([nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)])
654
+
655
+ def forward(self, timestep):
656
+ time_hidden = self.time_embed(timestep)
657
+ time_hidden = time_hidden.to(timestep.dtype)
658
+ for layer in self.time_mlp:
659
+ time_hidden = layer(time_hidden) # b d
660
+ return time_hidden
661
+
662
+
663
+ class DiTDecoderLayer(nn.Module):
664
+ def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig, look_ahead_block=0, look_backward_block=0):
665
+ super().__init__()
666
+ self.attn_norm = AdaLayerNormZero(config.hidden_size)
667
+
668
+ self.attn = DiTAttention(config)
669
+ self.look_ahead_block = look_ahead_block
670
+ self.look_backward_block = look_backward_block
671
+ self.ff_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6)
672
+ self.ff = DiTMLP(dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout)
673
+
674
+ def forward(
675
+ self, hidden_states, timestep, position_embeddings=None, block_diff=None
676
+ ): # x: noised input, t: time embedding
677
+ # pre-norm & modulation for attention input
678
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(hidden_states, emb=timestep)
679
+
680
+ # attention
681
+ attn_output = self.attn(
682
+ hidden_states=norm,
683
+ position_embeddings=position_embeddings,
684
+ attention_mask=(block_diff >= -float(self.look_backward_block))
685
+ & (block_diff <= float(self.look_ahead_block)),
686
+ )
687
+
688
+ # process attention output for input x
689
+ hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output
690
+
691
+ norm = self.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
692
+ ff_output = self.ff(norm)
693
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
694
+
695
+ return hidden_states
696
+
697
+
698
+ class SnakeBeta(nn.Module):
699
+ """
700
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
701
+ Shape:
702
+ - Input: (B, C, T)
703
+ - Output: (B, C, T), same shape as the input
704
+ Parameters:
705
+ - alpha - trainable parameter that controls frequency
706
+ - beta - trainable parameter that controls magnitude
707
+ References:
708
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
709
+ https://huggingface.co/papers/2006.08195
710
+ """
711
+
712
+ def __init__(self, in_features, alpha=1.0):
713
+ super().__init__()
714
+ self.in_features = in_features
715
+
716
+ # initialize alpha
717
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
718
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
719
+
720
+ self.no_div_by_zero = 0.000000001
721
+
722
+ def forward(self, hidden_states):
723
+ """
724
+ Forward pass of the function.
725
+ Applies the function to the input elementwise.
726
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
727
+ """
728
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
729
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
730
+ alpha = torch.exp(alpha)
731
+ beta = torch.exp(beta)
732
+ hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
733
+ torch.sin(hidden_states * alpha), 2
734
+ )
735
+
736
+ return hidden_states
737
+
738
+
739
+ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
740
+ """Generates a 1D Kaiser-windowed sinc filter.
741
+
742
+ Args:
743
+ cutoff (float): Normalized cutoff frequency (0 to 0.5).
744
+ half_width (float): Transition bandwidth.
745
+ kernel_size (int): Number of filter taps.
746
+
747
+ Returns:
748
+ torch.Tensor: A tensor of shape (1, 1, kernel_size) representing the filter.
749
+ """
750
+ is_even = kernel_size % 2 == 0
751
+ half_size = kernel_size // 2
752
+
753
+ # Compute Kaiser window parameters
754
+ delta_f = 4 * half_width
755
+ attenuation = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
756
+
757
+ if attenuation > 50.0:
758
+ beta = 0.1102 * (attenuation - 8.7)
759
+ elif attenuation >= 21.0:
760
+ beta = 0.5842 * (attenuation - 21) ** 0.4 + 0.07886 * (attenuation - 21.0)
761
+ else:
762
+ beta = 0.0
763
+
764
+ kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32)
765
+
766
+ # Compute time indices
767
+ if is_even:
768
+ time_indices = torch.arange(-half_size, half_size) + 0.5
769
+ else:
770
+ time_indices = torch.arange(kernel_size) - half_size
771
+
772
+ # Compute sinc filter
773
+ if cutoff == 0:
774
+ return torch.zeros((1, 1, kernel_size), dtype=torch.float32) # Ensures correct shape
775
+
776
+ sinc_filter = torch.sinc(2 * cutoff * time_indices)
777
+ normalized_filter = 2 * cutoff * kaiser_window * sinc_filter
778
+
779
+ # Normalize to ensure sum = 1 (avoid leakage of constant component)
780
+ normalized_filter /= normalized_filter.sum()
781
+
782
+ return normalized_filter.view(1, 1, kernel_size)
783
+
784
+
785
+ class UpSample1d(nn.Module):
786
+ def __init__(self, ratio=2, kernel_size=None):
787
+ super().__init__()
788
+ self.ratio = ratio
789
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
790
+ self.stride = ratio
791
+ self.pad = self.kernel_size // ratio - 1
792
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
793
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
794
+
795
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
796
+ self.register_buffer("filter", filter, persistent=False)
797
+
798
+ def forward(self, hidden_states):
799
+ channels = hidden_states.shape[1]
800
+
801
+ hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate")
802
+ hidden_states = self.ratio * F.conv_transpose1d(
803
+ hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels
804
+ )
805
+ hidden_states = hidden_states[..., self.pad_left : -self.pad_right]
806
+
807
+ return hidden_states
808
+
809
+
810
+ class DownSample1d(nn.Module):
811
+ def __init__(self, ratio=2, kernel_size=None):
812
+ super().__init__()
813
+ cutoff = 0.5 / ratio
814
+ half_width = 0.6 / ratio
815
+
816
+ if cutoff < 0.0:
817
+ raise ValueError("Minimum cutoff must be larger than zero.")
818
+ if cutoff > 0.5:
819
+ raise ValueError("A cutoff above 0.5 does not make sense.")
820
+
821
+ self.even = kernel_size % 2 == 0
822
+ self.pad_left = kernel_size // 2 - int(self.even)
823
+ self.pad_right = kernel_size // 2
824
+ self.stride = ratio
825
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
826
+ self.register_buffer("filter", filter, persistent=False)
827
+
828
+ def forward(self, hidden_states):
829
+ channels = hidden_states.shape[1]
830
+ hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate")
831
+ out = F.conv1d(hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels)
832
+ return out
833
+
834
+
835
+ class TorchActivation1d(nn.Module):
836
+ def __init__(
837
+ self,
838
+ activation,
839
+ up_ratio: int = 2,
840
+ down_ratio: int = 2,
841
+ up_kernel_size: int = 12,
842
+ down_kernel_size: int = 12,
843
+ ):
844
+ super().__init__()
845
+ if not callable(activation):
846
+ raise TypeError("Activation function must be callable")
847
+ self.act = activation
848
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
849
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
850
+
851
+ def forward(self, hidden_states):
852
+ hidden_states = self.upsample(hidden_states)
853
+ hidden_states = self.act(hidden_states)
854
+ hidden_states = self.downsample(hidden_states)
855
+
856
+ return hidden_states
857
+
858
+
859
+ class CausalConv1d(nn.Conv1d):
860
+ def __init__(self, *args, **kwargs):
861
+ super().__init__(*args, **kwargs)
862
+ self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1)
863
+
864
+ def forward(self, x):
865
+ return self._conv_forward(F.pad(x, [self.causal_padding, 0]), self.weight, self.bias)
866
+
867
+
868
+ class AMPBlock(torch.nn.Module):
869
+ def __init__(
870
+ self,
871
+ channels,
872
+ kernel_size=3,
873
+ dilation=(1, 3, 5),
874
+ causal_type='1',
875
+ ):
876
+ super().__init__()
877
+
878
+ self.convs1 = nn.ModuleList(
879
+ [
880
+ CausalConv1d(
881
+ channels,
882
+ channels,
883
+ kernel_size,
884
+ 1,
885
+ dilation=dilation[0],
886
+ ),
887
+ CausalConv1d(
888
+ channels,
889
+ channels,
890
+ kernel_size,
891
+ 1,
892
+ dilation=dilation[1],
893
+ ),
894
+ CausalConv1d(
895
+ channels,
896
+ channels,
897
+ kernel_size,
898
+ 1,
899
+ dilation=dilation[2],
900
+ ),
901
+ ]
902
+ )
903
+
904
+ if causal_type == '1':
905
+ self.convs2 = nn.ModuleList(
906
+ [
907
+ nn.Conv1d(
908
+ channels,
909
+ channels,
910
+ kernel_size,
911
+ 1,
912
+ dilation=1,
913
+ padding=self._get_padding(kernel_size, 1),
914
+ ),
915
+ nn.Conv1d(
916
+ channels,
917
+ channels,
918
+ kernel_size,
919
+ 1,
920
+ dilation=1,
921
+ padding=self._get_padding(kernel_size, 1),
922
+ ),
923
+ nn.Conv1d(
924
+ channels,
925
+ channels,
926
+ kernel_size,
927
+ 1,
928
+ dilation=1,
929
+ padding=self._get_padding(kernel_size, 1),
930
+ ),
931
+ ]
932
+ )
933
+ else:
934
+ self.convs2 = nn.ModuleList(
935
+ [
936
+ CausalConv1d(
937
+ channels,
938
+ channels,
939
+ kernel_size,
940
+ 1,
941
+ dilation=1,
942
+ ),
943
+ CausalConv1d(
944
+ channels,
945
+ channels,
946
+ kernel_size,
947
+ 1,
948
+ dilation=1,
949
+ ),
950
+ CausalConv1d(
951
+ channels,
952
+ channels,
953
+ kernel_size,
954
+ 1,
955
+ dilation=1,
956
+ ),
957
+ ]
958
+ )
959
+
960
+ self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
961
+
962
+ self.activations = nn.ModuleList(
963
+ [TorchActivation1d(activation=SnakeBeta(channels)) for _ in range(self.num_layers)]
964
+ )
965
+
966
+ if causal_type == '2':
967
+ self.pre_conv = nn.Conv1d(
968
+ channels,
969
+ channels,
970
+ kernel_size,
971
+ stride=1,
972
+ padding=self._get_padding(kernel_size, 1),
973
+ )
974
+ self.pre_act = TorchActivation1d(activation=SnakeBeta(channels))
975
+ else:
976
+ self.pre_conv = nn.Identity()
977
+ self.pre_act = nn.Identity()
978
+
979
+ def _get_padding(self, kernel_size, dilation=1):
980
+ return int((kernel_size * dilation - dilation) / 2)
981
+
982
+ def forward(self, x):
983
+ hidden_states = self.pre_conv(x)
984
+ hidden_states = self.pre_act(hidden_states)
985
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
986
+ for conv1, conv2, act1, act2 in zip(self.convs1, self.convs2, acts1, acts2):
987
+ hidden_states = act1(hidden_states)
988
+ hidden_states = conv1(hidden_states)
989
+ hidden_states = act2(hidden_states)
990
+ hidden_states = conv2(hidden_states)
991
+ x = x + hidden_states
992
+ return x
993
+
994
+
995
+ @auto_docstring
996
+ class Qwen3TTSTokenizerV1DecoderBigVGANModel(Qwen3TTSTokenizerV1DecoderPreTrainedModel):
997
+ config: Qwen3TTSTokenizerV1DecoderBigVGANConfig
998
+
999
+ def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig):
1000
+ super().__init__(config)
1001
+ self.num_residual_blocks = len(config.resblock_kernel_sizes)
1002
+ self.num_upsample_layers = len(config.upsample_rates)
1003
+
1004
+ self.conv_pre = nn.Conv1d(config.mel_dim, config.upsample_initial_channel, 5, 1, padding=2)
1005
+
1006
+ # Removing extra ModuleList breaks official state dict
1007
+ ups = [
1008
+ nn.ModuleList(
1009
+ [
1010
+ nn.ConvTranspose1d(
1011
+ config.upsample_initial_channel // (2**layer_idx),
1012
+ config.upsample_initial_channel // (2 ** (layer_idx + 1)),
1013
+ kernel_size,
1014
+ stride,
1015
+ padding=(kernel_size - stride) // 2,
1016
+ )
1017
+ ]
1018
+ )
1019
+ for layer_idx, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes))
1020
+ ]
1021
+ self.ups = nn.ModuleList(ups)
1022
+
1023
+ self.resblocks = nn.ModuleList(
1024
+ [
1025
+ AMPBlock(config.upsample_initial_channel // (2 ** (layer_idx + 1)), kernel_size, dilation, '1' if layer_idx > 1 else '2')
1026
+ for layer_idx in range(self.num_upsample_layers)
1027
+ for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes)
1028
+ ]
1029
+ )
1030
+
1031
+ self.activation_post = TorchActivation1d(
1032
+ activation=SnakeBeta(config.upsample_initial_channel // (2**self.num_upsample_layers))
1033
+ )
1034
+ self.conv_post = nn.Conv1d(
1035
+ config.upsample_initial_channel // (2**self.num_upsample_layers), 1, 7, 1, padding=3, bias=False
1036
+ )
1037
+
1038
+ def normalize_spectrogram(self, spectrogram, max_value, min_db):
1039
+ return torch.clamp((2 * max_value) * ((spectrogram - min_db) / (-min_db)) - max_value, -max_value, max_value)
1040
+
1041
+ def amplitude_to_db(self, amplitude, min_db_level):
1042
+ min_level = torch.exp(
1043
+ torch.tensor(min_db_level / 20.0 * np.log(10), device=amplitude.device, dtype=amplitude.dtype)
1044
+ )
1045
+ return 20 * torch.log10(torch.clamp(amplitude, min=min_level))
1046
+
1047
+ def process_mel_spectrogram(self, mel_spectrogram):
1048
+ amplitude_spectrum = torch.exp(mel_spectrogram)
1049
+ decibel_spectrum = self.amplitude_to_db(amplitude_spectrum, -115) - 20
1050
+ return self.normalize_spectrogram(decibel_spectrum, 1, -115)
1051
+
1052
+ def forward(self, mel_spectrogram):
1053
+ processed_spectrogram = self.process_mel_spectrogram(mel_spectrogram)
1054
+ hidden_representation = self.conv_pre(processed_spectrogram)
1055
+
1056
+ for layer_index in range(self.num_upsample_layers):
1057
+ hidden_representation = self.ups[layer_index][0](hidden_representation)
1058
+ residual_output = sum(
1059
+ self.resblocks[layer_index * self.num_residual_blocks + block_index](hidden_representation)
1060
+ for block_index in range(self.num_residual_blocks)
1061
+ )
1062
+ residual_output = residual_output / self.num_residual_blocks
1063
+ hidden_representation = residual_output
1064
+
1065
+ hidden_representation = self.activation_post(hidden_representation)
1066
+ output_waveform = self.conv_post(hidden_representation)
1067
+ return torch.clamp(output_waveform, min=-1.0, max=1.0).squeeze(1)
1068
+
1069
+
1070
+ @auto_docstring
1071
+ class Qwen3TTSTokenizerV1DecoderDiTModel(Qwen3TTSTokenizerV1DecoderPreTrainedModel):
1072
+ config: Qwen3TTSTokenizerV1DecoderDiTConfig
1073
+ _no_split_modules = ["DiTDecoderLayer"]
1074
+
1075
+ def __init__(self, config: Qwen3TTSTokenizerV1DecoderDiTConfig):
1076
+ super().__init__(config)
1077
+ self.mel_dim = config.mel_dim
1078
+ self.repeats = config.repeats
1079
+ self.time_embed = DiTTimestepEmbedding(config.hidden_size)
1080
+
1081
+ self.text_embed = DiTCodecEmbedding(config.num_embeds, config.emb_dim, config.repeats)
1082
+ self.input_embed = DiTInputEmbedding(config)
1083
+
1084
+ self.rotary_embed = Qwen3TTSTokenizerV1DecoderDiTRotaryEmbedding(config.head_dim)
1085
+
1086
+ self.hidden_size = config.hidden_size
1087
+ self.layers = config.num_hidden_layers
1088
+ self.block_size = config.block_size
1089
+ self.num_attention_heads = config.num_attention_heads
1090
+
1091
+ self.transformer_blocks = nn.ModuleList()
1092
+ for i in range(config.num_hidden_layers):
1093
+ self.transformer_blocks.append(
1094
+ DiTDecoderLayer(
1095
+ config,
1096
+ look_ahead_block=1 if i in config.look_ahead_layers else 0,
1097
+ look_backward_block=1 if i in config.look_backward_layers else 0,
1098
+ )
1099
+ )
1100
+
1101
+ self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation
1102
+ self.proj_out = nn.Linear(config.hidden_size, config.mel_dim)
1103
+
1104
+ def _create_block_diff(self, hidden_states):
1105
+ batch, seq_len = hidden_states.shape[0], hidden_states.shape[1]
1106
+ block_indices = torch.arange(seq_len, device=hidden_states.device) // self.block_size # [seq_length]
1107
+
1108
+ block_i = block_indices.unsqueeze(1) # [seq_length, 1]
1109
+ block_j = block_indices.unsqueeze(0) # [1, seq_length]
1110
+ block_diff = block_j - block_i # (n, n)
1111
+
1112
+ return block_diff.expand(batch, self.num_attention_heads, seq_len, seq_len)
1113
+
1114
+ def forward(
1115
+ self,
1116
+ hidden_states,
1117
+ condition_vector,
1118
+ speaker_embedding,
1119
+ quantized_code,
1120
+ time_step,
1121
+ drop_audio_conditioning=False,
1122
+ drop_code=False,
1123
+ apply_cfg=True,
1124
+ ):
1125
+ batch_size = hidden_states.shape[0] * 2
1126
+ if time_step.ndim == 0:
1127
+ time_step = time_step.repeat(batch_size)
1128
+
1129
+ # Compute embeddings
1130
+ time_embedding = self.time_embed(time_step)
1131
+ text_embedding = self.text_embed(quantized_code, drop_code=False if apply_cfg else drop_code)
1132
+ text_embedding_unconditioned = self.text_embed(quantized_code, drop_code=True) if apply_cfg else None
1133
+
1134
+ hidden_states = self.input_embed(
1135
+ hidden_states,
1136
+ speaker_embedding,
1137
+ condition_vector,
1138
+ text_embedding,
1139
+ drop_audio_cond=drop_audio_conditioning,
1140
+ code_embed_uncond=text_embedding_unconditioned,
1141
+ apply_cfg=apply_cfg,
1142
+ )
1143
+
1144
+ # Compute positional encodings
1145
+ position_embeddings = self.rotary_embed(hidden_states)
1146
+ blockwise_difference = self._create_block_diff(hidden_states)
1147
+
1148
+ # Transformer blocks
1149
+ for transformer_block in self.transformer_blocks:
1150
+ hidden_states = transformer_block(
1151
+ hidden_states,
1152
+ time_embedding,
1153
+ position_embeddings=position_embeddings,
1154
+ block_diff=blockwise_difference,
1155
+ )
1156
+
1157
+ hidden_states = self.norm_out(hidden_states, time_embedding)
1158
+ output = self.proj_out(hidden_states)
1159
+
1160
+ return output
1161
+
1162
+ def optimized_scale(self, positive_flat, negative_flat):
1163
+ # Calculate dot production
1164
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
1165
+ # Squared norm of uncondition
1166
+ squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
1167
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
1168
+ st_star = dot_product / squared_norm
1169
+ return st_star
1170
+
1171
+ @torch.no_grad()
1172
+ def sample(
1173
+ self,
1174
+ conditioning_vector,
1175
+ reference_mel_spectrogram,
1176
+ quantized_code,
1177
+ num_steps=10,
1178
+ guidance_scale=0.5,
1179
+ sway_coefficient=-1.0,
1180
+ ):
1181
+ noise_initialization = torch.randn([quantized_code.shape[0], 30000, self.mel_dim], dtype=reference_mel_spectrogram.dtype)
1182
+ maximum_duration = quantized_code.shape[1] * self.repeats
1183
+ initial_state = noise_initialization[:, :maximum_duration].to(quantized_code.device)
1184
+ conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, maximum_duration, 1)
1185
+
1186
+ def ode_function(time_step, hidden_states):
1187
+ if guidance_scale < 1e-5:
1188
+ prediction = self(
1189
+ hidden_states=hidden_states,
1190
+ speaker_embedding=conditioning_vector,
1191
+ condition_vector=reference_mel_spectrogram,
1192
+ quantized_code=quantized_code,
1193
+ time_step=time_step,
1194
+ drop_audio_conditioning=False,
1195
+ drop_code=False,
1196
+ )
1197
+ return prediction
1198
+
1199
+ model_output = self(
1200
+ hidden_states=hidden_states,
1201
+ quantized_code=quantized_code,
1202
+ speaker_embedding=conditioning_vector,
1203
+ condition_vector=reference_mel_spectrogram,
1204
+ time_step=time_step,
1205
+ apply_cfg=True,
1206
+ )
1207
+ guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0)
1208
+
1209
+ return guided_prediction + (guided_prediction - null_prediction) * guidance_scale
1210
+
1211
+ initial_time = 0
1212
+ time_embedding = torch.linspace(
1213
+ initial_time, 1, num_steps, device=quantized_code.device, dtype=conditioning_vector.dtype
1214
+ )
1215
+
1216
+ if sway_coefficient is not None:
1217
+ time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding)
1218
+
1219
+ values = initial_state.clone()
1220
+ for t0, t1 in zip(time_embedding[:-1], time_embedding[1:]):
1221
+ dt = t1 - t0
1222
+ vt = ode_function(t0, values)
1223
+ values = values + vt * dt
1224
+
1225
+ generated_mel_spectrogram = values.permute(0, 2, 1)
1226
+ return generated_mel_spectrogram
1227
+
1228
+
1229
+ @auto_docstring
1230
+ class Qwen3TTSTokenizerV1Decoder(Qwen3TTSTokenizerV1DecoderPreTrainedModel):
1231
+ config: Qwen3TTSTokenizerV1DecoderConfig
1232
+ base_model_prefix = "model"
1233
+ _no_split_modules = ["Qwen3TTSTokenizerV1DecoderDiTModel", "Qwen3TTSTokenizerV1DecoderBigVGANModel"]
1234
+
1235
+ def __init__(self, config: Qwen3TTSTokenizerV1DecoderConfig):
1236
+ super().__init__(config)
1237
+ attn_impl = config._attn_implementation
1238
+ if config._attn_implementation == "flash_attention_2":
1239
+ logger.warning_once(
1240
+ "Qwen3TTSTokenizerV1Decoder must inference with fp32, but flash_attention_2 only supports fp16 and bf16, "
1241
+ "attention implementation of Qwen3TTSTokenizerV1Decoder will fallback to sdpa."
1242
+ )
1243
+ attn_impl = "sdpa"
1244
+ elif config._attn_implementation == "eager":
1245
+ logger.warning_once(
1246
+ "Qwen3TTSTokenizerV1Decoder does not support eager attention implementation, fall back to sdpa"
1247
+ )
1248
+ attn_impl = "sdpa"
1249
+ self.dit = Qwen3TTSTokenizerV1DecoderDiTModel._from_config(
1250
+ config.dit_config, attn_implementation=attn_impl
1251
+ )
1252
+ self.bigvgan = Qwen3TTSTokenizerV1DecoderBigVGANModel._from_config(
1253
+ config.bigvgan_config, attn_implementation=attn_impl
1254
+ )
1255
+
1256
+ def forward(
1257
+ self,
1258
+ code,
1259
+ conditioning,
1260
+ reference_mel,
1261
+ num_steps=10,
1262
+ guidance_scale=0.5,
1263
+ sway_coefficient=-1.0,
1264
+ **kwargs,
1265
+ ):
1266
+ """Generates a waveform from input code and conditioning parameters."""
1267
+
1268
+ mel_spectrogram = self.dit.sample(
1269
+ conditioning,
1270
+ reference_mel,
1271
+ code,
1272
+ num_steps=num_steps,
1273
+ guidance_scale=guidance_scale,
1274
+ sway_coefficient=sway_coefficient,
1275
+ )
1276
+
1277
+ waveform = self.bigvgan(mel_spectrogram)
1278
+
1279
+ return waveform
1280
+
1281
+
1282
+ class Qwen3TTSTokenizerV1Encoder(Qwen3TTSTokenizerV1EncoderPreTrainedModel):
1283
+ config: Qwen3TTSTokenizerV1EncoderConfig
1284
+ def __init__(self, config: Qwen3TTSTokenizerV1EncoderConfig):
1285
+ super().__init__(config)
1286
+
1287
+ self.tokenizer = WhisperEncoderVQ(
1288
+ n_mels=config.n_mels,
1289
+ n_ctx=config.n_ctx,
1290
+ n_state=config.n_state,
1291
+ n_head=config.n_head,
1292
+ n_layer=config.n_layer,
1293
+ n_window=config.n_window,
1294
+ output_dim=config.output_dim,
1295
+ grad_checkpointing=config.grad_checkpointing,
1296
+ enable_mp=config.enable_mp,
1297
+ audio_sequence_parallel=config.audio_sequence_parallel,
1298
+ audio_vq_type=config.audio_vq_type,
1299
+ audio_vq_layers=config.audio_vq_layers,
1300
+ audio_vq_codebook_size=config.audio_vq_codebook_size,
1301
+ audio_vq_codebook_dim=config.audio_vq_codebook_dim,
1302
+ audio_vq_pe=config.audio_vq_pe,
1303
+ audio_vq_ds_rate=config.audio_vq_ds_rate,
1304
+ )
1305
+
1306
+ self.padding = True
1307
+ self.audio_vq_ds_rate = self.tokenizer.audio_vq_ds_rate
1308
+
1309
+ def speech2mel(self, speechs):
1310
+ mels = [
1311
+ get_mel_audio(
1312
+ speech, padding = self.padding, audio_vq_ds_rate = self.audio_vq_ds_rate
1313
+ ).to(speech.dtype).to(self.tokenizer.conv1.weight.device)
1314
+ for speech in speechs
1315
+ ]
1316
+ return mels
1317
+
1318
+ def mel2code(self, mels):
1319
+ audio_mellens = [mel.size(-1) for mel in mels]
1320
+ audio_aftercnnlens = [get_T_after_cnn(T) for T in audio_mellens]
1321
+ audio_seqlens = [T + 2 for T in audio_aftercnnlens]
1322
+
1323
+ with torch.no_grad():
1324
+ _, indices = self.tokenizer(
1325
+ x_list = mels,
1326
+ audio_mellens = audio_mellens,
1327
+ audio_aftercnnlens = audio_aftercnnlens,
1328
+ audio_seqlens = audio_seqlens,
1329
+ return_indices=True,
1330
+ )
1331
+
1332
+ indice_lens = [T // self.tokenizer.audio_vq_ds_rate for T in audio_aftercnnlens]
1333
+ indices = pad_sequence(torch.split(indices, indice_lens), batch_first=True, padding_value=0)
1334
+
1335
+ return indices, indice_lens
1336
+
1337
+ def quantize_speech(self, speechs):
1338
+ mels = self.speech2mel(speechs)
1339
+ indices, indice_lens = self.mel2code(mels)
1340
+ return indices, indice_lens
1341
+
1342
+
1343
+ @auto_docstring
1344
+ class Qwen3TTSTokenizerV1PreTrainedModel(PreTrainedModel):
1345
+ config: Qwen3TTSTokenizerV1Config
1346
+ base_model_prefix = "model"
1347
+ supports_gradient_checkpointing = True
1348
+ _skip_keys_device_placement = "past_key_values"
1349
+ _supports_flash_attn = True
1350
+ _supports_sdpa = True
1351
+ _can_compile_fullgraph = False
1352
+ _supports_attention_backend = True
1353
+
1354
+
1355
+ @auto_docstring(
1356
+ custom_intro="""
1357
+ The Qwen3TTSTokenizerV1 model.
1358
+ """
1359
+ )
1360
+ class Qwen3TTSTokenizerV1Model(Qwen3TTSTokenizerV1PreTrainedModel):
1361
+ def __init__(self, config: Qwen3TTSTokenizerV1Config):
1362
+ super().__init__(config)
1363
+ self.config = config
1364
+
1365
+ self.input_sample_rate = config.input_sample_rate
1366
+ self.output_sample_rate = config.output_sample_rate
1367
+
1368
+ self.decode_upsample_rate = config.decode_upsample_rate
1369
+ self.encode_downsample_rate = config.encode_downsample_rate
1370
+
1371
+ self.encoder = Qwen3TTSTokenizerV1Encoder._from_config(self.config.encoder_config)
1372
+ self.decoder = Qwen3TTSTokenizerV1Decoder._from_config(self.config.decoder_config)
1373
+
1374
+ self.encoder_xvector_extractor = None
1375
+
1376
+ self.post_init()
1377
+
1378
+ def load_encoder_xvector_extractor(self, model_path):
1379
+ self.encoder_xvector_extractor = XVectorExtractor(model_path)
1380
+
1381
+ def get_model_type(self):
1382
+ return self.config.model_type
1383
+
1384
+ def get_input_sample_rate(self):
1385
+ return self.input_sample_rate
1386
+
1387
+ def get_output_sample_rate(self):
1388
+ return self.output_sample_rate
1389
+
1390
+ def get_encode_downsample_rate(self):
1391
+ return self.encode_downsample_rate
1392
+
1393
+ def get_decode_upsample_rate(self):
1394
+ return self.decode_upsample_rate
1395
+
1396
+ @classmethod
1397
+ def from_pretrained(
1398
+ cls,
1399
+ pretrained_model_name_or_path,
1400
+ *model_args,
1401
+ config=None,
1402
+ cache_dir=None,
1403
+ ignore_mismatched_sizes=False,
1404
+ force_download=False,
1405
+ local_files_only=False,
1406
+ token=None,
1407
+ revision="main",
1408
+ use_safetensors=None,
1409
+ weights_only=True,
1410
+ **kwargs,
1411
+ ):
1412
+ model = super().from_pretrained(
1413
+ pretrained_model_name_or_path,
1414
+ *model_args,
1415
+ config=config,
1416
+ cache_dir=cache_dir,
1417
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
1418
+ force_download=force_download,
1419
+ local_files_only=local_files_only,
1420
+ token=token,
1421
+ revision=revision,
1422
+ use_safetensors=use_safetensors,
1423
+ weights_only=weights_only,
1424
+ **kwargs,
1425
+ )
1426
+ encoder_xvector_extractor_path = cached_file(
1427
+ pretrained_model_name_or_path,
1428
+ "campplus.onnx",
1429
+ subfolder=kwargs.pop("subfolder", None),
1430
+ cache_dir=kwargs.pop("cache_dir", None),
1431
+ force_download=kwargs.pop("force_download", False),
1432
+ proxies=kwargs.pop("proxies", None),
1433
+ resume_download=kwargs.pop("resume_download", None),
1434
+ local_files_only=kwargs.pop("local_files_only", False),
1435
+ token=kwargs.pop("use_auth_token", None),
1436
+ revision=kwargs.pop("revision", None),
1437
+ )
1438
+ if encoder_xvector_extractor_path is None:
1439
+ raise ValueError(f"""{pretrained_model_name_or_path}/{encoder_xvector_extractor_path} not exists""")
1440
+ model.load_encoder_xvector_extractor(encoder_xvector_extractor_path)
1441
+
1442
+ return model
1443
+
1444
+ def encode(
1445
+ self,
1446
+ input_values: torch.Tensor,
1447
+ padding_mask: Optional[torch.Tensor] = None,
1448
+ return_dict: Optional[bool] = None,
1449
+ ) -> Union[tuple[torch.Tensor, Optional[torch.Tensor]], Qwen3TTSTokenizerV1EncoderOutput]:
1450
+ """
1451
+ Encodes the input audio waveform into discrete codes.
1452
+
1453
+ Args:
1454
+ input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`):
1455
+ Float values of the input audio waveform.
1456
+ padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`):
1457
+ Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0
1458
+ for *masked*.
1459
+ return_dict (`bool`, *optional*):
1460
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1461
+ """
1462
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1463
+
1464
+ wavs = [value[:mask.sum()] for value, mask in zip(input_values, padding_mask)]
1465
+
1466
+ codes, codes_lens = self.encoder.quantize_speech(wavs)
1467
+ codes = [c[:l] for c, l in zip(codes, codes_lens)]
1468
+
1469
+ xvectors = []
1470
+ ref_mels = []
1471
+ for wav in wavs:
1472
+ xvector, ref_mel = self.encoder_xvector_extractor.extract_code(wav.cpu().numpy())
1473
+ xvector = torch.tensor(xvector).to(wav.dtype).to(wav.device)
1474
+ ref_mel = torch.tensor(ref_mel).to(wav.dtype).to(wav.device)
1475
+ xvectors.append(xvector)
1476
+ ref_mels.append(ref_mel)
1477
+
1478
+ if not return_dict:
1479
+ return (
1480
+ codes,
1481
+ xvectors,
1482
+ ref_mels
1483
+ )
1484
+
1485
+ return Qwen3TTSTokenizerV1EncoderOutput(codes, xvectors, ref_mels)
1486
+
1487
+ def decode(
1488
+ self,
1489
+ audio_codes: torch.Tensor,
1490
+ xvectors: torch.Tensor,
1491
+ ref_mels: torch.Tensor,
1492
+ return_dict: Optional[bool] = None,
1493
+ ) -> Union[tuple[torch.Tensor, torch.Tensor], Qwen3TTSTokenizerV1DecoderOutput]:
1494
+ """
1495
+ Decodes the given frames into an output audio waveform.
1496
+
1497
+ Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
1498
+ trimmed.
1499
+
1500
+ Args:
1501
+ audio_codes (`torch.LongTensor` of shape `(batch_size, codes_length)`, *optional*):
1502
+ Discret code embeddings computed using `model.encode`.
1503
+ xvectors (`torch.FloatTensor` of shape `(batch_size, xvector_dim)`, *optional*):
1504
+ X-vector embeddings computed using `model.encode`.
1505
+ ref_mels (`torch.FloatTensor` of shape `(batch_size, mel_length, mel_dim)`, *optional*):
1506
+ Reference mel spectrogram computed using `model.encode`.
1507
+ return_dict (`bool`, *optional*):
1508
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1509
+
1510
+ """
1511
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1512
+
1513
+ audio_values = self.decoder(code=audio_codes,
1514
+ reference_mel=ref_mels,
1515
+ conditioning=xvectors)
1516
+
1517
+ audio_lengths = (audio_codes > 0).sum(1) * self.decode_upsample_rate
1518
+ audio_values = [a[:l] for a, l in zip(audio_values, audio_lengths)]
1519
+
1520
+ if not return_dict:
1521
+ return (
1522
+ audio_values,
1523
+ )
1524
+
1525
+ return Qwen3TTSTokenizerV1DecoderOutput(audio_values)
1526
+
1527
+
1528
+ __all__ = ["Qwen3TTSTokenizerV1Model", "Qwen3TTSTokenizerV1PreTrainedModel"]
qwen_tts/core/tokenizer_25hz/vq/assets/mel_filters.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7450ae70723a5ef9d341e3cee628c7cb0177f36ce42c44b7ed2bf3325f0f6d4c
3
+ size 4271
qwen_tts/core/tokenizer_25hz/vq/core_vq.py ADDED
@@ -0,0 +1,523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # This implementation is inspired from
8
+ # https://github.com/lucidrains/vector-quantize-pytorch
9
+ # which is released under MIT License. Hereafter, the original license:
10
+ # MIT License
11
+ #
12
+ # Copyright (c) 2020 Phil Wang
13
+ #
14
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ # of this software and associated documentation files (the "Software"), to deal
16
+ # in the Software without restriction, including without limitation the rights
17
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ # copies of the Software, and to permit persons to whom the Software is
19
+ # furnished to do so, subject to the following conditions:
20
+ #
21
+ # The above copyright notice and this permission notice shall be included in all
22
+ # copies or substantial portions of the Software.
23
+ #
24
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30
+ # SOFTWARE.
31
+
32
+ """Core vector quantization implementation."""
33
+ import random
34
+ import typing as tp
35
+ from random import randrange
36
+
37
+ import numpy as np
38
+ from einops import rearrange, repeat
39
+ from math import ceil
40
+ import torch
41
+ from torch import nn
42
+ import torch.nn.functional as F
43
+
44
+
45
+ def round_up_multiple(num, mult):
46
+ return ceil(num / mult) * mult
47
+
48
+ def default(val: tp.Any, d: tp.Any) -> tp.Any:
49
+ return val if val is not None else d
50
+
51
+
52
+ def ema_inplace(moving_avg, new, decay: float):
53
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
54
+
55
+
56
+ def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
57
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
58
+
59
+
60
+ def uniform_init(*shape: int):
61
+ t = torch.empty(shape)
62
+ nn.init.kaiming_uniform_(t)
63
+ return t
64
+
65
+
66
+ def sample_vectors(samples, num: int):
67
+ num_samples, device = samples.shape[0], samples.device
68
+
69
+ if num_samples >= num:
70
+ indices = torch.randperm(num_samples, device=device)[:num]
71
+ else:
72
+ indices = torch.randint(0, num_samples, (num,), device=device)
73
+
74
+ return samples[indices]
75
+
76
+
77
+ @torch.no_grad()
78
+ def kmeans(samples, num_clusters: int, num_iters: int = 10):
79
+ dim, dtype = samples.shape[-1], samples.dtype
80
+
81
+ means = sample_vectors(samples, num_clusters)
82
+
83
+ for _ in range(num_iters):
84
+ dists = -(
85
+ samples.pow(2).sum(1, keepdim=True)
86
+ - 2 * torch.matmul(samples, means.t())
87
+ + means.t().pow(2).sum(0, keepdim=True)
88
+ )
89
+
90
+ buckets = dists.max(dim=-1).indices
91
+ del dists
92
+ bins = torch.bincount(buckets, minlength=num_clusters)
93
+ zero_mask = bins == 0
94
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
95
+
96
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
97
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
98
+ new_means = new_means / bins_min_clamped[..., None]
99
+
100
+ means = torch.where(zero_mask[..., None], means, new_means)
101
+ return means, bins
102
+
103
+
104
+ def preprocess(x):
105
+ x = rearrange(x, "... d -> (...) d")
106
+ return x
107
+
108
+
109
+ def postprocess_emb(embed_ind, shape):
110
+ return embed_ind.view(*shape[:-1])
111
+
112
+
113
+ class EuclideanCodebook(nn.Module):
114
+ """Codebook with Euclidean distance.
115
+ Args:
116
+ dim (int): Dimension.
117
+ codebook_size (int): Codebook size.
118
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
119
+ If set to true, run the k-means algorithm on the first training batch and use
120
+ the learned centroids as initialization.
121
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
122
+ decay (float): Decay for exponential moving average over the codebooks.
123
+ epsilon (float): Epsilon value for numerical stability.
124
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
125
+ that have an exponential moving average cluster size less than the specified threshold with
126
+ randomly selected vector from the current batch.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ dim: int,
132
+ codebook_size: int,
133
+ kmeans_init: int = False,
134
+ kmeans_iters: int = 10,
135
+ decay: float = 0.99,
136
+ epsilon: float = 1e-5,
137
+ threshold_ema_dead_code: float = 2.0,
138
+ ):
139
+ super().__init__()
140
+ self.decay = decay
141
+ self.codebook_size = codebook_size
142
+ self.kmeans_iters = kmeans_iters
143
+ self.epsilon = epsilon
144
+ self.threshold_ema_dead_code = threshold_ema_dead_code
145
+
146
+ self.inited = None
147
+ self.cluster_size = None
148
+ self.embed = None
149
+ self.embed_avg = None
150
+ self.training = True
151
+
152
+ def init_embed_(self, data):
153
+ if self.inited:
154
+ return
155
+
156
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
157
+ self.embed.data.copy_(embed)
158
+ self.embed_avg.data.copy_(embed.clone())
159
+ self.cluster_size.data.copy_(cluster_size)
160
+ self.inited.data.copy_(torch.Tensor([True]))
161
+ # Make sure all buffers across workers are in sync after initialization
162
+ # distrib.broadcast_tensors([self.embed, self.embed_avg, self.cluster_size, self.inited])
163
+
164
+ def replace_(self, samples, mask):
165
+ modified_codebook = torch.where(
166
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
167
+ )
168
+ self.embed.data.copy_(modified_codebook)
169
+
170
+ def expire_codes_(self, batch_samples):
171
+ if self.threshold_ema_dead_code == 0:
172
+ return
173
+
174
+ cluster_size = self.cluster_size / sum(self.cluster_size) * self.codebook_size
175
+ expired_codes = cluster_size < self.threshold_ema_dead_code
176
+ if not torch.any(expired_codes):
177
+ return
178
+ else:
179
+ print(f"VQ expire infos: num_expire={sum(expired_codes)}, cluster_size[:5]={cluster_size[:5]}")
180
+
181
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
182
+ self.replace_(batch_samples, mask=expired_codes)
183
+ # sync buffers outside for efficiency
184
+ # distrib.broadcast_tensors(self.buffers())
185
+
186
+ def quantize(self, x):
187
+ embed = self.embed.t()
188
+ dist = -(
189
+ x.pow(2).sum(1, keepdim=True)
190
+ - 2 * x @ embed
191
+ + embed.pow(2).sum(0, keepdim=True)
192
+ )
193
+ embed_ind = dist.max(dim=-1).indices
194
+ return embed_ind
195
+
196
+ def dequantize(self, embed_ind):
197
+ quantize = F.embedding(embed_ind, self.embed)
198
+ return quantize
199
+
200
+ def encode(self, x, buffers):
201
+ self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
202
+
203
+ shape = x.shape
204
+ # pre-process
205
+ x = preprocess(x)
206
+ # quantize
207
+ embed_ind = self.quantize(x)
208
+ # post-process
209
+ embed_ind = postprocess_emb(embed_ind, shape)
210
+ return embed_ind
211
+
212
+ def decode(self, embed_ind, buffers):
213
+ self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
214
+
215
+ quantize = self.dequantize(embed_ind)
216
+ return quantize
217
+
218
+ def forward(self, x, buffers):
219
+ self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
220
+
221
+ shape, dtype = x.shape, x.dtype
222
+ x = preprocess(x)
223
+
224
+ self.init_embed_(x)
225
+ if self.training:
226
+ # We do the expiry of code at that point as buffers are in sync
227
+ # and all the workers will take the same decision.
228
+ self.expire_codes_(x)
229
+
230
+ embed_ind = self.quantize(x)
231
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
232
+ embed_ind = postprocess_emb(embed_ind, shape)
233
+ quantize = self.dequantize(embed_ind)
234
+
235
+ if self.training:
236
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
237
+ embed_sum = x.t() @ embed_onehot
238
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
239
+ cluster_size = (
240
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
241
+ * self.cluster_size.sum()
242
+ )
243
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
244
+ self.embed.data.copy_(embed_normalized)
245
+ # Note: after ema update, there is a very small difference between codebooks on GPUs.
246
+ # The impact can be very small, ignore it.
247
+
248
+ return quantize, embed_ind
249
+
250
+
251
+ class VectorQuantization(nn.Module):
252
+ """Vector quantization implementation.
253
+ Currently, supports only euclidean distance.
254
+ Args:
255
+ dim (int): Dimension
256
+ codebook_size (int): Codebook size
257
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
258
+ decay (float): Decay for exponential moving average over the codebooks.
259
+ epsilon (float): Epsilon value for numerical stability.
260
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
261
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
262
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
263
+ that have an exponential moving average cluster size less than the specified threshold with
264
+ randomly selected vector from the current batch.
265
+ commitment_weight (float): Weight for commitment loss.
266
+ """
267
+ def __init__(
268
+ self,
269
+ dim: int,
270
+ codebook_size: int,
271
+ codebook_dim: tp.Optional[int] = None,
272
+ decay: float = 0.99,
273
+ epsilon: float = 1e-5,
274
+ kmeans_init: bool = True,
275
+ kmeans_iters: int = 50,
276
+ threshold_ema_dead_code: float = 2.0,
277
+ commitment_weight: float = 1.,
278
+ ):
279
+ super().__init__()
280
+ _codebook_dim: int = default(codebook_dim, dim)
281
+
282
+ requires_projection = _codebook_dim != dim
283
+ self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity())
284
+ self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity())
285
+
286
+ self.epsilon = epsilon
287
+ self.commitment_weight = commitment_weight
288
+
289
+ self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
290
+ kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
291
+ decay=decay, epsilon=epsilon,
292
+ threshold_ema_dead_code=threshold_ema_dead_code)
293
+ self.codebook_size = codebook_size
294
+ self.training = True
295
+
296
+ @property
297
+ def codebook(self):
298
+ return self._codebook.embed
299
+
300
+ def encode(self, x, buffers):
301
+ # x = rearrange(x, "b d n -> b n d")
302
+ x = self.project_in(x)
303
+ embed_in = self._codebook.encode(x, buffers)
304
+ return embed_in
305
+
306
+ def decode(self, embed_ind, buffers):
307
+ quantize = self._codebook.decode(embed_ind, buffers)
308
+ quantize = self.project_out(quantize)
309
+ # quantize = rearrange(quantize, "b n d -> b d n")
310
+ return quantize
311
+
312
+ def forward(self, x, buffers):
313
+ device = x.device
314
+ # x = rearrange(x, "b d n -> b n d")
315
+ x = self.project_in(x)
316
+
317
+ quantize, embed_ind = self._codebook(x, buffers)
318
+
319
+ if self.training:
320
+ quantize = x + (quantize - x).detach()
321
+
322
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
323
+
324
+ if self.training:
325
+ if self.commitment_weight > 0:
326
+ commit_loss = F.mse_loss(quantize.detach(), x)
327
+ loss = loss + commit_loss * self.commitment_weight
328
+
329
+ quantize = self.project_out(quantize)
330
+ # quantize = rearrange(quantize, "b n d -> b d n")
331
+ return quantize, embed_ind, loss
332
+
333
+
334
+ class DistributedResidualVectorQuantization(nn.Module):
335
+ """Efficient distributed residual vector quantization implementation.
336
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
337
+ """
338
+ def __init__(self, *,
339
+ num_quantizers,
340
+ quantize_dropout: bool = False,
341
+ rand_num_quant: tp.Optional[tp.List] = None,
342
+ **kwargs):
343
+ super().__init__()
344
+ """
345
+ dim: int,
346
+ codebook_size: int,
347
+ codebook_dim: tp.Optional[int] = None,
348
+ """
349
+ codebook_size, codebook_dim = kwargs["codebook_size"], kwargs["codebook_dim"] if kwargs["codebook_dim"] else kwargs["dim"]
350
+ kmeans_init = kwargs["kmeans_init"]
351
+ if isinstance(kmeans_init, bool):
352
+ if not kwargs["kmeans_init"]:
353
+ # use uniform init
354
+ embed = uniform_init(num_quantizers, codebook_size, codebook_dim)
355
+ inited = True
356
+ else:
357
+ # to perform kmeans init on first batch
358
+ embed = torch.zeros(num_quantizers, codebook_size, codebook_dim)
359
+ inited = False
360
+ elif isinstance(kmeans_init, str):
361
+ # use prepared kmeans init
362
+ embed = np.load(kmeans_init)
363
+ embed = torch.from_numpy(embed)
364
+ if embed.dim() == 2:
365
+ embed = embed.unsqueeze(0)
366
+ inited = True
367
+ else:
368
+ raise TypeError("kmeans_init should be either a bool or string path to init weights.")
369
+
370
+ self.register_buffer("inited", torch.Tensor([[inited] for _ in range(num_quantizers)]))
371
+ self.register_buffer("cluster_size", torch.zeros(num_quantizers, codebook_size))
372
+ self.register_buffer("embed", embed)
373
+ self.register_buffer("embed_avg", embed.clone())
374
+
375
+ self.q0_ds_ratio = 1
376
+ if "q0_ds_ratio" in kwargs:
377
+ self.q0_ds_ratio = kwargs.pop("q0_ds_ratio")
378
+
379
+ self.layers = nn.ModuleList()
380
+ for i in range(num_quantizers):
381
+ vq_args = dict(**kwargs)
382
+ vq = VectorQuantization(**vq_args)
383
+ self.layers.append(vq)
384
+
385
+ self.quantize_dropout = quantize_dropout
386
+ self.rand_num_quant = rand_num_quant
387
+
388
+ def forward(self, x, n_q: tp.Optional[int] = None):
389
+ quantized_out = torch.zeros_like(x)
390
+ residual = x
391
+ bb, cc, tt = x.shape
392
+ device = x.device
393
+
394
+ all_losses = []
395
+ all_indices = []
396
+ all_sub_quants = []
397
+ n_q = n_q or len(self.layers)
398
+
399
+ should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None
400
+ if should_quantize_dropout:
401
+ rand_quantize_dropout_index = random.choice(self.rand_num_quant)
402
+
403
+ null_indices_shape = (x.shape[0], x.shape[2])
404
+ null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
405
+ null_loss = torch.full((1,), 0., device=device, dtype=x.dtype)
406
+ null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype)
407
+
408
+ for quantizer_index, layer in enumerate(self.layers[:n_q]):
409
+ # dropout except the first quantizer
410
+ if should_quantize_dropout and quantizer_index >= rand_quantize_dropout_index:
411
+ all_indices.append(null_indices)
412
+ all_losses.append(null_loss)
413
+ all_sub_quants.append(null_sub_quant)
414
+ continue
415
+
416
+ quant_in = residual
417
+ if self.q0_ds_ratio > 1 and quantizer_index == 0:
418
+ quant_in = F.interpolate(quant_in, size=[tt//2])
419
+ quantized, indices, loss = layer(quant_in, [
420
+ self.inited[quantizer_index],
421
+ self.cluster_size[quantizer_index],
422
+ self.embed[quantizer_index],
423
+ self.embed_avg[quantizer_index]
424
+ ])
425
+ if self.q0_ds_ratio > 1 and quantizer_index == 0:
426
+ quantized = F.interpolate(quantized, size=[tt])
427
+ indices = F.interpolate(indices.unsqueeze(1).float(), size=[tt]).squeeze(1).long()
428
+ residual = residual - quantized
429
+ quantized_out = quantized_out + quantized
430
+
431
+ all_indices.append(indices)
432
+ all_losses.append(loss)
433
+ all_sub_quants.append(quantized)
434
+
435
+ # sync buffers after one forward step
436
+ # distrib.broadcast_tensors(self.buffers())
437
+ out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants))
438
+
439
+ return quantized_out, out_indices, out_losses
440
+
441
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
442
+ residual = x
443
+ all_indices = []
444
+ n_q = n_q or len(self.layers)
445
+ for i, layer in enumerate(self.layers[:n_q]):
446
+ indices = layer.encode(residual, [
447
+ self.inited[i],
448
+ self.cluster_size[i],
449
+ self.embed[i],
450
+ self.embed_avg[i]
451
+ ])
452
+ quantized = layer.decode(indices, [
453
+ self.inited[i],
454
+ self.cluster_size[i],
455
+ self.embed[i],
456
+ self.embed_avg[i]
457
+ ])
458
+ residual = residual - quantized
459
+ all_indices.append(indices)
460
+ out_indices = torch.stack(all_indices)
461
+ return out_indices
462
+
463
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
464
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
465
+ for i, indices in enumerate(q_indices):
466
+ layer = self.layers[i]
467
+ quantized = layer.decode(indices, [
468
+ self.inited[i],
469
+ self.cluster_size[i],
470
+ self.embed[i],
471
+ self.embed_avg[i]
472
+ ])
473
+ quantized_out = quantized_out + quantized
474
+ return quantized_out
475
+
476
+
477
+ class DistributedGroupResidualVectorQuantization(nn.Module):
478
+ """Efficient distributed group residual vector quantization implementation.
479
+ Follows Algorithm 1. in https://arxiv.org/abs/2305.02765
480
+ Group Then rvq
481
+ """
482
+ def __init__(self, *,
483
+ num_groups,
484
+ num_quantizers,
485
+ quantize_dropout: bool = False,
486
+ rand_num_quant: tp.Optional[tp.List] = None,
487
+ **kwargs):
488
+ super().__init__()
489
+ self.rvqs = nn.ModuleList(
490
+ [
491
+ DistributedResidualVectorQuantization(
492
+ num_quantizers=num_quantizers,
493
+ quantize_dropout=quantize_dropout,
494
+ rand_num_quant=rand_num_quant,
495
+ **kwargs
496
+ )
497
+ for _ in range(num_groups)
498
+ ]
499
+ )
500
+ self.num_groups = num_groups
501
+
502
+ def forward(self, x, n_q: tp.Optional[int] = None):
503
+ x_lst = torch.chunk(x, chunks=self.num_groups, dim=1)
504
+ all_quantized_out = []
505
+ all_indices = []
506
+ all_losses = []
507
+ for mod, item in zip(self.rvqs, x_lst):
508
+ quantized_out, out_indices, out_losses = mod(item, n_q)
509
+ all_quantized_out.append(quantized_out)
510
+ all_indices.append(out_indices)
511
+ all_losses.append(out_losses)
512
+
513
+ out_losses = torch.stack(all_losses, dim=1).mean(dim=1)
514
+
515
+ return torch.cat(all_quantized_out, dim=1), torch.stack(all_indices, dim=1), out_losses
516
+
517
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
518
+ x_lst = torch.chunk(x, chunks=self.num_groups, dim=1)
519
+ return torch.stack([mod.encode(item, n_q) for mod, item in zip(self.rvqs, x_lst)], dim=1)
520
+
521
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
522
+ q_indices_lst = torch.chunk(q_indices, chunks=self.num_groups, dim=1)
523
+ return torch.cat([mod.decode(item.squeeze(1)) for mod, item in zip(self.rvqs, q_indices_lst)], dim=1)
qwen_tts/core/tokenizer_25hz/vq/speech_vq.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import sox
17
+ import copy
18
+ import torch
19
+ import operator
20
+ import onnxruntime
21
+
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import torchaudio.compliance.kaldi as kaldi
25
+
26
+ from librosa.filters import mel as librosa_mel_fn
27
+ from itertools import accumulate
28
+ from typing import List
29
+ from torch import Tensor
30
+
31
+ from .core_vq import DistributedGroupResidualVectorQuantization
32
+ from .whisper_encoder import WhisperEncoder, Conv1d, ConvTranspose1d
33
+
34
+
35
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
36
+ return torch.log(torch.clamp(x, min=clip_val) * C)
37
+
38
+ def spectral_normalize_torch(magnitudes):
39
+ output = dynamic_range_compression_torch(magnitudes)
40
+ return output
41
+
42
+ class MelSpectrogramFeatures(nn.Module):
43
+ """
44
+ Calculate the BigVGAN style mel spectrogram of an input signal.
45
+ Args:
46
+ filter_length (int): The number of samples in the filter window, used for the Fourier Transform. Default is 1024.
47
+ hop_length (int): The number of samples between successive frames (stride of the STFT). Default is 160.
48
+ win_length (int): The length of the window function applied to each frame, usually less than or equal to the filter length. Default is 640.
49
+ n_mel_channels (int): The number of Mel-frequency channels to output from the Mel-scale spectrogram. Default is 80.
50
+ mel_fmin (int): The minimum frequency (in Hz) of the Mel-scale spectrogram. Default is 0.
51
+ mel_fmax (int): The maximum frequency (in Hz) of the Mel-scale spectrogram. Default is 8000.
52
+ sampling_rate (int): The sampling rate of the audio data (in Hz). Default is 16000.
53
+ sampling_rate_org (int, optional): The original sampling rate of the audio data before any resampling (in Hz), if applicable. Default is None.
54
+ padding (str): The padding mode for the input signal. 'center' pads the signal symmetrically around its center. Default is 'center'.
55
+
56
+ Returns:
57
+ torch.Tensor: Mel spectrogram.
58
+ """
59
+ def __init__(self,
60
+ filter_length=1024,
61
+ hop_length=160,
62
+ win_length=640,
63
+ n_mel_channels=80,
64
+ mel_fmin=0,
65
+ mel_fmax=8000,
66
+ sampling_rate=16000,
67
+ sampling_rate_org=None,
68
+ padding='center',
69
+ use_db = False,
70
+ ):
71
+ super().__init__()
72
+ if padding not in ["center", "same"]:
73
+ raise ValueError("Padding must be 'center' or 'same'.")
74
+ self.padding = padding
75
+
76
+ self.filter_length = filter_length
77
+ self.hop_length = hop_length
78
+ self.win_length = win_length
79
+ self.n_mel_channels = n_mel_channels
80
+ self.mel_fmin = mel_fmin
81
+ self.mel_fmax = mel_fmax
82
+ self.sampling_rate = sampling_rate
83
+ self.sampling_rate_org = sampling_rate_org if sampling_rate_org is not None else sampling_rate
84
+ self.mel_basis = {}
85
+ self.hann_window = {}
86
+
87
+ def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
88
+ with torch.no_grad():
89
+ feats = self.extract(audio, **kwargs)
90
+ return feats
91
+
92
+ def extract(self, audio, **kwargs):
93
+
94
+ if len(audio.shape) == 3:
95
+ audio = audio.squeeze(1) if audio.shape[1] == 1 else audio.squeeze(2)
96
+ assert len(audio.shape) == 2
97
+
98
+ y = audio
99
+ if len(list(self.mel_basis.keys())) == 0:
100
+ mel = librosa_mel_fn(sr=self.sampling_rate, n_fft=self.filter_length, n_mels=self.n_mel_channels, fmin=self.mel_fmin, fmax=self.mel_fmax)
101
+ self.mel_basis[str(self.mel_fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
102
+ self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to(y.device)
103
+
104
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((self.filter_length-self.hop_length)/2), int((self.filter_length-self.hop_length)/2)), mode='reflect')
105
+ y = y.squeeze(1)
106
+
107
+ spec = torch.stft(y, self.filter_length, hop_length=self.hop_length, win_length=self.win_length, window=self.hann_window[str(y.device)],
108
+ center=False, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
109
+ spec = torch.view_as_real(spec)
110
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
111
+
112
+ spec = torch.matmul(self.mel_basis[str(self.mel_fmax)+'_'+str(y.device)], spec)
113
+ spec = spectral_normalize_torch(spec)
114
+
115
+ return spec
116
+
117
+
118
+ class XVectorExtractor(nn.Module):
119
+ def __init__(self, audio_codec_with_xvector):
120
+ super().__init__()
121
+ option = onnxruntime.SessionOptions()
122
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
123
+ option.intra_op_num_threads = 1
124
+ providers = ["CPUExecutionProvider"]
125
+ self.ort_session = onnxruntime.InferenceSession(audio_codec_with_xvector, sess_options=option, providers=providers)
126
+
127
+ self.tfm = sox.Transformer()
128
+ self.tfm.norm(db_level=-6)
129
+
130
+ self.mel_ext = MelSpectrogramFeatures(
131
+ filter_length=1024,
132
+ hop_length=160,
133
+ win_length=640,
134
+ n_mel_channels=80,
135
+ mel_fmin=0,
136
+ mel_fmax=8000,
137
+ sampling_rate=16000
138
+ )
139
+
140
+ def extract_code(self, audio):
141
+ with torch.no_grad():
142
+ norm_audio = self.sox_norm(audio)
143
+
144
+ norm_audio = torch.from_numpy(copy.deepcopy(norm_audio)).unsqueeze(0)
145
+ feat = kaldi.fbank(norm_audio,
146
+ num_mel_bins=80,
147
+ dither=0,
148
+ sample_frequency=16000)
149
+ feat = feat - feat.mean(dim=0, keepdim=True)
150
+ norm_embedding = self.ort_session.run(None, {self.ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten()
151
+ norm_embedding = F.normalize(torch.from_numpy(norm_embedding), dim=0)
152
+
153
+ ref_mel = self.mel_ext.extract(audio=norm_audio)
154
+
155
+ return norm_embedding.numpy(), ref_mel.permute(0,2,1).squeeze(0).numpy()
156
+
157
+ def sox_norm(self, audio):
158
+ wav_norm = self.tfm.build_array(input_array=audio, sample_rate_in=16000)
159
+ return wav_norm
160
+
161
+
162
+ class WhisperEncoderVQ(WhisperEncoder):
163
+ def __init__(
164
+ self,
165
+ n_mels: int,
166
+ n_ctx: int,
167
+ n_state: int,
168
+ n_head: int,
169
+ n_layer: int,
170
+ n_window: int = 1500,
171
+ output_dim: int = 512,
172
+ grad_checkpointing: bool = False,
173
+ enable_mp: bool = False,
174
+ audio_sequence_parallel: bool = False,
175
+ audio_vq_layers: int = -1,
176
+ audio_vq_type: str = "NULL",
177
+ audio_vq_codebook_size: int = 4096,
178
+ audio_vq_pe: bool = False,
179
+ audio_vq_commit_loss: float = 0.0,
180
+ audio_vq_out_commit_loss: float = 0.0,
181
+ audio_vq_no_quantize: bool = False,
182
+ audio_vq_ff_layer: int = 0,
183
+ audio_vq_threshold_ema_dead_code: float = 0.1,
184
+ audio_vq_codebook_dim: int = None,
185
+ audio_vq_ds_rate: int = None,
186
+ ):
187
+ super().__init__(n_mels, n_ctx, n_state, n_head, n_layer, n_window, output_dim, grad_checkpointing, enable_mp, audio_sequence_parallel)
188
+
189
+ self.audio_vq_layers = audio_vq_layers
190
+ self.audio_vq_type = audio_vq_type
191
+ self.audio_vq_codebook_size = audio_vq_codebook_size
192
+ self.audio_vq_pe = audio_vq_pe
193
+ self.audio_vq_commit_loss = audio_vq_commit_loss
194
+ self.audio_vq_out_commit_loss = audio_vq_out_commit_loss
195
+ self.audio_vq_no_quantize = audio_vq_no_quantize
196
+ self.audio_vq_ff_layer = audio_vq_ff_layer
197
+
198
+ if audio_vq_layers > 0:
199
+ self.vq_feature_dim = self.n_state
200
+ self.audio_vq_ds_rate = 1
201
+ else:
202
+ raise NotImplementedError(f"Unsupported audio_vq_layers: {audio_vq_layers}")
203
+
204
+ if self.audio_vq_ds_rate == audio_vq_ds_rate:
205
+ self.audio_vq_downsample = nn.Identity()
206
+ self.audio_vq_upsample = nn.Identity()
207
+ else:
208
+ assert audio_vq_ds_rate % self.audio_vq_ds_rate == 0
209
+ stride = audio_vq_ds_rate // self.audio_vq_ds_rate
210
+ self.audio_vq_downsample = Conv1d(self.vq_feature_dim, self.vq_feature_dim, kernel_size=stride, stride=stride)
211
+ self.audio_vq_upsample = ConvTranspose1d(self.vq_feature_dim, self.vq_feature_dim, kernel_size=stride, stride=stride)
212
+ self.audio_vq_ds_rate = audio_vq_ds_rate
213
+
214
+ if audio_vq_type == "GRVQ":
215
+ self.audio_quantizer = DistributedGroupResidualVectorQuantization(
216
+ codebook_size = audio_vq_codebook_size,
217
+ dim = self.vq_feature_dim,
218
+ codebook_dim = self.vq_codebook_dim if audio_vq_codebook_dim is None else audio_vq_codebook_dim,
219
+ num_groups=1,
220
+ num_quantizers=1,
221
+ kmeans_init=False,
222
+ threshold_ema_dead_code = audio_vq_threshold_ema_dead_code
223
+ )
224
+ else:
225
+ raise NotImplementedError(f"Unsupported audio_vq_type: {audio_vq_type}")
226
+
227
+ if self.audio_vq_pe:
228
+ self.project_after_vq_pe = nn.Linear(self.n_state, self.n_state)
229
+
230
+ def _calc_quantize_activities(self, indices):
231
+ indices_onehot = F.one_hot(indices.long().flatten(), self.audio_vq_codebook_size).sum(dim=0)
232
+ vq_num_activities = sum(indices_onehot>0)
233
+ vq_num_tokens = sum(indices_onehot)
234
+ return {
235
+ "vq_num_activities": vq_num_activities,
236
+ "vq_num_tokens": vq_num_tokens,
237
+ }
238
+
239
+ def _do_quantize(self, x, pe=None, y=None):
240
+ """
241
+ x: torch.Tensor, shape = (T, D)
242
+ q: torch.Tensor, shape = (T, D)
243
+ i: torch.Tensor, shape = (T)
244
+ """
245
+ if self.audio_vq_out_commit_loss > 0:
246
+ x_teacher = x.clone()
247
+ x = x.unsqueeze(0)
248
+
249
+ x = self.audio_vq_downsample(x.transpose(1, 2))
250
+ x = x.transpose(1, 2)
251
+
252
+ vq_stats = {}
253
+
254
+ if self.audio_vq_type == "GRVQ":
255
+ if self.training:
256
+ raise NotImplementedError
257
+ else:
258
+ indices = self.audio_quantizer.encode(x)
259
+ x = self.audio_quantizer.decode(indices)
260
+ indices = indices.squeeze(2).squeeze(1)
261
+
262
+ vq_stats.update(self._calc_quantize_activities(indices))
263
+
264
+ x, indices = x.squeeze(0), indices.squeeze(0)
265
+ if self.audio_vq_pe:
266
+ x = x + pe
267
+ x = self.project_after_vq_pe(x)
268
+
269
+ x = self.audio_vq_upsample(x.unsqueeze(0).transpose(1, 2))
270
+ x = x.transpose(1, 2).squeeze(0)
271
+
272
+ if self.audio_vq_out_commit_loss > 0:
273
+ vq_out_commit_loss = F.mse_loss(x_teacher.detach(), x)
274
+ vq_stats["vq_out_commit_loss"] = vq_out_commit_loss * self.audio_vq_out_commit_loss
275
+
276
+ return x, indices, vq_stats
277
+
278
+ def forward(self, x_list: List[Tensor], audio_mellens:List[int], audio_aftercnnlens:List[int], audio_seqlens:List[int], return_indices=False, audio_pitchs=None):
279
+ """
280
+ x : torch.Tensor, shape = (n_mels, n_ctx)
281
+ the mel spectrogram of the audio
282
+ """
283
+
284
+ aftercnn_x_list = []
285
+ pe_for_vq_list = []
286
+ for each_x in x_list:
287
+ each_x_split_list = each_x.split(self.n_window * 2, dim=1)
288
+ for each_x_split in each_x_split_list:
289
+ each_x_split = F.gelu(self.conv1(each_x_split))
290
+ each_x_split = F.gelu(self.conv2(each_x_split))
291
+ each_x_split = each_x_split.permute(1, 0) # L,D
292
+
293
+ each_positional_embedding_split = self.positional_embedding[:each_x_split.shape[0]]
294
+ aftercnn_x_list.append(each_x_split+each_positional_embedding_split.to(each_x_split.dtype))
295
+
296
+ pe_for_vq_split = self.positional_embedding[:each_x_split.shape[0] // self.audio_vq_ds_rate]
297
+ pe_for_vq_list.append(pe_for_vq_split.to(each_x_split.dtype))
298
+
299
+ pe_for_vq = torch.cat(pe_for_vq_list, dim=0)
300
+ x = torch.cat(aftercnn_x_list, dim=0)
301
+ src_len = x.size(0)
302
+
303
+ output_list = []
304
+ for item in audio_aftercnnlens:
305
+ while item > self.n_window:
306
+ output_list.append(self.n_window)
307
+ item -= self.n_window
308
+ output_list.append(item)
309
+
310
+ cu_seqlens = list(accumulate(output_list, func=operator.add,initial=0))
311
+ cu_seqlens = torch.Tensor(cu_seqlens).to(device=x.device, dtype=torch.int32)
312
+
313
+ layer_id = 0
314
+
315
+ for block in self.blocks:
316
+ layer_id+=1
317
+
318
+ x = block(x, cu_seqlens=cu_seqlens)
319
+
320
+ if self.audio_vq_layers == layer_id: # vq inside encoder
321
+ x, indices, vq_stats = self._do_quantize(x, pe_for_vq)
322
+ if return_indices:
323
+ return x, indices
324
+
325
+ if self.avg_pooler:
326
+ x_list = x.split(audio_aftercnnlens, dim=0)
327
+ token_x_list = []
328
+ for x in x_list:
329
+ x = x.permute(1, 0)
330
+ x = self.avg_pooler(x)
331
+ x = x.permute(1, 0)
332
+ token_x_list.append(x)
333
+ x = torch.cat(token_x_list, dim=0)
334
+
335
+ x = self.ln_post(x)
336
+
337
+ x = self.proj(x)
338
+
339
+ output = torch.zeros(
340
+ (x.size(0) + len(audio_seqlens) * 2, x.size(1)),
341
+ device=x.device, dtype=x.dtype
342
+ )
343
+
344
+ audio_seqlens_acc = list(accumulate(audio_seqlens, func=operator.add, initial=0))
345
+ start_ids = torch.tensor(audio_seqlens_acc[:-1], device=x.device, dtype=torch.int32)
346
+ end_ids = torch.tensor(audio_seqlens_acc[1:], device=x.device, dtype=torch.int32) - 1
347
+
348
+ audio_tokens_mask = torch.ones(output.size(0), device=x.device, dtype=torch.bool)
349
+ audio_tokens_mask[start_ids] = False
350
+ audio_tokens_mask[end_ids] = False
351
+ output[start_ids] = self.audio_bos_eos_token.weight[0].to(x.dtype)
352
+ output[end_ids] = self.audio_bos_eos_token.weight[1].to(x.dtype)
353
+ output[audio_tokens_mask] = x
354
+
355
+ if self.audio_vq_type != "NULL":
356
+ return output, vq_stats
357
+ return output
qwen_tts/core/tokenizer_25hz/vq/whisper_encoder.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import os
17
+ import math
18
+ import torch
19
+ import operator
20
+
21
+ import numpy as np
22
+ import torch.nn.functional as F
23
+
24
+ from functools import lru_cache
25
+ from typing import Optional, Union, List
26
+ from torch import nn, Tensor
27
+ from itertools import accumulate
28
+
29
+ try:
30
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func
31
+ except ImportError:
32
+ try:
33
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_varlen_func
34
+ except ImportError:
35
+ print("\n********\nWarning: flash-attn is not installed. Will only run the manual PyTorch version. Please install flash-attn for faster inference.\n********\n ")
36
+ flash_attn_varlen_func = None
37
+
38
+
39
+ N_FFT = 400
40
+ HOP_LENGTH = 160
41
+
42
+
43
+ @lru_cache(maxsize=None)
44
+ def mel_filters(device, n_mels: int) -> torch.Tensor:
45
+ """
46
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
47
+ Allows decoupling librosa dependency; saved using:
48
+
49
+ np.savez_compressed(
50
+ "mel_filters.npz",
51
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
52
+ mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
53
+ )
54
+ """
55
+ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
56
+
57
+ filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
58
+ with np.load(filters_path, allow_pickle=False) as f:
59
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
60
+
61
+
62
+ def log_mel_spectrogram(
63
+ audio: Union[str, np.ndarray, torch.Tensor],
64
+ n_mels: int = 80,
65
+ padding: int = 0,
66
+ device: Optional[Union[str, torch.device]] = None,
67
+ ):
68
+ """
69
+ Compute the log-Mel spectrogram of
70
+
71
+ Parameters
72
+ ----------
73
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
74
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
75
+
76
+ n_mels: int
77
+ The number of Mel-frequency filters, only 80 is supported
78
+
79
+ padding: int
80
+ Number of zero samples to pad to the right
81
+
82
+ device: Optional[Union[str, torch.device]]
83
+ If given, the audio tensor is moved to this device before STFT
84
+
85
+ Returns
86
+ -------
87
+ torch.Tensor, shape = (80, n_frames)
88
+ A Tensor that contains the Mel spectrogram
89
+ """
90
+ if not torch.is_tensor(audio):
91
+ audio = torch.from_numpy(audio)
92
+
93
+ if device is not None:
94
+ audio = audio.to(device)
95
+ if padding > 0:
96
+ audio = F.pad(audio, (0, padding))
97
+ window = torch.hann_window(N_FFT).to(audio.device)
98
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
99
+ magnitudes = stft[..., :-1].abs() ** 2
100
+
101
+ filters = mel_filters(audio.device, n_mels)
102
+ mel_spec = filters @ magnitudes
103
+
104
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
105
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
106
+ log_spec = (log_spec + 4.0) / 4.0
107
+ return log_spec
108
+
109
+
110
+ def get_T_after_cnn(L_in, dilation=1):
111
+ for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
112
+ L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
113
+ L_out = 1 + L_out // stride
114
+ L_in = L_out
115
+ return L_out
116
+
117
+
118
+ def get_mel_audio(audio, padding=False, audio_vq_ds_rate = 1, n_mels = 128):
119
+ audio_len = len(audio)
120
+ if padding:
121
+ reduction = 160 * 2 * audio_vq_ds_rate
122
+ audio_pad = math.ceil(audio_len / reduction) * reduction - audio_len
123
+ mel = log_mel_spectrogram(audio, n_mels=n_mels, padding=audio_pad)
124
+ else:
125
+ mel = log_mel_spectrogram(audio, n_mels=n_mels) # [F,T]
126
+ return mel
127
+
128
+
129
+ def sinusoids(length, channels, max_timescale=10000):
130
+ """Returns sinusoids for positional embedding"""
131
+ assert channels % 2 == 0
132
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
133
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
134
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
135
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
136
+
137
+
138
+ class Conv1d(nn.Conv1d):
139
+ def _conv_forward(
140
+ self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
141
+ ) -> Tensor:
142
+ return super()._conv_forward(
143
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
144
+ )
145
+
146
+
147
+ class ConvTranspose1d(nn.ConvTranspose1d):
148
+ def _conv_forward(
149
+ self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
150
+ ) -> Tensor:
151
+ return super()._conv_forward(
152
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
153
+ )
154
+
155
+
156
+ class Linear(nn.Linear):
157
+ def forward(self, x: Tensor) -> Tensor:
158
+ return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) )
159
+
160
+
161
+ class MultiHeadAttention(nn.Module):
162
+ def __init__(self, n_state: int, n_head: int):
163
+ super().__init__()
164
+ self.n_head = n_head
165
+ self.query = Linear(n_state, n_state)
166
+ self.key = Linear(n_state, n_state, bias=False)
167
+ self.value = Linear(n_state, n_state)
168
+ self.out = Linear(n_state, n_state)
169
+
170
+ self.use_flash_attention = True
171
+
172
+ def forward(
173
+ self,
174
+ x: Tensor,
175
+ cu_seqlens = None,
176
+ ):
177
+ q = self.query(x)
178
+ k = self.key(x)
179
+ v = self.value(x)
180
+
181
+ if self.use_flash_attention:
182
+ if flash_attn_varlen_func is None:
183
+ x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
184
+ else:
185
+ if q.dtype not in [torch.float16, torch.bfloat16]:
186
+ x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
187
+ self.use_flash_attention = False
188
+ else:
189
+ x = self.qkv_flash_attention(q, k, v, cu_seqlens=cu_seqlens)
190
+ else:
191
+ x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
192
+
193
+ output = self.out(x)
194
+ return output
195
+
196
+ def qkv_flash_attention(
197
+ self, q: Tensor, k: Tensor, v: Tensor, cu_seqlens=None
198
+ ):
199
+ n_ctx, n_state = q.shape
200
+ # scale = (n_state // self.n_head) ** -0.25
201
+ q = q.view(n_ctx, self.n_head, -1)# (batch_size, seqlen, nheads, headdim)
202
+ k = k.view(n_ctx, self.n_head, -1)
203
+ v = v.view(n_ctx, self.n_head, -1)
204
+
205
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
206
+
207
+
208
+ x = flash_attn_varlen_func(
209
+ q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0
210
+ )
211
+ x = x.reshape(n_ctx, n_state)
212
+ return x
213
+
214
+ def qkv_attention_manual(
215
+ self, q: Tensor, k: Tensor, v: Tensor, cu_seqlens: Tensor
216
+ ):
217
+ n_ctx, n_state = q.shape
218
+ head_dim = n_state // self.n_head
219
+ scale = head_dim ** -0.5
220
+
221
+ q = q.view(n_ctx, self.n_head, head_dim)
222
+ k = k.view(n_ctx, self.n_head, head_dim)
223
+ v = v.view(n_ctx, self.n_head, head_dim)
224
+
225
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
226
+ batch_size = len(seqlens)
227
+ max_seqlen = max(seqlens)
228
+
229
+ q_padded = torch.zeros(batch_size, max_seqlen, self.n_head, head_dim, dtype=q.dtype, device=q.device)
230
+ k_padded = torch.zeros_like(q_padded)
231
+ v_padded = torch.zeros_like(q_padded)
232
+
233
+ for i in range(batch_size):
234
+ start_idx = cu_seqlens[i]
235
+ end_idx = cu_seqlens[i+1]
236
+ seq_len = seqlens[i]
237
+ q_padded[i, :seq_len] = q[start_idx:end_idx]
238
+ k_padded[i, :seq_len] = k[start_idx:end_idx]
239
+ v_padded[i, :seq_len] = v[start_idx:end_idx]
240
+
241
+ q_padded = q_padded.transpose(1, 2)
242
+ k_padded = k_padded.transpose(1, 2)
243
+ v_padded = v_padded.transpose(1, 2)
244
+
245
+ attn_mask = torch.arange(max_seqlen, device=q.device)[None, :] < torch.tensor(seqlens, device=q.device)[:, None]
246
+ attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
247
+
248
+ attn_mask = attn_mask.masked_fill(attn_mask == 0, -torch.finfo(q.dtype).max)
249
+
250
+ attn_scores = torch.matmul(q_padded, k_padded.transpose(-2, -1)) * scale
251
+ attn_scores = attn_scores + attn_mask
252
+ attn_weights = F.softmax(attn_scores, dim=-1)
253
+
254
+ context = torch.matmul(attn_weights, v_padded)
255
+
256
+ context = context.transpose(1, 2).contiguous().view(batch_size, max_seqlen, n_state)
257
+
258
+ output_packed = torch.cat([context[i, :seqlens[i]] for i in range(batch_size)], dim=0)
259
+
260
+ assert output_packed.shape == (n_ctx, n_state)
261
+
262
+ return output_packed
263
+
264
+
265
+ class ResidualAttentionBlock(nn.Module):
266
+ def __init__(self, n_state: int, n_head: int,
267
+ enable_mp: bool = False, sequence_parallel: bool = False):
268
+ super().__init__()
269
+ n_mlp = n_state * 4
270
+ self.attn_ln = nn.LayerNorm(n_state)
271
+ self.mlp_ln = nn.LayerNorm(n_state)
272
+
273
+ self.attn = MultiHeadAttention(n_state, n_head)
274
+ self.mlp = nn.Sequential(
275
+ Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
276
+ )
277
+
278
+ def forward(
279
+ self,
280
+ x: Tensor,
281
+ cu_seqlens = None
282
+ ):
283
+ x = x + self.attn(self.attn_ln(x), cu_seqlens=cu_seqlens)
284
+ x = x + self.mlp(self.mlp_ln(x))
285
+ return x
286
+
287
+
288
+ class WhisperEncoder(nn.Module):
289
+ def __init__(
290
+ self,
291
+ n_mels: int,
292
+ n_ctx: int,
293
+ n_state: int,
294
+ n_head: int,
295
+ n_layer: int,
296
+ n_window: int = 1500,
297
+ output_dim: int = 512,
298
+ grad_checkpointing: bool = False,
299
+ enable_mp: bool = False,
300
+ audio_sequence_parallel: bool = False,
301
+ ):
302
+ super().__init__()
303
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
304
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
305
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
306
+ self.n_layer = n_layer
307
+ self.n_mels = n_mels
308
+
309
+ self.blocks = nn.ModuleList(
310
+ [ResidualAttentionBlock(n_state, n_head, enable_mp=enable_mp, sequence_parallel=audio_sequence_parallel)
311
+ for _ in range(n_layer)]
312
+ )
313
+ self.ln_post = nn.LayerNorm(n_state)
314
+ self.avg_pooler = nn.AvgPool1d(2, stride=2)
315
+
316
+ self.proj = torch.nn.Linear(n_state, output_dim)
317
+
318
+ self.audio_bos_eos_token = nn.Embedding(2, output_dim)
319
+
320
+ self.output_dim = output_dim
321
+ self.grad_checkpointing = grad_checkpointing
322
+ self.enable_mp = enable_mp
323
+ self.n_head = n_head
324
+ self.n_state = n_state
325
+ self.n_window = n_window
326
+
327
+ self.audio_sequence_parallel = audio_sequence_parallel
328
+
329
+ self.tp_world_size = 1
330
+
331
+ self.set_audio_sync()
332
+
333
+ def set_audio_sync(self):
334
+ for name, param in self.named_parameters():
335
+ if not name.startswith("blocks"):
336
+ setattr(param, "audio_sync", True)
337
+
338
+ def forward(self, x_list: List[Tensor], audio_mellens:List[int], audio_aftercnnlens:List[int], audio_seqlens:List[int]):
339
+ """
340
+ x : torch.Tensor, shape = (n_mels, n_ctx)
341
+ the mel spectrogram of the audio
342
+ """
343
+
344
+ aftercnn_x_list = []
345
+ for each_x in x_list:
346
+ each_x_split_list = each_x.split(self.n_window * 2, dim=1)
347
+ for each_x_split in each_x_split_list:
348
+ each_x_split = F.gelu(self.conv1(each_x_split))
349
+ each_x_split = F.gelu(self.conv2(each_x_split))
350
+ each_x_split = each_x_split.permute(1, 0) # L,D
351
+ each_positional_embedding_split = self.positional_embedding[:each_x_split.shape[0]]
352
+ aftercnn_x_list.append(each_x_split+each_positional_embedding_split.to(each_x_split.dtype))
353
+
354
+ x = torch.cat(aftercnn_x_list, dim=0)
355
+ src_len = x.size(0)
356
+
357
+ output_list = []
358
+ for item in audio_aftercnnlens:
359
+ while item > self.n_window:
360
+ output_list.append(self.n_window)
361
+ item -= self.n_window
362
+ output_list.append(item)
363
+
364
+ cu_seqlens = list(accumulate(output_list, func=operator.add,initial=0))
365
+ cu_seqlens = torch.Tensor(cu_seqlens).to(device=x.device, dtype=torch.int32)
366
+
367
+ layer_id = 0
368
+ for block in self.blocks:
369
+ layer_id+=1
370
+ x = block(x, cu_seqlens=cu_seqlens)
371
+
372
+ if self.avg_pooler:
373
+ x_list = x.split(audio_aftercnnlens, dim=0)
374
+ token_x_list = []
375
+ for x in x_list:
376
+ x = x.permute(1, 0)
377
+ x = self.avg_pooler(x)
378
+ x = x.permute(1, 0)
379
+ token_x_list.append(x)
380
+ x = torch.cat(token_x_list, dim=0)
381
+
382
+ x = self.ln_post(x)
383
+ x = self.proj(x)
384
+
385
+ output = torch.zeros(
386
+ (x.size(0) + len(audio_seqlens) * 2, x.size(1)),
387
+ device=x.device, dtype=x.dtype
388
+ )
389
+
390
+ audio_seqlens_acc = list(accumulate(audio_seqlens, func=operator.add, initial=0))
391
+ start_ids = torch.tensor(audio_seqlens_acc[:-1], device=x.device, dtype=torch.int32)
392
+ end_ids = torch.tensor(audio_seqlens_acc[1:], device=x.device, dtype=torch.int32) - 1
393
+
394
+ audio_tokens_mask = torch.ones(output.size(0), device=x.device, dtype=torch.bool)
395
+ audio_tokens_mask[start_ids] = False
396
+ audio_tokens_mask[end_ids] = False
397
+ output[start_ids] = self.audio_bos_eos_token.weight[0].to(x.dtype)
398
+ output[end_ids] = self.audio_bos_eos_token.weight[1].to(x.dtype)
399
+ output[audio_tokens_mask] = x
400
+ return output
401
+
402
+ def lock(self, layers: int):
403
+ self.conv1.requires_grad_(False)
404
+ self.conv2.requires_grad_(False)
405
+ for i in range(min(layers, len(self.blocks))):
406
+ self.blocks[i].requires_grad_(False)
qwen_tts/inference/qwen3_tts_model.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import base64
17
+ import io
18
+ import urllib.request
19
+ from dataclasses import dataclass
20
+ from typing import Any, Dict, List, Optional, Tuple, Union
21
+ from urllib.parse import urlparse
22
+
23
+ import librosa
24
+ import numpy as np
25
+ import soundfile as sf
26
+ import torch
27
+ from transformers import AutoConfig, AutoModel, AutoProcessor
28
+
29
+ from ..core.models import Qwen3TTSConfig, Qwen3TTSForConditionalGeneration, Qwen3TTSProcessor
30
+
31
+ AudioLike = Union[
32
+ str, # wav path, URL, base64
33
+ np.ndarray, # waveform (requires sr)
34
+ Tuple[np.ndarray, int], # (waveform, sr)
35
+ ]
36
+
37
+ MaybeList = Union[Any, List[Any]]
38
+
39
+
40
+ @dataclass
41
+ class VoiceClonePromptItem:
42
+ """
43
+ Container for one sample's voice-clone prompt information that can be fed to the model.
44
+
45
+ Fields are aligned with `Qwen3TTSForConditionalGeneration.generate(..., voice_clone_prompt=...)`.
46
+ """
47
+ ref_code: Optional[torch.Tensor] # (T, Q) or (T,) depending on tokenizer 25Hz/12Hz
48
+ ref_spk_embedding: torch.Tensor # (D,)
49
+ x_vector_only_mode: bool
50
+ icl_mode: bool
51
+ ref_text: Optional[str] = None
52
+
53
+
54
+ class Qwen3TTSModel:
55
+ """
56
+ A HuggingFace-style wrapper for Qwen3 TTS models (CustomVoice/VoiceDesign/Base) that provides:
57
+ - from_pretrained() initialization via AutoModel/AutoProcessor
58
+ - generation APIs for:
59
+ * CustomVoice: generate_custom_voice()
60
+ * VoiceDesign: generate_voice_design()
61
+ * Base: generate_voice_clone() + create_voice_clone_prompt()
62
+ - consistent output: (wavs: List[np.ndarray], sample_rate: int)
63
+
64
+ Notes:
65
+ - This wrapper expects the underlying model class to be `Qwen3TTSForConditionalGeneration`
66
+ - Language / speaker validation is done via model methods:
67
+ model.get_supported_languages(), model.get_supported_speakers()
68
+ """
69
+
70
+ def __init__(self, model: Qwen3TTSForConditionalGeneration, processor, generate_defaults: Optional[Dict[str, Any]] = None):
71
+ self.model = model
72
+ self.processor = processor
73
+ self.generate_defaults = generate_defaults or {}
74
+
75
+ self.device = getattr(model, "device", None)
76
+ if self.device is None:
77
+ try:
78
+ self.device = next(model.parameters()).device
79
+ except StopIteration:
80
+ self.device = torch.device("cpu")
81
+
82
+ @classmethod
83
+ def from_pretrained(
84
+ cls,
85
+ pretrained_model_name_or_path: str,
86
+ **kwargs,
87
+ ) -> "Qwen3TTSModel":
88
+ """
89
+ Load a Qwen3 TTS model and its processor in HuggingFace `from_pretrained` style.
90
+
91
+ This method:
92
+ 1) Loads config via AutoConfig (so your side can register model_type -> config/model).
93
+ 2) Loads the model via AutoModel.from_pretrained(...), forwarding `kwargs` unchanged.
94
+ 3) Loads the processor via AutoProcessor.from_pretrained(model_path).
95
+ 4) Loads optional `generate_config.json` from the model directory/repo snapshot if present.
96
+
97
+ Args:
98
+ pretrained_model_name_or_path (str):
99
+ HuggingFace repo id or local directory of the model.
100
+ **kwargs:
101
+ Forwarded as-is into `AutoModel.from_pretrained(...)`.
102
+ Typical examples: device_map="cuda:0", dtype=torch.bfloat16, attn_implementation="flash_attention_2".
103
+
104
+ Returns:
105
+ Qwen3TTSModel:
106
+ Wrapper instance containing `model`, `processor`, and generation defaults.
107
+ """
108
+ AutoConfig.register("qwen3_tts", Qwen3TTSConfig)
109
+ AutoModel.register(Qwen3TTSConfig, Qwen3TTSForConditionalGeneration)
110
+ AutoProcessor.register(Qwen3TTSConfig, Qwen3TTSProcessor)
111
+
112
+ model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
113
+ if not isinstance(model, Qwen3TTSForConditionalGeneration):
114
+ raise TypeError(
115
+ f"AutoModel returned {type(model)}, expected Qwen3TTSForConditionalGeneration. "
116
+ )
117
+
118
+ processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True,)
119
+
120
+ generate_defaults = model.generate_config
121
+ return cls(model=model, processor=processor, generate_defaults=generate_defaults)
122
+
123
+ def _supported_languages_set(self) -> Optional[set]:
124
+ langs = getattr(self.model, "get_supported_languages", None)
125
+ if callable(langs):
126
+ v = langs()
127
+ if v is None:
128
+ return None
129
+ return set([str(x).lower() for x in v])
130
+ return None
131
+
132
+ def _supported_speakers_set(self) -> Optional[set]:
133
+ spks = getattr(self.model, "get_supported_speakers", None)
134
+ if callable(spks):
135
+ v = spks()
136
+ if v is None:
137
+ return None
138
+ return set([str(x).lower() for x in v])
139
+ return None
140
+
141
+ def _validate_languages(self, languages: List[str]) -> None:
142
+ """
143
+ Validate that requested languages are supported by the model.
144
+
145
+ Args:
146
+ languages (List[str]): Language names for each sample.
147
+
148
+ Raises:
149
+ ValueError: If any language is not supported.
150
+ """
151
+ supported = self._supported_languages_set()
152
+ if supported is None:
153
+ return
154
+
155
+ bad = []
156
+ for lang in languages:
157
+ if lang is None:
158
+ bad.append(lang)
159
+ continue
160
+ if str(lang).lower() not in supported:
161
+ bad.append(lang)
162
+ if bad:
163
+ raise ValueError(f"Unsupported languages: {bad}. Supported: {sorted(supported)}")
164
+
165
+ def _validate_speakers(self, speakers: List[Optional[str]]) -> None:
166
+ """
167
+ Validate that requested speakers are supported by the Instruct model.
168
+
169
+ Args:
170
+ speakers (List[Optional[str]]): Speaker names for each sample.
171
+
172
+ Raises:
173
+ ValueError: If any speaker is not supported.
174
+ """
175
+ supported = self._supported_speakers_set()
176
+ if supported is None:
177
+ return
178
+
179
+ bad = []
180
+ for spk in speakers:
181
+ if spk is None or spk == "":
182
+ continue
183
+ if str(spk).lower() not in supported:
184
+ bad.append(spk)
185
+ if bad:
186
+ raise ValueError(f"Unsupported speakers: {bad}. Supported: {sorted(supported)}")
187
+
188
+ def _is_probably_base64(self, s: str) -> bool:
189
+ if s.startswith("data:audio"):
190
+ return True
191
+ if ("/" not in s and "\\" not in s) and len(s) > 256:
192
+ return True
193
+ return False
194
+
195
+ def _is_url(self, s: str) -> bool:
196
+ try:
197
+ u = urlparse(s)
198
+ return u.scheme in ("http", "https") and bool(u.netloc)
199
+ except Exception:
200
+ return False
201
+
202
+ def _decode_base64_to_wav_bytes(self, b64: str) -> bytes:
203
+ if "," in b64 and b64.strip().startswith("data:"):
204
+ b64 = b64.split(",", 1)[1]
205
+ return base64.b64decode(b64)
206
+
207
+ def _load_audio_to_np(self, x: str) -> Tuple[np.ndarray, int]:
208
+ if self._is_url(x):
209
+ with urllib.request.urlopen(x) as resp:
210
+ audio_bytes = resp.read()
211
+ with io.BytesIO(audio_bytes) as f:
212
+ audio, sr = sf.read(f, dtype="float32", always_2d=False)
213
+ elif self._is_probably_base64(x):
214
+ wav_bytes = self._decode_base64_to_wav_bytes(x)
215
+ with io.BytesIO(wav_bytes) as f:
216
+ audio, sr = sf.read(f, dtype="float32", always_2d=False)
217
+ else:
218
+ audio, sr = librosa.load(x, sr=None, mono=True)
219
+
220
+ if audio.ndim > 1:
221
+ audio = np.mean(audio, axis=-1)
222
+
223
+ return audio.astype(np.float32), int(sr)
224
+
225
+ def _normalize_audio_inputs(self, audios: Union[AudioLike, List[AudioLike]]) -> List[Tuple[np.ndarray, int]]:
226
+ """
227
+ Normalize audio inputs into a list of (waveform, sr).
228
+
229
+ Supported forms:
230
+ - str: wav path / URL / base64 audio string
231
+ - (np.ndarray, sr): waveform + sampling rate
232
+ - list of the above
233
+
234
+ Args:
235
+ audios:
236
+ Audio input(s).
237
+
238
+ Returns:
239
+ List[Tuple[np.ndarray, int]]:
240
+ List of (float32 waveform, original sr).
241
+
242
+ Raises:
243
+ ValueError: If a numpy waveform is provided without sr.
244
+ """
245
+ if isinstance(audios, list):
246
+ items = audios
247
+ else:
248
+ items = [audios]
249
+
250
+ out: List[Tuple[np.ndarray, int]] = []
251
+ for a in items:
252
+ if isinstance(a, str):
253
+ out.append(self._load_audio_to_np(a))
254
+ elif isinstance(a, tuple) and len(a) == 2 and isinstance(a[0], np.ndarray):
255
+ out.append((a[0].astype(np.float32), int(a[1])))
256
+ elif isinstance(a, np.ndarray):
257
+ raise ValueError("For numpy waveform input, pass a tuple (audio, sr).")
258
+ else:
259
+ raise TypeError(f"Unsupported audio input type: {type(a)}")
260
+ for i, a in enumerate(out):
261
+ if a[0].ndim > 1:
262
+ a[0] = np.mean(a[0], axis=-1).astype(np.float32)
263
+ out[i] = (a[0], a[1])
264
+ return out
265
+
266
+ def _ensure_list(self, x: MaybeList) -> List[Any]:
267
+ return x if isinstance(x, list) else [x]
268
+
269
+ def _build_assistant_text(self, text: str) -> str:
270
+ return f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
271
+
272
+ def _build_ref_text(self, text: str) -> str:
273
+ return f"<|im_start|>assistant\n{text}<|im_end|>\n"
274
+
275
+ def _build_instruct_text(self, instruct: str) -> str:
276
+ return f"<|im_start|>user\n{instruct}<|im_end|>\n"
277
+
278
+ def _tokenize_texts(self, texts: List[str]) -> List[torch.Tensor]:
279
+ input_ids = []
280
+ for text in texts:
281
+ input = self.processor(text=text, return_tensors="pt", padding=True)
282
+ input_id = input["input_ids"].to(self.device)
283
+ input_id = input_id.unsqueeze(0) if input_id.dim() == 1 else input_id
284
+ input_ids.append(input_id)
285
+ return input_ids
286
+
287
+ def _merge_generate_kwargs(
288
+ self,
289
+ do_sample: Optional[bool] = None,
290
+ top_k: Optional[int] = None,
291
+ top_p: Optional[float] = None,
292
+ temperature: Optional[float] = None,
293
+ repetition_penalty: Optional[float] = None,
294
+ subtalker_dosample: Optional[bool] = None,
295
+ subtalker_top_k: Optional[int] = None,
296
+ subtalker_top_p: Optional[float] = None,
297
+ subtalker_temperature: Optional[float] = None,
298
+ max_new_tokens: Optional[int] = None,
299
+ **kwargs,
300
+ ) -> Dict[str, Any]:
301
+ """
302
+ Merge user-provided generation arguments with defaults from `generate_config.json`.
303
+
304
+ Rule:
305
+ - If the user explicitly passes a value (not None), use it.
306
+ - Otherwise, use the value from generate_config.json if present.
307
+ - Otherwise, fall back to the hard defaults.
308
+
309
+ Args:
310
+ do_sample, top_k, top_p, temperature, repetition_penalty,
311
+ subtalker_dosample, subtalker_top_k, subtalker_top_p, subtalker_temperature, max_new_tokens:
312
+ Common generation parameters.
313
+ **kwargs:
314
+ Other arguments forwarded to model.generate().
315
+
316
+ Returns:
317
+ Dict[str, Any]: Final kwargs to pass into model.generate().
318
+ """
319
+ hard_defaults = dict(
320
+ do_sample=True,
321
+ top_k=50,
322
+ top_p=1.0,
323
+ temperature=0.9,
324
+ repetition_penalty=1.05,
325
+ subtalker_dosample=True,
326
+ subtalker_top_k=50,
327
+ subtalker_top_p=1.0,
328
+ subtalker_temperature=0.9,
329
+ max_new_tokens=2048,
330
+ )
331
+
332
+ def pick(name: str, user_val: Any) -> Any:
333
+ if user_val is not None:
334
+ return user_val
335
+ if name in self.generate_defaults:
336
+ return self.generate_defaults[name]
337
+ return hard_defaults[name]
338
+
339
+ merged = dict(kwargs)
340
+ merged.update(
341
+ do_sample=pick("do_sample", do_sample),
342
+ top_k=pick("top_k", top_k),
343
+ top_p=pick("top_p", top_p),
344
+ temperature=pick("temperature", temperature),
345
+ repetition_penalty=pick("repetition_penalty", repetition_penalty),
346
+ subtalker_dosample=pick("subtalker_dosample", subtalker_dosample),
347
+ subtalker_top_k=pick("subtalker_top_k", subtalker_top_k),
348
+ subtalker_top_p=pick("subtalker_top_p", subtalker_top_p),
349
+ subtalker_temperature=pick("subtalker_temperature", subtalker_temperature),
350
+ max_new_tokens=pick("max_new_tokens", max_new_tokens),
351
+ )
352
+ return merged
353
+
354
+ # voice clone model
355
+ @torch.inference_mode()
356
+ def create_voice_clone_prompt(
357
+ self,
358
+ ref_audio: Union[AudioLike, List[AudioLike]],
359
+ ref_text: Optional[Union[str, List[Optional[str]]]] = None,
360
+ x_vector_only_mode: Union[bool, List[bool]] = False,
361
+ ) -> List[VoiceClonePromptItem]:
362
+ """
363
+ Build voice-clone prompt items from reference audio (and optionally reference text) using Base model.
364
+
365
+ Modes:
366
+ - x_vector_only_mode=True:
367
+ Only speaker embedding is used to clone voice; ref_text/ref_code are ignored.
368
+ This is mutually exclusive with ICL.
369
+ - x_vector_only_mode=False:
370
+ ICL mode is enabled automatically (icl_mode=True). In this case ref_text is required,
371
+ because the model continues/conditions on the reference text + reference speech codes.
372
+
373
+ Batch behavior:
374
+ - ref_audio can be a single item or a list.
375
+ - ref_text and x_vector_only_mode can be scalars or lists.
376
+ - If any of them are lists with length > 1, lengths must match.
377
+
378
+ Audio input:
379
+ - str: local wav path / URL / base64
380
+ - (np.ndarray, sr): waveform + sampling rate
381
+
382
+ Args:
383
+ ref_audio:
384
+ Reference audio(s) used to extract:
385
+ - ref_code via `model.speech_tokenizer.encode(...)`
386
+ - ref_spk_embedding via `model.extract_speaker_embedding(...)` (resampled to 24k)
387
+ ref_text:
388
+ Reference transcript(s). Required when x_vector_only_mode=False (ICL mode).
389
+ x_vector_only_mode:
390
+ Whether to use speaker embedding only. If False, ICL mode will be used.
391
+
392
+ Returns:
393
+ List[VoiceClonePromptItem]:
394
+ List of prompt items that can be converted into `voice_clone_prompt` dict.
395
+
396
+ Raises:
397
+ ValueError:
398
+ - If x_vector_only_mode=False but ref_text is missing.
399
+ - If batch lengths mismatch.
400
+ """
401
+ if self.model.tts_model_type != "base":
402
+ raise ValueError(
403
+ f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
404
+ f"tts_model_size: {self.model.tts_model_size}\n"
405
+ f"tts_model_type: {self.model.tts_model_type}\n"
406
+ "does not support create_voice_clone_prompt, Please check Model Card or Readme for more details."
407
+ )
408
+
409
+ ref_audio_list = self._ensure_list(ref_audio)
410
+ ref_text_list = self._ensure_list(ref_text) if isinstance(ref_text, list) else ([ref_text] * len(ref_audio_list))
411
+ xvec_list = self._ensure_list(x_vector_only_mode) if isinstance(x_vector_only_mode, list) else ([x_vector_only_mode] * len(ref_audio_list))
412
+
413
+ if len(ref_text_list) != len(ref_audio_list) or len(xvec_list) != len(ref_audio_list):
414
+ raise ValueError(
415
+ f"Batch size mismatch: ref_audio={len(ref_audio_list)}, ref_text={len(ref_text_list)}, x_vector_only_mode={len(xvec_list)}"
416
+ )
417
+
418
+ normalized = self._normalize_audio_inputs(ref_audio_list)
419
+
420
+ ref_wavs_for_code: List[np.ndarray] = []
421
+ ref_sr_for_code: List[int] = []
422
+ for wav, sr in normalized:
423
+ ref_wavs_for_code.append(wav)
424
+ ref_sr_for_code.append(sr)
425
+
426
+ if len(set(ref_sr_for_code)) == 1:
427
+ enc = self.model.speech_tokenizer.encode(ref_wavs_for_code, sr=ref_sr_for_code[0])
428
+ ref_codes = enc.audio_codes
429
+ else:
430
+ ref_codes = []
431
+ for wav, sr in normalized:
432
+ ref_codes.append(self.model.speech_tokenizer.encode(wav, sr=sr).audio_codes[0])
433
+
434
+ items: List[VoiceClonePromptItem] = []
435
+ for i, ((wav, sr), code, rtext, xvec_only) in enumerate(zip(normalized, ref_codes, ref_text_list, xvec_list)):
436
+ if not xvec_only:
437
+ if rtext is None or rtext == "":
438
+ raise ValueError(f"ref_text is required when x_vector_only_mode=False (ICL mode). Bad index={i}")
439
+
440
+ wav_resample = wav
441
+ if sr != self.model.speaker_encoder_sample_rate:
442
+ wav_resample = librosa.resample(y=wav_resample.astype(np.float32),
443
+ orig_sr=int(sr),
444
+ target_sr=self.model.speaker_encoder_sample_rate)
445
+
446
+ spk_emb = self.model.extract_speaker_embedding(audio=wav_resample,
447
+ sr=self.model.speaker_encoder_sample_rate)
448
+
449
+ items.append(
450
+ VoiceClonePromptItem(
451
+ ref_code=None if xvec_only else code,
452
+ ref_spk_embedding=spk_emb,
453
+ x_vector_only_mode=bool(xvec_only),
454
+ icl_mode=bool(not xvec_only),
455
+ ref_text=rtext,
456
+ )
457
+ )
458
+ return items
459
+
460
+ def _prompt_items_to_voice_clone_prompt(self, items: List[VoiceClonePromptItem]) -> Dict[str, Any]:
461
+ return dict(
462
+ ref_code=[it.ref_code for it in items],
463
+ ref_spk_embedding=[it.ref_spk_embedding for it in items],
464
+ x_vector_only_mode=[it.x_vector_only_mode for it in items],
465
+ icl_mode=[it.icl_mode for it in items],
466
+ )
467
+
468
+ # voice clone model
469
+ @torch.no_grad()
470
+ def generate_voice_clone(
471
+ self,
472
+ text: Union[str, List[str]],
473
+ language: Union[str, List[str]] = None,
474
+ ref_audio: Optional[Union[AudioLike, List[AudioLike]]] = None,
475
+ ref_text: Optional[Union[str, List[Optional[str]]]] = None,
476
+ x_vector_only_mode: Union[bool, List[bool]] = False,
477
+ voice_clone_prompt: Optional[Union[Dict[str, Any], List[VoiceClonePromptItem]]] = None,
478
+ non_streaming_mode: bool = False,
479
+ **kwargs,
480
+ ) -> Tuple[List[np.ndarray], int]:
481
+ """
482
+ Voice clone speech using the Base model.
483
+
484
+ You can provide either:
485
+ - (ref_audio, ref_text, x_vector_only_mode) and let this method build the prompt, OR
486
+ - `VoiceClonePromptItem` returned by `create_voice_clone_prompt`, OR
487
+ - a list of `VoiceClonePromptItem` returned by `create_voice_clone_prompt`.
488
+
489
+ `ref_audio` Supported forms:
490
+ - str: wav path / URL / base64 audio string
491
+ - (np.ndarray, sr): waveform + sampling rate
492
+ - list of the above
493
+
494
+ Input flexibility:
495
+ - text/language can be scalar or list.
496
+ - prompt can be single or batch.
497
+ - If batch mode (len(text)>1), lengths must match.
498
+
499
+ Args:
500
+ text:
501
+ Text(s) to synthesize.
502
+ language:
503
+ Language(s) for each sample.
504
+ ref_audio:
505
+ Reference audio(s) for prompt building. Required if voice_clone_prompt is not provided.
506
+ ref_text:
507
+ Reference text(s) used for ICL mode (required when x_vector_only_mode=False).
508
+ x_vector_only_mode:
509
+ If True, only speaker embedding is used (ignores ref_text/ref_code).
510
+ If False, ICL mode is used automatically.
511
+ voice_clone_prompt:
512
+ list[VoiceClonePromptItem] from `create_voice_clone_prompt`.
513
+ non_streaming_mode:
514
+ Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
515
+ rather than enabling true streaming input or streaming generation.
516
+ do_sample:
517
+ Whether to use sampling, recommended to be set to `true` for most use cases.
518
+ top_k:
519
+ Top-k sampling parameter.
520
+ top_p:
521
+ Top-p sampling parameter.
522
+ temperature:
523
+ Sampling temperature; higher => more random.
524
+ repetition_penalty:
525
+ Penalty to reduce repeated tokens/codes.
526
+ subtalker_dosample:
527
+ Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
528
+ subtalker_top_k:
529
+ Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
530
+ subtalker_top_p:
531
+ Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
532
+ subtalker_temperature:
533
+ Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
534
+ max_new_tokens:
535
+ Maximum number of new codec tokens to generate.
536
+ **kwargs:
537
+ Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
538
+ They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
539
+
540
+ Returns:
541
+ Tuple[List[np.ndarray], int]:
542
+ (wavs, sample_rate)
543
+
544
+ Raises:
545
+ ValueError:
546
+ If batch sizes mismatch or required prompt inputs are missing.
547
+ """
548
+ if self.model.tts_model_type != "base":
549
+ raise ValueError(
550
+ f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
551
+ f"tts_model_size: {self.model.tts_model_size}\n"
552
+ f"tts_model_type: {self.model.tts_model_type}\n"
553
+ "does not support generate_voice_clone, Please check Model Card or Readme for more details."
554
+ )
555
+
556
+ texts = self._ensure_list(text)
557
+ languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
558
+ if len(languages) == 1 and len(texts) > 1:
559
+ languages = languages * len(texts)
560
+ if len(texts) != len(languages):
561
+ raise ValueError(f"Batch size mismatch: text={len(texts)}, language={len(languages)}")
562
+
563
+ self._validate_languages(languages)
564
+
565
+ if voice_clone_prompt is None:
566
+ if ref_audio is None:
567
+ raise ValueError("Either `voice_clone_prompt` or `ref_audio` must be provided.")
568
+ prompt_items = self.create_voice_clone_prompt(ref_audio=ref_audio, ref_text=ref_text, x_vector_only_mode=x_vector_only_mode)
569
+ if len(prompt_items) == 1 and len(texts) > 1:
570
+ prompt_items = prompt_items * len(texts)
571
+ if len(prompt_items) != len(texts):
572
+ raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}")
573
+ voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items)
574
+ ref_texts_for_ids = [it.ref_text for it in prompt_items]
575
+ else:
576
+ if isinstance(voice_clone_prompt, list):
577
+ prompt_items = voice_clone_prompt
578
+ if len(prompt_items) == 1 and len(texts) > 1:
579
+ prompt_items = prompt_items * len(texts)
580
+ if len(prompt_items) != len(texts):
581
+ raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}")
582
+ voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items)
583
+ ref_texts_for_ids = [it.ref_text for it in prompt_items]
584
+ else:
585
+ voice_clone_prompt_dict = voice_clone_prompt
586
+ ref_texts_for_ids = None
587
+
588
+ input_texts = [self._build_assistant_text(t) for t in texts]
589
+ input_ids = self._tokenize_texts(input_texts)
590
+
591
+ ref_ids = None
592
+ if ref_texts_for_ids is not None:
593
+ ref_ids = []
594
+ for i, rt in enumerate(ref_texts_for_ids):
595
+ if rt is None or rt == "":
596
+ ref_ids.append(None)
597
+ else:
598
+ ref_tok = self._tokenize_texts([self._build_ref_text(rt)])[0]
599
+ ref_ids.append(ref_tok)
600
+
601
+ gen_kwargs = self._merge_generate_kwargs(**kwargs)
602
+
603
+ talker_codes_list, _ = self.model.generate(
604
+ input_ids=input_ids,
605
+ ref_ids=ref_ids,
606
+ voice_clone_prompt=voice_clone_prompt_dict,
607
+ languages=languages,
608
+ non_streaming_mode=non_streaming_mode,
609
+ **gen_kwargs,
610
+ )
611
+
612
+ codes_for_decode = []
613
+ for i, codes in enumerate(talker_codes_list):
614
+ ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
615
+ if ref_code_list is not None and ref_code_list[i] is not None:
616
+ codes_for_decode.append(torch.cat([ref_code_list[i].to(codes.device), codes], dim=0))
617
+ else:
618
+ codes_for_decode.append(codes)
619
+
620
+ wavs_all, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in codes_for_decode])
621
+
622
+ wavs_out: List[np.ndarray] = []
623
+ for i, wav in enumerate(wavs_all):
624
+ ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
625
+ if ref_code_list is not None and ref_code_list[i] is not None:
626
+ ref_len = int(ref_code_list[i].shape[0])
627
+ total_len = int(codes_for_decode[i].shape[0])
628
+ cut = int(ref_len / max(total_len, 1) * wav.shape[0])
629
+ wavs_out.append(wav[cut:])
630
+ else:
631
+ wavs_out.append(wav)
632
+
633
+ return wavs_out, fs
634
+
635
+ # voice design model
636
+ @torch.no_grad()
637
+ def generate_voice_design(
638
+ self,
639
+ text: Union[str, List[str]],
640
+ instruct: Union[str, List[str]],
641
+ language: Union[str, List[str]] = None,
642
+ non_streaming_mode: bool = True,
643
+ **kwargs,
644
+ ) -> Tuple[List[np.ndarray], int]:
645
+ """
646
+ Generate speech with the VoiceDesign model using natural-language style instructions.
647
+
648
+ Args:
649
+ text:
650
+ Text(s) to synthesize.
651
+ language:
652
+ Language(s) for each sample.
653
+ instruct:
654
+ Instruction(s) describing desired voice/style. Empty string is allowed (treated as no instruction).
655
+ non_streaming_mode:
656
+ Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
657
+ rather than enabling true streaming input or streaming generation.
658
+ do_sample:
659
+ Whether to use sampling, recommended to be set to `true` for most use cases.
660
+ top_k:
661
+ Top-k sampling parameter.
662
+ top_p:
663
+ Top-p sampling parameter.
664
+ temperature:
665
+ Sampling temperature; higher => more random.
666
+ repetition_penalty:
667
+ Penalty to reduce repeated tokens/codes.
668
+ subtalker_dosample:
669
+ Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
670
+ subtalker_top_k:
671
+ Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
672
+ subtalker_top_p:
673
+ Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
674
+ subtalker_temperature:
675
+ Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
676
+ max_new_tokens:
677
+ Maximum number of new codec tokens to generate.
678
+ **kwargs:
679
+ Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
680
+ They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
681
+
682
+ Returns:
683
+ Tuple[List[np.ndarray], int]:
684
+ (wavs, sample_rate)
685
+ """
686
+ if self.model.tts_model_type != "voice_design":
687
+ raise ValueError(
688
+ f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
689
+ f"tts_model_size: {self.model.tts_model_size}\n"
690
+ f"tts_model_type: {self.model.tts_model_type}\n"
691
+ "does not support generate_voice_design, Please check Model Card or Readme for more details."
692
+ )
693
+
694
+ texts = self._ensure_list(text)
695
+ languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
696
+ instructs = self._ensure_list(instruct)
697
+
698
+ if len(languages) == 1 and len(texts) > 1:
699
+ languages = languages * len(texts)
700
+ if len(instructs) == 1 and len(texts) > 1:
701
+ instructs = instructs * len(texts)
702
+
703
+ if not (len(texts) == len(languages) == len(instructs)):
704
+ raise ValueError(f"Batch size mismatch: text={len(texts)}, language={len(languages)}, instruct={len(instructs)}")
705
+
706
+ self._validate_languages(languages)
707
+
708
+ input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts])
709
+
710
+ instruct_ids: List[Optional[torch.Tensor]] = []
711
+ for ins in instructs:
712
+ if ins is None or ins == "":
713
+ instruct_ids.append(None)
714
+ else:
715
+ instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0])
716
+
717
+ gen_kwargs = self._merge_generate_kwargs(**kwargs)
718
+
719
+ talker_codes_list, _ = self.model.generate(
720
+ input_ids=input_ids,
721
+ instruct_ids=instruct_ids,
722
+ languages=languages,
723
+ non_streaming_mode=non_streaming_mode,
724
+ **gen_kwargs,
725
+ )
726
+
727
+ wavs, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in talker_codes_list])
728
+ return wavs, fs
729
+
730
+ # custom voice model
731
+ @torch.no_grad()
732
+ def generate_custom_voice(
733
+ self,
734
+ text: Union[str, List[str]],
735
+ speaker: Union[str, List[str]],
736
+ language: Union[str, List[str]] = None,
737
+ instruct: Optional[Union[str, List[str]]] = None,
738
+ non_streaming_mode: bool = True,
739
+ **kwargs,
740
+ ) -> Tuple[List[np.ndarray], int]:
741
+ """
742
+ Generate speech with the CustomVoice model using a predefined speaker id, optionally controlled by instruction text.
743
+
744
+ Args:
745
+ text:
746
+ Text(s) to synthesize.
747
+ language:
748
+ Language(s) for each sample.
749
+ speaker:
750
+ Speaker name(s). Will be validated against `model.get_supported_speakers()` (case-insensitive).
751
+ instruct:
752
+ Optional instruction(s). If None, treated as empty (no instruction).
753
+ non_streaming_mode:
754
+ Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
755
+ rather than enabling true streaming input or streaming generation.
756
+ do_sample:
757
+ Whether to use sampling, recommended to be set to `true` for most use cases.
758
+ top_k:
759
+ Top-k sampling parameter.
760
+ top_p:
761
+ Top-p sampling parameter.
762
+ temperature:
763
+ Sampling temperature; higher => more random.
764
+ repetition_penalty:
765
+ Penalty to reduce repeated tokens/codes.
766
+ subtalker_dosample:
767
+ Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
768
+ subtalker_top_k:
769
+ Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
770
+ subtalker_top_p:
771
+ Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
772
+ subtalker_temperature:
773
+ Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
774
+ max_new_tokens:
775
+ Maximum number of new codec tokens to generate.
776
+ **kwargs:
777
+ Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
778
+ They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
779
+
780
+ Returns:
781
+ Tuple[List[np.ndarray], int]:
782
+ (wavs, sample_rate)
783
+
784
+ Raises:
785
+ ValueError:
786
+ If any speaker/language is unsupported or batch sizes mismatch.
787
+ """
788
+ if self.model.tts_model_type != "custom_voice":
789
+ raise ValueError(
790
+ f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
791
+ f"tts_model_size: {self.model.tts_model_size}\n"
792
+ f"tts_model_type: {self.model.tts_model_type}\n"
793
+ "does not support generate_custom_voice, Please check Model Card or Readme for more details."
794
+ )
795
+
796
+ texts = self._ensure_list(text)
797
+ languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
798
+ speakers = self._ensure_list(speaker)
799
+ if self.model.tts_model_size in "0b6": # for 0b6 model, instruct is not supported
800
+ instruct = None
801
+ instructs = self._ensure_list(instruct) if isinstance(instruct, list) else ([instruct] * len(texts) if instruct is not None else [""] * len(texts))
802
+
803
+ if len(languages) == 1 and len(texts) > 1:
804
+ languages = languages * len(texts)
805
+ if len(speakers) == 1 and len(texts) > 1:
806
+ speakers = speakers * len(texts)
807
+ if len(instructs) == 1 and len(texts) > 1:
808
+ instructs = instructs * len(texts)
809
+
810
+ if not (len(texts) == len(languages) == len(speakers) == len(instructs)):
811
+ raise ValueError(
812
+ f"Batch size mismatch: text={len(texts)}, language={len(languages)}, speaker={len(speakers)}, instruct={len(instructs)}"
813
+ )
814
+
815
+ self._validate_languages(languages)
816
+ self._validate_speakers(speakers)
817
+
818
+ input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts])
819
+
820
+ instruct_ids: List[Optional[torch.Tensor]] = []
821
+ for ins in instructs:
822
+ if ins is None or ins == "":
823
+ instruct_ids.append(None)
824
+ else:
825
+ instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0])
826
+
827
+ gen_kwargs = self._merge_generate_kwargs(**kwargs)
828
+
829
+ talker_codes_list, _ = self.model.generate(
830
+ input_ids=input_ids,
831
+ instruct_ids=instruct_ids,
832
+ languages=languages,
833
+ speakers=speakers,
834
+ non_streaming_mode=non_streaming_mode,
835
+ **gen_kwargs,
836
+ )
837
+
838
+ wavs, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in talker_codes_list])
839
+ return wavs, fs
840
+
841
+
842
+ def get_supported_speakers(self) -> Optional[List[str]]:
843
+ """
844
+ List supported speaker names for the current model.
845
+
846
+ This is a convenience wrapper around `model.get_supported_speakers()`.
847
+ If the underlying model does not expose speaker constraints (returns None),
848
+ this method also returns None.
849
+
850
+ Returns:
851
+ Optional[List[str]]:
852
+ - A sorted list of supported speaker names (lowercased), if available.
853
+ - None if the model does not provide supported speakers.
854
+ """
855
+ supported = self._supported_speakers_set()
856
+ if supported is None:
857
+ return None
858
+ return sorted(supported)
859
+
860
+
861
+ def get_supported_languages(self) -> Optional[List[str]]:
862
+ """
863
+ List supported language names for the current model.
864
+
865
+ This is a convenience wrapper around `model.get_supported_languages()`.
866
+ If the underlying model does not expose language constraints (returns None),
867
+ this method also returns None.
868
+
869
+ Returns:
870
+ Optional[List[str]]:
871
+ - A sorted list of supported language names (lowercased), if available.
872
+ - None if the model does not provide supported languages.
873
+ """
874
+ supported = self._supported_languages_set()
875
+ if supported is None:
876
+ return None
877
+ return sorted(supported)
qwen_tts/inference/qwen3_tts_tokenizer.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2026 The Alibaba Qwen team.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import base64
17
+ import io
18
+ import urllib.request
19
+ from typing import List, Optional, Tuple, Union
20
+ from urllib.parse import urlparse
21
+
22
+ import librosa
23
+ import numpy as np
24
+ import soundfile as sf
25
+ import torch
26
+ from torch.nn.utils.rnn import pad_sequence
27
+ from transformers import AutoConfig, AutoFeatureExtractor, AutoModel
28
+
29
+ from ..core import (
30
+ Qwen3TTSTokenizerV1Config,
31
+ Qwen3TTSTokenizerV1Model,
32
+ Qwen3TTSTokenizerV2Config,
33
+ Qwen3TTSTokenizerV2Model,
34
+ )
35
+
36
+ AudioInput = Union[
37
+ str, # wav path, or base64 string
38
+ np.ndarray, # 1-D float array
39
+ List[str],
40
+ List[np.ndarray],
41
+ ]
42
+
43
+
44
+ class Qwen3TTSTokenizer:
45
+ """
46
+ A wrapper for Qwen3 TTS Tokenizer 25Hz/12Hz with HuggingFace-style loading.
47
+
48
+ - from_pretrained(): loads speech tokenizer model via AutoModel and feature_extractor via AutoFeatureExtractor.
49
+ - encode(): supports wav path(s), base64 audio string(s), numpy array(s).
50
+ - decode(): accepts either the raw model encode output, or a minimal dict/list-of-dicts.
51
+
52
+ Notes:
53
+ - For numpy array input, you must pass `sr` so the audio can be resampled to model sample rate.
54
+ - Returned audio is float32 numpy arrays and the output sample rate.
55
+ """
56
+
57
+ def __init__(self):
58
+ self.model = None
59
+ self.feature_extractor = None
60
+ self.config = None
61
+ self.device = None
62
+
63
+ @classmethod
64
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "Qwen3TTSTokenizer":
65
+ """
66
+ Initialize tokenizer with HuggingFace `from_pretrained` style.
67
+
68
+ Args:
69
+ pretrained_model_name_or_path (str):
70
+ HuggingFace repo id or local directory.
71
+ **kwargs (Any):
72
+ Forwarded to `AutoModel.from_pretrained(...)` directly.
73
+ Typical examples: device_map="cuda:0", dtype=torch.bfloat16, attn_implementation="eager".
74
+
75
+ Returns:
76
+ Qwen3TTSTokenizer:
77
+ Initialized instance with `model`, `feature_extractor`, `config`.
78
+ """
79
+ inst = cls()
80
+
81
+ AutoConfig.register("qwen3_tts_tokenizer_25hz", Qwen3TTSTokenizerV1Config)
82
+ AutoModel.register(Qwen3TTSTokenizerV1Config, Qwen3TTSTokenizerV1Model)
83
+
84
+ AutoConfig.register("qwen3_tts_tokenizer_12hz", Qwen3TTSTokenizerV2Config)
85
+ AutoModel.register(Qwen3TTSTokenizerV2Config, Qwen3TTSTokenizerV2Model)
86
+
87
+ inst.feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)
88
+ inst.model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
89
+ inst.config = inst.model.config
90
+
91
+ inst.device = getattr(inst.model, "device", None)
92
+ if inst.device is None:
93
+ # fallback: infer from first parameter device
94
+ try:
95
+ inst.device = next(inst.model.parameters()).device
96
+ except StopIteration:
97
+ inst.device = torch.device("cpu")
98
+
99
+ return inst
100
+
101
+ def _is_probably_base64(self, s: str) -> bool:
102
+ if s.startswith("data:audio"):
103
+ return True
104
+ # Heuristic: no filesystem path separators and long enough.
105
+ if ("/" not in s and "\\" not in s) and len(s) > 256:
106
+ return True
107
+ return False
108
+
109
+ def _is_url(self, s: str) -> bool:
110
+ try:
111
+ u = urlparse(s)
112
+ return u.scheme in ("http", "https") and bool(u.netloc)
113
+ except Exception:
114
+ return False
115
+
116
+ def _decode_base64_to_wav_bytes(self, b64: str) -> bytes:
117
+ # Accept both "data:audio/wav;base64,...." and raw base64
118
+ if "," in b64 and b64.strip().startswith("data:"):
119
+ b64 = b64.split(",", 1)[1]
120
+ return base64.b64decode(b64)
121
+
122
+ def load_audio(
123
+ self,
124
+ x: str,
125
+ target_sr: int,
126
+ ) -> np.ndarray:
127
+ """
128
+ Load audio from wav path or base64 string, then resample to target_sr.
129
+
130
+ Args:
131
+ x (str):
132
+ A wav file path, or a base64 audio string (raw or data URL).
133
+ target_sr (int):
134
+ Target sampling rate.
135
+
136
+ Returns:
137
+ np.ndarray:
138
+ 1-D float32 waveform at target_sr.
139
+ """
140
+ if self._is_url(x):
141
+ with urllib.request.urlopen(x) as resp:
142
+ audio_bytes = resp.read()
143
+ with io.BytesIO(audio_bytes) as f:
144
+ audio, sr = sf.read(f, dtype="float32", always_2d=False)
145
+ elif self._is_probably_base64(x):
146
+ wav_bytes = self._decode_base64_to_wav_bytes(x)
147
+ with io.BytesIO(wav_bytes) as f:
148
+ audio, sr = sf.read(f, dtype="float32", always_2d=False)
149
+ else:
150
+ audio, sr = librosa.load(x, sr=None, mono=True)
151
+
152
+ if audio.ndim > 1:
153
+ audio = np.mean(audio, axis=-1)
154
+
155
+ if sr != target_sr:
156
+ audio = librosa.resample(y=audio, orig_sr=sr, target_sr=target_sr)
157
+
158
+ return audio.astype(np.float32)
159
+
160
+ def _normalize_audio_inputs(
161
+ self,
162
+ audios: AudioInput,
163
+ sr: Optional[int],
164
+ ) -> List[np.ndarray]:
165
+ """
166
+ Normalize all supported input types into a list of 1-D numpy float32 waveforms
167
+ at `self.feature_extractor.sampling_rate`.
168
+
169
+ Args:
170
+ audios (AudioInput):
171
+ - str: wav path OR base64 audio string
172
+ - np.ndarray: raw waveform (sr must be provided)
173
+ - list[str] / list[np.ndarray]
174
+ sr (Optional[int]):
175
+ Sampling rate for raw numpy input. Required if input is np.ndarray or list[np.ndarray].
176
+
177
+ Returns:
178
+ List[np.ndarray]:
179
+ List of float32 waveforms resampled to model input SR.
180
+ """
181
+ target_sr = int(self.feature_extractor.sampling_rate)
182
+
183
+ if isinstance(audios, (str, np.ndarray)):
184
+ audios = [audios]
185
+
186
+ if len(audios) == 0:
187
+ return []
188
+
189
+ if isinstance(audios[0], str):
190
+ # wav path list or base64 list
191
+ return [self.load_audio(x, target_sr=target_sr) for x in audios] # type: ignore[arg-type]
192
+
193
+ # numpy list
194
+ if sr is None:
195
+ raise ValueError("For numpy waveform input, you must provide `sr` (original sampling rate).")
196
+
197
+ out: List[np.ndarray] = []
198
+ for a in audios: # type: ignore[assignment]
199
+ if not isinstance(a, np.ndarray):
200
+ raise TypeError("Mixed input types are not supported. Use all paths/base64 or all numpy arrays.")
201
+ if a.ndim > 1:
202
+ a = np.mean(a, axis=-1)
203
+ if int(sr) != target_sr:
204
+ a = librosa.resample(y=a.astype(np.float32), orig_sr=int(sr), target_sr=target_sr)
205
+ out.append(a.astype(np.float32))
206
+ return out
207
+
208
+ def encode(
209
+ self,
210
+ audios: AudioInput,
211
+ sr: Optional[int] = None,
212
+ return_dict: bool = True,
213
+ ):
214
+ """
215
+ Batch-encode audio into discrete codes (and optional conditioning, depending on 25Hz/12Hz).
216
+
217
+ Args:
218
+ audios (AudioInput):
219
+ Supported forms:
220
+ - np.ndarray: waveform (requires sr)
221
+ - list[np.ndarray]: waveforms (requires sr)
222
+ - str: wav path OR base64 audio string
223
+ - list[str]: wav paths and/or base64 strings
224
+ sr (Optional[int], default=None):
225
+ Original sampling rate for numpy waveform input.
226
+ return_dict (bool, default=True):
227
+ Forwarded to model.encode(...). If True, returns ModelOutput.
228
+
229
+ Returns:
230
+ 25Hz:
231
+ Qwen3TTSTokenizerV1EncoderOutput (if return_dict=True) with fields:
232
+ - audio_codes: List[torch.LongTensor] each (codes_len,)
233
+ - xvectors: List[torch.FloatTensor] each (xvector_dim,)
234
+ - ref_mels: List[torch.FloatTensor] each (mel_len, mel_dim)
235
+ 12Hz:
236
+ Qwen3TTSTokenizerV2EncoderOutput (if return_dict=True) with fields:
237
+ - audio_codes: List[torch.LongTensor] each (codes_len, num_quantizers)
238
+
239
+ If return_dict=False, returns the raw tuple from model.encode.
240
+ """
241
+ wavs = self._normalize_audio_inputs(audios, sr=sr)
242
+
243
+ inputs = self.feature_extractor(
244
+ raw_audio=wavs,
245
+ sampling_rate=int(self.feature_extractor.sampling_rate),
246
+ return_tensors="pt",
247
+ )
248
+ inputs = inputs.to(self.device).to(self.model.dtype)
249
+
250
+ with torch.inference_mode():
251
+ # model.encode expects (B, T) and (B, T)
252
+ enc = self.model.encode(
253
+ inputs["input_values"].squeeze(1),
254
+ inputs["padding_mask"].squeeze(1),
255
+ return_dict=return_dict,
256
+ )
257
+ return enc
258
+
259
+ def decode(
260
+ self,
261
+ encoded,
262
+ ) -> Tuple[List[np.ndarray], int]:
263
+ """
264
+ Decode back to waveform.
265
+
266
+ Usage:
267
+ 1) Pass the raw output of `encode(...)` directly (recommended).
268
+ - 25Hz: expects fields audio_codes, xvectors, ref_mels
269
+ - 12Hz: expects field audio_codes
270
+ 2) Pass a dict or list[dict] (minimal form) for custom pipelines:
271
+ - 25Hz dict keys: {"audio_codes", "xvectors", "ref_mels"}
272
+ - 12Hz dict keys: {"audio_codes"}
273
+ Values can be torch tensors or numpy arrays.
274
+
275
+ Args:
276
+ encoded (Any):
277
+ - ModelOutput returned by `encode()`, OR
278
+ - dict, OR
279
+ - list[dict]
280
+
281
+ Returns:
282
+ Tuple[List[np.ndarray], int]:
283
+ - wavs: list of 1-D float32 numpy arrays
284
+ - sample_rate: int, model output sampling rate
285
+ """
286
+ model_type = self.model.get_model_type()
287
+
288
+ def _to_tensor(x, dtype=None):
289
+ if isinstance(x, torch.Tensor):
290
+ return x
291
+ x = np.asarray(x)
292
+ t = torch.from_numpy(x)
293
+ if dtype is not None:
294
+ t = t.to(dtype)
295
+ return t
296
+
297
+ # Normalize `encoded` into the same shapes as the official demo uses.
298
+ if hasattr(encoded, "audio_codes"):
299
+ # ModelOutput from encode()
300
+ audio_codes_list = encoded.audio_codes
301
+ xvectors_list = getattr(encoded, "xvectors", None)
302
+ ref_mels_list = getattr(encoded, "ref_mels", None)
303
+ elif isinstance(encoded, dict):
304
+ audio_codes_list = encoded["audio_codes"]
305
+ xvectors_list = encoded.get("xvectors", None)
306
+ ref_mels_list = encoded.get("ref_mels", None)
307
+ elif isinstance(encoded, list):
308
+ # list of dicts
309
+ audio_codes_list = [e["audio_codes"] for e in encoded]
310
+ xvectors_list = [e["xvectors"] for e in encoded] if ("xvectors" in encoded[0]) else None
311
+ ref_mels_list = [e["ref_mels"] for e in encoded] if ("ref_mels" in encoded[0]) else None
312
+ else:
313
+ raise TypeError("`encoded` must be an encode output, a dict, or a list of dicts.")
314
+
315
+ # Ensure list form for per-sample tensors
316
+ if isinstance(audio_codes_list, torch.Tensor):
317
+ # Could be a single sample tensor or an already padded batch tensor.
318
+ t = audio_codes_list
319
+ if t.dim() == 1:
320
+ # 25Hz single sample: (C,) -> (1, C)
321
+ t = t.unsqueeze(0)
322
+ elif t.dim() == 2:
323
+ # 12Hz single sample: (C, Q) -> (1, C, Q)
324
+ t = t.unsqueeze(0)
325
+ audio_codes_padded = t.to(self.device)
326
+ else:
327
+ # List[Tensor/np]
328
+ audio_codes_list = [_to_tensor(c, dtype=torch.long) for c in audio_codes_list]
329
+ audio_codes_padded = pad_sequence(audio_codes_list, batch_first=True, padding_value=0).to(self.device)
330
+
331
+ with torch.inference_mode():
332
+ if model_type == "qwen3_tts_tokenizer_25hz":
333
+ if xvectors_list is None or ref_mels_list is None:
334
+ raise ValueError("25Hz decode requires `xvectors` and `ref_mels`.")
335
+
336
+ if isinstance(xvectors_list, torch.Tensor):
337
+ xvectors_batch = xvectors_list
338
+ if xvectors_batch.dim() == 1: # (D,) -> (1, D)
339
+ xvectors_batch = xvectors_batch.unsqueeze(0)
340
+ xvectors_batch = xvectors_batch.to(self.device).to(self.model.dtype)
341
+ else:
342
+ xvectors_list = [_to_tensor(x, dtype=torch.float32) for x in xvectors_list]
343
+ xvectors_batch = torch.stack(xvectors_list, dim=0).to(self.device).to(self.model.dtype)
344
+
345
+ if isinstance(ref_mels_list, torch.Tensor):
346
+ ref_mels_padded = ref_mels_list
347
+ if ref_mels_padded.dim() == 2: # (T, M) -> (1, T, M)
348
+ ref_mels_padded = ref_mels_padded.unsqueeze(0)
349
+ ref_mels_padded = ref_mels_padded.to(self.device).to(self.model.dtype)
350
+ else:
351
+ ref_mels_list = [_to_tensor(m, dtype=torch.float32) for m in ref_mels_list]
352
+ ref_mels_padded = pad_sequence(ref_mels_list, batch_first=True, padding_value=0).to(self.device).to(self.model.dtype)
353
+
354
+ dec = self.model.decode(audio_codes_padded, xvectors_batch, ref_mels_padded, return_dict=True)
355
+ wav_tensors = dec.audio_values
356
+
357
+ elif model_type == "qwen3_tts_tokenizer_12hz":
358
+ dec = self.model.decode(audio_codes_padded, return_dict=True)
359
+ wav_tensors = dec.audio_values
360
+
361
+ else:
362
+ raise ValueError(f"Unknown model type: {model_type}")
363
+
364
+ wavs = [w.to(torch.float32).detach().cpu().numpy() for w in wav_tensors]
365
+ return wavs, int(self.model.get_output_sample_rate())
366
+
367
+ def get_model_type(self) -> str:
368
+ """
369
+ Get the underlying tokenizer model type.
370
+
371
+ Returns:
372
+ str: Model type string from `self.model.config.model_type`
373
+ (e.g. "qwen3_tts_tokenizer_25hz" / "qwen3_tts_tokenizer_12hz").
374
+ """
375
+ return self.model.get_model_type()
376
+
377
+ def get_input_sample_rate(self) -> int:
378
+ """
379
+ Get the expected input sample rate for encoding.
380
+
381
+ Returns:
382
+ int: Input sample rate (Hz).
383
+ """
384
+ return int(self.model.get_input_sample_rate())
385
+
386
+ def get_output_sample_rate(self) -> int:
387
+ """
388
+ Get the output sample rate for decoded waveforms.
389
+
390
+ Returns:
391
+ int: Output sample rate (Hz).
392
+ """
393
+ return int(self.model.get_output_sample_rate())
394
+
395
+ def get_encode_downsample_rate(self) -> int:
396
+ """
397
+ Get the encoder downsample rate (waveform samples per code step).
398
+
399
+ Returns:
400
+ int: Encode downsample rate.
401
+ """
402
+ return int(self.model.get_encode_downsample_rate())
403
+
404
+ def get_decode_upsample_rate(self) -> int:
405
+ """
406
+ Get the decoder upsample rate (waveform samples per code step).
407
+
408
+ Returns:
409
+ int: Decode upsample rate.
410
+ """
411
+ return int(self.model.get_decode_upsample_rate())
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python-dotenv
2
+ torch==2.8.0
3
+ torchaudio==2.8.0
4
+ transformers==4.57.3
5
+ accelerate==1.12.0
6
+ einops
7
+ gradio
8
+ librosa
9
+ soundfile
10
+ sox
11
+ onnxruntime
12
+ spaces
13
+ numpy
14
+ kernels
15
+ openai-whisper