Spaces:
Sleeping
Sleeping
| """ | |
| Plutchik ERC v2.1 β Antigravity Training Harness | |
| Supports FP16, Macro-F1 Checkpointing, and CSV-based ingestion. | |
| """ | |
| import os | |
| import torch | |
| from torch.utils.data import DataLoader, ConcatDataset | |
| import sys | |
| from pathlib import Path | |
| import json | |
| import random | |
| import numpy as np | |
| # Bug 6 Fix: Set seeds for reproducibility | |
| SEED = 42 | |
| random.seed(SEED) | |
| np.random.seed(SEED) | |
| torch.manual_seed(SEED) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(SEED) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| # Project directory | |
| project_dir = Path(__file__).parent | |
| from models.multitask_emotion_model import PluTchikMultiTaskModel, MultiTaskLoss | |
| from utils.preprocessing import build_dataset_from_csv, load_contrastive_pairs, PlutchikERCDataset | |
| from utils.trainer import PluTchikTrainer | |
| # ============== CONFIGURATION ============== | |
| CONFIG = { | |
| "csv_path": "data/processed/ERC/plutchik_v2_production.csv", | |
| "model_dir": "my_plutchik_model", | |
| "batch_size": 4, # Reduced to 4 to fix MPS OOM on M1 | |
| "epochs": 1, # Fast prototype generation | |
| "lr": 5e-5, | |
| "max_len": 128, # Faster processing | |
| "iaa_weighting": True, | |
| "adv_weight": 0.3, | |
| "warmup_epochs": 0, | |
| "grl_lambda_max": 0.5 | |
| } | |
| # ============== PLUTCHIK CONSTANTS ============== | |
| from utils.constants import PLUTCHIK, NUM_EMOTIONS | |
| def run_antigravity_training(): | |
| print("=" * 60) | |
| print("π ANTIGRAVITY TRAINING HARNESS β PLUTCHIK ERC v2.1") | |
| print("=" * 60) | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" | |
| else: | |
| device = "cpu" | |
| print(f"\nβ Target Device: {device}") | |
| # 1. Build Dataset from Production CSV | |
| print(f"\nπ Ingesting production data from {CONFIG['csv_path']}...") | |
| csv_abs_path = Path(CONFIG["csv_path"]).resolve() | |
| if not csv_abs_path.exists(): | |
| csv_abs_path = project_dir.parent / "data" / "processed" / "ERC" / "plutchik_v2_production.csv" | |
| # Bug 1 Fix: Load train and val separately based on the 'split' column | |
| train_ds = build_dataset_from_csv( | |
| str(csv_abs_path), | |
| PLUTCHIK, | |
| tokenizer_name="roberta-base", | |
| split="train" | |
| ) | |
| val_ds = build_dataset_from_csv( | |
| str(csv_abs_path), | |
| PLUTCHIK, | |
| tokenizer_name="roberta-base", | |
| split="val" | |
| ) | |
| print(f"β Train set: {len(train_ds)} samples") | |
| print(f"β Val set: {len(val_ds)} samples") | |
| cda_path = os.environ.get("PLUTCHIK_CDA_JSONL", "").strip() | |
| cda_min = int(os.environ.get("PLUTCHIK_CDA_MIN_PAIRS", "200")) | |
| if cda_path: | |
| candidates = [Path(cda_path), project_dir / cda_path, project_dir.parent / cda_path] | |
| cda_file = next((p for p in candidates if p.is_file()), None) | |
| if cda_file is not None: | |
| cda_samples = load_contrastive_pairs(str(cda_file), PLUTCHIK, only_verified=True) | |
| if len(cda_samples) >= cda_min: | |
| train_ds = ConcatDataset([train_ds, PlutchikERCDataset(cda_samples)]) | |
| print(f"β Merged {len(cda_samples)} human-verified CDA rows (gate β₯{cda_min}).") | |
| else: | |
| print( | |
| f"β PLUTCHIK_CDA_JSONL set but only {len(cda_samples)} verified rows " | |
| f"(need β₯{cda_min}). Skipping CDA merge β export more via pair_verifier." | |
| ) | |
| else: | |
| print(f"β CDA file not found (tried): {cda_path}") | |
| # 2. Setup DataLoaders | |
| train_loader = DataLoader(train_ds, batch_size=CONFIG["batch_size"], shuffle=True) | |
| val_loader = DataLoader(val_ds, batch_size=CONFIG["batch_size"], shuffle=False) | |
| # 3. Model Initialization | |
| print("\nπ§ Initializing RoBERTa-base Multi-Task Model...") | |
| model = PluTchikMultiTaskModel(num_emotions=NUM_EMOTIONS) | |
| # 4. Loss Function (IAA Aware) | |
| loss_fn = MultiTaskLoss( | |
| emotion_weight=1.0, | |
| sarcasm_weight=0.7, | |
| intensity_weight=0.5, | |
| adv_weight=CONFIG["adv_weight"], | |
| iaa_weighting=CONFIG["iaa_weighting"] | |
| ) | |
| # 5. Trainer (Hardware-Aware with FP16) | |
| _model_dir = Path(CONFIG["model_dir"]) | |
| if not _model_dir.is_absolute(): | |
| _model_dir = project_dir / _model_dir | |
| trainer = PluTchikTrainer( | |
| model=model, | |
| loss_fn=loss_fn, | |
| device=device, | |
| save_dir=str(_model_dir), | |
| ) | |
| # 6. Execute Training | |
| print(f"\nπ₯ Commencing training for {CONFIG['epochs']} epochs...") | |
| history = trainer.fit( | |
| train_loader, | |
| val_loader, | |
| epochs=CONFIG["epochs"], | |
| learning_rate=CONFIG["lr"], | |
| warmup_epochs=CONFIG["warmup_epochs"], | |
| grl_lambda_max=CONFIG["grl_lambda_max"] | |
| ) | |
| # 7. Finalize | |
| print("\n" + "=" * 60) | |
| print("β¨ TRAINING COMPLETE") | |
| print(f"Best model saved to: {_model_dir / 'best_model.pt'}") | |
| print("=" * 60) | |
| if __name__ == "__main__": | |
| run_antigravity_training() | |