wavlm-phonemizer-word-detection / wavlm_phoneme_fr_it.py
Hugo Farajallah
feat(animation): now the animation clearly features the temporal dependencies.
fb785a4
import numpy as np
import torch
import transformers
_HIDDEN_STATES_START_POSITION = 2
def add_language_to_hidden(hidden_state: torch.Tensor, language: torch.Tensor):
return torch.cat(
[hidden_state, language.repeat((1, hidden_state.shape[1])).unsqueeze(dim=2)],
dim=2
)
def language_classifer(language):
"""
Return a float identifying each known language.
"fr" has value of 0, "it" a value of one.
Other languages will have a value increasing in lexicographic order.
:param str language: Language to identify, should be two letters.
:return float: Unique identifier, between 0 and 1.
"""
if language == "fr":
return 0
if language == "it":
return 1
# Some random code to encode a two-letter language between 0 and 1
# "aa" should be 0+1=1 and "zz" should be 1+2=3
codes = (
(ord(letter) - ord("a")) / (ord("z") - ord("a")) + i
for i, letter in enumerate(language)
)
# Transform to [0, 1]
return (sum(codes) - 1) / 2
class WavLMPhonemeFrIt(transformers.WavLMForCTC):
"""
PhonemeRecognizer: WavLM + Linear layer for speech recognition.
It natively separates French and Italian.
For a more professional implementation, view
https://github.com/huggingface/transformers/blob/main/src/transformers/models/wav2vec2/modeling_wav2vec2.py
"""
def __init__(self, config, tokenizer=None):
"""
Create the new model out of a combination of both models.
:param config: Model config.
"""
super().__init__(config)
output_hidden_size = (
config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
)
# Replace head and add multilingualism
self.lm_head = torch.nn.Linear(output_hidden_size + 1, config.vocab_size)
self.tokenizer = tokenizer
def forward(
self,
input_values: torch.Tensor,
attention_mask: torch.Tensor = None,
language: torch.Tensor = None,
output_attentions: bool = None,
output_hidden_states: bool = None,
return_dict: bool = None,
labels: torch.Tensor = None,
):
"""
Classify audio to a chain of phonemes of the same length.
Stolen from
https://github.com/huggingface/transformers/blob/6ba8a1ff4550b4450a22a0b0d907312955ce0fd5/src/transformers/models/wavlm/modeling_wavlm.py#L1196
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None and labels.max() >= self.config.vocab_size:
raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
outputs = self.wavlm(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
hidden_with_lang = add_language_to_hidden(hidden_states, language)
logits = self.lm_head(hidden_with_lang)
loss = None
if labels is not None:
# retrieve loss input_lengths from attention_mask
attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
)
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
# assuming that padded tokens are filled with -100
# when not being attended to
labels_mask = labels >= 0
target_lengths = labels_mask.sum(-1)
flattened_targets = labels.masked_select(labels_mask)
# ctc_loss doesn't support fp16
log_probs = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
loss = torch.nn.functional.ctc_loss(
log_probs,
flattened_targets,
input_lengths,
target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
if not return_dict:
output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
return ((loss,) + output) if loss is not None else output
return transformers.modeling_outputs.CausalLMOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)
def freeze_feature_encoder_only(self):
# Unfreeze base model
for param in self.wavlm.parameters():
param.requires_grad = True
# Now freeze the first layer
self.freeze_feature_encoder()
def freeze_layer(layer, freeze=True):
for param in layer.parameters():
param.requires_grad = not freeze
layer._requires_grad = not freeze
def get_wavlm_phoneme_fr_it(tokenizer, freeze_hidden_layers=False):
model = WavLMPhonemeFrIt.from_pretrained(
"microsoft/wavlm-base-plus",
ctc_loss_reduction="mean",
pad_token_id=tokenizer.pad_token_id,
vocab_size=len(tokenizer)
)
model.tokenizer = tokenizer
if freeze_hidden_layers:
model.freeze_base_model()
return model