MOSS-TTS-Local-Transformer-v1.5 / modeling_moss_tts.py
fdugyt's picture
Add files using upload-large-folder tool
2b6d647 verified
Raw
History Blame Contribute Delete
26.4 kB
# coding=utf-8
"""Modeling code for the MOSS-TTS-Local-Transformer-v1.5 HuggingFace release."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Optional, Union
import torch
import torch.nn as nn
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.utils import ModelOutput
from .configuration_moss_tts import MossTTSLocalConfig
from .gpt2_decoder import MossTTSNanoGPT2Model
from .qwen3_decoder import MossQwen3Model
@dataclass
class MossTTSLocalOutput(ModelOutput):
last_hidden_state: Optional[torch.FloatTensor] = None
past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
def _find_last_equal(input_ids: torch.LongTensor, value: int) -> torch.LongTensor:
matches = input_ids.eq(int(value))
if not bool(matches.any(dim=1).all().item()):
raise ValueError(f"Every sample must contain token id {int(value)}.")
positions = torch.arange(input_ids.shape[1], device=input_ids.device, dtype=torch.long)
masked_positions = positions.unsqueeze(0).masked_fill(~matches, -1)
return masked_positions.max(dim=1).values
class MossTTSLocalPreTrainedModel(PreTrainedModel):
config_class = MossTTSLocalConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["MossTTSNanoGPT2Block", "MossQwen3DecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False) -> None:
if isinstance(module, MossTTSNanoGPT2Model) or isinstance(module, MossQwen3Model):
module.gradient_checkpointing = value
class MossTTSLocalModel(MossTTSLocalPreTrainedModel):
_tied_weights_keys = None
def __init__(self, config: MossTTSLocalConfig) -> None:
super().__init__(config)
self._tied_weights_keys = self._build_tied_weights_keys(config)
config.qwen3_config.pad_token_id = config.pad_token_id
config.qwen3_config._attn_implementation = config.attn_implementation
local_gpt2_config = config.gpt2_config.to_dict()
local_gpt2_config["n_layer"] = int(getattr(config, "local_transformer_layers", config.gpt2_config.n_layer))
local_gpt2_config["n_positions"] = int(config.n_vq) + 1
local_gpt2_config["n_ctx"] = int(config.n_vq) + 1
local_gpt2_config = GPT2Config(**local_gpt2_config)
local_gpt2_config.pad_token_id = config.pad_token_id
local_gpt2_config._attn_implementation = config.local_transformer_attn_implementation
self.transformer = MossQwen3Model(config.qwen3_config)
self.local_transformer = MossTTSNanoGPT2Model(
local_gpt2_config,
attn_implementation=config.local_transformer_attn_implementation,
)
self.local_transformer.wte = nn.Identity()
hidden_size = int(config.hidden_size)
self.audio_embeddings = nn.ModuleList(
[
nn.Embedding(int(config.audio_codebook_sizes[index]), hidden_size)
for index in range(config.n_vq)
]
)
self.text_lm_head = nn.Linear(hidden_size, int(config.vocab_size), bias=False)
self.audio_lm_heads = nn.ModuleList(
[
nn.Linear(hidden_size, int(config.audio_codebook_sizes[index]), bias=False)
for index in range(config.n_vq)
]
)
self.local_text_lm_head = (
nn.Linear(hidden_size, 2, bias=False)
if self._use_binary_local_text_head()
else None
)
self.post_init()
self.tie_weights()
self.initialize_local_text_lm_head_from_text_lm_head()
def can_generate(self) -> bool:
return True
@staticmethod
def _build_tied_weights_keys(config: MossTTSLocalConfig) -> dict[str, str]:
tied_weights = {"text_lm_head.weight": "transformer.embed_tokens.weight"}
tied_weights.update(
{
f"audio_lm_heads.{index}.weight": f"audio_embeddings.{index}.weight"
for index in range(config.n_vq)
}
)
return tied_weights
def tie_weights(self, *args, **kwargs) -> None:
del args, kwargs
self.text_lm_head.weight = self.transformer.embed_tokens.weight
for embedding, head in zip(self.audio_embeddings, self.audio_lm_heads):
head.weight = embedding.weight
def get_input_embeddings(self) -> nn.Embedding:
return self.transformer.embed_tokens
def set_input_embeddings(self, value: nn.Embedding) -> None:
self.transformer.embed_tokens = value
self.tie_weights()
self.initialize_local_text_lm_head_from_text_lm_head()
def get_output_embeddings(self) -> nn.Linear:
return self.text_lm_head
def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
self.text_lm_head = new_embeddings
self.tie_weights()
self.initialize_local_text_lm_head_from_text_lm_head()
def _use_binary_local_text_head(self) -> bool:
return str(getattr(self.config, "local_text_head_mode", "full_vocab")).strip().lower() == "binary"
def _local_text_candidate_ids(self, device: torch.device) -> torch.LongTensor:
return torch.tensor(
[
int(self.config.audio_assistant_slot_token_id),
int(self.config.audio_end_token_id),
],
dtype=torch.long,
device=device,
)
def initialize_local_text_lm_head_from_text_lm_head(self) -> None:
if not self._use_binary_local_text_head() or self.local_text_lm_head is None:
return
candidate_ids = self._local_text_candidate_ids(self.text_lm_head.weight.device)
with torch.no_grad():
source_weight = self.text_lm_head.weight.index_select(0, candidate_ids)
if tuple(source_weight.shape) == tuple(self.local_text_lm_head.weight.shape):
self.local_text_lm_head.weight.copy_(
source_weight.to(
device=self.local_text_lm_head.weight.device,
dtype=self.local_text_lm_head.weight.dtype,
)
)
def _resolve_fixed_nq(
self,
n_vq_for_inference: Optional[int] = None,
nq: Optional[int] = None,
) -> int:
requested = n_vq_for_inference if n_vq_for_inference is not None else nq
config_nq = int(self.config.n_vq)
if requested is not None and int(requested) != config_nq:
raise ValueError(
"This MOSS-TTS-Local-Transformer-v1.5 release is trained with a fixed RVQ depth. "
f"Expected n_vq={config_nq}, got {int(requested)}."
)
return config_nq
def _build_inputs_embeds(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
if input_ids.ndim != 3 or input_ids.shape[-1] != self.config.n_vq + 1:
raise ValueError(
f"Expected input_ids shape [batch, seq, {self.config.n_vq + 1}], "
f"got {tuple(input_ids.shape)}."
)
text_ids = input_ids[..., 0]
inputs_embeds = self.transformer.embed_tokens(text_ids)
for channel_index, embedding in enumerate(self.audio_embeddings):
channel_ids = input_ids[..., channel_index + 1]
valid_mask = channel_ids.ne(self.config.audio_pad_token_id)
safe_ids = channel_ids.masked_fill(~valid_mask, 0)
audio_embeds = embedding(safe_ids) * valid_mask.unsqueeze(-1)
inputs_embeds = inputs_embeds + audio_embeds
return inputs_embeds
def _global_hidden_to_local(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
return hidden_states
@staticmethod
def _local_past_length(past_key_values: Optional[tuple[Any, ...]]) -> int:
if past_key_values is None or len(past_key_values) == 0:
return 0
first_layer_past = past_key_values[0]
if isinstance(first_layer_past, dict) and bool(first_layer_past.get("static_kv_cache", False)):
return int(first_layer_past.get("length", 0))
return int(first_layer_past[0].shape[1])
def _new_static_local_past_key_values(
self,
batch_size: int,
max_length: int,
device: torch.device,
dtype: torch.dtype,
) -> tuple[dict[str, Any], ...]:
layers = []
for block in self.local_transformer.h:
attn = block.attn
cache_shape = (
int(batch_size),
int(max_length),
int(attn.num_heads),
int(attn.head_dim),
)
layers.append(
{
"static_kv_cache": True,
"key": torch.empty(cache_shape, device=device, dtype=dtype),
"value": torch.empty(cache_shape, device=device, dtype=dtype),
"length": 0,
}
)
return tuple(layers)
def _decode_local_hidden_states_with_cache(
self,
local_inputs_embeds: torch.FloatTensor,
past_key_values: Optional[tuple[Any, ...]] = None,
) -> tuple[torch.FloatTensor, Optional[tuple[Any, ...]]]:
if (
past_key_values is None
and not self.training
and bool(getattr(self.config, "use_static_local_kv_cache", True))
):
max_length = max(int(getattr(self.config, "n_vq", 0)) + 1, int(local_inputs_embeds.shape[1]))
past_key_values = self._new_static_local_past_key_values(
batch_size=int(local_inputs_embeds.shape[0]),
max_length=max_length,
device=local_inputs_embeds.device,
dtype=local_inputs_embeds.dtype,
)
past_length = self._local_past_length(past_key_values)
local_seq_len = int(local_inputs_embeds.shape[1])
local_position_ids = torch.arange(
past_length,
past_length + local_seq_len,
device=local_inputs_embeds.device,
dtype=torch.long,
).unsqueeze(0)
if int(local_inputs_embeds.shape[0]) != 1:
local_position_ids = local_position_ids.expand(int(local_inputs_embeds.shape[0]), -1)
local_outputs = self.local_transformer(
input_ids=None,
past_key_values=past_key_values,
attention_mask=None,
position_ids=local_position_ids,
inputs_embeds=local_inputs_embeds,
use_cache=True,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
cu_seqlens=None,
num_sequences=None,
)
return local_outputs.last_hidden_state, local_outputs.past_key_values
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = True,
**kwargs,
) -> Union[tuple, MossTTSLocalOutput]:
del kwargs
if inputs_embeds is None:
if input_ids is None:
raise ValueError("Either input_ids or inputs_embeds must be provided.")
inputs_embeds = self._build_inputs_embeds(input_ids)
outputs = self.transformer(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
cu_seqlens=None,
num_sequences=None,
)
if not return_dict:
return (
outputs.last_hidden_state,
outputs.past_key_values,
outputs.hidden_states,
outputs.attentions,
)
return MossTTSLocalOutput(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def _decode_local_last_hidden_state(
self,
local_inputs_embeds: torch.FloatTensor,
) -> torch.FloatTensor:
local_seq_len = int(local_inputs_embeds.shape[1])
local_position_ids = torch.arange(
0,
local_seq_len,
device=local_inputs_embeds.device,
dtype=torch.long,
).unsqueeze(0)
if int(local_inputs_embeds.shape[0]) != 1:
local_position_ids = local_position_ids.expand(int(local_inputs_embeds.shape[0]), -1)
local_outputs = self.local_transformer(
input_ids=None,
attention_mask=None,
position_ids=local_position_ids,
inputs_embeds=local_inputs_embeds,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
cu_seqlens=None,
num_sequences=None,
)
return local_outputs.last_hidden_state[:, -1, :]
def _filter_logits(
self,
logits: torch.FloatTensor,
top_k: Optional[int],
top_p: Optional[float],
) -> torch.FloatTensor:
scores = logits
if top_k is not None and int(top_k) > 0 and int(top_k) < scores.shape[-1]:
kth = torch.topk(scores, int(top_k), dim=-1).values[..., -1, None]
scores = scores.masked_fill(scores < kth, -torch.inf)
if top_p is not None and 0.0 < float(top_p) < 1.0:
sorted_scores, sorted_indices = torch.sort(scores, descending=True, dim=-1)
sorted_probs = torch.softmax(sorted_scores, dim=-1)
cumulative_probs = sorted_probs.cumsum(dim=-1)
sorted_mask = cumulative_probs > float(top_p)
sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
sorted_mask[..., 0] = False
remove_mask = torch.zeros_like(scores, dtype=torch.bool)
remove_mask.scatter_(dim=-1, index=sorted_indices, src=sorted_mask)
scores = scores.masked_fill(remove_mask, -torch.inf)
return scores
def _apply_repetition_penalty(
self,
scores: torch.FloatTensor,
previous_token_ids: Optional[torch.LongTensor],
penalty: float,
) -> torch.FloatTensor:
if previous_token_ids is None or float(penalty) == 1.0:
return scores
if previous_token_ids.ndim == 1:
previous_token_ids = previous_token_ids.unsqueeze(0)
updated = scores.clone()
for batch_index in range(updated.shape[0]):
unique_token_ids = torch.unique(previous_token_ids[batch_index])
unique_token_ids = unique_token_ids[
(unique_token_ids >= 0) & (unique_token_ids < updated.shape[-1])
]
if unique_token_ids.numel() == 0:
continue
token_scores = updated[batch_index].index_select(0, unique_token_ids)
token_scores = torch.where(
token_scores < 0,
token_scores * float(penalty),
token_scores / float(penalty),
)
updated[batch_index].scatter_(0, unique_token_ids, token_scores)
return updated
def _sample_next_token(
self,
logits: torch.FloatTensor,
do_sample: bool,
temperature: float,
top_k: Optional[int],
top_p: Optional[float],
previous_token_ids: Optional[torch.LongTensor] = None,
repetition_penalty: float = 1.0,
) -> torch.LongTensor:
scores = logits.float()
scores = self._apply_repetition_penalty(scores, previous_token_ids, repetition_penalty)
if not do_sample:
return torch.argmax(scores, dim=-1)
if float(temperature) <= 0:
raise ValueError("temperature must be positive when do_sample=True.")
scores = scores / float(temperature)
scores = self._filter_logits(scores, top_k=top_k, top_p=top_p)
probs = torch.softmax(scores, dim=-1)
return torch.multinomial(probs, num_samples=1).squeeze(-1)
def _sample_next_assistant_text_token(
self,
local_hidden_states: torch.FloatTensor,
do_sample: bool,
temperature: float,
top_k: Optional[int],
top_p: Optional[float],
) -> torch.LongTensor:
if self._use_binary_local_text_head() and self.local_text_lm_head is not None:
logits = self.local_text_lm_head(local_hidden_states)
sampled_indices = self._sample_next_token(
logits=logits,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
candidate_ids = self._local_text_candidate_ids(logits.device)
return candidate_ids[sampled_indices]
candidate_ids = self._local_text_candidate_ids(local_hidden_states.device)
logits = self.text_lm_head(local_hidden_states).index_select(dim=-1, index=candidate_ids)
sampled_indices = self._sample_next_token(
logits=logits,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
return candidate_ids[sampled_indices]
def _build_generation_row(
self,
batch_size: int,
device: torch.device,
audio_token_ids: torch.LongTensor,
) -> torch.LongTensor:
row = torch.full(
(batch_size, 1, self.config.n_vq + 1),
int(self.config.audio_pad_token_id),
dtype=torch.long,
device=device,
)
row[:, :, 0] = int(self.config.audio_assistant_slot_token_id)
row[:, :, 1:] = audio_token_ids.unsqueeze(1)
return row
@torch.inference_mode()
def generate(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
max_new_tokens: Optional[int] = None,
max_new_frames: Optional[int] = None,
do_sample: bool = True,
text_temperature: float = 1.0,
text_top_p: float = 1.0,
text_top_k: int = 50,
audio_temperature: Optional[float] = None,
audio_top_p: Optional[float] = None,
audio_top_k: Optional[int] = None,
audio_repetition_penalty: Optional[float] = None,
temperature: float = 1.0,
top_p: float = 0.95,
top_k: int = 50,
repetition_penalty: float = 1.0,
use_kv_cache: bool = True,
n_vq_for_inference: Optional[int] = None,
nq: Optional[int] = None,
**kwargs,
) -> list[tuple[int, torch.LongTensor]]:
del kwargs
self._resolve_fixed_nq(n_vq_for_inference=n_vq_for_inference, nq=nq)
if input_ids.ndim == 2:
input_ids = input_ids.unsqueeze(0)
if input_ids.ndim != 3:
raise ValueError(f"Expected input_ids with 3 dims, got {tuple(input_ids.shape)}.")
if input_ids.shape[-1] != self.config.n_vq + 1:
raise ValueError(
f"Expected {self.config.n_vq + 1} channels from config.n_vq, got {input_ids.shape[-1]}."
)
if attention_mask is None:
attention_mask = torch.ones(input_ids.shape[:2], dtype=torch.bool, device=input_ids.device)
elif attention_mask.ndim == 1:
attention_mask = attention_mask.unsqueeze(0)
attention_mask = attention_mask.to(device=input_ids.device, dtype=torch.bool)
frame_budget = max_new_frames if max_new_frames is not None else max_new_tokens
if frame_budget is None:
frame_budget = 4096
frame_budget = int(frame_budget)
audio_temperature = float(temperature if audio_temperature is None else audio_temperature)
audio_top_p = float(top_p if audio_top_p is None else audio_top_p)
audio_top_k = int(top_k if audio_top_k is None else audio_top_k)
audio_repetition_penalty = float(
repetition_penalty if audio_repetition_penalty is None else audio_repetition_penalty
)
batch_size = input_ids.shape[0]
input_ids_length = input_ids.shape[1]
current_input_ids = input_ids
current_attention_mask = attention_mask
current_model_input_ids = current_input_ids
generated_frames: list[torch.LongTensor] = []
finished = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
past_key_values = None
local_dtype = self.local_transformer.ln_f.weight.dtype
for _ in range(frame_budget):
generated_audio_history = torch.stack(generated_frames, dim=1) if generated_frames else None
global_inputs_embeds = self._build_inputs_embeds(current_model_input_ids)
global_outputs = self.transformer(
input_ids=None,
past_key_values=past_key_values,
attention_mask=current_attention_mask,
position_ids=None,
inputs_embeds=global_inputs_embeds,
use_cache=use_kv_cache,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
cu_seqlens=None,
num_sequences=None,
)
global_hidden_states = global_outputs.last_hidden_state[:, -1, :]
local_global_hidden_states = self._global_hidden_to_local(global_hidden_states).to(dtype=local_dtype)
local_prefix_hidden_states, local_prefix_past_key_values = self._decode_local_hidden_states_with_cache(
local_global_hidden_states.unsqueeze(1)
)
local_hidden_states = local_prefix_hidden_states[:, -1, :]
next_text_tokens = self._sample_next_assistant_text_token(
local_hidden_states=local_hidden_states,
do_sample=do_sample,
temperature=text_temperature,
top_k=text_top_k,
top_p=text_top_p,
)
should_continue = next_text_tokens.eq(int(self.config.audio_assistant_slot_token_id)) & ~finished
finished = finished | next_text_tokens.eq(int(self.config.audio_end_token_id))
if not bool(should_continue.any().item()):
break
next_frame_tokens = []
for channel_index in range(int(self.config.n_vq)):
channel_logits = self.audio_lm_heads[channel_index](local_hidden_states)
channel_token = self._sample_next_token(
logits=channel_logits,
do_sample=do_sample,
temperature=audio_temperature,
top_k=audio_top_k,
top_p=audio_top_p,
previous_token_ids=(
None
if generated_audio_history is None
else generated_audio_history[:, :, channel_index]
),
repetition_penalty=audio_repetition_penalty,
)
next_frame_tokens.append(channel_token)
if channel_index + 1 < int(self.config.n_vq):
current_local_input = self.audio_embeddings[channel_index](channel_token).to(dtype=local_dtype)
local_token_hidden_states, local_prefix_past_key_values = (
self._decode_local_hidden_states_with_cache(
current_local_input.unsqueeze(1),
past_key_values=local_prefix_past_key_values,
)
)
local_hidden_states = local_token_hidden_states[:, -1, :]
next_frame = torch.stack(next_frame_tokens, dim=-1)
next_frame = next_frame.masked_fill(
~should_continue.unsqueeze(-1),
int(self.config.audio_pad_token_id),
)
generated_frames.append(next_frame)
next_row = self._build_generation_row(
batch_size=batch_size,
device=input_ids.device,
audio_token_ids=next_frame,
)
if bool((~should_continue).any().item()):
next_row[~should_continue, 0, 0] = int(self.config.pad_token_id)
next_row[~should_continue, 0, 1:] = int(self.config.audio_pad_token_id)
current_input_ids = torch.cat([current_input_ids, next_row], dim=1)
current_attention_mask = torch.cat(
[current_attention_mask, should_continue.unsqueeze(1)],
dim=1,
)
if use_kv_cache:
current_model_input_ids = next_row
past_key_values = global_outputs.past_key_values
else:
current_model_input_ids = current_input_ids
start_indices = _find_last_equal(input_ids[..., 0], int(self.config.audio_start_token_id))
start_lengths = input_ids_length - start_indices - 1
outputs: list[tuple[int, torch.LongTensor]] = []
for start_index, start_length, generation_ids in zip(
start_indices.tolist(),
start_lengths.tolist(),
current_input_ids,
):
outputs.append((int(start_length), generation_ids[int(start_index):].detach().cpu()))
return outputs