Hugo Farajallah
feat(animation): now the animation clearly features the temporal dependencies.
fb785a4 | import numpy as np | |
| import torch | |
| import transformers | |
| _HIDDEN_STATES_START_POSITION = 2 | |
| def add_language_to_hidden(hidden_state: torch.Tensor, language: torch.Tensor): | |
| return torch.cat( | |
| [hidden_state, language.repeat((1, hidden_state.shape[1])).unsqueeze(dim=2)], | |
| dim=2 | |
| ) | |
| def language_classifer(language): | |
| """ | |
| Return a float identifying each known language. | |
| "fr" has value of 0, "it" a value of one. | |
| Other languages will have a value increasing in lexicographic order. | |
| :param str language: Language to identify, should be two letters. | |
| :return float: Unique identifier, between 0 and 1. | |
| """ | |
| if language == "fr": | |
| return 0 | |
| if language == "it": | |
| return 1 | |
| # Some random code to encode a two-letter language between 0 and 1 | |
| # "aa" should be 0+1=1 and "zz" should be 1+2=3 | |
| codes = ( | |
| (ord(letter) - ord("a")) / (ord("z") - ord("a")) + i | |
| for i, letter in enumerate(language) | |
| ) | |
| # Transform to [0, 1] | |
| return (sum(codes) - 1) / 2 | |
| class WavLMPhonemeFrIt(transformers.WavLMForCTC): | |
| """ | |
| PhonemeRecognizer: WavLM + Linear layer for speech recognition. | |
| It natively separates French and Italian. | |
| For a more professional implementation, view | |
| https://github.com/huggingface/transformers/blob/main/src/transformers/models/wav2vec2/modeling_wav2vec2.py | |
| """ | |
| def __init__(self, config, tokenizer=None): | |
| """ | |
| Create the new model out of a combination of both models. | |
| :param config: Model config. | |
| """ | |
| super().__init__(config) | |
| output_hidden_size = ( | |
| config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size | |
| ) | |
| # Replace head and add multilingualism | |
| self.lm_head = torch.nn.Linear(output_hidden_size + 1, config.vocab_size) | |
| self.tokenizer = tokenizer | |
| def forward( | |
| self, | |
| input_values: torch.Tensor, | |
| attention_mask: torch.Tensor = None, | |
| language: torch.Tensor = None, | |
| output_attentions: bool = None, | |
| output_hidden_states: bool = None, | |
| return_dict: bool = None, | |
| labels: torch.Tensor = None, | |
| ): | |
| """ | |
| Classify audio to a chain of phonemes of the same length. | |
| Stolen from | |
| https://github.com/huggingface/transformers/blob/6ba8a1ff4550b4450a22a0b0d907312955ce0fd5/src/transformers/models/wavlm/modeling_wavlm.py#L1196 | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if labels is not None and labels.max() >= self.config.vocab_size: | |
| raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") | |
| outputs = self.wavlm( | |
| input_values, | |
| attention_mask=attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| hidden_states = outputs[0] | |
| hidden_states = self.dropout(hidden_states) | |
| hidden_with_lang = add_language_to_hidden(hidden_states, language) | |
| logits = self.lm_head(hidden_with_lang) | |
| loss = None | |
| if labels is not None: | |
| # retrieve loss input_lengths from attention_mask | |
| attention_mask = ( | |
| attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) | |
| ) | |
| input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) | |
| # assuming that padded tokens are filled with -100 | |
| # when not being attended to | |
| labels_mask = labels >= 0 | |
| target_lengths = labels_mask.sum(-1) | |
| flattened_targets = labels.masked_select(labels_mask) | |
| # ctc_loss doesn't support fp16 | |
| log_probs = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) | |
| with torch.backends.cudnn.flags(enabled=False): | |
| loss = torch.nn.functional.ctc_loss( | |
| log_probs, | |
| flattened_targets, | |
| input_lengths, | |
| target_lengths, | |
| blank=self.config.pad_token_id, | |
| reduction=self.config.ctc_loss_reduction, | |
| zero_infinity=self.config.ctc_zero_infinity, | |
| ) | |
| if not return_dict: | |
| output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] | |
| return ((loss,) + output) if loss is not None else output | |
| return transformers.modeling_outputs.CausalLMOutput( | |
| loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions | |
| ) | |
| def freeze_feature_encoder_only(self): | |
| # Unfreeze base model | |
| for param in self.wavlm.parameters(): | |
| param.requires_grad = True | |
| # Now freeze the first layer | |
| self.freeze_feature_encoder() | |
| def freeze_layer(layer, freeze=True): | |
| for param in layer.parameters(): | |
| param.requires_grad = not freeze | |
| layer._requires_grad = not freeze | |
| def get_wavlm_phoneme_fr_it(tokenizer, freeze_hidden_layers=False): | |
| model = WavLMPhonemeFrIt.from_pretrained( | |
| "microsoft/wavlm-base-plus", | |
| ctc_loss_reduction="mean", | |
| pad_token_id=tokenizer.pad_token_id, | |
| vocab_size=len(tokenizer) | |
| ) | |
| model.tokenizer = tokenizer | |
| if freeze_hidden_layers: | |
| model.freeze_base_model() | |
| return model | |