adityss commited on
Commit
b81683f
·
1 Parent(s): 0af208b

feat: implement multi-component dense reward function and environmental logic for GridMind-RL

Browse files
Files changed (2) hide show
  1. env/environment.go +2 -2
  2. env/rewards.go +3 -3
env/environment.go CHANGED
@@ -84,7 +84,7 @@ func (e *Environment) Reset(req ResetRequest) ResetResponse {
84
 
85
  // Apply task and difficulty
86
  e.taskID = req.TaskID
87
- if e.taskID < 1 || e.taskID > 3 {
88
  e.taskID = 1
89
  }
90
  e.difficulty = req.Difficulty
@@ -94,7 +94,7 @@ func (e *Environment) Reset(req ResetRequest) ResetResponse {
94
  e.difficulty = "easy"
95
  case 2:
96
  e.difficulty = "medium"
97
- case 3:
98
  e.difficulty = "hard"
99
  }
100
  }
 
84
 
85
  // Apply task and difficulty
86
  e.taskID = req.TaskID
87
+ if e.taskID < 1 || e.taskID > 4 {
88
  e.taskID = 1
89
  }
90
  e.difficulty = req.Difficulty
 
94
  e.difficulty = "easy"
95
  case 2:
96
  e.difficulty = "medium"
97
+ case 3, 4:
98
  e.difficulty = "hard"
99
  }
100
  }
env/rewards.go CHANGED
@@ -116,11 +116,11 @@ func ComputeReward(inp ComputeRewardInput) RewardComponents {
116
  }
117
 
118
  // ── Aggregate ────────────────────────────────────────────────────────────
119
- // Total includes all 9 components with fault_mitigation weighted at 0.05
120
- // Reduce StabilityPenalty weight by 0.05 to keep sum = 1.0
121
  rc.Total = rc.CostSavings + rc.TempConstraint + rc.GridResponse +
122
  rc.DeadlinePenalty + rc.EfficiencyBonus + rc.StabilityPenalty + rc.CarbonReward +
123
- rc.InstructionReward + rc.FaultMitigation*0.05 + rc.FaultMitigation*0.95
124
 
125
  return rc
126
  }
 
116
  }
117
 
118
  // ── Aggregate ────────────────────────────────────────────────────────────
119
+ // Total is the sum of all 9 reward components. Each component is computed
120
+ // independently above and contributes directly to the total signal.
121
  rc.Total = rc.CostSavings + rc.TempConstraint + rc.GridResponse +
122
  rc.DeadlinePenalty + rc.EfficiencyBonus + rc.StabilityPenalty + rc.CarbonReward +
123
+ rc.InstructionReward + rc.FaultMitigation
124
 
125
  return rc
126
  }