""" Pengendali Dinamik (PD) untuk AKSARA reasoning-only. Modul ini mempertahankan nama publik lama agar kompatibel dengan trainer, tetapi seluruh logika telah dipindahkan dari asumsi autoregresif/token-logit ke adaptasi bobot loss evaluator berbasis correctness. """ from __future__ import annotations from dataclasses import dataclass from typing import Dict @dataclass class PDConfig: """Konfigurasi Pengendali Dinamik untuk loss evaluator.""" lambda_min: float = 0.05 lambda_max: float = 2.0 ema_beta: float = 0.9 adapt_rate: float = 0.04 eps: float = 1e-8 # Batas aman untuk komponen yang diadaptasi trainer. lambda_sem_max: float = 2.0 lambda_ctx_max: float = 2.0 lambda_morph_max: float = 1.5 lambda_struct_min: float = 0.3 struct_sem_ratio_floor: float = 0.2 # Bobot awal untuk komponen correctness loss lambda_morph: float = 0.85 lambda_struct: float = 1.0 lambda_semantic: float = 1.1 lambda_lexical: float = 1.0 lambda_margin_boost: float = 1.15 lambda_confidence_boost: float = 1.1 lambda_semantic_gate_boost: float = 1.12 class PengendaliDinamik: """ Adaptasi bobot loss untuk training reasoning-only. Tidak ada ketergantungan pada output_logits, perplexity, atau signal autoregresif. PD ini hanya memantau komponen loss correctness: l_binary, l_margin, l_consist, l_calibrate, l_confidence. """ def __init__(self, config: PDConfig = None): self.config = config or PDConfig() self._lambdas = { "morph": float(self.config.lambda_morph), "struct": float(self.config.lambda_struct), "sem": float(self.config.lambda_semantic), "ctx": float(self.config.lambda_lexical), } self._ema_losses: Dict[str, float] = { "l_binary": 0.0, "l_margin": 0.0, "l_consist": 0.0, "l_calibrate": 0.0, "l_confidence": 0.0, "l_uncertainty": 0.0, "l_structural": 0.0, "l_hard_neg_struct": 0.0, "l_decision": 0.0, "l_commit": 0.0, "l_semantic_gate": 0.0, "l_semantic_gap": 0.0, } self._step = 0 def _clamp_lambda(self, key: str, value: float) -> float: min_by_key = { "struct": self.config.lambda_struct_min, }.get(key, self.config.lambda_min) max_by_key = { "sem": self.config.lambda_sem_max, "ctx": self.config.lambda_ctx_max, "morph": self.config.lambda_morph_max, }.get(key, self.config.lambda_max) return max(min_by_key, min(max_by_key, value)) def _mean_loss(self, losses: Dict[str, float]) -> float: vals = [float(v) for k, v in losses.items() if k in self._ema_losses] return sum(vals) / max(len(vals), 1) def _has_valid_signals(self) -> bool: return any(value > self.config.eps for value in self._ema_losses.values()) def _apply_guarded_update(self, key: str, value: float) -> None: self._lambdas[key] = self._clamp_lambda(key, value) def step_update(self, losses: Dict[str, float], optimizer=None): """ Update EMA loss dan bobot adaptif. Args: losses: dict hasil training step, idealnya berisi komponen l_binary/l_margin/l_consist/l_calibrate/l_confidence. optimizer: dipertahankan untuk kompatibilitas call-site, tetapi tidak digunakan. """ self._step += 1 beta = self.config.ema_beta observed_any = False for key in self._ema_losses: if key not in losses: continue value = float(losses[key]) if value != value: continue observed_any = True prev = self._ema_losses[key] self._ema_losses[key] = beta * prev + (1.0 - beta) * value if not observed_any or not self._has_valid_signals(): return self._lambdas mean_loss = self._mean_loss(self._ema_losses) if mean_loss <= self.config.eps: return self._lambdas adapt_rate = self.config.adapt_rate margin_ema = self._ema_losses["l_margin"] confidence_ema = self._ema_losses["l_confidence"] binary_ema = self._ema_losses["l_binary"] structural_ema = self._ema_losses["l_structural"] decision_ema = self._ema_losses["l_decision"] semantic_gate_ema = self._ema_losses["l_semantic_gate"] semantic_gap_ema = self._ema_losses["l_semantic_gap"] for key in self._ema_losses: target = self._ema_losses[key] / max(mean_loss, self.config.eps) bounded_target = min(max(target, 0.75), 1.25) if key == "l_binary": self._apply_guarded_update( "morph", self._lambdas["morph"] * (1.0 + adapt_rate * 0.35 * (bounded_target - 1.0)), ) elif key == "l_margin": boost = 1.0 + (self.config.lambda_margin_boost - 1.0) * max(bounded_target - 1.0, 0.0) if margin_ema > binary_ema: boost *= 1.05 self._apply_guarded_update("sem", self._lambdas["sem"] * boost) self._apply_guarded_update("ctx", self._lambdas["ctx"] * boost) elif key == "l_consist": struct_scale = 1.0 + adapt_rate * 0.25 * (bounded_target - 1.0) self._apply_guarded_update("struct", min(self._lambdas["struct"] * struct_scale, 1.5)) elif key == "l_structural": struct_scale = 1.0 + adapt_rate * 0.75 * max(bounded_target - 1.0, 0.0) if structural_ema >= binary_ema: struct_scale *= 1.05 self._apply_guarded_update("struct", self._lambdas["struct"] * struct_scale) elif key == "l_hard_neg_struct": struct_scale = 1.0 + adapt_rate * 0.55 * max(bounded_target - 1.0, 0.0) self._apply_guarded_update("struct", self._lambdas["struct"] * struct_scale) elif key == "l_calibrate": self._apply_guarded_update( "sem", self._lambdas["sem"] * (1.0 + adapt_rate * 0.6 * (bounded_target - 1.0)), ) elif key == "l_confidence": boost = 1.0 + (self.config.lambda_confidence_boost - 1.0) * max(bounded_target - 1.0, 0.0) if confidence_ema < 0.05: boost *= 1.05 self._apply_guarded_update("sem", self._lambdas["sem"] * boost) self._apply_guarded_update("ctx", self._lambdas["ctx"] * boost) elif key == "l_uncertainty": boost = 1.0 + adapt_rate * 0.6 * max(bounded_target - 1.0, 0.0) self._apply_guarded_update("sem", self._lambdas["sem"] * boost) self._apply_guarded_update("ctx", self._lambdas["ctx"] * boost) elif key == "l_decision": struct_scale = 1.0 + adapt_rate * 0.65 * max(bounded_target - 1.0, 0.0) if decision_ema >= binary_ema: struct_scale *= 1.05 self._apply_guarded_update("struct", self._lambdas["struct"] * struct_scale) elif key == "l_commit": boost = 1.0 + adapt_rate * 0.35 * max(bounded_target - 1.0, 0.0) self._apply_guarded_update("sem", self._lambdas["sem"] * boost) self._apply_guarded_update("ctx", self._lambdas["ctx"] * boost) elif key == "l_semantic_gate": boost = 1.0 + (self.config.lambda_semantic_gate_boost - 1.0) * max(bounded_target - 1.0, 0.0) if semantic_gate_ema > confidence_ema: boost *= 1.05 self._apply_guarded_update("sem", self._lambdas["sem"] * boost) self._apply_guarded_update("ctx", self._lambdas["ctx"] * boost) elif key == "l_semantic_gap": boost = 1.0 + adapt_rate * 0.45 * max(bounded_target - 1.0, 0.0) if semantic_gap_ema < 0.1: boost *= 1.03 self._apply_guarded_update("sem", self._lambdas["sem"] * boost) self._apply_guarded_update("ctx", self._lambdas["ctx"] * boost) self._lambdas["morph"] = self._clamp_lambda("morph", self._lambdas["morph"]) self._lambdas["sem"] = self._clamp_lambda("sem", self._lambdas["sem"]) self._lambdas["ctx"] = self._clamp_lambda("ctx", self._lambdas["ctx"]) self._lambdas["sem"] = self._clamp_lambda( "sem", max(self._lambdas["sem"], self._lambdas["morph"]), ) self._lambdas["ctx"] = self._clamp_lambda( "ctx", max(self._lambdas["ctx"], self._lambdas["morph"]), ) struct_floor = max( self.config.lambda_struct_min, self._lambdas["sem"] * self.config.struct_sem_ratio_floor, ) self._lambdas["struct"] = self._clamp_lambda("struct", max(self._lambdas["struct"], struct_floor)) return self._lambdas def get_lambdas(self) -> Dict[str, float]: """Ambil bobot aktif saat ini.""" return dict(self._lambdas) def get_diagnostics(self) -> Dict[str, Dict[str, float]]: """Ringkasan state PD untuk logging/checkpoint.""" return { "step": self._step, "lambdas": self.get_lambdas(), "ema_losses": dict(self._ema_losses), } def state_dict(self) -> Dict: return { "config": self.config.__dict__, "lambdas": self.get_lambdas(), "ema_losses": dict(self._ema_losses), "step": self._step, } def load_state_dict(self, state: Dict): cfg = state.get("config") if isinstance(cfg, dict): self.config = PDConfig(**cfg) self._lambdas.update(state.get("lambdas", {})) self._ema_losses.update(state.get("ema_losses", {})) self._step = int(state.get("step", 0)) class AksaraPD(PengendaliDinamik): """Alias backward-compatibility."""