MOSS-TTS-Local-Transformer-v1.5 / configuration_moss_tts.py
fdugyt's picture
Add files using upload-large-folder tool
2b6d647 verified
Raw
History Blame Contribute Delete
7.16 kB
# coding=utf-8
"""Configuration for the MOSS-TTS-Local-Transformer-v1.5 release."""
from __future__ import annotations
from typing import Any, Dict, Optional, Union
from transformers.configuration_utils import PretrainedConfig
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
SUPPORTED_ATTENTION_IMPLEMENTATIONS = {"flash_attention_2", "sdpa", "eager"}
def _normalize_attention_implementation(value: Optional[str], default: str = "flash_attention_2") -> str:
normalized = str(value or default).strip().lower()
if normalized in {"flash", "flash_attn", "flash-attn", "flash_attention"}:
normalized = "flash_attention_2"
if normalized not in SUPPORTED_ATTENTION_IMPLEMENTATIONS:
raise ValueError(
"attn_implementation must be one of "
f"{sorted(SUPPORTED_ATTENTION_IMPLEMENTATIONS)}, got {value!r}."
)
return normalized
class MossTTSLocalConfig(PretrainedConfig):
model_type = "moss_tts_local"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
qwen3_config: Optional[Union[Qwen3Config, Dict[str, Any]]] = None,
gpt2_config: Optional[Union[GPT2Config, Dict[str, Any]]] = None,
language_config: Optional[Union[Qwen3Config, Dict[str, Any]]] = None,
n_vq: int = 12,
audio_vocab_size: int = 1024,
audio_codebook_sizes: Optional[list[int]] = None,
audio_pad_token_id: int = 1024,
audio_pad_code: Optional[int] = None,
pad_token_id: int = 151643,
im_start_token_id: int = 151644,
im_end_token_id: int = 151645,
audio_start_token_id: int = 151669,
audio_end_token_id: int = 151670,
audio_user_slot_token_id: int = 151654,
audio_assistant_slot_token_id: int = 151656,
audio_assistant_gen_slot_token_id: Optional[int] = None,
sampling_rate: int = 48000,
audio_tokenizer_name_or_path: Optional[str] = None,
attn_implementation: str = "flash_attention_2",
local_transformer_attn_implementation: Optional[str] = None,
local_text_head_mode: str = "binary",
initializer_range: float = 0.02,
**kwargs: Any,
) -> None:
if qwen3_config is None and language_config is not None:
qwen3_config = language_config
if isinstance(qwen3_config, dict):
self.qwen3_config = Qwen3Config(**qwen3_config)
elif qwen3_config is None:
self.qwen3_config = Qwen3Config()
else:
self.qwen3_config = qwen3_config
if isinstance(gpt2_config, dict):
self.gpt2_config = GPT2Config(**gpt2_config)
elif gpt2_config is None:
self.gpt2_config = GPT2Config(
vocab_size=int(self.qwen3_config.vocab_size),
n_embd=int(self.qwen3_config.hidden_size),
n_layer=1,
n_head=max(1, int(self.qwen3_config.hidden_size) // 80),
n_positions=int(n_vq) + 1,
n_ctx=int(n_vq) + 1,
activation_function="silu",
layer_norm_epsilon=1e-6,
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0,
)
else:
self.gpt2_config = gpt2_config
self.n_vq = int(n_vq)
if self.n_vq <= 0:
raise ValueError("n_vq must be positive.")
if audio_codebook_sizes is None:
self.audio_codebook_sizes = [int(audio_vocab_size)] * self.n_vq
else:
self.audio_codebook_sizes = [int(size) for size in audio_codebook_sizes]
if len(self.audio_codebook_sizes) != self.n_vq:
raise ValueError(
f"audio_codebook_sizes must have length n_vq={self.n_vq}, "
f"got {len(self.audio_codebook_sizes)}."
)
if any(size <= 0 for size in self.audio_codebook_sizes):
raise ValueError("audio_codebook_sizes must contain positive integers.")
self.audio_vocab_size = int(max(int(audio_vocab_size), max(self.audio_codebook_sizes)))
self.audio_pad_token_id = int(audio_pad_code if audio_pad_code is not None else audio_pad_token_id)
self.audio_pad_code = self.audio_pad_token_id
if self.audio_pad_token_id < self.audio_vocab_size:
raise ValueError("audio_pad_token_id/audio_pad_code must be outside the audio vocab.")
self.pad_token_id = int(pad_token_id)
self.im_start_token_id = int(im_start_token_id)
self.im_end_token_id = int(im_end_token_id)
self.audio_start_token_id = int(audio_start_token_id)
self.audio_end_token_id = int(audio_end_token_id)
self.audio_user_slot_token_id = int(audio_user_slot_token_id)
self.audio_assistant_slot_token_id = int(
audio_assistant_slot_token_id
if audio_assistant_gen_slot_token_id is None
else audio_assistant_gen_slot_token_id
)
self.audio_assistant_gen_slot_token_id = self.audio_assistant_slot_token_id
self.sampling_rate = int(sampling_rate)
self.audio_tokenizer_name_or_path = audio_tokenizer_name_or_path
self.attn_implementation = _normalize_attention_implementation(attn_implementation)
self.local_transformer_attn_implementation = _normalize_attention_implementation(
local_transformer_attn_implementation,
default=self.attn_implementation,
)
self.initializer_range = float(initializer_range)
self.hidden_size = int(self.qwen3_config.hidden_size)
self.vocab_size = int(self.qwen3_config.vocab_size)
self.local_hidden_size = int(self.gpt2_config.hidden_size)
if self.local_hidden_size != self.hidden_size:
raise ValueError(
"This MOSS-TTS-Local-Transformer-v1.5 release expects local hidden size to "
"match Qwen3 hidden size so audio embeddings and heads are tied."
)
normalized_text_head_mode = str(local_text_head_mode or "full_vocab").strip().lower()
if normalized_text_head_mode in {"full", "full-vocab", "vocab"}:
normalized_text_head_mode = "full_vocab"
if normalized_text_head_mode not in {"full_vocab", "binary"}:
raise ValueError("local_text_head_mode must be 'full_vocab' or 'binary'.")
self.local_text_head_mode = normalized_text_head_mode
kwargs.setdefault("tie_word_embeddings", True)
super().__init__(pad_token_id=self.pad_token_id, **kwargs)
@property
def language_config(self) -> Qwen3Config:
return self.qwen3_config
def to_dict(self) -> Dict[str, Any]:
output = super().to_dict()
output["qwen3_config"] = self.qwen3_config.to_dict()
output["language_config"] = self.qwen3_config.to_dict()
output["gpt2_config"] = self.gpt2_config.to_dict()
output["audio_pad_code"] = self.audio_pad_token_id
output["audio_assistant_gen_slot_token_id"] = self.audio_assistant_slot_token_id
return output