File size: 8,713 Bytes
3311661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
"""
Explainability Pipeline: Tokenization Visualization, Embeddings PCA, and Cosine Similarity.
"""

import numpy as np
import torch
import torch.nn.functional as F
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from typing import Dict, List, Tuple, Optional
from transformers import RobertaTokenizer


class ExplainabilityEngine:
    """
    Extracts and visualizes model internals: tokens, embeddings, and similarity.
    """
    
    def __init__(self, tokenizer_name: str = "roberta-base", emotion_dict: Optional[Dict] = None):
        """
        Initialize explainability engine.
        
        Args:
            tokenizer_name: HuggingFace tokenizer
            emotion_dict: Plutchik emotion dictionary (for names)
        """
        self.tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name)
        self.emotion_dict = emotion_dict or {}
        self.pca_2d = None
        self.pca_3d = None
        self.emotion_centroids = {}  # Will be populated during training
    
    def tokenize_with_visualization(self, text: str) -> Dict:
        """
        Tokenize text and prepare for visualization.
        
        Args:
            text: Input text
        
        Returns:
            Dict with token_ids, tokens, token_strings
        """
        encoding = self.tokenizer(
            text,
            max_length=256,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        input_ids = encoding["input_ids"].squeeze().cpu().numpy()
        
        # Convert token IDs to tokens
        tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
        
        # Remove padding tokens for display
        attention_mask = encoding["attention_mask"].squeeze().cpu().numpy()
        active_tokens = [tokens[i] for i in range(len(tokens)) if attention_mask[i] == 1]
        active_ids = [input_ids[i] for i in range(len(input_ids)) if attention_mask[i] == 1]
        
        return {
            "token_ids": input_ids.tolist(),
            "tokens": tokens,
            "active_tokens": active_tokens,
            "active_token_ids": active_ids,
            "attention_mask": attention_mask.tolist()
        }
    
    def extract_embeddings(self, last_hidden_state: torch.Tensor, 
                          attention_mask: torch.Tensor) -> Dict:
        """
        Extract and process embeddings from last hidden state.
        
        Args:
            last_hidden_state: [batch_size, seq_length, hidden_dim] or [seq_length, hidden_dim]
            attention_mask: Attention mask for filtering padding
        
        Returns:
            Dict with CLS embedding, mean pooling, and token embeddings
        """
        if last_hidden_state.dim() == 3:
            last_hidden_state = last_hidden_state.squeeze(0)
        if attention_mask.dim() == 2:
            attention_mask = attention_mask.squeeze(0)
        
        # CLS token (first token)
        cls_embedding = last_hidden_state[0, :]  # [hidden_dim]
        
        # Mean pooling over non-padding tokens
        attention_mask_expanded = attention_mask.unsqueeze(-1).expand_as(last_hidden_state)
        sum_embeddings = torch.sum(last_hidden_state * attention_mask_expanded, dim=0)
        sum_mask = torch.clamp(attention_mask_expanded.sum(dim=0), min=1e-9)
        mean_embedding = sum_embeddings / sum_mask  # [hidden_dim]
        
        return {
            "cls_embedding": cls_embedding.detach().cpu().numpy(),
            "mean_embedding": mean_embedding.detach().cpu().numpy(),
            "all_token_embeddings": last_hidden_state.detach().cpu().numpy(),
            "attention_mask": attention_mask.detach().cpu().numpy()
        }
    
    def reduce_embeddings_pca(self, embeddings_list: List[np.ndarray], 
                             n_components: int = 2) -> Tuple[np.ndarray, PCA]:
        """
        Apply PCA to reduce high-dimensional embeddings.
        
        Args:
            embeddings_list: List of [hidden_dim] embeddings
            n_components: 2 or 3
        
        Returns:
            (reduced_embeddings, fitted_pca)
        """
        embeddings_array = np.vstack(embeddings_list)  # [n_samples, hidden_dim]
        
        # Standardize
        scaler = StandardScaler()
        embeddings_scaled = scaler.fit_transform(embeddings_array)
        
        # PCA
        pca = PCA(n_components=n_components)
        reduced = pca.fit_transform(embeddings_scaled)
        
        return reduced, pca
    
    def visualize_embedding_heatmap(self, embeddings: np.ndarray) -> Dict:
        """
        Prepare embedding as heatmap data (useful for Plotly heatmap).
        
        Args:
            embeddings: [seq_length, hidden_dim]
        
        Returns:
            Dict with heatmap-friendly data
        """
        # Normalize to [0, 1] for visualization
        em_min = embeddings.min(axis=0, keepdims=True)
        em_max = embeddings.max(axis=0, keepdims=True)
        em_max = np.where(em_max == em_min, 1.0, em_max)  # Avoid division by zero
        
        normalized = (embeddings - em_min) / (em_max - em_min + 1e-8)
        
        # Sample columns for visualization (768 is too many)
        sample_cols = np.linspace(0, normalized.shape[1] - 1, 30, dtype=int)
        sampled = normalized[:, sample_cols]
        
        return {
            "heatmap_data": sampled,
            "seq_length": embeddings.shape[0],
            "hidden_dim": embeddings.shape[1],
            "sampled_dims": sample_cols.tolist()
        }
    
    def compute_cosine_similarity(self, embedding: np.ndarray, 
                                 centroid: np.ndarray) -> float:
        """
        Compute cosine similarity between an embedding and emotion centroid.
        
        Args:
            embedding: [hidden_dim] numpy array
            centroid: [hidden_dim] numpy array
        
        Returns:
            Cosine similarity score [0, 1]
        """
        similarity = np.dot(embedding, centroid) / (
            np.linalg.norm(embedding) * np.linalg.norm(centroid) + 1e-8
        )
        # Map from [-1, 1] to [0, 1]
        similarity = (similarity + 1.0) / 2.0
        return float(similarity)
    
    def register_emotion_centroids(self, emotion_embeddings: Dict[str, List[np.ndarray]]):
        """
        Register emotion centroids from training data.
        Each emotion has a list of embeddings; compute mean as centroid.
        
        Args:
            emotion_embeddings: Dict[emotion_name -> List[embeddings]]
        """
        for emotion, embeddings_list in emotion_embeddings.items():
            if len(embeddings_list) > 0:
                centroid = np.mean(embeddings_list, axis=0)
                self.emotion_centroids[emotion] = centroid
    
    def explain_prediction(self, input_text: str, model_output: Dict,
                          predicted_emotion: str, top_k: int = 3) -> Dict:
        """
        Generate full explainability report for a single prediction.
        
        Args:
            input_text: Original input text
            model_output: Output dict from model (contains embeddings)
            predicted_emotion: Predicted emotion name
            top_k: Top K similar emotions in embedding space
        
        Returns:
            Comprehensive explainability dict
        """
        # Tokenization
        token_info = self.tokenize_with_visualization(input_text)
        
        # Embeddings
        embedding_info = self.extract_embeddings(
            model_output["last_hidden_state"],
            model_output.get("attention_mask", torch.ones(model_output["last_hidden_state"].shape[:-1]))
        )
        
        cls_emb = embedding_info["cls_embedding"]
        
        # Heatmap
        heatmap_info = self.visualize_embedding_heatmap(embedding_info["all_token_embeddings"])
        
        # Cosine similarity to emotion centroids
        similarities = {}
        for emotion, centroid in self.emotion_centroids.items():
            sim = self.compute_cosine_similarity(cls_emb, centroid)
            similarities[emotion] = sim
        
        # Top K
        sorted_sims = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
        top_k_emotions = sorted_sims[:top_k]
        
        return {
            "tokens": token_info,
            "embeddings": embedding_info,
            "heatmap": heatmap_info,
            "cosine_similarities": similarities,
            "top_k_similar_emotions": top_k_emotions,
            "predicted_emotion": predicted_emotion,
            "prediction_confidence": float(similarities.get(predicted_emotion, 0.0))
        }