""" Explainability Engine v2.1 — Feature Attribution via Captum Implements Integrated Gradients for token-level contribution analysis. """ import torch import numpy as np from captum.attr import IntegratedGradients, LayerIntegratedGradients from typing import List, Dict, Tuple class CaptumExplainer: """ Advanced explainability using Captum attribution. """ def __init__(self, model, tokenizer): self.model = model self.tokenizer = tokenizer def attribute_tokens(self, full_input: str, target_class: int, n_steps: int = 20): """ Compute Integrated Gradients attribution for tokens in the full input string. Returns both the list of all attributions and a split for the dashboard. """ self.model.eval() self.model.zero_grad() # Tokenize full input encoding = self.tokenizer( full_input, return_tensors='pt', truncation=True, max_length=256, padding='max_length' ) input_ids = encoding["input_ids"].to(next(self.model.parameters()).device) attention_mask = encoding["attention_mask"].to(next(self.model.parameters()).device) # We need a wrapper that returns just the emotion logits for Captum def model_forward_wrapper(ids, mask): outputs = self.model(ids, mask) return outputs["emotion_logits"] # Layer Integrated Gradients on embeddings lig = LayerIntegratedGradients(model_forward_wrapper, self.model.roberta.embeddings) # Baselines (use pad_token_id) baseline_ids = torch.full_like(input_ids, self.tokenizer.pad_token_id) attributions, delta = lig.attribute( inputs=input_ids, baselines=baseline_ids, additional_forward_args=(attention_mask,), target=target_class, n_steps=n_steps, return_convergence_delta=True ) # Summarize across embedding dimensions attributions = attributions.sum(dim=-1).squeeze(0) # Normalize attributions = attributions / (torch.norm(attributions) + 1e-8) # Map to tokens tokens = self.tokenizer.convert_ids_to_tokens(input_ids.squeeze().tolist()) # 1. Identify current span boundaries current_start_idx = -1 current_end_idx = len(tokens) for i, t in enumerate(tokens): ct = t.upper() if "URRENT" in ct: # Check for opening vs closing is_closing = any("/" in tokens[j] for j in range(max(0, i-2), i)) if is_closing: if current_end_idx == len(tokens): # Start of [/CURRENT] is usually 2 tokens back (e.g., [, / or [/) current_end_idx = i - 1 while current_end_idx > 0 and "[" not in tokens[current_end_idx]: current_end_idx -= 1 else: if current_start_idx == -1: # End of [CURRENT] is usually 1 token ahead (the ']') current_start_idx = i + 1 while current_start_idx < len(tokens) and "]" not in tokens[current_start_idx]: current_start_idx += 1 current_start_idx += 1 # Move past the ']' # 2. Process tokens and filter tags all_results = [] context_results = [] current_results = [] tag_parts = {"[", "]", "/", " [", "Ġ[", "Ġ[/", "Ġ/", "C", "URRENT", "CON", "TEXT", "SC", "EN", "AR", "IO", "TOP", "IC"} for i, token in enumerate(tokens): if token in ["", "", ""] or token == self.tokenizer.pad_token: continue clean_token = token.strip().upper() if clean_token in tag_parts or token in tag_parts: continue # If the token is basically just punctuation inside brackets, ignore it if len(clean_token) == 1 and clean_token in "[]/": continue entry = { "token": token.replace("Ġ", " "), "score": float(attributions[i].item()) } all_results.append(entry) if current_start_idx <= i < current_end_idx: current_results.append(entry) else: context_results.append(entry) return { "token_attributions": all_results, "context_span_top": sorted(context_results, key=lambda x: abs(x["score"]), reverse=True)[:10], "current_span_top": sorted(current_results, key=lambda x: abs(x["score"]), reverse=True)[:10] }