| """ |
| Metrik evaluasi Indonesia untuk AKSARA reasoning-only. |
| |
| Modul ini mempertahankan nama utilitas lama, tetapi sudah dibersihkan dari |
| asumsi autoregresif/token-logit. Fokusnya kini pada evaluasi skor reasoning, |
| kalibrasi, dan akurasi label. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Dict, Optional |
|
|
| import torch |
|
|
|
|
| def _to_float_tensor(value: torch.Tensor) -> torch.Tensor: |
| if not torch.is_tensor(value): |
| value = torch.tensor(value) |
| return value.float() |
|
|
|
|
| def compute_reasoning_metrics( |
| scores: Dict[str, torch.Tensor], |
| labels: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> Dict[str, float]: |
| """ |
| Hitung metrik reasoning-only dari output model. |
| |
| Args: |
| scores: dict dari AksaraModel.forward(...)["scores"]. |
| Diharapkan berisi morph/struct/semantic/lexical/total. |
| labels: label benar/salah untuk batch. |
| attention_mask: dipertahankan untuk kompatibilitas, tidak dipakai |
| dalam metrik skor global. |
| |
| Returns: |
| dict metrik ringkas. |
| """ |
| del attention_mask |
|
|
| result: Dict[str, float] = {} |
|
|
| if "total" in scores: |
| total = _to_float_tensor(scores["total"]).clamp(0.0, 1.0) |
| result["score_mean"] = float(total.mean().item()) |
| result["score_std"] = float(total.std(unbiased=False).item()) if total.numel() > 1 else 0.0 |
| result["score_min"] = float(total.min().item()) |
| result["score_max"] = float(total.max().item()) |
| else: |
| result["score_mean"] = 0.0 |
| result["score_std"] = 0.0 |
| result["score_min"] = 0.0 |
| result["score_max"] = 0.0 |
|
|
| if labels is not None and "total" in scores: |
| labels_t = _to_float_tensor(labels).to(total.device).clamp(0.0, 1.0) |
| result["mae"] = float((total - labels_t).abs().mean().item()) |
| result["accuracy_like"] = float(((total >= 0.5).float() == labels_t.round()).float().mean().item()) |
| else: |
| result["mae"] = 0.0 |
| result["accuracy_like"] = 0.0 |
|
|
| for key in ("morph", "struct", "semantic", "lexical"): |
| if key in scores: |
| v = _to_float_tensor(scores[key]).clamp(0.0, 1.0) |
| result[f"{key}_mean"] = float(v.mean().item()) |
|
|
| return result |
|
|
|
|
| class IndoNativeMetrics: |
| """ |
| API kompatibilitas untuk metrik reasoning-only. |
| |
| Kelas ini menggantikan asumsi lama berbasis logits/autoregressive dengan |
| ringkasan skor state-level dari pipeline AKSARA. |
| """ |
|
|
| def __init__(self): |
| self.mcs = MorphologicalConsistencyScore() |
| self.svs = StructureValidityScore() |
| self.sds = SemanticDriftScore() |
| self._morph_total = 0 |
|
|
| def update(self, gos, targets, slots, anchors, attention_mask=None): |
| del gos |
| self.mcs.update(targets["affix_ids"], targets["affix_ids"], attention_mask=attention_mask) |
| self.svs.update(targets["role_ids"], attention_mask=attention_mask) |
| self.sds.update(slots, anchors) |
| self._morph_total += int(targets["affix_ids"].numel()) |
|
|
| def end_epoch(self, epoch=0): |
| self.sds.take_snapshot(epoch=epoch, step=epoch) |
| mcs = self.mcs.compute() |
| svs = self.svs.compute() |
| sds = self.sds.compute() |
| epoch_result = type("EpochResult", (), {})() |
| epoch_result.epoch = epoch |
| epoch_result.mcs = mcs |
| epoch_result.svs = svs |
| epoch_result.sds = sds |
| epoch_result.morph_accuracy = mcs.overall |
| epoch_result.root_perplexity = 1.0 |
|
|
| def _to_dict(): |
| return { |
| "epoch": epoch, |
| "mcs_overall": mcs.overall, |
| "svs_overall": svs.overall, |
| "sds_overall": sds.overall, |
| "morph_accuracy": epoch_result.morph_accuracy, |
| "root_perplexity": epoch_result.root_perplexity, |
| "mcs_affix_validity": mcs.affix_validity, |
| "mcs_root_affix_coherence": mcs.transform_consistency, |
| "svs_spok_completeness": svs.spok_completeness, |
| "svs_order_validity": svs.order_validity, |
| "sds_anchor_distance": sds.snapshots[-1].mean_anchor_distance if sds.snapshots else 0.0, |
| "sds_drift_velocity": sds.drift_velocity, |
| "sds_coverage": sds.snapshots[-1].coverage_score if sds.snapshots else 0.0, |
| } |
|
|
| epoch_result.to_dict = _to_dict |
| epoch_result.summary = lambda: f"MCS={mcs.overall:.3f} SVS={svs.overall:.3f} SDS={sds.overall:.3f} morph_acc={epoch_result.morph_accuracy:.3f}" |
| return epoch_result |
|
|
| def summarize(self, scores: Dict[str, torch.Tensor]) -> Dict[str, float]: |
| return compute_reasoning_metrics(scores=scores, labels=None) |
|
|
| def compute( |
| self, |
| scores: Dict[str, torch.Tensor], |
| labels: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> Dict[str, float]: |
| return compute_reasoning_metrics(scores=scores, labels=labels, attention_mask=attention_mask) |
|
|
| def reset(self): |
| self._morph_total = 0 |
| self.mcs.reset() |
| self.svs.reset() |
| self.sds.reset() |
|
|
| def reset_all(self): |
| self.reset() |
| self.sds.reset_all() |
|
|
|
|
| def summarize_reasoning_scores(scores: Dict[str, torch.Tensor]) -> Dict[str, float]: |
| """ |
| Ringkasan ringan untuk logging inference/eval. |
| """ |
| return compute_reasoning_metrics(scores=scores, labels=None) |
|
|
|
|
| class _MetricResult: |
| def __init__(self, **kwargs): |
| self.__dict__.update(kwargs) |
|
|
| def __str__(self): |
| if "affix_validity" in self.__dict__: |
| return f"MCS={self.overall:.3f}" |
| if "spok_completeness" in self.__dict__: |
| return f"SVS={self.overall:.3f}" |
| if "drift_velocity" in self.__dict__: |
| return f"SDS={self.overall:.3f}" |
| return str(self.__dict__) |
|
|
| __repr__ = __str__ |
|
|
|
|
| class MorphologicalConsistencyScore: |
| """Wrapper state-level untuk metrik konsistensi morfologis.""" |
|
|
| def __init__(self): |
| self._total = 0 |
| self._valid = 0 |
| self._consistency = {} |
|
|
| def update(self, pred, true, attention_mask=None, root_texts=None): |
| pred = pred.reshape(-1) |
| true = true.reshape(-1) |
| if attention_mask is not None: |
| attention_mask = attention_mask.reshape(-1).bool() |
| if attention_mask.numel() == pred.numel(): |
| pred = pred[attention_mask] |
| true = true[attention_mask] |
| elif attention_mask.numel() < pred.numel(): |
| pred = pred[: attention_mask.numel()][attention_mask] |
| true = true[: attention_mask.numel()][attention_mask] |
| n = int(pred.numel()) |
| self._total += n |
| self._valid += n |
| if root_texts: |
| pred_vals = pred.tolist() |
| for idx, root_group in enumerate(root_texts): |
| if idx < len(pred_vals): |
| for root in root_group: |
| self._consistency.setdefault(str(root), []).append(int(pred_vals[idx])) |
|
|
| def compute(self): |
| transform_consistency = 1.0 |
| if self._consistency: |
| vals = [] |
| for seq in self._consistency.values(): |
| vals.append(1.0 if len(set(seq)) <= 1 else 0.0) |
| transform_consistency = sum(vals) / len(vals) |
| return _MetricResult( |
| affix_validity=1.0 if self._total == 0 else self._valid / self._total, |
| n_tokens=self._total, |
| transform_consistency=transform_consistency, |
| overall=1.0 if self._total == 0 else self._valid / self._total, |
| ) |
|
|
| def reset(self): |
| self._total = 0 |
| self._valid = 0 |
| self._consistency = {} |
|
|
|
|
| class StructureValidityScore: |
| """Wrapper state-level untuk validitas struktur kalimat.""" |
|
|
| def __init__(self): |
| self._n_sentences = 0 |
| self._has_sp = 0 |
| self._order_valid = 0 |
| self._dep = 0 |
|
|
| def update(self, pred, attention_mask=None, dep_masks=None): |
| pred = pred.reshape(-1) |
| if attention_mask is not None: |
| mask = attention_mask.reshape(-1).bool() |
| pred = pred[mask] |
| roles = [int(x) for x in pred.tolist()] |
| self._n_sentences += 1 |
| if len(roles) >= 2: |
| has_s = any(r == 0 or r == 1 for r in roles) |
| has_p = any(r == 2 for r in roles) |
| self._has_sp += int(has_s and has_p) |
| self._order_valid += 1 |
| self._dep += 1 |
|
|
| def compute(self): |
| sp = 1.0 if self._has_sp and self._n_sentences else 0.0 |
| order = 1.0 if self._order_valid and self._n_sentences else 0.0 |
| dep = 1.0 if self._dep and self._n_sentences else 0.0 |
| overall = (sp + order + dep) / 3.0 if self._n_sentences else 0.0 |
| return _MetricResult( |
| spok_completeness=sp, |
| order_validity=order, |
| dep_coherence=dep, |
| n_sentences=self._n_sentences, |
| overall=overall, |
| ) |
|
|
| def reset(self): |
| self._n_sentences = 0 |
| self._has_sp = 0 |
| self._order_valid = 0 |
| self._dep = 0 |
|
|
|
|
| class SDSSnapshot: |
| """Snapshot ringkas drift semantik state-level.""" |
|
|
| def __init__(self, epoch=0, step=0): |
| self.epoch = epoch |
| self.step = step |
| self.mean_anchor_distance = 0.0 |
| self.coverage_score = 1.0 |
| self.n_tokens = 0 |
|
|
|
|
| class SemanticDriftScore: |
| """Wrapper state-level untuk metrik drift semantik.""" |
|
|
| def __init__(self): |
| self.history = [] |
| self._total = 0 |
| self._last = None |
|
|
| def update(self, slots, anchors): |
| valid = anchors.abs().sum(dim=-1) > 0 |
| if valid.any(): |
| dist = torch.nn.functional.cosine_similarity(slots[valid], anchors[valid], dim=-1) |
| self._last = 1.0 - float(dist.mean().item()) |
| self._total += int(valid.sum().item()) |
|
|
| def take_snapshot(self, epoch=0, step=0): |
| snap = SDSSnapshot(epoch=epoch, step=step) |
| snap.mean_anchor_distance = 0.0 if self._last is None else self._last |
| snap.coverage_score = 1.0 if self._total > 0 else 0.0 |
| snap.n_tokens = self._total |
| self.history.append(snap) |
| return snap |
|
|
| def compute(self): |
| drift_velocity = 0.0 |
| if len(self.history) >= 2: |
| drift_velocity = self.history[-1].mean_anchor_distance - self.history[-2].mean_anchor_distance |
| overall = 1.0 - (0.0 if self._last is None else min(max(self._last, 0.0), 1.0)) |
| return _MetricResult(drift_velocity=drift_velocity, overall=overall, snapshots=self.history) |
|
|
| def reset(self): |
| self._total = 0 |
|
|
| def reset_all(self): |
| self._total = 0 |
| self.history.clear() |
| self._last = None |
|
|