"""Checkpoint manager for state-native AKSARA training. This module persists training state, model weights, optimizer/scheduler state, and exportable checkpoint artifacts for the state-native pipeline. """ from __future__ import annotations from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Optional import json @dataclass class TrainingCheckpointState: """Serializable resume state for state-native training.""" phase_index: int = 0 epoch_index: int = 0 global_step: int = 0 best_metric: float = 0.0 metrics: Dict[str, Any] = field(default_factory=dict) config_snapshot: Dict[str, Any] = field(default_factory=dict) class StateCheckpointManager: """Checkpoint manager for state-native training.""" def __init__(self, checkpoint_dir: str): self.checkpoint_dir = checkpoint_dir def save( self, model: Any, optimizer: Optional[Any] = None, scheduler: Optional[Any] = None, state: Optional[TrainingCheckpointState] = None, metadata: Optional[Dict[str, Any]] = None, ) -> None: checkpoint_path = Path(self.checkpoint_dir) checkpoint_path.mkdir(parents=True, exist_ok=True) payload = { "state": { "phase_index": state.phase_index if state else 0, "epoch_index": state.epoch_index if state else 0, "global_step": state.global_step if state else 0, "best_metric": state.best_metric if state else 0.0, "metrics": state.metrics if state else {}, "config_snapshot": state.config_snapshot if state else {}, }, "metadata": metadata or {}, } with open(checkpoint_path / "checkpoint.json", "w", encoding="utf-8") as f: json.dump(payload, f, ensure_ascii=False, indent=2) if hasattr(model, "state_dict"): import torch torch.save(model.state_dict(), checkpoint_path / "model.pt") if optimizer is not None and hasattr(optimizer, "state_dict"): import torch torch.save(optimizer.state_dict(), checkpoint_path / "optimizer.pt") if scheduler is not None and hasattr(scheduler, "state_dict"): import torch torch.save(scheduler.state_dict(), checkpoint_path / "scheduler.pt") def load(self) -> Dict[str, Any]: checkpoint_path = Path(self.checkpoint_dir) / "checkpoint.json" if not checkpoint_path.exists(): return {} with open(checkpoint_path, encoding="utf-8") as f: return json.load(f) def export(self, output_dir: str, model: Any, metadata: Optional[Dict[str, Any]] = None) -> None: import shutil src = Path(self.checkpoint_dir) dst = Path(output_dir) dst.mkdir(parents=True, exist_ok=True) for name in ["checkpoint.json", "model.pt", "optimizer.pt", "scheduler.pt"]: src_file = src / name if src_file.exists(): shutil.copy2(src_file, dst / name) if metadata: with open(dst / "metadata.json", "w", encoding="utf-8") as f: json.dump(metadata, f, ensure_ascii=False, indent=2) def save_training_checkpoint(*args, **kwargs): """Compatibility wrapper contract.""" raise NotImplementedError("save_training_checkpoint is a design-time contract only.") def load_training_checkpoint(*args, **kwargs): """Compatibility wrapper contract.""" raise NotImplementedError("load_training_checkpoint is a design-time contract only.") def export_final_checkpoint(manager, output_dir: str, model: Any = None, metadata: Optional[Dict[str, Any]] = None): """Export final checkpoint from the active checkpoint manager.""" return manager.export(output_dir, model, metadata=metadata)