Hugo Farajallah
chore(general): slight polish to the code, adds documentation.
82e797f
import functools
import os
import matplotlib.animation
import matplotlib.pyplot as plt
import numpy as np
import sounddevice as sd
import torch
from torchcodec.decoders import AudioDecoder
import wavlm_phoneme_fr_it
import json
import common
def fake_model(chunk):
output_length = int(chunk.shape[0] * 0.02)
with open("vocab.json", "r") as vocab_file:
vocab = json.loads(vocab_file.read())
vocab_size = len(vocab) + 3
return np.random.rand(output_length, vocab_size)
def update_frame(frames, ax, matrix_plot, tokenizer=None, colorbar=None):
ax.clear()
ax.set_title(
"Activation levels for WavLM Base +'s hidden layers\n"
f"Layer = {frames[0] + 1}, T = {frames[1]}s"
)
ax.set_xlabel("Phoneme Vocabulary")
ax.set_ylabel("Time Steps, and Selected Phoneme")
data = frames[2].detach().clone()
matrix_plot = ax.matshow(data, vmin=0, vmax=1, cmap='Blues')
if tokenizer is not None:
label_ids = torch.argmax(data, -1)
labels = tokenizer.batch_decode(label_ids)
ax.set_xticks([i for v, i in tokenizer.vocab.items() if v in labels])
ax.set_xticklabels([v for v, i in tokenizer.vocab.items() if v in labels], rotation=45, ha='right')
ax.set_yticks([i for i, v in enumerate(labels) if v not in ("", "[PAD]")])
ax.set_yticklabels([v for i, v in enumerate(labels) if v not in ("", "[PAD]")])
# Position the decoded text below the plot with proper spacing
decoded_text = tokenizer.decode(label_ids)
if len(decoded_text) > 50:
decoded_text = decoded_text[:50] + "..."
ax.text(
0.5, -0.15, f"Decoded: {decoded_text}",
transform=ax.transAxes, ha='center', va='top',
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8)
)
plt.tight_layout()
return ax, matrix_plot
def main(record_mic=False):
"""
Record an inference run of the model.
:param bool record_mic: True to record from the microphone, False to use dummy file.
:return str: Path of the output file.
"""
audio_duration = 5
split_length = 0.1
if record_mic:
print("Recording the microphone...")
waveform = sd.rec(
int(audio_duration * common.SAMPLING_RATE),
samplerate=common.SAMPLING_RATE,
channels=1
).T
sd.wait() # Wait until recording is finished
print("Recording finished.")
else:
audio_file = "ceci est un test.wav"
decoded = AudioDecoder(audio_file).get_all_samples()
waveform = decoded.data.numpy()
assert decoded.sample_rate == common.SAMPLING_RATE, f"Bad audio frequency {decoded.sample_rate}"
# Split audio
chunks = []
for i in np.linspace(0, waveform.shape[1], int(audio_duration / split_length), dtype=np.uint64):
if i == 0:
continue
chunks.append(waveform[0, :i])
model, processor = common.get_model()
inputs = processor(
chunks,
return_attention_mask=True,
sampling_rate=common.SAMPLING_RATE,
padding=True
)
inputs.update({
"input_values": torch.tensor(np.array(inputs["input_values"])),
"attention_mask": torch.tensor(np.array(inputs["attention_mask"])),
"language": torch.tensor([[0] for _ in enumerate(chunks)])
})
# Inference time
hidden_outputs = []
with torch.no_grad():
output = model(**inputs, output_hidden_states=True)
for hidden in output.hidden_states:
hidden_balanced = wavlm_phoneme_fr_it.add_language_to_hidden(hidden, inputs["language"])
hidden_outputs.append(torch.softmax(model.lm_head(hidden_balanced), dim=-1))
logits = torch.softmax(output.logits, dim=-1)
logit_groups = [
[
torch.zeros((logits.shape[1], logits.shape[2]))
for __ in enumerate(chunks)
]
for _ in range(len(hidden_outputs) + 1)
]
fig, ax = plt.subplots(animated=True)
ax.set_title("Animation Preview")
matrix_plot = ax.matshow(logit_groups[0][0], animated=True, vmin=0, vmax=1, cmap='Blues')
# Add colorbar once for the entire animation
colorbar = plt.colorbar(matrix_plot, ax=ax, label='Activation Level')
logits_list = []
masks = inputs["attention_mask"].sum(dim=1) / common.SAMPLING_RATE
for i, chunk in enumerate(chunks):
# logits = fake_model(chunk) # for testing purposes only
logits_list.append(logits)
time_indices = int(logits.shape[1] * masks[i])
for j, layer in enumerate(hidden_outputs + [logits]):
logit_groups[j][i][:time_indices] = layer[i, :time_indices]
# Flatten frames
flattened = []
for layer_index, layer_logits_list in enumerate(logit_groups):
for time_stamp_index, logits in enumerate(layer_logits_list):
flattened.append(
(layer_index, time_stamp_index * audio_duration / int(audio_duration / split_length), logits)
)
# Animate
global animation
animation = matplotlib.animation.FuncAnimation(
fig,
functools.partial(
update_frame,
ax=ax,
matrix_plot=matrix_plot,
tokenizer=processor.tokenizer
),
flattened,
interval=100,
repeat=False,
# blit=True
)
plt.show()
# Save to file
dir_path = "outputs"
if not os.path.exists(dir_path) or not os.path.isdir(dir_path):
os.makedirs(dir_path)
if os.path.exists(f"{dir_path}/animated.webm"):
i = 1
while os.path.exists(f"{dir_path}/animated_({i}).webm"):
i += 1
file_name = f"{dir_path}/animated_({i}).webm"
else:
file_name = f"{dir_path}/animated.webm"
animation.save(file_name)
return file_name
if __name__ == "__main__":
animation = None
main(record_mic=False)