Text-to-Speech
Transformers
Safetensors
moss_tts_local
image-feature-extraction
voice-cloning
custom_code
moss-tts
moss-tts-local
Instructions to use OpenMOSS-Team/MOSS-TTS-Local-Transformer-v1.5 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use OpenMOSS-Team/MOSS-TTS-Local-Transformer-v1.5 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-to-speech", model="OpenMOSS-Team/MOSS-TTS-Local-Transformer-v1.5", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("OpenMOSS-Team/MOSS-TTS-Local-Transformer-v1.5", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Add files using upload-large-folder tool
Browse files- .gitattributes +1 -0
- README.md +278 -0
- __init__.py +9 -0
- added_tokens.json +31 -0
- chat_template.jinja +89 -0
- config.json +381 -0
- configuration_moss_tts.py +158 -0
- gpt2_decoder.py +721 -0
- merges.txt +0 -0
- model.safetensors +3 -0
- modeling_moss_tts.py +623 -0
- processing_moss_tts.py +899 -0
- processor_config.json +7 -0
- qwen3_decoder.py +582 -0
- special_tokens_map.json +21 -0
- tokenizer.json +3 -0
- tokenizer_config.json +253 -0
- vocab.json +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,281 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
+
library_name: transformers
|
| 4 |
+
pipeline_tag: text-to-speech
|
| 5 |
+
tags:
|
| 6 |
+
- text-to-speech
|
| 7 |
+
- voice-cloning
|
| 8 |
+
- custom_code
|
| 9 |
+
- moss-tts
|
| 10 |
+
- moss-tts-local
|
| 11 |
+
- arxiv:2603.18090
|
| 12 |
+
language:
|
| 13 |
+
- zh
|
| 14 |
+
- yue
|
| 15 |
+
- en
|
| 16 |
+
- ar
|
| 17 |
+
- cs
|
| 18 |
+
- da
|
| 19 |
+
- de
|
| 20 |
+
- nl
|
| 21 |
+
- es
|
| 22 |
+
- fr
|
| 23 |
+
- fi
|
| 24 |
+
- el
|
| 25 |
+
- he
|
| 26 |
+
- hi
|
| 27 |
+
- hu
|
| 28 |
+
- ja
|
| 29 |
+
- it
|
| 30 |
+
- ko
|
| 31 |
+
- mk
|
| 32 |
+
- ms
|
| 33 |
+
- ru
|
| 34 |
+
- fa
|
| 35 |
+
- pl
|
| 36 |
+
- pt
|
| 37 |
+
- sv
|
| 38 |
+
- ro
|
| 39 |
+
- sw
|
| 40 |
+
- tl
|
| 41 |
+
- th
|
| 42 |
+
- tr
|
| 43 |
+
- vi
|
| 44 |
---
|
| 45 |
+
# MOSS-TTS Family
|
| 46 |
+
|
| 47 |
+
<br>
|
| 48 |
+
|
| 49 |
+
<p align="center">
|
| 50 |
+
|
| 51 |
+
<img src="https://speech-demo.oss-cn-shanghai.aliyuncs.com/moss_tts_demo/tts_readme_imgaes_demo/openmoss_x_mosi" height="50" align="middle" />
|
| 52 |
+
</p>
|
| 53 |
+
|
| 54 |
+
<div align="center">
|
| 55 |
+
<a href="https://github.com/OpenMOSS/MOSS-TTS/tree/main"><img src="https://img.shields.io/badge/Project%20Page-GitHub-blue"></a>
|
| 56 |
+
<a href="https://huggingface.co/OpenMOSS-Team/MOSS-TTS-Local-Transformer-v1.5"><img src="https://img.shields.io/badge/HuggingFace-Model-yellow?logo=huggingface"></a>
|
| 57 |
+
<a href="https://modelscope.cn/collections/OpenMOSS-Team/MOSS-TTS"><img src="https://img.shields.io/badge/ModelScope-Models-lightgrey?logo=modelscope&"></a>
|
| 58 |
+
<a href="https://mosi.cn/#models"><img src="https://img.shields.io/badge/Blog-View-blue?logo=internet-explorer&"></a>
|
| 59 |
+
|
| 60 |
+
<a href="https://arxiv.org/abs/2603.18090"><img src="https://img.shields.io/badge/Arxiv-2603.18090-red?logo=Arxiv&"></a>
|
| 61 |
+
<a href="https://studio.mosi.cn"><img src="https://img.shields.io/badge/AIStudio-Try-green?logo=internet-explorer&"></a>
|
| 62 |
+
<a href="https://studio.mosi.cn/docs/moss-tts"><img src="https://img.shields.io/badge/API-Docs-00A3FF?logo=fastapi&"></a>
|
| 63 |
+
<a href="https://x.com/Open_MOSS"><img src="https://img.shields.io/badge/Twitter-Follow-black?logo=x&"></a>
|
| 64 |
+
<a href="https://discord.gg/fvm5TaWjU3"><img src="https://img.shields.io/badge/Discord-Join-5865F2?logo=discord&"></a>
|
| 65 |
+
</div>
|
| 66 |
+
|
| 67 |
+
# MOSS-TTS-Local-Transformer-v1.5
|
| 68 |
+
|
| 69 |
+
**MOSS-TTS-Local-Transformer-v1.5** is continued from [MOSS-TTS-Local-Transformer-v1.0](https://huggingface.co/OpenMOSS-Team/MOSS-TTS-Local-Transformer). It preserves the main 1.0 capabilities, including zero-shot voice cloning, long-form speech generation, token-level duration control, Pinyin/IPA pronunciation control, multilingual synthesis, and code-switching. For the full 1.0 feature walkthrough, input schema, and evaluation tables, please refer to the [MOSS-TTS-Local-Transformer-v1.0 README](https://huggingface.co/OpenMOSS-Team/MOSS-TTS-Local-Transformer).
|
| 70 |
+
|
| 71 |
+
Compared with [MOSS-TTS-Local-Transformer-v1.0](https://huggingface.co/OpenMOSS-Team/MOSS-TTS-Local-Transformer), v1.5 focuses on the following improvements:
|
| 72 |
+
- **Higher-fidelity stereo audio modeling**: v1.5 uses [MOSS-Audio-Tokenizer-v2](https://huggingface.co/OpenMOSS-Team/MOSS-Audio-Tokenizer-v2) as the audio tokenizer, supporting native 48 kHz stereo input and output for richer spatial detail and more natural perceived audio quality. Since the codec output is stereo, save the `[channels, samples]` tensor returned by `processor.decode(...)` directly.
|
| 73 |
+
- **Stronger multilingual synthesis with language tags**: when the `language` field is omitted, v1.5 may improve some languages and regress slightly on others compared with 1.0. When the language is specified, v1.5 is stronger than 1.0 on almost all supported languages. Set the tag when building the user message, for example `processor.build_user_message(text=text_fr, language="French")`.
|
| 74 |
+
- **More stable voice cloning**: v1.5 improves speaker similarity and reduces cloning variance, making repeated generations more consistent.
|
| 75 |
+
- **Better long-reference, short-text cloning**: v1.5 handles scenarios where the reference audio is much longer than the target text more reliably than 1.0.
|
| 76 |
+
- **More stable punctuation-following prosody**: v1.5 follows punctuation-driven pauses more closely, especially in long sentences.
|
| 77 |
+
- **Explicit pause control**: v1.5 supports inline pause markers such as `"[pause 3.2s]"`. For example, `我今天学习了一首中国的古诗,它的名字是[pause 3.2s]静夜思!` inserts an explicit 3.2s pause before `静夜思`.
|
| 78 |
+
|
| 79 |
+
## Supported Languages
|
| 80 |
+
|
| 81 |
+
MOSS-TTS Local Transformer v1.5 supports **31 languages**. It keeps the 20 languages supported by [MOSS-TTS-Local-Transformer-v1.0](https://huggingface.co/OpenMOSS-Team/MOSS-TTS-Local-Transformer) and extends multilingual continued training to additional languages including Cantonese, Dutch, Finnish, Hindi, Macedonian, Malay, Romanian, Swahili, Tagalog, Thai, and Vietnamese.
|
| 82 |
+
|
| 83 |
+
| Language | Code | Flag | Language | Code | Flag | Language | Code | Flag |
|
| 84 |
+
|---|---|---|---|---|---|---|---|---|
|
| 85 |
+
| Chinese | zh | 🇨🇳 | Cantonese | yue | 🇭🇰 | English | en | 🇺🇸 |
|
| 86 |
+
| Arabic | ar | 🇸🇦 | Czech | cs | 🇨🇿 | Danish | da | 🇩🇰 |
|
| 87 |
+
| Dutch | nl | 🇳🇱 | Finnish | fi | 🇫🇮 | French | fr | 🇫🇷 |
|
| 88 |
+
| German | de | 🇩🇪 | Greek | el | 🇬🇷 | Hebrew | he | 🇮🇱 |
|
| 89 |
+
| Hindi | hi | 🇮🇳 | Hungarian | hu | 🇭🇺 | Italian | it | 🇮🇹 |
|
| 90 |
+
| Japanese | ja | 🇯🇵 | Korean | ko | 🇰🇷 | Macedonian | mk | 🇲🇰 |
|
| 91 |
+
| Malay | ms | 🇲🇾 | Persian (Farsi) | fa | 🇮🇷 | Polish | pl | 🇵🇱 |
|
| 92 |
+
| Portuguese | pt | 🇵🇹 | Romanian | ro | 🇷🇴 | Russian | ru | 🇷🇺 |
|
| 93 |
+
| Spanish | es | 🇪🇸 | Swahili | sw | 🇹🇿 | Swedish | sv | 🇸🇪 |
|
| 94 |
+
| Tagalog | tl | 🇵🇭 | Thai | th | 🇹🇭 | Turkish | tr | 🇹🇷 |
|
| 95 |
+
| Vietnamese | vi | 🇻🇳 | | | | | | |
|
| 96 |
+
|
| 97 |
+
## Quick Start
|
| 98 |
+
|
| 99 |
+
### Environment Setup
|
| 100 |
+
|
| 101 |
+
We recommend a clean, isolated Python environment with **Transformers 5.0.0**, or a recent Transformers version with Qwen3 support, to avoid dependency conflicts.
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
conda create -n moss-tts python=3.12 -y
|
| 105 |
+
conda activate moss-tts
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
Install all required dependencies:
|
| 109 |
+
|
| 110 |
+
```bash
|
| 111 |
+
git clone https://github.com/OpenMOSS/MOSS-TTS.git
|
| 112 |
+
cd MOSS-TTS
|
| 113 |
+
pip install --extra-index-url https://download.pytorch.org/whl/cu128 -e .
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
#### (Optional) Install FlashAttention 2
|
| 117 |
+
|
| 118 |
+
For better speed and lower GPU memory usage, you can install FlashAttention 2 if your hardware supports it.
|
| 119 |
+
|
| 120 |
+
```bash
|
| 121 |
+
pip install --extra-index-url https://download.pytorch.org/whl/cu128 -e ".[flash-attn]"
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
If your machine has limited RAM and many CPU cores, you can cap build parallelism:
|
| 125 |
+
|
| 126 |
+
```bash
|
| 127 |
+
MAX_JOBS=4 pip install --extra-index-url https://download.pytorch.org/whl/cu128 -e ".[flash-attn]"
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
Notes:
|
| 131 |
+
- Dependencies are managed in `pyproject.toml`, which currently pins `torch==2.9.1+cu128` and `torchaudio==2.9.1+cu128`.
|
| 132 |
+
- If FlashAttention 2 fails to build on your machine, you can skip it and use the default attention backend.
|
| 133 |
+
- FlashAttention 2 is only available on supported GPUs and is typically used with `torch.float16` or `torch.bfloat16`.
|
| 134 |
+
|
| 135 |
+
### Basic Usage
|
| 136 |
+
|
| 137 |
+
> Tip: MOSS-TTS-Local-Transformer-v1.5 uses a fixed 12-codebook RVQ depth. Do not set `n_vq_for_inference` to a value different from `config.n_vq`.
|
| 138 |
+
|
| 139 |
+
MOSS-TTS-Local-Transformer-v1.5 provides the standard Hugging Face `AutoProcessor` and `AutoModel` interface. The examples below cover:
|
| 140 |
+
1. Direct generation with language tags
|
| 141 |
+
2. Voice cloning
|
| 142 |
+
3. Duration control
|
| 143 |
+
4. Explicit pause control with `[pause X.Ys]`
|
| 144 |
+
|
| 145 |
+
```python
|
| 146 |
+
from pathlib import Path
|
| 147 |
+
from tqdm import tqdm
|
| 148 |
+
import importlib.util
|
| 149 |
+
|
| 150 |
+
import torch
|
| 151 |
+
import torchaudio
|
| 152 |
+
from transformers import AutoModel, AutoProcessor
|
| 153 |
+
|
| 154 |
+
# Disable the broken cuDNN SDPA backend on some CUDA/PyTorch combinations.
|
| 155 |
+
torch.backends.cuda.enable_cudnn_sdp(False)
|
| 156 |
+
# Keep these enabled as fallbacks.
|
| 157 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
| 158 |
+
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
| 159 |
+
torch.backends.cuda.enable_math_sdp(True)
|
| 160 |
+
|
| 161 |
+
pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS-Local-Transformer-v1.5"
|
| 162 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 163 |
+
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def resolve_attn_implementation() -> str:
|
| 167 |
+
# Prefer FlashAttention 2 when package + device conditions are met.
|
| 168 |
+
if (
|
| 169 |
+
device == "cuda"
|
| 170 |
+
and importlib.util.find_spec("flash_attn") is not None
|
| 171 |
+
and dtype in {torch.float16, torch.bfloat16}
|
| 172 |
+
):
|
| 173 |
+
major, _ = torch.cuda.get_device_capability()
|
| 174 |
+
if major >= 8:
|
| 175 |
+
return "flash_attention_2"
|
| 176 |
+
# CUDA fallback: use PyTorch SDPA kernels.
|
| 177 |
+
if device == "cuda":
|
| 178 |
+
return "sdpa"
|
| 179 |
+
# CPU fallback.
|
| 180 |
+
return "eager"
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
attn_implementation = resolve_attn_implementation()
|
| 184 |
+
print(f"[INFO] Using attn_implementation={attn_implementation}")
|
| 185 |
+
|
| 186 |
+
processor = AutoProcessor.from_pretrained(
|
| 187 |
+
pretrained_model_name_or_path,
|
| 188 |
+
trust_remote_code=True,
|
| 189 |
+
)
|
| 190 |
+
processor.audio_tokenizer = processor.audio_tokenizer.to(device)
|
| 191 |
+
|
| 192 |
+
text_zh = "亲爱的你,愿你的每一天都值得被记住,也值得被珍惜。"
|
| 193 |
+
text_en = "We stand on the threshold of the AI era, where intelligence becomes an extension of human creativity."
|
| 194 |
+
text_fr = "Bonjour, je voudrais essayer une voix francaise naturelle et stable."
|
| 195 |
+
text_pause = "我今天学习了一首中国的古诗,它的名字是[pause 3.2s]静夜思!"
|
| 196 |
+
|
| 197 |
+
# Use remote demo audio to avoid requiring local assets.
|
| 198 |
+
ref_audio_zh = "https://speech-demo.oss-cn-shanghai.aliyuncs.com/moss_tts_demo/tts_readme_demo/reference_zh.wav"
|
| 199 |
+
ref_audio_en = "https://speech-demo.oss-cn-shanghai.aliyuncs.com/moss_tts_demo/tts_readme_demo/reference_en.m4a"
|
| 200 |
+
|
| 201 |
+
conversations = [
|
| 202 |
+
# Direct TTS. Language tags are recommended in v1.5 when the language is known.
|
| 203 |
+
[processor.build_user_message(text=text_zh, language="Chinese")],
|
| 204 |
+
[processor.build_user_message(text=text_en, language="English")],
|
| 205 |
+
[processor.build_user_message(text=text_fr, language="French")],
|
| 206 |
+
# Explicit pause control. Use [pause X.Ys], such as [pause 3.2s].
|
| 207 |
+
[processor.build_user_message(text=text_pause, language="Chinese")],
|
| 208 |
+
# Voice cloning with a reference audio.
|
| 209 |
+
[processor.build_user_message(text=text_zh, reference=[ref_audio_zh], language="Chinese")],
|
| 210 |
+
[processor.build_user_message(text=text_en, reference=[ref_audio_en], language="English")],
|
| 211 |
+
# Duration control. At 12.5 frames per second, 125 frames is about 10 seconds.
|
| 212 |
+
[processor.build_user_message(text=text_en, tokens=125, language="English")],
|
| 213 |
+
]
|
| 214 |
+
|
| 215 |
+
model = AutoModel.from_pretrained(
|
| 216 |
+
pretrained_model_name_or_path,
|
| 217 |
+
trust_remote_code=True,
|
| 218 |
+
attn_implementation=attn_implementation,
|
| 219 |
+
torch_dtype=dtype,
|
| 220 |
+
).to(device)
|
| 221 |
+
model.eval()
|
| 222 |
+
|
| 223 |
+
batch_size = 1
|
| 224 |
+
save_dir = Path("inference_root_moss_tts_local_v1_5")
|
| 225 |
+
save_dir.mkdir(exist_ok=True, parents=True)
|
| 226 |
+
sample_idx = 0
|
| 227 |
+
|
| 228 |
+
with torch.no_grad():
|
| 229 |
+
for start in tqdm(range(0, len(conversations), batch_size)):
|
| 230 |
+
batch_conversations = conversations[start : start + batch_size]
|
| 231 |
+
batch = processor(batch_conversations, mode="generation")
|
| 232 |
+
input_ids = batch["input_ids"].to(device)
|
| 233 |
+
attention_mask = batch["attention_mask"].to(device)
|
| 234 |
+
|
| 235 |
+
outputs = model.generate(
|
| 236 |
+
input_ids=input_ids,
|
| 237 |
+
attention_mask=attention_mask,
|
| 238 |
+
max_new_tokens=4096,
|
| 239 |
+
do_sample=True,
|
| 240 |
+
audio_temperature=1.7,
|
| 241 |
+
audio_top_p=0.8,
|
| 242 |
+
audio_top_k=25,
|
| 243 |
+
audio_repetition_penalty=1.0,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
for message in processor.decode(outputs):
|
| 247 |
+
if message is None:
|
| 248 |
+
continue
|
| 249 |
+
audio = message.audio_codes_list[0]
|
| 250 |
+
out_path = save_dir / f"sample{sample_idx}.wav"
|
| 251 |
+
sample_idx += 1
|
| 252 |
+
# MOSS-TTS Local v1.5 codec returns stereo audio as [channels, samples].
|
| 253 |
+
# Save the two-channel tensor directly.
|
| 254 |
+
torchaudio.save(str(out_path), audio, processor.model_config.sampling_rate)
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
+
## Generation Parameters
|
| 258 |
+
|
| 259 |
+
| Parameter | Recommended | Description |
|
| 260 |
+
|---|---:|---|
|
| 261 |
+
| `audio_temperature` | `1.7` | Sampling temperature for audio RVQ layers. |
|
| 262 |
+
| `audio_top_p` | `0.8` | Nucleus sampling cutoff for audio RVQ layers. |
|
| 263 |
+
| `audio_top_k` | `25` | Top-k sampling cutoff for audio RVQ layers. |
|
| 264 |
+
| `audio_repetition_penalty` | `1.0` | Penalty for repeated acoustic token patterns. |
|
| 265 |
+
| `n_vq_for_inference` | `12` | Fixed by this release. Values other than `config.n_vq` are rejected. |
|
| 266 |
+
|
| 267 |
+
## Notes
|
| 268 |
+
|
| 269 |
+
- This repository uses Hugging Face remote code. Load it with `trust_remote_code=True`.
|
| 270 |
+
- The MOSS-TTS-Local-Transformer-v1.5 codec is stereo. `processor.decode(...)` returns audio tensors shaped as `[channels, samples]`, so save them directly with `torchaudio.save(path, audio, sampling_rate)`.
|
| 271 |
+
- Audio encoding and decoding use `OpenMOSS-Team/MOSS-Audio-Tokenizer-v2`.
|
| 272 |
+
- The model configuration sets `sampling_rate` to 48000 and `n_vq` to 12.
|
| 273 |
+
- If FlashAttention 2 is unavailable, the example falls back to SDPA on CUDA and eager attention on CPU.
|
| 274 |
+
|
| 275 |
+
## More Usage
|
| 276 |
+
|
| 277 |
+
MOSS-TTS-Local-Transformer-v1.5 is API-compatible with MOSS-TTS-Local-Transformer-v1.0. For continuation with prefix audio, detailed `UserMessage` and `AssistantMessage` fields, generation hyperparameters, Pinyin/IPA preprocessing examples, and evaluation results, see the [MOSS-TTS-Local-Transformer-v1.0](https://huggingface.co/OpenMOSS-Team/MOSS-TTS-Local-Transformer).
|
| 278 |
+
|
| 279 |
+
## Citation
|
| 280 |
+
|
| 281 |
+
If you use this model, please cite the [MOSS-TTS Technical Report](https://arxiv.org/abs/2603.18090).
|
__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .configuration_moss_tts import MossTTSLocalConfig
|
| 2 |
+
from .modeling_moss_tts import MossTTSLocalModel
|
| 3 |
+
from .processing_moss_tts import MossTTSLocalProcessor
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"MossTTSLocalConfig",
|
| 7 |
+
"MossTTSLocalModel",
|
| 8 |
+
"MossTTSLocalProcessor",
|
| 9 |
+
]
|
added_tokens.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"</think>": 151668,
|
| 3 |
+
"</tool_call>": 151658,
|
| 4 |
+
"</tool_response>": 151666,
|
| 5 |
+
"<think>": 151667,
|
| 6 |
+
"<tool_call>": 151657,
|
| 7 |
+
"<tool_response>": 151665,
|
| 8 |
+
"<|audio_end|>": 151670,
|
| 9 |
+
"<|audio_pad|>": 151671,
|
| 10 |
+
"<|audio_start|>": 151669,
|
| 11 |
+
"<|box_end|>": 151649,
|
| 12 |
+
"<|box_start|>": 151648,
|
| 13 |
+
"<|endoftext|>": 151643,
|
| 14 |
+
"<|file_sep|>": 151664,
|
| 15 |
+
"<|fim_middle|>": 151660,
|
| 16 |
+
"<|fim_pad|>": 151662,
|
| 17 |
+
"<|fim_prefix|>": 151659,
|
| 18 |
+
"<|fim_suffix|>": 151661,
|
| 19 |
+
"<|im_end|>": 151645,
|
| 20 |
+
"<|im_start|>": 151644,
|
| 21 |
+
"<|image_pad|>": 151655,
|
| 22 |
+
"<|object_ref_end|>": 151647,
|
| 23 |
+
"<|object_ref_start|>": 151646,
|
| 24 |
+
"<|quad_end|>": 151651,
|
| 25 |
+
"<|quad_start|>": 151650,
|
| 26 |
+
"<|repo_name|>": 151663,
|
| 27 |
+
"<|video_pad|>": 151656,
|
| 28 |
+
"<|vision_end|>": 151653,
|
| 29 |
+
"<|vision_pad|>": 151654,
|
| 30 |
+
"<|vision_start|>": 151652
|
| 31 |
+
}
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- if tools %}
|
| 2 |
+
{{- '<|im_start|>system\n' }}
|
| 3 |
+
{%- if messages[0].role == 'system' %}
|
| 4 |
+
{{- messages[0].content + '\n\n' }}
|
| 5 |
+
{%- endif %}
|
| 6 |
+
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
| 7 |
+
{%- for tool in tools %}
|
| 8 |
+
{{- "\n" }}
|
| 9 |
+
{{- tool | tojson }}
|
| 10 |
+
{%- endfor %}
|
| 11 |
+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
| 12 |
+
{%- else %}
|
| 13 |
+
{%- if messages[0].role == 'system' %}
|
| 14 |
+
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
| 15 |
+
{%- endif %}
|
| 16 |
+
{%- endif %}
|
| 17 |
+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
| 18 |
+
{%- for message in messages[::-1] %}
|
| 19 |
+
{%- set index = (messages|length - 1) - loop.index0 %}
|
| 20 |
+
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
| 21 |
+
{%- set ns.multi_step_tool = false %}
|
| 22 |
+
{%- set ns.last_query_index = index %}
|
| 23 |
+
{%- endif %}
|
| 24 |
+
{%- endfor %}
|
| 25 |
+
{%- for message in messages %}
|
| 26 |
+
{%- if message.content is string %}
|
| 27 |
+
{%- set content = message.content %}
|
| 28 |
+
{%- else %}
|
| 29 |
+
{%- set content = '' %}
|
| 30 |
+
{%- endif %}
|
| 31 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
| 32 |
+
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
| 33 |
+
{%- elif message.role == "assistant" %}
|
| 34 |
+
{%- set reasoning_content = '' %}
|
| 35 |
+
{%- if message.reasoning_content is string %}
|
| 36 |
+
{%- set reasoning_content = message.reasoning_content %}
|
| 37 |
+
{%- else %}
|
| 38 |
+
{%- if '</think>' in content %}
|
| 39 |
+
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
| 40 |
+
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
| 41 |
+
{%- endif %}
|
| 42 |
+
{%- endif %}
|
| 43 |
+
{%- if loop.index0 > ns.last_query_index %}
|
| 44 |
+
{%- if loop.last or (not loop.last and reasoning_content) %}
|
| 45 |
+
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
| 46 |
+
{%- else %}
|
| 47 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 48 |
+
{%- endif %}
|
| 49 |
+
{%- else %}
|
| 50 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 51 |
+
{%- endif %}
|
| 52 |
+
{%- if message.tool_calls %}
|
| 53 |
+
{%- for tool_call in message.tool_calls %}
|
| 54 |
+
{%- if (loop.first and content) or (not loop.first) %}
|
| 55 |
+
{{- '\n' }}
|
| 56 |
+
{%- endif %}
|
| 57 |
+
{%- if tool_call.function %}
|
| 58 |
+
{%- set tool_call = tool_call.function %}
|
| 59 |
+
{%- endif %}
|
| 60 |
+
{{- '<tool_call>\n{"name": "' }}
|
| 61 |
+
{{- tool_call.name }}
|
| 62 |
+
{{- '", "arguments": ' }}
|
| 63 |
+
{%- if tool_call.arguments is string %}
|
| 64 |
+
{{- tool_call.arguments }}
|
| 65 |
+
{%- else %}
|
| 66 |
+
{{- tool_call.arguments | tojson }}
|
| 67 |
+
{%- endif %}
|
| 68 |
+
{{- '}\n</tool_call>' }}
|
| 69 |
+
{%- endfor %}
|
| 70 |
+
{%- endif %}
|
| 71 |
+
{{- '<|im_end|>\n' }}
|
| 72 |
+
{%- elif message.role == "tool" %}
|
| 73 |
+
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
| 74 |
+
{{- '<|im_start|>user' }}
|
| 75 |
+
{%- endif %}
|
| 76 |
+
{{- '\n<tool_response>\n' }}
|
| 77 |
+
{{- content }}
|
| 78 |
+
{{- '\n</tool_response>' }}
|
| 79 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 80 |
+
{{- '<|im_end|>\n' }}
|
| 81 |
+
{%- endif %}
|
| 82 |
+
{%- endif %}
|
| 83 |
+
{%- endfor %}
|
| 84 |
+
{%- if add_generation_prompt %}
|
| 85 |
+
{{- '<|im_start|>assistant\n' }}
|
| 86 |
+
{%- if enable_thinking is defined and enable_thinking is false %}
|
| 87 |
+
{{- '<think>\n\n</think>\n\n' }}
|
| 88 |
+
{%- endif %}
|
| 89 |
+
{%- endif %}
|
config.json
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "moss_tts_local",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"MossTTSLocalModel"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_moss_tts.MossTTSLocalConfig",
|
| 8 |
+
"AutoModel": "modeling_moss_tts.MossTTSLocalModel",
|
| 9 |
+
"AutoProcessor": "processing_moss_tts.MossTTSLocalProcessor"
|
| 10 |
+
},
|
| 11 |
+
"processor_class": "MossTTSLocalProcessor",
|
| 12 |
+
"qwen3_config": {
|
| 13 |
+
"_name_or_path": "",
|
| 14 |
+
"add_cross_attention": false,
|
| 15 |
+
"architectures": [
|
| 16 |
+
"Qwen3ForCausalLM"
|
| 17 |
+
],
|
| 18 |
+
"attention_bias": false,
|
| 19 |
+
"attention_dropout": 0.0,
|
| 20 |
+
"bad_words_ids": null,
|
| 21 |
+
"begin_suppress_tokens": null,
|
| 22 |
+
"bos_token_id": 151643,
|
| 23 |
+
"chunk_size_feed_forward": 0,
|
| 24 |
+
"cross_attention_hidden_size": null,
|
| 25 |
+
"decoder_start_token_id": null,
|
| 26 |
+
"diversity_penalty": 0.0,
|
| 27 |
+
"do_sample": false,
|
| 28 |
+
"dtype": "bfloat16",
|
| 29 |
+
"early_stopping": false,
|
| 30 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 31 |
+
"eos_token_id": 151643,
|
| 32 |
+
"exponential_decay_length_penalty": null,
|
| 33 |
+
"finetuning_task": null,
|
| 34 |
+
"forced_bos_token_id": null,
|
| 35 |
+
"forced_eos_token_id": null,
|
| 36 |
+
"gradient_checkpointing_use_reentrant": false,
|
| 37 |
+
"head_dim": 128,
|
| 38 |
+
"hidden_act": "silu",
|
| 39 |
+
"hidden_size": 2560,
|
| 40 |
+
"id2label": {
|
| 41 |
+
"0": "LABEL_0",
|
| 42 |
+
"1": "LABEL_1"
|
| 43 |
+
},
|
| 44 |
+
"initializer_range": 0.02,
|
| 45 |
+
"intermediate_size": 9728,
|
| 46 |
+
"is_decoder": false,
|
| 47 |
+
"is_encoder_decoder": false,
|
| 48 |
+
"label2id": {
|
| 49 |
+
"LABEL_0": 0,
|
| 50 |
+
"LABEL_1": 1
|
| 51 |
+
},
|
| 52 |
+
"layer_types": [
|
| 53 |
+
"full_attention",
|
| 54 |
+
"full_attention",
|
| 55 |
+
"full_attention",
|
| 56 |
+
"full_attention",
|
| 57 |
+
"full_attention",
|
| 58 |
+
"full_attention",
|
| 59 |
+
"full_attention",
|
| 60 |
+
"full_attention",
|
| 61 |
+
"full_attention",
|
| 62 |
+
"full_attention",
|
| 63 |
+
"full_attention",
|
| 64 |
+
"full_attention",
|
| 65 |
+
"full_attention",
|
| 66 |
+
"full_attention",
|
| 67 |
+
"full_attention",
|
| 68 |
+
"full_attention",
|
| 69 |
+
"full_attention",
|
| 70 |
+
"full_attention",
|
| 71 |
+
"full_attention",
|
| 72 |
+
"full_attention",
|
| 73 |
+
"full_attention",
|
| 74 |
+
"full_attention",
|
| 75 |
+
"full_attention",
|
| 76 |
+
"full_attention",
|
| 77 |
+
"full_attention",
|
| 78 |
+
"full_attention",
|
| 79 |
+
"full_attention",
|
| 80 |
+
"full_attention",
|
| 81 |
+
"full_attention",
|
| 82 |
+
"full_attention",
|
| 83 |
+
"full_attention",
|
| 84 |
+
"full_attention",
|
| 85 |
+
"full_attention",
|
| 86 |
+
"full_attention",
|
| 87 |
+
"full_attention",
|
| 88 |
+
"full_attention"
|
| 89 |
+
],
|
| 90 |
+
"length_penalty": 1.0,
|
| 91 |
+
"max_length": 20,
|
| 92 |
+
"max_position_embeddings": 32768,
|
| 93 |
+
"max_window_layers": 36,
|
| 94 |
+
"min_length": 0,
|
| 95 |
+
"model_type": "qwen3",
|
| 96 |
+
"no_repeat_ngram_size": 0,
|
| 97 |
+
"num_attention_heads": 32,
|
| 98 |
+
"num_beam_groups": 1,
|
| 99 |
+
"num_beams": 1,
|
| 100 |
+
"num_hidden_layers": 36,
|
| 101 |
+
"num_key_value_heads": 8,
|
| 102 |
+
"num_return_sequences": 1,
|
| 103 |
+
"output_attentions": false,
|
| 104 |
+
"output_hidden_states": false,
|
| 105 |
+
"output_scores": false,
|
| 106 |
+
"pad_token_id": 151643,
|
| 107 |
+
"prefix": null,
|
| 108 |
+
"problem_type": null,
|
| 109 |
+
"pruned_heads": {},
|
| 110 |
+
"remove_invalid_values": false,
|
| 111 |
+
"repetition_penalty": 1.0,
|
| 112 |
+
"return_dict": true,
|
| 113 |
+
"return_dict_in_generate": false,
|
| 114 |
+
"rms_norm_eps": 1e-06,
|
| 115 |
+
"rope_scaling": null,
|
| 116 |
+
"rope_theta": 1000000,
|
| 117 |
+
"sep_token_id": null,
|
| 118 |
+
"sliding_window": null,
|
| 119 |
+
"suppress_tokens": null,
|
| 120 |
+
"task_specific_params": null,
|
| 121 |
+
"temperature": 1.0,
|
| 122 |
+
"tf_legacy_loss": false,
|
| 123 |
+
"tie_encoder_decoder": false,
|
| 124 |
+
"tie_word_embeddings": true,
|
| 125 |
+
"tokenizer_class": null,
|
| 126 |
+
"top_k": 50,
|
| 127 |
+
"top_p": 1.0,
|
| 128 |
+
"torchscript": false,
|
| 129 |
+
"transformers_version": "4.57.1",
|
| 130 |
+
"typical_p": 1.0,
|
| 131 |
+
"use_bfloat16": false,
|
| 132 |
+
"use_cache": false,
|
| 133 |
+
"use_sliding_window": false,
|
| 134 |
+
"vocab_size": 151936
|
| 135 |
+
},
|
| 136 |
+
"language_config": {
|
| 137 |
+
"_name_or_path": "",
|
| 138 |
+
"add_cross_attention": false,
|
| 139 |
+
"architectures": [
|
| 140 |
+
"Qwen3ForCausalLM"
|
| 141 |
+
],
|
| 142 |
+
"attention_bias": false,
|
| 143 |
+
"attention_dropout": 0.0,
|
| 144 |
+
"bad_words_ids": null,
|
| 145 |
+
"begin_suppress_tokens": null,
|
| 146 |
+
"bos_token_id": 151643,
|
| 147 |
+
"chunk_size_feed_forward": 0,
|
| 148 |
+
"cross_attention_hidden_size": null,
|
| 149 |
+
"decoder_start_token_id": null,
|
| 150 |
+
"diversity_penalty": 0.0,
|
| 151 |
+
"do_sample": false,
|
| 152 |
+
"dtype": "bfloat16",
|
| 153 |
+
"early_stopping": false,
|
| 154 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 155 |
+
"eos_token_id": 151643,
|
| 156 |
+
"exponential_decay_length_penalty": null,
|
| 157 |
+
"finetuning_task": null,
|
| 158 |
+
"forced_bos_token_id": null,
|
| 159 |
+
"forced_eos_token_id": null,
|
| 160 |
+
"gradient_checkpointing_use_reentrant": false,
|
| 161 |
+
"head_dim": 128,
|
| 162 |
+
"hidden_act": "silu",
|
| 163 |
+
"hidden_size": 2560,
|
| 164 |
+
"id2label": {
|
| 165 |
+
"0": "LABEL_0",
|
| 166 |
+
"1": "LABEL_1"
|
| 167 |
+
},
|
| 168 |
+
"initializer_range": 0.02,
|
| 169 |
+
"intermediate_size": 9728,
|
| 170 |
+
"is_decoder": false,
|
| 171 |
+
"is_encoder_decoder": false,
|
| 172 |
+
"label2id": {
|
| 173 |
+
"LABEL_0": 0,
|
| 174 |
+
"LABEL_1": 1
|
| 175 |
+
},
|
| 176 |
+
"layer_types": [
|
| 177 |
+
"full_attention",
|
| 178 |
+
"full_attention",
|
| 179 |
+
"full_attention",
|
| 180 |
+
"full_attention",
|
| 181 |
+
"full_attention",
|
| 182 |
+
"full_attention",
|
| 183 |
+
"full_attention",
|
| 184 |
+
"full_attention",
|
| 185 |
+
"full_attention",
|
| 186 |
+
"full_attention",
|
| 187 |
+
"full_attention",
|
| 188 |
+
"full_attention",
|
| 189 |
+
"full_attention",
|
| 190 |
+
"full_attention",
|
| 191 |
+
"full_attention",
|
| 192 |
+
"full_attention",
|
| 193 |
+
"full_attention",
|
| 194 |
+
"full_attention",
|
| 195 |
+
"full_attention",
|
| 196 |
+
"full_attention",
|
| 197 |
+
"full_attention",
|
| 198 |
+
"full_attention",
|
| 199 |
+
"full_attention",
|
| 200 |
+
"full_attention",
|
| 201 |
+
"full_attention",
|
| 202 |
+
"full_attention",
|
| 203 |
+
"full_attention",
|
| 204 |
+
"full_attention",
|
| 205 |
+
"full_attention",
|
| 206 |
+
"full_attention",
|
| 207 |
+
"full_attention",
|
| 208 |
+
"full_attention",
|
| 209 |
+
"full_attention",
|
| 210 |
+
"full_attention",
|
| 211 |
+
"full_attention",
|
| 212 |
+
"full_attention"
|
| 213 |
+
],
|
| 214 |
+
"length_penalty": 1.0,
|
| 215 |
+
"max_length": 20,
|
| 216 |
+
"max_position_embeddings": 32768,
|
| 217 |
+
"max_window_layers": 36,
|
| 218 |
+
"min_length": 0,
|
| 219 |
+
"model_type": "qwen3",
|
| 220 |
+
"no_repeat_ngram_size": 0,
|
| 221 |
+
"num_attention_heads": 32,
|
| 222 |
+
"num_beam_groups": 1,
|
| 223 |
+
"num_beams": 1,
|
| 224 |
+
"num_hidden_layers": 36,
|
| 225 |
+
"num_key_value_heads": 8,
|
| 226 |
+
"num_return_sequences": 1,
|
| 227 |
+
"output_attentions": false,
|
| 228 |
+
"output_hidden_states": false,
|
| 229 |
+
"output_scores": false,
|
| 230 |
+
"pad_token_id": 151643,
|
| 231 |
+
"prefix": null,
|
| 232 |
+
"problem_type": null,
|
| 233 |
+
"pruned_heads": {},
|
| 234 |
+
"remove_invalid_values": false,
|
| 235 |
+
"repetition_penalty": 1.0,
|
| 236 |
+
"return_dict": true,
|
| 237 |
+
"return_dict_in_generate": false,
|
| 238 |
+
"rms_norm_eps": 1e-06,
|
| 239 |
+
"rope_scaling": null,
|
| 240 |
+
"rope_theta": 1000000,
|
| 241 |
+
"sep_token_id": null,
|
| 242 |
+
"sliding_window": null,
|
| 243 |
+
"suppress_tokens": null,
|
| 244 |
+
"task_specific_params": null,
|
| 245 |
+
"temperature": 1.0,
|
| 246 |
+
"tf_legacy_loss": false,
|
| 247 |
+
"tie_encoder_decoder": false,
|
| 248 |
+
"tie_word_embeddings": true,
|
| 249 |
+
"tokenizer_class": null,
|
| 250 |
+
"top_k": 50,
|
| 251 |
+
"top_p": 1.0,
|
| 252 |
+
"torchscript": false,
|
| 253 |
+
"transformers_version": "4.57.1",
|
| 254 |
+
"typical_p": 1.0,
|
| 255 |
+
"use_bfloat16": false,
|
| 256 |
+
"use_cache": false,
|
| 257 |
+
"use_sliding_window": false,
|
| 258 |
+
"vocab_size": 151936
|
| 259 |
+
},
|
| 260 |
+
"gpt2_config": {
|
| 261 |
+
"_name_or_path": "",
|
| 262 |
+
"activation_function": "silu",
|
| 263 |
+
"add_cross_attention": false,
|
| 264 |
+
"architectures": null,
|
| 265 |
+
"attn_pdrop": 0.0,
|
| 266 |
+
"bad_words_ids": null,
|
| 267 |
+
"begin_suppress_tokens": null,
|
| 268 |
+
"bos_token_id": null,
|
| 269 |
+
"chunk_size_feed_forward": 0,
|
| 270 |
+
"cross_attention_hidden_size": null,
|
| 271 |
+
"decoder_start_token_id": null,
|
| 272 |
+
"diversity_penalty": 0.0,
|
| 273 |
+
"do_sample": false,
|
| 274 |
+
"dtype": null,
|
| 275 |
+
"early_stopping": false,
|
| 276 |
+
"embd_pdrop": 0.0,
|
| 277 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 278 |
+
"eos_token_id": 151645,
|
| 279 |
+
"exponential_decay_length_penalty": null,
|
| 280 |
+
"finetuning_task": null,
|
| 281 |
+
"forced_bos_token_id": null,
|
| 282 |
+
"forced_eos_token_id": null,
|
| 283 |
+
"id2label": {
|
| 284 |
+
"0": "LABEL_0",
|
| 285 |
+
"1": "LABEL_1"
|
| 286 |
+
},
|
| 287 |
+
"initializer_range": 0.02,
|
| 288 |
+
"is_decoder": false,
|
| 289 |
+
"is_encoder_decoder": false,
|
| 290 |
+
"label2id": {
|
| 291 |
+
"LABEL_0": 0,
|
| 292 |
+
"LABEL_1": 1
|
| 293 |
+
},
|
| 294 |
+
"layer_norm_epsilon": 1e-06,
|
| 295 |
+
"length_penalty": 1.0,
|
| 296 |
+
"max_length": 20,
|
| 297 |
+
"min_length": 0,
|
| 298 |
+
"model_type": "gpt2",
|
| 299 |
+
"n_ctx": 10240,
|
| 300 |
+
"n_embd": 2560,
|
| 301 |
+
"n_head": 32,
|
| 302 |
+
"n_inner": 9728,
|
| 303 |
+
"n_layer": 1,
|
| 304 |
+
"n_positions": 10240,
|
| 305 |
+
"no_repeat_ngram_size": 0,
|
| 306 |
+
"num_beam_groups": 1,
|
| 307 |
+
"num_beams": 1,
|
| 308 |
+
"num_return_sequences": 1,
|
| 309 |
+
"output_attentions": false,
|
| 310 |
+
"output_hidden_states": false,
|
| 311 |
+
"output_scores": false,
|
| 312 |
+
"pad_token_id": null,
|
| 313 |
+
"position_embedding_type": "rope",
|
| 314 |
+
"prefix": null,
|
| 315 |
+
"problem_type": null,
|
| 316 |
+
"pruned_heads": {},
|
| 317 |
+
"remove_invalid_values": false,
|
| 318 |
+
"reorder_and_upcast_attn": false,
|
| 319 |
+
"repetition_penalty": 1.0,
|
| 320 |
+
"resid_pdrop": 0.0,
|
| 321 |
+
"return_dict": true,
|
| 322 |
+
"return_dict_in_generate": false,
|
| 323 |
+
"rope_base": 1000000.0,
|
| 324 |
+
"scale_attn_by_inverse_layer_idx": false,
|
| 325 |
+
"scale_attn_weights": true,
|
| 326 |
+
"sep_token_id": null,
|
| 327 |
+
"summary_activation": null,
|
| 328 |
+
"summary_first_dropout": 0.1,
|
| 329 |
+
"summary_proj_to_labels": true,
|
| 330 |
+
"summary_type": "cls_index",
|
| 331 |
+
"summary_use_proj": true,
|
| 332 |
+
"suppress_tokens": null,
|
| 333 |
+
"task_specific_params": null,
|
| 334 |
+
"temperature": 1.0,
|
| 335 |
+
"tf_legacy_loss": false,
|
| 336 |
+
"tie_encoder_decoder": false,
|
| 337 |
+
"tie_word_embeddings": true,
|
| 338 |
+
"tokenizer_class": null,
|
| 339 |
+
"top_k": 50,
|
| 340 |
+
"top_p": 1.0,
|
| 341 |
+
"torchscript": false,
|
| 342 |
+
"transformers_version": "4.57.1",
|
| 343 |
+
"typical_p": 1.0,
|
| 344 |
+
"use_bfloat16": false,
|
| 345 |
+
"use_cache": true,
|
| 346 |
+
"vocab_size": 151936
|
| 347 |
+
},
|
| 348 |
+
"n_vq": 12,
|
| 349 |
+
"audio_vocab_size": 1024,
|
| 350 |
+
"audio_codebook_sizes": [
|
| 351 |
+
1024,
|
| 352 |
+
1024,
|
| 353 |
+
1024,
|
| 354 |
+
1024,
|
| 355 |
+
1024,
|
| 356 |
+
1024,
|
| 357 |
+
1024,
|
| 358 |
+
1024,
|
| 359 |
+
1024,
|
| 360 |
+
1024,
|
| 361 |
+
1024,
|
| 362 |
+
1024
|
| 363 |
+
],
|
| 364 |
+
"audio_pad_token_id": 1024,
|
| 365 |
+
"audio_pad_code": 1024,
|
| 366 |
+
"pad_token_id": 151643,
|
| 367 |
+
"im_start_token_id": 151644,
|
| 368 |
+
"im_end_token_id": 151645,
|
| 369 |
+
"audio_start_token_id": 151669,
|
| 370 |
+
"audio_end_token_id": 151670,
|
| 371 |
+
"audio_user_slot_token_id": 151654,
|
| 372 |
+
"audio_assistant_slot_token_id": 151656,
|
| 373 |
+
"audio_assistant_gen_slot_token_id": 151656,
|
| 374 |
+
"sampling_rate": 48000,
|
| 375 |
+
"audio_tokenizer_name_or_path": "OpenMOSS-Team/MOSS-Audio-Tokenizer-v2",
|
| 376 |
+
"attn_implementation": "flash_attention_2",
|
| 377 |
+
"local_transformer_layers": 1,
|
| 378 |
+
"local_text_head_mode": "binary",
|
| 379 |
+
"use_static_local_kv_cache": true,
|
| 380 |
+
"initializer_range": 0.02
|
| 381 |
+
}
|
configuration_moss_tts.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
"""Configuration for the MOSS-TTS-Local-Transformer-v1.5 release."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from typing import Any, Dict, Optional, Union
|
| 7 |
+
|
| 8 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 9 |
+
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
| 10 |
+
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
SUPPORTED_ATTENTION_IMPLEMENTATIONS = {"flash_attention_2", "sdpa", "eager"}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _normalize_attention_implementation(value: Optional[str], default: str = "flash_attention_2") -> str:
|
| 17 |
+
normalized = str(value or default).strip().lower()
|
| 18 |
+
if normalized in {"flash", "flash_attn", "flash-attn", "flash_attention"}:
|
| 19 |
+
normalized = "flash_attention_2"
|
| 20 |
+
if normalized not in SUPPORTED_ATTENTION_IMPLEMENTATIONS:
|
| 21 |
+
raise ValueError(
|
| 22 |
+
"attn_implementation must be one of "
|
| 23 |
+
f"{sorted(SUPPORTED_ATTENTION_IMPLEMENTATIONS)}, got {value!r}."
|
| 24 |
+
)
|
| 25 |
+
return normalized
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class MossTTSLocalConfig(PretrainedConfig):
|
| 29 |
+
model_type = "moss_tts_local"
|
| 30 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
qwen3_config: Optional[Union[Qwen3Config, Dict[str, Any]]] = None,
|
| 35 |
+
gpt2_config: Optional[Union[GPT2Config, Dict[str, Any]]] = None,
|
| 36 |
+
language_config: Optional[Union[Qwen3Config, Dict[str, Any]]] = None,
|
| 37 |
+
n_vq: int = 12,
|
| 38 |
+
audio_vocab_size: int = 1024,
|
| 39 |
+
audio_codebook_sizes: Optional[list[int]] = None,
|
| 40 |
+
audio_pad_token_id: int = 1024,
|
| 41 |
+
audio_pad_code: Optional[int] = None,
|
| 42 |
+
pad_token_id: int = 151643,
|
| 43 |
+
im_start_token_id: int = 151644,
|
| 44 |
+
im_end_token_id: int = 151645,
|
| 45 |
+
audio_start_token_id: int = 151669,
|
| 46 |
+
audio_end_token_id: int = 151670,
|
| 47 |
+
audio_user_slot_token_id: int = 151654,
|
| 48 |
+
audio_assistant_slot_token_id: int = 151656,
|
| 49 |
+
audio_assistant_gen_slot_token_id: Optional[int] = None,
|
| 50 |
+
sampling_rate: int = 48000,
|
| 51 |
+
audio_tokenizer_name_or_path: Optional[str] = None,
|
| 52 |
+
attn_implementation: str = "flash_attention_2",
|
| 53 |
+
local_transformer_attn_implementation: Optional[str] = None,
|
| 54 |
+
local_text_head_mode: str = "binary",
|
| 55 |
+
initializer_range: float = 0.02,
|
| 56 |
+
**kwargs: Any,
|
| 57 |
+
) -> None:
|
| 58 |
+
if qwen3_config is None and language_config is not None:
|
| 59 |
+
qwen3_config = language_config
|
| 60 |
+
if isinstance(qwen3_config, dict):
|
| 61 |
+
self.qwen3_config = Qwen3Config(**qwen3_config)
|
| 62 |
+
elif qwen3_config is None:
|
| 63 |
+
self.qwen3_config = Qwen3Config()
|
| 64 |
+
else:
|
| 65 |
+
self.qwen3_config = qwen3_config
|
| 66 |
+
|
| 67 |
+
if isinstance(gpt2_config, dict):
|
| 68 |
+
self.gpt2_config = GPT2Config(**gpt2_config)
|
| 69 |
+
elif gpt2_config is None:
|
| 70 |
+
self.gpt2_config = GPT2Config(
|
| 71 |
+
vocab_size=int(self.qwen3_config.vocab_size),
|
| 72 |
+
n_embd=int(self.qwen3_config.hidden_size),
|
| 73 |
+
n_layer=1,
|
| 74 |
+
n_head=max(1, int(self.qwen3_config.hidden_size) // 80),
|
| 75 |
+
n_positions=int(n_vq) + 1,
|
| 76 |
+
n_ctx=int(n_vq) + 1,
|
| 77 |
+
activation_function="silu",
|
| 78 |
+
layer_norm_epsilon=1e-6,
|
| 79 |
+
resid_pdrop=0.0,
|
| 80 |
+
embd_pdrop=0.0,
|
| 81 |
+
attn_pdrop=0.0,
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
self.gpt2_config = gpt2_config
|
| 85 |
+
|
| 86 |
+
self.n_vq = int(n_vq)
|
| 87 |
+
if self.n_vq <= 0:
|
| 88 |
+
raise ValueError("n_vq must be positive.")
|
| 89 |
+
if audio_codebook_sizes is None:
|
| 90 |
+
self.audio_codebook_sizes = [int(audio_vocab_size)] * self.n_vq
|
| 91 |
+
else:
|
| 92 |
+
self.audio_codebook_sizes = [int(size) for size in audio_codebook_sizes]
|
| 93 |
+
if len(self.audio_codebook_sizes) != self.n_vq:
|
| 94 |
+
raise ValueError(
|
| 95 |
+
f"audio_codebook_sizes must have length n_vq={self.n_vq}, "
|
| 96 |
+
f"got {len(self.audio_codebook_sizes)}."
|
| 97 |
+
)
|
| 98 |
+
if any(size <= 0 for size in self.audio_codebook_sizes):
|
| 99 |
+
raise ValueError("audio_codebook_sizes must contain positive integers.")
|
| 100 |
+
self.audio_vocab_size = int(max(int(audio_vocab_size), max(self.audio_codebook_sizes)))
|
| 101 |
+
self.audio_pad_token_id = int(audio_pad_code if audio_pad_code is not None else audio_pad_token_id)
|
| 102 |
+
self.audio_pad_code = self.audio_pad_token_id
|
| 103 |
+
if self.audio_pad_token_id < self.audio_vocab_size:
|
| 104 |
+
raise ValueError("audio_pad_token_id/audio_pad_code must be outside the audio vocab.")
|
| 105 |
+
|
| 106 |
+
self.pad_token_id = int(pad_token_id)
|
| 107 |
+
self.im_start_token_id = int(im_start_token_id)
|
| 108 |
+
self.im_end_token_id = int(im_end_token_id)
|
| 109 |
+
self.audio_start_token_id = int(audio_start_token_id)
|
| 110 |
+
self.audio_end_token_id = int(audio_end_token_id)
|
| 111 |
+
self.audio_user_slot_token_id = int(audio_user_slot_token_id)
|
| 112 |
+
self.audio_assistant_slot_token_id = int(
|
| 113 |
+
audio_assistant_slot_token_id
|
| 114 |
+
if audio_assistant_gen_slot_token_id is None
|
| 115 |
+
else audio_assistant_gen_slot_token_id
|
| 116 |
+
)
|
| 117 |
+
self.audio_assistant_gen_slot_token_id = self.audio_assistant_slot_token_id
|
| 118 |
+
|
| 119 |
+
self.sampling_rate = int(sampling_rate)
|
| 120 |
+
self.audio_tokenizer_name_or_path = audio_tokenizer_name_or_path
|
| 121 |
+
self.attn_implementation = _normalize_attention_implementation(attn_implementation)
|
| 122 |
+
self.local_transformer_attn_implementation = _normalize_attention_implementation(
|
| 123 |
+
local_transformer_attn_implementation,
|
| 124 |
+
default=self.attn_implementation,
|
| 125 |
+
)
|
| 126 |
+
self.initializer_range = float(initializer_range)
|
| 127 |
+
|
| 128 |
+
self.hidden_size = int(self.qwen3_config.hidden_size)
|
| 129 |
+
self.vocab_size = int(self.qwen3_config.vocab_size)
|
| 130 |
+
self.local_hidden_size = int(self.gpt2_config.hidden_size)
|
| 131 |
+
if self.local_hidden_size != self.hidden_size:
|
| 132 |
+
raise ValueError(
|
| 133 |
+
"This MOSS-TTS-Local-Transformer-v1.5 release expects local hidden size to "
|
| 134 |
+
"match Qwen3 hidden size so audio embeddings and heads are tied."
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
normalized_text_head_mode = str(local_text_head_mode or "full_vocab").strip().lower()
|
| 138 |
+
if normalized_text_head_mode in {"full", "full-vocab", "vocab"}:
|
| 139 |
+
normalized_text_head_mode = "full_vocab"
|
| 140 |
+
if normalized_text_head_mode not in {"full_vocab", "binary"}:
|
| 141 |
+
raise ValueError("local_text_head_mode must be 'full_vocab' or 'binary'.")
|
| 142 |
+
self.local_text_head_mode = normalized_text_head_mode
|
| 143 |
+
|
| 144 |
+
kwargs.setdefault("tie_word_embeddings", True)
|
| 145 |
+
super().__init__(pad_token_id=self.pad_token_id, **kwargs)
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def language_config(self) -> Qwen3Config:
|
| 149 |
+
return self.qwen3_config
|
| 150 |
+
|
| 151 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 152 |
+
output = super().to_dict()
|
| 153 |
+
output["qwen3_config"] = self.qwen3_config.to_dict()
|
| 154 |
+
output["language_config"] = self.qwen3_config.to_dict()
|
| 155 |
+
output["gpt2_config"] = self.gpt2_config.to_dict()
|
| 156 |
+
output["audio_pad_code"] = self.audio_pad_token_id
|
| 157 |
+
output["audio_assistant_gen_slot_token_id"] = self.audio_assistant_slot_token_id
|
| 158 |
+
return output
|
gpt2_decoder.py
ADDED
|
@@ -0,0 +1,721 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.utils.checkpoint
|
| 10 |
+
from transformers.activations import ACT2FN
|
| 11 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 12 |
+
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 16 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
| 17 |
+
|
| 18 |
+
_FLASH_ATTN_AVAILABLE = True
|
| 19 |
+
except Exception:
|
| 20 |
+
flash_attn_func = None
|
| 21 |
+
flash_attn_varlen_func = None
|
| 22 |
+
pad_input = None
|
| 23 |
+
unpad_input = None
|
| 24 |
+
_FLASH_ATTN_AVAILABLE = False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class PackedSequenceMetadata:
|
| 29 |
+
cu_seqlens: torch.Tensor
|
| 30 |
+
max_seqlen: int
|
| 31 |
+
indices: Optional[torch.Tensor] = None
|
| 32 |
+
batch_size: Optional[int] = None
|
| 33 |
+
seq_len: Optional[int] = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _is_static_kv_cache_layer(layer_past: object) -> bool:
|
| 37 |
+
return isinstance(layer_past, dict) and bool(layer_past.get("static_kv_cache", False))
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MossTTSNanoGPT2RotaryEmbedding(nn.Module):
|
| 41 |
+
def __init__(self, dim: int, base: float = 10000.0) -> None:
|
| 42 |
+
super().__init__()
|
| 43 |
+
if dim % 2 != 0:
|
| 44 |
+
raise ValueError(f"RoPE head_dim must be even, got {dim}")
|
| 45 |
+
self.dim = int(dim)
|
| 46 |
+
self.base = float(base)
|
| 47 |
+
self.register_buffer("inv_freq", self._compute_inv_freq(), persistent=False)
|
| 48 |
+
|
| 49 |
+
def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
|
| 50 |
+
return 1.0 / (
|
| 51 |
+
self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def forward(
|
| 55 |
+
self,
|
| 56 |
+
position_ids: torch.LongTensor,
|
| 57 |
+
*,
|
| 58 |
+
device: torch.device,
|
| 59 |
+
dtype: torch.dtype,
|
| 60 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 61 |
+
if position_ids.ndim == 1:
|
| 62 |
+
position_ids = position_ids.unsqueeze(0)
|
| 63 |
+
inv_freq = self._compute_inv_freq(device=device)
|
| 64 |
+
freqs = torch.einsum("bs,d->bsd", position_ids.to(device=device, dtype=inv_freq.dtype), inv_freq)
|
| 65 |
+
cos = freqs.cos().repeat_interleave(2, dim=-1).unsqueeze(2).to(dtype=dtype)
|
| 66 |
+
sin = freqs.sin().repeat_interleave(2, dim=-1).unsqueeze(2).to(dtype=dtype)
|
| 67 |
+
return cos, sin
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def rotate_half(hidden_states: torch.Tensor) -> torch.Tensor:
|
| 71 |
+
even = hidden_states[..., ::2]
|
| 72 |
+
odd = hidden_states[..., 1::2]
|
| 73 |
+
return torch.stack((-odd, even), dim=-1).reshape_as(hidden_states)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def apply_rotary_pos_emb(
|
| 77 |
+
hidden_states: torch.Tensor,
|
| 78 |
+
cos: torch.Tensor,
|
| 79 |
+
sin: torch.Tensor,
|
| 80 |
+
) -> torch.Tensor:
|
| 81 |
+
return (hidden_states * cos) + (rotate_half(hidden_states) * sin)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class MossTTSNanoGPT2MLP(nn.Module):
|
| 85 |
+
def __init__(self, config: GPT2Config) -> None:
|
| 86 |
+
super().__init__()
|
| 87 |
+
hidden_size = int(config.hidden_size)
|
| 88 |
+
inner_size = int(config.n_inner or 4 * hidden_size)
|
| 89 |
+
self.fc_in = nn.Linear(hidden_size, inner_size)
|
| 90 |
+
self.fc_out = nn.Linear(inner_size, hidden_size)
|
| 91 |
+
self.act = ACT2FN[config.activation_function]
|
| 92 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
| 93 |
+
|
| 94 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 95 |
+
hidden_states = self.fc_in(hidden_states)
|
| 96 |
+
hidden_states = self.act(hidden_states)
|
| 97 |
+
hidden_states = self.fc_out(hidden_states)
|
| 98 |
+
return self.dropout(hidden_states)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class MossTTSNanoGPT2Attention(nn.Module):
|
| 102 |
+
def __init__(self, config: GPT2Config, layer_idx: int, attn_implementation: str) -> None:
|
| 103 |
+
super().__init__()
|
| 104 |
+
hidden_size = int(config.hidden_size)
|
| 105 |
+
num_heads = int(config.num_attention_heads)
|
| 106 |
+
if hidden_size % num_heads != 0:
|
| 107 |
+
raise ValueError(f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_heads}")
|
| 108 |
+
|
| 109 |
+
self.num_heads = num_heads
|
| 110 |
+
self.head_dim = hidden_size // num_heads
|
| 111 |
+
self.embed_dim = hidden_size
|
| 112 |
+
self.layer_idx = layer_idx
|
| 113 |
+
self.attn_implementation = attn_implementation
|
| 114 |
+
self.attn_dropout = float(config.attn_pdrop)
|
| 115 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
| 116 |
+
self.scale_attn_weights = bool(getattr(config, "scale_attn_weights", True))
|
| 117 |
+
self.scale_attn_by_inverse_layer_idx = bool(getattr(config, "scale_attn_by_inverse_layer_idx", False))
|
| 118 |
+
self.position_embedding_type = str(getattr(config, "position_embedding_type", "absolute")).lower()
|
| 119 |
+
if self.position_embedding_type not in {"absolute", "rope"}:
|
| 120 |
+
raise ValueError(f"Unsupported position_embedding_type={self.position_embedding_type!r}")
|
| 121 |
+
|
| 122 |
+
self.c_attn = nn.Linear(hidden_size, 3 * hidden_size)
|
| 123 |
+
self.c_proj = nn.Linear(hidden_size, hidden_size)
|
| 124 |
+
self.rotary_emb = None
|
| 125 |
+
if self.position_embedding_type == "rope":
|
| 126 |
+
self.rotary_emb = MossTTSNanoGPT2RotaryEmbedding(
|
| 127 |
+
self.head_dim,
|
| 128 |
+
base=float(getattr(config, "rope_base", 10000.0)),
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def _split_heads(self, tensor: torch.Tensor) -> torch.Tensor:
|
| 132 |
+
if tensor.ndim == 3:
|
| 133 |
+
batch_size, seq_len, _ = tensor.shape
|
| 134 |
+
return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 135 |
+
if tensor.ndim == 2:
|
| 136 |
+
total_tokens, _ = tensor.shape
|
| 137 |
+
return tensor.view(total_tokens, self.num_heads, self.head_dim)
|
| 138 |
+
raise ValueError(f"Unsupported tensor rank for attention split: {tensor.ndim}")
|
| 139 |
+
|
| 140 |
+
def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
|
| 141 |
+
if tensor.ndim == 4:
|
| 142 |
+
batch_size, seq_len, _, _ = tensor.shape
|
| 143 |
+
return tensor.reshape(batch_size, seq_len, self.embed_dim)
|
| 144 |
+
if tensor.ndim == 3:
|
| 145 |
+
total_tokens, _, _ = tensor.shape
|
| 146 |
+
return tensor.reshape(total_tokens, self.embed_dim)
|
| 147 |
+
raise ValueError(f"Unsupported tensor rank for attention merge: {tensor.ndim}")
|
| 148 |
+
|
| 149 |
+
def _causal_attention_mask(
|
| 150 |
+
self,
|
| 151 |
+
attention_mask: Optional[torch.Tensor],
|
| 152 |
+
query_length: int,
|
| 153 |
+
key_length: int,
|
| 154 |
+
device: torch.device,
|
| 155 |
+
) -> torch.Tensor:
|
| 156 |
+
query_positions = torch.arange(query_length, device=device, dtype=torch.long)
|
| 157 |
+
query_positions = query_positions + max(key_length - query_length, 0)
|
| 158 |
+
key_positions = torch.arange(key_length, device=device, dtype=torch.long)
|
| 159 |
+
causal = key_positions.unsqueeze(0) <= query_positions.unsqueeze(1)
|
| 160 |
+
causal = causal.unsqueeze(0).unsqueeze(0)
|
| 161 |
+
if attention_mask is None:
|
| 162 |
+
return causal
|
| 163 |
+
key_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)
|
| 164 |
+
return causal & key_mask
|
| 165 |
+
|
| 166 |
+
def _eager_attention(
|
| 167 |
+
self,
|
| 168 |
+
query: torch.Tensor,
|
| 169 |
+
key: torch.Tensor,
|
| 170 |
+
value: torch.Tensor,
|
| 171 |
+
attention_mask: Optional[torch.Tensor],
|
| 172 |
+
) -> torch.Tensor:
|
| 173 |
+
query = query.transpose(1, 2)
|
| 174 |
+
key = key.transpose(1, 2)
|
| 175 |
+
value = value.transpose(1, 2)
|
| 176 |
+
|
| 177 |
+
scale = 1.0
|
| 178 |
+
if self.scale_attn_weights:
|
| 179 |
+
scale /= self.head_dim ** 0.5
|
| 180 |
+
if self.scale_attn_by_inverse_layer_idx:
|
| 181 |
+
scale /= float(self.layer_idx + 1)
|
| 182 |
+
|
| 183 |
+
scores = torch.matmul(query, key.transpose(-1, -2)) * scale
|
| 184 |
+
causal_mask = self._causal_attention_mask(
|
| 185 |
+
attention_mask=attention_mask,
|
| 186 |
+
query_length=query.shape[-2],
|
| 187 |
+
key_length=key.shape[-2],
|
| 188 |
+
device=query.device,
|
| 189 |
+
)
|
| 190 |
+
scores = scores.masked_fill(~causal_mask, torch.finfo(scores.dtype).min)
|
| 191 |
+
probs = torch.softmax(scores, dim=-1)
|
| 192 |
+
if self.training and self.attn_dropout > 0:
|
| 193 |
+
probs = torch.dropout(probs, self.attn_dropout, train=True)
|
| 194 |
+
output = torch.matmul(probs, value)
|
| 195 |
+
return output.transpose(1, 2).contiguous()
|
| 196 |
+
|
| 197 |
+
def _sdpa_attention(
|
| 198 |
+
self,
|
| 199 |
+
query: torch.Tensor,
|
| 200 |
+
key: torch.Tensor,
|
| 201 |
+
value: torch.Tensor,
|
| 202 |
+
attention_mask: Optional[torch.Tensor],
|
| 203 |
+
) -> torch.Tensor:
|
| 204 |
+
query = query.transpose(1, 2)
|
| 205 |
+
key = key.transpose(1, 2)
|
| 206 |
+
value = value.transpose(1, 2)
|
| 207 |
+
mask = None
|
| 208 |
+
if attention_mask is not None or query.shape[-2] != key.shape[-2]:
|
| 209 |
+
mask = self._causal_attention_mask(
|
| 210 |
+
attention_mask=attention_mask,
|
| 211 |
+
query_length=query.shape[-2],
|
| 212 |
+
key_length=key.shape[-2],
|
| 213 |
+
device=query.device,
|
| 214 |
+
)
|
| 215 |
+
output = torch.nn.functional.scaled_dot_product_attention(
|
| 216 |
+
query,
|
| 217 |
+
key,
|
| 218 |
+
value,
|
| 219 |
+
attn_mask=mask,
|
| 220 |
+
dropout_p=self.attn_dropout if self.training else 0.0,
|
| 221 |
+
is_causal=mask is None,
|
| 222 |
+
)
|
| 223 |
+
return output.transpose(1, 2).contiguous()
|
| 224 |
+
|
| 225 |
+
def _flash_attention(
|
| 226 |
+
self,
|
| 227 |
+
query: torch.Tensor,
|
| 228 |
+
key: torch.Tensor,
|
| 229 |
+
value: torch.Tensor,
|
| 230 |
+
attention_mask: Optional[torch.Tensor],
|
| 231 |
+
packed_metadata: Optional[PackedSequenceMetadata],
|
| 232 |
+
) -> torch.Tensor:
|
| 233 |
+
if not _FLASH_ATTN_AVAILABLE:
|
| 234 |
+
raise ImportError("flash_attn is not installed, but attn_implementation='flash_attention_2' was requested.")
|
| 235 |
+
if query.device.type != "cuda":
|
| 236 |
+
raise ValueError("flash_attention_2 requires CUDA tensors.")
|
| 237 |
+
if query.dtype not in (torch.float16, torch.bfloat16):
|
| 238 |
+
raise ValueError(
|
| 239 |
+
f"flash_attention_2 requires fp16/bf16 tensors, but received dtype={query.dtype}."
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
dropout_p = self.attn_dropout if self.training else 0.0
|
| 243 |
+
if packed_metadata is not None:
|
| 244 |
+
if packed_metadata.indices is not None:
|
| 245 |
+
query = query.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
|
| 246 |
+
key = key.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
|
| 247 |
+
value = value.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
|
| 248 |
+
output = flash_attn_varlen_func(
|
| 249 |
+
query,
|
| 250 |
+
key,
|
| 251 |
+
value,
|
| 252 |
+
packed_metadata.cu_seqlens,
|
| 253 |
+
packed_metadata.cu_seqlens,
|
| 254 |
+
packed_metadata.max_seqlen,
|
| 255 |
+
packed_metadata.max_seqlen,
|
| 256 |
+
dropout_p=dropout_p,
|
| 257 |
+
causal=True,
|
| 258 |
+
)
|
| 259 |
+
if packed_metadata.indices is None:
|
| 260 |
+
return output
|
| 261 |
+
return pad_input(
|
| 262 |
+
output,
|
| 263 |
+
packed_metadata.indices,
|
| 264 |
+
packed_metadata.batch_size,
|
| 265 |
+
packed_metadata.seq_len,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
if attention_mask is None or bool(attention_mask.all()):
|
| 269 |
+
return flash_attn_func(
|
| 270 |
+
query,
|
| 271 |
+
key,
|
| 272 |
+
value,
|
| 273 |
+
dropout_p=dropout_p,
|
| 274 |
+
causal=True,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
if query.shape[1] != key.shape[1]:
|
| 278 |
+
query_attention_mask = attention_mask[:, -query.shape[1] :]
|
| 279 |
+
unpadded_query, query_indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(
|
| 280 |
+
query,
|
| 281 |
+
query_attention_mask,
|
| 282 |
+
)
|
| 283 |
+
unpadded_key, _, cu_seqlens_k, max_seqlen_k, _ = unpad_input(key, attention_mask)
|
| 284 |
+
unpadded_value, _, _, _, _ = unpad_input(value, attention_mask)
|
| 285 |
+
output = flash_attn_varlen_func(
|
| 286 |
+
unpadded_query,
|
| 287 |
+
unpadded_key,
|
| 288 |
+
unpadded_value,
|
| 289 |
+
cu_seqlens_q,
|
| 290 |
+
cu_seqlens_k,
|
| 291 |
+
max_seqlen_q,
|
| 292 |
+
max_seqlen_k,
|
| 293 |
+
dropout_p=dropout_p,
|
| 294 |
+
causal=True,
|
| 295 |
+
)
|
| 296 |
+
return pad_input(output, query_indices, query.shape[0], query.shape[1])
|
| 297 |
+
|
| 298 |
+
unpadded_query, indices, cu_seqlens, max_seqlen, _ = unpad_input(query, attention_mask)
|
| 299 |
+
unpadded_key, _, _, _, _ = unpad_input(key, attention_mask)
|
| 300 |
+
unpadded_value, _, _, _, _ = unpad_input(value, attention_mask)
|
| 301 |
+
output = flash_attn_varlen_func(
|
| 302 |
+
unpadded_query,
|
| 303 |
+
unpadded_key,
|
| 304 |
+
unpadded_value,
|
| 305 |
+
cu_seqlens,
|
| 306 |
+
cu_seqlens,
|
| 307 |
+
max_seqlen,
|
| 308 |
+
max_seqlen,
|
| 309 |
+
dropout_p=dropout_p,
|
| 310 |
+
causal=True,
|
| 311 |
+
)
|
| 312 |
+
return pad_input(output, indices, query.shape[0], query.shape[1])
|
| 313 |
+
|
| 314 |
+
def forward(
|
| 315 |
+
self,
|
| 316 |
+
hidden_states: torch.Tensor,
|
| 317 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 318 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 319 |
+
packed_metadata: Optional[PackedSequenceMetadata] = None,
|
| 320 |
+
layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 321 |
+
use_cache: bool = False,
|
| 322 |
+
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
| 323 |
+
qkv = self.c_attn(hidden_states)
|
| 324 |
+
query, key, value = qkv.split(self.embed_dim, dim=-1)
|
| 325 |
+
query = self._split_heads(query)
|
| 326 |
+
key = self._split_heads(key)
|
| 327 |
+
value = self._split_heads(value)
|
| 328 |
+
|
| 329 |
+
if self.rotary_emb is not None:
|
| 330 |
+
if position_ids is None:
|
| 331 |
+
raise ValueError("position_ids must be provided when position_embedding_type='rope'.")
|
| 332 |
+
cos, sin = self.rotary_emb(
|
| 333 |
+
position_ids.to(device=query.device),
|
| 334 |
+
device=query.device,
|
| 335 |
+
dtype=query.dtype,
|
| 336 |
+
)
|
| 337 |
+
query = apply_rotary_pos_emb(query, cos, sin)
|
| 338 |
+
key = apply_rotary_pos_emb(key, cos, sin)
|
| 339 |
+
|
| 340 |
+
static_layer_past = layer_past is not None and _is_static_kv_cache_layer(layer_past)
|
| 341 |
+
if static_layer_past:
|
| 342 |
+
past_length = int(layer_past.get("length", 0))
|
| 343 |
+
new_length = past_length + int(key.shape[1])
|
| 344 |
+
key_cache = layer_past["key"]
|
| 345 |
+
value_cache = layer_past["value"]
|
| 346 |
+
if new_length > int(key_cache.shape[1]):
|
| 347 |
+
raise ValueError(
|
| 348 |
+
f"Static KV cache is too short: need {new_length}, capacity={int(key_cache.shape[1])}."
|
| 349 |
+
)
|
| 350 |
+
key_cache[:, past_length:new_length].copy_(key)
|
| 351 |
+
value_cache[:, past_length:new_length].copy_(value)
|
| 352 |
+
key = key_cache[:, :new_length]
|
| 353 |
+
value = value_cache[:, :new_length]
|
| 354 |
+
layer_past["length"] = new_length
|
| 355 |
+
elif layer_past is not None:
|
| 356 |
+
past_key, past_value = layer_past
|
| 357 |
+
key = torch.cat([past_key.to(device=key.device, dtype=key.dtype), key], dim=1)
|
| 358 |
+
value = torch.cat([past_value.to(device=value.device, dtype=value.dtype), value], dim=1)
|
| 359 |
+
|
| 360 |
+
present = layer_past if (use_cache and static_layer_past) else ((key, value) if use_cache else None)
|
| 361 |
+
|
| 362 |
+
if self.attn_implementation == "flash_attention_2":
|
| 363 |
+
attn_output = self._flash_attention(
|
| 364 |
+
query=query,
|
| 365 |
+
key=key,
|
| 366 |
+
value=value,
|
| 367 |
+
attention_mask=attention_mask,
|
| 368 |
+
packed_metadata=packed_metadata,
|
| 369 |
+
)
|
| 370 |
+
elif self.attn_implementation == "sdpa":
|
| 371 |
+
attn_output = self._sdpa_attention(
|
| 372 |
+
query=query,
|
| 373 |
+
key=key,
|
| 374 |
+
value=value,
|
| 375 |
+
attention_mask=attention_mask,
|
| 376 |
+
)
|
| 377 |
+
else:
|
| 378 |
+
attn_output = self._eager_attention(
|
| 379 |
+
query=query,
|
| 380 |
+
key=key,
|
| 381 |
+
value=value,
|
| 382 |
+
attention_mask=attention_mask,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
attn_output = self._merge_heads(attn_output)
|
| 386 |
+
attn_output = self.c_proj(attn_output)
|
| 387 |
+
return self.resid_dropout(attn_output), present
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
class MossTTSNanoGPT2Block(nn.Module):
|
| 391 |
+
def __init__(self, config: GPT2Config, layer_idx: int, attn_implementation: str) -> None:
|
| 392 |
+
super().__init__()
|
| 393 |
+
hidden_size = int(config.hidden_size)
|
| 394 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 395 |
+
self.attn = MossTTSNanoGPT2Attention(config, layer_idx=layer_idx, attn_implementation=attn_implementation)
|
| 396 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 397 |
+
self.mlp = MossTTSNanoGPT2MLP(config)
|
| 398 |
+
|
| 399 |
+
def forward(
|
| 400 |
+
self,
|
| 401 |
+
hidden_states: torch.Tensor,
|
| 402 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 403 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 404 |
+
packed_metadata: Optional[PackedSequenceMetadata] = None,
|
| 405 |
+
layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 406 |
+
use_cache: bool = False,
|
| 407 |
+
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
| 408 |
+
attn_output, present = self.attn(
|
| 409 |
+
self.ln_1(hidden_states),
|
| 410 |
+
attention_mask=attention_mask,
|
| 411 |
+
position_ids=position_ids,
|
| 412 |
+
packed_metadata=packed_metadata,
|
| 413 |
+
layer_past=layer_past,
|
| 414 |
+
use_cache=use_cache,
|
| 415 |
+
)
|
| 416 |
+
hidden_states = hidden_states + attn_output
|
| 417 |
+
hidden_states = hidden_states + self.mlp(self.ln_2(hidden_states))
|
| 418 |
+
return hidden_states, present
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
class MossTTSNanoGPT2Model(nn.Module):
|
| 422 |
+
def __init__(self, config: GPT2Config, attn_implementation: str = "eager") -> None:
|
| 423 |
+
super().__init__()
|
| 424 |
+
self.config = config
|
| 425 |
+
self.attn_implementation = attn_implementation
|
| 426 |
+
self.position_embedding_type = str(getattr(config, "position_embedding_type", "absolute")).lower()
|
| 427 |
+
if self.position_embedding_type not in {"absolute", "rope"}:
|
| 428 |
+
raise ValueError(f"Unsupported position_embedding_type={self.position_embedding_type!r}")
|
| 429 |
+
hidden_size = int(config.hidden_size)
|
| 430 |
+
self.wte = nn.Embedding(config.vocab_size, hidden_size)
|
| 431 |
+
self.wpe = nn.Embedding(config.n_positions, hidden_size) if self.position_embedding_type == "absolute" else nn.Identity()
|
| 432 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
| 433 |
+
self.h = nn.ModuleList(
|
| 434 |
+
[MossTTSNanoGPT2Block(config, layer_idx=index, attn_implementation=attn_implementation) for index in range(config.n_layer)]
|
| 435 |
+
)
|
| 436 |
+
self.ln_f = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 437 |
+
self.gradient_checkpointing = False
|
| 438 |
+
self._reset_parameters()
|
| 439 |
+
|
| 440 |
+
def _reset_parameters(self) -> None:
|
| 441 |
+
init_std = float(self.config.initializer_range)
|
| 442 |
+
for module in self.modules():
|
| 443 |
+
if isinstance(module, nn.Linear):
|
| 444 |
+
nn.init.normal_(module.weight, mean=0.0, std=init_std)
|
| 445 |
+
if module.bias is not None:
|
| 446 |
+
nn.init.zeros_(module.bias)
|
| 447 |
+
elif isinstance(module, nn.Embedding):
|
| 448 |
+
nn.init.normal_(module.weight, mean=0.0, std=init_std)
|
| 449 |
+
elif isinstance(module, nn.LayerNorm):
|
| 450 |
+
nn.init.ones_(module.weight)
|
| 451 |
+
nn.init.zeros_(module.bias)
|
| 452 |
+
|
| 453 |
+
@staticmethod
|
| 454 |
+
def _normalize_num_sequences(
|
| 455 |
+
cu_seqlens: torch.Tensor,
|
| 456 |
+
num_sequences: Optional[torch.Tensor],
|
| 457 |
+
device: torch.device,
|
| 458 |
+
) -> torch.Tensor:
|
| 459 |
+
if cu_seqlens.ndim == 1:
|
| 460 |
+
cu_seqlens = cu_seqlens.unsqueeze(0)
|
| 461 |
+
if num_sequences is None:
|
| 462 |
+
diffs = cu_seqlens[:, 1:] - cu_seqlens[:, :-1]
|
| 463 |
+
return diffs.gt(0).sum(dim=-1).to(device=device, dtype=torch.long)
|
| 464 |
+
if num_sequences.ndim == 0:
|
| 465 |
+
num_sequences = num_sequences.unsqueeze(0)
|
| 466 |
+
return num_sequences.to(device=device, dtype=torch.long)
|
| 467 |
+
|
| 468 |
+
@staticmethod
|
| 469 |
+
def _packed_segments_from_cu_seqlens(
|
| 470 |
+
cu_seqlens: torch.Tensor,
|
| 471 |
+
num_sequences: Optional[torch.Tensor],
|
| 472 |
+
device: torch.device,
|
| 473 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 474 |
+
if cu_seqlens.ndim == 1:
|
| 475 |
+
cu_seqlens = cu_seqlens.unsqueeze(0)
|
| 476 |
+
cu_seqlens = cu_seqlens.to(device=device)
|
| 477 |
+
batch_size, boundary_count = cu_seqlens.shape
|
| 478 |
+
segment_slots = boundary_count - 1
|
| 479 |
+
if segment_slots <= 0:
|
| 480 |
+
empty = torch.empty(0, dtype=torch.long, device=device)
|
| 481 |
+
return empty, empty, empty
|
| 482 |
+
|
| 483 |
+
counts = MossTTSNanoGPT2Model._normalize_num_sequences(cu_seqlens, num_sequences, device=device)
|
| 484 |
+
counts = counts.clamp(min=0, max=segment_slots)
|
| 485 |
+
segment_slots = int(counts.max().item()) if counts.numel() > 0 else 0
|
| 486 |
+
if segment_slots <= 0:
|
| 487 |
+
empty = torch.empty(0, dtype=torch.long, device=device)
|
| 488 |
+
return empty, empty, empty
|
| 489 |
+
cu_seqlens = cu_seqlens[:, : segment_slots + 1]
|
| 490 |
+
slot_ids = torch.arange(segment_slots, device=device).unsqueeze(0)
|
| 491 |
+
valid_slots = slot_ids < counts.unsqueeze(1)
|
| 492 |
+
|
| 493 |
+
starts = cu_seqlens[:, :-1].to(dtype=torch.long)
|
| 494 |
+
ends = cu_seqlens[:, 1:].to(dtype=torch.long)
|
| 495 |
+
lengths = (ends - starts).clamp_min(0)
|
| 496 |
+
lengths = torch.where(valid_slots, lengths, torch.zeros((), dtype=torch.long, device=device))
|
| 497 |
+
|
| 498 |
+
batch_ids = torch.arange(batch_size, device=device, dtype=torch.long).unsqueeze(1).expand(batch_size, segment_slots)
|
| 499 |
+
batch_ids = batch_ids.reshape(-1)
|
| 500 |
+
starts = starts.reshape(-1)
|
| 501 |
+
lengths = lengths.reshape(-1)
|
| 502 |
+
|
| 503 |
+
valid_segments = lengths.gt(0)
|
| 504 |
+
valid_count = int(valid_segments.to(dtype=torch.long).sum().item())
|
| 505 |
+
if valid_count <= 0:
|
| 506 |
+
empty = torch.empty(0, dtype=torch.long, device=device)
|
| 507 |
+
return empty, empty, empty
|
| 508 |
+
if valid_count == lengths.numel():
|
| 509 |
+
return batch_ids, starts, lengths
|
| 510 |
+
|
| 511 |
+
valid_order = torch.argsort(valid_segments.to(dtype=torch.long), descending=True, stable=True)[:valid_count]
|
| 512 |
+
return (
|
| 513 |
+
batch_ids.index_select(0, valid_order),
|
| 514 |
+
starts.index_select(0, valid_order),
|
| 515 |
+
lengths.index_select(0, valid_order),
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
@staticmethod
|
| 519 |
+
def _packed_token_indices(
|
| 520 |
+
batch_ids: torch.Tensor,
|
| 521 |
+
starts: torch.Tensor,
|
| 522 |
+
lengths: torch.Tensor,
|
| 523 |
+
seq_len: int,
|
| 524 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 525 |
+
total_tokens = int(lengths.sum().item())
|
| 526 |
+
if total_tokens <= 0:
|
| 527 |
+
empty = torch.empty(0, dtype=torch.long, device=lengths.device)
|
| 528 |
+
return empty, empty
|
| 529 |
+
|
| 530 |
+
segment_ids = torch.repeat_interleave(
|
| 531 |
+
torch.arange(lengths.numel(), device=lengths.device, dtype=torch.long),
|
| 532 |
+
lengths,
|
| 533 |
+
output_size=total_tokens,
|
| 534 |
+
)
|
| 535 |
+
segment_starts = torch.cumsum(lengths, dim=0) - lengths
|
| 536 |
+
positions = torch.arange(total_tokens, device=lengths.device, dtype=torch.long) - segment_starts[segment_ids]
|
| 537 |
+
indices = batch_ids[segment_ids] * seq_len + starts[segment_ids] + positions
|
| 538 |
+
return indices, positions
|
| 539 |
+
|
| 540 |
+
@staticmethod
|
| 541 |
+
def build_packed_position_ids(
|
| 542 |
+
attention_mask: Optional[torch.Tensor],
|
| 543 |
+
cu_seqlens: torch.Tensor,
|
| 544 |
+
num_sequences: Optional[torch.Tensor],
|
| 545 |
+
sequence_length: Optional[int] = None,
|
| 546 |
+
) -> torch.Tensor:
|
| 547 |
+
if cu_seqlens.ndim == 1:
|
| 548 |
+
cu_seqlens = cu_seqlens.unsqueeze(0)
|
| 549 |
+
batch_size = cu_seqlens.shape[0]
|
| 550 |
+
seq_len = int(sequence_length or (cu_seqlens.shape[1] - 1))
|
| 551 |
+
device = cu_seqlens.device
|
| 552 |
+
position_ids = torch.zeros((batch_size, seq_len), dtype=torch.long, device=device)
|
| 553 |
+
batch_ids, starts, lengths = MossTTSNanoGPT2Model._packed_segments_from_cu_seqlens(
|
| 554 |
+
cu_seqlens,
|
| 555 |
+
num_sequences,
|
| 556 |
+
device,
|
| 557 |
+
)
|
| 558 |
+
if lengths.numel() > 0:
|
| 559 |
+
indices, positions = MossTTSNanoGPT2Model._packed_token_indices(batch_ids, starts, lengths, seq_len)
|
| 560 |
+
position_ids.view(-1).scatter_(0, indices, positions)
|
| 561 |
+
if attention_mask is not None:
|
| 562 |
+
position_ids = position_ids * attention_mask.to(dtype=position_ids.dtype)
|
| 563 |
+
return position_ids
|
| 564 |
+
|
| 565 |
+
@staticmethod
|
| 566 |
+
def build_packed_metadata(
|
| 567 |
+
hidden_states: torch.Tensor,
|
| 568 |
+
cu_seqlens: torch.Tensor,
|
| 569 |
+
num_sequences: Optional[torch.Tensor],
|
| 570 |
+
) -> PackedSequenceMetadata:
|
| 571 |
+
if cu_seqlens.ndim == 1:
|
| 572 |
+
cu_seqlens = cu_seqlens.unsqueeze(0)
|
| 573 |
+
device = hidden_states.device
|
| 574 |
+
seq_len = hidden_states.shape[1]
|
| 575 |
+
batch_ids, starts, lengths = MossTTSNanoGPT2Model._packed_segments_from_cu_seqlens(
|
| 576 |
+
cu_seqlens,
|
| 577 |
+
num_sequences,
|
| 578 |
+
device,
|
| 579 |
+
)
|
| 580 |
+
if lengths.numel() == 0:
|
| 581 |
+
raise ValueError("cu_seqlens did not describe any non-empty packed sequences.")
|
| 582 |
+
|
| 583 |
+
indices, _ = MossTTSNanoGPT2Model._packed_token_indices(batch_ids, starts, lengths, seq_len)
|
| 584 |
+
cumulative = torch.empty(lengths.numel() + 1, dtype=torch.int32, device=device)
|
| 585 |
+
cumulative[0] = 0
|
| 586 |
+
cumulative[1:] = lengths.to(dtype=torch.int32).cumsum(dim=0)
|
| 587 |
+
return PackedSequenceMetadata(
|
| 588 |
+
cu_seqlens=cumulative,
|
| 589 |
+
max_seqlen=int(lengths.max().item()),
|
| 590 |
+
indices=indices,
|
| 591 |
+
batch_size=hidden_states.shape[0],
|
| 592 |
+
seq_len=hidden_states.shape[1],
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
def forward(
|
| 596 |
+
self,
|
| 597 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 598 |
+
past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 599 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 600 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 601 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 602 |
+
use_cache: Optional[bool] = None,
|
| 603 |
+
output_attentions: Optional[bool] = None,
|
| 604 |
+
output_hidden_states: Optional[bool] = None,
|
| 605 |
+
return_dict: bool = True,
|
| 606 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 607 |
+
num_sequences: Optional[torch.Tensor] = None,
|
| 608 |
+
) -> BaseModelOutputWithPast:
|
| 609 |
+
del input_ids, output_attentions
|
| 610 |
+
|
| 611 |
+
if inputs_embeds is None:
|
| 612 |
+
raise ValueError("inputs_embeds must be provided.")
|
| 613 |
+
|
| 614 |
+
use_cache = bool(use_cache)
|
| 615 |
+
if use_cache and cu_seqlens is not None:
|
| 616 |
+
raise ValueError("use_cache=True is not supported together with cu_seqlens packing.")
|
| 617 |
+
|
| 618 |
+
hidden_states = inputs_embeds
|
| 619 |
+
query_attention_mask = None
|
| 620 |
+
if attention_mask is not None:
|
| 621 |
+
attention_mask = attention_mask.to(dtype=torch.bool, device=hidden_states.device)
|
| 622 |
+
query_attention_mask = attention_mask[:, -hidden_states.shape[1] :]
|
| 623 |
+
|
| 624 |
+
packed_metadata = None
|
| 625 |
+
if position_ids is None:
|
| 626 |
+
if cu_seqlens is not None:
|
| 627 |
+
if attention_mask is None:
|
| 628 |
+
raise ValueError("attention_mask must be provided with cu_seqlens packing.")
|
| 629 |
+
position_ids = self.build_packed_position_ids(
|
| 630 |
+
attention_mask=attention_mask,
|
| 631 |
+
cu_seqlens=cu_seqlens.to(device=hidden_states.device),
|
| 632 |
+
num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
|
| 633 |
+
sequence_length=hidden_states.shape[1],
|
| 634 |
+
)
|
| 635 |
+
elif attention_mask is not None:
|
| 636 |
+
position_ids = attention_mask.long().cumsum(dim=-1) - 1
|
| 637 |
+
position_ids = position_ids.masked_fill(~attention_mask, 0)
|
| 638 |
+
position_ids = position_ids[:, -hidden_states.shape[1] :]
|
| 639 |
+
else:
|
| 640 |
+
past_length = 0
|
| 641 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
| 642 |
+
first_layer_past = past_key_values[0]
|
| 643 |
+
if _is_static_kv_cache_layer(first_layer_past):
|
| 644 |
+
past_length = int(first_layer_past.get("length", 0))
|
| 645 |
+
else:
|
| 646 |
+
past_length = first_layer_past[0].shape[1]
|
| 647 |
+
position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device, dtype=torch.long)
|
| 648 |
+
position_ids = position_ids + past_length
|
| 649 |
+
position_ids = position_ids.unsqueeze(0).expand(hidden_states.shape[0], -1)
|
| 650 |
+
|
| 651 |
+
if cu_seqlens is not None and self.attn_implementation == "flash_attention_2":
|
| 652 |
+
packed_metadata = self.build_packed_metadata(
|
| 653 |
+
hidden_states=hidden_states,
|
| 654 |
+
cu_seqlens=cu_seqlens.to(device=hidden_states.device),
|
| 655 |
+
num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
if self.position_embedding_type == "absolute":
|
| 659 |
+
hidden_states = hidden_states + self.wpe(position_ids)
|
| 660 |
+
hidden_states = self.drop(hidden_states)
|
| 661 |
+
if query_attention_mask is not None:
|
| 662 |
+
hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
| 663 |
+
|
| 664 |
+
all_hidden_states = () if output_hidden_states else None
|
| 665 |
+
presents = [] if use_cache else None
|
| 666 |
+
for layer_index, block in enumerate(self.h):
|
| 667 |
+
if output_hidden_states:
|
| 668 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 669 |
+
|
| 670 |
+
if self.gradient_checkpointing and self.training:
|
| 671 |
+
if use_cache:
|
| 672 |
+
raise ValueError("use_cache=True is not supported when gradient checkpointing is enabled during training.")
|
| 673 |
+
|
| 674 |
+
def custom_forward(*inputs):
|
| 675 |
+
output, _ = block(
|
| 676 |
+
inputs[0],
|
| 677 |
+
attention_mask=inputs[1],
|
| 678 |
+
position_ids=inputs[2],
|
| 679 |
+
packed_metadata=packed_metadata,
|
| 680 |
+
layer_past=None,
|
| 681 |
+
use_cache=False,
|
| 682 |
+
)
|
| 683 |
+
return output
|
| 684 |
+
|
| 685 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 686 |
+
custom_forward,
|
| 687 |
+
hidden_states,
|
| 688 |
+
attention_mask,
|
| 689 |
+
position_ids,
|
| 690 |
+
use_reentrant=False,
|
| 691 |
+
)
|
| 692 |
+
present = None
|
| 693 |
+
else:
|
| 694 |
+
hidden_states, present = block(
|
| 695 |
+
hidden_states,
|
| 696 |
+
attention_mask=attention_mask,
|
| 697 |
+
position_ids=position_ids,
|
| 698 |
+
packed_metadata=packed_metadata,
|
| 699 |
+
layer_past=None if past_key_values is None else past_key_values[layer_index],
|
| 700 |
+
use_cache=use_cache,
|
| 701 |
+
)
|
| 702 |
+
if query_attention_mask is not None:
|
| 703 |
+
hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
| 704 |
+
if presents is not None:
|
| 705 |
+
presents.append(present)
|
| 706 |
+
|
| 707 |
+
hidden_states = self.ln_f(hidden_states)
|
| 708 |
+
if query_attention_mask is not None:
|
| 709 |
+
hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
| 710 |
+
if output_hidden_states:
|
| 711 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 712 |
+
|
| 713 |
+
if not return_dict:
|
| 714 |
+
return (hidden_states, tuple(presents) if presents is not None else None, all_hidden_states, None)
|
| 715 |
+
|
| 716 |
+
return BaseModelOutputWithPast(
|
| 717 |
+
last_hidden_state=hidden_states,
|
| 718 |
+
past_key_values=tuple(presents) if presents is not None else None,
|
| 719 |
+
hidden_states=all_hidden_states,
|
| 720 |
+
attentions=None,
|
| 721 |
+
)
|
merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:608f1ff64bc6caa9be836060fc7c78a15c4658c4a07b8d73c78d6f70d1b39c23
|
| 3 |
+
size 9100859544
|
modeling_moss_tts.py
ADDED
|
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
"""Modeling code for the MOSS-TTS-Local-Transformer-v1.5 HuggingFace release."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from typing import Any, Optional, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 12 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 13 |
+
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
| 14 |
+
from transformers.utils import ModelOutput
|
| 15 |
+
|
| 16 |
+
from .configuration_moss_tts import MossTTSLocalConfig
|
| 17 |
+
from .gpt2_decoder import MossTTSNanoGPT2Model
|
| 18 |
+
from .qwen3_decoder import MossQwen3Model
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class MossTTSLocalOutput(ModelOutput):
|
| 23 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
| 24 |
+
past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None
|
| 25 |
+
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 26 |
+
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _find_last_equal(input_ids: torch.LongTensor, value: int) -> torch.LongTensor:
|
| 30 |
+
matches = input_ids.eq(int(value))
|
| 31 |
+
if not bool(matches.any(dim=1).all().item()):
|
| 32 |
+
raise ValueError(f"Every sample must contain token id {int(value)}.")
|
| 33 |
+
positions = torch.arange(input_ids.shape[1], device=input_ids.device, dtype=torch.long)
|
| 34 |
+
masked_positions = positions.unsqueeze(0).masked_fill(~matches, -1)
|
| 35 |
+
return masked_positions.max(dim=1).values
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class MossTTSLocalPreTrainedModel(PreTrainedModel):
|
| 39 |
+
config_class = MossTTSLocalConfig
|
| 40 |
+
base_model_prefix = "transformer"
|
| 41 |
+
supports_gradient_checkpointing = True
|
| 42 |
+
_no_split_modules = ["MossTTSNanoGPT2Block", "MossQwen3DecoderLayer"]
|
| 43 |
+
_supports_flash_attn_2 = True
|
| 44 |
+
_supports_sdpa = True
|
| 45 |
+
_supports_cache_class = True
|
| 46 |
+
|
| 47 |
+
def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False) -> None:
|
| 48 |
+
if isinstance(module, MossTTSNanoGPT2Model) or isinstance(module, MossQwen3Model):
|
| 49 |
+
module.gradient_checkpointing = value
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class MossTTSLocalModel(MossTTSLocalPreTrainedModel):
|
| 53 |
+
_tied_weights_keys = None
|
| 54 |
+
|
| 55 |
+
def __init__(self, config: MossTTSLocalConfig) -> None:
|
| 56 |
+
super().__init__(config)
|
| 57 |
+
self._tied_weights_keys = self._build_tied_weights_keys(config)
|
| 58 |
+
|
| 59 |
+
config.qwen3_config.pad_token_id = config.pad_token_id
|
| 60 |
+
config.qwen3_config._attn_implementation = config.attn_implementation
|
| 61 |
+
local_gpt2_config = config.gpt2_config.to_dict()
|
| 62 |
+
local_gpt2_config["n_layer"] = int(getattr(config, "local_transformer_layers", config.gpt2_config.n_layer))
|
| 63 |
+
local_gpt2_config["n_positions"] = int(config.n_vq) + 1
|
| 64 |
+
local_gpt2_config["n_ctx"] = int(config.n_vq) + 1
|
| 65 |
+
local_gpt2_config = GPT2Config(**local_gpt2_config)
|
| 66 |
+
local_gpt2_config.pad_token_id = config.pad_token_id
|
| 67 |
+
local_gpt2_config._attn_implementation = config.local_transformer_attn_implementation
|
| 68 |
+
|
| 69 |
+
self.transformer = MossQwen3Model(config.qwen3_config)
|
| 70 |
+
self.local_transformer = MossTTSNanoGPT2Model(
|
| 71 |
+
local_gpt2_config,
|
| 72 |
+
attn_implementation=config.local_transformer_attn_implementation,
|
| 73 |
+
)
|
| 74 |
+
self.local_transformer.wte = nn.Identity()
|
| 75 |
+
|
| 76 |
+
hidden_size = int(config.hidden_size)
|
| 77 |
+
self.audio_embeddings = nn.ModuleList(
|
| 78 |
+
[
|
| 79 |
+
nn.Embedding(int(config.audio_codebook_sizes[index]), hidden_size)
|
| 80 |
+
for index in range(config.n_vq)
|
| 81 |
+
]
|
| 82 |
+
)
|
| 83 |
+
self.text_lm_head = nn.Linear(hidden_size, int(config.vocab_size), bias=False)
|
| 84 |
+
self.audio_lm_heads = nn.ModuleList(
|
| 85 |
+
[
|
| 86 |
+
nn.Linear(hidden_size, int(config.audio_codebook_sizes[index]), bias=False)
|
| 87 |
+
for index in range(config.n_vq)
|
| 88 |
+
]
|
| 89 |
+
)
|
| 90 |
+
self.local_text_lm_head = (
|
| 91 |
+
nn.Linear(hidden_size, 2, bias=False)
|
| 92 |
+
if self._use_binary_local_text_head()
|
| 93 |
+
else None
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
self.post_init()
|
| 97 |
+
self.tie_weights()
|
| 98 |
+
self.initialize_local_text_lm_head_from_text_lm_head()
|
| 99 |
+
|
| 100 |
+
def can_generate(self) -> bool:
|
| 101 |
+
return True
|
| 102 |
+
|
| 103 |
+
@staticmethod
|
| 104 |
+
def _build_tied_weights_keys(config: MossTTSLocalConfig) -> dict[str, str]:
|
| 105 |
+
tied_weights = {"text_lm_head.weight": "transformer.embed_tokens.weight"}
|
| 106 |
+
tied_weights.update(
|
| 107 |
+
{
|
| 108 |
+
f"audio_lm_heads.{index}.weight": f"audio_embeddings.{index}.weight"
|
| 109 |
+
for index in range(config.n_vq)
|
| 110 |
+
}
|
| 111 |
+
)
|
| 112 |
+
return tied_weights
|
| 113 |
+
|
| 114 |
+
def tie_weights(self, *args, **kwargs) -> None:
|
| 115 |
+
del args, kwargs
|
| 116 |
+
self.text_lm_head.weight = self.transformer.embed_tokens.weight
|
| 117 |
+
for embedding, head in zip(self.audio_embeddings, self.audio_lm_heads):
|
| 118 |
+
head.weight = embedding.weight
|
| 119 |
+
|
| 120 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 121 |
+
return self.transformer.embed_tokens
|
| 122 |
+
|
| 123 |
+
def set_input_embeddings(self, value: nn.Embedding) -> None:
|
| 124 |
+
self.transformer.embed_tokens = value
|
| 125 |
+
self.tie_weights()
|
| 126 |
+
self.initialize_local_text_lm_head_from_text_lm_head()
|
| 127 |
+
|
| 128 |
+
def get_output_embeddings(self) -> nn.Linear:
|
| 129 |
+
return self.text_lm_head
|
| 130 |
+
|
| 131 |
+
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
|
| 132 |
+
self.text_lm_head = new_embeddings
|
| 133 |
+
self.tie_weights()
|
| 134 |
+
self.initialize_local_text_lm_head_from_text_lm_head()
|
| 135 |
+
|
| 136 |
+
def _use_binary_local_text_head(self) -> bool:
|
| 137 |
+
return str(getattr(self.config, "local_text_head_mode", "full_vocab")).strip().lower() == "binary"
|
| 138 |
+
|
| 139 |
+
def _local_text_candidate_ids(self, device: torch.device) -> torch.LongTensor:
|
| 140 |
+
return torch.tensor(
|
| 141 |
+
[
|
| 142 |
+
int(self.config.audio_assistant_slot_token_id),
|
| 143 |
+
int(self.config.audio_end_token_id),
|
| 144 |
+
],
|
| 145 |
+
dtype=torch.long,
|
| 146 |
+
device=device,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def initialize_local_text_lm_head_from_text_lm_head(self) -> None:
|
| 150 |
+
if not self._use_binary_local_text_head() or self.local_text_lm_head is None:
|
| 151 |
+
return
|
| 152 |
+
candidate_ids = self._local_text_candidate_ids(self.text_lm_head.weight.device)
|
| 153 |
+
with torch.no_grad():
|
| 154 |
+
source_weight = self.text_lm_head.weight.index_select(0, candidate_ids)
|
| 155 |
+
if tuple(source_weight.shape) == tuple(self.local_text_lm_head.weight.shape):
|
| 156 |
+
self.local_text_lm_head.weight.copy_(
|
| 157 |
+
source_weight.to(
|
| 158 |
+
device=self.local_text_lm_head.weight.device,
|
| 159 |
+
dtype=self.local_text_lm_head.weight.dtype,
|
| 160 |
+
)
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def _resolve_fixed_nq(
|
| 164 |
+
self,
|
| 165 |
+
n_vq_for_inference: Optional[int] = None,
|
| 166 |
+
nq: Optional[int] = None,
|
| 167 |
+
) -> int:
|
| 168 |
+
requested = n_vq_for_inference if n_vq_for_inference is not None else nq
|
| 169 |
+
config_nq = int(self.config.n_vq)
|
| 170 |
+
if requested is not None and int(requested) != config_nq:
|
| 171 |
+
raise ValueError(
|
| 172 |
+
"This MOSS-TTS-Local-Transformer-v1.5 release is trained with a fixed RVQ depth. "
|
| 173 |
+
f"Expected n_vq={config_nq}, got {int(requested)}."
|
| 174 |
+
)
|
| 175 |
+
return config_nq
|
| 176 |
+
|
| 177 |
+
def _build_inputs_embeds(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
|
| 178 |
+
if input_ids.ndim != 3 or input_ids.shape[-1] != self.config.n_vq + 1:
|
| 179 |
+
raise ValueError(
|
| 180 |
+
f"Expected input_ids shape [batch, seq, {self.config.n_vq + 1}], "
|
| 181 |
+
f"got {tuple(input_ids.shape)}."
|
| 182 |
+
)
|
| 183 |
+
text_ids = input_ids[..., 0]
|
| 184 |
+
inputs_embeds = self.transformer.embed_tokens(text_ids)
|
| 185 |
+
for channel_index, embedding in enumerate(self.audio_embeddings):
|
| 186 |
+
channel_ids = input_ids[..., channel_index + 1]
|
| 187 |
+
valid_mask = channel_ids.ne(self.config.audio_pad_token_id)
|
| 188 |
+
safe_ids = channel_ids.masked_fill(~valid_mask, 0)
|
| 189 |
+
audio_embeds = embedding(safe_ids) * valid_mask.unsqueeze(-1)
|
| 190 |
+
inputs_embeds = inputs_embeds + audio_embeds
|
| 191 |
+
return inputs_embeds
|
| 192 |
+
|
| 193 |
+
def _global_hidden_to_local(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
| 194 |
+
return hidden_states
|
| 195 |
+
|
| 196 |
+
@staticmethod
|
| 197 |
+
def _local_past_length(past_key_values: Optional[tuple[Any, ...]]) -> int:
|
| 198 |
+
if past_key_values is None or len(past_key_values) == 0:
|
| 199 |
+
return 0
|
| 200 |
+
first_layer_past = past_key_values[0]
|
| 201 |
+
if isinstance(first_layer_past, dict) and bool(first_layer_past.get("static_kv_cache", False)):
|
| 202 |
+
return int(first_layer_past.get("length", 0))
|
| 203 |
+
return int(first_layer_past[0].shape[1])
|
| 204 |
+
|
| 205 |
+
def _new_static_local_past_key_values(
|
| 206 |
+
self,
|
| 207 |
+
batch_size: int,
|
| 208 |
+
max_length: int,
|
| 209 |
+
device: torch.device,
|
| 210 |
+
dtype: torch.dtype,
|
| 211 |
+
) -> tuple[dict[str, Any], ...]:
|
| 212 |
+
layers = []
|
| 213 |
+
for block in self.local_transformer.h:
|
| 214 |
+
attn = block.attn
|
| 215 |
+
cache_shape = (
|
| 216 |
+
int(batch_size),
|
| 217 |
+
int(max_length),
|
| 218 |
+
int(attn.num_heads),
|
| 219 |
+
int(attn.head_dim),
|
| 220 |
+
)
|
| 221 |
+
layers.append(
|
| 222 |
+
{
|
| 223 |
+
"static_kv_cache": True,
|
| 224 |
+
"key": torch.empty(cache_shape, device=device, dtype=dtype),
|
| 225 |
+
"value": torch.empty(cache_shape, device=device, dtype=dtype),
|
| 226 |
+
"length": 0,
|
| 227 |
+
}
|
| 228 |
+
)
|
| 229 |
+
return tuple(layers)
|
| 230 |
+
|
| 231 |
+
def _decode_local_hidden_states_with_cache(
|
| 232 |
+
self,
|
| 233 |
+
local_inputs_embeds: torch.FloatTensor,
|
| 234 |
+
past_key_values: Optional[tuple[Any, ...]] = None,
|
| 235 |
+
) -> tuple[torch.FloatTensor, Optional[tuple[Any, ...]]]:
|
| 236 |
+
if (
|
| 237 |
+
past_key_values is None
|
| 238 |
+
and not self.training
|
| 239 |
+
and bool(getattr(self.config, "use_static_local_kv_cache", True))
|
| 240 |
+
):
|
| 241 |
+
max_length = max(int(getattr(self.config, "n_vq", 0)) + 1, int(local_inputs_embeds.shape[1]))
|
| 242 |
+
past_key_values = self._new_static_local_past_key_values(
|
| 243 |
+
batch_size=int(local_inputs_embeds.shape[0]),
|
| 244 |
+
max_length=max_length,
|
| 245 |
+
device=local_inputs_embeds.device,
|
| 246 |
+
dtype=local_inputs_embeds.dtype,
|
| 247 |
+
)
|
| 248 |
+
past_length = self._local_past_length(past_key_values)
|
| 249 |
+
local_seq_len = int(local_inputs_embeds.shape[1])
|
| 250 |
+
local_position_ids = torch.arange(
|
| 251 |
+
past_length,
|
| 252 |
+
past_length + local_seq_len,
|
| 253 |
+
device=local_inputs_embeds.device,
|
| 254 |
+
dtype=torch.long,
|
| 255 |
+
).unsqueeze(0)
|
| 256 |
+
if int(local_inputs_embeds.shape[0]) != 1:
|
| 257 |
+
local_position_ids = local_position_ids.expand(int(local_inputs_embeds.shape[0]), -1)
|
| 258 |
+
local_outputs = self.local_transformer(
|
| 259 |
+
input_ids=None,
|
| 260 |
+
past_key_values=past_key_values,
|
| 261 |
+
attention_mask=None,
|
| 262 |
+
position_ids=local_position_ids,
|
| 263 |
+
inputs_embeds=local_inputs_embeds,
|
| 264 |
+
use_cache=True,
|
| 265 |
+
output_attentions=False,
|
| 266 |
+
output_hidden_states=False,
|
| 267 |
+
return_dict=True,
|
| 268 |
+
cu_seqlens=None,
|
| 269 |
+
num_sequences=None,
|
| 270 |
+
)
|
| 271 |
+
return local_outputs.last_hidden_state, local_outputs.past_key_values
|
| 272 |
+
|
| 273 |
+
def forward(
|
| 274 |
+
self,
|
| 275 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 276 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 277 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 278 |
+
past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 279 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 280 |
+
use_cache: Optional[bool] = None,
|
| 281 |
+
output_attentions: Optional[bool] = None,
|
| 282 |
+
output_hidden_states: Optional[bool] = None,
|
| 283 |
+
return_dict: Optional[bool] = True,
|
| 284 |
+
**kwargs,
|
| 285 |
+
) -> Union[tuple, MossTTSLocalOutput]:
|
| 286 |
+
del kwargs
|
| 287 |
+
if inputs_embeds is None:
|
| 288 |
+
if input_ids is None:
|
| 289 |
+
raise ValueError("Either input_ids or inputs_embeds must be provided.")
|
| 290 |
+
inputs_embeds = self._build_inputs_embeds(input_ids)
|
| 291 |
+
outputs = self.transformer(
|
| 292 |
+
input_ids=None,
|
| 293 |
+
attention_mask=attention_mask,
|
| 294 |
+
position_ids=position_ids,
|
| 295 |
+
past_key_values=past_key_values,
|
| 296 |
+
inputs_embeds=inputs_embeds,
|
| 297 |
+
use_cache=use_cache,
|
| 298 |
+
output_attentions=output_attentions,
|
| 299 |
+
output_hidden_states=output_hidden_states,
|
| 300 |
+
return_dict=True,
|
| 301 |
+
cu_seqlens=None,
|
| 302 |
+
num_sequences=None,
|
| 303 |
+
)
|
| 304 |
+
if not return_dict:
|
| 305 |
+
return (
|
| 306 |
+
outputs.last_hidden_state,
|
| 307 |
+
outputs.past_key_values,
|
| 308 |
+
outputs.hidden_states,
|
| 309 |
+
outputs.attentions,
|
| 310 |
+
)
|
| 311 |
+
return MossTTSLocalOutput(
|
| 312 |
+
last_hidden_state=outputs.last_hidden_state,
|
| 313 |
+
past_key_values=outputs.past_key_values,
|
| 314 |
+
hidden_states=outputs.hidden_states,
|
| 315 |
+
attentions=outputs.attentions,
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
def _decode_local_last_hidden_state(
|
| 319 |
+
self,
|
| 320 |
+
local_inputs_embeds: torch.FloatTensor,
|
| 321 |
+
) -> torch.FloatTensor:
|
| 322 |
+
local_seq_len = int(local_inputs_embeds.shape[1])
|
| 323 |
+
local_position_ids = torch.arange(
|
| 324 |
+
0,
|
| 325 |
+
local_seq_len,
|
| 326 |
+
device=local_inputs_embeds.device,
|
| 327 |
+
dtype=torch.long,
|
| 328 |
+
).unsqueeze(0)
|
| 329 |
+
if int(local_inputs_embeds.shape[0]) != 1:
|
| 330 |
+
local_position_ids = local_position_ids.expand(int(local_inputs_embeds.shape[0]), -1)
|
| 331 |
+
local_outputs = self.local_transformer(
|
| 332 |
+
input_ids=None,
|
| 333 |
+
attention_mask=None,
|
| 334 |
+
position_ids=local_position_ids,
|
| 335 |
+
inputs_embeds=local_inputs_embeds,
|
| 336 |
+
use_cache=False,
|
| 337 |
+
output_attentions=False,
|
| 338 |
+
output_hidden_states=False,
|
| 339 |
+
return_dict=True,
|
| 340 |
+
cu_seqlens=None,
|
| 341 |
+
num_sequences=None,
|
| 342 |
+
)
|
| 343 |
+
return local_outputs.last_hidden_state[:, -1, :]
|
| 344 |
+
|
| 345 |
+
def _filter_logits(
|
| 346 |
+
self,
|
| 347 |
+
logits: torch.FloatTensor,
|
| 348 |
+
top_k: Optional[int],
|
| 349 |
+
top_p: Optional[float],
|
| 350 |
+
) -> torch.FloatTensor:
|
| 351 |
+
scores = logits
|
| 352 |
+
if top_k is not None and int(top_k) > 0 and int(top_k) < scores.shape[-1]:
|
| 353 |
+
kth = torch.topk(scores, int(top_k), dim=-1).values[..., -1, None]
|
| 354 |
+
scores = scores.masked_fill(scores < kth, -torch.inf)
|
| 355 |
+
if top_p is not None and 0.0 < float(top_p) < 1.0:
|
| 356 |
+
sorted_scores, sorted_indices = torch.sort(scores, descending=True, dim=-1)
|
| 357 |
+
sorted_probs = torch.softmax(sorted_scores, dim=-1)
|
| 358 |
+
cumulative_probs = sorted_probs.cumsum(dim=-1)
|
| 359 |
+
sorted_mask = cumulative_probs > float(top_p)
|
| 360 |
+
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
|
| 361 |
+
sorted_mask[..., 0] = False
|
| 362 |
+
remove_mask = torch.zeros_like(scores, dtype=torch.bool)
|
| 363 |
+
remove_mask.scatter_(dim=-1, index=sorted_indices, src=sorted_mask)
|
| 364 |
+
scores = scores.masked_fill(remove_mask, -torch.inf)
|
| 365 |
+
return scores
|
| 366 |
+
|
| 367 |
+
def _apply_repetition_penalty(
|
| 368 |
+
self,
|
| 369 |
+
scores: torch.FloatTensor,
|
| 370 |
+
previous_token_ids: Optional[torch.LongTensor],
|
| 371 |
+
penalty: float,
|
| 372 |
+
) -> torch.FloatTensor:
|
| 373 |
+
if previous_token_ids is None or float(penalty) == 1.0:
|
| 374 |
+
return scores
|
| 375 |
+
if previous_token_ids.ndim == 1:
|
| 376 |
+
previous_token_ids = previous_token_ids.unsqueeze(0)
|
| 377 |
+
updated = scores.clone()
|
| 378 |
+
for batch_index in range(updated.shape[0]):
|
| 379 |
+
unique_token_ids = torch.unique(previous_token_ids[batch_index])
|
| 380 |
+
unique_token_ids = unique_token_ids[
|
| 381 |
+
(unique_token_ids >= 0) & (unique_token_ids < updated.shape[-1])
|
| 382 |
+
]
|
| 383 |
+
if unique_token_ids.numel() == 0:
|
| 384 |
+
continue
|
| 385 |
+
token_scores = updated[batch_index].index_select(0, unique_token_ids)
|
| 386 |
+
token_scores = torch.where(
|
| 387 |
+
token_scores < 0,
|
| 388 |
+
token_scores * float(penalty),
|
| 389 |
+
token_scores / float(penalty),
|
| 390 |
+
)
|
| 391 |
+
updated[batch_index].scatter_(0, unique_token_ids, token_scores)
|
| 392 |
+
return updated
|
| 393 |
+
|
| 394 |
+
def _sample_next_token(
|
| 395 |
+
self,
|
| 396 |
+
logits: torch.FloatTensor,
|
| 397 |
+
do_sample: bool,
|
| 398 |
+
temperature: float,
|
| 399 |
+
top_k: Optional[int],
|
| 400 |
+
top_p: Optional[float],
|
| 401 |
+
previous_token_ids: Optional[torch.LongTensor] = None,
|
| 402 |
+
repetition_penalty: float = 1.0,
|
| 403 |
+
) -> torch.LongTensor:
|
| 404 |
+
scores = logits.float()
|
| 405 |
+
scores = self._apply_repetition_penalty(scores, previous_token_ids, repetition_penalty)
|
| 406 |
+
if not do_sample:
|
| 407 |
+
return torch.argmax(scores, dim=-1)
|
| 408 |
+
if float(temperature) <= 0:
|
| 409 |
+
raise ValueError("temperature must be positive when do_sample=True.")
|
| 410 |
+
scores = scores / float(temperature)
|
| 411 |
+
scores = self._filter_logits(scores, top_k=top_k, top_p=top_p)
|
| 412 |
+
probs = torch.softmax(scores, dim=-1)
|
| 413 |
+
return torch.multinomial(probs, num_samples=1).squeeze(-1)
|
| 414 |
+
|
| 415 |
+
def _sample_next_assistant_text_token(
|
| 416 |
+
self,
|
| 417 |
+
local_hidden_states: torch.FloatTensor,
|
| 418 |
+
do_sample: bool,
|
| 419 |
+
temperature: float,
|
| 420 |
+
top_k: Optional[int],
|
| 421 |
+
top_p: Optional[float],
|
| 422 |
+
) -> torch.LongTensor:
|
| 423 |
+
if self._use_binary_local_text_head() and self.local_text_lm_head is not None:
|
| 424 |
+
logits = self.local_text_lm_head(local_hidden_states)
|
| 425 |
+
sampled_indices = self._sample_next_token(
|
| 426 |
+
logits=logits,
|
| 427 |
+
do_sample=do_sample,
|
| 428 |
+
temperature=temperature,
|
| 429 |
+
top_k=top_k,
|
| 430 |
+
top_p=top_p,
|
| 431 |
+
)
|
| 432 |
+
candidate_ids = self._local_text_candidate_ids(logits.device)
|
| 433 |
+
return candidate_ids[sampled_indices]
|
| 434 |
+
|
| 435 |
+
candidate_ids = self._local_text_candidate_ids(local_hidden_states.device)
|
| 436 |
+
logits = self.text_lm_head(local_hidden_states).index_select(dim=-1, index=candidate_ids)
|
| 437 |
+
sampled_indices = self._sample_next_token(
|
| 438 |
+
logits=logits,
|
| 439 |
+
do_sample=do_sample,
|
| 440 |
+
temperature=temperature,
|
| 441 |
+
top_k=top_k,
|
| 442 |
+
top_p=top_p,
|
| 443 |
+
)
|
| 444 |
+
return candidate_ids[sampled_indices]
|
| 445 |
+
|
| 446 |
+
def _build_generation_row(
|
| 447 |
+
self,
|
| 448 |
+
batch_size: int,
|
| 449 |
+
device: torch.device,
|
| 450 |
+
audio_token_ids: torch.LongTensor,
|
| 451 |
+
) -> torch.LongTensor:
|
| 452 |
+
row = torch.full(
|
| 453 |
+
(batch_size, 1, self.config.n_vq + 1),
|
| 454 |
+
int(self.config.audio_pad_token_id),
|
| 455 |
+
dtype=torch.long,
|
| 456 |
+
device=device,
|
| 457 |
+
)
|
| 458 |
+
row[:, :, 0] = int(self.config.audio_assistant_slot_token_id)
|
| 459 |
+
row[:, :, 1:] = audio_token_ids.unsqueeze(1)
|
| 460 |
+
return row
|
| 461 |
+
|
| 462 |
+
@torch.inference_mode()
|
| 463 |
+
def generate(
|
| 464 |
+
self,
|
| 465 |
+
input_ids: torch.LongTensor,
|
| 466 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 467 |
+
max_new_tokens: Optional[int] = None,
|
| 468 |
+
max_new_frames: Optional[int] = None,
|
| 469 |
+
do_sample: bool = True,
|
| 470 |
+
text_temperature: float = 1.0,
|
| 471 |
+
text_top_p: float = 1.0,
|
| 472 |
+
text_top_k: int = 50,
|
| 473 |
+
audio_temperature: Optional[float] = None,
|
| 474 |
+
audio_top_p: Optional[float] = None,
|
| 475 |
+
audio_top_k: Optional[int] = None,
|
| 476 |
+
audio_repetition_penalty: Optional[float] = None,
|
| 477 |
+
temperature: float = 1.0,
|
| 478 |
+
top_p: float = 0.95,
|
| 479 |
+
top_k: int = 50,
|
| 480 |
+
repetition_penalty: float = 1.0,
|
| 481 |
+
use_kv_cache: bool = True,
|
| 482 |
+
n_vq_for_inference: Optional[int] = None,
|
| 483 |
+
nq: Optional[int] = None,
|
| 484 |
+
**kwargs,
|
| 485 |
+
) -> list[tuple[int, torch.LongTensor]]:
|
| 486 |
+
del kwargs
|
| 487 |
+
self._resolve_fixed_nq(n_vq_for_inference=n_vq_for_inference, nq=nq)
|
| 488 |
+
|
| 489 |
+
if input_ids.ndim == 2:
|
| 490 |
+
input_ids = input_ids.unsqueeze(0)
|
| 491 |
+
if input_ids.ndim != 3:
|
| 492 |
+
raise ValueError(f"Expected input_ids with 3 dims, got {tuple(input_ids.shape)}.")
|
| 493 |
+
if input_ids.shape[-1] != self.config.n_vq + 1:
|
| 494 |
+
raise ValueError(
|
| 495 |
+
f"Expected {self.config.n_vq + 1} channels from config.n_vq, got {input_ids.shape[-1]}."
|
| 496 |
+
)
|
| 497 |
+
if attention_mask is None:
|
| 498 |
+
attention_mask = torch.ones(input_ids.shape[:2], dtype=torch.bool, device=input_ids.device)
|
| 499 |
+
elif attention_mask.ndim == 1:
|
| 500 |
+
attention_mask = attention_mask.unsqueeze(0)
|
| 501 |
+
attention_mask = attention_mask.to(device=input_ids.device, dtype=torch.bool)
|
| 502 |
+
|
| 503 |
+
frame_budget = max_new_frames if max_new_frames is not None else max_new_tokens
|
| 504 |
+
if frame_budget is None:
|
| 505 |
+
frame_budget = 4096
|
| 506 |
+
frame_budget = int(frame_budget)
|
| 507 |
+
|
| 508 |
+
audio_temperature = float(temperature if audio_temperature is None else audio_temperature)
|
| 509 |
+
audio_top_p = float(top_p if audio_top_p is None else audio_top_p)
|
| 510 |
+
audio_top_k = int(top_k if audio_top_k is None else audio_top_k)
|
| 511 |
+
audio_repetition_penalty = float(
|
| 512 |
+
repetition_penalty if audio_repetition_penalty is None else audio_repetition_penalty
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
batch_size = input_ids.shape[0]
|
| 516 |
+
input_ids_length = input_ids.shape[1]
|
| 517 |
+
current_input_ids = input_ids
|
| 518 |
+
current_attention_mask = attention_mask
|
| 519 |
+
current_model_input_ids = current_input_ids
|
| 520 |
+
generated_frames: list[torch.LongTensor] = []
|
| 521 |
+
finished = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
|
| 522 |
+
past_key_values = None
|
| 523 |
+
local_dtype = self.local_transformer.ln_f.weight.dtype
|
| 524 |
+
|
| 525 |
+
for _ in range(frame_budget):
|
| 526 |
+
generated_audio_history = torch.stack(generated_frames, dim=1) if generated_frames else None
|
| 527 |
+
global_inputs_embeds = self._build_inputs_embeds(current_model_input_ids)
|
| 528 |
+
global_outputs = self.transformer(
|
| 529 |
+
input_ids=None,
|
| 530 |
+
past_key_values=past_key_values,
|
| 531 |
+
attention_mask=current_attention_mask,
|
| 532 |
+
position_ids=None,
|
| 533 |
+
inputs_embeds=global_inputs_embeds,
|
| 534 |
+
use_cache=use_kv_cache,
|
| 535 |
+
output_attentions=False,
|
| 536 |
+
output_hidden_states=False,
|
| 537 |
+
return_dict=True,
|
| 538 |
+
cu_seqlens=None,
|
| 539 |
+
num_sequences=None,
|
| 540 |
+
)
|
| 541 |
+
global_hidden_states = global_outputs.last_hidden_state[:, -1, :]
|
| 542 |
+
local_global_hidden_states = self._global_hidden_to_local(global_hidden_states).to(dtype=local_dtype)
|
| 543 |
+
|
| 544 |
+
local_prefix_hidden_states, local_prefix_past_key_values = self._decode_local_hidden_states_with_cache(
|
| 545 |
+
local_global_hidden_states.unsqueeze(1)
|
| 546 |
+
)
|
| 547 |
+
local_hidden_states = local_prefix_hidden_states[:, -1, :]
|
| 548 |
+
next_text_tokens = self._sample_next_assistant_text_token(
|
| 549 |
+
local_hidden_states=local_hidden_states,
|
| 550 |
+
do_sample=do_sample,
|
| 551 |
+
temperature=text_temperature,
|
| 552 |
+
top_k=text_top_k,
|
| 553 |
+
top_p=text_top_p,
|
| 554 |
+
)
|
| 555 |
+
should_continue = next_text_tokens.eq(int(self.config.audio_assistant_slot_token_id)) & ~finished
|
| 556 |
+
finished = finished | next_text_tokens.eq(int(self.config.audio_end_token_id))
|
| 557 |
+
if not bool(should_continue.any().item()):
|
| 558 |
+
break
|
| 559 |
+
|
| 560 |
+
next_frame_tokens = []
|
| 561 |
+
for channel_index in range(int(self.config.n_vq)):
|
| 562 |
+
channel_logits = self.audio_lm_heads[channel_index](local_hidden_states)
|
| 563 |
+
channel_token = self._sample_next_token(
|
| 564 |
+
logits=channel_logits,
|
| 565 |
+
do_sample=do_sample,
|
| 566 |
+
temperature=audio_temperature,
|
| 567 |
+
top_k=audio_top_k,
|
| 568 |
+
top_p=audio_top_p,
|
| 569 |
+
previous_token_ids=(
|
| 570 |
+
None
|
| 571 |
+
if generated_audio_history is None
|
| 572 |
+
else generated_audio_history[:, :, channel_index]
|
| 573 |
+
),
|
| 574 |
+
repetition_penalty=audio_repetition_penalty,
|
| 575 |
+
)
|
| 576 |
+
next_frame_tokens.append(channel_token)
|
| 577 |
+
if channel_index + 1 < int(self.config.n_vq):
|
| 578 |
+
current_local_input = self.audio_embeddings[channel_index](channel_token).to(dtype=local_dtype)
|
| 579 |
+
local_token_hidden_states, local_prefix_past_key_values = (
|
| 580 |
+
self._decode_local_hidden_states_with_cache(
|
| 581 |
+
current_local_input.unsqueeze(1),
|
| 582 |
+
past_key_values=local_prefix_past_key_values,
|
| 583 |
+
)
|
| 584 |
+
)
|
| 585 |
+
local_hidden_states = local_token_hidden_states[:, -1, :]
|
| 586 |
+
|
| 587 |
+
next_frame = torch.stack(next_frame_tokens, dim=-1)
|
| 588 |
+
next_frame = next_frame.masked_fill(
|
| 589 |
+
~should_continue.unsqueeze(-1),
|
| 590 |
+
int(self.config.audio_pad_token_id),
|
| 591 |
+
)
|
| 592 |
+
generated_frames.append(next_frame)
|
| 593 |
+
|
| 594 |
+
next_row = self._build_generation_row(
|
| 595 |
+
batch_size=batch_size,
|
| 596 |
+
device=input_ids.device,
|
| 597 |
+
audio_token_ids=next_frame,
|
| 598 |
+
)
|
| 599 |
+
if bool((~should_continue).any().item()):
|
| 600 |
+
next_row[~should_continue, 0, 0] = int(self.config.pad_token_id)
|
| 601 |
+
next_row[~should_continue, 0, 1:] = int(self.config.audio_pad_token_id)
|
| 602 |
+
|
| 603 |
+
current_input_ids = torch.cat([current_input_ids, next_row], dim=1)
|
| 604 |
+
current_attention_mask = torch.cat(
|
| 605 |
+
[current_attention_mask, should_continue.unsqueeze(1)],
|
| 606 |
+
dim=1,
|
| 607 |
+
)
|
| 608 |
+
if use_kv_cache:
|
| 609 |
+
current_model_input_ids = next_row
|
| 610 |
+
past_key_values = global_outputs.past_key_values
|
| 611 |
+
else:
|
| 612 |
+
current_model_input_ids = current_input_ids
|
| 613 |
+
|
| 614 |
+
start_indices = _find_last_equal(input_ids[..., 0], int(self.config.audio_start_token_id))
|
| 615 |
+
start_lengths = input_ids_length - start_indices - 1
|
| 616 |
+
outputs: list[tuple[int, torch.LongTensor]] = []
|
| 617 |
+
for start_index, start_length, generation_ids in zip(
|
| 618 |
+
start_indices.tolist(),
|
| 619 |
+
start_lengths.tolist(),
|
| 620 |
+
current_input_ids,
|
| 621 |
+
):
|
| 622 |
+
outputs.append((int(start_length), generation_ids[int(start_index):].detach().cpu()))
|
| 623 |
+
return outputs
|
processing_moss_tts.py
ADDED
|
@@ -0,0 +1,899 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
"""Processor for the MOSS-TTS-Local-Transformer-v1.5 HuggingFace release."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torchaudio
|
| 14 |
+
from transformers import (
|
| 15 |
+
AutoConfig,
|
| 16 |
+
AutoModel,
|
| 17 |
+
AutoTokenizer,
|
| 18 |
+
BatchFeature,
|
| 19 |
+
PreTrainedTokenizerBase,
|
| 20 |
+
ProcessorMixin,
|
| 21 |
+
logging,
|
| 22 |
+
processing_utils,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from .configuration_moss_tts import MossTTSLocalConfig
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if hasattr(processing_utils, "MODALITY_TO_BASE_CLASS_MAPPING"):
|
| 29 |
+
processing_utils.MODALITY_TO_BASE_CLASS_MAPPING["audio_tokenizer"] = "PreTrainedModel"
|
| 30 |
+
else:
|
| 31 |
+
processing_utils.AUTO_TO_BASE_CLASS_MAPPING["AutoModel"] = "PreTrainedModel"
|
| 32 |
+
logger = logging.get_logger(__name__)
|
| 33 |
+
|
| 34 |
+
AUDIO_PLACEHOLDER = "<|audio|>"
|
| 35 |
+
USER_ROLE_PREFIX = "user\n"
|
| 36 |
+
USER_TEMPLATE_REFERENCE_PREFIX = (
|
| 37 |
+
"<user_inst>\n"
|
| 38 |
+
"- Reference(s):\n"
|
| 39 |
+
)
|
| 40 |
+
USER_TEMPLATE_AFTER_REFERENCE_SUFFIX = (
|
| 41 |
+
"\n"
|
| 42 |
+
"- Text:\n"
|
| 43 |
+
)
|
| 44 |
+
USER_TEMPLATE_SUFFIX = "\n</user_inst>"
|
| 45 |
+
ASSISTANT_TURN_PREFIX = "\n"
|
| 46 |
+
ASSISTANT_ROLE_PREFIX = "assistant\n"
|
| 47 |
+
USER_MESSAGE_FIELDS = (
|
| 48 |
+
"text",
|
| 49 |
+
"reference",
|
| 50 |
+
"instruction",
|
| 51 |
+
"tokens",
|
| 52 |
+
"quality",
|
| 53 |
+
"sound_event",
|
| 54 |
+
"ambient_sound",
|
| 55 |
+
"language",
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _normalize_template_value(value: Any) -> str:
|
| 60 |
+
if value is None:
|
| 61 |
+
return "None"
|
| 62 |
+
resolved = str(value).strip()
|
| 63 |
+
return resolved or "None"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _render_user_prompt_after_reference(
|
| 67 |
+
language_code: object | None = None,
|
| 68 |
+
prompt_fields: Optional[Dict[str, Any]] = None,
|
| 69 |
+
) -> str:
|
| 70 |
+
fields = dict(prompt_fields or {})
|
| 71 |
+
return (
|
| 72 |
+
"\n- Instruction:\n"
|
| 73 |
+
+ _normalize_template_value(fields.get("instruction"))
|
| 74 |
+
+ "\n- Tokens:\n"
|
| 75 |
+
+ _normalize_template_value(fields.get("tokens"))
|
| 76 |
+
+ "\n- Quality:\n"
|
| 77 |
+
+ _normalize_template_value(fields.get("quality"))
|
| 78 |
+
+ "\n- Sound Event:\n"
|
| 79 |
+
+ _normalize_template_value(fields.get("sound_event"))
|
| 80 |
+
+ "\n- Ambient Sound:\n"
|
| 81 |
+
+ _normalize_template_value(fields.get("ambient_sound"))
|
| 82 |
+
+ "\n- Language:\n"
|
| 83 |
+
+ _normalize_template_value(fields.get("language", language_code))
|
| 84 |
+
+ USER_TEMPLATE_AFTER_REFERENCE_SUFFIX
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dataclass
|
| 89 |
+
class Message:
|
| 90 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 91 |
+
raise NotImplementedError
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@dataclass
|
| 95 |
+
class UserMessage(Message):
|
| 96 |
+
text: Optional[str] = None
|
| 97 |
+
reference: Optional[List[Optional[Union[str, os.PathLike, torch.Tensor]]]] = None
|
| 98 |
+
instruction: Optional[str] = None
|
| 99 |
+
tokens: Optional[int] = None
|
| 100 |
+
quality: Optional[str] = None
|
| 101 |
+
sound_event: Optional[str] = None
|
| 102 |
+
ambient_sound: Optional[str] = None
|
| 103 |
+
language: Optional[str] = None
|
| 104 |
+
|
| 105 |
+
def __post_init__(self) -> None:
|
| 106 |
+
template = """<user_inst>
|
| 107 |
+
- Reference(s):
|
| 108 |
+
{reference}
|
| 109 |
+
- Instruction:
|
| 110 |
+
{instruction}
|
| 111 |
+
- Tokens:
|
| 112 |
+
{tokens}
|
| 113 |
+
- Quality:
|
| 114 |
+
{quality}
|
| 115 |
+
- Sound Event:
|
| 116 |
+
{sound_event}
|
| 117 |
+
- Ambient Sound:
|
| 118 |
+
{ambient_sound}
|
| 119 |
+
- Language:
|
| 120 |
+
{language}
|
| 121 |
+
- Text:
|
| 122 |
+
{text}
|
| 123 |
+
</user_inst>"""
|
| 124 |
+
|
| 125 |
+
audio_codes_list: list[Union[str, os.PathLike, torch.Tensor]] = []
|
| 126 |
+
if self.reference is None:
|
| 127 |
+
reference = "None"
|
| 128 |
+
else:
|
| 129 |
+
reference_items: list[str] = []
|
| 130 |
+
for speaker_idx, speaker_reference in enumerate(self.reference):
|
| 131 |
+
if speaker_reference is None:
|
| 132 |
+
continue
|
| 133 |
+
# Keep raw audio placeholders directly under "- Reference(s):".
|
| 134 |
+
# Speaker labels such as "[S1]:" change the token sequence and
|
| 135 |
+
# can affect voice-clone conditioning.
|
| 136 |
+
reference_items.append(AUDIO_PLACEHOLDER)
|
| 137 |
+
audio_codes_list.append(speaker_reference)
|
| 138 |
+
reference = "\n".join(reference_items) if reference_items else "None"
|
| 139 |
+
|
| 140 |
+
self._content = (
|
| 141 |
+
template.replace("{reference}", str(reference))
|
| 142 |
+
.replace("{instruction}", str(self.instruction))
|
| 143 |
+
.replace("{tokens}", str(self.tokens))
|
| 144 |
+
.replace("{quality}", str(self.quality))
|
| 145 |
+
.replace("{sound_event}", str(self.sound_event))
|
| 146 |
+
.replace("{ambient_sound}", str(self.ambient_sound))
|
| 147 |
+
.replace("{language}", str(self.language))
|
| 148 |
+
.replace("{text}", str(self.text))
|
| 149 |
+
)
|
| 150 |
+
self._audio_codes_list = audio_codes_list
|
| 151 |
+
|
| 152 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 153 |
+
return {
|
| 154 |
+
"role": "user",
|
| 155 |
+
"content": self._content,
|
| 156 |
+
"audio_codes_list": self._audio_codes_list,
|
| 157 |
+
"text": self.text,
|
| 158 |
+
"instruction": self.instruction,
|
| 159 |
+
"tokens": self.tokens,
|
| 160 |
+
"quality": self.quality,
|
| 161 |
+
"sound_event": self.sound_event,
|
| 162 |
+
"ambient_sound": self.ambient_sound,
|
| 163 |
+
"language": self.language,
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@dataclass
|
| 168 |
+
class AssistantMessage(Message):
|
| 169 |
+
audio_codes_list: List[Union[str, os.PathLike, torch.Tensor]]
|
| 170 |
+
content: str = AUDIO_PLACEHOLDER
|
| 171 |
+
|
| 172 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 173 |
+
return {
|
| 174 |
+
"role": "assistant",
|
| 175 |
+
"content": self.content,
|
| 176 |
+
"audio_codes_list": self.audio_codes_list,
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class MossTTSLocalProcessor(ProcessorMixin):
|
| 181 |
+
attributes = ["tokenizer"]
|
| 182 |
+
tokenizer_class = "AutoTokenizer"
|
| 183 |
+
audio_tokenizer_class = "AutoModel"
|
| 184 |
+
|
| 185 |
+
tokenizer: PreTrainedTokenizerBase
|
| 186 |
+
audio_tokenizer: Any
|
| 187 |
+
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
tokenizer: PreTrainedTokenizerBase,
|
| 191 |
+
audio_tokenizer: Any = None,
|
| 192 |
+
model_config: Optional[MossTTSLocalConfig] = None,
|
| 193 |
+
**kwargs,
|
| 194 |
+
) -> None:
|
| 195 |
+
super().__init__(tokenizer=tokenizer, audio_tokenizer=audio_tokenizer, **kwargs)
|
| 196 |
+
self.tokenizer = tokenizer
|
| 197 |
+
self.audio_tokenizer = audio_tokenizer
|
| 198 |
+
self.model_config = model_config or MossTTSLocalConfig()
|
| 199 |
+
|
| 200 |
+
def _id_to_token(token_id: int) -> str:
|
| 201 |
+
token = tokenizer.convert_ids_to_tokens(int(token_id))
|
| 202 |
+
if isinstance(token, list):
|
| 203 |
+
return token[0] if token else ""
|
| 204 |
+
return cast(str, token)
|
| 205 |
+
|
| 206 |
+
self.audio_user_slot_token = _id_to_token(self.model_config.audio_user_slot_token_id)
|
| 207 |
+
self.audio_assistant_slot_token = _id_to_token(self.model_config.audio_assistant_slot_token_id)
|
| 208 |
+
self.audio_start_token = _id_to_token(self.model_config.audio_start_token_id)
|
| 209 |
+
self.audio_end_token = _id_to_token(self.model_config.audio_end_token_id)
|
| 210 |
+
|
| 211 |
+
@classmethod
|
| 212 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
|
| 213 |
+
trust_remote_code = kwargs.pop("trust_remote_code", True)
|
| 214 |
+
kwargs.pop("_from_auto", None)
|
| 215 |
+
codec_path = kwargs.pop("codec_path", None)
|
| 216 |
+
|
| 217 |
+
model_ref = Path(str(pretrained_model_name_or_path))
|
| 218 |
+
model_ref_or_name = model_ref if model_ref.exists() else pretrained_model_name_or_path
|
| 219 |
+
model_config = cast(
|
| 220 |
+
MossTTSLocalConfig,
|
| 221 |
+
AutoConfig.from_pretrained(
|
| 222 |
+
model_ref_or_name,
|
| 223 |
+
*args,
|
| 224 |
+
trust_remote_code=trust_remote_code,
|
| 225 |
+
**kwargs,
|
| 226 |
+
),
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
if codec_path is None:
|
| 230 |
+
try:
|
| 231 |
+
processor_dict, _ = cls.get_processor_dict(
|
| 232 |
+
pretrained_model_name_or_path,
|
| 233 |
+
**dict(kwargs),
|
| 234 |
+
)
|
| 235 |
+
codec_path = processor_dict.get("audio_tokenizer_name_or_path")
|
| 236 |
+
audio_tokenizer_dict = processor_dict.get("audio_tokenizer", {})
|
| 237 |
+
if isinstance(audio_tokenizer_dict, dict):
|
| 238 |
+
codec_path = audio_tokenizer_dict.get("audio_tokenizer_name_or_path") or codec_path
|
| 239 |
+
except Exception:
|
| 240 |
+
codec_path = None
|
| 241 |
+
if codec_path is None:
|
| 242 |
+
codec_path = getattr(model_config, "audio_tokenizer_name_or_path", None)
|
| 243 |
+
if codec_path is None:
|
| 244 |
+
codec_path = "OpenMOSS-Team/MOSS-Audio-Tokenizer-v2"
|
| 245 |
+
|
| 246 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 247 |
+
model_ref_or_name,
|
| 248 |
+
*args,
|
| 249 |
+
trust_remote_code=trust_remote_code,
|
| 250 |
+
**kwargs,
|
| 251 |
+
)
|
| 252 |
+
audio_tokenizer = AutoModel.from_pretrained(
|
| 253 |
+
codec_path,
|
| 254 |
+
trust_remote_code=trust_remote_code,
|
| 255 |
+
**kwargs,
|
| 256 |
+
)
|
| 257 |
+
return cls(
|
| 258 |
+
tokenizer=tokenizer,
|
| 259 |
+
audio_tokenizer=audio_tokenizer,
|
| 260 |
+
model_config=model_config,
|
| 261 |
+
**kwargs,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
@staticmethod
|
| 265 |
+
def build_user_message(
|
| 266 |
+
text: Optional[str] = None,
|
| 267 |
+
reference: Optional[List[Optional[Union[str, os.PathLike, torch.Tensor]]]] = None,
|
| 268 |
+
instruction: Optional[str] = None,
|
| 269 |
+
tokens: Optional[int] = None,
|
| 270 |
+
quality: Optional[str] = None,
|
| 271 |
+
sound_event: Optional[str] = None,
|
| 272 |
+
ambient_sound: Optional[str] = None,
|
| 273 |
+
language: Optional[str] = None,
|
| 274 |
+
) -> Dict[str, Any]:
|
| 275 |
+
if reference is not None and not isinstance(reference, list):
|
| 276 |
+
reference = [reference]
|
| 277 |
+
return UserMessage(
|
| 278 |
+
text=text,
|
| 279 |
+
reference=reference,
|
| 280 |
+
instruction=instruction,
|
| 281 |
+
tokens=tokens,
|
| 282 |
+
quality=quality,
|
| 283 |
+
sound_event=sound_event,
|
| 284 |
+
ambient_sound=ambient_sound,
|
| 285 |
+
language=language,
|
| 286 |
+
).to_dict()
|
| 287 |
+
|
| 288 |
+
@staticmethod
|
| 289 |
+
def build_assistant_message(
|
| 290 |
+
audio_codes_list: List[Union[str, os.PathLike, torch.Tensor]],
|
| 291 |
+
content: str = AUDIO_PLACEHOLDER,
|
| 292 |
+
) -> Dict[str, Any]:
|
| 293 |
+
return AssistantMessage(audio_codes_list=audio_codes_list, content=content).to_dict()
|
| 294 |
+
|
| 295 |
+
def _assert_fixed_nq(self, n_vq: Optional[int]) -> int:
|
| 296 |
+
config_nq = int(self.model_config.n_vq)
|
| 297 |
+
if n_vq is not None and int(n_vq) != config_nq:
|
| 298 |
+
raise ValueError(
|
| 299 |
+
"This MOSS-TTS-Local-Transformer-v1.5 release uses the RVQ depth stored in the model config. "
|
| 300 |
+
f"Expected n_vq={config_nq}, got {int(n_vq)}."
|
| 301 |
+
)
|
| 302 |
+
return config_nq
|
| 303 |
+
|
| 304 |
+
def _encode_text(self, text: str) -> list[int]:
|
| 305 |
+
try:
|
| 306 |
+
return list(self.tokenizer.encode(text, add_special_tokens=False))
|
| 307 |
+
except TypeError:
|
| 308 |
+
return list(self.tokenizer.encode(text))
|
| 309 |
+
|
| 310 |
+
def _build_text_rows(self, token_ids: Sequence[int], *, device: Optional[torch.device] = None) -> torch.Tensor:
|
| 311 |
+
rows = torch.full(
|
| 312 |
+
(len(token_ids), int(self.model_config.n_vq) + 1),
|
| 313 |
+
int(self.model_config.audio_pad_token_id),
|
| 314 |
+
dtype=torch.long,
|
| 315 |
+
device=device,
|
| 316 |
+
)
|
| 317 |
+
if token_ids:
|
| 318 |
+
rows[:, 0] = torch.tensor([int(token_id) for token_id in token_ids], dtype=torch.long, device=rows.device)
|
| 319 |
+
return rows
|
| 320 |
+
|
| 321 |
+
def _build_audio_rows(self, audio_tokens: torch.Tensor, slot_token_id: int) -> torch.Tensor:
|
| 322 |
+
rows = torch.full(
|
| 323 |
+
(int(audio_tokens.shape[0]), int(self.model_config.n_vq) + 1),
|
| 324 |
+
int(self.model_config.audio_pad_token_id),
|
| 325 |
+
dtype=torch.long,
|
| 326 |
+
device=audio_tokens.device,
|
| 327 |
+
)
|
| 328 |
+
if rows.shape[0] > 0:
|
| 329 |
+
rows[:, 0] = int(slot_token_id)
|
| 330 |
+
rows[:, 1:] = audio_tokens.to(dtype=torch.long)
|
| 331 |
+
return rows
|
| 332 |
+
|
| 333 |
+
def _user_prompt_prefix_ids(self) -> list[int]:
|
| 334 |
+
return (
|
| 335 |
+
[int(self.model_config.im_start_token_id)]
|
| 336 |
+
+ self._encode_text(USER_ROLE_PREFIX)
|
| 337 |
+
+ self._encode_text(USER_TEMPLATE_REFERENCE_PREFIX)
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
def _user_prompt_after_reference_ids(
|
| 341 |
+
self,
|
| 342 |
+
language_code: object | None,
|
| 343 |
+
prompt_fields: Optional[Dict[str, Any]],
|
| 344 |
+
) -> list[int]:
|
| 345 |
+
return self._encode_text(
|
| 346 |
+
_render_user_prompt_after_reference(
|
| 347 |
+
language_code=language_code,
|
| 348 |
+
prompt_fields=prompt_fields,
|
| 349 |
+
)
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
def _assistant_prompt_prefix_ids(self) -> list[int]:
|
| 353 |
+
return (
|
| 354 |
+
self._encode_text(USER_TEMPLATE_SUFFIX)
|
| 355 |
+
+ [int(self.model_config.im_end_token_id)]
|
| 356 |
+
+ self._encode_text(ASSISTANT_TURN_PREFIX)
|
| 357 |
+
+ [int(self.model_config.im_start_token_id)]
|
| 358 |
+
+ self._encode_text(ASSISTANT_ROLE_PREFIX)
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
def _prompt_fields_from_user_message(self, message: Dict[str, Any]) -> dict[str, Any]:
|
| 362 |
+
fields = {}
|
| 363 |
+
for key in ("instruction", "tokens", "quality", "sound_event", "ambient_sound"):
|
| 364 |
+
if key in message and message.get(key) is not None:
|
| 365 |
+
fields[key] = message.get(key)
|
| 366 |
+
if "language" in message and message.get("language") is not None:
|
| 367 |
+
fields["language"] = message.get("language")
|
| 368 |
+
return fields
|
| 369 |
+
|
| 370 |
+
def _build_generation_or_voice_clone_codes(
|
| 371 |
+
self,
|
| 372 |
+
message: Dict[str, Any],
|
| 373 |
+
n_vq: int,
|
| 374 |
+
) -> torch.Tensor:
|
| 375 |
+
if "text" not in message:
|
| 376 |
+
raise ValueError("Direct MOSS-TTS-Local-Transformer-v1.5 generation requires messages built by build_user_message(...).")
|
| 377 |
+
text = "" if message.get("text") is None else str(message.get("text"))
|
| 378 |
+
prompt_fields = self._prompt_fields_from_user_message(message)
|
| 379 |
+
language_code = message.get("language")
|
| 380 |
+
audio_codes_list = self._resolve_audio_items(message.get("audio_codes_list", []), n_vq)
|
| 381 |
+
text_token_ids = self._encode_text(text)
|
| 382 |
+
|
| 383 |
+
if audio_codes_list:
|
| 384 |
+
parts: list[torch.Tensor] = [self._build_text_rows(
|
| 385 |
+
self._user_prompt_prefix_ids(),
|
| 386 |
+
device=audio_codes_list[0].device,
|
| 387 |
+
)]
|
| 388 |
+
for reference_codes in audio_codes_list:
|
| 389 |
+
parts.append(self._build_text_rows([int(self.model_config.audio_start_token_id)], device=reference_codes.device))
|
| 390 |
+
parts.append(self._build_audio_rows(reference_codes, int(self.model_config.audio_user_slot_token_id)))
|
| 391 |
+
parts.append(self._build_text_rows([int(self.model_config.audio_end_token_id)], device=reference_codes.device))
|
| 392 |
+
parts.append(
|
| 393 |
+
self._build_text_rows(
|
| 394 |
+
self._user_prompt_after_reference_ids(language_code, prompt_fields)
|
| 395 |
+
+ text_token_ids
|
| 396 |
+
+ self._assistant_prompt_prefix_ids()
|
| 397 |
+
+ [int(self.model_config.audio_start_token_id)],
|
| 398 |
+
device=audio_codes_list[0].device,
|
| 399 |
+
)
|
| 400 |
+
)
|
| 401 |
+
return torch.cat(parts, dim=0)
|
| 402 |
+
|
| 403 |
+
prompt_token_ids = (
|
| 404 |
+
self._user_prompt_prefix_ids()
|
| 405 |
+
+ self._encode_text("None")
|
| 406 |
+
+ self._user_prompt_after_reference_ids(language_code, prompt_fields)
|
| 407 |
+
+ text_token_ids
|
| 408 |
+
+ self._assistant_prompt_prefix_ids()
|
| 409 |
+
+ [int(self.model_config.audio_start_token_id)]
|
| 410 |
+
)
|
| 411 |
+
return self._build_text_rows(prompt_token_ids)
|
| 412 |
+
|
| 413 |
+
def _build_continuation_codes(
|
| 414 |
+
self,
|
| 415 |
+
conversation: list[Dict[str, Any]],
|
| 416 |
+
n_vq: int,
|
| 417 |
+
) -> torch.Tensor:
|
| 418 |
+
if len(conversation) < 2:
|
| 419 |
+
raise ValueError("continuation mode requires a user message followed by an assistant audio message.")
|
| 420 |
+
user_message = conversation[-2]
|
| 421 |
+
assistant_message = conversation[-1]
|
| 422 |
+
if user_message.get("role") != "user" or assistant_message.get("role") != "assistant":
|
| 423 |
+
raise ValueError("continuation mode requires the last two messages to be user, assistant.")
|
| 424 |
+
if "text" not in user_message:
|
| 425 |
+
raise ValueError("Direct MOSS-TTS-Local-Transformer-v1.5 continuation requires user messages built by build_user_message(...).")
|
| 426 |
+
|
| 427 |
+
text = "" if user_message.get("text") is None else str(user_message.get("text"))
|
| 428 |
+
prompt_fields = self._prompt_fields_from_user_message(user_message)
|
| 429 |
+
language_code = user_message.get("language")
|
| 430 |
+
prompt_token_ids = (
|
| 431 |
+
self._user_prompt_prefix_ids()
|
| 432 |
+
+ self._encode_text("None")
|
| 433 |
+
+ self._user_prompt_after_reference_ids(language_code, prompt_fields)
|
| 434 |
+
+ self._encode_text(text)
|
| 435 |
+
+ self._assistant_prompt_prefix_ids()
|
| 436 |
+
+ [int(self.model_config.audio_start_token_id)]
|
| 437 |
+
)
|
| 438 |
+
audio_codes_list = self._resolve_audio_items(assistant_message.get("audio_codes_list", []), n_vq)
|
| 439 |
+
if not audio_codes_list:
|
| 440 |
+
return self._build_text_rows(prompt_token_ids)
|
| 441 |
+
if len(audio_codes_list) != 1:
|
| 442 |
+
raise ValueError("The MOSS-TTS-Local-Transformer-v1.5 continuation path expects exactly one prompt audio item.")
|
| 443 |
+
prompt_audio_codes = audio_codes_list[0]
|
| 444 |
+
return torch.cat(
|
| 445 |
+
[
|
| 446 |
+
self._build_text_rows(prompt_token_ids, device=prompt_audio_codes.device),
|
| 447 |
+
self._build_audio_rows(prompt_audio_codes, int(self.model_config.audio_assistant_slot_token_id)),
|
| 448 |
+
],
|
| 449 |
+
dim=0,
|
| 450 |
+
)
|
| 451 |
+
|
| 452 |
+
def _try_build_direct_codes(
|
| 453 |
+
self,
|
| 454 |
+
conversation: list[Dict[str, Any]],
|
| 455 |
+
mode: str,
|
| 456 |
+
n_vq: int,
|
| 457 |
+
) -> Optional[torch.Tensor]:
|
| 458 |
+
if mode == "generation" and len(conversation) == 1 and conversation[-1].get("role") == "user":
|
| 459 |
+
if "text" in conversation[-1]:
|
| 460 |
+
return self._build_generation_or_voice_clone_codes(conversation[-1], n_vq)
|
| 461 |
+
return None
|
| 462 |
+
if mode == "continuation" and len(conversation) >= 2:
|
| 463 |
+
if "text" in conversation[-2]:
|
| 464 |
+
return self._build_continuation_codes(conversation, n_vq)
|
| 465 |
+
return None
|
| 466 |
+
return None
|
| 467 |
+
|
| 468 |
+
def __call__(self, *args, **kwargs) -> BatchFeature:
|
| 469 |
+
conversations = args[0] if args else kwargs.pop("conversations")
|
| 470 |
+
mode: str = kwargs.pop("mode", "generation")
|
| 471 |
+
apply_chat_template: bool = kwargs.pop("apply_chat_template", True)
|
| 472 |
+
n_vq = self._assert_fixed_nq(kwargs.pop("n_vq", None))
|
| 473 |
+
|
| 474 |
+
kwargs.pop("return_tensors", None)
|
| 475 |
+
kwargs.pop("padding", None)
|
| 476 |
+
kwargs.pop("truncation", None)
|
| 477 |
+
|
| 478 |
+
if mode not in {"generation", "continuation", "computing_loss"}:
|
| 479 |
+
raise ValueError(f"Unsupported mode: {mode}")
|
| 480 |
+
if isinstance(conversations, (Message, dict)):
|
| 481 |
+
conversations = [conversations]
|
| 482 |
+
elif isinstance(conversations, list) and conversations and all(
|
| 483 |
+
isinstance(item, (Message, dict)) for item in conversations
|
| 484 |
+
):
|
| 485 |
+
conversations = [conversations]
|
| 486 |
+
|
| 487 |
+
input_ids_list: list[torch.Tensor] = []
|
| 488 |
+
for conversation in conversations:
|
| 489 |
+
if isinstance(conversation, (Message, dict)):
|
| 490 |
+
conversation = [conversation]
|
| 491 |
+
conversation = [self._normalize_message(message) for message in conversation]
|
| 492 |
+
|
| 493 |
+
if (mode == "generation") ^ (conversation[-1]["role"] == "user"):
|
| 494 |
+
raise ValueError("generation mode must end with a user message.")
|
| 495 |
+
if mode == "continuation" and conversation[-1]["role"] != "assistant":
|
| 496 |
+
raise ValueError("continuation mode must end with an assistant message.")
|
| 497 |
+
|
| 498 |
+
direct_codes = self._try_build_direct_codes(conversation, mode, n_vq)
|
| 499 |
+
if direct_codes is not None:
|
| 500 |
+
input_ids_list.append(direct_codes)
|
| 501 |
+
continue
|
| 502 |
+
|
| 503 |
+
unified_parts = []
|
| 504 |
+
for message_idx, message in enumerate(conversation):
|
| 505 |
+
content = str(message["content"])
|
| 506 |
+
if apply_chat_template:
|
| 507 |
+
add_generation_prompt = mode == "generation" and message_idx == len(conversation) - 1
|
| 508 |
+
try:
|
| 509 |
+
content = self.tokenizer.apply_chat_template(
|
| 510 |
+
[{"role": message["role"], "content": content}],
|
| 511 |
+
add_generation_prompt=add_generation_prompt,
|
| 512 |
+
tokenize=False,
|
| 513 |
+
)
|
| 514 |
+
except Exception:
|
| 515 |
+
logger.warning("apply_chat_template failed; falling back to raw message content.")
|
| 516 |
+
|
| 517 |
+
raw_audio_items = message.get("audio_codes_list", [])
|
| 518 |
+
audio_codes_list = self._resolve_audio_items(raw_audio_items, n_vq)
|
| 519 |
+
unified_parts.append(
|
| 520 |
+
self._get_unified_codes(
|
| 521 |
+
role=message["role"],
|
| 522 |
+
content=content,
|
| 523 |
+
audio_codes_list=audio_codes_list,
|
| 524 |
+
truncation=(mode == "continuation"),
|
| 525 |
+
)
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
unified_codes = torch.cat(unified_parts, dim=0)
|
| 529 |
+
if mode == "generation":
|
| 530 |
+
audio_start_row = torch.full(
|
| 531 |
+
(1, n_vq + 1),
|
| 532 |
+
int(self.model_config.audio_pad_token_id),
|
| 533 |
+
dtype=unified_codes.dtype,
|
| 534 |
+
device=unified_codes.device,
|
| 535 |
+
)
|
| 536 |
+
audio_start_row[:, 0] = int(self.model_config.audio_start_token_id)
|
| 537 |
+
unified_codes = torch.cat([unified_codes, audio_start_row], dim=0)
|
| 538 |
+
input_ids_list.append(unified_codes)
|
| 539 |
+
|
| 540 |
+
return BatchFeature(data=self._pad(input_ids_list))
|
| 541 |
+
|
| 542 |
+
def _normalize_message(self, message: Union[Message, Dict[str, Any]]) -> Dict[str, Any]:
|
| 543 |
+
if isinstance(message, Message):
|
| 544 |
+
return message.to_dict()
|
| 545 |
+
if not isinstance(message, dict):
|
| 546 |
+
raise TypeError("Each message must be a Message or dict.")
|
| 547 |
+
if "content" in message and "audio_codes_list" in message:
|
| 548 |
+
return message
|
| 549 |
+
role = message.get("role")
|
| 550 |
+
if role == "user":
|
| 551 |
+
return self.build_user_message(**{key: message.get(key) for key in USER_MESSAGE_FIELDS})
|
| 552 |
+
if role == "assistant":
|
| 553 |
+
return self.build_assistant_message(
|
| 554 |
+
audio_codes_list=message.get("audio_codes_list", []),
|
| 555 |
+
content=message.get("content", AUDIO_PLACEHOLDER),
|
| 556 |
+
)
|
| 557 |
+
raise ValueError(f"Unsupported role: {role}")
|
| 558 |
+
|
| 559 |
+
def _resolve_audio_items(
|
| 560 |
+
self,
|
| 561 |
+
raw_audio_items: list[Any],
|
| 562 |
+
n_vq: int,
|
| 563 |
+
) -> list[torch.Tensor]:
|
| 564 |
+
if not raw_audio_items:
|
| 565 |
+
return []
|
| 566 |
+
resolved: list[Optional[torch.Tensor]] = [None] * len(raw_audio_items)
|
| 567 |
+
paths: list[str] = []
|
| 568 |
+
path_positions: list[int] = []
|
| 569 |
+
for index, item in enumerate(raw_audio_items):
|
| 570 |
+
if isinstance(item, torch.Tensor):
|
| 571 |
+
if item.ndim != 2 or int(item.shape[1]) != n_vq:
|
| 572 |
+
raise ValueError(f"audio code tensor must have shape [T, {n_vq}], got {tuple(item.shape)}.")
|
| 573 |
+
resolved[index] = item.to(dtype=torch.long).cpu()
|
| 574 |
+
elif isinstance(item, (str, os.PathLike)):
|
| 575 |
+
paths.append(str(item))
|
| 576 |
+
path_positions.append(index)
|
| 577 |
+
else:
|
| 578 |
+
raise TypeError("Audio items must be tensors or path-like values.")
|
| 579 |
+
if paths:
|
| 580 |
+
encoded = self.encode_audios_from_path(paths, n_vq=n_vq)
|
| 581 |
+
for position, codes in zip(path_positions, encoded):
|
| 582 |
+
resolved[position] = codes
|
| 583 |
+
return [cast(torch.Tensor, item) for item in resolved]
|
| 584 |
+
|
| 585 |
+
def _pad(self, input_ids_list: list[torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| 586 |
+
device = input_ids_list[0].device
|
| 587 |
+
lengths = torch.tensor([item.shape[0] for item in input_ids_list], device=device)
|
| 588 |
+
padded = torch.nn.utils.rnn.pad_sequence(
|
| 589 |
+
input_ids_list,
|
| 590 |
+
batch_first=True,
|
| 591 |
+
padding_value=int(self.model_config.audio_pad_token_id),
|
| 592 |
+
padding_side="left",
|
| 593 |
+
)
|
| 594 |
+
left_pad_mask = (padded.shape[1] - lengths).unsqueeze(1) > torch.arange(
|
| 595 |
+
padded.shape[1],
|
| 596 |
+
device=device,
|
| 597 |
+
).unsqueeze(0)
|
| 598 |
+
padded[..., 0][left_pad_mask] = int(self.model_config.pad_token_id)
|
| 599 |
+
attention_mask = torch.zeros(padded.shape[:2], dtype=torch.bool, device=device)
|
| 600 |
+
attention_mask[~left_pad_mask] = True
|
| 601 |
+
return {"input_ids": padded, "attention_mask": attention_mask}
|
| 602 |
+
|
| 603 |
+
@staticmethod
|
| 604 |
+
def _replace_audio_placeholders(
|
| 605 |
+
content: str,
|
| 606 |
+
lengths: list[int],
|
| 607 |
+
slot_token: str,
|
| 608 |
+
audio_start_token: str,
|
| 609 |
+
audio_end_token: str,
|
| 610 |
+
) -> str:
|
| 611 |
+
placeholder_count = content.count(AUDIO_PLACEHOLDER)
|
| 612 |
+
if placeholder_count != len(lengths):
|
| 613 |
+
raise ValueError(
|
| 614 |
+
f"Number of {AUDIO_PLACEHOLDER} ({placeholder_count}) does not match "
|
| 615 |
+
f"audio item count ({len(lengths)})."
|
| 616 |
+
)
|
| 617 |
+
lengths_iter = iter(lengths)
|
| 618 |
+
|
| 619 |
+
def replacer(_: re.Match) -> str:
|
| 620 |
+
length = int(next(lengths_iter))
|
| 621 |
+
if length <= 0:
|
| 622 |
+
return f"{audio_start_token}{audio_end_token}"
|
| 623 |
+
return f"{audio_start_token}{slot_token * length}{audio_end_token}"
|
| 624 |
+
|
| 625 |
+
return re.sub(re.escape(AUDIO_PLACEHOLDER), replacer, content)
|
| 626 |
+
|
| 627 |
+
def _get_unified_codes(
|
| 628 |
+
self,
|
| 629 |
+
role: str,
|
| 630 |
+
content: str,
|
| 631 |
+
audio_codes_list: list[torch.Tensor],
|
| 632 |
+
truncation: bool,
|
| 633 |
+
) -> torch.Tensor:
|
| 634 |
+
n_vq = int(self.model_config.n_vq)
|
| 635 |
+
slot_token = self.audio_user_slot_token if role == "user" else self.audio_assistant_slot_token
|
| 636 |
+
content = self._replace_audio_placeholders(
|
| 637 |
+
content=content,
|
| 638 |
+
lengths=[int(codes.shape[0]) for codes in audio_codes_list],
|
| 639 |
+
slot_token=slot_token,
|
| 640 |
+
audio_start_token=self.audio_start_token,
|
| 641 |
+
audio_end_token=self.audio_end_token,
|
| 642 |
+
)
|
| 643 |
+
text_codes = torch.tensor(
|
| 644 |
+
self.tokenizer.encode(content),
|
| 645 |
+
dtype=torch.long,
|
| 646 |
+
device=audio_codes_list[0].device if audio_codes_list else None,
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
audio_start_indices = torch.where(text_codes == int(self.model_config.audio_start_token_id))[0]
|
| 650 |
+
audio_end_indices = torch.where(text_codes == int(self.model_config.audio_end_token_id))[0]
|
| 651 |
+
if len(audio_start_indices) != len(audio_codes_list) or len(audio_end_indices) != len(audio_codes_list):
|
| 652 |
+
raise ValueError("Audio placeholders do not match the encoded audio spans.")
|
| 653 |
+
|
| 654 |
+
if not audio_codes_list:
|
| 655 |
+
audio_codes = torch.full(
|
| 656 |
+
(len(text_codes), n_vq),
|
| 657 |
+
int(self.model_config.audio_pad_token_id),
|
| 658 |
+
dtype=torch.long,
|
| 659 |
+
device=text_codes.device,
|
| 660 |
+
)
|
| 661 |
+
else:
|
| 662 |
+
pieces: list[torch.Tensor] = []
|
| 663 |
+
prefix_idx = 0
|
| 664 |
+
for start_t, end_t, codes in zip(audio_start_indices, audio_end_indices, audio_codes_list):
|
| 665 |
+
start_idx = int(start_t.item())
|
| 666 |
+
end_idx = int(end_t.item())
|
| 667 |
+
pad_before = torch.full(
|
| 668 |
+
(start_idx - prefix_idx + 1, n_vq),
|
| 669 |
+
int(self.model_config.audio_pad_token_id),
|
| 670 |
+
dtype=torch.long,
|
| 671 |
+
device=codes.device,
|
| 672 |
+
)
|
| 673 |
+
pieces.extend([pad_before, codes.to(dtype=torch.long)])
|
| 674 |
+
prefix_idx = end_idx
|
| 675 |
+
if truncation:
|
| 676 |
+
trailing = torch.zeros(
|
| 677 |
+
(0, n_vq),
|
| 678 |
+
dtype=torch.long,
|
| 679 |
+
device=audio_codes_list[0].device,
|
| 680 |
+
)
|
| 681 |
+
else:
|
| 682 |
+
last_end = int(audio_end_indices[-1].item())
|
| 683 |
+
trailing = torch.full(
|
| 684 |
+
(len(text_codes) - last_end, n_vq),
|
| 685 |
+
int(self.model_config.audio_pad_token_id),
|
| 686 |
+
dtype=torch.long,
|
| 687 |
+
device=audio_codes_list[0].device,
|
| 688 |
+
)
|
| 689 |
+
pieces.append(trailing)
|
| 690 |
+
audio_codes = torch.cat(pieces, dim=0)
|
| 691 |
+
|
| 692 |
+
if text_codes.shape[0] != audio_codes.shape[0]:
|
| 693 |
+
min_len = min(text_codes.shape[0], audio_codes.shape[0])
|
| 694 |
+
text_codes = text_codes[:min_len]
|
| 695 |
+
audio_codes = audio_codes[:min_len]
|
| 696 |
+
return torch.cat([text_codes.unsqueeze(1), audio_codes], dim=1)
|
| 697 |
+
|
| 698 |
+
def _parse_text_codes(self, start_length: int, text_codes: torch.LongTensor) -> str:
|
| 699 |
+
text = cast(str, self.tokenizer.decode(text_codes))
|
| 700 |
+
prefix = cast(str, self.tokenizer.decode(text_codes[:start_length]))
|
| 701 |
+
text = text[len(prefix):]
|
| 702 |
+
audio_pattern = re.compile(
|
| 703 |
+
rf"(?:{re.escape(self.audio_start_token)})?"
|
| 704 |
+
rf"(?:{re.escape(self.audio_assistant_slot_token)})*"
|
| 705 |
+
rf"{re.escape(self.audio_end_token)}"
|
| 706 |
+
)
|
| 707 |
+
return audio_pattern.sub(
|
| 708 |
+
lambda match: AUDIO_PLACEHOLDER if self.audio_assistant_slot_token in match.group(0) else "",
|
| 709 |
+
text,
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
def _parse_audio_codes(
|
| 713 |
+
self,
|
| 714 |
+
start_length: int,
|
| 715 |
+
audio_codes: torch.LongTensor,
|
| 716 |
+
*,
|
| 717 |
+
return_stereo: bool = True,
|
| 718 |
+
) -> list[torch.Tensor]:
|
| 719 |
+
is_pad = audio_codes.eq(int(self.model_config.audio_pad_token_id)).all(dim=1)
|
| 720 |
+
non_pad = ~is_pad
|
| 721 |
+
if not bool(non_pad.any().item()):
|
| 722 |
+
return []
|
| 723 |
+
idx = torch.nonzero(non_pad).squeeze(1)
|
| 724 |
+
breaks = torch.where(idx[1:] != idx[:-1] + 1)[0] + 1
|
| 725 |
+
segment_indices = [idx] if breaks.numel() == 0 else list(torch.tensor_split(idx, breaks.cpu().tolist()))
|
| 726 |
+
code_segments = [audio_codes[segment] for segment in segment_indices]
|
| 727 |
+
decoded = self.decode_audio_codes(code_segments, return_stereo=return_stereo)
|
| 728 |
+
|
| 729 |
+
if start_length > 0 and code_segments and decoded:
|
| 730 |
+
first_code_length = int(code_segments[0].shape[0])
|
| 731 |
+
if first_code_length > 0:
|
| 732 |
+
trim_ratio = max(0.0, min(float(start_length) / float(first_code_length), 1.0))
|
| 733 |
+
if trim_ratio >= 1.0:
|
| 734 |
+
decoded = decoded[1:]
|
| 735 |
+
elif trim_ratio > 0.0:
|
| 736 |
+
trim_samples = int(decoded[0].shape[-1] * trim_ratio)
|
| 737 |
+
decoded[0] = decoded[0][..., trim_samples:]
|
| 738 |
+
return decoded
|
| 739 |
+
|
| 740 |
+
def decode(self, output: Any, *, return_stereo: bool = True) -> list[Optional[AssistantMessage]]:
|
| 741 |
+
generated_messages: list[Optional[AssistantMessage]] = []
|
| 742 |
+
for start_length, generation_ids in output:
|
| 743 |
+
content = self._parse_text_codes(int(start_length), generation_ids[:, 0])
|
| 744 |
+
audio_codes_list = self._parse_audio_codes(
|
| 745 |
+
int(start_length),
|
| 746 |
+
generation_ids[:, 1:],
|
| 747 |
+
return_stereo=return_stereo,
|
| 748 |
+
)
|
| 749 |
+
if content == "":
|
| 750 |
+
generated_messages.append(None)
|
| 751 |
+
else:
|
| 752 |
+
generated_messages.append(
|
| 753 |
+
AssistantMessage(
|
| 754 |
+
content=content,
|
| 755 |
+
audio_codes_list=cast(list[Union[str, torch.Tensor]], audio_codes_list),
|
| 756 |
+
)
|
| 757 |
+
)
|
| 758 |
+
return generated_messages
|
| 759 |
+
|
| 760 |
+
@staticmethod
|
| 761 |
+
def loudness_normalize(
|
| 762 |
+
wav: torch.Tensor,
|
| 763 |
+
target_dbfs: float = -20.0,
|
| 764 |
+
gain_range: tuple[float, float] = (-3.0, 3.0),
|
| 765 |
+
) -> torch.Tensor:
|
| 766 |
+
wav = wav.to(torch.float32)
|
| 767 |
+
if wav.numel() == 0:
|
| 768 |
+
return wav
|
| 769 |
+
current_dbfs = 10.0 * torch.log10(torch.mean(wav**2) + 1e-9)
|
| 770 |
+
gain = max(gain_range[0], min(float(target_dbfs - current_dbfs), gain_range[1]))
|
| 771 |
+
return wav * (10.0 ** (gain / 20.0))
|
| 772 |
+
|
| 773 |
+
def _get_audio_tokenizer_device(self) -> torch.device:
|
| 774 |
+
audio_tokenizer = getattr(self, "audio_tokenizer", None)
|
| 775 |
+
if audio_tokenizer is None:
|
| 776 |
+
raise RuntimeError("audio_tokenizer is not set.")
|
| 777 |
+
try:
|
| 778 |
+
return next(audio_tokenizer.parameters()).device
|
| 779 |
+
except StopIteration:
|
| 780 |
+
return torch.device("cpu")
|
| 781 |
+
|
| 782 |
+
def encode_audios_from_wav(
|
| 783 |
+
self,
|
| 784 |
+
wav_list: Union[torch.Tensor, list[torch.Tensor]],
|
| 785 |
+
sampling_rate: int,
|
| 786 |
+
n_vq: Optional[int] = None,
|
| 787 |
+
) -> list[torch.Tensor]:
|
| 788 |
+
n_vq = self._assert_fixed_nq(n_vq)
|
| 789 |
+
if self.audio_tokenizer is None:
|
| 790 |
+
raise RuntimeError("audio_tokenizer is not set.")
|
| 791 |
+
if isinstance(wav_list, torch.Tensor):
|
| 792 |
+
wav_list = [wav_list]
|
| 793 |
+
target_sr = int(self.model_config.sampling_rate)
|
| 794 |
+
device = self._get_audio_tokenizer_device()
|
| 795 |
+
prepared = []
|
| 796 |
+
for wav in wav_list:
|
| 797 |
+
if wav.ndim == 1:
|
| 798 |
+
wav = wav.unsqueeze(0)
|
| 799 |
+
if wav.shape[0] == 1:
|
| 800 |
+
wav = wav.repeat(2, 1)
|
| 801 |
+
elif wav.shape[0] > 2:
|
| 802 |
+
wav = wav[:2]
|
| 803 |
+
if int(sampling_rate) != target_sr:
|
| 804 |
+
wav = torchaudio.functional.resample(wav, int(sampling_rate), target_sr)
|
| 805 |
+
prepared.append(self.loudness_normalize(wav).to(device))
|
| 806 |
+
|
| 807 |
+
if hasattr(self.audio_tokenizer, "batch_encode"):
|
| 808 |
+
encoded = self.audio_tokenizer.batch_encode(prepared, num_quantizers=n_vq)
|
| 809 |
+
audio_codes = encoded.audio_codes
|
| 810 |
+
audio_lengths = encoded.audio_codes_lengths
|
| 811 |
+
else:
|
| 812 |
+
max_len = max(int(wav.shape[-1]) for wav in prepared)
|
| 813 |
+
input_values = torch.zeros(len(prepared), 1, max_len, dtype=torch.float32, device=device)
|
| 814 |
+
padding_mask = torch.zeros(len(prepared), max_len, dtype=torch.bool, device=device)
|
| 815 |
+
for index, wav in enumerate(prepared):
|
| 816 |
+
input_values[index, 0, : wav.shape[-1]] = wav
|
| 817 |
+
padding_mask[index, : wav.shape[-1]] = True
|
| 818 |
+
encoded = self.audio_tokenizer.encode(
|
| 819 |
+
input_values,
|
| 820 |
+
padding_mask=padding_mask,
|
| 821 |
+
num_quantizers=n_vq,
|
| 822 |
+
return_dict=True,
|
| 823 |
+
)
|
| 824 |
+
audio_codes = encoded.audio_codes
|
| 825 |
+
audio_lengths = encoded.audio_codes_lengths
|
| 826 |
+
|
| 827 |
+
if audio_codes is None or audio_lengths is None:
|
| 828 |
+
raise RuntimeError("audio_tokenizer did not return audio_codes/audio_codes_lengths.")
|
| 829 |
+
result = []
|
| 830 |
+
for index in range(int(audio_codes.shape[1])):
|
| 831 |
+
length = int(audio_lengths[index].item())
|
| 832 |
+
result.append(audio_codes[:, index, :length].transpose(0, 1).contiguous().cpu().long())
|
| 833 |
+
return result
|
| 834 |
+
|
| 835 |
+
def encode_audios_from_path(
|
| 836 |
+
self,
|
| 837 |
+
wav_path_list: Union[str, os.PathLike, list[Union[str, os.PathLike]]],
|
| 838 |
+
n_vq: Optional[int] = None,
|
| 839 |
+
) -> list[torch.Tensor]:
|
| 840 |
+
if isinstance(wav_path_list, (str, os.PathLike)):
|
| 841 |
+
wav_path_list = [wav_path_list]
|
| 842 |
+
wavs = []
|
| 843 |
+
target_sr = int(self.model_config.sampling_rate)
|
| 844 |
+
for wav_path in wav_path_list:
|
| 845 |
+
wav, sr = torchaudio.load(str(wav_path))
|
| 846 |
+
if int(sr) != target_sr:
|
| 847 |
+
wav = torchaudio.functional.resample(wav, int(sr), target_sr)
|
| 848 |
+
wavs.append(wav)
|
| 849 |
+
return self.encode_audios_from_wav(wavs, target_sr, n_vq=n_vq)
|
| 850 |
+
|
| 851 |
+
def decode_audio_codes(
|
| 852 |
+
self,
|
| 853 |
+
audio_tokens_list: Union[torch.Tensor, list[torch.Tensor]],
|
| 854 |
+
*,
|
| 855 |
+
return_stereo: bool = True,
|
| 856 |
+
) -> list[torch.Tensor]:
|
| 857 |
+
if self.audio_tokenizer is None:
|
| 858 |
+
raise RuntimeError("audio_tokenizer is not set.")
|
| 859 |
+
if isinstance(audio_tokens_list, torch.Tensor):
|
| 860 |
+
audio_tokens_list = [audio_tokens_list]
|
| 861 |
+
if not audio_tokens_list:
|
| 862 |
+
return []
|
| 863 |
+
|
| 864 |
+
n_vq = int(self.model_config.n_vq)
|
| 865 |
+
device = self._get_audio_tokenizer_device()
|
| 866 |
+
codes_list = [
|
| 867 |
+
codes[:, :n_vq].transpose(0, 1).contiguous().to(device=device, dtype=torch.long)
|
| 868 |
+
for codes in audio_tokens_list
|
| 869 |
+
]
|
| 870 |
+
max_len = max(int(codes.shape[1]) for codes in codes_list)
|
| 871 |
+
audio_codes = torch.zeros(n_vq, len(codes_list), max_len, device=device, dtype=torch.long)
|
| 872 |
+
padding_mask = torch.zeros(len(codes_list), max_len, device=device, dtype=torch.bool)
|
| 873 |
+
for index, codes in enumerate(codes_list):
|
| 874 |
+
length = int(codes.shape[1])
|
| 875 |
+
audio_codes[:, index, :length] = codes
|
| 876 |
+
padding_mask[index, :length] = True
|
| 877 |
+
|
| 878 |
+
decoded = self.audio_tokenizer.decode(
|
| 879 |
+
audio_codes,
|
| 880 |
+
padding_mask=padding_mask,
|
| 881 |
+
num_quantizers=n_vq,
|
| 882 |
+
return_dict=True,
|
| 883 |
+
chunk_duration=8,
|
| 884 |
+
)
|
| 885 |
+
audio = decoded.audio
|
| 886 |
+
audio_lengths = decoded.audio_lengths
|
| 887 |
+
if audio is None or audio_lengths is None:
|
| 888 |
+
raise RuntimeError("audio_tokenizer.decode did not return audio/audio_lengths.")
|
| 889 |
+
wavs = []
|
| 890 |
+
for index in range(int(audio.shape[0])):
|
| 891 |
+
length = int(audio_lengths[index].item())
|
| 892 |
+
wav = audio[index, :, :length].contiguous().cpu().to(torch.float32)
|
| 893 |
+
if not return_stereo:
|
| 894 |
+
if wav.shape[0] == 1:
|
| 895 |
+
wav = wav.squeeze(0)
|
| 896 |
+
else:
|
| 897 |
+
wav = wav.mean(dim=0)
|
| 898 |
+
wavs.append(wav)
|
| 899 |
+
return wavs
|
processor_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"processor_class": "MossTTSLocalProcessor",
|
| 3 |
+
"audio_tokenizer_name_or_path": "OpenMOSS-Team/MOSS-Audio-Tokenizer-v2",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoProcessor": "processing_moss_tts.MossTTSLocalProcessor"
|
| 6 |
+
}
|
| 7 |
+
}
|
qwen3_decoder.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.utils.checkpoint
|
| 11 |
+
from safetensors.torch import load_file
|
| 12 |
+
from transformers.activations import ACT2FN
|
| 13 |
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
| 14 |
+
|
| 15 |
+
from .gpt2_decoder import PackedSequenceMetadata, MossTTSNanoGPT2Model
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
| 19 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
| 20 |
+
|
| 21 |
+
_FLASH_ATTN_AVAILABLE = True
|
| 22 |
+
except Exception:
|
| 23 |
+
flash_attn_func = None
|
| 24 |
+
flash_attn_varlen_func = None
|
| 25 |
+
pad_input = None
|
| 26 |
+
unpad_input = None
|
| 27 |
+
_FLASH_ATTN_AVAILABLE = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class MossQwen3RMSNorm(nn.Module):
|
| 31 |
+
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 34 |
+
self.variance_epsilon = eps
|
| 35 |
+
|
| 36 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 37 |
+
input_dtype = hidden_states.dtype
|
| 38 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 39 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 40 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 41 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class MossQwen3RotaryEmbedding(nn.Module):
|
| 45 |
+
def __init__(self, config) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
head_dim = int(getattr(config, "head_dim", config.hidden_size // config.num_attention_heads))
|
| 48 |
+
rope_theta = getattr(config, "rope_theta", None)
|
| 49 |
+
if rope_theta is None:
|
| 50 |
+
rope_scaling = getattr(config, "rope_scaling", None)
|
| 51 |
+
if isinstance(rope_scaling, dict):
|
| 52 |
+
rope_theta = rope_scaling.get("rope_theta")
|
| 53 |
+
if rope_theta is None:
|
| 54 |
+
rope_theta = 1000000.0
|
| 55 |
+
rope_theta = float(rope_theta)
|
| 56 |
+
self.head_dim = head_dim
|
| 57 |
+
self.rope_theta = rope_theta
|
| 58 |
+
self.register_buffer("inv_freq", self._compute_inv_freq(), persistent=False)
|
| 59 |
+
|
| 60 |
+
def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
|
| 61 |
+
return 1.0 / (
|
| 62 |
+
self.rope_theta ** (torch.arange(0, self.head_dim, 2, device=device, dtype=torch.float32) / self.head_dim)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def forward(
|
| 66 |
+
self,
|
| 67 |
+
hidden_states: torch.Tensor,
|
| 68 |
+
position_ids: torch.LongTensor,
|
| 69 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 70 |
+
inv_freq = self._compute_inv_freq(device=hidden_states.device)
|
| 71 |
+
freqs = torch.einsum(
|
| 72 |
+
"bs,d->bsd",
|
| 73 |
+
position_ids.to(device=hidden_states.device, dtype=inv_freq.dtype),
|
| 74 |
+
inv_freq,
|
| 75 |
+
)
|
| 76 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 77 |
+
return emb.cos().to(dtype=hidden_states.dtype), emb.sin().to(dtype=hidden_states.dtype)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def rotate_half(hidden_states: torch.Tensor) -> torch.Tensor:
|
| 81 |
+
first_half = hidden_states[..., : hidden_states.shape[-1] // 2]
|
| 82 |
+
second_half = hidden_states[..., hidden_states.shape[-1] // 2 :]
|
| 83 |
+
return torch.cat((-second_half, first_half), dim=-1)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def apply_rotary_pos_emb(
|
| 87 |
+
query: torch.Tensor,
|
| 88 |
+
key: torch.Tensor,
|
| 89 |
+
cos: torch.Tensor,
|
| 90 |
+
sin: torch.Tensor,
|
| 91 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 92 |
+
cos = cos.unsqueeze(-2)
|
| 93 |
+
sin = sin.unsqueeze(-2)
|
| 94 |
+
query = (query * cos) + (rotate_half(query) * sin)
|
| 95 |
+
key = (key * cos) + (rotate_half(key) * sin)
|
| 96 |
+
return query, key
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 100 |
+
if n_rep == 1:
|
| 101 |
+
return hidden_states
|
| 102 |
+
batch, seq_len, num_key_value_heads, head_dim = hidden_states.shape
|
| 103 |
+
hidden_states = hidden_states[:, :, :, None, :].expand(batch, seq_len, num_key_value_heads, n_rep, head_dim)
|
| 104 |
+
return hidden_states.reshape(batch, seq_len, num_key_value_heads * n_rep, head_dim)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class MossQwen3MLP(nn.Module):
|
| 108 |
+
def __init__(self, config) -> None:
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
| 111 |
+
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
| 112 |
+
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
| 113 |
+
self.act_fn = ACT2FN[getattr(config, "hidden_act", "silu")]
|
| 114 |
+
|
| 115 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 116 |
+
return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class MossQwen3Attention(nn.Module):
|
| 120 |
+
def __init__(self, config, layer_idx: int) -> None:
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.config = config
|
| 123 |
+
self.layer_idx = int(layer_idx)
|
| 124 |
+
self.hidden_size = int(config.hidden_size)
|
| 125 |
+
self.num_heads = int(config.num_attention_heads)
|
| 126 |
+
self.num_key_value_heads = int(config.num_key_value_heads)
|
| 127 |
+
self.head_dim = int(getattr(config, "head_dim", self.hidden_size // self.num_heads))
|
| 128 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 129 |
+
self.scaling = self.head_dim ** -0.5
|
| 130 |
+
self.attention_dropout = float(getattr(config, "attention_dropout", 0.0))
|
| 131 |
+
self.attn_implementation = str(getattr(config, "_attn_implementation", "eager"))
|
| 132 |
+
|
| 133 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=bool(config.attention_bias))
|
| 134 |
+
self.k_proj = nn.Linear(
|
| 135 |
+
self.hidden_size,
|
| 136 |
+
self.num_key_value_heads * self.head_dim,
|
| 137 |
+
bias=bool(config.attention_bias),
|
| 138 |
+
)
|
| 139 |
+
self.v_proj = nn.Linear(
|
| 140 |
+
self.hidden_size,
|
| 141 |
+
self.num_key_value_heads * self.head_dim,
|
| 142 |
+
bias=bool(config.attention_bias),
|
| 143 |
+
)
|
| 144 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=bool(config.attention_bias))
|
| 145 |
+
self.q_norm = MossQwen3RMSNorm(self.head_dim, eps=float(config.rms_norm_eps))
|
| 146 |
+
self.k_norm = MossQwen3RMSNorm(self.head_dim, eps=float(config.rms_norm_eps))
|
| 147 |
+
|
| 148 |
+
def _causal_attention_mask(
|
| 149 |
+
self,
|
| 150 |
+
attention_mask: Optional[torch.Tensor],
|
| 151 |
+
query_length: int,
|
| 152 |
+
key_length: int,
|
| 153 |
+
device: torch.device,
|
| 154 |
+
) -> torch.Tensor:
|
| 155 |
+
query_positions = torch.arange(query_length, device=device, dtype=torch.long)
|
| 156 |
+
query_positions = query_positions + max(key_length - query_length, 0)
|
| 157 |
+
key_positions = torch.arange(key_length, device=device, dtype=torch.long)
|
| 158 |
+
causal = key_positions.unsqueeze(0) <= query_positions.unsqueeze(1)
|
| 159 |
+
causal = causal.unsqueeze(0).unsqueeze(0)
|
| 160 |
+
if attention_mask is None:
|
| 161 |
+
return causal
|
| 162 |
+
key_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)
|
| 163 |
+
return causal & key_mask
|
| 164 |
+
|
| 165 |
+
def _eager_attention(
|
| 166 |
+
self,
|
| 167 |
+
query: torch.Tensor,
|
| 168 |
+
key: torch.Tensor,
|
| 169 |
+
value: torch.Tensor,
|
| 170 |
+
attention_mask: Optional[torch.Tensor],
|
| 171 |
+
) -> torch.Tensor:
|
| 172 |
+
key = repeat_kv(key, self.num_key_value_groups)
|
| 173 |
+
value = repeat_kv(value, self.num_key_value_groups)
|
| 174 |
+
query = query.transpose(1, 2)
|
| 175 |
+
key = key.transpose(1, 2)
|
| 176 |
+
value = value.transpose(1, 2)
|
| 177 |
+
scores = torch.matmul(query, key.transpose(-1, -2)) * self.scaling
|
| 178 |
+
mask = self._causal_attention_mask(
|
| 179 |
+
attention_mask=attention_mask,
|
| 180 |
+
query_length=query.shape[-2],
|
| 181 |
+
key_length=key.shape[-2],
|
| 182 |
+
device=query.device,
|
| 183 |
+
)
|
| 184 |
+
scores = scores.masked_fill(~mask, torch.finfo(scores.dtype).min)
|
| 185 |
+
probs = torch.softmax(scores, dim=-1)
|
| 186 |
+
if self.training and self.attention_dropout > 0:
|
| 187 |
+
probs = torch.dropout(probs, self.attention_dropout, train=True)
|
| 188 |
+
output = torch.matmul(probs, value)
|
| 189 |
+
return output.transpose(1, 2).contiguous()
|
| 190 |
+
|
| 191 |
+
def _sdpa_attention(
|
| 192 |
+
self,
|
| 193 |
+
query: torch.Tensor,
|
| 194 |
+
key: torch.Tensor,
|
| 195 |
+
value: torch.Tensor,
|
| 196 |
+
attention_mask: Optional[torch.Tensor],
|
| 197 |
+
) -> torch.Tensor:
|
| 198 |
+
key = repeat_kv(key, self.num_key_value_groups)
|
| 199 |
+
value = repeat_kv(value, self.num_key_value_groups)
|
| 200 |
+
query = query.transpose(1, 2)
|
| 201 |
+
key = key.transpose(1, 2)
|
| 202 |
+
value = value.transpose(1, 2)
|
| 203 |
+
mask = None
|
| 204 |
+
if attention_mask is not None or query.shape[-2] != key.shape[-2]:
|
| 205 |
+
mask = self._causal_attention_mask(
|
| 206 |
+
attention_mask=attention_mask,
|
| 207 |
+
query_length=query.shape[-2],
|
| 208 |
+
key_length=key.shape[-2],
|
| 209 |
+
device=query.device,
|
| 210 |
+
)
|
| 211 |
+
output = torch.nn.functional.scaled_dot_product_attention(
|
| 212 |
+
query,
|
| 213 |
+
key,
|
| 214 |
+
value,
|
| 215 |
+
attn_mask=mask,
|
| 216 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
| 217 |
+
is_causal=mask is None,
|
| 218 |
+
scale=self.scaling,
|
| 219 |
+
)
|
| 220 |
+
return output.transpose(1, 2).contiguous()
|
| 221 |
+
|
| 222 |
+
def _flash_attention(
|
| 223 |
+
self,
|
| 224 |
+
query: torch.Tensor,
|
| 225 |
+
key: torch.Tensor,
|
| 226 |
+
value: torch.Tensor,
|
| 227 |
+
attention_mask: Optional[torch.Tensor],
|
| 228 |
+
packed_metadata: Optional[PackedSequenceMetadata],
|
| 229 |
+
) -> torch.Tensor:
|
| 230 |
+
if not _FLASH_ATTN_AVAILABLE:
|
| 231 |
+
raise ImportError("flash_attn is not installed, but attn_implementation='flash_attention_2' was requested.")
|
| 232 |
+
if query.device.type != "cuda":
|
| 233 |
+
raise ValueError("flash_attention_2 requires CUDA tensors.")
|
| 234 |
+
if query.dtype not in (torch.float16, torch.bfloat16):
|
| 235 |
+
raise ValueError(f"flash_attention_2 requires fp16/bf16 tensors, got dtype={query.dtype}.")
|
| 236 |
+
|
| 237 |
+
dropout_p = self.attention_dropout if self.training else 0.0
|
| 238 |
+
if packed_metadata is not None:
|
| 239 |
+
if packed_metadata.indices is not None:
|
| 240 |
+
query = query.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
|
| 241 |
+
key = key.reshape(-1, self.num_key_value_heads, self.head_dim).index_select(0, packed_metadata.indices)
|
| 242 |
+
value = value.reshape(-1, self.num_key_value_heads, self.head_dim).index_select(0, packed_metadata.indices)
|
| 243 |
+
output = flash_attn_varlen_func(
|
| 244 |
+
query,
|
| 245 |
+
key,
|
| 246 |
+
value,
|
| 247 |
+
packed_metadata.cu_seqlens,
|
| 248 |
+
packed_metadata.cu_seqlens,
|
| 249 |
+
packed_metadata.max_seqlen,
|
| 250 |
+
packed_metadata.max_seqlen,
|
| 251 |
+
dropout_p=dropout_p,
|
| 252 |
+
causal=True,
|
| 253 |
+
)
|
| 254 |
+
if packed_metadata.indices is None:
|
| 255 |
+
return output
|
| 256 |
+
return pad_input(
|
| 257 |
+
output,
|
| 258 |
+
packed_metadata.indices,
|
| 259 |
+
packed_metadata.batch_size,
|
| 260 |
+
packed_metadata.seq_len,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
if attention_mask is None or bool(attention_mask.all()):
|
| 264 |
+
return flash_attn_func(query, key, value, dropout_p=dropout_p, causal=True)
|
| 265 |
+
|
| 266 |
+
if query.shape[1] != key.shape[1]:
|
| 267 |
+
query_attention_mask = attention_mask[:, -query.shape[1] :]
|
| 268 |
+
unpadded_query, query_indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(
|
| 269 |
+
query,
|
| 270 |
+
query_attention_mask,
|
| 271 |
+
)
|
| 272 |
+
unpadded_key, _, cu_seqlens_k, max_seqlen_k, _ = unpad_input(key, attention_mask)
|
| 273 |
+
unpadded_value, _, _, _, _ = unpad_input(value, attention_mask)
|
| 274 |
+
output = flash_attn_varlen_func(
|
| 275 |
+
unpadded_query,
|
| 276 |
+
unpadded_key,
|
| 277 |
+
unpadded_value,
|
| 278 |
+
cu_seqlens_q,
|
| 279 |
+
cu_seqlens_k,
|
| 280 |
+
max_seqlen_q,
|
| 281 |
+
max_seqlen_k,
|
| 282 |
+
dropout_p=dropout_p,
|
| 283 |
+
causal=True,
|
| 284 |
+
)
|
| 285 |
+
return pad_input(output, query_indices, query.shape[0], query.shape[1])
|
| 286 |
+
|
| 287 |
+
unpadded_query, indices, cu_seqlens, max_seqlen, _ = unpad_input(query, attention_mask)
|
| 288 |
+
unpadded_key, _, _, _, _ = unpad_input(key, attention_mask)
|
| 289 |
+
unpadded_value, _, _, _, _ = unpad_input(value, attention_mask)
|
| 290 |
+
output = flash_attn_varlen_func(
|
| 291 |
+
unpadded_query,
|
| 292 |
+
unpadded_key,
|
| 293 |
+
unpadded_value,
|
| 294 |
+
cu_seqlens,
|
| 295 |
+
cu_seqlens,
|
| 296 |
+
max_seqlen,
|
| 297 |
+
max_seqlen,
|
| 298 |
+
dropout_p=dropout_p,
|
| 299 |
+
causal=True,
|
| 300 |
+
)
|
| 301 |
+
return pad_input(output, indices, query.shape[0], query.shape[1])
|
| 302 |
+
|
| 303 |
+
def forward(
|
| 304 |
+
self,
|
| 305 |
+
hidden_states: torch.Tensor,
|
| 306 |
+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
| 307 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 308 |
+
packed_metadata: Optional[PackedSequenceMetadata] = None,
|
| 309 |
+
layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 310 |
+
use_cache: bool = False,
|
| 311 |
+
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
| 312 |
+
input_shape = hidden_states.shape[:-1]
|
| 313 |
+
query_states = self.q_norm(
|
| 314 |
+
self.q_proj(hidden_states).view(*input_shape, self.num_heads, self.head_dim)
|
| 315 |
+
)
|
| 316 |
+
key_states = self.k_norm(
|
| 317 |
+
self.k_proj(hidden_states).view(*input_shape, self.num_key_value_heads, self.head_dim)
|
| 318 |
+
)
|
| 319 |
+
value_states = self.v_proj(hidden_states).view(*input_shape, self.num_key_value_heads, self.head_dim)
|
| 320 |
+
|
| 321 |
+
cos, sin = position_embeddings
|
| 322 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
| 323 |
+
|
| 324 |
+
if layer_past is not None:
|
| 325 |
+
past_key, past_value = layer_past
|
| 326 |
+
key_states = torch.cat([past_key.to(device=key_states.device, dtype=key_states.dtype), key_states], dim=1)
|
| 327 |
+
value_states = torch.cat(
|
| 328 |
+
[past_value.to(device=value_states.device, dtype=value_states.dtype), value_states],
|
| 329 |
+
dim=1,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
present = (key_states, value_states) if use_cache else None
|
| 333 |
+
if self.attn_implementation == "flash_attention_2":
|
| 334 |
+
attn_output = self._flash_attention(
|
| 335 |
+
query=query_states,
|
| 336 |
+
key=key_states,
|
| 337 |
+
value=value_states,
|
| 338 |
+
attention_mask=attention_mask,
|
| 339 |
+
packed_metadata=packed_metadata,
|
| 340 |
+
)
|
| 341 |
+
elif self.attn_implementation == "sdpa":
|
| 342 |
+
attn_output = self._sdpa_attention(
|
| 343 |
+
query=query_states,
|
| 344 |
+
key=key_states,
|
| 345 |
+
value=value_states,
|
| 346 |
+
attention_mask=attention_mask,
|
| 347 |
+
)
|
| 348 |
+
else:
|
| 349 |
+
attn_output = self._eager_attention(
|
| 350 |
+
query=query_states,
|
| 351 |
+
key=key_states,
|
| 352 |
+
value=value_states,
|
| 353 |
+
attention_mask=attention_mask,
|
| 354 |
+
)
|
| 355 |
+
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 356 |
+
return self.o_proj(attn_output), present
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
class MossQwen3DecoderLayer(nn.Module):
|
| 360 |
+
def __init__(self, config, layer_idx: int) -> None:
|
| 361 |
+
super().__init__()
|
| 362 |
+
self.self_attn = MossQwen3Attention(config=config, layer_idx=layer_idx)
|
| 363 |
+
self.mlp = MossQwen3MLP(config)
|
| 364 |
+
self.input_layernorm = MossQwen3RMSNorm(config.hidden_size, eps=float(config.rms_norm_eps))
|
| 365 |
+
self.post_attention_layernorm = MossQwen3RMSNorm(config.hidden_size, eps=float(config.rms_norm_eps))
|
| 366 |
+
|
| 367 |
+
def forward(
|
| 368 |
+
self,
|
| 369 |
+
hidden_states: torch.Tensor,
|
| 370 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 371 |
+
packed_metadata: Optional[PackedSequenceMetadata] = None,
|
| 372 |
+
layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 373 |
+
use_cache: bool = False,
|
| 374 |
+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
|
| 375 |
+
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
| 376 |
+
residual = hidden_states
|
| 377 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 378 |
+
attn_output, present = self.self_attn(
|
| 379 |
+
hidden_states=hidden_states,
|
| 380 |
+
position_embeddings=position_embeddings,
|
| 381 |
+
attention_mask=attention_mask,
|
| 382 |
+
packed_metadata=packed_metadata,
|
| 383 |
+
layer_past=layer_past,
|
| 384 |
+
use_cache=use_cache,
|
| 385 |
+
)
|
| 386 |
+
hidden_states = residual + attn_output
|
| 387 |
+
|
| 388 |
+
residual = hidden_states
|
| 389 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 390 |
+
hidden_states = residual + self.mlp(hidden_states)
|
| 391 |
+
return hidden_states, present
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
class MossQwen3Model(nn.Module):
|
| 395 |
+
def __init__(self, config) -> None:
|
| 396 |
+
super().__init__()
|
| 397 |
+
self.config = config
|
| 398 |
+
self.attn_implementation = str(getattr(config, "_attn_implementation", "eager"))
|
| 399 |
+
self.padding_idx = getattr(config, "pad_token_id", None)
|
| 400 |
+
self.vocab_size = int(config.vocab_size)
|
| 401 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 402 |
+
self.layers = nn.ModuleList(
|
| 403 |
+
[MossQwen3DecoderLayer(config, layer_idx=index) for index in range(config.num_hidden_layers)]
|
| 404 |
+
)
|
| 405 |
+
self.norm = MossQwen3RMSNorm(config.hidden_size, eps=float(config.rms_norm_eps))
|
| 406 |
+
self.rotary_emb = MossQwen3RotaryEmbedding(config)
|
| 407 |
+
self.gradient_checkpointing = False
|
| 408 |
+
self.gradient_checkpointing_use_reentrant = bool(
|
| 409 |
+
getattr(config, "gradient_checkpointing_use_reentrant", False)
|
| 410 |
+
)
|
| 411 |
+
self._reset_parameters()
|
| 412 |
+
|
| 413 |
+
def _reset_parameters(self) -> None:
|
| 414 |
+
init_std = float(getattr(self.config, "initializer_range", 0.02))
|
| 415 |
+
for module in self.modules():
|
| 416 |
+
if isinstance(module, nn.Linear):
|
| 417 |
+
nn.init.normal_(module.weight, mean=0.0, std=init_std)
|
| 418 |
+
if module.bias is not None:
|
| 419 |
+
nn.init.zeros_(module.bias)
|
| 420 |
+
elif isinstance(module, nn.Embedding):
|
| 421 |
+
nn.init.normal_(module.weight, mean=0.0, std=init_std)
|
| 422 |
+
|
| 423 |
+
def get_input_embeddings(self):
|
| 424 |
+
return self.embed_tokens
|
| 425 |
+
|
| 426 |
+
def set_input_embeddings(self, value):
|
| 427 |
+
self.embed_tokens = value
|
| 428 |
+
|
| 429 |
+
def load_qwen3_pretrained_weights(self, pretrained_path: str) -> None:
|
| 430 |
+
model_dir = Path(pretrained_path)
|
| 431 |
+
index_path = model_dir / "model.safetensors.index.json"
|
| 432 |
+
if not index_path.exists():
|
| 433 |
+
raise FileNotFoundError(f"Missing Qwen3 safetensors index: {index_path}")
|
| 434 |
+
with index_path.open("r", encoding="utf-8") as handle:
|
| 435 |
+
index = json.load(handle)
|
| 436 |
+
weight_map = index.get("weight_map", {})
|
| 437 |
+
shard_to_keys: dict[str, list[str]] = {}
|
| 438 |
+
for key, shard in weight_map.items():
|
| 439 |
+
if not key.startswith("model."):
|
| 440 |
+
continue
|
| 441 |
+
shard_to_keys.setdefault(str(shard), []).append(key)
|
| 442 |
+
|
| 443 |
+
state_dict = self.state_dict()
|
| 444 |
+
loaded_state = {}
|
| 445 |
+
for shard, keys in sorted(shard_to_keys.items()):
|
| 446 |
+
shard_tensors = load_file(str(model_dir / shard), device="cpu")
|
| 447 |
+
for key in keys:
|
| 448 |
+
target_key = key[len("model.") :]
|
| 449 |
+
if target_key not in state_dict:
|
| 450 |
+
continue
|
| 451 |
+
tensor = shard_tensors[key]
|
| 452 |
+
if tuple(tensor.shape) != tuple(state_dict[target_key].shape):
|
| 453 |
+
raise ValueError(
|
| 454 |
+
f"Shape mismatch while loading Qwen3 weight {key}: "
|
| 455 |
+
f"checkpoint={tuple(tensor.shape)} model={tuple(state_dict[target_key].shape)}"
|
| 456 |
+
)
|
| 457 |
+
loaded_state[target_key] = tensor
|
| 458 |
+
|
| 459 |
+
missing, unexpected = self.load_state_dict(loaded_state, strict=False)
|
| 460 |
+
unexpected = [key for key in unexpected if key]
|
| 461 |
+
if unexpected:
|
| 462 |
+
raise RuntimeError(f"Unexpected Qwen3 pretrained keys after load: {unexpected[:10]}")
|
| 463 |
+
missing = [key for key in missing if key not in loaded_state]
|
| 464 |
+
if missing:
|
| 465 |
+
raise RuntimeError(f"Missing Qwen3 pretrained keys after load: {missing[:10]}")
|
| 466 |
+
|
| 467 |
+
def forward(
|
| 468 |
+
self,
|
| 469 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 470 |
+
past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None,
|
| 471 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 472 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 473 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 474 |
+
use_cache: Optional[bool] = None,
|
| 475 |
+
output_attentions: Optional[bool] = None,
|
| 476 |
+
output_hidden_states: Optional[bool] = None,
|
| 477 |
+
return_dict: bool = True,
|
| 478 |
+
cu_seqlens: Optional[torch.Tensor] = None,
|
| 479 |
+
num_sequences: Optional[torch.Tensor] = None,
|
| 480 |
+
) -> BaseModelOutputWithPast:
|
| 481 |
+
del input_ids, output_attentions
|
| 482 |
+
if inputs_embeds is None:
|
| 483 |
+
raise ValueError("inputs_embeds must be provided.")
|
| 484 |
+
|
| 485 |
+
use_cache = bool(use_cache)
|
| 486 |
+
if use_cache and cu_seqlens is not None:
|
| 487 |
+
raise ValueError("use_cache=True is not supported together with cu_seqlens packing.")
|
| 488 |
+
|
| 489 |
+
hidden_states = inputs_embeds
|
| 490 |
+
if attention_mask is None:
|
| 491 |
+
attention_mask = torch.ones(hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device)
|
| 492 |
+
else:
|
| 493 |
+
attention_mask = attention_mask.to(dtype=torch.bool, device=hidden_states.device)
|
| 494 |
+
query_attention_mask = attention_mask[:, -hidden_states.shape[1] :]
|
| 495 |
+
|
| 496 |
+
packed_metadata = None
|
| 497 |
+
if position_ids is None:
|
| 498 |
+
if cu_seqlens is not None:
|
| 499 |
+
position_ids = MossTTSNanoGPT2Model.build_packed_position_ids(
|
| 500 |
+
attention_mask=attention_mask,
|
| 501 |
+
cu_seqlens=cu_seqlens.to(device=hidden_states.device),
|
| 502 |
+
num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
|
| 503 |
+
sequence_length=hidden_states.shape[1],
|
| 504 |
+
)
|
| 505 |
+
elif attention_mask is not None:
|
| 506 |
+
position_ids = attention_mask.long().cumsum(dim=-1) - 1
|
| 507 |
+
position_ids = position_ids.masked_fill(~attention_mask, 0)
|
| 508 |
+
position_ids = position_ids[:, -hidden_states.shape[1] :]
|
| 509 |
+
else:
|
| 510 |
+
past_length = 0
|
| 511 |
+
if past_key_values is not None and len(past_key_values) > 0:
|
| 512 |
+
past_length = past_key_values[0][0].shape[1]
|
| 513 |
+
position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device, dtype=torch.long)
|
| 514 |
+
position_ids = position_ids + past_length
|
| 515 |
+
position_ids = position_ids.unsqueeze(0).expand(hidden_states.shape[0], -1)
|
| 516 |
+
|
| 517 |
+
if cu_seqlens is not None and self.attn_implementation == "flash_attention_2":
|
| 518 |
+
packed_metadata = MossTTSNanoGPT2Model.build_packed_metadata(
|
| 519 |
+
hidden_states=hidden_states,
|
| 520 |
+
cu_seqlens=cu_seqlens.to(device=hidden_states.device),
|
| 521 |
+
num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
| 525 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 526 |
+
|
| 527 |
+
all_hidden_states = () if output_hidden_states else None
|
| 528 |
+
presents = [] if use_cache else None
|
| 529 |
+
for layer_index, decoder_layer in enumerate(self.layers):
|
| 530 |
+
if output_hidden_states:
|
| 531 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 532 |
+
|
| 533 |
+
if self.gradient_checkpointing and self.training:
|
| 534 |
+
if use_cache:
|
| 535 |
+
raise ValueError("use_cache=True is not supported when gradient checkpointing is enabled during training.")
|
| 536 |
+
|
| 537 |
+
def custom_forward(*inputs):
|
| 538 |
+
output, _ = decoder_layer(
|
| 539 |
+
hidden_states=inputs[0],
|
| 540 |
+
attention_mask=inputs[1],
|
| 541 |
+
packed_metadata=packed_metadata,
|
| 542 |
+
layer_past=None,
|
| 543 |
+
use_cache=False,
|
| 544 |
+
position_embeddings=position_embeddings,
|
| 545 |
+
)
|
| 546 |
+
return output
|
| 547 |
+
|
| 548 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 549 |
+
custom_forward,
|
| 550 |
+
hidden_states,
|
| 551 |
+
attention_mask,
|
| 552 |
+
use_reentrant=self.gradient_checkpointing_use_reentrant,
|
| 553 |
+
)
|
| 554 |
+
present = None
|
| 555 |
+
else:
|
| 556 |
+
hidden_states, present = decoder_layer(
|
| 557 |
+
hidden_states=hidden_states,
|
| 558 |
+
attention_mask=attention_mask,
|
| 559 |
+
packed_metadata=packed_metadata,
|
| 560 |
+
layer_past=None if past_key_values is None else past_key_values[layer_index],
|
| 561 |
+
use_cache=use_cache,
|
| 562 |
+
position_embeddings=position_embeddings,
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
| 566 |
+
if presents is not None:
|
| 567 |
+
presents.append(present)
|
| 568 |
+
|
| 569 |
+
hidden_states = self.norm(hidden_states)
|
| 570 |
+
hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
|
| 571 |
+
if output_hidden_states:
|
| 572 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 573 |
+
|
| 574 |
+
if not return_dict:
|
| 575 |
+
return (hidden_states, tuple(presents) if presents is not None else None, all_hidden_states, None)
|
| 576 |
+
|
| 577 |
+
return BaseModelOutputWithPast(
|
| 578 |
+
last_hidden_state=hidden_states,
|
| 579 |
+
past_key_values=tuple(presents) if presents is not None else None,
|
| 580 |
+
hidden_states=all_hidden_states,
|
| 581 |
+
attentions=None,
|
| 582 |
+
)
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<|audio_start|>",
|
| 4 |
+
"<|audio_end|>",
|
| 5 |
+
"<|audio_pad|>"
|
| 6 |
+
],
|
| 7 |
+
"eos_token": {
|
| 8 |
+
"content": "<|im_end|>",
|
| 9 |
+
"lstrip": false,
|
| 10 |
+
"normalized": false,
|
| 11 |
+
"rstrip": false,
|
| 12 |
+
"single_word": false
|
| 13 |
+
},
|
| 14 |
+
"pad_token": {
|
| 15 |
+
"content": "<|endoftext|>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false
|
| 20 |
+
}
|
| 21 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:06902d1fb775216338802205886a24bc715ccc606bd872a892e3d3c83ca1b9e2
|
| 3 |
+
size 11423220
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_prefix_space": false,
|
| 4 |
+
"added_tokens_decoder": {
|
| 5 |
+
"151643": {
|
| 6 |
+
"content": "<|endoftext|>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false,
|
| 11 |
+
"special": true
|
| 12 |
+
},
|
| 13 |
+
"151644": {
|
| 14 |
+
"content": "<|im_start|>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": false,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false,
|
| 19 |
+
"special": true
|
| 20 |
+
},
|
| 21 |
+
"151645": {
|
| 22 |
+
"content": "<|im_end|>",
|
| 23 |
+
"lstrip": false,
|
| 24 |
+
"normalized": false,
|
| 25 |
+
"rstrip": false,
|
| 26 |
+
"single_word": false,
|
| 27 |
+
"special": true
|
| 28 |
+
},
|
| 29 |
+
"151646": {
|
| 30 |
+
"content": "<|object_ref_start|>",
|
| 31 |
+
"lstrip": false,
|
| 32 |
+
"normalized": false,
|
| 33 |
+
"rstrip": false,
|
| 34 |
+
"single_word": false,
|
| 35 |
+
"special": true
|
| 36 |
+
},
|
| 37 |
+
"151647": {
|
| 38 |
+
"content": "<|object_ref_end|>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false,
|
| 43 |
+
"special": true
|
| 44 |
+
},
|
| 45 |
+
"151648": {
|
| 46 |
+
"content": "<|box_start|>",
|
| 47 |
+
"lstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"rstrip": false,
|
| 50 |
+
"single_word": false,
|
| 51 |
+
"special": true
|
| 52 |
+
},
|
| 53 |
+
"151649": {
|
| 54 |
+
"content": "<|box_end|>",
|
| 55 |
+
"lstrip": false,
|
| 56 |
+
"normalized": false,
|
| 57 |
+
"rstrip": false,
|
| 58 |
+
"single_word": false,
|
| 59 |
+
"special": true
|
| 60 |
+
},
|
| 61 |
+
"151650": {
|
| 62 |
+
"content": "<|quad_start|>",
|
| 63 |
+
"lstrip": false,
|
| 64 |
+
"normalized": false,
|
| 65 |
+
"rstrip": false,
|
| 66 |
+
"single_word": false,
|
| 67 |
+
"special": true
|
| 68 |
+
},
|
| 69 |
+
"151651": {
|
| 70 |
+
"content": "<|quad_end|>",
|
| 71 |
+
"lstrip": false,
|
| 72 |
+
"normalized": false,
|
| 73 |
+
"rstrip": false,
|
| 74 |
+
"single_word": false,
|
| 75 |
+
"special": true
|
| 76 |
+
},
|
| 77 |
+
"151652": {
|
| 78 |
+
"content": "<|vision_start|>",
|
| 79 |
+
"lstrip": false,
|
| 80 |
+
"normalized": false,
|
| 81 |
+
"rstrip": false,
|
| 82 |
+
"single_word": false,
|
| 83 |
+
"special": true
|
| 84 |
+
},
|
| 85 |
+
"151653": {
|
| 86 |
+
"content": "<|vision_end|>",
|
| 87 |
+
"lstrip": false,
|
| 88 |
+
"normalized": false,
|
| 89 |
+
"rstrip": false,
|
| 90 |
+
"single_word": false,
|
| 91 |
+
"special": true
|
| 92 |
+
},
|
| 93 |
+
"151654": {
|
| 94 |
+
"content": "<|vision_pad|>",
|
| 95 |
+
"lstrip": false,
|
| 96 |
+
"normalized": false,
|
| 97 |
+
"rstrip": false,
|
| 98 |
+
"single_word": false,
|
| 99 |
+
"special": true
|
| 100 |
+
},
|
| 101 |
+
"151655": {
|
| 102 |
+
"content": "<|image_pad|>",
|
| 103 |
+
"lstrip": false,
|
| 104 |
+
"normalized": false,
|
| 105 |
+
"rstrip": false,
|
| 106 |
+
"single_word": false,
|
| 107 |
+
"special": true
|
| 108 |
+
},
|
| 109 |
+
"151656": {
|
| 110 |
+
"content": "<|video_pad|>",
|
| 111 |
+
"lstrip": false,
|
| 112 |
+
"normalized": false,
|
| 113 |
+
"rstrip": false,
|
| 114 |
+
"single_word": false,
|
| 115 |
+
"special": true
|
| 116 |
+
},
|
| 117 |
+
"151657": {
|
| 118 |
+
"content": "<tool_call>",
|
| 119 |
+
"lstrip": false,
|
| 120 |
+
"normalized": false,
|
| 121 |
+
"rstrip": false,
|
| 122 |
+
"single_word": false,
|
| 123 |
+
"special": false
|
| 124 |
+
},
|
| 125 |
+
"151658": {
|
| 126 |
+
"content": "</tool_call>",
|
| 127 |
+
"lstrip": false,
|
| 128 |
+
"normalized": false,
|
| 129 |
+
"rstrip": false,
|
| 130 |
+
"single_word": false,
|
| 131 |
+
"special": false
|
| 132 |
+
},
|
| 133 |
+
"151659": {
|
| 134 |
+
"content": "<|fim_prefix|>",
|
| 135 |
+
"lstrip": false,
|
| 136 |
+
"normalized": false,
|
| 137 |
+
"rstrip": false,
|
| 138 |
+
"single_word": false,
|
| 139 |
+
"special": false
|
| 140 |
+
},
|
| 141 |
+
"151660": {
|
| 142 |
+
"content": "<|fim_middle|>",
|
| 143 |
+
"lstrip": false,
|
| 144 |
+
"normalized": false,
|
| 145 |
+
"rstrip": false,
|
| 146 |
+
"single_word": false,
|
| 147 |
+
"special": false
|
| 148 |
+
},
|
| 149 |
+
"151661": {
|
| 150 |
+
"content": "<|fim_suffix|>",
|
| 151 |
+
"lstrip": false,
|
| 152 |
+
"normalized": false,
|
| 153 |
+
"rstrip": false,
|
| 154 |
+
"single_word": false,
|
| 155 |
+
"special": false
|
| 156 |
+
},
|
| 157 |
+
"151662": {
|
| 158 |
+
"content": "<|fim_pad|>",
|
| 159 |
+
"lstrip": false,
|
| 160 |
+
"normalized": false,
|
| 161 |
+
"rstrip": false,
|
| 162 |
+
"single_word": false,
|
| 163 |
+
"special": false
|
| 164 |
+
},
|
| 165 |
+
"151663": {
|
| 166 |
+
"content": "<|repo_name|>",
|
| 167 |
+
"lstrip": false,
|
| 168 |
+
"normalized": false,
|
| 169 |
+
"rstrip": false,
|
| 170 |
+
"single_word": false,
|
| 171 |
+
"special": false
|
| 172 |
+
},
|
| 173 |
+
"151664": {
|
| 174 |
+
"content": "<|file_sep|>",
|
| 175 |
+
"lstrip": false,
|
| 176 |
+
"normalized": false,
|
| 177 |
+
"rstrip": false,
|
| 178 |
+
"single_word": false,
|
| 179 |
+
"special": false
|
| 180 |
+
},
|
| 181 |
+
"151665": {
|
| 182 |
+
"content": "<tool_response>",
|
| 183 |
+
"lstrip": false,
|
| 184 |
+
"normalized": false,
|
| 185 |
+
"rstrip": false,
|
| 186 |
+
"single_word": false,
|
| 187 |
+
"special": false
|
| 188 |
+
},
|
| 189 |
+
"151666": {
|
| 190 |
+
"content": "</tool_response>",
|
| 191 |
+
"lstrip": false,
|
| 192 |
+
"normalized": false,
|
| 193 |
+
"rstrip": false,
|
| 194 |
+
"single_word": false,
|
| 195 |
+
"special": false
|
| 196 |
+
},
|
| 197 |
+
"151667": {
|
| 198 |
+
"content": "<think>",
|
| 199 |
+
"lstrip": false,
|
| 200 |
+
"normalized": false,
|
| 201 |
+
"rstrip": false,
|
| 202 |
+
"single_word": false,
|
| 203 |
+
"special": false
|
| 204 |
+
},
|
| 205 |
+
"151668": {
|
| 206 |
+
"content": "</think>",
|
| 207 |
+
"lstrip": false,
|
| 208 |
+
"normalized": false,
|
| 209 |
+
"rstrip": false,
|
| 210 |
+
"single_word": false,
|
| 211 |
+
"special": false
|
| 212 |
+
},
|
| 213 |
+
"151669": {
|
| 214 |
+
"content": "<|audio_start|>",
|
| 215 |
+
"lstrip": false,
|
| 216 |
+
"normalized": false,
|
| 217 |
+
"rstrip": false,
|
| 218 |
+
"single_word": false,
|
| 219 |
+
"special": true
|
| 220 |
+
},
|
| 221 |
+
"151670": {
|
| 222 |
+
"content": "<|audio_end|>",
|
| 223 |
+
"lstrip": false,
|
| 224 |
+
"normalized": false,
|
| 225 |
+
"rstrip": false,
|
| 226 |
+
"single_word": false,
|
| 227 |
+
"special": true
|
| 228 |
+
},
|
| 229 |
+
"151671": {
|
| 230 |
+
"content": "<|audio_pad|>",
|
| 231 |
+
"lstrip": false,
|
| 232 |
+
"normalized": false,
|
| 233 |
+
"rstrip": false,
|
| 234 |
+
"single_word": false,
|
| 235 |
+
"special": true
|
| 236 |
+
}
|
| 237 |
+
},
|
| 238 |
+
"additional_special_tokens": [
|
| 239 |
+
"<|audio_start|>",
|
| 240 |
+
"<|audio_end|>",
|
| 241 |
+
"<|audio_pad|>"
|
| 242 |
+
],
|
| 243 |
+
"bos_token": null,
|
| 244 |
+
"clean_up_tokenization_spaces": false,
|
| 245 |
+
"eos_token": "<|im_end|>",
|
| 246 |
+
"errors": "replace",
|
| 247 |
+
"extra_special_tokens": {},
|
| 248 |
+
"model_max_length": 131072,
|
| 249 |
+
"pad_token": "<|endoftext|>",
|
| 250 |
+
"split_special_tokens": false,
|
| 251 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 252 |
+
"unk_token": null
|
| 253 |
+
}
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|