Hugo Farajallah commited on
Commit ·
0cde9d4
1
Parent(s): 54520eb
ui(charts): better display of the data.
Browse files- dataset_process.py +52 -39
- main.py +23 -9
dataset_process.py
CHANGED
|
@@ -126,58 +126,71 @@ def display_matrix_result(path_matrix, matching, prediction, target, processor=N
|
|
| 126 |
|
| 127 |
Returns the figure instead of showing it directly for use in Gradio.
|
| 128 |
"""
|
| 129 |
-
fig, axis = plt.subplots(figsize=(
|
| 130 |
|
| 131 |
if processor is None:
|
| 132 |
_model, processor = common.get_model()
|
| 133 |
|
| 134 |
# Display the matrix
|
| 135 |
im = axis.matshow(path_matrix.T, aspect="auto", cmap='Blues')
|
| 136 |
-
plt.colorbar(im, ax=axis)
|
|
|
|
| 137 |
|
| 138 |
-
# Set the labels for the axes
|
| 139 |
-
axis.set_xlabel('Predicted
|
| 140 |
-
axis.
|
|
|
|
|
|
|
| 141 |
|
| 142 |
-
#
|
| 143 |
predicted_labels = tuple(map(processor.decode, torch.argmax(prediction, -1)[0]))
|
| 144 |
-
axis.set_xticks(
|
| 145 |
-
[i for i, label in enumerate(predicted_labels) if label == ""],
|
| 146 |
-
labels=[label for label in predicted_labels if label == ""]
|
| 147 |
-
)
|
| 148 |
-
axis.set_xticks(
|
| 149 |
-
[i for i, label in enumerate(predicted_labels) if label not in ("[PAD]", "")],
|
| 150 |
-
labels=[label for label in predicted_labels if label not in ("[PAD]", "")],
|
| 151 |
-
minor=True
|
| 152 |
-
)
|
| 153 |
-
|
| 154 |
-
axis.set_ylabel('Target String', fontsize=12)
|
| 155 |
target_labels = tuple(map(processor.decode, torch.argmax(target, -1)[0]))
|
| 156 |
-
axis.set_yticks(
|
| 157 |
-
[i for i, label in enumerate(target_labels) if label == ""],
|
| 158 |
-
labels=[label for label in target_labels if label == ""]
|
| 159 |
-
)
|
| 160 |
-
axis.set_yticks(
|
| 161 |
-
[i for i, label in enumerate(target_labels) if label != ""],
|
| 162 |
-
labels=[label for label in target_labels if label != ""],
|
| 163 |
-
minor=True
|
| 164 |
-
)
|
| 165 |
|
| 166 |
-
axis
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
)
|
| 179 |
|
| 180 |
-
axis.legend()
|
| 181 |
plt.tight_layout()
|
| 182 |
|
| 183 |
return fig
|
|
|
|
| 126 |
|
| 127 |
Returns the figure instead of showing it directly for use in Gradio.
|
| 128 |
"""
|
| 129 |
+
fig, axis = plt.subplots(figsize=(12, 8))
|
| 130 |
|
| 131 |
if processor is None:
|
| 132 |
_model, processor = common.get_model()
|
| 133 |
|
| 134 |
# Display the matrix
|
| 135 |
im = axis.matshow(path_matrix.T, aspect="auto", cmap='Blues')
|
| 136 |
+
cbar = plt.colorbar(im, ax=axis)
|
| 137 |
+
cbar.set_label('Alignment Cost', rotation=270, labelpad=20, fontsize=11)
|
| 138 |
|
| 139 |
+
# Set the labels for the axes with clearer names
|
| 140 |
+
axis.set_xlabel('Predicted Phoneme Sequence', fontsize=12)
|
| 141 |
+
axis.set_ylabel('Target Phoneme Sequence', fontsize=12)
|
| 142 |
+
axis.set_title('Phoneme Alignment Matrix\n(Blue = Lower Cost, Red Line = Optimal Path)',
|
| 143 |
+
fontsize=14, pad=20)
|
| 144 |
|
| 145 |
+
# Get phoneme labels for both axes
|
| 146 |
predicted_labels = tuple(map(processor.decode, torch.argmax(prediction, -1)[0]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
target_labels = tuple(map(processor.decode, torch.argmax(target, -1)[0]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
# Set x-axis ticks (predicted phonemes)
|
| 150 |
+
non_empty_pred_indices = [i for i, label in enumerate(predicted_labels) if label not in ("", "[PAD]")]
|
| 151 |
+
non_empty_pred_labels = [label for i, label in enumerate(predicted_labels) if label not in ("", "[PAD]")]
|
| 152 |
+
|
| 153 |
+
if non_empty_pred_indices:
|
| 154 |
+
axis.set_xticks(non_empty_pred_indices)
|
| 155 |
+
axis.set_xticklabels(non_empty_pred_labels, rotation=45, ha='right', fontsize=10)
|
| 156 |
+
|
| 157 |
+
# Set y-axis ticks (target phonemes)
|
| 158 |
+
non_empty_target_indices = [i for i, label in enumerate(target_labels) if label not in ("", "[PAD]")]
|
| 159 |
+
non_empty_target_labels = [label for i, label in enumerate(target_labels) if label not in ("", "[PAD]")]
|
| 160 |
+
|
| 161 |
+
if non_empty_target_indices:
|
| 162 |
+
axis.set_yticks(non_empty_target_indices)
|
| 163 |
+
axis.set_yticklabels(non_empty_target_labels, fontsize=10)
|
| 164 |
+
|
| 165 |
+
# Add subtle grid
|
| 166 |
+
axis.grid(which="major", color="gray", alpha=0.2, linestyle="-")
|
| 167 |
+
|
| 168 |
+
# Plot the optimal path in red with better visibility
|
| 169 |
+
if matching:
|
| 170 |
+
axis.plot(
|
| 171 |
+
[val[0] for val in matching],
|
| 172 |
+
[val[1] for val in matching],
|
| 173 |
+
color="red",
|
| 174 |
+
linewidth=3,
|
| 175 |
+
marker='o',
|
| 176 |
+
markersize=4,
|
| 177 |
+
markerfacecolor='white',
|
| 178 |
+
markeredgecolor='red',
|
| 179 |
+
markeredgewidth=2,
|
| 180 |
+
label="Optimal Alignment Path",
|
| 181 |
+
alpha=0.9
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Add legend with better positioning
|
| 185 |
+
axis.legend(loc='upper right', bbox_to_anchor=(1.0, 1.0), fontsize=11)
|
| 186 |
+
|
| 187 |
+
# Add text annotations for better understanding
|
| 188 |
+
axis.text(
|
| 189 |
+
0.02, 0.98, 'Lower values indicate\nbetter alignment',
|
| 190 |
+
transform=axis.transAxes, fontsize=9, va='top', ha='left',
|
| 191 |
+
bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8)
|
| 192 |
)
|
| 193 |
|
|
|
|
| 194 |
plt.tight_layout()
|
| 195 |
|
| 196 |
return fig
|
main.py
CHANGED
|
@@ -20,25 +20,36 @@ def fake_model(chunk):
|
|
| 20 |
return np.random.rand(output_length, vocab_size)
|
| 21 |
|
| 22 |
|
| 23 |
-
def update_frame(frames, ax, matrix_plot, tokenizer=None):
|
| 24 |
ax.clear()
|
| 25 |
ax.set_title(
|
| 26 |
"Activation levels for WavLM Base +'s hidden layers\n"
|
| 27 |
-
f"Layer = {frames[0]}, T = {frames[1]}s"
|
| 28 |
)
|
| 29 |
-
ax.set_xlabel("
|
| 30 |
-
ax.set_ylabel("
|
| 31 |
data = frames[2].detach().clone()
|
| 32 |
-
matrix_plot = ax.matshow(data, vmin=0, vmax=1)
|
|
|
|
| 33 |
if tokenizer is not None:
|
| 34 |
label_ids = torch.argmax(data, -1)
|
| 35 |
labels = tokenizer.batch_decode(label_ids)
|
| 36 |
ax.set_xticks([i for v, i in tokenizer.vocab.items() if v in labels])
|
| 37 |
-
ax.set_xticklabels([v for v, i in tokenizer.vocab.items() if v in labels])
|
| 38 |
ax.set_yticks([i for i, v in enumerate(labels) if v not in ("", "[PAD]")])
|
| 39 |
ax.set_yticklabels([v for i, v in enumerate(labels) if v not in ("", "[PAD]")])
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
return ax, matrix_plot
|
| 43 |
|
| 44 |
|
|
@@ -100,7 +111,10 @@ def main(record_mic=False):
|
|
| 100 |
]
|
| 101 |
fig, ax = plt.subplots(animated=True)
|
| 102 |
ax.set_title("Animation Preview")
|
| 103 |
-
matrix_plot = ax.matshow(logit_groups[0][0], animated=True, vmin=0, vmax=1)
|
|
|
|
|
|
|
|
|
|
| 104 |
logits_list = []
|
| 105 |
masks = inputs["attention_mask"].sum(dim=1) / common.SAMPLING_RATE
|
| 106 |
for i, chunk in enumerate(chunks):
|
|
|
|
| 20 |
return np.random.rand(output_length, vocab_size)
|
| 21 |
|
| 22 |
|
| 23 |
+
def update_frame(frames, ax, matrix_plot, tokenizer=None, colorbar=None):
|
| 24 |
ax.clear()
|
| 25 |
ax.set_title(
|
| 26 |
"Activation levels for WavLM Base +'s hidden layers\n"
|
| 27 |
+
f"Layer = {frames[0] + 1}, T = {frames[1]}s"
|
| 28 |
)
|
| 29 |
+
ax.set_xlabel("Phoneme Vocabulary")
|
| 30 |
+
ax.set_ylabel("Time Steps, and Selected Phoneme")
|
| 31 |
data = frames[2].detach().clone()
|
| 32 |
+
matrix_plot = ax.matshow(data, vmin=0, vmax=1, cmap='Blues')
|
| 33 |
+
|
| 34 |
if tokenizer is not None:
|
| 35 |
label_ids = torch.argmax(data, -1)
|
| 36 |
labels = tokenizer.batch_decode(label_ids)
|
| 37 |
ax.set_xticks([i for v, i in tokenizer.vocab.items() if v in labels])
|
| 38 |
+
ax.set_xticklabels([v for v, i in tokenizer.vocab.items() if v in labels], rotation=45, ha='right')
|
| 39 |
ax.set_yticks([i for i, v in enumerate(labels) if v not in ("", "[PAD]")])
|
| 40 |
ax.set_yticklabels([v for i, v in enumerate(labels) if v not in ("", "[PAD]")])
|
| 41 |
+
|
| 42 |
+
# Position the decoded text below the plot with proper spacing
|
| 43 |
+
decoded_text = tokenizer.decode(label_ids)
|
| 44 |
+
if len(decoded_text) > 50:
|
| 45 |
+
decoded_text = decoded_text[:50] + "..."
|
| 46 |
+
ax.text(
|
| 47 |
+
0.5, -0.15, f"Decoded: {decoded_text}",
|
| 48 |
+
transform=ax.transAxes, ha='center', va='top',
|
| 49 |
+
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
plt.tight_layout()
|
| 53 |
return ax, matrix_plot
|
| 54 |
|
| 55 |
|
|
|
|
| 111 |
]
|
| 112 |
fig, ax = plt.subplots(animated=True)
|
| 113 |
ax.set_title("Animation Preview")
|
| 114 |
+
matrix_plot = ax.matshow(logit_groups[0][0], animated=True, vmin=0, vmax=1, cmap='Blues')
|
| 115 |
+
|
| 116 |
+
# Add colorbar once for the entire animation
|
| 117 |
+
colorbar = plt.colorbar(matrix_plot, ax=ax, label='Activation Level')
|
| 118 |
logits_list = []
|
| 119 |
masks = inputs["attention_mask"].sum(dim=1) / common.SAMPLING_RATE
|
| 120 |
for i, chunk in enumerate(chunks):
|