"""State-native evaluator for AKSARA training. This module evaluates state traces, structural coherence, constraint satisfaction, curriculum progression, and calibration quality across validation splits. """ from __future__ import annotations from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Iterable, Optional import json import torch from aksara.training.state_objectives import ( ConstraintSatisfactionLoss, GOSCoherenceLoss, MultiStateMarginLoss, SemanticBindingLoss, StateAlignmentLoss, compute_state_objective_loss, CurriculumObjectiveBundle, ) @dataclass class StateEvalMetrics: """Aggregate metrics used by curriculum gating.""" state_accuracy: float = 0.0 constraint_f1: float = 0.0 semantic_separation: float = 0.0 trace_coherence: float = 0.0 drift_score: float = 0.0 calibration_gap: float = 0.0 loss: float = 0.0 extra: Dict[str, float] = field(default_factory=dict) class StateEvaluator: """Evaluator for validation and curriculum gating.""" def __init__(self, objective_bundle: Optional[CurriculumObjectiveBundle] = None): self.objective_bundle = objective_bundle or CurriculumObjectiveBundle() def _to_tensor(self, value: Any, device: Optional[torch.device] = None) -> torch.Tensor: if isinstance(value, torch.Tensor): return value.to(device=device) if device is not None else value if isinstance(value, (list, tuple)): return torch.tensor(value, dtype=torch.float32, device=device) if isinstance(value, (int, float)): return torch.tensor(float(value), dtype=torch.float32, device=device) return torch.tensor(0.0, dtype=torch.float32, device=device) def _summarize_output(self, model_output: Dict[str, Any]) -> Dict[str, float]: summary = {} if "state_eval" in model_output and isinstance(model_output["state_eval"], dict): state_eval = model_output["state_eval"] if "energi_total" in state_eval: summary["state_accuracy"] = float(self._to_tensor(state_eval["energi_total"]).detach().cpu().item()) if "kelengkapan_struktur" in state_eval: summary["structure_score"] = float(self._to_tensor(state_eval["kelengkapan_struktur"]).detach().cpu().item()) if "gos_trace" in model_output and isinstance(model_output["gos_trace"], dict): trace_scores = model_output["gos_trace"].get("trace_scores") if trace_scores is not None: ts = self._to_tensor(trace_scores) summary["trace_mean"] = float(ts.mean().detach().cpu().item()) summary["trace_std"] = float(ts.std(unbiased=False).detach().cpu().item()) if ts.numel() > 1 else 0.0 return summary def evaluate(self, model: Any, batch: Optional[Dict] = None) -> StateEvalMetrics: batch = batch or {} if hasattr(model, "eval"): model.eval() model_output = {} if hasattr(model, "forward"): with torch.no_grad(): try: model_output = model.forward(batch) except TypeError: model_output = model.forward() elif callable(model): with torch.no_grad(): model_output = model(batch) losses = { "state": StateAlignmentLoss()(model_output, batch), "constraint": ConstraintSatisfactionLoss()(model_output, batch), "semantic": SemanticBindingLoss()(model_output, batch), "trace": GOSCoherenceLoss()(model_output, batch), "margin": MultiStateMarginLoss()(model_output, batch), } total_loss = sum(losses.values()) summary = self._summarize_output(model_output) return StateEvalMetrics( state_accuracy=summary.get("state_accuracy", 0.0), constraint_f1=max(0.0, 1.0 - float(losses["constraint"].detach().cpu().item())), semantic_separation=max(0.0, 1.0 - float(losses["semantic"].detach().cpu().item())), trace_coherence=max(0.0, 1.0 - float(losses["trace"].detach().cpu().item())), drift_score=float(losses["margin"].detach().cpu().item()), calibration_gap=max(0.0, float(losses["state"].detach().cpu().item())), loss=float(total_loss.detach().cpu().item()), extra=summary, ) def _load_validation_rows(self, dataset: Any) -> Iterable[dict]: if isinstance(dataset, str): path = Path(dataset) if not path.exists(): raise FileNotFoundError(f"Validation dataset not found: {dataset}") if path.suffix.lower() == ".jsonl": with open(path, encoding="utf-8") as f: for line in f: line = line.strip() if line: yield json.loads(line) return if path.suffix.lower() == ".json": with open(path, encoding="utf-8") as f: rows = json.load(f) if isinstance(rows, list): yield from rows return raise ValueError("Validation JSON must contain a list of rows") raise ValueError("Validation dataset must be .json or .jsonl") if isinstance(dataset, Iterable): yield from dataset return raise TypeError("Unsupported validation dataset type") def evaluate_validation_split(self, model: Any, dataset: Any) -> StateEvalMetrics: rows = list(self._load_validation_rows(dataset)) if not rows: return StateEvalMetrics() metrics = [] for row in rows: metrics.append(self.evaluate(model, row)) count = float(len(metrics)) aggregated = StateEvalMetrics( state_accuracy=sum(m.state_accuracy for m in metrics) / count, constraint_f1=sum(m.constraint_f1 for m in metrics) / count, semantic_separation=sum(m.semantic_separation for m in metrics) / count, trace_coherence=sum(m.trace_coherence for m in metrics) / count, drift_score=sum(m.drift_score for m in metrics) / count, calibration_gap=sum(m.calibration_gap for m in metrics) / count, loss=sum(m.loss for m in metrics) / count, ) aggregated.extra["num_rows"] = float(len(rows)) return aggregated def should_advance_phase(self, metrics: StateEvalMetrics, config: Optional[Dict] = None) -> bool: config = config or {} thresholds = config.get("thresholds", {}) return ( metrics.state_accuracy >= float(thresholds.get("state_accuracy", 0.0)) and metrics.constraint_f1 >= float(thresholds.get("constraint_f1", 0.0)) and metrics.trace_coherence >= float(thresholds.get("trace_coherence", 0.0)) ) def summarize_metrics(self, metrics: StateEvalMetrics) -> Dict[str, float]: return { "state_accuracy": metrics.state_accuracy, "constraint_f1": metrics.constraint_f1, "semantic_separation": metrics.semantic_separation, "trace_coherence": metrics.trace_coherence, "drift_score": metrics.drift_score, "calibration_gap": metrics.calibration_gap, "loss": metrics.loss, **metrics.extra, } def evaluate_state_batch(*args, **kwargs): evaluator = StateEvaluator() return evaluator.evaluate(*args, **kwargs) def evaluate_validation_split(*args, **kwargs): evaluator = StateEvaluator() return evaluator.evaluate_validation_split(*args, **kwargs) def should_advance_phase(*args, **kwargs): evaluator = StateEvaluator() return evaluator.should_advance_phase(*args, **kwargs)