"""State-native objective stack for AKSARA training. This module implements the curriculum-aligned objective surface used by the state-native pipeline. It combines supervision over AksaraState traces, constraint satisfaction, semantic binding, and GOS structural coherence. The objective design intentionally avoids token-level next prediction and instead optimizes the internal linguistic-state graph, trace consistency, and phase-aware structural regularization. """ from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Dict, Iterable, Optional import torch import torch.nn.functional as F @dataclass class CurriculumObjectiveBundle: """Container for enabled objectives and curriculum weights.""" state_consistency: float = 1.0 constraint_satisfaction: float = 1.0 semantic_alignment: float = 1.0 gos_coherence: float = 1.0 multi_state_margin: float = 1.0 extra: Dict[str, float] = field(default_factory=dict) def weights(self) -> Dict[str, float]: weights = { "state_consistency": float(self.state_consistency), "constraint_satisfaction": float(self.constraint_satisfaction), "semantic_alignment": float(self.semantic_alignment), "gos_coherence": float(self.gos_coherence), "multi_state_margin": float(self.multi_state_margin), } weights.update({k: float(v) for k, v in self.extra.items()}) return weights def _as_tensor(value: Any, device: Optional[torch.device] = None) -> torch.Tensor: if isinstance(value, torch.Tensor): return value.to(device=device) if device is not None else value if isinstance(value, (list, tuple)): return torch.tensor(value, dtype=torch.float32, device=device) if isinstance(value, (int, float)): return torch.tensor(float(value), dtype=torch.float32, device=device) return torch.tensor(0.0, dtype=torch.float32, device=device) def _mean_state_values(state: Any, keys: Iterable[str]) -> torch.Tensor: vals = [] for key in keys: if isinstance(state, dict) and key in state: vals.append(_as_tensor(state[key]).float()) if not vals: return torch.tensor(0.0) return torch.stack([v.reshape(()) for v in vals]).mean() class StateAlignmentLoss: """Align state scores to supervision values.""" def __call__(self, model_output: Dict, targets: Optional[Dict] = None) -> torch.Tensor: device = None if isinstance(model_output, dict): for value in model_output.values(): if isinstance(value, torch.Tensor): device = value.device break pred_state = model_output.get("state_eval") if isinstance(model_output, dict) else None target_state = (targets or {}).get("state") if targets else None if pred_state is None or target_state is None: return torch.zeros((), device=device, requires_grad=True) pred_energy = _mean_state_values(pred_state, ["energi_total", "energy", "energi"]) target_energy = _mean_state_values(target_state, ["energi_total", "energy", "energi"]) pred_structure = _mean_state_values(pred_state, ["kelengkapan_struktur", "structure", "completeness"]) target_structure = _mean_state_values(target_state, ["kelengkapan_struktur", "structure", "completeness"]) return F.mse_loss(pred_energy, target_energy) + F.mse_loss(pred_structure, target_structure) class ConstraintSatisfactionLoss: """Penalize constraint violations in AksaraState space.""" def __call__(self, model_output: Dict, targets: Optional[Dict] = None) -> torch.Tensor: device = None if isinstance(model_output, dict): for value in model_output.values(): if isinstance(value, torch.Tensor): device = value.device break pred_state = model_output.get("state_eval") if isinstance(model_output, dict) else None if pred_state is None: return torch.zeros((), device=device, requires_grad=True) violations = [] for key in ("pelanggaran", "violations", "constraint_violations"): if isinstance(pred_state, dict) and key in pred_state: value = pred_state[key] if isinstance(value, (list, tuple)): if len(value) > 0: # list kosong bukan violation — jangan append 0.0 violations.append(torch.tensor(float(len(value)), device=device)) else: violations.append(_as_tensor(value, device=device).float()) if violations: return torch.stack([v.reshape(()) for v in violations]).mean() # Fallback: bandingkan constraint_satisfaction model vs target dataset # Ini aktif ketika pelanggaran linguistik eksplisit tidak tersedia cs_pred = pred_state.get("constraint_satisfaction") if isinstance(pred_state, dict) else None target_dict = targets.get("target", {}) if isinstance(targets, dict) else {} cs_target = ( target_dict.get("constraint_satisfaction") or target_dict.get("skor_linguistik") ) if isinstance(target_dict, dict) else None if cs_pred is not None and cs_target is not None: return F.mse_loss( _as_tensor(cs_pred, device=device).float().reshape(()), _as_tensor(cs_target, device=device).float().reshape(()), ) return torch.zeros((), device=device, requires_grad=True) class SemanticBindingLoss: """ Semantic complementarity loss antar lapisan BSU → MEB → GOS. PRINSIP (anti-MSE): MSE lama salah: memaksa BSU = MEB = GOS secara nilai — menghancurkan diferensiasi antar lapisan yang masing-masing punya peran berbeda. Ganti dengan cosine complementarity: - Setiap lapisan BOLEH punya representasi berbeda (nilai berbeda) - Tapi harus SEARAH secara semantik (cosine similarity tinggi) Justifikasi: BSU = representasi morfem mentah dari KBBI MEB = evolusi makna setelah constraint propagation GOS = abstraksi makna wacana global Ketiganya membicarakan hal yang sama → searah. Nilainya tidak harus sama. """ def __call__(self, model_output: Dict, targets: Optional[Dict] = None) -> torch.Tensor: bsu_repr = model_output.get("bsu_repr") if isinstance(model_output, dict) else None meb_repr = model_output.get("meb_repr") if isinstance(model_output, dict) else None gos_repr = model_output.get("gos_repr") if isinstance(model_output, dict) else None tensor = None if isinstance(model_output, dict): for value in model_output.values(): if isinstance(value, torch.Tensor): tensor = value break device = tensor.device if tensor is not None else None if bsu_repr is None or meb_repr is None or gos_repr is None: return torch.zeros((), device=device, requires_grad=True) # Cosine distance = 1 - cosine_similarity ∈ [0, 2] # Kita minimisasi jarak kosinus antar lapisan → dorong searah, bukan sama persis def _cosine_distance(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: a_f = a.float().reshape(-1) b_f = b.float().reshape(-1) norm_a = F.normalize(a_f.unsqueeze(0), dim=-1) norm_b = F.normalize(b_f.unsqueeze(0), dim=-1) return 1.0 - (norm_a * norm_b).sum() loss_bsu_meb = _cosine_distance(bsu_repr, meb_repr) loss_meb_gos = _cosine_distance(meb_repr, gos_repr) return (loss_bsu_meb + loss_meb_gos) * 0.5 class GOSCoherenceLoss: """ Diversitas representasi antar lapisan MEB. Justifikasi linguistik (anti-Transformer): Setiap layer MEB menangani jenis constraint berbeda: Layer 0: morfologi (affix, root) Layer 1: sintaktik (peran, dependensi) Layer 2+: semantik (grounding, konteks) Jika representasi antar layer terlalu mirip (cosine sim ≈ 1), berarti layer tidak berkontribusi nyata terhadap pemahaman makna. Kita MINIMISASI cosine similarity antar layer: → mendorong setiap layer menambah perspektif linguistik berbeda → bukan tentang prediksi token — tentang kedalaman pemahaman Berbeda dari Transformer attention yang O(n²) dan implisit: Di sini setiap layer HARUS punya peran eksplisit yang berbeda. """ def __call__(self, model_output: Dict, targets: Optional[Dict] = None) -> torch.Tensor: device = None if isinstance(model_output, dict): for value in model_output.values(): if isinstance(value, torch.Tensor): device = value.device break gos_trace = model_output.get("gos_trace") if isinstance(model_output, dict) else None if gos_trace is None: return torch.zeros((), device=device, requires_grad=True) if not (isinstance(gos_trace, dict) and "trace_vectors" in gos_trace): # Fallback ke trace_scores jika trace_vectors tidak ada if isinstance(gos_trace, dict) and "trace_scores" in gos_trace: trace_scores = _as_tensor(gos_trace["trace_scores"], device=device) if trace_scores.numel() < 2: return torch.zeros((), device=device, requires_grad=True) # Normalized delta: seberapa besar perubahan relatif antar layer mean_abs = trace_scores.detach().abs().mean().clamp(min=1e-6) delta = trace_scores[1:] - trace_scores[:-1] return (delta / mean_abs).pow(2).mean() return torch.zeros((), device=device, requires_grad=True) # Gunakan trace_vectors jika tersedia: (n_layers, D) trace_vecs = gos_trace["trace_vectors"] if not isinstance(trace_vecs, torch.Tensor) or trace_vecs.shape[0] < 2: return torch.zeros((), device=device, requires_grad=True) # Cosine similarity antar layer berurutan # Nilai ∈ [0, 1] setelah clamp (cos sim bisa negatif, clamp ke 0 = hanya pinalti jika terlalu mirip) tv = trace_vecs.reshape(trace_vecs.shape[0], -1).float() # (n_layers, D*) tv_norm = F.normalize(tv, dim=-1, eps=1e-8) # (n_layers, D*) cos_sim = (tv_norm[:-1] * tv_norm[1:]).sum(dim=-1) # (n_layers-1,) return cos_sim.clamp(min=0.0).mean() class MultiStateMarginLoss: """Separate positive vs negative or stable vs unstable state traces.""" def __call__(self, model_output: Dict, targets: Optional[Dict] = None) -> torch.Tensor: positive = model_output.get("positive_score") if isinstance(model_output, dict) else None negative = model_output.get("negative_score") if isinstance(model_output, dict) else None if positive is None or negative is None: tensor = None if isinstance(model_output, dict): for value in model_output.values(): if isinstance(value, torch.Tensor): tensor = value break device = tensor.device if tensor is not None else None return torch.zeros((), device=device, requires_grad=True) margin = 1.0 - (positive.float() - negative.float()) return F.relu(margin) def build_state_objective_bundle(config: Optional[Dict] = None) -> CurriculumObjectiveBundle: """Build the objective bundle from config.""" config = config or {} objectives = config.get("objectives", {}) if isinstance(objectives, dict): return CurriculumObjectiveBundle( state_consistency=float(objectives.get("state_consistency", 1.0)), constraint_satisfaction=float(objectives.get("constraint_satisfaction", 1.0)), semantic_alignment=float(objectives.get("semantic_alignment", 1.0)), gos_coherence=float(objectives.get("gos_coherence", 1.0)), multi_state_margin=float(objectives.get("multi_state_margin", 1.0)), extra={k: float(v) for k, v in objectives.items() if k not in { "state_consistency", "constraint_satisfaction", "semantic_alignment", "gos_coherence", "multi_state_margin", }}, ) return CurriculumObjectiveBundle() def compute_state_objective_loss( bundle: CurriculumObjectiveBundle, model_output: Dict[str, Any], targets: Optional[Dict[str, Any]] = None, ) -> tuple[torch.Tensor, Dict[str, float]]: """Combine all objective terms into a single scalar loss.""" device = None for value in model_output.values() if isinstance(model_output, dict) else []: if isinstance(value, torch.Tensor): device = value.device break losses: Dict[str, torch.Tensor] = { "state_consistency": StateAlignmentLoss()(model_output, targets), "constraint_satisfaction": ConstraintSatisfactionLoss()(model_output, targets), "semantic_alignment": SemanticBindingLoss()(model_output, targets), "gos_coherence": GOSCoherenceLoss()(model_output, targets), "multi_state_margin": MultiStateMarginLoss()(model_output, targets), } total = torch.zeros((), device=device, requires_grad=True) weights = bundle.weights() for name, loss in losses.items(): weight = weights.get(name, 1.0) total = total + loss * float(weight) return total, {name: float(loss.detach().cpu().item()) for name, loss in losses.items()}