fdugyt commited on
Commit
2b6d647
·
verified ·
1 Parent(s): d594ff3

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,281 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ library_name: transformers
4
+ pipeline_tag: text-to-speech
5
+ tags:
6
+ - text-to-speech
7
+ - voice-cloning
8
+ - custom_code
9
+ - moss-tts
10
+ - moss-tts-local
11
+ - arxiv:2603.18090
12
+ language:
13
+ - zh
14
+ - yue
15
+ - en
16
+ - ar
17
+ - cs
18
+ - da
19
+ - de
20
+ - nl
21
+ - es
22
+ - fr
23
+ - fi
24
+ - el
25
+ - he
26
+ - hi
27
+ - hu
28
+ - ja
29
+ - it
30
+ - ko
31
+ - mk
32
+ - ms
33
+ - ru
34
+ - fa
35
+ - pl
36
+ - pt
37
+ - sv
38
+ - ro
39
+ - sw
40
+ - tl
41
+ - th
42
+ - tr
43
+ - vi
44
  ---
45
+ # MOSS-TTS Family
46
+
47
+ <br>
48
+
49
+ <p align="center">
50
+ &nbsp;&nbsp;&nbsp;&nbsp;
51
+ <img src="https://speech-demo.oss-cn-shanghai.aliyuncs.com/moss_tts_demo/tts_readme_imgaes_demo/openmoss_x_mosi" height="50" align="middle" />
52
+ </p>
53
+
54
+ <div align="center">
55
+ <a href="https://github.com/OpenMOSS/MOSS-TTS/tree/main"><img src="https://img.shields.io/badge/Project%20Page-GitHub-blue"></a>
56
+ <a href="https://huggingface.co/OpenMOSS-Team/MOSS-TTS-Local-Transformer-v1.5"><img src="https://img.shields.io/badge/HuggingFace-Model-yellow?logo=huggingface"></a>
57
+ <a href="https://modelscope.cn/collections/OpenMOSS-Team/MOSS-TTS"><img src="https://img.shields.io/badge/ModelScope-Models-lightgrey?logo=modelscope&amp"></a>
58
+ <a href="https://mosi.cn/#models"><img src="https://img.shields.io/badge/Blog-View-blue?logo=internet-explorer&amp"></a>
59
+
60
+ <a href="https://arxiv.org/abs/2603.18090"><img src="https://img.shields.io/badge/Arxiv-2603.18090-red?logo=Arxiv&amp"></a>
61
+ <a href="https://studio.mosi.cn"><img src="https://img.shields.io/badge/AIStudio-Try-green?logo=internet-explorer&amp"></a>
62
+ <a href="https://studio.mosi.cn/docs/moss-tts"><img src="https://img.shields.io/badge/API-Docs-00A3FF?logo=fastapi&amp"></a>
63
+ <a href="https://x.com/Open_MOSS"><img src="https://img.shields.io/badge/Twitter-Follow-black?logo=x&amp"></a>
64
+ <a href="https://discord.gg/fvm5TaWjU3"><img src="https://img.shields.io/badge/Discord-Join-5865F2?logo=discord&amp"></a>
65
+ </div>
66
+
67
+ # MOSS-TTS-Local-Transformer-v1.5
68
+
69
+ **MOSS-TTS-Local-Transformer-v1.5** is continued from [MOSS-TTS-Local-Transformer-v1.0](https://huggingface.co/OpenMOSS-Team/MOSS-TTS-Local-Transformer). It preserves the main 1.0 capabilities, including zero-shot voice cloning, long-form speech generation, token-level duration control, Pinyin/IPA pronunciation control, multilingual synthesis, and code-switching. For the full 1.0 feature walkthrough, input schema, and evaluation tables, please refer to the [MOSS-TTS-Local-Transformer-v1.0 README](https://huggingface.co/OpenMOSS-Team/MOSS-TTS-Local-Transformer).
70
+
71
+ Compared with [MOSS-TTS-Local-Transformer-v1.0](https://huggingface.co/OpenMOSS-Team/MOSS-TTS-Local-Transformer), v1.5 focuses on the following improvements:
72
+ - **Higher-fidelity stereo audio modeling**: v1.5 uses [MOSS-Audio-Tokenizer-v2](https://huggingface.co/OpenMOSS-Team/MOSS-Audio-Tokenizer-v2) as the audio tokenizer, supporting native 48 kHz stereo input and output for richer spatial detail and more natural perceived audio quality. Since the codec output is stereo, save the `[channels, samples]` tensor returned by `processor.decode(...)` directly.
73
+ - **Stronger multilingual synthesis with language tags**: when the `language` field is omitted, v1.5 may improve some languages and regress slightly on others compared with 1.0. When the language is specified, v1.5 is stronger than 1.0 on almost all supported languages. Set the tag when building the user message, for example `processor.build_user_message(text=text_fr, language="French")`.
74
+ - **More stable voice cloning**: v1.5 improves speaker similarity and reduces cloning variance, making repeated generations more consistent.
75
+ - **Better long-reference, short-text cloning**: v1.5 handles scenarios where the reference audio is much longer than the target text more reliably than 1.0.
76
+ - **More stable punctuation-following prosody**: v1.5 follows punctuation-driven pauses more closely, especially in long sentences.
77
+ - **Explicit pause control**: v1.5 supports inline pause markers such as `"[pause 3.2s]"`. For example, `我今天学习了一首中国的古诗,它的名字是[pause 3.2s]静夜思!` inserts an explicit 3.2s pause before `静夜思`.
78
+
79
+ ## Supported Languages
80
+
81
+ MOSS-TTS Local Transformer v1.5 supports **31 languages**. It keeps the 20 languages supported by [MOSS-TTS-Local-Transformer-v1.0](https://huggingface.co/OpenMOSS-Team/MOSS-TTS-Local-Transformer) and extends multilingual continued training to additional languages including Cantonese, Dutch, Finnish, Hindi, Macedonian, Malay, Romanian, Swahili, Tagalog, Thai, and Vietnamese.
82
+
83
+ | Language | Code | Flag | Language | Code | Flag | Language | Code | Flag |
84
+ |---|---|---|---|---|---|---|---|---|
85
+ | Chinese | zh | 🇨🇳 | Cantonese | yue | 🇭🇰 | English | en | 🇺🇸 |
86
+ | Arabic | ar | 🇸🇦 | Czech | cs | 🇨🇿 | Danish | da | 🇩🇰 |
87
+ | Dutch | nl | 🇳🇱 | Finnish | fi | 🇫🇮 | French | fr | 🇫🇷 |
88
+ | German | de | 🇩🇪 | Greek | el | 🇬🇷 | Hebrew | he | 🇮🇱 |
89
+ | Hindi | hi | 🇮🇳 | Hungarian | hu | 🇭🇺 | Italian | it | 🇮🇹 |
90
+ | Japanese | ja | 🇯🇵 | Korean | ko | 🇰🇷 | Macedonian | mk | 🇲🇰 |
91
+ | Malay | ms | 🇲🇾 | Persian (Farsi) | fa | 🇮🇷 | Polish | pl | 🇵🇱 |
92
+ | Portuguese | pt | 🇵🇹 | Romanian | ro | 🇷🇴 | Russian | ru | 🇷🇺 |
93
+ | Spanish | es | 🇪🇸 | Swahili | sw | 🇹🇿 | Swedish | sv | 🇸🇪 |
94
+ | Tagalog | tl | 🇵🇭 | Thai | th | 🇹🇭 | Turkish | tr | 🇹🇷 |
95
+ | Vietnamese | vi | 🇻🇳 | | | | | | |
96
+
97
+ ## Quick Start
98
+
99
+ ### Environment Setup
100
+
101
+ We recommend a clean, isolated Python environment with **Transformers 5.0.0**, or a recent Transformers version with Qwen3 support, to avoid dependency conflicts.
102
+
103
+ ```bash
104
+ conda create -n moss-tts python=3.12 -y
105
+ conda activate moss-tts
106
+ ```
107
+
108
+ Install all required dependencies:
109
+
110
+ ```bash
111
+ git clone https://github.com/OpenMOSS/MOSS-TTS.git
112
+ cd MOSS-TTS
113
+ pip install --extra-index-url https://download.pytorch.org/whl/cu128 -e .
114
+ ```
115
+
116
+ #### (Optional) Install FlashAttention 2
117
+
118
+ For better speed and lower GPU memory usage, you can install FlashAttention 2 if your hardware supports it.
119
+
120
+ ```bash
121
+ pip install --extra-index-url https://download.pytorch.org/whl/cu128 -e ".[flash-attn]"
122
+ ```
123
+
124
+ If your machine has limited RAM and many CPU cores, you can cap build parallelism:
125
+
126
+ ```bash
127
+ MAX_JOBS=4 pip install --extra-index-url https://download.pytorch.org/whl/cu128 -e ".[flash-attn]"
128
+ ```
129
+
130
+ Notes:
131
+ - Dependencies are managed in `pyproject.toml`, which currently pins `torch==2.9.1+cu128` and `torchaudio==2.9.1+cu128`.
132
+ - If FlashAttention 2 fails to build on your machine, you can skip it and use the default attention backend.
133
+ - FlashAttention 2 is only available on supported GPUs and is typically used with `torch.float16` or `torch.bfloat16`.
134
+
135
+ ### Basic Usage
136
+
137
+ > Tip: MOSS-TTS-Local-Transformer-v1.5 uses a fixed 12-codebook RVQ depth. Do not set `n_vq_for_inference` to a value different from `config.n_vq`.
138
+
139
+ MOSS-TTS-Local-Transformer-v1.5 provides the standard Hugging Face `AutoProcessor` and `AutoModel` interface. The examples below cover:
140
+ 1. Direct generation with language tags
141
+ 2. Voice cloning
142
+ 3. Duration control
143
+ 4. Explicit pause control with `[pause X.Ys]`
144
+
145
+ ```python
146
+ from pathlib import Path
147
+ from tqdm import tqdm
148
+ import importlib.util
149
+
150
+ import torch
151
+ import torchaudio
152
+ from transformers import AutoModel, AutoProcessor
153
+
154
+ # Disable the broken cuDNN SDPA backend on some CUDA/PyTorch combinations.
155
+ torch.backends.cuda.enable_cudnn_sdp(False)
156
+ # Keep these enabled as fallbacks.
157
+ torch.backends.cuda.enable_flash_sdp(True)
158
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
159
+ torch.backends.cuda.enable_math_sdp(True)
160
+
161
+ pretrained_model_name_or_path = "OpenMOSS-Team/MOSS-TTS-Local-Transformer-v1.5"
162
+ device = "cuda" if torch.cuda.is_available() else "cpu"
163
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
164
+
165
+
166
+ def resolve_attn_implementation() -> str:
167
+ # Prefer FlashAttention 2 when package + device conditions are met.
168
+ if (
169
+ device == "cuda"
170
+ and importlib.util.find_spec("flash_attn") is not None
171
+ and dtype in {torch.float16, torch.bfloat16}
172
+ ):
173
+ major, _ = torch.cuda.get_device_capability()
174
+ if major >= 8:
175
+ return "flash_attention_2"
176
+ # CUDA fallback: use PyTorch SDPA kernels.
177
+ if device == "cuda":
178
+ return "sdpa"
179
+ # CPU fallback.
180
+ return "eager"
181
+
182
+
183
+ attn_implementation = resolve_attn_implementation()
184
+ print(f"[INFO] Using attn_implementation={attn_implementation}")
185
+
186
+ processor = AutoProcessor.from_pretrained(
187
+ pretrained_model_name_or_path,
188
+ trust_remote_code=True,
189
+ )
190
+ processor.audio_tokenizer = processor.audio_tokenizer.to(device)
191
+
192
+ text_zh = "亲爱的你,愿你的每一天都值得被记住,也值得被珍惜。"
193
+ text_en = "We stand on the threshold of the AI era, where intelligence becomes an extension of human creativity."
194
+ text_fr = "Bonjour, je voudrais essayer une voix francaise naturelle et stable."
195
+ text_pause = "我今天学习了一首中国的古诗,它的名字是[pause 3.2s]静夜思!"
196
+
197
+ # Use remote demo audio to avoid requiring local assets.
198
+ ref_audio_zh = "https://speech-demo.oss-cn-shanghai.aliyuncs.com/moss_tts_demo/tts_readme_demo/reference_zh.wav"
199
+ ref_audio_en = "https://speech-demo.oss-cn-shanghai.aliyuncs.com/moss_tts_demo/tts_readme_demo/reference_en.m4a"
200
+
201
+ conversations = [
202
+ # Direct TTS. Language tags are recommended in v1.5 when the language is known.
203
+ [processor.build_user_message(text=text_zh, language="Chinese")],
204
+ [processor.build_user_message(text=text_en, language="English")],
205
+ [processor.build_user_message(text=text_fr, language="French")],
206
+ # Explicit pause control. Use [pause X.Ys], such as [pause 3.2s].
207
+ [processor.build_user_message(text=text_pause, language="Chinese")],
208
+ # Voice cloning with a reference audio.
209
+ [processor.build_user_message(text=text_zh, reference=[ref_audio_zh], language="Chinese")],
210
+ [processor.build_user_message(text=text_en, reference=[ref_audio_en], language="English")],
211
+ # Duration control. At 12.5 frames per second, 125 frames is about 10 seconds.
212
+ [processor.build_user_message(text=text_en, tokens=125, language="English")],
213
+ ]
214
+
215
+ model = AutoModel.from_pretrained(
216
+ pretrained_model_name_or_path,
217
+ trust_remote_code=True,
218
+ attn_implementation=attn_implementation,
219
+ torch_dtype=dtype,
220
+ ).to(device)
221
+ model.eval()
222
+
223
+ batch_size = 1
224
+ save_dir = Path("inference_root_moss_tts_local_v1_5")
225
+ save_dir.mkdir(exist_ok=True, parents=True)
226
+ sample_idx = 0
227
+
228
+ with torch.no_grad():
229
+ for start in tqdm(range(0, len(conversations), batch_size)):
230
+ batch_conversations = conversations[start : start + batch_size]
231
+ batch = processor(batch_conversations, mode="generation")
232
+ input_ids = batch["input_ids"].to(device)
233
+ attention_mask = batch["attention_mask"].to(device)
234
+
235
+ outputs = model.generate(
236
+ input_ids=input_ids,
237
+ attention_mask=attention_mask,
238
+ max_new_tokens=4096,
239
+ do_sample=True,
240
+ audio_temperature=1.7,
241
+ audio_top_p=0.8,
242
+ audio_top_k=25,
243
+ audio_repetition_penalty=1.0,
244
+ )
245
+
246
+ for message in processor.decode(outputs):
247
+ if message is None:
248
+ continue
249
+ audio = message.audio_codes_list[0]
250
+ out_path = save_dir / f"sample{sample_idx}.wav"
251
+ sample_idx += 1
252
+ # MOSS-TTS Local v1.5 codec returns stereo audio as [channels, samples].
253
+ # Save the two-channel tensor directly.
254
+ torchaudio.save(str(out_path), audio, processor.model_config.sampling_rate)
255
+ ```
256
+
257
+ ## Generation Parameters
258
+
259
+ | Parameter | Recommended | Description |
260
+ |---|---:|---|
261
+ | `audio_temperature` | `1.7` | Sampling temperature for audio RVQ layers. |
262
+ | `audio_top_p` | `0.8` | Nucleus sampling cutoff for audio RVQ layers. |
263
+ | `audio_top_k` | `25` | Top-k sampling cutoff for audio RVQ layers. |
264
+ | `audio_repetition_penalty` | `1.0` | Penalty for repeated acoustic token patterns. |
265
+ | `n_vq_for_inference` | `12` | Fixed by this release. Values other than `config.n_vq` are rejected. |
266
+
267
+ ## Notes
268
+
269
+ - This repository uses Hugging Face remote code. Load it with `trust_remote_code=True`.
270
+ - The MOSS-TTS-Local-Transformer-v1.5 codec is stereo. `processor.decode(...)` returns audio tensors shaped as `[channels, samples]`, so save them directly with `torchaudio.save(path, audio, sampling_rate)`.
271
+ - Audio encoding and decoding use `OpenMOSS-Team/MOSS-Audio-Tokenizer-v2`.
272
+ - The model configuration sets `sampling_rate` to 48000 and `n_vq` to 12.
273
+ - If FlashAttention 2 is unavailable, the example falls back to SDPA on CUDA and eager attention on CPU.
274
+
275
+ ## More Usage
276
+
277
+ MOSS-TTS-Local-Transformer-v1.5 is API-compatible with MOSS-TTS-Local-Transformer-v1.0. For continuation with prefix audio, detailed `UserMessage` and `AssistantMessage` fields, generation hyperparameters, Pinyin/IPA preprocessing examples, and evaluation results, see the [MOSS-TTS-Local-Transformer-v1.0](https://huggingface.co/OpenMOSS-Team/MOSS-TTS-Local-Transformer).
278
+
279
+ ## Citation
280
+
281
+ If you use this model, please cite the [MOSS-TTS Technical Report](https://arxiv.org/abs/2603.18090).
__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .configuration_moss_tts import MossTTSLocalConfig
2
+ from .modeling_moss_tts import MossTTSLocalModel
3
+ from .processing_moss_tts import MossTTSLocalProcessor
4
+
5
+ __all__ = [
6
+ "MossTTSLocalConfig",
7
+ "MossTTSLocalModel",
8
+ "MossTTSLocalProcessor",
9
+ ]
added_tokens.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|audio_end|>": 151670,
9
+ "<|audio_pad|>": 151671,
10
+ "<|audio_start|>": 151669,
11
+ "<|box_end|>": 151649,
12
+ "<|box_start|>": 151648,
13
+ "<|endoftext|>": 151643,
14
+ "<|file_sep|>": 151664,
15
+ "<|fim_middle|>": 151660,
16
+ "<|fim_pad|>": 151662,
17
+ "<|fim_prefix|>": 151659,
18
+ "<|fim_suffix|>": 151661,
19
+ "<|im_end|>": 151645,
20
+ "<|im_start|>": 151644,
21
+ "<|image_pad|>": 151655,
22
+ "<|object_ref_end|>": 151647,
23
+ "<|object_ref_start|>": 151646,
24
+ "<|quad_end|>": 151651,
25
+ "<|quad_start|>": 151650,
26
+ "<|repo_name|>": 151663,
27
+ "<|video_pad|>": 151656,
28
+ "<|vision_end|>": 151653,
29
+ "<|vision_pad|>": 151654,
30
+ "<|vision_start|>": 151652
31
+ }
chat_template.jinja ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if message.content is string %}
27
+ {%- set content = message.content %}
28
+ {%- else %}
29
+ {%- set content = '' %}
30
+ {%- endif %}
31
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
32
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
33
+ {%- elif message.role == "assistant" %}
34
+ {%- set reasoning_content = '' %}
35
+ {%- if message.reasoning_content is string %}
36
+ {%- set reasoning_content = message.reasoning_content %}
37
+ {%- else %}
38
+ {%- if '</think>' in content %}
39
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
40
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
41
+ {%- endif %}
42
+ {%- endif %}
43
+ {%- if loop.index0 > ns.last_query_index %}
44
+ {%- if loop.last or (not loop.last and reasoning_content) %}
45
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
46
+ {%- else %}
47
+ {{- '<|im_start|>' + message.role + '\n' + content }}
48
+ {%- endif %}
49
+ {%- else %}
50
+ {{- '<|im_start|>' + message.role + '\n' + content }}
51
+ {%- endif %}
52
+ {%- if message.tool_calls %}
53
+ {%- for tool_call in message.tool_calls %}
54
+ {%- if (loop.first and content) or (not loop.first) %}
55
+ {{- '\n' }}
56
+ {%- endif %}
57
+ {%- if tool_call.function %}
58
+ {%- set tool_call = tool_call.function %}
59
+ {%- endif %}
60
+ {{- '<tool_call>\n{"name": "' }}
61
+ {{- tool_call.name }}
62
+ {{- '", "arguments": ' }}
63
+ {%- if tool_call.arguments is string %}
64
+ {{- tool_call.arguments }}
65
+ {%- else %}
66
+ {{- tool_call.arguments | tojson }}
67
+ {%- endif %}
68
+ {{- '}\n</tool_call>' }}
69
+ {%- endfor %}
70
+ {%- endif %}
71
+ {{- '<|im_end|>\n' }}
72
+ {%- elif message.role == "tool" %}
73
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
74
+ {{- '<|im_start|>user' }}
75
+ {%- endif %}
76
+ {{- '\n<tool_response>\n' }}
77
+ {{- content }}
78
+ {{- '\n</tool_response>' }}
79
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
80
+ {{- '<|im_end|>\n' }}
81
+ {%- endif %}
82
+ {%- endif %}
83
+ {%- endfor %}
84
+ {%- if add_generation_prompt %}
85
+ {{- '<|im_start|>assistant\n' }}
86
+ {%- if enable_thinking is defined and enable_thinking is false %}
87
+ {{- '<think>\n\n</think>\n\n' }}
88
+ {%- endif %}
89
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "moss_tts_local",
3
+ "architectures": [
4
+ "MossTTSLocalModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_moss_tts.MossTTSLocalConfig",
8
+ "AutoModel": "modeling_moss_tts.MossTTSLocalModel",
9
+ "AutoProcessor": "processing_moss_tts.MossTTSLocalProcessor"
10
+ },
11
+ "processor_class": "MossTTSLocalProcessor",
12
+ "qwen3_config": {
13
+ "_name_or_path": "",
14
+ "add_cross_attention": false,
15
+ "architectures": [
16
+ "Qwen3ForCausalLM"
17
+ ],
18
+ "attention_bias": false,
19
+ "attention_dropout": 0.0,
20
+ "bad_words_ids": null,
21
+ "begin_suppress_tokens": null,
22
+ "bos_token_id": 151643,
23
+ "chunk_size_feed_forward": 0,
24
+ "cross_attention_hidden_size": null,
25
+ "decoder_start_token_id": null,
26
+ "diversity_penalty": 0.0,
27
+ "do_sample": false,
28
+ "dtype": "bfloat16",
29
+ "early_stopping": false,
30
+ "encoder_no_repeat_ngram_size": 0,
31
+ "eos_token_id": 151643,
32
+ "exponential_decay_length_penalty": null,
33
+ "finetuning_task": null,
34
+ "forced_bos_token_id": null,
35
+ "forced_eos_token_id": null,
36
+ "gradient_checkpointing_use_reentrant": false,
37
+ "head_dim": 128,
38
+ "hidden_act": "silu",
39
+ "hidden_size": 2560,
40
+ "id2label": {
41
+ "0": "LABEL_0",
42
+ "1": "LABEL_1"
43
+ },
44
+ "initializer_range": 0.02,
45
+ "intermediate_size": 9728,
46
+ "is_decoder": false,
47
+ "is_encoder_decoder": false,
48
+ "label2id": {
49
+ "LABEL_0": 0,
50
+ "LABEL_1": 1
51
+ },
52
+ "layer_types": [
53
+ "full_attention",
54
+ "full_attention",
55
+ "full_attention",
56
+ "full_attention",
57
+ "full_attention",
58
+ "full_attention",
59
+ "full_attention",
60
+ "full_attention",
61
+ "full_attention",
62
+ "full_attention",
63
+ "full_attention",
64
+ "full_attention",
65
+ "full_attention",
66
+ "full_attention",
67
+ "full_attention",
68
+ "full_attention",
69
+ "full_attention",
70
+ "full_attention",
71
+ "full_attention",
72
+ "full_attention",
73
+ "full_attention",
74
+ "full_attention",
75
+ "full_attention",
76
+ "full_attention",
77
+ "full_attention",
78
+ "full_attention",
79
+ "full_attention",
80
+ "full_attention",
81
+ "full_attention",
82
+ "full_attention",
83
+ "full_attention",
84
+ "full_attention",
85
+ "full_attention",
86
+ "full_attention",
87
+ "full_attention",
88
+ "full_attention"
89
+ ],
90
+ "length_penalty": 1.0,
91
+ "max_length": 20,
92
+ "max_position_embeddings": 32768,
93
+ "max_window_layers": 36,
94
+ "min_length": 0,
95
+ "model_type": "qwen3",
96
+ "no_repeat_ngram_size": 0,
97
+ "num_attention_heads": 32,
98
+ "num_beam_groups": 1,
99
+ "num_beams": 1,
100
+ "num_hidden_layers": 36,
101
+ "num_key_value_heads": 8,
102
+ "num_return_sequences": 1,
103
+ "output_attentions": false,
104
+ "output_hidden_states": false,
105
+ "output_scores": false,
106
+ "pad_token_id": 151643,
107
+ "prefix": null,
108
+ "problem_type": null,
109
+ "pruned_heads": {},
110
+ "remove_invalid_values": false,
111
+ "repetition_penalty": 1.0,
112
+ "return_dict": true,
113
+ "return_dict_in_generate": false,
114
+ "rms_norm_eps": 1e-06,
115
+ "rope_scaling": null,
116
+ "rope_theta": 1000000,
117
+ "sep_token_id": null,
118
+ "sliding_window": null,
119
+ "suppress_tokens": null,
120
+ "task_specific_params": null,
121
+ "temperature": 1.0,
122
+ "tf_legacy_loss": false,
123
+ "tie_encoder_decoder": false,
124
+ "tie_word_embeddings": true,
125
+ "tokenizer_class": null,
126
+ "top_k": 50,
127
+ "top_p": 1.0,
128
+ "torchscript": false,
129
+ "transformers_version": "4.57.1",
130
+ "typical_p": 1.0,
131
+ "use_bfloat16": false,
132
+ "use_cache": false,
133
+ "use_sliding_window": false,
134
+ "vocab_size": 151936
135
+ },
136
+ "language_config": {
137
+ "_name_or_path": "",
138
+ "add_cross_attention": false,
139
+ "architectures": [
140
+ "Qwen3ForCausalLM"
141
+ ],
142
+ "attention_bias": false,
143
+ "attention_dropout": 0.0,
144
+ "bad_words_ids": null,
145
+ "begin_suppress_tokens": null,
146
+ "bos_token_id": 151643,
147
+ "chunk_size_feed_forward": 0,
148
+ "cross_attention_hidden_size": null,
149
+ "decoder_start_token_id": null,
150
+ "diversity_penalty": 0.0,
151
+ "do_sample": false,
152
+ "dtype": "bfloat16",
153
+ "early_stopping": false,
154
+ "encoder_no_repeat_ngram_size": 0,
155
+ "eos_token_id": 151643,
156
+ "exponential_decay_length_penalty": null,
157
+ "finetuning_task": null,
158
+ "forced_bos_token_id": null,
159
+ "forced_eos_token_id": null,
160
+ "gradient_checkpointing_use_reentrant": false,
161
+ "head_dim": 128,
162
+ "hidden_act": "silu",
163
+ "hidden_size": 2560,
164
+ "id2label": {
165
+ "0": "LABEL_0",
166
+ "1": "LABEL_1"
167
+ },
168
+ "initializer_range": 0.02,
169
+ "intermediate_size": 9728,
170
+ "is_decoder": false,
171
+ "is_encoder_decoder": false,
172
+ "label2id": {
173
+ "LABEL_0": 0,
174
+ "LABEL_1": 1
175
+ },
176
+ "layer_types": [
177
+ "full_attention",
178
+ "full_attention",
179
+ "full_attention",
180
+ "full_attention",
181
+ "full_attention",
182
+ "full_attention",
183
+ "full_attention",
184
+ "full_attention",
185
+ "full_attention",
186
+ "full_attention",
187
+ "full_attention",
188
+ "full_attention",
189
+ "full_attention",
190
+ "full_attention",
191
+ "full_attention",
192
+ "full_attention",
193
+ "full_attention",
194
+ "full_attention",
195
+ "full_attention",
196
+ "full_attention",
197
+ "full_attention",
198
+ "full_attention",
199
+ "full_attention",
200
+ "full_attention",
201
+ "full_attention",
202
+ "full_attention",
203
+ "full_attention",
204
+ "full_attention",
205
+ "full_attention",
206
+ "full_attention",
207
+ "full_attention",
208
+ "full_attention",
209
+ "full_attention",
210
+ "full_attention",
211
+ "full_attention",
212
+ "full_attention"
213
+ ],
214
+ "length_penalty": 1.0,
215
+ "max_length": 20,
216
+ "max_position_embeddings": 32768,
217
+ "max_window_layers": 36,
218
+ "min_length": 0,
219
+ "model_type": "qwen3",
220
+ "no_repeat_ngram_size": 0,
221
+ "num_attention_heads": 32,
222
+ "num_beam_groups": 1,
223
+ "num_beams": 1,
224
+ "num_hidden_layers": 36,
225
+ "num_key_value_heads": 8,
226
+ "num_return_sequences": 1,
227
+ "output_attentions": false,
228
+ "output_hidden_states": false,
229
+ "output_scores": false,
230
+ "pad_token_id": 151643,
231
+ "prefix": null,
232
+ "problem_type": null,
233
+ "pruned_heads": {},
234
+ "remove_invalid_values": false,
235
+ "repetition_penalty": 1.0,
236
+ "return_dict": true,
237
+ "return_dict_in_generate": false,
238
+ "rms_norm_eps": 1e-06,
239
+ "rope_scaling": null,
240
+ "rope_theta": 1000000,
241
+ "sep_token_id": null,
242
+ "sliding_window": null,
243
+ "suppress_tokens": null,
244
+ "task_specific_params": null,
245
+ "temperature": 1.0,
246
+ "tf_legacy_loss": false,
247
+ "tie_encoder_decoder": false,
248
+ "tie_word_embeddings": true,
249
+ "tokenizer_class": null,
250
+ "top_k": 50,
251
+ "top_p": 1.0,
252
+ "torchscript": false,
253
+ "transformers_version": "4.57.1",
254
+ "typical_p": 1.0,
255
+ "use_bfloat16": false,
256
+ "use_cache": false,
257
+ "use_sliding_window": false,
258
+ "vocab_size": 151936
259
+ },
260
+ "gpt2_config": {
261
+ "_name_or_path": "",
262
+ "activation_function": "silu",
263
+ "add_cross_attention": false,
264
+ "architectures": null,
265
+ "attn_pdrop": 0.0,
266
+ "bad_words_ids": null,
267
+ "begin_suppress_tokens": null,
268
+ "bos_token_id": null,
269
+ "chunk_size_feed_forward": 0,
270
+ "cross_attention_hidden_size": null,
271
+ "decoder_start_token_id": null,
272
+ "diversity_penalty": 0.0,
273
+ "do_sample": false,
274
+ "dtype": null,
275
+ "early_stopping": false,
276
+ "embd_pdrop": 0.0,
277
+ "encoder_no_repeat_ngram_size": 0,
278
+ "eos_token_id": 151645,
279
+ "exponential_decay_length_penalty": null,
280
+ "finetuning_task": null,
281
+ "forced_bos_token_id": null,
282
+ "forced_eos_token_id": null,
283
+ "id2label": {
284
+ "0": "LABEL_0",
285
+ "1": "LABEL_1"
286
+ },
287
+ "initializer_range": 0.02,
288
+ "is_decoder": false,
289
+ "is_encoder_decoder": false,
290
+ "label2id": {
291
+ "LABEL_0": 0,
292
+ "LABEL_1": 1
293
+ },
294
+ "layer_norm_epsilon": 1e-06,
295
+ "length_penalty": 1.0,
296
+ "max_length": 20,
297
+ "min_length": 0,
298
+ "model_type": "gpt2",
299
+ "n_ctx": 10240,
300
+ "n_embd": 2560,
301
+ "n_head": 32,
302
+ "n_inner": 9728,
303
+ "n_layer": 1,
304
+ "n_positions": 10240,
305
+ "no_repeat_ngram_size": 0,
306
+ "num_beam_groups": 1,
307
+ "num_beams": 1,
308
+ "num_return_sequences": 1,
309
+ "output_attentions": false,
310
+ "output_hidden_states": false,
311
+ "output_scores": false,
312
+ "pad_token_id": null,
313
+ "position_embedding_type": "rope",
314
+ "prefix": null,
315
+ "problem_type": null,
316
+ "pruned_heads": {},
317
+ "remove_invalid_values": false,
318
+ "reorder_and_upcast_attn": false,
319
+ "repetition_penalty": 1.0,
320
+ "resid_pdrop": 0.0,
321
+ "return_dict": true,
322
+ "return_dict_in_generate": false,
323
+ "rope_base": 1000000.0,
324
+ "scale_attn_by_inverse_layer_idx": false,
325
+ "scale_attn_weights": true,
326
+ "sep_token_id": null,
327
+ "summary_activation": null,
328
+ "summary_first_dropout": 0.1,
329
+ "summary_proj_to_labels": true,
330
+ "summary_type": "cls_index",
331
+ "summary_use_proj": true,
332
+ "suppress_tokens": null,
333
+ "task_specific_params": null,
334
+ "temperature": 1.0,
335
+ "tf_legacy_loss": false,
336
+ "tie_encoder_decoder": false,
337
+ "tie_word_embeddings": true,
338
+ "tokenizer_class": null,
339
+ "top_k": 50,
340
+ "top_p": 1.0,
341
+ "torchscript": false,
342
+ "transformers_version": "4.57.1",
343
+ "typical_p": 1.0,
344
+ "use_bfloat16": false,
345
+ "use_cache": true,
346
+ "vocab_size": 151936
347
+ },
348
+ "n_vq": 12,
349
+ "audio_vocab_size": 1024,
350
+ "audio_codebook_sizes": [
351
+ 1024,
352
+ 1024,
353
+ 1024,
354
+ 1024,
355
+ 1024,
356
+ 1024,
357
+ 1024,
358
+ 1024,
359
+ 1024,
360
+ 1024,
361
+ 1024,
362
+ 1024
363
+ ],
364
+ "audio_pad_token_id": 1024,
365
+ "audio_pad_code": 1024,
366
+ "pad_token_id": 151643,
367
+ "im_start_token_id": 151644,
368
+ "im_end_token_id": 151645,
369
+ "audio_start_token_id": 151669,
370
+ "audio_end_token_id": 151670,
371
+ "audio_user_slot_token_id": 151654,
372
+ "audio_assistant_slot_token_id": 151656,
373
+ "audio_assistant_gen_slot_token_id": 151656,
374
+ "sampling_rate": 48000,
375
+ "audio_tokenizer_name_or_path": "OpenMOSS-Team/MOSS-Audio-Tokenizer-v2",
376
+ "attn_implementation": "flash_attention_2",
377
+ "local_transformer_layers": 1,
378
+ "local_text_head_mode": "binary",
379
+ "use_static_local_kv_cache": true,
380
+ "initializer_range": 0.02
381
+ }
configuration_moss_tts.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """Configuration for the MOSS-TTS-Local-Transformer-v1.5 release."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import Any, Dict, Optional, Union
7
+
8
+ from transformers.configuration_utils import PretrainedConfig
9
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
10
+ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
11
+
12
+
13
+ SUPPORTED_ATTENTION_IMPLEMENTATIONS = {"flash_attention_2", "sdpa", "eager"}
14
+
15
+
16
+ def _normalize_attention_implementation(value: Optional[str], default: str = "flash_attention_2") -> str:
17
+ normalized = str(value or default).strip().lower()
18
+ if normalized in {"flash", "flash_attn", "flash-attn", "flash_attention"}:
19
+ normalized = "flash_attention_2"
20
+ if normalized not in SUPPORTED_ATTENTION_IMPLEMENTATIONS:
21
+ raise ValueError(
22
+ "attn_implementation must be one of "
23
+ f"{sorted(SUPPORTED_ATTENTION_IMPLEMENTATIONS)}, got {value!r}."
24
+ )
25
+ return normalized
26
+
27
+
28
+ class MossTTSLocalConfig(PretrainedConfig):
29
+ model_type = "moss_tts_local"
30
+ keys_to_ignore_at_inference = ["past_key_values"]
31
+
32
+ def __init__(
33
+ self,
34
+ qwen3_config: Optional[Union[Qwen3Config, Dict[str, Any]]] = None,
35
+ gpt2_config: Optional[Union[GPT2Config, Dict[str, Any]]] = None,
36
+ language_config: Optional[Union[Qwen3Config, Dict[str, Any]]] = None,
37
+ n_vq: int = 12,
38
+ audio_vocab_size: int = 1024,
39
+ audio_codebook_sizes: Optional[list[int]] = None,
40
+ audio_pad_token_id: int = 1024,
41
+ audio_pad_code: Optional[int] = None,
42
+ pad_token_id: int = 151643,
43
+ im_start_token_id: int = 151644,
44
+ im_end_token_id: int = 151645,
45
+ audio_start_token_id: int = 151669,
46
+ audio_end_token_id: int = 151670,
47
+ audio_user_slot_token_id: int = 151654,
48
+ audio_assistant_slot_token_id: int = 151656,
49
+ audio_assistant_gen_slot_token_id: Optional[int] = None,
50
+ sampling_rate: int = 48000,
51
+ audio_tokenizer_name_or_path: Optional[str] = None,
52
+ attn_implementation: str = "flash_attention_2",
53
+ local_transformer_attn_implementation: Optional[str] = None,
54
+ local_text_head_mode: str = "binary",
55
+ initializer_range: float = 0.02,
56
+ **kwargs: Any,
57
+ ) -> None:
58
+ if qwen3_config is None and language_config is not None:
59
+ qwen3_config = language_config
60
+ if isinstance(qwen3_config, dict):
61
+ self.qwen3_config = Qwen3Config(**qwen3_config)
62
+ elif qwen3_config is None:
63
+ self.qwen3_config = Qwen3Config()
64
+ else:
65
+ self.qwen3_config = qwen3_config
66
+
67
+ if isinstance(gpt2_config, dict):
68
+ self.gpt2_config = GPT2Config(**gpt2_config)
69
+ elif gpt2_config is None:
70
+ self.gpt2_config = GPT2Config(
71
+ vocab_size=int(self.qwen3_config.vocab_size),
72
+ n_embd=int(self.qwen3_config.hidden_size),
73
+ n_layer=1,
74
+ n_head=max(1, int(self.qwen3_config.hidden_size) // 80),
75
+ n_positions=int(n_vq) + 1,
76
+ n_ctx=int(n_vq) + 1,
77
+ activation_function="silu",
78
+ layer_norm_epsilon=1e-6,
79
+ resid_pdrop=0.0,
80
+ embd_pdrop=0.0,
81
+ attn_pdrop=0.0,
82
+ )
83
+ else:
84
+ self.gpt2_config = gpt2_config
85
+
86
+ self.n_vq = int(n_vq)
87
+ if self.n_vq <= 0:
88
+ raise ValueError("n_vq must be positive.")
89
+ if audio_codebook_sizes is None:
90
+ self.audio_codebook_sizes = [int(audio_vocab_size)] * self.n_vq
91
+ else:
92
+ self.audio_codebook_sizes = [int(size) for size in audio_codebook_sizes]
93
+ if len(self.audio_codebook_sizes) != self.n_vq:
94
+ raise ValueError(
95
+ f"audio_codebook_sizes must have length n_vq={self.n_vq}, "
96
+ f"got {len(self.audio_codebook_sizes)}."
97
+ )
98
+ if any(size <= 0 for size in self.audio_codebook_sizes):
99
+ raise ValueError("audio_codebook_sizes must contain positive integers.")
100
+ self.audio_vocab_size = int(max(int(audio_vocab_size), max(self.audio_codebook_sizes)))
101
+ self.audio_pad_token_id = int(audio_pad_code if audio_pad_code is not None else audio_pad_token_id)
102
+ self.audio_pad_code = self.audio_pad_token_id
103
+ if self.audio_pad_token_id < self.audio_vocab_size:
104
+ raise ValueError("audio_pad_token_id/audio_pad_code must be outside the audio vocab.")
105
+
106
+ self.pad_token_id = int(pad_token_id)
107
+ self.im_start_token_id = int(im_start_token_id)
108
+ self.im_end_token_id = int(im_end_token_id)
109
+ self.audio_start_token_id = int(audio_start_token_id)
110
+ self.audio_end_token_id = int(audio_end_token_id)
111
+ self.audio_user_slot_token_id = int(audio_user_slot_token_id)
112
+ self.audio_assistant_slot_token_id = int(
113
+ audio_assistant_slot_token_id
114
+ if audio_assistant_gen_slot_token_id is None
115
+ else audio_assistant_gen_slot_token_id
116
+ )
117
+ self.audio_assistant_gen_slot_token_id = self.audio_assistant_slot_token_id
118
+
119
+ self.sampling_rate = int(sampling_rate)
120
+ self.audio_tokenizer_name_or_path = audio_tokenizer_name_or_path
121
+ self.attn_implementation = _normalize_attention_implementation(attn_implementation)
122
+ self.local_transformer_attn_implementation = _normalize_attention_implementation(
123
+ local_transformer_attn_implementation,
124
+ default=self.attn_implementation,
125
+ )
126
+ self.initializer_range = float(initializer_range)
127
+
128
+ self.hidden_size = int(self.qwen3_config.hidden_size)
129
+ self.vocab_size = int(self.qwen3_config.vocab_size)
130
+ self.local_hidden_size = int(self.gpt2_config.hidden_size)
131
+ if self.local_hidden_size != self.hidden_size:
132
+ raise ValueError(
133
+ "This MOSS-TTS-Local-Transformer-v1.5 release expects local hidden size to "
134
+ "match Qwen3 hidden size so audio embeddings and heads are tied."
135
+ )
136
+
137
+ normalized_text_head_mode = str(local_text_head_mode or "full_vocab").strip().lower()
138
+ if normalized_text_head_mode in {"full", "full-vocab", "vocab"}:
139
+ normalized_text_head_mode = "full_vocab"
140
+ if normalized_text_head_mode not in {"full_vocab", "binary"}:
141
+ raise ValueError("local_text_head_mode must be 'full_vocab' or 'binary'.")
142
+ self.local_text_head_mode = normalized_text_head_mode
143
+
144
+ kwargs.setdefault("tie_word_embeddings", True)
145
+ super().__init__(pad_token_id=self.pad_token_id, **kwargs)
146
+
147
+ @property
148
+ def language_config(self) -> Qwen3Config:
149
+ return self.qwen3_config
150
+
151
+ def to_dict(self) -> Dict[str, Any]:
152
+ output = super().to_dict()
153
+ output["qwen3_config"] = self.qwen3_config.to_dict()
154
+ output["language_config"] = self.qwen3_config.to_dict()
155
+ output["gpt2_config"] = self.gpt2_config.to_dict()
156
+ output["audio_pad_code"] = self.audio_pad_token_id
157
+ output["audio_assistant_gen_slot_token_id"] = self.audio_assistant_slot_token_id
158
+ return output
gpt2_decoder.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.utils.checkpoint
10
+ from transformers.activations import ACT2FN
11
+ from transformers.modeling_outputs import BaseModelOutputWithPast
12
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
13
+
14
+ try:
15
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
16
+ from flash_attn.bert_padding import pad_input, unpad_input
17
+
18
+ _FLASH_ATTN_AVAILABLE = True
19
+ except Exception:
20
+ flash_attn_func = None
21
+ flash_attn_varlen_func = None
22
+ pad_input = None
23
+ unpad_input = None
24
+ _FLASH_ATTN_AVAILABLE = False
25
+
26
+
27
+ @dataclass
28
+ class PackedSequenceMetadata:
29
+ cu_seqlens: torch.Tensor
30
+ max_seqlen: int
31
+ indices: Optional[torch.Tensor] = None
32
+ batch_size: Optional[int] = None
33
+ seq_len: Optional[int] = None
34
+
35
+
36
+ def _is_static_kv_cache_layer(layer_past: object) -> bool:
37
+ return isinstance(layer_past, dict) and bool(layer_past.get("static_kv_cache", False))
38
+
39
+
40
+ class MossTTSNanoGPT2RotaryEmbedding(nn.Module):
41
+ def __init__(self, dim: int, base: float = 10000.0) -> None:
42
+ super().__init__()
43
+ if dim % 2 != 0:
44
+ raise ValueError(f"RoPE head_dim must be even, got {dim}")
45
+ self.dim = int(dim)
46
+ self.base = float(base)
47
+ self.register_buffer("inv_freq", self._compute_inv_freq(), persistent=False)
48
+
49
+ def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
50
+ return 1.0 / (
51
+ self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
52
+ )
53
+
54
+ def forward(
55
+ self,
56
+ position_ids: torch.LongTensor,
57
+ *,
58
+ device: torch.device,
59
+ dtype: torch.dtype,
60
+ ) -> tuple[torch.Tensor, torch.Tensor]:
61
+ if position_ids.ndim == 1:
62
+ position_ids = position_ids.unsqueeze(0)
63
+ inv_freq = self._compute_inv_freq(device=device)
64
+ freqs = torch.einsum("bs,d->bsd", position_ids.to(device=device, dtype=inv_freq.dtype), inv_freq)
65
+ cos = freqs.cos().repeat_interleave(2, dim=-1).unsqueeze(2).to(dtype=dtype)
66
+ sin = freqs.sin().repeat_interleave(2, dim=-1).unsqueeze(2).to(dtype=dtype)
67
+ return cos, sin
68
+
69
+
70
+ def rotate_half(hidden_states: torch.Tensor) -> torch.Tensor:
71
+ even = hidden_states[..., ::2]
72
+ odd = hidden_states[..., 1::2]
73
+ return torch.stack((-odd, even), dim=-1).reshape_as(hidden_states)
74
+
75
+
76
+ def apply_rotary_pos_emb(
77
+ hidden_states: torch.Tensor,
78
+ cos: torch.Tensor,
79
+ sin: torch.Tensor,
80
+ ) -> torch.Tensor:
81
+ return (hidden_states * cos) + (rotate_half(hidden_states) * sin)
82
+
83
+
84
+ class MossTTSNanoGPT2MLP(nn.Module):
85
+ def __init__(self, config: GPT2Config) -> None:
86
+ super().__init__()
87
+ hidden_size = int(config.hidden_size)
88
+ inner_size = int(config.n_inner or 4 * hidden_size)
89
+ self.fc_in = nn.Linear(hidden_size, inner_size)
90
+ self.fc_out = nn.Linear(inner_size, hidden_size)
91
+ self.act = ACT2FN[config.activation_function]
92
+ self.dropout = nn.Dropout(config.resid_pdrop)
93
+
94
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
95
+ hidden_states = self.fc_in(hidden_states)
96
+ hidden_states = self.act(hidden_states)
97
+ hidden_states = self.fc_out(hidden_states)
98
+ return self.dropout(hidden_states)
99
+
100
+
101
+ class MossTTSNanoGPT2Attention(nn.Module):
102
+ def __init__(self, config: GPT2Config, layer_idx: int, attn_implementation: str) -> None:
103
+ super().__init__()
104
+ hidden_size = int(config.hidden_size)
105
+ num_heads = int(config.num_attention_heads)
106
+ if hidden_size % num_heads != 0:
107
+ raise ValueError(f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_heads}")
108
+
109
+ self.num_heads = num_heads
110
+ self.head_dim = hidden_size // num_heads
111
+ self.embed_dim = hidden_size
112
+ self.layer_idx = layer_idx
113
+ self.attn_implementation = attn_implementation
114
+ self.attn_dropout = float(config.attn_pdrop)
115
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
116
+ self.scale_attn_weights = bool(getattr(config, "scale_attn_weights", True))
117
+ self.scale_attn_by_inverse_layer_idx = bool(getattr(config, "scale_attn_by_inverse_layer_idx", False))
118
+ self.position_embedding_type = str(getattr(config, "position_embedding_type", "absolute")).lower()
119
+ if self.position_embedding_type not in {"absolute", "rope"}:
120
+ raise ValueError(f"Unsupported position_embedding_type={self.position_embedding_type!r}")
121
+
122
+ self.c_attn = nn.Linear(hidden_size, 3 * hidden_size)
123
+ self.c_proj = nn.Linear(hidden_size, hidden_size)
124
+ self.rotary_emb = None
125
+ if self.position_embedding_type == "rope":
126
+ self.rotary_emb = MossTTSNanoGPT2RotaryEmbedding(
127
+ self.head_dim,
128
+ base=float(getattr(config, "rope_base", 10000.0)),
129
+ )
130
+
131
+ def _split_heads(self, tensor: torch.Tensor) -> torch.Tensor:
132
+ if tensor.ndim == 3:
133
+ batch_size, seq_len, _ = tensor.shape
134
+ return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim)
135
+ if tensor.ndim == 2:
136
+ total_tokens, _ = tensor.shape
137
+ return tensor.view(total_tokens, self.num_heads, self.head_dim)
138
+ raise ValueError(f"Unsupported tensor rank for attention split: {tensor.ndim}")
139
+
140
+ def _merge_heads(self, tensor: torch.Tensor) -> torch.Tensor:
141
+ if tensor.ndim == 4:
142
+ batch_size, seq_len, _, _ = tensor.shape
143
+ return tensor.reshape(batch_size, seq_len, self.embed_dim)
144
+ if tensor.ndim == 3:
145
+ total_tokens, _, _ = tensor.shape
146
+ return tensor.reshape(total_tokens, self.embed_dim)
147
+ raise ValueError(f"Unsupported tensor rank for attention merge: {tensor.ndim}")
148
+
149
+ def _causal_attention_mask(
150
+ self,
151
+ attention_mask: Optional[torch.Tensor],
152
+ query_length: int,
153
+ key_length: int,
154
+ device: torch.device,
155
+ ) -> torch.Tensor:
156
+ query_positions = torch.arange(query_length, device=device, dtype=torch.long)
157
+ query_positions = query_positions + max(key_length - query_length, 0)
158
+ key_positions = torch.arange(key_length, device=device, dtype=torch.long)
159
+ causal = key_positions.unsqueeze(0) <= query_positions.unsqueeze(1)
160
+ causal = causal.unsqueeze(0).unsqueeze(0)
161
+ if attention_mask is None:
162
+ return causal
163
+ key_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)
164
+ return causal & key_mask
165
+
166
+ def _eager_attention(
167
+ self,
168
+ query: torch.Tensor,
169
+ key: torch.Tensor,
170
+ value: torch.Tensor,
171
+ attention_mask: Optional[torch.Tensor],
172
+ ) -> torch.Tensor:
173
+ query = query.transpose(1, 2)
174
+ key = key.transpose(1, 2)
175
+ value = value.transpose(1, 2)
176
+
177
+ scale = 1.0
178
+ if self.scale_attn_weights:
179
+ scale /= self.head_dim ** 0.5
180
+ if self.scale_attn_by_inverse_layer_idx:
181
+ scale /= float(self.layer_idx + 1)
182
+
183
+ scores = torch.matmul(query, key.transpose(-1, -2)) * scale
184
+ causal_mask = self._causal_attention_mask(
185
+ attention_mask=attention_mask,
186
+ query_length=query.shape[-2],
187
+ key_length=key.shape[-2],
188
+ device=query.device,
189
+ )
190
+ scores = scores.masked_fill(~causal_mask, torch.finfo(scores.dtype).min)
191
+ probs = torch.softmax(scores, dim=-1)
192
+ if self.training and self.attn_dropout > 0:
193
+ probs = torch.dropout(probs, self.attn_dropout, train=True)
194
+ output = torch.matmul(probs, value)
195
+ return output.transpose(1, 2).contiguous()
196
+
197
+ def _sdpa_attention(
198
+ self,
199
+ query: torch.Tensor,
200
+ key: torch.Tensor,
201
+ value: torch.Tensor,
202
+ attention_mask: Optional[torch.Tensor],
203
+ ) -> torch.Tensor:
204
+ query = query.transpose(1, 2)
205
+ key = key.transpose(1, 2)
206
+ value = value.transpose(1, 2)
207
+ mask = None
208
+ if attention_mask is not None or query.shape[-2] != key.shape[-2]:
209
+ mask = self._causal_attention_mask(
210
+ attention_mask=attention_mask,
211
+ query_length=query.shape[-2],
212
+ key_length=key.shape[-2],
213
+ device=query.device,
214
+ )
215
+ output = torch.nn.functional.scaled_dot_product_attention(
216
+ query,
217
+ key,
218
+ value,
219
+ attn_mask=mask,
220
+ dropout_p=self.attn_dropout if self.training else 0.0,
221
+ is_causal=mask is None,
222
+ )
223
+ return output.transpose(1, 2).contiguous()
224
+
225
+ def _flash_attention(
226
+ self,
227
+ query: torch.Tensor,
228
+ key: torch.Tensor,
229
+ value: torch.Tensor,
230
+ attention_mask: Optional[torch.Tensor],
231
+ packed_metadata: Optional[PackedSequenceMetadata],
232
+ ) -> torch.Tensor:
233
+ if not _FLASH_ATTN_AVAILABLE:
234
+ raise ImportError("flash_attn is not installed, but attn_implementation='flash_attention_2' was requested.")
235
+ if query.device.type != "cuda":
236
+ raise ValueError("flash_attention_2 requires CUDA tensors.")
237
+ if query.dtype not in (torch.float16, torch.bfloat16):
238
+ raise ValueError(
239
+ f"flash_attention_2 requires fp16/bf16 tensors, but received dtype={query.dtype}."
240
+ )
241
+
242
+ dropout_p = self.attn_dropout if self.training else 0.0
243
+ if packed_metadata is not None:
244
+ if packed_metadata.indices is not None:
245
+ query = query.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
246
+ key = key.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
247
+ value = value.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
248
+ output = flash_attn_varlen_func(
249
+ query,
250
+ key,
251
+ value,
252
+ packed_metadata.cu_seqlens,
253
+ packed_metadata.cu_seqlens,
254
+ packed_metadata.max_seqlen,
255
+ packed_metadata.max_seqlen,
256
+ dropout_p=dropout_p,
257
+ causal=True,
258
+ )
259
+ if packed_metadata.indices is None:
260
+ return output
261
+ return pad_input(
262
+ output,
263
+ packed_metadata.indices,
264
+ packed_metadata.batch_size,
265
+ packed_metadata.seq_len,
266
+ )
267
+
268
+ if attention_mask is None or bool(attention_mask.all()):
269
+ return flash_attn_func(
270
+ query,
271
+ key,
272
+ value,
273
+ dropout_p=dropout_p,
274
+ causal=True,
275
+ )
276
+
277
+ if query.shape[1] != key.shape[1]:
278
+ query_attention_mask = attention_mask[:, -query.shape[1] :]
279
+ unpadded_query, query_indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(
280
+ query,
281
+ query_attention_mask,
282
+ )
283
+ unpadded_key, _, cu_seqlens_k, max_seqlen_k, _ = unpad_input(key, attention_mask)
284
+ unpadded_value, _, _, _, _ = unpad_input(value, attention_mask)
285
+ output = flash_attn_varlen_func(
286
+ unpadded_query,
287
+ unpadded_key,
288
+ unpadded_value,
289
+ cu_seqlens_q,
290
+ cu_seqlens_k,
291
+ max_seqlen_q,
292
+ max_seqlen_k,
293
+ dropout_p=dropout_p,
294
+ causal=True,
295
+ )
296
+ return pad_input(output, query_indices, query.shape[0], query.shape[1])
297
+
298
+ unpadded_query, indices, cu_seqlens, max_seqlen, _ = unpad_input(query, attention_mask)
299
+ unpadded_key, _, _, _, _ = unpad_input(key, attention_mask)
300
+ unpadded_value, _, _, _, _ = unpad_input(value, attention_mask)
301
+ output = flash_attn_varlen_func(
302
+ unpadded_query,
303
+ unpadded_key,
304
+ unpadded_value,
305
+ cu_seqlens,
306
+ cu_seqlens,
307
+ max_seqlen,
308
+ max_seqlen,
309
+ dropout_p=dropout_p,
310
+ causal=True,
311
+ )
312
+ return pad_input(output, indices, query.shape[0], query.shape[1])
313
+
314
+ def forward(
315
+ self,
316
+ hidden_states: torch.Tensor,
317
+ attention_mask: Optional[torch.Tensor] = None,
318
+ position_ids: Optional[torch.LongTensor] = None,
319
+ packed_metadata: Optional[PackedSequenceMetadata] = None,
320
+ layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
321
+ use_cache: bool = False,
322
+ ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
323
+ qkv = self.c_attn(hidden_states)
324
+ query, key, value = qkv.split(self.embed_dim, dim=-1)
325
+ query = self._split_heads(query)
326
+ key = self._split_heads(key)
327
+ value = self._split_heads(value)
328
+
329
+ if self.rotary_emb is not None:
330
+ if position_ids is None:
331
+ raise ValueError("position_ids must be provided when position_embedding_type='rope'.")
332
+ cos, sin = self.rotary_emb(
333
+ position_ids.to(device=query.device),
334
+ device=query.device,
335
+ dtype=query.dtype,
336
+ )
337
+ query = apply_rotary_pos_emb(query, cos, sin)
338
+ key = apply_rotary_pos_emb(key, cos, sin)
339
+
340
+ static_layer_past = layer_past is not None and _is_static_kv_cache_layer(layer_past)
341
+ if static_layer_past:
342
+ past_length = int(layer_past.get("length", 0))
343
+ new_length = past_length + int(key.shape[1])
344
+ key_cache = layer_past["key"]
345
+ value_cache = layer_past["value"]
346
+ if new_length > int(key_cache.shape[1]):
347
+ raise ValueError(
348
+ f"Static KV cache is too short: need {new_length}, capacity={int(key_cache.shape[1])}."
349
+ )
350
+ key_cache[:, past_length:new_length].copy_(key)
351
+ value_cache[:, past_length:new_length].copy_(value)
352
+ key = key_cache[:, :new_length]
353
+ value = value_cache[:, :new_length]
354
+ layer_past["length"] = new_length
355
+ elif layer_past is not None:
356
+ past_key, past_value = layer_past
357
+ key = torch.cat([past_key.to(device=key.device, dtype=key.dtype), key], dim=1)
358
+ value = torch.cat([past_value.to(device=value.device, dtype=value.dtype), value], dim=1)
359
+
360
+ present = layer_past if (use_cache and static_layer_past) else ((key, value) if use_cache else None)
361
+
362
+ if self.attn_implementation == "flash_attention_2":
363
+ attn_output = self._flash_attention(
364
+ query=query,
365
+ key=key,
366
+ value=value,
367
+ attention_mask=attention_mask,
368
+ packed_metadata=packed_metadata,
369
+ )
370
+ elif self.attn_implementation == "sdpa":
371
+ attn_output = self._sdpa_attention(
372
+ query=query,
373
+ key=key,
374
+ value=value,
375
+ attention_mask=attention_mask,
376
+ )
377
+ else:
378
+ attn_output = self._eager_attention(
379
+ query=query,
380
+ key=key,
381
+ value=value,
382
+ attention_mask=attention_mask,
383
+ )
384
+
385
+ attn_output = self._merge_heads(attn_output)
386
+ attn_output = self.c_proj(attn_output)
387
+ return self.resid_dropout(attn_output), present
388
+
389
+
390
+ class MossTTSNanoGPT2Block(nn.Module):
391
+ def __init__(self, config: GPT2Config, layer_idx: int, attn_implementation: str) -> None:
392
+ super().__init__()
393
+ hidden_size = int(config.hidden_size)
394
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
395
+ self.attn = MossTTSNanoGPT2Attention(config, layer_idx=layer_idx, attn_implementation=attn_implementation)
396
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
397
+ self.mlp = MossTTSNanoGPT2MLP(config)
398
+
399
+ def forward(
400
+ self,
401
+ hidden_states: torch.Tensor,
402
+ attention_mask: Optional[torch.Tensor] = None,
403
+ position_ids: Optional[torch.LongTensor] = None,
404
+ packed_metadata: Optional[PackedSequenceMetadata] = None,
405
+ layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
406
+ use_cache: bool = False,
407
+ ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
408
+ attn_output, present = self.attn(
409
+ self.ln_1(hidden_states),
410
+ attention_mask=attention_mask,
411
+ position_ids=position_ids,
412
+ packed_metadata=packed_metadata,
413
+ layer_past=layer_past,
414
+ use_cache=use_cache,
415
+ )
416
+ hidden_states = hidden_states + attn_output
417
+ hidden_states = hidden_states + self.mlp(self.ln_2(hidden_states))
418
+ return hidden_states, present
419
+
420
+
421
+ class MossTTSNanoGPT2Model(nn.Module):
422
+ def __init__(self, config: GPT2Config, attn_implementation: str = "eager") -> None:
423
+ super().__init__()
424
+ self.config = config
425
+ self.attn_implementation = attn_implementation
426
+ self.position_embedding_type = str(getattr(config, "position_embedding_type", "absolute")).lower()
427
+ if self.position_embedding_type not in {"absolute", "rope"}:
428
+ raise ValueError(f"Unsupported position_embedding_type={self.position_embedding_type!r}")
429
+ hidden_size = int(config.hidden_size)
430
+ self.wte = nn.Embedding(config.vocab_size, hidden_size)
431
+ self.wpe = nn.Embedding(config.n_positions, hidden_size) if self.position_embedding_type == "absolute" else nn.Identity()
432
+ self.drop = nn.Dropout(config.embd_pdrop)
433
+ self.h = nn.ModuleList(
434
+ [MossTTSNanoGPT2Block(config, layer_idx=index, attn_implementation=attn_implementation) for index in range(config.n_layer)]
435
+ )
436
+ self.ln_f = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
437
+ self.gradient_checkpointing = False
438
+ self._reset_parameters()
439
+
440
+ def _reset_parameters(self) -> None:
441
+ init_std = float(self.config.initializer_range)
442
+ for module in self.modules():
443
+ if isinstance(module, nn.Linear):
444
+ nn.init.normal_(module.weight, mean=0.0, std=init_std)
445
+ if module.bias is not None:
446
+ nn.init.zeros_(module.bias)
447
+ elif isinstance(module, nn.Embedding):
448
+ nn.init.normal_(module.weight, mean=0.0, std=init_std)
449
+ elif isinstance(module, nn.LayerNorm):
450
+ nn.init.ones_(module.weight)
451
+ nn.init.zeros_(module.bias)
452
+
453
+ @staticmethod
454
+ def _normalize_num_sequences(
455
+ cu_seqlens: torch.Tensor,
456
+ num_sequences: Optional[torch.Tensor],
457
+ device: torch.device,
458
+ ) -> torch.Tensor:
459
+ if cu_seqlens.ndim == 1:
460
+ cu_seqlens = cu_seqlens.unsqueeze(0)
461
+ if num_sequences is None:
462
+ diffs = cu_seqlens[:, 1:] - cu_seqlens[:, :-1]
463
+ return diffs.gt(0).sum(dim=-1).to(device=device, dtype=torch.long)
464
+ if num_sequences.ndim == 0:
465
+ num_sequences = num_sequences.unsqueeze(0)
466
+ return num_sequences.to(device=device, dtype=torch.long)
467
+
468
+ @staticmethod
469
+ def _packed_segments_from_cu_seqlens(
470
+ cu_seqlens: torch.Tensor,
471
+ num_sequences: Optional[torch.Tensor],
472
+ device: torch.device,
473
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
474
+ if cu_seqlens.ndim == 1:
475
+ cu_seqlens = cu_seqlens.unsqueeze(0)
476
+ cu_seqlens = cu_seqlens.to(device=device)
477
+ batch_size, boundary_count = cu_seqlens.shape
478
+ segment_slots = boundary_count - 1
479
+ if segment_slots <= 0:
480
+ empty = torch.empty(0, dtype=torch.long, device=device)
481
+ return empty, empty, empty
482
+
483
+ counts = MossTTSNanoGPT2Model._normalize_num_sequences(cu_seqlens, num_sequences, device=device)
484
+ counts = counts.clamp(min=0, max=segment_slots)
485
+ segment_slots = int(counts.max().item()) if counts.numel() > 0 else 0
486
+ if segment_slots <= 0:
487
+ empty = torch.empty(0, dtype=torch.long, device=device)
488
+ return empty, empty, empty
489
+ cu_seqlens = cu_seqlens[:, : segment_slots + 1]
490
+ slot_ids = torch.arange(segment_slots, device=device).unsqueeze(0)
491
+ valid_slots = slot_ids < counts.unsqueeze(1)
492
+
493
+ starts = cu_seqlens[:, :-1].to(dtype=torch.long)
494
+ ends = cu_seqlens[:, 1:].to(dtype=torch.long)
495
+ lengths = (ends - starts).clamp_min(0)
496
+ lengths = torch.where(valid_slots, lengths, torch.zeros((), dtype=torch.long, device=device))
497
+
498
+ batch_ids = torch.arange(batch_size, device=device, dtype=torch.long).unsqueeze(1).expand(batch_size, segment_slots)
499
+ batch_ids = batch_ids.reshape(-1)
500
+ starts = starts.reshape(-1)
501
+ lengths = lengths.reshape(-1)
502
+
503
+ valid_segments = lengths.gt(0)
504
+ valid_count = int(valid_segments.to(dtype=torch.long).sum().item())
505
+ if valid_count <= 0:
506
+ empty = torch.empty(0, dtype=torch.long, device=device)
507
+ return empty, empty, empty
508
+ if valid_count == lengths.numel():
509
+ return batch_ids, starts, lengths
510
+
511
+ valid_order = torch.argsort(valid_segments.to(dtype=torch.long), descending=True, stable=True)[:valid_count]
512
+ return (
513
+ batch_ids.index_select(0, valid_order),
514
+ starts.index_select(0, valid_order),
515
+ lengths.index_select(0, valid_order),
516
+ )
517
+
518
+ @staticmethod
519
+ def _packed_token_indices(
520
+ batch_ids: torch.Tensor,
521
+ starts: torch.Tensor,
522
+ lengths: torch.Tensor,
523
+ seq_len: int,
524
+ ) -> tuple[torch.Tensor, torch.Tensor]:
525
+ total_tokens = int(lengths.sum().item())
526
+ if total_tokens <= 0:
527
+ empty = torch.empty(0, dtype=torch.long, device=lengths.device)
528
+ return empty, empty
529
+
530
+ segment_ids = torch.repeat_interleave(
531
+ torch.arange(lengths.numel(), device=lengths.device, dtype=torch.long),
532
+ lengths,
533
+ output_size=total_tokens,
534
+ )
535
+ segment_starts = torch.cumsum(lengths, dim=0) - lengths
536
+ positions = torch.arange(total_tokens, device=lengths.device, dtype=torch.long) - segment_starts[segment_ids]
537
+ indices = batch_ids[segment_ids] * seq_len + starts[segment_ids] + positions
538
+ return indices, positions
539
+
540
+ @staticmethod
541
+ def build_packed_position_ids(
542
+ attention_mask: Optional[torch.Tensor],
543
+ cu_seqlens: torch.Tensor,
544
+ num_sequences: Optional[torch.Tensor],
545
+ sequence_length: Optional[int] = None,
546
+ ) -> torch.Tensor:
547
+ if cu_seqlens.ndim == 1:
548
+ cu_seqlens = cu_seqlens.unsqueeze(0)
549
+ batch_size = cu_seqlens.shape[0]
550
+ seq_len = int(sequence_length or (cu_seqlens.shape[1] - 1))
551
+ device = cu_seqlens.device
552
+ position_ids = torch.zeros((batch_size, seq_len), dtype=torch.long, device=device)
553
+ batch_ids, starts, lengths = MossTTSNanoGPT2Model._packed_segments_from_cu_seqlens(
554
+ cu_seqlens,
555
+ num_sequences,
556
+ device,
557
+ )
558
+ if lengths.numel() > 0:
559
+ indices, positions = MossTTSNanoGPT2Model._packed_token_indices(batch_ids, starts, lengths, seq_len)
560
+ position_ids.view(-1).scatter_(0, indices, positions)
561
+ if attention_mask is not None:
562
+ position_ids = position_ids * attention_mask.to(dtype=position_ids.dtype)
563
+ return position_ids
564
+
565
+ @staticmethod
566
+ def build_packed_metadata(
567
+ hidden_states: torch.Tensor,
568
+ cu_seqlens: torch.Tensor,
569
+ num_sequences: Optional[torch.Tensor],
570
+ ) -> PackedSequenceMetadata:
571
+ if cu_seqlens.ndim == 1:
572
+ cu_seqlens = cu_seqlens.unsqueeze(0)
573
+ device = hidden_states.device
574
+ seq_len = hidden_states.shape[1]
575
+ batch_ids, starts, lengths = MossTTSNanoGPT2Model._packed_segments_from_cu_seqlens(
576
+ cu_seqlens,
577
+ num_sequences,
578
+ device,
579
+ )
580
+ if lengths.numel() == 0:
581
+ raise ValueError("cu_seqlens did not describe any non-empty packed sequences.")
582
+
583
+ indices, _ = MossTTSNanoGPT2Model._packed_token_indices(batch_ids, starts, lengths, seq_len)
584
+ cumulative = torch.empty(lengths.numel() + 1, dtype=torch.int32, device=device)
585
+ cumulative[0] = 0
586
+ cumulative[1:] = lengths.to(dtype=torch.int32).cumsum(dim=0)
587
+ return PackedSequenceMetadata(
588
+ cu_seqlens=cumulative,
589
+ max_seqlen=int(lengths.max().item()),
590
+ indices=indices,
591
+ batch_size=hidden_states.shape[0],
592
+ seq_len=hidden_states.shape[1],
593
+ )
594
+
595
+ def forward(
596
+ self,
597
+ input_ids: Optional[torch.LongTensor] = None,
598
+ past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None,
599
+ attention_mask: Optional[torch.Tensor] = None,
600
+ position_ids: Optional[torch.LongTensor] = None,
601
+ inputs_embeds: Optional[torch.FloatTensor] = None,
602
+ use_cache: Optional[bool] = None,
603
+ output_attentions: Optional[bool] = None,
604
+ output_hidden_states: Optional[bool] = None,
605
+ return_dict: bool = True,
606
+ cu_seqlens: Optional[torch.Tensor] = None,
607
+ num_sequences: Optional[torch.Tensor] = None,
608
+ ) -> BaseModelOutputWithPast:
609
+ del input_ids, output_attentions
610
+
611
+ if inputs_embeds is None:
612
+ raise ValueError("inputs_embeds must be provided.")
613
+
614
+ use_cache = bool(use_cache)
615
+ if use_cache and cu_seqlens is not None:
616
+ raise ValueError("use_cache=True is not supported together with cu_seqlens packing.")
617
+
618
+ hidden_states = inputs_embeds
619
+ query_attention_mask = None
620
+ if attention_mask is not None:
621
+ attention_mask = attention_mask.to(dtype=torch.bool, device=hidden_states.device)
622
+ query_attention_mask = attention_mask[:, -hidden_states.shape[1] :]
623
+
624
+ packed_metadata = None
625
+ if position_ids is None:
626
+ if cu_seqlens is not None:
627
+ if attention_mask is None:
628
+ raise ValueError("attention_mask must be provided with cu_seqlens packing.")
629
+ position_ids = self.build_packed_position_ids(
630
+ attention_mask=attention_mask,
631
+ cu_seqlens=cu_seqlens.to(device=hidden_states.device),
632
+ num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
633
+ sequence_length=hidden_states.shape[1],
634
+ )
635
+ elif attention_mask is not None:
636
+ position_ids = attention_mask.long().cumsum(dim=-1) - 1
637
+ position_ids = position_ids.masked_fill(~attention_mask, 0)
638
+ position_ids = position_ids[:, -hidden_states.shape[1] :]
639
+ else:
640
+ past_length = 0
641
+ if past_key_values is not None and len(past_key_values) > 0:
642
+ first_layer_past = past_key_values[0]
643
+ if _is_static_kv_cache_layer(first_layer_past):
644
+ past_length = int(first_layer_past.get("length", 0))
645
+ else:
646
+ past_length = first_layer_past[0].shape[1]
647
+ position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device, dtype=torch.long)
648
+ position_ids = position_ids + past_length
649
+ position_ids = position_ids.unsqueeze(0).expand(hidden_states.shape[0], -1)
650
+
651
+ if cu_seqlens is not None and self.attn_implementation == "flash_attention_2":
652
+ packed_metadata = self.build_packed_metadata(
653
+ hidden_states=hidden_states,
654
+ cu_seqlens=cu_seqlens.to(device=hidden_states.device),
655
+ num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
656
+ )
657
+
658
+ if self.position_embedding_type == "absolute":
659
+ hidden_states = hidden_states + self.wpe(position_ids)
660
+ hidden_states = self.drop(hidden_states)
661
+ if query_attention_mask is not None:
662
+ hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
663
+
664
+ all_hidden_states = () if output_hidden_states else None
665
+ presents = [] if use_cache else None
666
+ for layer_index, block in enumerate(self.h):
667
+ if output_hidden_states:
668
+ all_hidden_states = all_hidden_states + (hidden_states,)
669
+
670
+ if self.gradient_checkpointing and self.training:
671
+ if use_cache:
672
+ raise ValueError("use_cache=True is not supported when gradient checkpointing is enabled during training.")
673
+
674
+ def custom_forward(*inputs):
675
+ output, _ = block(
676
+ inputs[0],
677
+ attention_mask=inputs[1],
678
+ position_ids=inputs[2],
679
+ packed_metadata=packed_metadata,
680
+ layer_past=None,
681
+ use_cache=False,
682
+ )
683
+ return output
684
+
685
+ hidden_states = torch.utils.checkpoint.checkpoint(
686
+ custom_forward,
687
+ hidden_states,
688
+ attention_mask,
689
+ position_ids,
690
+ use_reentrant=False,
691
+ )
692
+ present = None
693
+ else:
694
+ hidden_states, present = block(
695
+ hidden_states,
696
+ attention_mask=attention_mask,
697
+ position_ids=position_ids,
698
+ packed_metadata=packed_metadata,
699
+ layer_past=None if past_key_values is None else past_key_values[layer_index],
700
+ use_cache=use_cache,
701
+ )
702
+ if query_attention_mask is not None:
703
+ hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
704
+ if presents is not None:
705
+ presents.append(present)
706
+
707
+ hidden_states = self.ln_f(hidden_states)
708
+ if query_attention_mask is not None:
709
+ hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
710
+ if output_hidden_states:
711
+ all_hidden_states = all_hidden_states + (hidden_states,)
712
+
713
+ if not return_dict:
714
+ return (hidden_states, tuple(presents) if presents is not None else None, all_hidden_states, None)
715
+
716
+ return BaseModelOutputWithPast(
717
+ last_hidden_state=hidden_states,
718
+ past_key_values=tuple(presents) if presents is not None else None,
719
+ hidden_states=all_hidden_states,
720
+ attentions=None,
721
+ )
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:608f1ff64bc6caa9be836060fc7c78a15c4658c4a07b8d73c78d6f70d1b39c23
3
+ size 9100859544
modeling_moss_tts.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """Modeling code for the MOSS-TTS-Local-Transformer-v1.5 HuggingFace release."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from dataclasses import dataclass
7
+ from typing import Any, Optional, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from transformers.modeling_outputs import BaseModelOutputWithPast
12
+ from transformers.modeling_utils import PreTrainedModel
13
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
14
+ from transformers.utils import ModelOutput
15
+
16
+ from .configuration_moss_tts import MossTTSLocalConfig
17
+ from .gpt2_decoder import MossTTSNanoGPT2Model
18
+ from .qwen3_decoder import MossQwen3Model
19
+
20
+
21
+ @dataclass
22
+ class MossTTSLocalOutput(ModelOutput):
23
+ last_hidden_state: Optional[torch.FloatTensor] = None
24
+ past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None
25
+ hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
26
+ attentions: Optional[tuple[torch.FloatTensor, ...]] = None
27
+
28
+
29
+ def _find_last_equal(input_ids: torch.LongTensor, value: int) -> torch.LongTensor:
30
+ matches = input_ids.eq(int(value))
31
+ if not bool(matches.any(dim=1).all().item()):
32
+ raise ValueError(f"Every sample must contain token id {int(value)}.")
33
+ positions = torch.arange(input_ids.shape[1], device=input_ids.device, dtype=torch.long)
34
+ masked_positions = positions.unsqueeze(0).masked_fill(~matches, -1)
35
+ return masked_positions.max(dim=1).values
36
+
37
+
38
+ class MossTTSLocalPreTrainedModel(PreTrainedModel):
39
+ config_class = MossTTSLocalConfig
40
+ base_model_prefix = "transformer"
41
+ supports_gradient_checkpointing = True
42
+ _no_split_modules = ["MossTTSNanoGPT2Block", "MossQwen3DecoderLayer"]
43
+ _supports_flash_attn_2 = True
44
+ _supports_sdpa = True
45
+ _supports_cache_class = True
46
+
47
+ def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False) -> None:
48
+ if isinstance(module, MossTTSNanoGPT2Model) or isinstance(module, MossQwen3Model):
49
+ module.gradient_checkpointing = value
50
+
51
+
52
+ class MossTTSLocalModel(MossTTSLocalPreTrainedModel):
53
+ _tied_weights_keys = None
54
+
55
+ def __init__(self, config: MossTTSLocalConfig) -> None:
56
+ super().__init__(config)
57
+ self._tied_weights_keys = self._build_tied_weights_keys(config)
58
+
59
+ config.qwen3_config.pad_token_id = config.pad_token_id
60
+ config.qwen3_config._attn_implementation = config.attn_implementation
61
+ local_gpt2_config = config.gpt2_config.to_dict()
62
+ local_gpt2_config["n_layer"] = int(getattr(config, "local_transformer_layers", config.gpt2_config.n_layer))
63
+ local_gpt2_config["n_positions"] = int(config.n_vq) + 1
64
+ local_gpt2_config["n_ctx"] = int(config.n_vq) + 1
65
+ local_gpt2_config = GPT2Config(**local_gpt2_config)
66
+ local_gpt2_config.pad_token_id = config.pad_token_id
67
+ local_gpt2_config._attn_implementation = config.local_transformer_attn_implementation
68
+
69
+ self.transformer = MossQwen3Model(config.qwen3_config)
70
+ self.local_transformer = MossTTSNanoGPT2Model(
71
+ local_gpt2_config,
72
+ attn_implementation=config.local_transformer_attn_implementation,
73
+ )
74
+ self.local_transformer.wte = nn.Identity()
75
+
76
+ hidden_size = int(config.hidden_size)
77
+ self.audio_embeddings = nn.ModuleList(
78
+ [
79
+ nn.Embedding(int(config.audio_codebook_sizes[index]), hidden_size)
80
+ for index in range(config.n_vq)
81
+ ]
82
+ )
83
+ self.text_lm_head = nn.Linear(hidden_size, int(config.vocab_size), bias=False)
84
+ self.audio_lm_heads = nn.ModuleList(
85
+ [
86
+ nn.Linear(hidden_size, int(config.audio_codebook_sizes[index]), bias=False)
87
+ for index in range(config.n_vq)
88
+ ]
89
+ )
90
+ self.local_text_lm_head = (
91
+ nn.Linear(hidden_size, 2, bias=False)
92
+ if self._use_binary_local_text_head()
93
+ else None
94
+ )
95
+
96
+ self.post_init()
97
+ self.tie_weights()
98
+ self.initialize_local_text_lm_head_from_text_lm_head()
99
+
100
+ def can_generate(self) -> bool:
101
+ return True
102
+
103
+ @staticmethod
104
+ def _build_tied_weights_keys(config: MossTTSLocalConfig) -> dict[str, str]:
105
+ tied_weights = {"text_lm_head.weight": "transformer.embed_tokens.weight"}
106
+ tied_weights.update(
107
+ {
108
+ f"audio_lm_heads.{index}.weight": f"audio_embeddings.{index}.weight"
109
+ for index in range(config.n_vq)
110
+ }
111
+ )
112
+ return tied_weights
113
+
114
+ def tie_weights(self, *args, **kwargs) -> None:
115
+ del args, kwargs
116
+ self.text_lm_head.weight = self.transformer.embed_tokens.weight
117
+ for embedding, head in zip(self.audio_embeddings, self.audio_lm_heads):
118
+ head.weight = embedding.weight
119
+
120
+ def get_input_embeddings(self) -> nn.Embedding:
121
+ return self.transformer.embed_tokens
122
+
123
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
124
+ self.transformer.embed_tokens = value
125
+ self.tie_weights()
126
+ self.initialize_local_text_lm_head_from_text_lm_head()
127
+
128
+ def get_output_embeddings(self) -> nn.Linear:
129
+ return self.text_lm_head
130
+
131
+ def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
132
+ self.text_lm_head = new_embeddings
133
+ self.tie_weights()
134
+ self.initialize_local_text_lm_head_from_text_lm_head()
135
+
136
+ def _use_binary_local_text_head(self) -> bool:
137
+ return str(getattr(self.config, "local_text_head_mode", "full_vocab")).strip().lower() == "binary"
138
+
139
+ def _local_text_candidate_ids(self, device: torch.device) -> torch.LongTensor:
140
+ return torch.tensor(
141
+ [
142
+ int(self.config.audio_assistant_slot_token_id),
143
+ int(self.config.audio_end_token_id),
144
+ ],
145
+ dtype=torch.long,
146
+ device=device,
147
+ )
148
+
149
+ def initialize_local_text_lm_head_from_text_lm_head(self) -> None:
150
+ if not self._use_binary_local_text_head() or self.local_text_lm_head is None:
151
+ return
152
+ candidate_ids = self._local_text_candidate_ids(self.text_lm_head.weight.device)
153
+ with torch.no_grad():
154
+ source_weight = self.text_lm_head.weight.index_select(0, candidate_ids)
155
+ if tuple(source_weight.shape) == tuple(self.local_text_lm_head.weight.shape):
156
+ self.local_text_lm_head.weight.copy_(
157
+ source_weight.to(
158
+ device=self.local_text_lm_head.weight.device,
159
+ dtype=self.local_text_lm_head.weight.dtype,
160
+ )
161
+ )
162
+
163
+ def _resolve_fixed_nq(
164
+ self,
165
+ n_vq_for_inference: Optional[int] = None,
166
+ nq: Optional[int] = None,
167
+ ) -> int:
168
+ requested = n_vq_for_inference if n_vq_for_inference is not None else nq
169
+ config_nq = int(self.config.n_vq)
170
+ if requested is not None and int(requested) != config_nq:
171
+ raise ValueError(
172
+ "This MOSS-TTS-Local-Transformer-v1.5 release is trained with a fixed RVQ depth. "
173
+ f"Expected n_vq={config_nq}, got {int(requested)}."
174
+ )
175
+ return config_nq
176
+
177
+ def _build_inputs_embeds(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
178
+ if input_ids.ndim != 3 or input_ids.shape[-1] != self.config.n_vq + 1:
179
+ raise ValueError(
180
+ f"Expected input_ids shape [batch, seq, {self.config.n_vq + 1}], "
181
+ f"got {tuple(input_ids.shape)}."
182
+ )
183
+ text_ids = input_ids[..., 0]
184
+ inputs_embeds = self.transformer.embed_tokens(text_ids)
185
+ for channel_index, embedding in enumerate(self.audio_embeddings):
186
+ channel_ids = input_ids[..., channel_index + 1]
187
+ valid_mask = channel_ids.ne(self.config.audio_pad_token_id)
188
+ safe_ids = channel_ids.masked_fill(~valid_mask, 0)
189
+ audio_embeds = embedding(safe_ids) * valid_mask.unsqueeze(-1)
190
+ inputs_embeds = inputs_embeds + audio_embeds
191
+ return inputs_embeds
192
+
193
+ def _global_hidden_to_local(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
194
+ return hidden_states
195
+
196
+ @staticmethod
197
+ def _local_past_length(past_key_values: Optional[tuple[Any, ...]]) -> int:
198
+ if past_key_values is None or len(past_key_values) == 0:
199
+ return 0
200
+ first_layer_past = past_key_values[0]
201
+ if isinstance(first_layer_past, dict) and bool(first_layer_past.get("static_kv_cache", False)):
202
+ return int(first_layer_past.get("length", 0))
203
+ return int(first_layer_past[0].shape[1])
204
+
205
+ def _new_static_local_past_key_values(
206
+ self,
207
+ batch_size: int,
208
+ max_length: int,
209
+ device: torch.device,
210
+ dtype: torch.dtype,
211
+ ) -> tuple[dict[str, Any], ...]:
212
+ layers = []
213
+ for block in self.local_transformer.h:
214
+ attn = block.attn
215
+ cache_shape = (
216
+ int(batch_size),
217
+ int(max_length),
218
+ int(attn.num_heads),
219
+ int(attn.head_dim),
220
+ )
221
+ layers.append(
222
+ {
223
+ "static_kv_cache": True,
224
+ "key": torch.empty(cache_shape, device=device, dtype=dtype),
225
+ "value": torch.empty(cache_shape, device=device, dtype=dtype),
226
+ "length": 0,
227
+ }
228
+ )
229
+ return tuple(layers)
230
+
231
+ def _decode_local_hidden_states_with_cache(
232
+ self,
233
+ local_inputs_embeds: torch.FloatTensor,
234
+ past_key_values: Optional[tuple[Any, ...]] = None,
235
+ ) -> tuple[torch.FloatTensor, Optional[tuple[Any, ...]]]:
236
+ if (
237
+ past_key_values is None
238
+ and not self.training
239
+ and bool(getattr(self.config, "use_static_local_kv_cache", True))
240
+ ):
241
+ max_length = max(int(getattr(self.config, "n_vq", 0)) + 1, int(local_inputs_embeds.shape[1]))
242
+ past_key_values = self._new_static_local_past_key_values(
243
+ batch_size=int(local_inputs_embeds.shape[0]),
244
+ max_length=max_length,
245
+ device=local_inputs_embeds.device,
246
+ dtype=local_inputs_embeds.dtype,
247
+ )
248
+ past_length = self._local_past_length(past_key_values)
249
+ local_seq_len = int(local_inputs_embeds.shape[1])
250
+ local_position_ids = torch.arange(
251
+ past_length,
252
+ past_length + local_seq_len,
253
+ device=local_inputs_embeds.device,
254
+ dtype=torch.long,
255
+ ).unsqueeze(0)
256
+ if int(local_inputs_embeds.shape[0]) != 1:
257
+ local_position_ids = local_position_ids.expand(int(local_inputs_embeds.shape[0]), -1)
258
+ local_outputs = self.local_transformer(
259
+ input_ids=None,
260
+ past_key_values=past_key_values,
261
+ attention_mask=None,
262
+ position_ids=local_position_ids,
263
+ inputs_embeds=local_inputs_embeds,
264
+ use_cache=True,
265
+ output_attentions=False,
266
+ output_hidden_states=False,
267
+ return_dict=True,
268
+ cu_seqlens=None,
269
+ num_sequences=None,
270
+ )
271
+ return local_outputs.last_hidden_state, local_outputs.past_key_values
272
+
273
+ def forward(
274
+ self,
275
+ input_ids: Optional[torch.LongTensor] = None,
276
+ attention_mask: Optional[torch.Tensor] = None,
277
+ position_ids: Optional[torch.LongTensor] = None,
278
+ past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None,
279
+ inputs_embeds: Optional[torch.FloatTensor] = None,
280
+ use_cache: Optional[bool] = None,
281
+ output_attentions: Optional[bool] = None,
282
+ output_hidden_states: Optional[bool] = None,
283
+ return_dict: Optional[bool] = True,
284
+ **kwargs,
285
+ ) -> Union[tuple, MossTTSLocalOutput]:
286
+ del kwargs
287
+ if inputs_embeds is None:
288
+ if input_ids is None:
289
+ raise ValueError("Either input_ids or inputs_embeds must be provided.")
290
+ inputs_embeds = self._build_inputs_embeds(input_ids)
291
+ outputs = self.transformer(
292
+ input_ids=None,
293
+ attention_mask=attention_mask,
294
+ position_ids=position_ids,
295
+ past_key_values=past_key_values,
296
+ inputs_embeds=inputs_embeds,
297
+ use_cache=use_cache,
298
+ output_attentions=output_attentions,
299
+ output_hidden_states=output_hidden_states,
300
+ return_dict=True,
301
+ cu_seqlens=None,
302
+ num_sequences=None,
303
+ )
304
+ if not return_dict:
305
+ return (
306
+ outputs.last_hidden_state,
307
+ outputs.past_key_values,
308
+ outputs.hidden_states,
309
+ outputs.attentions,
310
+ )
311
+ return MossTTSLocalOutput(
312
+ last_hidden_state=outputs.last_hidden_state,
313
+ past_key_values=outputs.past_key_values,
314
+ hidden_states=outputs.hidden_states,
315
+ attentions=outputs.attentions,
316
+ )
317
+
318
+ def _decode_local_last_hidden_state(
319
+ self,
320
+ local_inputs_embeds: torch.FloatTensor,
321
+ ) -> torch.FloatTensor:
322
+ local_seq_len = int(local_inputs_embeds.shape[1])
323
+ local_position_ids = torch.arange(
324
+ 0,
325
+ local_seq_len,
326
+ device=local_inputs_embeds.device,
327
+ dtype=torch.long,
328
+ ).unsqueeze(0)
329
+ if int(local_inputs_embeds.shape[0]) != 1:
330
+ local_position_ids = local_position_ids.expand(int(local_inputs_embeds.shape[0]), -1)
331
+ local_outputs = self.local_transformer(
332
+ input_ids=None,
333
+ attention_mask=None,
334
+ position_ids=local_position_ids,
335
+ inputs_embeds=local_inputs_embeds,
336
+ use_cache=False,
337
+ output_attentions=False,
338
+ output_hidden_states=False,
339
+ return_dict=True,
340
+ cu_seqlens=None,
341
+ num_sequences=None,
342
+ )
343
+ return local_outputs.last_hidden_state[:, -1, :]
344
+
345
+ def _filter_logits(
346
+ self,
347
+ logits: torch.FloatTensor,
348
+ top_k: Optional[int],
349
+ top_p: Optional[float],
350
+ ) -> torch.FloatTensor:
351
+ scores = logits
352
+ if top_k is not None and int(top_k) > 0 and int(top_k) < scores.shape[-1]:
353
+ kth = torch.topk(scores, int(top_k), dim=-1).values[..., -1, None]
354
+ scores = scores.masked_fill(scores < kth, -torch.inf)
355
+ if top_p is not None and 0.0 < float(top_p) < 1.0:
356
+ sorted_scores, sorted_indices = torch.sort(scores, descending=True, dim=-1)
357
+ sorted_probs = torch.softmax(sorted_scores, dim=-1)
358
+ cumulative_probs = sorted_probs.cumsum(dim=-1)
359
+ sorted_mask = cumulative_probs > float(top_p)
360
+ sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
361
+ sorted_mask[..., 0] = False
362
+ remove_mask = torch.zeros_like(scores, dtype=torch.bool)
363
+ remove_mask.scatter_(dim=-1, index=sorted_indices, src=sorted_mask)
364
+ scores = scores.masked_fill(remove_mask, -torch.inf)
365
+ return scores
366
+
367
+ def _apply_repetition_penalty(
368
+ self,
369
+ scores: torch.FloatTensor,
370
+ previous_token_ids: Optional[torch.LongTensor],
371
+ penalty: float,
372
+ ) -> torch.FloatTensor:
373
+ if previous_token_ids is None or float(penalty) == 1.0:
374
+ return scores
375
+ if previous_token_ids.ndim == 1:
376
+ previous_token_ids = previous_token_ids.unsqueeze(0)
377
+ updated = scores.clone()
378
+ for batch_index in range(updated.shape[0]):
379
+ unique_token_ids = torch.unique(previous_token_ids[batch_index])
380
+ unique_token_ids = unique_token_ids[
381
+ (unique_token_ids >= 0) & (unique_token_ids < updated.shape[-1])
382
+ ]
383
+ if unique_token_ids.numel() == 0:
384
+ continue
385
+ token_scores = updated[batch_index].index_select(0, unique_token_ids)
386
+ token_scores = torch.where(
387
+ token_scores < 0,
388
+ token_scores * float(penalty),
389
+ token_scores / float(penalty),
390
+ )
391
+ updated[batch_index].scatter_(0, unique_token_ids, token_scores)
392
+ return updated
393
+
394
+ def _sample_next_token(
395
+ self,
396
+ logits: torch.FloatTensor,
397
+ do_sample: bool,
398
+ temperature: float,
399
+ top_k: Optional[int],
400
+ top_p: Optional[float],
401
+ previous_token_ids: Optional[torch.LongTensor] = None,
402
+ repetition_penalty: float = 1.0,
403
+ ) -> torch.LongTensor:
404
+ scores = logits.float()
405
+ scores = self._apply_repetition_penalty(scores, previous_token_ids, repetition_penalty)
406
+ if not do_sample:
407
+ return torch.argmax(scores, dim=-1)
408
+ if float(temperature) <= 0:
409
+ raise ValueError("temperature must be positive when do_sample=True.")
410
+ scores = scores / float(temperature)
411
+ scores = self._filter_logits(scores, top_k=top_k, top_p=top_p)
412
+ probs = torch.softmax(scores, dim=-1)
413
+ return torch.multinomial(probs, num_samples=1).squeeze(-1)
414
+
415
+ def _sample_next_assistant_text_token(
416
+ self,
417
+ local_hidden_states: torch.FloatTensor,
418
+ do_sample: bool,
419
+ temperature: float,
420
+ top_k: Optional[int],
421
+ top_p: Optional[float],
422
+ ) -> torch.LongTensor:
423
+ if self._use_binary_local_text_head() and self.local_text_lm_head is not None:
424
+ logits = self.local_text_lm_head(local_hidden_states)
425
+ sampled_indices = self._sample_next_token(
426
+ logits=logits,
427
+ do_sample=do_sample,
428
+ temperature=temperature,
429
+ top_k=top_k,
430
+ top_p=top_p,
431
+ )
432
+ candidate_ids = self._local_text_candidate_ids(logits.device)
433
+ return candidate_ids[sampled_indices]
434
+
435
+ candidate_ids = self._local_text_candidate_ids(local_hidden_states.device)
436
+ logits = self.text_lm_head(local_hidden_states).index_select(dim=-1, index=candidate_ids)
437
+ sampled_indices = self._sample_next_token(
438
+ logits=logits,
439
+ do_sample=do_sample,
440
+ temperature=temperature,
441
+ top_k=top_k,
442
+ top_p=top_p,
443
+ )
444
+ return candidate_ids[sampled_indices]
445
+
446
+ def _build_generation_row(
447
+ self,
448
+ batch_size: int,
449
+ device: torch.device,
450
+ audio_token_ids: torch.LongTensor,
451
+ ) -> torch.LongTensor:
452
+ row = torch.full(
453
+ (batch_size, 1, self.config.n_vq + 1),
454
+ int(self.config.audio_pad_token_id),
455
+ dtype=torch.long,
456
+ device=device,
457
+ )
458
+ row[:, :, 0] = int(self.config.audio_assistant_slot_token_id)
459
+ row[:, :, 1:] = audio_token_ids.unsqueeze(1)
460
+ return row
461
+
462
+ @torch.inference_mode()
463
+ def generate(
464
+ self,
465
+ input_ids: torch.LongTensor,
466
+ attention_mask: Optional[torch.Tensor] = None,
467
+ max_new_tokens: Optional[int] = None,
468
+ max_new_frames: Optional[int] = None,
469
+ do_sample: bool = True,
470
+ text_temperature: float = 1.0,
471
+ text_top_p: float = 1.0,
472
+ text_top_k: int = 50,
473
+ audio_temperature: Optional[float] = None,
474
+ audio_top_p: Optional[float] = None,
475
+ audio_top_k: Optional[int] = None,
476
+ audio_repetition_penalty: Optional[float] = None,
477
+ temperature: float = 1.0,
478
+ top_p: float = 0.95,
479
+ top_k: int = 50,
480
+ repetition_penalty: float = 1.0,
481
+ use_kv_cache: bool = True,
482
+ n_vq_for_inference: Optional[int] = None,
483
+ nq: Optional[int] = None,
484
+ **kwargs,
485
+ ) -> list[tuple[int, torch.LongTensor]]:
486
+ del kwargs
487
+ self._resolve_fixed_nq(n_vq_for_inference=n_vq_for_inference, nq=nq)
488
+
489
+ if input_ids.ndim == 2:
490
+ input_ids = input_ids.unsqueeze(0)
491
+ if input_ids.ndim != 3:
492
+ raise ValueError(f"Expected input_ids with 3 dims, got {tuple(input_ids.shape)}.")
493
+ if input_ids.shape[-1] != self.config.n_vq + 1:
494
+ raise ValueError(
495
+ f"Expected {self.config.n_vq + 1} channels from config.n_vq, got {input_ids.shape[-1]}."
496
+ )
497
+ if attention_mask is None:
498
+ attention_mask = torch.ones(input_ids.shape[:2], dtype=torch.bool, device=input_ids.device)
499
+ elif attention_mask.ndim == 1:
500
+ attention_mask = attention_mask.unsqueeze(0)
501
+ attention_mask = attention_mask.to(device=input_ids.device, dtype=torch.bool)
502
+
503
+ frame_budget = max_new_frames if max_new_frames is not None else max_new_tokens
504
+ if frame_budget is None:
505
+ frame_budget = 4096
506
+ frame_budget = int(frame_budget)
507
+
508
+ audio_temperature = float(temperature if audio_temperature is None else audio_temperature)
509
+ audio_top_p = float(top_p if audio_top_p is None else audio_top_p)
510
+ audio_top_k = int(top_k if audio_top_k is None else audio_top_k)
511
+ audio_repetition_penalty = float(
512
+ repetition_penalty if audio_repetition_penalty is None else audio_repetition_penalty
513
+ )
514
+
515
+ batch_size = input_ids.shape[0]
516
+ input_ids_length = input_ids.shape[1]
517
+ current_input_ids = input_ids
518
+ current_attention_mask = attention_mask
519
+ current_model_input_ids = current_input_ids
520
+ generated_frames: list[torch.LongTensor] = []
521
+ finished = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
522
+ past_key_values = None
523
+ local_dtype = self.local_transformer.ln_f.weight.dtype
524
+
525
+ for _ in range(frame_budget):
526
+ generated_audio_history = torch.stack(generated_frames, dim=1) if generated_frames else None
527
+ global_inputs_embeds = self._build_inputs_embeds(current_model_input_ids)
528
+ global_outputs = self.transformer(
529
+ input_ids=None,
530
+ past_key_values=past_key_values,
531
+ attention_mask=current_attention_mask,
532
+ position_ids=None,
533
+ inputs_embeds=global_inputs_embeds,
534
+ use_cache=use_kv_cache,
535
+ output_attentions=False,
536
+ output_hidden_states=False,
537
+ return_dict=True,
538
+ cu_seqlens=None,
539
+ num_sequences=None,
540
+ )
541
+ global_hidden_states = global_outputs.last_hidden_state[:, -1, :]
542
+ local_global_hidden_states = self._global_hidden_to_local(global_hidden_states).to(dtype=local_dtype)
543
+
544
+ local_prefix_hidden_states, local_prefix_past_key_values = self._decode_local_hidden_states_with_cache(
545
+ local_global_hidden_states.unsqueeze(1)
546
+ )
547
+ local_hidden_states = local_prefix_hidden_states[:, -1, :]
548
+ next_text_tokens = self._sample_next_assistant_text_token(
549
+ local_hidden_states=local_hidden_states,
550
+ do_sample=do_sample,
551
+ temperature=text_temperature,
552
+ top_k=text_top_k,
553
+ top_p=text_top_p,
554
+ )
555
+ should_continue = next_text_tokens.eq(int(self.config.audio_assistant_slot_token_id)) & ~finished
556
+ finished = finished | next_text_tokens.eq(int(self.config.audio_end_token_id))
557
+ if not bool(should_continue.any().item()):
558
+ break
559
+
560
+ next_frame_tokens = []
561
+ for channel_index in range(int(self.config.n_vq)):
562
+ channel_logits = self.audio_lm_heads[channel_index](local_hidden_states)
563
+ channel_token = self._sample_next_token(
564
+ logits=channel_logits,
565
+ do_sample=do_sample,
566
+ temperature=audio_temperature,
567
+ top_k=audio_top_k,
568
+ top_p=audio_top_p,
569
+ previous_token_ids=(
570
+ None
571
+ if generated_audio_history is None
572
+ else generated_audio_history[:, :, channel_index]
573
+ ),
574
+ repetition_penalty=audio_repetition_penalty,
575
+ )
576
+ next_frame_tokens.append(channel_token)
577
+ if channel_index + 1 < int(self.config.n_vq):
578
+ current_local_input = self.audio_embeddings[channel_index](channel_token).to(dtype=local_dtype)
579
+ local_token_hidden_states, local_prefix_past_key_values = (
580
+ self._decode_local_hidden_states_with_cache(
581
+ current_local_input.unsqueeze(1),
582
+ past_key_values=local_prefix_past_key_values,
583
+ )
584
+ )
585
+ local_hidden_states = local_token_hidden_states[:, -1, :]
586
+
587
+ next_frame = torch.stack(next_frame_tokens, dim=-1)
588
+ next_frame = next_frame.masked_fill(
589
+ ~should_continue.unsqueeze(-1),
590
+ int(self.config.audio_pad_token_id),
591
+ )
592
+ generated_frames.append(next_frame)
593
+
594
+ next_row = self._build_generation_row(
595
+ batch_size=batch_size,
596
+ device=input_ids.device,
597
+ audio_token_ids=next_frame,
598
+ )
599
+ if bool((~should_continue).any().item()):
600
+ next_row[~should_continue, 0, 0] = int(self.config.pad_token_id)
601
+ next_row[~should_continue, 0, 1:] = int(self.config.audio_pad_token_id)
602
+
603
+ current_input_ids = torch.cat([current_input_ids, next_row], dim=1)
604
+ current_attention_mask = torch.cat(
605
+ [current_attention_mask, should_continue.unsqueeze(1)],
606
+ dim=1,
607
+ )
608
+ if use_kv_cache:
609
+ current_model_input_ids = next_row
610
+ past_key_values = global_outputs.past_key_values
611
+ else:
612
+ current_model_input_ids = current_input_ids
613
+
614
+ start_indices = _find_last_equal(input_ids[..., 0], int(self.config.audio_start_token_id))
615
+ start_lengths = input_ids_length - start_indices - 1
616
+ outputs: list[tuple[int, torch.LongTensor]] = []
617
+ for start_index, start_length, generation_ids in zip(
618
+ start_indices.tolist(),
619
+ start_lengths.tolist(),
620
+ current_input_ids,
621
+ ):
622
+ outputs.append((int(start_length), generation_ids[int(start_index):].detach().cpu()))
623
+ return outputs
processing_moss_tts.py ADDED
@@ -0,0 +1,899 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """Processor for the MOSS-TTS-Local-Transformer-v1.5 HuggingFace release."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import os
7
+ import re
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
11
+
12
+ import torch
13
+ import torchaudio
14
+ from transformers import (
15
+ AutoConfig,
16
+ AutoModel,
17
+ AutoTokenizer,
18
+ BatchFeature,
19
+ PreTrainedTokenizerBase,
20
+ ProcessorMixin,
21
+ logging,
22
+ processing_utils,
23
+ )
24
+
25
+ from .configuration_moss_tts import MossTTSLocalConfig
26
+
27
+
28
+ if hasattr(processing_utils, "MODALITY_TO_BASE_CLASS_MAPPING"):
29
+ processing_utils.MODALITY_TO_BASE_CLASS_MAPPING["audio_tokenizer"] = "PreTrainedModel"
30
+ else:
31
+ processing_utils.AUTO_TO_BASE_CLASS_MAPPING["AutoModel"] = "PreTrainedModel"
32
+ logger = logging.get_logger(__name__)
33
+
34
+ AUDIO_PLACEHOLDER = "<|audio|>"
35
+ USER_ROLE_PREFIX = "user\n"
36
+ USER_TEMPLATE_REFERENCE_PREFIX = (
37
+ "<user_inst>\n"
38
+ "- Reference(s):\n"
39
+ )
40
+ USER_TEMPLATE_AFTER_REFERENCE_SUFFIX = (
41
+ "\n"
42
+ "- Text:\n"
43
+ )
44
+ USER_TEMPLATE_SUFFIX = "\n</user_inst>"
45
+ ASSISTANT_TURN_PREFIX = "\n"
46
+ ASSISTANT_ROLE_PREFIX = "assistant\n"
47
+ USER_MESSAGE_FIELDS = (
48
+ "text",
49
+ "reference",
50
+ "instruction",
51
+ "tokens",
52
+ "quality",
53
+ "sound_event",
54
+ "ambient_sound",
55
+ "language",
56
+ )
57
+
58
+
59
+ def _normalize_template_value(value: Any) -> str:
60
+ if value is None:
61
+ return "None"
62
+ resolved = str(value).strip()
63
+ return resolved or "None"
64
+
65
+
66
+ def _render_user_prompt_after_reference(
67
+ language_code: object | None = None,
68
+ prompt_fields: Optional[Dict[str, Any]] = None,
69
+ ) -> str:
70
+ fields = dict(prompt_fields or {})
71
+ return (
72
+ "\n- Instruction:\n"
73
+ + _normalize_template_value(fields.get("instruction"))
74
+ + "\n- Tokens:\n"
75
+ + _normalize_template_value(fields.get("tokens"))
76
+ + "\n- Quality:\n"
77
+ + _normalize_template_value(fields.get("quality"))
78
+ + "\n- Sound Event:\n"
79
+ + _normalize_template_value(fields.get("sound_event"))
80
+ + "\n- Ambient Sound:\n"
81
+ + _normalize_template_value(fields.get("ambient_sound"))
82
+ + "\n- Language:\n"
83
+ + _normalize_template_value(fields.get("language", language_code))
84
+ + USER_TEMPLATE_AFTER_REFERENCE_SUFFIX
85
+ )
86
+
87
+
88
+ @dataclass
89
+ class Message:
90
+ def to_dict(self) -> Dict[str, Any]:
91
+ raise NotImplementedError
92
+
93
+
94
+ @dataclass
95
+ class UserMessage(Message):
96
+ text: Optional[str] = None
97
+ reference: Optional[List[Optional[Union[str, os.PathLike, torch.Tensor]]]] = None
98
+ instruction: Optional[str] = None
99
+ tokens: Optional[int] = None
100
+ quality: Optional[str] = None
101
+ sound_event: Optional[str] = None
102
+ ambient_sound: Optional[str] = None
103
+ language: Optional[str] = None
104
+
105
+ def __post_init__(self) -> None:
106
+ template = """<user_inst>
107
+ - Reference(s):
108
+ {reference}
109
+ - Instruction:
110
+ {instruction}
111
+ - Tokens:
112
+ {tokens}
113
+ - Quality:
114
+ {quality}
115
+ - Sound Event:
116
+ {sound_event}
117
+ - Ambient Sound:
118
+ {ambient_sound}
119
+ - Language:
120
+ {language}
121
+ - Text:
122
+ {text}
123
+ </user_inst>"""
124
+
125
+ audio_codes_list: list[Union[str, os.PathLike, torch.Tensor]] = []
126
+ if self.reference is None:
127
+ reference = "None"
128
+ else:
129
+ reference_items: list[str] = []
130
+ for speaker_idx, speaker_reference in enumerate(self.reference):
131
+ if speaker_reference is None:
132
+ continue
133
+ # Keep raw audio placeholders directly under "- Reference(s):".
134
+ # Speaker labels such as "[S1]:" change the token sequence and
135
+ # can affect voice-clone conditioning.
136
+ reference_items.append(AUDIO_PLACEHOLDER)
137
+ audio_codes_list.append(speaker_reference)
138
+ reference = "\n".join(reference_items) if reference_items else "None"
139
+
140
+ self._content = (
141
+ template.replace("{reference}", str(reference))
142
+ .replace("{instruction}", str(self.instruction))
143
+ .replace("{tokens}", str(self.tokens))
144
+ .replace("{quality}", str(self.quality))
145
+ .replace("{sound_event}", str(self.sound_event))
146
+ .replace("{ambient_sound}", str(self.ambient_sound))
147
+ .replace("{language}", str(self.language))
148
+ .replace("{text}", str(self.text))
149
+ )
150
+ self._audio_codes_list = audio_codes_list
151
+
152
+ def to_dict(self) -> Dict[str, Any]:
153
+ return {
154
+ "role": "user",
155
+ "content": self._content,
156
+ "audio_codes_list": self._audio_codes_list,
157
+ "text": self.text,
158
+ "instruction": self.instruction,
159
+ "tokens": self.tokens,
160
+ "quality": self.quality,
161
+ "sound_event": self.sound_event,
162
+ "ambient_sound": self.ambient_sound,
163
+ "language": self.language,
164
+ }
165
+
166
+
167
+ @dataclass
168
+ class AssistantMessage(Message):
169
+ audio_codes_list: List[Union[str, os.PathLike, torch.Tensor]]
170
+ content: str = AUDIO_PLACEHOLDER
171
+
172
+ def to_dict(self) -> Dict[str, Any]:
173
+ return {
174
+ "role": "assistant",
175
+ "content": self.content,
176
+ "audio_codes_list": self.audio_codes_list,
177
+ }
178
+
179
+
180
+ class MossTTSLocalProcessor(ProcessorMixin):
181
+ attributes = ["tokenizer"]
182
+ tokenizer_class = "AutoTokenizer"
183
+ audio_tokenizer_class = "AutoModel"
184
+
185
+ tokenizer: PreTrainedTokenizerBase
186
+ audio_tokenizer: Any
187
+
188
+ def __init__(
189
+ self,
190
+ tokenizer: PreTrainedTokenizerBase,
191
+ audio_tokenizer: Any = None,
192
+ model_config: Optional[MossTTSLocalConfig] = None,
193
+ **kwargs,
194
+ ) -> None:
195
+ super().__init__(tokenizer=tokenizer, audio_tokenizer=audio_tokenizer, **kwargs)
196
+ self.tokenizer = tokenizer
197
+ self.audio_tokenizer = audio_tokenizer
198
+ self.model_config = model_config or MossTTSLocalConfig()
199
+
200
+ def _id_to_token(token_id: int) -> str:
201
+ token = tokenizer.convert_ids_to_tokens(int(token_id))
202
+ if isinstance(token, list):
203
+ return token[0] if token else ""
204
+ return cast(str, token)
205
+
206
+ self.audio_user_slot_token = _id_to_token(self.model_config.audio_user_slot_token_id)
207
+ self.audio_assistant_slot_token = _id_to_token(self.model_config.audio_assistant_slot_token_id)
208
+ self.audio_start_token = _id_to_token(self.model_config.audio_start_token_id)
209
+ self.audio_end_token = _id_to_token(self.model_config.audio_end_token_id)
210
+
211
+ @classmethod
212
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
213
+ trust_remote_code = kwargs.pop("trust_remote_code", True)
214
+ kwargs.pop("_from_auto", None)
215
+ codec_path = kwargs.pop("codec_path", None)
216
+
217
+ model_ref = Path(str(pretrained_model_name_or_path))
218
+ model_ref_or_name = model_ref if model_ref.exists() else pretrained_model_name_or_path
219
+ model_config = cast(
220
+ MossTTSLocalConfig,
221
+ AutoConfig.from_pretrained(
222
+ model_ref_or_name,
223
+ *args,
224
+ trust_remote_code=trust_remote_code,
225
+ **kwargs,
226
+ ),
227
+ )
228
+
229
+ if codec_path is None:
230
+ try:
231
+ processor_dict, _ = cls.get_processor_dict(
232
+ pretrained_model_name_or_path,
233
+ **dict(kwargs),
234
+ )
235
+ codec_path = processor_dict.get("audio_tokenizer_name_or_path")
236
+ audio_tokenizer_dict = processor_dict.get("audio_tokenizer", {})
237
+ if isinstance(audio_tokenizer_dict, dict):
238
+ codec_path = audio_tokenizer_dict.get("audio_tokenizer_name_or_path") or codec_path
239
+ except Exception:
240
+ codec_path = None
241
+ if codec_path is None:
242
+ codec_path = getattr(model_config, "audio_tokenizer_name_or_path", None)
243
+ if codec_path is None:
244
+ codec_path = "OpenMOSS-Team/MOSS-Audio-Tokenizer-v2"
245
+
246
+ tokenizer = AutoTokenizer.from_pretrained(
247
+ model_ref_or_name,
248
+ *args,
249
+ trust_remote_code=trust_remote_code,
250
+ **kwargs,
251
+ )
252
+ audio_tokenizer = AutoModel.from_pretrained(
253
+ codec_path,
254
+ trust_remote_code=trust_remote_code,
255
+ **kwargs,
256
+ )
257
+ return cls(
258
+ tokenizer=tokenizer,
259
+ audio_tokenizer=audio_tokenizer,
260
+ model_config=model_config,
261
+ **kwargs,
262
+ )
263
+
264
+ @staticmethod
265
+ def build_user_message(
266
+ text: Optional[str] = None,
267
+ reference: Optional[List[Optional[Union[str, os.PathLike, torch.Tensor]]]] = None,
268
+ instruction: Optional[str] = None,
269
+ tokens: Optional[int] = None,
270
+ quality: Optional[str] = None,
271
+ sound_event: Optional[str] = None,
272
+ ambient_sound: Optional[str] = None,
273
+ language: Optional[str] = None,
274
+ ) -> Dict[str, Any]:
275
+ if reference is not None and not isinstance(reference, list):
276
+ reference = [reference]
277
+ return UserMessage(
278
+ text=text,
279
+ reference=reference,
280
+ instruction=instruction,
281
+ tokens=tokens,
282
+ quality=quality,
283
+ sound_event=sound_event,
284
+ ambient_sound=ambient_sound,
285
+ language=language,
286
+ ).to_dict()
287
+
288
+ @staticmethod
289
+ def build_assistant_message(
290
+ audio_codes_list: List[Union[str, os.PathLike, torch.Tensor]],
291
+ content: str = AUDIO_PLACEHOLDER,
292
+ ) -> Dict[str, Any]:
293
+ return AssistantMessage(audio_codes_list=audio_codes_list, content=content).to_dict()
294
+
295
+ def _assert_fixed_nq(self, n_vq: Optional[int]) -> int:
296
+ config_nq = int(self.model_config.n_vq)
297
+ if n_vq is not None and int(n_vq) != config_nq:
298
+ raise ValueError(
299
+ "This MOSS-TTS-Local-Transformer-v1.5 release uses the RVQ depth stored in the model config. "
300
+ f"Expected n_vq={config_nq}, got {int(n_vq)}."
301
+ )
302
+ return config_nq
303
+
304
+ def _encode_text(self, text: str) -> list[int]:
305
+ try:
306
+ return list(self.tokenizer.encode(text, add_special_tokens=False))
307
+ except TypeError:
308
+ return list(self.tokenizer.encode(text))
309
+
310
+ def _build_text_rows(self, token_ids: Sequence[int], *, device: Optional[torch.device] = None) -> torch.Tensor:
311
+ rows = torch.full(
312
+ (len(token_ids), int(self.model_config.n_vq) + 1),
313
+ int(self.model_config.audio_pad_token_id),
314
+ dtype=torch.long,
315
+ device=device,
316
+ )
317
+ if token_ids:
318
+ rows[:, 0] = torch.tensor([int(token_id) for token_id in token_ids], dtype=torch.long, device=rows.device)
319
+ return rows
320
+
321
+ def _build_audio_rows(self, audio_tokens: torch.Tensor, slot_token_id: int) -> torch.Tensor:
322
+ rows = torch.full(
323
+ (int(audio_tokens.shape[0]), int(self.model_config.n_vq) + 1),
324
+ int(self.model_config.audio_pad_token_id),
325
+ dtype=torch.long,
326
+ device=audio_tokens.device,
327
+ )
328
+ if rows.shape[0] > 0:
329
+ rows[:, 0] = int(slot_token_id)
330
+ rows[:, 1:] = audio_tokens.to(dtype=torch.long)
331
+ return rows
332
+
333
+ def _user_prompt_prefix_ids(self) -> list[int]:
334
+ return (
335
+ [int(self.model_config.im_start_token_id)]
336
+ + self._encode_text(USER_ROLE_PREFIX)
337
+ + self._encode_text(USER_TEMPLATE_REFERENCE_PREFIX)
338
+ )
339
+
340
+ def _user_prompt_after_reference_ids(
341
+ self,
342
+ language_code: object | None,
343
+ prompt_fields: Optional[Dict[str, Any]],
344
+ ) -> list[int]:
345
+ return self._encode_text(
346
+ _render_user_prompt_after_reference(
347
+ language_code=language_code,
348
+ prompt_fields=prompt_fields,
349
+ )
350
+ )
351
+
352
+ def _assistant_prompt_prefix_ids(self) -> list[int]:
353
+ return (
354
+ self._encode_text(USER_TEMPLATE_SUFFIX)
355
+ + [int(self.model_config.im_end_token_id)]
356
+ + self._encode_text(ASSISTANT_TURN_PREFIX)
357
+ + [int(self.model_config.im_start_token_id)]
358
+ + self._encode_text(ASSISTANT_ROLE_PREFIX)
359
+ )
360
+
361
+ def _prompt_fields_from_user_message(self, message: Dict[str, Any]) -> dict[str, Any]:
362
+ fields = {}
363
+ for key in ("instruction", "tokens", "quality", "sound_event", "ambient_sound"):
364
+ if key in message and message.get(key) is not None:
365
+ fields[key] = message.get(key)
366
+ if "language" in message and message.get("language") is not None:
367
+ fields["language"] = message.get("language")
368
+ return fields
369
+
370
+ def _build_generation_or_voice_clone_codes(
371
+ self,
372
+ message: Dict[str, Any],
373
+ n_vq: int,
374
+ ) -> torch.Tensor:
375
+ if "text" not in message:
376
+ raise ValueError("Direct MOSS-TTS-Local-Transformer-v1.5 generation requires messages built by build_user_message(...).")
377
+ text = "" if message.get("text") is None else str(message.get("text"))
378
+ prompt_fields = self._prompt_fields_from_user_message(message)
379
+ language_code = message.get("language")
380
+ audio_codes_list = self._resolve_audio_items(message.get("audio_codes_list", []), n_vq)
381
+ text_token_ids = self._encode_text(text)
382
+
383
+ if audio_codes_list:
384
+ parts: list[torch.Tensor] = [self._build_text_rows(
385
+ self._user_prompt_prefix_ids(),
386
+ device=audio_codes_list[0].device,
387
+ )]
388
+ for reference_codes in audio_codes_list:
389
+ parts.append(self._build_text_rows([int(self.model_config.audio_start_token_id)], device=reference_codes.device))
390
+ parts.append(self._build_audio_rows(reference_codes, int(self.model_config.audio_user_slot_token_id)))
391
+ parts.append(self._build_text_rows([int(self.model_config.audio_end_token_id)], device=reference_codes.device))
392
+ parts.append(
393
+ self._build_text_rows(
394
+ self._user_prompt_after_reference_ids(language_code, prompt_fields)
395
+ + text_token_ids
396
+ + self._assistant_prompt_prefix_ids()
397
+ + [int(self.model_config.audio_start_token_id)],
398
+ device=audio_codes_list[0].device,
399
+ )
400
+ )
401
+ return torch.cat(parts, dim=0)
402
+
403
+ prompt_token_ids = (
404
+ self._user_prompt_prefix_ids()
405
+ + self._encode_text("None")
406
+ + self._user_prompt_after_reference_ids(language_code, prompt_fields)
407
+ + text_token_ids
408
+ + self._assistant_prompt_prefix_ids()
409
+ + [int(self.model_config.audio_start_token_id)]
410
+ )
411
+ return self._build_text_rows(prompt_token_ids)
412
+
413
+ def _build_continuation_codes(
414
+ self,
415
+ conversation: list[Dict[str, Any]],
416
+ n_vq: int,
417
+ ) -> torch.Tensor:
418
+ if len(conversation) < 2:
419
+ raise ValueError("continuation mode requires a user message followed by an assistant audio message.")
420
+ user_message = conversation[-2]
421
+ assistant_message = conversation[-1]
422
+ if user_message.get("role") != "user" or assistant_message.get("role") != "assistant":
423
+ raise ValueError("continuation mode requires the last two messages to be user, assistant.")
424
+ if "text" not in user_message:
425
+ raise ValueError("Direct MOSS-TTS-Local-Transformer-v1.5 continuation requires user messages built by build_user_message(...).")
426
+
427
+ text = "" if user_message.get("text") is None else str(user_message.get("text"))
428
+ prompt_fields = self._prompt_fields_from_user_message(user_message)
429
+ language_code = user_message.get("language")
430
+ prompt_token_ids = (
431
+ self._user_prompt_prefix_ids()
432
+ + self._encode_text("None")
433
+ + self._user_prompt_after_reference_ids(language_code, prompt_fields)
434
+ + self._encode_text(text)
435
+ + self._assistant_prompt_prefix_ids()
436
+ + [int(self.model_config.audio_start_token_id)]
437
+ )
438
+ audio_codes_list = self._resolve_audio_items(assistant_message.get("audio_codes_list", []), n_vq)
439
+ if not audio_codes_list:
440
+ return self._build_text_rows(prompt_token_ids)
441
+ if len(audio_codes_list) != 1:
442
+ raise ValueError("The MOSS-TTS-Local-Transformer-v1.5 continuation path expects exactly one prompt audio item.")
443
+ prompt_audio_codes = audio_codes_list[0]
444
+ return torch.cat(
445
+ [
446
+ self._build_text_rows(prompt_token_ids, device=prompt_audio_codes.device),
447
+ self._build_audio_rows(prompt_audio_codes, int(self.model_config.audio_assistant_slot_token_id)),
448
+ ],
449
+ dim=0,
450
+ )
451
+
452
+ def _try_build_direct_codes(
453
+ self,
454
+ conversation: list[Dict[str, Any]],
455
+ mode: str,
456
+ n_vq: int,
457
+ ) -> Optional[torch.Tensor]:
458
+ if mode == "generation" and len(conversation) == 1 and conversation[-1].get("role") == "user":
459
+ if "text" in conversation[-1]:
460
+ return self._build_generation_or_voice_clone_codes(conversation[-1], n_vq)
461
+ return None
462
+ if mode == "continuation" and len(conversation) >= 2:
463
+ if "text" in conversation[-2]:
464
+ return self._build_continuation_codes(conversation, n_vq)
465
+ return None
466
+ return None
467
+
468
+ def __call__(self, *args, **kwargs) -> BatchFeature:
469
+ conversations = args[0] if args else kwargs.pop("conversations")
470
+ mode: str = kwargs.pop("mode", "generation")
471
+ apply_chat_template: bool = kwargs.pop("apply_chat_template", True)
472
+ n_vq = self._assert_fixed_nq(kwargs.pop("n_vq", None))
473
+
474
+ kwargs.pop("return_tensors", None)
475
+ kwargs.pop("padding", None)
476
+ kwargs.pop("truncation", None)
477
+
478
+ if mode not in {"generation", "continuation", "computing_loss"}:
479
+ raise ValueError(f"Unsupported mode: {mode}")
480
+ if isinstance(conversations, (Message, dict)):
481
+ conversations = [conversations]
482
+ elif isinstance(conversations, list) and conversations and all(
483
+ isinstance(item, (Message, dict)) for item in conversations
484
+ ):
485
+ conversations = [conversations]
486
+
487
+ input_ids_list: list[torch.Tensor] = []
488
+ for conversation in conversations:
489
+ if isinstance(conversation, (Message, dict)):
490
+ conversation = [conversation]
491
+ conversation = [self._normalize_message(message) for message in conversation]
492
+
493
+ if (mode == "generation") ^ (conversation[-1]["role"] == "user"):
494
+ raise ValueError("generation mode must end with a user message.")
495
+ if mode == "continuation" and conversation[-1]["role"] != "assistant":
496
+ raise ValueError("continuation mode must end with an assistant message.")
497
+
498
+ direct_codes = self._try_build_direct_codes(conversation, mode, n_vq)
499
+ if direct_codes is not None:
500
+ input_ids_list.append(direct_codes)
501
+ continue
502
+
503
+ unified_parts = []
504
+ for message_idx, message in enumerate(conversation):
505
+ content = str(message["content"])
506
+ if apply_chat_template:
507
+ add_generation_prompt = mode == "generation" and message_idx == len(conversation) - 1
508
+ try:
509
+ content = self.tokenizer.apply_chat_template(
510
+ [{"role": message["role"], "content": content}],
511
+ add_generation_prompt=add_generation_prompt,
512
+ tokenize=False,
513
+ )
514
+ except Exception:
515
+ logger.warning("apply_chat_template failed; falling back to raw message content.")
516
+
517
+ raw_audio_items = message.get("audio_codes_list", [])
518
+ audio_codes_list = self._resolve_audio_items(raw_audio_items, n_vq)
519
+ unified_parts.append(
520
+ self._get_unified_codes(
521
+ role=message["role"],
522
+ content=content,
523
+ audio_codes_list=audio_codes_list,
524
+ truncation=(mode == "continuation"),
525
+ )
526
+ )
527
+
528
+ unified_codes = torch.cat(unified_parts, dim=0)
529
+ if mode == "generation":
530
+ audio_start_row = torch.full(
531
+ (1, n_vq + 1),
532
+ int(self.model_config.audio_pad_token_id),
533
+ dtype=unified_codes.dtype,
534
+ device=unified_codes.device,
535
+ )
536
+ audio_start_row[:, 0] = int(self.model_config.audio_start_token_id)
537
+ unified_codes = torch.cat([unified_codes, audio_start_row], dim=0)
538
+ input_ids_list.append(unified_codes)
539
+
540
+ return BatchFeature(data=self._pad(input_ids_list))
541
+
542
+ def _normalize_message(self, message: Union[Message, Dict[str, Any]]) -> Dict[str, Any]:
543
+ if isinstance(message, Message):
544
+ return message.to_dict()
545
+ if not isinstance(message, dict):
546
+ raise TypeError("Each message must be a Message or dict.")
547
+ if "content" in message and "audio_codes_list" in message:
548
+ return message
549
+ role = message.get("role")
550
+ if role == "user":
551
+ return self.build_user_message(**{key: message.get(key) for key in USER_MESSAGE_FIELDS})
552
+ if role == "assistant":
553
+ return self.build_assistant_message(
554
+ audio_codes_list=message.get("audio_codes_list", []),
555
+ content=message.get("content", AUDIO_PLACEHOLDER),
556
+ )
557
+ raise ValueError(f"Unsupported role: {role}")
558
+
559
+ def _resolve_audio_items(
560
+ self,
561
+ raw_audio_items: list[Any],
562
+ n_vq: int,
563
+ ) -> list[torch.Tensor]:
564
+ if not raw_audio_items:
565
+ return []
566
+ resolved: list[Optional[torch.Tensor]] = [None] * len(raw_audio_items)
567
+ paths: list[str] = []
568
+ path_positions: list[int] = []
569
+ for index, item in enumerate(raw_audio_items):
570
+ if isinstance(item, torch.Tensor):
571
+ if item.ndim != 2 or int(item.shape[1]) != n_vq:
572
+ raise ValueError(f"audio code tensor must have shape [T, {n_vq}], got {tuple(item.shape)}.")
573
+ resolved[index] = item.to(dtype=torch.long).cpu()
574
+ elif isinstance(item, (str, os.PathLike)):
575
+ paths.append(str(item))
576
+ path_positions.append(index)
577
+ else:
578
+ raise TypeError("Audio items must be tensors or path-like values.")
579
+ if paths:
580
+ encoded = self.encode_audios_from_path(paths, n_vq=n_vq)
581
+ for position, codes in zip(path_positions, encoded):
582
+ resolved[position] = codes
583
+ return [cast(torch.Tensor, item) for item in resolved]
584
+
585
+ def _pad(self, input_ids_list: list[torch.Tensor]) -> Dict[str, torch.Tensor]:
586
+ device = input_ids_list[0].device
587
+ lengths = torch.tensor([item.shape[0] for item in input_ids_list], device=device)
588
+ padded = torch.nn.utils.rnn.pad_sequence(
589
+ input_ids_list,
590
+ batch_first=True,
591
+ padding_value=int(self.model_config.audio_pad_token_id),
592
+ padding_side="left",
593
+ )
594
+ left_pad_mask = (padded.shape[1] - lengths).unsqueeze(1) > torch.arange(
595
+ padded.shape[1],
596
+ device=device,
597
+ ).unsqueeze(0)
598
+ padded[..., 0][left_pad_mask] = int(self.model_config.pad_token_id)
599
+ attention_mask = torch.zeros(padded.shape[:2], dtype=torch.bool, device=device)
600
+ attention_mask[~left_pad_mask] = True
601
+ return {"input_ids": padded, "attention_mask": attention_mask}
602
+
603
+ @staticmethod
604
+ def _replace_audio_placeholders(
605
+ content: str,
606
+ lengths: list[int],
607
+ slot_token: str,
608
+ audio_start_token: str,
609
+ audio_end_token: str,
610
+ ) -> str:
611
+ placeholder_count = content.count(AUDIO_PLACEHOLDER)
612
+ if placeholder_count != len(lengths):
613
+ raise ValueError(
614
+ f"Number of {AUDIO_PLACEHOLDER} ({placeholder_count}) does not match "
615
+ f"audio item count ({len(lengths)})."
616
+ )
617
+ lengths_iter = iter(lengths)
618
+
619
+ def replacer(_: re.Match) -> str:
620
+ length = int(next(lengths_iter))
621
+ if length <= 0:
622
+ return f"{audio_start_token}{audio_end_token}"
623
+ return f"{audio_start_token}{slot_token * length}{audio_end_token}"
624
+
625
+ return re.sub(re.escape(AUDIO_PLACEHOLDER), replacer, content)
626
+
627
+ def _get_unified_codes(
628
+ self,
629
+ role: str,
630
+ content: str,
631
+ audio_codes_list: list[torch.Tensor],
632
+ truncation: bool,
633
+ ) -> torch.Tensor:
634
+ n_vq = int(self.model_config.n_vq)
635
+ slot_token = self.audio_user_slot_token if role == "user" else self.audio_assistant_slot_token
636
+ content = self._replace_audio_placeholders(
637
+ content=content,
638
+ lengths=[int(codes.shape[0]) for codes in audio_codes_list],
639
+ slot_token=slot_token,
640
+ audio_start_token=self.audio_start_token,
641
+ audio_end_token=self.audio_end_token,
642
+ )
643
+ text_codes = torch.tensor(
644
+ self.tokenizer.encode(content),
645
+ dtype=torch.long,
646
+ device=audio_codes_list[0].device if audio_codes_list else None,
647
+ )
648
+
649
+ audio_start_indices = torch.where(text_codes == int(self.model_config.audio_start_token_id))[0]
650
+ audio_end_indices = torch.where(text_codes == int(self.model_config.audio_end_token_id))[0]
651
+ if len(audio_start_indices) != len(audio_codes_list) or len(audio_end_indices) != len(audio_codes_list):
652
+ raise ValueError("Audio placeholders do not match the encoded audio spans.")
653
+
654
+ if not audio_codes_list:
655
+ audio_codes = torch.full(
656
+ (len(text_codes), n_vq),
657
+ int(self.model_config.audio_pad_token_id),
658
+ dtype=torch.long,
659
+ device=text_codes.device,
660
+ )
661
+ else:
662
+ pieces: list[torch.Tensor] = []
663
+ prefix_idx = 0
664
+ for start_t, end_t, codes in zip(audio_start_indices, audio_end_indices, audio_codes_list):
665
+ start_idx = int(start_t.item())
666
+ end_idx = int(end_t.item())
667
+ pad_before = torch.full(
668
+ (start_idx - prefix_idx + 1, n_vq),
669
+ int(self.model_config.audio_pad_token_id),
670
+ dtype=torch.long,
671
+ device=codes.device,
672
+ )
673
+ pieces.extend([pad_before, codes.to(dtype=torch.long)])
674
+ prefix_idx = end_idx
675
+ if truncation:
676
+ trailing = torch.zeros(
677
+ (0, n_vq),
678
+ dtype=torch.long,
679
+ device=audio_codes_list[0].device,
680
+ )
681
+ else:
682
+ last_end = int(audio_end_indices[-1].item())
683
+ trailing = torch.full(
684
+ (len(text_codes) - last_end, n_vq),
685
+ int(self.model_config.audio_pad_token_id),
686
+ dtype=torch.long,
687
+ device=audio_codes_list[0].device,
688
+ )
689
+ pieces.append(trailing)
690
+ audio_codes = torch.cat(pieces, dim=0)
691
+
692
+ if text_codes.shape[0] != audio_codes.shape[0]:
693
+ min_len = min(text_codes.shape[0], audio_codes.shape[0])
694
+ text_codes = text_codes[:min_len]
695
+ audio_codes = audio_codes[:min_len]
696
+ return torch.cat([text_codes.unsqueeze(1), audio_codes], dim=1)
697
+
698
+ def _parse_text_codes(self, start_length: int, text_codes: torch.LongTensor) -> str:
699
+ text = cast(str, self.tokenizer.decode(text_codes))
700
+ prefix = cast(str, self.tokenizer.decode(text_codes[:start_length]))
701
+ text = text[len(prefix):]
702
+ audio_pattern = re.compile(
703
+ rf"(?:{re.escape(self.audio_start_token)})?"
704
+ rf"(?:{re.escape(self.audio_assistant_slot_token)})*"
705
+ rf"{re.escape(self.audio_end_token)}"
706
+ )
707
+ return audio_pattern.sub(
708
+ lambda match: AUDIO_PLACEHOLDER if self.audio_assistant_slot_token in match.group(0) else "",
709
+ text,
710
+ )
711
+
712
+ def _parse_audio_codes(
713
+ self,
714
+ start_length: int,
715
+ audio_codes: torch.LongTensor,
716
+ *,
717
+ return_stereo: bool = True,
718
+ ) -> list[torch.Tensor]:
719
+ is_pad = audio_codes.eq(int(self.model_config.audio_pad_token_id)).all(dim=1)
720
+ non_pad = ~is_pad
721
+ if not bool(non_pad.any().item()):
722
+ return []
723
+ idx = torch.nonzero(non_pad).squeeze(1)
724
+ breaks = torch.where(idx[1:] != idx[:-1] + 1)[0] + 1
725
+ segment_indices = [idx] if breaks.numel() == 0 else list(torch.tensor_split(idx, breaks.cpu().tolist()))
726
+ code_segments = [audio_codes[segment] for segment in segment_indices]
727
+ decoded = self.decode_audio_codes(code_segments, return_stereo=return_stereo)
728
+
729
+ if start_length > 0 and code_segments and decoded:
730
+ first_code_length = int(code_segments[0].shape[0])
731
+ if first_code_length > 0:
732
+ trim_ratio = max(0.0, min(float(start_length) / float(first_code_length), 1.0))
733
+ if trim_ratio >= 1.0:
734
+ decoded = decoded[1:]
735
+ elif trim_ratio > 0.0:
736
+ trim_samples = int(decoded[0].shape[-1] * trim_ratio)
737
+ decoded[0] = decoded[0][..., trim_samples:]
738
+ return decoded
739
+
740
+ def decode(self, output: Any, *, return_stereo: bool = True) -> list[Optional[AssistantMessage]]:
741
+ generated_messages: list[Optional[AssistantMessage]] = []
742
+ for start_length, generation_ids in output:
743
+ content = self._parse_text_codes(int(start_length), generation_ids[:, 0])
744
+ audio_codes_list = self._parse_audio_codes(
745
+ int(start_length),
746
+ generation_ids[:, 1:],
747
+ return_stereo=return_stereo,
748
+ )
749
+ if content == "":
750
+ generated_messages.append(None)
751
+ else:
752
+ generated_messages.append(
753
+ AssistantMessage(
754
+ content=content,
755
+ audio_codes_list=cast(list[Union[str, torch.Tensor]], audio_codes_list),
756
+ )
757
+ )
758
+ return generated_messages
759
+
760
+ @staticmethod
761
+ def loudness_normalize(
762
+ wav: torch.Tensor,
763
+ target_dbfs: float = -20.0,
764
+ gain_range: tuple[float, float] = (-3.0, 3.0),
765
+ ) -> torch.Tensor:
766
+ wav = wav.to(torch.float32)
767
+ if wav.numel() == 0:
768
+ return wav
769
+ current_dbfs = 10.0 * torch.log10(torch.mean(wav**2) + 1e-9)
770
+ gain = max(gain_range[0], min(float(target_dbfs - current_dbfs), gain_range[1]))
771
+ return wav * (10.0 ** (gain / 20.0))
772
+
773
+ def _get_audio_tokenizer_device(self) -> torch.device:
774
+ audio_tokenizer = getattr(self, "audio_tokenizer", None)
775
+ if audio_tokenizer is None:
776
+ raise RuntimeError("audio_tokenizer is not set.")
777
+ try:
778
+ return next(audio_tokenizer.parameters()).device
779
+ except StopIteration:
780
+ return torch.device("cpu")
781
+
782
+ def encode_audios_from_wav(
783
+ self,
784
+ wav_list: Union[torch.Tensor, list[torch.Tensor]],
785
+ sampling_rate: int,
786
+ n_vq: Optional[int] = None,
787
+ ) -> list[torch.Tensor]:
788
+ n_vq = self._assert_fixed_nq(n_vq)
789
+ if self.audio_tokenizer is None:
790
+ raise RuntimeError("audio_tokenizer is not set.")
791
+ if isinstance(wav_list, torch.Tensor):
792
+ wav_list = [wav_list]
793
+ target_sr = int(self.model_config.sampling_rate)
794
+ device = self._get_audio_tokenizer_device()
795
+ prepared = []
796
+ for wav in wav_list:
797
+ if wav.ndim == 1:
798
+ wav = wav.unsqueeze(0)
799
+ if wav.shape[0] == 1:
800
+ wav = wav.repeat(2, 1)
801
+ elif wav.shape[0] > 2:
802
+ wav = wav[:2]
803
+ if int(sampling_rate) != target_sr:
804
+ wav = torchaudio.functional.resample(wav, int(sampling_rate), target_sr)
805
+ prepared.append(self.loudness_normalize(wav).to(device))
806
+
807
+ if hasattr(self.audio_tokenizer, "batch_encode"):
808
+ encoded = self.audio_tokenizer.batch_encode(prepared, num_quantizers=n_vq)
809
+ audio_codes = encoded.audio_codes
810
+ audio_lengths = encoded.audio_codes_lengths
811
+ else:
812
+ max_len = max(int(wav.shape[-1]) for wav in prepared)
813
+ input_values = torch.zeros(len(prepared), 1, max_len, dtype=torch.float32, device=device)
814
+ padding_mask = torch.zeros(len(prepared), max_len, dtype=torch.bool, device=device)
815
+ for index, wav in enumerate(prepared):
816
+ input_values[index, 0, : wav.shape[-1]] = wav
817
+ padding_mask[index, : wav.shape[-1]] = True
818
+ encoded = self.audio_tokenizer.encode(
819
+ input_values,
820
+ padding_mask=padding_mask,
821
+ num_quantizers=n_vq,
822
+ return_dict=True,
823
+ )
824
+ audio_codes = encoded.audio_codes
825
+ audio_lengths = encoded.audio_codes_lengths
826
+
827
+ if audio_codes is None or audio_lengths is None:
828
+ raise RuntimeError("audio_tokenizer did not return audio_codes/audio_codes_lengths.")
829
+ result = []
830
+ for index in range(int(audio_codes.shape[1])):
831
+ length = int(audio_lengths[index].item())
832
+ result.append(audio_codes[:, index, :length].transpose(0, 1).contiguous().cpu().long())
833
+ return result
834
+
835
+ def encode_audios_from_path(
836
+ self,
837
+ wav_path_list: Union[str, os.PathLike, list[Union[str, os.PathLike]]],
838
+ n_vq: Optional[int] = None,
839
+ ) -> list[torch.Tensor]:
840
+ if isinstance(wav_path_list, (str, os.PathLike)):
841
+ wav_path_list = [wav_path_list]
842
+ wavs = []
843
+ target_sr = int(self.model_config.sampling_rate)
844
+ for wav_path in wav_path_list:
845
+ wav, sr = torchaudio.load(str(wav_path))
846
+ if int(sr) != target_sr:
847
+ wav = torchaudio.functional.resample(wav, int(sr), target_sr)
848
+ wavs.append(wav)
849
+ return self.encode_audios_from_wav(wavs, target_sr, n_vq=n_vq)
850
+
851
+ def decode_audio_codes(
852
+ self,
853
+ audio_tokens_list: Union[torch.Tensor, list[torch.Tensor]],
854
+ *,
855
+ return_stereo: bool = True,
856
+ ) -> list[torch.Tensor]:
857
+ if self.audio_tokenizer is None:
858
+ raise RuntimeError("audio_tokenizer is not set.")
859
+ if isinstance(audio_tokens_list, torch.Tensor):
860
+ audio_tokens_list = [audio_tokens_list]
861
+ if not audio_tokens_list:
862
+ return []
863
+
864
+ n_vq = int(self.model_config.n_vq)
865
+ device = self._get_audio_tokenizer_device()
866
+ codes_list = [
867
+ codes[:, :n_vq].transpose(0, 1).contiguous().to(device=device, dtype=torch.long)
868
+ for codes in audio_tokens_list
869
+ ]
870
+ max_len = max(int(codes.shape[1]) for codes in codes_list)
871
+ audio_codes = torch.zeros(n_vq, len(codes_list), max_len, device=device, dtype=torch.long)
872
+ padding_mask = torch.zeros(len(codes_list), max_len, device=device, dtype=torch.bool)
873
+ for index, codes in enumerate(codes_list):
874
+ length = int(codes.shape[1])
875
+ audio_codes[:, index, :length] = codes
876
+ padding_mask[index, :length] = True
877
+
878
+ decoded = self.audio_tokenizer.decode(
879
+ audio_codes,
880
+ padding_mask=padding_mask,
881
+ num_quantizers=n_vq,
882
+ return_dict=True,
883
+ chunk_duration=8,
884
+ )
885
+ audio = decoded.audio
886
+ audio_lengths = decoded.audio_lengths
887
+ if audio is None or audio_lengths is None:
888
+ raise RuntimeError("audio_tokenizer.decode did not return audio/audio_lengths.")
889
+ wavs = []
890
+ for index in range(int(audio.shape[0])):
891
+ length = int(audio_lengths[index].item())
892
+ wav = audio[index, :, :length].contiguous().cpu().to(torch.float32)
893
+ if not return_stereo:
894
+ if wav.shape[0] == 1:
895
+ wav = wav.squeeze(0)
896
+ else:
897
+ wav = wav.mean(dim=0)
898
+ wavs.append(wav)
899
+ return wavs
processor_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "processor_class": "MossTTSLocalProcessor",
3
+ "audio_tokenizer_name_or_path": "OpenMOSS-Team/MOSS-Audio-Tokenizer-v2",
4
+ "auto_map": {
5
+ "AutoProcessor": "processing_moss_tts.MossTTSLocalProcessor"
6
+ }
7
+ }
qwen3_decoder.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from __future__ import annotations
3
+
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.utils.checkpoint
11
+ from safetensors.torch import load_file
12
+ from transformers.activations import ACT2FN
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast
14
+
15
+ from .gpt2_decoder import PackedSequenceMetadata, MossTTSNanoGPT2Model
16
+
17
+ try:
18
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
19
+ from flash_attn.bert_padding import pad_input, unpad_input
20
+
21
+ _FLASH_ATTN_AVAILABLE = True
22
+ except Exception:
23
+ flash_attn_func = None
24
+ flash_attn_varlen_func = None
25
+ pad_input = None
26
+ unpad_input = None
27
+ _FLASH_ATTN_AVAILABLE = False
28
+
29
+
30
+ class MossQwen3RMSNorm(nn.Module):
31
+ def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
32
+ super().__init__()
33
+ self.weight = nn.Parameter(torch.ones(hidden_size))
34
+ self.variance_epsilon = eps
35
+
36
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
37
+ input_dtype = hidden_states.dtype
38
+ hidden_states = hidden_states.to(torch.float32)
39
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
40
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
41
+ return self.weight * hidden_states.to(input_dtype)
42
+
43
+
44
+ class MossQwen3RotaryEmbedding(nn.Module):
45
+ def __init__(self, config) -> None:
46
+ super().__init__()
47
+ head_dim = int(getattr(config, "head_dim", config.hidden_size // config.num_attention_heads))
48
+ rope_theta = getattr(config, "rope_theta", None)
49
+ if rope_theta is None:
50
+ rope_scaling = getattr(config, "rope_scaling", None)
51
+ if isinstance(rope_scaling, dict):
52
+ rope_theta = rope_scaling.get("rope_theta")
53
+ if rope_theta is None:
54
+ rope_theta = 1000000.0
55
+ rope_theta = float(rope_theta)
56
+ self.head_dim = head_dim
57
+ self.rope_theta = rope_theta
58
+ self.register_buffer("inv_freq", self._compute_inv_freq(), persistent=False)
59
+
60
+ def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
61
+ return 1.0 / (
62
+ self.rope_theta ** (torch.arange(0, self.head_dim, 2, device=device, dtype=torch.float32) / self.head_dim)
63
+ )
64
+
65
+ def forward(
66
+ self,
67
+ hidden_states: torch.Tensor,
68
+ position_ids: torch.LongTensor,
69
+ ) -> tuple[torch.Tensor, torch.Tensor]:
70
+ inv_freq = self._compute_inv_freq(device=hidden_states.device)
71
+ freqs = torch.einsum(
72
+ "bs,d->bsd",
73
+ position_ids.to(device=hidden_states.device, dtype=inv_freq.dtype),
74
+ inv_freq,
75
+ )
76
+ emb = torch.cat((freqs, freqs), dim=-1)
77
+ return emb.cos().to(dtype=hidden_states.dtype), emb.sin().to(dtype=hidden_states.dtype)
78
+
79
+
80
+ def rotate_half(hidden_states: torch.Tensor) -> torch.Tensor:
81
+ first_half = hidden_states[..., : hidden_states.shape[-1] // 2]
82
+ second_half = hidden_states[..., hidden_states.shape[-1] // 2 :]
83
+ return torch.cat((-second_half, first_half), dim=-1)
84
+
85
+
86
+ def apply_rotary_pos_emb(
87
+ query: torch.Tensor,
88
+ key: torch.Tensor,
89
+ cos: torch.Tensor,
90
+ sin: torch.Tensor,
91
+ ) -> tuple[torch.Tensor, torch.Tensor]:
92
+ cos = cos.unsqueeze(-2)
93
+ sin = sin.unsqueeze(-2)
94
+ query = (query * cos) + (rotate_half(query) * sin)
95
+ key = (key * cos) + (rotate_half(key) * sin)
96
+ return query, key
97
+
98
+
99
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
100
+ if n_rep == 1:
101
+ return hidden_states
102
+ batch, seq_len, num_key_value_heads, head_dim = hidden_states.shape
103
+ hidden_states = hidden_states[:, :, :, None, :].expand(batch, seq_len, num_key_value_heads, n_rep, head_dim)
104
+ return hidden_states.reshape(batch, seq_len, num_key_value_heads * n_rep, head_dim)
105
+
106
+
107
+ class MossQwen3MLP(nn.Module):
108
+ def __init__(self, config) -> None:
109
+ super().__init__()
110
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
111
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
112
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
113
+ self.act_fn = ACT2FN[getattr(config, "hidden_act", "silu")]
114
+
115
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
116
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states))
117
+
118
+
119
+ class MossQwen3Attention(nn.Module):
120
+ def __init__(self, config, layer_idx: int) -> None:
121
+ super().__init__()
122
+ self.config = config
123
+ self.layer_idx = int(layer_idx)
124
+ self.hidden_size = int(config.hidden_size)
125
+ self.num_heads = int(config.num_attention_heads)
126
+ self.num_key_value_heads = int(config.num_key_value_heads)
127
+ self.head_dim = int(getattr(config, "head_dim", self.hidden_size // self.num_heads))
128
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
129
+ self.scaling = self.head_dim ** -0.5
130
+ self.attention_dropout = float(getattr(config, "attention_dropout", 0.0))
131
+ self.attn_implementation = str(getattr(config, "_attn_implementation", "eager"))
132
+
133
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=bool(config.attention_bias))
134
+ self.k_proj = nn.Linear(
135
+ self.hidden_size,
136
+ self.num_key_value_heads * self.head_dim,
137
+ bias=bool(config.attention_bias),
138
+ )
139
+ self.v_proj = nn.Linear(
140
+ self.hidden_size,
141
+ self.num_key_value_heads * self.head_dim,
142
+ bias=bool(config.attention_bias),
143
+ )
144
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=bool(config.attention_bias))
145
+ self.q_norm = MossQwen3RMSNorm(self.head_dim, eps=float(config.rms_norm_eps))
146
+ self.k_norm = MossQwen3RMSNorm(self.head_dim, eps=float(config.rms_norm_eps))
147
+
148
+ def _causal_attention_mask(
149
+ self,
150
+ attention_mask: Optional[torch.Tensor],
151
+ query_length: int,
152
+ key_length: int,
153
+ device: torch.device,
154
+ ) -> torch.Tensor:
155
+ query_positions = torch.arange(query_length, device=device, dtype=torch.long)
156
+ query_positions = query_positions + max(key_length - query_length, 0)
157
+ key_positions = torch.arange(key_length, device=device, dtype=torch.long)
158
+ causal = key_positions.unsqueeze(0) <= query_positions.unsqueeze(1)
159
+ causal = causal.unsqueeze(0).unsqueeze(0)
160
+ if attention_mask is None:
161
+ return causal
162
+ key_mask = attention_mask[:, None, None, :].to(dtype=torch.bool)
163
+ return causal & key_mask
164
+
165
+ def _eager_attention(
166
+ self,
167
+ query: torch.Tensor,
168
+ key: torch.Tensor,
169
+ value: torch.Tensor,
170
+ attention_mask: Optional[torch.Tensor],
171
+ ) -> torch.Tensor:
172
+ key = repeat_kv(key, self.num_key_value_groups)
173
+ value = repeat_kv(value, self.num_key_value_groups)
174
+ query = query.transpose(1, 2)
175
+ key = key.transpose(1, 2)
176
+ value = value.transpose(1, 2)
177
+ scores = torch.matmul(query, key.transpose(-1, -2)) * self.scaling
178
+ mask = self._causal_attention_mask(
179
+ attention_mask=attention_mask,
180
+ query_length=query.shape[-2],
181
+ key_length=key.shape[-2],
182
+ device=query.device,
183
+ )
184
+ scores = scores.masked_fill(~mask, torch.finfo(scores.dtype).min)
185
+ probs = torch.softmax(scores, dim=-1)
186
+ if self.training and self.attention_dropout > 0:
187
+ probs = torch.dropout(probs, self.attention_dropout, train=True)
188
+ output = torch.matmul(probs, value)
189
+ return output.transpose(1, 2).contiguous()
190
+
191
+ def _sdpa_attention(
192
+ self,
193
+ query: torch.Tensor,
194
+ key: torch.Tensor,
195
+ value: torch.Tensor,
196
+ attention_mask: Optional[torch.Tensor],
197
+ ) -> torch.Tensor:
198
+ key = repeat_kv(key, self.num_key_value_groups)
199
+ value = repeat_kv(value, self.num_key_value_groups)
200
+ query = query.transpose(1, 2)
201
+ key = key.transpose(1, 2)
202
+ value = value.transpose(1, 2)
203
+ mask = None
204
+ if attention_mask is not None or query.shape[-2] != key.shape[-2]:
205
+ mask = self._causal_attention_mask(
206
+ attention_mask=attention_mask,
207
+ query_length=query.shape[-2],
208
+ key_length=key.shape[-2],
209
+ device=query.device,
210
+ )
211
+ output = torch.nn.functional.scaled_dot_product_attention(
212
+ query,
213
+ key,
214
+ value,
215
+ attn_mask=mask,
216
+ dropout_p=self.attention_dropout if self.training else 0.0,
217
+ is_causal=mask is None,
218
+ scale=self.scaling,
219
+ )
220
+ return output.transpose(1, 2).contiguous()
221
+
222
+ def _flash_attention(
223
+ self,
224
+ query: torch.Tensor,
225
+ key: torch.Tensor,
226
+ value: torch.Tensor,
227
+ attention_mask: Optional[torch.Tensor],
228
+ packed_metadata: Optional[PackedSequenceMetadata],
229
+ ) -> torch.Tensor:
230
+ if not _FLASH_ATTN_AVAILABLE:
231
+ raise ImportError("flash_attn is not installed, but attn_implementation='flash_attention_2' was requested.")
232
+ if query.device.type != "cuda":
233
+ raise ValueError("flash_attention_2 requires CUDA tensors.")
234
+ if query.dtype not in (torch.float16, torch.bfloat16):
235
+ raise ValueError(f"flash_attention_2 requires fp16/bf16 tensors, got dtype={query.dtype}.")
236
+
237
+ dropout_p = self.attention_dropout if self.training else 0.0
238
+ if packed_metadata is not None:
239
+ if packed_metadata.indices is not None:
240
+ query = query.reshape(-1, self.num_heads, self.head_dim).index_select(0, packed_metadata.indices)
241
+ key = key.reshape(-1, self.num_key_value_heads, self.head_dim).index_select(0, packed_metadata.indices)
242
+ value = value.reshape(-1, self.num_key_value_heads, self.head_dim).index_select(0, packed_metadata.indices)
243
+ output = flash_attn_varlen_func(
244
+ query,
245
+ key,
246
+ value,
247
+ packed_metadata.cu_seqlens,
248
+ packed_metadata.cu_seqlens,
249
+ packed_metadata.max_seqlen,
250
+ packed_metadata.max_seqlen,
251
+ dropout_p=dropout_p,
252
+ causal=True,
253
+ )
254
+ if packed_metadata.indices is None:
255
+ return output
256
+ return pad_input(
257
+ output,
258
+ packed_metadata.indices,
259
+ packed_metadata.batch_size,
260
+ packed_metadata.seq_len,
261
+ )
262
+
263
+ if attention_mask is None or bool(attention_mask.all()):
264
+ return flash_attn_func(query, key, value, dropout_p=dropout_p, causal=True)
265
+
266
+ if query.shape[1] != key.shape[1]:
267
+ query_attention_mask = attention_mask[:, -query.shape[1] :]
268
+ unpadded_query, query_indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(
269
+ query,
270
+ query_attention_mask,
271
+ )
272
+ unpadded_key, _, cu_seqlens_k, max_seqlen_k, _ = unpad_input(key, attention_mask)
273
+ unpadded_value, _, _, _, _ = unpad_input(value, attention_mask)
274
+ output = flash_attn_varlen_func(
275
+ unpadded_query,
276
+ unpadded_key,
277
+ unpadded_value,
278
+ cu_seqlens_q,
279
+ cu_seqlens_k,
280
+ max_seqlen_q,
281
+ max_seqlen_k,
282
+ dropout_p=dropout_p,
283
+ causal=True,
284
+ )
285
+ return pad_input(output, query_indices, query.shape[0], query.shape[1])
286
+
287
+ unpadded_query, indices, cu_seqlens, max_seqlen, _ = unpad_input(query, attention_mask)
288
+ unpadded_key, _, _, _, _ = unpad_input(key, attention_mask)
289
+ unpadded_value, _, _, _, _ = unpad_input(value, attention_mask)
290
+ output = flash_attn_varlen_func(
291
+ unpadded_query,
292
+ unpadded_key,
293
+ unpadded_value,
294
+ cu_seqlens,
295
+ cu_seqlens,
296
+ max_seqlen,
297
+ max_seqlen,
298
+ dropout_p=dropout_p,
299
+ causal=True,
300
+ )
301
+ return pad_input(output, indices, query.shape[0], query.shape[1])
302
+
303
+ def forward(
304
+ self,
305
+ hidden_states: torch.Tensor,
306
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
307
+ attention_mask: Optional[torch.Tensor] = None,
308
+ packed_metadata: Optional[PackedSequenceMetadata] = None,
309
+ layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
310
+ use_cache: bool = False,
311
+ ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
312
+ input_shape = hidden_states.shape[:-1]
313
+ query_states = self.q_norm(
314
+ self.q_proj(hidden_states).view(*input_shape, self.num_heads, self.head_dim)
315
+ )
316
+ key_states = self.k_norm(
317
+ self.k_proj(hidden_states).view(*input_shape, self.num_key_value_heads, self.head_dim)
318
+ )
319
+ value_states = self.v_proj(hidden_states).view(*input_shape, self.num_key_value_heads, self.head_dim)
320
+
321
+ cos, sin = position_embeddings
322
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
323
+
324
+ if layer_past is not None:
325
+ past_key, past_value = layer_past
326
+ key_states = torch.cat([past_key.to(device=key_states.device, dtype=key_states.dtype), key_states], dim=1)
327
+ value_states = torch.cat(
328
+ [past_value.to(device=value_states.device, dtype=value_states.dtype), value_states],
329
+ dim=1,
330
+ )
331
+
332
+ present = (key_states, value_states) if use_cache else None
333
+ if self.attn_implementation == "flash_attention_2":
334
+ attn_output = self._flash_attention(
335
+ query=query_states,
336
+ key=key_states,
337
+ value=value_states,
338
+ attention_mask=attention_mask,
339
+ packed_metadata=packed_metadata,
340
+ )
341
+ elif self.attn_implementation == "sdpa":
342
+ attn_output = self._sdpa_attention(
343
+ query=query_states,
344
+ key=key_states,
345
+ value=value_states,
346
+ attention_mask=attention_mask,
347
+ )
348
+ else:
349
+ attn_output = self._eager_attention(
350
+ query=query_states,
351
+ key=key_states,
352
+ value=value_states,
353
+ attention_mask=attention_mask,
354
+ )
355
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
356
+ return self.o_proj(attn_output), present
357
+
358
+
359
+ class MossQwen3DecoderLayer(nn.Module):
360
+ def __init__(self, config, layer_idx: int) -> None:
361
+ super().__init__()
362
+ self.self_attn = MossQwen3Attention(config=config, layer_idx=layer_idx)
363
+ self.mlp = MossQwen3MLP(config)
364
+ self.input_layernorm = MossQwen3RMSNorm(config.hidden_size, eps=float(config.rms_norm_eps))
365
+ self.post_attention_layernorm = MossQwen3RMSNorm(config.hidden_size, eps=float(config.rms_norm_eps))
366
+
367
+ def forward(
368
+ self,
369
+ hidden_states: torch.Tensor,
370
+ attention_mask: Optional[torch.Tensor] = None,
371
+ packed_metadata: Optional[PackedSequenceMetadata] = None,
372
+ layer_past: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
373
+ use_cache: bool = False,
374
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
375
+ ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
376
+ residual = hidden_states
377
+ hidden_states = self.input_layernorm(hidden_states)
378
+ attn_output, present = self.self_attn(
379
+ hidden_states=hidden_states,
380
+ position_embeddings=position_embeddings,
381
+ attention_mask=attention_mask,
382
+ packed_metadata=packed_metadata,
383
+ layer_past=layer_past,
384
+ use_cache=use_cache,
385
+ )
386
+ hidden_states = residual + attn_output
387
+
388
+ residual = hidden_states
389
+ hidden_states = self.post_attention_layernorm(hidden_states)
390
+ hidden_states = residual + self.mlp(hidden_states)
391
+ return hidden_states, present
392
+
393
+
394
+ class MossQwen3Model(nn.Module):
395
+ def __init__(self, config) -> None:
396
+ super().__init__()
397
+ self.config = config
398
+ self.attn_implementation = str(getattr(config, "_attn_implementation", "eager"))
399
+ self.padding_idx = getattr(config, "pad_token_id", None)
400
+ self.vocab_size = int(config.vocab_size)
401
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
402
+ self.layers = nn.ModuleList(
403
+ [MossQwen3DecoderLayer(config, layer_idx=index) for index in range(config.num_hidden_layers)]
404
+ )
405
+ self.norm = MossQwen3RMSNorm(config.hidden_size, eps=float(config.rms_norm_eps))
406
+ self.rotary_emb = MossQwen3RotaryEmbedding(config)
407
+ self.gradient_checkpointing = False
408
+ self.gradient_checkpointing_use_reentrant = bool(
409
+ getattr(config, "gradient_checkpointing_use_reentrant", False)
410
+ )
411
+ self._reset_parameters()
412
+
413
+ def _reset_parameters(self) -> None:
414
+ init_std = float(getattr(self.config, "initializer_range", 0.02))
415
+ for module in self.modules():
416
+ if isinstance(module, nn.Linear):
417
+ nn.init.normal_(module.weight, mean=0.0, std=init_std)
418
+ if module.bias is not None:
419
+ nn.init.zeros_(module.bias)
420
+ elif isinstance(module, nn.Embedding):
421
+ nn.init.normal_(module.weight, mean=0.0, std=init_std)
422
+
423
+ def get_input_embeddings(self):
424
+ return self.embed_tokens
425
+
426
+ def set_input_embeddings(self, value):
427
+ self.embed_tokens = value
428
+
429
+ def load_qwen3_pretrained_weights(self, pretrained_path: str) -> None:
430
+ model_dir = Path(pretrained_path)
431
+ index_path = model_dir / "model.safetensors.index.json"
432
+ if not index_path.exists():
433
+ raise FileNotFoundError(f"Missing Qwen3 safetensors index: {index_path}")
434
+ with index_path.open("r", encoding="utf-8") as handle:
435
+ index = json.load(handle)
436
+ weight_map = index.get("weight_map", {})
437
+ shard_to_keys: dict[str, list[str]] = {}
438
+ for key, shard in weight_map.items():
439
+ if not key.startswith("model."):
440
+ continue
441
+ shard_to_keys.setdefault(str(shard), []).append(key)
442
+
443
+ state_dict = self.state_dict()
444
+ loaded_state = {}
445
+ for shard, keys in sorted(shard_to_keys.items()):
446
+ shard_tensors = load_file(str(model_dir / shard), device="cpu")
447
+ for key in keys:
448
+ target_key = key[len("model.") :]
449
+ if target_key not in state_dict:
450
+ continue
451
+ tensor = shard_tensors[key]
452
+ if tuple(tensor.shape) != tuple(state_dict[target_key].shape):
453
+ raise ValueError(
454
+ f"Shape mismatch while loading Qwen3 weight {key}: "
455
+ f"checkpoint={tuple(tensor.shape)} model={tuple(state_dict[target_key].shape)}"
456
+ )
457
+ loaded_state[target_key] = tensor
458
+
459
+ missing, unexpected = self.load_state_dict(loaded_state, strict=False)
460
+ unexpected = [key for key in unexpected if key]
461
+ if unexpected:
462
+ raise RuntimeError(f"Unexpected Qwen3 pretrained keys after load: {unexpected[:10]}")
463
+ missing = [key for key in missing if key not in loaded_state]
464
+ if missing:
465
+ raise RuntimeError(f"Missing Qwen3 pretrained keys after load: {missing[:10]}")
466
+
467
+ def forward(
468
+ self,
469
+ input_ids: Optional[torch.LongTensor] = None,
470
+ past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None,
471
+ attention_mask: Optional[torch.Tensor] = None,
472
+ position_ids: Optional[torch.LongTensor] = None,
473
+ inputs_embeds: Optional[torch.FloatTensor] = None,
474
+ use_cache: Optional[bool] = None,
475
+ output_attentions: Optional[bool] = None,
476
+ output_hidden_states: Optional[bool] = None,
477
+ return_dict: bool = True,
478
+ cu_seqlens: Optional[torch.Tensor] = None,
479
+ num_sequences: Optional[torch.Tensor] = None,
480
+ ) -> BaseModelOutputWithPast:
481
+ del input_ids, output_attentions
482
+ if inputs_embeds is None:
483
+ raise ValueError("inputs_embeds must be provided.")
484
+
485
+ use_cache = bool(use_cache)
486
+ if use_cache and cu_seqlens is not None:
487
+ raise ValueError("use_cache=True is not supported together with cu_seqlens packing.")
488
+
489
+ hidden_states = inputs_embeds
490
+ if attention_mask is None:
491
+ attention_mask = torch.ones(hidden_states.shape[:2], dtype=torch.bool, device=hidden_states.device)
492
+ else:
493
+ attention_mask = attention_mask.to(dtype=torch.bool, device=hidden_states.device)
494
+ query_attention_mask = attention_mask[:, -hidden_states.shape[1] :]
495
+
496
+ packed_metadata = None
497
+ if position_ids is None:
498
+ if cu_seqlens is not None:
499
+ position_ids = MossTTSNanoGPT2Model.build_packed_position_ids(
500
+ attention_mask=attention_mask,
501
+ cu_seqlens=cu_seqlens.to(device=hidden_states.device),
502
+ num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
503
+ sequence_length=hidden_states.shape[1],
504
+ )
505
+ elif attention_mask is not None:
506
+ position_ids = attention_mask.long().cumsum(dim=-1) - 1
507
+ position_ids = position_ids.masked_fill(~attention_mask, 0)
508
+ position_ids = position_ids[:, -hidden_states.shape[1] :]
509
+ else:
510
+ past_length = 0
511
+ if past_key_values is not None and len(past_key_values) > 0:
512
+ past_length = past_key_values[0][0].shape[1]
513
+ position_ids = torch.arange(hidden_states.shape[1], device=hidden_states.device, dtype=torch.long)
514
+ position_ids = position_ids + past_length
515
+ position_ids = position_ids.unsqueeze(0).expand(hidden_states.shape[0], -1)
516
+
517
+ if cu_seqlens is not None and self.attn_implementation == "flash_attention_2":
518
+ packed_metadata = MossTTSNanoGPT2Model.build_packed_metadata(
519
+ hidden_states=hidden_states,
520
+ cu_seqlens=cu_seqlens.to(device=hidden_states.device),
521
+ num_sequences=num_sequences.to(device=hidden_states.device) if num_sequences is not None else None,
522
+ )
523
+
524
+ hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
525
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
526
+
527
+ all_hidden_states = () if output_hidden_states else None
528
+ presents = [] if use_cache else None
529
+ for layer_index, decoder_layer in enumerate(self.layers):
530
+ if output_hidden_states:
531
+ all_hidden_states = all_hidden_states + (hidden_states,)
532
+
533
+ if self.gradient_checkpointing and self.training:
534
+ if use_cache:
535
+ raise ValueError("use_cache=True is not supported when gradient checkpointing is enabled during training.")
536
+
537
+ def custom_forward(*inputs):
538
+ output, _ = decoder_layer(
539
+ hidden_states=inputs[0],
540
+ attention_mask=inputs[1],
541
+ packed_metadata=packed_metadata,
542
+ layer_past=None,
543
+ use_cache=False,
544
+ position_embeddings=position_embeddings,
545
+ )
546
+ return output
547
+
548
+ hidden_states = torch.utils.checkpoint.checkpoint(
549
+ custom_forward,
550
+ hidden_states,
551
+ attention_mask,
552
+ use_reentrant=self.gradient_checkpointing_use_reentrant,
553
+ )
554
+ present = None
555
+ else:
556
+ hidden_states, present = decoder_layer(
557
+ hidden_states=hidden_states,
558
+ attention_mask=attention_mask,
559
+ packed_metadata=packed_metadata,
560
+ layer_past=None if past_key_values is None else past_key_values[layer_index],
561
+ use_cache=use_cache,
562
+ position_embeddings=position_embeddings,
563
+ )
564
+
565
+ hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
566
+ if presents is not None:
567
+ presents.append(present)
568
+
569
+ hidden_states = self.norm(hidden_states)
570
+ hidden_states = hidden_states * query_attention_mask.unsqueeze(-1).to(dtype=hidden_states.dtype)
571
+ if output_hidden_states:
572
+ all_hidden_states = all_hidden_states + (hidden_states,)
573
+
574
+ if not return_dict:
575
+ return (hidden_states, tuple(presents) if presents is not None else None, all_hidden_states, None)
576
+
577
+ return BaseModelOutputWithPast(
578
+ last_hidden_state=hidden_states,
579
+ past_key_values=tuple(presents) if presents is not None else None,
580
+ hidden_states=all_hidden_states,
581
+ attentions=None,
582
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|audio_start|>",
4
+ "<|audio_end|>",
5
+ "<|audio_pad|>"
6
+ ],
7
+ "eos_token": {
8
+ "content": "<|im_end|>",
9
+ "lstrip": false,
10
+ "normalized": false,
11
+ "rstrip": false,
12
+ "single_word": false
13
+ },
14
+ "pad_token": {
15
+ "content": "<|endoftext|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ }
21
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06902d1fb775216338802205886a24bc715ccc606bd872a892e3d3c83ca1b9e2
3
+ size 11423220
tokenizer_config.json ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ },
213
+ "151669": {
214
+ "content": "<|audio_start|>",
215
+ "lstrip": false,
216
+ "normalized": false,
217
+ "rstrip": false,
218
+ "single_word": false,
219
+ "special": true
220
+ },
221
+ "151670": {
222
+ "content": "<|audio_end|>",
223
+ "lstrip": false,
224
+ "normalized": false,
225
+ "rstrip": false,
226
+ "single_word": false,
227
+ "special": true
228
+ },
229
+ "151671": {
230
+ "content": "<|audio_pad|>",
231
+ "lstrip": false,
232
+ "normalized": false,
233
+ "rstrip": false,
234
+ "single_word": false,
235
+ "special": true
236
+ }
237
+ },
238
+ "additional_special_tokens": [
239
+ "<|audio_start|>",
240
+ "<|audio_end|>",
241
+ "<|audio_pad|>"
242
+ ],
243
+ "bos_token": null,
244
+ "clean_up_tokenization_spaces": false,
245
+ "eos_token": "<|im_end|>",
246
+ "errors": "replace",
247
+ "extra_special_tokens": {},
248
+ "model_max_length": 131072,
249
+ "pad_token": "<|endoftext|>",
250
+ "split_special_tokens": false,
251
+ "tokenizer_class": "Qwen2Tokenizer",
252
+ "unk_token": null
253
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff