Hugo Farajallah commited on
Commit
0cde9d4
·
1 Parent(s): 54520eb

ui(charts): better display of the data.

Browse files
Files changed (2) hide show
  1. dataset_process.py +52 -39
  2. main.py +23 -9
dataset_process.py CHANGED
@@ -126,58 +126,71 @@ def display_matrix_result(path_matrix, matching, prediction, target, processor=N
126
 
127
  Returns the figure instead of showing it directly for use in Gradio.
128
  """
129
- fig, axis = plt.subplots(figsize=(10, 6))
130
 
131
  if processor is None:
132
  _model, processor = common.get_model()
133
 
134
  # Display the matrix
135
  im = axis.matshow(path_matrix.T, aspect="auto", cmap='Blues')
136
- plt.colorbar(im, ax=axis)
 
137
 
138
- # Set the labels for the axes
139
- axis.set_xlabel('Predicted String', fontsize=12)
140
- axis.set_title('Alignment Matrix: Predicted vs Target Phonemes', fontsize=14, pad=20)
 
 
141
 
142
- # String for the x-axis
143
  predicted_labels = tuple(map(processor.decode, torch.argmax(prediction, -1)[0]))
144
- axis.set_xticks(
145
- [i for i, label in enumerate(predicted_labels) if label == ""],
146
- labels=[label for label in predicted_labels if label == ""]
147
- )
148
- axis.set_xticks(
149
- [i for i, label in enumerate(predicted_labels) if label not in ("[PAD]", "")],
150
- labels=[label for label in predicted_labels if label not in ("[PAD]", "")],
151
- minor=True
152
- )
153
-
154
- axis.set_ylabel('Target String', fontsize=12)
155
  target_labels = tuple(map(processor.decode, torch.argmax(target, -1)[0]))
156
- axis.set_yticks(
157
- [i for i, label in enumerate(target_labels) if label == ""],
158
- labels=[label for label in target_labels if label == ""]
159
- )
160
- axis.set_yticks(
161
- [i for i, label in enumerate(target_labels) if label != ""],
162
- labels=[label for label in target_labels if label != ""],
163
- minor=True
164
- )
165
 
166
- axis.grid(which="major", color="black", alpha=0.3)
167
- axis.grid(which="minor", linestyle="--", alpha=0.2)
168
-
169
- # Plot the optimal path in red
170
- axis.plot(
171
- [val[0] for val in matching],
172
- [val[1] for val in matching],
173
- color="red",
174
- linewidth=2,
175
- marker='o',
176
- markersize=3,
177
- label="Optimal Alignment Path"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  )
179
 
180
- axis.legend()
181
  plt.tight_layout()
182
 
183
  return fig
 
126
 
127
  Returns the figure instead of showing it directly for use in Gradio.
128
  """
129
+ fig, axis = plt.subplots(figsize=(12, 8))
130
 
131
  if processor is None:
132
  _model, processor = common.get_model()
133
 
134
  # Display the matrix
135
  im = axis.matshow(path_matrix.T, aspect="auto", cmap='Blues')
136
+ cbar = plt.colorbar(im, ax=axis)
137
+ cbar.set_label('Alignment Cost', rotation=270, labelpad=20, fontsize=11)
138
 
139
+ # Set the labels for the axes with clearer names
140
+ axis.set_xlabel('Predicted Phoneme Sequence', fontsize=12)
141
+ axis.set_ylabel('Target Phoneme Sequence', fontsize=12)
142
+ axis.set_title('Phoneme Alignment Matrix\n(Blue = Lower Cost, Red Line = Optimal Path)',
143
+ fontsize=14, pad=20)
144
 
145
+ # Get phoneme labels for both axes
146
  predicted_labels = tuple(map(processor.decode, torch.argmax(prediction, -1)[0]))
 
 
 
 
 
 
 
 
 
 
 
147
  target_labels = tuple(map(processor.decode, torch.argmax(target, -1)[0]))
 
 
 
 
 
 
 
 
 
148
 
149
+ # Set x-axis ticks (predicted phonemes)
150
+ non_empty_pred_indices = [i for i, label in enumerate(predicted_labels) if label not in ("", "[PAD]")]
151
+ non_empty_pred_labels = [label for i, label in enumerate(predicted_labels) if label not in ("", "[PAD]")]
152
+
153
+ if non_empty_pred_indices:
154
+ axis.set_xticks(non_empty_pred_indices)
155
+ axis.set_xticklabels(non_empty_pred_labels, rotation=45, ha='right', fontsize=10)
156
+
157
+ # Set y-axis ticks (target phonemes)
158
+ non_empty_target_indices = [i for i, label in enumerate(target_labels) if label not in ("", "[PAD]")]
159
+ non_empty_target_labels = [label for i, label in enumerate(target_labels) if label not in ("", "[PAD]")]
160
+
161
+ if non_empty_target_indices:
162
+ axis.set_yticks(non_empty_target_indices)
163
+ axis.set_yticklabels(non_empty_target_labels, fontsize=10)
164
+
165
+ # Add subtle grid
166
+ axis.grid(which="major", color="gray", alpha=0.2, linestyle="-")
167
+
168
+ # Plot the optimal path in red with better visibility
169
+ if matching:
170
+ axis.plot(
171
+ [val[0] for val in matching],
172
+ [val[1] for val in matching],
173
+ color="red",
174
+ linewidth=3,
175
+ marker='o',
176
+ markersize=4,
177
+ markerfacecolor='white',
178
+ markeredgecolor='red',
179
+ markeredgewidth=2,
180
+ label="Optimal Alignment Path",
181
+ alpha=0.9
182
+ )
183
+
184
+ # Add legend with better positioning
185
+ axis.legend(loc='upper right', bbox_to_anchor=(1.0, 1.0), fontsize=11)
186
+
187
+ # Add text annotations for better understanding
188
+ axis.text(
189
+ 0.02, 0.98, 'Lower values indicate\nbetter alignment',
190
+ transform=axis.transAxes, fontsize=9, va='top', ha='left',
191
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8)
192
  )
193
 
 
194
  plt.tight_layout()
195
 
196
  return fig
main.py CHANGED
@@ -20,25 +20,36 @@ def fake_model(chunk):
20
  return np.random.rand(output_length, vocab_size)
21
 
22
 
23
- def update_frame(frames, ax, matrix_plot, tokenizer=None):
24
  ax.clear()
25
  ax.set_title(
26
  "Activation levels for WavLM Base +'s hidden layers\n"
27
- f"Layer = {frames[0]}, T = {frames[1]}s"
28
  )
29
- ax.set_xlabel("Phonemes list")
30
- ax.set_ylabel("Selected phoneme by Timestamp")
31
  data = frames[2].detach().clone()
32
- matrix_plot = ax.matshow(data, vmin=0, vmax=1)
 
33
  if tokenizer is not None:
34
  label_ids = torch.argmax(data, -1)
35
  labels = tokenizer.batch_decode(label_ids)
36
  ax.set_xticks([i for v, i in tokenizer.vocab.items() if v in labels])
37
- ax.set_xticklabels([v for v, i in tokenizer.vocab.items() if v in labels])
38
  ax.set_yticks([i for i, v in enumerate(labels) if v not in ("", "[PAD]")])
39
  ax.set_yticklabels([v for i, v in enumerate(labels) if v not in ("", "[PAD]")])
40
- ax.text(0, data.shape[0] + 15, "Decoded: " + tokenizer.decode(label_ids))
41
- # matrix_plot.set_data(data)
 
 
 
 
 
 
 
 
 
 
42
  return ax, matrix_plot
43
 
44
 
@@ -100,7 +111,10 @@ def main(record_mic=False):
100
  ]
101
  fig, ax = plt.subplots(animated=True)
102
  ax.set_title("Animation Preview")
103
- matrix_plot = ax.matshow(logit_groups[0][0], animated=True, vmin=0, vmax=1)
 
 
 
104
  logits_list = []
105
  masks = inputs["attention_mask"].sum(dim=1) / common.SAMPLING_RATE
106
  for i, chunk in enumerate(chunks):
 
20
  return np.random.rand(output_length, vocab_size)
21
 
22
 
23
+ def update_frame(frames, ax, matrix_plot, tokenizer=None, colorbar=None):
24
  ax.clear()
25
  ax.set_title(
26
  "Activation levels for WavLM Base +'s hidden layers\n"
27
+ f"Layer = {frames[0] + 1}, T = {frames[1]}s"
28
  )
29
+ ax.set_xlabel("Phoneme Vocabulary")
30
+ ax.set_ylabel("Time Steps, and Selected Phoneme")
31
  data = frames[2].detach().clone()
32
+ matrix_plot = ax.matshow(data, vmin=0, vmax=1, cmap='Blues')
33
+
34
  if tokenizer is not None:
35
  label_ids = torch.argmax(data, -1)
36
  labels = tokenizer.batch_decode(label_ids)
37
  ax.set_xticks([i for v, i in tokenizer.vocab.items() if v in labels])
38
+ ax.set_xticklabels([v for v, i in tokenizer.vocab.items() if v in labels], rotation=45, ha='right')
39
  ax.set_yticks([i for i, v in enumerate(labels) if v not in ("", "[PAD]")])
40
  ax.set_yticklabels([v for i, v in enumerate(labels) if v not in ("", "[PAD]")])
41
+
42
+ # Position the decoded text below the plot with proper spacing
43
+ decoded_text = tokenizer.decode(label_ids)
44
+ if len(decoded_text) > 50:
45
+ decoded_text = decoded_text[:50] + "..."
46
+ ax.text(
47
+ 0.5, -0.15, f"Decoded: {decoded_text}",
48
+ transform=ax.transAxes, ha='center', va='top',
49
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8)
50
+ )
51
+
52
+ plt.tight_layout()
53
  return ax, matrix_plot
54
 
55
 
 
111
  ]
112
  fig, ax = plt.subplots(animated=True)
113
  ax.set_title("Animation Preview")
114
+ matrix_plot = ax.matshow(logit_groups[0][0], animated=True, vmin=0, vmax=1, cmap='Blues')
115
+
116
+ # Add colorbar once for the entire animation
117
+ colorbar = plt.colorbar(matrix_plot, ax=ax, label='Activation Level')
118
  logits_list = []
119
  masks = inputs["attention_mask"].sum(dim=1) / common.SAMPLING_RATE
120
  for i, chunk in enumerate(chunks):