""" AksaraMetrics - Metrik evaluasi untuk workflow AKSARA yang berbasis kebenaran/koherensi semantik. Metrik ini mengakumulasi sinyal: 1. correctness morfologi dan struktur 2. alignment semantik ke anchor/KBBI 3. calibration skor keputusan 4. pressure anti-loop / anti-repetition 5. ringkasan loss epoch-level yang relevan """ from typing import Any, Dict, List, Optional import math import torch import torch.nn.functional as F from aksara.linguistic.lps import MorfologiAnalyzer from aksara.base.state import AksaraState class AksaraMetrics: """ Metrik evaluasi untuk model AKSARA. Fokusnya: 1. Morph correctness: kesesuaian affix terhadap target 2. Structural correctness: kesesuaian role terhadap target 3. KBBI alignment: kedekatan semantic slots ke anchor 4. Score calibration: kualitas skor evaluator terhadap label 5. Loop/repetition pressure: penalti untuk keluaran yang berulang """ def __init__(self): self.analyzer = MorfologiAnalyzer() self._reset() def _reset(self): self._morph_correct = 0 self._morph_total = 0 self._struct_correct = 0 self._struct_total = 0 self._morph_seen = 0 self._struct_seen = 0 self._kbbi_sim_sum = 0.0 self._kbbi_count = 0 self._score_abs_error_sum = 0.0 self._score_count = 0 self._loop_penalty_sum = 0.0 self._loop_count = 0 self._repetition_flag_count = 0 self._epoch_batches = 0 self._epoch_samples = 0 self._epoch_loss_sums: Dict[str, float] = {} self._guardrail_flags: List[str] = [] self._semantic_coherence_sum = 0.0 self._semantic_coherence_count = 0 self._novelty_sum = 0.0 self._novelty_count = 0 self._generalization_sum = 0.0 self._generalization_count = 0 self._reasoning_sum = 0.0 self._reasoning_count = 0 self._meaning_discriminator_sum = 0.0 self._meaning_discriminator_count = 0 self._global_constraint_sum = 0.0 self._global_constraint_count = 0 self._anti_loop_sum = 0.0 self._anti_loop_count = 0 def update( self, struct_output: Optional[Dict[str, torch.Tensor]], targets: Dict[str, torch.Tensor], semantic_slots: torch.Tensor, kbbi_anchors: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, score_total: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, losses: Optional[Dict[str, torch.Tensor]] = None, kbbi_mask: Optional[torch.Tensor] = None, state: Optional[AksaraState] = None, state_dict: Optional[Dict[str, Any]] = None, ): """ Update metrik dari satu batch. """ with torch.no_grad(): if struct_output is not None: pred_affix = struct_output.get("affix_ids") true_affix = targets["affix_ids"] if pred_affix is not None: if attention_mask is not None: valid = attention_mask.bool() self._morph_correct += (pred_affix[valid] == true_affix[valid]).sum().item() self._morph_total += valid.sum().item() self._morph_seen += int(valid.sum().item()) else: self._morph_correct += (pred_affix == true_affix).sum().item() self._morph_total += pred_affix.numel() self._morph_seen += int(pred_affix.numel()) pred_role = struct_output.get("role_ids") true_role = targets["role_ids"] if pred_role is not None: role_valid = (true_role > 0) if attention_mask is not None: role_valid = role_valid & attention_mask.bool() if role_valid.any(): self._struct_correct += (pred_role[role_valid] == true_role[role_valid]).sum().item() self._struct_total += role_valid.sum().item() self._struct_seen += int(role_valid.sum().item()) if semantic_slots.shape == kbbi_anchors.shape: cos_sim = F.cosine_similarity(semantic_slots, kbbi_anchors, dim=-1) if kbbi_mask is not None: has_kbbi = kbbi_mask.bool() else: has_kbbi = kbbi_anchors.abs().sum(dim=-1) > 0 if attention_mask is not None: has_kbbi = has_kbbi & attention_mask.bool() if has_kbbi.any(): self._kbbi_sim_sum += cos_sim[has_kbbi].sum().item() self._kbbi_count += has_kbbi.sum().item() if state is not None and labels is not None: state_score = torch.tensor([float(state.skor_linguistik)], device=semantic_slots.device if torch.is_tensor(semantic_slots) else None) labels_state = labels.float().clamp(0.0, 1.0).view(-1) if state_score.numel() == labels_state.numel(): self._score_abs_error_sum += (state_score - labels_state).abs().sum().item() self._score_count += state_score.numel() elif state_dict is not None and labels is not None: state_score_value = state_dict.get("skor_linguistik") if state_score_value is not None: state_score = torch.tensor([float(state_score_value)], device=semantic_slots.device if torch.is_tensor(semantic_slots) else None) labels_state = labels.float().clamp(0.0, 1.0).view(-1) if state_score.numel() == labels_state.numel(): self._score_abs_error_sum += (state_score - labels_state).abs().sum().item() self._score_count += state_score.numel() elif score_total is not None and labels is not None: score_total = score_total.float().clamp(0.0, 1.0) labels = labels.float().clamp(0.0, 1.0).to(score_total.device) self._score_abs_error_sum += (score_total - labels).abs().sum().item() self._score_count += score_total.numel() if struct_output is not None: loop_penalty = self._extract_loop_penalty(struct_output) if loop_penalty is not None: self._loop_penalty_sum += loop_penalty self._loop_count += 1 if loop_penalty > 0.35: self._repetition_flag_count += 1 if state is not None: self._update_state_signals(state) elif state_dict is not None: self._update_state_dict_signals(state_dict) if losses is not None: self.update_epoch_losses(losses) self._epoch_batches += 1 if labels is not None: self._epoch_samples += int(labels.numel()) def update_epoch_losses(self, losses: Dict[str, Any]): """Akumulasi loss/metric per-batch untuk agregasi epoch-level eksplisit.""" for key, value in losses.items(): if key == "lambdas": continue if torch.is_tensor(value): numeric = float(value.detach().mean().item()) elif isinstance(value, (int, float)): numeric = float(value) elif isinstance(value, (list, tuple)) and value: numeric = float(sum(float(v) for v in value) / len(value)) else: continue self._epoch_loss_sums[key] = self._epoch_loss_sums.get(key, 0.0) + numeric def _mean_from_epoch_loss_sums(self) -> Dict[str, float]: if self._epoch_batches <= 0: return {} return {k: v / self._epoch_batches for k, v in self._epoch_loss_sums.items()} def _compute_guardrails(self, results: Dict[str, float]) -> List[str]: alerts: List[str] = [] total = results.get("total") score_mean = results.get("score_mean") score_cal = results.get("score_calibration") morph_acc = results.get("morph_accuracy") struct_acc = results.get("struct_accuracy") kbbi_align = results.get("kbbi_alignment") loop_penalty = results.get("loop_penalty_mean") repetition_rate = results.get("repetition_flag_rate") if total is not None and total > 5.0: alerts.append(f"loss tinggi={total:.4f}") if score_mean is not None and score_mean < 0.25: alerts.append(f"score_mean rendah={score_mean:.4f}") if score_cal is not None and score_cal < 0.7: alerts.append(f"calibration lemah={score_cal:.4f}") if morph_acc is not None and morph_acc < 0.8: alerts.append(f"morph_accuracy rendah={morph_acc:.4f}") if struct_acc is not None and struct_acc < 0.75: alerts.append(f"struct_accuracy rendah={struct_acc:.4f}") if kbbi_align is not None and kbbi_align < 0.2: alerts.append(f"kbbi_alignment rendah={kbbi_align:.4f}") if loop_penalty is not None and loop_penalty > 0.25: alerts.append(f"loop_penalty tinggi={loop_penalty:.4f}") if repetition_rate is not None and repetition_rate > 0.20: alerts.append(f"repetition_flag_rate tinggi={repetition_rate:.4f}") return alerts def epoch_summary(self) -> Dict[str, float]: """Ringkasan eksplisit metrik epoch-level.""" results = self.compute() results.update(self._mean_from_epoch_loss_sums()) results["epoch_batches"] = float(self._epoch_batches) results["epoch_samples"] = float(self._epoch_samples) alerts = self._compute_guardrails(results) results["guardrail_alert_count"] = float(len(alerts)) results["guardrail_alerts"] = alerts # type: ignore[assignment] return results def guardrail_alerts(self, summary: Optional[Dict[str, float]] = None) -> List[str]: summary = summary or self.epoch_summary() alerts = self._compute_guardrails(summary) self._guardrail_flags = alerts return alerts def compute(self) -> Dict[str, float]: """Hitung semua metrik dari akumulasi update.""" results: Dict[str, float] = {} results["sentence_semantic_coherence"] = ( self._semantic_coherence_sum / self._semantic_coherence_count if self._semantic_coherence_count > 0 else 0.0 ) results["novelty_score"] = ( self._novelty_sum / self._novelty_count if self._novelty_count > 0 else 0.0 ) results["generalization_score"] = ( self._generalization_sum / self._generalization_count if self._generalization_count > 0 else 0.0 ) results["reasoning_score"] = ( self._reasoning_sum / self._reasoning_count if self._reasoning_count > 0 else 0.0 ) results["meaning_discriminator"] = ( self._meaning_discriminator_sum / self._meaning_discriminator_count if self._meaning_discriminator_count > 0 else 0.0 ) results["global_constraint"] = ( self._global_constraint_sum / self._global_constraint_count if self._global_constraint_count > 0 else 0.0 ) results["anti_loop_score"] = ( self._anti_loop_sum / self._anti_loop_count if self._anti_loop_count > 0 else 0.0 ) results["morph_accuracy"] = ( self._morph_correct / self._morph_total if self._morph_total > 0 else 0.0 ) results["struct_accuracy"] = ( self._struct_correct / self._struct_total if self._struct_total > 0 else 0.0 ) results["morph_seen"] = float(self._morph_seen) results["struct_seen"] = float(self._struct_seen) results["kbbi_alignment"] = ( self._kbbi_sim_sum / self._kbbi_count if self._kbbi_count > 0 else 0.0 ) results["score_calibration"] = ( max(0.0, 1.0 - (self._score_abs_error_sum / self._score_count)) if self._score_count > 0 else 0.0 ) results["loop_penalty_mean"] = ( self._loop_penalty_sum / self._loop_count if self._loop_count > 0 else 0.0 ) results["repetition_flag_rate"] = ( self._repetition_flag_count / self._loop_count if self._loop_count > 0 else 0.0 ) results["loop_count"] = float(self._loop_count) return results def reset(self): """Reset semua akumulasi.""" self._reset() def _update_state_signals(self, state: AksaraState): try: self._semantic_coherence_sum += float(getattr(state, "semantic_coherence", 0.0)) self._semantic_coherence_count += 1 self._novelty_sum += float(getattr(state, "novelty_score", 0.0)) self._novelty_count += 1 self._generalization_sum += float(getattr(state, "generalization_score", 0.0)) self._generalization_count += 1 self._reasoning_sum += float(getattr(state, "reasoning_score", 0.0)) self._reasoning_count += 1 self._meaning_discriminator_sum += float(getattr(state, "meaning_discriminator", 0.0)) self._meaning_discriminator_count += 1 self._global_constraint_sum += float(getattr(state, "global_constraint", 0.0)) self._global_constraint_count += 1 self._anti_loop_sum += float(getattr(state, "anti_loop_score", 0.0)) self._anti_loop_count += 1 except Exception: pass metadata = getattr(state, "metadata", {}) or {} if metadata: self._semantic_coherence_sum += float(metadata.get("sentence_semantic_coherence", metadata.get("semantic_coherence", metadata.get("coherence_score", 0.0)))) self._novelty_sum += float(metadata.get("novelty", metadata.get("generation_novelty", metadata.get("novelty_score", 0.0)))) self._generalization_sum += float(metadata.get("generalization", metadata.get("pattern_generalization", metadata.get("generalization_score", 0.0)))) self._reasoning_sum += float(metadata.get("reasoning_score", metadata.get("relation_reasoning", 0.0))) self._meaning_discriminator_sum += float(metadata.get("meaning_discriminator", metadata.get("meaning_disc_score", metadata.get("global_meaning_score", 0.0)))) self._global_constraint_sum += float(metadata.get("global_constraint", metadata.get("global_coherence", metadata.get("relational_meaning", 0.0)))) self._anti_loop_sum += float(metadata.get("anti_loop_score", metadata.get("anti_loop", metadata.get("loop_suppression", 0.0)))) self._semantic_coherence_count += 1 self._novelty_count += 1 self._generalization_count += 1 self._reasoning_count += 1 self._meaning_discriminator_count += 1 self._global_constraint_count += 1 self._anti_loop_count += 1 def _update_state_dict_signals(self, state_dict: Dict[str, Any]): try: self._semantic_coherence_sum += float(state_dict.get("semantic_coherence", state_dict.get("sentence_semantic_coherence", state_dict.get("coherence_score", 0.0)))) self._novelty_sum += float(state_dict.get("novelty_score", state_dict.get("novelty", state_dict.get("generation_novelty", 0.0)))) self._generalization_sum += float(state_dict.get("generalization_score", state_dict.get("generalization", state_dict.get("pattern_generalization", 0.0)))) self._reasoning_sum += float(state_dict.get("reasoning_score", state_dict.get("relation_reasoning", 0.0))) self._meaning_discriminator_sum += float(state_dict.get("meaning_discriminator", state_dict.get("meaning_disc_score", state_dict.get("global_meaning_score", 0.0)))) self._global_constraint_sum += float(state_dict.get("global_constraint", state_dict.get("global_coherence", state_dict.get("relational_meaning", 0.0)))) self._anti_loop_sum += float(state_dict.get("anti_loop_score", state_dict.get("anti_loop", state_dict.get("loop_suppression", 0.0)))) self._semantic_coherence_count += 1 self._novelty_count += 1 self._generalization_count += 1 self._reasoning_count += 1 self._meaning_discriminator_count += 1 self._global_constraint_count += 1 self._anti_loop_count += 1 except Exception: pass @staticmethod def _extract_loop_penalty(struct_output: Dict[str, Any]) -> Optional[float]: candidate_keys = ( "loop_penalty", "repetition_penalty", "anti_loop_score", "generated_loop_penalty", "loop_ratio", ) for key in candidate_keys: if key not in struct_output: continue value = struct_output.get(key) if torch.is_tensor(value): if value.numel() == 0: continue return float(value.detach().float().mean().item()) if isinstance(value, (int, float)) and math.isfinite(float(value)): return float(value) return None @staticmethod def evaluate_morfologi( texts: List[str], affixes: List[List[str]], id_to_affix: Dict[int, str], ) -> Dict[str, float]: """Evaluasi morfologi secara standalone (tanpa model).""" del id_to_affix analyzer = MorfologiAnalyzer() correct = 0 total = 0 for text, pred_aff_seq in zip(texts, affixes): words = text.split() for word, pred_aff in zip(words, pred_aff_seq): _, true_aff = analyzer.best(word) if pred_aff == true_aff: correct += 1 total += 1 return {"morph_accuracy": correct / max(total, 1)}