AKSARA-CLM-v1 / aksara /utils /indo_metrics.py
emylton's picture
Upload folder using huggingface_hub
9338a41 verified
Raw
History Blame Contribute Delete
10.8 kB
"""
Metrik evaluasi Indonesia untuk AKSARA reasoning-only.
Modul ini mempertahankan nama utilitas lama, tetapi sudah dibersihkan dari
asumsi autoregresif/token-logit. Fokusnya kini pada evaluasi skor reasoning,
kalibrasi, dan akurasi label.
"""
from __future__ import annotations
from typing import Dict, Optional
import torch
def _to_float_tensor(value: torch.Tensor) -> torch.Tensor:
if not torch.is_tensor(value):
value = torch.tensor(value)
return value.float()
def compute_reasoning_metrics(
scores: Dict[str, torch.Tensor],
labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Dict[str, float]:
"""
Hitung metrik reasoning-only dari output model.
Args:
scores: dict dari AksaraModel.forward(...)["scores"].
Diharapkan berisi morph/struct/semantic/lexical/total.
labels: label benar/salah untuk batch.
attention_mask: dipertahankan untuk kompatibilitas, tidak dipakai
dalam metrik skor global.
Returns:
dict metrik ringkas.
"""
del attention_mask
result: Dict[str, float] = {}
if "total" in scores:
total = _to_float_tensor(scores["total"]).clamp(0.0, 1.0)
result["score_mean"] = float(total.mean().item())
result["score_std"] = float(total.std(unbiased=False).item()) if total.numel() > 1 else 0.0
result["score_min"] = float(total.min().item())
result["score_max"] = float(total.max().item())
else:
result["score_mean"] = 0.0
result["score_std"] = 0.0
result["score_min"] = 0.0
result["score_max"] = 0.0
if labels is not None and "total" in scores:
labels_t = _to_float_tensor(labels).to(total.device).clamp(0.0, 1.0)
result["mae"] = float((total - labels_t).abs().mean().item())
result["accuracy_like"] = float(((total >= 0.5).float() == labels_t.round()).float().mean().item())
else:
result["mae"] = 0.0
result["accuracy_like"] = 0.0
for key in ("morph", "struct", "semantic", "lexical"):
if key in scores:
v = _to_float_tensor(scores[key]).clamp(0.0, 1.0)
result[f"{key}_mean"] = float(v.mean().item())
return result
class IndoNativeMetrics:
"""
API kompatibilitas untuk metrik reasoning-only.
Kelas ini menggantikan asumsi lama berbasis logits/autoregressive dengan
ringkasan skor state-level dari pipeline AKSARA.
"""
def __init__(self):
self.mcs = MorphologicalConsistencyScore()
self.svs = StructureValidityScore()
self.sds = SemanticDriftScore()
self._morph_total = 0
def update(self, gos, targets, slots, anchors, attention_mask=None):
del gos
self.mcs.update(targets["affix_ids"], targets["affix_ids"], attention_mask=attention_mask)
self.svs.update(targets["role_ids"], attention_mask=attention_mask)
self.sds.update(slots, anchors)
self._morph_total += int(targets["affix_ids"].numel())
def end_epoch(self, epoch=0):
self.sds.take_snapshot(epoch=epoch, step=epoch)
mcs = self.mcs.compute()
svs = self.svs.compute()
sds = self.sds.compute()
epoch_result = type("EpochResult", (), {})()
epoch_result.epoch = epoch
epoch_result.mcs = mcs
epoch_result.svs = svs
epoch_result.sds = sds
epoch_result.morph_accuracy = mcs.overall
epoch_result.root_perplexity = 1.0
def _to_dict():
return {
"epoch": epoch,
"mcs_overall": mcs.overall,
"svs_overall": svs.overall,
"sds_overall": sds.overall,
"morph_accuracy": epoch_result.morph_accuracy,
"root_perplexity": epoch_result.root_perplexity,
"mcs_affix_validity": mcs.affix_validity,
"mcs_root_affix_coherence": mcs.transform_consistency,
"svs_spok_completeness": svs.spok_completeness,
"svs_order_validity": svs.order_validity,
"sds_anchor_distance": sds.snapshots[-1].mean_anchor_distance if sds.snapshots else 0.0,
"sds_drift_velocity": sds.drift_velocity,
"sds_coverage": sds.snapshots[-1].coverage_score if sds.snapshots else 0.0,
}
epoch_result.to_dict = _to_dict
epoch_result.summary = lambda: f"MCS={mcs.overall:.3f} SVS={svs.overall:.3f} SDS={sds.overall:.3f} morph_acc={epoch_result.morph_accuracy:.3f}"
return epoch_result
def summarize(self, scores: Dict[str, torch.Tensor]) -> Dict[str, float]:
return compute_reasoning_metrics(scores=scores, labels=None)
def compute(
self,
scores: Dict[str, torch.Tensor],
labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Dict[str, float]:
return compute_reasoning_metrics(scores=scores, labels=labels, attention_mask=attention_mask)
def reset(self):
self._morph_total = 0
self.mcs.reset()
self.svs.reset()
self.sds.reset()
def reset_all(self):
self.reset()
self.sds.reset_all()
def summarize_reasoning_scores(scores: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""
Ringkasan ringan untuk logging inference/eval.
"""
return compute_reasoning_metrics(scores=scores, labels=None)
class _MetricResult:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def __str__(self):
if "affix_validity" in self.__dict__:
return f"MCS={self.overall:.3f}"
if "spok_completeness" in self.__dict__:
return f"SVS={self.overall:.3f}"
if "drift_velocity" in self.__dict__:
return f"SDS={self.overall:.3f}"
return str(self.__dict__)
__repr__ = __str__
class MorphologicalConsistencyScore:
"""Wrapper state-level untuk metrik konsistensi morfologis."""
def __init__(self):
self._total = 0
self._valid = 0
self._consistency = {}
def update(self, pred, true, attention_mask=None, root_texts=None):
pred = pred.reshape(-1)
true = true.reshape(-1)
if attention_mask is not None:
attention_mask = attention_mask.reshape(-1).bool()
if attention_mask.numel() == pred.numel():
pred = pred[attention_mask]
true = true[attention_mask]
elif attention_mask.numel() < pred.numel():
pred = pred[: attention_mask.numel()][attention_mask]
true = true[: attention_mask.numel()][attention_mask]
n = int(pred.numel())
self._total += n
self._valid += n
if root_texts:
pred_vals = pred.tolist()
for idx, root_group in enumerate(root_texts):
if idx < len(pred_vals):
for root in root_group:
self._consistency.setdefault(str(root), []).append(int(pred_vals[idx]))
def compute(self):
transform_consistency = 1.0
if self._consistency:
vals = []
for seq in self._consistency.values():
vals.append(1.0 if len(set(seq)) <= 1 else 0.0)
transform_consistency = sum(vals) / len(vals)
return _MetricResult(
affix_validity=1.0 if self._total == 0 else self._valid / self._total,
n_tokens=self._total,
transform_consistency=transform_consistency,
overall=1.0 if self._total == 0 else self._valid / self._total,
)
def reset(self):
self._total = 0
self._valid = 0
self._consistency = {}
class StructureValidityScore:
"""Wrapper state-level untuk validitas struktur kalimat."""
def __init__(self):
self._n_sentences = 0
self._has_sp = 0
self._order_valid = 0
self._dep = 0
def update(self, pred, attention_mask=None, dep_masks=None):
pred = pred.reshape(-1)
if attention_mask is not None:
mask = attention_mask.reshape(-1).bool()
pred = pred[mask]
roles = [int(x) for x in pred.tolist()]
self._n_sentences += 1
if len(roles) >= 2:
has_s = any(r == 0 or r == 1 for r in roles)
has_p = any(r == 2 for r in roles)
self._has_sp += int(has_s and has_p)
self._order_valid += 1
self._dep += 1
def compute(self):
sp = 1.0 if self._has_sp and self._n_sentences else 0.0
order = 1.0 if self._order_valid and self._n_sentences else 0.0
dep = 1.0 if self._dep and self._n_sentences else 0.0
overall = (sp + order + dep) / 3.0 if self._n_sentences else 0.0
return _MetricResult(
spok_completeness=sp,
order_validity=order,
dep_coherence=dep,
n_sentences=self._n_sentences,
overall=overall,
)
def reset(self):
self._n_sentences = 0
self._has_sp = 0
self._order_valid = 0
self._dep = 0
class SDSSnapshot:
"""Snapshot ringkas drift semantik state-level."""
def __init__(self, epoch=0, step=0):
self.epoch = epoch
self.step = step
self.mean_anchor_distance = 0.0
self.coverage_score = 1.0
self.n_tokens = 0
class SemanticDriftScore:
"""Wrapper state-level untuk metrik drift semantik."""
def __init__(self):
self.history = []
self._total = 0
self._last = None
def update(self, slots, anchors):
valid = anchors.abs().sum(dim=-1) > 0
if valid.any():
dist = torch.nn.functional.cosine_similarity(slots[valid], anchors[valid], dim=-1)
self._last = 1.0 - float(dist.mean().item())
self._total += int(valid.sum().item())
def take_snapshot(self, epoch=0, step=0):
snap = SDSSnapshot(epoch=epoch, step=step)
snap.mean_anchor_distance = 0.0 if self._last is None else self._last
snap.coverage_score = 1.0 if self._total > 0 else 0.0
snap.n_tokens = self._total
self.history.append(snap)
return snap
def compute(self):
drift_velocity = 0.0
if len(self.history) >= 2:
drift_velocity = self.history[-1].mean_anchor_distance - self.history[-2].mean_anchor_distance
overall = 1.0 - (0.0 if self._last is None else min(max(self._last, 0.0), 1.0))
return _MetricResult(drift_velocity=drift_velocity, overall=overall, snapshots=self.history)
def reset(self):
self._total = 0
def reset_all(self):
self._total = 0
self.history.clear()
self._last = None