import unittest from models import OrbitalThrusterAction from server.orbital_thruster_environment import OrbitalThrusterEnvironment from server.reward import RewardBundle from server.tasks import get_task class RewardBundleTests(unittest.TestCase): def setUp(self) -> None: self.task = get_task("detumble_satellite") self.telemetry = { "action_type": "fire_pitch_pos_small", "reason": "Brake pitch error with a small pulse.", "previous_error_norm": 12.0, "error_norm": 8.0, "rate_norm": 0.25, "fuel_used_step": 0.02, "fuel_remaining": 99.98, "fuel_used": 0.02, "overshoot_increment": 0.0, "overshoot_total": 0.0, "on_target_streak": 0, "step_number": 1, } def test_weights_sum_to_one(self) -> None: self.assertAlmostEqual(sum(RewardBundle.weights.values()), 1.0) def test_components_and_score_are_bounded(self) -> None: bundle = RewardBundle() components = bundle.components(self.task, self.telemetry) self.assertEqual( set(components), { "primary_objective", "process_quality", "format_compliance", "efficiency", "constraint_satisfaction", }, ) for value in components.values(): self.assertGreaterEqual(value, 0.0) self.assertLessEqual(value, 1.0) score = bundle.score(self.task, self.telemetry) self.assertGreaterEqual(score, 0.0) self.assertLessEqual(score, 1.0) def test_format_compliance_rejects_invalid_action_and_empty_reason(self) -> None: bundle = RewardBundle() good = bundle.reward_format_compliance(self.task, self.telemetry) bad = dict(self.telemetry, action_type="fire_pitch_wrong", reason="") self.assertGreater(good, bundle.reward_format_compliance(self.task, bad)) class EnvironmentRewardIntegrationTests(unittest.TestCase): def test_environment_supports_concurrent_sessions(self) -> None: self.assertTrue(OrbitalThrusterEnvironment.SUPPORTS_CONCURRENT_SESSIONS) def test_step_feedback_exposes_reward_components(self) -> None: env = OrbitalThrusterEnvironment() env.reset(task_id="detumble_satellite") observation = env.step( OrbitalThrusterAction( action_type="fire_pitch_neg_small", reason="Reduce deployment pitch spin with a small pulse.", ) ) self.assertGreaterEqual(observation.reward, 0.0) self.assertLessEqual(observation.reward, 1.0) self.assertIn("primary=", observation.last_feedback or "") self.assertIn("process=", observation.last_feedback or "") self.assertIn("constraints=", observation.last_feedback or "") if __name__ == "__main__": unittest.main()