Hugo Farajallah commited on
Commit
54520eb
·
1 Parent(s): b6bd379

fix(HF): the displayed alignment matrix was not correct.

Browse files
Files changed (2) hide show
  1. dataset_process.py +23 -11
  2. hf_space.py +19 -4
dataset_process.py CHANGED
@@ -283,6 +283,25 @@ def score_phoneme_deletion(matching, prediction, target, threshold):
283
  return 0
284
 
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  def get_alignment_score(
287
  prediction,
288
  target,
@@ -304,17 +323,10 @@ def get_alignment_score(
304
  :param common.Scoring scoring: Type of scoring to use
305
  :return int: Scoring score.
306
  """
307
-
308
- logits = torch.softmax(
309
- torch.as_tensor(prediction) / weights[3],
310
- dim=-1
311
- )
312
-
313
- reduced_logits = logits[torch.argmax(logits, -1) != pad_token_id]
314
- reduced_logits = reduced_logits.reshape((1, reduced_logits.shape[0], reduced_logits.shape[1]))
315
 
316
  matching, alignment_score = bellman_matching(
317
- reduced_logits,
318
  target,
319
  insertion_cost=weights[0],
320
  deletion_cost=weights[1],
@@ -323,9 +335,9 @@ def get_alignment_score(
323
  np_matching = np.array(matching)
324
 
325
  if scoring is common.Scoring.NUMBER_CORRECT:
326
- return score_correct(np_matching, reduced_logits, target, weights[2])
327
 
328
  if scoring is common.Scoring.PHONEME_DELETION:
329
- return score_phoneme_deletion(np_matching, reduced_logits, target, weights[2])
330
 
331
  raise NotImplementedError("Unknown scoring method.")
 
283
  return 0
284
 
285
 
286
+ def remove_pad_tokens(prediction, pad_token_id, temperature):
287
+ """
288
+ Remove the pad token from a prediction to decrease temporal effects.
289
+
290
+ :param prediction: Predicted logits.
291
+ :param int pad_token_id: ID of the pad token.
292
+ :param float temperature: Temperature to pass to the SoftMax.
293
+ :return torch.Tensor: Probabilities where no row has a pad token id as an argmax.
294
+ """
295
+ logits = torch.softmax(
296
+ torch.as_tensor(prediction) / temperature,
297
+ dim=-1
298
+ )
299
+
300
+ reduced_logits = logits[torch.argmax(logits, -1) != pad_token_id]
301
+ reduced_logits = reduced_logits.reshape((1, reduced_logits.shape[0], reduced_logits.shape[1]))
302
+ return reduced_logits
303
+
304
+
305
  def get_alignment_score(
306
  prediction,
307
  target,
 
323
  :param common.Scoring scoring: Type of scoring to use
324
  :return int: Scoring score.
325
  """
326
+ collapsed_prediction = remove_pad_tokens(prediction, pad_token_id, weights[3])
 
 
 
 
 
 
 
327
 
328
  matching, alignment_score = bellman_matching(
329
+ collapsed_prediction,
330
  target,
331
  insertion_cost=weights[0],
332
  deletion_cost=weights[1],
 
335
  np_matching = np.array(matching)
336
 
337
  if scoring is common.Scoring.NUMBER_CORRECT:
338
+ return score_correct(np_matching, collapsed_prediction, target, weights[2])
339
 
340
  if scoring is common.Scoring.PHONEME_DELETION:
341
+ return score_phoneme_deletion(np_matching, collapsed_prediction, target, weights[2])
342
 
343
  raise NotImplementedError("Unknown scoring method.")
hf_space.py CHANGED
@@ -84,20 +84,35 @@ def process_audio_advanced(audio_data, target_word, language, advanced_mode, ins
84
  prediction_logits,
85
  target_encoded,
86
  weights,
87
- 94,
88
  scoring=scoring_enum
89
  )
90
 
91
- # Generate alignment plot
 
 
 
 
 
92
  path_matrix = dataset_process.compute_path_matrix(
93
- prediction_logits,
94
  target_encoded,
95
  dataset_process.l2_logit_norm,
96
  insertion_cost,
97
  deletion_cost
98
  )
 
 
 
 
 
 
 
 
 
 
99
  alignment_plot_fig = dataset_process.display_matrix_result(
100
- path_matrix, matching, prediction_logits, target_encoded, processor
101
  )
102
 
103
  alignment_result = f"**🔬 Advanced Alignment Analysis:**\n\n"
 
84
  prediction_logits,
85
  target_encoded,
86
  weights,
87
+ processor.tokenizer.pad_token_id,
88
  scoring=scoring_enum
89
  )
90
 
91
+ # Use reduced prediction tensor for alignment plot (remove temporal effects)
92
+ reduced_prediction = dataset_process.remove_pad_tokens(
93
+ prediction_logits, processor.tokenizer.pad_token_id, temperature
94
+ )
95
+
96
+ # Generate alignment plot with reduced prediction
97
  path_matrix = dataset_process.compute_path_matrix(
98
+ reduced_prediction,
99
  target_encoded,
100
  dataset_process.l2_logit_norm,
101
  insertion_cost,
102
  deletion_cost
103
  )
104
+
105
+ # Re-compute matching with reduced prediction for visualization
106
+ matching_for_plot, _ = dataset_process.bellman_matching(
107
+ reduced_prediction,
108
+ target_encoded,
109
+ insertion_cost=insertion_cost,
110
+ deletion_cost=deletion_cost,
111
+ metric=dataset_process.l2_logit_norm
112
+ )
113
+
114
  alignment_plot_fig = dataset_process.display_matrix_result(
115
+ path_matrix, matching_for_plot, reduced_prediction, target_encoded, processor
116
  )
117
 
118
  alignment_result = f"**🔬 Advanced Alignment Analysis:**\n\n"