import matplotlib.pyplot as plt import numpy as np import torch import transformers import common def encode_phonemes( phonemes: list[str] | str, tokenizer: transformers.Wav2Vec2Tokenizer ) -> torch.Tensor: """ From list of phonemes to a logits-like matrix. :param list[str] | str phonemes: List of individuals phonemes to encode :param tokenizer: The tokenizer to use. :return torch.Tensor: Encodings for the phonemes, size (1, len(phonemes), n_tokens) """ encodings = torch.zeros( (1, len(phonemes), len(tokenizer.encoder) + 2), dtype=torch.uint8 ) label_ids = tokenizer.encode(phonemes) for i, label_id in enumerate(label_ids): encodings[0, i, label_id] = 1 return encodings def l2_logit_norm(prediction, target): """ Apply L2 distance between two vectors. Is close to 0 for two similar vectors, close to 1 for different vectors """ val = torch.norm(prediction - target) / 1.414 return val def cosine_similarity(prediction, target): normed = torch.softmax(prediction, 0) similarity = (1 - torch.nn.CosineSimilarity(dim=0)(normed, target)) / 2 return similarity # * torch.norm(prediction) def argmax_selection(prediction, target): """ Select normalized(prediction)[argmax(target)] 0 for same vectors, 1 for totally different """ return 1 - prediction[torch.argmax(target)] def plot_metric(metric, prediction, target): """Plot the result of a metric.""" fig, ax = plt.subplots() _model, processor = common.get_model() predicted_labels = ( processor.decode(i) if i < prediction.shape[0] - 4 else "" for i in range(prediction.shape[0]) ) normed = torch.softmax(prediction, 0) ax.plot(normed, label="Normed prediction") ax.scatter([torch.argmax(target).item()], [1], label="Target", marker="X") ax.set_xticks(range(prediction.shape[0]), predicted_labels) value = 1 - metric(prediction, target) ax.plot([0, normed.shape[0]], [value, value], label="1 - Metric value (1 = perfect)") plt.legend() plt.show() def compute_path_matrix(prediction, target, metric, insertion_cost, deletion_cost): """Compute the alignment matrix of two matrices.""" # Define the matrix path_matrix = torch.empty((prediction.shape[1], target.shape[1])) # Now run recursively for i, pred_column in enumerate(prediction[0]): for j, target_column in enumerate(target[0]): if i == 0 and j == 0: path_matrix[i, j] = 0 elif i == 0: path_matrix[0, j] = j * insertion_cost elif j == 0: path_matrix[i, 0] = i * deletion_cost else: # plot_metric(metric, pred_column, target_column) path_matrix[i, j] = min( path_matrix[i - 1, j - 1] + metric(pred_column, target_column), path_matrix[i - 1, j] + deletion_cost, path_matrix[i, j - 1] + insertion_cost ) return path_matrix def solve_path(prediction, target, path_matrix): """ Find the matching path between a prediction, a target and a path matrix. For each step we minimize the cost. """ line, col = prediction.shape[1] - 1, target.shape[1] - 1 matching = [] while line > 0 or col > 0: matching.append((line, col)) directions = [] if line > 0 and col > 0: directions.append((line - 1, col - 1)) if line > 0: directions.append((line - 1, col)) if col > 0: directions.append((line, col - 1)) best_score = float("inf") dir_index = -1 for i, direction in enumerate(directions): if path_matrix[direction[0]][direction[1]] < best_score: best_score = path_matrix[direction[0]][direction[1]] dir_index = i line, col = directions[dir_index] matching.reverse() return matching def display_matrix_result(path_matrix, matching, prediction, target, processor=None): """Display all the information resulting from a Bellman matching of matrices. Returns the figure instead of showing it directly for use in Gradio. """ fig, axis = plt.subplots(figsize=(12, 8)) if processor is None: _model, processor = common.get_model() # Display the matrix im = axis.matshow(path_matrix.T, aspect="auto", cmap='Blues') cbar = plt.colorbar(im, ax=axis) cbar.set_label('Alignment Cost', rotation=270, labelpad=20, fontsize=11) # Set the labels for the axes with clearer names axis.set_xlabel('Predicted Phoneme Sequence', fontsize=12) axis.set_ylabel('Target Phoneme Sequence', fontsize=12) axis.set_title('Phoneme Alignment Matrix\n(Blue = Lower Cost, Red Line = Optimal Path)', fontsize=14, pad=20) # Get phoneme labels for both axes predicted_labels = tuple(map(processor.decode, torch.argmax(prediction, -1)[0])) target_labels = tuple(map(processor.decode, torch.argmax(target, -1)[0])) # Set x-axis ticks (predicted phonemes) non_empty_pred_indices = [i for i, label in enumerate(predicted_labels) if label not in ("", "[PAD]")] non_empty_pred_labels = [label for i, label in enumerate(predicted_labels) if label not in ("", "[PAD]")] if non_empty_pred_indices: axis.set_xticks(non_empty_pred_indices) axis.set_xticklabels(non_empty_pred_labels, rotation=45, ha='right', fontsize=10) # Set y-axis ticks (target phonemes) non_empty_target_indices = [i for i, label in enumerate(target_labels) if label not in ("", "[PAD]")] non_empty_target_labels = [label for i, label in enumerate(target_labels) if label not in ("", "[PAD]")] if non_empty_target_indices: axis.set_yticks(non_empty_target_indices) axis.set_yticklabels(non_empty_target_labels, fontsize=10) # Add subtle grid axis.grid(which="major", color="gray", alpha=0.2, linestyle="-") # Plot the optimal path in red with better visibility if matching: axis.plot( [val[0] for val in matching], [val[1] for val in matching], color="red", linewidth=3, marker='o', markersize=4, markerfacecolor='white', markeredgecolor='red', markeredgewidth=2, label="Optimal Alignment Path", alpha=0.9 ) # Add legend with better positioning axis.legend(loc='upper right', bbox_to_anchor=(1.0, 1.0), fontsize=11) # Add text annotations for better understanding axis.text( 0.02, 0.98, 'Lower values indicate\nbetter alignment', transform=axis.transAxes, fontsize=9, va='top', ha='left', bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8) ) plt.tight_layout() return fig def bellman_matching(prediction, target, insertion_cost=1.3, deletion_cost=3, metric=l2_logit_norm): """ Match to sequences with Bellman's algorithm. :param prediction: Actual prediction :param target: Target list of values. :param float insertion_cost: Something was added in prediction. :param float deletion_cost: Something was missing in prediction. :param Callable metric: The metric to use. :return tuple(list, float): Best alignment [(prediction[i], target[j]), ...] for all elements, and its score """ # Add padding: start matching on letters (do not penalize kids starting with insertions or long audio) padded_target = torch.zeros((target.shape[0], target.shape[1] + 1, target.shape[2])) padded_target[0, 1:] = target padded_prediction = torch.zeros((prediction.shape[0], prediction.shape[1] + 1, prediction.shape[2])) padded_prediction[0, 1:] = prediction path_matrix = compute_path_matrix( padded_prediction, padded_target, metric, insertion_cost, deletion_cost ) # Now solve path, find candidate diagonal padded_matching = solve_path(padded_prediction, padded_target, path_matrix) short_matching = [] for match in padded_matching: if match[0] == 0 or match[1] == 0: continue short_matching.append((match[0] - 1, match[1] - 1)) if match[1] == padded_target.shape[1] - 1: break # display_matrix_result(path_matrix, padded_matching, padded_prediction, padded_target) # Initial padding should not reduce score score = path_matrix[padded_matching[-1]] return short_matching, score.item() def score_correct(matching, prediction, target, threshold): """Count the number of correct phonemes in the target""" # Now from the matching count errors insertions = deletions = substitutions = 0 for i, match in enumerate(matching[1:]): if np.all(match - matching[i] == [0, 1]): # Deletion occurred deletions += 1 elif np.all(match - matching[i] == [1, 0]): # Insertion insertions += 1 else: # Match probability, 1 == good match # plot_metric(argmax_selection, reduced_logits[0, match[0]], target[0, match[1]]) match_value = 1 - argmax_selection(prediction[0, match[0]], target[0, match[1]]) if match_value < threshold: substitutions += 1 return max(0, target.shape[1] - insertions - deletions - substitutions) def score_phoneme_deletion(matching, prediction, target, threshold): # Now from the matching count errors insertions = deletions = substitutions = 0 for i, match in enumerate(matching[1:]): if np.all(match - matching[i] == [0, 1]): # Deletion occurred deletions += 1 elif np.all(match - matching[i] == [1, 0]): # Insertion insertions += 1 else: # Match probability, 1 == good match # plot_metric(argmax_selection, reduced_logits[0, match[0]], target[0, match[1]]) match_value = 1 - argmax_selection(prediction[0, match[0]], target[0, match[1]]) if match_value < threshold: substitutions += 1 # First phoneme should NOT match if 0 in matching[0]: indices = np.argwhere(matching[:, 0] == 0).flatten() for i in indices: match_value = 1 - argmax_selection( prediction[0, matching[i, 0]], target[0, matching[i, 1]] ) if match_value > threshold: return 0 if insertions + deletions + substitutions == 0: return 2 if insertions + deletions + substitutions == 1: return 1 return 0 def remove_pad_tokens(prediction, pad_token_id, temperature): """ Remove the pad token from a prediction to decrease temporal effects. :param prediction: Predicted logits. :param int pad_token_id: ID of the pad token. :param float temperature: Temperature to pass to the SoftMax. :return torch.Tensor: Probabilities where no row has a pad token id as an argmax. """ logits = torch.softmax( torch.as_tensor(prediction) / temperature, dim=-1 ) reduced_logits = logits[torch.argmax(logits, -1) != pad_token_id] reduced_logits = reduced_logits.reshape((1, reduced_logits.shape[0], reduced_logits.shape[1])) return reduced_logits def get_alignment_score( prediction, target, weights, pad_token_id=58, scoring=common.Scoring.NUMBER_CORRECT ): """ Get a classification score, either 0, 1 or 2 from a prediction and a target. Both the prediction and the target should be logits. The result depends on the type of scoring. :param prediction: The output of the model, without activation function. :param target: The logits we have to match. :param weights: A sequence of weights to apply. :param int pad_token_id: Index of elements in the sequence that should be ignored. :param common.Scoring scoring: Type of scoring to use :return int: Scoring score. """ collapsed_prediction = remove_pad_tokens(prediction, pad_token_id, weights[3]) matching, alignment_score = bellman_matching( collapsed_prediction, target, insertion_cost=weights[0], deletion_cost=weights[1], metric=l2_logit_norm ) np_matching = np.array(matching) if scoring is common.Scoring.NUMBER_CORRECT: return score_correct(np_matching, collapsed_prediction, target, weights[2]) if scoring is common.Scoring.PHONEME_DELETION: return score_phoneme_deletion(np_matching, collapsed_prediction, target, weights[2]) raise NotImplementedError("Unknown scoring method.")