from __future__ import annotations import argparse import json import os from pathlib import Path from typing import Any from datasets import Dataset from trl import GRPOConfig, GRPOTrainer from client import OrbitalThrusterEnv from models import OrbitalThrusterAction TASKS = ( "detumble_satellite", "retarget_180_flip", "long_horizon_precision_hold", ) SYSTEM_PROMPT = """You are a spacecraft attitude-control agent. Use the available tools to fire one thruster pulse or idle. Goal: minimize signed attitude error, damp angular velocity, avoid overshoot, conserve fuel. Call tools until the episode is done or the spacecraft is stably on target.""" def _as_dict(value: Any) -> dict[str, Any]: if hasattr(value, "model_dump"): return value.model_dump() if isinstance(value, dict): return value return {} def format_observation(observation: Any) -> str: obs = _as_dict(observation) fields = [ f"task={obs.get('task_id')}", f"difficulty={obs.get('difficulty')}", f"phase={obs.get('mission_phase')}", f"attitude_error_deg={obs.get('attitude_error_deg')}", f"angular_velocity_dps={obs.get('current_angular_velocity_dps')}", f"target_attitude_deg={obs.get('target_attitude_deg')}", f"fuel_remaining={obs.get('fuel_remaining')}", f"steps={obs.get('steps_used')}/{obs.get('step_budget')}", f"last_feedback={obs.get('last_feedback')}", ] return "\n".join(str(field) for field in fields) class OrbitalThrusterToolEnv: def __init__(self) -> None: self.base_url = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/") self.client = OrbitalThrusterEnv(base_url=self.base_url).sync() self.client.connect() self.reward = 0.0 self.total_reward = 0.0 self.steps = 0 self.success = False self.done = False self.task_id = TASKS[0] self._last_observation: Any = None def reset(self, **kwargs: Any) -> str | None: task_id = kwargs.get("task_id") or TASKS[0] if isinstance(task_id, list): task_id = task_id[0] if task_id not in TASKS: task_id = TASKS[0] self.task_id = str(task_id) self.reward = 0.0 self.total_reward = 0.0 self.steps = 0 self.success = False self.done = False result = self.client.reset(task_id=self.task_id) self._last_observation = result.observation return format_observation(result.observation) def fire_pitch(self, direction: str, size: str, reason: str = "") -> str: """Fire a pitch-axis thruster pulse. Args: direction: Either 'pos' or 'neg'. size: Either 'small' or 'large'. reason: Short control rationale. Returns: Latest spacecraft telemetry and reward feedback. """ return self._fire("pitch", direction, size, reason) def fire_roll(self, direction: str, size: str, reason: str = "") -> str: """Fire a roll-axis thruster pulse. Args: direction: Either 'pos' or 'neg'. size: Either 'small' or 'large'. reason: Short control rationale. Returns: Latest spacecraft telemetry and reward feedback. """ return self._fire("roll", direction, size, reason) def fire_yaw(self, direction: str, size: str, reason: str = "") -> str: """Fire a yaw-axis thruster pulse. Args: direction: Either 'pos' or 'neg'. size: Either 'small' or 'large'. reason: Short control rationale. Returns: Latest spacecraft telemetry and reward feedback. """ return self._fire("yaw", direction, size, reason) def idle(self, reason: str = "") -> str: """Use no thruster this step. Args: reason: Short control rationale. Returns: Latest spacecraft telemetry and reward feedback. """ return self._step("idle", reason) def _fire(self, axis: str, direction: str, size: str, reason: str) -> str: direction = direction.strip().lower() size = size.strip().lower() if direction not in {"pos", "neg"}: raise ValueError("direction must be 'pos' or 'neg'") if size not in {"small", "large"}: raise ValueError("size must be 'small' or 'large'") return self._step(f"fire_{axis}_{direction}_{size}", reason) def _step(self, action_type: str, reason: str) -> str: if self.done: raise ValueError("Episode already done.") action = OrbitalThrusterAction(action_type=action_type, reason=reason[:240]) result = self.client.step(action) self._last_observation = result.observation obs = _as_dict(result.observation) self.steps = int(obs.get("steps_used", self.steps + 1)) self.total_reward = float(obs.get("reward_so_far", self.total_reward + float(result.reward))) self.success = bool(obs.get("success", False)) self.done = bool(result.done or obs.get("done", False)) average_reward = self.total_reward / max(self.steps, 1) self.reward = min(1.0, max(0.0, average_reward + (0.15 if self.success else 0.0))) return format_observation(result.observation) def reward_func(environments: list[OrbitalThrusterToolEnv], **kwargs: Any) -> list[float]: rewards = [float(env.reward) for env in environments] log_metric = kwargs.get("log_metric") if callable(log_metric) and environments: log_metric("episode/success_rate", sum(1.0 for env in environments if env.success) / len(environments)) log_metric("episode/mean_steps", sum(float(env.steps) for env in environments) / len(environments)) return rewards def get_prompt_dataset(samples_per_task: int) -> Dataset: prompts: list[list[dict[str, str]]] = [] task_ids: list[str] = [] for task_id in TASKS: for _ in range(samples_per_task): prompts.append( [ {"role": "system", "content": SYSTEM_PROMPT}, { "role": "user", "content": ( f"Task: {task_id}. Operate the spacecraft by calling tools. " "Prefer small pulses near target, brake rates before overshoot, and preserve fuel." ), }, ] ) task_ids.append(task_id) return Dataset.from_dict({"prompt": prompts, "task_id": task_ids}) def load_unsloth_model(model_id: str, max_seq_length: int): from unsloth import FastLanguageModel model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_id, max_seq_length=max_seq_length, dtype=None, load_in_4bit=True, ) model = FastLanguageModel.get_peft_model( model, r=16, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_alpha=16, lora_dropout=0, bias="none", use_gradient_checkpointing="unsloth", random_state=42, ) return model, tokenizer def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--model-id", default=os.environ.get("MODEL_ID", "unsloth/gpt-oss-20b-unsloth-bnb-4bit")) parser.add_argument("--env-url", default=os.environ.get("ENV_URL", "http://localhost:7860")) parser.add_argument("--hub-model-id", default=os.environ.get("HUB_MODEL_ID", "pixxel-phantom/orbital-thruster-gpt-oss-20b-grpo")) parser.add_argument("--output-dir", default=os.environ.get("OUTPUT_DIR", "outputs/grpo_gpt_oss")) parser.add_argument("--max-steps", type=int, default=int(os.environ.get("MAX_STEPS", "200"))) parser.add_argument("--samples-per-task", type=int, default=int(os.environ.get("SAMPLES_PER_TASK", "128"))) parser.add_argument("--max-seq-length", type=int, default=int(os.environ.get("MAX_SEQ_LENGTH", "4096"))) parser.add_argument("--max-completion-length", type=int, default=int(os.environ.get("MAX_COMPLETION_LENGTH", "1536"))) parser.add_argument("--num-generations", type=int, default=int(os.environ.get("NUM_GENERATIONS", "4"))) parser.add_argument("--gradient-accumulation-steps", type=int, default=int(os.environ.get("GRADIENT_ACCUMULATION_STEPS", "4"))) parser.add_argument("--learning-rate", type=float, default=float(os.environ.get("LEARNING_RATE", "5e-6"))) parser.add_argument("--report-to", default=os.environ.get("REPORT_TO", "trackio")) parser.add_argument("--run-name", default=os.environ.get("RUN_NAME", "orbital-thruster-gpt-oss-20b-grpo")) parser.add_argument("--no-push", action="store_true") args = parser.parse_args() os.environ["ENV_URL"] = args.env_url.rstrip("/") Path(args.output_dir).mkdir(parents=True, exist_ok=True) dataset = get_prompt_dataset(args.samples_per_task) model, tokenizer = load_unsloth_model(args.model_id, args.max_seq_length) report_to = [] if args.report_to.lower() in {"", "none", "off"} else [args.report_to] training_args = GRPOConfig( output_dir=args.output_dir, max_steps=args.max_steps, per_device_train_batch_size=1, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, warmup_ratio=0.1, logging_steps=1, save_steps=50, save_total_limit=2, num_generations=args.num_generations, max_completion_length=args.max_completion_length, log_completions=True, report_to=report_to, run_name=args.run_name, push_to_hub=not args.no_push, hub_model_id=args.hub_model_id if not args.no_push else None, ) trainer = GRPOTrainer( model=model, processing_class=tokenizer, reward_funcs=reward_func, train_dataset=dataset, args=training_args, environment_factory=OrbitalThrusterToolEnv, ) trainer.train() adapter_dir = Path(args.output_dir) / "final_adapters" trainer.model.save_pretrained(adapter_dir) tokenizer.save_pretrained(adapter_dir) summary_path = Path(args.output_dir) / "training_summary.json" summary_path.write_text( json.dumps( { "model_id": args.model_id, "hub_model_id": None if args.no_push else args.hub_model_id, "env_url": args.env_url, "max_steps": args.max_steps, "samples_per_task": args.samples_per_task, "num_generations": args.num_generations, "log_history": trainer.state.log_history, }, indent=2, ), encoding="utf-8", ) if not args.no_push: trainer.push_to_hub() if __name__ == "__main__": main()