from __future__ import annotations import argparse import json from pathlib import Path from typing import Any import matplotlib.pyplot as plt def _load_log_history(output_dir: Path) -> list[dict[str, Any]]: summary = output_dir / "training_summary.json" if summary.exists(): return json.loads(summary.read_text(encoding="utf-8")).get("log_history", []) trainer_state = output_dir / "trainer_state.json" if trainer_state.exists(): return json.loads(trainer_state.read_text(encoding="utf-8")).get("log_history", []) raise FileNotFoundError(f"No training_summary.json or trainer_state.json found in {output_dir}") def _series(log_history: list[dict[str, Any]], metric_names: tuple[str, ...]) -> tuple[list[int], list[float]]: steps: list[int] = [] values: list[float] = [] for row in log_history: for name in metric_names: if name in row: steps.append(int(row.get("step", len(steps) + 1))) values.append(float(row[name])) break return steps, values def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--output-dir", default="outputs/grpo_gpt_oss") parser.add_argument("--plots-dir", default="plots") parser.add_argument("--baseline", type=float, required=True) parser.add_argument("--trained", type=float, required=True) args = parser.parse_args() output_dir = Path(args.output_dir) plots_dir = Path(args.plots_dir) plots_dir.mkdir(parents=True, exist_ok=True) log_history = _load_log_history(output_dir) reward_steps, rewards = _series(log_history, ("reward", "train/reward", "episode/reward", "episode/success_rate")) if not rewards: raise ValueError("No reward-like metric found in training logs.") fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(reward_steps, rewards, label="Trained agent", linewidth=2) ax.axhline(args.baseline, color="gray", linestyle=":", label="Untrained baseline") ax.set_xlabel("Training Step") ax.set_ylabel("Reward (0-1)") ax.set_title("OrbitalThrusterEnv: GRPO Training Progress") ax.legend() fig.tight_layout() fig.savefig(plots_dir / "reward_curve.png", dpi=150) plt.close(fig) fig, ax = plt.subplots(figsize=(7, 5)) ax.bar(["Untrained", "Trained"], [args.baseline, args.trained], color=["#70757a", "#1967d2"]) ax.set_ylim(0, 1) ax.set_ylabel("Mean Episode Reward") ax.set_title("Before vs After GRPO") for idx, value in enumerate([args.baseline, args.trained]): ax.text(idx, min(value + 0.03, 0.97), f"{value:.3f}", ha="center") fig.tight_layout() fig.savefig(plots_dir / "before_after.png", dpi=150) plt.close(fig) if __name__ == "__main__": main()