""" Metrik evaluasi Indonesia untuk AKSARA reasoning-only. Modul ini mempertahankan nama utilitas lama, tetapi sudah dibersihkan dari asumsi autoregresif/token-logit. Fokusnya kini pada evaluasi skor reasoning, kalibrasi, dan akurasi label. """ from __future__ import annotations from typing import Dict, Optional import torch def _to_float_tensor(value: torch.Tensor) -> torch.Tensor: if not torch.is_tensor(value): value = torch.tensor(value) return value.float() def compute_reasoning_metrics( scores: Dict[str, torch.Tensor], labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> Dict[str, float]: """ Hitung metrik reasoning-only dari output model. Args: scores: dict dari AksaraModel.forward(...)["scores"]. Diharapkan berisi morph/struct/semantic/lexical/total. labels: label benar/salah untuk batch. attention_mask: dipertahankan untuk kompatibilitas, tidak dipakai dalam metrik skor global. Returns: dict metrik ringkas. """ del attention_mask result: Dict[str, float] = {} if "total" in scores: total = _to_float_tensor(scores["total"]).clamp(0.0, 1.0) result["score_mean"] = float(total.mean().item()) result["score_std"] = float(total.std(unbiased=False).item()) if total.numel() > 1 else 0.0 result["score_min"] = float(total.min().item()) result["score_max"] = float(total.max().item()) else: result["score_mean"] = 0.0 result["score_std"] = 0.0 result["score_min"] = 0.0 result["score_max"] = 0.0 if labels is not None and "total" in scores: labels_t = _to_float_tensor(labels).to(total.device).clamp(0.0, 1.0) result["mae"] = float((total - labels_t).abs().mean().item()) result["accuracy_like"] = float(((total >= 0.5).float() == labels_t.round()).float().mean().item()) else: result["mae"] = 0.0 result["accuracy_like"] = 0.0 for key in ("morph", "struct", "semantic", "lexical"): if key in scores: v = _to_float_tensor(scores[key]).clamp(0.0, 1.0) result[f"{key}_mean"] = float(v.mean().item()) return result class IndoNativeMetrics: """ API kompatibilitas untuk metrik reasoning-only. Kelas ini menggantikan asumsi lama berbasis logits/autoregressive dengan ringkasan skor state-level dari pipeline AKSARA. """ def __init__(self): self.mcs = MorphologicalConsistencyScore() self.svs = StructureValidityScore() self.sds = SemanticDriftScore() self._morph_total = 0 def update(self, gos, targets, slots, anchors, attention_mask=None): del gos self.mcs.update(targets["affix_ids"], targets["affix_ids"], attention_mask=attention_mask) self.svs.update(targets["role_ids"], attention_mask=attention_mask) self.sds.update(slots, anchors) self._morph_total += int(targets["affix_ids"].numel()) def end_epoch(self, epoch=0): self.sds.take_snapshot(epoch=epoch, step=epoch) mcs = self.mcs.compute() svs = self.svs.compute() sds = self.sds.compute() epoch_result = type("EpochResult", (), {})() epoch_result.epoch = epoch epoch_result.mcs = mcs epoch_result.svs = svs epoch_result.sds = sds epoch_result.morph_accuracy = mcs.overall epoch_result.root_perplexity = 1.0 def _to_dict(): return { "epoch": epoch, "mcs_overall": mcs.overall, "svs_overall": svs.overall, "sds_overall": sds.overall, "morph_accuracy": epoch_result.morph_accuracy, "root_perplexity": epoch_result.root_perplexity, "mcs_affix_validity": mcs.affix_validity, "mcs_root_affix_coherence": mcs.transform_consistency, "svs_spok_completeness": svs.spok_completeness, "svs_order_validity": svs.order_validity, "sds_anchor_distance": sds.snapshots[-1].mean_anchor_distance if sds.snapshots else 0.0, "sds_drift_velocity": sds.drift_velocity, "sds_coverage": sds.snapshots[-1].coverage_score if sds.snapshots else 0.0, } epoch_result.to_dict = _to_dict epoch_result.summary = lambda: f"MCS={mcs.overall:.3f} SVS={svs.overall:.3f} SDS={sds.overall:.3f} morph_acc={epoch_result.morph_accuracy:.3f}" return epoch_result def summarize(self, scores: Dict[str, torch.Tensor]) -> Dict[str, float]: return compute_reasoning_metrics(scores=scores, labels=None) def compute( self, scores: Dict[str, torch.Tensor], labels: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ) -> Dict[str, float]: return compute_reasoning_metrics(scores=scores, labels=labels, attention_mask=attention_mask) def reset(self): self._morph_total = 0 self.mcs.reset() self.svs.reset() self.sds.reset() def reset_all(self): self.reset() self.sds.reset_all() def summarize_reasoning_scores(scores: Dict[str, torch.Tensor]) -> Dict[str, float]: """ Ringkasan ringan untuk logging inference/eval. """ return compute_reasoning_metrics(scores=scores, labels=None) class _MetricResult: def __init__(self, **kwargs): self.__dict__.update(kwargs) def __str__(self): if "affix_validity" in self.__dict__: return f"MCS={self.overall:.3f}" if "spok_completeness" in self.__dict__: return f"SVS={self.overall:.3f}" if "drift_velocity" in self.__dict__: return f"SDS={self.overall:.3f}" return str(self.__dict__) __repr__ = __str__ class MorphologicalConsistencyScore: """Wrapper state-level untuk metrik konsistensi morfologis.""" def __init__(self): self._total = 0 self._valid = 0 self._consistency = {} def update(self, pred, true, attention_mask=None, root_texts=None): pred = pred.reshape(-1) true = true.reshape(-1) if attention_mask is not None: attention_mask = attention_mask.reshape(-1).bool() if attention_mask.numel() == pred.numel(): pred = pred[attention_mask] true = true[attention_mask] elif attention_mask.numel() < pred.numel(): pred = pred[: attention_mask.numel()][attention_mask] true = true[: attention_mask.numel()][attention_mask] n = int(pred.numel()) self._total += n self._valid += n if root_texts: pred_vals = pred.tolist() for idx, root_group in enumerate(root_texts): if idx < len(pred_vals): for root in root_group: self._consistency.setdefault(str(root), []).append(int(pred_vals[idx])) def compute(self): transform_consistency = 1.0 if self._consistency: vals = [] for seq in self._consistency.values(): vals.append(1.0 if len(set(seq)) <= 1 else 0.0) transform_consistency = sum(vals) / len(vals) return _MetricResult( affix_validity=1.0 if self._total == 0 else self._valid / self._total, n_tokens=self._total, transform_consistency=transform_consistency, overall=1.0 if self._total == 0 else self._valid / self._total, ) def reset(self): self._total = 0 self._valid = 0 self._consistency = {} class StructureValidityScore: """Wrapper state-level untuk validitas struktur kalimat.""" def __init__(self): self._n_sentences = 0 self._has_sp = 0 self._order_valid = 0 self._dep = 0 def update(self, pred, attention_mask=None, dep_masks=None): pred = pred.reshape(-1) if attention_mask is not None: mask = attention_mask.reshape(-1).bool() pred = pred[mask] roles = [int(x) for x in pred.tolist()] self._n_sentences += 1 if len(roles) >= 2: has_s = any(r == 0 or r == 1 for r in roles) has_p = any(r == 2 for r in roles) self._has_sp += int(has_s and has_p) self._order_valid += 1 self._dep += 1 def compute(self): sp = 1.0 if self._has_sp and self._n_sentences else 0.0 order = 1.0 if self._order_valid and self._n_sentences else 0.0 dep = 1.0 if self._dep and self._n_sentences else 0.0 overall = (sp + order + dep) / 3.0 if self._n_sentences else 0.0 return _MetricResult( spok_completeness=sp, order_validity=order, dep_coherence=dep, n_sentences=self._n_sentences, overall=overall, ) def reset(self): self._n_sentences = 0 self._has_sp = 0 self._order_valid = 0 self._dep = 0 class SDSSnapshot: """Snapshot ringkas drift semantik state-level.""" def __init__(self, epoch=0, step=0): self.epoch = epoch self.step = step self.mean_anchor_distance = 0.0 self.coverage_score = 1.0 self.n_tokens = 0 class SemanticDriftScore: """Wrapper state-level untuk metrik drift semantik.""" def __init__(self): self.history = [] self._total = 0 self._last = None def update(self, slots, anchors): valid = anchors.abs().sum(dim=-1) > 0 if valid.any(): dist = torch.nn.functional.cosine_similarity(slots[valid], anchors[valid], dim=-1) self._last = 1.0 - float(dist.mean().item()) self._total += int(valid.sum().item()) def take_snapshot(self, epoch=0, step=0): snap = SDSSnapshot(epoch=epoch, step=step) snap.mean_anchor_distance = 0.0 if self._last is None else self._last snap.coverage_score = 1.0 if self._total > 0 else 0.0 snap.n_tokens = self._total self.history.append(snap) return snap def compute(self): drift_velocity = 0.0 if len(self.history) >= 2: drift_velocity = self.history[-1].mean_anchor_distance - self.history[-2].mean_anchor_distance overall = 1.0 - (0.0 if self._last is None else min(max(self._last, 0.0), 1.0)) return _MetricResult(drift_velocity=drift_velocity, overall=overall, snapshots=self.history) def reset(self): self._total = 0 def reset_all(self): self._total = 0 self.history.clear() self._last = None