from __future__ import annotations import numpy as np import torch import torch.nn as nn from gymnasium import spaces from stable_baselines3.common.torch_layers import BaseFeaturesExtractor class LanduseFeaturesExtractor(BaseFeaturesExtractor): """Extracts a joint embedding from stacked camera views and goal info. The camera stream (stacked frames of camera + visited-mask) is processed by a small CNN; the goal-info stream goes through an MLP. Both are concatenated into a single feature vector. """ def __init__( self, observation_space: spaces.Dict, view_emb_dim: int = 128, goal_emb_dim: int = 32, ): output_dim = view_emb_dim + goal_emb_dim super().__init__(observation_space, features_dim=output_dim) cam_shape = observation_space.spaces["camera"].shape self.camera_h, self.camera_w = cam_shape[1], cam_shape[2] self.stacked_frames = cam_shape[0] total_input_channels = 2 * self.stacked_frames # camera + visited_mask self.cnn = nn.Sequential( nn.Conv2d(total_input_channels, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(), nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1), nn.ReLU(), nn.Flatten(), ) with torch.no_grad(): dummy = torch.zeros(1, total_input_channels, self.camera_h, self.camera_w) raw_cnn_dim = self.cnn(dummy).shape[1] print(f"[DroneFeaturesExtractor] CNN output dim: {raw_cnn_dim}") self.cnn_projection = nn.Sequential( nn.Linear(raw_cnn_dim, view_emb_dim), nn.ReLU(), nn.LayerNorm(view_emb_dim), ) goal_shape = observation_space.spaces["goal_info"].shape total_goal_input_dim = np.prod(goal_shape) self.goal_fc = nn.Sequential( nn.Linear(total_goal_input_dim, goal_emb_dim), nn.ReLU(), nn.Linear(goal_emb_dim, goal_emb_dim), nn.ReLU(), nn.LayerNorm(goal_emb_dim), ) self.apply(self._init_weights) @staticmethod def _init_weights(module: nn.Module): if isinstance(module, (nn.Linear, nn.Conv2d)): nn.init.orthogonal_(module.weight, gain=nn.init.calculate_gain("relu")) if module.bias is not None: nn.init.constant_(module.bias, 0) def forward(self, observations: dict[str, torch.Tensor]) -> torch.Tensor: camera = observations["camera"] mask = observations["visited_mask"] cnn_input = torch.cat([camera, mask], dim=1) raw_cnn = self.cnn(cnn_input) cnn_feat = self.cnn_projection(raw_cnn) goal_emb = self.goal_fc(observations["goal_info"]) return torch.cat([cnn_feat, goal_emb], dim=1)