feat: enhance observation processing and debugging in RLAgent for MetaWorld policies
Browse files
agent.py
CHANGED
|
@@ -45,6 +45,20 @@ class RLAgent(AgentInterface):
|
|
| 45 |
if hasattr(self.policy, 'bias'):
|
| 46 |
print(f"Policy bias: {self.policy.bias}")
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
# Track episode state
|
| 49 |
self.episode_step = 0
|
| 50 |
self.max_episode_steps = kwargs.get("max_episode_steps", 200)
|
|
@@ -52,6 +66,9 @@ class RLAgent(AgentInterface):
|
|
| 52 |
# Policy scaling factor (can be adjusted if policy constants are too high)
|
| 53 |
self.policy_scale = kwargs.get("policy_scale", 1.0)
|
| 54 |
|
|
|
|
|
|
|
|
|
|
| 55 |
# Debug flags
|
| 56 |
self.debug_observations = True
|
| 57 |
self.debug_actions = True
|
|
@@ -83,12 +100,54 @@ class RLAgent(AgentInterface):
|
|
| 83 |
# Process observation to extract the format needed by the expert policy
|
| 84 |
processed_obs = self._process_observation(obs)
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
# Debug processed observation (reduced frequency)
|
| 87 |
print(f"Processed obs: shape={processed_obs.shape}, dtype={processed_obs.dtype}")
|
| 88 |
print(f"Processed obs sample: {processed_obs[:10]}...") # First 10 values
|
| 89 |
|
| 90 |
-
#
|
| 91 |
-
action_numpy =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
# Debug raw policy output (reduced frequency)
|
| 94 |
print(f"Raw policy action: {action_numpy}, type: {type(action_numpy)}")
|
|
@@ -136,58 +195,151 @@ class RLAgent(AgentInterface):
|
|
| 136 |
"""
|
| 137 |
Helper method to process observations for the MetaWorld expert policy.
|
| 138 |
|
| 139 |
-
MetaWorld policies typically expect
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
"""
|
| 141 |
if isinstance(obs, dict):
|
| 142 |
-
#
|
| 143 |
-
|
| 144 |
-
"observation",
|
| 145 |
-
"obs",
|
| 146 |
-
"
|
| 147 |
-
"achieved_goal",
|
| 148 |
-
"
|
| 149 |
]
|
| 150 |
-
|
| 151 |
processed_obs = None
|
| 152 |
-
for key in
|
| 153 |
if key in obs:
|
| 154 |
processed_obs = obs[key]
|
| 155 |
-
print(f"Using observation key: {key}")
|
| 156 |
break
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
if processed_obs is None:
|
| 159 |
-
#
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
| 161 |
for key, value in obs.items():
|
| 162 |
-
if isinstance(value,
|
| 163 |
-
flat_value =
|
| 164 |
-
|
| 165 |
-
print(f"
|
| 166 |
-
|
| 167 |
-
if
|
| 168 |
-
processed_obs = np.concatenate(
|
| 169 |
print(f"Concatenated observation shape: {processed_obs.shape}")
|
| 170 |
else:
|
| 171 |
-
# Last resort:
|
| 172 |
-
processed_obs =
|
| 173 |
-
print("No
|
| 174 |
else:
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
-
|
| 178 |
-
if not isinstance(processed_obs, np.ndarray):
|
| 179 |
-
try:
|
| 180 |
-
processed_obs = np.array(processed_obs, dtype=np.float32)
|
| 181 |
-
except Exception as e:
|
| 182 |
-
print(f"Failed to convert observation to numpy array: {e}")
|
| 183 |
-
# Return default observation size for MetaWorld reach task
|
| 184 |
-
processed_obs = np.zeros(39, dtype=np.float32)
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
def reset(self) -> None:
|
| 193 |
"""
|
|
|
|
| 45 |
if hasattr(self.policy, 'bias'):
|
| 46 |
print(f"Policy bias: {self.policy.bias}")
|
| 47 |
|
| 48 |
+
# Inspect policy methods to understand expected input format
|
| 49 |
+
if hasattr(self.policy, 'get_action'):
|
| 50 |
+
print(f"Policy has get_action method")
|
| 51 |
+
if hasattr(self.policy, '_get_obs'):
|
| 52 |
+
print(f"Policy has _get_obs method")
|
| 53 |
+
|
| 54 |
+
# Try to understand what observation format the policy expects
|
| 55 |
+
try:
|
| 56 |
+
# Some MetaWorld policies might have observation space info
|
| 57 |
+
if hasattr(self.policy, 'observation_space'):
|
| 58 |
+
print(f"Policy observation space: {self.policy.observation_space}")
|
| 59 |
+
except:
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
# Track episode state
|
| 63 |
self.episode_step = 0
|
| 64 |
self.max_episode_steps = kwargs.get("max_episode_steps", 200)
|
|
|
|
| 66 |
# Policy scaling factor (can be adjusted if policy constants are too high)
|
| 67 |
self.policy_scale = kwargs.get("policy_scale", 1.0)
|
| 68 |
|
| 69 |
+
# Flag to try different observation processing strategies
|
| 70 |
+
self.try_alternative_obs = True
|
| 71 |
+
|
| 72 |
# Debug flags
|
| 73 |
self.debug_observations = True
|
| 74 |
self.debug_actions = True
|
|
|
|
| 100 |
# Process observation to extract the format needed by the expert policy
|
| 101 |
processed_obs = self._process_observation(obs)
|
| 102 |
|
| 103 |
+
# Optionally normalize observation
|
| 104 |
+
if self.try_alternative_obs:
|
| 105 |
+
processed_obs = self._normalize_observation(processed_obs)
|
| 106 |
+
|
| 107 |
+
# Debug: print all observation keys and their shapes to understand the structure
|
| 108 |
+
if isinstance(obs, dict):
|
| 109 |
+
print("Full observation keys and shapes:")
|
| 110 |
+
for key, value in obs.items():
|
| 111 |
+
if isinstance(value, np.ndarray):
|
| 112 |
+
print(f" {key}: shape={value.shape}, dtype={value.dtype}, range=[{value.min():.3f}, {value.max():.3f}]")
|
| 113 |
+
else:
|
| 114 |
+
print(f" {key}: {type(value)} = {value}")
|
| 115 |
+
|
| 116 |
# Debug processed observation (reduced frequency)
|
| 117 |
print(f"Processed obs: shape={processed_obs.shape}, dtype={processed_obs.dtype}")
|
| 118 |
print(f"Processed obs sample: {processed_obs[:10]}...") # First 10 values
|
| 119 |
|
| 120 |
+
# Try different approaches for the MetaWorld policy
|
| 121 |
+
action_numpy = None
|
| 122 |
+
|
| 123 |
+
# Strategy 1: Try with processed observation (39-dim flattened array)
|
| 124 |
+
try:
|
| 125 |
+
action_numpy = self.policy.get_action(processed_obs)
|
| 126 |
+
print(f"✓ Used processed 39-dim observation for policy")
|
| 127 |
+
except Exception as e1:
|
| 128 |
+
print(f"✗ Failed with processed observation: {e1}")
|
| 129 |
+
|
| 130 |
+
# Strategy 2: Try with raw observation if it's a dict
|
| 131 |
+
if action_numpy is None and isinstance(obs, dict):
|
| 132 |
+
try:
|
| 133 |
+
action_numpy = self.policy.get_action(obs)
|
| 134 |
+
print(f"✓ Used raw observation dictionary for policy")
|
| 135 |
+
except Exception as e2:
|
| 136 |
+
print(f"✗ Failed with raw observation dictionary: {e2}")
|
| 137 |
+
|
| 138 |
+
# Strategy 3: Try extracting specific MetaWorld observation components
|
| 139 |
+
try:
|
| 140 |
+
metaworld_obs = self._extract_metaworld_obs(obs)
|
| 141 |
+
if metaworld_obs is not None:
|
| 142 |
+
action_numpy = self.policy.get_action(metaworld_obs)
|
| 143 |
+
print(f"✓ Used extracted MetaWorld observation for policy")
|
| 144 |
+
except Exception as e3:
|
| 145 |
+
print(f"✗ Failed with extracted observation: {e3}")
|
| 146 |
+
|
| 147 |
+
# Final fallback
|
| 148 |
+
if action_numpy is None:
|
| 149 |
+
print("⚠ Using zero action as fallback")
|
| 150 |
+
action_numpy = np.zeros(4, dtype=np.float32)
|
| 151 |
|
| 152 |
# Debug raw policy output (reduced frequency)
|
| 153 |
print(f"Raw policy action: {action_numpy}, type: {type(action_numpy)}")
|
|
|
|
| 195 |
"""
|
| 196 |
Helper method to process observations for the MetaWorld expert policy.
|
| 197 |
|
| 198 |
+
MetaWorld reach task policies typically expect observations with:
|
| 199 |
+
- End effector position (3 values)
|
| 200 |
+
- Target position (3 values)
|
| 201 |
+
- Joint positions and velocities (various dimensions)
|
| 202 |
+
- Total around 39 dimensions for Sawyer reach task
|
| 203 |
"""
|
| 204 |
if isinstance(obs, dict):
|
| 205 |
+
# MetaWorld-specific observation keys for reach task
|
| 206 |
+
metaworld_keys = [
|
| 207 |
+
"observation", # Standard observation
|
| 208 |
+
"obs", # Alternative observation key
|
| 209 |
+
"state", # State observation
|
| 210 |
+
"achieved_goal", # For goal-based tasks
|
| 211 |
+
"desired_goal", # Target position
|
| 212 |
]
|
| 213 |
+
|
| 214 |
processed_obs = None
|
| 215 |
+
for key in metaworld_keys:
|
| 216 |
if key in obs:
|
| 217 |
processed_obs = obs[key]
|
| 218 |
+
print(f"Using MetaWorld observation key: {key}")
|
| 219 |
break
|
| 220 |
+
|
| 221 |
+
# If we found a specific key, ensure it's the right format
|
| 222 |
+
if processed_obs is not None:
|
| 223 |
+
if isinstance(processed_obs, np.ndarray):
|
| 224 |
+
# Ensure it's flattened and has the right dtype
|
| 225 |
+
processed_obs = processed_obs.flatten().astype(np.float32)
|
| 226 |
+
else:
|
| 227 |
+
processed_obs = np.array(processed_obs, dtype=np.float32).flatten()
|
| 228 |
+
|
| 229 |
if processed_obs is None:
|
| 230 |
+
# Fallback: concatenate relevant observation components
|
| 231 |
+
print("No standard MetaWorld key found, concatenating observation components")
|
| 232 |
+
|
| 233 |
+
# Look for position and velocity information
|
| 234 |
+
components = []
|
| 235 |
for key, value in obs.items():
|
| 236 |
+
if isinstance(value, np.ndarray) and len(value.flatten()) > 0:
|
| 237 |
+
flat_value = value.flatten().astype(np.float32)
|
| 238 |
+
components.append(flat_value)
|
| 239 |
+
print(f"Adding component {key}: shape={flat_value.shape}")
|
| 240 |
+
|
| 241 |
+
if components:
|
| 242 |
+
processed_obs = np.concatenate(components)
|
| 243 |
print(f"Concatenated observation shape: {processed_obs.shape}")
|
| 244 |
else:
|
| 245 |
+
# Last resort: create zeros
|
| 246 |
+
processed_obs = np.zeros(39, dtype=np.float32)
|
| 247 |
+
print("No valid observation components found, using zeros")
|
| 248 |
else:
|
| 249 |
+
# If obs is already an array, ensure it's properly formatted
|
| 250 |
+
processed_obs = np.array(obs, dtype=np.float32).flatten()
|
| 251 |
+
|
| 252 |
+
# Ensure we have the expected dimension for MetaWorld reach (typically 39)
|
| 253 |
+
if len(processed_obs) != 39:
|
| 254 |
+
print(f"Observation dimension mismatch: got {len(processed_obs)}, expected 39")
|
| 255 |
+
if len(processed_obs) < 39:
|
| 256 |
+
# Pad with zeros
|
| 257 |
+
padding = np.zeros(39 - len(processed_obs), dtype=np.float32)
|
| 258 |
+
processed_obs = np.concatenate([processed_obs, padding])
|
| 259 |
+
print(f"Padded observation to 39 dimensions")
|
| 260 |
+
else:
|
| 261 |
+
# Truncate
|
| 262 |
+
processed_obs = processed_obs[:39]
|
| 263 |
+
print(f"Truncated observation to 39 dimensions")
|
| 264 |
|
| 265 |
+
return processed_obs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
|
| 267 |
+
def _extract_metaworld_obs(self, obs):
|
| 268 |
+
"""
|
| 269 |
+
Extract MetaWorld-specific observation components for the reach task.
|
| 270 |
+
|
| 271 |
+
MetaWorld reach observations typically include:
|
| 272 |
+
- Joint positions (7 values for Sawyer)
|
| 273 |
+
- Joint velocities (7 values)
|
| 274 |
+
- End effector position (3 values)
|
| 275 |
+
- Target position (3 values)
|
| 276 |
+
- Other task-specific info
|
| 277 |
+
"""
|
| 278 |
+
if not isinstance(obs, dict):
|
| 279 |
+
return None
|
| 280 |
+
|
| 281 |
+
components = []
|
| 282 |
+
|
| 283 |
+
# Try to find joint positions
|
| 284 |
+
if 'qpos' in obs:
|
| 285 |
+
joint_pos = np.array(obs['qpos'], dtype=np.float32).flatten()
|
| 286 |
+
components.append(joint_pos)
|
| 287 |
+
print(f"Found joint positions: {joint_pos.shape}")
|
| 288 |
+
|
| 289 |
+
# Try to find joint velocities
|
| 290 |
+
if 'qvel' in obs:
|
| 291 |
+
joint_vel = np.array(obs['qvel'], dtype=np.float32).flatten()
|
| 292 |
+
components.append(joint_vel)
|
| 293 |
+
print(f"Found joint velocities: {joint_vel.shape}")
|
| 294 |
+
|
| 295 |
+
# Try to find end effector position
|
| 296 |
+
if 'eef_pos' in obs or 'achieved_goal' in obs:
|
| 297 |
+
eef_key = 'eef_pos' if 'eef_pos' in obs else 'achieved_goal'
|
| 298 |
+
eef_pos = np.array(obs[eef_key], dtype=np.float32).flatten()
|
| 299 |
+
if len(eef_pos) >= 3:
|
| 300 |
+
components.append(eef_pos[:3]) # Take first 3 values (x, y, z)
|
| 301 |
+
print(f"Found end effector position: {eef_pos[:3]}")
|
| 302 |
+
|
| 303 |
+
# Try to find target/goal position
|
| 304 |
+
if 'target_pos' in obs or 'desired_goal' in obs:
|
| 305 |
+
target_key = 'target_pos' if 'target_pos' in obs else 'desired_goal'
|
| 306 |
+
target_pos = np.array(obs[target_key], dtype=np.float32).flatten()
|
| 307 |
+
if len(target_pos) >= 3:
|
| 308 |
+
components.append(target_pos[:3]) # Take first 3 values (x, y, z)
|
| 309 |
+
print(f"Found target position: {target_pos[:3]}")
|
| 310 |
+
|
| 311 |
+
# If we found components, concatenate them
|
| 312 |
+
if components:
|
| 313 |
+
metaworld_obs = np.concatenate(components)
|
| 314 |
+
print(f"Extracted MetaWorld observation: {metaworld_obs.shape} dimensions")
|
| 315 |
+
return metaworld_obs
|
| 316 |
+
|
| 317 |
+
return None
|
| 318 |
+
|
| 319 |
+
def _normalize_observation(self, obs):
|
| 320 |
+
"""
|
| 321 |
+
Normalize observation if needed for MetaWorld policy.
|
| 322 |
|
| 323 |
+
Some MetaWorld policies expect normalized observations.
|
| 324 |
+
"""
|
| 325 |
+
if not isinstance(obs, np.ndarray):
|
| 326 |
+
return obs
|
| 327 |
+
|
| 328 |
+
# Check if observation values are in a reasonable range
|
| 329 |
+
obs_min, obs_max = obs.min(), obs.max()
|
| 330 |
+
|
| 331 |
+
# If values are very large or very small, they might need normalization
|
| 332 |
+
if abs(obs_max) > 10 or abs(obs_min) > 10:
|
| 333 |
+
print(f"Observation values seem large (min={obs_min:.3f}, max={obs_max:.3f}), normalizing...")
|
| 334 |
+
# Normalize to roughly [-1, 1] range
|
| 335 |
+
obs_mean = obs.mean()
|
| 336 |
+
obs_std = obs.std()
|
| 337 |
+
if obs_std > 0:
|
| 338 |
+
normalized_obs = (obs - obs_mean) / obs_std
|
| 339 |
+
print(f"Normalized observation range: [{normalized_obs.min():.3f}, {normalized_obs.max():.3f}]")
|
| 340 |
+
return normalized_obs
|
| 341 |
+
|
| 342 |
+
return obs
|
| 343 |
|
| 344 |
def reset(self) -> None:
|
| 345 |
"""
|