""" AksaraTrainer - Training loop untuk AKSARA reasoning-only. Mengintegrasikan PD (Pengendali Dinamik) untuk adaptasi fokus loss evaluator. """ import os import time import json from dataclasses import dataclass, field from typing import Dict, List, Optional, Callable import torch import torch.nn as nn from torch.utils.data import DataLoader, Sampler from aksara.core.model import AksaraModel from aksara.training.pd import PengendaliDinamik, PDConfig from aksara.data.dataset import AksaraBatch, collate_fn from aksara.validation.validator import ValidationStatus from aksara.utils.metrics import AksaraMetrics @dataclass class TrainerConfig: output_dir: str = "aksara_output" num_epochs: int = 10 batch_size: int = 16 learning_rate: float = 1e-3 weight_decay: float = 0.01 warmup_steps: int = 100 max_grad_norm: float = 1.0 save_every_n_steps: int = 500 log_every_n_steps: int = 50 eval_every_n_steps: int = 200 device: str = "cuda" if torch.cuda.is_available() else "cpu" use_pd: bool = True pd_config: PDConfig = field(default_factory=PDConfig) fp16: bool = False seed: int = 42 use_curriculum: bool = True curriculum_boundary: int = 10 class BalancedPairSampler(Sampler): """Sampler yang menjamin setiap item batch memuat pasangan positif-negatif jika tersedia.""" def __init__(self, dataset, shuffle: bool = True, seed: int = 42): self.dataset = dataset self.shuffle = shuffle self.seed = int(seed) def __iter__(self): pairs = [] if hasattr(self.dataset, "build_balanced_index_pairs"): pairs = list(self.dataset.build_balanced_index_pairs()) if not pairs: indices = list(range(len(self.dataset))) if self.shuffle: generator = torch.Generator() generator.manual_seed(self.seed) perm = torch.randperm(len(indices), generator=generator).tolist() indices = [indices[i] for i in perm] for idx in indices: yield idx return if self.shuffle: generator = torch.Generator() generator.manual_seed(self.seed) perm = torch.randperm(len(pairs), generator=generator).tolist() pairs = [pairs[i] for i in perm] for pair in pairs: yield pair def __len__(self): pairs = [] if hasattr(self.dataset, "build_balanced_index_pairs"): pairs = list(self.dataset.build_balanced_index_pairs()) return len(pairs) if pairs else len(self.dataset) class AksaraTrainer: """ Training loop untuk AksaraModel. Mengintegrasikan: - CorrectnessLoss via AksaraModel - Pengendali Dinamik (PD) untuk adaptive lambda weights - Warmup + cosine LR schedule - Gradient clipping - Checkpoint saving """ def __init__( self, model: AksaraModel, train_dataset, eval_dataset=None, config: TrainerConfig = None, callbacks: Optional[List[Callable]] = None, ): self.model = model self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.config = config or TrainerConfig() self.callbacks = callbacks or [] self.device = torch.device(self.config.device) self.model = self.model.to(self.device) self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.config.learning_rate, weight_decay=self.config.weight_decay, ) self.pd = PengendaliDinamik(self.config.pd_config) if self.config.use_pd else None self.scaler = torch.cuda.amp.GradScaler() if self.config.fp16 and self.device.type == "cuda" else None self.global_step = 0 self.best_eval_loss = float("inf") self._train_losses: List[Dict] = [] self._curriculum_stage = 0 self.metrics = AksaraMetrics() self._last_epoch_summary: Dict[str, object] = {} self._guardrail_thresholds = { "total": 5.0, "score_mean": 0.25, "score_calibration": 0.7, "morph_accuracy": 0.8, "struct_accuracy": 0.75, "kbbi_alignment": 0.2, "sentence_semantic_coherence": 0.35, "anti_loop_score": 0.15, } os.makedirs(self.config.output_dir, exist_ok=True) torch.manual_seed(self.config.seed) self._train_sampler = BalancedPairSampler(self.train_dataset, shuffle=True, seed=self.config.seed) self._eval_sampler = ( BalancedPairSampler(self.eval_dataset, shuffle=False, seed=self.config.seed) if self.eval_dataset is not None else None ) def _get_lr(self, step: int) -> float: warmup = self.config.warmup_steps if step < warmup: return self.config.learning_rate * (step + 1) / warmup total_steps = max(1, self.config.num_epochs * len(self.train_dataset) // self.config.batch_size) progress = (step - warmup) / max(total_steps - warmup, 1) import math return self.config.learning_rate * (0.1 + 0.9 * 0.5 * (1.0 + math.cos(math.pi * progress))) def _set_lr(self, lr: float): for pg in self.optimizer.param_groups: pg["lr"] = lr def _build_dep_masks(self, batch: AksaraBatch) -> torch.Tensor: B = batch.morpheme_ids.size(0) L = batch.morpheme_ids.size(1) dep_masks = torch.zeros(B, L, L, dtype=torch.bool, device=self.device) for i in range(B): actual_len = int(batch.lengths[i].item()) dummy_tokens = ["_"] * actual_len mask_i = self.model.lps.build_dep_mask(dummy_tokens, L) mask_i[actual_len:, :] = False mask_i[:, actual_len:] = False dep_masks[i] = mask_i.to(self.device) return dep_masks def _sort_batch_by_length(self, batch: AksaraBatch) -> AksaraBatch: if not self.config.use_curriculum: return batch order = torch.argsort(batch.lengths, descending=False) return AksaraBatch( morpheme_ids=batch.morpheme_ids[order], affix_ids=batch.affix_ids[order], role_ids=batch.role_ids[order], attention_mask=batch.attention_mask[order], lengths=batch.lengths[order], texts=[batch.texts[i] for i in order.tolist()], labels=batch.labels[order], ) def _curriculum_maybe_advance(self, batch: AksaraBatch): if not self.config.use_curriculum: return max_len = int(batch.lengths.max().item()) if max_len > self.config.curriculum_boundary: self._curriculum_stage = 1 def _training_step(self, batch: AksaraBatch) -> Dict: self.model.train() batch = self._sort_batch_by_length(batch).to(self.device) self._curriculum_maybe_advance(batch) dep_masks = self._build_dep_masks(batch) lps_output = { "morpheme_ids": batch.morpheme_ids, "affix_ids": batch.affix_ids, "role_ids": batch.role_ids, "dep_masks": dep_masks, "attention_mask": batch.attention_mask, "lengths": batch.lengths, "max_len": batch.morpheme_ids.size(1), } labels = batch.labels if hasattr(batch, "labels") else None if labels is None: raise ValueError("AksaraTrainer membutuhkan batch.labels untuk training reasoning-only.") labels = labels.to(self.device).float() if self.scaler: with torch.cuda.amp.autocast(): outputs = self.model(lps_output, labels=labels) loss = outputs["losses"]["total"] self.scaler.scale(loss).backward() self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) self.scaler.step(self.optimizer) self.scaler.update() else: outputs = self.model(lps_output, labels=labels) loss = outputs["losses"]["total"] loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) self.optimizer.step() self.optimizer.zero_grad() if self.pd: self.pd.step_update(outputs["losses"], optimizer=self.optimizer) step_losses = {k: v.item() if torch.is_tensor(v) else v for k, v in outputs["losses"].items()} scores = outputs.get("scores", {}) if isinstance(scores, dict) and scores: total_score = scores.get("total") if torch.is_tensor(total_score): step_losses["score_mean"] = float(total_score.mean().item()) step_losses["score_min"] = float(total_score.min().item()) step_losses["score_max"] = float(total_score.max().item()) elif isinstance(total_score, (int, float)): total_score = float(total_score) step_losses["score_mean"] = total_score step_losses["score_min"] = total_score step_losses["score_max"] = total_score metric_source = { "affix_ids": batch.affix_ids.detach(), "role_ids": batch.role_ids.detach(), } semantic_slots = outputs.get("meb_out") if semantic_slots is None: semantic_slots = batch.morpheme_ids.float().unsqueeze(-1) kbbi_anchors = outputs.get("kbbi_anchors") if kbbi_anchors is None: kbbi_anchors = semantic_slots kbbi_mask = outputs.get("kbbi_mask") self.metrics.update( targets=metric_source, semantic_slots=semantic_slots, kbbi_anchors=kbbi_anchors, attention_mask=batch.attention_mask, score_total=scores.get("total") if isinstance(scores, dict) else None, labels=labels, losses=outputs.get("losses"), kbbi_mask=kbbi_mask, ) metric_summary = self.metrics.compute() step_losses["morph_accuracy"] = metric_summary.get("morph_accuracy") step_losses["struct_accuracy"] = metric_summary.get("struct_accuracy") step_losses["kbbi_alignment"] = metric_summary.get("kbbi_alignment") step_losses["score_calibration"] = metric_summary.get("score_calibration") step_losses["sentence_semantic_coherence"] = metric_summary.get("sentence_semantic_coherence", 0.0) step_losses["anti_loop_score"] = metric_summary.get("anti_loop_score", 0.0) step_losses["novelty_score"] = metric_summary.get("novelty_score", 0.0) step_losses["generalization_score"] = metric_summary.get("generalization_score", 0.0) step_losses["reasoning_score"] = metric_summary.get("reasoning_score", 0.0) step_losses["meaning_discriminator"] = metric_summary.get("meaning_discriminator", 0.0) step_losses["global_constraint"] = metric_summary.get("global_constraint", 0.0) return step_losses @torch.no_grad() def _eval_step(self, batch: AksaraBatch) -> Dict: self.model.eval() batch = batch.to(self.device) dep_masks = self._build_dep_masks(batch) lps_output = { "morpheme_ids": batch.morpheme_ids, "affix_ids": batch.affix_ids, "role_ids": batch.role_ids, "dep_masks": dep_masks, "attention_mask": batch.attention_mask, "lengths": batch.lengths, "max_len": batch.morpheme_ids.size(1), } labels = batch.labels if hasattr(batch, "labels") else None if labels is None: raise ValueError("AksaraTrainer membutuhkan batch.labels untuk evaluasi reasoning-only.") labels = labels.to(self.device).float() outputs = self.model(lps_output, labels=labels) step_losses = {k: v.item() if torch.is_tensor(v) else v for k, v in outputs["losses"].items()} scores = outputs.get("scores", {}) if isinstance(scores, dict) and scores: total_score = scores.get("total") if torch.is_tensor(total_score): step_losses["score_mean"] = float(total_score.mean().item()) step_losses["score_min"] = float(total_score.min().item()) step_losses["score_max"] = float(total_score.max().item()) elif isinstance(total_score, (int, float)): total_score = float(total_score) step_losses["score_mean"] = total_score step_losses["score_min"] = total_score step_losses["score_max"] = total_score metric_source = { "affix_ids": batch.affix_ids.detach(), "role_ids": batch.role_ids.detach(), } semantic_slots = outputs.get("meb_out") if semantic_slots is None: semantic_slots = batch.morpheme_ids.float().unsqueeze(-1) kbbi_anchors = outputs.get("kbbi_anchors") if kbbi_anchors is None: kbbi_anchors = semantic_slots kbbi_mask = outputs.get("kbbi_mask") self.metrics.update( targets=metric_source, semantic_slots=semantic_slots, kbbi_anchors=kbbi_anchors, attention_mask=batch.attention_mask, score_total=scores.get("total") if isinstance(scores, dict) else None, labels=labels, losses=outputs.get("losses"), kbbi_mask=kbbi_mask, ) metric_summary = self.metrics.compute() step_losses["morph_accuracy"] = metric_summary.get("morph_accuracy") step_losses["struct_accuracy"] = metric_summary.get("struct_accuracy") step_losses["kbbi_alignment"] = metric_summary.get("kbbi_alignment") step_losses["score_calibration"] = metric_summary.get("score_calibration") step_losses["sentence_semantic_coherence"] = metric_summary.get("sentence_semantic_coherence", 0.0) step_losses["anti_loop_score"] = metric_summary.get("anti_loop_score", 0.0) step_losses["novelty_score"] = metric_summary.get("novelty_score", 0.0) step_losses["generalization_score"] = metric_summary.get("generalization_score", 0.0) step_losses["reasoning_score"] = metric_summary.get("reasoning_score", 0.0) step_losses["meaning_discriminator"] = metric_summary.get("meaning_discriminator", 0.0) step_losses["global_constraint"] = metric_summary.get("global_constraint", 0.0) return step_losses def _format_guardrail_alerts(self, summary: Dict[str, object]) -> List[str]: alerts = summary.get("guardrail_alerts") or [] if not isinstance(alerts, list): return [] return [str(a) for a in alerts] def _emit_guardrail_alerts(self, epoch: int, summary: Dict[str, object]): alerts = self._format_guardrail_alerts(summary) if not alerts: print(f"[Epoch {epoch+1}] Guardrail: OK") return print(f"[Epoch {epoch+1}] Guardrail alert ({len(alerts)}): " + " | ".join(alerts)) def train(self): train_loader = DataLoader( self.train_dataset, batch_size=self.config.batch_size, sampler=self._train_sampler, collate_fn=collate_fn, num_workers=0, ) eval_loader = None if self.eval_dataset is not None and len(self.eval_dataset) == 0: raise ValueError("Eval dataset kosong setelah preprocessing/filtering. Hentikan training dan perbaiki pipeline data split.") if self.eval_dataset: eval_loader = DataLoader( self.eval_dataset, batch_size=self.config.batch_size, sampler=self._eval_sampler, collate_fn=collate_fn, num_workers=0, ) print(f"[AKSARA] Mulai training — {self.model.num_parameters['trainable']:,} parameter") print(f"[AKSARA] Device: {self.device} | Epochs: {self.config.num_epochs}") print(f"[AKSARA] KBBI coverage: {self.model.lsk.kbbi_coverage:.1%}") for epoch in range(self.config.num_epochs): epoch_losses = [] t0 = time.time() self.metrics.reset() for batch in train_loader: lr = self._get_lr(self.global_step) self._set_lr(lr) step_losses = self._training_step(batch) epoch_losses.append(step_losses) self._train_losses.append(step_losses) self.global_step += 1 if self.global_step % self.config.log_every_n_steps == 0: self._log_step(step_losses, epoch, lr) if eval_loader and self.global_step % self.config.eval_every_n_steps == 0: eval_loss = self._evaluate(eval_loader) print(f"[Eval step {self.global_step}] total={eval_loss['total']:.4f}") self._safe_checkpoint(self.global_step, eval_loss) if eval_loss["total"] < self.best_eval_loss: self.best_eval_loss = eval_loss["total"] self._safe_checkpoint("best", eval_loss) if self.global_step % self.config.save_every_n_steps == 0: self._safe_checkpoint(f"step_{self.global_step}") for cb in self.callbacks: cb(self, step_losses, self.global_step) avg = self._average_losses(epoch_losses) epoch_summary = self.metrics.epoch_summary() epoch_summary.update(avg) self._last_epoch_summary = epoch_summary elapsed = time.time() - t0 print( f"\n[Epoch {epoch+1}/{self.config.num_epochs}] " f"total={epoch_summary.get('total', avg.get('total', 0)):.4f} | " f"binary={epoch_summary.get('l_binary', avg.get('l_binary', 0)):.4f} | " f"margin={epoch_summary.get('l_margin', avg.get('l_margin', 0)):.4f} | " f"hard_neg={epoch_summary.get('l_hard_neg', avg.get('l_hard_neg', 0)):.4f} | " f"consist={epoch_summary.get('l_consist', avg.get('l_consist', 0)):.4f} | " f"confidence={epoch_summary.get('l_confidence', avg.get('l_confidence', 0)):.4f} | " f"score_mean={epoch_summary.get('score_mean', avg.get('score_mean', 0)):.4f} | " f"morph_acc={epoch_summary.get('morph_accuracy', 0):.4f} | " f"struct_acc={(epoch_summary.get('struct_accuracy') or 0):.4f} | " f"kbbi_align={epoch_summary.get('kbbi_alignment', 0):.4f} | " f"score_cal={epoch_summary.get('score_calibration', 0):.4f} | " f"sem_coh={epoch_summary.get('sentence_semantic_coherence', 0):.4f} | " f"anti_loop={epoch_summary.get('anti_loop_score', 0):.4f} | " f"batches={int(epoch_summary.get('epoch_batches', 0))} | " f"samples={int(epoch_summary.get('epoch_samples', 0))} | " f"{elapsed:.1f}s" ) self._emit_guardrail_alerts(epoch, epoch_summary) if self.pd: diag = self.pd.get_diagnostics() print( f" PD λ: morph={diag['lambdas']['morph']:.3f} " f"struct={diag['lambdas']['struct']:.3f} " f"sem={diag['lambdas']['sem']:.3f} " f"ctx={diag['lambdas']['ctx']:.3f}" ) if self.best_eval_loss == float("inf") and eval_loader: self.best_eval_loss = self._evaluate(eval_loader)["total"] self._safe_checkpoint("best", {"total": self.best_eval_loss}) self._safe_checkpoint("final") print(f"\n[AKSARA] Training selesai. Best eval loss: {self.best_eval_loss:.4f}") return self._train_losses def _evaluate(self, loader: DataLoader) -> Dict: all_losses = [] for batch in loader: losses = self._eval_step(batch) all_losses.append(losses) return self._average_losses(all_losses) @staticmethod def _average_losses(losses_list: List[Dict]) -> Dict: if not losses_list: return {} keys = [k for k in losses_list[0] if k != "lambdas"] return {k: sum(d.get(k, 0) for d in losses_list) / len(losses_list) for k in keys} def _log_step(self, losses: Dict, epoch: int, lr: float): pd_info = "" if self.pd: lam = self.pd.get_lambdas() pd_info = f" | λ_m={lam['morph']:.2f} λ_s={lam['struct']:.2f} λ_k={lam['sem']:.2f} λ_c={lam['ctx']:.2f}" def _fmt(value, default=0.0): try: return float(value if value is not None else default) except (TypeError, ValueError): return default print( f"[Step {self.global_step} | E{epoch+1}] " f"total={_fmt(losses.get('total')):.4f} " f"binary={_fmt(losses.get('l_binary')):.4f} " f"margin={_fmt(losses.get('l_margin')):.4f} " f"hard_neg={_fmt(losses.get('l_hard_neg')):.4f} " f"consist={_fmt(losses.get('l_consist')):.4f} " f"confidence={_fmt(losses.get('l_confidence')):.4f} " f"score_mean={_fmt(losses.get('score_mean')):.4f} " f"sem_coh={_fmt(losses.get('sentence_semantic_coherence')):.4f} " f"anti_loop={_fmt(losses.get('anti_loop_score')):.4f} " f"lr={lr:.2e}{pd_info}" ) def _safe_checkpoint(self, tag: str, eval_loss: Optional[Dict] = None): path = os.path.join(self.config.output_dir, f"checkpoint_{tag}") try: self.model.save(path) if self.pd: diag = self.pd.get_diagnostics() with open(os.path.join(path, "pd_state.json"), "w", encoding="utf-8") as f: json.dump(diag, f, indent=2) if eval_loss is not None and isinstance(eval_loss, dict): with open(os.path.join(path, "eval_loss.json"), "w", encoding="utf-8") as f: json.dump(eval_loss, f, indent=2) except (RuntimeError, OSError) as exc: print(f"[AKSARA] Warning: checkpoint '{tag}' gagal disimpan: {exc}") return False return True