Feature Extraction
Transformers
Safetensors
meralion_bestrq
speech
best-rq
meralion
meralion-2
custom_code
Instructions to use MERaLiON/MERaLiON-SpeechEncoder-2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use MERaLiON/MERaLiON-SpeechEncoder-2 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="MERaLiON/MERaLiON-SpeechEncoder-2", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("MERaLiON/MERaLiON-SpeechEncoder-2", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import os | |
| import torch | |
| import math | |
| from torch import nn | |
| from typing import Optional, Tuple, Union | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.activations import ACT2FN | |
| from transformers.modeling_outputs import BaseModelOutput, Wav2Vec2BaseModelOutput, CausalLMOutput | |
| from safetensors.torch import load_file | |
| from .configuration_bestrq_conformer import MeralionBestRqConformerConfig | |
| _HIDDEN_STATES_START_POSITION = 2 | |
| def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor: | |
| """ | |
| Create a boolean padding mask from a tensor of sequence lengths. | |
| Args: | |
| lens (`torch.LongTensor`): | |
| A tensor of shape `(batch_size,)` containing the length of each sequence in the batch. | |
| Returns: | |
| `torch.BoolTensor`: | |
| A boolean mask of shape `(batch_size, max_len)` where `True` indicates a padded position. | |
| """ | |
| bsz, max_lens = lens.size(0), torch.max(lens).item() | |
| mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) | |
| mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) | |
| return mask | |
| def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: | |
| """Make mask tensor containing indices of padded part. | |
| See description of make_non_pad_mask. | |
| Args: | |
| lengths (torch.Tensor): Batch of lengths (B,). | |
| Returns: | |
| torch.Tensor: Mask tensor containing indices of padded part. | |
| Examples: | |
| >>> lengths = [5, 3, 2] | |
| >>> make_pad_mask(lengths) | |
| masks = [[0, 0, 0, 0 ,0], | |
| [0, 0, 0, 1, 1], | |
| [0, 0, 1, 1, 1]] | |
| """ | |
| batch_size = lengths.size(0) | |
| max_len = max_len if max_len > 0 else lengths.max().item() | |
| seq_range = torch.arange(0, | |
| max_len, | |
| dtype=torch.int64, | |
| device=lengths.device) | |
| seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) | |
| seq_length_expand = lengths.unsqueeze(-1) | |
| mask = seq_range_expand >= seq_length_expand | |
| return mask | |
| class Conv2dSubsampling(nn.Module): | |
| """ | |
| Convolutional 2D subsampling (to 1/4 length) | |
| For feature extraction/downsampling of input mel spectrogram | |
| Args: | |
| in_channels (int): Number of channels in the input image | |
| out_channels (int): Number of channels produced by the convolution | |
| Inputs: | |
| inputs (batch, time, dim): Tensor containing sequence of inputs | |
| input_lengths (batch): Tensor containing input_length for each item in batch | |
| Returns: | |
| outputs (batch, time, dim): Tensor produced by the convolution | |
| output_lengths (batch): Tensor containing output_length for each item in batch | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| self.sequential = nn.Sequential( | |
| nn.Conv2d(config.input_channels, config.hidden_size, kernel_size=3, stride=2), | |
| nn.ReLU(), | |
| nn.Conv2d(config.hidden_size, config.hidden_size, kernel_size=3, stride=2), | |
| nn.ReLU(), | |
| ) | |
| def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor): | |
| _, max_seq_len, _ = inputs.size() | |
| outputs = self.sequential(inputs.unsqueeze(1)) | |
| batch_size, channels, subsampled_lengths, sumsampled_dim = outputs.size() | |
| outputs = outputs.permute(0, 2, 1, 3) | |
| outputs = outputs.contiguous().view(batch_size, subsampled_lengths, channels * sumsampled_dim) | |
| subsampling_factor = int(max_seq_len * 1.0 / subsampled_lengths + 0.5) | |
| input_len_0 = (input_lengths.float() / subsampling_factor).ceil().long() | |
| input_len_1 = outputs.size(1) * torch.ones([input_lengths.size(0)]).long().to( | |
| input_len_0.device | |
| ) | |
| output_lengths = torch.min(input_len_0, input_len_1) | |
| return outputs, output_lengths | |
| class ConformerRelPositionalEmbedding(nn.Module): | |
| """Relative positional encoding module (new implementation). | |
| Args: | |
| d_model: Embedding dimension. | |
| dropout_rate: Dropout rate. | |
| max_len: Maximum input length. | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| self.max_len = config.max_source_positions | |
| self.d_model = config.hidden_size | |
| self.pe = None | |
| self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)) | |
| def extend_pe(self, x): | |
| """Reset the positional encodings.""" | |
| if self.pe is not None: | |
| # self.pe contains both positive and negative parts | |
| # the length of self.pe is 2 * input_len - 1 | |
| if self.pe.size(1) >= x.size(1) * 2 - 1: | |
| if self.pe.dtype != x.dtype or self.pe.device != x.device: | |
| self.pe = self.pe.to(dtype=x.dtype, device=x.device) | |
| return | |
| # Suppose `i` means to the position of query vector and `j` means the | |
| # position of key vector. We use position relative positions when keys | |
| # are to the left (i>j) and negative relative positions otherwise (i<j). | |
| pe_positive = torch.zeros(x.size(1), self.d_model) | |
| pe_negative = torch.zeros(x.size(1), self.d_model) | |
| position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) | |
| div_term = torch.exp( | |
| torch.arange(0, self.d_model, 2, dtype=torch.float32) | |
| * -(math.log(10000.0) / self.d_model) | |
| ) | |
| pe_positive[:, 0::2] = torch.sin(position * div_term) | |
| pe_positive[:, 1::2] = torch.cos(position * div_term) | |
| pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) | |
| pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) | |
| # Reserve the order of positive indices and concat both positive and | |
| # negative indices. This is used to support the shifting trick | |
| # as in https://arxiv.org/abs/1901.02860 | |
| pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) | |
| pe_negative = pe_negative[1:].unsqueeze(0) | |
| pe = torch.cat([pe_positive, pe_negative], dim=1) | |
| self.pe = pe.to(device=x.device, dtype=x.dtype) | |
| def forward(self, x: torch.Tensor): | |
| """Add positional encoding. | |
| Args: | |
| x : Input tensor T X B X C. | |
| Returns: | |
| torch.Tensor: Encoded tensor T X B X C. | |
| """ | |
| x = x.transpose(0, 1) # Change TBC to BTC | |
| self.extend_pe(x) | |
| pos_emb = self.pe[ | |
| :, | |
| self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1), | |
| ] | |
| pos_emb = pos_emb.transpose(0, 1) # change to TBC | |
| return pos_emb | |
| class ConformerRotaryPositionalEmbedding(nn.Module): | |
| """Rotary positional embedding | |
| Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| dim = config.hidden_size // config.num_attention_heads | |
| base = config.rotary_embedding_base | |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)) | |
| self.register_buffer("inv_freq", inv_freq) | |
| self.cached_sequence_length = None | |
| self.cached_rotary_positional_embedding = None | |
| def forward(self, hidden_states): | |
| sequence_length = hidden_states.shape[1] | |
| if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None: | |
| return self.cached_rotary_positional_embedding | |
| self.cached_sequence_length = sequence_length | |
| # Embeddings are computed in the dtype of the inv_freq constant | |
| time_stamps = torch.arange(sequence_length).type_as(self.inv_freq) | |
| freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq) | |
| embeddings = torch.cat((freqs, freqs), dim=-1) | |
| cos_embeddings = embeddings.cos()[:, None, None, :] | |
| sin_embeddings = embeddings.sin()[:, None, None, :] | |
| # Computed embeddings are cast to the dtype of the hidden state inputs | |
| self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]).type_as(hidden_states) | |
| return self.cached_rotary_positional_embedding | |
| class ConformerInputFeatureProjection(nn.Module): | |
| """ | |
| Projects the input features to the hidden size of the Conformer model. This layer is applied after the | |
| convolutional subsampling. | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| subsample_embed_dim = config.hidden_size * (((config.input_dim - 1) // 2 - 1) // 2) | |
| #self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) | |
| self.projection = nn.Linear(subsample_embed_dim, config.hidden_size) | |
| self.dropout = nn.Dropout(config.feat_proj_dropout) | |
| def forward(self, hidden_states): | |
| """ | |
| Args: | |
| hidden_states: Input Tensor of shape T X B X C | |
| Returns: | |
| Tensor of shape T X B X C | |
| """ | |
| # non-projected hidden states are needed for quantization | |
| #norm_hidden_states = self.layer_norm(hidden_states) | |
| hidden_states = self.projection(hidden_states) | |
| hidden_states = self.dropout(hidden_states) | |
| return hidden_states | |
| class ConformerFeedForward(nn.Module): | |
| """Positionwise feed forward layer used in conformer""" | |
| def __init__(self, config): | |
| super().__init__() | |
| #self.layer_norm = torch.nn.LayerNorm(config.hidden_size, eps=1e-5, elementwise_affine=True) | |
| self.intermediate_dropout = nn.Dropout(config.activation_dropout) | |
| self.intermediate_dense = nn.Linear(config.hidden_size, config.ffn_dim) | |
| if isinstance(config.hidden_act, str): | |
| self.intermediate_act_fn = ACT2FN[config.hidden_act] | |
| else: | |
| self.intermediate_act_fn = config.hidden_act | |
| self.output_dense = nn.Linear(config.ffn_dim, config.hidden_size) | |
| self.output_dropout = nn.Dropout(config.hidden_dropout) | |
| def forward(self, hidden_states): | |
| """ | |
| Args: | |
| x: Input Tensor of shape T X B X C | |
| Returns: | |
| Tensor of shape T X B X C | |
| """ | |
| hidden_states = self.intermediate_dense(hidden_states) | |
| hidden_states = self.intermediate_act_fn(hidden_states) | |
| hidden_states = self.intermediate_dropout(hidden_states) | |
| hidden_states = self.output_dense(hidden_states) | |
| hidden_states = self.output_dropout(hidden_states) | |
| return hidden_states | |
| class ConformerConvolutionModule(nn.Module): | |
| """Convolution block used in the conformer block""" | |
| def __init__(self, config): | |
| super().__init__() | |
| if (config.conv_depthwise_kernel_size - 1) % 2 == 1: | |
| raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding") | |
| self.layer_norm = nn.LayerNorm(config.hidden_size) | |
| self.pointwise_conv1 = nn.Conv1d( | |
| config.hidden_size, | |
| 2 * config.hidden_size, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| ) | |
| self.glu = nn.GLU(dim=1) | |
| self.depthwise_conv = nn.Conv1d( | |
| config.hidden_size, | |
| config.hidden_size, | |
| config.conv_depthwise_kernel_size, | |
| stride=1, | |
| padding=(config.conv_depthwise_kernel_size - 1) // 2, | |
| groups=config.hidden_size, | |
| bias=False, | |
| ) | |
| self.batch_norm = nn.BatchNorm1d(config.hidden_size) | |
| self.activation = ACT2FN[config.hidden_act] | |
| self.pointwise_conv2 = nn.Conv1d( | |
| config.hidden_size, | |
| config.hidden_size, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| bias=False, | |
| ) | |
| self.dropout = nn.Dropout(config.conformer_conv_dropout) | |
| def forward(self, hidden_states): | |
| """ | |
| Args: | |
| hidden_states: Input of shape B X T X C | |
| Returns: | |
| Tensor of shape B X T X C | |
| """ | |
| hidden_states = self.layer_norm(hidden_states) | |
| hidden_states = hidden_states.transpose(1, 2) | |
| # GLU mechanism | |
| # => (batch, 2*channel, dim) | |
| hidden_states = self.pointwise_conv1(hidden_states) | |
| # => (batch, channel, dim) | |
| hidden_states = self.glu(hidden_states) | |
| # 1D Depthwise Conv | |
| hidden_states = self.depthwise_conv(hidden_states) | |
| hidden_states = self.batch_norm(hidden_states) | |
| hidden_states = self.activation(hidden_states) | |
| hidden_states = self.pointwise_conv2(hidden_states) | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_states = hidden_states.transpose(1, 2) | |
| return hidden_states | |
| class ConformerSelfAttention(nn.Module): | |
| """ConformerSelfAttention object. | |
| Can be enhanced with rotary or relative position embeddings. | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| self.head_size = config.hidden_size // config.num_attention_heads | |
| self.num_heads = config.num_attention_heads | |
| self.position_embeddings_type = config.position_embeddings_type | |
| self.linear_q = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.linear_k = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.linear_v = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.linear_out = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.dropout = nn.Dropout(p=config.attention_dropout) | |
| if self.position_embeddings_type == "relative": | |
| # linear transformation for positional encoding | |
| self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False) | |
| # these two learnable bias are used in matrix c and matrix d | |
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |
| self.pos_bias_u = nn.Parameter(torch.Tensor(self.num_heads, self.head_size)) | |
| self.pos_bias_v = nn.Parameter(torch.Tensor(self.num_heads, self.head_size)) | |
| torch.nn.init.xavier_uniform_(self.pos_bias_u) ## | |
| torch.nn.init.xavier_uniform_(self.pos_bias_v) ## | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, #[T, B, C] | |
| attention_mask: Optional[torch.Tensor] = None, | |
| relative_position_embeddings: Optional[torch.Tensor] = None, #[T, B, C] | |
| output_attentions: bool = False, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
| # self-attention mechanism | |
| hidden_states = hidden_states.transpose(0, 1) #[B, T, C] | |
| relative_position_embeddings = relative_position_embeddings.transpose(0, 1) #[B, T, C] | |
| batch_size, sequence_length, hidden_size = hidden_states.size() | |
| # make sure query/key states can be != value states | |
| query_key_states = hidden_states | |
| value_states = hidden_states | |
| if self.position_embeddings_type == "rotary": | |
| if relative_position_embeddings is None: | |
| raise ValueError( | |
| "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'" | |
| ) | |
| query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings) | |
| # project query_key_states and value_states | |
| query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) | |
| key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) | |
| value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size) | |
| # => (batch, head, time1, d_k) | |
| query = query.transpose(1, 2) | |
| key = key.transpose(1, 2) | |
| value = value.transpose(1, 2) | |
| if self.position_embeddings_type == "relative": | |
| if relative_position_embeddings is None: | |
| raise ValueError( | |
| "`relative_position_embeddings` has to be defined when `self.position_embeddings_type ==" | |
| " 'relative'" | |
| ) | |
| # apply relative_position_embeddings to qk scores | |
| # as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860 | |
| scores = self._apply_relative_embeddings( | |
| query=query, key=key, relative_position_embeddings=relative_position_embeddings | |
| ) | |
| else: | |
| scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size) | |
| # apply attention_mask if necessary | |
| if attention_mask is not None: | |
| scores = scores.masked_fill( | |
| attention_mask.unsqueeze(1).unsqueeze(2).to(bool), | |
| float("-inf"), # (batch, head, time1, time2) | |
| ) | |
| # => (batch, head, time1, time2) | |
| probs = torch.softmax(scores, dim=-1) | |
| probs = self.dropout(probs) | |
| # => (batch, head, time1, d_k) | |
| hidden_states = torch.matmul(probs, value) | |
| # => (batch, time1, hidden_size) | |
| hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size) | |
| hidden_states = self.linear_out(hidden_states) | |
| # => (time1, batch, hidden_size) | |
| hidden_states = hidden_states.transpose(0, 1) | |
| return hidden_states, probs | |
| def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings): | |
| batch_size, sequence_length, hidden_size = hidden_states.size() | |
| hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size) | |
| cos = relative_position_embeddings[0, :sequence_length, ...] | |
| sin = relative_position_embeddings[1, :sequence_length, ...] | |
| # rotate hidden_states with rotary embeddings | |
| hidden_states = hidden_states.transpose(0, 1) | |
| rotated_states_begin = hidden_states[..., : self.head_size // 2] | |
| rotated_states_end = hidden_states[..., self.head_size // 2 :] | |
| rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1) | |
| hidden_states = (hidden_states * cos) + (rotated_states * sin) | |
| hidden_states = hidden_states.transpose(0, 1) | |
| hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size) | |
| return hidden_states | |
| def _apply_relative_embeddings(self, query, key, relative_position_embeddings): | |
| # 1. project positional embeddings | |
| # => (batch, head, d_k, 2*time1-1) | |
| proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings) | |
| proj_relative_position_embeddings = proj_relative_position_embeddings.view( | |
| relative_position_embeddings.size(0), -1, self.num_heads, self.head_size # (batch, 2*time1-1, head, d_k) | |
| ) | |
| proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2) # (batch, head, 2*time1-1, d_k) | |
| proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3) # (batch, head, d_k, 2*time1-1) | |
| # 2. Add bias to query | |
| # => (batch, head, time1, d_k) | |
| query = query.transpose(1, 2) # (batch, time1, head, d_k) | |
| q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) | |
| q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) | |
| # 3. attention score: first compute matrix a and matrix c | |
| # as described in https://arxiv.org/abs/1901.02860 Section 3.3 | |
| # => (batch, head, time1, time2) | |
| scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) | |
| # 4. then compute matrix b and matrix d | |
| # => (batch, head, time1, 2*time1-1) | |
| scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings) | |
| # 5. shift matrix b and matrix d | |
| zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype) | |
| scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1) | |
| scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2]) | |
| scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape) | |
| scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd) | |
| scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1] | |
| # 6. sum matrices | |
| # => (batch, head, time1, time2) | |
| scores = (scores_ac + scores_bd) / math.sqrt(self.head_size) | |
| return scores | |
| class ConformerEncoderLayer(nn.Module): | |
| """Conformer block based on https://arxiv.org/abs/2005.08100.""" | |
| def __init__(self, config): | |
| super().__init__() | |
| embed_dim = config.hidden_size | |
| dropout = config.attention_dropout | |
| # Feed-forward 1 | |
| self.ffn1_layer_norm = nn.LayerNorm(embed_dim) | |
| self.ffn1 = ConformerFeedForward(config) | |
| # Self-Attention | |
| self.self_attn_layer_norm = nn.LayerNorm(embed_dim) | |
| self.self_attn_dropout = nn.Dropout(dropout) | |
| self.self_attn = ConformerSelfAttention(config) | |
| # Conformer Convolution | |
| self.conv_module = ConformerConvolutionModule(config) | |
| # Feed-forward 2 | |
| self.ffn2_layer_norm = nn.LayerNorm(embed_dim) | |
| self.ffn2 = ConformerFeedForward(config) | |
| self.final_layer_norm = nn.LayerNorm(embed_dim) | |
| def forward( | |
| self, | |
| hidden_states, # [T, B, C] | |
| attention_mask: Optional[torch.Tensor] = None, | |
| relative_position_embeddings: Optional[torch.Tensor] = None, | |
| output_attentions: bool = False, | |
| ): | |
| hidden_states = hidden_states | |
| # 1. Feed-Forward 1 layer | |
| residual = hidden_states | |
| hidden_states = self.ffn1_layer_norm(hidden_states) | |
| hidden_states = self.ffn1(hidden_states) | |
| hidden_states = hidden_states * 0.5 + residual | |
| residual = hidden_states | |
| # 2. Self-Attention layer | |
| hidden_states = self.self_attn_layer_norm(hidden_states) | |
| hidden_states, attn_weights = self.self_attn( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| relative_position_embeddings=relative_position_embeddings, | |
| output_attentions=output_attentions, | |
| ) | |
| hidden_states = self.self_attn_dropout(hidden_states) | |
| hidden_states = hidden_states + residual | |
| # 3. Convolutional Layer | |
| residual = hidden_states | |
| hidden_states = hidden_states.transpose(0, 1) # [T,B,C] to [B,T,C] | |
| hidden_states = self.conv_module(hidden_states) | |
| hidden_states = hidden_states.transpose(0, 1) # [B,T,C] to [T,B,C] | |
| hidden_states = residual + hidden_states | |
| # 4. Feed-Forward 2 Layer | |
| residual = hidden_states | |
| hidden_states = self.ffn2_layer_norm(hidden_states) | |
| hidden_states = self.ffn2(hidden_states) | |
| hidden_states = hidden_states * 0.5 + residual | |
| hidden_states = self.final_layer_norm(hidden_states) | |
| return hidden_states, attn_weights | |
| class ConformerEncoder(nn.Module): | |
| """ | |
| The Conformer encoder module. This module is composed of a stack of Conformer layers and is responsible for | |
| encoding the input features. | |
| Args: | |
| config ([`MeralionBestRqConformerConfig`]): | |
| The configuration object for the model. | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.embed_scale = math.sqrt(config.hidden_size) | |
| if config.no_scale_embedding: | |
| self.embed_scale = 1.0 | |
| if config.position_embeddings_type == "relative": | |
| self.embed_positions = ConformerRelPositionalEmbedding(config) | |
| elif config.position_embeddings_type == "rotary": | |
| self.embed_positions = ConformerRotaryPositionalEmbedding(config) | |
| else: | |
| self.embed_positions = None | |
| self.input_projection = ConformerInputFeatureProjection(config) # [T,B,C] | |
| self.layers = nn.ModuleList([ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)]) | |
| self.gradient_checkpointing = False | |
| self.self_condition_layers = self.config.self_condition_layers | |
| self.conditioning_layer = None | |
| self.conditioning_softmax = None | |
| self.out_projection = None | |
| if self.self_condition_layers: | |
| # If self-conditioning is enabled, we need these layers | |
| if self.config.vocab_size is None: | |
| raise ValueError("output_size (i.e., vocab_size) must be provided for self-conditioning.") | |
| self.conditioning_layer = nn.Linear(self.config.vocab_size, self.config.hidden_size) | |
| self.conditioning_softmax = nn.Softmax(dim=-1) | |
| def forward( | |
| self, | |
| hidden_states, # conv_out | |
| attention_mask=None, # encoder_padding_mask | |
| output_attentions=False, | |
| output_hidden_states=False, | |
| return_dict=True, | |
| ctc_decoder=None, | |
| ): | |
| all_hidden_states = () if output_hidden_states else None | |
| all_self_attentions = () if output_attentions else None | |
| ctc_outputs = () if self.self_condition_layers else None | |
| hidden_states = self.embed_scale * hidden_states | |
| if self.embed_positions is not None: | |
| relative_position_embeddings = self.embed_positions(hidden_states) # [T,B,C] | |
| else: | |
| relative_position_embeddings = None | |
| hidden_states = self.input_projection(hidden_states) # [T,B,C] | |
| if output_hidden_states: | |
| all_hidden_states = all_hidden_states + (hidden_states.transpose(0, 1),) | |
| for i, layer in enumerate(self.layers): | |
| # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) | |
| dropout_probability = torch.rand([]) | |
| skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False | |
| if not skip_the_layer: | |
| layer_outputs = layer( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| relative_position_embeddings=relative_position_embeddings, | |
| output_attentions=output_attentions, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| if skip_the_layer: | |
| layer_outputs = (None, None) | |
| if output_hidden_states: | |
| all_hidden_states = all_hidden_states + (hidden_states.transpose(0, 1),) # [T,B,C] -> [B,T,C] | |
| if output_attentions: | |
| all_self_attentions = all_self_attentions + (layer_outputs[1],) | |
| if i in self.self_condition_layers: | |
| assert isinstance(ctc_decoder, nn.Module), "A CTC decoder must be passed in for self conditioning" | |
| # Apply CTC decoder to the output of this layer | |
| ctc_logits = ctc_decoder(hidden_states.transpose(0, 1)) | |
| ctc_outputs = ctc_outputs + (ctc_logits,) | |
| ctc_probs = self.conditioning_softmax(ctc_logits) | |
| conditioning_embedding = self.conditioning_layer(ctc_probs).transpose(0, 1) #[T, B, C] | |
| hidden_states = hidden_states + conditioning_embedding # Additive conditioning residual | |
| hidden_states = hidden_states.transpose(0, 1) # [B,T,C] | |
| if self.self_condition_layers: | |
| all_hidden_states = all_hidden_states + ctc_outputs | |
| if not return_dict: | |
| return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) | |
| return BaseModelOutput( | |
| last_hidden_state=hidden_states, | |
| hidden_states=all_hidden_states, | |
| attentions=all_self_attentions, | |
| ) | |
| class MeralionBestRqModel(PreTrainedModel): | |
| """ | |
| The core BEST-RQ Conformer model. This model is a `PreTrainedModel` that takes the raw mel-spectrogram features | |
| and outputs the final encoder hidden states. | |
| This model inherits from [`PreTrainedModel`]. For the available methods and functionalities, see the | |
| documentation in [`PreTrainedModel`]. | |
| Args: | |
| config ([`MeralionBestRqConformerConfig`]): | |
| Model configuration class with all the parameters of the model. | |
| Initializing with a config file does not load the weights associated with the model, only the | |
| configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. | |
| """ | |
| config_class = MeralionBestRqConformerConfig | |
| base_model_prefix = "bestrq_encoder" | |
| def __init__(self, config: MeralionBestRqConformerConfig): | |
| super().__init__(config) | |
| self.config = config | |
| self.conv_subsample = Conv2dSubsampling(config) | |
| self.encoder = ConformerEncoder(config) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def forward( | |
| self, | |
| input_values: Optional[torch.Tensor], # [B,C,T] | |
| attention_mask: Optional[torch.Tensor], | |
| mask_time_indices: Optional[torch.FloatTensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ctc_decoder: Optional[nn.Module] = None, | |
| ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: | |
| r""" | |
| Performs the forward pass of the BEST-RQ Conformer model. | |
| Args: | |
| input_values (`torch.FloatTensor` of shape `(batch_size, num_features, sequence_length)`): | |
| Float values of mel features extracted from the raw speech signal. | |
| attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
| Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, where 1 for | |
| tokens that are not masked and 0 for tokens that are masked. | |
| mask_time_indices (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Currently unused. | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
| returned tensors for more detail. | |
| output_hidden_states (`bool`, *optional*): | |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors | |
| for more detail. | |
| return_dict (`bool`, *optional*): | |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
| ctc_decoder (`nn.Module`, *optional*): | |
| A CTC decoder module that can be used for self-conditioning. If provided, the model will apply this | |
| decoder at intermediate layers and use the output to condition the subsequent layers. | |
| Returns: | |
| [`Wav2Vec2BaseModelOutput`] or `tuple`: | |
| A [`Wav2Vec2BaseModelOutput`] (if `return_dict=True`) or a tuple of tensors (if `return_dict=False`) | |
| comprising the following elements: | |
| - **last_hidden_state** (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): | |
| Sequence of hidden-states at the output of the last layer of the model. | |
| - **extract_features** (`torch.FloatTensor` of shape `(batch_size, sequence_length, conv_dim)`): | |
| Sequence of robustly extracted features from the CNN feature extractor. | |
| - **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): | |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) | |
| of shape `(batch_size, sequence_length, hidden_size)`. | |
| - **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): | |
| Tuple of `torch.FloatTensor` (one for each layer) of shape | |
| `(batch_size, num_heads, sequence_length, sequence_length)`. | |
| - **output_lengths** (`torch.LongTensor` of shape `(batch_size,)`): | |
| The length of each sequence after the convolutional subsampling. | |
| """ | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| input_lengths = attention_mask.sum(dim=-1) | |
| input_values = input_values.transpose(2, 1) # [B,C,T] -> [B,T,C] | |
| conv_outputs, output_lengths = self.conv_subsample(input_values, input_lengths) # returns [B,T,C] | |
| x = conv_outputs.transpose(0, 1) # [T,B,C] | |
| encoder_padding_mask = make_pad_mask(output_lengths, max_len=x.shape[0]) | |
| encoder_outputs = self.encoder( | |
| x, | |
| attention_mask=encoder_padding_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ctc_decoder=ctc_decoder, | |
| ) | |
| hidden_states = encoder_outputs[0] | |
| if not return_dict: | |
| return (hidden_states, conv_outputs) + encoder_outputs[1:] | |
| output = Wav2Vec2BaseModelOutput( | |
| last_hidden_state=hidden_states, | |
| extract_features=conv_outputs, | |
| hidden_states=encoder_outputs.hidden_states, | |
| attentions=encoder_outputs.attentions, | |
| ) | |
| output["output_lengths"] = output_lengths | |
| return output | |
| class MeralionBestRqModelForCTC(PreTrainedModel): | |
| """ | |
| BEST-RQ Conformer model with a CTC head on top for Connectionist Temporal Classification. This model can | |
| also use a weighted sum of the encoder's hidden states. | |
| This model inherits from [`PreTrainedModel`]. For the available methods and functionalities, see the | |
| documentation in [`PreTrainedModel`]. | |
| Args: | |
| config ([`MeralionBestRqConformerConfig`]): | |
| Model configuration class with all the parameters of the model. | |
| Initializing with a config file does not load the weights associated with the model, only the | |
| configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. | |
| """ | |
| # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer | |
| config_class = MeralionBestRqConformerConfig | |
| base_model_prefix = "bestrq_encoder" | |
| def __init__(self, config, target_lang: Optional[str] = None, **kwargs): | |
| super().__init__(config) | |
| self.bestrq_encoder = MeralionBestRqModel(config) | |
| self.dropout = nn.Dropout(config.final_dropout) | |
| self.target_lang = target_lang | |
| if self.config.use_weighted_sum: | |
| self.weights = nn.Parameter(torch.zeros(self.config.num_hidden_layers)) | |
| self.softmax = nn.Softmax(dim=-1) | |
| if config.vocab_size is None: | |
| raise ValueError( | |
| f"You are trying to instantiate {self.__class__} with a configuration that " | |
| "does not define the vocabulary size of the language model head. Please " | |
| "instantiate the model as follows: `MeralionBestRqModelForCTC.from_pretrained(..., vocab_size=vocab_size)`. " | |
| "or define `vocab_size` of your model's configuration." | |
| ) | |
| output_hidden_size = ( | |
| config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size | |
| ) | |
| self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer | |
| def forward( | |
| self, | |
| input_values: Optional[torch.Tensor], | |
| attention_mask: Optional[torch.Tensor], | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| ) -> Union[Tuple, CausalLMOutput]: | |
| r""" | |
| Performs the forward pass of the BEST-RQ Conformer model with a CTC head. | |
| Args: | |
| input_values (`torch.FloatTensor` of shape `(batch_size, num_features, sequence_length)`): | |
| Float values of mel features extracted from the raw speech signal. | |
| attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
| Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, where 1 for | |
| tokens that are not masked and 0 for tokens that are masked. | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
| returned tensors for more detail. | |
| output_hidden_states (`bool`, *optional*): | |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors | |
| for more detail. | |
| return_dict (`bool`, *optional*): | |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
| labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): | |
| Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal | |
| to the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. | |
| All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., | |
| config.vocab_size - 1]`. | |
| Returns: | |
| [`CausalLMOutput`] or `tuple`: | |
| A [`CausalLMOutput`] (if `return_dict=True`) or a tuple of tensors (if `return_dict=False`) | |
| comprising the following elements: | |
| - **loss** (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): | |
| CTC loss. | |
| - **logits** (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): | |
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | |
| - **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): | |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) | |
| of shape `(batch_size, sequence_length, hidden_size)`. | |
| - **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): | |
| Tuple of `torch.FloatTensor` (one for each layer) of shape | |
| `(batch_size, num_heads, sequence_length, sequence_length)`. | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if labels is not None and labels.max() >= self.config.vocab_size: | |
| raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") | |
| outputs = self.bestrq_encoder( | |
| input_values, | |
| attention_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| if self.config.use_weighted_sum: | |
| assert output_hidden_states is True, "output_hidden_states must be True when using use_weighted_sum" | |
| # Skip the first hidden state as that is collected before the first encoder layer forward | |
| hidden_states = outputs.hidden_states[1:self.config.num_hidden_layers+1] | |
| hidden_states = self._weighted_sum(hidden_states) | |
| else: | |
| hidden_states = outputs.last_hidden_state | |
| hidden_states = self.dropout(hidden_states) | |
| logits = self.lm_head(hidden_states) | |
| loss = None | |
| if labels is not None: | |
| # assuming that padded tokens are filled with -100 | |
| # when not being attended to | |
| labels_mask = labels >= 0 | |
| target_lengths = labels_mask.sum(-1) | |
| flattened_targets = labels.masked_select(labels_mask) | |
| # ctc_loss doesn't support fp16 | |
| log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) | |
| with torch.backends.cudnn.flags(enabled=False): | |
| loss = nn.functional.ctc_loss( | |
| log_probs, | |
| flattened_targets, | |
| outputs.output_lengths, #lengths after initial CNN downsampling | |
| target_lengths, | |
| blank=self.config.pad_token_id, | |
| reduction=self.config.ctc_loss_reduction, | |
| zero_infinity=self.config.ctc_zero_infinity, | |
| ) | |
| if not return_dict: | |
| output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] | |
| return ((loss,) + output) if loss is not None else output | |
| return CausalLMOutput( | |
| loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions | |
| ) | |
| class LSTMCTCHead(nn.Module): | |
| """ | |
| A CTC head that includes LSTM layers before the final projection. This is used for the | |
| `MeralionBestRqModelForLSTMCTC` model. | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| output_hidden_size = ( | |
| config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size | |
| ) | |
| self.lstm = nn.LSTM( | |
| output_hidden_size, | |
| config.lstm_dim, | |
| num_layers=config.lstm_num_layers, | |
| dropout=config.lstm_dropout_prob, | |
| batch_first=True, | |
| bidirectional=True, | |
| ) | |
| self.lm_head = nn.Linear(config.lstm_dim * 2, config.vocab_size) | |
| def forward(self, hidden_states): | |
| hidden_states, _ = self.lstm(hidden_states) | |
| logits = self.lm_head(hidden_states) | |
| return logits | |
| def _weighted_sum(self, embeddings, normalize=False): | |
| assert isinstance(embeddings, list) or isinstance(embeddings, tuple) | |
| assert len(embeddings) == self.config.num_hidden_layers, f"Number of embeddings: {len(embeddings)} does not match number of layers: {self.config.num_hidden_layers}" | |
| stacked_hs = torch.stack(embeddings, dim=0) | |
| if normalize: | |
| stacked_hs = nn.functional.layer_norm(stacked_hs, (stacked_hs.shape[-1],)) | |
| _, *origin_shape = stacked_hs.shape | |
| stacked_hs = stacked_hs.view(self.config.num_hidden_layers, -1) | |
| norm_weights = self.softmax(self.weights) | |
| weighted_hs = (norm_weights.unsqueeze(-1) * stacked_hs).sum(dim=0) | |
| weighted_hs = weighted_hs.view(*origin_shape) | |
| return weighted_hs | |
| class MeralionBestRqModelForLSTMCTC(PreTrainedModel): | |
| """ | |
| BEST-RQ Conformer model with an LSTM-CTC head on top for Connectionist Temporal Classification. This model can | |
| also use a weighted sum of the encoder's hidden states. | |
| This model inherits from [`PreTrainedModel`]. For the available methods and functionalities, see the | |
| documentation in [`PreTrainedModel`]. | |
| Args: | |
| config ([`MeralionBestRqConformerConfig`]): | |
| Model configuration class with all the parameters of the model. | |
| Initializing with a config file does not load the weights associated with the model, only the | |
| configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. | |
| """ | |
| # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer | |
| config_class = MeralionBestRqConformerConfig | |
| base_model_prefix = "bestrq_encoder" | |
| main_input_name = "input_values" | |
| def __init__(self, config, target_lang: Optional[str] = None, **kwargs): | |
| super().__init__(config) | |
| self.config = config | |
| self.bestrq_encoder = MeralionBestRqModel(config) | |
| self.dropout = nn.Dropout(config.final_dropout) | |
| self.target_lang = target_lang | |
| if config.vocab_size is None: | |
| raise ValueError( | |
| f"You are trying to instantiate {self.__class__} with a configuration that " | |
| "does not define the vocabulary size of the language model head. Please " | |
| "instantiate the model as follows: `MeralionBestRqModelForLSTMCTC.from_pretrained(..., vocab_size=vocab_size)`. " | |
| "or define `vocab_size` of your model's configuration." | |
| ) | |
| if self.config.use_weighted_sum: | |
| self.weights = nn.Parameter(torch.zeros(self.config.num_hidden_layers)) | |
| self.softmax = nn.Softmax(dim=-1) | |
| self.lstm_ctc_decoder = LSTMCTCHead(config) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer | |
| def forward( | |
| self, | |
| input_values: Optional[torch.Tensor], | |
| attention_mask: Optional[torch.Tensor], | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = True, | |
| return_dict: Optional[bool] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| ) -> Union[Tuple, CausalLMOutput]: | |
| r""" | |
| Performs the forward pass of the BEST-RQ Conformer model with an LSTM-CTC head. | |
| Args: | |
| input_values (`torch.FloatTensor` of shape `(batch_size, num_features, sequence_length)`): | |
| Float values of mel features extracted from the raw speech signal. | |
| attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): | |
| Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, where 1 for | |
| tokens that are not masked and 0 for tokens that are masked. | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
| returned tensors for more detail. | |
| output_hidden_states (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors | |
| for more detail. | |
| return_dict (`bool`, *optional*): | |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
| labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): | |
| Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal | |
| to the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. | |
| All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., | |
| config.vocab_size - 1]`. | |
| Returns: | |
| [`CausalLMOutput`] or `tuple`: | |
| A [`CausalLMOutput`] (if `return_dict=True`) or a tuple of tensors (if `return_dict=False`) | |
| comprising the following elements: | |
| - **loss** (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): | |
| CTC loss. | |
| - **logits** (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): | |
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). | |
| - **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`): | |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) | |
| of shape `(batch_size, sequence_length, hidden_size)`. | |
| - **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True`): | |
| Tuple of `torch.FloatTensor` (one for each layer) of shape | |
| `(batch_size, num_heads, sequence_length, sequence_length)`. | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if labels is not None and labels.max() >= self.config.vocab_size: | |
| raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") | |
| outputs = self.bestrq_encoder( | |
| input_values, | |
| attention_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ctc_decoder=self.lstm_ctc_decoder, | |
| ) | |
| if self.config.use_weighted_sum: | |
| assert output_hidden_states is True, "output_hidden_states must be True when using use_weighted_sum" | |
| # Skip the first hidden state as that is collected before the first encoder layer forward | |
| hidden_states = outputs.hidden_states[1:self.config.num_hidden_layers+1] | |
| hidden_states = self._weighted_sum(hidden_states) | |
| else: | |
| hidden_states = outputs.last_hidden_state | |
| hidden_states = self.dropout(hidden_states) | |
| logits = self.lstm_ctc_decoder(hidden_states) | |
| loss = None | |
| if labels is not None: | |
| # assuming that padded tokens are filled with -100 | |
| # when not being attended to | |
| labels_mask = labels >= 0 | |
| target_lengths = labels_mask.sum(-1) | |
| flattened_targets = labels.masked_select(labels_mask) | |
| # ctc_loss doesn't support fp16 | |
| log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) | |
| with torch.backends.cudnn.flags(enabled=False): | |
| loss = nn.functional.ctc_loss( | |
| log_probs, | |
| flattened_targets, | |
| outputs.output_lengths, #lengths after initial CNN downsampling | |
| target_lengths, | |
| reduction=self.config.ctc_loss_reduction, | |
| zero_infinity=self.config.ctc_zero_infinity, | |
| ) | |
| if not return_dict: | |
| output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] | |
| return ((loss,) + output) if loss is not None else output | |
| return CausalLMOutput( | |
| loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions | |
| ) | |
| def _weighted_sum(self, embeddings, normalize=False): | |
| assert isinstance(embeddings, list) or isinstance(embeddings, tuple) | |
| assert len(embeddings) == self.config.num_hidden_layers, f"Number of embeddings: {len(embeddings)} does not match number of layers: {self.config.num_hidden_layers}" | |
| stacked_hs = torch.stack(embeddings, dim=0) | |
| if normalize: | |
| stacked_hs = nn.functional.layer_norm(stacked_hs, (stacked_hs.shape[-1],)) | |
| _, *origin_shape = stacked_hs.shape | |
| stacked_hs = stacked_hs.view(self.config.num_hidden_layers, -1) | |
| norm_weights = self.softmax(self.weights) | |
| weighted_hs = (norm_weights.unsqueeze(-1) * stacked_hs).sum(dim=0) | |
| weighted_hs = weighted_hs.view(*origin_shape) | |
| return weighted_hs | |