AKSARA-CLM-v1 / aksara /utils /trainer.py
emylton's picture
Upload folder using huggingface_hub
9338a41 verified
Raw
History Blame Contribute Delete
22.5 kB
"""
AksaraTrainer - Training loop untuk AKSARA reasoning-only.
Mengintegrasikan PD (Pengendali Dinamik) untuk adaptasi fokus loss evaluator.
"""
import os
import time
import json
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Callable
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Sampler
from aksara.core.model import AksaraModel
from aksara.training.pd import PengendaliDinamik, PDConfig
from aksara.data.dataset import AksaraBatch, collate_fn
from aksara.validation.validator import ValidationStatus
from aksara.utils.metrics import AksaraMetrics
@dataclass
class TrainerConfig:
output_dir: str = "aksara_output"
num_epochs: int = 10
batch_size: int = 16
learning_rate: float = 1e-3
weight_decay: float = 0.01
warmup_steps: int = 100
max_grad_norm: float = 1.0
save_every_n_steps: int = 500
log_every_n_steps: int = 50
eval_every_n_steps: int = 200
device: str = "cuda" if torch.cuda.is_available() else "cpu"
use_pd: bool = True
pd_config: PDConfig = field(default_factory=PDConfig)
fp16: bool = False
seed: int = 42
use_curriculum: bool = True
curriculum_boundary: int = 10
class BalancedPairSampler(Sampler):
"""Sampler yang menjamin setiap item batch memuat pasangan positif-negatif jika tersedia."""
def __init__(self, dataset, shuffle: bool = True, seed: int = 42):
self.dataset = dataset
self.shuffle = shuffle
self.seed = int(seed)
def __iter__(self):
pairs = []
if hasattr(self.dataset, "build_balanced_index_pairs"):
pairs = list(self.dataset.build_balanced_index_pairs())
if not pairs:
indices = list(range(len(self.dataset)))
if self.shuffle:
generator = torch.Generator()
generator.manual_seed(self.seed)
perm = torch.randperm(len(indices), generator=generator).tolist()
indices = [indices[i] for i in perm]
for idx in indices:
yield idx
return
if self.shuffle:
generator = torch.Generator()
generator.manual_seed(self.seed)
perm = torch.randperm(len(pairs), generator=generator).tolist()
pairs = [pairs[i] for i in perm]
for pair in pairs:
yield pair
def __len__(self):
pairs = []
if hasattr(self.dataset, "build_balanced_index_pairs"):
pairs = list(self.dataset.build_balanced_index_pairs())
return len(pairs) if pairs else len(self.dataset)
class AksaraTrainer:
"""
Training loop untuk AksaraModel.
Mengintegrasikan:
- CorrectnessLoss via AksaraModel
- Pengendali Dinamik (PD) untuk adaptive lambda weights
- Warmup + cosine LR schedule
- Gradient clipping
- Checkpoint saving
"""
def __init__(
self,
model: AksaraModel,
train_dataset,
eval_dataset=None,
config: TrainerConfig = None,
callbacks: Optional[List[Callable]] = None,
):
self.model = model
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.config = config or TrainerConfig()
self.callbacks = callbacks or []
self.device = torch.device(self.config.device)
self.model = self.model.to(self.device)
self.optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=self.config.learning_rate,
weight_decay=self.config.weight_decay,
)
self.pd = PengendaliDinamik(self.config.pd_config) if self.config.use_pd else None
self.scaler = torch.cuda.amp.GradScaler() if self.config.fp16 and self.device.type == "cuda" else None
self.global_step = 0
self.best_eval_loss = float("inf")
self._train_losses: List[Dict] = []
self._curriculum_stage = 0
self.metrics = AksaraMetrics()
self._last_epoch_summary: Dict[str, object] = {}
self._guardrail_thresholds = {
"total": 5.0,
"score_mean": 0.25,
"score_calibration": 0.7,
"morph_accuracy": 0.8,
"struct_accuracy": 0.75,
"kbbi_alignment": 0.2,
"sentence_semantic_coherence": 0.35,
"anti_loop_score": 0.15,
}
os.makedirs(self.config.output_dir, exist_ok=True)
torch.manual_seed(self.config.seed)
self._train_sampler = BalancedPairSampler(self.train_dataset, shuffle=True, seed=self.config.seed)
self._eval_sampler = (
BalancedPairSampler(self.eval_dataset, shuffle=False, seed=self.config.seed)
if self.eval_dataset is not None
else None
)
def _get_lr(self, step: int) -> float:
warmup = self.config.warmup_steps
if step < warmup:
return self.config.learning_rate * (step + 1) / warmup
total_steps = max(1, self.config.num_epochs * len(self.train_dataset) // self.config.batch_size)
progress = (step - warmup) / max(total_steps - warmup, 1)
import math
return self.config.learning_rate * (0.1 + 0.9 * 0.5 * (1.0 + math.cos(math.pi * progress)))
def _set_lr(self, lr: float):
for pg in self.optimizer.param_groups:
pg["lr"] = lr
def _build_dep_masks(self, batch: AksaraBatch) -> torch.Tensor:
B = batch.morpheme_ids.size(0)
L = batch.morpheme_ids.size(1)
dep_masks = torch.zeros(B, L, L, dtype=torch.bool, device=self.device)
for i in range(B):
actual_len = int(batch.lengths[i].item())
dummy_tokens = ["_"] * actual_len
mask_i = self.model.lps.build_dep_mask(dummy_tokens, L)
mask_i[actual_len:, :] = False
mask_i[:, actual_len:] = False
dep_masks[i] = mask_i.to(self.device)
return dep_masks
def _sort_batch_by_length(self, batch: AksaraBatch) -> AksaraBatch:
if not self.config.use_curriculum:
return batch
order = torch.argsort(batch.lengths, descending=False)
return AksaraBatch(
morpheme_ids=batch.morpheme_ids[order],
affix_ids=batch.affix_ids[order],
role_ids=batch.role_ids[order],
attention_mask=batch.attention_mask[order],
lengths=batch.lengths[order],
texts=[batch.texts[i] for i in order.tolist()],
labels=batch.labels[order],
)
def _curriculum_maybe_advance(self, batch: AksaraBatch):
if not self.config.use_curriculum:
return
max_len = int(batch.lengths.max().item())
if max_len > self.config.curriculum_boundary:
self._curriculum_stage = 1
def _training_step(self, batch: AksaraBatch) -> Dict:
self.model.train()
batch = self._sort_batch_by_length(batch).to(self.device)
self._curriculum_maybe_advance(batch)
dep_masks = self._build_dep_masks(batch)
lps_output = {
"morpheme_ids": batch.morpheme_ids,
"affix_ids": batch.affix_ids,
"role_ids": batch.role_ids,
"dep_masks": dep_masks,
"attention_mask": batch.attention_mask,
"lengths": batch.lengths,
"max_len": batch.morpheme_ids.size(1),
}
labels = batch.labels if hasattr(batch, "labels") else None
if labels is None:
raise ValueError("AksaraTrainer membutuhkan batch.labels untuk training reasoning-only.")
labels = labels.to(self.device).float()
if self.scaler:
with torch.cuda.amp.autocast():
outputs = self.model(lps_output, labels=labels)
loss = outputs["losses"]["total"]
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
outputs = self.model(lps_output, labels=labels)
loss = outputs["losses"]["total"]
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
self.optimizer.step()
self.optimizer.zero_grad()
if self.pd:
self.pd.step_update(outputs["losses"], optimizer=self.optimizer)
step_losses = {k: v.item() if torch.is_tensor(v) else v for k, v in outputs["losses"].items()}
scores = outputs.get("scores", {})
if isinstance(scores, dict) and scores:
total_score = scores.get("total")
if torch.is_tensor(total_score):
step_losses["score_mean"] = float(total_score.mean().item())
step_losses["score_min"] = float(total_score.min().item())
step_losses["score_max"] = float(total_score.max().item())
elif isinstance(total_score, (int, float)):
total_score = float(total_score)
step_losses["score_mean"] = total_score
step_losses["score_min"] = total_score
step_losses["score_max"] = total_score
metric_source = {
"affix_ids": batch.affix_ids.detach(),
"role_ids": batch.role_ids.detach(),
}
semantic_slots = outputs.get("meb_out")
if semantic_slots is None:
semantic_slots = batch.morpheme_ids.float().unsqueeze(-1)
kbbi_anchors = outputs.get("kbbi_anchors")
if kbbi_anchors is None:
kbbi_anchors = semantic_slots
kbbi_mask = outputs.get("kbbi_mask")
self.metrics.update(
targets=metric_source,
semantic_slots=semantic_slots,
kbbi_anchors=kbbi_anchors,
attention_mask=batch.attention_mask,
score_total=scores.get("total") if isinstance(scores, dict) else None,
labels=labels,
losses=outputs.get("losses"),
kbbi_mask=kbbi_mask,
)
metric_summary = self.metrics.compute()
step_losses["morph_accuracy"] = metric_summary.get("morph_accuracy")
step_losses["struct_accuracy"] = metric_summary.get("struct_accuracy")
step_losses["kbbi_alignment"] = metric_summary.get("kbbi_alignment")
step_losses["score_calibration"] = metric_summary.get("score_calibration")
step_losses["sentence_semantic_coherence"] = metric_summary.get("sentence_semantic_coherence", 0.0)
step_losses["anti_loop_score"] = metric_summary.get("anti_loop_score", 0.0)
step_losses["novelty_score"] = metric_summary.get("novelty_score", 0.0)
step_losses["generalization_score"] = metric_summary.get("generalization_score", 0.0)
step_losses["reasoning_score"] = metric_summary.get("reasoning_score", 0.0)
step_losses["meaning_discriminator"] = metric_summary.get("meaning_discriminator", 0.0)
step_losses["global_constraint"] = metric_summary.get("global_constraint", 0.0)
return step_losses
@torch.no_grad()
def _eval_step(self, batch: AksaraBatch) -> Dict:
self.model.eval()
batch = batch.to(self.device)
dep_masks = self._build_dep_masks(batch)
lps_output = {
"morpheme_ids": batch.morpheme_ids,
"affix_ids": batch.affix_ids,
"role_ids": batch.role_ids,
"dep_masks": dep_masks,
"attention_mask": batch.attention_mask,
"lengths": batch.lengths,
"max_len": batch.morpheme_ids.size(1),
}
labels = batch.labels if hasattr(batch, "labels") else None
if labels is None:
raise ValueError("AksaraTrainer membutuhkan batch.labels untuk evaluasi reasoning-only.")
labels = labels.to(self.device).float()
outputs = self.model(lps_output, labels=labels)
step_losses = {k: v.item() if torch.is_tensor(v) else v for k, v in outputs["losses"].items()}
scores = outputs.get("scores", {})
if isinstance(scores, dict) and scores:
total_score = scores.get("total")
if torch.is_tensor(total_score):
step_losses["score_mean"] = float(total_score.mean().item())
step_losses["score_min"] = float(total_score.min().item())
step_losses["score_max"] = float(total_score.max().item())
elif isinstance(total_score, (int, float)):
total_score = float(total_score)
step_losses["score_mean"] = total_score
step_losses["score_min"] = total_score
step_losses["score_max"] = total_score
metric_source = {
"affix_ids": batch.affix_ids.detach(),
"role_ids": batch.role_ids.detach(),
}
semantic_slots = outputs.get("meb_out")
if semantic_slots is None:
semantic_slots = batch.morpheme_ids.float().unsqueeze(-1)
kbbi_anchors = outputs.get("kbbi_anchors")
if kbbi_anchors is None:
kbbi_anchors = semantic_slots
kbbi_mask = outputs.get("kbbi_mask")
self.metrics.update(
targets=metric_source,
semantic_slots=semantic_slots,
kbbi_anchors=kbbi_anchors,
attention_mask=batch.attention_mask,
score_total=scores.get("total") if isinstance(scores, dict) else None,
labels=labels,
losses=outputs.get("losses"),
kbbi_mask=kbbi_mask,
)
metric_summary = self.metrics.compute()
step_losses["morph_accuracy"] = metric_summary.get("morph_accuracy")
step_losses["struct_accuracy"] = metric_summary.get("struct_accuracy")
step_losses["kbbi_alignment"] = metric_summary.get("kbbi_alignment")
step_losses["score_calibration"] = metric_summary.get("score_calibration")
step_losses["sentence_semantic_coherence"] = metric_summary.get("sentence_semantic_coherence", 0.0)
step_losses["anti_loop_score"] = metric_summary.get("anti_loop_score", 0.0)
step_losses["novelty_score"] = metric_summary.get("novelty_score", 0.0)
step_losses["generalization_score"] = metric_summary.get("generalization_score", 0.0)
step_losses["reasoning_score"] = metric_summary.get("reasoning_score", 0.0)
step_losses["meaning_discriminator"] = metric_summary.get("meaning_discriminator", 0.0)
step_losses["global_constraint"] = metric_summary.get("global_constraint", 0.0)
return step_losses
def _format_guardrail_alerts(self, summary: Dict[str, object]) -> List[str]:
alerts = summary.get("guardrail_alerts") or []
if not isinstance(alerts, list):
return []
return [str(a) for a in alerts]
def _emit_guardrail_alerts(self, epoch: int, summary: Dict[str, object]):
alerts = self._format_guardrail_alerts(summary)
if not alerts:
print(f"[Epoch {epoch+1}] Guardrail: OK")
return
print(f"[Epoch {epoch+1}] Guardrail alert ({len(alerts)}): " + " | ".join(alerts))
def train(self):
train_loader = DataLoader(
self.train_dataset,
batch_size=self.config.batch_size,
sampler=self._train_sampler,
collate_fn=collate_fn,
num_workers=0,
)
eval_loader = None
if self.eval_dataset is not None and len(self.eval_dataset) == 0:
raise ValueError("Eval dataset kosong setelah preprocessing/filtering. Hentikan training dan perbaiki pipeline data split.")
if self.eval_dataset:
eval_loader = DataLoader(
self.eval_dataset,
batch_size=self.config.batch_size,
sampler=self._eval_sampler,
collate_fn=collate_fn,
num_workers=0,
)
print(f"[AKSARA] Mulai training — {self.model.num_parameters['trainable']:,} parameter")
print(f"[AKSARA] Device: {self.device} | Epochs: {self.config.num_epochs}")
print(f"[AKSARA] KBBI coverage: {self.model.lsk.kbbi_coverage:.1%}")
for epoch in range(self.config.num_epochs):
epoch_losses = []
t0 = time.time()
self.metrics.reset()
for batch in train_loader:
lr = self._get_lr(self.global_step)
self._set_lr(lr)
step_losses = self._training_step(batch)
epoch_losses.append(step_losses)
self._train_losses.append(step_losses)
self.global_step += 1
if self.global_step % self.config.log_every_n_steps == 0:
self._log_step(step_losses, epoch, lr)
if eval_loader and self.global_step % self.config.eval_every_n_steps == 0:
eval_loss = self._evaluate(eval_loader)
print(f"[Eval step {self.global_step}] total={eval_loss['total']:.4f}")
self._safe_checkpoint(self.global_step, eval_loss)
if eval_loss["total"] < self.best_eval_loss:
self.best_eval_loss = eval_loss["total"]
self._safe_checkpoint("best", eval_loss)
if self.global_step % self.config.save_every_n_steps == 0:
self._safe_checkpoint(f"step_{self.global_step}")
for cb in self.callbacks:
cb(self, step_losses, self.global_step)
avg = self._average_losses(epoch_losses)
epoch_summary = self.metrics.epoch_summary()
epoch_summary.update(avg)
self._last_epoch_summary = epoch_summary
elapsed = time.time() - t0
print(
f"\n[Epoch {epoch+1}/{self.config.num_epochs}] "
f"total={epoch_summary.get('total', avg.get('total', 0)):.4f} | "
f"binary={epoch_summary.get('l_binary', avg.get('l_binary', 0)):.4f} | "
f"margin={epoch_summary.get('l_margin', avg.get('l_margin', 0)):.4f} | "
f"hard_neg={epoch_summary.get('l_hard_neg', avg.get('l_hard_neg', 0)):.4f} | "
f"consist={epoch_summary.get('l_consist', avg.get('l_consist', 0)):.4f} | "
f"confidence={epoch_summary.get('l_confidence', avg.get('l_confidence', 0)):.4f} | "
f"score_mean={epoch_summary.get('score_mean', avg.get('score_mean', 0)):.4f} | "
f"morph_acc={epoch_summary.get('morph_accuracy', 0):.4f} | "
f"struct_acc={(epoch_summary.get('struct_accuracy') or 0):.4f} | "
f"kbbi_align={epoch_summary.get('kbbi_alignment', 0):.4f} | "
f"score_cal={epoch_summary.get('score_calibration', 0):.4f} | "
f"sem_coh={epoch_summary.get('sentence_semantic_coherence', 0):.4f} | "
f"anti_loop={epoch_summary.get('anti_loop_score', 0):.4f} | "
f"batches={int(epoch_summary.get('epoch_batches', 0))} | "
f"samples={int(epoch_summary.get('epoch_samples', 0))} | "
f"{elapsed:.1f}s"
)
self._emit_guardrail_alerts(epoch, epoch_summary)
if self.pd:
diag = self.pd.get_diagnostics()
print(
f" PD λ: morph={diag['lambdas']['morph']:.3f} "
f"struct={diag['lambdas']['struct']:.3f} "
f"sem={diag['lambdas']['sem']:.3f} "
f"ctx={diag['lambdas']['ctx']:.3f}"
)
if self.best_eval_loss == float("inf") and eval_loader:
self.best_eval_loss = self._evaluate(eval_loader)["total"]
self._safe_checkpoint("best", {"total": self.best_eval_loss})
self._safe_checkpoint("final")
print(f"\n[AKSARA] Training selesai. Best eval loss: {self.best_eval_loss:.4f}")
return self._train_losses
def _evaluate(self, loader: DataLoader) -> Dict:
all_losses = []
for batch in loader:
losses = self._eval_step(batch)
all_losses.append(losses)
return self._average_losses(all_losses)
@staticmethod
def _average_losses(losses_list: List[Dict]) -> Dict:
if not losses_list:
return {}
keys = [k for k in losses_list[0] if k != "lambdas"]
return {k: sum(d.get(k, 0) for d in losses_list) / len(losses_list) for k in keys}
def _log_step(self, losses: Dict, epoch: int, lr: float):
pd_info = ""
if self.pd:
lam = self.pd.get_lambdas()
pd_info = f" | λ_m={lam['morph']:.2f} λ_s={lam['struct']:.2f} λ_k={lam['sem']:.2f} λ_c={lam['ctx']:.2f}"
def _fmt(value, default=0.0):
try:
return float(value if value is not None else default)
except (TypeError, ValueError):
return default
print(
f"[Step {self.global_step} | E{epoch+1}] "
f"total={_fmt(losses.get('total')):.4f} "
f"binary={_fmt(losses.get('l_binary')):.4f} "
f"margin={_fmt(losses.get('l_margin')):.4f} "
f"hard_neg={_fmt(losses.get('l_hard_neg')):.4f} "
f"consist={_fmt(losses.get('l_consist')):.4f} "
f"confidence={_fmt(losses.get('l_confidence')):.4f} "
f"score_mean={_fmt(losses.get('score_mean')):.4f} "
f"sem_coh={_fmt(losses.get('sentence_semantic_coherence')):.4f} "
f"anti_loop={_fmt(losses.get('anti_loop_score')):.4f} "
f"lr={lr:.2e}{pd_info}"
)
def _safe_checkpoint(self, tag: str, eval_loss: Optional[Dict] = None):
path = os.path.join(self.config.output_dir, f"checkpoint_{tag}")
try:
self.model.save(path)
if self.pd:
diag = self.pd.get_diagnostics()
with open(os.path.join(path, "pd_state.json"), "w", encoding="utf-8") as f:
json.dump(diag, f, indent=2)
if eval_loss is not None and isinstance(eval_loss, dict):
with open(os.path.join(path, "eval_loss.json"), "w", encoding="utf-8") as f:
json.dump(eval_loss, f, indent=2)
except (RuntimeError, OSError) as exc:
print(f"[AKSARA] Warning: checkpoint '{tag}' gagal disimpan: {exc}")
return False
return True