Spaces:
Sleeping
Sleeping
File size: 8,713 Bytes
3311661 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 | """
Explainability Pipeline: Tokenization Visualization, Embeddings PCA, and Cosine Similarity.
"""
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from typing import Dict, List, Tuple, Optional
from transformers import RobertaTokenizer
class ExplainabilityEngine:
"""
Extracts and visualizes model internals: tokens, embeddings, and similarity.
"""
def __init__(self, tokenizer_name: str = "roberta-base", emotion_dict: Optional[Dict] = None):
"""
Initialize explainability engine.
Args:
tokenizer_name: HuggingFace tokenizer
emotion_dict: Plutchik emotion dictionary (for names)
"""
self.tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name)
self.emotion_dict = emotion_dict or {}
self.pca_2d = None
self.pca_3d = None
self.emotion_centroids = {} # Will be populated during training
def tokenize_with_visualization(self, text: str) -> Dict:
"""
Tokenize text and prepare for visualization.
Args:
text: Input text
Returns:
Dict with token_ids, tokens, token_strings
"""
encoding = self.tokenizer(
text,
max_length=256,
padding='max_length',
truncation=True,
return_tensors='pt'
)
input_ids = encoding["input_ids"].squeeze().cpu().numpy()
# Convert token IDs to tokens
tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
# Remove padding tokens for display
attention_mask = encoding["attention_mask"].squeeze().cpu().numpy()
active_tokens = [tokens[i] for i in range(len(tokens)) if attention_mask[i] == 1]
active_ids = [input_ids[i] for i in range(len(input_ids)) if attention_mask[i] == 1]
return {
"token_ids": input_ids.tolist(),
"tokens": tokens,
"active_tokens": active_tokens,
"active_token_ids": active_ids,
"attention_mask": attention_mask.tolist()
}
def extract_embeddings(self, last_hidden_state: torch.Tensor,
attention_mask: torch.Tensor) -> Dict:
"""
Extract and process embeddings from last hidden state.
Args:
last_hidden_state: [batch_size, seq_length, hidden_dim] or [seq_length, hidden_dim]
attention_mask: Attention mask for filtering padding
Returns:
Dict with CLS embedding, mean pooling, and token embeddings
"""
if last_hidden_state.dim() == 3:
last_hidden_state = last_hidden_state.squeeze(0)
if attention_mask.dim() == 2:
attention_mask = attention_mask.squeeze(0)
# CLS token (first token)
cls_embedding = last_hidden_state[0, :] # [hidden_dim]
# Mean pooling over non-padding tokens
attention_mask_expanded = attention_mask.unsqueeze(-1).expand_as(last_hidden_state)
sum_embeddings = torch.sum(last_hidden_state * attention_mask_expanded, dim=0)
sum_mask = torch.clamp(attention_mask_expanded.sum(dim=0), min=1e-9)
mean_embedding = sum_embeddings / sum_mask # [hidden_dim]
return {
"cls_embedding": cls_embedding.detach().cpu().numpy(),
"mean_embedding": mean_embedding.detach().cpu().numpy(),
"all_token_embeddings": last_hidden_state.detach().cpu().numpy(),
"attention_mask": attention_mask.detach().cpu().numpy()
}
def reduce_embeddings_pca(self, embeddings_list: List[np.ndarray],
n_components: int = 2) -> Tuple[np.ndarray, PCA]:
"""
Apply PCA to reduce high-dimensional embeddings.
Args:
embeddings_list: List of [hidden_dim] embeddings
n_components: 2 or 3
Returns:
(reduced_embeddings, fitted_pca)
"""
embeddings_array = np.vstack(embeddings_list) # [n_samples, hidden_dim]
# Standardize
scaler = StandardScaler()
embeddings_scaled = scaler.fit_transform(embeddings_array)
# PCA
pca = PCA(n_components=n_components)
reduced = pca.fit_transform(embeddings_scaled)
return reduced, pca
def visualize_embedding_heatmap(self, embeddings: np.ndarray) -> Dict:
"""
Prepare embedding as heatmap data (useful for Plotly heatmap).
Args:
embeddings: [seq_length, hidden_dim]
Returns:
Dict with heatmap-friendly data
"""
# Normalize to [0, 1] for visualization
em_min = embeddings.min(axis=0, keepdims=True)
em_max = embeddings.max(axis=0, keepdims=True)
em_max = np.where(em_max == em_min, 1.0, em_max) # Avoid division by zero
normalized = (embeddings - em_min) / (em_max - em_min + 1e-8)
# Sample columns for visualization (768 is too many)
sample_cols = np.linspace(0, normalized.shape[1] - 1, 30, dtype=int)
sampled = normalized[:, sample_cols]
return {
"heatmap_data": sampled,
"seq_length": embeddings.shape[0],
"hidden_dim": embeddings.shape[1],
"sampled_dims": sample_cols.tolist()
}
def compute_cosine_similarity(self, embedding: np.ndarray,
centroid: np.ndarray) -> float:
"""
Compute cosine similarity between an embedding and emotion centroid.
Args:
embedding: [hidden_dim] numpy array
centroid: [hidden_dim] numpy array
Returns:
Cosine similarity score [0, 1]
"""
similarity = np.dot(embedding, centroid) / (
np.linalg.norm(embedding) * np.linalg.norm(centroid) + 1e-8
)
# Map from [-1, 1] to [0, 1]
similarity = (similarity + 1.0) / 2.0
return float(similarity)
def register_emotion_centroids(self, emotion_embeddings: Dict[str, List[np.ndarray]]):
"""
Register emotion centroids from training data.
Each emotion has a list of embeddings; compute mean as centroid.
Args:
emotion_embeddings: Dict[emotion_name -> List[embeddings]]
"""
for emotion, embeddings_list in emotion_embeddings.items():
if len(embeddings_list) > 0:
centroid = np.mean(embeddings_list, axis=0)
self.emotion_centroids[emotion] = centroid
def explain_prediction(self, input_text: str, model_output: Dict,
predicted_emotion: str, top_k: int = 3) -> Dict:
"""
Generate full explainability report for a single prediction.
Args:
input_text: Original input text
model_output: Output dict from model (contains embeddings)
predicted_emotion: Predicted emotion name
top_k: Top K similar emotions in embedding space
Returns:
Comprehensive explainability dict
"""
# Tokenization
token_info = self.tokenize_with_visualization(input_text)
# Embeddings
embedding_info = self.extract_embeddings(
model_output["last_hidden_state"],
model_output.get("attention_mask", torch.ones(model_output["last_hidden_state"].shape[:-1]))
)
cls_emb = embedding_info["cls_embedding"]
# Heatmap
heatmap_info = self.visualize_embedding_heatmap(embedding_info["all_token_embeddings"])
# Cosine similarity to emotion centroids
similarities = {}
for emotion, centroid in self.emotion_centroids.items():
sim = self.compute_cosine_similarity(cls_emb, centroid)
similarities[emotion] = sim
# Top K
sorted_sims = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
top_k_emotions = sorted_sims[:top_k]
return {
"tokens": token_info,
"embeddings": embedding_info,
"heatmap": heatmap_info,
"cosine_similarities": similarities,
"top_k_similar_emotions": top_k_emotions,
"predicted_emotion": predicted_emotion,
"prediction_confidence": float(similarities.get(predicted_emotion, 0.0))
}
|