import numpy as np import torch import transformers _HIDDEN_STATES_START_POSITION = 2 def add_language_to_hidden(hidden_state: torch.Tensor, language: torch.Tensor): return torch.cat( [hidden_state, language.repeat((1, hidden_state.shape[1])).unsqueeze(dim=2)], dim=2 ) def language_classifer(language): """ Return a float identifying each known language. "fr" has value of 0, "it" a value of one. Other languages will have a value increasing in lexicographic order. :param str language: Language to identify, should be two letters. :return float: Unique identifier, between 0 and 1. """ if language == "fr": return 0 if language == "it": return 1 # Some random code to encode a two-letter language between 0 and 1 # "aa" should be 0+1=1 and "zz" should be 1+2=3 codes = ( (ord(letter) - ord("a")) / (ord("z") - ord("a")) + i for i, letter in enumerate(language) ) # Transform to [0, 1] return (sum(codes) - 1) / 2 class WavLMPhonemeFrIt(transformers.WavLMForCTC): """ PhonemeRecognizer: WavLM + Linear layer for speech recognition. It natively separates French and Italian. For a more professional implementation, view https://github.com/huggingface/transformers/blob/main/src/transformers/models/wav2vec2/modeling_wav2vec2.py """ def __init__(self, config, tokenizer=None): """ Create the new model out of a combination of both models. :param config: Model config. """ super().__init__(config) output_hidden_size = ( config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size ) # Replace head and add multilingualism self.lm_head = torch.nn.Linear(output_hidden_size + 1, config.vocab_size) self.tokenizer = tokenizer def forward( self, input_values: torch.Tensor, attention_mask: torch.Tensor = None, language: torch.Tensor = None, output_attentions: bool = None, output_hidden_states: bool = None, return_dict: bool = None, labels: torch.Tensor = None, ): """ Classify audio to a chain of phonemes of the same length. Stolen from https://github.com/huggingface/transformers/blob/6ba8a1ff4550b4450a22a0b0d907312955ce0fd5/src/transformers/models/wavlm/modeling_wavlm.py#L1196 """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None and labels.max() >= self.config.vocab_size: raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") outputs = self.wavlm( input_values, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] hidden_states = self.dropout(hidden_states) hidden_with_lang = add_language_to_hidden(hidden_states, language) logits = self.lm_head(hidden_with_lang) loss = None if labels is not None: # retrieve loss input_lengths from attention_mask attention_mask = ( attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) ) input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) # assuming that padded tokens are filled with -100 # when not being attended to labels_mask = labels >= 0 target_lengths = labels_mask.sum(-1) flattened_targets = labels.masked_select(labels_mask) # ctc_loss doesn't support fp16 log_probs = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) with torch.backends.cudnn.flags(enabled=False): loss = torch.nn.functional.ctc_loss( log_probs, flattened_targets, input_lengths, target_lengths, blank=self.config.pad_token_id, reduction=self.config.ctc_loss_reduction, zero_infinity=self.config.ctc_zero_infinity, ) if not return_dict: output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] return ((loss,) + output) if loss is not None else output return transformers.modeling_outputs.CausalLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions ) def freeze_feature_encoder_only(self): # Unfreeze base model for param in self.wavlm.parameters(): param.requires_grad = True # Now freeze the first layer self.freeze_feature_encoder() def freeze_layer(layer, freeze=True): for param in layer.parameters(): param.requires_grad = not freeze layer._requires_grad = not freeze def get_wavlm_phoneme_fr_it(tokenizer, freeze_hidden_layers=False): model = WavLMPhonemeFrIt.from_pretrained( "microsoft/wavlm-base-plus", ctc_loss_reduction="mean", pad_token_id=tokenizer.pad_token_id, vocab_size=len(tokenizer) ) model.tokenizer = tokenizer if freeze_hidden_layers: model.freeze_base_model() return model