Hugo Farajallah commited on
Commit
1b910a7
·
0 Parent(s):

feat(animation): initial animation of the logits with WavLM.

Browse files
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
+ *.gif filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ # Ignored generated figures
13
+ figures/*.png
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # WavLM Demo
2
+
3
+ Some simple utility script to show how WavLM works and how to use it.
4
+ It is all based on WavLM Base + Phonemizer FR-IT
5
+
6
+ ## Idea
7
+
8
+ - [x] Show activation logits of WavLM (fake model for now)
9
+ - [ ] Compare performances with Wav2Vec 2.0-Phonemizer-FR
10
+ - [x] Animate activation logits over time.
11
+ - [ ] SHow the result from the feature encoder.
ceci est un test.wav ADDED
Binary file (33.9 kB). View file
 
figures/.gitkeep ADDED
File without changes
main.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import matplotlib.animation
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import transformers
7
+ # import wavlm_phoneme_fr_it
8
+
9
+ SAMPLING_RATE = 16_000
10
+ VOCAB_SIZE = 100
11
+
12
+
13
+ def get_model():
14
+ checkpoint = "hugofara/wavlm-base-plus-phonemizer-fr-it"
15
+ processor = transformers.AutoProcessor.from_pretrained(
16
+ checkpoint, unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|"
17
+ )
18
+
19
+ model = wavlm_phoneme_fr_it.WavLMPhonemeFrIt.from_pretrained(
20
+ checkpoint
21
+ )
22
+ return model, processor
23
+
24
+
25
+ def fake_model(chunk):
26
+ output_length = int(chunk.shape[0] * 0.02)
27
+ return np.random.rand(output_length, VOCAB_SIZE)
28
+
29
+
30
+ def update_frame(data, ax):
31
+ ax.matshow(data)
32
+ return ax,
33
+
34
+
35
+ def main():
36
+ # model, processor = get_model()
37
+ audio_length = 5
38
+ split_length = 0.1
39
+ audio_file = np.random.rand(audio_length * SAMPLING_RATE)
40
+ # TODO: normalize audio
41
+
42
+ # Split audio
43
+ chunks = []
44
+ for i in np.linspace(0, audio_file.shape[0], int(audio_length / split_length), dtype=np.uint64):
45
+ if i == 0:
46
+ continue
47
+ chunks.append(audio_file[:i])
48
+
49
+ # Inference time
50
+ logit_groups = [
51
+ np.zeros((int(chunks[-1].shape[0] * 0.02), VOCAB_SIZE)) for _ in enumerate(chunks)
52
+ ]
53
+ fig, ax = plt.subplots(1, 1)
54
+ for i, chunk in enumerate(chunks):
55
+ logits = fake_model(chunk)
56
+ logit_groups[i][:logits.shape[0]] = logits
57
+
58
+ fig.savefig(f"figures/test{i}.png")
59
+
60
+ # Animate
61
+ global animation
62
+ animation = matplotlib.animation.FuncAnimation(
63
+ fig,
64
+ functools.partial(update_frame, ax=ax),
65
+ logit_groups,
66
+ # blit=True
67
+ )
68
+ animation.save("animated.webm")
69
+
70
+
71
+ if __name__ == "__main__":
72
+ animation = None
73
+ main()
pyproject.toml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "wavlm-demo"
3
+ version = "0.1.0"
4
+ description = "Demonstration project for WavLM"
5
+ readme = "README.md"
6
+ requires-python = ">=3.12"
7
+ dependencies = [
8
+ "matplotlib>=3.10.6",
9
+ "numpy>=2.3.3",
10
+ "pyqt6>=6.9.1",
11
+ "transformers>=4.56.1",
12
+ ]
uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
wavlm_phoneme_fr_it.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import transformers
4
+
5
+
6
+ _HIDDEN_STATES_START_POSITION = 2
7
+
8
+
9
+ def add_language_to_hidden(input_values, language):
10
+ if isinstance(language, str):
11
+ raise TypeError("Language should be None, list of torch.Tensor, not str")
12
+ input_batch = torch.empty(
13
+ (input_values.shape[0], input_values.shape[1], input_values.shape[2] + 1),
14
+ dtype=input_values.dtype,
15
+ device=input_values.device
16
+ )
17
+
18
+ input_batch[:, :, :-1] = input_values
19
+ input_batch[:, :, -1] = language
20
+
21
+ return input_batch
22
+ if language is None:
23
+ lang_val = torch.zeros((input_values.shape[1],))
24
+ elif isinstance(language, torch.Tensor) and len(language.shape) == 0:
25
+ lang_val = language
26
+ elif isinstance(language, np.ndarray):
27
+ lang_val = torch.tensor(language)
28
+ else:
29
+ lang_val = (
30
+ torch
31
+ .tensor([[lang] for lang in language])
32
+ .repeat((1, input_batch.shape[1]))
33
+ )
34
+ input_batch[:, :, -1] *= lang_val
35
+
36
+ return input_batch
37
+
38
+
39
+ def language_classifer(language):
40
+ """
41
+ Return a float identifying each known language.
42
+
43
+ "fr" has value of 0, "it" a value of one.
44
+ Other languages will have a value increasing in lexicographic order.
45
+
46
+ :param str language: Language to identify, should be two letters.
47
+ :return float: Unique identifier, between 0 and 1.
48
+ """
49
+ if language == "fr":
50
+ return 0
51
+ if language == "it":
52
+ return 1
53
+
54
+ # Some random code to encode a two-letter language between 0 and 1
55
+ # "aa" should be 0+1=1 and "zz" should be 1+2=3
56
+ codes = (
57
+ (ord(letter) - ord("a")) / (ord("z") - ord("a")) + i
58
+ for i, letter in enumerate(language)
59
+ )
60
+ # Transform to [0, 1]
61
+ return (sum(codes) - 1) / 2
62
+
63
+
64
+ class WavLMPhonemeFrIt(transformers.WavLMForCTC):
65
+ """
66
+ PhonemeRecognizer: WavLM + Linear layer for speech recognition.
67
+
68
+ It natively separates French and Italian.
69
+
70
+ For a more professional implementation, view
71
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/wav2vec2/modeling_wav2vec2.py
72
+ """
73
+
74
+ def __init__(self, config, tokenizer=None):
75
+ """
76
+ Create the new model out of a combination of both models.
77
+
78
+ :param config: Model config.
79
+ """
80
+ super().__init__(config)
81
+ output_hidden_size = (
82
+ config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size
83
+ )
84
+ # Replace head and add multilingualism
85
+ self.lm_head = torch.nn.Linear(output_hidden_size + 1, config.vocab_size)
86
+ self.tokenizer = tokenizer
87
+
88
+ def forward(
89
+ self,
90
+ input_values: torch.Tensor,
91
+ attention_mask: torch.Tensor = None,
92
+ language: torch.Tensor = None,
93
+ output_attentions: bool = None,
94
+ output_hidden_states: bool = None,
95
+ return_dict: bool = None,
96
+ labels: torch.Tensor = None,
97
+ ):
98
+ """
99
+ Classify audio to a chain of phonemes of the same length.
100
+
101
+ Stolen from
102
+ https://github.com/huggingface/transformers/blob/6ba8a1ff4550b4450a22a0b0d907312955ce0fd5/src/transformers/models/wavlm/modeling_wavlm.py#L1196
103
+ """
104
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
105
+
106
+ if labels is not None and labels.max() >= self.config.vocab_size:
107
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
108
+
109
+ outputs = self.wavlm(
110
+ input_values,
111
+ attention_mask=attention_mask,
112
+ output_attentions=output_attentions,
113
+ output_hidden_states=output_hidden_states,
114
+ return_dict=return_dict,
115
+ )
116
+
117
+ hidden_states = outputs[0]
118
+ hidden_states = self.dropout(hidden_states)
119
+
120
+ # hidden_with_lang = add_language_to_hidden(hidden_states, language)
121
+ hidden_with_lang = torch.cat(
122
+ [hidden_states, language.repeat(hidden_states.shape[1]).reshape((1, -1, 1))],
123
+ dim=2
124
+ )
125
+
126
+ logits = self.lm_head(hidden_with_lang)
127
+
128
+ loss = None
129
+ if labels is not None:
130
+ # retrieve loss input_lengths from attention_mask
131
+ attention_mask = (
132
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
133
+ )
134
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
135
+
136
+ # assuming that padded tokens are filled with -100
137
+ # when not being attended to
138
+ labels_mask = labels >= 0
139
+ target_lengths = labels_mask.sum(-1)
140
+ flattened_targets = labels.masked_select(labels_mask)
141
+
142
+ # ctc_loss doesn't support fp16
143
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
144
+
145
+ with torch.backends.cudnn.flags(enabled=False):
146
+ loss = torch.nn.functional.ctc_loss(
147
+ log_probs,
148
+ flattened_targets,
149
+ input_lengths,
150
+ target_lengths,
151
+ blank=self.config.pad_token_id,
152
+ reduction=self.config.ctc_loss_reduction,
153
+ zero_infinity=self.config.ctc_zero_infinity,
154
+ )
155
+
156
+ if not return_dict:
157
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
158
+ return ((loss,) + output) if loss is not None else output
159
+
160
+ return transformers.modeling_outputs.CausalLMOutput(
161
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
162
+ )
163
+
164
+ def freeze_feature_encoder_only(self):
165
+ # Unfreeze base model
166
+ for param in self.wavlm.parameters():
167
+ param.requires_grad = True
168
+ # Now freeze the first layer
169
+ self.freeze_feature_encoder()
170
+
171
+
172
+ def freeze_layer(layer, freeze=True):
173
+ for param in layer.parameters():
174
+ param.requires_grad = not freeze
175
+ layer._requires_grad = not freeze
176
+
177
+
178
+ def get_wavlm_phoneme_fr_it(tokenizer, freeze_hidden_layers=False):
179
+ model = WavLMPhonemeFrIt.from_pretrained(
180
+ "microsoft/wavlm-base-plus",
181
+ ctc_loss_reduction="mean",
182
+ pad_token_id=tokenizer.pad_token_id,
183
+ vocab_size=len(tokenizer)
184
+ )
185
+ model.tokenizer = tokenizer
186
+ if freeze_hidden_layers:
187
+ model.freeze_base_model()
188
+ return model