""" CorrectnessLoss — Loss untuk training AKSARA sebagai evaluator kebenaran kalimat. Komponen utama: L_binary : BCE loss pada skor total untuk kalimat benar/salah L_margin : pairwise margin ranking antara pasangan benar vs salah L_consist : konsistensi antar skor dimensi L_calibrate : penalti agar skor tidak kolaps ke 0.5 L_confidence : dorong pemisahan yang cukup antara positif dan negatif L_semantic_gate : penalti saat contoh tampak "lulus" tetapi sinyal semantic lemah L_hard_neg : tekanan hard-negative untuk mencegah lolos lewat morph/struct saja L_sem_consist : keselarasan semantik dengan total score dan label Total = λ_b * L_binary + λ_m * L_margin + λ_c * L_consist + λ_cal * L_calibrate + λ_conf * L_confidence + ... """ from typing import Dict import torch import torch.nn as nn import torch.nn.functional as F class CorrectnessLoss(nn.Module): """ Loss untuk training AKSARA sebagai evaluator kebenaran kalimat. Input batch boleh berupa label campuran 0/1. Loss utama tetap BCE pada skor total, sementara komponen tambahan memberi sinyal pemisahan yang lebih kuat dan stabil. """ def __init__( self, margin: float = 0.35, lambda_binary: float = 1.0, lambda_margin: float = 1.25, lambda_consist: float = 0.08, lambda_calibrate: float = 0.03, lambda_confidence: float = 0.35, target_low: float = 0.08, target_high: float = 0.92, confidence_floor: float = 0.15, hard_neg_floor: float = 0.05, uncertainty_floor: float = 0.05, lambda_uncertainty: float = 0.4, lambda_structural: float = 0.5, lambda_decision: float = 0.6, lambda_commit: float = 0.35, struct_confidence_floor: float = 0.3, ): super().__init__() self.margin = float(margin) self.lambda_binary = float(lambda_binary) self.lambda_margin = float(lambda_margin) self.lambda_consist = float(lambda_consist) self.lambda_calibrate = float(lambda_calibrate) self.lambda_confidence = float(lambda_confidence) self.target_low = float(target_low) self.target_high = float(target_high) self.confidence_floor = float(confidence_floor) self.hard_neg_floor = float(hard_neg_floor) self.uncertainty_floor = float(uncertainty_floor) self.lambda_uncertainty = float(lambda_uncertainty) self.lambda_structural = float(lambda_structural) self.lambda_decision = float(lambda_decision) self.lambda_commit = float(lambda_commit) self.struct_confidence_floor = float(struct_confidence_floor) self.semantic_floor = 0.45 self.semantic_gap_floor = 0.18 self.semantic_gate_weight = 0.5 self.bce = nn.BCELoss() self.hard_negative_margin_floor = 0.02 self.max_lambda_guard = 1.5 self.dynamic_lambda_cap = 2.0 def forward( self, score_total: torch.Tensor, scores: Dict[str, torch.Tensor], labels: torch.Tensor, ) -> Dict[str, torch.Tensor]: """ Args: score_total : (B,) — skor gabungan dari CorrectnessHead scores : dict {morph, struct, semantic, lexical, total} labels : (B,) float — 1.0 benar, 0.0 salah Returns: dict losses dengan semua komponen + total """ device = score_total.device labels = labels.float().clamp(0.0, 1.0) score_total = score_total.float().clamp(1e-6, 1.0 - 1e-6) losses: Dict[str, torch.Tensor] = {} # ── L_binary: sinyal utama dari label batch ───────────────────────── l_binary = self.bce(score_total, labels) losses["l_binary"] = torch.nan_to_num(l_binary, nan=0.0) # ── L_margin: hard negative separation ───────────────────────────── pos_mask = labels > 0.5 neg_mask = ~pos_mask pos_scores = score_total[pos_mask] neg_scores = score_total[neg_mask] if pos_scores.numel() and neg_scores.numel(): pairwise = self.margin - pos_scores.unsqueeze(1) + neg_scores.unsqueeze(0) l_margin = F.relu(pairwise).mean() # Hard negative emphasis: penalize the closest negatives and weakest positives hardest_neg = neg_scores.max() weakest_pos = pos_scores.min() raw_hard_neg = F.relu(self.margin - weakest_pos + hardest_neg) # Keep pressure alive even when raw margin is already satisfied. semantic_gap = torch.abs(pos_scores.mean() - neg_scores.mean()) hard_neg_floor = torch.clamp(self.hard_negative_margin_floor - semantic_gap, min=0.0) l_hard_neg = raw_hard_neg + hard_neg_floor else: target = torch.where( labels > 0.5, torch.full_like(score_total, self.target_high), torch.full_like(score_total, self.target_low), ) l_margin = torch.mean(torch.abs(score_total - target)) l_hard_neg = l_margin.detach() * 0.0 + self.hard_neg_floor losses["l_margin"] = torch.nan_to_num(l_margin, nan=0.0) losses["l_hard_neg"] = torch.nan_to_num(l_hard_neg, nan=0.0) # ── L_consist: sub-skor tidak boleh saling divergen ekstrem ──────── morph_scores = scores["morph"].float() struct_scores = scores["struct"].float() semantic_scores = scores["semantic"].float() lexical_scores = scores["lexical"].float() sub_scores = torch.stack( [morph_scores, struct_scores, semantic_scores, lexical_scores], dim=1, ).float() l_consist = sub_scores.var(dim=1, unbiased=False).mean() losses["l_consist"] = torch.nan_to_num(l_consist, nan=0.0) # ── Semantic gate: morph/struct tidak boleh mengalahkan semantic ─── semantic_target = torch.where( labels > 0.5, torch.full_like(semantic_scores, self.target_high), torch.full_like(semantic_scores, self.target_low), ) l_semantic_gate = F.binary_cross_entropy( semantic_scores.clamp(1e-6, 1.0 - 1e-6), labels, ) l_semantic_target = torch.mean((semantic_scores - semantic_target) ** 2) semantic_gap = torch.mean(torch.abs(score_total - semantic_scores)) semantic_floor_penalty = F.relu(self.semantic_floor - torch.mean(semantic_scores)) l_sem_gate = l_semantic_gate + 0.5 * l_semantic_target + 0.5 * semantic_floor_penalty losses["l_semantic_gate"] = torch.nan_to_num(l_sem_gate, nan=0.0) losses["l_semantic_gap"] = torch.nan_to_num(semantic_gap, nan=0.0) # ── L_calibrate: dorong skor jauh dari titik abu-abu 0.5 ──────────── l_calibrate = torch.mean((score_total - 0.5) ** 2) losses["l_calibrate"] = torch.nan_to_num(l_calibrate, nan=0.0) # ── L_confidence: reward confident correct, penalize collapse ke tengah ─ target = torch.where( labels > 0.5, torch.full_like(score_total, self.target_high), torch.full_like(score_total, self.target_low), ) mse_conf = torch.mean((score_total - target) ** 2) collapse_penalty = torch.mean(F.relu(self.confidence_floor - torch.abs(score_total - 0.5)) ** 2) semantic_alignment = torch.mean((semantic_scores - labels) ** 2) semantic_gap_penalty = torch.mean(F.relu(self.semantic_gap_floor - torch.abs(score_total - semantic_scores))) l_confidence = mse_conf + 0.5 * collapse_penalty + 0.35 * semantic_alignment + 0.5 * semantic_gap_penalty losses["l_confidence"] = torch.nan_to_num(l_confidence, nan=0.0) confidence_signal = torch.mean(torch.abs(score_total - 0.5) * 2.0) l_uncertainty = F.relu(self.uncertainty_floor - confidence_signal) losses["l_uncertainty"] = torch.nan_to_num(l_uncertainty, nan=0.0) # ── Structural supervision: structure must actively participate ───── struct_target = labels l_structural = F.binary_cross_entropy(struct_scores.clamp(1e-6, 1.0 - 1e-6), struct_target) struct_conf = torch.mean(torch.abs(struct_scores - 0.5) * 2.0) l_struct_commit = F.relu(self.struct_confidence_floor - struct_conf) losses["l_structural"] = torch.nan_to_num(l_structural + 0.5 * l_struct_commit, nan=0.0) # ── Structural hard negative: force label-aware separation on struct ─ struct_pos = struct_scores[pos_mask] struct_neg = struct_scores[neg_mask] if struct_pos.numel() and struct_neg.numel(): struct_pairwise = self.margin - struct_pos.unsqueeze(1) + struct_neg.unsqueeze(0) l_hard_neg_struct = F.relu(struct_pairwise).mean() else: l_hard_neg_struct = score_total.new_tensor(self.hard_neg_floor) losses["l_hard_neg_struct"] = torch.nan_to_num(l_hard_neg_struct, nan=0.0) # ── Decision supervision: valid/salah adalah klasifikasi berbasis skor total ─ decision_prob = torch.stack([1.0 - score_total, score_total], dim=1) decision_targets = labels.long() l_decision = F.nll_loss(torch.log(decision_prob.clamp(1e-6, 1.0)), decision_targets) losses["l_decision"] = torch.nan_to_num(l_decision, nan=0.0) # ── Commit pressure: punish indecisive midpoint behavior ──────────── l_commit = torch.mean((0.5 - torch.abs(score_total - 0.5)) ** 2) losses["l_commit"] = torch.nan_to_num(l_commit, nan=0.0) lambda_margin = min(self.lambda_margin, self.dynamic_lambda_cap) lambda_consist = min(self.lambda_consist, self.dynamic_lambda_cap) lambda_calibrate = min(self.lambda_calibrate, self.dynamic_lambda_cap) lambda_confidence = min(self.lambda_confidence, self.dynamic_lambda_cap) lambda_uncertainty = min(self.lambda_uncertainty, self.max_lambda_guard) lambda_structural = min(self.lambda_structural, self.max_lambda_guard) lambda_decision = min(self.lambda_decision, self.max_lambda_guard) lambda_commit = min(self.lambda_commit, self.max_lambda_guard) total = ( self.lambda_binary * losses["l_binary"] + lambda_margin * losses["l_margin"] + lambda_consist * losses["l_consist"] + lambda_calibrate * losses["l_calibrate"] + lambda_confidence * losses["l_confidence"] + 0.75 * losses["l_hard_neg"] + 0.5 * losses["l_hard_neg_struct"] + 0.65 * losses["l_semantic_gate"] + 0.35 * losses["l_semantic_gap"] + lambda_uncertainty * losses["l_uncertainty"] + lambda_structural * losses["l_structural"] + lambda_decision * losses["l_decision"] + lambda_commit * losses["l_commit"] ) losses["total"] = torch.nan_to_num(total, nan=0.0) return losses class AksaraLoss(CorrectnessLoss): """Alias backward-compat. Gunakan CorrectnessLoss.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)