""" Explainability Pipeline: Tokenization Visualization, Embeddings PCA, and Cosine Similarity. """ import numpy as np import torch import torch.nn.functional as F from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler from typing import Dict, List, Tuple, Optional from transformers import RobertaTokenizer class ExplainabilityEngine: """ Extracts and visualizes model internals: tokens, embeddings, and similarity. """ def __init__(self, tokenizer_name: str = "roberta-base", emotion_dict: Optional[Dict] = None): """ Initialize explainability engine. Args: tokenizer_name: HuggingFace tokenizer emotion_dict: Plutchik emotion dictionary (for names) """ self.tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name) self.emotion_dict = emotion_dict or {} self.pca_2d = None self.pca_3d = None self.emotion_centroids = {} # Will be populated during training def tokenize_with_visualization(self, text: str) -> Dict: """ Tokenize text and prepare for visualization. Args: text: Input text Returns: Dict with token_ids, tokens, token_strings """ encoding = self.tokenizer( text, max_length=256, padding='max_length', truncation=True, return_tensors='pt' ) input_ids = encoding["input_ids"].squeeze().cpu().numpy() # Convert token IDs to tokens tokens = self.tokenizer.convert_ids_to_tokens(input_ids) # Remove padding tokens for display attention_mask = encoding["attention_mask"].squeeze().cpu().numpy() active_tokens = [tokens[i] for i in range(len(tokens)) if attention_mask[i] == 1] active_ids = [input_ids[i] for i in range(len(input_ids)) if attention_mask[i] == 1] return { "token_ids": input_ids.tolist(), "tokens": tokens, "active_tokens": active_tokens, "active_token_ids": active_ids, "attention_mask": attention_mask.tolist() } def extract_embeddings(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> Dict: """ Extract and process embeddings from last hidden state. Args: last_hidden_state: [batch_size, seq_length, hidden_dim] or [seq_length, hidden_dim] attention_mask: Attention mask for filtering padding Returns: Dict with CLS embedding, mean pooling, and token embeddings """ if last_hidden_state.dim() == 3: last_hidden_state = last_hidden_state.squeeze(0) if attention_mask.dim() == 2: attention_mask = attention_mask.squeeze(0) # CLS token (first token) cls_embedding = last_hidden_state[0, :] # [hidden_dim] # Mean pooling over non-padding tokens attention_mask_expanded = attention_mask.unsqueeze(-1).expand_as(last_hidden_state) sum_embeddings = torch.sum(last_hidden_state * attention_mask_expanded, dim=0) sum_mask = torch.clamp(attention_mask_expanded.sum(dim=0), min=1e-9) mean_embedding = sum_embeddings / sum_mask # [hidden_dim] return { "cls_embedding": cls_embedding.detach().cpu().numpy(), "mean_embedding": mean_embedding.detach().cpu().numpy(), "all_token_embeddings": last_hidden_state.detach().cpu().numpy(), "attention_mask": attention_mask.detach().cpu().numpy() } def reduce_embeddings_pca(self, embeddings_list: List[np.ndarray], n_components: int = 2) -> Tuple[np.ndarray, PCA]: """ Apply PCA to reduce high-dimensional embeddings. Args: embeddings_list: List of [hidden_dim] embeddings n_components: 2 or 3 Returns: (reduced_embeddings, fitted_pca) """ embeddings_array = np.vstack(embeddings_list) # [n_samples, hidden_dim] # Standardize scaler = StandardScaler() embeddings_scaled = scaler.fit_transform(embeddings_array) # PCA pca = PCA(n_components=n_components) reduced = pca.fit_transform(embeddings_scaled) return reduced, pca def visualize_embedding_heatmap(self, embeddings: np.ndarray) -> Dict: """ Prepare embedding as heatmap data (useful for Plotly heatmap). Args: embeddings: [seq_length, hidden_dim] Returns: Dict with heatmap-friendly data """ # Normalize to [0, 1] for visualization em_min = embeddings.min(axis=0, keepdims=True) em_max = embeddings.max(axis=0, keepdims=True) em_max = np.where(em_max == em_min, 1.0, em_max) # Avoid division by zero normalized = (embeddings - em_min) / (em_max - em_min + 1e-8) # Sample columns for visualization (768 is too many) sample_cols = np.linspace(0, normalized.shape[1] - 1, 30, dtype=int) sampled = normalized[:, sample_cols] return { "heatmap_data": sampled, "seq_length": embeddings.shape[0], "hidden_dim": embeddings.shape[1], "sampled_dims": sample_cols.tolist() } def compute_cosine_similarity(self, embedding: np.ndarray, centroid: np.ndarray) -> float: """ Compute cosine similarity between an embedding and emotion centroid. Args: embedding: [hidden_dim] numpy array centroid: [hidden_dim] numpy array Returns: Cosine similarity score [0, 1] """ similarity = np.dot(embedding, centroid) / ( np.linalg.norm(embedding) * np.linalg.norm(centroid) + 1e-8 ) # Map from [-1, 1] to [0, 1] similarity = (similarity + 1.0) / 2.0 return float(similarity) def register_emotion_centroids(self, emotion_embeddings: Dict[str, List[np.ndarray]]): """ Register emotion centroids from training data. Each emotion has a list of embeddings; compute mean as centroid. Args: emotion_embeddings: Dict[emotion_name -> List[embeddings]] """ for emotion, embeddings_list in emotion_embeddings.items(): if len(embeddings_list) > 0: centroid = np.mean(embeddings_list, axis=0) self.emotion_centroids[emotion] = centroid def explain_prediction(self, input_text: str, model_output: Dict, predicted_emotion: str, top_k: int = 3) -> Dict: """ Generate full explainability report for a single prediction. Args: input_text: Original input text model_output: Output dict from model (contains embeddings) predicted_emotion: Predicted emotion name top_k: Top K similar emotions in embedding space Returns: Comprehensive explainability dict """ # Tokenization token_info = self.tokenize_with_visualization(input_text) # Embeddings embedding_info = self.extract_embeddings( model_output["last_hidden_state"], model_output.get("attention_mask", torch.ones(model_output["last_hidden_state"].shape[:-1])) ) cls_emb = embedding_info["cls_embedding"] # Heatmap heatmap_info = self.visualize_embedding_heatmap(embedding_info["all_token_embeddings"]) # Cosine similarity to emotion centroids similarities = {} for emotion, centroid in self.emotion_centroids.items(): sim = self.compute_cosine_similarity(cls_emb, centroid) similarities[emotion] = sim # Top K sorted_sims = sorted(similarities.items(), key=lambda x: x[1], reverse=True) top_k_emotions = sorted_sims[:top_k] return { "tokens": token_info, "embeddings": embedding_info, "heatmap": heatmap_info, "cosine_similarities": similarities, "top_k_similar_emotions": top_k_emotions, "predicted_emotion": predicted_emotion, "prediction_confidence": float(similarities.get(predicted_emotion, 0.0)) }