Masaaki Kawata commited on
Commit ·
e07602f
1
Parent(s): 44e5c9c
initial commit
Browse files- .dockerignore +12 -0
- .gitignore +167 -0
- Dockerfile +47 -0
- README.md +4 -3
- app.py +288 -0
- docker-compose.yml +27 -0
- faster_qwen3_tts/__init__.py +7 -0
- faster_qwen3_tts/cli.py +407 -0
- faster_qwen3_tts/generate.py +215 -0
- faster_qwen3_tts/model.py +1370 -0
- faster_qwen3_tts/predictor_graph.py +214 -0
- faster_qwen3_tts/sampling.py +66 -0
- faster_qwen3_tts/streaming.py +359 -0
- faster_qwen3_tts/talker_graph.py +214 -0
- faster_qwen3_tts/utils.py +30 -0
- main.py +144 -0
- qwen_tts/__init__.py +24 -0
- qwen_tts/__main__.py +24 -0
- qwen_tts/core/__init__.py +19 -0
- qwen_tts/core/models/__init__.py +18 -0
- qwen_tts/core/models/configuration_qwen3_tts.py +502 -0
- qwen_tts/core/models/modeling_qwen3_tts.py +0 -0
- qwen_tts/core/models/processing_qwen3_tts.py +106 -0
- qwen_tts/core/tokenizer_12hz/configuration_qwen3_tts_tokenizer_v2.py +172 -0
- qwen_tts/core/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py +1025 -0
- qwen_tts/core/tokenizer_25hz/configuration_qwen3_tts_tokenizer_v1.py +332 -0
- qwen_tts/core/tokenizer_25hz/modeling_qwen3_tts_tokenizer_v1.py +1528 -0
- qwen_tts/core/tokenizer_25hz/vq/assets/mel_filters.npz +3 -0
- qwen_tts/core/tokenizer_25hz/vq/core_vq.py +523 -0
- qwen_tts/core/tokenizer_25hz/vq/speech_vq.py +357 -0
- qwen_tts/core/tokenizer_25hz/vq/whisper_encoder.py +406 -0
- qwen_tts/inference/qwen3_tts_model.py +877 -0
- qwen_tts/inference/qwen3_tts_tokenizer.py +411 -0
- requirements.txt +15 -0
.dockerignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
| 2 |
+
.env
|
| 3 |
+
.git
|
| 4 |
+
.gitignore
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[cod]
|
| 7 |
+
.cache/
|
| 8 |
+
.venv/
|
| 9 |
+
venv/
|
| 10 |
+
logs/
|
| 11 |
+
data/
|
| 12 |
+
models/
|
.gitignore
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# poetry
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 102 |
+
#poetry.lock
|
| 103 |
+
|
| 104 |
+
# pdm
|
| 105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 106 |
+
#pdm.lock
|
| 107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 108 |
+
# in version control.
|
| 109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 110 |
+
.pdm.toml
|
| 111 |
+
.pdm-python
|
| 112 |
+
.pdm-build/
|
| 113 |
+
|
| 114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 115 |
+
__pypackages__/
|
| 116 |
+
|
| 117 |
+
# Celery stuff
|
| 118 |
+
celerybeat-schedule
|
| 119 |
+
celerybeat.pid
|
| 120 |
+
|
| 121 |
+
# SageMath parsed files
|
| 122 |
+
*.sage.py
|
| 123 |
+
|
| 124 |
+
# Environments
|
| 125 |
+
.env
|
| 126 |
+
.venv
|
| 127 |
+
env/
|
| 128 |
+
venv/
|
| 129 |
+
ENV/
|
| 130 |
+
env.bak/
|
| 131 |
+
venv.bak/
|
| 132 |
+
|
| 133 |
+
# Spyder project settings
|
| 134 |
+
.spyderproject
|
| 135 |
+
.spyproject
|
| 136 |
+
|
| 137 |
+
# Rope project settings
|
| 138 |
+
.ropeproject
|
| 139 |
+
|
| 140 |
+
# mkdocs documentation
|
| 141 |
+
/site
|
| 142 |
+
|
| 143 |
+
# mypy
|
| 144 |
+
.mypy_cache/
|
| 145 |
+
.dmypy.json
|
| 146 |
+
dmypy.json
|
| 147 |
+
|
| 148 |
+
# Pyre type checker
|
| 149 |
+
.pyre/
|
| 150 |
+
|
| 151 |
+
# pytype static type analyzer
|
| 152 |
+
.pytype/
|
| 153 |
+
|
| 154 |
+
# Cython debug symbols
|
| 155 |
+
cython_debug/
|
| 156 |
+
|
| 157 |
+
# PyCharm
|
| 158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 162 |
+
#.idea/
|
| 163 |
+
|
| 164 |
+
# Generated by MacOS
|
| 165 |
+
.DS_Store
|
| 166 |
+
|
| 167 |
+
#GPT_SoVITS/text/ja_userdic/
|
Dockerfile
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
ENV DEBIAN_FRONTEND=noninteractive \
|
| 4 |
+
PYTHONUNBUFFERED=1 \
|
| 5 |
+
PIP_NO_CACHE_DIR=1 \
|
| 6 |
+
HF_HOME=/data/huggingface \
|
| 7 |
+
WHISPER_CACHE_DIR=/data/whisper \
|
| 8 |
+
GRADIO_SERVER_NAME=0.0.0.0 \
|
| 9 |
+
GRADIO_SERVER_PORT=7860 \
|
| 10 |
+
NVIDIA_VISIBLE_DEVICES=all \
|
| 11 |
+
NVIDIA_DRIVER_CAPABILITIES=compute,utility
|
| 12 |
+
|
| 13 |
+
RUN apt-get update \
|
| 14 |
+
&& apt-get install -y --no-install-recommends \
|
| 15 |
+
build-essential \
|
| 16 |
+
curl \
|
| 17 |
+
ffmpeg \
|
| 18 |
+
git \
|
| 19 |
+
libsndfile1 \
|
| 20 |
+
sox \
|
| 21 |
+
&& apt-get clean \
|
| 22 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 23 |
+
|
| 24 |
+
WORKDIR /app
|
| 25 |
+
|
| 26 |
+
COPY requirements.txt .
|
| 27 |
+
RUN python -m pip install --upgrade pip setuptools wheel \
|
| 28 |
+
&& python -m pip install \
|
| 29 |
+
--index-url https://download.pytorch.org/whl/cu128 \
|
| 30 |
+
torch==2.10.0+cu128 \
|
| 31 |
+
torchaudio==2.10.0+cu128 \
|
| 32 |
+
&& sed '/^torch==/d; /^torchaudio==/d' requirements.txt > /tmp/requirements-no-torch.txt \
|
| 33 |
+
&& python -m pip install -r /tmp/requirements-no-torch.txt
|
| 34 |
+
|
| 35 |
+
COPY app.py .
|
| 36 |
+
COPY faster_qwen3_tts ./faster_qwen3_tts
|
| 37 |
+
COPY qwen_tts ./qwen_tts
|
| 38 |
+
|
| 39 |
+
RUN useradd --create-home --uid 1000 appuser \
|
| 40 |
+
&& mkdir -p /data/huggingface /data/whisper \
|
| 41 |
+
&& chown -R appuser:appuser /app /data
|
| 42 |
+
|
| 43 |
+
USER appuser
|
| 44 |
+
|
| 45 |
+
EXPOSE 7860
|
| 46 |
+
|
| 47 |
+
CMD ["python", "app.py"]
|
README.md
CHANGED
|
@@ -1,13 +1,14 @@
|
|
| 1 |
---
|
| 2 |
title: Merkurius
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.14.0
|
| 8 |
python_version: '3.12'
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
title: Merkurius
|
| 3 |
+
emoji: 🌟
|
| 4 |
+
colorFrom: pink
|
| 5 |
+
colorTo: yellow
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 6.14.0
|
| 8 |
python_version: '3.12'
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
+
short_description: milchchan.com
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#import subprocess
|
| 2 |
+
#subprocess.run('pip install flash-attn==2.7.4.post1', shell=True)
|
| 3 |
+
import io
|
| 4 |
+
import re
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import hashlib
|
| 8 |
+
import threading
|
| 9 |
+
import time
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import spaces
|
| 13 |
+
import whisper
|
| 14 |
+
import gradio as gr
|
| 15 |
+
from gradio.themes.base import Base
|
| 16 |
+
from gradio.themes.utils import colors, fonts, sizes
|
| 17 |
+
from typing import Iterable
|
| 18 |
+
from dotenv import load_dotenv
|
| 19 |
+
from urllib.request import urlopen, Request
|
| 20 |
+
from scipy.signal import resample_poly
|
| 21 |
+
#from huggingface_hub import snapshot_download
|
| 22 |
+
#from qwen_tts import Qwen3TTSModel
|
| 23 |
+
from faster_qwen3_tts import FasterQwen3TTS
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
load_dotenv(verbose=False)
|
| 27 |
+
|
| 28 |
+
#TTS_MODEL = Qwen3TTSModel.from_pretrained(snapshot_download('Qwen/Qwen3-TTS-12Hz-1.7B-Base', token=os.environ['HF_TOKEN']), device_map=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), dtype=torch.bfloat16, token=os.environ['HF_TOKEN'], attn_implementation='kernels-community/flash-attn3')
|
| 29 |
+
TTS_MODEL = FasterQwen3TTS.from_pretrained('Qwen/Qwen3-TTS-12Hz-1.7B-Base')
|
| 30 |
+
WHISPER_MODEL = whisper.load_model('turbo', device='cpu', download_root=os.environ.get('WHISPER_CACHE_DIR'))
|
| 31 |
+
REFERENCE_AUDIO_TRANSCRIPTION_CACHE: dict[str, tuple[float, str, str]] = {}
|
| 32 |
+
REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LOCK = threading.Lock()
|
| 33 |
+
REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LIMIT = max(1, int(os.environ.get('REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LIMIT', 100)))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Theme(Base):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
*,
|
| 40 |
+
primary_hue: colors.Color | str = colors.neutral,
|
| 41 |
+
secondary_hue: colors.Color | str = colors.neutral,
|
| 42 |
+
neutral_hue: colors.Color | str = colors.neutral,
|
| 43 |
+
spacing_size: sizes.Size | str = sizes.spacing_md,
|
| 44 |
+
radius_size: sizes.Size | str = sizes.radius_md,
|
| 45 |
+
text_size: sizes.Size | str = sizes.text_md,
|
| 46 |
+
font: fonts.Font | str | Iterable[fonts.Font | str] = (fonts.GoogleFont('Barlow'), 'ui-sans-serif', 'sans-serif'),
|
| 47 |
+
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (fonts.GoogleFont('IBM Plex Mono'), 'ui-monospace', 'monospace'),
|
| 48 |
+
):
|
| 49 |
+
super().__init__(
|
| 50 |
+
primary_hue=primary_hue,
|
| 51 |
+
secondary_hue=secondary_hue,
|
| 52 |
+
neutral_hue=neutral_hue,
|
| 53 |
+
spacing_size=spacing_size,
|
| 54 |
+
radius_size=radius_size,
|
| 55 |
+
text_size=text_size,
|
| 56 |
+
font=font,
|
| 57 |
+
font_mono=font_mono,
|
| 58 |
+
)
|
| 59 |
+
super().set(
|
| 60 |
+
color_accent='rgb(0 231 255 / 1)',
|
| 61 |
+
slider_color='rgb(0 231 255 / 1)',
|
| 62 |
+
slider_color_dark='rgb(0 231 255 / 1)',
|
| 63 |
+
button_primary_background_fill='rgb(0 231 255 / 1)',
|
| 64 |
+
button_primary_background_fill_hover='rgb(0 231 255 / .75)',
|
| 65 |
+
button_primary_text_color='#ffffff',
|
| 66 |
+
button_primary_background_fill_dark='rgb(0 231 255 / 1)',
|
| 67 |
+
button_primary_background_fill_hover_dark='rgb(0 231 255 / .75)',
|
| 68 |
+
button_primary_text_color_dark='#ffffff',
|
| 69 |
+
loader_color='rgb(255 199 229 / 1)',
|
| 70 |
+
loader_color_dark='rgb(255 199 229 / 1)'
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _normalize_audio(wav, eps=1e-12, clip=True):
|
| 75 |
+
"""Normalize audio to float32 in [-1, 1] range."""
|
| 76 |
+
x = np.asarray(wav)
|
| 77 |
+
|
| 78 |
+
if np.issubdtype(x.dtype, np.integer):
|
| 79 |
+
info = np.iinfo(x.dtype)
|
| 80 |
+
|
| 81 |
+
if info.min < 0:
|
| 82 |
+
y = x.astype(np.float32) / max(abs(info.min), info.max)
|
| 83 |
+
else:
|
| 84 |
+
mid = (info.max + 1) / 2.0
|
| 85 |
+
y = (x.astype(np.float32) - mid) / mid
|
| 86 |
+
|
| 87 |
+
elif np.issubdtype(x.dtype, np.floating):
|
| 88 |
+
y = x.astype(np.float32)
|
| 89 |
+
m = np.max(np.abs(y)) if y.size else 0.0
|
| 90 |
+
|
| 91 |
+
if m > 1.0 + 1e-6:
|
| 92 |
+
y = y / (m + eps)
|
| 93 |
+
else:
|
| 94 |
+
return None
|
| 95 |
+
|
| 96 |
+
if clip:
|
| 97 |
+
y = np.clip(y, -1.0, 1.0)
|
| 98 |
+
|
| 99 |
+
if y.ndim > 1:
|
| 100 |
+
y = np.mean(y, axis=-1).astype(np.float32)
|
| 101 |
+
|
| 102 |
+
return y
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _resample(x: np.ndarray, original_sample_rate: int, target_sample_rate: int, axis: int = 0) -> np.ndarray:
|
| 106 |
+
g = np.gcd(original_sample_rate, target_sample_rate)
|
| 107 |
+
|
| 108 |
+
return resample_poly(x, up=target_sample_rate // g, down=original_sample_rate // g, axis=axis)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _reference_audio_hash(reference_audio: tuple[np.ndarray, int]) -> str:
|
| 112 |
+
audio = reference_audio[0]
|
| 113 |
+
audio = np.ascontiguousarray(np.asarray(audio))
|
| 114 |
+
digest = hashlib.sha256()
|
| 115 |
+
digest.update(audio.tobytes())
|
| 116 |
+
|
| 117 |
+
return digest.hexdigest()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _detect_reference_text_and_language(reference_audio: tuple[np.ndarray, int], sample_rate: int) -> tuple[str, str]:
|
| 121 |
+
audio = np.asarray(reference_audio[0])
|
| 122 |
+
|
| 123 |
+
if audio.ndim == 2:
|
| 124 |
+
audio = audio.mean(axis=1)
|
| 125 |
+
|
| 126 |
+
if sample_rate != 16000:
|
| 127 |
+
audio = _resample(audio, sample_rate, 16000).astype(np.float32)
|
| 128 |
+
|
| 129 |
+
model = WHISPER_MODEL.to(device='cuda' if torch.cuda.is_available() else 'cpu')
|
| 130 |
+
audio = np.clip(audio, -1.0, 1.0)
|
| 131 |
+
audio = whisper.pad_or_trim(audio)
|
| 132 |
+
mel = whisper.log_mel_spectrogram(audio, n_mels=model.dims.n_mels).to(model.device)
|
| 133 |
+
_, probs = model.detect_language(mel)
|
| 134 |
+
detected_language = max(probs, key=probs.get)
|
| 135 |
+
result = whisper.decode(model, mel, whisper.DecodingOptions())
|
| 136 |
+
reference_text = re.sub(r'\s*\n\s*', '', result.text)
|
| 137 |
+
|
| 138 |
+
if detected_language == 'ja':
|
| 139 |
+
converted_reference_text = generate_text(reference_text)
|
| 140 |
+
|
| 141 |
+
if converted_reference_text is not None:
|
| 142 |
+
reference_text = converted_reference_text
|
| 143 |
+
|
| 144 |
+
return reference_text, detected_language
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _get_reference_text_and_language(reference_audio: tuple[np.ndarray, int], sample_rate: int) -> tuple[str, str]:
|
| 148 |
+
cache_key = _reference_audio_hash(reference_audio)
|
| 149 |
+
|
| 150 |
+
with REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LOCK:
|
| 151 |
+
cached_result = REFERENCE_AUDIO_TRANSCRIPTION_CACHE.get(cache_key)
|
| 152 |
+
|
| 153 |
+
if cached_result is not None:
|
| 154 |
+
_, reference_text, detected_language = cached_result
|
| 155 |
+
REFERENCE_AUDIO_TRANSCRIPTION_CACHE[cache_key] = (time.time(), reference_text, detected_language)
|
| 156 |
+
|
| 157 |
+
if cached_result is not None:
|
| 158 |
+
return reference_text, detected_language
|
| 159 |
+
|
| 160 |
+
reference_text, detected_language = _detect_reference_text_and_language(reference_audio, sample_rate)
|
| 161 |
+
|
| 162 |
+
with REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LOCK:
|
| 163 |
+
REFERENCE_AUDIO_TRANSCRIPTION_CACHE[cache_key] = (time.time(), reference_text, detected_language)
|
| 164 |
+
|
| 165 |
+
if len(REFERENCE_AUDIO_TRANSCRIPTION_CACHE) > REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LIMIT:
|
| 166 |
+
expired_cache_keys = sorted(
|
| 167 |
+
REFERENCE_AUDIO_TRANSCRIPTION_CACHE,
|
| 168 |
+
key=lambda key: REFERENCE_AUDIO_TRANSCRIPTION_CACHE[key][0]
|
| 169 |
+
)[:-REFERENCE_AUDIO_TRANSCRIPTION_CACHE_LIMIT]
|
| 170 |
+
|
| 171 |
+
for expired_cache_key in expired_cache_keys:
|
| 172 |
+
del REFERENCE_AUDIO_TRANSCRIPTION_CACHE[expired_cache_key]
|
| 173 |
+
|
| 174 |
+
return reference_text, detected_language
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def generate_text(prompt: str) -> str | None:
|
| 178 |
+
system_prompt = '''あなたは日本語テキストを「読み(かな)」だけに変換する変換器です。
|
| 179 |
+
|
| 180 |
+
出力に含めてよい文字は ひらがな・カタカナ・長音記号ー・空白 のみです。改行も禁止(1行で出力)。
|
| 181 |
+
入力に含まれる 漢字は必ずかなにする。
|
| 182 |
+
英数字・記号は、可能な範囲で日本語のカナ読みにする(例:AI→えーあい、LLM→えるえるえむ、2026→にせんにじゅうろく)。
|
| 183 |
+
出力は 変換後の本文のみ。説明、注釈、引用符、箇条書き、コードブロックは一切禁止。
|
| 184 |
+
最後に必ず自己検査を行う:出力が ^[ぁ-ゟ゠-ヿー ]+$ に一致しない場合、条件を満たすまで修正してから出力する。
|
| 185 |
+
それでも読めない文字がある場合は、意味を落としてよいので「最も近いかな」に置き換える(記号は省略よりも読みを優先。ただし許可文字以外は絶対に出さない)。'''
|
| 186 |
+
request = Request('https://api.openai.com/v1/responses', data=json.dumps({
|
| 187 |
+
'model': os.environ.get('OPENAI_MODEL', 'gpt-5.4-mini'),
|
| 188 |
+
'input': [{
|
| 189 |
+
'role': 'developer',
|
| 190 |
+
'content': system_prompt
|
| 191 |
+
},
|
| 192 |
+
{
|
| 193 |
+
'role': 'user',
|
| 194 |
+
'content': [
|
| 195 |
+
{
|
| 196 |
+
'type': 'input_text',
|
| 197 |
+
'text': prompt
|
| 198 |
+
}
|
| 199 |
+
]
|
| 200 |
+
}],
|
| 201 |
+
'temperature': 1,
|
| 202 |
+
'reasoning': {'effort': 'none'},
|
| 203 |
+
}).encode('utf-8'), method='POST', headers={'Content-Type': 'application/json', 'Authorization': f'Bearer {os.environ["OPENAI_API_KEY"]}'})
|
| 204 |
+
|
| 205 |
+
with urlopen(request) as response:
|
| 206 |
+
result = json.loads(response.read().decode('utf-8'))
|
| 207 |
+
|
| 208 |
+
for output in result['output']:
|
| 209 |
+
if 'type' in output and output['type'] == 'message':
|
| 210 |
+
for content in output['content']:
|
| 211 |
+
if 'type' in content and content['type'] == 'output_text':
|
| 212 |
+
return content['text']
|
| 213 |
+
|
| 214 |
+
return None
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
@spaces.GPU(duration=30)
|
| 218 |
+
def generate_voice_clone(input_text: str, language: str | None, reference_audio: np.ndarray, reference_text: str | None, temperature: float, progress: gr.Progress=gr.Progress(track_tqdm=True)) -> (np.ndarray, str | None, str | None):
|
| 219 |
+
language_codes = {'en': 'English', 'ja': 'Japanese'}
|
| 220 |
+
transcribed_text = None
|
| 221 |
+
detected_language = None
|
| 222 |
+
|
| 223 |
+
if isinstance(reference_audio, tuple) and len(reference_audio) == 2 and isinstance(reference_audio[0], int):
|
| 224 |
+
sample_rate, wav = reference_audio
|
| 225 |
+
sample_rate = int(sample_rate)
|
| 226 |
+
reference_audio = (_normalize_audio(wav), sample_rate)
|
| 227 |
+
|
| 228 |
+
if isinstance(reference_audio, dict) and 'sampling_rate' in reference_audio and 'data' in reference_audio:
|
| 229 |
+
sample_rate = int(reference_audio['sampling_rate'])
|
| 230 |
+
reference_audio = (_normalize_audio(reference_audio['data']), sample_rate)
|
| 231 |
+
|
| 232 |
+
if reference_text is None or len(reference_text) == 0:
|
| 233 |
+
reference_text, detected_language = _get_reference_text_and_language(reference_audio, sample_rate)
|
| 234 |
+
transcribed_text = reference_text
|
| 235 |
+
|
| 236 |
+
if language is None:
|
| 237 |
+
if detected_language in language_codes:
|
| 238 |
+
language = language_codes[detected_language]
|
| 239 |
+
else:
|
| 240 |
+
language = 'Auto'
|
| 241 |
+
elif language == 'Auto':
|
| 242 |
+
if detected_language in language_codes:
|
| 243 |
+
language = language_codes[detected_language]
|
| 244 |
+
elif language in language_codes:
|
| 245 |
+
language = language_codes[language]
|
| 246 |
+
|
| 247 |
+
elif language is None:
|
| 248 |
+
language = 'Auto'
|
| 249 |
+
|
| 250 |
+
elif language in language_codes:
|
| 251 |
+
language = language_codes[language]
|
| 252 |
+
|
| 253 |
+
if sample_rate != 48000:
|
| 254 |
+
reference_audio = (_resample(reference_audio[0], sample_rate, 48000), 48000)
|
| 255 |
+
|
| 256 |
+
wavs, sample_rate = TTS_MODEL.generate_voice_clone(text=input_text.strip(), language=language, ref_audio=reference_audio, ref_text=reference_text.strip(), temperature=temperature, append_silence=False)
|
| 257 |
+
#wavs, sample_rate = TTS_MODEL.generate_voice_clone(text=input_text.strip(), language=language, ref_audio=reference_audio, ref_text=reference_text, max_new_tokens=2048, temperature=temperature)
|
| 258 |
+
|
| 259 |
+
return (sample_rate, (np.clip(wavs[0], -1.0, 1.0) * 32768.0).round().astype(np.int16)), transcribed_text, detected_language
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
with gr.Blocks() as demo:
|
| 263 |
+
with gr.Row():
|
| 264 |
+
with gr.Column(scale=2):
|
| 265 |
+
with gr.Group():
|
| 266 |
+
tts_reference_audio = gr.Audio(label='Reference Audio', type='numpy', buttons=['download'], waveform_options={'waveform_color': 'rgb(0 231 255 / 1)', 'waveform_progress_color': 'rgb(255 199 229 / 1)'})
|
| 267 |
+
tts_reference_text = gr.Textbox(label='Reference Text', value='', lines=1)
|
| 268 |
+
|
| 269 |
+
tts_input_text = gr.Textbox(label='Input', lines=4)
|
| 270 |
+
tts_language = gr.Dropdown(label='Language', choices=[('Automatic', 'Auto'), ('English', 'en'), ('Japanese', 'ja')], value='Auto', interactive=True)
|
| 271 |
+
tts_temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, value=0.9, step=0.1, label='Temperature')
|
| 272 |
+
tts_generate_button = gr.Button('Generate', variant='primary')
|
| 273 |
+
|
| 274 |
+
with gr.Column(scale=2):
|
| 275 |
+
tts_audio_output = gr.Audio(label='Output', type='numpy', buttons=['download'], waveform_options={'waveform_color': 'rgb(0 231 255 / 1)', 'waveform_progress_color': 'rgb(255 199 229 / 1)'})
|
| 276 |
+
tts_transcribed_text = gr.Label(label='Transcript', value='')
|
| 277 |
+
tts_detected_language = gr.Label(label='Language', value='')
|
| 278 |
+
|
| 279 |
+
tts_generate_button.click(fn=generate_voice_clone, inputs=[tts_input_text, tts_language, tts_reference_audio, tts_reference_text, tts_temperature_slider], outputs=[tts_audio_output, tts_transcribed_text, tts_detected_language], api_name='synthesize')
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
if __name__ == '__main__':
|
| 283 |
+
demo.launch(
|
| 284 |
+
server_name=os.environ.get('GRADIO_SERVER_NAME', '0.0.0.0'),
|
| 285 |
+
server_port=int(os.environ.get('GRADIO_SERVER_PORT', os.environ.get('PORT', 7860))),
|
| 286 |
+
theme=Theme(),
|
| 287 |
+
css='.column>.row>.column:first-of-type .block { border-width: 0px !important; }'
|
| 288 |
+
)
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
ai:
|
| 3 |
+
container_name: "milchchanai"
|
| 4 |
+
build:
|
| 5 |
+
context: .
|
| 6 |
+
restart: unless-stopped
|
| 7 |
+
tty: true
|
| 8 |
+
env_file:
|
| 9 |
+
- .env
|
| 10 |
+
environment:
|
| 11 |
+
GRADIO_SERVER_NAME: 0.0.0.0
|
| 12 |
+
GRADIO_SERVER_PORT: 7860
|
| 13 |
+
HF_HOME: /data/huggingface
|
| 14 |
+
WHISPER_CACHE_DIR: /data/whisper
|
| 15 |
+
volumes:
|
| 16 |
+
- hf-cache:/data/huggingface
|
| 17 |
+
- whisper-cache:/data/whisper
|
| 18 |
+
ports:
|
| 19 |
+
- "7860:7860"
|
| 20 |
+
deploy:
|
| 21 |
+
resources:
|
| 22 |
+
reservations:
|
| 23 |
+
devices:
|
| 24 |
+
- capabilities: [gpu]
|
| 25 |
+
volumes:
|
| 26 |
+
hf-cache:
|
| 27 |
+
whisper-cache:
|
faster_qwen3_tts/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
faster-qwen3-tts: Real-time Qwen3-TTS inference using CUDA graphs
|
| 3 |
+
"""
|
| 4 |
+
from .model import FasterQwen3TTS
|
| 5 |
+
|
| 6 |
+
__version__ = "0.2.5"
|
| 7 |
+
__all__ = ["FasterQwen3TTS"]
|
faster_qwen3_tts/cli.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""CLI for FasterQwen3TTS."""
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import time
|
| 7 |
+
import numpy as np
|
| 8 |
+
import soundfile as sf
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from faster_qwen3_tts import FasterQwen3TTS
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _load_model(model_id: str, device: str, dtype: str):
|
| 15 |
+
if dtype == "bf16":
|
| 16 |
+
torch_dtype = torch.bfloat16
|
| 17 |
+
elif dtype == "fp16":
|
| 18 |
+
torch_dtype = torch.float16
|
| 19 |
+
else:
|
| 20 |
+
torch_dtype = torch.float32
|
| 21 |
+
|
| 22 |
+
return FasterQwen3TTS.from_pretrained(
|
| 23 |
+
model_id,
|
| 24 |
+
device=device,
|
| 25 |
+
dtype=torch_dtype,
|
| 26 |
+
attn_implementation="sdpa",
|
| 27 |
+
max_seq_len=2048,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _write_audio(out_path: str, audio: np.ndarray, sr: int):
|
| 32 |
+
os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
|
| 33 |
+
sf.write(out_path, audio, sr)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _stream_to_audio(gen):
|
| 37 |
+
chunks = []
|
| 38 |
+
sr = None
|
| 39 |
+
for audio_chunk, sr, _ in gen:
|
| 40 |
+
chunks.append(audio_chunk)
|
| 41 |
+
if not chunks:
|
| 42 |
+
return np.zeros(1, dtype=np.float32), 24000
|
| 43 |
+
return np.concatenate(chunks), sr
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def cmd_clone(args):
|
| 47 |
+
model = _load_model(args.model, args.device, args.dtype)
|
| 48 |
+
|
| 49 |
+
if args.streaming:
|
| 50 |
+
start = time.perf_counter()
|
| 51 |
+
gen = model.generate_voice_clone_streaming(
|
| 52 |
+
text=args.text,
|
| 53 |
+
language=args.language,
|
| 54 |
+
ref_audio=args.ref_audio,
|
| 55 |
+
ref_text=args.ref_text,
|
| 56 |
+
chunk_size=args.chunk_size,
|
| 57 |
+
max_new_tokens=args.max_new_tokens,
|
| 58 |
+
temperature=args.temperature,
|
| 59 |
+
top_k=args.top_k,
|
| 60 |
+
do_sample=not args.greedy,
|
| 61 |
+
repetition_penalty=args.repetition_penalty,
|
| 62 |
+
xvec_only=args.xvec_only,
|
| 63 |
+
non_streaming_mode=args.non_streaming_mode,
|
| 64 |
+
)
|
| 65 |
+
audio, sr = _stream_to_audio(gen)
|
| 66 |
+
total_time = time.perf_counter() - start
|
| 67 |
+
audio_dur = len(audio) / sr if sr else 0.0
|
| 68 |
+
rtf = audio_dur / total_time if total_time > 0 else 0.0
|
| 69 |
+
else:
|
| 70 |
+
start = time.perf_counter()
|
| 71 |
+
audio_list, sr = model.generate_voice_clone(
|
| 72 |
+
text=args.text,
|
| 73 |
+
language=args.language,
|
| 74 |
+
ref_audio=args.ref_audio,
|
| 75 |
+
ref_text=args.ref_text,
|
| 76 |
+
max_new_tokens=args.max_new_tokens,
|
| 77 |
+
temperature=args.temperature,
|
| 78 |
+
top_k=args.top_k,
|
| 79 |
+
do_sample=not args.greedy,
|
| 80 |
+
repetition_penalty=args.repetition_penalty,
|
| 81 |
+
xvec_only=args.xvec_only,
|
| 82 |
+
non_streaming_mode=args.non_streaming_mode,
|
| 83 |
+
)
|
| 84 |
+
audio = audio_list[0]
|
| 85 |
+
total_time = time.perf_counter() - start
|
| 86 |
+
audio_dur = len(audio) / sr if sr else 0.0
|
| 87 |
+
rtf = audio_dur / total_time if total_time > 0 else 0.0
|
| 88 |
+
|
| 89 |
+
_write_audio(args.output, audio, sr)
|
| 90 |
+
print(f"Wrote {args.output} (dur {audio_dur:.2f}s, RTF {rtf:.2f})")
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def cmd_custom(args):
|
| 94 |
+
model = _load_model(args.model, args.device, args.dtype)
|
| 95 |
+
|
| 96 |
+
if args.list_speakers:
|
| 97 |
+
speakers = model.model.get_supported_speakers() or []
|
| 98 |
+
print("\n".join(speakers))
|
| 99 |
+
return
|
| 100 |
+
|
| 101 |
+
if not args.speaker:
|
| 102 |
+
print("ERROR: --speaker is required (use --list-speakers)")
|
| 103 |
+
sys.exit(2)
|
| 104 |
+
|
| 105 |
+
if args.streaming:
|
| 106 |
+
start = time.perf_counter()
|
| 107 |
+
gen = model.generate_custom_voice_streaming(
|
| 108 |
+
text=args.text,
|
| 109 |
+
speaker=args.speaker,
|
| 110 |
+
language=args.language,
|
| 111 |
+
instruct=args.instruct,
|
| 112 |
+
chunk_size=args.chunk_size,
|
| 113 |
+
max_new_tokens=args.max_new_tokens,
|
| 114 |
+
temperature=args.temperature,
|
| 115 |
+
top_k=args.top_k,
|
| 116 |
+
do_sample=not args.greedy,
|
| 117 |
+
repetition_penalty=args.repetition_penalty,
|
| 118 |
+
)
|
| 119 |
+
audio, sr = _stream_to_audio(gen)
|
| 120 |
+
total_time = time.perf_counter() - start
|
| 121 |
+
audio_dur = len(audio) / sr if sr else 0.0
|
| 122 |
+
rtf = audio_dur / total_time if total_time > 0 else 0.0
|
| 123 |
+
else:
|
| 124 |
+
start = time.perf_counter()
|
| 125 |
+
audio_list, sr = model.generate_custom_voice(
|
| 126 |
+
text=args.text,
|
| 127 |
+
speaker=args.speaker,
|
| 128 |
+
language=args.language,
|
| 129 |
+
instruct=args.instruct,
|
| 130 |
+
max_new_tokens=args.max_new_tokens,
|
| 131 |
+
temperature=args.temperature,
|
| 132 |
+
top_k=args.top_k,
|
| 133 |
+
do_sample=not args.greedy,
|
| 134 |
+
repetition_penalty=args.repetition_penalty,
|
| 135 |
+
)
|
| 136 |
+
audio = audio_list[0]
|
| 137 |
+
total_time = time.perf_counter() - start
|
| 138 |
+
audio_dur = len(audio) / sr if sr else 0.0
|
| 139 |
+
rtf = audio_dur / total_time if total_time > 0 else 0.0
|
| 140 |
+
|
| 141 |
+
_write_audio(args.output, audio, sr)
|
| 142 |
+
print(f"Wrote {args.output} (dur {audio_dur:.2f}s, RTF {rtf:.2f})")
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def cmd_design(args):
|
| 146 |
+
model = _load_model(args.model, args.device, args.dtype)
|
| 147 |
+
|
| 148 |
+
if args.streaming:
|
| 149 |
+
start = time.perf_counter()
|
| 150 |
+
gen = model.generate_voice_design_streaming(
|
| 151 |
+
text=args.text,
|
| 152 |
+
instruct=args.instruct,
|
| 153 |
+
language=args.language,
|
| 154 |
+
chunk_size=args.chunk_size,
|
| 155 |
+
max_new_tokens=args.max_new_tokens,
|
| 156 |
+
temperature=args.temperature,
|
| 157 |
+
top_k=args.top_k,
|
| 158 |
+
do_sample=not args.greedy,
|
| 159 |
+
repetition_penalty=args.repetition_penalty,
|
| 160 |
+
)
|
| 161 |
+
audio, sr = _stream_to_audio(gen)
|
| 162 |
+
total_time = time.perf_counter() - start
|
| 163 |
+
audio_dur = len(audio) / sr if sr else 0.0
|
| 164 |
+
rtf = audio_dur / total_time if total_time > 0 else 0.0
|
| 165 |
+
else:
|
| 166 |
+
start = time.perf_counter()
|
| 167 |
+
audio_list, sr = model.generate_voice_design(
|
| 168 |
+
text=args.text,
|
| 169 |
+
instruct=args.instruct,
|
| 170 |
+
language=args.language,
|
| 171 |
+
max_new_tokens=args.max_new_tokens,
|
| 172 |
+
temperature=args.temperature,
|
| 173 |
+
top_k=args.top_k,
|
| 174 |
+
do_sample=not args.greedy,
|
| 175 |
+
repetition_penalty=args.repetition_penalty,
|
| 176 |
+
)
|
| 177 |
+
audio = audio_list[0]
|
| 178 |
+
total_time = time.perf_counter() - start
|
| 179 |
+
audio_dur = len(audio) / sr if sr else 0.0
|
| 180 |
+
rtf = audio_dur / total_time if total_time > 0 else 0.0
|
| 181 |
+
|
| 182 |
+
_write_audio(args.output, audio, sr)
|
| 183 |
+
print(f"Wrote {args.output} (dur {audio_dur:.2f}s, RTF {rtf:.2f})")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def cmd_serve(args):
|
| 187 |
+
model = _load_model(args.model, args.device, args.dtype)
|
| 188 |
+
|
| 189 |
+
if args.mode == "clone":
|
| 190 |
+
if not args.ref_audio or not args.ref_text:
|
| 191 |
+
print("ERROR: --ref-audio and --ref-text are required for clone mode")
|
| 192 |
+
sys.exit(2)
|
| 193 |
+
if args.mode == "custom" and not args.speaker:
|
| 194 |
+
print("ERROR: --speaker is required for custom mode")
|
| 195 |
+
sys.exit(2)
|
| 196 |
+
if args.mode == "design" and not args.instruct:
|
| 197 |
+
print("ERROR: --instruct is required for design mode")
|
| 198 |
+
sys.exit(2)
|
| 199 |
+
|
| 200 |
+
print("Server started. Enter text per line. Type 'exit' or 'quit' to stop.")
|
| 201 |
+
idx = 1
|
| 202 |
+
for line in sys.stdin:
|
| 203 |
+
text = line.strip()
|
| 204 |
+
if not text:
|
| 205 |
+
continue
|
| 206 |
+
if text.lower() in ("exit", "quit", "stop"):
|
| 207 |
+
break
|
| 208 |
+
|
| 209 |
+
out_path = os.path.join(args.output_dir, f"out_{idx:04d}.wav")
|
| 210 |
+
idx += 1
|
| 211 |
+
|
| 212 |
+
start = time.perf_counter()
|
| 213 |
+
|
| 214 |
+
if args.mode == "clone":
|
| 215 |
+
if args.streaming:
|
| 216 |
+
gen = model.generate_voice_clone_streaming(
|
| 217 |
+
text=text,
|
| 218 |
+
language=args.language,
|
| 219 |
+
ref_audio=args.ref_audio,
|
| 220 |
+
ref_text=args.ref_text,
|
| 221 |
+
chunk_size=args.chunk_size,
|
| 222 |
+
max_new_tokens=args.max_new_tokens,
|
| 223 |
+
temperature=args.temperature,
|
| 224 |
+
top_k=args.top_k,
|
| 225 |
+
do_sample=not args.greedy,
|
| 226 |
+
repetition_penalty=args.repetition_penalty,
|
| 227 |
+
xvec_only=False,
|
| 228 |
+
non_streaming_mode=args.non_streaming_mode,
|
| 229 |
+
)
|
| 230 |
+
audio, sr = _stream_to_audio(gen)
|
| 231 |
+
else:
|
| 232 |
+
audio_list, sr = model.generate_voice_clone(
|
| 233 |
+
text=text,
|
| 234 |
+
language=args.language,
|
| 235 |
+
ref_audio=args.ref_audio,
|
| 236 |
+
ref_text=args.ref_text,
|
| 237 |
+
max_new_tokens=args.max_new_tokens,
|
| 238 |
+
temperature=args.temperature,
|
| 239 |
+
top_k=args.top_k,
|
| 240 |
+
do_sample=not args.greedy,
|
| 241 |
+
repetition_penalty=args.repetition_penalty,
|
| 242 |
+
xvec_only=False,
|
| 243 |
+
non_streaming_mode=args.non_streaming_mode,
|
| 244 |
+
)
|
| 245 |
+
audio = audio_list[0]
|
| 246 |
+
elif args.mode == "custom":
|
| 247 |
+
if args.streaming:
|
| 248 |
+
gen = model.generate_custom_voice_streaming(
|
| 249 |
+
text=text,
|
| 250 |
+
speaker=args.speaker,
|
| 251 |
+
language=args.language,
|
| 252 |
+
instruct=args.instruct,
|
| 253 |
+
chunk_size=args.chunk_size,
|
| 254 |
+
max_new_tokens=args.max_new_tokens,
|
| 255 |
+
temperature=args.temperature,
|
| 256 |
+
top_k=args.top_k,
|
| 257 |
+
do_sample=not args.greedy,
|
| 258 |
+
repetition_penalty=args.repetition_penalty,
|
| 259 |
+
)
|
| 260 |
+
audio, sr = _stream_to_audio(gen)
|
| 261 |
+
else:
|
| 262 |
+
audio_list, sr = model.generate_custom_voice(
|
| 263 |
+
text=text,
|
| 264 |
+
speaker=args.speaker,
|
| 265 |
+
language=args.language,
|
| 266 |
+
instruct=args.instruct,
|
| 267 |
+
max_new_tokens=args.max_new_tokens,
|
| 268 |
+
temperature=args.temperature,
|
| 269 |
+
top_k=args.top_k,
|
| 270 |
+
do_sample=not args.greedy,
|
| 271 |
+
repetition_penalty=args.repetition_penalty,
|
| 272 |
+
)
|
| 273 |
+
audio = audio_list[0]
|
| 274 |
+
else:
|
| 275 |
+
if args.streaming:
|
| 276 |
+
gen = model.generate_voice_design_streaming(
|
| 277 |
+
text=text,
|
| 278 |
+
instruct=args.instruct,
|
| 279 |
+
language=args.language,
|
| 280 |
+
chunk_size=args.chunk_size,
|
| 281 |
+
max_new_tokens=args.max_new_tokens,
|
| 282 |
+
temperature=args.temperature,
|
| 283 |
+
top_k=args.top_k,
|
| 284 |
+
do_sample=not args.greedy,
|
| 285 |
+
repetition_penalty=args.repetition_penalty,
|
| 286 |
+
)
|
| 287 |
+
audio, sr = _stream_to_audio(gen)
|
| 288 |
+
else:
|
| 289 |
+
audio_list, sr = model.generate_voice_design(
|
| 290 |
+
text=text,
|
| 291 |
+
instruct=args.instruct,
|
| 292 |
+
language=args.language,
|
| 293 |
+
max_new_tokens=args.max_new_tokens,
|
| 294 |
+
temperature=args.temperature,
|
| 295 |
+
top_k=args.top_k,
|
| 296 |
+
do_sample=not args.greedy,
|
| 297 |
+
repetition_penalty=args.repetition_penalty,
|
| 298 |
+
)
|
| 299 |
+
audio = audio_list[0]
|
| 300 |
+
|
| 301 |
+
_write_audio(out_path, audio, sr)
|
| 302 |
+
total_time = time.perf_counter() - start
|
| 303 |
+
audio_dur = len(audio) / sr if sr else 0.0
|
| 304 |
+
rtf = audio_dur / total_time if total_time > 0 else 0.0
|
| 305 |
+
print(f"Wrote {out_path} (dur {audio_dur:.2f}s, RTF {rtf:.2f})")
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def build_parser():
|
| 309 |
+
p = argparse.ArgumentParser(prog="faster-qwen3-tts", description="FasterQwen3TTS CLI")
|
| 310 |
+
p.add_argument("--device", default="cuda", help="Device (cuda or cpu)")
|
| 311 |
+
p.add_argument("--dtype", default="bf16", choices=["bf16", "fp16", "fp32"], help="Model dtype")
|
| 312 |
+
sub = p.add_subparsers(dest="command", required=True)
|
| 313 |
+
|
| 314 |
+
def add_common(sp):
|
| 315 |
+
sp.add_argument("--text", required=True, help="Text to synthesize")
|
| 316 |
+
sp.add_argument("--language", default="Auto", help="Language (Auto, English, French, ...)" )
|
| 317 |
+
sp.add_argument("--output", required=True, help="Output wav path")
|
| 318 |
+
sp.add_argument("--model", required=True, help="Model id or local path")
|
| 319 |
+
sp.add_argument("--max-new-tokens", type=int, default=2048)
|
| 320 |
+
sp.add_argument("--temperature", type=float, default=0.9)
|
| 321 |
+
sp.add_argument("--top-k", type=int, default=50)
|
| 322 |
+
sp.add_argument("--repetition-penalty", type=float, default=1.05)
|
| 323 |
+
sp.add_argument("--greedy", action="store_true", help="Disable sampling")
|
| 324 |
+
sp.add_argument("--streaming", action="store_true", help="Use streaming generation")
|
| 325 |
+
nsm_group = sp.add_mutually_exclusive_group()
|
| 326 |
+
nsm_group.add_argument(
|
| 327 |
+
"--non-streaming-mode",
|
| 328 |
+
dest="non_streaming_mode",
|
| 329 |
+
action="store_true",
|
| 330 |
+
help="Prefill full text before decode",
|
| 331 |
+
)
|
| 332 |
+
nsm_group.add_argument(
|
| 333 |
+
"--no-non-streaming-mode",
|
| 334 |
+
dest="non_streaming_mode",
|
| 335 |
+
action="store_false",
|
| 336 |
+
help="Use upstream step-by-step text feeding during decode",
|
| 337 |
+
)
|
| 338 |
+
sp.set_defaults(non_streaming_mode=True)
|
| 339 |
+
sp.add_argument("--chunk-size", type=int, default=8, help="Streaming chunk size")
|
| 340 |
+
|
| 341 |
+
sp = sub.add_parser("clone", help="Voice cloning (reference audio)")
|
| 342 |
+
add_common(sp)
|
| 343 |
+
sp.add_argument("--ref-audio", required=True, help="Reference audio path")
|
| 344 |
+
sp.add_argument("--ref-text", required=True, help="Reference transcript")
|
| 345 |
+
sp.add_argument(
|
| 346 |
+
"--xvec-only",
|
| 347 |
+
action="store_true",
|
| 348 |
+
help="Use speaker embedding only instead of upstream-default ICL mode",
|
| 349 |
+
)
|
| 350 |
+
sp.set_defaults(non_streaming_mode=False)
|
| 351 |
+
sp.set_defaults(fn=cmd_clone)
|
| 352 |
+
|
| 353 |
+
sp = sub.add_parser("custom", help="CustomVoice model (speaker IDs)")
|
| 354 |
+
add_common(sp)
|
| 355 |
+
sp.add_argument("--speaker", help="Speaker ID")
|
| 356 |
+
sp.add_argument("--instruct", default="", help="Optional instruction")
|
| 357 |
+
sp.add_argument("--list-speakers", action="store_true", help="List available speaker IDs")
|
| 358 |
+
sp.set_defaults(fn=cmd_custom)
|
| 359 |
+
|
| 360 |
+
sp = sub.add_parser("design", help="VoiceDesign model (instruction-based)")
|
| 361 |
+
add_common(sp)
|
| 362 |
+
sp.add_argument("--instruct", required=True, help="Voice/style instruction")
|
| 363 |
+
sp.set_defaults(fn=cmd_design)
|
| 364 |
+
|
| 365 |
+
sp = sub.add_parser("serve", help="Keep model hot and generate multiple requests from stdin")
|
| 366 |
+
sp.add_argument("--mode", required=True, choices=["clone", "custom", "design"])
|
| 367 |
+
sp.add_argument("--model", required=True, help="Model id or local path")
|
| 368 |
+
sp.add_argument("--language", default="Auto", help="Language (Auto, English, French, ...)")
|
| 369 |
+
sp.add_argument("--ref-audio", help="Reference audio path (clone)")
|
| 370 |
+
sp.add_argument("--ref-text", help="Reference transcript (clone)")
|
| 371 |
+
sp.add_argument("--speaker", help="Speaker ID (custom)")
|
| 372 |
+
sp.add_argument("--instruct", default="", help="Instruction (custom/design)")
|
| 373 |
+
sp.add_argument("--streaming", action="store_true", help="Use streaming generation")
|
| 374 |
+
nsm_group = sp.add_mutually_exclusive_group()
|
| 375 |
+
nsm_group.add_argument(
|
| 376 |
+
"--non-streaming-mode",
|
| 377 |
+
dest="non_streaming_mode",
|
| 378 |
+
action="store_true",
|
| 379 |
+
help="Prefill full text before decode",
|
| 380 |
+
)
|
| 381 |
+
nsm_group.add_argument(
|
| 382 |
+
"--no-non-streaming-mode",
|
| 383 |
+
dest="non_streaming_mode",
|
| 384 |
+
action="store_false",
|
| 385 |
+
help="Use upstream step-by-step text feeding during decode",
|
| 386 |
+
)
|
| 387 |
+
sp.set_defaults(non_streaming_mode=False)
|
| 388 |
+
sp.add_argument("--chunk-size", type=int, default=8, help="Streaming chunk size")
|
| 389 |
+
sp.add_argument("--max-new-tokens", type=int, default=2048)
|
| 390 |
+
sp.add_argument("--temperature", type=float, default=0.9)
|
| 391 |
+
sp.add_argument("--top-k", type=int, default=50)
|
| 392 |
+
sp.add_argument("--repetition-penalty", type=float, default=1.05)
|
| 393 |
+
sp.add_argument("--greedy", action="store_true", help="Disable sampling")
|
| 394 |
+
sp.add_argument("--output-dir", default="outputs", help="Directory for output wavs")
|
| 395 |
+
sp.set_defaults(fn=cmd_serve)
|
| 396 |
+
|
| 397 |
+
return p
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def main():
|
| 401 |
+
parser = build_parser()
|
| 402 |
+
args = parser.parse_args()
|
| 403 |
+
args.fn(args)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
if __name__ == "__main__":
|
| 407 |
+
main()
|
faster_qwen3_tts/generate.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Non-streaming generation loop using CUDA graphs for both predictor and talker.
|
| 4 |
+
"""
|
| 5 |
+
import time
|
| 6 |
+
from typing import Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from .predictor_graph import PredictorGraph
|
| 11 |
+
from .sampling import apply_repetition_penalty, sample_logits
|
| 12 |
+
from .talker_graph import TalkerGraph
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@torch.inference_mode()
|
| 16 |
+
def fast_generate(
|
| 17 |
+
talker,
|
| 18 |
+
talker_input_embeds: torch.Tensor,
|
| 19 |
+
attention_mask: torch.Tensor,
|
| 20 |
+
trailing_text_hiddens: torch.Tensor,
|
| 21 |
+
tts_pad_embed: torch.Tensor,
|
| 22 |
+
config,
|
| 23 |
+
predictor_graph: PredictorGraph,
|
| 24 |
+
talker_graph: TalkerGraph,
|
| 25 |
+
max_new_tokens: int = 2048,
|
| 26 |
+
min_new_tokens: int = 2,
|
| 27 |
+
temperature: float = 0.9,
|
| 28 |
+
top_k: int = 50,
|
| 29 |
+
top_p: float = 1.0,
|
| 30 |
+
do_sample: bool = True,
|
| 31 |
+
repetition_penalty: float = 1.05,
|
| 32 |
+
subtalker_dosample: Optional[bool] = None,
|
| 33 |
+
subtalker_top_k: Optional[int] = None,
|
| 34 |
+
subtalker_top_p: Optional[float] = None,
|
| 35 |
+
subtalker_temperature: Optional[float] = None,
|
| 36 |
+
parity_mode: bool = False,
|
| 37 |
+
) -> Tuple[Optional[torch.Tensor], dict]:
|
| 38 |
+
"""
|
| 39 |
+
Fast autoregressive generation with CUDA-graphed predictor and talker.
|
| 40 |
+
"""
|
| 41 |
+
eos_id = config.codec_eos_token_id
|
| 42 |
+
num_code_groups = config.num_code_groups
|
| 43 |
+
vocab_size = config.vocab_size
|
| 44 |
+
device = talker_input_embeds.device
|
| 45 |
+
|
| 46 |
+
suppress_mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
|
| 47 |
+
suppress_start = max(0, vocab_size - 1024)
|
| 48 |
+
for i in range(suppress_start, vocab_size):
|
| 49 |
+
if i != eos_id:
|
| 50 |
+
suppress_mask[i] = True
|
| 51 |
+
|
| 52 |
+
if parity_mode:
|
| 53 |
+
suppress_tokens = [i for i in range(suppress_start, vocab_size) if i != eos_id]
|
| 54 |
+
t_start = time.time()
|
| 55 |
+
talker_result = talker.generate(
|
| 56 |
+
inputs_embeds=talker_input_embeds,
|
| 57 |
+
attention_mask=attention_mask,
|
| 58 |
+
trailing_text_hidden=trailing_text_hiddens,
|
| 59 |
+
tts_pad_embed=tts_pad_embed,
|
| 60 |
+
max_new_tokens=max_new_tokens,
|
| 61 |
+
min_new_tokens=min_new_tokens,
|
| 62 |
+
do_sample=do_sample,
|
| 63 |
+
top_k=top_k,
|
| 64 |
+
top_p=top_p,
|
| 65 |
+
temperature=temperature,
|
| 66 |
+
repetition_penalty=repetition_penalty,
|
| 67 |
+
eos_token_id=eos_id,
|
| 68 |
+
suppress_tokens=suppress_tokens,
|
| 69 |
+
subtalker_dosample=subtalker_dosample if subtalker_dosample is not None else do_sample,
|
| 70 |
+
subtalker_top_k=subtalker_top_k if subtalker_top_k is not None else top_k,
|
| 71 |
+
subtalker_top_p=subtalker_top_p if subtalker_top_p is not None else top_p,
|
| 72 |
+
subtalker_temperature=subtalker_temperature if subtalker_temperature is not None else temperature,
|
| 73 |
+
output_hidden_states=True,
|
| 74 |
+
return_dict_in_generate=True,
|
| 75 |
+
)
|
| 76 |
+
talker_codes = torch.stack(
|
| 77 |
+
[hid[-1] for hid in talker_result.hidden_states if hid[-1] is not None],
|
| 78 |
+
dim=1,
|
| 79 |
+
)
|
| 80 |
+
first_codebook = talker_codes[:, :, 0]
|
| 81 |
+
is_stop_token = first_codebook == eos_id
|
| 82 |
+
stop_indices = torch.argmax(is_stop_token.int(), dim=1)
|
| 83 |
+
has_stop_token = is_stop_token.any(dim=1)
|
| 84 |
+
effective_lengths = torch.where(has_stop_token, stop_indices, talker_codes.shape[1])
|
| 85 |
+
talker_codes_list = [talker_codes[i, :length, :] for i, length in enumerate(effective_lengths)]
|
| 86 |
+
|
| 87 |
+
torch.cuda.synchronize()
|
| 88 |
+
total_time = time.time() - t_start
|
| 89 |
+
steps = int(talker_codes_list[0].shape[0]) if talker_codes_list else 0
|
| 90 |
+
timing = {
|
| 91 |
+
'prefill_ms': 0.0,
|
| 92 |
+
'decode_s': total_time,
|
| 93 |
+
'steps': steps,
|
| 94 |
+
'ms_per_step': (total_time / steps * 1000) if steps > 0 else 0.0,
|
| 95 |
+
'steps_per_s': (steps / total_time) if total_time > 0 else 0.0,
|
| 96 |
+
}
|
| 97 |
+
return talker_codes_list[0] if talker_codes_list else None, timing
|
| 98 |
+
|
| 99 |
+
predictor = talker.code_predictor
|
| 100 |
+
talker_codec_embed = talker.get_input_embeddings()
|
| 101 |
+
talker_codec_head = talker.codec_head
|
| 102 |
+
predictor_codec_embeds = predictor.get_input_embeddings()
|
| 103 |
+
|
| 104 |
+
# === PREFILL (still uses HF forward for variable-length prefill) ===
|
| 105 |
+
t_start = time.time()
|
| 106 |
+
|
| 107 |
+
out = talker.forward(
|
| 108 |
+
inputs_embeds=talker_input_embeds,
|
| 109 |
+
attention_mask=attention_mask,
|
| 110 |
+
use_cache=True,
|
| 111 |
+
output_hidden_states=True,
|
| 112 |
+
return_dict=True,
|
| 113 |
+
trailing_text_hidden=trailing_text_hiddens,
|
| 114 |
+
tts_pad_embed=tts_pad_embed,
|
| 115 |
+
generation_step=None,
|
| 116 |
+
past_hidden=None,
|
| 117 |
+
past_key_values=None,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
talker_past_kv = out.past_key_values
|
| 121 |
+
past_hidden = out.past_hidden
|
| 122 |
+
gen_step = out.generation_step
|
| 123 |
+
|
| 124 |
+
logits = out.logits[:, -1, :]
|
| 125 |
+
suppress_eos = min_new_tokens > 0
|
| 126 |
+
token = sample_logits(
|
| 127 |
+
logits,
|
| 128 |
+
temperature=temperature,
|
| 129 |
+
top_k=top_k,
|
| 130 |
+
top_p=top_p,
|
| 131 |
+
do_sample=do_sample,
|
| 132 |
+
suppress_mask=suppress_mask,
|
| 133 |
+
suppress_tokens=[eos_id] if suppress_eos else None,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
# Copy prefill KV cache into talker graph's static cache
|
| 137 |
+
prefill_len = talker_graph.prefill_kv(talker_past_kv)
|
| 138 |
+
# Sync padding mask + rope deltas for decode parity
|
| 139 |
+
rope_deltas = getattr(talker, "rope_deltas", None)
|
| 140 |
+
talker_graph.set_generation_state(attention_mask, rope_deltas)
|
| 141 |
+
|
| 142 |
+
torch.cuda.synchronize()
|
| 143 |
+
t_prefill = time.time() - t_start
|
| 144 |
+
|
| 145 |
+
# === DECODE LOOP ===
|
| 146 |
+
t_decode_start = time.time()
|
| 147 |
+
all_codec_ids = []
|
| 148 |
+
|
| 149 |
+
for step_idx in range(max_new_tokens):
|
| 150 |
+
if token.item() == eos_id:
|
| 151 |
+
break
|
| 152 |
+
|
| 153 |
+
# --- CUDA-Graphed Code Predictor ---
|
| 154 |
+
last_id_hidden = talker_codec_embed(token.unsqueeze(1)) # [1, 1, H]
|
| 155 |
+
pred_input = torch.cat((past_hidden, last_id_hidden), dim=1) # [1, 2, H]
|
| 156 |
+
codebook_token_ids = predictor_graph.run(pred_input) # [15] long tensor
|
| 157 |
+
|
| 158 |
+
# Build full codec: [first_cb, cb1, ..., cb15]
|
| 159 |
+
all_cb = torch.cat([token.view(1), codebook_token_ids]) # [16]
|
| 160 |
+
all_codec_ids.append(all_cb.detach())
|
| 161 |
+
|
| 162 |
+
# --- Build input embedding for talker ---
|
| 163 |
+
codec_hiddens = [last_id_hidden]
|
| 164 |
+
for i in range(num_code_groups - 1):
|
| 165 |
+
codec_hiddens.append(predictor_codec_embeds[i](codebook_token_ids[i].unsqueeze(0).unsqueeze(0)))
|
| 166 |
+
inputs_embeds = torch.cat(codec_hiddens, dim=1).sum(1, keepdim=True)
|
| 167 |
+
|
| 168 |
+
if gen_step < trailing_text_hiddens.shape[1]:
|
| 169 |
+
inputs_embeds = inputs_embeds + trailing_text_hiddens[:, gen_step].unsqueeze(1)
|
| 170 |
+
else:
|
| 171 |
+
inputs_embeds = inputs_embeds + tts_pad_embed
|
| 172 |
+
|
| 173 |
+
# --- CUDA-Graphed Talker decode step ---
|
| 174 |
+
current_pos = prefill_len + step_idx
|
| 175 |
+
if current_pos >= talker_graph.max_seq_len - 1:
|
| 176 |
+
# Stop if we exceed max_seq_len
|
| 177 |
+
break
|
| 178 |
+
|
| 179 |
+
hidden_states = talker_graph.run(inputs_embeds, position=current_pos)
|
| 180 |
+
# hidden_states is the static output buffer - use it immediately
|
| 181 |
+
|
| 182 |
+
logits = talker_codec_head(hidden_states[:, -1, :]).unsqueeze(0)
|
| 183 |
+
|
| 184 |
+
if repetition_penalty != 1.0 and len(all_codec_ids) > 0:
|
| 185 |
+
history = torch.stack([c[0] for c in all_codec_ids])
|
| 186 |
+
logits = apply_repetition_penalty(logits, history, repetition_penalty)
|
| 187 |
+
|
| 188 |
+
suppress_eos = len(all_codec_ids) < min_new_tokens
|
| 189 |
+
token = sample_logits(
|
| 190 |
+
logits.squeeze(0),
|
| 191 |
+
temperature=temperature,
|
| 192 |
+
top_k=top_k,
|
| 193 |
+
top_p=top_p,
|
| 194 |
+
do_sample=do_sample,
|
| 195 |
+
suppress_mask=suppress_mask,
|
| 196 |
+
suppress_tokens=[eos_id] if suppress_eos else None,
|
| 197 |
+
)
|
| 198 |
+
past_hidden = hidden_states[:, -1:, :].clone() # clone since it's the static buffer
|
| 199 |
+
gen_step += 1
|
| 200 |
+
|
| 201 |
+
torch.cuda.synchronize()
|
| 202 |
+
t_decode = time.time() - t_decode_start
|
| 203 |
+
|
| 204 |
+
n_steps = len(all_codec_ids)
|
| 205 |
+
timing = {
|
| 206 |
+
'prefill_ms': t_prefill * 1000,
|
| 207 |
+
'decode_s': t_decode,
|
| 208 |
+
'steps': n_steps,
|
| 209 |
+
'ms_per_step': (t_decode / n_steps * 1000) if n_steps > 0 else 0,
|
| 210 |
+
'steps_per_s': (n_steps / t_decode) if t_decode > 0 else 0,
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
if all_codec_ids:
|
| 214 |
+
return torch.stack(all_codec_ids), timing
|
| 215 |
+
return None, timing
|
faster_qwen3_tts/model.py
ADDED
|
@@ -0,0 +1,1370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FasterQwen3TTS: Real-time TTS using CUDA graph capture.
|
| 3 |
+
|
| 4 |
+
Wrapper class that provides a Qwen3-TTS API while using
|
| 5 |
+
CUDA graphs for 6-10x speedup.
|
| 6 |
+
"""
|
| 7 |
+
import logging
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import soundfile as sf
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from .utils import suppress_flash_attn_warning
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class FasterQwen3TTS:
|
| 23 |
+
"""
|
| 24 |
+
Qwen3-TTS model with CUDA graphs for real-time inference.
|
| 25 |
+
|
| 26 |
+
Compatible API with Qwen3TTSModel, but uses CUDA graph
|
| 27 |
+
capture for 6-10x speedup on NVIDIA GPUs.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
base_model,
|
| 33 |
+
predictor_graph,
|
| 34 |
+
talker_graph,
|
| 35 |
+
device: str = "cuda",
|
| 36 |
+
dtype: torch.dtype = torch.bfloat16,
|
| 37 |
+
max_seq_len: int = 2048,
|
| 38 |
+
):
|
| 39 |
+
self.model = base_model # The qwen-tts Qwen3TTSModel instance
|
| 40 |
+
self.predictor_graph = predictor_graph
|
| 41 |
+
self.talker_graph = talker_graph
|
| 42 |
+
self.device = device
|
| 43 |
+
self.dtype = dtype
|
| 44 |
+
self.max_seq_len = max_seq_len
|
| 45 |
+
self.sample_rate = self._infer_sample_rate(base_model)
|
| 46 |
+
self._warmed_up = False
|
| 47 |
+
self._voice_prompt_cache = {} # Cache (ref_audio, ref_text) -> (vcp, ref_ids)
|
| 48 |
+
|
| 49 |
+
@staticmethod
|
| 50 |
+
def _get_speech_tokenizer(base_model):
|
| 51 |
+
"""Return the nested qwen-tts speech tokenizer when available."""
|
| 52 |
+
return getattr(getattr(base_model, "model", None), "speech_tokenizer", None)
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def speech_tokenizer(self):
|
| 56 |
+
"""Expose the codec decoder on the wrapper's public surface."""
|
| 57 |
+
speech_tokenizer = self._get_speech_tokenizer(self.model)
|
| 58 |
+
if speech_tokenizer is None:
|
| 59 |
+
raise AttributeError("Underlying model does not expose a speech_tokenizer")
|
| 60 |
+
return speech_tokenizer
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
def _infer_sample_rate(base_model) -> int:
|
| 64 |
+
"""Infer output audio sample rate from qwen-tts internals."""
|
| 65 |
+
# Qwen3-TTS model IDs include "12Hz", but that is codec frame-rate (tokens/s),
|
| 66 |
+
# not waveform sampling rate. Generated audio is 24kHz.
|
| 67 |
+
sample_rate = None
|
| 68 |
+
|
| 69 |
+
speech_tokenizer = FasterQwen3TTS._get_speech_tokenizer(base_model)
|
| 70 |
+
if speech_tokenizer is not None:
|
| 71 |
+
sample_rate = getattr(speech_tokenizer, "sample_rate", None)
|
| 72 |
+
|
| 73 |
+
if sample_rate is None:
|
| 74 |
+
sample_rate = getattr(base_model, "sample_rate", None)
|
| 75 |
+
|
| 76 |
+
if sample_rate is None:
|
| 77 |
+
logger.warning(
|
| 78 |
+
"Could not infer sample rate from base model; defaulting to 24000 Hz."
|
| 79 |
+
)
|
| 80 |
+
return 24000
|
| 81 |
+
|
| 82 |
+
return int(sample_rate)
|
| 83 |
+
|
| 84 |
+
@classmethod
|
| 85 |
+
def from_pretrained(
|
| 86 |
+
cls,
|
| 87 |
+
model_name: str,
|
| 88 |
+
device: str = "cuda",
|
| 89 |
+
dtype: Union[str, torch.dtype] = torch.bfloat16,
|
| 90 |
+
attn_implementation: str = "sdpa",
|
| 91 |
+
max_seq_len: int = 2048,
|
| 92 |
+
):
|
| 93 |
+
"""
|
| 94 |
+
Load Qwen3-TTS model and prepare CUDA graphs.
|
| 95 |
+
|
| 96 |
+
Args:
|
| 97 |
+
model_name: Model path or HuggingFace Hub ID
|
| 98 |
+
device: Device to use ("cuda" or "cpu")
|
| 99 |
+
dtype: Data type for inference
|
| 100 |
+
attn_implementation: Attention implementation ("sdpa" or "flash_attention_2")
|
| 101 |
+
max_seq_len: Maximum sequence length for static cache
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
FasterQwen3TTS instance
|
| 105 |
+
"""
|
| 106 |
+
if isinstance(dtype, str):
|
| 107 |
+
dtype = getattr(torch, dtype)
|
| 108 |
+
|
| 109 |
+
if not device.startswith("cuda") or not torch.cuda.is_available():
|
| 110 |
+
raise ValueError("CUDA graphs require CUDA device")
|
| 111 |
+
|
| 112 |
+
logger.info(f"Loading Qwen3-TTS model: {model_name}")
|
| 113 |
+
|
| 114 |
+
# Import here to avoid dependency issues (and suppress flash-attn warning)
|
| 115 |
+
with suppress_flash_attn_warning():
|
| 116 |
+
from qwen_tts import Qwen3TTSModel
|
| 117 |
+
from .predictor_graph import PredictorGraph
|
| 118 |
+
from .talker_graph import TalkerGraph
|
| 119 |
+
# Load base model using qwen-tts library
|
| 120 |
+
base_model = Qwen3TTSModel.from_pretrained(
|
| 121 |
+
model_name,
|
| 122 |
+
device_map=device,
|
| 123 |
+
torch_dtype=dtype,
|
| 124 |
+
attn_implementation=attn_implementation,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
talker = base_model.model.talker
|
| 128 |
+
talker_config = base_model.model.config.talker_config
|
| 129 |
+
|
| 130 |
+
# Extract predictor config from loaded model
|
| 131 |
+
predictor = talker.code_predictor
|
| 132 |
+
pred_config = predictor.model.config
|
| 133 |
+
talker_hidden = talker_config.hidden_size
|
| 134 |
+
|
| 135 |
+
# Build CUDA graphs
|
| 136 |
+
logger.info("Building CUDA graphs...")
|
| 137 |
+
predictor_graph = PredictorGraph(
|
| 138 |
+
predictor,
|
| 139 |
+
pred_config,
|
| 140 |
+
talker_hidden,
|
| 141 |
+
device=device,
|
| 142 |
+
dtype=dtype,
|
| 143 |
+
do_sample=True, # subtalker_dosample (Default: True)
|
| 144 |
+
top_k=50, # subtalker_top_k (Default: 50)
|
| 145 |
+
top_p=1.0, # subtalker_top_p (Default: 1.0)
|
| 146 |
+
temperature=0.2, # subtalker_temperature (Default: 0.9)
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
talker_graph = TalkerGraph(
|
| 150 |
+
talker.model,
|
| 151 |
+
talker_config,
|
| 152 |
+
device=device,
|
| 153 |
+
dtype=dtype,
|
| 154 |
+
max_seq_len=max_seq_len,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
logger.info("CUDA graphs initialized (will capture on first run)")
|
| 158 |
+
|
| 159 |
+
return cls(
|
| 160 |
+
base_model=base_model,
|
| 161 |
+
predictor_graph=predictor_graph,
|
| 162 |
+
talker_graph=talker_graph,
|
| 163 |
+
device=device,
|
| 164 |
+
dtype=dtype,
|
| 165 |
+
max_seq_len=max_seq_len,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def _warmup(self, prefill_len: int):
|
| 169 |
+
"""Warm up and capture CUDA graphs with given prefill length."""
|
| 170 |
+
if self._warmed_up:
|
| 171 |
+
return
|
| 172 |
+
|
| 173 |
+
logger.info("Warming up CUDA graphs...")
|
| 174 |
+
self.predictor_graph.capture(num_warmup=3)
|
| 175 |
+
self.talker_graph.capture(prefill_len=prefill_len, num_warmup=3)
|
| 176 |
+
self._warmed_up = True
|
| 177 |
+
logger.info("CUDA graphs captured and ready")
|
| 178 |
+
|
| 179 |
+
def generate(
|
| 180 |
+
self,
|
| 181 |
+
text: str,
|
| 182 |
+
language: str = "English",
|
| 183 |
+
max_new_tokens: int = 2048,
|
| 184 |
+
temperature: float = 0.9,
|
| 185 |
+
top_k: int = 50,
|
| 186 |
+
do_sample: bool = True,
|
| 187 |
+
repetition_penalty: float = 1.05,
|
| 188 |
+
) -> Tuple[list, int]:
|
| 189 |
+
"""
|
| 190 |
+
Generate speech from text using default voice.
|
| 191 |
+
|
| 192 |
+
Not yet implemented - use generate_voice_clone() instead.
|
| 193 |
+
"""
|
| 194 |
+
raise NotImplementedError(
|
| 195 |
+
"Default voice generation not yet implemented. "
|
| 196 |
+
"Use generate_voice_clone() with reference audio."
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
def _load_ref_audio_with_silence(self, ref_audio: Union[str, Path, tuple], silence_secs: float = 0.5) -> Tuple[np.ndarray, int]:
|
| 200 |
+
"""Load reference audio and optionally append trailing silence.
|
| 201 |
+
|
| 202 |
+
The ICL voice-cloning prompt ends with the last codec token of the reference
|
| 203 |
+
audio, so the model's first generated token is conditioned on whatever phoneme
|
| 204 |
+
the reference ends with. Appending a short silence makes the last tokens
|
| 205 |
+
encode silence instead, preventing that phoneme from bleeding into the start
|
| 206 |
+
of the generated speech. Set silence_secs=0 to disable this behavior.
|
| 207 |
+
"""
|
| 208 |
+
if isinstance(ref_audio, tuple):
|
| 209 |
+
audio, sr = ref_audio
|
| 210 |
+
else:
|
| 211 |
+
audio, sr = sf.read(str(ref_audio), dtype="float32", always_2d=False)
|
| 212 |
+
|
| 213 |
+
if audio.ndim > 1:
|
| 214 |
+
audio = audio.mean(axis=1) # convert to mono
|
| 215 |
+
if silence_secs > 0:
|
| 216 |
+
silence = np.zeros(int(silence_secs * sr), dtype=np.float32)
|
| 217 |
+
audio = np.concatenate([audio, silence])
|
| 218 |
+
return audio, sr
|
| 219 |
+
|
| 220 |
+
def _resolve_voice_clone_prompt(
|
| 221 |
+
self,
|
| 222 |
+
input_ids,
|
| 223 |
+
ref_audio: Optional[Union[str, Path, tuple]],
|
| 224 |
+
ref_text: str,
|
| 225 |
+
xvec_only: bool,
|
| 226 |
+
append_silence: bool,
|
| 227 |
+
voice_clone_prompt: Optional[Union[Dict[str, Any], List[Any]]],
|
| 228 |
+
) -> Tuple[Dict[str, Any], list, bool]:
|
| 229 |
+
"""Resolve voice clone prompt data and return (prompt, ref_ids, using_icl_mode)."""
|
| 230 |
+
if voice_clone_prompt is not None:
|
| 231 |
+
return self._resolve_precomputed_voice_clone_prompt(
|
| 232 |
+
input_ids=input_ids,
|
| 233 |
+
ref_text=ref_text,
|
| 234 |
+
voice_clone_prompt=voice_clone_prompt,
|
| 235 |
+
)
|
| 236 |
+
if ref_audio is None:
|
| 237 |
+
raise ValueError("ref_audio is required when voice_clone_prompt is not provided")
|
| 238 |
+
|
| 239 |
+
return self._resolve_voice_clone_prompt_from_reference(
|
| 240 |
+
input_ids=input_ids,
|
| 241 |
+
ref_audio=ref_audio,
|
| 242 |
+
ref_text=ref_text,
|
| 243 |
+
xvec_only=xvec_only,
|
| 244 |
+
append_silence=append_silence,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
def _resolve_precomputed_voice_clone_prompt(
|
| 248 |
+
self,
|
| 249 |
+
input_ids,
|
| 250 |
+
ref_text: str,
|
| 251 |
+
voice_clone_prompt: Union[Dict[str, Any], List[Any]],
|
| 252 |
+
) -> Tuple[Dict[str, Any], list, bool]:
|
| 253 |
+
if isinstance(voice_clone_prompt, list):
|
| 254 |
+
if len(voice_clone_prompt) != len(input_ids):
|
| 255 |
+
raise ValueError(
|
| 256 |
+
f"voice_clone_prompt must have length {len(input_ids)}, got {len(voice_clone_prompt)}"
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
vcp = self.model._prompt_items_to_voice_clone_prompt(voice_clone_prompt)
|
| 260 |
+
ref_ids = []
|
| 261 |
+
for item in voice_clone_prompt:
|
| 262 |
+
if bool(item.icl_mode):
|
| 263 |
+
item_ref_text = item.ref_text if item.ref_text else ref_text
|
| 264 |
+
if not item_ref_text:
|
| 265 |
+
raise ValueError(
|
| 266 |
+
"ref_text is required when voice_clone_prompt uses ICL mode."
|
| 267 |
+
)
|
| 268 |
+
ref_id = self.model._tokenize_texts(
|
| 269 |
+
[self.model._build_ref_text(item_ref_text)]
|
| 270 |
+
)[0]
|
| 271 |
+
ref_ids.append(ref_id)
|
| 272 |
+
else:
|
| 273 |
+
ref_ids.append(None)
|
| 274 |
+
|
| 275 |
+
return vcp, ref_ids, any(vcp["icl_mode"])
|
| 276 |
+
|
| 277 |
+
required_keys = ("ref_spk_embedding",)
|
| 278 |
+
missing = [k for k in required_keys if k not in voice_clone_prompt]
|
| 279 |
+
if missing:
|
| 280 |
+
raise ValueError(
|
| 281 |
+
f"voice_clone_prompt missing required keys: {missing}. "
|
| 282 |
+
f"Expected keys: {list(required_keys)}"
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
list_keys = ("ref_spk_embedding", "x_vector_only_mode", "icl_mode", "ref_code")
|
| 286 |
+
for key in list_keys:
|
| 287 |
+
if key not in voice_clone_prompt:
|
| 288 |
+
continue
|
| 289 |
+
value = voice_clone_prompt[key]
|
| 290 |
+
if not isinstance(value, list) or len(value) != len(input_ids):
|
| 291 |
+
raise ValueError(
|
| 292 |
+
f"voice_clone_prompt[{key!r}] must be a list with length {len(input_ids)}"
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
xvec_modes = voice_clone_prompt.get("x_vector_only_mode", [True] * len(input_ids))
|
| 296 |
+
if "icl_mode" in voice_clone_prompt:
|
| 297 |
+
icl_modes = [bool(v) for v in voice_clone_prompt["icl_mode"]]
|
| 298 |
+
for i, (xvec_mode, icl_mode) in enumerate(zip(xvec_modes, icl_modes)):
|
| 299 |
+
if bool(xvec_mode) == bool(icl_mode):
|
| 300 |
+
raise ValueError(
|
| 301 |
+
f"voice_clone_prompt has inconsistent mode flags at index {i}: "
|
| 302 |
+
"x_vector_only_mode and icl_mode must be opposites"
|
| 303 |
+
)
|
| 304 |
+
else:
|
| 305 |
+
icl_modes = [not bool(v) for v in xvec_modes]
|
| 306 |
+
|
| 307 |
+
ref_codes = voice_clone_prompt.get("ref_code", [None] * len(input_ids))
|
| 308 |
+
for i, (xvec_mode, icl_mode, ref_code) in enumerate(zip(xvec_modes, icl_modes, ref_codes)):
|
| 309 |
+
if bool(xvec_mode) and ref_code is not None:
|
| 310 |
+
raise ValueError(
|
| 311 |
+
f"voice_clone_prompt index {i}: ref_code must be None in x_vector_only mode"
|
| 312 |
+
)
|
| 313 |
+
if bool(icl_mode) and ref_code is None:
|
| 314 |
+
raise ValueError(
|
| 315 |
+
f"voice_clone_prompt index {i}: ref_code is required in ICL mode"
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
vcp = dict(
|
| 319 |
+
ref_code=ref_codes,
|
| 320 |
+
ref_spk_embedding=voice_clone_prompt["ref_spk_embedding"],
|
| 321 |
+
x_vector_only_mode=[bool(v) for v in xvec_modes],
|
| 322 |
+
icl_mode=[bool(v) for v in icl_modes],
|
| 323 |
+
)
|
| 324 |
+
using_icl_mode = any(vcp["icl_mode"])
|
| 325 |
+
|
| 326 |
+
if using_icl_mode:
|
| 327 |
+
if not ref_text:
|
| 328 |
+
raise ValueError(
|
| 329 |
+
"ref_text is required when voice_clone_prompt uses ICL mode."
|
| 330 |
+
)
|
| 331 |
+
ref_texts = [self.model._build_ref_text(ref_text)]
|
| 332 |
+
# NOTE: single ref_text is shared across all ICL items in the batch.
|
| 333 |
+
ref_id = self.model._tokenize_texts(ref_texts)[0]
|
| 334 |
+
ref_ids = [ref_id if is_icl else None for is_icl in vcp["icl_mode"]]
|
| 335 |
+
else:
|
| 336 |
+
ref_ids = [None] * len(input_ids)
|
| 337 |
+
|
| 338 |
+
return vcp, ref_ids, using_icl_mode
|
| 339 |
+
|
| 340 |
+
def _resolve_voice_clone_prompt_from_reference(
|
| 341 |
+
self,
|
| 342 |
+
input_ids,
|
| 343 |
+
ref_audio: Union[str, Path, tuple],
|
| 344 |
+
ref_text: str,
|
| 345 |
+
xvec_only: bool,
|
| 346 |
+
append_silence: bool,
|
| 347 |
+
) -> Tuple[Dict[str, Any], list, bool]:
|
| 348 |
+
using_icl_mode = not xvec_only
|
| 349 |
+
cache_key = (str(ref_audio), ref_text, xvec_only, append_silence)
|
| 350 |
+
if cache_key in self._voice_prompt_cache:
|
| 351 |
+
vcp, ref_ids = self._voice_prompt_cache[cache_key]
|
| 352 |
+
return vcp, ref_ids, using_icl_mode
|
| 353 |
+
|
| 354 |
+
if xvec_only:
|
| 355 |
+
prompt_items = self.model.create_voice_clone_prompt(
|
| 356 |
+
ref_audio=str(ref_audio),
|
| 357 |
+
ref_text="",
|
| 358 |
+
x_vector_only_mode=True,
|
| 359 |
+
)
|
| 360 |
+
spk_emb = prompt_items[0].ref_spk_embedding
|
| 361 |
+
vcp = dict(
|
| 362 |
+
ref_code=[None],
|
| 363 |
+
ref_spk_embedding=[spk_emb],
|
| 364 |
+
x_vector_only_mode=[True],
|
| 365 |
+
icl_mode=[False],
|
| 366 |
+
)
|
| 367 |
+
ref_ids = [None] * len(input_ids)
|
| 368 |
+
self._voice_prompt_cache[cache_key] = (vcp, ref_ids)
|
| 369 |
+
return vcp, ref_ids, using_icl_mode
|
| 370 |
+
|
| 371 |
+
silence_secs = 0.5 if append_silence else 0.0
|
| 372 |
+
ref_audio_input = self._load_ref_audio_with_silence(ref_audio, silence_secs=silence_secs)
|
| 373 |
+
prompt_items = self.model.create_voice_clone_prompt(
|
| 374 |
+
ref_audio=ref_audio_input,
|
| 375 |
+
ref_text=ref_text
|
| 376 |
+
)
|
| 377 |
+
vcp = self.model._prompt_items_to_voice_clone_prompt(prompt_items)
|
| 378 |
+
|
| 379 |
+
ref_ids = []
|
| 380 |
+
rt = prompt_items[0].ref_text
|
| 381 |
+
if rt:
|
| 382 |
+
ref_texts = [self.model._build_ref_text(rt)]
|
| 383 |
+
ref_ids.append(self.model._tokenize_texts(ref_texts)[0])
|
| 384 |
+
else:
|
| 385 |
+
ref_ids.append(None)
|
| 386 |
+
|
| 387 |
+
self._voice_prompt_cache[cache_key] = (vcp, ref_ids)
|
| 388 |
+
return vcp, ref_ids, using_icl_mode
|
| 389 |
+
|
| 390 |
+
def _prepare_generation(
|
| 391 |
+
self,
|
| 392 |
+
text: str,
|
| 393 |
+
ref_audio: Optional[Union[str, Path, tuple]] = None,
|
| 394 |
+
ref_text: str = "",
|
| 395 |
+
language: str = "English",
|
| 396 |
+
xvec_only: bool = False,
|
| 397 |
+
non_streaming_mode: bool = False,
|
| 398 |
+
append_silence: bool = True,
|
| 399 |
+
voice_clone_prompt: Optional[Union[Dict[str, Any], List[Any]]] = None,
|
| 400 |
+
instruct: Optional[str] = None,
|
| 401 |
+
):
|
| 402 |
+
"""Prepare inputs for generation (shared by streaming and non-streaming).
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
xvec_only: When True, use only the speaker embedding (x-vector) for voice
|
| 406 |
+
cloning instead of the full ICL acoustic prompt. This prevents the model from
|
| 407 |
+
continuing the reference audio's last phoneme and allows natural language switching.
|
| 408 |
+
Default False to match upstream ICL behavior, where the full reference
|
| 409 |
+
audio codec tokens are included in context.
|
| 410 |
+
voice_clone_prompt: Optional precomputed prompt dict from
|
| 411 |
+
`create_voice_clone_prompt`/`_prompt_items_to_voice_clone_prompt`.
|
| 412 |
+
When provided, `xvec_only` is ignored. This path supports both:
|
| 413 |
+
x-vector-only prompts (`ref_spk_embedding` only) and ICL prompts
|
| 414 |
+
(`ref_spk_embedding` + `ref_code` + mode flags). `ref_text` is ignored
|
| 415 |
+
for x-vector-only and required for ICL.
|
| 416 |
+
instruct: Optional instruction string to guide generation style/language (e.g.
|
| 417 |
+
"请用纯正广东话朗读"). Prepended as a user turn before the assistant TTS turn.
|
| 418 |
+
"""
|
| 419 |
+
input_texts = [self.model._build_assistant_text(text)]
|
| 420 |
+
input_ids = self.model._tokenize_texts(input_texts)
|
| 421 |
+
|
| 422 |
+
instruct_ids = [None]
|
| 423 |
+
if instruct:
|
| 424 |
+
instruct_ids = [self.model._tokenize_texts([self.model._build_instruct_text(instruct)])[0]]
|
| 425 |
+
|
| 426 |
+
vcp, ref_ids, using_icl_mode = self._resolve_voice_clone_prompt(
|
| 427 |
+
input_ids=input_ids,
|
| 428 |
+
ref_audio=ref_audio,
|
| 429 |
+
ref_text=ref_text,
|
| 430 |
+
xvec_only=xvec_only,
|
| 431 |
+
append_silence=append_silence,
|
| 432 |
+
voice_clone_prompt=voice_clone_prompt,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
if instruct and not using_icl_mode:
|
| 436 |
+
logger.warning(
|
| 437 |
+
"Base-model instruct with x-vector-only voice cloning is experimental. "
|
| 438 |
+
"Upstream Qwen3-TTS itself does not follow instructions reliably in this "
|
| 439 |
+
"mode. Prefer xvec_only=False (ICL mode) when using instruct for voice "
|
| 440 |
+
"cloning."
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
m = self.model.model
|
| 444 |
+
|
| 445 |
+
tie, tam, tth, tpe = self._build_talker_inputs_local(
|
| 446 |
+
m=m,
|
| 447 |
+
input_ids=input_ids,
|
| 448 |
+
ref_ids=ref_ids,
|
| 449 |
+
voice_clone_prompt=vcp,
|
| 450 |
+
languages=[language] if language is not None else ["Auto"],
|
| 451 |
+
speakers=None,
|
| 452 |
+
non_streaming_mode=non_streaming_mode,
|
| 453 |
+
instruct_ids=instruct_ids,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
if not self._warmed_up:
|
| 457 |
+
self._warmup(tie.shape[1])
|
| 458 |
+
|
| 459 |
+
talker = m.talker
|
| 460 |
+
config = m.config.talker_config
|
| 461 |
+
talker.rope_deltas = None
|
| 462 |
+
|
| 463 |
+
# For ICL mode: return ref_codes so the decoder can use them as acoustic context
|
| 464 |
+
ref_codes = None
|
| 465 |
+
if using_icl_mode and vcp.get("ref_code") and vcp["ref_code"][0] is not None:
|
| 466 |
+
ref_codes = vcp["ref_code"][0]
|
| 467 |
+
|
| 468 |
+
return m, talker, config, tie, tam, tth, tpe, ref_codes
|
| 469 |
+
|
| 470 |
+
def _prepare_generation_custom(
|
| 471 |
+
self,
|
| 472 |
+
text: str,
|
| 473 |
+
language: str,
|
| 474 |
+
speaker: Optional[str],
|
| 475 |
+
instruct: Optional[str] = None,
|
| 476 |
+
non_streaming_mode: bool = True,
|
| 477 |
+
):
|
| 478 |
+
input_texts = [self.model._build_assistant_text(text)]
|
| 479 |
+
input_ids = self.model._tokenize_texts(input_texts)
|
| 480 |
+
|
| 481 |
+
instruct_ids = []
|
| 482 |
+
if instruct is None or instruct == "":
|
| 483 |
+
instruct_ids.append(None)
|
| 484 |
+
else:
|
| 485 |
+
instruct_ids.append(self.model._tokenize_texts([self.model._build_instruct_text(instruct)])[0])
|
| 486 |
+
|
| 487 |
+
m = self.model.model
|
| 488 |
+
tie, tam, tth, tpe = self._build_talker_inputs_local(
|
| 489 |
+
m=m,
|
| 490 |
+
input_ids=input_ids,
|
| 491 |
+
ref_ids=[None],
|
| 492 |
+
voice_clone_prompt=None,
|
| 493 |
+
languages=[language] if language is not None else ["Auto"],
|
| 494 |
+
speakers=[speaker],
|
| 495 |
+
non_streaming_mode=non_streaming_mode,
|
| 496 |
+
instruct_ids=instruct_ids,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
if not self._warmed_up:
|
| 500 |
+
self._warmup(tie.shape[1])
|
| 501 |
+
|
| 502 |
+
talker = m.talker
|
| 503 |
+
config = m.config.talker_config
|
| 504 |
+
talker.rope_deltas = None
|
| 505 |
+
|
| 506 |
+
return m, talker, config, tie, tam, tth, tpe
|
| 507 |
+
|
| 508 |
+
def _build_talker_inputs_local(
|
| 509 |
+
self,
|
| 510 |
+
m,
|
| 511 |
+
input_ids,
|
| 512 |
+
ref_ids,
|
| 513 |
+
voice_clone_prompt,
|
| 514 |
+
languages,
|
| 515 |
+
speakers,
|
| 516 |
+
non_streaming_mode: bool,
|
| 517 |
+
instruct_ids=None,
|
| 518 |
+
):
|
| 519 |
+
"""Local copy of upstream talker input building for qwen-tts main repo."""
|
| 520 |
+
talker_input_embeds = [[] for _ in range(len(input_ids))]
|
| 521 |
+
|
| 522 |
+
voice_clone_spk_embeds = None
|
| 523 |
+
if voice_clone_prompt is not None:
|
| 524 |
+
voice_clone_spk_embeds = m.generate_speaker_prompt(voice_clone_prompt)
|
| 525 |
+
|
| 526 |
+
if instruct_ids is not None:
|
| 527 |
+
for index, instruct_id in enumerate(instruct_ids):
|
| 528 |
+
if instruct_id is not None:
|
| 529 |
+
talker_input_embeds[index].append(
|
| 530 |
+
m.talker.text_projection(m.talker.get_text_embeddings()(instruct_id))
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
if speakers is None:
|
| 534 |
+
speakers = [None] * len(input_ids)
|
| 535 |
+
|
| 536 |
+
trailing_text_hiddens = []
|
| 537 |
+
tts_pad_embed = None
|
| 538 |
+
|
| 539 |
+
for index, (input_id, language, speaker) in enumerate(zip(input_ids, languages, speakers)):
|
| 540 |
+
if voice_clone_spk_embeds is None:
|
| 541 |
+
if speaker == "" or speaker is None:
|
| 542 |
+
speaker_embed = None
|
| 543 |
+
else:
|
| 544 |
+
if speaker.lower() not in m.config.talker_config.spk_id:
|
| 545 |
+
raise NotImplementedError(f"Speaker {speaker} not implemented")
|
| 546 |
+
spk_id = m.config.talker_config.spk_id[speaker.lower()]
|
| 547 |
+
speaker_embed = m.talker.get_input_embeddings()(
|
| 548 |
+
torch.tensor(spk_id, device=m.talker.device, dtype=input_id.dtype)
|
| 549 |
+
)
|
| 550 |
+
else:
|
| 551 |
+
if voice_clone_prompt["x_vector_only_mode"][index] or voice_clone_prompt["icl_mode"][index]:
|
| 552 |
+
speaker_embed = voice_clone_spk_embeds[index]
|
| 553 |
+
else:
|
| 554 |
+
speaker_embed = None
|
| 555 |
+
|
| 556 |
+
assert language is not None
|
| 557 |
+
if language.lower() == "auto":
|
| 558 |
+
language_id = None
|
| 559 |
+
else:
|
| 560 |
+
if language.lower() not in m.config.talker_config.codec_language_id:
|
| 561 |
+
raise NotImplementedError(f"Language {language} not implemented")
|
| 562 |
+
language_id = m.config.talker_config.codec_language_id[language.lower()]
|
| 563 |
+
|
| 564 |
+
if (
|
| 565 |
+
language.lower() in ["chinese", "auto"]
|
| 566 |
+
and speaker not in ("", None)
|
| 567 |
+
and m.config.talker_config.spk_is_dialect[speaker.lower()]
|
| 568 |
+
):
|
| 569 |
+
dialect = m.config.talker_config.spk_is_dialect[speaker.lower()]
|
| 570 |
+
language_id = m.config.talker_config.codec_language_id[dialect]
|
| 571 |
+
|
| 572 |
+
tts_bos_embed, tts_eos_embed, tts_pad_embed = m.talker.text_projection(
|
| 573 |
+
m.talker.get_text_embeddings()(
|
| 574 |
+
torch.tensor(
|
| 575 |
+
[[m.config.tts_bos_token_id, m.config.tts_eos_token_id, m.config.tts_pad_token_id]],
|
| 576 |
+
device=m.talker.device,
|
| 577 |
+
dtype=input_id.dtype,
|
| 578 |
+
)
|
| 579 |
+
)
|
| 580 |
+
).chunk(3, dim=1)
|
| 581 |
+
|
| 582 |
+
if language_id is None:
|
| 583 |
+
codec_prefill_list = [[
|
| 584 |
+
m.config.talker_config.codec_nothink_id,
|
| 585 |
+
m.config.talker_config.codec_think_bos_id,
|
| 586 |
+
m.config.talker_config.codec_think_eos_id,
|
| 587 |
+
]]
|
| 588 |
+
else:
|
| 589 |
+
codec_prefill_list = [[
|
| 590 |
+
m.config.talker_config.codec_think_id,
|
| 591 |
+
m.config.talker_config.codec_think_bos_id,
|
| 592 |
+
language_id,
|
| 593 |
+
m.config.talker_config.codec_think_eos_id,
|
| 594 |
+
]]
|
| 595 |
+
|
| 596 |
+
codec_input_emebdding_0 = m.talker.get_input_embeddings()(
|
| 597 |
+
torch.tensor(codec_prefill_list, device=m.talker.device, dtype=input_id.dtype)
|
| 598 |
+
)
|
| 599 |
+
codec_input_emebdding_1 = m.talker.get_input_embeddings()(
|
| 600 |
+
torch.tensor(
|
| 601 |
+
[[m.config.talker_config.codec_pad_id, m.config.talker_config.codec_bos_id]],
|
| 602 |
+
device=m.talker.device,
|
| 603 |
+
dtype=input_id.dtype,
|
| 604 |
+
)
|
| 605 |
+
)
|
| 606 |
+
if speaker_embed is None:
|
| 607 |
+
codec_input_emebdding = torch.cat([codec_input_emebdding_0, codec_input_emebdding_1], dim=1)
|
| 608 |
+
else:
|
| 609 |
+
codec_input_emebdding = torch.cat([codec_input_emebdding_0, speaker_embed.view(1, 1, -1), codec_input_emebdding_1], dim=1)
|
| 610 |
+
|
| 611 |
+
_talker_input_embed_role = m.talker.text_projection(
|
| 612 |
+
m.talker.get_text_embeddings()(input_id[:, :3])
|
| 613 |
+
)
|
| 614 |
+
_talker_input_embed = torch.cat(
|
| 615 |
+
(
|
| 616 |
+
tts_pad_embed.expand(-1, codec_input_emebdding.shape[1] - 2, -1),
|
| 617 |
+
tts_bos_embed,
|
| 618 |
+
),
|
| 619 |
+
dim=1,
|
| 620 |
+
) + codec_input_emebdding[:, :-1]
|
| 621 |
+
|
| 622 |
+
talker_input_embed = torch.cat((_talker_input_embed_role, _talker_input_embed), dim=1)
|
| 623 |
+
|
| 624 |
+
if (
|
| 625 |
+
voice_clone_prompt is not None
|
| 626 |
+
and voice_clone_prompt.get("ref_code", None) is not None
|
| 627 |
+
and voice_clone_prompt["icl_mode"][index]
|
| 628 |
+
):
|
| 629 |
+
icl_input_embed, trailing_text_hidden = m.generate_icl_prompt(
|
| 630 |
+
text_id=input_id[:, 3:-5],
|
| 631 |
+
ref_id=ref_ids[index][:, 3:-2],
|
| 632 |
+
ref_code=voice_clone_prompt["ref_code"][index].to(m.talker.device).clone(), # escape inference_mode context
|
| 633 |
+
tts_pad_embed=tts_pad_embed,
|
| 634 |
+
tts_eos_embed=tts_eos_embed,
|
| 635 |
+
non_streaming_mode=non_streaming_mode,
|
| 636 |
+
)
|
| 637 |
+
talker_input_embed = torch.cat([talker_input_embed, icl_input_embed], dim=1)
|
| 638 |
+
else:
|
| 639 |
+
talker_input_embed = torch.cat(
|
| 640 |
+
[
|
| 641 |
+
talker_input_embed,
|
| 642 |
+
m.talker.text_projection(
|
| 643 |
+
m.talker.get_text_embeddings()(input_id[:, 3:4])
|
| 644 |
+
)
|
| 645 |
+
+ codec_input_emebdding[:, -1:],
|
| 646 |
+
],
|
| 647 |
+
dim=1,
|
| 648 |
+
)
|
| 649 |
+
if non_streaming_mode:
|
| 650 |
+
talker_input_embed = talker_input_embed[:, :-1]
|
| 651 |
+
talker_input_embed = torch.cat(
|
| 652 |
+
[
|
| 653 |
+
talker_input_embed,
|
| 654 |
+
torch.cat(
|
| 655 |
+
(
|
| 656 |
+
m.talker.text_projection(
|
| 657 |
+
m.talker.get_text_embeddings()(input_id[:, 3:-5])
|
| 658 |
+
),
|
| 659 |
+
tts_eos_embed,
|
| 660 |
+
),
|
| 661 |
+
dim=1,
|
| 662 |
+
)
|
| 663 |
+
+ m.talker.get_input_embeddings()(
|
| 664 |
+
torch.tensor(
|
| 665 |
+
[[m.config.talker_config.codec_pad_id] * (input_id[:, 3:-5].shape[1] + 1)],
|
| 666 |
+
device=m.talker.device,
|
| 667 |
+
dtype=input_id.dtype,
|
| 668 |
+
)
|
| 669 |
+
),
|
| 670 |
+
tts_pad_embed
|
| 671 |
+
+ m.talker.get_input_embeddings()(
|
| 672 |
+
torch.tensor(
|
| 673 |
+
[[m.config.talker_config.codec_bos_id]],
|
| 674 |
+
device=m.talker.device,
|
| 675 |
+
dtype=input_id.dtype,
|
| 676 |
+
)
|
| 677 |
+
),
|
| 678 |
+
],
|
| 679 |
+
dim=1,
|
| 680 |
+
)
|
| 681 |
+
trailing_text_hidden = tts_pad_embed
|
| 682 |
+
else:
|
| 683 |
+
trailing_text_hidden = torch.cat(
|
| 684 |
+
(
|
| 685 |
+
m.talker.text_projection(
|
| 686 |
+
m.talker.get_text_embeddings()(input_id[:, 4:-5])
|
| 687 |
+
),
|
| 688 |
+
tts_eos_embed,
|
| 689 |
+
),
|
| 690 |
+
dim=1,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
talker_input_embeds[index].append(talker_input_embed)
|
| 694 |
+
trailing_text_hiddens.append(trailing_text_hidden)
|
| 695 |
+
|
| 696 |
+
for index, talker_input_embed in enumerate(talker_input_embeds):
|
| 697 |
+
talker_input_embeds[index] = torch.cat([item for item in talker_input_embed if item is not None], dim=1)
|
| 698 |
+
|
| 699 |
+
original_lengths = torch.tensor([t.shape[1] for t in talker_input_embeds])
|
| 700 |
+
sequences = [t.squeeze(0) for t in talker_input_embeds]
|
| 701 |
+
sequences_reversed = [t.flip(dims=[0]) for t in sequences]
|
| 702 |
+
padded_reversed = torch.nn.utils.rnn.pad_sequence(
|
| 703 |
+
sequences_reversed,
|
| 704 |
+
batch_first=True,
|
| 705 |
+
padding_value=0.0,
|
| 706 |
+
)
|
| 707 |
+
talker_input_embeds = padded_reversed.flip(dims=[1])
|
| 708 |
+
|
| 709 |
+
batch_size, max_len = talker_input_embeds.shape[0], talker_input_embeds.shape[1]
|
| 710 |
+
indices = torch.arange(max_len).expand(batch_size, -1)
|
| 711 |
+
num_pads = max_len - original_lengths
|
| 712 |
+
talker_attention_mask = (indices >= num_pads.unsqueeze(1)).long().to(talker_input_embeds.device)
|
| 713 |
+
|
| 714 |
+
pad_embedding_vector = tts_pad_embed.squeeze()
|
| 715 |
+
sequences_to_pad = [t.squeeze(0) for t in trailing_text_hiddens]
|
| 716 |
+
trailing_text_original_lengths = [s.shape[0] for s in sequences_to_pad]
|
| 717 |
+
padded_hiddens = torch.nn.utils.rnn.pad_sequence(
|
| 718 |
+
sequences_to_pad,
|
| 719 |
+
batch_first=True,
|
| 720 |
+
padding_value=0.0,
|
| 721 |
+
)
|
| 722 |
+
arange_tensor = torch.arange(max(trailing_text_original_lengths), device=padded_hiddens.device).expand(
|
| 723 |
+
len(trailing_text_original_lengths), -1
|
| 724 |
+
)
|
| 725 |
+
lengths_tensor = torch.tensor(trailing_text_original_lengths, device=padded_hiddens.device).unsqueeze(1)
|
| 726 |
+
padding_mask = arange_tensor >= lengths_tensor
|
| 727 |
+
padded_hiddens[padding_mask] = pad_embedding_vector
|
| 728 |
+
trailing_text_hiddens = padded_hiddens
|
| 729 |
+
|
| 730 |
+
return talker_input_embeds, talker_attention_mask, trailing_text_hiddens, tts_pad_embed
|
| 731 |
+
|
| 732 |
+
@torch.inference_mode()
|
| 733 |
+
def generate_voice_clone(
|
| 734 |
+
self,
|
| 735 |
+
text: str,
|
| 736 |
+
language: str,
|
| 737 |
+
ref_audio: Optional[Union[str, Path, tuple]] = None,
|
| 738 |
+
ref_text: str = "",
|
| 739 |
+
max_new_tokens: int = 2048,
|
| 740 |
+
min_new_tokens: int = 2,
|
| 741 |
+
temperature: float = 0.9,
|
| 742 |
+
top_k: int = 50,
|
| 743 |
+
top_p: float = 1.0,
|
| 744 |
+
do_sample: bool = True,
|
| 745 |
+
repetition_penalty: float = 1.05,
|
| 746 |
+
xvec_only: bool = False,
|
| 747 |
+
non_streaming_mode: bool = False,
|
| 748 |
+
append_silence: bool = True,
|
| 749 |
+
instruct: Optional[str] = None,
|
| 750 |
+
voice_clone_prompt: Optional[Union[Dict[str, Any], List[Any]]] = None,
|
| 751 |
+
) -> Tuple[list, int]:
|
| 752 |
+
"""
|
| 753 |
+
Generate speech with voice cloning using reference audio.
|
| 754 |
+
|
| 755 |
+
Args:
|
| 756 |
+
text: Text to synthesize
|
| 757 |
+
language: Target language
|
| 758 |
+
ref_audio: Path to reference audio file. Required when `voice_clone_prompt` is not provided.
|
| 759 |
+
ref_text: Transcription of reference audio.
|
| 760 |
+
max_new_tokens: Maximum tokens to generate
|
| 761 |
+
min_new_tokens: Minimum tokens before EOS is allowed
|
| 762 |
+
temperature: Sampling temperature
|
| 763 |
+
top_k: Top-k sampling
|
| 764 |
+
top_p: Top-p (nucleus) sampling
|
| 765 |
+
do_sample: Whether to sample
|
| 766 |
+
repetition_penalty: Repetition penalty
|
| 767 |
+
xvec_only: When True, use only the speaker embedding for voice cloning.
|
| 768 |
+
This prevents phoneme bleed-through from the reference and allows clean
|
| 769 |
+
language switching. Default False to match upstream ICL behavior
|
| 770 |
+
(reference audio in context).
|
| 771 |
+
non_streaming_mode: Match upstream text-feeding layout. Default False to match
|
| 772 |
+
upstream step-by-step text feeding during decode.
|
| 773 |
+
voice_clone_prompt: Optional precomputed voice clone prompt dict. When provided,
|
| 774 |
+
`xvec_only` is ignored and prompt extraction from `ref_audio` is skipped.
|
| 775 |
+
This path supports x-vector-only prompts (`ref_spk_embedding` only)
|
| 776 |
+
and ICL prompts (`ref_spk_embedding` + `ref_code` + mode flags).
|
| 777 |
+
`ref_text` is ignored for x-vector-only and required for ICL.
|
| 778 |
+
instruct: Optional instruction to guide generation style/dialect (e.g.
|
| 779 |
+
"请用纯正广东话朗读"). Prepended as a user turn before the TTS assistant turn.
|
| 780 |
+
Experimental for x-vector-only voice cloning; prefer `xvec_only=False`.
|
| 781 |
+
|
| 782 |
+
Returns:
|
| 783 |
+
Tuple of ([audio_waveform], sample_rate)
|
| 784 |
+
"""
|
| 785 |
+
from .generate import fast_generate
|
| 786 |
+
|
| 787 |
+
m, talker, config, tie, tam, tth, tpe, ref_codes = self._prepare_generation(
|
| 788 |
+
text=text,
|
| 789 |
+
language=language,
|
| 790 |
+
ref_audio=ref_audio,
|
| 791 |
+
ref_text=ref_text,
|
| 792 |
+
xvec_only=xvec_only,
|
| 793 |
+
non_streaming_mode=non_streaming_mode,
|
| 794 |
+
append_silence=append_silence,
|
| 795 |
+
voice_clone_prompt=voice_clone_prompt,
|
| 796 |
+
instruct=instruct,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
codec_ids, timing = fast_generate(
|
| 800 |
+
talker=talker,
|
| 801 |
+
talker_input_embeds=tie,
|
| 802 |
+
attention_mask=tam,
|
| 803 |
+
trailing_text_hiddens=tth,
|
| 804 |
+
tts_pad_embed=tpe,
|
| 805 |
+
config=config,
|
| 806 |
+
predictor_graph=self.predictor_graph,
|
| 807 |
+
talker_graph=self.talker_graph,
|
| 808 |
+
max_new_tokens=max_new_tokens,
|
| 809 |
+
min_new_tokens=min_new_tokens,
|
| 810 |
+
temperature=temperature,
|
| 811 |
+
top_k=top_k,
|
| 812 |
+
top_p=top_p,
|
| 813 |
+
do_sample=do_sample,
|
| 814 |
+
repetition_penalty=repetition_penalty,
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
if codec_ids is None:
|
| 818 |
+
logger.warning("Generation returned no tokens")
|
| 819 |
+
return [np.zeros(1, dtype=np.float32)], self.sample_rate
|
| 820 |
+
|
| 821 |
+
# In ICL mode: prepend reference codes before decoding so the codec decoder
|
| 822 |
+
# has acoustic context from the reference audio (matches official implementation).
|
| 823 |
+
speech_tokenizer = m.speech_tokenizer
|
| 824 |
+
if ref_codes is not None:
|
| 825 |
+
ref_codes_dev = ref_codes.to(codec_ids.device)
|
| 826 |
+
codes_for_decode = torch.cat([ref_codes_dev, codec_ids], dim=0)
|
| 827 |
+
else:
|
| 828 |
+
codes_for_decode = codec_ids
|
| 829 |
+
audio_list, sr = speech_tokenizer.decode({"audio_codes": codes_for_decode.unsqueeze(0)})
|
| 830 |
+
|
| 831 |
+
# Convert to numpy and trim off the reference audio portion
|
| 832 |
+
ref_len = ref_codes.shape[0] if ref_codes is not None else 0
|
| 833 |
+
total_len = codes_for_decode.shape[0]
|
| 834 |
+
audio_arrays = []
|
| 835 |
+
for a in audio_list:
|
| 836 |
+
if hasattr(a, 'cpu'): # torch tensor
|
| 837 |
+
a = a.flatten().cpu().numpy()
|
| 838 |
+
else: # already numpy
|
| 839 |
+
a = a.flatten() if hasattr(a, 'flatten') else a
|
| 840 |
+
if ref_len > 0:
|
| 841 |
+
cut = int(ref_len / max(total_len, 1) * len(a))
|
| 842 |
+
a = a[cut:]
|
| 843 |
+
audio_arrays.append(a)
|
| 844 |
+
|
| 845 |
+
n_steps = timing['steps']
|
| 846 |
+
audio_duration = n_steps / 12.0 # 12 Hz codec
|
| 847 |
+
total_time = timing['prefill_ms']/1000 + timing['decode_s']
|
| 848 |
+
rtf = audio_duration / total_time if total_time > 0 else 0
|
| 849 |
+
|
| 850 |
+
logger.info(
|
| 851 |
+
f"Generated {audio_duration:.2f}s audio in {total_time:.2f}s "
|
| 852 |
+
f"({timing['ms_per_step']:.1f}ms/step, RTF: {rtf:.2f})"
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
return audio_arrays, sr
|
| 856 |
+
|
| 857 |
+
@torch.inference_mode()
|
| 858 |
+
def generate_voice_clone_streaming(
|
| 859 |
+
self,
|
| 860 |
+
text: str,
|
| 861 |
+
language: str,
|
| 862 |
+
ref_audio: Optional[Union[str, Path]] = None,
|
| 863 |
+
ref_text: str = "",
|
| 864 |
+
max_new_tokens: int = 2048,
|
| 865 |
+
min_new_tokens: int = 2,
|
| 866 |
+
temperature: float = 0.9,
|
| 867 |
+
top_k: int = 50,
|
| 868 |
+
top_p: float = 1.0,
|
| 869 |
+
do_sample: bool = True,
|
| 870 |
+
repetition_penalty: float = 1.05,
|
| 871 |
+
chunk_size: int = 12,
|
| 872 |
+
xvec_only: bool = False,
|
| 873 |
+
non_streaming_mode: bool = False,
|
| 874 |
+
append_silence: bool = True,
|
| 875 |
+
parity_mode: bool = False,
|
| 876 |
+
instruct: Optional[str] = None,
|
| 877 |
+
voice_clone_prompt: Optional[Union[Dict[str, Any], List[Any]]] = None,
|
| 878 |
+
) -> Generator[Tuple[np.ndarray, int, dict], None, None]:
|
| 879 |
+
"""
|
| 880 |
+
Stream voice-cloned speech generation, yielding audio chunks.
|
| 881 |
+
|
| 882 |
+
Same as generate_voice_clone() but yields (audio_chunk, sample_rate, timing)
|
| 883 |
+
tuples every chunk_size codec steps (~chunk_size/12 seconds of audio).
|
| 884 |
+
|
| 885 |
+
Args:
|
| 886 |
+
text: Text to synthesize
|
| 887 |
+
language: Target language
|
| 888 |
+
ref_audio: Path to reference audio file. Required when `voice_clone_prompt` is not provided.
|
| 889 |
+
ref_text: Transcription of reference audio.
|
| 890 |
+
max_new_tokens: Maximum tokens to generate
|
| 891 |
+
min_new_tokens: Minimum tokens before EOS is allowed
|
| 892 |
+
temperature: Sampling temperature
|
| 893 |
+
top_k: Top-k sampling
|
| 894 |
+
top_p: Top-p (nucleus) sampling
|
| 895 |
+
do_sample: Whether to sample
|
| 896 |
+
repetition_penalty: Repetition penalty
|
| 897 |
+
chunk_size: Codec steps per chunk (12 = ~1 second)
|
| 898 |
+
xvec_only: When True, use only the speaker embedding for voice cloning.
|
| 899 |
+
This prevents phoneme bleed-through from the reference and allows clean
|
| 900 |
+
language switching. Default False to match upstream ICL behavior
|
| 901 |
+
(reference audio in context).
|
| 902 |
+
non_streaming_mode: Default False to match upstream text feeding during decode.
|
| 903 |
+
Set to True to prefill the full target text before streaming decode.
|
| 904 |
+
parity_mode: When True, disables CUDA graphs and uses dynamic cache streaming.
|
| 905 |
+
voice_clone_prompt: Optional precomputed voice clone prompt dict. When provided,
|
| 906 |
+
`xvec_only` is ignored and prompt extraction from `ref_audio` is skipped.
|
| 907 |
+
This path supports x-vector-only prompts (`ref_spk_embedding` only)
|
| 908 |
+
and ICL prompts (`ref_spk_embedding` + `ref_code` + mode flags).
|
| 909 |
+
`ref_text` is ignored for x-vector-only and required for ICL.
|
| 910 |
+
instruct: Optional instruction to guide generation style/dialect (e.g.
|
| 911 |
+
"请用纯正广东话朗读"). Prepended as a user turn before the TTS assistant turn.
|
| 912 |
+
Experimental for x-vector-only voice cloning; prefer `xvec_only=False`.
|
| 913 |
+
|
| 914 |
+
Yields:
|
| 915 |
+
Tuple of (audio_chunk_numpy, sample_rate, timing_dict)
|
| 916 |
+
"""
|
| 917 |
+
from .streaming import fast_generate_streaming, parity_generate_streaming
|
| 918 |
+
|
| 919 |
+
m, talker, config, tie, tam, tth, tpe, ref_codes = self._prepare_generation(
|
| 920 |
+
text=text,
|
| 921 |
+
language=language,
|
| 922 |
+
ref_audio=ref_audio,
|
| 923 |
+
ref_text=ref_text,
|
| 924 |
+
xvec_only=xvec_only,
|
| 925 |
+
non_streaming_mode=non_streaming_mode,
|
| 926 |
+
append_silence=append_silence,
|
| 927 |
+
voice_clone_prompt=voice_clone_prompt,
|
| 928 |
+
instruct=instruct,
|
| 929 |
+
)
|
| 930 |
+
|
| 931 |
+
speech_tokenizer = m.speech_tokenizer
|
| 932 |
+
|
| 933 |
+
# Hybrid decode strategy:
|
| 934 |
+
# 1. Accumulated decode for early chunks (correct, calibrates samples_per_frame)
|
| 935 |
+
# 2. Sliding window with 25-frame left context once calibrated (constant cost)
|
| 936 |
+
# This avoids boundary artifacts (pops) while keeping decode cost bounded.
|
| 937 |
+
context_frames = 25
|
| 938 |
+
min_calibration_frames = max(context_frames, chunk_size)
|
| 939 |
+
all_codes = []
|
| 940 |
+
prev_gen_audio_len = 0 # tracks position within the generated (non-ref) audio
|
| 941 |
+
samples_per_frame = None
|
| 942 |
+
|
| 943 |
+
stream_fn = parity_generate_streaming if parity_mode else fast_generate_streaming
|
| 944 |
+
stream_kwargs = dict(
|
| 945 |
+
talker=talker,
|
| 946 |
+
talker_input_embeds=tie,
|
| 947 |
+
attention_mask=tam,
|
| 948 |
+
trailing_text_hiddens=tth,
|
| 949 |
+
tts_pad_embed=tpe,
|
| 950 |
+
config=config,
|
| 951 |
+
max_new_tokens=max_new_tokens,
|
| 952 |
+
min_new_tokens=min_new_tokens,
|
| 953 |
+
temperature=temperature,
|
| 954 |
+
top_k=top_k,
|
| 955 |
+
top_p=top_p,
|
| 956 |
+
do_sample=do_sample,
|
| 957 |
+
repetition_penalty=repetition_penalty,
|
| 958 |
+
chunk_size=chunk_size,
|
| 959 |
+
)
|
| 960 |
+
if not parity_mode:
|
| 961 |
+
stream_kwargs["predictor_graph"] = self.predictor_graph
|
| 962 |
+
stream_kwargs["talker_graph"] = self.talker_graph
|
| 963 |
+
|
| 964 |
+
for codec_chunk, timing in stream_fn(**stream_kwargs):
|
| 965 |
+
all_codes.append(codec_chunk)
|
| 966 |
+
n_new = codec_chunk.shape[0]
|
| 967 |
+
all_flat = torch.cat(all_codes, dim=0)
|
| 968 |
+
n_total = all_flat.shape[0]
|
| 969 |
+
|
| 970 |
+
if samples_per_frame is None:
|
| 971 |
+
# Phase 1: accumulated decode until we can calibrate.
|
| 972 |
+
# In ICL mode prepend reference codes so the codec decoder has acoustic
|
| 973 |
+
# context from the reference audio (matches official implementation).
|
| 974 |
+
if ref_codes is not None:
|
| 975 |
+
codes_input = torch.cat([ref_codes.to(all_flat.device), all_flat], dim=0)
|
| 976 |
+
else:
|
| 977 |
+
codes_input = all_flat
|
| 978 |
+
audio_list, sr = speech_tokenizer.decode(
|
| 979 |
+
{"audio_codes": codes_input.unsqueeze(0)}
|
| 980 |
+
)
|
| 981 |
+
audio = audio_list[0]
|
| 982 |
+
if hasattr(audio, 'cpu'):
|
| 983 |
+
audio = audio.flatten().cpu().numpy()
|
| 984 |
+
else:
|
| 985 |
+
audio = audio.flatten() if hasattr(audio, 'flatten') else audio
|
| 986 |
+
|
| 987 |
+
# Separate out reference audio portion; track position in generated audio only
|
| 988 |
+
if ref_codes is not None:
|
| 989 |
+
ref_len = ref_codes.shape[0]
|
| 990 |
+
total_len = codes_input.shape[0]
|
| 991 |
+
ref_audio_cut = int(ref_len / max(total_len, 1) * len(audio))
|
| 992 |
+
gen_audio = audio[ref_audio_cut:]
|
| 993 |
+
else:
|
| 994 |
+
gen_audio = audio
|
| 995 |
+
|
| 996 |
+
new_audio = gen_audio[prev_gen_audio_len:]
|
| 997 |
+
prev_gen_audio_len = len(gen_audio)
|
| 998 |
+
|
| 999 |
+
if n_total >= min_calibration_frames:
|
| 1000 |
+
samples_per_frame = len(gen_audio) / n_total
|
| 1001 |
+
else:
|
| 1002 |
+
# Phase 2: sliding window with left context
|
| 1003 |
+
ctx_start = max(0, n_total - n_new - context_frames)
|
| 1004 |
+
window = all_flat[ctx_start:]
|
| 1005 |
+
n_ctx = window.shape[0] - n_new
|
| 1006 |
+
|
| 1007 |
+
audio_list, sr = speech_tokenizer.decode(
|
| 1008 |
+
{"audio_codes": window.unsqueeze(0)}
|
| 1009 |
+
)
|
| 1010 |
+
audio = audio_list[0]
|
| 1011 |
+
if hasattr(audio, 'cpu'):
|
| 1012 |
+
audio = audio.flatten().cpu().numpy()
|
| 1013 |
+
else:
|
| 1014 |
+
audio = audio.flatten() if hasattr(audio, 'flatten') else audio
|
| 1015 |
+
|
| 1016 |
+
if n_ctx > 0:
|
| 1017 |
+
ctx_samples = int(round(n_ctx * samples_per_frame))
|
| 1018 |
+
new_audio = audio[ctx_samples:]
|
| 1019 |
+
else:
|
| 1020 |
+
new_audio = audio
|
| 1021 |
+
|
| 1022 |
+
yield new_audio, sr, timing
|
| 1023 |
+
|
| 1024 |
+
@torch.inference_mode()
|
| 1025 |
+
def generate_custom_voice(
|
| 1026 |
+
self,
|
| 1027 |
+
text: str,
|
| 1028 |
+
speaker: str,
|
| 1029 |
+
language: str,
|
| 1030 |
+
instruct: Optional[str] = None,
|
| 1031 |
+
non_streaming_mode: bool = True,
|
| 1032 |
+
max_new_tokens: int = 2048,
|
| 1033 |
+
min_new_tokens: int = 2,
|
| 1034 |
+
temperature: float = 0.9,
|
| 1035 |
+
top_k: int = 50,
|
| 1036 |
+
top_p: float = 1.0,
|
| 1037 |
+
do_sample: bool = True,
|
| 1038 |
+
repetition_penalty: float = 1.05,
|
| 1039 |
+
) -> Tuple[list, int]:
|
| 1040 |
+
if self.model.model.tts_model_type != "custom_voice":
|
| 1041 |
+
raise ValueError("Loaded model does not support custom voice generation")
|
| 1042 |
+
|
| 1043 |
+
self.model._validate_languages([language])
|
| 1044 |
+
self.model._validate_speakers([speaker])
|
| 1045 |
+
|
| 1046 |
+
if self.model.model.tts_model_size in "0b6":
|
| 1047 |
+
instruct = None
|
| 1048 |
+
|
| 1049 |
+
from .generate import fast_generate
|
| 1050 |
+
|
| 1051 |
+
m, talker, config, tie, tam, tth, tpe = self._prepare_generation_custom(
|
| 1052 |
+
text=text,
|
| 1053 |
+
language=language,
|
| 1054 |
+
speaker=speaker,
|
| 1055 |
+
instruct=instruct,
|
| 1056 |
+
non_streaming_mode=non_streaming_mode,
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
+
codec_ids, timing = fast_generate(
|
| 1060 |
+
talker=talker,
|
| 1061 |
+
talker_input_embeds=tie,
|
| 1062 |
+
attention_mask=tam,
|
| 1063 |
+
trailing_text_hiddens=tth,
|
| 1064 |
+
tts_pad_embed=tpe,
|
| 1065 |
+
config=config,
|
| 1066 |
+
predictor_graph=self.predictor_graph,
|
| 1067 |
+
talker_graph=self.talker_graph,
|
| 1068 |
+
max_new_tokens=max_new_tokens,
|
| 1069 |
+
min_new_tokens=min_new_tokens,
|
| 1070 |
+
temperature=temperature,
|
| 1071 |
+
top_k=top_k,
|
| 1072 |
+
top_p=top_p,
|
| 1073 |
+
do_sample=do_sample,
|
| 1074 |
+
repetition_penalty=repetition_penalty,
|
| 1075 |
+
)
|
| 1076 |
+
|
| 1077 |
+
if codec_ids is None:
|
| 1078 |
+
logger.warning("Generation returned no tokens")
|
| 1079 |
+
return [np.zeros(1, dtype=np.float32)], self.sample_rate
|
| 1080 |
+
|
| 1081 |
+
speech_tokenizer = m.speech_tokenizer
|
| 1082 |
+
audio_list, sr = speech_tokenizer.decode({"audio_codes": codec_ids.unsqueeze(0)})
|
| 1083 |
+
|
| 1084 |
+
audio_arrays = []
|
| 1085 |
+
for a in audio_list:
|
| 1086 |
+
if hasattr(a, "cpu"):
|
| 1087 |
+
audio_arrays.append(a.flatten().cpu().numpy())
|
| 1088 |
+
else:
|
| 1089 |
+
audio_arrays.append(a.flatten() if hasattr(a, "flatten") else a)
|
| 1090 |
+
|
| 1091 |
+
n_steps = timing["steps"]
|
| 1092 |
+
audio_duration = n_steps / 12.0
|
| 1093 |
+
total_time = timing["prefill_ms"] / 1000 + timing["decode_s"]
|
| 1094 |
+
rtf = audio_duration / total_time if total_time > 0 else 0
|
| 1095 |
+
|
| 1096 |
+
logger.info(
|
| 1097 |
+
f"Generated {audio_duration:.2f}s audio in {total_time:.2f}s "
|
| 1098 |
+
f"({timing['ms_per_step']:.1f}ms/step, RTF: {rtf:.2f})"
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
+
return audio_arrays, sr
|
| 1102 |
+
|
| 1103 |
+
@torch.inference_mode()
|
| 1104 |
+
def generate_custom_voice_streaming(
|
| 1105 |
+
self,
|
| 1106 |
+
text: str,
|
| 1107 |
+
speaker: str,
|
| 1108 |
+
language: str,
|
| 1109 |
+
instruct: Optional[str] = None,
|
| 1110 |
+
non_streaming_mode: bool = True,
|
| 1111 |
+
max_new_tokens: int = 2048,
|
| 1112 |
+
min_new_tokens: int = 2,
|
| 1113 |
+
temperature: float = 0.9,
|
| 1114 |
+
top_k: int = 50,
|
| 1115 |
+
top_p: float = 1.0,
|
| 1116 |
+
do_sample: bool = True,
|
| 1117 |
+
repetition_penalty: float = 1.05,
|
| 1118 |
+
chunk_size: int = 12,
|
| 1119 |
+
) -> Generator[Tuple[np.ndarray, int, dict], None, None]:
|
| 1120 |
+
if self.model.model.tts_model_type != "custom_voice":
|
| 1121 |
+
raise ValueError("Loaded model does not support custom voice generation")
|
| 1122 |
+
|
| 1123 |
+
self.model._validate_languages([language])
|
| 1124 |
+
self.model._validate_speakers([speaker])
|
| 1125 |
+
|
| 1126 |
+
if self.model.model.tts_model_size in "0b6":
|
| 1127 |
+
instruct = None
|
| 1128 |
+
|
| 1129 |
+
from .streaming import fast_generate_streaming
|
| 1130 |
+
|
| 1131 |
+
m, talker, config, tie, tam, tth, tpe = self._prepare_generation_custom(
|
| 1132 |
+
text=text,
|
| 1133 |
+
language=language,
|
| 1134 |
+
speaker=speaker,
|
| 1135 |
+
instruct=instruct,
|
| 1136 |
+
non_streaming_mode=non_streaming_mode,
|
| 1137 |
+
)
|
| 1138 |
+
|
| 1139 |
+
speech_tokenizer = m.speech_tokenizer
|
| 1140 |
+
|
| 1141 |
+
context_frames = 25
|
| 1142 |
+
min_calibration_frames = max(context_frames, chunk_size)
|
| 1143 |
+
all_codes = []
|
| 1144 |
+
prev_audio_len = 0
|
| 1145 |
+
samples_per_frame = None
|
| 1146 |
+
|
| 1147 |
+
for codec_chunk, timing in fast_generate_streaming(
|
| 1148 |
+
talker=talker,
|
| 1149 |
+
talker_input_embeds=tie,
|
| 1150 |
+
attention_mask=tam,
|
| 1151 |
+
trailing_text_hiddens=tth,
|
| 1152 |
+
tts_pad_embed=tpe,
|
| 1153 |
+
config=config,
|
| 1154 |
+
predictor_graph=self.predictor_graph,
|
| 1155 |
+
talker_graph=self.talker_graph,
|
| 1156 |
+
max_new_tokens=max_new_tokens,
|
| 1157 |
+
min_new_tokens=min_new_tokens,
|
| 1158 |
+
temperature=temperature,
|
| 1159 |
+
top_k=top_k,
|
| 1160 |
+
top_p=top_p,
|
| 1161 |
+
do_sample=do_sample,
|
| 1162 |
+
repetition_penalty=repetition_penalty,
|
| 1163 |
+
chunk_size=chunk_size,
|
| 1164 |
+
):
|
| 1165 |
+
all_codes.append(codec_chunk)
|
| 1166 |
+
n_new = codec_chunk.shape[0]
|
| 1167 |
+
all_flat = torch.cat(all_codes, dim=0)
|
| 1168 |
+
n_total = all_flat.shape[0]
|
| 1169 |
+
|
| 1170 |
+
if samples_per_frame is None:
|
| 1171 |
+
audio_list, sr = speech_tokenizer.decode({"audio_codes": all_flat.unsqueeze(0)})
|
| 1172 |
+
audio = audio_list[0]
|
| 1173 |
+
if hasattr(audio, "cpu"):
|
| 1174 |
+
audio = audio.flatten().cpu().numpy()
|
| 1175 |
+
else:
|
| 1176 |
+
audio = audio.flatten() if hasattr(audio, "flatten") else audio
|
| 1177 |
+
|
| 1178 |
+
new_audio = audio[prev_audio_len:]
|
| 1179 |
+
prev_audio_len = len(audio)
|
| 1180 |
+
|
| 1181 |
+
if n_total >= min_calibration_frames:
|
| 1182 |
+
samples_per_frame = len(audio) / n_total
|
| 1183 |
+
else:
|
| 1184 |
+
ctx_start = max(0, n_total - n_new - context_frames)
|
| 1185 |
+
window = all_flat[ctx_start:]
|
| 1186 |
+
n_ctx = window.shape[0] - n_new
|
| 1187 |
+
|
| 1188 |
+
audio_list, sr = speech_tokenizer.decode({"audio_codes": window.unsqueeze(0)})
|
| 1189 |
+
audio = audio_list[0]
|
| 1190 |
+
if hasattr(audio, "cpu"):
|
| 1191 |
+
audio = audio.flatten().cpu().numpy()
|
| 1192 |
+
else:
|
| 1193 |
+
audio = audio.flatten() if hasattr(audio, "flatten") else audio
|
| 1194 |
+
|
| 1195 |
+
if n_ctx > 0:
|
| 1196 |
+
ctx_samples = int(round(n_ctx * samples_per_frame))
|
| 1197 |
+
new_audio = audio[ctx_samples:]
|
| 1198 |
+
else:
|
| 1199 |
+
new_audio = audio
|
| 1200 |
+
|
| 1201 |
+
yield new_audio, sr, timing
|
| 1202 |
+
|
| 1203 |
+
@torch.inference_mode()
|
| 1204 |
+
def generate_voice_design(
|
| 1205 |
+
self,
|
| 1206 |
+
text: str,
|
| 1207 |
+
instruct: str,
|
| 1208 |
+
language: str,
|
| 1209 |
+
non_streaming_mode: bool = True,
|
| 1210 |
+
max_new_tokens: int = 2048,
|
| 1211 |
+
min_new_tokens: int = 2,
|
| 1212 |
+
temperature: float = 0.9,
|
| 1213 |
+
top_k: int = 50,
|
| 1214 |
+
top_p: float = 1.0,
|
| 1215 |
+
do_sample: bool = True,
|
| 1216 |
+
repetition_penalty: float = 1.05,
|
| 1217 |
+
) -> Tuple[list, int]:
|
| 1218 |
+
if self.model.model.tts_model_type != "voice_design":
|
| 1219 |
+
raise ValueError("Loaded model does not support voice design generation")
|
| 1220 |
+
|
| 1221 |
+
self.model._validate_languages([language])
|
| 1222 |
+
|
| 1223 |
+
from .generate import fast_generate
|
| 1224 |
+
|
| 1225 |
+
m, talker, config, tie, tam, tth, tpe = self._prepare_generation_custom(
|
| 1226 |
+
text=text,
|
| 1227 |
+
language=language,
|
| 1228 |
+
speaker=None,
|
| 1229 |
+
instruct=instruct,
|
| 1230 |
+
non_streaming_mode=non_streaming_mode,
|
| 1231 |
+
)
|
| 1232 |
+
|
| 1233 |
+
codec_ids, timing = fast_generate(
|
| 1234 |
+
talker=talker,
|
| 1235 |
+
talker_input_embeds=tie,
|
| 1236 |
+
attention_mask=tam,
|
| 1237 |
+
trailing_text_hiddens=tth,
|
| 1238 |
+
tts_pad_embed=tpe,
|
| 1239 |
+
config=config,
|
| 1240 |
+
predictor_graph=self.predictor_graph,
|
| 1241 |
+
talker_graph=self.talker_graph,
|
| 1242 |
+
max_new_tokens=max_new_tokens,
|
| 1243 |
+
min_new_tokens=min_new_tokens,
|
| 1244 |
+
temperature=temperature,
|
| 1245 |
+
top_k=top_k,
|
| 1246 |
+
top_p=top_p,
|
| 1247 |
+
do_sample=do_sample,
|
| 1248 |
+
repetition_penalty=repetition_penalty,
|
| 1249 |
+
)
|
| 1250 |
+
|
| 1251 |
+
if codec_ids is None:
|
| 1252 |
+
logger.warning("Generation returned no tokens")
|
| 1253 |
+
return [np.zeros(1, dtype=np.float32)], self.sample_rate
|
| 1254 |
+
|
| 1255 |
+
speech_tokenizer = m.speech_tokenizer
|
| 1256 |
+
audio_list, sr = speech_tokenizer.decode({"audio_codes": codec_ids.unsqueeze(0)})
|
| 1257 |
+
|
| 1258 |
+
audio_arrays = []
|
| 1259 |
+
for a in audio_list:
|
| 1260 |
+
if hasattr(a, "cpu"):
|
| 1261 |
+
audio_arrays.append(a.flatten().cpu().numpy())
|
| 1262 |
+
else:
|
| 1263 |
+
audio_arrays.append(a.flatten() if hasattr(a, "flatten") else a)
|
| 1264 |
+
|
| 1265 |
+
n_steps = timing["steps"]
|
| 1266 |
+
audio_duration = n_steps / 12.0
|
| 1267 |
+
total_time = timing["prefill_ms"] / 1000 + timing["decode_s"]
|
| 1268 |
+
rtf = audio_duration / total_time if total_time > 0 else 0
|
| 1269 |
+
|
| 1270 |
+
logger.info(
|
| 1271 |
+
f"Generated {audio_duration:.2f}s audio in {total_time:.2f}s "
|
| 1272 |
+
f"({timing['ms_per_step']:.1f}ms/step, RTF: {rtf:.2f})"
|
| 1273 |
+
)
|
| 1274 |
+
|
| 1275 |
+
return audio_arrays, sr
|
| 1276 |
+
|
| 1277 |
+
@torch.inference_mode()
|
| 1278 |
+
def generate_voice_design_streaming(
|
| 1279 |
+
self,
|
| 1280 |
+
text: str,
|
| 1281 |
+
instruct: str,
|
| 1282 |
+
language: str,
|
| 1283 |
+
non_streaming_mode: bool = True,
|
| 1284 |
+
max_new_tokens: int = 2048,
|
| 1285 |
+
min_new_tokens: int = 2,
|
| 1286 |
+
temperature: float = 0.9,
|
| 1287 |
+
top_k: int = 50,
|
| 1288 |
+
top_p: float = 1.0,
|
| 1289 |
+
do_sample: bool = True,
|
| 1290 |
+
repetition_penalty: float = 1.05,
|
| 1291 |
+
chunk_size: int = 12,
|
| 1292 |
+
) -> Generator[Tuple[np.ndarray, int, dict], None, None]:
|
| 1293 |
+
if self.model.model.tts_model_type != "voice_design":
|
| 1294 |
+
raise ValueError("Loaded model does not support voice design generation")
|
| 1295 |
+
|
| 1296 |
+
self.model._validate_languages([language])
|
| 1297 |
+
|
| 1298 |
+
from .streaming import fast_generate_streaming
|
| 1299 |
+
|
| 1300 |
+
m, talker, config, tie, tam, tth, tpe = self._prepare_generation_custom(
|
| 1301 |
+
text=text,
|
| 1302 |
+
language=language,
|
| 1303 |
+
speaker=None,
|
| 1304 |
+
instruct=instruct,
|
| 1305 |
+
non_streaming_mode=non_streaming_mode,
|
| 1306 |
+
)
|
| 1307 |
+
|
| 1308 |
+
speech_tokenizer = m.speech_tokenizer
|
| 1309 |
+
|
| 1310 |
+
context_frames = 25
|
| 1311 |
+
min_calibration_frames = max(context_frames, chunk_size)
|
| 1312 |
+
all_codes = []
|
| 1313 |
+
prev_audio_len = 0
|
| 1314 |
+
samples_per_frame = None
|
| 1315 |
+
|
| 1316 |
+
for codec_chunk, timing in fast_generate_streaming(
|
| 1317 |
+
talker=talker,
|
| 1318 |
+
talker_input_embeds=tie,
|
| 1319 |
+
attention_mask=tam,
|
| 1320 |
+
trailing_text_hiddens=tth,
|
| 1321 |
+
tts_pad_embed=tpe,
|
| 1322 |
+
config=config,
|
| 1323 |
+
predictor_graph=self.predictor_graph,
|
| 1324 |
+
talker_graph=self.talker_graph,
|
| 1325 |
+
max_new_tokens=max_new_tokens,
|
| 1326 |
+
min_new_tokens=min_new_tokens,
|
| 1327 |
+
temperature=temperature,
|
| 1328 |
+
top_k=top_k,
|
| 1329 |
+
top_p=top_p,
|
| 1330 |
+
do_sample=do_sample,
|
| 1331 |
+
repetition_penalty=repetition_penalty,
|
| 1332 |
+
chunk_size=chunk_size,
|
| 1333 |
+
):
|
| 1334 |
+
all_codes.append(codec_chunk)
|
| 1335 |
+
n_new = codec_chunk.shape[0]
|
| 1336 |
+
all_flat = torch.cat(all_codes, dim=0)
|
| 1337 |
+
n_total = all_flat.shape[0]
|
| 1338 |
+
|
| 1339 |
+
if samples_per_frame is None:
|
| 1340 |
+
audio_list, sr = speech_tokenizer.decode({"audio_codes": all_flat.unsqueeze(0)})
|
| 1341 |
+
audio = audio_list[0]
|
| 1342 |
+
if hasattr(audio, "cpu"):
|
| 1343 |
+
audio = audio.flatten().cpu().numpy()
|
| 1344 |
+
else:
|
| 1345 |
+
audio = audio.flatten() if hasattr(audio, "flatten") else audio
|
| 1346 |
+
|
| 1347 |
+
new_audio = audio[prev_audio_len:]
|
| 1348 |
+
prev_audio_len = len(audio)
|
| 1349 |
+
|
| 1350 |
+
if n_total >= min_calibration_frames:
|
| 1351 |
+
samples_per_frame = len(audio) / n_total
|
| 1352 |
+
else:
|
| 1353 |
+
ctx_start = max(0, n_total - n_new - context_frames)
|
| 1354 |
+
window = all_flat[ctx_start:]
|
| 1355 |
+
n_ctx = window.shape[0] - n_new
|
| 1356 |
+
|
| 1357 |
+
audio_list, sr = speech_tokenizer.decode({"audio_codes": window.unsqueeze(0)})
|
| 1358 |
+
audio = audio_list[0]
|
| 1359 |
+
if hasattr(audio, "cpu"):
|
| 1360 |
+
audio = audio.flatten().cpu().numpy()
|
| 1361 |
+
else:
|
| 1362 |
+
audio = audio.flatten() if hasattr(audio, "flatten") else audio
|
| 1363 |
+
|
| 1364 |
+
if n_ctx > 0:
|
| 1365 |
+
ctx_samples = int(round(n_ctx * samples_per_frame))
|
| 1366 |
+
new_audio = audio[ctx_samples:]
|
| 1367 |
+
else:
|
| 1368 |
+
new_audio = audio
|
| 1369 |
+
|
| 1370 |
+
yield new_audio, sr, timing
|
faster_qwen3_tts/predictor_graph.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
CUDA graph capture for the code predictor's 15-step decode loop,
|
| 4 |
+
using transformers StaticCache.
|
| 5 |
+
|
| 6 |
+
The predictor generates 15 codebooks autoregressively:
|
| 7 |
+
- Step 0: prefill with 2 tokens (past_hidden + first_codebook_embed), get logits[0]
|
| 8 |
+
- Steps 1-14: decode 1 token at a time using previous codebook token's embedding
|
| 9 |
+
|
| 10 |
+
Strategy:
|
| 11 |
+
- Use transformers StaticCache for KV cache management
|
| 12 |
+
- Use the predictor's inner model forward (handles mask, RoPE, attention internally)
|
| 13 |
+
- Unroll the full 15-step loop for deterministic shapes
|
| 14 |
+
- Capture the entire loop as a single CUDA graph
|
| 15 |
+
"""
|
| 16 |
+
import torch
|
| 17 |
+
from transformers import StaticCache
|
| 18 |
+
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
| 19 |
+
|
| 20 |
+
from .sampling import sample_logits
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class PredictorGraph:
|
| 24 |
+
"""
|
| 25 |
+
Captures the full predictor 15-step loop as a CUDA graph,
|
| 26 |
+
using the model's forward with transformers StaticCache.
|
| 27 |
+
|
| 28 |
+
Usage:
|
| 29 |
+
mpg = PredictorGraph(code_predictor, pred_config, talker_hidden_size)
|
| 30 |
+
mpg.capture()
|
| 31 |
+
codebook_tokens = mpg.run(pred_input) # pred_input: [1, 2, H]
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, code_predictor, pred_config, talker_hidden_size, device='cuda', dtype=torch.bfloat16,
|
| 35 |
+
do_sample=True, top_k=50, top_p=1.0, temperature=0.9):
|
| 36 |
+
self.device = device
|
| 37 |
+
device_index = torch.device(device).index
|
| 38 |
+
device_index = device_index if device_index is not None else torch.cuda.current_device()
|
| 39 |
+
self.device_index = device_index
|
| 40 |
+
|
| 41 |
+
self.dtype = dtype
|
| 42 |
+
self.num_layers = pred_config.num_hidden_layers
|
| 43 |
+
self.hidden_size = pred_config.hidden_size
|
| 44 |
+
self.num_code_groups = pred_config.num_code_groups
|
| 45 |
+
self.num_codebooks = self.num_code_groups - 1 # 15
|
| 46 |
+
self.max_seq = 2 + self.num_codebooks # 17
|
| 47 |
+
self.do_sample = do_sample
|
| 48 |
+
self.top_k = top_k
|
| 49 |
+
self.top_p = top_p
|
| 50 |
+
self.temperature = temperature
|
| 51 |
+
|
| 52 |
+
# Extract model components (references, not copies)
|
| 53 |
+
cp = code_predictor
|
| 54 |
+
self.small_to_mtp = cp.small_to_mtp_projection
|
| 55 |
+
self.pred_model = cp.model # Inner transformer model (5 layers)
|
| 56 |
+
self.lm_heads = cp.lm_head # ModuleList[15]
|
| 57 |
+
self.codec_embeds = cp.model.codec_embedding # ModuleList[15]
|
| 58 |
+
self.has_sliding_layers = "sliding_attention" in getattr(self.pred_model.config, "layer_types", [])
|
| 59 |
+
|
| 60 |
+
# Transformers StaticCache for the predictor
|
| 61 |
+
self.static_cache = StaticCache(config=pred_config, max_cache_len=self.max_seq)
|
| 62 |
+
|
| 63 |
+
# Pre-allocate cache_position tensors for each step (avoids CPU→GPU in graph)
|
| 64 |
+
self.prefill_cache_pos = torch.arange(2, device=device)
|
| 65 |
+
self.decode_cache_positions = [
|
| 66 |
+
torch.tensor([2 + i], device=device) for i in range(self.num_codebooks - 1)
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
# I/O buffers
|
| 70 |
+
self.input_buf = torch.zeros(1, 2, talker_hidden_size, dtype=dtype, device=device)
|
| 71 |
+
self.output_tokens = torch.zeros(self.num_codebooks, dtype=torch.long, device=device)
|
| 72 |
+
|
| 73 |
+
self.graph = None
|
| 74 |
+
self.captured = False
|
| 75 |
+
self.prefill_attn = None
|
| 76 |
+
self.decode_attn = None
|
| 77 |
+
|
| 78 |
+
def _init_cache_layers(self):
|
| 79 |
+
"""Force lazy initialization of StaticCache layers before graph capture."""
|
| 80 |
+
config = self.pred_model.config
|
| 81 |
+
num_kv_heads = getattr(config, 'num_key_value_heads', config.num_attention_heads)
|
| 82 |
+
head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
|
| 83 |
+
dummy_k = torch.zeros(1, num_kv_heads, 1, head_dim, dtype=self.dtype, device=self.device)
|
| 84 |
+
for layer in self.static_cache.layers:
|
| 85 |
+
if not layer.is_initialized:
|
| 86 |
+
layer.lazy_initialization(dummy_k)
|
| 87 |
+
|
| 88 |
+
def _make_attn_mask(self, input_embeds: torch.Tensor, cache_position: torch.Tensor):
|
| 89 |
+
mask = create_causal_mask(
|
| 90 |
+
config=self.pred_model.config,
|
| 91 |
+
input_embeds=input_embeds,
|
| 92 |
+
attention_mask=None,
|
| 93 |
+
cache_position=cache_position,
|
| 94 |
+
past_key_values=self.static_cache,
|
| 95 |
+
)
|
| 96 |
+
if self.has_sliding_layers:
|
| 97 |
+
sliding = create_sliding_window_causal_mask(
|
| 98 |
+
config=self.pred_model.config,
|
| 99 |
+
input_embeds=input_embeds,
|
| 100 |
+
attention_mask=None,
|
| 101 |
+
cache_position=cache_position,
|
| 102 |
+
past_key_values=self.static_cache,
|
| 103 |
+
)
|
| 104 |
+
return {"full_attention": mask, "sliding_attention": sliding}
|
| 105 |
+
return {"full_attention": mask}
|
| 106 |
+
|
| 107 |
+
def _build_attention_masks(self):
|
| 108 |
+
dummy_prefill = torch.zeros(1, 2, self.hidden_size, dtype=self.dtype, device=self.device)
|
| 109 |
+
dummy_decode = torch.zeros(1, 1, self.hidden_size, dtype=self.dtype, device=self.device)
|
| 110 |
+
self.prefill_attn = self._make_attn_mask(dummy_prefill, self.prefill_cache_pos)
|
| 111 |
+
self.decode_attn = []
|
| 112 |
+
for pos in self.decode_cache_positions:
|
| 113 |
+
self.decode_attn.append(self._make_attn_mask(dummy_decode, pos))
|
| 114 |
+
|
| 115 |
+
def _full_loop(self):
|
| 116 |
+
"""The full 15-step predictor loop on static buffers."""
|
| 117 |
+
# Project input from talker hidden size to predictor hidden size
|
| 118 |
+
h = self.small_to_mtp(self.input_buf) # [1, 2, hidden]
|
| 119 |
+
|
| 120 |
+
# Prefill: 2 tokens through all layers
|
| 121 |
+
out = self.pred_model(
|
| 122 |
+
inputs_embeds=h,
|
| 123 |
+
attention_mask=self.prefill_attn,
|
| 124 |
+
past_key_values=self.static_cache,
|
| 125 |
+
cache_position=self.prefill_cache_pos,
|
| 126 |
+
use_cache=True,
|
| 127 |
+
)
|
| 128 |
+
h = out.last_hidden_state # [1, 2, hidden] — already normalized
|
| 129 |
+
|
| 130 |
+
# First codebook: logits from last position
|
| 131 |
+
logits = self.lm_heads[0](h[:, -1:, :]) # [1, 1, vocab]
|
| 132 |
+
tok = sample_logits(
|
| 133 |
+
logits[:, 0, :],
|
| 134 |
+
temperature=self.temperature,
|
| 135 |
+
top_k=self.top_k,
|
| 136 |
+
top_p=self.top_p,
|
| 137 |
+
do_sample=self.do_sample,
|
| 138 |
+
)
|
| 139 |
+
self.output_tokens[0] = tok[0]
|
| 140 |
+
|
| 141 |
+
# Remaining 14 codebooks
|
| 142 |
+
for cb_idx in range(1, self.num_codebooks):
|
| 143 |
+
# Embed previous token using codebook-specific embedding
|
| 144 |
+
emb = self.codec_embeds[cb_idx - 1](tok.unsqueeze(0)) # [1, 1, codec_hidden]
|
| 145 |
+
emb = self.small_to_mtp(emb) # [1, 1, hidden]
|
| 146 |
+
|
| 147 |
+
# Single-token decode through all layers
|
| 148 |
+
out = self.pred_model(
|
| 149 |
+
inputs_embeds=emb,
|
| 150 |
+
attention_mask=self.decode_attn[cb_idx - 1],
|
| 151 |
+
past_key_values=self.static_cache,
|
| 152 |
+
cache_position=self.decode_cache_positions[cb_idx - 1],
|
| 153 |
+
use_cache=True,
|
| 154 |
+
)
|
| 155 |
+
h = out.last_hidden_state
|
| 156 |
+
|
| 157 |
+
logits = self.lm_heads[cb_idx](h[:, -1:, :])
|
| 158 |
+
tok = sample_logits(
|
| 159 |
+
logits[:, 0, :],
|
| 160 |
+
temperature=self.temperature,
|
| 161 |
+
top_k=self.top_k,
|
| 162 |
+
top_p=self.top_p,
|
| 163 |
+
do_sample=self.do_sample,
|
| 164 |
+
)
|
| 165 |
+
self.output_tokens[cb_idx] = tok[0]
|
| 166 |
+
|
| 167 |
+
return self.output_tokens
|
| 168 |
+
|
| 169 |
+
@torch.inference_mode()
|
| 170 |
+
def capture(self, num_warmup=3):
|
| 171 |
+
"""Warmup and capture the CUDA graph."""
|
| 172 |
+
print(f"Warming up predictor ({num_warmup} runs)...")
|
| 173 |
+
|
| 174 |
+
# Force cache initialization before graph capture
|
| 175 |
+
self._init_cache_layers()
|
| 176 |
+
self._build_attention_masks()
|
| 177 |
+
|
| 178 |
+
for _ in range(num_warmup):
|
| 179 |
+
self.static_cache.reset()
|
| 180 |
+
self._full_loop()
|
| 181 |
+
torch.cuda.synchronize()
|
| 182 |
+
|
| 183 |
+
print("Capturing CUDA graph for predictor...")
|
| 184 |
+
|
| 185 |
+
with torch.cuda.device(self.device_index):
|
| 186 |
+
s = torch.cuda.Stream()
|
| 187 |
+
s.wait_stream(torch.cuda.current_stream())
|
| 188 |
+
with torch.cuda.stream(s):
|
| 189 |
+
self.graph = torch.cuda.CUDAGraph()
|
| 190 |
+
# Warmup in capture stream
|
| 191 |
+
self.static_cache.reset()
|
| 192 |
+
self._full_loop()
|
| 193 |
+
torch.cuda.synchronize()
|
| 194 |
+
|
| 195 |
+
self.static_cache.reset()
|
| 196 |
+
with torch.cuda.graph(self.graph):
|
| 197 |
+
self._full_loop()
|
| 198 |
+
|
| 199 |
+
torch.cuda.current_stream().wait_stream(s)
|
| 200 |
+
torch.cuda.synchronize()
|
| 201 |
+
self.captured = True
|
| 202 |
+
print("CUDA graph captured!")
|
| 203 |
+
|
| 204 |
+
@torch.inference_mode()
|
| 205 |
+
def run(self, pred_input: torch.Tensor) -> torch.Tensor:
|
| 206 |
+
"""
|
| 207 |
+
Run the captured graph.
|
| 208 |
+
pred_input: [1, 2, talker_hidden_size] (past_hidden cat first_codebook_embed)
|
| 209 |
+
Returns: [15] long tensor of codebook tokens
|
| 210 |
+
"""
|
| 211 |
+
self.input_buf.copy_(pred_input)
|
| 212 |
+
self.static_cache.reset()
|
| 213 |
+
self.graph.replay()
|
| 214 |
+
return self.output_tokens.clone()
|
faster_qwen3_tts/sampling.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared sampling helpers for talker and predictor generation."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from typing import Iterable, Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def apply_repetition_penalty(
|
| 11 |
+
logits: torch.Tensor,
|
| 12 |
+
token_history: torch.Tensor,
|
| 13 |
+
repetition_penalty: float,
|
| 14 |
+
) -> torch.Tensor:
|
| 15 |
+
"""Apply repetition penalty to logits in-place and return them.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
logits: Tensor shaped [1, 1, vocab] or [1, vocab].
|
| 19 |
+
token_history: 1-D tensor of previously generated token ids.
|
| 20 |
+
repetition_penalty: HF-style repetition penalty (>1.0).
|
| 21 |
+
"""
|
| 22 |
+
if repetition_penalty == 1.0 or token_history.numel() == 0:
|
| 23 |
+
return logits
|
| 24 |
+
unique_toks = token_history.unique()
|
| 25 |
+
tok_logits = logits[..., unique_toks]
|
| 26 |
+
logits[..., unique_toks] = torch.where(
|
| 27 |
+
tok_logits > 0, tok_logits / repetition_penalty, tok_logits * repetition_penalty
|
| 28 |
+
)
|
| 29 |
+
return logits
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def sample_logits(
|
| 33 |
+
logits: torch.Tensor,
|
| 34 |
+
*,
|
| 35 |
+
temperature: float,
|
| 36 |
+
top_k: int,
|
| 37 |
+
top_p: float,
|
| 38 |
+
do_sample: bool,
|
| 39 |
+
suppress_mask: Optional[torch.Tensor] = None,
|
| 40 |
+
suppress_tokens: Optional[Iterable[int]] = None,
|
| 41 |
+
) -> torch.Tensor:
|
| 42 |
+
"""Sample a token from logits.
|
| 43 |
+
|
| 44 |
+
Mirrors HF order: suppress -> temperature -> top-k -> top-p -> sample.
|
| 45 |
+
"""
|
| 46 |
+
logits = logits.clone()
|
| 47 |
+
if suppress_mask is not None:
|
| 48 |
+
logits[..., suppress_mask] = float("-inf")
|
| 49 |
+
if suppress_tokens:
|
| 50 |
+
logits[..., list(suppress_tokens)] = float("-inf")
|
| 51 |
+
if not do_sample:
|
| 52 |
+
return torch.argmax(logits, dim=-1)
|
| 53 |
+
logits = logits / temperature
|
| 54 |
+
if top_k > 0:
|
| 55 |
+
topk_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 56 |
+
logits = torch.where(logits < topk_vals[..., -1:], torch.full_like(logits, float("-inf")), logits)
|
| 57 |
+
if top_p < 1.0:
|
| 58 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 59 |
+
probs = F.softmax(sorted_logits, dim=-1)
|
| 60 |
+
cumulative_probs = torch.cumsum(probs, dim=-1)
|
| 61 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 62 |
+
sorted_indices_to_remove[..., 0] = False
|
| 63 |
+
sorted_logits[sorted_indices_to_remove] = float("-inf")
|
| 64 |
+
logits = torch.full_like(logits, float("-inf"))
|
| 65 |
+
logits.scatter_(-1, sorted_indices, sorted_logits)
|
| 66 |
+
return torch.multinomial(F.softmax(logits, dim=-1), 1).squeeze(-1)
|
faster_qwen3_tts/streaming.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Streaming generation with CUDA graphs for both predictor and talker.
|
| 4 |
+
|
| 5 |
+
Yields codec ID chunks during generation instead of collecting all at once.
|
| 6 |
+
CUDA graph usage is identical to non-streaming — same per-step performance.
|
| 7 |
+
"""
|
| 8 |
+
import time
|
| 9 |
+
from typing import Generator, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from .predictor_graph import PredictorGraph
|
| 14 |
+
from .sampling import apply_repetition_penalty, sample_logits
|
| 15 |
+
from .talker_graph import TalkerGraph
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@torch.inference_mode()
|
| 19 |
+
def fast_generate_streaming(
|
| 20 |
+
talker,
|
| 21 |
+
talker_input_embeds: torch.Tensor,
|
| 22 |
+
attention_mask: torch.Tensor,
|
| 23 |
+
trailing_text_hiddens: torch.Tensor,
|
| 24 |
+
tts_pad_embed: torch.Tensor,
|
| 25 |
+
config,
|
| 26 |
+
predictor_graph: PredictorGraph,
|
| 27 |
+
talker_graph: TalkerGraph,
|
| 28 |
+
max_new_tokens: int = 2048,
|
| 29 |
+
min_new_tokens: int = 2,
|
| 30 |
+
temperature: float = 0.9,
|
| 31 |
+
top_k: int = 50,
|
| 32 |
+
top_p: float = 1.0,
|
| 33 |
+
do_sample: bool = True,
|
| 34 |
+
repetition_penalty: float = 1.05,
|
| 35 |
+
chunk_size: int = 12,
|
| 36 |
+
) -> Generator[Tuple[torch.Tensor, dict], None, None]:
|
| 37 |
+
"""
|
| 38 |
+
Streaming autoregressive generation with CUDA-graphed predictor and talker.
|
| 39 |
+
|
| 40 |
+
Yields (codec_chunk, timing_info) tuples every chunk_size steps.
|
| 41 |
+
codec_chunk: [chunk_steps, 16] tensor of codec IDs.
|
| 42 |
+
The final chunk may be shorter than chunk_size.
|
| 43 |
+
"""
|
| 44 |
+
eos_id = config.codec_eos_token_id
|
| 45 |
+
vocab_size = config.vocab_size
|
| 46 |
+
device = talker_input_embeds.device
|
| 47 |
+
|
| 48 |
+
suppress_mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
|
| 49 |
+
suppress_start = max(0, vocab_size - 1024)
|
| 50 |
+
for i in range(suppress_start, vocab_size):
|
| 51 |
+
if i != eos_id:
|
| 52 |
+
suppress_mask[i] = True
|
| 53 |
+
|
| 54 |
+
predictor = talker.code_predictor
|
| 55 |
+
talker_codec_embed = talker.get_input_embeddings()
|
| 56 |
+
talker_codec_head = talker.codec_head
|
| 57 |
+
predictor_codec_embeds = predictor.get_input_embeddings()
|
| 58 |
+
num_code_groups = config.num_code_groups
|
| 59 |
+
|
| 60 |
+
# === PREFILL (still uses HF forward for variable-length prefill) ===
|
| 61 |
+
t_start = time.time()
|
| 62 |
+
|
| 63 |
+
out = talker.forward(
|
| 64 |
+
inputs_embeds=talker_input_embeds,
|
| 65 |
+
attention_mask=attention_mask,
|
| 66 |
+
use_cache=True,
|
| 67 |
+
output_hidden_states=True,
|
| 68 |
+
return_dict=True,
|
| 69 |
+
trailing_text_hidden=trailing_text_hiddens,
|
| 70 |
+
tts_pad_embed=tts_pad_embed,
|
| 71 |
+
generation_step=None,
|
| 72 |
+
past_hidden=None,
|
| 73 |
+
past_key_values=None,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
talker_past_kv = out.past_key_values
|
| 77 |
+
past_hidden = out.past_hidden
|
| 78 |
+
gen_step = out.generation_step
|
| 79 |
+
|
| 80 |
+
logits = out.logits[:, -1, :]
|
| 81 |
+
suppress_eos = min_new_tokens > 0
|
| 82 |
+
token = sample_logits(
|
| 83 |
+
logits,
|
| 84 |
+
temperature=temperature,
|
| 85 |
+
top_k=top_k,
|
| 86 |
+
top_p=top_p,
|
| 87 |
+
do_sample=do_sample,
|
| 88 |
+
suppress_mask=suppress_mask,
|
| 89 |
+
suppress_tokens=[eos_id] if suppress_eos else None,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
prefill_len = talker_graph.prefill_kv(talker_past_kv)
|
| 93 |
+
rope_deltas = getattr(talker, "rope_deltas", None)
|
| 94 |
+
talker_graph.set_generation_state(attention_mask, rope_deltas)
|
| 95 |
+
|
| 96 |
+
torch.cuda.synchronize()
|
| 97 |
+
t_prefill = time.time() - t_start
|
| 98 |
+
|
| 99 |
+
# === DECODE LOOP — yield chunks ===
|
| 100 |
+
chunk_buffer = []
|
| 101 |
+
all_first_tokens = [] # for repetition penalty across chunks
|
| 102 |
+
total_steps = 0
|
| 103 |
+
chunk_count = 0
|
| 104 |
+
chunk_start = time.time()
|
| 105 |
+
|
| 106 |
+
for step_idx in range(max_new_tokens):
|
| 107 |
+
if token.item() == eos_id:
|
| 108 |
+
break
|
| 109 |
+
|
| 110 |
+
# --- CUDA-Graphed Code Predictor ---
|
| 111 |
+
last_id_hidden = talker_codec_embed(token.unsqueeze(1))
|
| 112 |
+
pred_input = torch.cat((past_hidden, last_id_hidden), dim=1)
|
| 113 |
+
codebook_token_ids = predictor_graph.run(pred_input)
|
| 114 |
+
|
| 115 |
+
all_cb = torch.cat([token.view(1), codebook_token_ids])
|
| 116 |
+
chunk_buffer.append(all_cb.detach())
|
| 117 |
+
all_first_tokens.append(token.detach())
|
| 118 |
+
|
| 119 |
+
# --- Build input embedding for talker ---
|
| 120 |
+
codec_hiddens = [last_id_hidden]
|
| 121 |
+
for i in range(num_code_groups - 1):
|
| 122 |
+
codec_hiddens.append(predictor_codec_embeds[i](codebook_token_ids[i].unsqueeze(0).unsqueeze(0)))
|
| 123 |
+
inputs_embeds = torch.cat(codec_hiddens, dim=1).sum(1, keepdim=True)
|
| 124 |
+
|
| 125 |
+
if gen_step < trailing_text_hiddens.shape[1]:
|
| 126 |
+
inputs_embeds = inputs_embeds + trailing_text_hiddens[:, gen_step].unsqueeze(1)
|
| 127 |
+
else:
|
| 128 |
+
inputs_embeds = inputs_embeds + tts_pad_embed
|
| 129 |
+
|
| 130 |
+
# --- CUDA-Graphed Talker decode step ---
|
| 131 |
+
current_pos = prefill_len + step_idx
|
| 132 |
+
if current_pos >= talker_graph.max_seq_len - 1:
|
| 133 |
+
break
|
| 134 |
+
|
| 135 |
+
hidden_states = talker_graph.run(inputs_embeds, position=current_pos)
|
| 136 |
+
|
| 137 |
+
logits = talker_codec_head(hidden_states[:, -1, :]).unsqueeze(0)
|
| 138 |
+
|
| 139 |
+
if repetition_penalty != 1.0 and all_first_tokens:
|
| 140 |
+
history = torch.stack(all_first_tokens)
|
| 141 |
+
logits = apply_repetition_penalty(logits, history, repetition_penalty)
|
| 142 |
+
|
| 143 |
+
suppress_eos = len(all_first_tokens) < min_new_tokens
|
| 144 |
+
token = sample_logits(
|
| 145 |
+
logits.squeeze(0),
|
| 146 |
+
temperature=temperature,
|
| 147 |
+
top_k=top_k,
|
| 148 |
+
top_p=top_p,
|
| 149 |
+
do_sample=do_sample,
|
| 150 |
+
suppress_mask=suppress_mask,
|
| 151 |
+
suppress_tokens=[eos_id] if suppress_eos else None,
|
| 152 |
+
)
|
| 153 |
+
past_hidden = hidden_states[:, -1:, :].clone()
|
| 154 |
+
gen_step += 1
|
| 155 |
+
|
| 156 |
+
# --- Yield chunk when buffer is full ---
|
| 157 |
+
if len(chunk_buffer) >= chunk_size:
|
| 158 |
+
torch.cuda.synchronize()
|
| 159 |
+
chunk_decode_time = time.time() - chunk_start
|
| 160 |
+
total_steps += len(chunk_buffer)
|
| 161 |
+
|
| 162 |
+
yield torch.stack(chunk_buffer), {
|
| 163 |
+
'chunk_index': chunk_count,
|
| 164 |
+
'chunk_steps': len(chunk_buffer),
|
| 165 |
+
'prefill_ms': t_prefill * 1000 if chunk_count == 0 else 0,
|
| 166 |
+
'decode_ms': chunk_decode_time * 1000,
|
| 167 |
+
'total_steps_so_far': total_steps,
|
| 168 |
+
'is_final': False,
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
chunk_buffer = []
|
| 172 |
+
chunk_count += 1
|
| 173 |
+
chunk_start = time.time()
|
| 174 |
+
|
| 175 |
+
# --- Yield final partial chunk ---
|
| 176 |
+
if chunk_buffer:
|
| 177 |
+
torch.cuda.synchronize()
|
| 178 |
+
chunk_decode_time = time.time() - chunk_start
|
| 179 |
+
total_steps += len(chunk_buffer)
|
| 180 |
+
|
| 181 |
+
yield torch.stack(chunk_buffer), {
|
| 182 |
+
'chunk_index': chunk_count,
|
| 183 |
+
'chunk_steps': len(chunk_buffer),
|
| 184 |
+
'prefill_ms': t_prefill * 1000 if chunk_count == 0 else 0,
|
| 185 |
+
'decode_ms': chunk_decode_time * 1000,
|
| 186 |
+
'total_steps_so_far': total_steps,
|
| 187 |
+
'is_final': True,
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
@torch.inference_mode()
|
| 192 |
+
def parity_generate_streaming(
|
| 193 |
+
talker,
|
| 194 |
+
talker_input_embeds: torch.Tensor,
|
| 195 |
+
attention_mask: torch.Tensor,
|
| 196 |
+
trailing_text_hiddens: torch.Tensor,
|
| 197 |
+
tts_pad_embed: torch.Tensor,
|
| 198 |
+
config,
|
| 199 |
+
max_new_tokens: int = 2048,
|
| 200 |
+
min_new_tokens: int = 2,
|
| 201 |
+
temperature: float = 0.9,
|
| 202 |
+
top_k: int = 50,
|
| 203 |
+
top_p: float = 1.0,
|
| 204 |
+
do_sample: bool = True,
|
| 205 |
+
repetition_penalty: float = 1.05,
|
| 206 |
+
chunk_size: int = 12,
|
| 207 |
+
) -> Generator[Tuple[torch.Tensor, dict], None, None]:
|
| 208 |
+
"""
|
| 209 |
+
Streaming generation without CUDA graphs (dynamic cache).
|
| 210 |
+
|
| 211 |
+
Yields (codec_chunk, timing_info) tuples every chunk_size steps.
|
| 212 |
+
"""
|
| 213 |
+
# NOTE: This function intentionally mirrors fast_generate_streaming. The core
|
| 214 |
+
# decode loop is duplicated so we can swap CUDA graphs/static cache for the
|
| 215 |
+
# dynamic-cache path while keeping sampling/chunking identical. If you edit
|
| 216 |
+
# the fast path, check parity_generate_streaming for matching changes.
|
| 217 |
+
eos_id = config.codec_eos_token_id
|
| 218 |
+
vocab_size = config.vocab_size
|
| 219 |
+
device = talker_input_embeds.device
|
| 220 |
+
|
| 221 |
+
suppress_mask = torch.zeros(vocab_size, dtype=torch.bool, device=device)
|
| 222 |
+
suppress_start = max(0, vocab_size - 1024)
|
| 223 |
+
for i in range(suppress_start, vocab_size):
|
| 224 |
+
if i != eos_id:
|
| 225 |
+
suppress_mask[i] = True
|
| 226 |
+
|
| 227 |
+
# === PREFILL ===
|
| 228 |
+
t_start = time.time()
|
| 229 |
+
|
| 230 |
+
out = talker.forward(
|
| 231 |
+
inputs_embeds=talker_input_embeds,
|
| 232 |
+
attention_mask=attention_mask,
|
| 233 |
+
use_cache=True,
|
| 234 |
+
output_hidden_states=True,
|
| 235 |
+
return_dict=True,
|
| 236 |
+
trailing_text_hidden=trailing_text_hiddens,
|
| 237 |
+
tts_pad_embed=tts_pad_embed,
|
| 238 |
+
generation_step=None,
|
| 239 |
+
past_hidden=None,
|
| 240 |
+
past_key_values=None,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
talker_past_kv = out.past_key_values
|
| 244 |
+
past_hidden = out.past_hidden
|
| 245 |
+
gen_step = out.generation_step
|
| 246 |
+
|
| 247 |
+
logits = out.logits[:, -1, :]
|
| 248 |
+
suppress_eos = min_new_tokens > 0
|
| 249 |
+
token = sample_logits(
|
| 250 |
+
logits,
|
| 251 |
+
temperature=temperature,
|
| 252 |
+
top_k=top_k,
|
| 253 |
+
top_p=top_p,
|
| 254 |
+
do_sample=do_sample,
|
| 255 |
+
suppress_mask=suppress_mask,
|
| 256 |
+
suppress_tokens=[eos_id] if suppress_eos else None,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
if attention_mask is not None:
|
| 260 |
+
attention_mask = attention_mask.clone()
|
| 261 |
+
|
| 262 |
+
torch.cuda.synchronize()
|
| 263 |
+
t_prefill = time.time() - t_start
|
| 264 |
+
|
| 265 |
+
# === DECODE LOOP — yield chunks ===
|
| 266 |
+
chunk_buffer = []
|
| 267 |
+
all_first_tokens = []
|
| 268 |
+
total_steps = 0
|
| 269 |
+
chunk_count = 0
|
| 270 |
+
chunk_start = time.time()
|
| 271 |
+
|
| 272 |
+
for _ in range(max_new_tokens):
|
| 273 |
+
if token.item() == eos_id:
|
| 274 |
+
break
|
| 275 |
+
|
| 276 |
+
cache_position = None
|
| 277 |
+
if attention_mask is not None:
|
| 278 |
+
attention_mask = torch.cat(
|
| 279 |
+
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
|
| 280 |
+
dim=1,
|
| 281 |
+
)
|
| 282 |
+
cache_position = torch.tensor([attention_mask.shape[1] - 1], device=attention_mask.device)
|
| 283 |
+
|
| 284 |
+
out = talker.forward(
|
| 285 |
+
input_ids=token.view(1, 1),
|
| 286 |
+
attention_mask=attention_mask,
|
| 287 |
+
use_cache=True,
|
| 288 |
+
output_hidden_states=True,
|
| 289 |
+
return_dict=True,
|
| 290 |
+
trailing_text_hidden=trailing_text_hiddens,
|
| 291 |
+
tts_pad_embed=tts_pad_embed,
|
| 292 |
+
generation_step=gen_step,
|
| 293 |
+
past_hidden=past_hidden,
|
| 294 |
+
past_key_values=talker_past_kv,
|
| 295 |
+
subtalker_dosample=do_sample,
|
| 296 |
+
subtalker_top_k=top_k,
|
| 297 |
+
subtalker_top_p=top_p,
|
| 298 |
+
subtalker_temperature=temperature,
|
| 299 |
+
cache_position=cache_position,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
codec_ids = out.hidden_states[1]
|
| 303 |
+
if codec_ids is None:
|
| 304 |
+
break
|
| 305 |
+
|
| 306 |
+
chunk_buffer.append(codec_ids.squeeze(0).detach())
|
| 307 |
+
all_first_tokens.append(token.detach())
|
| 308 |
+
|
| 309 |
+
logits = out.logits[:, -1, :]
|
| 310 |
+
if repetition_penalty != 1.0 and all_first_tokens:
|
| 311 |
+
history = torch.stack(all_first_tokens)
|
| 312 |
+
logits = apply_repetition_penalty(logits, history, repetition_penalty)
|
| 313 |
+
|
| 314 |
+
suppress_eos = len(all_first_tokens) < min_new_tokens
|
| 315 |
+
token = sample_logits(
|
| 316 |
+
logits,
|
| 317 |
+
temperature=temperature,
|
| 318 |
+
top_k=top_k,
|
| 319 |
+
top_p=top_p,
|
| 320 |
+
do_sample=do_sample,
|
| 321 |
+
suppress_mask=suppress_mask,
|
| 322 |
+
suppress_tokens=[eos_id] if suppress_eos else None,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
talker_past_kv = out.past_key_values
|
| 326 |
+
past_hidden = out.past_hidden
|
| 327 |
+
gen_step = out.generation_step
|
| 328 |
+
|
| 329 |
+
if len(chunk_buffer) >= chunk_size:
|
| 330 |
+
torch.cuda.synchronize()
|
| 331 |
+
chunk_decode_time = time.time() - chunk_start
|
| 332 |
+
total_steps += len(chunk_buffer)
|
| 333 |
+
|
| 334 |
+
yield torch.stack(chunk_buffer), {
|
| 335 |
+
'chunk_index': chunk_count,
|
| 336 |
+
'chunk_steps': len(chunk_buffer),
|
| 337 |
+
'prefill_ms': t_prefill * 1000 if chunk_count == 0 else 0,
|
| 338 |
+
'decode_ms': chunk_decode_time * 1000,
|
| 339 |
+
'total_steps_so_far': total_steps,
|
| 340 |
+
'is_final': False,
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
chunk_buffer = []
|
| 344 |
+
chunk_count += 1
|
| 345 |
+
chunk_start = time.time()
|
| 346 |
+
|
| 347 |
+
if chunk_buffer:
|
| 348 |
+
torch.cuda.synchronize()
|
| 349 |
+
chunk_decode_time = time.time() - chunk_start
|
| 350 |
+
total_steps += len(chunk_buffer)
|
| 351 |
+
|
| 352 |
+
yield torch.stack(chunk_buffer), {
|
| 353 |
+
'chunk_index': chunk_count,
|
| 354 |
+
'chunk_steps': len(chunk_buffer),
|
| 355 |
+
'prefill_ms': t_prefill * 1000 if chunk_count == 0 else 0,
|
| 356 |
+
'decode_ms': chunk_decode_time * 1000,
|
| 357 |
+
'total_steps_so_far': total_steps,
|
| 358 |
+
'is_final': True,
|
| 359 |
+
}
|
faster_qwen3_tts/talker_graph.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
CUDA graph capture for the talker's single-token decode step,
|
| 4 |
+
using transformers StaticCache.
|
| 5 |
+
|
| 6 |
+
The talker has 28 transformer layers. Instead of reimplementing the
|
| 7 |
+
forward pass manually, we use the model's own forward with StaticCache.
|
| 8 |
+
The StaticCache provides fixed-size KV tensors compatible with CUDA graphs.
|
| 9 |
+
|
| 10 |
+
Strategy:
|
| 11 |
+
- Use transformers StaticCache for KV cache management
|
| 12 |
+
- Use the model's forward method (handles mask, RoPE, attention internally)
|
| 13 |
+
- Capture the single-token decode as a CUDA graph
|
| 14 |
+
- Update cache_position buffer between replays
|
| 15 |
+
"""
|
| 16 |
+
import torch
|
| 17 |
+
from transformers import StaticCache
|
| 18 |
+
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class TalkerGraph:
|
| 22 |
+
"""
|
| 23 |
+
Captures the talker's single-token decode step as a CUDA graph,
|
| 24 |
+
using the model's own forward with transformers StaticCache.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, talker_model, talker_config, device='cuda', dtype=torch.bfloat16,
|
| 28 |
+
max_seq_len=512):
|
| 29 |
+
self.device = device
|
| 30 |
+
device_index = torch.device(device).index
|
| 31 |
+
device_index = device_index if device_index is not None else torch.cuda.current_device()
|
| 32 |
+
self.device_index = device_index
|
| 33 |
+
|
| 34 |
+
self.dtype = dtype
|
| 35 |
+
self.max_seq_len = max_seq_len
|
| 36 |
+
self.hidden_size = talker_config.hidden_size
|
| 37 |
+
self.num_layers = talker_config.num_hidden_layers
|
| 38 |
+
|
| 39 |
+
# Keep reference to the inner model (transformer backbone)
|
| 40 |
+
self.model = talker_model
|
| 41 |
+
|
| 42 |
+
# Transformers StaticCache — handles index_copy_ and fixed-size KV internally
|
| 43 |
+
self.static_cache = StaticCache(config=talker_config, max_cache_len=max_seq_len)
|
| 44 |
+
|
| 45 |
+
# Static I/O buffers for CUDA graph
|
| 46 |
+
self.input_buf = torch.zeros(1, 1, self.hidden_size, dtype=dtype, device=device)
|
| 47 |
+
self.output_buf = torch.zeros(1, 1, self.hidden_size, dtype=dtype, device=device)
|
| 48 |
+
|
| 49 |
+
# Cache position buffer — updated before each graph replay
|
| 50 |
+
self.cache_position = torch.zeros(1, dtype=torch.long, device=device)
|
| 51 |
+
# Rope deltas from prefill (shape [batch, 1]) and position ids buffer.
|
| 52 |
+
self.rope_deltas = torch.zeros(1, 1, dtype=torch.float32, device=device)
|
| 53 |
+
self.position_ids = torch.zeros(3, 1, 1, dtype=torch.float32, device=device)
|
| 54 |
+
|
| 55 |
+
self.graph = None
|
| 56 |
+
self.captured = False
|
| 57 |
+
self.attn_mask = None
|
| 58 |
+
self.attn_mask_table = None
|
| 59 |
+
self._mask_key = None
|
| 60 |
+
|
| 61 |
+
def _init_cache_layers(self):
|
| 62 |
+
"""Force lazy initialization of StaticCache layers before graph capture."""
|
| 63 |
+
config = self.model.config
|
| 64 |
+
num_kv_heads = getattr(config, 'num_key_value_heads', config.num_attention_heads)
|
| 65 |
+
head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads)
|
| 66 |
+
dummy_k = torch.zeros(1, num_kv_heads, 1, head_dim, dtype=self.dtype, device=self.device)
|
| 67 |
+
for layer in self.static_cache.layers:
|
| 68 |
+
if not layer.is_initialized:
|
| 69 |
+
layer.lazy_initialization(dummy_k)
|
| 70 |
+
|
| 71 |
+
def _build_attention_masks(self, attention_mask: torch.Tensor | None = None):
|
| 72 |
+
dummy = torch.zeros(1, 1, self.hidden_size, dtype=self.dtype, device=self.device)
|
| 73 |
+
max_len = self.max_seq_len
|
| 74 |
+
self.attn_mask_table = [None] * max_len
|
| 75 |
+
|
| 76 |
+
mask_fn = create_causal_mask if self.model.config.sliding_window is None else create_sliding_window_causal_mask
|
| 77 |
+
|
| 78 |
+
for i in range(max_len):
|
| 79 |
+
pos = torch.tensor([i], device=self.device)
|
| 80 |
+
full = mask_fn(
|
| 81 |
+
config=self.model.config,
|
| 82 |
+
input_embeds=dummy,
|
| 83 |
+
attention_mask=attention_mask,
|
| 84 |
+
cache_position=pos,
|
| 85 |
+
past_key_values=self.static_cache,
|
| 86 |
+
)
|
| 87 |
+
self.attn_mask_table[i] = full
|
| 88 |
+
|
| 89 |
+
if self.attn_mask is None:
|
| 90 |
+
self.attn_mask = self.attn_mask_table[0].clone()
|
| 91 |
+
else:
|
| 92 |
+
self.attn_mask.copy_(self.attn_mask_table[0])
|
| 93 |
+
|
| 94 |
+
def _set_attention_mask(self, position: int):
|
| 95 |
+
self.attn_mask.copy_(self.attn_mask_table[position])
|
| 96 |
+
|
| 97 |
+
def _decode_step(self):
|
| 98 |
+
"""Single-token decode through the model's forward."""
|
| 99 |
+
out = self.model(
|
| 100 |
+
inputs_embeds=self.input_buf,
|
| 101 |
+
attention_mask=self.attn_mask,
|
| 102 |
+
past_key_values=self.static_cache,
|
| 103 |
+
cache_position=self.cache_position,
|
| 104 |
+
position_ids=self.position_ids,
|
| 105 |
+
use_cache=True,
|
| 106 |
+
)
|
| 107 |
+
self.output_buf.copy_(out.last_hidden_state)
|
| 108 |
+
|
| 109 |
+
@torch.inference_mode()
|
| 110 |
+
def capture(self, prefill_len=100, num_warmup=3):
|
| 111 |
+
"""
|
| 112 |
+
Capture CUDA graph for single-token decode.
|
| 113 |
+
prefill_len: simulated prefill length for warmup (graph is position-independent).
|
| 114 |
+
"""
|
| 115 |
+
print(f"Warming up talker graph ({num_warmup} runs)...")
|
| 116 |
+
|
| 117 |
+
# Force cache initialization before graph capture
|
| 118 |
+
self._init_cache_layers()
|
| 119 |
+
self._build_attention_masks()
|
| 120 |
+
|
| 121 |
+
# Set cache_position for warmup
|
| 122 |
+
self.cache_position[0] = prefill_len
|
| 123 |
+
self._set_attention_mask(prefill_len)
|
| 124 |
+
|
| 125 |
+
for _ in range(num_warmup):
|
| 126 |
+
self._decode_step()
|
| 127 |
+
torch.cuda.synchronize()
|
| 128 |
+
|
| 129 |
+
print("Capturing CUDA graph for talker decode...")
|
| 130 |
+
|
| 131 |
+
with torch.cuda.device(self.device_index):
|
| 132 |
+
self.graph = torch.cuda.CUDAGraph()
|
| 133 |
+
|
| 134 |
+
s = torch.cuda.Stream()
|
| 135 |
+
s.wait_stream(torch.cuda.current_stream())
|
| 136 |
+
with torch.cuda.stream(s):
|
| 137 |
+
# Warmup in capture stream
|
| 138 |
+
self._decode_step()
|
| 139 |
+
torch.cuda.synchronize()
|
| 140 |
+
|
| 141 |
+
with torch.cuda.graph(self.graph):
|
| 142 |
+
self._decode_step()
|
| 143 |
+
|
| 144 |
+
torch.cuda.current_stream().wait_stream(s)
|
| 145 |
+
torch.cuda.synchronize()
|
| 146 |
+
self.captured = True
|
| 147 |
+
print("Talker CUDA graph captured!")
|
| 148 |
+
|
| 149 |
+
def reset(self, prefill_len: int):
|
| 150 |
+
"""Reset cache for new sequence."""
|
| 151 |
+
self.static_cache.reset()
|
| 152 |
+
|
| 153 |
+
def prefill_kv(self, past_key_values):
|
| 154 |
+
"""
|
| 155 |
+
Copy HF DynamicCache from prefill into our StaticCache.
|
| 156 |
+
past_key_values: DynamicCache with num_layers layers of [1, kv_heads, seq_len, head_dim]
|
| 157 |
+
"""
|
| 158 |
+
self.static_cache.reset()
|
| 159 |
+
seq_len = 0
|
| 160 |
+
for li in range(self.num_layers):
|
| 161 |
+
k, v = past_key_values[li] # each [1, kv_heads, seq_len, head_dim]
|
| 162 |
+
seq_len = k.shape[2]
|
| 163 |
+
if seq_len > self.max_seq_len:
|
| 164 |
+
raise RuntimeError(
|
| 165 |
+
f"Input is too long: prefill has {seq_len} tokens but max_seq_len={self.max_seq_len}. "
|
| 166 |
+
"Use shorter text or shorter reference audio."
|
| 167 |
+
)
|
| 168 |
+
cache_pos = torch.arange(seq_len, device=self.device)
|
| 169 |
+
self.static_cache.update(k, v, li, {"cache_position": cache_pos})
|
| 170 |
+
return seq_len
|
| 171 |
+
|
| 172 |
+
def set_generation_state(self, attention_mask: torch.Tensor, rope_deltas: torch.Tensor | None):
|
| 173 |
+
"""Set padding-aware attention mask and rope deltas for decode parity."""
|
| 174 |
+
mask_key = None
|
| 175 |
+
full_attention_mask = None
|
| 176 |
+
if attention_mask is not None:
|
| 177 |
+
pad_counts = (attention_mask == 0).sum(dim=-1)
|
| 178 |
+
mask_key = tuple(pad_counts.tolist())
|
| 179 |
+
full_attention_mask = torch.ones(
|
| 180 |
+
attention_mask.shape[0],
|
| 181 |
+
self.max_seq_len,
|
| 182 |
+
dtype=attention_mask.dtype,
|
| 183 |
+
device=attention_mask.device,
|
| 184 |
+
)
|
| 185 |
+
for b, pads in enumerate(pad_counts.tolist()):
|
| 186 |
+
if pads > 0:
|
| 187 |
+
full_attention_mask[b, :pads] = 0
|
| 188 |
+
if self.attn_mask_table is None or mask_key != self._mask_key:
|
| 189 |
+
self._build_attention_masks(full_attention_mask)
|
| 190 |
+
self._mask_key = mask_key
|
| 191 |
+
if rope_deltas is None:
|
| 192 |
+
self.rope_deltas.zero_()
|
| 193 |
+
else:
|
| 194 |
+
if rope_deltas.dim() == 1:
|
| 195 |
+
rope_deltas = rope_deltas.unsqueeze(1)
|
| 196 |
+
self.rope_deltas.copy_(rope_deltas.to(self.rope_deltas.device, dtype=self.rope_deltas.dtype))
|
| 197 |
+
|
| 198 |
+
@torch.inference_mode()
|
| 199 |
+
def run(self, input_embeds: torch.Tensor, position: int) -> torch.Tensor:
|
| 200 |
+
"""
|
| 201 |
+
Run one decode step.
|
| 202 |
+
input_embeds: [1, 1, hidden_size]
|
| 203 |
+
position: current sequence position
|
| 204 |
+
Returns: [1, 1, hidden_size] hidden states
|
| 205 |
+
"""
|
| 206 |
+
self.input_buf.copy_(input_embeds)
|
| 207 |
+
self.cache_position[0] = position
|
| 208 |
+
self._set_attention_mask(position)
|
| 209 |
+
# position_ids = arange(seq_len=1) + cache_position + rope_deltas
|
| 210 |
+
delta = self.rope_deltas + self.cache_position[0].to(self.rope_deltas.dtype)
|
| 211 |
+
self.position_ids.copy_(delta.unsqueeze(0).expand(3, -1, -1))
|
| 212 |
+
self.graph.replay()
|
| 213 |
+
|
| 214 |
+
return self.output_buf # static buffer — caller should use immediately or clone
|
faster_qwen3_tts/utils.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class _FilteredStdout:
|
| 6 |
+
def __init__(self, stream, suppress_substrings):
|
| 7 |
+
self._stream = stream
|
| 8 |
+
self._suppress = suppress_substrings
|
| 9 |
+
|
| 10 |
+
def write(self, data):
|
| 11 |
+
if any(s in data for s in self._suppress):
|
| 12 |
+
return len(data)
|
| 13 |
+
return self._stream.write(data)
|
| 14 |
+
|
| 15 |
+
def flush(self):
|
| 16 |
+
return self._stream.flush()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@contextlib.contextmanager
|
| 20 |
+
def suppress_flash_attn_warning():
|
| 21 |
+
filtered = _FilteredStdout(
|
| 22 |
+
sys.stdout,
|
| 23 |
+
suppress_substrings=(
|
| 24 |
+
"flash-attn is not installed",
|
| 25 |
+
"manual PyTorch version",
|
| 26 |
+
"Please install flash-attn",
|
| 27 |
+
),
|
| 28 |
+
)
|
| 29 |
+
with contextlib.redirect_stdout(filtered):
|
| 30 |
+
yield
|
main.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PYTORCH_ENABLE_MPS_FALLBACK=1 uvicorn main:app --host 0.0.0.0 --port 8888 --reload
|
| 2 |
+
# PYTORCH_ENABLE_MPS_FALLBACK=1 gunicorn main:app -b 0.0.0.0:8000 -w 4 -k uvicorn.workers.UvicornWorker
|
| 3 |
+
import io
|
| 4 |
+
import re
|
| 5 |
+
import os
|
| 6 |
+
import logging
|
| 7 |
+
import json
|
| 8 |
+
from time import gmtime
|
| 9 |
+
from datetime import datetime, timezone
|
| 10 |
+
from scipy.io import wavfile
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
from contextlib import asynccontextmanager
|
| 13 |
+
from tts import synthesize, device
|
| 14 |
+
from huggingface_hub import hf_hub_download
|
| 15 |
+
from llama_cpp import Llama
|
| 16 |
+
|
| 17 |
+
from fastapi import FastAPI, Response, Body, UploadFile, HTTPException
|
| 18 |
+
from starlette.middleware.cors import CORSMiddleware
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
load_dotenv(verbose=False)
|
| 22 |
+
|
| 23 |
+
LOGGING_DIRECTORY = os.getenv('LOGGING_DIRECTORY', 'logs')
|
| 24 |
+
|
| 25 |
+
if not os.path.isdir(LOGGING_DIRECTORY):
|
| 26 |
+
os.makedirs(LOGGING_DIRECTORY)
|
| 27 |
+
|
| 28 |
+
file_handler = logging.FileHandler(os.path.join(LOGGING_DIRECTORY, 'api.log'), mode='a', encoding='utf-8')
|
| 29 |
+
formatter = logging.Formatter(fmt='%(asctime)s.%(msecs)03dZ - %(levelname)s - %(message)s', datefmt='%Y-%m-%dT%H:%M:%S')
|
| 30 |
+
formatter.converter = gmtime
|
| 31 |
+
file_handler.setFormatter(formatter)
|
| 32 |
+
#logger = logging.getLogger('uvicorn')
|
| 33 |
+
logger = logging.getLogger('gunicorn.error')
|
| 34 |
+
logger.addHandler(file_handler)
|
| 35 |
+
|
| 36 |
+
llm_prompt_format = os.getenv('LLM_PROMPT_FORMAT', None)
|
| 37 |
+
model_path = os.environ.get('LLAMACPP_PATH', None)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@asynccontextmanager
|
| 41 |
+
async def lifespan(app: FastAPI):
|
| 42 |
+
global model_path
|
| 43 |
+
|
| 44 |
+
base_directory = 'data'
|
| 45 |
+
|
| 46 |
+
for language in os.listdir(base_directory):
|
| 47 |
+
path = os.path.join(base_directory, language)
|
| 48 |
+
|
| 49 |
+
if os.path.isdir(path):
|
| 50 |
+
for filename in os.listdir(path):
|
| 51 |
+
_, extension = os.path.splitext(filename)
|
| 52 |
+
|
| 53 |
+
if extension.lower() == '.wav':
|
| 54 |
+
with open(os.path.join(path, filename), mode='rb') as f, io.BytesIO() as wave_bytes, open(os.path.join(path, 'prompt.txt'), 'r', encoding='utf-8') as prompt_file, open(os.path.join(path, 'input.txt'), 'r', encoding='utf-8') as input_file:
|
| 55 |
+
wave_bytes.write(f.read())
|
| 56 |
+
wave_bytes.seek(0)
|
| 57 |
+
|
| 58 |
+
synthesize(prompt_wave=wave_bytes, prompt_text=prompt_file.read(), prompt_language=language, input_text=input_file.read(), input_language=language, top_p=1, temperature=1)
|
| 59 |
+
|
| 60 |
+
if model_path is None:
|
| 61 |
+
model_path = hf_hub_download(repo_id=os.environ['LLAMACPP_REPO_ID'], filename=os.environ['LLAMACPP_FILENAME'], local_dir='./models')
|
| 62 |
+
|
| 63 |
+
yield
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
app = FastAPI(lifespan=lifespan)
|
| 67 |
+
app.add_middleware(CORSMiddleware, allow_origins=['*'], allow_credentials=True, allow_methods=['*'], allow_headers=['*'])
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@app.get("/device")
|
| 71 |
+
async def read_device():
|
| 72 |
+
return {'device': str(device), 'timestamp': int(datetime.now(timezone.utc).replace(tzinfo=timezone.utc).timestamp())}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@app.post("/generate", status_code=201)
|
| 76 |
+
def create_generated_text(messages: list[dict[str, str]] = Body(...), temperature: float = Body(default=1.0)):
|
| 77 |
+
input_text = ''
|
| 78 |
+
|
| 79 |
+
if llm_prompt_format == 'Llama':
|
| 80 |
+
for message in messages:
|
| 81 |
+
if message['role'] == 'system':
|
| 82 |
+
input_text += f"<|start_header_id|>system<|end_header_id|>\n\n{message['content']}<|eot_id|>"
|
| 83 |
+
elif message['role'] == 'user':
|
| 84 |
+
input_text += f"<|start_header_id|>user<|end_header_id|>\n\n{message['content']}<|eot_id|>"
|
| 85 |
+
elif message['role'] == 'assistant':
|
| 86 |
+
input_text += f"<|start_header_id|>assistant<|end_header_id|>\n\n{message['content']}<|eot_id|>"
|
| 87 |
+
|
| 88 |
+
input_text += '<|start_header_id|>assistant<|end_header_id|>\n\n'
|
| 89 |
+
pattern = r'<|start_header_id|>assistant<|end_header_id|>\n\n(.+?)(?:(?:<|eot_id|>)|$)'
|
| 90 |
+
|
| 91 |
+
else:
|
| 92 |
+
for message in messages:
|
| 93 |
+
if message['role'] == 'system' or message['role'] == 'user':
|
| 94 |
+
input_text += f"<start_of_turn>user\n{message['content']}<end_of_turn>\n"
|
| 95 |
+
elif message['role'] == 'assistant':
|
| 96 |
+
input_text += f"<start_of_turn>model\n{message['content']}<end_of_turn>\n"
|
| 97 |
+
|
| 98 |
+
input_text += '<start_of_turn>model\n'
|
| 99 |
+
pattern = r'<start_of_turn>model\n(.+?)(?:(?:<end_of_turn>)|$)'
|
| 100 |
+
|
| 101 |
+
if len(input_text) > 0:
|
| 102 |
+
llm = Llama(model_path=model_path, n_ctx=8192, n_gpu_layers=-1, n_batch=32, verbose=False)
|
| 103 |
+
choices = []
|
| 104 |
+
|
| 105 |
+
try:
|
| 106 |
+
for choice in llm(input_text, max_tokens=2048, temperature=temperature, top_p=0.95, echo=True)['choices']:
|
| 107 |
+
matches = re.findall(pattern, choice['text'], re.DOTALL)
|
| 108 |
+
|
| 109 |
+
if len(matches) > 0:
|
| 110 |
+
choices.append({'role': 'assistant', 'content': matches[len(matches) - 1]})
|
| 111 |
+
|
| 112 |
+
finally:
|
| 113 |
+
llm.close()
|
| 114 |
+
|
| 115 |
+
return {'choices': choices, 'timestamp': int(datetime.now(timezone.utc).replace(tzinfo=timezone.utc).timestamp())}
|
| 116 |
+
|
| 117 |
+
else:
|
| 118 |
+
raise HTTPException(status_code=400)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@app.post("/synthesize", status_code=201)
|
| 122 |
+
def create_uploaded_file(file: UploadFile, data = Body(...)):
|
| 123 |
+
if file.content_type == 'audio/wav':
|
| 124 |
+
try:
|
| 125 |
+
data = json.loads(data)
|
| 126 |
+
|
| 127 |
+
with io.BytesIO() as prompt_wave_bytes, io.BytesIO() as output_wave_bytes:
|
| 128 |
+
prompt_wave_bytes.write(file.file.read())
|
| 129 |
+
prompt_wave_bytes.seek(0)
|
| 130 |
+
|
| 131 |
+
output, sample_rate = synthesize(prompt_wave=prompt_wave_bytes, prompt_text=data['prompt'] if 'prompt' in data else None, prompt_language=data['language'], input_text=data['input'], input_language=data['language'], top_p=data['top_p'] if 'top_p' in data else 1.0, temperature=data['temperature'] if 'temperature' in data else 1.0)
|
| 132 |
+
|
| 133 |
+
wavfile.write(output_wave_bytes, sample_rate, output)
|
| 134 |
+
output_wave_bytes.seek(0)
|
| 135 |
+
|
| 136 |
+
return Response(content=output_wave_bytes.read(), media_type="audio/wav")
|
| 137 |
+
|
| 138 |
+
except Exception as e:
|
| 139 |
+
logging.error(f'{e}')
|
| 140 |
+
|
| 141 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 142 |
+
|
| 143 |
+
else:
|
| 144 |
+
raise HTTPException(status_code=400)
|
qwen_tts/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
qwen_tts: Qwen-TTS package.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from .inference.qwen3_tts_model import Qwen3TTSModel, VoiceClonePromptItem
|
| 22 |
+
from .inference.qwen3_tts_tokenizer import Qwen3TTSTokenizer
|
| 23 |
+
|
| 24 |
+
__all__ = ["__version__"]
|
qwen_tts/__main__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
def main():
|
| 17 |
+
print(
|
| 18 |
+
"qwen_tts package.\n"
|
| 19 |
+
"Use CLI entrypoints:\n"
|
| 20 |
+
" - qwen-tts-demo\n"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
if __name__ == "__main__":
|
| 24 |
+
main()
|
qwen_tts/core/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
from .tokenizer_25hz.configuration_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Config
|
| 17 |
+
from .tokenizer_25hz.modeling_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Model
|
| 18 |
+
from .tokenizer_12hz.configuration_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Config
|
| 19 |
+
from .tokenizer_12hz.modeling_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Model
|
qwen_tts/core/models/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
from .configuration_qwen3_tts import Qwen3TTSConfig
|
| 17 |
+
from .modeling_qwen3_tts import Qwen3TTSForConditionalGeneration
|
| 18 |
+
from .processing_qwen3_tts import Qwen3TTSProcessor
|
qwen_tts/core/models/configuration_qwen3_tts.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
|
| 16 |
+
from transformers.modeling_rope_utils import rope_config_validation
|
| 17 |
+
from transformers.utils import logging
|
| 18 |
+
|
| 19 |
+
logger = logging.get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Qwen3TTSSpeakerEncoderConfig(PretrainedConfig):
|
| 23 |
+
r"""
|
| 24 |
+
This is the configuration class to store the configuration of a [`Qwen3TTSSpeakerEncoder`].
|
| 25 |
+
It is used to instantiate a Qwen3TTS speaker encoder model according to the specified arguments, defining the model
|
| 26 |
+
architecture. The architecture is based on the ECAPA-TDNN model.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
mel_dim (`int`, *optional*, defaults to 128):
|
| 30 |
+
The dimension of the input mel-spectrogram.
|
| 31 |
+
enc_dim (`int`, *optional*, defaults to 192):
|
| 32 |
+
The dimension of the final speaker embedding.
|
| 33 |
+
enc_channels (`list[int]`, *optional*, defaults to `[512, 512, 512, 512, 1536]`):
|
| 34 |
+
A list of output channels for each TDNN/SERes2Net layer in the encoder. The first channel size is for the initial TDNN layer,
|
| 35 |
+
the intermediate ones for the `SqueezeExcitationRes2NetBlock` layers, and the last one for the multi-layer feature aggregation.
|
| 36 |
+
enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`):
|
| 37 |
+
A list of kernel sizes for each layer in the encoder, corresponding to `enc_channels`.
|
| 38 |
+
enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`):
|
| 39 |
+
A list of dilations for each layer in the encoder, corresponding to `enc_channels`.
|
| 40 |
+
enc_attention_channels (`int`, *optional*, defaults to 128):
|
| 41 |
+
The number of attention channels in the `AttentiveStatisticsPooling` layer.
|
| 42 |
+
enc_res2net_scale (`int`, *optional*,defaults to 8):
|
| 43 |
+
The scale of the `Res2NetBlock` in the encoder.
|
| 44 |
+
enc_se_channels (`int`, *optional*, defaults to 128):
|
| 45 |
+
The number of channels in the squeeze part of the `SqueezeExcitationBlock`.
|
| 46 |
+
"""
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
mel_dim=128,
|
| 50 |
+
enc_dim=1024,
|
| 51 |
+
enc_channels=[512, 512, 512, 512, 1536],
|
| 52 |
+
enc_kernel_sizes=[5, 3, 3, 3, 1],
|
| 53 |
+
enc_dilations=[1, 2, 3, 4, 1],
|
| 54 |
+
enc_attention_channels=128,
|
| 55 |
+
enc_res2net_scale=8,
|
| 56 |
+
enc_se_channels=128,
|
| 57 |
+
sample_rate=24000,
|
| 58 |
+
):
|
| 59 |
+
self.mel_dim = mel_dim
|
| 60 |
+
self.enc_dim = enc_dim
|
| 61 |
+
self.enc_channels = enc_channels
|
| 62 |
+
self.enc_kernel_sizes = enc_kernel_sizes
|
| 63 |
+
self.enc_dilations = enc_dilations
|
| 64 |
+
self.enc_attention_channels = enc_attention_channels
|
| 65 |
+
self.enc_res2net_scale = enc_res2net_scale
|
| 66 |
+
self.enc_se_channels = enc_se_channels
|
| 67 |
+
self.sample_rate = sample_rate
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class Qwen3TTSTalkerCodePredictorConfig(PretrainedConfig):
|
| 71 |
+
r"""
|
| 72 |
+
This is the configuration class to store the configuration of a [`Qwen3TTSTalkerCodePredictorModel`]. It is used to instantiate a
|
| 73 |
+
Qwen3TTSTalkerCodePredictor model according to the specified arguments, defining the model architecture.
|
| 74 |
+
|
| 75 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 76 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
vocab_size (`int`, *optional*, defaults to 151936):
|
| 81 |
+
Vocabulary size of the Qwen3TTSTalkerCodePredictor model. Defines the number of different tokens that can be represented by the
|
| 82 |
+
`inputs_ids` passed when calling [`Qwen3TTSTalkerCodePredictorModel`]
|
| 83 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 84 |
+
Dimension of the hidden representations.
|
| 85 |
+
intermediate_size (`int`, *optional*, defaults to 22016):
|
| 86 |
+
Dimension of the MLP representations.
|
| 87 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 88 |
+
Number of hidden layers in the Transformer encoder.
|
| 89 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 90 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 91 |
+
num_key_value_heads (`int`, *optional*, defaults to 32):
|
| 92 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 93 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 94 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 95 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 96 |
+
by meanpooling all the original heads within that group. For more details, check out [this
|
| 97 |
+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
|
| 98 |
+
head_dim (`int`, *optional*, defaults to 128):
|
| 99 |
+
The attention head dimension.
|
| 100 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 101 |
+
The non-linear activation function (function or string) in the decoder.
|
| 102 |
+
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
| 103 |
+
The maximum sequence length that this model might ever be used with.
|
| 104 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 105 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 106 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 107 |
+
The epsilon used by the rms normalization layers.
|
| 108 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 109 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 110 |
+
relevant if `config.is_decoder=True`.
|
| 111 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 112 |
+
Whether the model's input and output word embeddings should be tied.
|
| 113 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 114 |
+
The base period of the RoPE embeddings.
|
| 115 |
+
rope_scaling (`Dict`, *optional*):
|
| 116 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 117 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 118 |
+
accordingly.
|
| 119 |
+
Expected contents:
|
| 120 |
+
`rope_type` (`str`):
|
| 121 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 122 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 123 |
+
`factor` (`float`, *optional*):
|
| 124 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 125 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 126 |
+
original maximum pre-trained length.
|
| 127 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 128 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 129 |
+
pretraining.
|
| 130 |
+
`attention_factor` (`float`, *optional*):
|
| 131 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 132 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 133 |
+
`factor` field to infer the suggested value.
|
| 134 |
+
`beta_fast` (`float`, *optional*):
|
| 135 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 136 |
+
ramp function. If unspecified, it defaults to 32.
|
| 137 |
+
`beta_slow` (`float`, *optional*):
|
| 138 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 139 |
+
ramp function. If unspecified, it defaults to 1.
|
| 140 |
+
`short_factor` (`list[float]`, *optional*):
|
| 141 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 142 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 143 |
+
size divided by the number of attention heads divided by 2
|
| 144 |
+
`long_factor` (`list[float]`, *optional*):
|
| 145 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 146 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 147 |
+
size divided by the number of attention heads divided by 2
|
| 148 |
+
`low_freq_factor` (`float`, *optional*):
|
| 149 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 150 |
+
`high_freq_factor` (`float`, *optional*):
|
| 151 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 152 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 153 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 154 |
+
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
| 155 |
+
Whether to use sliding window attention.
|
| 156 |
+
sliding_window (`int`, *optional*, defaults to 4096):
|
| 157 |
+
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
| 158 |
+
max_window_layers (`int`, *optional*, defaults to 28):
|
| 159 |
+
The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any
|
| 160 |
+
additional layer afterwards will use SWA (Sliding Window Attention).
|
| 161 |
+
layer_types (`list`, *optional*):
|
| 162 |
+
Attention pattern for each layer.
|
| 163 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 164 |
+
The dropout ratio for the attention probabilities.
|
| 165 |
+
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
model_type = "qwen3_tts_talker_code_predictor"
|
| 169 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 170 |
+
|
| 171 |
+
# Default tensor parallel plan for base model `Qwen3TTSTalkerCodePredictor`
|
| 172 |
+
base_model_tp_plan = {
|
| 173 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 174 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 175 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 176 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 177 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 178 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 179 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 180 |
+
}
|
| 181 |
+
base_model_pp_plan = {
|
| 182 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 183 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 184 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
vocab_size=2048,
|
| 190 |
+
hidden_size=1024,
|
| 191 |
+
intermediate_size=3072,
|
| 192 |
+
num_hidden_layers=5,
|
| 193 |
+
num_attention_heads=16,
|
| 194 |
+
num_key_value_heads=8,
|
| 195 |
+
head_dim=128,
|
| 196 |
+
hidden_act="silu",
|
| 197 |
+
max_position_embeddings=32768,
|
| 198 |
+
initializer_range=0.02,
|
| 199 |
+
rms_norm_eps=0.000001,
|
| 200 |
+
use_cache=True,
|
| 201 |
+
tie_word_embeddings=False,
|
| 202 |
+
rope_theta=10000,
|
| 203 |
+
rope_scaling=None,
|
| 204 |
+
attention_bias=False,
|
| 205 |
+
use_sliding_window=False,
|
| 206 |
+
sliding_window=4096,
|
| 207 |
+
max_window_layers=28,
|
| 208 |
+
layer_types=None,
|
| 209 |
+
attention_dropout=0,
|
| 210 |
+
num_code_groups=32,
|
| 211 |
+
**kwargs,
|
| 212 |
+
):
|
| 213 |
+
super().__init__(
|
| 214 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 215 |
+
**kwargs,
|
| 216 |
+
)
|
| 217 |
+
self.vocab_size = vocab_size
|
| 218 |
+
self.max_position_embeddings = max_position_embeddings
|
| 219 |
+
self.hidden_size = hidden_size
|
| 220 |
+
self.intermediate_size = intermediate_size
|
| 221 |
+
self.num_hidden_layers = num_hidden_layers
|
| 222 |
+
self.num_attention_heads = num_attention_heads
|
| 223 |
+
self.use_sliding_window = use_sliding_window
|
| 224 |
+
self.sliding_window = sliding_window if self.use_sliding_window else None
|
| 225 |
+
self.max_window_layers = max_window_layers
|
| 226 |
+
|
| 227 |
+
# for backward compatibility
|
| 228 |
+
if num_key_value_heads is None:
|
| 229 |
+
num_key_value_heads = num_attention_heads
|
| 230 |
+
|
| 231 |
+
self.num_key_value_heads = num_key_value_heads
|
| 232 |
+
self.head_dim = head_dim
|
| 233 |
+
self.hidden_act = hidden_act
|
| 234 |
+
self.initializer_range = initializer_range
|
| 235 |
+
self.rms_norm_eps = rms_norm_eps
|
| 236 |
+
self.use_cache = use_cache
|
| 237 |
+
self.rope_theta = rope_theta
|
| 238 |
+
self.rope_scaling = rope_scaling
|
| 239 |
+
self.attention_bias = attention_bias
|
| 240 |
+
self.attention_dropout = attention_dropout
|
| 241 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 242 |
+
# BC: if there is a 'type' field, move it to 'rope_type'.
|
| 243 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 244 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 245 |
+
rope_config_validation(self)
|
| 246 |
+
|
| 247 |
+
self.layer_types = layer_types
|
| 248 |
+
if self.layer_types is None:
|
| 249 |
+
self.layer_types = [
|
| 250 |
+
"sliding_attention"
|
| 251 |
+
if self.sliding_window is not None and i >= self.max_window_layers
|
| 252 |
+
else "full_attention"
|
| 253 |
+
for i in range(self.num_hidden_layers)
|
| 254 |
+
]
|
| 255 |
+
layer_type_validation(self.layer_types)
|
| 256 |
+
self.num_code_groups = num_code_groups
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class Qwen3TTSTalkerConfig(PretrainedConfig):
|
| 260 |
+
r"""
|
| 261 |
+
This is the configuration class to store the configuration of a [`Qwen3TTSTalkerModel`]. It is used to instantiate a
|
| 262 |
+
Qwen3TTSTalker model according to the specified arguments, defining the model architecture.
|
| 263 |
+
|
| 264 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 265 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
vocab_size (`int`, *optional*, defaults to 151936):
|
| 270 |
+
Vocabulary size of the Qwen3TTSTalker model. Defines the number of different tokens that can be represented by the
|
| 271 |
+
`inputs_ids` passed when calling [`Qwen3TTSTalkerModel`]
|
| 272 |
+
hidden_size (`int`, *optional*, defaults to 2048):
|
| 273 |
+
Dimension of the hidden representations.
|
| 274 |
+
intermediate_size (`int`, *optional*, defaults to 6144):
|
| 275 |
+
Dimension of the MLP representations.
|
| 276 |
+
num_hidden_layers (`int`, *optional*, defaults to 24):
|
| 277 |
+
Number of hidden layers in the Transformer encoder.
|
| 278 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 279 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 280 |
+
num_key_value_heads (`int`, *optional*, defaults to 4):
|
| 281 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 282 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 283 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 284 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 285 |
+
by meanpooling all the original heads within that group. For more details, check out [this
|
| 286 |
+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
|
| 287 |
+
|
| 288 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 289 |
+
The non-linear activation function (function or string) in the decoder.
|
| 290 |
+
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
| 291 |
+
The maximum sequence length that this model might ever be used with.
|
| 292 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 293 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 294 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 295 |
+
The epsilon used by the rms normalization layers.
|
| 296 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 297 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 298 |
+
relevant if `config.is_decoder=True`.
|
| 299 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 300 |
+
Whether the model's input and output word embeddings should be tied.
|
| 301 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 302 |
+
The base period of the RoPE embeddings.
|
| 303 |
+
rope_scaling (`Dict`, *optional*):
|
| 304 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 305 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 306 |
+
accordingly.
|
| 307 |
+
Expected contents:
|
| 308 |
+
`rope_type` (`str`):
|
| 309 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 310 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 311 |
+
`factor` (`float`, *optional*):
|
| 312 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 313 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 314 |
+
original maximum pre-trained length.
|
| 315 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 316 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 317 |
+
pretraining.
|
| 318 |
+
`attention_factor` (`float`, *optional*):
|
| 319 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 320 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 321 |
+
`factor` field to infer the suggested value.
|
| 322 |
+
`beta_fast` (`float`, *optional*):
|
| 323 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 324 |
+
ramp function. If unspecified, it defaults to 32.
|
| 325 |
+
`beta_slow` (`float`, *optional*):
|
| 326 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 327 |
+
ramp function. If unspecified, it defaults to 1.
|
| 328 |
+
`short_factor` (`list[float]`, *optional*):
|
| 329 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 330 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 331 |
+
size divided by the number of attention heads divided by 2
|
| 332 |
+
`long_factor` (`list[float]`, *optional*):
|
| 333 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 334 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 335 |
+
size divided by the number of attention heads divided by 2
|
| 336 |
+
`low_freq_factor` (`float`, *optional*):
|
| 337 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 338 |
+
`high_freq_factor` (`float`, *optional*):
|
| 339 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 340 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 341 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 342 |
+
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
| 343 |
+
Whether to use sliding window attention.
|
| 344 |
+
sliding_window (`int`, *optional*, defaults to 4096):
|
| 345 |
+
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
| 346 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 347 |
+
The dropout ratio for the attention probabilities.
|
| 348 |
+
"""
|
| 349 |
+
|
| 350 |
+
model_type = "qwen3_tts_talker"
|
| 351 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 352 |
+
|
| 353 |
+
# Default tensor parallel plan for base model `Qwen3TTSTalker`
|
| 354 |
+
base_model_tp_plan = {
|
| 355 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 356 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 357 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 358 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 359 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 360 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 361 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 362 |
+
}
|
| 363 |
+
base_model_pp_plan = {
|
| 364 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 365 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 366 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 367 |
+
}
|
| 368 |
+
sub_configs = {"code_predictor_config": Qwen3TTSTalkerCodePredictorConfig}
|
| 369 |
+
|
| 370 |
+
def __init__(
|
| 371 |
+
self,
|
| 372 |
+
code_predictor_config=None,
|
| 373 |
+
vocab_size=3072,
|
| 374 |
+
hidden_size=1024,
|
| 375 |
+
intermediate_size=2048,
|
| 376 |
+
num_hidden_layers=20,
|
| 377 |
+
num_attention_heads=16,
|
| 378 |
+
num_key_value_heads=2,
|
| 379 |
+
hidden_act="silu",
|
| 380 |
+
max_position_embeddings=32768,
|
| 381 |
+
initializer_range=0.02,
|
| 382 |
+
rms_norm_eps=0.000001,
|
| 383 |
+
use_cache=True,
|
| 384 |
+
tie_word_embeddings=False,
|
| 385 |
+
rope_theta=10000,
|
| 386 |
+
rope_scaling=None,
|
| 387 |
+
attention_bias=False,
|
| 388 |
+
use_sliding_window=False,
|
| 389 |
+
sliding_window=4096,
|
| 390 |
+
attention_dropout=0,
|
| 391 |
+
num_code_groups=32,
|
| 392 |
+
text_hidden_size=2048,
|
| 393 |
+
codec_eos_token_id=4198,
|
| 394 |
+
codec_think_id=4202,
|
| 395 |
+
codec_nothink_id=4203,
|
| 396 |
+
codec_think_bos_id=4204,
|
| 397 |
+
codec_think_eos_id=4205,
|
| 398 |
+
codec_pad_id=4196,
|
| 399 |
+
codec_bos_id=4197,
|
| 400 |
+
spk_id=None,
|
| 401 |
+
spk_is_dialect=None,
|
| 402 |
+
codec_language_id=None,
|
| 403 |
+
**kwargs,
|
| 404 |
+
):
|
| 405 |
+
super().__init__(
|
| 406 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 407 |
+
**kwargs,
|
| 408 |
+
)
|
| 409 |
+
self.vocab_size = vocab_size
|
| 410 |
+
self.max_position_embeddings = max_position_embeddings
|
| 411 |
+
self.hidden_size = hidden_size
|
| 412 |
+
self.intermediate_size = intermediate_size
|
| 413 |
+
self.num_hidden_layers = num_hidden_layers
|
| 414 |
+
self.num_attention_heads = num_attention_heads
|
| 415 |
+
self.use_sliding_window = use_sliding_window
|
| 416 |
+
self.sliding_window = sliding_window if use_sliding_window else None
|
| 417 |
+
|
| 418 |
+
self.num_key_value_heads = num_key_value_heads
|
| 419 |
+
self.hidden_act = hidden_act
|
| 420 |
+
self.initializer_range = initializer_range
|
| 421 |
+
self.rms_norm_eps = rms_norm_eps
|
| 422 |
+
self.use_cache = use_cache
|
| 423 |
+
self.rope_theta = rope_theta
|
| 424 |
+
self.rope_scaling = rope_scaling
|
| 425 |
+
self.attention_bias = attention_bias
|
| 426 |
+
self.attention_dropout = attention_dropout
|
| 427 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 428 |
+
# BC: if there is a 'type' field, move it to 'rope_type'.
|
| 429 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 430 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 431 |
+
|
| 432 |
+
if code_predictor_config is None:
|
| 433 |
+
code_predictor_config = {}
|
| 434 |
+
self.code_predictor_config = Qwen3TTSTalkerCodePredictorConfig()
|
| 435 |
+
logger.info("code_predictor_config is None. Initializing code_predictor model with default values")
|
| 436 |
+
elif isinstance(code_predictor_config, Qwen3TTSTalkerCodePredictorConfig):
|
| 437 |
+
self.code_predictor_config = code_predictor_config
|
| 438 |
+
else:
|
| 439 |
+
self.code_predictor_config = Qwen3TTSTalkerCodePredictorConfig(**code_predictor_config)
|
| 440 |
+
self.num_code_groups = num_code_groups
|
| 441 |
+
self.text_hidden_size = text_hidden_size
|
| 442 |
+
self.codec_eos_token_id = codec_eos_token_id
|
| 443 |
+
self.codec_think_id = codec_think_id
|
| 444 |
+
self.codec_language_id = codec_language_id
|
| 445 |
+
self.codec_nothink_id = codec_nothink_id
|
| 446 |
+
self.codec_think_bos_id = codec_think_bos_id
|
| 447 |
+
self.codec_think_eos_id = codec_think_eos_id
|
| 448 |
+
self.codec_pad_id = codec_pad_id
|
| 449 |
+
self.codec_bos_id = codec_bos_id
|
| 450 |
+
self.spk_id = spk_id
|
| 451 |
+
self.spk_is_dialect = spk_is_dialect
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
class Qwen3TTSConfig(PretrainedConfig):
|
| 455 |
+
"""
|
| 456 |
+
This is the configuration class to store the configuration of a [`Qwen3TTSForConditionalGeneration`].
|
| 457 |
+
"""
|
| 458 |
+
|
| 459 |
+
model_type = "qwen3_tts"
|
| 460 |
+
sub_configs = {
|
| 461 |
+
"talker_config": Qwen3TTSTalkerConfig,
|
| 462 |
+
"speaker_encoder_config": Qwen3TTSSpeakerEncoderConfig,
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
def __init__(
|
| 466 |
+
self,
|
| 467 |
+
talker_config=None,
|
| 468 |
+
speaker_encoder_config=None,
|
| 469 |
+
tokenizer_type=None,
|
| 470 |
+
tts_model_size=None,
|
| 471 |
+
tts_model_type=None,
|
| 472 |
+
im_start_token_id=151644,
|
| 473 |
+
im_end_token_id=151645,
|
| 474 |
+
tts_pad_token_id=151671,
|
| 475 |
+
tts_bos_token_id=151672,
|
| 476 |
+
tts_eos_token_id=151673,
|
| 477 |
+
**kwargs,
|
| 478 |
+
):
|
| 479 |
+
super().__init__(**kwargs)
|
| 480 |
+
|
| 481 |
+
if talker_config is None:
|
| 482 |
+
talker_config = {}
|
| 483 |
+
logger.info("talker_config is None. Initializing talker model with default values")
|
| 484 |
+
if speaker_encoder_config is None:
|
| 485 |
+
speaker_encoder_config = {}
|
| 486 |
+
logger.info("speaker_encoder_config is None. Initializing talker model with default values")
|
| 487 |
+
|
| 488 |
+
self.talker_config = Qwen3TTSTalkerConfig(**talker_config)
|
| 489 |
+
self.speaker_encoder_config = Qwen3TTSSpeakerEncoderConfig(**speaker_encoder_config)
|
| 490 |
+
|
| 491 |
+
self.tokenizer_type = tokenizer_type
|
| 492 |
+
self.tts_model_size = tts_model_size
|
| 493 |
+
self.tts_model_type = tts_model_type
|
| 494 |
+
|
| 495 |
+
self.im_start_token_id = im_start_token_id
|
| 496 |
+
self.im_end_token_id = im_end_token_id
|
| 497 |
+
self.tts_pad_token_id = tts_pad_token_id
|
| 498 |
+
self.tts_bos_token_id = tts_bos_token_id
|
| 499 |
+
self.tts_eos_token_id = tts_eos_token_id
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
__all__ = ["Qwen3TTSConfig", "Qwen3TTSTalkerConfig", "Qwen3TTSSpeakerEncoderConfig"]
|
qwen_tts/core/models/modeling_qwen3_tts.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
qwen_tts/core/models/processing_qwen3_tts.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 16 |
+
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Qwen3TTSProcessorKwargs(ProcessingKwargs, total=False):
|
| 20 |
+
_defaults = {
|
| 21 |
+
"text_kwargs": {
|
| 22 |
+
"padding": False,
|
| 23 |
+
"padding_side": "left",
|
| 24 |
+
}
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
class Qwen3TTSProcessor(ProcessorMixin):
|
| 28 |
+
r"""
|
| 29 |
+
Constructs a Qwen3TTS processor.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
| 33 |
+
The text tokenizer.
|
| 34 |
+
chat_template (`Optional[str]`, *optional*):
|
| 35 |
+
The Jinja template to use for formatting the conversation. If not provided, the default chat template is used.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
attributes = ["tokenizer"]
|
| 39 |
+
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self, tokenizer=None, chat_template=None
|
| 43 |
+
):
|
| 44 |
+
super().__init__(tokenizer, chat_template=chat_template)
|
| 45 |
+
|
| 46 |
+
def __call__(self, text=None, **kwargs) -> BatchFeature:
|
| 47 |
+
"""
|
| 48 |
+
Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
|
| 49 |
+
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
| 50 |
+
the text.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
text (`str`, `List[str]`, `List[List[str]]`):
|
| 54 |
+
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
| 55 |
+
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
| 56 |
+
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
if text is None:
|
| 60 |
+
raise ValueError("You need to specify either a `text` input to process.")
|
| 61 |
+
|
| 62 |
+
output_kwargs = self._merge_kwargs(
|
| 63 |
+
Qwen3TTSProcessorKwargs,
|
| 64 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 65 |
+
**kwargs,
|
| 66 |
+
)
|
| 67 |
+
if not isinstance(text, list):
|
| 68 |
+
text = [text]
|
| 69 |
+
|
| 70 |
+
texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
| 71 |
+
|
| 72 |
+
return BatchFeature(
|
| 73 |
+
data={**texts_inputs},
|
| 74 |
+
tensor_type=kwargs.get("return_tensors"),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def batch_decode(self, *args, **kwargs):
|
| 78 |
+
"""
|
| 79 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 80 |
+
refer to the docstring of this method for more information.
|
| 81 |
+
"""
|
| 82 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 83 |
+
|
| 84 |
+
def decode(self, *args, **kwargs):
|
| 85 |
+
"""
|
| 86 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 87 |
+
the docstring of this method for more information.
|
| 88 |
+
"""
|
| 89 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 90 |
+
|
| 91 |
+
def apply_chat_template(self, conversations, chat_template=None, **kwargs):
|
| 92 |
+
if isinstance(conversations[0], dict):
|
| 93 |
+
conversations = [conversations]
|
| 94 |
+
return super().apply_chat_template(conversations, chat_template, **kwargs)
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def model_input_names(self):
|
| 98 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 99 |
+
return list(
|
| 100 |
+
dict.fromkeys(
|
| 101 |
+
tokenizer_input_names
|
| 102 |
+
)
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
__all__ = ["Qwen3TTSProcessor"]
|
qwen_tts/core/tokenizer_12hz/configuration_qwen3_tts_tokenizer_v2.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Qwen3TTSTokenizerV2 model configuration"""
|
| 16 |
+
|
| 17 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 18 |
+
from transformers.utils import logging
|
| 19 |
+
|
| 20 |
+
from transformers import MimiConfig
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.get_logger(__name__)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Qwen3TTSTokenizerV2DecoderConfig(PretrainedConfig):
|
| 27 |
+
r"""
|
| 28 |
+
This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV2DecoderConfig`].
|
| 29 |
+
|
| 30 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 31 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
codebook_size (`int`, *optional*, defaults to 2048):
|
| 35 |
+
Number of entries in each residual codebook used for acoustic token quantization.
|
| 36 |
+
hidden_size (`int`, *optional*, defaults to 1024):
|
| 37 |
+
Dimensionality of the hidden states and embeddings in the autoregressive transformer decoder.
|
| 38 |
+
max_position_embeddings (`int`, *optional*, defaults to 8000):
|
| 39 |
+
Maximum sequence length that the autoregressive decoder can handle. Determines positional embedding size.
|
| 40 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 41 |
+
The base period for rotary position embeddings (RoPE) applied to attention layers.
|
| 42 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
| 43 |
+
Number of attention heads for each attention layer in the decoder.
|
| 44 |
+
num_key_value_heads (`int`, *optional*, defaults to 16):
|
| 45 |
+
Number of key and value attention heads used in grouped-query attention (if applicable).
|
| 46 |
+
attention_bias (`bool`, *optional*, defaults to `False`):
|
| 47 |
+
Whether to use bias in the attention projection layers.
|
| 48 |
+
sliding_window (`int`, *optional*, defaults to 72):
|
| 49 |
+
Window size for local attention mechanism, limiting attention context to improve efficiency.
|
| 50 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 51 |
+
Dimensionality of the feed-forward (intermediate) layer in each transformer block.
|
| 52 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 53 |
+
The non-linear activation function used in the feed-forward layers. Supports `"silu"`, `"relu"`, `"gelu"`, etc.
|
| 54 |
+
layer_scale_initial_scale (`float`, *optional*, defaults to 0.01):
|
| 55 |
+
Initial value for LayerScale applied in transformer blocks, helping stabilize training.
|
| 56 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-5):
|
| 57 |
+
Epsilon value for RMS normalization layers to prevent division by zero.
|
| 58 |
+
num_hidden_layers (`int`, *optional*, defaults to 8):
|
| 59 |
+
Number of transformer blocks in the autoregressive decoder.
|
| 60 |
+
num_quantizers (`int`, *optional*, defaults to 16):
|
| 61 |
+
Number of residual vector quantizers used in the vocoder for fine-grained audio reconstruction.
|
| 62 |
+
upsample_rates (`Tuple[int]`, *optional*, defaults to `(8, 5, 4, 3)`):
|
| 63 |
+
Rate at which features are upsampled in the final waveform synthesis stage.
|
| 64 |
+
upsampling_ratios (`Tuple[int]`, *optional*, defaults to `(2, 2)`):
|
| 65 |
+
Ratios used in transposed convolutional layers to progressively upsample feature maps to waveform.
|
| 66 |
+
decoder_dim (`int`, *optional*, defaults to 1536):
|
| 67 |
+
Final dimensionality of the decoder's output before waveform generation.
|
| 68 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 69 |
+
Dropout probability applied to attention weights in the decoder.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
codebook_size=2048,
|
| 75 |
+
hidden_size=1024,
|
| 76 |
+
latent_dim=1024,
|
| 77 |
+
max_position_embeddings=8000,
|
| 78 |
+
rope_theta=10000,
|
| 79 |
+
num_attention_heads=16,
|
| 80 |
+
num_key_value_heads=16,
|
| 81 |
+
attention_bias=False,
|
| 82 |
+
sliding_window=72,
|
| 83 |
+
intermediate_size=3072,
|
| 84 |
+
hidden_act="silu",
|
| 85 |
+
layer_scale_initial_scale=0.01,
|
| 86 |
+
rms_norm_eps=1e-5,
|
| 87 |
+
num_hidden_layers=8,
|
| 88 |
+
num_quantizers=16,
|
| 89 |
+
upsample_rates=(8, 5, 4, 3),
|
| 90 |
+
upsampling_ratios=(2, 2),
|
| 91 |
+
decoder_dim=1536,
|
| 92 |
+
attention_dropout=0.0,
|
| 93 |
+
**kwargs,
|
| 94 |
+
):
|
| 95 |
+
super().__init__(**kwargs)
|
| 96 |
+
self.codebook_size = codebook_size
|
| 97 |
+
self.hidden_size = hidden_size
|
| 98 |
+
self.latent_dim = latent_dim
|
| 99 |
+
self.max_position_embeddings = max_position_embeddings
|
| 100 |
+
self.rope_theta = rope_theta
|
| 101 |
+
self.num_attention_heads = num_attention_heads
|
| 102 |
+
self.num_key_value_heads = num_key_value_heads
|
| 103 |
+
self.attention_bias = attention_bias
|
| 104 |
+
self.sliding_window = sliding_window
|
| 105 |
+
self.intermediate_size = intermediate_size
|
| 106 |
+
self.hidden_act = hidden_act
|
| 107 |
+
self.layer_scale_initial_scale = layer_scale_initial_scale
|
| 108 |
+
self.rms_norm_eps = rms_norm_eps
|
| 109 |
+
self.num_hidden_layers = num_hidden_layers
|
| 110 |
+
self.num_quantizers = num_quantizers
|
| 111 |
+
self.upsample_rates = upsample_rates
|
| 112 |
+
self.upsampling_ratios = upsampling_ratios
|
| 113 |
+
self.decoder_dim = decoder_dim
|
| 114 |
+
self.attention_dropout = attention_dropout
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def layer_types(self):
|
| 118 |
+
"""
|
| 119 |
+
All layer in code2wav should be sliding attention
|
| 120 |
+
"""
|
| 121 |
+
return ["sliding_attention"] * self.num_hidden_layers
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class Qwen3TTSTokenizerV2Config(PretrainedConfig):
|
| 125 |
+
"""
|
| 126 |
+
This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV2Config`]. It is used to instantiate a Qwen3TTSTokenizerV2Model
|
| 127 |
+
model according to the specified sub-models configurations, defining the model architecture.
|
| 128 |
+
|
| 129 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 130 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
encoder_config (`dict`, *optional*): Configuration of the underlying encoder sub-model.
|
| 134 |
+
decoder_config (`dict`, *optional*): Configuration of the underlying decoder sub-model.
|
| 135 |
+
"""
|
| 136 |
+
|
| 137 |
+
model_type = "qwen3_tts_tokenizer_12hz"
|
| 138 |
+
sub_configs = {
|
| 139 |
+
"encoder_config": MimiConfig,
|
| 140 |
+
"decoder_config": Qwen3TTSTokenizerV2DecoderConfig,
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
def __init__(
|
| 144 |
+
self,
|
| 145 |
+
encoder_config=None,
|
| 146 |
+
decoder_config=None,
|
| 147 |
+
encoder_valid_num_quantizers=16,
|
| 148 |
+
input_sample_rate=24000,
|
| 149 |
+
output_sample_rate=24000,
|
| 150 |
+
decode_upsample_rate=1920,
|
| 151 |
+
encode_downsample_rate=1920,
|
| 152 |
+
**kwargs,
|
| 153 |
+
):
|
| 154 |
+
super().__init__(**kwargs)
|
| 155 |
+
if encoder_config is None:
|
| 156 |
+
encoder_config = {}
|
| 157 |
+
logger.info("encoder_config is None. Initializing encoder with default values")
|
| 158 |
+
if decoder_config is None:
|
| 159 |
+
decoder_config = {}
|
| 160 |
+
logger.info("decoder_config is None. Initializing decoder with default values")
|
| 161 |
+
|
| 162 |
+
self.encoder_config = MimiConfig(**encoder_config)
|
| 163 |
+
self.decoder_config = Qwen3TTSTokenizerV2DecoderConfig(**decoder_config)
|
| 164 |
+
|
| 165 |
+
self.encoder_valid_num_quantizers = encoder_valid_num_quantizers
|
| 166 |
+
self.input_sample_rate = input_sample_rate
|
| 167 |
+
self.output_sample_rate = output_sample_rate
|
| 168 |
+
self.decode_upsample_rate = decode_upsample_rate
|
| 169 |
+
self.encode_downsample_rate = encode_downsample_rate
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
__all__ = ["Qwen3TTSTokenizerV2Config", "Qwen3TTSTokenizerV2DecoderConfig"]
|
qwen_tts/core/tokenizer_12hz/modeling_qwen3_tts_tokenizer_v2.py
ADDED
|
@@ -0,0 +1,1025 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch Qwen3TTSTokenizerV2 model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Callable, Optional, Union, List
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
from torch import nn
|
| 24 |
+
from torch.nn import Parameter
|
| 25 |
+
from torch.nn import functional as F
|
| 26 |
+
from transformers import MimiConfig, MimiModel
|
| 27 |
+
from transformers.activations import ACT2FN
|
| 28 |
+
from transformers.cache_utils import Cache, DynamicCache
|
| 29 |
+
from transformers.integrations import use_kernel_forward_from_hub
|
| 30 |
+
from transformers.masking_utils import (
|
| 31 |
+
create_causal_mask,
|
| 32 |
+
create_sliding_window_causal_mask,
|
| 33 |
+
)
|
| 34 |
+
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 35 |
+
from transformers.modeling_layers import GradientCheckpointingLayer
|
| 36 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 37 |
+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 38 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 39 |
+
from transformers.processing_utils import Unpack
|
| 40 |
+
from transformers.utils import ModelOutput, auto_docstring, logging
|
| 41 |
+
from transformers.utils.deprecation import deprecate_kwarg
|
| 42 |
+
from transformers.utils.generic import check_model_inputs
|
| 43 |
+
|
| 44 |
+
from .configuration_qwen3_tts_tokenizer_v2 import (
|
| 45 |
+
Qwen3TTSTokenizerV2Config,
|
| 46 |
+
Qwen3TTSTokenizerV2DecoderConfig,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
@auto_docstring
|
| 54 |
+
class Qwen3TTSTokenizerV2EncoderOutput(ModelOutput):
|
| 55 |
+
r"""
|
| 56 |
+
audio_codes (`List[torch.LongTensor]`):
|
| 57 |
+
Discret code embeddings computed using `model.encode`, each tensor has shape (codes_length_i, num_quantizers).
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
audio_codes: List[torch.LongTensor] = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
@auto_docstring
|
| 65 |
+
class Qwen3TTSTokenizerV2DecoderOutput(ModelOutput):
|
| 66 |
+
r"""
|
| 67 |
+
audio_values (`List[torch.FloatTensor]`):
|
| 68 |
+
Decoded audio values, obtained using the decoder part of Qwen3TTSTokenizerV1.
|
| 69 |
+
Each tensor has shape (segment_length_i).
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
audio_values: List[torch.FloatTensor] = None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def rotate_half(x):
|
| 76 |
+
"""Rotates half the hidden dims of the input."""
|
| 77 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 78 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 79 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 83 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
q (`torch.Tensor`): The query tensor.
|
| 87 |
+
k (`torch.Tensor`): The key tensor.
|
| 88 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 89 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 90 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 91 |
+
Deprecated and unused.
|
| 92 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 93 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 94 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 95 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 96 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 97 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 98 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 99 |
+
Returns:
|
| 100 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 101 |
+
"""
|
| 102 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 103 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 104 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 105 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 106 |
+
return q_embed, k_embed
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 110 |
+
"""
|
| 111 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 112 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 113 |
+
"""
|
| 114 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 115 |
+
if n_rep == 1:
|
| 116 |
+
return hidden_states
|
| 117 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 118 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def eager_attention_forward(
|
| 122 |
+
module: nn.Module,
|
| 123 |
+
query: torch.Tensor,
|
| 124 |
+
key: torch.Tensor,
|
| 125 |
+
value: torch.Tensor,
|
| 126 |
+
attention_mask: Optional[torch.Tensor],
|
| 127 |
+
scaling: float,
|
| 128 |
+
dropout: float = 0.0,
|
| 129 |
+
**kwargs,
|
| 130 |
+
):
|
| 131 |
+
key_states = repeat_kv(key, module.num_key_value_groups)
|
| 132 |
+
value_states = repeat_kv(value, module.num_key_value_groups)
|
| 133 |
+
|
| 134 |
+
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
|
| 135 |
+
if attention_mask is not None:
|
| 136 |
+
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
| 137 |
+
attn_weights = attn_weights + causal_mask
|
| 138 |
+
|
| 139 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
|
| 140 |
+
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
|
| 141 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 142 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 143 |
+
|
| 144 |
+
return attn_output, attn_weights
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@auto_docstring
|
| 148 |
+
class Qwen3TTSTokenizerV2DecoderPreTrainedModel(PreTrainedModel):
|
| 149 |
+
config: Qwen3TTSTokenizerV2DecoderConfig
|
| 150 |
+
base_model_prefix = "model"
|
| 151 |
+
supports_gradient_checkpointing = True
|
| 152 |
+
_skip_keys_device_placement = "past_key_values"
|
| 153 |
+
_supports_flash_attn = True
|
| 154 |
+
_supports_sdpa = True
|
| 155 |
+
_can_compile_fullgraph = False
|
| 156 |
+
_supports_attention_backend = True
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class Qwen3TTSTokenizerV2CausalConvNet(nn.Module):
|
| 160 |
+
def __init__(
|
| 161 |
+
self,
|
| 162 |
+
in_channels,
|
| 163 |
+
out_channels,
|
| 164 |
+
kernel_size,
|
| 165 |
+
dilation=1,
|
| 166 |
+
stride=1,
|
| 167 |
+
groups=1,
|
| 168 |
+
):
|
| 169 |
+
super().__init__()
|
| 170 |
+
self.conv = nn.Conv1d(
|
| 171 |
+
in_channels,
|
| 172 |
+
out_channels,
|
| 173 |
+
kernel_size,
|
| 174 |
+
stride=stride,
|
| 175 |
+
dilation=dilation,
|
| 176 |
+
groups=groups,
|
| 177 |
+
)
|
| 178 |
+
self.stride = stride
|
| 179 |
+
self.kernel_size = (kernel_size - 1) * dilation + 1
|
| 180 |
+
self.dilation = dilation
|
| 181 |
+
self.padding = self.kernel_size - self.stride
|
| 182 |
+
|
| 183 |
+
def _get_extra_padding_for_conv1d(self, hidden_state: torch.Tensor) -> int:
|
| 184 |
+
length = hidden_state.shape[-1]
|
| 185 |
+
n_frames = (length - self.kernel_size + self.padding) / self.stride + 1
|
| 186 |
+
ideal_length = (math.ceil(n_frames) - 1) * self.stride + (self.kernel_size - self.padding)
|
| 187 |
+
return ideal_length - length
|
| 188 |
+
|
| 189 |
+
def forward(self, hidden_state):
|
| 190 |
+
extra_padding = self._get_extra_padding_for_conv1d(hidden_state)
|
| 191 |
+
hidden_state = F.pad(hidden_state, (self.padding, extra_padding), mode="constant", value=0)
|
| 192 |
+
return self.conv(hidden_state).contiguous()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class Qwen3TTSTokenizerV2CausalTransConvNet(nn.Module):
|
| 196 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=stride)
|
| 199 |
+
|
| 200 |
+
pad = kernel_size - stride
|
| 201 |
+
self.left_pad = math.ceil(pad)
|
| 202 |
+
self.right_pad = pad = self.left_pad
|
| 203 |
+
|
| 204 |
+
def forward(self, hidden_state):
|
| 205 |
+
hidden_state = self.conv(hidden_state)
|
| 206 |
+
hidden_state = hidden_state[..., self.left_pad : hidden_state.shape[-1] - self.right_pad]
|
| 207 |
+
return hidden_state.contiguous()
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class Qwen3TTSTokenizerV2ConvNeXtBlock(nn.Module):
|
| 211 |
+
def __init__(self, dim: int):
|
| 212 |
+
super().__init__()
|
| 213 |
+
self.dwconv = Qwen3TTSTokenizerV2CausalConvNet(
|
| 214 |
+
dim,
|
| 215 |
+
dim,
|
| 216 |
+
kernel_size=7,
|
| 217 |
+
groups=dim,
|
| 218 |
+
dilation=1,
|
| 219 |
+
)
|
| 220 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 221 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim)
|
| 222 |
+
self.act = nn.GELU()
|
| 223 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
| 224 |
+
self.gamma = nn.Parameter(1e-6 * torch.ones(dim))
|
| 225 |
+
|
| 226 |
+
def forward(self, hidden_states):
|
| 227 |
+
input = hidden_states
|
| 228 |
+
|
| 229 |
+
hidden_states = self.dwconv(hidden_states)
|
| 230 |
+
hidden_states = hidden_states.permute(0, 2, 1)
|
| 231 |
+
hidden_states = self.norm(hidden_states)
|
| 232 |
+
hidden_states = self.pwconv1(hidden_states)
|
| 233 |
+
hidden_states = self.act(hidden_states)
|
| 234 |
+
hidden_states = self.pwconv2(hidden_states)
|
| 235 |
+
|
| 236 |
+
hidden_states = self.gamma * hidden_states
|
| 237 |
+
|
| 238 |
+
hidden_states = hidden_states.permute(0, 2, 1)
|
| 239 |
+
|
| 240 |
+
hidden_states = input + hidden_states
|
| 241 |
+
|
| 242 |
+
return hidden_states
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class Qwen3TTSTokenizerV2DecoderRotatoryEmbedding(nn.Module):
|
| 246 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 247 |
+
|
| 248 |
+
def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, device=None):
|
| 249 |
+
super().__init__()
|
| 250 |
+
# BC: "rope_type" was originally "type"
|
| 251 |
+
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
|
| 252 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 253 |
+
else:
|
| 254 |
+
self.rope_type = "default"
|
| 255 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 256 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 257 |
+
|
| 258 |
+
self.config = config
|
| 259 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 260 |
+
|
| 261 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
| 262 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 263 |
+
self.original_inv_freq = self.inv_freq
|
| 264 |
+
|
| 265 |
+
@torch.no_grad()
|
| 266 |
+
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
|
| 267 |
+
def forward(self, x, position_ids):
|
| 268 |
+
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
|
| 269 |
+
position_ids_expanded = position_ids[:, None, :].float()
|
| 270 |
+
|
| 271 |
+
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
|
| 272 |
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
| 273 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
| 274 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 275 |
+
cos = emb.cos() * self.attention_scaling
|
| 276 |
+
sin = emb.sin() * self.attention_scaling
|
| 277 |
+
|
| 278 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
class Qwen3TTSTokenizerV2DecoderAttention(nn.Module):
|
| 282 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 283 |
+
|
| 284 |
+
def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, layer_idx):
|
| 285 |
+
super().__init__()
|
| 286 |
+
self.config = config
|
| 287 |
+
self.layer_idx = layer_idx
|
| 288 |
+
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
| 289 |
+
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
|
| 290 |
+
self.scaling = self.head_dim**-0.5
|
| 291 |
+
self.attention_dropout = config.attention_dropout
|
| 292 |
+
self.is_causal = True
|
| 293 |
+
|
| 294 |
+
self.q_proj = nn.Linear(
|
| 295 |
+
config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
|
| 296 |
+
)
|
| 297 |
+
self.k_proj = nn.Linear(
|
| 298 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 299 |
+
)
|
| 300 |
+
self.v_proj = nn.Linear(
|
| 301 |
+
config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
|
| 302 |
+
)
|
| 303 |
+
self.o_proj = nn.Linear(
|
| 304 |
+
config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
|
| 305 |
+
)
|
| 306 |
+
self.q_norm = nn.Identity()
|
| 307 |
+
self.k_norm = nn.Identity()
|
| 308 |
+
self.sliding_window = config.sliding_window
|
| 309 |
+
|
| 310 |
+
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
|
| 311 |
+
def forward(
|
| 312 |
+
self,
|
| 313 |
+
hidden_states: torch.Tensor,
|
| 314 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 315 |
+
attention_mask: Optional[torch.Tensor],
|
| 316 |
+
past_key_values: Optional[Cache] = None,
|
| 317 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 318 |
+
**kwargs: Unpack[FlashAttentionKwargs],
|
| 319 |
+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 320 |
+
input_shape = hidden_states.shape[:-1]
|
| 321 |
+
hidden_shape = (*input_shape, -1, self.head_dim)
|
| 322 |
+
|
| 323 |
+
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 324 |
+
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
|
| 325 |
+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
| 326 |
+
|
| 327 |
+
cos, sin = position_embeddings
|
| 328 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 329 |
+
|
| 330 |
+
if past_key_values is not None:
|
| 331 |
+
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 332 |
+
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 333 |
+
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 334 |
+
|
| 335 |
+
attention_interface: Callable = eager_attention_forward
|
| 336 |
+
if self.config._attn_implementation != "eager":
|
| 337 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 338 |
+
|
| 339 |
+
attn_output, attn_weights = attention_interface(
|
| 340 |
+
self,
|
| 341 |
+
query_states,
|
| 342 |
+
key_states,
|
| 343 |
+
value_states,
|
| 344 |
+
attention_mask,
|
| 345 |
+
dropout=0.0 if not self.training else self.attention_dropout,
|
| 346 |
+
scaling=self.scaling,
|
| 347 |
+
sliding_window=self.sliding_window, # diff with Llama
|
| 348 |
+
**kwargs,
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 352 |
+
attn_output = self.o_proj(attn_output)
|
| 353 |
+
return attn_output, attn_weights
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class Qwen3TTSTokenizerV2DecoderMlp(nn.Module):
|
| 357 |
+
def __init__(self, config):
|
| 358 |
+
super().__init__()
|
| 359 |
+
self.config = config
|
| 360 |
+
self.hidden_size = config.hidden_size
|
| 361 |
+
self.intermediate_size = config.intermediate_size
|
| 362 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 363 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 364 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 365 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 366 |
+
|
| 367 |
+
def forward(self, x):
|
| 368 |
+
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
| 369 |
+
return down_proj
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
@use_kernel_forward_from_hub("RMSNorm")
|
| 373 |
+
class Qwen3TTSTokenizerV2DecoderRMSNorm(nn.Module):
|
| 374 |
+
def __init__(self, hidden_size, eps: float = 1e-6) -> None:
|
| 375 |
+
"""
|
| 376 |
+
Qwen3TTSTokenizerV2DecoderRMSNorm is equivalent to T5LayerNorm
|
| 377 |
+
"""
|
| 378 |
+
super().__init__()
|
| 379 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 380 |
+
self.variance_epsilon = eps
|
| 381 |
+
|
| 382 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 383 |
+
input_dtype = hidden_states.dtype
|
| 384 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 385 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 386 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 387 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 388 |
+
|
| 389 |
+
def extra_repr(self):
|
| 390 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
class Qwen3TTSTokenizerV2DecoderLayerScale(nn.Module):
|
| 394 |
+
"""Layer scale from [Touvron et al 2021] (https://huggingface.co/papers/2103.17239).
|
| 395 |
+
This rescales diagonally the residual outputs close to 0, with a learnt scale.
|
| 396 |
+
"""
|
| 397 |
+
|
| 398 |
+
def __init__(self, config):
|
| 399 |
+
super().__init__()
|
| 400 |
+
channels = config.hidden_size
|
| 401 |
+
initial_scale = config.layer_scale_initial_scale
|
| 402 |
+
self.scale = nn.Parameter(torch.full((channels,), initial_scale, requires_grad=True))
|
| 403 |
+
|
| 404 |
+
def forward(self, x: torch.Tensor):
|
| 405 |
+
return self.scale * x
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class Qwen3TTSTokenizerV2DecoderTransformerLayer(GradientCheckpointingLayer):
|
| 409 |
+
def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, layer_idx):
|
| 410 |
+
super().__init__()
|
| 411 |
+
self.hidden_size = config.hidden_size
|
| 412 |
+
self.self_attn = Qwen3TTSTokenizerV2DecoderAttention(config, layer_idx)
|
| 413 |
+
self.mlp = Qwen3TTSTokenizerV2DecoderMlp(config)
|
| 414 |
+
self.input_layernorm = Qwen3TTSTokenizerV2DecoderRMSNorm(config.hidden_size, config.rms_norm_eps)
|
| 415 |
+
self.post_attention_layernorm = Qwen3TTSTokenizerV2DecoderRMSNorm(config.hidden_size, config.rms_norm_eps)
|
| 416 |
+
self.self_attn_layer_scale = Qwen3TTSTokenizerV2DecoderLayerScale(config)
|
| 417 |
+
self.mlp_layer_scale = Qwen3TTSTokenizerV2DecoderLayerScale(config)
|
| 418 |
+
self.attention_type = "sliding_attention"
|
| 419 |
+
|
| 420 |
+
def forward(
|
| 421 |
+
self,
|
| 422 |
+
hidden_states: torch.Tensor,
|
| 423 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 424 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 425 |
+
past_key_values: Optional[Cache] = None,
|
| 426 |
+
use_cache: Optional[bool] = False,
|
| 427 |
+
cache_position: Optional[torch.LongTensor] = None,
|
| 428 |
+
**kwargs,
|
| 429 |
+
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 430 |
+
"""
|
| 431 |
+
Args:
|
| 432 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 433 |
+
attention_mask (`torch.FloatTensor`, *optional*):
|
| 434 |
+
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
|
| 435 |
+
query_sequence_length, key_sequence_length)` if default attention is used.
|
| 436 |
+
output_attentions (`bool`, *optional*):
|
| 437 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 438 |
+
returned tensors for more detail.
|
| 439 |
+
use_cache (`bool`, *optional*):
|
| 440 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 441 |
+
(see `past_key_values`).
|
| 442 |
+
past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 443 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 444 |
+
Indices depicting the position of the input sequence tokens in the sequence
|
| 445 |
+
kwargs (`dict`, *optional*):
|
| 446 |
+
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
| 447 |
+
into the model
|
| 448 |
+
"""
|
| 449 |
+
residual = hidden_states
|
| 450 |
+
|
| 451 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 452 |
+
|
| 453 |
+
# Self Attention
|
| 454 |
+
hidden_states, _ = self.self_attn(
|
| 455 |
+
hidden_states=hidden_states,
|
| 456 |
+
attention_mask=attention_mask,
|
| 457 |
+
position_ids=position_ids,
|
| 458 |
+
past_key_values=past_key_values,
|
| 459 |
+
use_cache=use_cache,
|
| 460 |
+
cache_position=cache_position,
|
| 461 |
+
**kwargs,
|
| 462 |
+
)
|
| 463 |
+
hidden_states = residual + self.self_attn_layer_scale(hidden_states)
|
| 464 |
+
|
| 465 |
+
# Fully Connected
|
| 466 |
+
residual = hidden_states
|
| 467 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 468 |
+
hidden_states = self.mlp(hidden_states)
|
| 469 |
+
hidden_states = residual + self.mlp_layer_scale(hidden_states)
|
| 470 |
+
|
| 471 |
+
return hidden_states
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
@auto_docstring
|
| 475 |
+
class Qwen3TTSTokenizerV2DecoderTransformerModel(Qwen3TTSTokenizerV2DecoderPreTrainedModel):
|
| 476 |
+
_can_record_outputs = {
|
| 477 |
+
"hidden_states": Qwen3TTSTokenizerV2DecoderTransformerLayer,
|
| 478 |
+
"attentions": Qwen3TTSTokenizerV2DecoderAttention,
|
| 479 |
+
}
|
| 480 |
+
|
| 481 |
+
def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig):
|
| 482 |
+
super().__init__(config)
|
| 483 |
+
self.layers = nn.ModuleList(
|
| 484 |
+
[Qwen3TTSTokenizerV2DecoderTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 485 |
+
)
|
| 486 |
+
self.norm = Qwen3TTSTokenizerV2DecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 487 |
+
self.rotary_emb = Qwen3TTSTokenizerV2DecoderRotatoryEmbedding(config=config)
|
| 488 |
+
self.gradient_checkpointing = False
|
| 489 |
+
self.has_sliding_layers = "sliding_attention" in self.config.layer_types
|
| 490 |
+
self.window_size = config.sliding_window
|
| 491 |
+
|
| 492 |
+
self.input_proj = nn.Linear(config.latent_dim, config.hidden_size)
|
| 493 |
+
self.output_proj = nn.Linear(config.hidden_size, config.latent_dim)
|
| 494 |
+
|
| 495 |
+
# Initialize weights and apply final processing
|
| 496 |
+
self.post_init()
|
| 497 |
+
|
| 498 |
+
@check_model_inputs()
|
| 499 |
+
@auto_docstring
|
| 500 |
+
def forward(
|
| 501 |
+
self,
|
| 502 |
+
input_ids=None,
|
| 503 |
+
attention_mask=None,
|
| 504 |
+
position_ids=None,
|
| 505 |
+
past_key_values=None,
|
| 506 |
+
inputs_embeds=None,
|
| 507 |
+
use_cache=None,
|
| 508 |
+
cache_position=None,
|
| 509 |
+
**kwargs,
|
| 510 |
+
) -> BaseModelOutputWithPast:
|
| 511 |
+
if input_ids is not None:
|
| 512 |
+
raise ValueError("input_ids is not expected")
|
| 513 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 514 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 515 |
+
|
| 516 |
+
if inputs_embeds is None:
|
| 517 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 518 |
+
|
| 519 |
+
inputs_embeds = self.input_proj(inputs_embeds)
|
| 520 |
+
|
| 521 |
+
if use_cache and past_key_values is None:
|
| 522 |
+
past_key_values = DynamicCache(config=self.config)
|
| 523 |
+
|
| 524 |
+
if cache_position is None:
|
| 525 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 526 |
+
cache_position = torch.arange(
|
| 527 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 528 |
+
)
|
| 529 |
+
|
| 530 |
+
if position_ids is None:
|
| 531 |
+
position_ids = cache_position.unsqueeze(0)
|
| 532 |
+
|
| 533 |
+
# It may already have been prepared by e.g. `generate`
|
| 534 |
+
if not isinstance(causal_mask_mapping := attention_mask, dict):
|
| 535 |
+
# Prepare mask arguments
|
| 536 |
+
mask_kwargs = {
|
| 537 |
+
"config": self.config,
|
| 538 |
+
"input_embeds": inputs_embeds,
|
| 539 |
+
"attention_mask": attention_mask,
|
| 540 |
+
"cache_position": cache_position,
|
| 541 |
+
"past_key_values": past_key_values,
|
| 542 |
+
"position_ids": position_ids,
|
| 543 |
+
}
|
| 544 |
+
# Create the masks
|
| 545 |
+
causal_mask_mapping = {
|
| 546 |
+
"full_attention": create_causal_mask(**mask_kwargs),
|
| 547 |
+
}
|
| 548 |
+
# The sliding window alternating layers are not always activated depending on the config
|
| 549 |
+
if self.has_sliding_layers:
|
| 550 |
+
causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)
|
| 551 |
+
|
| 552 |
+
hidden_states = inputs_embeds
|
| 553 |
+
|
| 554 |
+
# create position embeddings to be shared across the decoder layers
|
| 555 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 556 |
+
|
| 557 |
+
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
|
| 558 |
+
hidden_states = decoder_layer(
|
| 559 |
+
hidden_states,
|
| 560 |
+
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
|
| 561 |
+
position_ids=position_ids,
|
| 562 |
+
past_key_values=past_key_values,
|
| 563 |
+
use_cache=use_cache,
|
| 564 |
+
cache_position=cache_position,
|
| 565 |
+
position_embeddings=position_embeddings,
|
| 566 |
+
**kwargs,
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
hidden_states = self.norm(hidden_states)
|
| 570 |
+
hidden_states = self.output_proj(hidden_states)
|
| 571 |
+
return BaseModelOutputWithPast(
|
| 572 |
+
last_hidden_state=hidden_states,
|
| 573 |
+
past_key_values=past_key_values if use_cache else None,
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
class SnakeBeta(nn.Module):
|
| 578 |
+
"""
|
| 579 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
| 580 |
+
Shape:
|
| 581 |
+
- Input: (B, C, T)
|
| 582 |
+
- Output: (B, C, T), same shape as the input
|
| 583 |
+
Parameters:
|
| 584 |
+
- alpha - trainable parameter that controls frequency
|
| 585 |
+
- beta - trainable parameter that controls magnitude
|
| 586 |
+
References:
|
| 587 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 588 |
+
https://huggingface.co/papers/2006.08195
|
| 589 |
+
"""
|
| 590 |
+
|
| 591 |
+
def __init__(self, in_features, alpha=1.0):
|
| 592 |
+
super().__init__()
|
| 593 |
+
self.in_features = in_features
|
| 594 |
+
|
| 595 |
+
# initialize alpha
|
| 596 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 597 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
| 598 |
+
|
| 599 |
+
self.no_div_by_zero = 0.000000001
|
| 600 |
+
|
| 601 |
+
def forward(self, hidden_states):
|
| 602 |
+
"""
|
| 603 |
+
Forward pass of the function.
|
| 604 |
+
Applies the function to the input elementwise.
|
| 605 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
| 606 |
+
"""
|
| 607 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 608 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 609 |
+
alpha = torch.exp(alpha)
|
| 610 |
+
beta = torch.exp(beta)
|
| 611 |
+
hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
|
| 612 |
+
torch.sin(hidden_states * alpha), 2
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
return hidden_states
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
class Qwen3TTSTokenizerV2DecoderDecoderResidualUnit(nn.Module):
|
| 619 |
+
def __init__(self, dim: int = 16, dilation: int = 1):
|
| 620 |
+
super().__init__()
|
| 621 |
+
|
| 622 |
+
self.act1 = SnakeBeta(dim)
|
| 623 |
+
self.conv1 = Qwen3TTSTokenizerV2CausalConvNet(dim, dim, kernel_size=7, dilation=dilation)
|
| 624 |
+
self.act2 = SnakeBeta(dim)
|
| 625 |
+
self.conv2 = Qwen3TTSTokenizerV2CausalConvNet(dim, dim, kernel_size=1)
|
| 626 |
+
|
| 627 |
+
def forward(self, hidden_state):
|
| 628 |
+
residual = hidden_state
|
| 629 |
+
|
| 630 |
+
hidden_state = self.act1(hidden_state)
|
| 631 |
+
hidden_state = self.conv1(hidden_state)
|
| 632 |
+
hidden_state = self.act2(hidden_state)
|
| 633 |
+
hidden_state = self.conv2(hidden_state)
|
| 634 |
+
return hidden_state + residual
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
class Qwen3TTSTokenizerV2DecoderDecoderBlock(Qwen3TTSTokenizerV2DecoderPreTrainedModel):
|
| 638 |
+
def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig, layer_idx):
|
| 639 |
+
super().__init__(config)
|
| 640 |
+
in_dim = config.decoder_dim // 2**layer_idx
|
| 641 |
+
out_dim = config.decoder_dim // 2 ** (layer_idx + 1)
|
| 642 |
+
upsample_rate = config.upsample_rates[layer_idx]
|
| 643 |
+
|
| 644 |
+
block = [
|
| 645 |
+
SnakeBeta(in_dim),
|
| 646 |
+
Qwen3TTSTokenizerV2CausalTransConvNet(in_dim, out_dim, 2 * upsample_rate, upsample_rate),
|
| 647 |
+
]
|
| 648 |
+
|
| 649 |
+
for dilation in (1, 3, 9):
|
| 650 |
+
block.append(Qwen3TTSTokenizerV2DecoderDecoderResidualUnit(out_dim, dilation))
|
| 651 |
+
|
| 652 |
+
self.block = nn.ModuleList(block)
|
| 653 |
+
|
| 654 |
+
def forward(self, hidden):
|
| 655 |
+
for block in self.block:
|
| 656 |
+
hidden = block(hidden)
|
| 657 |
+
return hidden
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
class EuclideanCodebook(nn.Module):
|
| 661 |
+
def __init__(
|
| 662 |
+
self,
|
| 663 |
+
dim: int,
|
| 664 |
+
codebook_size: int,
|
| 665 |
+
epsilon: float = 1e-5,
|
| 666 |
+
):
|
| 667 |
+
super().__init__()
|
| 668 |
+
self.dim = dim
|
| 669 |
+
self.codebook_size = codebook_size
|
| 670 |
+
self.epsilon = epsilon
|
| 671 |
+
|
| 672 |
+
self.cluster_usage = nn.Parameter(torch.ones(codebook_size))
|
| 673 |
+
self.embedding_sum = nn.Parameter(torch.zeros(codebook_size, dim))
|
| 674 |
+
|
| 675 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 676 |
+
embedding = self.embedding_sum / self.cluster_usage.clamp(min=self.epsilon)[:, None]
|
| 677 |
+
quantized = F.embedding(codes, embedding)
|
| 678 |
+
return quantized
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
class VectorQuantization(nn.Module):
|
| 682 |
+
def __init__(
|
| 683 |
+
self,
|
| 684 |
+
dim: int,
|
| 685 |
+
codebook_size: int,
|
| 686 |
+
codebook_dim: Optional[int] = None,
|
| 687 |
+
epsilon: float = 1e-5,
|
| 688 |
+
):
|
| 689 |
+
super().__init__()
|
| 690 |
+
if codebook_dim is None:
|
| 691 |
+
codebook_dim = dim
|
| 692 |
+
|
| 693 |
+
requires_projection = codebook_dim != dim
|
| 694 |
+
|
| 695 |
+
self.project_out = (
|
| 696 |
+
nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
|
| 697 |
+
)
|
| 698 |
+
self.epsilon = epsilon
|
| 699 |
+
self._codebook = EuclideanCodebook(
|
| 700 |
+
dim=codebook_dim,
|
| 701 |
+
codebook_size=codebook_size,
|
| 702 |
+
epsilon=epsilon
|
| 703 |
+
)
|
| 704 |
+
self.codebook_size = codebook_size
|
| 705 |
+
|
| 706 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 707 |
+
quantized = self._codebook.decode(codes)
|
| 708 |
+
quantized = self.project_out(quantized)
|
| 709 |
+
quantized = quantized.transpose(1, 2)
|
| 710 |
+
return quantized
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
class ResidualVectorQuantization(nn.Module):
|
| 714 |
+
def __init__(self, *, num_quantizers: int, **kwargs):
|
| 715 |
+
super().__init__()
|
| 716 |
+
self.layers = nn.ModuleList(
|
| 717 |
+
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 721 |
+
quantized = torch.zeros([1], device=codes.device)[0]
|
| 722 |
+
for idx, layer_codes in enumerate(codes):
|
| 723 |
+
layer = self.layers[idx]
|
| 724 |
+
assert isinstance(layer, VectorQuantization)
|
| 725 |
+
quantized = quantized + layer.decode(layer_codes)
|
| 726 |
+
return quantized
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
class ResidualVectorQuantizer(nn.Module):
|
| 730 |
+
def __init__(
|
| 731 |
+
self,
|
| 732 |
+
dimension: int = 128,
|
| 733 |
+
input_dimension: Optional[int] = None,
|
| 734 |
+
output_dimension: Optional[int] = None,
|
| 735 |
+
n_q: int = 8,
|
| 736 |
+
q_dropout: bool = False,
|
| 737 |
+
no_quantization_rate: float = 0.0,
|
| 738 |
+
bins: int = 1024,
|
| 739 |
+
decay: float = 0.99,
|
| 740 |
+
force_projection: bool = False,
|
| 741 |
+
):
|
| 742 |
+
super().__init__()
|
| 743 |
+
self.max_n_q = n_q
|
| 744 |
+
self.n_q = n_q
|
| 745 |
+
self.q_dropout = q_dropout
|
| 746 |
+
self.no_quantization_rate = no_quantization_rate
|
| 747 |
+
self.dimension = dimension
|
| 748 |
+
self.input_dimension = input_dimension or dimension
|
| 749 |
+
self.output_dimension = output_dimension or dimension
|
| 750 |
+
self.bins = bins
|
| 751 |
+
self.decay = decay
|
| 752 |
+
self.input_proj: torch.nn.Module
|
| 753 |
+
self.output_proj: torch.nn.Module
|
| 754 |
+
if self.input_dimension == self.dimension and not force_projection:
|
| 755 |
+
self.input_proj = torch.nn.Identity()
|
| 756 |
+
else:
|
| 757 |
+
self.input_proj = torch.nn.Conv1d(
|
| 758 |
+
self.input_dimension, self.dimension, 1, bias=False
|
| 759 |
+
)
|
| 760 |
+
if self.output_dimension == self.dimension and not force_projection:
|
| 761 |
+
self.output_proj = torch.nn.Identity()
|
| 762 |
+
else:
|
| 763 |
+
self.output_proj = torch.nn.Conv1d(
|
| 764 |
+
self.dimension, self.output_dimension, 1, bias=False
|
| 765 |
+
)
|
| 766 |
+
self.vq = ResidualVectorQuantization(
|
| 767 |
+
dim=self.dimension,
|
| 768 |
+
codebook_size=self.bins,
|
| 769 |
+
num_quantizers=self.n_q
|
| 770 |
+
)
|
| 771 |
+
|
| 772 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 773 |
+
codes = codes.transpose(0, 1)
|
| 774 |
+
quantized = self.vq.decode(codes)
|
| 775 |
+
quantized = self.output_proj(quantized)
|
| 776 |
+
return quantized
|
| 777 |
+
|
| 778 |
+
|
| 779 |
+
class SplitResidualVectorQuantizer(nn.Module):
|
| 780 |
+
"""Residual Vector Quantizer with separate projections for the first quantizer and the rest.
|
| 781 |
+
|
| 782 |
+
Args:
|
| 783 |
+
n_q (int): Number of residual vector quantizers used.
|
| 784 |
+
n_semantic_q (int): Number of residual vector quantizers used for the semantic quantizer.
|
| 785 |
+
**kwargs: Arguments to the constructor of `ResidualVectorQuantizer` that are shared between both.
|
| 786 |
+
"""
|
| 787 |
+
|
| 788 |
+
def __init__(
|
| 789 |
+
self,
|
| 790 |
+
*,
|
| 791 |
+
n_q: int = 8,
|
| 792 |
+
n_q_semantic: int = 1,
|
| 793 |
+
**kwargs,
|
| 794 |
+
):
|
| 795 |
+
super().__init__()
|
| 796 |
+
assert n_q > n_q_semantic, (
|
| 797 |
+
f"Number of quantizers {n_q} must be larger "
|
| 798 |
+
f"than the number of semantic quantizers {n_q_semantic}."
|
| 799 |
+
)
|
| 800 |
+
self.max_n_q = n_q
|
| 801 |
+
self.n_q_semantic = n_q_semantic
|
| 802 |
+
self.n_q_acoustic = n_q - n_q_semantic
|
| 803 |
+
q_dropout = kwargs.pop("q_dropout", False)
|
| 804 |
+
self.rvq_first = ResidualVectorQuantizer(
|
| 805 |
+
n_q=n_q_semantic, force_projection=True, q_dropout=False, **kwargs
|
| 806 |
+
)
|
| 807 |
+
self.rvq_rest = ResidualVectorQuantizer(
|
| 808 |
+
n_q=n_q - n_q_semantic,
|
| 809 |
+
force_projection=True,
|
| 810 |
+
q_dropout=q_dropout,
|
| 811 |
+
**kwargs,
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
| 815 |
+
"""Decode the given codes to the quantized representation."""
|
| 816 |
+
# codes is [B, K, T], with T frames, K nb of codebooks.
|
| 817 |
+
quantized = self.rvq_first.decode(codes[:, : self.n_q_semantic])
|
| 818 |
+
if codes.shape[1] > self.n_q_semantic:
|
| 819 |
+
quantized += self.rvq_rest.decode(codes[:, self.n_q_semantic :])
|
| 820 |
+
return quantized
|
| 821 |
+
|
| 822 |
+
|
| 823 |
+
class Qwen3TTSTokenizerV2Decoder(Qwen3TTSTokenizerV2DecoderPreTrainedModel):
|
| 824 |
+
def __init__(self, config: Qwen3TTSTokenizerV2DecoderConfig):
|
| 825 |
+
super().__init__(config)
|
| 826 |
+
self.total_upsample = np.prod(config.upsample_rates + config.upsampling_ratios)
|
| 827 |
+
self.pre_transformer = Qwen3TTSTokenizerV2DecoderTransformerModel._from_config(config)
|
| 828 |
+
|
| 829 |
+
self.quantizer = SplitResidualVectorQuantizer(
|
| 830 |
+
dimension=config.codebook_dim // 2,
|
| 831 |
+
n_q=config.num_quantizers,
|
| 832 |
+
n_q_semantic=1,
|
| 833 |
+
bins=config.codebook_size,
|
| 834 |
+
input_dimension=config.codebook_dim,
|
| 835 |
+
output_dimension=config.codebook_dim,
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
self.pre_conv = Qwen3TTSTokenizerV2CausalConvNet(
|
| 839 |
+
config.codebook_dim,
|
| 840 |
+
config.latent_dim,
|
| 841 |
+
kernel_size=3,
|
| 842 |
+
)
|
| 843 |
+
|
| 844 |
+
upsample = []
|
| 845 |
+
for factor in config.upsampling_ratios:
|
| 846 |
+
upsample.append(
|
| 847 |
+
nn.ModuleList(
|
| 848 |
+
[
|
| 849 |
+
Qwen3TTSTokenizerV2CausalTransConvNet(config.latent_dim, config.latent_dim, factor, factor),
|
| 850 |
+
Qwen3TTSTokenizerV2ConvNeXtBlock(config.latent_dim),
|
| 851 |
+
]
|
| 852 |
+
)
|
| 853 |
+
)
|
| 854 |
+
self.upsample = nn.ModuleList(upsample)
|
| 855 |
+
|
| 856 |
+
decoder = [Qwen3TTSTokenizerV2CausalConvNet(config.latent_dim, config.decoder_dim, 7)]
|
| 857 |
+
for i in range(len(config.upsample_rates)):
|
| 858 |
+
decoder.append(Qwen3TTSTokenizerV2DecoderDecoderBlock(config, i))
|
| 859 |
+
output_dim = config.decoder_dim // 2 ** len(config.upsample_rates)
|
| 860 |
+
decoder += [
|
| 861 |
+
SnakeBeta(output_dim),
|
| 862 |
+
Qwen3TTSTokenizerV2CausalConvNet(output_dim, 1, 7),
|
| 863 |
+
]
|
| 864 |
+
self.decoder = nn.ModuleList(decoder)
|
| 865 |
+
|
| 866 |
+
self.post_init()
|
| 867 |
+
|
| 868 |
+
def forward(self, codes):
|
| 869 |
+
if codes.shape[1] != self.config.num_quantizers:
|
| 870 |
+
raise ValueError(f"Expected {self.config.num_quantizers} layer of codes, got {codes.shape[1]}")
|
| 871 |
+
|
| 872 |
+
hidden = self.quantizer.decode(codes)
|
| 873 |
+
hidden = self.pre_conv(hidden).transpose(1, 2)
|
| 874 |
+
|
| 875 |
+
hidden = self.pre_transformer(inputs_embeds=hidden).last_hidden_state
|
| 876 |
+
hidden = hidden.permute(0, 2, 1)
|
| 877 |
+
for blocks in self.upsample:
|
| 878 |
+
for block in blocks:
|
| 879 |
+
hidden = block(hidden)
|
| 880 |
+
wav = hidden
|
| 881 |
+
for block in self.decoder:
|
| 882 |
+
wav = block(wav)
|
| 883 |
+
return wav.clamp(min=-1, max=1)
|
| 884 |
+
|
| 885 |
+
def chunked_decode(self, codes, chunk_size=300, left_context_size=25):
|
| 886 |
+
wavs = []
|
| 887 |
+
start_index = 0
|
| 888 |
+
while start_index < codes.shape[-1]:
|
| 889 |
+
end_index = min(start_index + chunk_size, codes.shape[-1])
|
| 890 |
+
context_size = left_context_size if start_index - left_context_size > 0 else start_index
|
| 891 |
+
codes_chunk = codes[..., start_index - context_size : end_index]
|
| 892 |
+
wav_chunk = self(codes_chunk)
|
| 893 |
+
wavs.append(wav_chunk[..., context_size * self.total_upsample :])
|
| 894 |
+
start_index = end_index
|
| 895 |
+
return torch.cat(wavs, dim=-1)
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
class Qwen3TTSTokenizerV2Encoder(MimiModel):
|
| 899 |
+
def __init__(self, config: MimiConfig):
|
| 900 |
+
super().__init__(config)
|
| 901 |
+
self.config = config
|
| 902 |
+
|
| 903 |
+
self.upsample = None
|
| 904 |
+
self.decoder_transformer = None
|
| 905 |
+
self.decoder = None
|
| 906 |
+
|
| 907 |
+
self.post_init()
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
@auto_docstring
|
| 911 |
+
class Qwen3TTSTokenizerV2PreTrainedModel(PreTrainedModel):
|
| 912 |
+
config: Qwen3TTSTokenizerV2Config
|
| 913 |
+
base_model_prefix = "model"
|
| 914 |
+
supports_gradient_checkpointing = True
|
| 915 |
+
_skip_keys_device_placement = "past_key_values"
|
| 916 |
+
_supports_flash_attn = True
|
| 917 |
+
_supports_sdpa = True
|
| 918 |
+
_can_compile_fullgraph = False
|
| 919 |
+
_supports_attention_backend = True
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
@auto_docstring(
|
| 923 |
+
custom_intro="""
|
| 924 |
+
The Qwen3TTSTokenizerV2 model.
|
| 925 |
+
"""
|
| 926 |
+
)
|
| 927 |
+
class Qwen3TTSTokenizerV2Model(Qwen3TTSTokenizerV2PreTrainedModel):
|
| 928 |
+
def __init__(self, config: Qwen3TTSTokenizerV2Config):
|
| 929 |
+
super().__init__(config)
|
| 930 |
+
self.config = config
|
| 931 |
+
|
| 932 |
+
self.encoder_valid_num_quantizers = config.encoder_valid_num_quantizers
|
| 933 |
+
|
| 934 |
+
self.input_sample_rate = config.input_sample_rate
|
| 935 |
+
self.output_sample_rate = config.output_sample_rate
|
| 936 |
+
|
| 937 |
+
self.decode_upsample_rate = config.decode_upsample_rate
|
| 938 |
+
self.encode_downsample_rate = config.encode_downsample_rate
|
| 939 |
+
|
| 940 |
+
self.encoder = Qwen3TTSTokenizerV2Encoder._from_config(self.config.encoder_config)
|
| 941 |
+
self.decoder = Qwen3TTSTokenizerV2Decoder._from_config(self.config.decoder_config)
|
| 942 |
+
|
| 943 |
+
self.post_init()
|
| 944 |
+
|
| 945 |
+
def get_model_type(self):
|
| 946 |
+
return self.config.model_type
|
| 947 |
+
|
| 948 |
+
def get_input_sample_rate(self):
|
| 949 |
+
return self.input_sample_rate
|
| 950 |
+
|
| 951 |
+
def get_output_sample_rate(self):
|
| 952 |
+
return self.output_sample_rate
|
| 953 |
+
|
| 954 |
+
def get_encode_downsample_rate(self):
|
| 955 |
+
return self.encode_downsample_rate
|
| 956 |
+
|
| 957 |
+
def get_decode_upsample_rate(self):
|
| 958 |
+
return self.decode_upsample_rate
|
| 959 |
+
|
| 960 |
+
def encode(
|
| 961 |
+
self,
|
| 962 |
+
input_values: torch.Tensor,
|
| 963 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 964 |
+
return_dict: Optional[bool] = None,
|
| 965 |
+
) -> Union[tuple[torch.Tensor, Optional[torch.Tensor]], Qwen3TTSTokenizerV2EncoderOutput]:
|
| 966 |
+
"""
|
| 967 |
+
Encodes the input audio waveform into discrete codes.
|
| 968 |
+
|
| 969 |
+
Args:
|
| 970 |
+
input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
| 971 |
+
Float values of the input audio waveform.
|
| 972 |
+
padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
| 973 |
+
Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0
|
| 974 |
+
for *masked*.
|
| 975 |
+
return_dict (`bool`, *optional*):
|
| 976 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 977 |
+
"""
|
| 978 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 979 |
+
|
| 980 |
+
encoded_frames = self.encoder.encode(input_values=input_values.unsqueeze(1),
|
| 981 |
+
return_dict=True)
|
| 982 |
+
audio_codes = encoded_frames.audio_codes[:, :self.encoder_valid_num_quantizers]
|
| 983 |
+
audio_codes = [code[..., :-(-mask.sum() // self.encode_downsample_rate)].transpose(0, 1) for code, mask in zip(audio_codes, padding_mask)]
|
| 984 |
+
|
| 985 |
+
if not return_dict:
|
| 986 |
+
return (
|
| 987 |
+
audio_codes,
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
return Qwen3TTSTokenizerV2EncoderOutput(audio_codes)
|
| 991 |
+
|
| 992 |
+
def decode(
|
| 993 |
+
self,
|
| 994 |
+
audio_codes: torch.Tensor,
|
| 995 |
+
return_dict: Optional[bool] = None,
|
| 996 |
+
) -> Union[tuple[torch.Tensor, torch.Tensor], Qwen3TTSTokenizerV2DecoderOutput]:
|
| 997 |
+
"""
|
| 998 |
+
Decodes the given frames into an output audio waveform.
|
| 999 |
+
|
| 1000 |
+
Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
|
| 1001 |
+
trimmed.
|
| 1002 |
+
|
| 1003 |
+
Args:
|
| 1004 |
+
audio_codes (`torch.LongTensor` of shape `(batch_size, codes_length, num_quantizers)`, *optional*):
|
| 1005 |
+
Discret code embeddings computed using `model.encode`.
|
| 1006 |
+
return_dict (`bool`, *optional*):
|
| 1007 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1008 |
+
|
| 1009 |
+
"""
|
| 1010 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 1011 |
+
|
| 1012 |
+
audio_values = self.decoder.chunked_decode(audio_codes.transpose(1, 2)).squeeze(1)
|
| 1013 |
+
|
| 1014 |
+
audio_lengths = (audio_codes[..., 0] > 0).sum(1) * self.decode_upsample_rate
|
| 1015 |
+
audio_values = [a[:l] for a, l in zip(audio_values, audio_lengths)]
|
| 1016 |
+
|
| 1017 |
+
if not return_dict:
|
| 1018 |
+
return (
|
| 1019 |
+
audio_values,
|
| 1020 |
+
)
|
| 1021 |
+
|
| 1022 |
+
return Qwen3TTSTokenizerV2DecoderOutput(audio_values)
|
| 1023 |
+
|
| 1024 |
+
|
| 1025 |
+
__all__ = ["Qwen3TTSTokenizerV2Model", "Qwen3TTSTokenizerV2PreTrainedModel"]
|
qwen_tts/core/tokenizer_25hz/configuration_qwen3_tts_tokenizer_v1.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""Qwen3TTSTokenizerV1 model configuration"""
|
| 16 |
+
|
| 17 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 18 |
+
from transformers.utils import logging
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.get_logger(__name__)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Qwen3TTSTokenizerV1DecoderDiTConfig(PretrainedConfig):
|
| 25 |
+
r"""
|
| 26 |
+
This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1DecoderToken2WavDiT.
|
| 27 |
+
It defines the architecture of the DiT model, which is used for generating mel-spectrograms from tokens.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
hidden_size (`int`, *optional*, defaults to 1024):
|
| 31 |
+
The dimension of the model.
|
| 32 |
+
num_hidden_layers (`int`, *optional*, defaults to 22):
|
| 33 |
+
The number of transformer blocks in the DiT model.
|
| 34 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
| 35 |
+
The number of attention heads in each transformer block.
|
| 36 |
+
ff_mult (`int`, *optional*, defaults to 2):
|
| 37 |
+
The multiplier for the feedforward layer in each transformer block.
|
| 38 |
+
emb_dim (`int`, *optional*, defaults to 512):
|
| 39 |
+
The dimension of the embedding layer.
|
| 40 |
+
head_dim (`int`, *optional*, defaults to 64):
|
| 41 |
+
The dimension of each attention head.
|
| 42 |
+
repeats (`int`, *optional*, defaults to 2):
|
| 43 |
+
The number of times the codec embeddings are repeated.
|
| 44 |
+
num_embeds (`int`, *optional*, defaults to 8193):
|
| 45 |
+
The number of unique embeddings in the codec.
|
| 46 |
+
mel_dim (`int`, *optional*, defaults to 80):
|
| 47 |
+
The dimension of the mel-spectrogram.
|
| 48 |
+
dropout (`float`, *optional*, defaults to 0.1):
|
| 49 |
+
The dropout rate for the transformer blocks.
|
| 50 |
+
|
| 51 |
+
enc_emb_dim (`int`, *optional*, defaults to 192):
|
| 52 |
+
The dimension of the pre-trained speaker embedding.
|
| 53 |
+
enc_dim (`int`, *optional*, defaults to 128):
|
| 54 |
+
The dimension of the encoder output.
|
| 55 |
+
enc_channels (`list[int]`, *optional*, defaults to `[256, 256, 256, 256, 768]`):
|
| 56 |
+
A list of output channels for each TDNN/SERes2Net layer in the encoder.
|
| 57 |
+
enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`):
|
| 58 |
+
A list of kernel sizes for each layer in the encoder.
|
| 59 |
+
enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`):
|
| 60 |
+
A list of dilations for each layer in the encoder.
|
| 61 |
+
enc_attention_channels (`int`, *optional*, defaults to 64):
|
| 62 |
+
The number of attention channels in the SqueezeExcitationBlock.
|
| 63 |
+
enc_res2net_scale (`int`, *optional*, defaults to 2):
|
| 64 |
+
The scale of the Res2Net block in the encoder.
|
| 65 |
+
enc_se_channels (`int`, *optional*, defaults to 64):
|
| 66 |
+
The number of output channels after squeeze in the SqueezeExcitationBlock.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
model_type = "qwen3_tts_tokenizer_v1_decoder_dit"
|
| 70 |
+
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
hidden_size=1024,
|
| 74 |
+
num_hidden_layers=22,
|
| 75 |
+
num_attention_heads=16,
|
| 76 |
+
ff_mult=2,
|
| 77 |
+
emb_dim=512,
|
| 78 |
+
head_dim=64,
|
| 79 |
+
rope_theta=10000.0,
|
| 80 |
+
max_position_embeddings=32768,
|
| 81 |
+
block_size=24,
|
| 82 |
+
look_ahead_layers=[10],
|
| 83 |
+
look_backward_layers=[0, 20],
|
| 84 |
+
repeats=2,
|
| 85 |
+
num_embeds=8193,
|
| 86 |
+
mel_dim=80,
|
| 87 |
+
dropout=0.1,
|
| 88 |
+
enc_emb_dim=192,
|
| 89 |
+
enc_dim=128,
|
| 90 |
+
enc_channels=[256, 256, 256, 256, 768],
|
| 91 |
+
enc_kernel_sizes=[5, 3, 3, 3, 1],
|
| 92 |
+
enc_dilations=[1, 2, 3, 4, 1],
|
| 93 |
+
enc_attention_channels=64,
|
| 94 |
+
enc_res2net_scale=2,
|
| 95 |
+
enc_se_channels=64,
|
| 96 |
+
**kwargs,
|
| 97 |
+
):
|
| 98 |
+
self.hidden_size = hidden_size
|
| 99 |
+
self.num_hidden_layers = num_hidden_layers
|
| 100 |
+
self.num_attention_heads = num_attention_heads
|
| 101 |
+
self.ff_mult = ff_mult
|
| 102 |
+
self.emb_dim = emb_dim
|
| 103 |
+
self.head_dim = head_dim
|
| 104 |
+
self.rope_theta = rope_theta
|
| 105 |
+
self.max_position_embeddings = max_position_embeddings
|
| 106 |
+
self.block_size = block_size
|
| 107 |
+
self.look_ahead_layers = look_ahead_layers
|
| 108 |
+
self.look_backward_layers = look_backward_layers
|
| 109 |
+
self.repeats = repeats
|
| 110 |
+
self.num_embeds = num_embeds
|
| 111 |
+
self.mel_dim = mel_dim
|
| 112 |
+
self.dropout = dropout
|
| 113 |
+
self.enc_emb_dim = enc_emb_dim
|
| 114 |
+
self.enc_dim = enc_dim
|
| 115 |
+
self.enc_channels = enc_channels
|
| 116 |
+
self.enc_kernel_sizes = enc_kernel_sizes
|
| 117 |
+
self.enc_dilations = enc_dilations
|
| 118 |
+
self.enc_attention_channels = enc_attention_channels
|
| 119 |
+
self.enc_res2net_scale = enc_res2net_scale
|
| 120 |
+
self.enc_se_channels = enc_se_channels
|
| 121 |
+
super().__init__(**kwargs)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class Qwen3TTSTokenizerV1DecoderBigVGANConfig(PretrainedConfig):
|
| 125 |
+
r"""
|
| 126 |
+
This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1DecoderToken2WavBigVGAN module.
|
| 127 |
+
It defines the architecture of the BigVGAN model, which is used for converting mel-spectrograms to waveforms.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
mel_dim (`int`, *optional*, defaults to 80):
|
| 131 |
+
The dimension of the mel-spectrogram.
|
| 132 |
+
upsample_initial_channel (`int`, *optional*, defaults to 1536):
|
| 133 |
+
The number of channels in the initial upsampling layer.
|
| 134 |
+
resblock_kernel_sizes (`list[int]`, *optional*, defaults to `[3, 7, 11]`):
|
| 135 |
+
A list of kernel sizes for each residual block.
|
| 136 |
+
resblock_dilation_sizes (`list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
|
| 137 |
+
A list of dilation sizes for each residual block.
|
| 138 |
+
upsample_rates (`list[int]`, *optional*, defaults to `[5, 3, 2, 2, 2, 2]`):
|
| 139 |
+
A list of upsampling rates for each upsampling layer.
|
| 140 |
+
upsample_kernel_sizes (`list[int]`, *optional*, defaults to `[11, 7, 4, 4, 4, 4]`):
|
| 141 |
+
A list of kernel sizes for each upsampling layer.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
model_type = "qwen3_tts_tokenizer_v1_decoder_bigvgan"
|
| 145 |
+
|
| 146 |
+
def __init__(
|
| 147 |
+
self,
|
| 148 |
+
mel_dim=80,
|
| 149 |
+
upsample_initial_channel=1536,
|
| 150 |
+
resblock_kernel_sizes=[3, 7, 11],
|
| 151 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 152 |
+
upsample_rates=[5, 3, 2, 2, 2, 2],
|
| 153 |
+
upsample_kernel_sizes=[11, 7, 4, 4, 4, 4],
|
| 154 |
+
**kwargs,
|
| 155 |
+
):
|
| 156 |
+
self.mel_dim = mel_dim
|
| 157 |
+
self.upsample_initial_channel = upsample_initial_channel
|
| 158 |
+
self.resblock_kernel_sizes = resblock_kernel_sizes
|
| 159 |
+
self.resblock_dilation_sizes = resblock_dilation_sizes
|
| 160 |
+
self.upsample_rates = upsample_rates
|
| 161 |
+
self.upsample_kernel_sizes = upsample_kernel_sizes
|
| 162 |
+
super().__init__(**kwargs)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class Qwen3TTSTokenizerV1DecoderConfig(PretrainedConfig):
|
| 166 |
+
r"""
|
| 167 |
+
This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV1DecoderConfig`].
|
| 168 |
+
|
| 169 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 170 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
dit_config ([`DiT_Args`], *optional*):
|
| 174 |
+
Configuration class for the Diffusion Transformer (DiT) module responsible for generating mel-spectrograms.
|
| 175 |
+
bigvgan_config ([`BigVGAN_Args`], *optional*):
|
| 176 |
+
Configuration class for the BigVGAN module responsible for converting mel-spectrograms to waveforms.
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
model_type = "qwen3_tts_tokenizer_v1_decoder"
|
| 180 |
+
sub_configs = {
|
| 181 |
+
"dit_config": Qwen3TTSTokenizerV1DecoderDiTConfig,
|
| 182 |
+
"bigvgan_config": Qwen3TTSTokenizerV1DecoderBigVGANConfig,
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
def __init__(self, dit_config=None, bigvgan_config=None, **kwargs):
|
| 186 |
+
if dit_config is None:
|
| 187 |
+
dit_config = {}
|
| 188 |
+
if bigvgan_config is None:
|
| 189 |
+
bigvgan_config = {}
|
| 190 |
+
self.dit_config = Qwen3TTSTokenizerV1DecoderDiTConfig(**dit_config)
|
| 191 |
+
self.bigvgan_config = Qwen3TTSTokenizerV1DecoderBigVGANConfig(**bigvgan_config)
|
| 192 |
+
super().__init__(**kwargs)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class Qwen3TTSTokenizerV1EncoderConfig(PretrainedConfig):
|
| 196 |
+
r"""
|
| 197 |
+
This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1 Encoder.
|
| 198 |
+
|
| 199 |
+
The encoder typically takes mel-spectrogram features and produces high-level audio representations, then (optionally)
|
| 200 |
+
applies an Audio-VQ module (e.g., GRVQ) to discretize continuous representations into codes.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
n_mels (`int`, *optional*, defaults to 128):
|
| 204 |
+
Number of mel bins in the input mel-spectrogram.
|
| 205 |
+
n_ctx (`int`, *optional*, defaults to 1500):
|
| 206 |
+
Maximum input sequence length (in frames/tokens) for the encoder.
|
| 207 |
+
n_state (`int`, *optional*, defaults to 1280):
|
| 208 |
+
Hidden size (model dimension) of the encoder transformer.
|
| 209 |
+
n_head (`int`, *optional*, defaults to 20):
|
| 210 |
+
Number of attention heads in each transformer layer.
|
| 211 |
+
n_layer (`int`, *optional*, defaults to 32):
|
| 212 |
+
Number of transformer layers.
|
| 213 |
+
n_window (`int`, *optional*, defaults to 100):
|
| 214 |
+
Window size used by the model for local attention / chunking (implementation-dependent).
|
| 215 |
+
output_dim (`int`, *optional*, defaults to 3584):
|
| 216 |
+
Output feature dimension produced by the encoder head (before/after projection, implementation-dependent).
|
| 217 |
+
|
| 218 |
+
grad_checkpointing (`bool`, *optional*, defaults to `False`):
|
| 219 |
+
Whether to enable gradient checkpointing to reduce memory usage during training.
|
| 220 |
+
enable_mp (`bool`, *optional*, defaults to `False`):
|
| 221 |
+
Whether to enable model parallel features (implementation-dependent).
|
| 222 |
+
audio_sequence_parallel (`bool`, *optional*, defaults to `False`):
|
| 223 |
+
Whether to enable sequence parallelism for audio branch (implementation-dependent).
|
| 224 |
+
|
| 225 |
+
audio_vq_type (`str`, *optional*, defaults to `"GRVQ"`):
|
| 226 |
+
Type of audio vector-quantization module. Common choices: `"GRVQ"`, `"RVQ"`, etc.
|
| 227 |
+
audio_vq_layers (`int`, *optional*, defaults to 6):
|
| 228 |
+
Number of VQ layers / quantizers (e.g., number of residual quantizers for RVQ/GRVQ-like designs).
|
| 229 |
+
audio_vq_codebook_size (`int`, *optional*, defaults to 32768):
|
| 230 |
+
Size of each codebook (number of entries).
|
| 231 |
+
audio_vq_codebook_dim (`int`, *optional*, defaults to 1280):
|
| 232 |
+
Dimension of codebook vectors (often equals encoder hidden size).
|
| 233 |
+
audio_vq_pe (`bool`, *optional*, defaults to `True`):
|
| 234 |
+
Whether to use positional encoding (or position embeddings) inside the VQ module.
|
| 235 |
+
audio_vq_ds_rate (`int`, *optional*, defaults to 2):
|
| 236 |
+
Downsampling rate applied before VQ (e.g., temporal downsample factor).
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
model_type = "qwen3_tts_tokenizer_v1_encoder"
|
| 240 |
+
|
| 241 |
+
def __init__(
|
| 242 |
+
self,
|
| 243 |
+
n_mels=128,
|
| 244 |
+
n_ctx=1500,
|
| 245 |
+
n_state=1280,
|
| 246 |
+
n_head=20,
|
| 247 |
+
n_layer=32,
|
| 248 |
+
n_window=100,
|
| 249 |
+
output_dim=3584,
|
| 250 |
+
grad_checkpointing=False,
|
| 251 |
+
enable_mp=False,
|
| 252 |
+
audio_sequence_parallel=False,
|
| 253 |
+
audio_vq_type="GRVQ",
|
| 254 |
+
audio_vq_layers=6,
|
| 255 |
+
audio_vq_codebook_size=32768,
|
| 256 |
+
audio_vq_codebook_dim=1280,
|
| 257 |
+
audio_vq_pe=True,
|
| 258 |
+
audio_vq_ds_rate=2,
|
| 259 |
+
**kwargs,
|
| 260 |
+
):
|
| 261 |
+
super().__init__(**kwargs)
|
| 262 |
+
self.n_mels = n_mels
|
| 263 |
+
self.n_ctx = n_ctx
|
| 264 |
+
self.n_state = n_state
|
| 265 |
+
self.n_head = n_head
|
| 266 |
+
self.n_layer = n_layer
|
| 267 |
+
self.n_window = n_window
|
| 268 |
+
self.output_dim = output_dim
|
| 269 |
+
self.grad_checkpointing = grad_checkpointing
|
| 270 |
+
self.enable_mp = enable_mp
|
| 271 |
+
self.audio_sequence_parallel = audio_sequence_parallel
|
| 272 |
+
self.audio_vq_type = audio_vq_type
|
| 273 |
+
self.audio_vq_layers = audio_vq_layers
|
| 274 |
+
self.audio_vq_codebook_size = audio_vq_codebook_size
|
| 275 |
+
self.audio_vq_codebook_dim = audio_vq_codebook_dim
|
| 276 |
+
self.audio_vq_pe = audio_vq_pe
|
| 277 |
+
self.audio_vq_ds_rate = audio_vq_ds_rate
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class Qwen3TTSTokenizerV1Config(PretrainedConfig):
|
| 281 |
+
"""
|
| 282 |
+
This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV1Config`]. It is used to instantiate a Qwen3TTSTokenizerV1Model
|
| 283 |
+
model according to the specified sub-models configurations, defining the model architecture.
|
| 284 |
+
|
| 285 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 286 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 287 |
+
|
| 288 |
+
Args:
|
| 289 |
+
encoder_config (`dict`, *optional*): Configuration of the underlying encoder sub-model.
|
| 290 |
+
decoder_config (`dict`, *optional*): Configuration of the underlying decoder sub-model.
|
| 291 |
+
"""
|
| 292 |
+
|
| 293 |
+
model_type = "qwen3_tts_tokenizer_25hz"
|
| 294 |
+
sub_configs = {
|
| 295 |
+
"encoder_config": Qwen3TTSTokenizerV1EncoderConfig,
|
| 296 |
+
"decoder_config": Qwen3TTSTokenizerV1DecoderConfig,
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
def __init__(
|
| 300 |
+
self,
|
| 301 |
+
encoder_config=None,
|
| 302 |
+
decoder_config=None,
|
| 303 |
+
input_sample_rate=24000,
|
| 304 |
+
output_sample_rate=24000,
|
| 305 |
+
decode_upsample_rate=1920,
|
| 306 |
+
encode_downsample_rate=1920,
|
| 307 |
+
**kwargs,
|
| 308 |
+
):
|
| 309 |
+
super().__init__(**kwargs)
|
| 310 |
+
if encoder_config is None:
|
| 311 |
+
encoder_config = {}
|
| 312 |
+
logger.info("encoder_config is None. Initializing encoder with default values")
|
| 313 |
+
if decoder_config is None:
|
| 314 |
+
decoder_config = {}
|
| 315 |
+
logger.info("decoder_config is None. Initializing decoder with default values")
|
| 316 |
+
|
| 317 |
+
self.encoder_config = Qwen3TTSTokenizerV1EncoderConfig(**encoder_config)
|
| 318 |
+
self.decoder_config = Qwen3TTSTokenizerV1DecoderConfig(**decoder_config)
|
| 319 |
+
|
| 320 |
+
self.input_sample_rate = input_sample_rate
|
| 321 |
+
self.output_sample_rate = output_sample_rate
|
| 322 |
+
self.decode_upsample_rate = decode_upsample_rate
|
| 323 |
+
self.encode_downsample_rate = encode_downsample_rate
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
__all__ = [
|
| 327 |
+
"Qwen3TTSTokenizerV1Config",
|
| 328 |
+
"Qwen3TTSTokenizerV1EncoderConfig",
|
| 329 |
+
"Qwen3TTSTokenizerV1DecoderConfig",
|
| 330 |
+
"Qwen3TTSTokenizerV1DecoderBigVGANConfig",
|
| 331 |
+
"Qwen3TTSTokenizerV1DecoderDiTConfig"
|
| 332 |
+
]
|
qwen_tts/core/tokenizer_25hz/modeling_qwen3_tts_tokenizer_v1.py
ADDED
|
@@ -0,0 +1,1528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch Qwen3TTSTokenizerV1 model."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Optional, Union, List
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
from torch import nn
|
| 24 |
+
from torch.nn import Parameter
|
| 25 |
+
from torch.nn import functional as F
|
| 26 |
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 27 |
+
from transformers.utils import ModelOutput, auto_docstring, logging
|
| 28 |
+
from transformers.utils.hub import cached_file
|
| 29 |
+
|
| 30 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 31 |
+
|
| 32 |
+
from .vq.whisper_encoder import get_mel_audio, get_T_after_cnn
|
| 33 |
+
from .vq.speech_vq import WhisperEncoderVQ, XVectorExtractor
|
| 34 |
+
|
| 35 |
+
from .configuration_qwen3_tts_tokenizer_v1 import (
|
| 36 |
+
Qwen3TTSTokenizerV1Config,
|
| 37 |
+
Qwen3TTSTokenizerV1EncoderConfig,
|
| 38 |
+
Qwen3TTSTokenizerV1DecoderConfig,
|
| 39 |
+
Qwen3TTSTokenizerV1DecoderBigVGANConfig,
|
| 40 |
+
Qwen3TTSTokenizerV1DecoderDiTConfig
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
logger = logging.get_logger(__name__)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
@auto_docstring
|
| 48 |
+
class Qwen3TTSTokenizerV1EncoderOutput(ModelOutput):
|
| 49 |
+
r"""
|
| 50 |
+
audio_codes (`List[torch.LongTensor]`):
|
| 51 |
+
Discret code embeddings computed using `model.encode`, each tensor has shape (codes_length_i,).
|
| 52 |
+
xvectors (`List[torch.FloatTensor]`):
|
| 53 |
+
X-vector embeddings computed using `model.encode`, each tensor has shape (xvector_dim,).
|
| 54 |
+
ref_mels (`List[torch.FloatTensor]`):
|
| 55 |
+
Reference mel spectrogram computed using `model.encode`, each tensor has shape (mel_length_i, mel_dim,).
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
audio_codes: List[torch.LongTensor] = None
|
| 59 |
+
xvectors: List[torch.FloatTensor] = None
|
| 60 |
+
ref_mels: List[torch.FloatTensor] = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@dataclass
|
| 64 |
+
@auto_docstring
|
| 65 |
+
class Qwen3TTSTokenizerV1DecoderOutput(ModelOutput):
|
| 66 |
+
r"""
|
| 67 |
+
audio_values (`List[torch.FloatTensor]`):
|
| 68 |
+
Decoded audio values, obtained using the decoder part of Qwen3TTSTokenizerV1.
|
| 69 |
+
Each tensor has shape (segment_length_i).
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
audio_values: List[torch.FloatTensor] = None
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@auto_docstring
|
| 76 |
+
class Qwen3TTSTokenizerV1DecoderPreTrainedModel(PreTrainedModel):
|
| 77 |
+
config: Qwen3TTSTokenizerV1DecoderConfig
|
| 78 |
+
base_model_prefix = "model"
|
| 79 |
+
supports_gradient_checkpointing = True
|
| 80 |
+
_skip_keys_device_placement = "past_key_values"
|
| 81 |
+
_supports_flash_attn = True
|
| 82 |
+
_supports_sdpa = True
|
| 83 |
+
_can_compile_fullgraph = False
|
| 84 |
+
_supports_attention_backend = True
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@auto_docstring
|
| 88 |
+
class Qwen3TTSTokenizerV1EncoderPreTrainedModel(PreTrainedModel):
|
| 89 |
+
config: Qwen3TTSTokenizerV1EncoderConfig
|
| 90 |
+
base_model_prefix = "model"
|
| 91 |
+
supports_gradient_checkpointing = True
|
| 92 |
+
_skip_keys_device_placement = "past_key_values"
|
| 93 |
+
_supports_flash_attn = True
|
| 94 |
+
_supports_sdpa = True
|
| 95 |
+
_can_compile_fullgraph = False
|
| 96 |
+
_supports_attention_backend = True
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class Qwen3TTSTokenizerV1DecoderDiTRotaryEmbedding(nn.Module):
|
| 100 |
+
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
| 101 |
+
|
| 102 |
+
def __init__(self, dim, base=10000):
|
| 103 |
+
super().__init__()
|
| 104 |
+
|
| 105 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 106 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 107 |
+
|
| 108 |
+
def forward(self, x):
|
| 109 |
+
batch_size, seq_len = x.shape[0], x.shape[1]
|
| 110 |
+
t = torch.arange(seq_len, device=x.device)
|
| 111 |
+
device_type = x.device.type
|
| 112 |
+
device_type = device_type if device_type != "mps" else "cpu"
|
| 113 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 114 |
+
freqs = t.unsqueeze(1).float() @ self.inv_freq.unsqueeze(0).float()
|
| 115 |
+
freqs = torch.stack((freqs, freqs), dim=-1)
|
| 116 |
+
freqs = freqs.reshape(*freqs.shape[:-2], -1)
|
| 117 |
+
freqs = freqs.repeat(batch_size, *([1] * freqs.dim()))
|
| 118 |
+
cos = freqs.cos()
|
| 119 |
+
sin = freqs.sin()
|
| 120 |
+
|
| 121 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class TimeDelayNetBlock(nn.Module):
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
in_channels,
|
| 128 |
+
out_channels,
|
| 129 |
+
kernel_size,
|
| 130 |
+
dilation,
|
| 131 |
+
):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.conv = nn.Conv1d(
|
| 134 |
+
in_channels=in_channels,
|
| 135 |
+
out_channels=out_channels,
|
| 136 |
+
kernel_size=kernel_size,
|
| 137 |
+
dilation=dilation,
|
| 138 |
+
padding="same",
|
| 139 |
+
padding_mode="reflect",
|
| 140 |
+
)
|
| 141 |
+
self.activation = nn.ReLU()
|
| 142 |
+
|
| 143 |
+
def forward(self, hidden_states: torch.Tensor):
|
| 144 |
+
return self.activation(self.conv(hidden_states))
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class Res2NetBlock(torch.nn.Module):
|
| 148 |
+
def __init__(self, in_channels, out_channels, scale=8, kernel_size=3, dilation=1):
|
| 149 |
+
super().__init__()
|
| 150 |
+
|
| 151 |
+
in_channel = in_channels // scale
|
| 152 |
+
hidden_channel = out_channels // scale
|
| 153 |
+
|
| 154 |
+
self.blocks = nn.ModuleList(
|
| 155 |
+
[
|
| 156 |
+
TimeDelayNetBlock(
|
| 157 |
+
in_channel,
|
| 158 |
+
hidden_channel,
|
| 159 |
+
kernel_size=kernel_size,
|
| 160 |
+
dilation=dilation,
|
| 161 |
+
)
|
| 162 |
+
for i in range(scale - 1)
|
| 163 |
+
]
|
| 164 |
+
)
|
| 165 |
+
self.scale = scale
|
| 166 |
+
|
| 167 |
+
def forward(self, hidden_states):
|
| 168 |
+
outputs = []
|
| 169 |
+
for i, hidden_part in enumerate(torch.chunk(hidden_states, self.scale, dim=1)):
|
| 170 |
+
if i == 0:
|
| 171 |
+
output_part = hidden_part
|
| 172 |
+
elif i == 1:
|
| 173 |
+
output_part = self.blocks[i - 1](hidden_part)
|
| 174 |
+
else:
|
| 175 |
+
output_part = self.blocks[i - 1](hidden_part + output_part)
|
| 176 |
+
outputs.append(output_part)
|
| 177 |
+
output = torch.cat(outputs, dim=1)
|
| 178 |
+
return output
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class SqueezeExcitationBlock(nn.Module):
|
| 182 |
+
def __init__(self, in_channels, se_channels, out_channels):
|
| 183 |
+
super().__init__()
|
| 184 |
+
|
| 185 |
+
self.conv1 = nn.Conv1d(
|
| 186 |
+
in_channels=in_channels,
|
| 187 |
+
out_channels=se_channels,
|
| 188 |
+
kernel_size=1,
|
| 189 |
+
padding="same",
|
| 190 |
+
padding_mode="reflect",
|
| 191 |
+
)
|
| 192 |
+
self.relu = nn.ReLU(inplace=True)
|
| 193 |
+
self.conv2 = nn.Conv1d(
|
| 194 |
+
in_channels=se_channels,
|
| 195 |
+
out_channels=out_channels,
|
| 196 |
+
kernel_size=1,
|
| 197 |
+
padding="same",
|
| 198 |
+
padding_mode="reflect",
|
| 199 |
+
)
|
| 200 |
+
self.sigmoid = nn.Sigmoid()
|
| 201 |
+
|
| 202 |
+
def forward(self, hidden_states):
|
| 203 |
+
hidden_states_mean = hidden_states.mean(dim=2, keepdim=True)
|
| 204 |
+
|
| 205 |
+
hidden_states_mean = self.relu(self.conv1(hidden_states_mean))
|
| 206 |
+
hidden_states_mean = self.sigmoid(self.conv2(hidden_states_mean))
|
| 207 |
+
|
| 208 |
+
return hidden_states * hidden_states_mean
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class AttentiveStatisticsPooling(nn.Module):
|
| 212 |
+
"""This class implements an attentive statistic pooling layer for each channel.
|
| 213 |
+
It returns the concatenated mean and std of the input tensor.
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
def __init__(self, channels, attention_channels=128):
|
| 217 |
+
super().__init__()
|
| 218 |
+
|
| 219 |
+
self.eps = 1e-12
|
| 220 |
+
self.tdnn = TimeDelayNetBlock(channels * 3, attention_channels, 1, 1)
|
| 221 |
+
self.tanh = nn.Tanh()
|
| 222 |
+
self.conv = nn.Conv1d(
|
| 223 |
+
in_channels=attention_channels,
|
| 224 |
+
out_channels=channels,
|
| 225 |
+
kernel_size=1,
|
| 226 |
+
padding="same",
|
| 227 |
+
padding_mode="reflect",
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
def _length_to_mask(self, length, max_len=None, dtype=None, device=None):
|
| 231 |
+
"""Creates a binary mask for each sequence.
|
| 232 |
+
|
| 233 |
+
Reference: https://discuss.pytorch.org/t/how-to-generate-variable-length-mask/23397/3
|
| 234 |
+
|
| 235 |
+
Arguments
|
| 236 |
+
---------
|
| 237 |
+
length : torch.LongTensor
|
| 238 |
+
Containing the length of each sequence in the batch. Must be 1D.
|
| 239 |
+
max_len : int
|
| 240 |
+
Max length for the mask, also the size of the second dimension.
|
| 241 |
+
dtype : torch.dtype, default: None
|
| 242 |
+
The dtype of the generated mask.
|
| 243 |
+
device: torch.device, default: None
|
| 244 |
+
The device to put the mask variable.
|
| 245 |
+
|
| 246 |
+
Returns
|
| 247 |
+
-------
|
| 248 |
+
mask : tensor
|
| 249 |
+
The binary mask.
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
if max_len is None:
|
| 253 |
+
max_len = length.max().long().item() # using arange to generate mask
|
| 254 |
+
mask = torch.arange(max_len, device=length.device, dtype=length.dtype).expand(
|
| 255 |
+
len(length), max_len
|
| 256 |
+
) < length.unsqueeze(1)
|
| 257 |
+
|
| 258 |
+
mask = torch.as_tensor(mask, dtype=dtype, device=device)
|
| 259 |
+
return mask
|
| 260 |
+
|
| 261 |
+
def _compute_statistics(self, x, m, dim=2):
|
| 262 |
+
mean = (m * x).sum(dim)
|
| 263 |
+
std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(self.eps))
|
| 264 |
+
return mean, std
|
| 265 |
+
|
| 266 |
+
def forward(self, hidden_states):
|
| 267 |
+
seq_length = hidden_states.shape[-1]
|
| 268 |
+
lengths = torch.ones(hidden_states.shape[0], device=hidden_states.device)
|
| 269 |
+
|
| 270 |
+
# Make binary mask of shape [N, 1, L]
|
| 271 |
+
mask = self._length_to_mask(
|
| 272 |
+
lengths * seq_length, max_len=seq_length, dtype=hidden_states.dtype, device=hidden_states.device
|
| 273 |
+
)
|
| 274 |
+
mask = mask.unsqueeze(1)
|
| 275 |
+
|
| 276 |
+
# Expand the temporal context of the pooling layer by allowing the
|
| 277 |
+
# self-attention to look at global properties of the utterance.
|
| 278 |
+
total = mask.sum(dim=2, keepdim=True)
|
| 279 |
+
|
| 280 |
+
mean, std = self._compute_statistics(hidden_states, mask / total)
|
| 281 |
+
mean = mean.unsqueeze(2).repeat(1, 1, seq_length)
|
| 282 |
+
std = std.unsqueeze(2).repeat(1, 1, seq_length)
|
| 283 |
+
attention = torch.cat([hidden_states, mean, std], dim=1)
|
| 284 |
+
|
| 285 |
+
# Apply layers
|
| 286 |
+
attention = self.conv(self.tanh(self.tdnn(attention)))
|
| 287 |
+
|
| 288 |
+
# Filter out zero-paddings
|
| 289 |
+
attention = attention.masked_fill(mask == 0, float("-inf"))
|
| 290 |
+
|
| 291 |
+
attention = F.softmax(attention, dim=2)
|
| 292 |
+
mean, std = self._compute_statistics(hidden_states, attention)
|
| 293 |
+
# Append mean and std of the batch
|
| 294 |
+
pooled_stats = torch.cat((mean, std), dim=1)
|
| 295 |
+
pooled_stats = pooled_stats.unsqueeze(2)
|
| 296 |
+
|
| 297 |
+
return pooled_stats
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class SqueezeExcitationRes2NetBlock(nn.Module):
|
| 301 |
+
"""An implementation of building block in ECAPA-TDNN, i.e.,
|
| 302 |
+
TDNN-Res2Net-TDNN-SqueezeExcitationBlock.
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
def __init__(
|
| 306 |
+
self,
|
| 307 |
+
in_channels,
|
| 308 |
+
out_channels,
|
| 309 |
+
res2net_scale=8,
|
| 310 |
+
se_channels=128,
|
| 311 |
+
kernel_size=1,
|
| 312 |
+
dilation=1,
|
| 313 |
+
):
|
| 314 |
+
super().__init__()
|
| 315 |
+
self.out_channels = out_channels
|
| 316 |
+
self.tdnn1 = TimeDelayNetBlock(
|
| 317 |
+
in_channels,
|
| 318 |
+
out_channels,
|
| 319 |
+
kernel_size=1,
|
| 320 |
+
dilation=1,
|
| 321 |
+
)
|
| 322 |
+
self.res2net_block = Res2NetBlock(out_channels, out_channels, res2net_scale, kernel_size, dilation)
|
| 323 |
+
self.tdnn2 = TimeDelayNetBlock(
|
| 324 |
+
out_channels,
|
| 325 |
+
out_channels,
|
| 326 |
+
kernel_size=1,
|
| 327 |
+
dilation=1,
|
| 328 |
+
)
|
| 329 |
+
self.se_block = SqueezeExcitationBlock(out_channels, se_channels, out_channels)
|
| 330 |
+
|
| 331 |
+
def forward(self, hidden_state):
|
| 332 |
+
residual = hidden_state
|
| 333 |
+
|
| 334 |
+
hidden_state = self.tdnn1(hidden_state)
|
| 335 |
+
hidden_state = self.res2net_block(hidden_state)
|
| 336 |
+
hidden_state = self.tdnn2(hidden_state)
|
| 337 |
+
hidden_state = self.se_block(hidden_state)
|
| 338 |
+
|
| 339 |
+
return hidden_state + residual
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
class ECAPA_TimeDelayNet(torch.nn.Module):
|
| 343 |
+
"""An implementation of the speaker embedding model in a paper.
|
| 344 |
+
"ECAPA-TDNN: Emphasized Channel Attention, Propagation and Aggregation in
|
| 345 |
+
TDNN Based Speaker Verification" (https://huggingface.co/papers/2005.07143).
|
| 346 |
+
"""
|
| 347 |
+
|
| 348 |
+
def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig):
|
| 349 |
+
super().__init__()
|
| 350 |
+
if len(config.enc_channels) != len(config.enc_kernel_sizes) or len(config.enc_channels) != len(
|
| 351 |
+
config.enc_dilations
|
| 352 |
+
):
|
| 353 |
+
raise ValueError("enc_channels, enc_kernel_sizes and enc_dilations should have same length")
|
| 354 |
+
self.channels = config.enc_channels
|
| 355 |
+
self.blocks = nn.ModuleList()
|
| 356 |
+
|
| 357 |
+
# The initial TDNN layer
|
| 358 |
+
self.blocks.append(
|
| 359 |
+
TimeDelayNetBlock(
|
| 360 |
+
config.mel_dim,
|
| 361 |
+
config.enc_channels[0],
|
| 362 |
+
config.enc_kernel_sizes[0],
|
| 363 |
+
config.enc_dilations[0],
|
| 364 |
+
)
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# SE-Res2Net layers
|
| 368 |
+
for i in range(1, len(config.enc_channels) - 1):
|
| 369 |
+
self.blocks.append(
|
| 370 |
+
SqueezeExcitationRes2NetBlock(
|
| 371 |
+
config.enc_channels[i - 1],
|
| 372 |
+
config.enc_channels[i],
|
| 373 |
+
res2net_scale=config.enc_res2net_scale,
|
| 374 |
+
se_channels=config.enc_se_channels,
|
| 375 |
+
kernel_size=config.enc_kernel_sizes[i],
|
| 376 |
+
dilation=config.enc_dilations[i],
|
| 377 |
+
)
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# Multi-layer feature aggregation
|
| 381 |
+
self.mfa = TimeDelayNetBlock(
|
| 382 |
+
config.enc_channels[-1],
|
| 383 |
+
config.enc_channels[-1],
|
| 384 |
+
config.enc_kernel_sizes[-1],
|
| 385 |
+
config.enc_dilations[-1],
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
# Attentive Statistical Pooling
|
| 389 |
+
self.asp = AttentiveStatisticsPooling(
|
| 390 |
+
config.enc_channels[-1],
|
| 391 |
+
attention_channels=config.enc_attention_channels,
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
# Final linear transformation
|
| 395 |
+
self.fc = nn.Conv1d(
|
| 396 |
+
in_channels=config.enc_channels[-1] * 2,
|
| 397 |
+
out_channels=config.enc_dim,
|
| 398 |
+
kernel_size=1,
|
| 399 |
+
padding="same",
|
| 400 |
+
padding_mode="reflect",
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
def forward(self, hidden_states):
|
| 404 |
+
# Minimize transpose for efficiency
|
| 405 |
+
hidden_states = hidden_states.transpose(1, 2)
|
| 406 |
+
|
| 407 |
+
hidden_states_list = []
|
| 408 |
+
for layer in self.blocks:
|
| 409 |
+
hidden_states = layer(hidden_states)
|
| 410 |
+
hidden_states_list.append(hidden_states)
|
| 411 |
+
|
| 412 |
+
# Multi-layer feature aggregation
|
| 413 |
+
hidden_states = torch.cat(hidden_states_list[1:], dim=1)
|
| 414 |
+
hidden_states = self.mfa(hidden_states)
|
| 415 |
+
|
| 416 |
+
# Attentive Statistical Pooling
|
| 417 |
+
hidden_states = self.asp(hidden_states)
|
| 418 |
+
|
| 419 |
+
# Final linear transformation
|
| 420 |
+
hidden_states = self.fc(hidden_states)
|
| 421 |
+
|
| 422 |
+
hidden_states = hidden_states.squeeze(-1)
|
| 423 |
+
return hidden_states
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
class DiTInputEmbedding(nn.Module):
|
| 427 |
+
def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig):
|
| 428 |
+
super().__init__()
|
| 429 |
+
self.proj = nn.Linear(
|
| 430 |
+
config.mel_dim + config.enc_dim + config.enc_emb_dim + config.emb_dim,
|
| 431 |
+
config.hidden_size,
|
| 432 |
+
)
|
| 433 |
+
self.spk_encoder = ECAPA_TimeDelayNet(config)
|
| 434 |
+
|
| 435 |
+
def forward(
|
| 436 |
+
self,
|
| 437 |
+
hidden_states: torch.Tensor,
|
| 438 |
+
speaker_embedding: torch.Tensor,
|
| 439 |
+
condition_vector: torch.Tensor,
|
| 440 |
+
code_embed: torch.Tensor,
|
| 441 |
+
drop_audio_cond: Optional[bool] = False,
|
| 442 |
+
code_embed_uncond: Optional[bool] = None,
|
| 443 |
+
apply_cfg: Optional[bool] = True,
|
| 444 |
+
):
|
| 445 |
+
if apply_cfg:
|
| 446 |
+
hidden_states = torch.cat([hidden_states, hidden_states], dim=0)
|
| 447 |
+
speaker_embedding = torch.cat([speaker_embedding, torch.zeros_like(speaker_embedding)], dim=0)
|
| 448 |
+
condition_vector = torch.cat([condition_vector, torch.zeros_like(condition_vector)], dim=0)
|
| 449 |
+
code_embed = torch.cat([code_embed, code_embed_uncond], dim=0)
|
| 450 |
+
elif drop_audio_cond: # cfg for cond audio
|
| 451 |
+
condition_vector = torch.zeros_like(condition_vector)
|
| 452 |
+
speaker_embedding = torch.zeros_like(speaker_embedding)
|
| 453 |
+
condition_vector = self.spk_encoder(condition_vector).unsqueeze(1).repeat(1, hidden_states.size(1), 1)
|
| 454 |
+
hidden_states = self.proj(torch.cat((hidden_states, condition_vector, code_embed, speaker_embedding), dim=-1))
|
| 455 |
+
|
| 456 |
+
return hidden_states
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
# Transformer backbone using DiT blocks
|
| 460 |
+
class DiTCodecEmbedding(nn.Module):
|
| 461 |
+
def __init__(self, codec_num_embeds, codec_dim, repeats):
|
| 462 |
+
super().__init__()
|
| 463 |
+
self.repeats = repeats
|
| 464 |
+
self.codec_embed = nn.Embedding(codec_num_embeds + 1, codec_dim)
|
| 465 |
+
|
| 466 |
+
def forward(self, code, drop_code=False):
|
| 467 |
+
if drop_code:
|
| 468 |
+
code = torch.zeros_like(code)
|
| 469 |
+
code_embed = self.codec_embed(code)
|
| 470 |
+
|
| 471 |
+
code_embed = torch.repeat_interleave(code_embed, repeats=self.repeats, dim=1)
|
| 472 |
+
return code_embed
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
# AdaLayerNormZero
|
| 476 |
+
# return with modulated x for attn input, and params for later mlp modulation
|
| 477 |
+
class AdaLayerNormZero(nn.Module):
|
| 478 |
+
def __init__(self, dim):
|
| 479 |
+
super().__init__()
|
| 480 |
+
|
| 481 |
+
self.silu = nn.SiLU()
|
| 482 |
+
self.linear = nn.Linear(dim, dim * 6)
|
| 483 |
+
|
| 484 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 485 |
+
|
| 486 |
+
def forward(self, hidden_states, emb=None):
|
| 487 |
+
emb = self.linear(self.silu(emb))
|
| 488 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
| 489 |
+
|
| 490 |
+
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
| 491 |
+
return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
# AdaLayerNormZero for final layer
|
| 495 |
+
# return only with modulated x for attn input, cuz no more mlp modulation
|
| 496 |
+
class AdaLayerNormZero_Final(nn.Module):
|
| 497 |
+
def __init__(self, dim):
|
| 498 |
+
super().__init__()
|
| 499 |
+
|
| 500 |
+
self.silu = nn.SiLU()
|
| 501 |
+
self.linear = nn.Linear(dim, dim * 2)
|
| 502 |
+
|
| 503 |
+
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
| 504 |
+
|
| 505 |
+
def forward(self, hidden_states, emb):
|
| 506 |
+
emb = self.linear(self.silu(emb))
|
| 507 |
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
| 508 |
+
|
| 509 |
+
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
| 510 |
+
return hidden_states
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
# FeedForward
|
| 514 |
+
class DiTMLP(nn.Module):
|
| 515 |
+
def __init__(self, dim, mult=4, dropout=0.0):
|
| 516 |
+
super().__init__()
|
| 517 |
+
inner_dim = int(dim * mult)
|
| 518 |
+
|
| 519 |
+
self.ff = nn.ModuleList(
|
| 520 |
+
[
|
| 521 |
+
nn.Linear(dim, inner_dim),
|
| 522 |
+
nn.GELU(approximate="tanh"),
|
| 523 |
+
nn.Dropout(dropout),
|
| 524 |
+
nn.Linear(inner_dim, dim),
|
| 525 |
+
]
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
def forward(self, hidden_states):
|
| 529 |
+
for layer in self.ff:
|
| 530 |
+
hidden_states = layer(hidden_states)
|
| 531 |
+
return hidden_states
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
# Modified from Llama with a different rotate function, will fixed in next release
|
| 535 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
| 536 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
| 537 |
+
|
| 538 |
+
Args:
|
| 539 |
+
q (`torch.Tensor`): The query tensor.
|
| 540 |
+
k (`torch.Tensor`): The key tensor.
|
| 541 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 542 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 543 |
+
position_ids (`torch.Tensor`, *optional*):
|
| 544 |
+
Deprecated and unused.
|
| 545 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 546 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 547 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 548 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 549 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 550 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 551 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 552 |
+
Returns:
|
| 553 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 554 |
+
"""
|
| 555 |
+
|
| 556 |
+
def rotate_half_codec(x):
|
| 557 |
+
# x = rearrange(x, "... (d r) -> ... d r", r=2)
|
| 558 |
+
x = x.reshape(*x.shape[:-1], -1, 2)
|
| 559 |
+
x1, x2 = x.unbind(dim=-1)
|
| 560 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 561 |
+
return x.reshape(*x.shape[:-2], -1)
|
| 562 |
+
|
| 563 |
+
cos = cos.unsqueeze(unsqueeze_dim)
|
| 564 |
+
sin = sin.unsqueeze(unsqueeze_dim)
|
| 565 |
+
q_embed = (q * cos) + (rotate_half_codec(q) * sin)
|
| 566 |
+
k_embed = (k * cos) + (rotate_half_codec(k) * sin)
|
| 567 |
+
return q_embed, k_embed
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
class DiTAttention(nn.Module):
|
| 571 |
+
def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig):
|
| 572 |
+
super().__init__()
|
| 573 |
+
|
| 574 |
+
self.config = config
|
| 575 |
+
self.dim = config.hidden_size
|
| 576 |
+
self.heads = config.num_attention_heads
|
| 577 |
+
self.inner_dim = config.head_dim * config.num_attention_heads
|
| 578 |
+
self.dropout = config.dropout
|
| 579 |
+
self.is_causal = False
|
| 580 |
+
|
| 581 |
+
self.to_q = nn.Linear(config.hidden_size, self.inner_dim)
|
| 582 |
+
self.to_k = nn.Linear(config.hidden_size, self.inner_dim)
|
| 583 |
+
self.to_v = nn.Linear(config.hidden_size, self.inner_dim)
|
| 584 |
+
|
| 585 |
+
self.to_out = nn.ModuleList([nn.Linear(self.inner_dim, config.hidden_size), nn.Dropout(config.dropout)])
|
| 586 |
+
|
| 587 |
+
def forward(
|
| 588 |
+
self,
|
| 589 |
+
hidden_states, # noised input x
|
| 590 |
+
position_embeddings=None, # rotary position embedding for x
|
| 591 |
+
attention_mask=None,
|
| 592 |
+
) -> torch.Tensor:
|
| 593 |
+
batch_size = hidden_states.shape[0]
|
| 594 |
+
|
| 595 |
+
# `sample` projections.
|
| 596 |
+
query = self.to_q(hidden_states)
|
| 597 |
+
key = self.to_k(hidden_states)
|
| 598 |
+
value = self.to_v(hidden_states)
|
| 599 |
+
|
| 600 |
+
# attention
|
| 601 |
+
inner_dim = key.shape[-1]
|
| 602 |
+
head_dim = inner_dim // self.heads
|
| 603 |
+
query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
| 604 |
+
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
| 605 |
+
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
|
| 606 |
+
|
| 607 |
+
# apply rotary position embedding
|
| 608 |
+
# Due to training process, only first head is applied with RoPE, will be fixed at next release
|
| 609 |
+
cos, sin = position_embeddings
|
| 610 |
+
query, key = apply_rotary_pos_emb(query, key, cos, sin)
|
| 611 |
+
|
| 612 |
+
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
| 613 |
+
attention_weights, _ = attention_interface(
|
| 614 |
+
self,
|
| 615 |
+
query,
|
| 616 |
+
key,
|
| 617 |
+
value,
|
| 618 |
+
attention_mask=attention_mask,
|
| 619 |
+
is_causal=False,
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
| 623 |
+
attention_weights = attention_weights.reshape(batch_size, -1, self.heads * head_dim)
|
| 624 |
+
attention_weights = attention_weights.to(query.dtype)
|
| 625 |
+
|
| 626 |
+
# linear proj
|
| 627 |
+
attention_output = self.to_out[0](attention_weights)
|
| 628 |
+
attention_output = self.to_out[1](attention_output)
|
| 629 |
+
|
| 630 |
+
return attention_output
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
# time step conditioning embedding
|
| 634 |
+
class SinusPositionEmbedding(nn.Module):
|
| 635 |
+
def __init__(self, dim):
|
| 636 |
+
super().__init__()
|
| 637 |
+
self.dim = dim
|
| 638 |
+
|
| 639 |
+
def forward(self, hidden_states, scale=1000):
|
| 640 |
+
device = hidden_states.device
|
| 641 |
+
half_dim = self.dim // 2
|
| 642 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 643 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
| 644 |
+
emb = scale * hidden_states.unsqueeze(1) * emb.unsqueeze(0)
|
| 645 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 646 |
+
return emb.type_as(hidden_states)
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
class DiTTimestepEmbedding(nn.Module):
|
| 650 |
+
def __init__(self, dim, freq_embed_dim=256):
|
| 651 |
+
super().__init__()
|
| 652 |
+
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
| 653 |
+
self.time_mlp = nn.ModuleList([nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim)])
|
| 654 |
+
|
| 655 |
+
def forward(self, timestep):
|
| 656 |
+
time_hidden = self.time_embed(timestep)
|
| 657 |
+
time_hidden = time_hidden.to(timestep.dtype)
|
| 658 |
+
for layer in self.time_mlp:
|
| 659 |
+
time_hidden = layer(time_hidden) # b d
|
| 660 |
+
return time_hidden
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
class DiTDecoderLayer(nn.Module):
|
| 664 |
+
def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig, look_ahead_block=0, look_backward_block=0):
|
| 665 |
+
super().__init__()
|
| 666 |
+
self.attn_norm = AdaLayerNormZero(config.hidden_size)
|
| 667 |
+
|
| 668 |
+
self.attn = DiTAttention(config)
|
| 669 |
+
self.look_ahead_block = look_ahead_block
|
| 670 |
+
self.look_backward_block = look_backward_block
|
| 671 |
+
self.ff_norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, eps=1e-6)
|
| 672 |
+
self.ff = DiTMLP(dim=config.hidden_size, mult=config.ff_mult, dropout=config.dropout)
|
| 673 |
+
|
| 674 |
+
def forward(
|
| 675 |
+
self, hidden_states, timestep, position_embeddings=None, block_diff=None
|
| 676 |
+
): # x: noised input, t: time embedding
|
| 677 |
+
# pre-norm & modulation for attention input
|
| 678 |
+
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(hidden_states, emb=timestep)
|
| 679 |
+
|
| 680 |
+
# attention
|
| 681 |
+
attn_output = self.attn(
|
| 682 |
+
hidden_states=norm,
|
| 683 |
+
position_embeddings=position_embeddings,
|
| 684 |
+
attention_mask=(block_diff >= -float(self.look_backward_block))
|
| 685 |
+
& (block_diff <= float(self.look_ahead_block)),
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
# process attention output for input x
|
| 689 |
+
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_output
|
| 690 |
+
|
| 691 |
+
norm = self.ff_norm(hidden_states) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 692 |
+
ff_output = self.ff(norm)
|
| 693 |
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
|
| 694 |
+
|
| 695 |
+
return hidden_states
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
class SnakeBeta(nn.Module):
|
| 699 |
+
"""
|
| 700 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
| 701 |
+
Shape:
|
| 702 |
+
- Input: (B, C, T)
|
| 703 |
+
- Output: (B, C, T), same shape as the input
|
| 704 |
+
Parameters:
|
| 705 |
+
- alpha - trainable parameter that controls frequency
|
| 706 |
+
- beta - trainable parameter that controls magnitude
|
| 707 |
+
References:
|
| 708 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
| 709 |
+
https://huggingface.co/papers/2006.08195
|
| 710 |
+
"""
|
| 711 |
+
|
| 712 |
+
def __init__(self, in_features, alpha=1.0):
|
| 713 |
+
super().__init__()
|
| 714 |
+
self.in_features = in_features
|
| 715 |
+
|
| 716 |
+
# initialize alpha
|
| 717 |
+
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
| 718 |
+
self.beta = Parameter(torch.zeros(in_features) * alpha)
|
| 719 |
+
|
| 720 |
+
self.no_div_by_zero = 0.000000001
|
| 721 |
+
|
| 722 |
+
def forward(self, hidden_states):
|
| 723 |
+
"""
|
| 724 |
+
Forward pass of the function.
|
| 725 |
+
Applies the function to the input elementwise.
|
| 726 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
| 727 |
+
"""
|
| 728 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
| 729 |
+
beta = self.beta.unsqueeze(0).unsqueeze(-1)
|
| 730 |
+
alpha = torch.exp(alpha)
|
| 731 |
+
beta = torch.exp(beta)
|
| 732 |
+
hidden_states = hidden_states + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(
|
| 733 |
+
torch.sin(hidden_states * alpha), 2
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
return hidden_states
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
|
| 740 |
+
"""Generates a 1D Kaiser-windowed sinc filter.
|
| 741 |
+
|
| 742 |
+
Args:
|
| 743 |
+
cutoff (float): Normalized cutoff frequency (0 to 0.5).
|
| 744 |
+
half_width (float): Transition bandwidth.
|
| 745 |
+
kernel_size (int): Number of filter taps.
|
| 746 |
+
|
| 747 |
+
Returns:
|
| 748 |
+
torch.Tensor: A tensor of shape (1, 1, kernel_size) representing the filter.
|
| 749 |
+
"""
|
| 750 |
+
is_even = kernel_size % 2 == 0
|
| 751 |
+
half_size = kernel_size // 2
|
| 752 |
+
|
| 753 |
+
# Compute Kaiser window parameters
|
| 754 |
+
delta_f = 4 * half_width
|
| 755 |
+
attenuation = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
| 756 |
+
|
| 757 |
+
if attenuation > 50.0:
|
| 758 |
+
beta = 0.1102 * (attenuation - 8.7)
|
| 759 |
+
elif attenuation >= 21.0:
|
| 760 |
+
beta = 0.5842 * (attenuation - 21) ** 0.4 + 0.07886 * (attenuation - 21.0)
|
| 761 |
+
else:
|
| 762 |
+
beta = 0.0
|
| 763 |
+
|
| 764 |
+
kaiser_window = torch.kaiser_window(kernel_size, beta=beta, periodic=False, dtype=torch.float32)
|
| 765 |
+
|
| 766 |
+
# Compute time indices
|
| 767 |
+
if is_even:
|
| 768 |
+
time_indices = torch.arange(-half_size, half_size) + 0.5
|
| 769 |
+
else:
|
| 770 |
+
time_indices = torch.arange(kernel_size) - half_size
|
| 771 |
+
|
| 772 |
+
# Compute sinc filter
|
| 773 |
+
if cutoff == 0:
|
| 774 |
+
return torch.zeros((1, 1, kernel_size), dtype=torch.float32) # Ensures correct shape
|
| 775 |
+
|
| 776 |
+
sinc_filter = torch.sinc(2 * cutoff * time_indices)
|
| 777 |
+
normalized_filter = 2 * cutoff * kaiser_window * sinc_filter
|
| 778 |
+
|
| 779 |
+
# Normalize to ensure sum = 1 (avoid leakage of constant component)
|
| 780 |
+
normalized_filter /= normalized_filter.sum()
|
| 781 |
+
|
| 782 |
+
return normalized_filter.view(1, 1, kernel_size)
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
class UpSample1d(nn.Module):
|
| 786 |
+
def __init__(self, ratio=2, kernel_size=None):
|
| 787 |
+
super().__init__()
|
| 788 |
+
self.ratio = ratio
|
| 789 |
+
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
| 790 |
+
self.stride = ratio
|
| 791 |
+
self.pad = self.kernel_size // ratio - 1
|
| 792 |
+
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
| 793 |
+
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
| 794 |
+
|
| 795 |
+
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
|
| 796 |
+
self.register_buffer("filter", filter, persistent=False)
|
| 797 |
+
|
| 798 |
+
def forward(self, hidden_states):
|
| 799 |
+
channels = hidden_states.shape[1]
|
| 800 |
+
|
| 801 |
+
hidden_states = F.pad(hidden_states, (self.pad, self.pad), mode="replicate")
|
| 802 |
+
hidden_states = self.ratio * F.conv_transpose1d(
|
| 803 |
+
hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels
|
| 804 |
+
)
|
| 805 |
+
hidden_states = hidden_states[..., self.pad_left : -self.pad_right]
|
| 806 |
+
|
| 807 |
+
return hidden_states
|
| 808 |
+
|
| 809 |
+
|
| 810 |
+
class DownSample1d(nn.Module):
|
| 811 |
+
def __init__(self, ratio=2, kernel_size=None):
|
| 812 |
+
super().__init__()
|
| 813 |
+
cutoff = 0.5 / ratio
|
| 814 |
+
half_width = 0.6 / ratio
|
| 815 |
+
|
| 816 |
+
if cutoff < 0.0:
|
| 817 |
+
raise ValueError("Minimum cutoff must be larger than zero.")
|
| 818 |
+
if cutoff > 0.5:
|
| 819 |
+
raise ValueError("A cutoff above 0.5 does not make sense.")
|
| 820 |
+
|
| 821 |
+
self.even = kernel_size % 2 == 0
|
| 822 |
+
self.pad_left = kernel_size // 2 - int(self.even)
|
| 823 |
+
self.pad_right = kernel_size // 2
|
| 824 |
+
self.stride = ratio
|
| 825 |
+
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
| 826 |
+
self.register_buffer("filter", filter, persistent=False)
|
| 827 |
+
|
| 828 |
+
def forward(self, hidden_states):
|
| 829 |
+
channels = hidden_states.shape[1]
|
| 830 |
+
hidden_states = F.pad(hidden_states, (self.pad_left, self.pad_right), mode="replicate")
|
| 831 |
+
out = F.conv1d(hidden_states, self.filter.expand(channels, -1, -1), stride=self.stride, groups=channels)
|
| 832 |
+
return out
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
class TorchActivation1d(nn.Module):
|
| 836 |
+
def __init__(
|
| 837 |
+
self,
|
| 838 |
+
activation,
|
| 839 |
+
up_ratio: int = 2,
|
| 840 |
+
down_ratio: int = 2,
|
| 841 |
+
up_kernel_size: int = 12,
|
| 842 |
+
down_kernel_size: int = 12,
|
| 843 |
+
):
|
| 844 |
+
super().__init__()
|
| 845 |
+
if not callable(activation):
|
| 846 |
+
raise TypeError("Activation function must be callable")
|
| 847 |
+
self.act = activation
|
| 848 |
+
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
| 849 |
+
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
| 850 |
+
|
| 851 |
+
def forward(self, hidden_states):
|
| 852 |
+
hidden_states = self.upsample(hidden_states)
|
| 853 |
+
hidden_states = self.act(hidden_states)
|
| 854 |
+
hidden_states = self.downsample(hidden_states)
|
| 855 |
+
|
| 856 |
+
return hidden_states
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
class CausalConv1d(nn.Conv1d):
|
| 860 |
+
def __init__(self, *args, **kwargs):
|
| 861 |
+
super().__init__(*args, **kwargs)
|
| 862 |
+
self.causal_padding = self.dilation[0] * (self.kernel_size[0] - 1)
|
| 863 |
+
|
| 864 |
+
def forward(self, x):
|
| 865 |
+
return self._conv_forward(F.pad(x, [self.causal_padding, 0]), self.weight, self.bias)
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
class AMPBlock(torch.nn.Module):
|
| 869 |
+
def __init__(
|
| 870 |
+
self,
|
| 871 |
+
channels,
|
| 872 |
+
kernel_size=3,
|
| 873 |
+
dilation=(1, 3, 5),
|
| 874 |
+
causal_type='1',
|
| 875 |
+
):
|
| 876 |
+
super().__init__()
|
| 877 |
+
|
| 878 |
+
self.convs1 = nn.ModuleList(
|
| 879 |
+
[
|
| 880 |
+
CausalConv1d(
|
| 881 |
+
channels,
|
| 882 |
+
channels,
|
| 883 |
+
kernel_size,
|
| 884 |
+
1,
|
| 885 |
+
dilation=dilation[0],
|
| 886 |
+
),
|
| 887 |
+
CausalConv1d(
|
| 888 |
+
channels,
|
| 889 |
+
channels,
|
| 890 |
+
kernel_size,
|
| 891 |
+
1,
|
| 892 |
+
dilation=dilation[1],
|
| 893 |
+
),
|
| 894 |
+
CausalConv1d(
|
| 895 |
+
channels,
|
| 896 |
+
channels,
|
| 897 |
+
kernel_size,
|
| 898 |
+
1,
|
| 899 |
+
dilation=dilation[2],
|
| 900 |
+
),
|
| 901 |
+
]
|
| 902 |
+
)
|
| 903 |
+
|
| 904 |
+
if causal_type == '1':
|
| 905 |
+
self.convs2 = nn.ModuleList(
|
| 906 |
+
[
|
| 907 |
+
nn.Conv1d(
|
| 908 |
+
channels,
|
| 909 |
+
channels,
|
| 910 |
+
kernel_size,
|
| 911 |
+
1,
|
| 912 |
+
dilation=1,
|
| 913 |
+
padding=self._get_padding(kernel_size, 1),
|
| 914 |
+
),
|
| 915 |
+
nn.Conv1d(
|
| 916 |
+
channels,
|
| 917 |
+
channels,
|
| 918 |
+
kernel_size,
|
| 919 |
+
1,
|
| 920 |
+
dilation=1,
|
| 921 |
+
padding=self._get_padding(kernel_size, 1),
|
| 922 |
+
),
|
| 923 |
+
nn.Conv1d(
|
| 924 |
+
channels,
|
| 925 |
+
channels,
|
| 926 |
+
kernel_size,
|
| 927 |
+
1,
|
| 928 |
+
dilation=1,
|
| 929 |
+
padding=self._get_padding(kernel_size, 1),
|
| 930 |
+
),
|
| 931 |
+
]
|
| 932 |
+
)
|
| 933 |
+
else:
|
| 934 |
+
self.convs2 = nn.ModuleList(
|
| 935 |
+
[
|
| 936 |
+
CausalConv1d(
|
| 937 |
+
channels,
|
| 938 |
+
channels,
|
| 939 |
+
kernel_size,
|
| 940 |
+
1,
|
| 941 |
+
dilation=1,
|
| 942 |
+
),
|
| 943 |
+
CausalConv1d(
|
| 944 |
+
channels,
|
| 945 |
+
channels,
|
| 946 |
+
kernel_size,
|
| 947 |
+
1,
|
| 948 |
+
dilation=1,
|
| 949 |
+
),
|
| 950 |
+
CausalConv1d(
|
| 951 |
+
channels,
|
| 952 |
+
channels,
|
| 953 |
+
kernel_size,
|
| 954 |
+
1,
|
| 955 |
+
dilation=1,
|
| 956 |
+
),
|
| 957 |
+
]
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
| 961 |
+
|
| 962 |
+
self.activations = nn.ModuleList(
|
| 963 |
+
[TorchActivation1d(activation=SnakeBeta(channels)) for _ in range(self.num_layers)]
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
if causal_type == '2':
|
| 967 |
+
self.pre_conv = nn.Conv1d(
|
| 968 |
+
channels,
|
| 969 |
+
channels,
|
| 970 |
+
kernel_size,
|
| 971 |
+
stride=1,
|
| 972 |
+
padding=self._get_padding(kernel_size, 1),
|
| 973 |
+
)
|
| 974 |
+
self.pre_act = TorchActivation1d(activation=SnakeBeta(channels))
|
| 975 |
+
else:
|
| 976 |
+
self.pre_conv = nn.Identity()
|
| 977 |
+
self.pre_act = nn.Identity()
|
| 978 |
+
|
| 979 |
+
def _get_padding(self, kernel_size, dilation=1):
|
| 980 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 981 |
+
|
| 982 |
+
def forward(self, x):
|
| 983 |
+
hidden_states = self.pre_conv(x)
|
| 984 |
+
hidden_states = self.pre_act(hidden_states)
|
| 985 |
+
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
| 986 |
+
for conv1, conv2, act1, act2 in zip(self.convs1, self.convs2, acts1, acts2):
|
| 987 |
+
hidden_states = act1(hidden_states)
|
| 988 |
+
hidden_states = conv1(hidden_states)
|
| 989 |
+
hidden_states = act2(hidden_states)
|
| 990 |
+
hidden_states = conv2(hidden_states)
|
| 991 |
+
x = x + hidden_states
|
| 992 |
+
return x
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
@auto_docstring
|
| 996 |
+
class Qwen3TTSTokenizerV1DecoderBigVGANModel(Qwen3TTSTokenizerV1DecoderPreTrainedModel):
|
| 997 |
+
config: Qwen3TTSTokenizerV1DecoderBigVGANConfig
|
| 998 |
+
|
| 999 |
+
def __init__(self, config: Qwen3TTSTokenizerV1DecoderBigVGANConfig):
|
| 1000 |
+
super().__init__(config)
|
| 1001 |
+
self.num_residual_blocks = len(config.resblock_kernel_sizes)
|
| 1002 |
+
self.num_upsample_layers = len(config.upsample_rates)
|
| 1003 |
+
|
| 1004 |
+
self.conv_pre = nn.Conv1d(config.mel_dim, config.upsample_initial_channel, 5, 1, padding=2)
|
| 1005 |
+
|
| 1006 |
+
# Removing extra ModuleList breaks official state dict
|
| 1007 |
+
ups = [
|
| 1008 |
+
nn.ModuleList(
|
| 1009 |
+
[
|
| 1010 |
+
nn.ConvTranspose1d(
|
| 1011 |
+
config.upsample_initial_channel // (2**layer_idx),
|
| 1012 |
+
config.upsample_initial_channel // (2 ** (layer_idx + 1)),
|
| 1013 |
+
kernel_size,
|
| 1014 |
+
stride,
|
| 1015 |
+
padding=(kernel_size - stride) // 2,
|
| 1016 |
+
)
|
| 1017 |
+
]
|
| 1018 |
+
)
|
| 1019 |
+
for layer_idx, (stride, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes))
|
| 1020 |
+
]
|
| 1021 |
+
self.ups = nn.ModuleList(ups)
|
| 1022 |
+
|
| 1023 |
+
self.resblocks = nn.ModuleList(
|
| 1024 |
+
[
|
| 1025 |
+
AMPBlock(config.upsample_initial_channel // (2 ** (layer_idx + 1)), kernel_size, dilation, '1' if layer_idx > 1 else '2')
|
| 1026 |
+
for layer_idx in range(self.num_upsample_layers)
|
| 1027 |
+
for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes)
|
| 1028 |
+
]
|
| 1029 |
+
)
|
| 1030 |
+
|
| 1031 |
+
self.activation_post = TorchActivation1d(
|
| 1032 |
+
activation=SnakeBeta(config.upsample_initial_channel // (2**self.num_upsample_layers))
|
| 1033 |
+
)
|
| 1034 |
+
self.conv_post = nn.Conv1d(
|
| 1035 |
+
config.upsample_initial_channel // (2**self.num_upsample_layers), 1, 7, 1, padding=3, bias=False
|
| 1036 |
+
)
|
| 1037 |
+
|
| 1038 |
+
def normalize_spectrogram(self, spectrogram, max_value, min_db):
|
| 1039 |
+
return torch.clamp((2 * max_value) * ((spectrogram - min_db) / (-min_db)) - max_value, -max_value, max_value)
|
| 1040 |
+
|
| 1041 |
+
def amplitude_to_db(self, amplitude, min_db_level):
|
| 1042 |
+
min_level = torch.exp(
|
| 1043 |
+
torch.tensor(min_db_level / 20.0 * np.log(10), device=amplitude.device, dtype=amplitude.dtype)
|
| 1044 |
+
)
|
| 1045 |
+
return 20 * torch.log10(torch.clamp(amplitude, min=min_level))
|
| 1046 |
+
|
| 1047 |
+
def process_mel_spectrogram(self, mel_spectrogram):
|
| 1048 |
+
amplitude_spectrum = torch.exp(mel_spectrogram)
|
| 1049 |
+
decibel_spectrum = self.amplitude_to_db(amplitude_spectrum, -115) - 20
|
| 1050 |
+
return self.normalize_spectrogram(decibel_spectrum, 1, -115)
|
| 1051 |
+
|
| 1052 |
+
def forward(self, mel_spectrogram):
|
| 1053 |
+
processed_spectrogram = self.process_mel_spectrogram(mel_spectrogram)
|
| 1054 |
+
hidden_representation = self.conv_pre(processed_spectrogram)
|
| 1055 |
+
|
| 1056 |
+
for layer_index in range(self.num_upsample_layers):
|
| 1057 |
+
hidden_representation = self.ups[layer_index][0](hidden_representation)
|
| 1058 |
+
residual_output = sum(
|
| 1059 |
+
self.resblocks[layer_index * self.num_residual_blocks + block_index](hidden_representation)
|
| 1060 |
+
for block_index in range(self.num_residual_blocks)
|
| 1061 |
+
)
|
| 1062 |
+
residual_output = residual_output / self.num_residual_blocks
|
| 1063 |
+
hidden_representation = residual_output
|
| 1064 |
+
|
| 1065 |
+
hidden_representation = self.activation_post(hidden_representation)
|
| 1066 |
+
output_waveform = self.conv_post(hidden_representation)
|
| 1067 |
+
return torch.clamp(output_waveform, min=-1.0, max=1.0).squeeze(1)
|
| 1068 |
+
|
| 1069 |
+
|
| 1070 |
+
@auto_docstring
|
| 1071 |
+
class Qwen3TTSTokenizerV1DecoderDiTModel(Qwen3TTSTokenizerV1DecoderPreTrainedModel):
|
| 1072 |
+
config: Qwen3TTSTokenizerV1DecoderDiTConfig
|
| 1073 |
+
_no_split_modules = ["DiTDecoderLayer"]
|
| 1074 |
+
|
| 1075 |
+
def __init__(self, config: Qwen3TTSTokenizerV1DecoderDiTConfig):
|
| 1076 |
+
super().__init__(config)
|
| 1077 |
+
self.mel_dim = config.mel_dim
|
| 1078 |
+
self.repeats = config.repeats
|
| 1079 |
+
self.time_embed = DiTTimestepEmbedding(config.hidden_size)
|
| 1080 |
+
|
| 1081 |
+
self.text_embed = DiTCodecEmbedding(config.num_embeds, config.emb_dim, config.repeats)
|
| 1082 |
+
self.input_embed = DiTInputEmbedding(config)
|
| 1083 |
+
|
| 1084 |
+
self.rotary_embed = Qwen3TTSTokenizerV1DecoderDiTRotaryEmbedding(config.head_dim)
|
| 1085 |
+
|
| 1086 |
+
self.hidden_size = config.hidden_size
|
| 1087 |
+
self.layers = config.num_hidden_layers
|
| 1088 |
+
self.block_size = config.block_size
|
| 1089 |
+
self.num_attention_heads = config.num_attention_heads
|
| 1090 |
+
|
| 1091 |
+
self.transformer_blocks = nn.ModuleList()
|
| 1092 |
+
for i in range(config.num_hidden_layers):
|
| 1093 |
+
self.transformer_blocks.append(
|
| 1094 |
+
DiTDecoderLayer(
|
| 1095 |
+
config,
|
| 1096 |
+
look_ahead_block=1 if i in config.look_ahead_layers else 0,
|
| 1097 |
+
look_backward_block=1 if i in config.look_backward_layers else 0,
|
| 1098 |
+
)
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
+
self.norm_out = AdaLayerNormZero_Final(config.hidden_size) # final modulation
|
| 1102 |
+
self.proj_out = nn.Linear(config.hidden_size, config.mel_dim)
|
| 1103 |
+
|
| 1104 |
+
def _create_block_diff(self, hidden_states):
|
| 1105 |
+
batch, seq_len = hidden_states.shape[0], hidden_states.shape[1]
|
| 1106 |
+
block_indices = torch.arange(seq_len, device=hidden_states.device) // self.block_size # [seq_length]
|
| 1107 |
+
|
| 1108 |
+
block_i = block_indices.unsqueeze(1) # [seq_length, 1]
|
| 1109 |
+
block_j = block_indices.unsqueeze(0) # [1, seq_length]
|
| 1110 |
+
block_diff = block_j - block_i # (n, n)
|
| 1111 |
+
|
| 1112 |
+
return block_diff.expand(batch, self.num_attention_heads, seq_len, seq_len)
|
| 1113 |
+
|
| 1114 |
+
def forward(
|
| 1115 |
+
self,
|
| 1116 |
+
hidden_states,
|
| 1117 |
+
condition_vector,
|
| 1118 |
+
speaker_embedding,
|
| 1119 |
+
quantized_code,
|
| 1120 |
+
time_step,
|
| 1121 |
+
drop_audio_conditioning=False,
|
| 1122 |
+
drop_code=False,
|
| 1123 |
+
apply_cfg=True,
|
| 1124 |
+
):
|
| 1125 |
+
batch_size = hidden_states.shape[0] * 2
|
| 1126 |
+
if time_step.ndim == 0:
|
| 1127 |
+
time_step = time_step.repeat(batch_size)
|
| 1128 |
+
|
| 1129 |
+
# Compute embeddings
|
| 1130 |
+
time_embedding = self.time_embed(time_step)
|
| 1131 |
+
text_embedding = self.text_embed(quantized_code, drop_code=False if apply_cfg else drop_code)
|
| 1132 |
+
text_embedding_unconditioned = self.text_embed(quantized_code, drop_code=True) if apply_cfg else None
|
| 1133 |
+
|
| 1134 |
+
hidden_states = self.input_embed(
|
| 1135 |
+
hidden_states,
|
| 1136 |
+
speaker_embedding,
|
| 1137 |
+
condition_vector,
|
| 1138 |
+
text_embedding,
|
| 1139 |
+
drop_audio_cond=drop_audio_conditioning,
|
| 1140 |
+
code_embed_uncond=text_embedding_unconditioned,
|
| 1141 |
+
apply_cfg=apply_cfg,
|
| 1142 |
+
)
|
| 1143 |
+
|
| 1144 |
+
# Compute positional encodings
|
| 1145 |
+
position_embeddings = self.rotary_embed(hidden_states)
|
| 1146 |
+
blockwise_difference = self._create_block_diff(hidden_states)
|
| 1147 |
+
|
| 1148 |
+
# Transformer blocks
|
| 1149 |
+
for transformer_block in self.transformer_blocks:
|
| 1150 |
+
hidden_states = transformer_block(
|
| 1151 |
+
hidden_states,
|
| 1152 |
+
time_embedding,
|
| 1153 |
+
position_embeddings=position_embeddings,
|
| 1154 |
+
block_diff=blockwise_difference,
|
| 1155 |
+
)
|
| 1156 |
+
|
| 1157 |
+
hidden_states = self.norm_out(hidden_states, time_embedding)
|
| 1158 |
+
output = self.proj_out(hidden_states)
|
| 1159 |
+
|
| 1160 |
+
return output
|
| 1161 |
+
|
| 1162 |
+
def optimized_scale(self, positive_flat, negative_flat):
|
| 1163 |
+
# Calculate dot production
|
| 1164 |
+
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
|
| 1165 |
+
# Squared norm of uncondition
|
| 1166 |
+
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
|
| 1167 |
+
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
| 1168 |
+
st_star = dot_product / squared_norm
|
| 1169 |
+
return st_star
|
| 1170 |
+
|
| 1171 |
+
@torch.no_grad()
|
| 1172 |
+
def sample(
|
| 1173 |
+
self,
|
| 1174 |
+
conditioning_vector,
|
| 1175 |
+
reference_mel_spectrogram,
|
| 1176 |
+
quantized_code,
|
| 1177 |
+
num_steps=10,
|
| 1178 |
+
guidance_scale=0.5,
|
| 1179 |
+
sway_coefficient=-1.0,
|
| 1180 |
+
):
|
| 1181 |
+
noise_initialization = torch.randn([quantized_code.shape[0], 30000, self.mel_dim], dtype=reference_mel_spectrogram.dtype)
|
| 1182 |
+
maximum_duration = quantized_code.shape[1] * self.repeats
|
| 1183 |
+
initial_state = noise_initialization[:, :maximum_duration].to(quantized_code.device)
|
| 1184 |
+
conditioning_vector = conditioning_vector.unsqueeze(1).repeat(1, maximum_duration, 1)
|
| 1185 |
+
|
| 1186 |
+
def ode_function(time_step, hidden_states):
|
| 1187 |
+
if guidance_scale < 1e-5:
|
| 1188 |
+
prediction = self(
|
| 1189 |
+
hidden_states=hidden_states,
|
| 1190 |
+
speaker_embedding=conditioning_vector,
|
| 1191 |
+
condition_vector=reference_mel_spectrogram,
|
| 1192 |
+
quantized_code=quantized_code,
|
| 1193 |
+
time_step=time_step,
|
| 1194 |
+
drop_audio_conditioning=False,
|
| 1195 |
+
drop_code=False,
|
| 1196 |
+
)
|
| 1197 |
+
return prediction
|
| 1198 |
+
|
| 1199 |
+
model_output = self(
|
| 1200 |
+
hidden_states=hidden_states,
|
| 1201 |
+
quantized_code=quantized_code,
|
| 1202 |
+
speaker_embedding=conditioning_vector,
|
| 1203 |
+
condition_vector=reference_mel_spectrogram,
|
| 1204 |
+
time_step=time_step,
|
| 1205 |
+
apply_cfg=True,
|
| 1206 |
+
)
|
| 1207 |
+
guided_prediction, null_prediction = torch.chunk(model_output, 2, dim=0)
|
| 1208 |
+
|
| 1209 |
+
return guided_prediction + (guided_prediction - null_prediction) * guidance_scale
|
| 1210 |
+
|
| 1211 |
+
initial_time = 0
|
| 1212 |
+
time_embedding = torch.linspace(
|
| 1213 |
+
initial_time, 1, num_steps, device=quantized_code.device, dtype=conditioning_vector.dtype
|
| 1214 |
+
)
|
| 1215 |
+
|
| 1216 |
+
if sway_coefficient is not None:
|
| 1217 |
+
time_embedding += sway_coefficient * (torch.cos(torch.pi / 2 * time_embedding) - 1 + time_embedding)
|
| 1218 |
+
|
| 1219 |
+
values = initial_state.clone()
|
| 1220 |
+
for t0, t1 in zip(time_embedding[:-1], time_embedding[1:]):
|
| 1221 |
+
dt = t1 - t0
|
| 1222 |
+
vt = ode_function(t0, values)
|
| 1223 |
+
values = values + vt * dt
|
| 1224 |
+
|
| 1225 |
+
generated_mel_spectrogram = values.permute(0, 2, 1)
|
| 1226 |
+
return generated_mel_spectrogram
|
| 1227 |
+
|
| 1228 |
+
|
| 1229 |
+
@auto_docstring
|
| 1230 |
+
class Qwen3TTSTokenizerV1Decoder(Qwen3TTSTokenizerV1DecoderPreTrainedModel):
|
| 1231 |
+
config: Qwen3TTSTokenizerV1DecoderConfig
|
| 1232 |
+
base_model_prefix = "model"
|
| 1233 |
+
_no_split_modules = ["Qwen3TTSTokenizerV1DecoderDiTModel", "Qwen3TTSTokenizerV1DecoderBigVGANModel"]
|
| 1234 |
+
|
| 1235 |
+
def __init__(self, config: Qwen3TTSTokenizerV1DecoderConfig):
|
| 1236 |
+
super().__init__(config)
|
| 1237 |
+
attn_impl = config._attn_implementation
|
| 1238 |
+
if config._attn_implementation == "flash_attention_2":
|
| 1239 |
+
logger.warning_once(
|
| 1240 |
+
"Qwen3TTSTokenizerV1Decoder must inference with fp32, but flash_attention_2 only supports fp16 and bf16, "
|
| 1241 |
+
"attention implementation of Qwen3TTSTokenizerV1Decoder will fallback to sdpa."
|
| 1242 |
+
)
|
| 1243 |
+
attn_impl = "sdpa"
|
| 1244 |
+
elif config._attn_implementation == "eager":
|
| 1245 |
+
logger.warning_once(
|
| 1246 |
+
"Qwen3TTSTokenizerV1Decoder does not support eager attention implementation, fall back to sdpa"
|
| 1247 |
+
)
|
| 1248 |
+
attn_impl = "sdpa"
|
| 1249 |
+
self.dit = Qwen3TTSTokenizerV1DecoderDiTModel._from_config(
|
| 1250 |
+
config.dit_config, attn_implementation=attn_impl
|
| 1251 |
+
)
|
| 1252 |
+
self.bigvgan = Qwen3TTSTokenizerV1DecoderBigVGANModel._from_config(
|
| 1253 |
+
config.bigvgan_config, attn_implementation=attn_impl
|
| 1254 |
+
)
|
| 1255 |
+
|
| 1256 |
+
def forward(
|
| 1257 |
+
self,
|
| 1258 |
+
code,
|
| 1259 |
+
conditioning,
|
| 1260 |
+
reference_mel,
|
| 1261 |
+
num_steps=10,
|
| 1262 |
+
guidance_scale=0.5,
|
| 1263 |
+
sway_coefficient=-1.0,
|
| 1264 |
+
**kwargs,
|
| 1265 |
+
):
|
| 1266 |
+
"""Generates a waveform from input code and conditioning parameters."""
|
| 1267 |
+
|
| 1268 |
+
mel_spectrogram = self.dit.sample(
|
| 1269 |
+
conditioning,
|
| 1270 |
+
reference_mel,
|
| 1271 |
+
code,
|
| 1272 |
+
num_steps=num_steps,
|
| 1273 |
+
guidance_scale=guidance_scale,
|
| 1274 |
+
sway_coefficient=sway_coefficient,
|
| 1275 |
+
)
|
| 1276 |
+
|
| 1277 |
+
waveform = self.bigvgan(mel_spectrogram)
|
| 1278 |
+
|
| 1279 |
+
return waveform
|
| 1280 |
+
|
| 1281 |
+
|
| 1282 |
+
class Qwen3TTSTokenizerV1Encoder(Qwen3TTSTokenizerV1EncoderPreTrainedModel):
|
| 1283 |
+
config: Qwen3TTSTokenizerV1EncoderConfig
|
| 1284 |
+
def __init__(self, config: Qwen3TTSTokenizerV1EncoderConfig):
|
| 1285 |
+
super().__init__(config)
|
| 1286 |
+
|
| 1287 |
+
self.tokenizer = WhisperEncoderVQ(
|
| 1288 |
+
n_mels=config.n_mels,
|
| 1289 |
+
n_ctx=config.n_ctx,
|
| 1290 |
+
n_state=config.n_state,
|
| 1291 |
+
n_head=config.n_head,
|
| 1292 |
+
n_layer=config.n_layer,
|
| 1293 |
+
n_window=config.n_window,
|
| 1294 |
+
output_dim=config.output_dim,
|
| 1295 |
+
grad_checkpointing=config.grad_checkpointing,
|
| 1296 |
+
enable_mp=config.enable_mp,
|
| 1297 |
+
audio_sequence_parallel=config.audio_sequence_parallel,
|
| 1298 |
+
audio_vq_type=config.audio_vq_type,
|
| 1299 |
+
audio_vq_layers=config.audio_vq_layers,
|
| 1300 |
+
audio_vq_codebook_size=config.audio_vq_codebook_size,
|
| 1301 |
+
audio_vq_codebook_dim=config.audio_vq_codebook_dim,
|
| 1302 |
+
audio_vq_pe=config.audio_vq_pe,
|
| 1303 |
+
audio_vq_ds_rate=config.audio_vq_ds_rate,
|
| 1304 |
+
)
|
| 1305 |
+
|
| 1306 |
+
self.padding = True
|
| 1307 |
+
self.audio_vq_ds_rate = self.tokenizer.audio_vq_ds_rate
|
| 1308 |
+
|
| 1309 |
+
def speech2mel(self, speechs):
|
| 1310 |
+
mels = [
|
| 1311 |
+
get_mel_audio(
|
| 1312 |
+
speech, padding = self.padding, audio_vq_ds_rate = self.audio_vq_ds_rate
|
| 1313 |
+
).to(speech.dtype).to(self.tokenizer.conv1.weight.device)
|
| 1314 |
+
for speech in speechs
|
| 1315 |
+
]
|
| 1316 |
+
return mels
|
| 1317 |
+
|
| 1318 |
+
def mel2code(self, mels):
|
| 1319 |
+
audio_mellens = [mel.size(-1) for mel in mels]
|
| 1320 |
+
audio_aftercnnlens = [get_T_after_cnn(T) for T in audio_mellens]
|
| 1321 |
+
audio_seqlens = [T + 2 for T in audio_aftercnnlens]
|
| 1322 |
+
|
| 1323 |
+
with torch.no_grad():
|
| 1324 |
+
_, indices = self.tokenizer(
|
| 1325 |
+
x_list = mels,
|
| 1326 |
+
audio_mellens = audio_mellens,
|
| 1327 |
+
audio_aftercnnlens = audio_aftercnnlens,
|
| 1328 |
+
audio_seqlens = audio_seqlens,
|
| 1329 |
+
return_indices=True,
|
| 1330 |
+
)
|
| 1331 |
+
|
| 1332 |
+
indice_lens = [T // self.tokenizer.audio_vq_ds_rate for T in audio_aftercnnlens]
|
| 1333 |
+
indices = pad_sequence(torch.split(indices, indice_lens), batch_first=True, padding_value=0)
|
| 1334 |
+
|
| 1335 |
+
return indices, indice_lens
|
| 1336 |
+
|
| 1337 |
+
def quantize_speech(self, speechs):
|
| 1338 |
+
mels = self.speech2mel(speechs)
|
| 1339 |
+
indices, indice_lens = self.mel2code(mels)
|
| 1340 |
+
return indices, indice_lens
|
| 1341 |
+
|
| 1342 |
+
|
| 1343 |
+
@auto_docstring
|
| 1344 |
+
class Qwen3TTSTokenizerV1PreTrainedModel(PreTrainedModel):
|
| 1345 |
+
config: Qwen3TTSTokenizerV1Config
|
| 1346 |
+
base_model_prefix = "model"
|
| 1347 |
+
supports_gradient_checkpointing = True
|
| 1348 |
+
_skip_keys_device_placement = "past_key_values"
|
| 1349 |
+
_supports_flash_attn = True
|
| 1350 |
+
_supports_sdpa = True
|
| 1351 |
+
_can_compile_fullgraph = False
|
| 1352 |
+
_supports_attention_backend = True
|
| 1353 |
+
|
| 1354 |
+
|
| 1355 |
+
@auto_docstring(
|
| 1356 |
+
custom_intro="""
|
| 1357 |
+
The Qwen3TTSTokenizerV1 model.
|
| 1358 |
+
"""
|
| 1359 |
+
)
|
| 1360 |
+
class Qwen3TTSTokenizerV1Model(Qwen3TTSTokenizerV1PreTrainedModel):
|
| 1361 |
+
def __init__(self, config: Qwen3TTSTokenizerV1Config):
|
| 1362 |
+
super().__init__(config)
|
| 1363 |
+
self.config = config
|
| 1364 |
+
|
| 1365 |
+
self.input_sample_rate = config.input_sample_rate
|
| 1366 |
+
self.output_sample_rate = config.output_sample_rate
|
| 1367 |
+
|
| 1368 |
+
self.decode_upsample_rate = config.decode_upsample_rate
|
| 1369 |
+
self.encode_downsample_rate = config.encode_downsample_rate
|
| 1370 |
+
|
| 1371 |
+
self.encoder = Qwen3TTSTokenizerV1Encoder._from_config(self.config.encoder_config)
|
| 1372 |
+
self.decoder = Qwen3TTSTokenizerV1Decoder._from_config(self.config.decoder_config)
|
| 1373 |
+
|
| 1374 |
+
self.encoder_xvector_extractor = None
|
| 1375 |
+
|
| 1376 |
+
self.post_init()
|
| 1377 |
+
|
| 1378 |
+
def load_encoder_xvector_extractor(self, model_path):
|
| 1379 |
+
self.encoder_xvector_extractor = XVectorExtractor(model_path)
|
| 1380 |
+
|
| 1381 |
+
def get_model_type(self):
|
| 1382 |
+
return self.config.model_type
|
| 1383 |
+
|
| 1384 |
+
def get_input_sample_rate(self):
|
| 1385 |
+
return self.input_sample_rate
|
| 1386 |
+
|
| 1387 |
+
def get_output_sample_rate(self):
|
| 1388 |
+
return self.output_sample_rate
|
| 1389 |
+
|
| 1390 |
+
def get_encode_downsample_rate(self):
|
| 1391 |
+
return self.encode_downsample_rate
|
| 1392 |
+
|
| 1393 |
+
def get_decode_upsample_rate(self):
|
| 1394 |
+
return self.decode_upsample_rate
|
| 1395 |
+
|
| 1396 |
+
@classmethod
|
| 1397 |
+
def from_pretrained(
|
| 1398 |
+
cls,
|
| 1399 |
+
pretrained_model_name_or_path,
|
| 1400 |
+
*model_args,
|
| 1401 |
+
config=None,
|
| 1402 |
+
cache_dir=None,
|
| 1403 |
+
ignore_mismatched_sizes=False,
|
| 1404 |
+
force_download=False,
|
| 1405 |
+
local_files_only=False,
|
| 1406 |
+
token=None,
|
| 1407 |
+
revision="main",
|
| 1408 |
+
use_safetensors=None,
|
| 1409 |
+
weights_only=True,
|
| 1410 |
+
**kwargs,
|
| 1411 |
+
):
|
| 1412 |
+
model = super().from_pretrained(
|
| 1413 |
+
pretrained_model_name_or_path,
|
| 1414 |
+
*model_args,
|
| 1415 |
+
config=config,
|
| 1416 |
+
cache_dir=cache_dir,
|
| 1417 |
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
| 1418 |
+
force_download=force_download,
|
| 1419 |
+
local_files_only=local_files_only,
|
| 1420 |
+
token=token,
|
| 1421 |
+
revision=revision,
|
| 1422 |
+
use_safetensors=use_safetensors,
|
| 1423 |
+
weights_only=weights_only,
|
| 1424 |
+
**kwargs,
|
| 1425 |
+
)
|
| 1426 |
+
encoder_xvector_extractor_path = cached_file(
|
| 1427 |
+
pretrained_model_name_or_path,
|
| 1428 |
+
"campplus.onnx",
|
| 1429 |
+
subfolder=kwargs.pop("subfolder", None),
|
| 1430 |
+
cache_dir=kwargs.pop("cache_dir", None),
|
| 1431 |
+
force_download=kwargs.pop("force_download", False),
|
| 1432 |
+
proxies=kwargs.pop("proxies", None),
|
| 1433 |
+
resume_download=kwargs.pop("resume_download", None),
|
| 1434 |
+
local_files_only=kwargs.pop("local_files_only", False),
|
| 1435 |
+
token=kwargs.pop("use_auth_token", None),
|
| 1436 |
+
revision=kwargs.pop("revision", None),
|
| 1437 |
+
)
|
| 1438 |
+
if encoder_xvector_extractor_path is None:
|
| 1439 |
+
raise ValueError(f"""{pretrained_model_name_or_path}/{encoder_xvector_extractor_path} not exists""")
|
| 1440 |
+
model.load_encoder_xvector_extractor(encoder_xvector_extractor_path)
|
| 1441 |
+
|
| 1442 |
+
return model
|
| 1443 |
+
|
| 1444 |
+
def encode(
|
| 1445 |
+
self,
|
| 1446 |
+
input_values: torch.Tensor,
|
| 1447 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 1448 |
+
return_dict: Optional[bool] = None,
|
| 1449 |
+
) -> Union[tuple[torch.Tensor, Optional[torch.Tensor]], Qwen3TTSTokenizerV1EncoderOutput]:
|
| 1450 |
+
"""
|
| 1451 |
+
Encodes the input audio waveform into discrete codes.
|
| 1452 |
+
|
| 1453 |
+
Args:
|
| 1454 |
+
input_values (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
| 1455 |
+
Float values of the input audio waveform.
|
| 1456 |
+
padding_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`):
|
| 1457 |
+
Indicates which inputs are to be ignored due to padding, where elements are either 1 for *not masked* or 0
|
| 1458 |
+
for *masked*.
|
| 1459 |
+
return_dict (`bool`, *optional*):
|
| 1460 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1461 |
+
"""
|
| 1462 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 1463 |
+
|
| 1464 |
+
wavs = [value[:mask.sum()] for value, mask in zip(input_values, padding_mask)]
|
| 1465 |
+
|
| 1466 |
+
codes, codes_lens = self.encoder.quantize_speech(wavs)
|
| 1467 |
+
codes = [c[:l] for c, l in zip(codes, codes_lens)]
|
| 1468 |
+
|
| 1469 |
+
xvectors = []
|
| 1470 |
+
ref_mels = []
|
| 1471 |
+
for wav in wavs:
|
| 1472 |
+
xvector, ref_mel = self.encoder_xvector_extractor.extract_code(wav.cpu().numpy())
|
| 1473 |
+
xvector = torch.tensor(xvector).to(wav.dtype).to(wav.device)
|
| 1474 |
+
ref_mel = torch.tensor(ref_mel).to(wav.dtype).to(wav.device)
|
| 1475 |
+
xvectors.append(xvector)
|
| 1476 |
+
ref_mels.append(ref_mel)
|
| 1477 |
+
|
| 1478 |
+
if not return_dict:
|
| 1479 |
+
return (
|
| 1480 |
+
codes,
|
| 1481 |
+
xvectors,
|
| 1482 |
+
ref_mels
|
| 1483 |
+
)
|
| 1484 |
+
|
| 1485 |
+
return Qwen3TTSTokenizerV1EncoderOutput(codes, xvectors, ref_mels)
|
| 1486 |
+
|
| 1487 |
+
def decode(
|
| 1488 |
+
self,
|
| 1489 |
+
audio_codes: torch.Tensor,
|
| 1490 |
+
xvectors: torch.Tensor,
|
| 1491 |
+
ref_mels: torch.Tensor,
|
| 1492 |
+
return_dict: Optional[bool] = None,
|
| 1493 |
+
) -> Union[tuple[torch.Tensor, torch.Tensor], Qwen3TTSTokenizerV1DecoderOutput]:
|
| 1494 |
+
"""
|
| 1495 |
+
Decodes the given frames into an output audio waveform.
|
| 1496 |
+
|
| 1497 |
+
Note that the output might be a bit bigger than the input. In that case, any extra steps at the end can be
|
| 1498 |
+
trimmed.
|
| 1499 |
+
|
| 1500 |
+
Args:
|
| 1501 |
+
audio_codes (`torch.LongTensor` of shape `(batch_size, codes_length)`, *optional*):
|
| 1502 |
+
Discret code embeddings computed using `model.encode`.
|
| 1503 |
+
xvectors (`torch.FloatTensor` of shape `(batch_size, xvector_dim)`, *optional*):
|
| 1504 |
+
X-vector embeddings computed using `model.encode`.
|
| 1505 |
+
ref_mels (`torch.FloatTensor` of shape `(batch_size, mel_length, mel_dim)`, *optional*):
|
| 1506 |
+
Reference mel spectrogram computed using `model.encode`.
|
| 1507 |
+
return_dict (`bool`, *optional*):
|
| 1508 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 1509 |
+
|
| 1510 |
+
"""
|
| 1511 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
| 1512 |
+
|
| 1513 |
+
audio_values = self.decoder(code=audio_codes,
|
| 1514 |
+
reference_mel=ref_mels,
|
| 1515 |
+
conditioning=xvectors)
|
| 1516 |
+
|
| 1517 |
+
audio_lengths = (audio_codes > 0).sum(1) * self.decode_upsample_rate
|
| 1518 |
+
audio_values = [a[:l] for a, l in zip(audio_values, audio_lengths)]
|
| 1519 |
+
|
| 1520 |
+
if not return_dict:
|
| 1521 |
+
return (
|
| 1522 |
+
audio_values,
|
| 1523 |
+
)
|
| 1524 |
+
|
| 1525 |
+
return Qwen3TTSTokenizerV1DecoderOutput(audio_values)
|
| 1526 |
+
|
| 1527 |
+
|
| 1528 |
+
__all__ = ["Qwen3TTSTokenizerV1Model", "Qwen3TTSTokenizerV1PreTrainedModel"]
|
qwen_tts/core/tokenizer_25hz/vq/assets/mel_filters.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7450ae70723a5ef9d341e3cee628c7cb0177f36ce42c44b7ed2bf3325f0f6d4c
|
| 3 |
+
size 4271
|
qwen_tts/core/tokenizer_25hz/vq/core_vq.py
ADDED
|
@@ -0,0 +1,523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
# This implementation is inspired from
|
| 8 |
+
# https://github.com/lucidrains/vector-quantize-pytorch
|
| 9 |
+
# which is released under MIT License. Hereafter, the original license:
|
| 10 |
+
# MIT License
|
| 11 |
+
#
|
| 12 |
+
# Copyright (c) 2020 Phil Wang
|
| 13 |
+
#
|
| 14 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 15 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 16 |
+
# in the Software without restriction, including without limitation the rights
|
| 17 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 18 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 19 |
+
# furnished to do so, subject to the following conditions:
|
| 20 |
+
#
|
| 21 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 22 |
+
# copies or substantial portions of the Software.
|
| 23 |
+
#
|
| 24 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 25 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 26 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 27 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 28 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 29 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 30 |
+
# SOFTWARE.
|
| 31 |
+
|
| 32 |
+
"""Core vector quantization implementation."""
|
| 33 |
+
import random
|
| 34 |
+
import typing as tp
|
| 35 |
+
from random import randrange
|
| 36 |
+
|
| 37 |
+
import numpy as np
|
| 38 |
+
from einops import rearrange, repeat
|
| 39 |
+
from math import ceil
|
| 40 |
+
import torch
|
| 41 |
+
from torch import nn
|
| 42 |
+
import torch.nn.functional as F
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def round_up_multiple(num, mult):
|
| 46 |
+
return ceil(num / mult) * mult
|
| 47 |
+
|
| 48 |
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
| 49 |
+
return val if val is not None else d
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def ema_inplace(moving_avg, new, decay: float):
|
| 53 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
| 57 |
+
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def uniform_init(*shape: int):
|
| 61 |
+
t = torch.empty(shape)
|
| 62 |
+
nn.init.kaiming_uniform_(t)
|
| 63 |
+
return t
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def sample_vectors(samples, num: int):
|
| 67 |
+
num_samples, device = samples.shape[0], samples.device
|
| 68 |
+
|
| 69 |
+
if num_samples >= num:
|
| 70 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
| 71 |
+
else:
|
| 72 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
| 73 |
+
|
| 74 |
+
return samples[indices]
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@torch.no_grad()
|
| 78 |
+
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
| 79 |
+
dim, dtype = samples.shape[-1], samples.dtype
|
| 80 |
+
|
| 81 |
+
means = sample_vectors(samples, num_clusters)
|
| 82 |
+
|
| 83 |
+
for _ in range(num_iters):
|
| 84 |
+
dists = -(
|
| 85 |
+
samples.pow(2).sum(1, keepdim=True)
|
| 86 |
+
- 2 * torch.matmul(samples, means.t())
|
| 87 |
+
+ means.t().pow(2).sum(0, keepdim=True)
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
buckets = dists.max(dim=-1).indices
|
| 91 |
+
del dists
|
| 92 |
+
bins = torch.bincount(buckets, minlength=num_clusters)
|
| 93 |
+
zero_mask = bins == 0
|
| 94 |
+
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
| 95 |
+
|
| 96 |
+
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
| 97 |
+
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
| 98 |
+
new_means = new_means / bins_min_clamped[..., None]
|
| 99 |
+
|
| 100 |
+
means = torch.where(zero_mask[..., None], means, new_means)
|
| 101 |
+
return means, bins
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def preprocess(x):
|
| 105 |
+
x = rearrange(x, "... d -> (...) d")
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def postprocess_emb(embed_ind, shape):
|
| 110 |
+
return embed_ind.view(*shape[:-1])
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class EuclideanCodebook(nn.Module):
|
| 114 |
+
"""Codebook with Euclidean distance.
|
| 115 |
+
Args:
|
| 116 |
+
dim (int): Dimension.
|
| 117 |
+
codebook_size (int): Codebook size.
|
| 118 |
+
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
| 119 |
+
If set to true, run the k-means algorithm on the first training batch and use
|
| 120 |
+
the learned centroids as initialization.
|
| 121 |
+
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
| 122 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 123 |
+
epsilon (float): Epsilon value for numerical stability.
|
| 124 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 125 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
| 126 |
+
randomly selected vector from the current batch.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
dim: int,
|
| 132 |
+
codebook_size: int,
|
| 133 |
+
kmeans_init: int = False,
|
| 134 |
+
kmeans_iters: int = 10,
|
| 135 |
+
decay: float = 0.99,
|
| 136 |
+
epsilon: float = 1e-5,
|
| 137 |
+
threshold_ema_dead_code: float = 2.0,
|
| 138 |
+
):
|
| 139 |
+
super().__init__()
|
| 140 |
+
self.decay = decay
|
| 141 |
+
self.codebook_size = codebook_size
|
| 142 |
+
self.kmeans_iters = kmeans_iters
|
| 143 |
+
self.epsilon = epsilon
|
| 144 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
| 145 |
+
|
| 146 |
+
self.inited = None
|
| 147 |
+
self.cluster_size = None
|
| 148 |
+
self.embed = None
|
| 149 |
+
self.embed_avg = None
|
| 150 |
+
self.training = True
|
| 151 |
+
|
| 152 |
+
def init_embed_(self, data):
|
| 153 |
+
if self.inited:
|
| 154 |
+
return
|
| 155 |
+
|
| 156 |
+
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
| 157 |
+
self.embed.data.copy_(embed)
|
| 158 |
+
self.embed_avg.data.copy_(embed.clone())
|
| 159 |
+
self.cluster_size.data.copy_(cluster_size)
|
| 160 |
+
self.inited.data.copy_(torch.Tensor([True]))
|
| 161 |
+
# Make sure all buffers across workers are in sync after initialization
|
| 162 |
+
# distrib.broadcast_tensors([self.embed, self.embed_avg, self.cluster_size, self.inited])
|
| 163 |
+
|
| 164 |
+
def replace_(self, samples, mask):
|
| 165 |
+
modified_codebook = torch.where(
|
| 166 |
+
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
| 167 |
+
)
|
| 168 |
+
self.embed.data.copy_(modified_codebook)
|
| 169 |
+
|
| 170 |
+
def expire_codes_(self, batch_samples):
|
| 171 |
+
if self.threshold_ema_dead_code == 0:
|
| 172 |
+
return
|
| 173 |
+
|
| 174 |
+
cluster_size = self.cluster_size / sum(self.cluster_size) * self.codebook_size
|
| 175 |
+
expired_codes = cluster_size < self.threshold_ema_dead_code
|
| 176 |
+
if not torch.any(expired_codes):
|
| 177 |
+
return
|
| 178 |
+
else:
|
| 179 |
+
print(f"VQ expire infos: num_expire={sum(expired_codes)}, cluster_size[:5]={cluster_size[:5]}")
|
| 180 |
+
|
| 181 |
+
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
| 182 |
+
self.replace_(batch_samples, mask=expired_codes)
|
| 183 |
+
# sync buffers outside for efficiency
|
| 184 |
+
# distrib.broadcast_tensors(self.buffers())
|
| 185 |
+
|
| 186 |
+
def quantize(self, x):
|
| 187 |
+
embed = self.embed.t()
|
| 188 |
+
dist = -(
|
| 189 |
+
x.pow(2).sum(1, keepdim=True)
|
| 190 |
+
- 2 * x @ embed
|
| 191 |
+
+ embed.pow(2).sum(0, keepdim=True)
|
| 192 |
+
)
|
| 193 |
+
embed_ind = dist.max(dim=-1).indices
|
| 194 |
+
return embed_ind
|
| 195 |
+
|
| 196 |
+
def dequantize(self, embed_ind):
|
| 197 |
+
quantize = F.embedding(embed_ind, self.embed)
|
| 198 |
+
return quantize
|
| 199 |
+
|
| 200 |
+
def encode(self, x, buffers):
|
| 201 |
+
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
|
| 202 |
+
|
| 203 |
+
shape = x.shape
|
| 204 |
+
# pre-process
|
| 205 |
+
x = preprocess(x)
|
| 206 |
+
# quantize
|
| 207 |
+
embed_ind = self.quantize(x)
|
| 208 |
+
# post-process
|
| 209 |
+
embed_ind = postprocess_emb(embed_ind, shape)
|
| 210 |
+
return embed_ind
|
| 211 |
+
|
| 212 |
+
def decode(self, embed_ind, buffers):
|
| 213 |
+
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
|
| 214 |
+
|
| 215 |
+
quantize = self.dequantize(embed_ind)
|
| 216 |
+
return quantize
|
| 217 |
+
|
| 218 |
+
def forward(self, x, buffers):
|
| 219 |
+
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
|
| 220 |
+
|
| 221 |
+
shape, dtype = x.shape, x.dtype
|
| 222 |
+
x = preprocess(x)
|
| 223 |
+
|
| 224 |
+
self.init_embed_(x)
|
| 225 |
+
if self.training:
|
| 226 |
+
# We do the expiry of code at that point as buffers are in sync
|
| 227 |
+
# and all the workers will take the same decision.
|
| 228 |
+
self.expire_codes_(x)
|
| 229 |
+
|
| 230 |
+
embed_ind = self.quantize(x)
|
| 231 |
+
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
| 232 |
+
embed_ind = postprocess_emb(embed_ind, shape)
|
| 233 |
+
quantize = self.dequantize(embed_ind)
|
| 234 |
+
|
| 235 |
+
if self.training:
|
| 236 |
+
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
| 237 |
+
embed_sum = x.t() @ embed_onehot
|
| 238 |
+
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
| 239 |
+
cluster_size = (
|
| 240 |
+
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
| 241 |
+
* self.cluster_size.sum()
|
| 242 |
+
)
|
| 243 |
+
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
| 244 |
+
self.embed.data.copy_(embed_normalized)
|
| 245 |
+
# Note: after ema update, there is a very small difference between codebooks on GPUs.
|
| 246 |
+
# The impact can be very small, ignore it.
|
| 247 |
+
|
| 248 |
+
return quantize, embed_ind
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class VectorQuantization(nn.Module):
|
| 252 |
+
"""Vector quantization implementation.
|
| 253 |
+
Currently, supports only euclidean distance.
|
| 254 |
+
Args:
|
| 255 |
+
dim (int): Dimension
|
| 256 |
+
codebook_size (int): Codebook size
|
| 257 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
| 258 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
| 259 |
+
epsilon (float): Epsilon value for numerical stability.
|
| 260 |
+
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
| 261 |
+
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
| 262 |
+
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
| 263 |
+
that have an exponential moving average cluster size less than the specified threshold with
|
| 264 |
+
randomly selected vector from the current batch.
|
| 265 |
+
commitment_weight (float): Weight for commitment loss.
|
| 266 |
+
"""
|
| 267 |
+
def __init__(
|
| 268 |
+
self,
|
| 269 |
+
dim: int,
|
| 270 |
+
codebook_size: int,
|
| 271 |
+
codebook_dim: tp.Optional[int] = None,
|
| 272 |
+
decay: float = 0.99,
|
| 273 |
+
epsilon: float = 1e-5,
|
| 274 |
+
kmeans_init: bool = True,
|
| 275 |
+
kmeans_iters: int = 50,
|
| 276 |
+
threshold_ema_dead_code: float = 2.0,
|
| 277 |
+
commitment_weight: float = 1.,
|
| 278 |
+
):
|
| 279 |
+
super().__init__()
|
| 280 |
+
_codebook_dim: int = default(codebook_dim, dim)
|
| 281 |
+
|
| 282 |
+
requires_projection = _codebook_dim != dim
|
| 283 |
+
self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity())
|
| 284 |
+
self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity())
|
| 285 |
+
|
| 286 |
+
self.epsilon = epsilon
|
| 287 |
+
self.commitment_weight = commitment_weight
|
| 288 |
+
|
| 289 |
+
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
|
| 290 |
+
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
|
| 291 |
+
decay=decay, epsilon=epsilon,
|
| 292 |
+
threshold_ema_dead_code=threshold_ema_dead_code)
|
| 293 |
+
self.codebook_size = codebook_size
|
| 294 |
+
self.training = True
|
| 295 |
+
|
| 296 |
+
@property
|
| 297 |
+
def codebook(self):
|
| 298 |
+
return self._codebook.embed
|
| 299 |
+
|
| 300 |
+
def encode(self, x, buffers):
|
| 301 |
+
# x = rearrange(x, "b d n -> b n d")
|
| 302 |
+
x = self.project_in(x)
|
| 303 |
+
embed_in = self._codebook.encode(x, buffers)
|
| 304 |
+
return embed_in
|
| 305 |
+
|
| 306 |
+
def decode(self, embed_ind, buffers):
|
| 307 |
+
quantize = self._codebook.decode(embed_ind, buffers)
|
| 308 |
+
quantize = self.project_out(quantize)
|
| 309 |
+
# quantize = rearrange(quantize, "b n d -> b d n")
|
| 310 |
+
return quantize
|
| 311 |
+
|
| 312 |
+
def forward(self, x, buffers):
|
| 313 |
+
device = x.device
|
| 314 |
+
# x = rearrange(x, "b d n -> b n d")
|
| 315 |
+
x = self.project_in(x)
|
| 316 |
+
|
| 317 |
+
quantize, embed_ind = self._codebook(x, buffers)
|
| 318 |
+
|
| 319 |
+
if self.training:
|
| 320 |
+
quantize = x + (quantize - x).detach()
|
| 321 |
+
|
| 322 |
+
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
| 323 |
+
|
| 324 |
+
if self.training:
|
| 325 |
+
if self.commitment_weight > 0:
|
| 326 |
+
commit_loss = F.mse_loss(quantize.detach(), x)
|
| 327 |
+
loss = loss + commit_loss * self.commitment_weight
|
| 328 |
+
|
| 329 |
+
quantize = self.project_out(quantize)
|
| 330 |
+
# quantize = rearrange(quantize, "b n d -> b d n")
|
| 331 |
+
return quantize, embed_ind, loss
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class DistributedResidualVectorQuantization(nn.Module):
|
| 335 |
+
"""Efficient distributed residual vector quantization implementation.
|
| 336 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
| 337 |
+
"""
|
| 338 |
+
def __init__(self, *,
|
| 339 |
+
num_quantizers,
|
| 340 |
+
quantize_dropout: bool = False,
|
| 341 |
+
rand_num_quant: tp.Optional[tp.List] = None,
|
| 342 |
+
**kwargs):
|
| 343 |
+
super().__init__()
|
| 344 |
+
"""
|
| 345 |
+
dim: int,
|
| 346 |
+
codebook_size: int,
|
| 347 |
+
codebook_dim: tp.Optional[int] = None,
|
| 348 |
+
"""
|
| 349 |
+
codebook_size, codebook_dim = kwargs["codebook_size"], kwargs["codebook_dim"] if kwargs["codebook_dim"] else kwargs["dim"]
|
| 350 |
+
kmeans_init = kwargs["kmeans_init"]
|
| 351 |
+
if isinstance(kmeans_init, bool):
|
| 352 |
+
if not kwargs["kmeans_init"]:
|
| 353 |
+
# use uniform init
|
| 354 |
+
embed = uniform_init(num_quantizers, codebook_size, codebook_dim)
|
| 355 |
+
inited = True
|
| 356 |
+
else:
|
| 357 |
+
# to perform kmeans init on first batch
|
| 358 |
+
embed = torch.zeros(num_quantizers, codebook_size, codebook_dim)
|
| 359 |
+
inited = False
|
| 360 |
+
elif isinstance(kmeans_init, str):
|
| 361 |
+
# use prepared kmeans init
|
| 362 |
+
embed = np.load(kmeans_init)
|
| 363 |
+
embed = torch.from_numpy(embed)
|
| 364 |
+
if embed.dim() == 2:
|
| 365 |
+
embed = embed.unsqueeze(0)
|
| 366 |
+
inited = True
|
| 367 |
+
else:
|
| 368 |
+
raise TypeError("kmeans_init should be either a bool or string path to init weights.")
|
| 369 |
+
|
| 370 |
+
self.register_buffer("inited", torch.Tensor([[inited] for _ in range(num_quantizers)]))
|
| 371 |
+
self.register_buffer("cluster_size", torch.zeros(num_quantizers, codebook_size))
|
| 372 |
+
self.register_buffer("embed", embed)
|
| 373 |
+
self.register_buffer("embed_avg", embed.clone())
|
| 374 |
+
|
| 375 |
+
self.q0_ds_ratio = 1
|
| 376 |
+
if "q0_ds_ratio" in kwargs:
|
| 377 |
+
self.q0_ds_ratio = kwargs.pop("q0_ds_ratio")
|
| 378 |
+
|
| 379 |
+
self.layers = nn.ModuleList()
|
| 380 |
+
for i in range(num_quantizers):
|
| 381 |
+
vq_args = dict(**kwargs)
|
| 382 |
+
vq = VectorQuantization(**vq_args)
|
| 383 |
+
self.layers.append(vq)
|
| 384 |
+
|
| 385 |
+
self.quantize_dropout = quantize_dropout
|
| 386 |
+
self.rand_num_quant = rand_num_quant
|
| 387 |
+
|
| 388 |
+
def forward(self, x, n_q: tp.Optional[int] = None):
|
| 389 |
+
quantized_out = torch.zeros_like(x)
|
| 390 |
+
residual = x
|
| 391 |
+
bb, cc, tt = x.shape
|
| 392 |
+
device = x.device
|
| 393 |
+
|
| 394 |
+
all_losses = []
|
| 395 |
+
all_indices = []
|
| 396 |
+
all_sub_quants = []
|
| 397 |
+
n_q = n_q or len(self.layers)
|
| 398 |
+
|
| 399 |
+
should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None
|
| 400 |
+
if should_quantize_dropout:
|
| 401 |
+
rand_quantize_dropout_index = random.choice(self.rand_num_quant)
|
| 402 |
+
|
| 403 |
+
null_indices_shape = (x.shape[0], x.shape[2])
|
| 404 |
+
null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
|
| 405 |
+
null_loss = torch.full((1,), 0., device=device, dtype=x.dtype)
|
| 406 |
+
null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype)
|
| 407 |
+
|
| 408 |
+
for quantizer_index, layer in enumerate(self.layers[:n_q]):
|
| 409 |
+
# dropout except the first quantizer
|
| 410 |
+
if should_quantize_dropout and quantizer_index >= rand_quantize_dropout_index:
|
| 411 |
+
all_indices.append(null_indices)
|
| 412 |
+
all_losses.append(null_loss)
|
| 413 |
+
all_sub_quants.append(null_sub_quant)
|
| 414 |
+
continue
|
| 415 |
+
|
| 416 |
+
quant_in = residual
|
| 417 |
+
if self.q0_ds_ratio > 1 and quantizer_index == 0:
|
| 418 |
+
quant_in = F.interpolate(quant_in, size=[tt//2])
|
| 419 |
+
quantized, indices, loss = layer(quant_in, [
|
| 420 |
+
self.inited[quantizer_index],
|
| 421 |
+
self.cluster_size[quantizer_index],
|
| 422 |
+
self.embed[quantizer_index],
|
| 423 |
+
self.embed_avg[quantizer_index]
|
| 424 |
+
])
|
| 425 |
+
if self.q0_ds_ratio > 1 and quantizer_index == 0:
|
| 426 |
+
quantized = F.interpolate(quantized, size=[tt])
|
| 427 |
+
indices = F.interpolate(indices.unsqueeze(1).float(), size=[tt]).squeeze(1).long()
|
| 428 |
+
residual = residual - quantized
|
| 429 |
+
quantized_out = quantized_out + quantized
|
| 430 |
+
|
| 431 |
+
all_indices.append(indices)
|
| 432 |
+
all_losses.append(loss)
|
| 433 |
+
all_sub_quants.append(quantized)
|
| 434 |
+
|
| 435 |
+
# sync buffers after one forward step
|
| 436 |
+
# distrib.broadcast_tensors(self.buffers())
|
| 437 |
+
out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants))
|
| 438 |
+
|
| 439 |
+
return quantized_out, out_indices, out_losses
|
| 440 |
+
|
| 441 |
+
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
| 442 |
+
residual = x
|
| 443 |
+
all_indices = []
|
| 444 |
+
n_q = n_q or len(self.layers)
|
| 445 |
+
for i, layer in enumerate(self.layers[:n_q]):
|
| 446 |
+
indices = layer.encode(residual, [
|
| 447 |
+
self.inited[i],
|
| 448 |
+
self.cluster_size[i],
|
| 449 |
+
self.embed[i],
|
| 450 |
+
self.embed_avg[i]
|
| 451 |
+
])
|
| 452 |
+
quantized = layer.decode(indices, [
|
| 453 |
+
self.inited[i],
|
| 454 |
+
self.cluster_size[i],
|
| 455 |
+
self.embed[i],
|
| 456 |
+
self.embed_avg[i]
|
| 457 |
+
])
|
| 458 |
+
residual = residual - quantized
|
| 459 |
+
all_indices.append(indices)
|
| 460 |
+
out_indices = torch.stack(all_indices)
|
| 461 |
+
return out_indices
|
| 462 |
+
|
| 463 |
+
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
| 464 |
+
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
| 465 |
+
for i, indices in enumerate(q_indices):
|
| 466 |
+
layer = self.layers[i]
|
| 467 |
+
quantized = layer.decode(indices, [
|
| 468 |
+
self.inited[i],
|
| 469 |
+
self.cluster_size[i],
|
| 470 |
+
self.embed[i],
|
| 471 |
+
self.embed_avg[i]
|
| 472 |
+
])
|
| 473 |
+
quantized_out = quantized_out + quantized
|
| 474 |
+
return quantized_out
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
class DistributedGroupResidualVectorQuantization(nn.Module):
|
| 478 |
+
"""Efficient distributed group residual vector quantization implementation.
|
| 479 |
+
Follows Algorithm 1. in https://arxiv.org/abs/2305.02765
|
| 480 |
+
Group Then rvq
|
| 481 |
+
"""
|
| 482 |
+
def __init__(self, *,
|
| 483 |
+
num_groups,
|
| 484 |
+
num_quantizers,
|
| 485 |
+
quantize_dropout: bool = False,
|
| 486 |
+
rand_num_quant: tp.Optional[tp.List] = None,
|
| 487 |
+
**kwargs):
|
| 488 |
+
super().__init__()
|
| 489 |
+
self.rvqs = nn.ModuleList(
|
| 490 |
+
[
|
| 491 |
+
DistributedResidualVectorQuantization(
|
| 492 |
+
num_quantizers=num_quantizers,
|
| 493 |
+
quantize_dropout=quantize_dropout,
|
| 494 |
+
rand_num_quant=rand_num_quant,
|
| 495 |
+
**kwargs
|
| 496 |
+
)
|
| 497 |
+
for _ in range(num_groups)
|
| 498 |
+
]
|
| 499 |
+
)
|
| 500 |
+
self.num_groups = num_groups
|
| 501 |
+
|
| 502 |
+
def forward(self, x, n_q: tp.Optional[int] = None):
|
| 503 |
+
x_lst = torch.chunk(x, chunks=self.num_groups, dim=1)
|
| 504 |
+
all_quantized_out = []
|
| 505 |
+
all_indices = []
|
| 506 |
+
all_losses = []
|
| 507 |
+
for mod, item in zip(self.rvqs, x_lst):
|
| 508 |
+
quantized_out, out_indices, out_losses = mod(item, n_q)
|
| 509 |
+
all_quantized_out.append(quantized_out)
|
| 510 |
+
all_indices.append(out_indices)
|
| 511 |
+
all_losses.append(out_losses)
|
| 512 |
+
|
| 513 |
+
out_losses = torch.stack(all_losses, dim=1).mean(dim=1)
|
| 514 |
+
|
| 515 |
+
return torch.cat(all_quantized_out, dim=1), torch.stack(all_indices, dim=1), out_losses
|
| 516 |
+
|
| 517 |
+
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
| 518 |
+
x_lst = torch.chunk(x, chunks=self.num_groups, dim=1)
|
| 519 |
+
return torch.stack([mod.encode(item, n_q) for mod, item in zip(self.rvqs, x_lst)], dim=1)
|
| 520 |
+
|
| 521 |
+
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
| 522 |
+
q_indices_lst = torch.chunk(q_indices, chunks=self.num_groups, dim=1)
|
| 523 |
+
return torch.cat([mod.decode(item.squeeze(1)) for mod, item in zip(self.rvqs, q_indices_lst)], dim=1)
|
qwen_tts/core/tokenizer_25hz/vq/speech_vq.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
import sox
|
| 17 |
+
import copy
|
| 18 |
+
import torch
|
| 19 |
+
import operator
|
| 20 |
+
import onnxruntime
|
| 21 |
+
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
import torchaudio.compliance.kaldi as kaldi
|
| 25 |
+
|
| 26 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 27 |
+
from itertools import accumulate
|
| 28 |
+
from typing import List
|
| 29 |
+
from torch import Tensor
|
| 30 |
+
|
| 31 |
+
from .core_vq import DistributedGroupResidualVectorQuantization
|
| 32 |
+
from .whisper_encoder import WhisperEncoder, Conv1d, ConvTranspose1d
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
| 36 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 37 |
+
|
| 38 |
+
def spectral_normalize_torch(magnitudes):
|
| 39 |
+
output = dynamic_range_compression_torch(magnitudes)
|
| 40 |
+
return output
|
| 41 |
+
|
| 42 |
+
class MelSpectrogramFeatures(nn.Module):
|
| 43 |
+
"""
|
| 44 |
+
Calculate the BigVGAN style mel spectrogram of an input signal.
|
| 45 |
+
Args:
|
| 46 |
+
filter_length (int): The number of samples in the filter window, used for the Fourier Transform. Default is 1024.
|
| 47 |
+
hop_length (int): The number of samples between successive frames (stride of the STFT). Default is 160.
|
| 48 |
+
win_length (int): The length of the window function applied to each frame, usually less than or equal to the filter length. Default is 640.
|
| 49 |
+
n_mel_channels (int): The number of Mel-frequency channels to output from the Mel-scale spectrogram. Default is 80.
|
| 50 |
+
mel_fmin (int): The minimum frequency (in Hz) of the Mel-scale spectrogram. Default is 0.
|
| 51 |
+
mel_fmax (int): The maximum frequency (in Hz) of the Mel-scale spectrogram. Default is 8000.
|
| 52 |
+
sampling_rate (int): The sampling rate of the audio data (in Hz). Default is 16000.
|
| 53 |
+
sampling_rate_org (int, optional): The original sampling rate of the audio data before any resampling (in Hz), if applicable. Default is None.
|
| 54 |
+
padding (str): The padding mode for the input signal. 'center' pads the signal symmetrically around its center. Default is 'center'.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
torch.Tensor: Mel spectrogram.
|
| 58 |
+
"""
|
| 59 |
+
def __init__(self,
|
| 60 |
+
filter_length=1024,
|
| 61 |
+
hop_length=160,
|
| 62 |
+
win_length=640,
|
| 63 |
+
n_mel_channels=80,
|
| 64 |
+
mel_fmin=0,
|
| 65 |
+
mel_fmax=8000,
|
| 66 |
+
sampling_rate=16000,
|
| 67 |
+
sampling_rate_org=None,
|
| 68 |
+
padding='center',
|
| 69 |
+
use_db = False,
|
| 70 |
+
):
|
| 71 |
+
super().__init__()
|
| 72 |
+
if padding not in ["center", "same"]:
|
| 73 |
+
raise ValueError("Padding must be 'center' or 'same'.")
|
| 74 |
+
self.padding = padding
|
| 75 |
+
|
| 76 |
+
self.filter_length = filter_length
|
| 77 |
+
self.hop_length = hop_length
|
| 78 |
+
self.win_length = win_length
|
| 79 |
+
self.n_mel_channels = n_mel_channels
|
| 80 |
+
self.mel_fmin = mel_fmin
|
| 81 |
+
self.mel_fmax = mel_fmax
|
| 82 |
+
self.sampling_rate = sampling_rate
|
| 83 |
+
self.sampling_rate_org = sampling_rate_org if sampling_rate_org is not None else sampling_rate
|
| 84 |
+
self.mel_basis = {}
|
| 85 |
+
self.hann_window = {}
|
| 86 |
+
|
| 87 |
+
def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
|
| 88 |
+
with torch.no_grad():
|
| 89 |
+
feats = self.extract(audio, **kwargs)
|
| 90 |
+
return feats
|
| 91 |
+
|
| 92 |
+
def extract(self, audio, **kwargs):
|
| 93 |
+
|
| 94 |
+
if len(audio.shape) == 3:
|
| 95 |
+
audio = audio.squeeze(1) if audio.shape[1] == 1 else audio.squeeze(2)
|
| 96 |
+
assert len(audio.shape) == 2
|
| 97 |
+
|
| 98 |
+
y = audio
|
| 99 |
+
if len(list(self.mel_basis.keys())) == 0:
|
| 100 |
+
mel = librosa_mel_fn(sr=self.sampling_rate, n_fft=self.filter_length, n_mels=self.n_mel_channels, fmin=self.mel_fmin, fmax=self.mel_fmax)
|
| 101 |
+
self.mel_basis[str(self.mel_fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
| 102 |
+
self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to(y.device)
|
| 103 |
+
|
| 104 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((self.filter_length-self.hop_length)/2), int((self.filter_length-self.hop_length)/2)), mode='reflect')
|
| 105 |
+
y = y.squeeze(1)
|
| 106 |
+
|
| 107 |
+
spec = torch.stft(y, self.filter_length, hop_length=self.hop_length, win_length=self.win_length, window=self.hann_window[str(y.device)],
|
| 108 |
+
center=False, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
| 109 |
+
spec = torch.view_as_real(spec)
|
| 110 |
+
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
| 111 |
+
|
| 112 |
+
spec = torch.matmul(self.mel_basis[str(self.mel_fmax)+'_'+str(y.device)], spec)
|
| 113 |
+
spec = spectral_normalize_torch(spec)
|
| 114 |
+
|
| 115 |
+
return spec
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class XVectorExtractor(nn.Module):
|
| 119 |
+
def __init__(self, audio_codec_with_xvector):
|
| 120 |
+
super().__init__()
|
| 121 |
+
option = onnxruntime.SessionOptions()
|
| 122 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
| 123 |
+
option.intra_op_num_threads = 1
|
| 124 |
+
providers = ["CPUExecutionProvider"]
|
| 125 |
+
self.ort_session = onnxruntime.InferenceSession(audio_codec_with_xvector, sess_options=option, providers=providers)
|
| 126 |
+
|
| 127 |
+
self.tfm = sox.Transformer()
|
| 128 |
+
self.tfm.norm(db_level=-6)
|
| 129 |
+
|
| 130 |
+
self.mel_ext = MelSpectrogramFeatures(
|
| 131 |
+
filter_length=1024,
|
| 132 |
+
hop_length=160,
|
| 133 |
+
win_length=640,
|
| 134 |
+
n_mel_channels=80,
|
| 135 |
+
mel_fmin=0,
|
| 136 |
+
mel_fmax=8000,
|
| 137 |
+
sampling_rate=16000
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def extract_code(self, audio):
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
norm_audio = self.sox_norm(audio)
|
| 143 |
+
|
| 144 |
+
norm_audio = torch.from_numpy(copy.deepcopy(norm_audio)).unsqueeze(0)
|
| 145 |
+
feat = kaldi.fbank(norm_audio,
|
| 146 |
+
num_mel_bins=80,
|
| 147 |
+
dither=0,
|
| 148 |
+
sample_frequency=16000)
|
| 149 |
+
feat = feat - feat.mean(dim=0, keepdim=True)
|
| 150 |
+
norm_embedding = self.ort_session.run(None, {self.ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten()
|
| 151 |
+
norm_embedding = F.normalize(torch.from_numpy(norm_embedding), dim=0)
|
| 152 |
+
|
| 153 |
+
ref_mel = self.mel_ext.extract(audio=norm_audio)
|
| 154 |
+
|
| 155 |
+
return norm_embedding.numpy(), ref_mel.permute(0,2,1).squeeze(0).numpy()
|
| 156 |
+
|
| 157 |
+
def sox_norm(self, audio):
|
| 158 |
+
wav_norm = self.tfm.build_array(input_array=audio, sample_rate_in=16000)
|
| 159 |
+
return wav_norm
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class WhisperEncoderVQ(WhisperEncoder):
|
| 163 |
+
def __init__(
|
| 164 |
+
self,
|
| 165 |
+
n_mels: int,
|
| 166 |
+
n_ctx: int,
|
| 167 |
+
n_state: int,
|
| 168 |
+
n_head: int,
|
| 169 |
+
n_layer: int,
|
| 170 |
+
n_window: int = 1500,
|
| 171 |
+
output_dim: int = 512,
|
| 172 |
+
grad_checkpointing: bool = False,
|
| 173 |
+
enable_mp: bool = False,
|
| 174 |
+
audio_sequence_parallel: bool = False,
|
| 175 |
+
audio_vq_layers: int = -1,
|
| 176 |
+
audio_vq_type: str = "NULL",
|
| 177 |
+
audio_vq_codebook_size: int = 4096,
|
| 178 |
+
audio_vq_pe: bool = False,
|
| 179 |
+
audio_vq_commit_loss: float = 0.0,
|
| 180 |
+
audio_vq_out_commit_loss: float = 0.0,
|
| 181 |
+
audio_vq_no_quantize: bool = False,
|
| 182 |
+
audio_vq_ff_layer: int = 0,
|
| 183 |
+
audio_vq_threshold_ema_dead_code: float = 0.1,
|
| 184 |
+
audio_vq_codebook_dim: int = None,
|
| 185 |
+
audio_vq_ds_rate: int = None,
|
| 186 |
+
):
|
| 187 |
+
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer, n_window, output_dim, grad_checkpointing, enable_mp, audio_sequence_parallel)
|
| 188 |
+
|
| 189 |
+
self.audio_vq_layers = audio_vq_layers
|
| 190 |
+
self.audio_vq_type = audio_vq_type
|
| 191 |
+
self.audio_vq_codebook_size = audio_vq_codebook_size
|
| 192 |
+
self.audio_vq_pe = audio_vq_pe
|
| 193 |
+
self.audio_vq_commit_loss = audio_vq_commit_loss
|
| 194 |
+
self.audio_vq_out_commit_loss = audio_vq_out_commit_loss
|
| 195 |
+
self.audio_vq_no_quantize = audio_vq_no_quantize
|
| 196 |
+
self.audio_vq_ff_layer = audio_vq_ff_layer
|
| 197 |
+
|
| 198 |
+
if audio_vq_layers > 0:
|
| 199 |
+
self.vq_feature_dim = self.n_state
|
| 200 |
+
self.audio_vq_ds_rate = 1
|
| 201 |
+
else:
|
| 202 |
+
raise NotImplementedError(f"Unsupported audio_vq_layers: {audio_vq_layers}")
|
| 203 |
+
|
| 204 |
+
if self.audio_vq_ds_rate == audio_vq_ds_rate:
|
| 205 |
+
self.audio_vq_downsample = nn.Identity()
|
| 206 |
+
self.audio_vq_upsample = nn.Identity()
|
| 207 |
+
else:
|
| 208 |
+
assert audio_vq_ds_rate % self.audio_vq_ds_rate == 0
|
| 209 |
+
stride = audio_vq_ds_rate // self.audio_vq_ds_rate
|
| 210 |
+
self.audio_vq_downsample = Conv1d(self.vq_feature_dim, self.vq_feature_dim, kernel_size=stride, stride=stride)
|
| 211 |
+
self.audio_vq_upsample = ConvTranspose1d(self.vq_feature_dim, self.vq_feature_dim, kernel_size=stride, stride=stride)
|
| 212 |
+
self.audio_vq_ds_rate = audio_vq_ds_rate
|
| 213 |
+
|
| 214 |
+
if audio_vq_type == "GRVQ":
|
| 215 |
+
self.audio_quantizer = DistributedGroupResidualVectorQuantization(
|
| 216 |
+
codebook_size = audio_vq_codebook_size,
|
| 217 |
+
dim = self.vq_feature_dim,
|
| 218 |
+
codebook_dim = self.vq_codebook_dim if audio_vq_codebook_dim is None else audio_vq_codebook_dim,
|
| 219 |
+
num_groups=1,
|
| 220 |
+
num_quantizers=1,
|
| 221 |
+
kmeans_init=False,
|
| 222 |
+
threshold_ema_dead_code = audio_vq_threshold_ema_dead_code
|
| 223 |
+
)
|
| 224 |
+
else:
|
| 225 |
+
raise NotImplementedError(f"Unsupported audio_vq_type: {audio_vq_type}")
|
| 226 |
+
|
| 227 |
+
if self.audio_vq_pe:
|
| 228 |
+
self.project_after_vq_pe = nn.Linear(self.n_state, self.n_state)
|
| 229 |
+
|
| 230 |
+
def _calc_quantize_activities(self, indices):
|
| 231 |
+
indices_onehot = F.one_hot(indices.long().flatten(), self.audio_vq_codebook_size).sum(dim=0)
|
| 232 |
+
vq_num_activities = sum(indices_onehot>0)
|
| 233 |
+
vq_num_tokens = sum(indices_onehot)
|
| 234 |
+
return {
|
| 235 |
+
"vq_num_activities": vq_num_activities,
|
| 236 |
+
"vq_num_tokens": vq_num_tokens,
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
def _do_quantize(self, x, pe=None, y=None):
|
| 240 |
+
"""
|
| 241 |
+
x: torch.Tensor, shape = (T, D)
|
| 242 |
+
q: torch.Tensor, shape = (T, D)
|
| 243 |
+
i: torch.Tensor, shape = (T)
|
| 244 |
+
"""
|
| 245 |
+
if self.audio_vq_out_commit_loss > 0:
|
| 246 |
+
x_teacher = x.clone()
|
| 247 |
+
x = x.unsqueeze(0)
|
| 248 |
+
|
| 249 |
+
x = self.audio_vq_downsample(x.transpose(1, 2))
|
| 250 |
+
x = x.transpose(1, 2)
|
| 251 |
+
|
| 252 |
+
vq_stats = {}
|
| 253 |
+
|
| 254 |
+
if self.audio_vq_type == "GRVQ":
|
| 255 |
+
if self.training:
|
| 256 |
+
raise NotImplementedError
|
| 257 |
+
else:
|
| 258 |
+
indices = self.audio_quantizer.encode(x)
|
| 259 |
+
x = self.audio_quantizer.decode(indices)
|
| 260 |
+
indices = indices.squeeze(2).squeeze(1)
|
| 261 |
+
|
| 262 |
+
vq_stats.update(self._calc_quantize_activities(indices))
|
| 263 |
+
|
| 264 |
+
x, indices = x.squeeze(0), indices.squeeze(0)
|
| 265 |
+
if self.audio_vq_pe:
|
| 266 |
+
x = x + pe
|
| 267 |
+
x = self.project_after_vq_pe(x)
|
| 268 |
+
|
| 269 |
+
x = self.audio_vq_upsample(x.unsqueeze(0).transpose(1, 2))
|
| 270 |
+
x = x.transpose(1, 2).squeeze(0)
|
| 271 |
+
|
| 272 |
+
if self.audio_vq_out_commit_loss > 0:
|
| 273 |
+
vq_out_commit_loss = F.mse_loss(x_teacher.detach(), x)
|
| 274 |
+
vq_stats["vq_out_commit_loss"] = vq_out_commit_loss * self.audio_vq_out_commit_loss
|
| 275 |
+
|
| 276 |
+
return x, indices, vq_stats
|
| 277 |
+
|
| 278 |
+
def forward(self, x_list: List[Tensor], audio_mellens:List[int], audio_aftercnnlens:List[int], audio_seqlens:List[int], return_indices=False, audio_pitchs=None):
|
| 279 |
+
"""
|
| 280 |
+
x : torch.Tensor, shape = (n_mels, n_ctx)
|
| 281 |
+
the mel spectrogram of the audio
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
aftercnn_x_list = []
|
| 285 |
+
pe_for_vq_list = []
|
| 286 |
+
for each_x in x_list:
|
| 287 |
+
each_x_split_list = each_x.split(self.n_window * 2, dim=1)
|
| 288 |
+
for each_x_split in each_x_split_list:
|
| 289 |
+
each_x_split = F.gelu(self.conv1(each_x_split))
|
| 290 |
+
each_x_split = F.gelu(self.conv2(each_x_split))
|
| 291 |
+
each_x_split = each_x_split.permute(1, 0) # L,D
|
| 292 |
+
|
| 293 |
+
each_positional_embedding_split = self.positional_embedding[:each_x_split.shape[0]]
|
| 294 |
+
aftercnn_x_list.append(each_x_split+each_positional_embedding_split.to(each_x_split.dtype))
|
| 295 |
+
|
| 296 |
+
pe_for_vq_split = self.positional_embedding[:each_x_split.shape[0] // self.audio_vq_ds_rate]
|
| 297 |
+
pe_for_vq_list.append(pe_for_vq_split.to(each_x_split.dtype))
|
| 298 |
+
|
| 299 |
+
pe_for_vq = torch.cat(pe_for_vq_list, dim=0)
|
| 300 |
+
x = torch.cat(aftercnn_x_list, dim=0)
|
| 301 |
+
src_len = x.size(0)
|
| 302 |
+
|
| 303 |
+
output_list = []
|
| 304 |
+
for item in audio_aftercnnlens:
|
| 305 |
+
while item > self.n_window:
|
| 306 |
+
output_list.append(self.n_window)
|
| 307 |
+
item -= self.n_window
|
| 308 |
+
output_list.append(item)
|
| 309 |
+
|
| 310 |
+
cu_seqlens = list(accumulate(output_list, func=operator.add,initial=0))
|
| 311 |
+
cu_seqlens = torch.Tensor(cu_seqlens).to(device=x.device, dtype=torch.int32)
|
| 312 |
+
|
| 313 |
+
layer_id = 0
|
| 314 |
+
|
| 315 |
+
for block in self.blocks:
|
| 316 |
+
layer_id+=1
|
| 317 |
+
|
| 318 |
+
x = block(x, cu_seqlens=cu_seqlens)
|
| 319 |
+
|
| 320 |
+
if self.audio_vq_layers == layer_id: # vq inside encoder
|
| 321 |
+
x, indices, vq_stats = self._do_quantize(x, pe_for_vq)
|
| 322 |
+
if return_indices:
|
| 323 |
+
return x, indices
|
| 324 |
+
|
| 325 |
+
if self.avg_pooler:
|
| 326 |
+
x_list = x.split(audio_aftercnnlens, dim=0)
|
| 327 |
+
token_x_list = []
|
| 328 |
+
for x in x_list:
|
| 329 |
+
x = x.permute(1, 0)
|
| 330 |
+
x = self.avg_pooler(x)
|
| 331 |
+
x = x.permute(1, 0)
|
| 332 |
+
token_x_list.append(x)
|
| 333 |
+
x = torch.cat(token_x_list, dim=0)
|
| 334 |
+
|
| 335 |
+
x = self.ln_post(x)
|
| 336 |
+
|
| 337 |
+
x = self.proj(x)
|
| 338 |
+
|
| 339 |
+
output = torch.zeros(
|
| 340 |
+
(x.size(0) + len(audio_seqlens) * 2, x.size(1)),
|
| 341 |
+
device=x.device, dtype=x.dtype
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
audio_seqlens_acc = list(accumulate(audio_seqlens, func=operator.add, initial=0))
|
| 345 |
+
start_ids = torch.tensor(audio_seqlens_acc[:-1], device=x.device, dtype=torch.int32)
|
| 346 |
+
end_ids = torch.tensor(audio_seqlens_acc[1:], device=x.device, dtype=torch.int32) - 1
|
| 347 |
+
|
| 348 |
+
audio_tokens_mask = torch.ones(output.size(0), device=x.device, dtype=torch.bool)
|
| 349 |
+
audio_tokens_mask[start_ids] = False
|
| 350 |
+
audio_tokens_mask[end_ids] = False
|
| 351 |
+
output[start_ids] = self.audio_bos_eos_token.weight[0].to(x.dtype)
|
| 352 |
+
output[end_ids] = self.audio_bos_eos_token.weight[1].to(x.dtype)
|
| 353 |
+
output[audio_tokens_mask] = x
|
| 354 |
+
|
| 355 |
+
if self.audio_vq_type != "NULL":
|
| 356 |
+
return output, vq_stats
|
| 357 |
+
return output
|
qwen_tts/core/tokenizer_25hz/vq/whisper_encoder.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
import os
|
| 17 |
+
import math
|
| 18 |
+
import torch
|
| 19 |
+
import operator
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
|
| 24 |
+
from functools import lru_cache
|
| 25 |
+
from typing import Optional, Union, List
|
| 26 |
+
from torch import nn, Tensor
|
| 27 |
+
from itertools import accumulate
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func
|
| 31 |
+
except ImportError:
|
| 32 |
+
try:
|
| 33 |
+
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_varlen_func
|
| 34 |
+
except ImportError:
|
| 35 |
+
print("\n********\nWarning: flash-attn is not installed. Will only run the manual PyTorch version. Please install flash-attn for faster inference.\n********\n ")
|
| 36 |
+
flash_attn_varlen_func = None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
N_FFT = 400
|
| 40 |
+
HOP_LENGTH = 160
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@lru_cache(maxsize=None)
|
| 44 |
+
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
| 45 |
+
"""
|
| 46 |
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
| 47 |
+
Allows decoupling librosa dependency; saved using:
|
| 48 |
+
|
| 49 |
+
np.savez_compressed(
|
| 50 |
+
"mel_filters.npz",
|
| 51 |
+
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
| 52 |
+
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
| 53 |
+
)
|
| 54 |
+
"""
|
| 55 |
+
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
| 56 |
+
|
| 57 |
+
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
| 58 |
+
with np.load(filters_path, allow_pickle=False) as f:
|
| 59 |
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def log_mel_spectrogram(
|
| 63 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
| 64 |
+
n_mels: int = 80,
|
| 65 |
+
padding: int = 0,
|
| 66 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 67 |
+
):
|
| 68 |
+
"""
|
| 69 |
+
Compute the log-Mel spectrogram of
|
| 70 |
+
|
| 71 |
+
Parameters
|
| 72 |
+
----------
|
| 73 |
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
| 74 |
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
| 75 |
+
|
| 76 |
+
n_mels: int
|
| 77 |
+
The number of Mel-frequency filters, only 80 is supported
|
| 78 |
+
|
| 79 |
+
padding: int
|
| 80 |
+
Number of zero samples to pad to the right
|
| 81 |
+
|
| 82 |
+
device: Optional[Union[str, torch.device]]
|
| 83 |
+
If given, the audio tensor is moved to this device before STFT
|
| 84 |
+
|
| 85 |
+
Returns
|
| 86 |
+
-------
|
| 87 |
+
torch.Tensor, shape = (80, n_frames)
|
| 88 |
+
A Tensor that contains the Mel spectrogram
|
| 89 |
+
"""
|
| 90 |
+
if not torch.is_tensor(audio):
|
| 91 |
+
audio = torch.from_numpy(audio)
|
| 92 |
+
|
| 93 |
+
if device is not None:
|
| 94 |
+
audio = audio.to(device)
|
| 95 |
+
if padding > 0:
|
| 96 |
+
audio = F.pad(audio, (0, padding))
|
| 97 |
+
window = torch.hann_window(N_FFT).to(audio.device)
|
| 98 |
+
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
| 99 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
| 100 |
+
|
| 101 |
+
filters = mel_filters(audio.device, n_mels)
|
| 102 |
+
mel_spec = filters @ magnitudes
|
| 103 |
+
|
| 104 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| 105 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
| 106 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 107 |
+
return log_spec
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_T_after_cnn(L_in, dilation=1):
|
| 111 |
+
for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
|
| 112 |
+
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
|
| 113 |
+
L_out = 1 + L_out // stride
|
| 114 |
+
L_in = L_out
|
| 115 |
+
return L_out
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def get_mel_audio(audio, padding=False, audio_vq_ds_rate = 1, n_mels = 128):
|
| 119 |
+
audio_len = len(audio)
|
| 120 |
+
if padding:
|
| 121 |
+
reduction = 160 * 2 * audio_vq_ds_rate
|
| 122 |
+
audio_pad = math.ceil(audio_len / reduction) * reduction - audio_len
|
| 123 |
+
mel = log_mel_spectrogram(audio, n_mels=n_mels, padding=audio_pad)
|
| 124 |
+
else:
|
| 125 |
+
mel = log_mel_spectrogram(audio, n_mels=n_mels) # [F,T]
|
| 126 |
+
return mel
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def sinusoids(length, channels, max_timescale=10000):
|
| 130 |
+
"""Returns sinusoids for positional embedding"""
|
| 131 |
+
assert channels % 2 == 0
|
| 132 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
| 133 |
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
| 134 |
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
| 135 |
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class Conv1d(nn.Conv1d):
|
| 139 |
+
def _conv_forward(
|
| 140 |
+
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
| 141 |
+
) -> Tensor:
|
| 142 |
+
return super()._conv_forward(
|
| 143 |
+
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class ConvTranspose1d(nn.ConvTranspose1d):
|
| 148 |
+
def _conv_forward(
|
| 149 |
+
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
| 150 |
+
) -> Tensor:
|
| 151 |
+
return super()._conv_forward(
|
| 152 |
+
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class Linear(nn.Linear):
|
| 157 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 158 |
+
return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) )
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class MultiHeadAttention(nn.Module):
|
| 162 |
+
def __init__(self, n_state: int, n_head: int):
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.n_head = n_head
|
| 165 |
+
self.query = Linear(n_state, n_state)
|
| 166 |
+
self.key = Linear(n_state, n_state, bias=False)
|
| 167 |
+
self.value = Linear(n_state, n_state)
|
| 168 |
+
self.out = Linear(n_state, n_state)
|
| 169 |
+
|
| 170 |
+
self.use_flash_attention = True
|
| 171 |
+
|
| 172 |
+
def forward(
|
| 173 |
+
self,
|
| 174 |
+
x: Tensor,
|
| 175 |
+
cu_seqlens = None,
|
| 176 |
+
):
|
| 177 |
+
q = self.query(x)
|
| 178 |
+
k = self.key(x)
|
| 179 |
+
v = self.value(x)
|
| 180 |
+
|
| 181 |
+
if self.use_flash_attention:
|
| 182 |
+
if flash_attn_varlen_func is None:
|
| 183 |
+
x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
|
| 184 |
+
else:
|
| 185 |
+
if q.dtype not in [torch.float16, torch.bfloat16]:
|
| 186 |
+
x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
|
| 187 |
+
self.use_flash_attention = False
|
| 188 |
+
else:
|
| 189 |
+
x = self.qkv_flash_attention(q, k, v, cu_seqlens=cu_seqlens)
|
| 190 |
+
else:
|
| 191 |
+
x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
|
| 192 |
+
|
| 193 |
+
output = self.out(x)
|
| 194 |
+
return output
|
| 195 |
+
|
| 196 |
+
def qkv_flash_attention(
|
| 197 |
+
self, q: Tensor, k: Tensor, v: Tensor, cu_seqlens=None
|
| 198 |
+
):
|
| 199 |
+
n_ctx, n_state = q.shape
|
| 200 |
+
# scale = (n_state // self.n_head) ** -0.25
|
| 201 |
+
q = q.view(n_ctx, self.n_head, -1)# (batch_size, seqlen, nheads, headdim)
|
| 202 |
+
k = k.view(n_ctx, self.n_head, -1)
|
| 203 |
+
v = v.view(n_ctx, self.n_head, -1)
|
| 204 |
+
|
| 205 |
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
x = flash_attn_varlen_func(
|
| 209 |
+
q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0
|
| 210 |
+
)
|
| 211 |
+
x = x.reshape(n_ctx, n_state)
|
| 212 |
+
return x
|
| 213 |
+
|
| 214 |
+
def qkv_attention_manual(
|
| 215 |
+
self, q: Tensor, k: Tensor, v: Tensor, cu_seqlens: Tensor
|
| 216 |
+
):
|
| 217 |
+
n_ctx, n_state = q.shape
|
| 218 |
+
head_dim = n_state // self.n_head
|
| 219 |
+
scale = head_dim ** -0.5
|
| 220 |
+
|
| 221 |
+
q = q.view(n_ctx, self.n_head, head_dim)
|
| 222 |
+
k = k.view(n_ctx, self.n_head, head_dim)
|
| 223 |
+
v = v.view(n_ctx, self.n_head, head_dim)
|
| 224 |
+
|
| 225 |
+
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
| 226 |
+
batch_size = len(seqlens)
|
| 227 |
+
max_seqlen = max(seqlens)
|
| 228 |
+
|
| 229 |
+
q_padded = torch.zeros(batch_size, max_seqlen, self.n_head, head_dim, dtype=q.dtype, device=q.device)
|
| 230 |
+
k_padded = torch.zeros_like(q_padded)
|
| 231 |
+
v_padded = torch.zeros_like(q_padded)
|
| 232 |
+
|
| 233 |
+
for i in range(batch_size):
|
| 234 |
+
start_idx = cu_seqlens[i]
|
| 235 |
+
end_idx = cu_seqlens[i+1]
|
| 236 |
+
seq_len = seqlens[i]
|
| 237 |
+
q_padded[i, :seq_len] = q[start_idx:end_idx]
|
| 238 |
+
k_padded[i, :seq_len] = k[start_idx:end_idx]
|
| 239 |
+
v_padded[i, :seq_len] = v[start_idx:end_idx]
|
| 240 |
+
|
| 241 |
+
q_padded = q_padded.transpose(1, 2)
|
| 242 |
+
k_padded = k_padded.transpose(1, 2)
|
| 243 |
+
v_padded = v_padded.transpose(1, 2)
|
| 244 |
+
|
| 245 |
+
attn_mask = torch.arange(max_seqlen, device=q.device)[None, :] < torch.tensor(seqlens, device=q.device)[:, None]
|
| 246 |
+
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
|
| 247 |
+
|
| 248 |
+
attn_mask = attn_mask.masked_fill(attn_mask == 0, -torch.finfo(q.dtype).max)
|
| 249 |
+
|
| 250 |
+
attn_scores = torch.matmul(q_padded, k_padded.transpose(-2, -1)) * scale
|
| 251 |
+
attn_scores = attn_scores + attn_mask
|
| 252 |
+
attn_weights = F.softmax(attn_scores, dim=-1)
|
| 253 |
+
|
| 254 |
+
context = torch.matmul(attn_weights, v_padded)
|
| 255 |
+
|
| 256 |
+
context = context.transpose(1, 2).contiguous().view(batch_size, max_seqlen, n_state)
|
| 257 |
+
|
| 258 |
+
output_packed = torch.cat([context[i, :seqlens[i]] for i in range(batch_size)], dim=0)
|
| 259 |
+
|
| 260 |
+
assert output_packed.shape == (n_ctx, n_state)
|
| 261 |
+
|
| 262 |
+
return output_packed
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
class ResidualAttentionBlock(nn.Module):
|
| 266 |
+
def __init__(self, n_state: int, n_head: int,
|
| 267 |
+
enable_mp: bool = False, sequence_parallel: bool = False):
|
| 268 |
+
super().__init__()
|
| 269 |
+
n_mlp = n_state * 4
|
| 270 |
+
self.attn_ln = nn.LayerNorm(n_state)
|
| 271 |
+
self.mlp_ln = nn.LayerNorm(n_state)
|
| 272 |
+
|
| 273 |
+
self.attn = MultiHeadAttention(n_state, n_head)
|
| 274 |
+
self.mlp = nn.Sequential(
|
| 275 |
+
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
def forward(
|
| 279 |
+
self,
|
| 280 |
+
x: Tensor,
|
| 281 |
+
cu_seqlens = None
|
| 282 |
+
):
|
| 283 |
+
x = x + self.attn(self.attn_ln(x), cu_seqlens=cu_seqlens)
|
| 284 |
+
x = x + self.mlp(self.mlp_ln(x))
|
| 285 |
+
return x
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
class WhisperEncoder(nn.Module):
|
| 289 |
+
def __init__(
|
| 290 |
+
self,
|
| 291 |
+
n_mels: int,
|
| 292 |
+
n_ctx: int,
|
| 293 |
+
n_state: int,
|
| 294 |
+
n_head: int,
|
| 295 |
+
n_layer: int,
|
| 296 |
+
n_window: int = 1500,
|
| 297 |
+
output_dim: int = 512,
|
| 298 |
+
grad_checkpointing: bool = False,
|
| 299 |
+
enable_mp: bool = False,
|
| 300 |
+
audio_sequence_parallel: bool = False,
|
| 301 |
+
):
|
| 302 |
+
super().__init__()
|
| 303 |
+
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
| 304 |
+
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
| 305 |
+
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
| 306 |
+
self.n_layer = n_layer
|
| 307 |
+
self.n_mels = n_mels
|
| 308 |
+
|
| 309 |
+
self.blocks = nn.ModuleList(
|
| 310 |
+
[ResidualAttentionBlock(n_state, n_head, enable_mp=enable_mp, sequence_parallel=audio_sequence_parallel)
|
| 311 |
+
for _ in range(n_layer)]
|
| 312 |
+
)
|
| 313 |
+
self.ln_post = nn.LayerNorm(n_state)
|
| 314 |
+
self.avg_pooler = nn.AvgPool1d(2, stride=2)
|
| 315 |
+
|
| 316 |
+
self.proj = torch.nn.Linear(n_state, output_dim)
|
| 317 |
+
|
| 318 |
+
self.audio_bos_eos_token = nn.Embedding(2, output_dim)
|
| 319 |
+
|
| 320 |
+
self.output_dim = output_dim
|
| 321 |
+
self.grad_checkpointing = grad_checkpointing
|
| 322 |
+
self.enable_mp = enable_mp
|
| 323 |
+
self.n_head = n_head
|
| 324 |
+
self.n_state = n_state
|
| 325 |
+
self.n_window = n_window
|
| 326 |
+
|
| 327 |
+
self.audio_sequence_parallel = audio_sequence_parallel
|
| 328 |
+
|
| 329 |
+
self.tp_world_size = 1
|
| 330 |
+
|
| 331 |
+
self.set_audio_sync()
|
| 332 |
+
|
| 333 |
+
def set_audio_sync(self):
|
| 334 |
+
for name, param in self.named_parameters():
|
| 335 |
+
if not name.startswith("blocks"):
|
| 336 |
+
setattr(param, "audio_sync", True)
|
| 337 |
+
|
| 338 |
+
def forward(self, x_list: List[Tensor], audio_mellens:List[int], audio_aftercnnlens:List[int], audio_seqlens:List[int]):
|
| 339 |
+
"""
|
| 340 |
+
x : torch.Tensor, shape = (n_mels, n_ctx)
|
| 341 |
+
the mel spectrogram of the audio
|
| 342 |
+
"""
|
| 343 |
+
|
| 344 |
+
aftercnn_x_list = []
|
| 345 |
+
for each_x in x_list:
|
| 346 |
+
each_x_split_list = each_x.split(self.n_window * 2, dim=1)
|
| 347 |
+
for each_x_split in each_x_split_list:
|
| 348 |
+
each_x_split = F.gelu(self.conv1(each_x_split))
|
| 349 |
+
each_x_split = F.gelu(self.conv2(each_x_split))
|
| 350 |
+
each_x_split = each_x_split.permute(1, 0) # L,D
|
| 351 |
+
each_positional_embedding_split = self.positional_embedding[:each_x_split.shape[0]]
|
| 352 |
+
aftercnn_x_list.append(each_x_split+each_positional_embedding_split.to(each_x_split.dtype))
|
| 353 |
+
|
| 354 |
+
x = torch.cat(aftercnn_x_list, dim=0)
|
| 355 |
+
src_len = x.size(0)
|
| 356 |
+
|
| 357 |
+
output_list = []
|
| 358 |
+
for item in audio_aftercnnlens:
|
| 359 |
+
while item > self.n_window:
|
| 360 |
+
output_list.append(self.n_window)
|
| 361 |
+
item -= self.n_window
|
| 362 |
+
output_list.append(item)
|
| 363 |
+
|
| 364 |
+
cu_seqlens = list(accumulate(output_list, func=operator.add,initial=0))
|
| 365 |
+
cu_seqlens = torch.Tensor(cu_seqlens).to(device=x.device, dtype=torch.int32)
|
| 366 |
+
|
| 367 |
+
layer_id = 0
|
| 368 |
+
for block in self.blocks:
|
| 369 |
+
layer_id+=1
|
| 370 |
+
x = block(x, cu_seqlens=cu_seqlens)
|
| 371 |
+
|
| 372 |
+
if self.avg_pooler:
|
| 373 |
+
x_list = x.split(audio_aftercnnlens, dim=0)
|
| 374 |
+
token_x_list = []
|
| 375 |
+
for x in x_list:
|
| 376 |
+
x = x.permute(1, 0)
|
| 377 |
+
x = self.avg_pooler(x)
|
| 378 |
+
x = x.permute(1, 0)
|
| 379 |
+
token_x_list.append(x)
|
| 380 |
+
x = torch.cat(token_x_list, dim=0)
|
| 381 |
+
|
| 382 |
+
x = self.ln_post(x)
|
| 383 |
+
x = self.proj(x)
|
| 384 |
+
|
| 385 |
+
output = torch.zeros(
|
| 386 |
+
(x.size(0) + len(audio_seqlens) * 2, x.size(1)),
|
| 387 |
+
device=x.device, dtype=x.dtype
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
audio_seqlens_acc = list(accumulate(audio_seqlens, func=operator.add, initial=0))
|
| 391 |
+
start_ids = torch.tensor(audio_seqlens_acc[:-1], device=x.device, dtype=torch.int32)
|
| 392 |
+
end_ids = torch.tensor(audio_seqlens_acc[1:], device=x.device, dtype=torch.int32) - 1
|
| 393 |
+
|
| 394 |
+
audio_tokens_mask = torch.ones(output.size(0), device=x.device, dtype=torch.bool)
|
| 395 |
+
audio_tokens_mask[start_ids] = False
|
| 396 |
+
audio_tokens_mask[end_ids] = False
|
| 397 |
+
output[start_ids] = self.audio_bos_eos_token.weight[0].to(x.dtype)
|
| 398 |
+
output[end_ids] = self.audio_bos_eos_token.weight[1].to(x.dtype)
|
| 399 |
+
output[audio_tokens_mask] = x
|
| 400 |
+
return output
|
| 401 |
+
|
| 402 |
+
def lock(self, layers: int):
|
| 403 |
+
self.conv1.requires_grad_(False)
|
| 404 |
+
self.conv2.requires_grad_(False)
|
| 405 |
+
for i in range(min(layers, len(self.blocks))):
|
| 406 |
+
self.blocks[i].requires_grad_(False)
|
qwen_tts/inference/qwen3_tts_model.py
ADDED
|
@@ -0,0 +1,877 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
import base64
|
| 17 |
+
import io
|
| 18 |
+
import urllib.request
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 21 |
+
from urllib.parse import urlparse
|
| 22 |
+
|
| 23 |
+
import librosa
|
| 24 |
+
import numpy as np
|
| 25 |
+
import soundfile as sf
|
| 26 |
+
import torch
|
| 27 |
+
from transformers import AutoConfig, AutoModel, AutoProcessor
|
| 28 |
+
|
| 29 |
+
from ..core.models import Qwen3TTSConfig, Qwen3TTSForConditionalGeneration, Qwen3TTSProcessor
|
| 30 |
+
|
| 31 |
+
AudioLike = Union[
|
| 32 |
+
str, # wav path, URL, base64
|
| 33 |
+
np.ndarray, # waveform (requires sr)
|
| 34 |
+
Tuple[np.ndarray, int], # (waveform, sr)
|
| 35 |
+
]
|
| 36 |
+
|
| 37 |
+
MaybeList = Union[Any, List[Any]]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class VoiceClonePromptItem:
|
| 42 |
+
"""
|
| 43 |
+
Container for one sample's voice-clone prompt information that can be fed to the model.
|
| 44 |
+
|
| 45 |
+
Fields are aligned with `Qwen3TTSForConditionalGeneration.generate(..., voice_clone_prompt=...)`.
|
| 46 |
+
"""
|
| 47 |
+
ref_code: Optional[torch.Tensor] # (T, Q) or (T,) depending on tokenizer 25Hz/12Hz
|
| 48 |
+
ref_spk_embedding: torch.Tensor # (D,)
|
| 49 |
+
x_vector_only_mode: bool
|
| 50 |
+
icl_mode: bool
|
| 51 |
+
ref_text: Optional[str] = None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class Qwen3TTSModel:
|
| 55 |
+
"""
|
| 56 |
+
A HuggingFace-style wrapper for Qwen3 TTS models (CustomVoice/VoiceDesign/Base) that provides:
|
| 57 |
+
- from_pretrained() initialization via AutoModel/AutoProcessor
|
| 58 |
+
- generation APIs for:
|
| 59 |
+
* CustomVoice: generate_custom_voice()
|
| 60 |
+
* VoiceDesign: generate_voice_design()
|
| 61 |
+
* Base: generate_voice_clone() + create_voice_clone_prompt()
|
| 62 |
+
- consistent output: (wavs: List[np.ndarray], sample_rate: int)
|
| 63 |
+
|
| 64 |
+
Notes:
|
| 65 |
+
- This wrapper expects the underlying model class to be `Qwen3TTSForConditionalGeneration`
|
| 66 |
+
- Language / speaker validation is done via model methods:
|
| 67 |
+
model.get_supported_languages(), model.get_supported_speakers()
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, model: Qwen3TTSForConditionalGeneration, processor, generate_defaults: Optional[Dict[str, Any]] = None):
|
| 71 |
+
self.model = model
|
| 72 |
+
self.processor = processor
|
| 73 |
+
self.generate_defaults = generate_defaults or {}
|
| 74 |
+
|
| 75 |
+
self.device = getattr(model, "device", None)
|
| 76 |
+
if self.device is None:
|
| 77 |
+
try:
|
| 78 |
+
self.device = next(model.parameters()).device
|
| 79 |
+
except StopIteration:
|
| 80 |
+
self.device = torch.device("cpu")
|
| 81 |
+
|
| 82 |
+
@classmethod
|
| 83 |
+
def from_pretrained(
|
| 84 |
+
cls,
|
| 85 |
+
pretrained_model_name_or_path: str,
|
| 86 |
+
**kwargs,
|
| 87 |
+
) -> "Qwen3TTSModel":
|
| 88 |
+
"""
|
| 89 |
+
Load a Qwen3 TTS model and its processor in HuggingFace `from_pretrained` style.
|
| 90 |
+
|
| 91 |
+
This method:
|
| 92 |
+
1) Loads config via AutoConfig (so your side can register model_type -> config/model).
|
| 93 |
+
2) Loads the model via AutoModel.from_pretrained(...), forwarding `kwargs` unchanged.
|
| 94 |
+
3) Loads the processor via AutoProcessor.from_pretrained(model_path).
|
| 95 |
+
4) Loads optional `generate_config.json` from the model directory/repo snapshot if present.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
pretrained_model_name_or_path (str):
|
| 99 |
+
HuggingFace repo id or local directory of the model.
|
| 100 |
+
**kwargs:
|
| 101 |
+
Forwarded as-is into `AutoModel.from_pretrained(...)`.
|
| 102 |
+
Typical examples: device_map="cuda:0", dtype=torch.bfloat16, attn_implementation="flash_attention_2".
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
Qwen3TTSModel:
|
| 106 |
+
Wrapper instance containing `model`, `processor`, and generation defaults.
|
| 107 |
+
"""
|
| 108 |
+
AutoConfig.register("qwen3_tts", Qwen3TTSConfig)
|
| 109 |
+
AutoModel.register(Qwen3TTSConfig, Qwen3TTSForConditionalGeneration)
|
| 110 |
+
AutoProcessor.register(Qwen3TTSConfig, Qwen3TTSProcessor)
|
| 111 |
+
|
| 112 |
+
model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 113 |
+
if not isinstance(model, Qwen3TTSForConditionalGeneration):
|
| 114 |
+
raise TypeError(
|
| 115 |
+
f"AutoModel returned {type(model)}, expected Qwen3TTSForConditionalGeneration. "
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True,)
|
| 119 |
+
|
| 120 |
+
generate_defaults = model.generate_config
|
| 121 |
+
return cls(model=model, processor=processor, generate_defaults=generate_defaults)
|
| 122 |
+
|
| 123 |
+
def _supported_languages_set(self) -> Optional[set]:
|
| 124 |
+
langs = getattr(self.model, "get_supported_languages", None)
|
| 125 |
+
if callable(langs):
|
| 126 |
+
v = langs()
|
| 127 |
+
if v is None:
|
| 128 |
+
return None
|
| 129 |
+
return set([str(x).lower() for x in v])
|
| 130 |
+
return None
|
| 131 |
+
|
| 132 |
+
def _supported_speakers_set(self) -> Optional[set]:
|
| 133 |
+
spks = getattr(self.model, "get_supported_speakers", None)
|
| 134 |
+
if callable(spks):
|
| 135 |
+
v = spks()
|
| 136 |
+
if v is None:
|
| 137 |
+
return None
|
| 138 |
+
return set([str(x).lower() for x in v])
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
def _validate_languages(self, languages: List[str]) -> None:
|
| 142 |
+
"""
|
| 143 |
+
Validate that requested languages are supported by the model.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
languages (List[str]): Language names for each sample.
|
| 147 |
+
|
| 148 |
+
Raises:
|
| 149 |
+
ValueError: If any language is not supported.
|
| 150 |
+
"""
|
| 151 |
+
supported = self._supported_languages_set()
|
| 152 |
+
if supported is None:
|
| 153 |
+
return
|
| 154 |
+
|
| 155 |
+
bad = []
|
| 156 |
+
for lang in languages:
|
| 157 |
+
if lang is None:
|
| 158 |
+
bad.append(lang)
|
| 159 |
+
continue
|
| 160 |
+
if str(lang).lower() not in supported:
|
| 161 |
+
bad.append(lang)
|
| 162 |
+
if bad:
|
| 163 |
+
raise ValueError(f"Unsupported languages: {bad}. Supported: {sorted(supported)}")
|
| 164 |
+
|
| 165 |
+
def _validate_speakers(self, speakers: List[Optional[str]]) -> None:
|
| 166 |
+
"""
|
| 167 |
+
Validate that requested speakers are supported by the Instruct model.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
speakers (List[Optional[str]]): Speaker names for each sample.
|
| 171 |
+
|
| 172 |
+
Raises:
|
| 173 |
+
ValueError: If any speaker is not supported.
|
| 174 |
+
"""
|
| 175 |
+
supported = self._supported_speakers_set()
|
| 176 |
+
if supported is None:
|
| 177 |
+
return
|
| 178 |
+
|
| 179 |
+
bad = []
|
| 180 |
+
for spk in speakers:
|
| 181 |
+
if spk is None or spk == "":
|
| 182 |
+
continue
|
| 183 |
+
if str(spk).lower() not in supported:
|
| 184 |
+
bad.append(spk)
|
| 185 |
+
if bad:
|
| 186 |
+
raise ValueError(f"Unsupported speakers: {bad}. Supported: {sorted(supported)}")
|
| 187 |
+
|
| 188 |
+
def _is_probably_base64(self, s: str) -> bool:
|
| 189 |
+
if s.startswith("data:audio"):
|
| 190 |
+
return True
|
| 191 |
+
if ("/" not in s and "\\" not in s) and len(s) > 256:
|
| 192 |
+
return True
|
| 193 |
+
return False
|
| 194 |
+
|
| 195 |
+
def _is_url(self, s: str) -> bool:
|
| 196 |
+
try:
|
| 197 |
+
u = urlparse(s)
|
| 198 |
+
return u.scheme in ("http", "https") and bool(u.netloc)
|
| 199 |
+
except Exception:
|
| 200 |
+
return False
|
| 201 |
+
|
| 202 |
+
def _decode_base64_to_wav_bytes(self, b64: str) -> bytes:
|
| 203 |
+
if "," in b64 and b64.strip().startswith("data:"):
|
| 204 |
+
b64 = b64.split(",", 1)[1]
|
| 205 |
+
return base64.b64decode(b64)
|
| 206 |
+
|
| 207 |
+
def _load_audio_to_np(self, x: str) -> Tuple[np.ndarray, int]:
|
| 208 |
+
if self._is_url(x):
|
| 209 |
+
with urllib.request.urlopen(x) as resp:
|
| 210 |
+
audio_bytes = resp.read()
|
| 211 |
+
with io.BytesIO(audio_bytes) as f:
|
| 212 |
+
audio, sr = sf.read(f, dtype="float32", always_2d=False)
|
| 213 |
+
elif self._is_probably_base64(x):
|
| 214 |
+
wav_bytes = self._decode_base64_to_wav_bytes(x)
|
| 215 |
+
with io.BytesIO(wav_bytes) as f:
|
| 216 |
+
audio, sr = sf.read(f, dtype="float32", always_2d=False)
|
| 217 |
+
else:
|
| 218 |
+
audio, sr = librosa.load(x, sr=None, mono=True)
|
| 219 |
+
|
| 220 |
+
if audio.ndim > 1:
|
| 221 |
+
audio = np.mean(audio, axis=-1)
|
| 222 |
+
|
| 223 |
+
return audio.astype(np.float32), int(sr)
|
| 224 |
+
|
| 225 |
+
def _normalize_audio_inputs(self, audios: Union[AudioLike, List[AudioLike]]) -> List[Tuple[np.ndarray, int]]:
|
| 226 |
+
"""
|
| 227 |
+
Normalize audio inputs into a list of (waveform, sr).
|
| 228 |
+
|
| 229 |
+
Supported forms:
|
| 230 |
+
- str: wav path / URL / base64 audio string
|
| 231 |
+
- (np.ndarray, sr): waveform + sampling rate
|
| 232 |
+
- list of the above
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
audios:
|
| 236 |
+
Audio input(s).
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
List[Tuple[np.ndarray, int]]:
|
| 240 |
+
List of (float32 waveform, original sr).
|
| 241 |
+
|
| 242 |
+
Raises:
|
| 243 |
+
ValueError: If a numpy waveform is provided without sr.
|
| 244 |
+
"""
|
| 245 |
+
if isinstance(audios, list):
|
| 246 |
+
items = audios
|
| 247 |
+
else:
|
| 248 |
+
items = [audios]
|
| 249 |
+
|
| 250 |
+
out: List[Tuple[np.ndarray, int]] = []
|
| 251 |
+
for a in items:
|
| 252 |
+
if isinstance(a, str):
|
| 253 |
+
out.append(self._load_audio_to_np(a))
|
| 254 |
+
elif isinstance(a, tuple) and len(a) == 2 and isinstance(a[0], np.ndarray):
|
| 255 |
+
out.append((a[0].astype(np.float32), int(a[1])))
|
| 256 |
+
elif isinstance(a, np.ndarray):
|
| 257 |
+
raise ValueError("For numpy waveform input, pass a tuple (audio, sr).")
|
| 258 |
+
else:
|
| 259 |
+
raise TypeError(f"Unsupported audio input type: {type(a)}")
|
| 260 |
+
for i, a in enumerate(out):
|
| 261 |
+
if a[0].ndim > 1:
|
| 262 |
+
a[0] = np.mean(a[0], axis=-1).astype(np.float32)
|
| 263 |
+
out[i] = (a[0], a[1])
|
| 264 |
+
return out
|
| 265 |
+
|
| 266 |
+
def _ensure_list(self, x: MaybeList) -> List[Any]:
|
| 267 |
+
return x if isinstance(x, list) else [x]
|
| 268 |
+
|
| 269 |
+
def _build_assistant_text(self, text: str) -> str:
|
| 270 |
+
return f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
|
| 271 |
+
|
| 272 |
+
def _build_ref_text(self, text: str) -> str:
|
| 273 |
+
return f"<|im_start|>assistant\n{text}<|im_end|>\n"
|
| 274 |
+
|
| 275 |
+
def _build_instruct_text(self, instruct: str) -> str:
|
| 276 |
+
return f"<|im_start|>user\n{instruct}<|im_end|>\n"
|
| 277 |
+
|
| 278 |
+
def _tokenize_texts(self, texts: List[str]) -> List[torch.Tensor]:
|
| 279 |
+
input_ids = []
|
| 280 |
+
for text in texts:
|
| 281 |
+
input = self.processor(text=text, return_tensors="pt", padding=True)
|
| 282 |
+
input_id = input["input_ids"].to(self.device)
|
| 283 |
+
input_id = input_id.unsqueeze(0) if input_id.dim() == 1 else input_id
|
| 284 |
+
input_ids.append(input_id)
|
| 285 |
+
return input_ids
|
| 286 |
+
|
| 287 |
+
def _merge_generate_kwargs(
|
| 288 |
+
self,
|
| 289 |
+
do_sample: Optional[bool] = None,
|
| 290 |
+
top_k: Optional[int] = None,
|
| 291 |
+
top_p: Optional[float] = None,
|
| 292 |
+
temperature: Optional[float] = None,
|
| 293 |
+
repetition_penalty: Optional[float] = None,
|
| 294 |
+
subtalker_dosample: Optional[bool] = None,
|
| 295 |
+
subtalker_top_k: Optional[int] = None,
|
| 296 |
+
subtalker_top_p: Optional[float] = None,
|
| 297 |
+
subtalker_temperature: Optional[float] = None,
|
| 298 |
+
max_new_tokens: Optional[int] = None,
|
| 299 |
+
**kwargs,
|
| 300 |
+
) -> Dict[str, Any]:
|
| 301 |
+
"""
|
| 302 |
+
Merge user-provided generation arguments with defaults from `generate_config.json`.
|
| 303 |
+
|
| 304 |
+
Rule:
|
| 305 |
+
- If the user explicitly passes a value (not None), use it.
|
| 306 |
+
- Otherwise, use the value from generate_config.json if present.
|
| 307 |
+
- Otherwise, fall back to the hard defaults.
|
| 308 |
+
|
| 309 |
+
Args:
|
| 310 |
+
do_sample, top_k, top_p, temperature, repetition_penalty,
|
| 311 |
+
subtalker_dosample, subtalker_top_k, subtalker_top_p, subtalker_temperature, max_new_tokens:
|
| 312 |
+
Common generation parameters.
|
| 313 |
+
**kwargs:
|
| 314 |
+
Other arguments forwarded to model.generate().
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
Dict[str, Any]: Final kwargs to pass into model.generate().
|
| 318 |
+
"""
|
| 319 |
+
hard_defaults = dict(
|
| 320 |
+
do_sample=True,
|
| 321 |
+
top_k=50,
|
| 322 |
+
top_p=1.0,
|
| 323 |
+
temperature=0.9,
|
| 324 |
+
repetition_penalty=1.05,
|
| 325 |
+
subtalker_dosample=True,
|
| 326 |
+
subtalker_top_k=50,
|
| 327 |
+
subtalker_top_p=1.0,
|
| 328 |
+
subtalker_temperature=0.9,
|
| 329 |
+
max_new_tokens=2048,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
def pick(name: str, user_val: Any) -> Any:
|
| 333 |
+
if user_val is not None:
|
| 334 |
+
return user_val
|
| 335 |
+
if name in self.generate_defaults:
|
| 336 |
+
return self.generate_defaults[name]
|
| 337 |
+
return hard_defaults[name]
|
| 338 |
+
|
| 339 |
+
merged = dict(kwargs)
|
| 340 |
+
merged.update(
|
| 341 |
+
do_sample=pick("do_sample", do_sample),
|
| 342 |
+
top_k=pick("top_k", top_k),
|
| 343 |
+
top_p=pick("top_p", top_p),
|
| 344 |
+
temperature=pick("temperature", temperature),
|
| 345 |
+
repetition_penalty=pick("repetition_penalty", repetition_penalty),
|
| 346 |
+
subtalker_dosample=pick("subtalker_dosample", subtalker_dosample),
|
| 347 |
+
subtalker_top_k=pick("subtalker_top_k", subtalker_top_k),
|
| 348 |
+
subtalker_top_p=pick("subtalker_top_p", subtalker_top_p),
|
| 349 |
+
subtalker_temperature=pick("subtalker_temperature", subtalker_temperature),
|
| 350 |
+
max_new_tokens=pick("max_new_tokens", max_new_tokens),
|
| 351 |
+
)
|
| 352 |
+
return merged
|
| 353 |
+
|
| 354 |
+
# voice clone model
|
| 355 |
+
@torch.inference_mode()
|
| 356 |
+
def create_voice_clone_prompt(
|
| 357 |
+
self,
|
| 358 |
+
ref_audio: Union[AudioLike, List[AudioLike]],
|
| 359 |
+
ref_text: Optional[Union[str, List[Optional[str]]]] = None,
|
| 360 |
+
x_vector_only_mode: Union[bool, List[bool]] = False,
|
| 361 |
+
) -> List[VoiceClonePromptItem]:
|
| 362 |
+
"""
|
| 363 |
+
Build voice-clone prompt items from reference audio (and optionally reference text) using Base model.
|
| 364 |
+
|
| 365 |
+
Modes:
|
| 366 |
+
- x_vector_only_mode=True:
|
| 367 |
+
Only speaker embedding is used to clone voice; ref_text/ref_code are ignored.
|
| 368 |
+
This is mutually exclusive with ICL.
|
| 369 |
+
- x_vector_only_mode=False:
|
| 370 |
+
ICL mode is enabled automatically (icl_mode=True). In this case ref_text is required,
|
| 371 |
+
because the model continues/conditions on the reference text + reference speech codes.
|
| 372 |
+
|
| 373 |
+
Batch behavior:
|
| 374 |
+
- ref_audio can be a single item or a list.
|
| 375 |
+
- ref_text and x_vector_only_mode can be scalars or lists.
|
| 376 |
+
- If any of them are lists with length > 1, lengths must match.
|
| 377 |
+
|
| 378 |
+
Audio input:
|
| 379 |
+
- str: local wav path / URL / base64
|
| 380 |
+
- (np.ndarray, sr): waveform + sampling rate
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
ref_audio:
|
| 384 |
+
Reference audio(s) used to extract:
|
| 385 |
+
- ref_code via `model.speech_tokenizer.encode(...)`
|
| 386 |
+
- ref_spk_embedding via `model.extract_speaker_embedding(...)` (resampled to 24k)
|
| 387 |
+
ref_text:
|
| 388 |
+
Reference transcript(s). Required when x_vector_only_mode=False (ICL mode).
|
| 389 |
+
x_vector_only_mode:
|
| 390 |
+
Whether to use speaker embedding only. If False, ICL mode will be used.
|
| 391 |
+
|
| 392 |
+
Returns:
|
| 393 |
+
List[VoiceClonePromptItem]:
|
| 394 |
+
List of prompt items that can be converted into `voice_clone_prompt` dict.
|
| 395 |
+
|
| 396 |
+
Raises:
|
| 397 |
+
ValueError:
|
| 398 |
+
- If x_vector_only_mode=False but ref_text is missing.
|
| 399 |
+
- If batch lengths mismatch.
|
| 400 |
+
"""
|
| 401 |
+
if self.model.tts_model_type != "base":
|
| 402 |
+
raise ValueError(
|
| 403 |
+
f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
|
| 404 |
+
f"tts_model_size: {self.model.tts_model_size}\n"
|
| 405 |
+
f"tts_model_type: {self.model.tts_model_type}\n"
|
| 406 |
+
"does not support create_voice_clone_prompt, Please check Model Card or Readme for more details."
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
ref_audio_list = self._ensure_list(ref_audio)
|
| 410 |
+
ref_text_list = self._ensure_list(ref_text) if isinstance(ref_text, list) else ([ref_text] * len(ref_audio_list))
|
| 411 |
+
xvec_list = self._ensure_list(x_vector_only_mode) if isinstance(x_vector_only_mode, list) else ([x_vector_only_mode] * len(ref_audio_list))
|
| 412 |
+
|
| 413 |
+
if len(ref_text_list) != len(ref_audio_list) or len(xvec_list) != len(ref_audio_list):
|
| 414 |
+
raise ValueError(
|
| 415 |
+
f"Batch size mismatch: ref_audio={len(ref_audio_list)}, ref_text={len(ref_text_list)}, x_vector_only_mode={len(xvec_list)}"
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
normalized = self._normalize_audio_inputs(ref_audio_list)
|
| 419 |
+
|
| 420 |
+
ref_wavs_for_code: List[np.ndarray] = []
|
| 421 |
+
ref_sr_for_code: List[int] = []
|
| 422 |
+
for wav, sr in normalized:
|
| 423 |
+
ref_wavs_for_code.append(wav)
|
| 424 |
+
ref_sr_for_code.append(sr)
|
| 425 |
+
|
| 426 |
+
if len(set(ref_sr_for_code)) == 1:
|
| 427 |
+
enc = self.model.speech_tokenizer.encode(ref_wavs_for_code, sr=ref_sr_for_code[0])
|
| 428 |
+
ref_codes = enc.audio_codes
|
| 429 |
+
else:
|
| 430 |
+
ref_codes = []
|
| 431 |
+
for wav, sr in normalized:
|
| 432 |
+
ref_codes.append(self.model.speech_tokenizer.encode(wav, sr=sr).audio_codes[0])
|
| 433 |
+
|
| 434 |
+
items: List[VoiceClonePromptItem] = []
|
| 435 |
+
for i, ((wav, sr), code, rtext, xvec_only) in enumerate(zip(normalized, ref_codes, ref_text_list, xvec_list)):
|
| 436 |
+
if not xvec_only:
|
| 437 |
+
if rtext is None or rtext == "":
|
| 438 |
+
raise ValueError(f"ref_text is required when x_vector_only_mode=False (ICL mode). Bad index={i}")
|
| 439 |
+
|
| 440 |
+
wav_resample = wav
|
| 441 |
+
if sr != self.model.speaker_encoder_sample_rate:
|
| 442 |
+
wav_resample = librosa.resample(y=wav_resample.astype(np.float32),
|
| 443 |
+
orig_sr=int(sr),
|
| 444 |
+
target_sr=self.model.speaker_encoder_sample_rate)
|
| 445 |
+
|
| 446 |
+
spk_emb = self.model.extract_speaker_embedding(audio=wav_resample,
|
| 447 |
+
sr=self.model.speaker_encoder_sample_rate)
|
| 448 |
+
|
| 449 |
+
items.append(
|
| 450 |
+
VoiceClonePromptItem(
|
| 451 |
+
ref_code=None if xvec_only else code,
|
| 452 |
+
ref_spk_embedding=spk_emb,
|
| 453 |
+
x_vector_only_mode=bool(xvec_only),
|
| 454 |
+
icl_mode=bool(not xvec_only),
|
| 455 |
+
ref_text=rtext,
|
| 456 |
+
)
|
| 457 |
+
)
|
| 458 |
+
return items
|
| 459 |
+
|
| 460 |
+
def _prompt_items_to_voice_clone_prompt(self, items: List[VoiceClonePromptItem]) -> Dict[str, Any]:
|
| 461 |
+
return dict(
|
| 462 |
+
ref_code=[it.ref_code for it in items],
|
| 463 |
+
ref_spk_embedding=[it.ref_spk_embedding for it in items],
|
| 464 |
+
x_vector_only_mode=[it.x_vector_only_mode for it in items],
|
| 465 |
+
icl_mode=[it.icl_mode for it in items],
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
# voice clone model
|
| 469 |
+
@torch.no_grad()
|
| 470 |
+
def generate_voice_clone(
|
| 471 |
+
self,
|
| 472 |
+
text: Union[str, List[str]],
|
| 473 |
+
language: Union[str, List[str]] = None,
|
| 474 |
+
ref_audio: Optional[Union[AudioLike, List[AudioLike]]] = None,
|
| 475 |
+
ref_text: Optional[Union[str, List[Optional[str]]]] = None,
|
| 476 |
+
x_vector_only_mode: Union[bool, List[bool]] = False,
|
| 477 |
+
voice_clone_prompt: Optional[Union[Dict[str, Any], List[VoiceClonePromptItem]]] = None,
|
| 478 |
+
non_streaming_mode: bool = False,
|
| 479 |
+
**kwargs,
|
| 480 |
+
) -> Tuple[List[np.ndarray], int]:
|
| 481 |
+
"""
|
| 482 |
+
Voice clone speech using the Base model.
|
| 483 |
+
|
| 484 |
+
You can provide either:
|
| 485 |
+
- (ref_audio, ref_text, x_vector_only_mode) and let this method build the prompt, OR
|
| 486 |
+
- `VoiceClonePromptItem` returned by `create_voice_clone_prompt`, OR
|
| 487 |
+
- a list of `VoiceClonePromptItem` returned by `create_voice_clone_prompt`.
|
| 488 |
+
|
| 489 |
+
`ref_audio` Supported forms:
|
| 490 |
+
- str: wav path / URL / base64 audio string
|
| 491 |
+
- (np.ndarray, sr): waveform + sampling rate
|
| 492 |
+
- list of the above
|
| 493 |
+
|
| 494 |
+
Input flexibility:
|
| 495 |
+
- text/language can be scalar or list.
|
| 496 |
+
- prompt can be single or batch.
|
| 497 |
+
- If batch mode (len(text)>1), lengths must match.
|
| 498 |
+
|
| 499 |
+
Args:
|
| 500 |
+
text:
|
| 501 |
+
Text(s) to synthesize.
|
| 502 |
+
language:
|
| 503 |
+
Language(s) for each sample.
|
| 504 |
+
ref_audio:
|
| 505 |
+
Reference audio(s) for prompt building. Required if voice_clone_prompt is not provided.
|
| 506 |
+
ref_text:
|
| 507 |
+
Reference text(s) used for ICL mode (required when x_vector_only_mode=False).
|
| 508 |
+
x_vector_only_mode:
|
| 509 |
+
If True, only speaker embedding is used (ignores ref_text/ref_code).
|
| 510 |
+
If False, ICL mode is used automatically.
|
| 511 |
+
voice_clone_prompt:
|
| 512 |
+
list[VoiceClonePromptItem] from `create_voice_clone_prompt`.
|
| 513 |
+
non_streaming_mode:
|
| 514 |
+
Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
|
| 515 |
+
rather than enabling true streaming input or streaming generation.
|
| 516 |
+
do_sample:
|
| 517 |
+
Whether to use sampling, recommended to be set to `true` for most use cases.
|
| 518 |
+
top_k:
|
| 519 |
+
Top-k sampling parameter.
|
| 520 |
+
top_p:
|
| 521 |
+
Top-p sampling parameter.
|
| 522 |
+
temperature:
|
| 523 |
+
Sampling temperature; higher => more random.
|
| 524 |
+
repetition_penalty:
|
| 525 |
+
Penalty to reduce repeated tokens/codes.
|
| 526 |
+
subtalker_dosample:
|
| 527 |
+
Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
|
| 528 |
+
subtalker_top_k:
|
| 529 |
+
Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 530 |
+
subtalker_top_p:
|
| 531 |
+
Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 532 |
+
subtalker_temperature:
|
| 533 |
+
Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 534 |
+
max_new_tokens:
|
| 535 |
+
Maximum number of new codec tokens to generate.
|
| 536 |
+
**kwargs:
|
| 537 |
+
Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
|
| 538 |
+
They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
|
| 539 |
+
|
| 540 |
+
Returns:
|
| 541 |
+
Tuple[List[np.ndarray], int]:
|
| 542 |
+
(wavs, sample_rate)
|
| 543 |
+
|
| 544 |
+
Raises:
|
| 545 |
+
ValueError:
|
| 546 |
+
If batch sizes mismatch or required prompt inputs are missing.
|
| 547 |
+
"""
|
| 548 |
+
if self.model.tts_model_type != "base":
|
| 549 |
+
raise ValueError(
|
| 550 |
+
f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
|
| 551 |
+
f"tts_model_size: {self.model.tts_model_size}\n"
|
| 552 |
+
f"tts_model_type: {self.model.tts_model_type}\n"
|
| 553 |
+
"does not support generate_voice_clone, Please check Model Card or Readme for more details."
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
texts = self._ensure_list(text)
|
| 557 |
+
languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
|
| 558 |
+
if len(languages) == 1 and len(texts) > 1:
|
| 559 |
+
languages = languages * len(texts)
|
| 560 |
+
if len(texts) != len(languages):
|
| 561 |
+
raise ValueError(f"Batch size mismatch: text={len(texts)}, language={len(languages)}")
|
| 562 |
+
|
| 563 |
+
self._validate_languages(languages)
|
| 564 |
+
|
| 565 |
+
if voice_clone_prompt is None:
|
| 566 |
+
if ref_audio is None:
|
| 567 |
+
raise ValueError("Either `voice_clone_prompt` or `ref_audio` must be provided.")
|
| 568 |
+
prompt_items = self.create_voice_clone_prompt(ref_audio=ref_audio, ref_text=ref_text, x_vector_only_mode=x_vector_only_mode)
|
| 569 |
+
if len(prompt_items) == 1 and len(texts) > 1:
|
| 570 |
+
prompt_items = prompt_items * len(texts)
|
| 571 |
+
if len(prompt_items) != len(texts):
|
| 572 |
+
raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}")
|
| 573 |
+
voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items)
|
| 574 |
+
ref_texts_for_ids = [it.ref_text for it in prompt_items]
|
| 575 |
+
else:
|
| 576 |
+
if isinstance(voice_clone_prompt, list):
|
| 577 |
+
prompt_items = voice_clone_prompt
|
| 578 |
+
if len(prompt_items) == 1 and len(texts) > 1:
|
| 579 |
+
prompt_items = prompt_items * len(texts)
|
| 580 |
+
if len(prompt_items) != len(texts):
|
| 581 |
+
raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}")
|
| 582 |
+
voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items)
|
| 583 |
+
ref_texts_for_ids = [it.ref_text for it in prompt_items]
|
| 584 |
+
else:
|
| 585 |
+
voice_clone_prompt_dict = voice_clone_prompt
|
| 586 |
+
ref_texts_for_ids = None
|
| 587 |
+
|
| 588 |
+
input_texts = [self._build_assistant_text(t) for t in texts]
|
| 589 |
+
input_ids = self._tokenize_texts(input_texts)
|
| 590 |
+
|
| 591 |
+
ref_ids = None
|
| 592 |
+
if ref_texts_for_ids is not None:
|
| 593 |
+
ref_ids = []
|
| 594 |
+
for i, rt in enumerate(ref_texts_for_ids):
|
| 595 |
+
if rt is None or rt == "":
|
| 596 |
+
ref_ids.append(None)
|
| 597 |
+
else:
|
| 598 |
+
ref_tok = self._tokenize_texts([self._build_ref_text(rt)])[0]
|
| 599 |
+
ref_ids.append(ref_tok)
|
| 600 |
+
|
| 601 |
+
gen_kwargs = self._merge_generate_kwargs(**kwargs)
|
| 602 |
+
|
| 603 |
+
talker_codes_list, _ = self.model.generate(
|
| 604 |
+
input_ids=input_ids,
|
| 605 |
+
ref_ids=ref_ids,
|
| 606 |
+
voice_clone_prompt=voice_clone_prompt_dict,
|
| 607 |
+
languages=languages,
|
| 608 |
+
non_streaming_mode=non_streaming_mode,
|
| 609 |
+
**gen_kwargs,
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
codes_for_decode = []
|
| 613 |
+
for i, codes in enumerate(talker_codes_list):
|
| 614 |
+
ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
|
| 615 |
+
if ref_code_list is not None and ref_code_list[i] is not None:
|
| 616 |
+
codes_for_decode.append(torch.cat([ref_code_list[i].to(codes.device), codes], dim=0))
|
| 617 |
+
else:
|
| 618 |
+
codes_for_decode.append(codes)
|
| 619 |
+
|
| 620 |
+
wavs_all, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in codes_for_decode])
|
| 621 |
+
|
| 622 |
+
wavs_out: List[np.ndarray] = []
|
| 623 |
+
for i, wav in enumerate(wavs_all):
|
| 624 |
+
ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
|
| 625 |
+
if ref_code_list is not None and ref_code_list[i] is not None:
|
| 626 |
+
ref_len = int(ref_code_list[i].shape[0])
|
| 627 |
+
total_len = int(codes_for_decode[i].shape[0])
|
| 628 |
+
cut = int(ref_len / max(total_len, 1) * wav.shape[0])
|
| 629 |
+
wavs_out.append(wav[cut:])
|
| 630 |
+
else:
|
| 631 |
+
wavs_out.append(wav)
|
| 632 |
+
|
| 633 |
+
return wavs_out, fs
|
| 634 |
+
|
| 635 |
+
# voice design model
|
| 636 |
+
@torch.no_grad()
|
| 637 |
+
def generate_voice_design(
|
| 638 |
+
self,
|
| 639 |
+
text: Union[str, List[str]],
|
| 640 |
+
instruct: Union[str, List[str]],
|
| 641 |
+
language: Union[str, List[str]] = None,
|
| 642 |
+
non_streaming_mode: bool = True,
|
| 643 |
+
**kwargs,
|
| 644 |
+
) -> Tuple[List[np.ndarray], int]:
|
| 645 |
+
"""
|
| 646 |
+
Generate speech with the VoiceDesign model using natural-language style instructions.
|
| 647 |
+
|
| 648 |
+
Args:
|
| 649 |
+
text:
|
| 650 |
+
Text(s) to synthesize.
|
| 651 |
+
language:
|
| 652 |
+
Language(s) for each sample.
|
| 653 |
+
instruct:
|
| 654 |
+
Instruction(s) describing desired voice/style. Empty string is allowed (treated as no instruction).
|
| 655 |
+
non_streaming_mode:
|
| 656 |
+
Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
|
| 657 |
+
rather than enabling true streaming input or streaming generation.
|
| 658 |
+
do_sample:
|
| 659 |
+
Whether to use sampling, recommended to be set to `true` for most use cases.
|
| 660 |
+
top_k:
|
| 661 |
+
Top-k sampling parameter.
|
| 662 |
+
top_p:
|
| 663 |
+
Top-p sampling parameter.
|
| 664 |
+
temperature:
|
| 665 |
+
Sampling temperature; higher => more random.
|
| 666 |
+
repetition_penalty:
|
| 667 |
+
Penalty to reduce repeated tokens/codes.
|
| 668 |
+
subtalker_dosample:
|
| 669 |
+
Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
|
| 670 |
+
subtalker_top_k:
|
| 671 |
+
Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 672 |
+
subtalker_top_p:
|
| 673 |
+
Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 674 |
+
subtalker_temperature:
|
| 675 |
+
Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 676 |
+
max_new_tokens:
|
| 677 |
+
Maximum number of new codec tokens to generate.
|
| 678 |
+
**kwargs:
|
| 679 |
+
Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
|
| 680 |
+
They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
|
| 681 |
+
|
| 682 |
+
Returns:
|
| 683 |
+
Tuple[List[np.ndarray], int]:
|
| 684 |
+
(wavs, sample_rate)
|
| 685 |
+
"""
|
| 686 |
+
if self.model.tts_model_type != "voice_design":
|
| 687 |
+
raise ValueError(
|
| 688 |
+
f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
|
| 689 |
+
f"tts_model_size: {self.model.tts_model_size}\n"
|
| 690 |
+
f"tts_model_type: {self.model.tts_model_type}\n"
|
| 691 |
+
"does not support generate_voice_design, Please check Model Card or Readme for more details."
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
texts = self._ensure_list(text)
|
| 695 |
+
languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
|
| 696 |
+
instructs = self._ensure_list(instruct)
|
| 697 |
+
|
| 698 |
+
if len(languages) == 1 and len(texts) > 1:
|
| 699 |
+
languages = languages * len(texts)
|
| 700 |
+
if len(instructs) == 1 and len(texts) > 1:
|
| 701 |
+
instructs = instructs * len(texts)
|
| 702 |
+
|
| 703 |
+
if not (len(texts) == len(languages) == len(instructs)):
|
| 704 |
+
raise ValueError(f"Batch size mismatch: text={len(texts)}, language={len(languages)}, instruct={len(instructs)}")
|
| 705 |
+
|
| 706 |
+
self._validate_languages(languages)
|
| 707 |
+
|
| 708 |
+
input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts])
|
| 709 |
+
|
| 710 |
+
instruct_ids: List[Optional[torch.Tensor]] = []
|
| 711 |
+
for ins in instructs:
|
| 712 |
+
if ins is None or ins == "":
|
| 713 |
+
instruct_ids.append(None)
|
| 714 |
+
else:
|
| 715 |
+
instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0])
|
| 716 |
+
|
| 717 |
+
gen_kwargs = self._merge_generate_kwargs(**kwargs)
|
| 718 |
+
|
| 719 |
+
talker_codes_list, _ = self.model.generate(
|
| 720 |
+
input_ids=input_ids,
|
| 721 |
+
instruct_ids=instruct_ids,
|
| 722 |
+
languages=languages,
|
| 723 |
+
non_streaming_mode=non_streaming_mode,
|
| 724 |
+
**gen_kwargs,
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
wavs, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in talker_codes_list])
|
| 728 |
+
return wavs, fs
|
| 729 |
+
|
| 730 |
+
# custom voice model
|
| 731 |
+
@torch.no_grad()
|
| 732 |
+
def generate_custom_voice(
|
| 733 |
+
self,
|
| 734 |
+
text: Union[str, List[str]],
|
| 735 |
+
speaker: Union[str, List[str]],
|
| 736 |
+
language: Union[str, List[str]] = None,
|
| 737 |
+
instruct: Optional[Union[str, List[str]]] = None,
|
| 738 |
+
non_streaming_mode: bool = True,
|
| 739 |
+
**kwargs,
|
| 740 |
+
) -> Tuple[List[np.ndarray], int]:
|
| 741 |
+
"""
|
| 742 |
+
Generate speech with the CustomVoice model using a predefined speaker id, optionally controlled by instruction text.
|
| 743 |
+
|
| 744 |
+
Args:
|
| 745 |
+
text:
|
| 746 |
+
Text(s) to synthesize.
|
| 747 |
+
language:
|
| 748 |
+
Language(s) for each sample.
|
| 749 |
+
speaker:
|
| 750 |
+
Speaker name(s). Will be validated against `model.get_supported_speakers()` (case-insensitive).
|
| 751 |
+
instruct:
|
| 752 |
+
Optional instruction(s). If None, treated as empty (no instruction).
|
| 753 |
+
non_streaming_mode:
|
| 754 |
+
Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
|
| 755 |
+
rather than enabling true streaming input or streaming generation.
|
| 756 |
+
do_sample:
|
| 757 |
+
Whether to use sampling, recommended to be set to `true` for most use cases.
|
| 758 |
+
top_k:
|
| 759 |
+
Top-k sampling parameter.
|
| 760 |
+
top_p:
|
| 761 |
+
Top-p sampling parameter.
|
| 762 |
+
temperature:
|
| 763 |
+
Sampling temperature; higher => more random.
|
| 764 |
+
repetition_penalty:
|
| 765 |
+
Penalty to reduce repeated tokens/codes.
|
| 766 |
+
subtalker_dosample:
|
| 767 |
+
Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
|
| 768 |
+
subtalker_top_k:
|
| 769 |
+
Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 770 |
+
subtalker_top_p:
|
| 771 |
+
Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 772 |
+
subtalker_temperature:
|
| 773 |
+
Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
| 774 |
+
max_new_tokens:
|
| 775 |
+
Maximum number of new codec tokens to generate.
|
| 776 |
+
**kwargs:
|
| 777 |
+
Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
|
| 778 |
+
They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
|
| 779 |
+
|
| 780 |
+
Returns:
|
| 781 |
+
Tuple[List[np.ndarray], int]:
|
| 782 |
+
(wavs, sample_rate)
|
| 783 |
+
|
| 784 |
+
Raises:
|
| 785 |
+
ValueError:
|
| 786 |
+
If any speaker/language is unsupported or batch sizes mismatch.
|
| 787 |
+
"""
|
| 788 |
+
if self.model.tts_model_type != "custom_voice":
|
| 789 |
+
raise ValueError(
|
| 790 |
+
f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
|
| 791 |
+
f"tts_model_size: {self.model.tts_model_size}\n"
|
| 792 |
+
f"tts_model_type: {self.model.tts_model_type}\n"
|
| 793 |
+
"does not support generate_custom_voice, Please check Model Card or Readme for more details."
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
texts = self._ensure_list(text)
|
| 797 |
+
languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
|
| 798 |
+
speakers = self._ensure_list(speaker)
|
| 799 |
+
if self.model.tts_model_size in "0b6": # for 0b6 model, instruct is not supported
|
| 800 |
+
instruct = None
|
| 801 |
+
instructs = self._ensure_list(instruct) if isinstance(instruct, list) else ([instruct] * len(texts) if instruct is not None else [""] * len(texts))
|
| 802 |
+
|
| 803 |
+
if len(languages) == 1 and len(texts) > 1:
|
| 804 |
+
languages = languages * len(texts)
|
| 805 |
+
if len(speakers) == 1 and len(texts) > 1:
|
| 806 |
+
speakers = speakers * len(texts)
|
| 807 |
+
if len(instructs) == 1 and len(texts) > 1:
|
| 808 |
+
instructs = instructs * len(texts)
|
| 809 |
+
|
| 810 |
+
if not (len(texts) == len(languages) == len(speakers) == len(instructs)):
|
| 811 |
+
raise ValueError(
|
| 812 |
+
f"Batch size mismatch: text={len(texts)}, language={len(languages)}, speaker={len(speakers)}, instruct={len(instructs)}"
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
self._validate_languages(languages)
|
| 816 |
+
self._validate_speakers(speakers)
|
| 817 |
+
|
| 818 |
+
input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts])
|
| 819 |
+
|
| 820 |
+
instruct_ids: List[Optional[torch.Tensor]] = []
|
| 821 |
+
for ins in instructs:
|
| 822 |
+
if ins is None or ins == "":
|
| 823 |
+
instruct_ids.append(None)
|
| 824 |
+
else:
|
| 825 |
+
instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0])
|
| 826 |
+
|
| 827 |
+
gen_kwargs = self._merge_generate_kwargs(**kwargs)
|
| 828 |
+
|
| 829 |
+
talker_codes_list, _ = self.model.generate(
|
| 830 |
+
input_ids=input_ids,
|
| 831 |
+
instruct_ids=instruct_ids,
|
| 832 |
+
languages=languages,
|
| 833 |
+
speakers=speakers,
|
| 834 |
+
non_streaming_mode=non_streaming_mode,
|
| 835 |
+
**gen_kwargs,
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
wavs, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in talker_codes_list])
|
| 839 |
+
return wavs, fs
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
def get_supported_speakers(self) -> Optional[List[str]]:
|
| 843 |
+
"""
|
| 844 |
+
List supported speaker names for the current model.
|
| 845 |
+
|
| 846 |
+
This is a convenience wrapper around `model.get_supported_speakers()`.
|
| 847 |
+
If the underlying model does not expose speaker constraints (returns None),
|
| 848 |
+
this method also returns None.
|
| 849 |
+
|
| 850 |
+
Returns:
|
| 851 |
+
Optional[List[str]]:
|
| 852 |
+
- A sorted list of supported speaker names (lowercased), if available.
|
| 853 |
+
- None if the model does not provide supported speakers.
|
| 854 |
+
"""
|
| 855 |
+
supported = self._supported_speakers_set()
|
| 856 |
+
if supported is None:
|
| 857 |
+
return None
|
| 858 |
+
return sorted(supported)
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
def get_supported_languages(self) -> Optional[List[str]]:
|
| 862 |
+
"""
|
| 863 |
+
List supported language names for the current model.
|
| 864 |
+
|
| 865 |
+
This is a convenience wrapper around `model.get_supported_languages()`.
|
| 866 |
+
If the underlying model does not expose language constraints (returns None),
|
| 867 |
+
this method also returns None.
|
| 868 |
+
|
| 869 |
+
Returns:
|
| 870 |
+
Optional[List[str]]:
|
| 871 |
+
- A sorted list of supported language names (lowercased), if available.
|
| 872 |
+
- None if the model does not provide supported languages.
|
| 873 |
+
"""
|
| 874 |
+
supported = self._supported_languages_set()
|
| 875 |
+
if supported is None:
|
| 876 |
+
return None
|
| 877 |
+
return sorted(supported)
|
qwen_tts/inference/qwen3_tts_tokenizer.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2026 The Alibaba Qwen team.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
import base64
|
| 17 |
+
import io
|
| 18 |
+
import urllib.request
|
| 19 |
+
from typing import List, Optional, Tuple, Union
|
| 20 |
+
from urllib.parse import urlparse
|
| 21 |
+
|
| 22 |
+
import librosa
|
| 23 |
+
import numpy as np
|
| 24 |
+
import soundfile as sf
|
| 25 |
+
import torch
|
| 26 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 27 |
+
from transformers import AutoConfig, AutoFeatureExtractor, AutoModel
|
| 28 |
+
|
| 29 |
+
from ..core import (
|
| 30 |
+
Qwen3TTSTokenizerV1Config,
|
| 31 |
+
Qwen3TTSTokenizerV1Model,
|
| 32 |
+
Qwen3TTSTokenizerV2Config,
|
| 33 |
+
Qwen3TTSTokenizerV2Model,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
AudioInput = Union[
|
| 37 |
+
str, # wav path, or base64 string
|
| 38 |
+
np.ndarray, # 1-D float array
|
| 39 |
+
List[str],
|
| 40 |
+
List[np.ndarray],
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Qwen3TTSTokenizer:
|
| 45 |
+
"""
|
| 46 |
+
A wrapper for Qwen3 TTS Tokenizer 25Hz/12Hz with HuggingFace-style loading.
|
| 47 |
+
|
| 48 |
+
- from_pretrained(): loads speech tokenizer model via AutoModel and feature_extractor via AutoFeatureExtractor.
|
| 49 |
+
- encode(): supports wav path(s), base64 audio string(s), numpy array(s).
|
| 50 |
+
- decode(): accepts either the raw model encode output, or a minimal dict/list-of-dicts.
|
| 51 |
+
|
| 52 |
+
Notes:
|
| 53 |
+
- For numpy array input, you must pass `sr` so the audio can be resampled to model sample rate.
|
| 54 |
+
- Returned audio is float32 numpy arrays and the output sample rate.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self):
|
| 58 |
+
self.model = None
|
| 59 |
+
self.feature_extractor = None
|
| 60 |
+
self.config = None
|
| 61 |
+
self.device = None
|
| 62 |
+
|
| 63 |
+
@classmethod
|
| 64 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "Qwen3TTSTokenizer":
|
| 65 |
+
"""
|
| 66 |
+
Initialize tokenizer with HuggingFace `from_pretrained` style.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
pretrained_model_name_or_path (str):
|
| 70 |
+
HuggingFace repo id or local directory.
|
| 71 |
+
**kwargs (Any):
|
| 72 |
+
Forwarded to `AutoModel.from_pretrained(...)` directly.
|
| 73 |
+
Typical examples: device_map="cuda:0", dtype=torch.bfloat16, attn_implementation="eager".
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Qwen3TTSTokenizer:
|
| 77 |
+
Initialized instance with `model`, `feature_extractor`, `config`.
|
| 78 |
+
"""
|
| 79 |
+
inst = cls()
|
| 80 |
+
|
| 81 |
+
AutoConfig.register("qwen3_tts_tokenizer_25hz", Qwen3TTSTokenizerV1Config)
|
| 82 |
+
AutoModel.register(Qwen3TTSTokenizerV1Config, Qwen3TTSTokenizerV1Model)
|
| 83 |
+
|
| 84 |
+
AutoConfig.register("qwen3_tts_tokenizer_12hz", Qwen3TTSTokenizerV2Config)
|
| 85 |
+
AutoModel.register(Qwen3TTSTokenizerV2Config, Qwen3TTSTokenizerV2Model)
|
| 86 |
+
|
| 87 |
+
inst.feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)
|
| 88 |
+
inst.model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 89 |
+
inst.config = inst.model.config
|
| 90 |
+
|
| 91 |
+
inst.device = getattr(inst.model, "device", None)
|
| 92 |
+
if inst.device is None:
|
| 93 |
+
# fallback: infer from first parameter device
|
| 94 |
+
try:
|
| 95 |
+
inst.device = next(inst.model.parameters()).device
|
| 96 |
+
except StopIteration:
|
| 97 |
+
inst.device = torch.device("cpu")
|
| 98 |
+
|
| 99 |
+
return inst
|
| 100 |
+
|
| 101 |
+
def _is_probably_base64(self, s: str) -> bool:
|
| 102 |
+
if s.startswith("data:audio"):
|
| 103 |
+
return True
|
| 104 |
+
# Heuristic: no filesystem path separators and long enough.
|
| 105 |
+
if ("/" not in s and "\\" not in s) and len(s) > 256:
|
| 106 |
+
return True
|
| 107 |
+
return False
|
| 108 |
+
|
| 109 |
+
def _is_url(self, s: str) -> bool:
|
| 110 |
+
try:
|
| 111 |
+
u = urlparse(s)
|
| 112 |
+
return u.scheme in ("http", "https") and bool(u.netloc)
|
| 113 |
+
except Exception:
|
| 114 |
+
return False
|
| 115 |
+
|
| 116 |
+
def _decode_base64_to_wav_bytes(self, b64: str) -> bytes:
|
| 117 |
+
# Accept both "data:audio/wav;base64,...." and raw base64
|
| 118 |
+
if "," in b64 and b64.strip().startswith("data:"):
|
| 119 |
+
b64 = b64.split(",", 1)[1]
|
| 120 |
+
return base64.b64decode(b64)
|
| 121 |
+
|
| 122 |
+
def load_audio(
|
| 123 |
+
self,
|
| 124 |
+
x: str,
|
| 125 |
+
target_sr: int,
|
| 126 |
+
) -> np.ndarray:
|
| 127 |
+
"""
|
| 128 |
+
Load audio from wav path or base64 string, then resample to target_sr.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
x (str):
|
| 132 |
+
A wav file path, or a base64 audio string (raw or data URL).
|
| 133 |
+
target_sr (int):
|
| 134 |
+
Target sampling rate.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
np.ndarray:
|
| 138 |
+
1-D float32 waveform at target_sr.
|
| 139 |
+
"""
|
| 140 |
+
if self._is_url(x):
|
| 141 |
+
with urllib.request.urlopen(x) as resp:
|
| 142 |
+
audio_bytes = resp.read()
|
| 143 |
+
with io.BytesIO(audio_bytes) as f:
|
| 144 |
+
audio, sr = sf.read(f, dtype="float32", always_2d=False)
|
| 145 |
+
elif self._is_probably_base64(x):
|
| 146 |
+
wav_bytes = self._decode_base64_to_wav_bytes(x)
|
| 147 |
+
with io.BytesIO(wav_bytes) as f:
|
| 148 |
+
audio, sr = sf.read(f, dtype="float32", always_2d=False)
|
| 149 |
+
else:
|
| 150 |
+
audio, sr = librosa.load(x, sr=None, mono=True)
|
| 151 |
+
|
| 152 |
+
if audio.ndim > 1:
|
| 153 |
+
audio = np.mean(audio, axis=-1)
|
| 154 |
+
|
| 155 |
+
if sr != target_sr:
|
| 156 |
+
audio = librosa.resample(y=audio, orig_sr=sr, target_sr=target_sr)
|
| 157 |
+
|
| 158 |
+
return audio.astype(np.float32)
|
| 159 |
+
|
| 160 |
+
def _normalize_audio_inputs(
|
| 161 |
+
self,
|
| 162 |
+
audios: AudioInput,
|
| 163 |
+
sr: Optional[int],
|
| 164 |
+
) -> List[np.ndarray]:
|
| 165 |
+
"""
|
| 166 |
+
Normalize all supported input types into a list of 1-D numpy float32 waveforms
|
| 167 |
+
at `self.feature_extractor.sampling_rate`.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
audios (AudioInput):
|
| 171 |
+
- str: wav path OR base64 audio string
|
| 172 |
+
- np.ndarray: raw waveform (sr must be provided)
|
| 173 |
+
- list[str] / list[np.ndarray]
|
| 174 |
+
sr (Optional[int]):
|
| 175 |
+
Sampling rate for raw numpy input. Required if input is np.ndarray or list[np.ndarray].
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
List[np.ndarray]:
|
| 179 |
+
List of float32 waveforms resampled to model input SR.
|
| 180 |
+
"""
|
| 181 |
+
target_sr = int(self.feature_extractor.sampling_rate)
|
| 182 |
+
|
| 183 |
+
if isinstance(audios, (str, np.ndarray)):
|
| 184 |
+
audios = [audios]
|
| 185 |
+
|
| 186 |
+
if len(audios) == 0:
|
| 187 |
+
return []
|
| 188 |
+
|
| 189 |
+
if isinstance(audios[0], str):
|
| 190 |
+
# wav path list or base64 list
|
| 191 |
+
return [self.load_audio(x, target_sr=target_sr) for x in audios] # type: ignore[arg-type]
|
| 192 |
+
|
| 193 |
+
# numpy list
|
| 194 |
+
if sr is None:
|
| 195 |
+
raise ValueError("For numpy waveform input, you must provide `sr` (original sampling rate).")
|
| 196 |
+
|
| 197 |
+
out: List[np.ndarray] = []
|
| 198 |
+
for a in audios: # type: ignore[assignment]
|
| 199 |
+
if not isinstance(a, np.ndarray):
|
| 200 |
+
raise TypeError("Mixed input types are not supported. Use all paths/base64 or all numpy arrays.")
|
| 201 |
+
if a.ndim > 1:
|
| 202 |
+
a = np.mean(a, axis=-1)
|
| 203 |
+
if int(sr) != target_sr:
|
| 204 |
+
a = librosa.resample(y=a.astype(np.float32), orig_sr=int(sr), target_sr=target_sr)
|
| 205 |
+
out.append(a.astype(np.float32))
|
| 206 |
+
return out
|
| 207 |
+
|
| 208 |
+
def encode(
|
| 209 |
+
self,
|
| 210 |
+
audios: AudioInput,
|
| 211 |
+
sr: Optional[int] = None,
|
| 212 |
+
return_dict: bool = True,
|
| 213 |
+
):
|
| 214 |
+
"""
|
| 215 |
+
Batch-encode audio into discrete codes (and optional conditioning, depending on 25Hz/12Hz).
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
audios (AudioInput):
|
| 219 |
+
Supported forms:
|
| 220 |
+
- np.ndarray: waveform (requires sr)
|
| 221 |
+
- list[np.ndarray]: waveforms (requires sr)
|
| 222 |
+
- str: wav path OR base64 audio string
|
| 223 |
+
- list[str]: wav paths and/or base64 strings
|
| 224 |
+
sr (Optional[int], default=None):
|
| 225 |
+
Original sampling rate for numpy waveform input.
|
| 226 |
+
return_dict (bool, default=True):
|
| 227 |
+
Forwarded to model.encode(...). If True, returns ModelOutput.
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
25Hz:
|
| 231 |
+
Qwen3TTSTokenizerV1EncoderOutput (if return_dict=True) with fields:
|
| 232 |
+
- audio_codes: List[torch.LongTensor] each (codes_len,)
|
| 233 |
+
- xvectors: List[torch.FloatTensor] each (xvector_dim,)
|
| 234 |
+
- ref_mels: List[torch.FloatTensor] each (mel_len, mel_dim)
|
| 235 |
+
12Hz:
|
| 236 |
+
Qwen3TTSTokenizerV2EncoderOutput (if return_dict=True) with fields:
|
| 237 |
+
- audio_codes: List[torch.LongTensor] each (codes_len, num_quantizers)
|
| 238 |
+
|
| 239 |
+
If return_dict=False, returns the raw tuple from model.encode.
|
| 240 |
+
"""
|
| 241 |
+
wavs = self._normalize_audio_inputs(audios, sr=sr)
|
| 242 |
+
|
| 243 |
+
inputs = self.feature_extractor(
|
| 244 |
+
raw_audio=wavs,
|
| 245 |
+
sampling_rate=int(self.feature_extractor.sampling_rate),
|
| 246 |
+
return_tensors="pt",
|
| 247 |
+
)
|
| 248 |
+
inputs = inputs.to(self.device).to(self.model.dtype)
|
| 249 |
+
|
| 250 |
+
with torch.inference_mode():
|
| 251 |
+
# model.encode expects (B, T) and (B, T)
|
| 252 |
+
enc = self.model.encode(
|
| 253 |
+
inputs["input_values"].squeeze(1),
|
| 254 |
+
inputs["padding_mask"].squeeze(1),
|
| 255 |
+
return_dict=return_dict,
|
| 256 |
+
)
|
| 257 |
+
return enc
|
| 258 |
+
|
| 259 |
+
def decode(
|
| 260 |
+
self,
|
| 261 |
+
encoded,
|
| 262 |
+
) -> Tuple[List[np.ndarray], int]:
|
| 263 |
+
"""
|
| 264 |
+
Decode back to waveform.
|
| 265 |
+
|
| 266 |
+
Usage:
|
| 267 |
+
1) Pass the raw output of `encode(...)` directly (recommended).
|
| 268 |
+
- 25Hz: expects fields audio_codes, xvectors, ref_mels
|
| 269 |
+
- 12Hz: expects field audio_codes
|
| 270 |
+
2) Pass a dict or list[dict] (minimal form) for custom pipelines:
|
| 271 |
+
- 25Hz dict keys: {"audio_codes", "xvectors", "ref_mels"}
|
| 272 |
+
- 12Hz dict keys: {"audio_codes"}
|
| 273 |
+
Values can be torch tensors or numpy arrays.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
encoded (Any):
|
| 277 |
+
- ModelOutput returned by `encode()`, OR
|
| 278 |
+
- dict, OR
|
| 279 |
+
- list[dict]
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
Tuple[List[np.ndarray], int]:
|
| 283 |
+
- wavs: list of 1-D float32 numpy arrays
|
| 284 |
+
- sample_rate: int, model output sampling rate
|
| 285 |
+
"""
|
| 286 |
+
model_type = self.model.get_model_type()
|
| 287 |
+
|
| 288 |
+
def _to_tensor(x, dtype=None):
|
| 289 |
+
if isinstance(x, torch.Tensor):
|
| 290 |
+
return x
|
| 291 |
+
x = np.asarray(x)
|
| 292 |
+
t = torch.from_numpy(x)
|
| 293 |
+
if dtype is not None:
|
| 294 |
+
t = t.to(dtype)
|
| 295 |
+
return t
|
| 296 |
+
|
| 297 |
+
# Normalize `encoded` into the same shapes as the official demo uses.
|
| 298 |
+
if hasattr(encoded, "audio_codes"):
|
| 299 |
+
# ModelOutput from encode()
|
| 300 |
+
audio_codes_list = encoded.audio_codes
|
| 301 |
+
xvectors_list = getattr(encoded, "xvectors", None)
|
| 302 |
+
ref_mels_list = getattr(encoded, "ref_mels", None)
|
| 303 |
+
elif isinstance(encoded, dict):
|
| 304 |
+
audio_codes_list = encoded["audio_codes"]
|
| 305 |
+
xvectors_list = encoded.get("xvectors", None)
|
| 306 |
+
ref_mels_list = encoded.get("ref_mels", None)
|
| 307 |
+
elif isinstance(encoded, list):
|
| 308 |
+
# list of dicts
|
| 309 |
+
audio_codes_list = [e["audio_codes"] for e in encoded]
|
| 310 |
+
xvectors_list = [e["xvectors"] for e in encoded] if ("xvectors" in encoded[0]) else None
|
| 311 |
+
ref_mels_list = [e["ref_mels"] for e in encoded] if ("ref_mels" in encoded[0]) else None
|
| 312 |
+
else:
|
| 313 |
+
raise TypeError("`encoded` must be an encode output, a dict, or a list of dicts.")
|
| 314 |
+
|
| 315 |
+
# Ensure list form for per-sample tensors
|
| 316 |
+
if isinstance(audio_codes_list, torch.Tensor):
|
| 317 |
+
# Could be a single sample tensor or an already padded batch tensor.
|
| 318 |
+
t = audio_codes_list
|
| 319 |
+
if t.dim() == 1:
|
| 320 |
+
# 25Hz single sample: (C,) -> (1, C)
|
| 321 |
+
t = t.unsqueeze(0)
|
| 322 |
+
elif t.dim() == 2:
|
| 323 |
+
# 12Hz single sample: (C, Q) -> (1, C, Q)
|
| 324 |
+
t = t.unsqueeze(0)
|
| 325 |
+
audio_codes_padded = t.to(self.device)
|
| 326 |
+
else:
|
| 327 |
+
# List[Tensor/np]
|
| 328 |
+
audio_codes_list = [_to_tensor(c, dtype=torch.long) for c in audio_codes_list]
|
| 329 |
+
audio_codes_padded = pad_sequence(audio_codes_list, batch_first=True, padding_value=0).to(self.device)
|
| 330 |
+
|
| 331 |
+
with torch.inference_mode():
|
| 332 |
+
if model_type == "qwen3_tts_tokenizer_25hz":
|
| 333 |
+
if xvectors_list is None or ref_mels_list is None:
|
| 334 |
+
raise ValueError("25Hz decode requires `xvectors` and `ref_mels`.")
|
| 335 |
+
|
| 336 |
+
if isinstance(xvectors_list, torch.Tensor):
|
| 337 |
+
xvectors_batch = xvectors_list
|
| 338 |
+
if xvectors_batch.dim() == 1: # (D,) -> (1, D)
|
| 339 |
+
xvectors_batch = xvectors_batch.unsqueeze(0)
|
| 340 |
+
xvectors_batch = xvectors_batch.to(self.device).to(self.model.dtype)
|
| 341 |
+
else:
|
| 342 |
+
xvectors_list = [_to_tensor(x, dtype=torch.float32) for x in xvectors_list]
|
| 343 |
+
xvectors_batch = torch.stack(xvectors_list, dim=0).to(self.device).to(self.model.dtype)
|
| 344 |
+
|
| 345 |
+
if isinstance(ref_mels_list, torch.Tensor):
|
| 346 |
+
ref_mels_padded = ref_mels_list
|
| 347 |
+
if ref_mels_padded.dim() == 2: # (T, M) -> (1, T, M)
|
| 348 |
+
ref_mels_padded = ref_mels_padded.unsqueeze(0)
|
| 349 |
+
ref_mels_padded = ref_mels_padded.to(self.device).to(self.model.dtype)
|
| 350 |
+
else:
|
| 351 |
+
ref_mels_list = [_to_tensor(m, dtype=torch.float32) for m in ref_mels_list]
|
| 352 |
+
ref_mels_padded = pad_sequence(ref_mels_list, batch_first=True, padding_value=0).to(self.device).to(self.model.dtype)
|
| 353 |
+
|
| 354 |
+
dec = self.model.decode(audio_codes_padded, xvectors_batch, ref_mels_padded, return_dict=True)
|
| 355 |
+
wav_tensors = dec.audio_values
|
| 356 |
+
|
| 357 |
+
elif model_type == "qwen3_tts_tokenizer_12hz":
|
| 358 |
+
dec = self.model.decode(audio_codes_padded, return_dict=True)
|
| 359 |
+
wav_tensors = dec.audio_values
|
| 360 |
+
|
| 361 |
+
else:
|
| 362 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
| 363 |
+
|
| 364 |
+
wavs = [w.to(torch.float32).detach().cpu().numpy() for w in wav_tensors]
|
| 365 |
+
return wavs, int(self.model.get_output_sample_rate())
|
| 366 |
+
|
| 367 |
+
def get_model_type(self) -> str:
|
| 368 |
+
"""
|
| 369 |
+
Get the underlying tokenizer model type.
|
| 370 |
+
|
| 371 |
+
Returns:
|
| 372 |
+
str: Model type string from `self.model.config.model_type`
|
| 373 |
+
(e.g. "qwen3_tts_tokenizer_25hz" / "qwen3_tts_tokenizer_12hz").
|
| 374 |
+
"""
|
| 375 |
+
return self.model.get_model_type()
|
| 376 |
+
|
| 377 |
+
def get_input_sample_rate(self) -> int:
|
| 378 |
+
"""
|
| 379 |
+
Get the expected input sample rate for encoding.
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
int: Input sample rate (Hz).
|
| 383 |
+
"""
|
| 384 |
+
return int(self.model.get_input_sample_rate())
|
| 385 |
+
|
| 386 |
+
def get_output_sample_rate(self) -> int:
|
| 387 |
+
"""
|
| 388 |
+
Get the output sample rate for decoded waveforms.
|
| 389 |
+
|
| 390 |
+
Returns:
|
| 391 |
+
int: Output sample rate (Hz).
|
| 392 |
+
"""
|
| 393 |
+
return int(self.model.get_output_sample_rate())
|
| 394 |
+
|
| 395 |
+
def get_encode_downsample_rate(self) -> int:
|
| 396 |
+
"""
|
| 397 |
+
Get the encoder downsample rate (waveform samples per code step).
|
| 398 |
+
|
| 399 |
+
Returns:
|
| 400 |
+
int: Encode downsample rate.
|
| 401 |
+
"""
|
| 402 |
+
return int(self.model.get_encode_downsample_rate())
|
| 403 |
+
|
| 404 |
+
def get_decode_upsample_rate(self) -> int:
|
| 405 |
+
"""
|
| 406 |
+
Get the decoder upsample rate (waveform samples per code step).
|
| 407 |
+
|
| 408 |
+
Returns:
|
| 409 |
+
int: Decode upsample rate.
|
| 410 |
+
"""
|
| 411 |
+
return int(self.model.get_decode_upsample_rate())
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python-dotenv
|
| 2 |
+
torch==2.8.0
|
| 3 |
+
torchaudio==2.8.0
|
| 4 |
+
transformers==4.57.3
|
| 5 |
+
accelerate==1.12.0
|
| 6 |
+
einops
|
| 7 |
+
gradio
|
| 8 |
+
librosa
|
| 9 |
+
soundfile
|
| 10 |
+
sox
|
| 11 |
+
onnxruntime
|
| 12 |
+
spaces
|
| 13 |
+
numpy
|
| 14 |
+
kernels
|
| 15 |
+
openai-whisper
|