"""State-native curriculum trainer for AKSARA. This module provides the runtime training entry point for the state-native pipeline. It trains on linguistic/state representations (BSU / MEB / GOS / AksaraState) and does not use state-native heuristics, next-token softmax, or autoregressive decoding mechanics. """ from __future__ import annotations from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Iterable, Optional import json import random import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from aksara.primitives.sfm.lexicon import LexiconLoader from aksara.primitives.sfm.manifold import SemanticManifold from aksara.primitives.lps.morfem import Morfem, KelasKata from aksara.core.bsu import BSUConfig, BahasaStateUnit from aksara.core.gos import MeaningOutputBridge from aksara.core.meb import MEBConfig, MesinEvolusiBahasa from aksara.training.state_checkpoint import ( StateCheckpointManager, TrainingCheckpointState, export_final_checkpoint, ) from aksara.training.state_evaluator import StateEvaluator, StateEvalMetrics from aksara.training.state_objectives import ( CurriculumObjectiveBundle, build_state_objective_bundle, compute_state_objective_loss, StateAlignmentLoss, ConstraintSatisfactionLoss, SemanticBindingLoss, GOSCoherenceLoss, MultiStateMarginLoss, ) @dataclass class StateTrainingConfig: """Runtime configuration for state-native training.""" seed: int = 42 device: str = "cpu" epochs: int = 1 batch_size: int = 8 lr: float = 3e-4 kbbi_path: str = "kbbi_core_v2.json" weight_decay: float = 1e-4 grad_clip: float = 1.0 log_every: int = 20 eval_every: int = 1 save_every: int = 1 output_dir: str = "./aksara_output_train" export_final: Optional[str] = None resume: Optional[str] = None strict_resume: bool = True data_path: Optional[str] = None val_path: Optional[str] = None test_path: Optional[str] = None curriculum: Dict[str, Any] = field(default_factory=dict) model: Dict[str, Any] = field(default_factory=dict) class StateCorpusDataset(Dataset): """JSONL dataset with state-native supervision.""" def __init__(self, rows: Iterable[dict]): self.rows = list(rows) def __len__(self) -> int: return len(self.rows) def __getitem__(self, idx: int) -> dict: return self.rows[idx] def load_state_jsonl(path: str) -> list[dict]: with open(path, encoding="utf-8") as f: return [json.loads(line) for line in f if line.strip()] def load_state_corpus(path: str) -> list[dict]: p = Path(path) if not p.exists(): raise FileNotFoundError(f"State corpus not found: {path}") if p.suffix.lower() == ".jsonl": return load_state_jsonl(path) if p.suffix.lower() == ".json": with open(path, encoding="utf-8") as f: data = json.load(f) if isinstance(data, list): return data raise ValueError("State corpus JSON must be a list of objects") raise ValueError("State corpus must be .json or .jsonl") def state_collate_fn(batch: list[dict]) -> dict: return { "items": batch, "texts": [row.get("text", "") for row in batch], "states": [row.get("state") for row in batch], "targets": [row.get("target") for row in batch], } class _AllComponentsWrapper(torch.nn.Module): """Wrapper untuk menyimpan BSU+MEB+GOS dalam satu state_dict.""" def __init__(self, components: dict): super().__init__() self.bsu = components["bsu"] self.meb = components["meb"] self.gos = components["gos"] def _build_runtime_model(config: StateTrainingConfig): bsu_cfg = BSUConfig(**config.model.get("bsu_config", {})) meb_cfg = MEBConfig(**config.model.get("meb_config", {})) meb_cfg.bsu_config = bsu_cfg bsu = BahasaStateUnit( bsu_cfg, vocab_size=config.model.get("vocab_size", 5000), affix_vocab_size=config.model.get("affix_vocab_size", 40), kbbi_input_dim=config.model.get("kbbi_input_dim", 16), ) meb = MesinEvolusiBahasa(meb_cfg, affix_vocab_size=config.model.get("affix_vocab_size", 40)) gos = MeaningOutputBridge(bsu_cfg.d_total) return {"bsu": bsu, "meb": meb, "gos": gos} class StateTrainingRunner: """Runtime state-native training runner.""" def __init__( self, config: StateTrainingConfig, checkpoint_manager: Optional[StateCheckpointManager] = None, evaluator: Optional[StateEvaluator] = None, ): self.config = config self.device = torch.device(config.device) random.seed(config.seed) torch.manual_seed(config.seed) self.checkpoint_manager = checkpoint_manager or StateCheckpointManager(config.output_dir) self.evaluator = evaluator or StateEvaluator() self.objectives: CurriculumObjectiveBundle = build_state_objective_bundle(config.curriculum) self.components = _build_runtime_model(config) for comp in self.components.values(): if hasattr(comp, "to"): comp.to(self.device) # SFM dari KBBI nyata — sumber representasi linguistik self._sfm: Optional[SemanticManifold] = None self._sfm_kbbi_dim: int = 31 # dimensi vektor SFM try: _lex = LexiconLoader() _lex.muat_kbbi(config.kbbi_path) self._sfm = SemanticManifold(_lex) except Exception: self._sfm = None # fallback: gunakan nol jika file tidak ada self.optimizer = torch.optim.AdamW( list(self.components["bsu"].parameters()) + list(self.components["meb"].parameters()) + list(self.components["gos"].parameters()), lr=config.lr, weight_decay=config.weight_decay, ) self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, mode="min", factor=0.5, patience=2 ) def _update_running_diagnostics(self, batch_info: Dict[str, Any]) -> None: if not hasattr(self, "_epoch_diagnostics"): self._epoch_diagnostics = {"running_terms": {}, "last_batch": {}, "validation": {}, "test": {}} self._epoch_diagnostics["last_batch"] = batch_info.get("terms", {}) running_terms = self._epoch_diagnostics.setdefault("running_terms", {}) for key, value in batch_info.get("terms", {}).items(): running_terms[key] = running_terms.get(key, 0.0) + float(value) def _prepare_epoch_diagnostics(self) -> None: self._epoch_diagnostics = {"running_terms": {}, "last_batch": {}, "validation": {}, "test": {}} def _token_to_id(self, token: str) -> int: value = sum(ord(ch) for ch in token.lower()) % 4096 return max(1, value) @staticmethod def _detect_affix_id(token: str) -> int: """ Deteksi afiks TBBBI pada token dan kembalikan ID afiks (0=tanpa afiks). Justifikasi linguistik: verba gramatikal Indonesia WAJIB membawa afiks meN- (aktif) atau di- (pasif). Kalimat tanpa afiks verba adalah pelanggaran TBBBI. Pemetaan ID (1-39) → afiks_vocab_size=40: 1-8 : prefiks tunggal (me,men,mem,meng,meny,di,ber,ter) 9-12 : sufiks tunggal (-kan,-i,-an,-nya) 13-20 : sirkumfiks meN-+-kan/i (me-kan=13, men-kan=14, mem-kan=15, meng-kan=16, meny-kan=17, di-kan=18, me-i=19, mem-i=20) 21-28 : sirkumfiks lain (meng-i=21,di-i=22,ber-kan=23,ber-i=24, ter-kan=25,ke-=26,se-=27,pe-=28) 29-32 : reduplikasi/ulang (ulang=29), pe-an=30,per-=31,ter-=8 """ t = token.lower() # Tabel prefiks (urutan: paling panjang dulu agar tidak false-match) PREFIX = [ ("meny", 5), ("meng", 4), ("mem", 3), ("men", 2), ("me", 1), ("di", 6), ("ber", 7), ("ter", 8), ("ke", 26), ("se", 27), ("pe", 28), ] SUFFIX = [ ("kan", 9), ("an", 11), ("nya", 12), ("lah", 12), ] # -i hanya valid sebagai sufiks verba jika ada prefiks (cegah false-positive # pada kata pinjaman berakhir -i: publikasi, sanksi, promosi) SUFFIX_WITH_PREFIX = [("i", 10)] # Sirkumfiks me(N)-...-kan / me(N)-...-i CIRCUMFIX_TABLE = { (1, 9): 13, (2, 9): 14, (3, 9): 15, (4, 9): 16, (5, 9): 17, (6, 9): 18, (1, 10): 19, (2, 10): 19, (3, 10): 20, (4, 10): 21, (5, 10): 19, (6, 10): 22, (7, 9): 23, (7, 10): 24, (8, 9): 25, } VOWELS = set("aiueo") pre_id = 0 suf_id = 0 for pfx, pid in PREFIX: if t.startswith(pfx) and len(t) > len(pfx) + 2: remaining = t[len(pfx):] # Root Indonesia valid: dimulai vokal atau C+V (bukan CC cluster) # Filter false-positive: 'ke'+'rjakan' → 'rj' = CC → bukan prefiks if len(remaining) >= 2 and remaining[0] not in VOWELS and remaining[1] not in VOWELS: continue pre_id = pid t = remaining break active_suffix = SUFFIX + (SUFFIX_WITH_PREFIX if pre_id else []) for sfx, sid in active_suffix: if t.endswith(sfx) and len(t) > len(sfx) + 3: # root min 4 char suf_id = sid break if pre_id and suf_id: return CIRCUMFIX_TABLE.get((pre_id, suf_id), pre_id) if pre_id: return pre_id if suf_id: return suf_id return 0 @staticmethod def _get_kbbi_root(token: str) -> str: """ Normalisasi token ke bentuk dasar untuk lookup KBBI. Justifikasi: KBBI menyimpan bentuk dasar (root), bukan bentuk berimbuhan. Tanpa normalisasi: 'mempublikasikan' → kbbi_vec=0 (tidak ada di KBBI), tapi 'publikasi' → kbbi_vec≠0 → sinyal TERBALIK (invalid terkesan lebih valid). Dengan normalisasi: keduanya lookup 'publikasi' → kbbi_vec sama → sinyal netral; perbedaan gramatikal ditangani oleh affix_emb saja. """ t = token.lower() # Strip prefix (panjang string, tanpa memperhatikan luluh) PREFIX_LEN = [ ("meny", 4), ("meng", 4), ("mem", 3), ("men", 3), ("me", 2), ("di", 2), ("ber", 3), ("ter", 3), ("ke", 2), ("se", 2), ("pe", 2), ] VOWELS = set("aiueo") pre_len = 0 for pfx, plen in PREFIX_LEN: if t.startswith(pfx) and len(t) > plen + 2: remaining = t[plen:] if len(remaining) < 2 or remaining[0] in VOWELS or remaining[1] in VOWELS: pre_len = plen t = remaining break # Strip suffix (-kan, -i, -an, -nya, -lah, -kah) SUFFIX_LEN = [("kan", 3), ("an", 2), ("nya", 3), ("lah", 3), ("kah", 3), ("i", 1)] for sfx, slen in SUFFIX_LEN: if t.endswith(sfx) and len(t) > slen + 3: t = t[:-slen] break return t if t else token.lower() def _build_sfm_vectors(self, tokens: list[str]) -> torch.Tensor: """ Bangun vektor semantik dari KBBI via SemanticManifold (SFM). Output: (1, seq_len, 31) — domain[20]+kelas[9]+register[1]+kepastian[1] Setiap dimensi punya interpretasi linguistik eksplisit. Ini menggantikan _build_kbbi_vectors() lama yang pakai sin/cos noise matematika. """ vectors = [] kbbi_input_dim = int( getattr(self.components["bsu"], "kbbi_proj", None) and getattr(self.components["bsu"].kbbi_proj, "in_features", self._sfm_kbbi_dim) or self._sfm_kbbi_dim ) for token in tokens or [""]: if self._sfm is not None: kbbi_root = self._get_kbbi_root(token) m = Morfem( indeks=0, teks_asli=token, root=kbbi_root, kelas_kata=KelasKata.TIDAK_DIKETAHUI, ada_di_kbbi=self._sfm.leksikon.ada(kbbi_root), ) state = self._sfm.encode_morfem(m) vec_31 = state.vektor_lengkap.float() else: # Fallback: nol vektor jika SFM tidak tersedia vec_31 = torch.zeros(self._sfm_kbbi_dim, dtype=torch.float32) # Sesuaikan dimensi ke kbbi_input_dim yang diharapkan BSU if vec_31.shape[0] < kbbi_input_dim: vec = torch.nn.functional.pad(vec_31, (0, kbbi_input_dim - vec_31.shape[0])) else: vec = vec_31[:kbbi_input_dim] vectors.append(vec) tensor = torch.stack(vectors, dim=0) return tensor.unsqueeze(0).to(dtype=torch.float32) def _build_role_ids(self, target: dict, seq_len: int) -> torch.Tensor: role_name = str(target.get("role", "root")).lower() role_seed = sum(ord(ch) for ch in role_name) values = [(role_seed + idx) % 8 for idx in range(seq_len)] return torch.tensor([values], dtype=torch.long) def build_training_components(self) -> Dict[str, Any]: return self.components def resume_from_checkpoint(self) -> Dict[str, Any]: if not self.config.resume: return {} return self.checkpoint_manager.load(self.config.resume, strict=self.config.strict_resume) def _build_text_verification(self, batch_item: dict, model_output: Dict[str, Any]) -> Dict[str, Any]: text = batch_item.get("text", "") target_state = batch_item.get("state", {}) or {} ringkasan_makna = { "teks": text, "makna": target_state.get("meaning", target_state.get("summary", "")), "role": target_state.get("role", "root"), "register": target_state.get("register", "formal"), } inferensi_terdeteksi = { "kesimpulan": target_state.get("inference", target_state.get("reasoning", "")), "sumber": "state-native reasoning", } keputusan_valid = bool(target_state.get("valid", True)) sentence_output = model_output.get("sentence_output", {}) if isinstance(model_output, dict) else {} alasan = [ f"teks_input={text}", f"ringkasan_makna={ringkasan_makna}", f"inferensi_terdeteksi={inferensi_terdeteksi}", f"keputusan_valid={keputusan_valid}", ] if sentence_output: alasan.extend( [ f"kalimat_terhasil={sentence_output.get('kalimat', '')}", f"makna_terhasil={sentence_output.get('makna', '')}", f"reasoning_terhasil={sentence_output.get('reasoning', '')}", f"bukti_state_terhasil={sentence_output.get('bukti_state', {})}", ] ) if not keputusan_valid: alasan.append(f"penjelasan={target_state.get('explanation', target_state.get('reason', ''))}") return { "teks_input": text, "ringkasan_makna": ringkasan_makna, "inferensi_terdeteksi": inferensi_terdeteksi, "keputusan_valid": keputusan_valid, "sentence_output": sentence_output, "alasan": alasan, "state_eval": model_output.get("state_eval", {}), } def validate_meaning_based_conversation(self, prompt: str) -> Dict[str, Any]: batch_item = {"text": prompt, "state": {"meaning": "", "summary": "", "role": "root", "register": "formal", "valid": True}, "target": {}} model_output = self._infer_state_output(batch_item) verification = model_output.get("text_verification", {}) return { "prompt": prompt, "teks_input": verification.get("teks_input", prompt), "ringkasan_makna": verification.get("ringkasan_makna", {}), "inferensi_terdeteksi": verification.get("inferensi_terdeteksi", {}), "keputusan_valid": verification.get("keputusan_valid", False), "alasan": verification.get("alasan", []), "state_eval": verification.get("state_eval", {}), } def _infer_state_output(self, batch_item: dict) -> Dict[str, Any]: text = batch_item.get("text", "") target_state = batch_item.get("state", {}) or {} target = batch_item.get("target", {}) or {} tokens = text.split() batch_size = 1 seq_len = max(len(tokens), 1) morpheme_ids = torch.tensor([[self._token_to_id(token) for token in tokens] or [0]], dtype=torch.long) affix_ids = torch.tensor( [[self._detect_affix_id(tok) for tok in tokens] or [0]], dtype=torch.long ) kbbi_vectors = self._build_sfm_vectors(tokens) role_ids = self._build_role_ids(target, seq_len) if kbbi_vectors.shape[1] != seq_len: if kbbi_vectors.shape[1] == 1: kbbi_vectors = kbbi_vectors.repeat(1, seq_len, 1) else: kbbi_vectors = kbbi_vectors[:, :seq_len, :] kbbi_input_dim = getattr(self.components["bsu"], "kbbi_proj", None) kbbi_input_dim = int(getattr(kbbi_input_dim, "in_features", kbbi_vectors.shape[-1])) if kbbi_vectors.shape[-1] != kbbi_input_dim: if kbbi_vectors.shape[-1] < kbbi_input_dim: kbbi_vectors = torch.nn.functional.pad( kbbi_vectors, (0, kbbi_input_dim - kbbi_vectors.shape[-1]), ) else: kbbi_vectors = kbbi_vectors[:, :, :kbbi_input_dim] bsu_states, bsu_slots = self.components["bsu"]( morpheme_ids=morpheme_ids.to(self.device), affix_ids=affix_ids.to(self.device), kbbi_vectors=kbbi_vectors.to(self.device), role_ids=role_ids.to(self.device), ) meb_states, meb_trace = self.components["meb"]( bsu_states, affix_ids=affix_ids.to(self.device), kbbi_anchors=kbbi_vectors.to(self.device), return_all_layers=True, ) gos_output = self.components["gos"]( meb_states, bsu_states, ) role_signal = gos_output.get("role_signal") if role_signal is None: role_signal = gos_output.get("relation_signal") if role_signal is None: role_signal = torch.zeros_like(gos_output["global_signal"]) coherence_val = float(torch.sigmoid(gos_output["coherence_signal"].mean()).detach().cpu().item()) state_eval = { "energi_total": float(torch.sigmoid(meb_states.mean()).detach().cpu().item()), "kelengkapan_struktur": coherence_val, "constraint_satisfaction": coherence_val, "pelanggaran": [], # kosong: gunakan cs MSE path di ConstraintSatisfactionLoss "register": target_state.get("register", "formal"), } # Multi-state margin: pisahkan valid (label=1) dari invalid (label=0) # Berdasarkan label dataset, koherensi model harus tinggi untuk kalimat valid label = int((target or {}).get("label", 1)) coherence_score = torch.sigmoid(gos_output["coherence_signal"].mean()) if label == 1: positive_score = coherence_score negative_score = (1.0 - coherence_score).detach() else: positive_score = (1.0 - coherence_score).detach() negative_score = coherence_score model_output = { "text": text, "target_state": target_state, "target": target, "bsu_repr": bsu_states.mean(dim=1), "meb_repr": meb_states.mean(dim=1), "gos_repr": gos_output["h_final"].mean(dim=1) if gos_output["h_final"].dim() == 3 else gos_output["h_final"], "positive_score": positive_score, "negative_score": negative_score, "state_eval": state_eval, "sentence_output": gos_output.get("sentence_output"), "gos_trace": { # trace_scores: (n_layers,) — scalar per layer untuk backward compat "trace_scores": torch.stack([layer.mean(dim=1).mean(dim=-1) for layer in meb_trace], dim=0) if meb_trace else meb_states.mean(dim=-1).unsqueeze(0), # trace_vectors: (n_layers, d_total) — vektor per layer untuk cosine diversity loss # Justifikasi: setiap layer MEB harus berkontribusi representasi berbeda "trace_vectors": torch.stack([layer.mean(dim=1) for layer in meb_trace], dim=0) if meb_trace else meb_states.mean(dim=1).unsqueeze(0), }, "state_profile": { "semantic_signal": gos_output["semantic_signal"], "coherence_signal": gos_output["coherence_signal"], "global_signal": gos_output["global_signal"], "reasoning_signal": gos_output["reasoning_signal"], "role_signal": role_signal, "relation_signal": gos_output.get("relation_signal", role_signal), "state_summary": gos_output["state_summary"], "root_state": gos_output["root_state"], }, "aux": { "bsu": bsu_states, "bsu_slots": bsu_slots, "meb": meb_states, "gos": gos_output, }, } model_output["text_verification"] = self._build_text_verification(batch_item, model_output) return model_output def _forward_batch(self, batch: dict) -> tuple[torch.Tensor, dict]: texts = batch["texts"] states = batch["states"] targets = batch["targets"] outputs = [] batch_loss = torch.zeros((), device=self.device, requires_grad=True) term_sums: Dict[str, float] = { "state_consistency": 0.0, "constraint_satisfaction": 0.0, "semantic_alignment": 0.0, "gos_coherence": 0.0, "multi_state_margin": 0.0, } for text, state, target in zip(texts, states, targets): batch_item = {"text": text, "state": state, "target": target} model_output = self._infer_state_output(batch_item) loss, terms = compute_state_objective_loss(self.objectives, model_output, batch_item) # Direct cs_regression: latih cs_head eksplisit dari label # Justifikasi linguistik: kalimat valid harus punya koherensi tinggi (0.72), # kalimat invalid harus rendah (0.28) — target di tengah sigmoid agar gradient besar label = int((target or {}).get("label", 1)) cs_tgt_val = 0.72 if label == 1 else 0.28 cs_tgt = torch.tensor(cs_tgt_val, device=self.device, dtype=torch.float32) cs_tensor = model_output.get("positive_score" if label == 1 else "negative_score") if cs_tensor is not None and isinstance(cs_tensor, torch.Tensor) and cs_tensor.requires_grad: cs_reg = torch.nn.functional.mse_loss(cs_tensor.squeeze(), cs_tgt) loss = loss + cs_reg * 1.0 term_sums["constraint_satisfaction"] = term_sums.get("constraint_satisfaction", 0.0) + float(cs_reg.detach()) outputs.append(model_output) batch_loss = batch_loss + loss for key, value in terms.items(): if key != "constraint_satisfaction": # sudah di-update dari cs_reg di atas term_sums[key] = term_sums.get(key, 0.0) + float(value) if outputs: batch_loss = batch_loss / float(len(outputs)) term_sums = {k: v / float(len(outputs)) for k, v in term_sums.items()} return batch_loss, {"outputs": outputs, "terms": term_sums} def train_state_native(self) -> Dict[str, Any]: if not self.config.data_path: raise ValueError("State training requires config.data_path") data = load_state_corpus(self.config.data_path) dataset = StateCorpusDataset(data) loader = DataLoader(dataset, batch_size=self.config.batch_size, shuffle=True, collate_fn=state_collate_fn) history = [] for epoch in range(self.config.epochs): self._prepare_epoch_diagnostics() self.components["bsu"].train() self.components["meb"].train() self.components["gos"].train() epoch_losses = [] for batch_index, batch in enumerate(loader, start=1): loss, batch_info = self._forward_batch(batch) self.optimizer.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_( list(self.components["bsu"].parameters()) + list(self.components["meb"].parameters()) + list(self.components["gos"].parameters()), self.config.grad_clip, ) self.optimizer.step() self.scheduler.step(float(loss.detach().cpu().item())) loss_value = float(loss.detach().cpu().item()) epoch_losses.append(loss_value) if self.config.log_every and batch_index % self.config.log_every == 0: print( "[AKSARA][train] " f"epoch={epoch + 1} batch={batch_index} loss={loss_value:.6f} " f"terms={batch_info.get('terms', {})}" ) if self.config.eval_every and batch_index % self.config.eval_every == 0 and self.config.val_path: metrics = self.evaluator.evaluate_validation_split(self.components, self.config.val_path) if self.evaluator.should_advance_phase(metrics, self.config.curriculum): self.config.curriculum["phase_advanced"] = True self._update_running_diagnostics(batch_info) avg_loss = sum(epoch_losses) / max(len(epoch_losses), 1) history.append(avg_loss) diagnostics = self._epoch_diagnostics diagnostics["avg_loss"] = avg_loss if (epoch + 1) % self.config.save_every == 0: # Simpan semua komponen terlatih (BSU+MEB+GOS) dalam satu state_dict _all_components = _AllComponentsWrapper(self.components) self.checkpoint_manager.save( model=_all_components, optimizer=self.optimizer, scheduler=self.scheduler, state=TrainingCheckpointState( phase_index=int(self.config.curriculum.get("phase_index", 0)), epoch_index=epoch + 1, global_step=(epoch + 1) * max(len(loader), 1), best_metric=min(history), metrics={ "loss": avg_loss, "running_terms": diagnostics.get("running_terms", {}), "last_batch": diagnostics.get("last_batch", {}), }, config_snapshot={ "seed": self.config.seed, "device": self.config.device, "curriculum": self.config.curriculum, "model": self.config.model, }, ), metadata={ "loss": avg_loss, "diagnostics": diagnostics, }, ) if self.config.val_path: val_metrics = self.evaluator.evaluate_validation_split(self.components, self.config.val_path) diagnostics["validation"] = self.evaluator.summarize_metrics(val_metrics) if self.evaluator.should_advance_phase(val_metrics, self.config.curriculum): self.config.curriculum["phase_index"] = int(self.config.curriculum.get("phase_index", 0)) + 1 diagnostics["phase_advanced"] = True if self.config.test_path: test_metrics = self.evaluator.evaluate_validation_split(self.components, self.config.test_path) diagnostics["test"] = self.evaluator.summarize_metrics(test_metrics) text_verification = [] for sample in data[: min(len(data), 8)]: model_output = self._infer_state_output(sample) text_verification.append(model_output.get("text_verification", {})) result = { "loss_history": history, "epochs": self.config.epochs, "curriculum": self.config.curriculum, "diagnostics": getattr(self, "_epoch_diagnostics", {}), "text_verification": text_verification, "sentence_output_samples": [item.get("sentence_output", {}) for item in text_verification], } if self.config.export_final: export_final_checkpoint( self.checkpoint_manager, self.config.export_final, model=_AllComponentsWrapper(self.components), metadata=result, ) return result def train_curriculum_state_native(self) -> Dict[str, Any]: return self.train_state_native() def evaluate_and_maybe_save(self, *args, **kwargs) -> Dict[str, Any]: metrics: StateEvalMetrics = self.evaluator.evaluate_validation_split(self.components, self.config.val_path) return {"metrics": metrics} def load_state_training_config(path: str) -> StateTrainingConfig: with open(path, encoding="utf-8") as f: raw = json.load(f) return StateTrainingConfig(**raw) def build_state_training_runner(config: StateTrainingConfig) -> StateTrainingRunner: return StateTrainingRunner(config) def train_state_native(*args, **kwargs) -> Dict[str, Any]: config = kwargs.get("config") if "config" in kwargs else args[0] runner = build_state_training_runner(config) return runner.train_state_native() def train_curriculum_state_native(*args, **kwargs) -> Dict[str, Any]: config = kwargs.get("config") if "config" in kwargs else args[0] runner = build_state_training_runner(config) return runner.train_curriculum_state_native() def validate_meaning_based_conversation(*args, **kwargs) -> Dict[str, Any]: config = kwargs.get("config") if "config" in kwargs else args[0] prompt = kwargs.get("prompt") if "prompt" in kwargs else args[1] runner = build_state_training_runner(config) return runner.validate_meaning_based_conversation(prompt)