Hugo Farajallah commited on
Commit ·
54520eb
1
Parent(s): b6bd379
fix(HF): the displayed alignment matrix was not correct.
Browse files- dataset_process.py +23 -11
- hf_space.py +19 -4
dataset_process.py
CHANGED
|
@@ -283,6 +283,25 @@ def score_phoneme_deletion(matching, prediction, target, threshold):
|
|
| 283 |
return 0
|
| 284 |
|
| 285 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
def get_alignment_score(
|
| 287 |
prediction,
|
| 288 |
target,
|
|
@@ -304,17 +323,10 @@ def get_alignment_score(
|
|
| 304 |
:param common.Scoring scoring: Type of scoring to use
|
| 305 |
:return int: Scoring score.
|
| 306 |
"""
|
| 307 |
-
|
| 308 |
-
logits = torch.softmax(
|
| 309 |
-
torch.as_tensor(prediction) / weights[3],
|
| 310 |
-
dim=-1
|
| 311 |
-
)
|
| 312 |
-
|
| 313 |
-
reduced_logits = logits[torch.argmax(logits, -1) != pad_token_id]
|
| 314 |
-
reduced_logits = reduced_logits.reshape((1, reduced_logits.shape[0], reduced_logits.shape[1]))
|
| 315 |
|
| 316 |
matching, alignment_score = bellman_matching(
|
| 317 |
-
|
| 318 |
target,
|
| 319 |
insertion_cost=weights[0],
|
| 320 |
deletion_cost=weights[1],
|
|
@@ -323,9 +335,9 @@ def get_alignment_score(
|
|
| 323 |
np_matching = np.array(matching)
|
| 324 |
|
| 325 |
if scoring is common.Scoring.NUMBER_CORRECT:
|
| 326 |
-
return score_correct(np_matching,
|
| 327 |
|
| 328 |
if scoring is common.Scoring.PHONEME_DELETION:
|
| 329 |
-
return score_phoneme_deletion(np_matching,
|
| 330 |
|
| 331 |
raise NotImplementedError("Unknown scoring method.")
|
|
|
|
| 283 |
return 0
|
| 284 |
|
| 285 |
|
| 286 |
+
def remove_pad_tokens(prediction, pad_token_id, temperature):
|
| 287 |
+
"""
|
| 288 |
+
Remove the pad token from a prediction to decrease temporal effects.
|
| 289 |
+
|
| 290 |
+
:param prediction: Predicted logits.
|
| 291 |
+
:param int pad_token_id: ID of the pad token.
|
| 292 |
+
:param float temperature: Temperature to pass to the SoftMax.
|
| 293 |
+
:return torch.Tensor: Probabilities where no row has a pad token id as an argmax.
|
| 294 |
+
"""
|
| 295 |
+
logits = torch.softmax(
|
| 296 |
+
torch.as_tensor(prediction) / temperature,
|
| 297 |
+
dim=-1
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
reduced_logits = logits[torch.argmax(logits, -1) != pad_token_id]
|
| 301 |
+
reduced_logits = reduced_logits.reshape((1, reduced_logits.shape[0], reduced_logits.shape[1]))
|
| 302 |
+
return reduced_logits
|
| 303 |
+
|
| 304 |
+
|
| 305 |
def get_alignment_score(
|
| 306 |
prediction,
|
| 307 |
target,
|
|
|
|
| 323 |
:param common.Scoring scoring: Type of scoring to use
|
| 324 |
:return int: Scoring score.
|
| 325 |
"""
|
| 326 |
+
collapsed_prediction = remove_pad_tokens(prediction, pad_token_id, weights[3])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
|
| 328 |
matching, alignment_score = bellman_matching(
|
| 329 |
+
collapsed_prediction,
|
| 330 |
target,
|
| 331 |
insertion_cost=weights[0],
|
| 332 |
deletion_cost=weights[1],
|
|
|
|
| 335 |
np_matching = np.array(matching)
|
| 336 |
|
| 337 |
if scoring is common.Scoring.NUMBER_CORRECT:
|
| 338 |
+
return score_correct(np_matching, collapsed_prediction, target, weights[2])
|
| 339 |
|
| 340 |
if scoring is common.Scoring.PHONEME_DELETION:
|
| 341 |
+
return score_phoneme_deletion(np_matching, collapsed_prediction, target, weights[2])
|
| 342 |
|
| 343 |
raise NotImplementedError("Unknown scoring method.")
|
hf_space.py
CHANGED
|
@@ -84,20 +84,35 @@ def process_audio_advanced(audio_data, target_word, language, advanced_mode, ins
|
|
| 84 |
prediction_logits,
|
| 85 |
target_encoded,
|
| 86 |
weights,
|
| 87 |
-
|
| 88 |
scoring=scoring_enum
|
| 89 |
)
|
| 90 |
|
| 91 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
path_matrix = dataset_process.compute_path_matrix(
|
| 93 |
-
|
| 94 |
target_encoded,
|
| 95 |
dataset_process.l2_logit_norm,
|
| 96 |
insertion_cost,
|
| 97 |
deletion_cost
|
| 98 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
alignment_plot_fig = dataset_process.display_matrix_result(
|
| 100 |
-
path_matrix,
|
| 101 |
)
|
| 102 |
|
| 103 |
alignment_result = f"**🔬 Advanced Alignment Analysis:**\n\n"
|
|
|
|
| 84 |
prediction_logits,
|
| 85 |
target_encoded,
|
| 86 |
weights,
|
| 87 |
+
processor.tokenizer.pad_token_id,
|
| 88 |
scoring=scoring_enum
|
| 89 |
)
|
| 90 |
|
| 91 |
+
# Use reduced prediction tensor for alignment plot (remove temporal effects)
|
| 92 |
+
reduced_prediction = dataset_process.remove_pad_tokens(
|
| 93 |
+
prediction_logits, processor.tokenizer.pad_token_id, temperature
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# Generate alignment plot with reduced prediction
|
| 97 |
path_matrix = dataset_process.compute_path_matrix(
|
| 98 |
+
reduced_prediction,
|
| 99 |
target_encoded,
|
| 100 |
dataset_process.l2_logit_norm,
|
| 101 |
insertion_cost,
|
| 102 |
deletion_cost
|
| 103 |
)
|
| 104 |
+
|
| 105 |
+
# Re-compute matching with reduced prediction for visualization
|
| 106 |
+
matching_for_plot, _ = dataset_process.bellman_matching(
|
| 107 |
+
reduced_prediction,
|
| 108 |
+
target_encoded,
|
| 109 |
+
insertion_cost=insertion_cost,
|
| 110 |
+
deletion_cost=deletion_cost,
|
| 111 |
+
metric=dataset_process.l2_logit_norm
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
alignment_plot_fig = dataset_process.display_matrix_result(
|
| 115 |
+
path_matrix, matching_for_plot, reduced_prediction, target_encoded, processor
|
| 116 |
)
|
| 117 |
|
| 118 |
alignment_result = f"**🔬 Advanced Alignment Analysis:**\n\n"
|