"""State-native dataset utilities for AKSARA training. This module is a design-time placeholder/spec scaffold for the new state-native training pipeline. It defines the expected data contracts and helper entry points without implementing the full training logic. """ from __future__ import annotations from dataclasses import dataclass, field from typing import Dict, List, Optional @dataclass class StateTrainingRecord: """Single state-native training record. Expected JSONL fields: - text: raw surface string - label: 0/1 or other scalar supervision - state: optional nested state supervision - trace: optional evolution trace - source: optional provenance tag """ text: str label: Optional[float] = None state: Optional[Dict] = None trace: Optional[List[Dict]] = None source: Optional[str] = None meta: Dict = field(default_factory=dict) class StateDataset: """Container for state-native records. The final implementation should expose: - __len__ - __getitem__ - filtering/normalization helpers - compatibility with a custom collate function """ def __init__(self, records: List[Dict], min_length: int = 2): self.records = [] for record in records: text = str(record.get("text", "")).strip() if len(text.split()) < min_length: continue self.records.append(record) def __len__(self) -> int: return len(self.records) def __getitem__(self, idx: int) -> Dict: return self.records[idx] def load_state_jsonl(path: str, limit: int = 0) -> List[Dict]: """Load JSONL records for state-native training. The implementation in the final training pipeline should parse one JSON object per line and preserve optional nested state supervision. """ records: List[Dict] = [] with open(path, encoding="utf-8") as f: for raw_line in f: line = raw_line.strip() if not line: continue try: import json obj = json.loads(line) except Exception: continue if isinstance(obj, dict) and obj.get("text"): records.append(obj) if limit and len(records) >= limit: break return records def load_state_corpus(path: str, limit: int = 0) -> List[Dict]: """Compatibility wrapper around load_state_jsonl().""" return load_state_jsonl(path, limit=limit) def state_collate_fn(batch: List[Dict], root_vocab: Dict[str, int], max_length: int = 32) -> Dict: """Build the state-native batch contract. Final implementation should return: - lps_input - labels - state_targets - meta - attention_mask - dep_masks """ texts = [str(item.get("text", "")) for item in batch] labels = [item.get("label") for item in batch] state_targets = [item.get("state") for item in batch] meta = [dict(item.get("meta", {})) for item in batch] return { "texts": texts, "labels": labels, "state_targets": state_targets, "meta": meta, "lps_input": { "texts": texts, "root_vocab": root_vocab, "max_length": max_length, }, "attention_mask": None, "dep_masks": None, }