| """ |
| AksaraTrainer - Training loop untuk AKSARA reasoning-only. |
| Mengintegrasikan PD (Pengendali Dinamik) untuk adaptasi fokus loss evaluator. |
| """ |
|
|
| import os |
| import time |
| import json |
| from dataclasses import dataclass, field |
| from typing import Dict, List, Optional, Callable |
|
|
| import torch |
| import torch.nn as nn |
| from torch.utils.data import DataLoader, Sampler |
|
|
| from aksara.core.model import AksaraModel |
| from aksara.training.pd import PengendaliDinamik, PDConfig |
| from aksara.data.dataset import AksaraBatch, collate_fn |
| from aksara.validation.validator import ValidationStatus |
| from aksara.utils.metrics import AksaraMetrics |
|
|
|
|
| @dataclass |
| class TrainerConfig: |
| output_dir: str = "aksara_output" |
| num_epochs: int = 10 |
| batch_size: int = 16 |
| learning_rate: float = 1e-3 |
| weight_decay: float = 0.01 |
| warmup_steps: int = 100 |
| max_grad_norm: float = 1.0 |
| save_every_n_steps: int = 500 |
| log_every_n_steps: int = 50 |
| eval_every_n_steps: int = 200 |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" |
| use_pd: bool = True |
| pd_config: PDConfig = field(default_factory=PDConfig) |
| fp16: bool = False |
| seed: int = 42 |
| use_curriculum: bool = True |
| curriculum_boundary: int = 10 |
|
|
|
|
| class BalancedPairSampler(Sampler): |
| """Sampler yang menjamin setiap item batch memuat pasangan positif-negatif jika tersedia.""" |
|
|
| def __init__(self, dataset, shuffle: bool = True, seed: int = 42): |
| self.dataset = dataset |
| self.shuffle = shuffle |
| self.seed = int(seed) |
|
|
| def __iter__(self): |
| pairs = [] |
| if hasattr(self.dataset, "build_balanced_index_pairs"): |
| pairs = list(self.dataset.build_balanced_index_pairs()) |
|
|
| if not pairs: |
| indices = list(range(len(self.dataset))) |
| if self.shuffle: |
| generator = torch.Generator() |
| generator.manual_seed(self.seed) |
| perm = torch.randperm(len(indices), generator=generator).tolist() |
| indices = [indices[i] for i in perm] |
| for idx in indices: |
| yield idx |
| return |
|
|
| if self.shuffle: |
| generator = torch.Generator() |
| generator.manual_seed(self.seed) |
| perm = torch.randperm(len(pairs), generator=generator).tolist() |
| pairs = [pairs[i] for i in perm] |
|
|
| for pair in pairs: |
| yield pair |
|
|
| def __len__(self): |
| pairs = [] |
| if hasattr(self.dataset, "build_balanced_index_pairs"): |
| pairs = list(self.dataset.build_balanced_index_pairs()) |
| return len(pairs) if pairs else len(self.dataset) |
|
|
|
|
| class AksaraTrainer: |
| """ |
| Training loop untuk AksaraModel. |
| |
| Mengintegrasikan: |
| - CorrectnessLoss via AksaraModel |
| - Pengendali Dinamik (PD) untuk adaptive lambda weights |
| - Warmup + cosine LR schedule |
| - Gradient clipping |
| - Checkpoint saving |
| """ |
|
|
| def __init__( |
| self, |
| model: AksaraModel, |
| train_dataset, |
| eval_dataset=None, |
| config: TrainerConfig = None, |
| callbacks: Optional[List[Callable]] = None, |
| ): |
| self.model = model |
| self.train_dataset = train_dataset |
| self.eval_dataset = eval_dataset |
| self.config = config or TrainerConfig() |
| self.callbacks = callbacks or [] |
|
|
| self.device = torch.device(self.config.device) |
| self.model = self.model.to(self.device) |
|
|
| self.optimizer = torch.optim.AdamW( |
| self.model.parameters(), |
| lr=self.config.learning_rate, |
| weight_decay=self.config.weight_decay, |
| ) |
|
|
| self.pd = PengendaliDinamik(self.config.pd_config) if self.config.use_pd else None |
| self.scaler = torch.cuda.amp.GradScaler() if self.config.fp16 and self.device.type == "cuda" else None |
|
|
| self.global_step = 0 |
| self.best_eval_loss = float("inf") |
| self._train_losses: List[Dict] = [] |
| self._curriculum_stage = 0 |
| self.metrics = AksaraMetrics() |
| self._last_epoch_summary: Dict[str, object] = {} |
| self._guardrail_thresholds = { |
| "total": 5.0, |
| "score_mean": 0.25, |
| "score_calibration": 0.7, |
| "morph_accuracy": 0.8, |
| "struct_accuracy": 0.75, |
| "kbbi_alignment": 0.2, |
| "sentence_semantic_coherence": 0.35, |
| "anti_loop_score": 0.15, |
| } |
|
|
| os.makedirs(self.config.output_dir, exist_ok=True) |
| torch.manual_seed(self.config.seed) |
| self._train_sampler = BalancedPairSampler(self.train_dataset, shuffle=True, seed=self.config.seed) |
| self._eval_sampler = ( |
| BalancedPairSampler(self.eval_dataset, shuffle=False, seed=self.config.seed) |
| if self.eval_dataset is not None |
| else None |
| ) |
|
|
| def _get_lr(self, step: int) -> float: |
| warmup = self.config.warmup_steps |
| if step < warmup: |
| return self.config.learning_rate * (step + 1) / warmup |
| total_steps = max(1, self.config.num_epochs * len(self.train_dataset) // self.config.batch_size) |
| progress = (step - warmup) / max(total_steps - warmup, 1) |
| import math |
| return self.config.learning_rate * (0.1 + 0.9 * 0.5 * (1.0 + math.cos(math.pi * progress))) |
|
|
| def _set_lr(self, lr: float): |
| for pg in self.optimizer.param_groups: |
| pg["lr"] = lr |
|
|
| def _build_dep_masks(self, batch: AksaraBatch) -> torch.Tensor: |
| B = batch.morpheme_ids.size(0) |
| L = batch.morpheme_ids.size(1) |
| dep_masks = torch.zeros(B, L, L, dtype=torch.bool, device=self.device) |
|
|
| for i in range(B): |
| actual_len = int(batch.lengths[i].item()) |
| dummy_tokens = ["_"] * actual_len |
| mask_i = self.model.lps.build_dep_mask(dummy_tokens, L) |
| mask_i[actual_len:, :] = False |
| mask_i[:, actual_len:] = False |
| dep_masks[i] = mask_i.to(self.device) |
|
|
| return dep_masks |
|
|
| def _sort_batch_by_length(self, batch: AksaraBatch) -> AksaraBatch: |
| if not self.config.use_curriculum: |
| return batch |
| order = torch.argsort(batch.lengths, descending=False) |
| return AksaraBatch( |
| morpheme_ids=batch.morpheme_ids[order], |
| affix_ids=batch.affix_ids[order], |
| role_ids=batch.role_ids[order], |
| attention_mask=batch.attention_mask[order], |
| lengths=batch.lengths[order], |
| texts=[batch.texts[i] for i in order.tolist()], |
| labels=batch.labels[order], |
| ) |
|
|
| def _curriculum_maybe_advance(self, batch: AksaraBatch): |
| if not self.config.use_curriculum: |
| return |
| max_len = int(batch.lengths.max().item()) |
| if max_len > self.config.curriculum_boundary: |
| self._curriculum_stage = 1 |
|
|
| def _training_step(self, batch: AksaraBatch) -> Dict: |
| self.model.train() |
| batch = self._sort_batch_by_length(batch).to(self.device) |
| self._curriculum_maybe_advance(batch) |
|
|
| dep_masks = self._build_dep_masks(batch) |
| lps_output = { |
| "morpheme_ids": batch.morpheme_ids, |
| "affix_ids": batch.affix_ids, |
| "role_ids": batch.role_ids, |
| "dep_masks": dep_masks, |
| "attention_mask": batch.attention_mask, |
| "lengths": batch.lengths, |
| "max_len": batch.morpheme_ids.size(1), |
| } |
|
|
| labels = batch.labels if hasattr(batch, "labels") else None |
| if labels is None: |
| raise ValueError("AksaraTrainer membutuhkan batch.labels untuk training reasoning-only.") |
| labels = labels.to(self.device).float() |
|
|
| if self.scaler: |
| with torch.cuda.amp.autocast(): |
| outputs = self.model(lps_output, labels=labels) |
| loss = outputs["losses"]["total"] |
| self.scaler.scale(loss).backward() |
| self.scaler.unscale_(self.optimizer) |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| else: |
| outputs = self.model(lps_output, labels=labels) |
| loss = outputs["losses"]["total"] |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm) |
| self.optimizer.step() |
|
|
| self.optimizer.zero_grad() |
|
|
| if self.pd: |
| self.pd.step_update(outputs["losses"], optimizer=self.optimizer) |
|
|
| step_losses = {k: v.item() if torch.is_tensor(v) else v for k, v in outputs["losses"].items()} |
| scores = outputs.get("scores", {}) |
| if isinstance(scores, dict) and scores: |
| total_score = scores.get("total") |
| if torch.is_tensor(total_score): |
| step_losses["score_mean"] = float(total_score.mean().item()) |
| step_losses["score_min"] = float(total_score.min().item()) |
| step_losses["score_max"] = float(total_score.max().item()) |
| elif isinstance(total_score, (int, float)): |
| total_score = float(total_score) |
| step_losses["score_mean"] = total_score |
| step_losses["score_min"] = total_score |
| step_losses["score_max"] = total_score |
|
|
| metric_source = { |
| "affix_ids": batch.affix_ids.detach(), |
| "role_ids": batch.role_ids.detach(), |
| } |
| semantic_slots = outputs.get("meb_out") |
| if semantic_slots is None: |
| semantic_slots = batch.morpheme_ids.float().unsqueeze(-1) |
| kbbi_anchors = outputs.get("kbbi_anchors") |
| if kbbi_anchors is None: |
| kbbi_anchors = semantic_slots |
| kbbi_mask = outputs.get("kbbi_mask") |
| self.metrics.update( |
| targets=metric_source, |
| semantic_slots=semantic_slots, |
| kbbi_anchors=kbbi_anchors, |
| attention_mask=batch.attention_mask, |
| score_total=scores.get("total") if isinstance(scores, dict) else None, |
| labels=labels, |
| losses=outputs.get("losses"), |
| kbbi_mask=kbbi_mask, |
| ) |
|
|
| metric_summary = self.metrics.compute() |
| step_losses["morph_accuracy"] = metric_summary.get("morph_accuracy") |
| step_losses["struct_accuracy"] = metric_summary.get("struct_accuracy") |
| step_losses["kbbi_alignment"] = metric_summary.get("kbbi_alignment") |
| step_losses["score_calibration"] = metric_summary.get("score_calibration") |
| step_losses["sentence_semantic_coherence"] = metric_summary.get("sentence_semantic_coherence", 0.0) |
| step_losses["anti_loop_score"] = metric_summary.get("anti_loop_score", 0.0) |
| step_losses["novelty_score"] = metric_summary.get("novelty_score", 0.0) |
| step_losses["generalization_score"] = metric_summary.get("generalization_score", 0.0) |
| step_losses["reasoning_score"] = metric_summary.get("reasoning_score", 0.0) |
| step_losses["meaning_discriminator"] = metric_summary.get("meaning_discriminator", 0.0) |
| step_losses["global_constraint"] = metric_summary.get("global_constraint", 0.0) |
| return step_losses |
|
|
| @torch.no_grad() |
| def _eval_step(self, batch: AksaraBatch) -> Dict: |
| self.model.eval() |
| batch = batch.to(self.device) |
|
|
| dep_masks = self._build_dep_masks(batch) |
| lps_output = { |
| "morpheme_ids": batch.morpheme_ids, |
| "affix_ids": batch.affix_ids, |
| "role_ids": batch.role_ids, |
| "dep_masks": dep_masks, |
| "attention_mask": batch.attention_mask, |
| "lengths": batch.lengths, |
| "max_len": batch.morpheme_ids.size(1), |
| } |
|
|
| labels = batch.labels if hasattr(batch, "labels") else None |
| if labels is None: |
| raise ValueError("AksaraTrainer membutuhkan batch.labels untuk evaluasi reasoning-only.") |
| labels = labels.to(self.device).float() |
| outputs = self.model(lps_output, labels=labels) |
|
|
| step_losses = {k: v.item() if torch.is_tensor(v) else v for k, v in outputs["losses"].items()} |
| scores = outputs.get("scores", {}) |
| if isinstance(scores, dict) and scores: |
| total_score = scores.get("total") |
| if torch.is_tensor(total_score): |
| step_losses["score_mean"] = float(total_score.mean().item()) |
| step_losses["score_min"] = float(total_score.min().item()) |
| step_losses["score_max"] = float(total_score.max().item()) |
| elif isinstance(total_score, (int, float)): |
| total_score = float(total_score) |
| step_losses["score_mean"] = total_score |
| step_losses["score_min"] = total_score |
| step_losses["score_max"] = total_score |
|
|
| metric_source = { |
| "affix_ids": batch.affix_ids.detach(), |
| "role_ids": batch.role_ids.detach(), |
| } |
| semantic_slots = outputs.get("meb_out") |
| if semantic_slots is None: |
| semantic_slots = batch.morpheme_ids.float().unsqueeze(-1) |
| kbbi_anchors = outputs.get("kbbi_anchors") |
| if kbbi_anchors is None: |
| kbbi_anchors = semantic_slots |
| kbbi_mask = outputs.get("kbbi_mask") |
| self.metrics.update( |
| targets=metric_source, |
| semantic_slots=semantic_slots, |
| kbbi_anchors=kbbi_anchors, |
| attention_mask=batch.attention_mask, |
| score_total=scores.get("total") if isinstance(scores, dict) else None, |
| labels=labels, |
| losses=outputs.get("losses"), |
| kbbi_mask=kbbi_mask, |
| ) |
|
|
| metric_summary = self.metrics.compute() |
| step_losses["morph_accuracy"] = metric_summary.get("morph_accuracy") |
| step_losses["struct_accuracy"] = metric_summary.get("struct_accuracy") |
| step_losses["kbbi_alignment"] = metric_summary.get("kbbi_alignment") |
| step_losses["score_calibration"] = metric_summary.get("score_calibration") |
| step_losses["sentence_semantic_coherence"] = metric_summary.get("sentence_semantic_coherence", 0.0) |
| step_losses["anti_loop_score"] = metric_summary.get("anti_loop_score", 0.0) |
| step_losses["novelty_score"] = metric_summary.get("novelty_score", 0.0) |
| step_losses["generalization_score"] = metric_summary.get("generalization_score", 0.0) |
| step_losses["reasoning_score"] = metric_summary.get("reasoning_score", 0.0) |
| step_losses["meaning_discriminator"] = metric_summary.get("meaning_discriminator", 0.0) |
| step_losses["global_constraint"] = metric_summary.get("global_constraint", 0.0) |
| return step_losses |
|
|
| def _format_guardrail_alerts(self, summary: Dict[str, object]) -> List[str]: |
| alerts = summary.get("guardrail_alerts") or [] |
| if not isinstance(alerts, list): |
| return [] |
| return [str(a) for a in alerts] |
|
|
| def _emit_guardrail_alerts(self, epoch: int, summary: Dict[str, object]): |
| alerts = self._format_guardrail_alerts(summary) |
| if not alerts: |
| print(f"[Epoch {epoch+1}] Guardrail: OK") |
| return |
| print(f"[Epoch {epoch+1}] Guardrail alert ({len(alerts)}): " + " | ".join(alerts)) |
|
|
| def train(self): |
| train_loader = DataLoader( |
| self.train_dataset, |
| batch_size=self.config.batch_size, |
| sampler=self._train_sampler, |
| collate_fn=collate_fn, |
| num_workers=0, |
| ) |
|
|
| eval_loader = None |
| if self.eval_dataset is not None and len(self.eval_dataset) == 0: |
| raise ValueError("Eval dataset kosong setelah preprocessing/filtering. Hentikan training dan perbaiki pipeline data split.") |
| if self.eval_dataset: |
| eval_loader = DataLoader( |
| self.eval_dataset, |
| batch_size=self.config.batch_size, |
| sampler=self._eval_sampler, |
| collate_fn=collate_fn, |
| num_workers=0, |
| ) |
|
|
| print(f"[AKSARA] Mulai training — {self.model.num_parameters['trainable']:,} parameter") |
| print(f"[AKSARA] Device: {self.device} | Epochs: {self.config.num_epochs}") |
| print(f"[AKSARA] KBBI coverage: {self.model.lsk.kbbi_coverage:.1%}") |
|
|
| for epoch in range(self.config.num_epochs): |
| epoch_losses = [] |
| t0 = time.time() |
| self.metrics.reset() |
|
|
| for batch in train_loader: |
| lr = self._get_lr(self.global_step) |
| self._set_lr(lr) |
|
|
| step_losses = self._training_step(batch) |
| epoch_losses.append(step_losses) |
| self._train_losses.append(step_losses) |
| self.global_step += 1 |
|
|
| if self.global_step % self.config.log_every_n_steps == 0: |
| self._log_step(step_losses, epoch, lr) |
|
|
| if eval_loader and self.global_step % self.config.eval_every_n_steps == 0: |
| eval_loss = self._evaluate(eval_loader) |
| print(f"[Eval step {self.global_step}] total={eval_loss['total']:.4f}") |
| self._safe_checkpoint(self.global_step, eval_loss) |
| if eval_loss["total"] < self.best_eval_loss: |
| self.best_eval_loss = eval_loss["total"] |
| self._safe_checkpoint("best", eval_loss) |
|
|
| if self.global_step % self.config.save_every_n_steps == 0: |
| self._safe_checkpoint(f"step_{self.global_step}") |
|
|
| for cb in self.callbacks: |
| cb(self, step_losses, self.global_step) |
|
|
| avg = self._average_losses(epoch_losses) |
| epoch_summary = self.metrics.epoch_summary() |
| epoch_summary.update(avg) |
| self._last_epoch_summary = epoch_summary |
| elapsed = time.time() - t0 |
| print( |
| f"\n[Epoch {epoch+1}/{self.config.num_epochs}] " |
| f"total={epoch_summary.get('total', avg.get('total', 0)):.4f} | " |
| f"binary={epoch_summary.get('l_binary', avg.get('l_binary', 0)):.4f} | " |
| f"margin={epoch_summary.get('l_margin', avg.get('l_margin', 0)):.4f} | " |
| f"hard_neg={epoch_summary.get('l_hard_neg', avg.get('l_hard_neg', 0)):.4f} | " |
| f"consist={epoch_summary.get('l_consist', avg.get('l_consist', 0)):.4f} | " |
| f"confidence={epoch_summary.get('l_confidence', avg.get('l_confidence', 0)):.4f} | " |
| f"score_mean={epoch_summary.get('score_mean', avg.get('score_mean', 0)):.4f} | " |
| f"morph_acc={epoch_summary.get('morph_accuracy', 0):.4f} | " |
| f"struct_acc={(epoch_summary.get('struct_accuracy') or 0):.4f} | " |
| f"kbbi_align={epoch_summary.get('kbbi_alignment', 0):.4f} | " |
| f"score_cal={epoch_summary.get('score_calibration', 0):.4f} | " |
| f"sem_coh={epoch_summary.get('sentence_semantic_coherence', 0):.4f} | " |
| f"anti_loop={epoch_summary.get('anti_loop_score', 0):.4f} | " |
| f"batches={int(epoch_summary.get('epoch_batches', 0))} | " |
| f"samples={int(epoch_summary.get('epoch_samples', 0))} | " |
| f"{elapsed:.1f}s" |
| ) |
|
|
| self._emit_guardrail_alerts(epoch, epoch_summary) |
|
|
| if self.pd: |
| diag = self.pd.get_diagnostics() |
| print( |
| f" PD λ: morph={diag['lambdas']['morph']:.3f} " |
| f"struct={diag['lambdas']['struct']:.3f} " |
| f"sem={diag['lambdas']['sem']:.3f} " |
| f"ctx={diag['lambdas']['ctx']:.3f}" |
| ) |
|
|
| if self.best_eval_loss == float("inf") and eval_loader: |
| self.best_eval_loss = self._evaluate(eval_loader)["total"] |
| self._safe_checkpoint("best", {"total": self.best_eval_loss}) |
| self._safe_checkpoint("final") |
| print(f"\n[AKSARA] Training selesai. Best eval loss: {self.best_eval_loss:.4f}") |
| return self._train_losses |
|
|
| def _evaluate(self, loader: DataLoader) -> Dict: |
| all_losses = [] |
| for batch in loader: |
| losses = self._eval_step(batch) |
| all_losses.append(losses) |
| return self._average_losses(all_losses) |
|
|
| @staticmethod |
| def _average_losses(losses_list: List[Dict]) -> Dict: |
| if not losses_list: |
| return {} |
| keys = [k for k in losses_list[0] if k != "lambdas"] |
| return {k: sum(d.get(k, 0) for d in losses_list) / len(losses_list) for k in keys} |
|
|
| def _log_step(self, losses: Dict, epoch: int, lr: float): |
| pd_info = "" |
| if self.pd: |
| lam = self.pd.get_lambdas() |
| pd_info = f" | λ_m={lam['morph']:.2f} λ_s={lam['struct']:.2f} λ_k={lam['sem']:.2f} λ_c={lam['ctx']:.2f}" |
|
|
| def _fmt(value, default=0.0): |
| try: |
| return float(value if value is not None else default) |
| except (TypeError, ValueError): |
| return default |
|
|
| print( |
| f"[Step {self.global_step} | E{epoch+1}] " |
| f"total={_fmt(losses.get('total')):.4f} " |
| f"binary={_fmt(losses.get('l_binary')):.4f} " |
| f"margin={_fmt(losses.get('l_margin')):.4f} " |
| f"hard_neg={_fmt(losses.get('l_hard_neg')):.4f} " |
| f"consist={_fmt(losses.get('l_consist')):.4f} " |
| f"confidence={_fmt(losses.get('l_confidence')):.4f} " |
| f"score_mean={_fmt(losses.get('score_mean')):.4f} " |
| f"sem_coh={_fmt(losses.get('sentence_semantic_coherence')):.4f} " |
| f"anti_loop={_fmt(losses.get('anti_loop_score')):.4f} " |
| f"lr={lr:.2e}{pd_info}" |
| ) |
|
|
| def _safe_checkpoint(self, tag: str, eval_loss: Optional[Dict] = None): |
| path = os.path.join(self.config.output_dir, f"checkpoint_{tag}") |
| try: |
| self.model.save(path) |
| if self.pd: |
| diag = self.pd.get_diagnostics() |
| with open(os.path.join(path, "pd_state.json"), "w", encoding="utf-8") as f: |
| json.dump(diag, f, indent=2) |
| if eval_loss is not None and isinstance(eval_loss, dict): |
| with open(os.path.join(path, "eval_loss.json"), "w", encoding="utf-8") as f: |
| json.dump(eval_loss, f, indent=2) |
| except (RuntimeError, OSError) as exc: |
| print(f"[AKSARA] Warning: checkpoint '{tag}' gagal disimpan: {exc}") |
| return False |
| return True |
|
|