""" Deep ERC Preprocessing Pipeline with Metadata and Context Windows. Handles dialogue context, metadata augmentation, and dataset preparation. """ import pandas as pd import numpy as np import torch from torch.utils.data import Dataset from transformers import RobertaTokenizer from typing import Dict, List, Tuple, Optional class ERCPreprocessor: """ Preprocesses ERC dialogue data with scenario, topic, and context window. """ def __init__(self, plutchik_dict: Dict, tokenizer_name: str = "roberta-base"): """ Initialize preprocessor. Args: plutchik_dict: Dictionary mapping emotion names to metadata tokenizer_name: HuggingFace tokenizer identifier """ self.plutchik = plutchik_dict self.tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name) self.emotion_to_idx = {emotion: idx for idx, emotion in enumerate(sorted(plutchik_dict.keys()))} self.idx_to_emotion = {v: k for k, v in self.emotion_to_idx.items()} self.scenarios = ["workplace", "friendship", "family", "romance", "support", "academic", "conflict", "casual", "social", "travel", "technology", "creative", "wellbeing", "community"] # Binary Domain Mapping for Adversarial Hardening # Based on data-derived sarcasm rates (median split) # Group 0: High-Sarcasm scenarios, Group 1: Low-Sarcasm scenarios # Audit: workplace(47%), social(32%), conflict(24%), casual(22%), friendship(13%), romance(9%) self.high_sarcasm_scenarios = { "workplace", "social", "conflict", "casual", "friendship", "romance" } self.scenario_to_idx = {s: (0 if s in self.high_sarcasm_scenarios else 1) for s in self.scenarios} def augment_with_metadata(self, text: str, scenario: str, topic: str) -> str: """ Prepend scenario and topic metadata to text. Format: [SCENARIO] workplace [/SCENARIO] [TOPIC] termination [/TOPIC] Args: text: Original utterance scenario: Workplace, friendship, family, etc. topic: conversation topic Returns: Augmented text with metadata """ augmented = f"[SCENARIO] {scenario} [/SCENARIO] [TOPIC] {topic} [/TOPIC] {text}" return augmented def get_context_window(self, dialogues: List[Tuple], current_idx: int, window_size: int = 2) -> str: """ Retrieve previous N turns from the dialogue to capture emotional shift. Args: dialogues: List of (speaker, text, emotion, sarcasm_flag, emotion_cause) tuples current_idx: Index of current utterance window_size: Number of previous turns to include Returns: Concatenated context string """ context_turns = [] # Include previous turns (up to window_size) start_idx = max(0, current_idx - window_size) for idx in range(start_idx, current_idx): speaker, text, _, _, _ = dialogues[idx] # Abbreviate speaker for context speaker_abbr = speaker.split('_')[0][:3] context_turns.append(f"{speaker_abbr}: {text}") context_window = " | ".join(context_turns) if context_turns else "[NO_CONTEXT]" return context_window def prepare_sample(self, speaker: str, text: str, emotion: str, sarcasm_flag: bool, emotion_cause: Optional[str], scenario: str, topic: str, dialogues: List[Tuple], current_idx: int, iaa_score: float = 0.75, row_data: Dict = None) -> Dict: """ Prepare a single training sample with all augmentations. Args: speaker: Speaker name text: Utterance text emotion: Target emotion label sarcasm_flag: Whether utterance contains sarcasm emotion_cause: Explanation of emotion trigger scenario: Dialogue scenario topic: Dialogue topic dialogues: Full dialogue list (for context) current_idx: Current utterance index iaa_score: Inter-annotator agreement score for weighting Returns: Dict with processed sample """ # Augment with scenario and topic augmented_text = self.augment_with_metadata(text, scenario, topic) # Get context window context = self.get_context_window(dialogues, current_idx, window_size=2) # Combine augmented text with context full_input = f"[CONTEXT] {context} [/CONTEXT] [CURRENT] {augmented_text} [/CURRENT]" # Tokenize encoding = self.tokenizer( full_input, max_length=256, padding='max_length', truncation=True, return_tensors='pt' ) # Get emotion index emotion_idx = self.emotion_to_idx[emotion] # Sarcasm as binary (0 or 1) sarcasm_idx = int(sarcasm_flag) # Intensity: map primary emotions to 0.5, intense to 1.0, mild to 0.25 ring = self.plutchik[emotion].get("ring", "primary") if ring == "intense": intensity = 1.0 elif ring == "primary": intensity = 0.5 elif ring == "mild": intensity = 0.25 elif ring == "dyadic": intensity = 0.6 else: intensity = 0.5 return { "input_ids": encoding["input_ids"].squeeze(), "attention_mask": encoding["attention_mask"].squeeze(), "emotion_label": emotion_idx, "sarcasm_label": sarcasm_idx, "intensity_label": intensity, "iaa_weight": iaa_score, "emotion_name": emotion, "speaker": speaker, "scenario": scenario, "scenario_label": self.scenario_to_idx.get(scenario, 0), "topic": topic, "emotion_cause": emotion_cause if pd.notna(emotion_cause) else "Not specified", "full_text": full_input, "split": row_data.get("split", "train") if isinstance(row_data, dict) else "train" } class PlutchikERCDataset(Dataset): """ PyTorch Dataset for Plutchik ERC data with multi-task labels. """ def __init__(self, samples: List[Dict]): """ Args: samples: List of preprocessed samples from ERCPreprocessor """ self.samples = samples def __len__(self): return len(self.samples) def __getitem__(self, idx): sample = self.samples[idx] batch = { "input_ids": sample["input_ids"], "attention_mask": sample["attention_mask"], "emotion_label": torch.tensor(sample["emotion_label"], dtype=torch.long), "sarcasm_label": torch.tensor(sample["sarcasm_label"], dtype=torch.long), "intensity_label": torch.tensor(sample["intensity_label"], dtype=torch.float32).unsqueeze(0), "iaa_weight": torch.tensor(sample["iaa_weight"], dtype=torch.float32), "scenario_label": torch.tensor(sample["scenario_label"], dtype=torch.long), } # Optional Dissonance Head Fields if "context_input_ids" in sample: batch["context_input_ids"] = sample["context_input_ids"] batch["context_attention_mask"] = sample["context_attention_mask"] batch["dissonance_label"] = torch.tensor(sample["dissonance_label"], dtype=torch.float32) return batch def build_dataset_from_dialogues(dialogues_list: List[Dict], plutchik_dict: Dict, tokenizer_name: str = "roberta-base") -> PlutchikERCDataset: """ Build complete dataset from dialogue list (DIALOGUES constant format). """ preprocessor = ERCPreprocessor(plutchik_dict, tokenizer_name) samples = [] for dialogue_dict in dialogues_list: scenario = dialogue_dict["scenario"] topic = dialogue_dict["topic"] utterances = dialogue_dict["utterances"] for idx, utterance in enumerate(utterances): speaker, text, emotion, sarcasm_flag, emotion_cause = utterance sample = preprocessor.prepare_sample( speaker=speaker, text=text, emotion=emotion, sarcasm_flag=sarcasm_flag, emotion_cause=emotion_cause, scenario=scenario, topic=topic, dialogues=utterances, current_idx=idx, iaa_score=0.80 ) samples.append(sample) return PlutchikERCDataset(samples) def build_dataset_from_csv(csv_path: str, plutchik_dict: Dict, tokenizer_name: str = "roberta-base", split: str = None) -> PlutchikERCDataset: """ Build dataset by loading CSV and grouping by dialogue_id for context. Optionally filters by split (train/val/test). """ df = pd.read_csv(csv_path) if split: df = df[df["split"] == split] preprocessor = ERCPreprocessor(plutchik_dict, tokenizer_name) samples = [] # Group by dialogue_id to preserve context dialogues = df.groupby("dialogue_id") for _, group in dialogues: group = group.sort_values("turn_id") utterances_list = [] for _, row in group.iterrows(): utterances_list.append(( row["speaker"], row["text"], row["emotion"], row["sarcasm_flag"], row["emotion_cause"] )) for idx, row in group.reset_index().iterrows(): # Pass full row as dict to prepare_sample for extra metadata (like split) row_dict = row.to_dict() sample = preprocessor.prepare_sample( speaker=row["speaker"], text=row["text"], emotion=row["emotion"], sarcasm_flag=row["sarcasm_flag"], emotion_cause=row["emotion_cause"], scenario=row["scenario"], topic=row["topic"], dialogues=utterances_list, current_idx=idx, iaa_score=row["inter_annotator_agreement"], row_data=row_dict ) samples.append(sample) # FIX: was missing — caused empty dataset on all CSV loads assert len(samples) > 0, ( f"build_dataset_from_csv produced 0 samples from {csv_path}. " "Check that the CSV has the expected columns and that the split filter is correct." ) return PlutchikERCDataset(samples) def load_contrastive_pairs(jsonl_path: str, plutchik_dict: Dict, tokenizer_name: str = "roberta-base") -> List[Dict]: """ Load human-verified contrastive pairs for dissonance head training. """ import json import os preprocessor = ERCPreprocessor(plutchik_dict, tokenizer_name) samples = [] if not os.path.exists(jsonl_path): return [] with open(jsonl_path, 'r') as f: for line in f: pair = json.loads(line) # pair_verifier.py appends both 'pair' and 'twin' to the file. dissonance_score = pair.get('dissonance_score', 1.0) is_dissonant = dissonance_score > 0.5 context = pair['original_context'] if is_dissonant else pair['twin_context'] emotion = pair['original_emotion'] if is_dissonant else pair['twin_emotion'] # Prepare sample augmented_text = preprocessor.augment_with_metadata(pair['text'], pair['scenario'], "general") full_input = f"[CONTEXT] {context} [/CONTEXT] [CURRENT] {augmented_text} [/CURRENT]" encoding = preprocessor.tokenizer( full_input, max_length=256, padding='max_length', truncation=True, return_tensors='pt' ) # For the Dual-Encoder dissonance head, we also need the context *alone* ctx_encoding = preprocessor.tokenizer( context, max_length=128, padding='max_length', truncation=True, return_tensors='pt' ) samples.append({ "input_ids": encoding["input_ids"].squeeze(), "attention_mask": encoding["attention_mask"].squeeze(), "context_input_ids": ctx_encoding["input_ids"].squeeze(), "context_attention_mask": ctx_encoding["attention_mask"].squeeze(), "emotion_label": preprocessor.emotion_to_idx.get(emotion, 0), "sarcasm_label": 1 if is_dissonant else 0, "intensity_label": 0.8 if is_dissonant else 0.4, "dissonance_label": dissonance_score, "iaa_weight": 2.0, "scenario_label": preprocessor.scenario_to_idx.get(pair['scenario'], 0), "split": "train" }) return samples