| import functools |
| import os |
|
|
| import matplotlib.animation |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import sounddevice as sd |
| import torch |
| from torchcodec.decoders import AudioDecoder |
| import wavlm_phoneme_fr_it |
| import json |
|
|
| import common |
|
|
|
|
| def fake_model(chunk): |
| output_length = int(chunk.shape[0] * 0.02) |
| with open("vocab.json", "r") as vocab_file: |
| vocab = json.loads(vocab_file.read()) |
| vocab_size = len(vocab) + 3 |
| return np.random.rand(output_length, vocab_size) |
|
|
|
|
| def update_frame(frames, ax, matrix_plot, tokenizer=None, colorbar=None): |
| ax.clear() |
| ax.set_title( |
| "Activation levels for WavLM Base +'s hidden layers\n" |
| f"Layer = {frames[0] + 1}, T = {frames[1]}s" |
| ) |
| ax.set_xlabel("Phoneme Vocabulary") |
| ax.set_ylabel("Time Steps, and Selected Phoneme") |
| data = frames[2].detach().clone() |
| matrix_plot = ax.matshow(data, vmin=0, vmax=1, cmap='Blues') |
|
|
| if tokenizer is not None: |
| label_ids = torch.argmax(data, -1) |
| labels = tokenizer.batch_decode(label_ids) |
| ax.set_xticks([i for v, i in tokenizer.vocab.items() if v in labels]) |
| ax.set_xticklabels([v for v, i in tokenizer.vocab.items() if v in labels], rotation=45, ha='right') |
| ax.set_yticks([i for i, v in enumerate(labels) if v not in ("", "[PAD]")]) |
| ax.set_yticklabels([v for i, v in enumerate(labels) if v not in ("", "[PAD]")]) |
|
|
| |
| decoded_text = tokenizer.decode(label_ids) |
| if len(decoded_text) > 50: |
| decoded_text = decoded_text[:50] + "..." |
| ax.text( |
| 0.5, -0.15, f"Decoded: {decoded_text}", |
| transform=ax.transAxes, ha='center', va='top', |
| bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8) |
| ) |
|
|
| plt.tight_layout() |
| return ax, matrix_plot |
|
|
|
|
| def main(record_mic=False): |
| """ |
| Record an inference run of the model. |
| |
| :param bool record_mic: True to record from the microphone, False to use dummy file. |
| :return str: Path of the output file. |
| """ |
| audio_duration = 5 |
| split_length = 0.1 |
|
|
| if record_mic: |
| print("Recording the microphone...") |
| waveform = sd.rec( |
| int(audio_duration * common.SAMPLING_RATE), |
| samplerate=common.SAMPLING_RATE, |
| channels=1 |
| ).T |
| sd.wait() |
| print("Recording finished.") |
| else: |
| audio_file = "ceci est un test.wav" |
| decoded = AudioDecoder(audio_file).get_all_samples() |
| waveform = decoded.data.numpy() |
| assert decoded.sample_rate == common.SAMPLING_RATE, f"Bad audio frequency {decoded.sample_rate}" |
|
|
| |
| chunks = [] |
| for i in np.linspace(0, waveform.shape[1], int(audio_duration / split_length), dtype=np.uint64): |
| if i == 0: |
| continue |
| chunks.append(waveform[0, :i]) |
|
|
| model, processor = common.get_model() |
| inputs = processor( |
| chunks, |
| return_attention_mask=True, |
| sampling_rate=common.SAMPLING_RATE, |
| padding=True |
| ) |
| inputs.update({ |
| "input_values": torch.tensor(np.array(inputs["input_values"])), |
| "attention_mask": torch.tensor(np.array(inputs["attention_mask"])), |
| "language": torch.tensor([[0] for _ in enumerate(chunks)]) |
| }) |
|
|
| |
| hidden_outputs = [] |
| with torch.no_grad(): |
| output = model(**inputs, output_hidden_states=True) |
|
|
| for hidden in output.hidden_states: |
| hidden_balanced = wavlm_phoneme_fr_it.add_language_to_hidden(hidden, inputs["language"]) |
| hidden_outputs.append(torch.softmax(model.lm_head(hidden_balanced), dim=-1)) |
|
|
| logits = torch.softmax(output.logits, dim=-1) |
| logit_groups = [ |
| [ |
| torch.zeros((logits.shape[1], logits.shape[2])) |
| for __ in enumerate(chunks) |
| ] |
| for _ in range(len(hidden_outputs) + 1) |
| ] |
| fig, ax = plt.subplots(animated=True) |
| ax.set_title("Animation Preview") |
| matrix_plot = ax.matshow(logit_groups[0][0], animated=True, vmin=0, vmax=1, cmap='Blues') |
|
|
| |
| colorbar = plt.colorbar(matrix_plot, ax=ax, label='Activation Level') |
| logits_list = [] |
| masks = inputs["attention_mask"].sum(dim=1) / common.SAMPLING_RATE |
| for i, chunk in enumerate(chunks): |
| |
| logits_list.append(logits) |
| time_indices = int(logits.shape[1] * masks[i]) |
| for j, layer in enumerate(hidden_outputs + [logits]): |
| logit_groups[j][i][:time_indices] = layer[i, :time_indices] |
|
|
| |
| flattened = [] |
| for layer_index, layer_logits_list in enumerate(logit_groups): |
| for time_stamp_index, logits in enumerate(layer_logits_list): |
| flattened.append( |
| (layer_index, time_stamp_index * audio_duration / int(audio_duration / split_length), logits) |
| ) |
|
|
| |
| global animation |
| animation = matplotlib.animation.FuncAnimation( |
| fig, |
| functools.partial( |
| update_frame, |
| ax=ax, |
| matrix_plot=matrix_plot, |
| tokenizer=processor.tokenizer |
| ), |
| flattened, |
| interval=100, |
| repeat=False, |
| |
| ) |
| plt.show() |
|
|
| |
| dir_path = "outputs" |
| if not os.path.exists(dir_path) or not os.path.isdir(dir_path): |
| os.makedirs(dir_path) |
|
|
| if os.path.exists(f"{dir_path}/animated.webm"): |
| i = 1 |
| while os.path.exists(f"{dir_path}/animated_({i}).webm"): |
| i += 1 |
| file_name = f"{dir_path}/animated_({i}).webm" |
| else: |
| file_name = f"{dir_path}/animated.webm" |
| animation.save(file_name) |
| return file_name |
|
|
|
|
| if __name__ == "__main__": |
| animation = None |
| main(record_mic=False) |
|
|