"""Training configuration dataclass and model recipes.""" from __future__ import annotations from dataclasses import asdict, dataclass @dataclass class TrainConfig: # Model. The default is the "m5-smart" recipe: about 55-60M parameters. vocab_size: int = 24000 d_model: int = 512 n_heads: int = 8 n_layers: int = 6 dim_feedforward: int = 2048 dropout: float = 0.10 max_len: int = 512 # Training. Batch 2 is the default for the 56M recipe; the watchdog keeps # the process about 7 GB below a 32 GB unified-memory ceiling. batch_size: int = 2 grad_accum: int = 8 epochs: int = 8 lr: float = 3e-4 min_lr_ratio: float = 0.08 warmup_steps: int = 800 weight_decay: float = 0.01 label_smoothing: float = 0.05 grad_clip: float = 1.0 # Data. max_pairs: int = 120000 val_split: float = 0.03 seed: int = 1337 # Runtime safety. eval_every_steps: int = 1000 save_every_steps: int = 1000 log_every_steps: int = 25 empty_cache_every: int = 10 memory_stop_fraction: float = 0.78 num_workers: int = 0 # Decoding defaults. beam_size: int = 4 length_penalty: float = 0.7 repetition_penalty: float = 1.08 RECIPES: dict[str, dict] = { # Always works. Good for testing the full pipeline. "survivor": dict( vocab_size=16000, d_model=384, n_heads=6, n_layers=4, dim_feedforward=1536, max_len=256, batch_size=2, grad_accum=8, max_pairs=80000, epochs=6, ), # Default. Best balance for an M5 Mac with 32 GB RAM. "m5-smart": dict( vocab_size=24000, d_model=512, n_heads=8, n_layers=6, dim_feedforward=2048, max_len=512, batch_size=2, grad_accum=8, max_pairs=120000, epochs=8, ), # Heavier. Use only if m5-smart trains overnight without memory stops. "m5-large": dict( vocab_size=24000, d_model=640, n_heads=10, n_layers=8, dim_feedforward=2560, max_len=384, batch_size=1, grad_accum=16, max_pairs=180000, epochs=10, ), } def apply_recipe(name: str, cfg: TrainConfig | None = None) -> TrainConfig: cfg = cfg or TrainConfig() if name not in RECIPES: raise SystemExit(f"Unknown recipe '{name}'. Choose: {', '.join(RECIPES)}") for key, value in RECIPES[name].items(): setattr(cfg, key, value) return cfg def model_param_count(cfg: TrainConfig) -> int: # Close estimate. The exact count is printed after model construction. e = cfg.d_model ff = cfg.dim_feedforward enc_layer = 4 * e * e + 2 * e * ff dec_layer = 8 * e * e + 2 * e * ff return cfg.vocab_size * e + cfg.n_layers * (enc_layer + dec_layer)