rishiad commited on
Commit
d36955c
·
unverified ·
1 Parent(s): 7fdb0a1

enhance: add policy scaling and improved debug output in RLAgent

Browse files
Files changed (1) hide show
  1. agent.py +38 -28
agent.py CHANGED
@@ -37,10 +37,21 @@ class RLAgent(AgentInterface):
37
  self.policy = SawyerReachV3Policy()
38
  print("Successfully initialized SawyerReachV3Policy")
39
 
 
 
 
 
 
 
 
 
40
  # Track episode state
41
  self.episode_step = 0
42
  self.max_episode_steps = kwargs.get("max_episode_steps", 200)
43
-
 
 
 
44
  # Debug flags
45
  self.debug_observations = True
46
  self.debug_actions = True
@@ -59,32 +70,29 @@ class RLAgent(AgentInterface):
59
  action: Action tensor to take in the environment
60
  """
61
  try:
62
- # Debug observation structure
63
- if self.debug_observations and self.episode_step % 20 == 0:
64
- print(f"Raw observation structure: {type(obs)}")
65
- if isinstance(obs, dict):
66
- print(f"Observation keys: {list(obs.keys())}")
67
- for key, value in obs.items():
68
- if isinstance(value, np.ndarray):
69
- print(f" {key}: shape={value.shape}, dtype={value.dtype}")
70
- else:
71
- print(f" {key}: {type(value)} = {value}")
72
 
73
  # Process observation to extract the format needed by the expert policy
74
  processed_obs = self._process_observation(obs)
75
 
76
- # Debug processed observation
77
- if self.debug_observations and self.episode_step % 20 == 0:
78
- print(f"Processed obs: shape={processed_obs.shape}, dtype={processed_obs.dtype}")
79
- print(f"Processed obs sample: {processed_obs[:10]}...") # First 10 values
80
 
81
  # Use the expert policy
82
  action_numpy = self.policy.get_action(processed_obs)
83
 
84
- # Debug raw policy output
85
- if self.debug_actions and self.episode_step % 20 == 0:
86
- print(f"Raw policy action: {action_numpy}, type: {type(action_numpy)}")
87
- print(f"Action shape: {np.array(action_numpy).shape}")
88
 
89
  # Convert to tensor
90
  if isinstance(action_numpy, (list, tuple)):
@@ -92,6 +100,12 @@ class RLAgent(AgentInterface):
92
  else:
93
  action_tensor = torch.from_numpy(np.array(action_numpy)).float()
94
 
 
 
 
 
 
 
95
  # Ensure correct action dimensionality
96
  if self.action_space and hasattr(self.action_space, 'shape'):
97
  expected_shape = self.action_space.shape[0]
@@ -104,9 +118,8 @@ class RLAgent(AgentInterface):
104
  else:
105
  action_tensor = action_tensor[:expected_shape]
106
 
107
- # Debug final action
108
- if self.debug_actions and self.episode_step % 20 == 0:
109
- print(f"Final action tensor: {action_tensor}")
110
 
111
  self.episode_step += 1
112
  return action_tensor
@@ -139,8 +152,7 @@ class RLAgent(AgentInterface):
139
  for key in possible_keys:
140
  if key in obs:
141
  processed_obs = obs[key]
142
- if self.debug_observations and self.episode_step % 50 == 0:
143
- print(f"Using observation key: {key}")
144
  break
145
 
146
  if processed_obs is None:
@@ -150,13 +162,11 @@ class RLAgent(AgentInterface):
150
  if isinstance(value, (np.ndarray, list, tuple)):
151
  flat_value = np.array(value).flatten()
152
  numeric_values.append(flat_value)
153
- if self.debug_observations and self.episode_step % 50 == 0:
154
- print(f"Concatenating key {key}: shape={flat_value.shape}")
155
 
156
  if numeric_values:
157
  processed_obs = np.concatenate(numeric_values)
158
- if self.debug_observations and self.episode_step % 50 == 0:
159
- print(f"Concatenated observation shape: {processed_obs.shape}")
160
  else:
161
  # Last resort: use first value
162
  processed_obs = next(iter(obs.values()))
 
37
  self.policy = SawyerReachV3Policy()
38
  print("Successfully initialized SawyerReachV3Policy")
39
 
40
+ # Check if policy has any scaling attributes that might need adjustment
41
+ if hasattr(self.policy, 'action_space'):
42
+ print(f"Policy action space: {self.policy.action_space}")
43
+ if hasattr(self.policy, 'scale'):
44
+ print(f"Policy scale: {self.policy.scale}")
45
+ if hasattr(self.policy, 'bias'):
46
+ print(f"Policy bias: {self.policy.bias}")
47
+
48
  # Track episode state
49
  self.episode_step = 0
50
  self.max_episode_steps = kwargs.get("max_episode_steps", 200)
51
+
52
+ # Policy scaling factor (can be adjusted if policy constants are too high)
53
+ self.policy_scale = kwargs.get("policy_scale", 1.0)
54
+
55
  # Debug flags
56
  self.debug_observations = True
57
  self.debug_actions = True
 
70
  action: Action tensor to take in the environment
71
  """
72
  try:
73
+ # Debug observation structure (reduced frequency)
74
+ print(f"Raw observation structure: {type(obs)}")
75
+ if isinstance(obs, dict):
76
+ print(f"Observation keys: {list(obs.keys())}")
77
+ for key, value in obs.items():
78
+ if isinstance(value, np.ndarray):
79
+ print(f" {key}: shape={value.shape}, dtype={value.dtype}")
80
+ else:
81
+ print(f" {key}: {type(value)} = {value}")
 
82
 
83
  # Process observation to extract the format needed by the expert policy
84
  processed_obs = self._process_observation(obs)
85
 
86
+ # Debug processed observation (reduced frequency)
87
+ print(f"Processed obs: shape={processed_obs.shape}, dtype={processed_obs.dtype}")
88
+ print(f"Processed obs sample: {processed_obs[:10]}...") # First 10 values
 
89
 
90
  # Use the expert policy
91
  action_numpy = self.policy.get_action(processed_obs)
92
 
93
+ # Debug raw policy output (reduced frequency)
94
+ print(f"Raw policy action: {action_numpy}, type: {type(action_numpy)}")
95
+ print(f"Action shape: {np.array(action_numpy).shape}")
 
96
 
97
  # Convert to tensor
98
  if isinstance(action_numpy, (list, tuple)):
 
100
  else:
101
  action_tensor = torch.from_numpy(np.array(action_numpy)).float()
102
 
103
+ # Apply scaling factor if needed (helps with policy constants that may be too high)
104
+ action_tensor = action_tensor * self.policy_scale
105
+
106
+ # Clip actions to [-1, 1] range to handle policy constants that may be too high
107
+ action_tensor = torch.clamp(action_tensor, -1.0, 1.0)
108
+
109
  # Ensure correct action dimensionality
110
  if self.action_space and hasattr(self.action_space, 'shape'):
111
  expected_shape = self.action_space.shape[0]
 
118
  else:
119
  action_tensor = action_tensor[:expected_shape]
120
 
121
+ # Debug final action (reduced frequency)
122
+ print(f"Final action tensor: {action_tensor}")
 
123
 
124
  self.episode_step += 1
125
  return action_tensor
 
152
  for key in possible_keys:
153
  if key in obs:
154
  processed_obs = obs[key]
155
+ print(f"Using observation key: {key}")
 
156
  break
157
 
158
  if processed_obs is None:
 
162
  if isinstance(value, (np.ndarray, list, tuple)):
163
  flat_value = np.array(value).flatten()
164
  numeric_values.append(flat_value)
165
+ print(f"Concatenating key {key}: shape={flat_value.shape}")
 
166
 
167
  if numeric_values:
168
  processed_obs = np.concatenate(numeric_values)
169
+ print(f"Concatenated observation shape: {processed_obs.shape}")
 
170
  else:
171
  # Last resort: use first value
172
  processed_obs = next(iter(obs.values()))