"""Inference example — Phase I HASP agent (landuse grid, DQN). Hugging Face repository: https://huggingface.co/bestak/uav-navigation-hasp Two modes are supported: Mode A — Standalone (no drone-navigation package needed) pip install stable-baselines3 huggingface-hub torch gymnasium Mode B — Full package (runs the actual environment) pip install git+https://gitlab.ciirc.cvut.cz/bestavoj/information-driven-uav-navigation.git # also requires the landuse map data in data/ (run drone-prepare-data first) """ from __future__ import annotations import importlib.util import sys import numpy as np REPO_ID = "bestak/uav-navigation-hasp" # ────────────────────────────────────────────────────────────────────────────── # 1. Load the feature extractor from the HF repo (standalone — no package import) # ────────────────────────────────────────────────────────────────────────────── def _load_feature_extractor(repo_id: str): """Dynamically load LanduseFeaturesExtractor from the HF repo.""" import types from huggingface_hub import hf_hub_download fe_path = hf_hub_download(repo_id, "feature_extractor.py") spec = importlib.util.spec_from_file_location("_drone_nav_fe", fe_path) mod = importlib.util.module_from_spec(spec) sys.modules["_drone_nav_fe"] = mod # register so cloudpickle can find it spec.loader.exec_module(mod) # Stub out drone_navigation so cloudpickle can deserialise ALL objects stored # in the zip (lr_schedule, policy_kwargs, etc.) without needing the package. for _name in [ "drone_navigation", "drone_navigation.models", "drone_navigation.models.feature_extractor_aerial", "drone_navigation.models.feature_extractor_landuse", ]: sys.modules.setdefault(_name, types.ModuleType(_name)) # Wire the loaded class to the original pickle path so cloudpickle resolves it. sys.modules["drone_navigation.models.feature_extractor_landuse"].LanduseFeaturesExtractor = ( mod.LanduseFeaturesExtractor ) return mod.LanduseFeaturesExtractor # ────────────────────────────────────────────────────────────────────────────── # 2. Load the SB3 model # ────────────────────────────────────────────────────────────────────────────── def load_model(repo_id: str = REPO_ID, device: str = "cpu"): """Download and load the DQN model from Hugging Face. The feature extractor class is injected via ``custom_objects`` so that the drone-navigation package does NOT need to be installed. """ from huggingface_hub import hf_hub_download from stable_baselines3 import DQN FeaturesExtractor = _load_feature_extractor(repo_id) model = DQN.load( hf_hub_download(repo_id, "best_model.zip"), custom_objects={"features_extractor_class": FeaturesExtractor}, device=device, ) print(f"Model loaded from {repo_id}") print(f" Policy : {type(model.policy).__name__}") return model # ────────────────────────────────────────────────────────────────────────────── # 3a. Run inference — standalone dummy observation (no environment needed) # ────────────────────────────────────────────────────────────────────────────── def predict_dummy(model, frame_stack: int = 4, camera_size: int = 11): """Run one forward pass with a zeroed observation. Observation space (frame_stack=4, camera_size=11): camera : float32 (frame_stack, camera_size, camera_size) — landuse edge-map crop visited_mask : float32 (frame_stack, camera_size, camera_size) — cells visited this episode goal_info : float32 (24,) — [dist_norm, sin_θ, cos_θ, step_progress, steps_left_norm, reached] Action space: Discrete(4) — 0=right, 1=down, 2=left, 3=up """ obs = { "camera": np.zeros((frame_stack, camera_size, camera_size), dtype=np.float32), "visited_mask": np.zeros((frame_stack, camera_size, camera_size), dtype=np.float32), "goal_info": np.zeros(24, dtype=np.float32), } action, _ = model.predict(obs, deterministic=True) directions = {0: "right", 1: "down", 2: "left", 3: "up"} print(f"Dummy action: {action} ({directions.get(int(action), '?')})") return action # ────────────────────────────────────────────────────────────────────────────── # 3b. Run inference — full environment (requires drone-navigation package + data) # ────────────────────────────────────────────────────────────────────────────── def run_with_env(model, n_episodes: int = 3): """Run the model against the actual landuse environment. Requirements ------------ pip install git+https://gitlab.ciirc.cvut.cz/bestavoj/information-driven-uav-navigation.git # Landuse map (CZ_10m_3035_tiled.tif) must be present under data/ # Run: drone-prepare-data --only landuse """ from huggingface_hub import hf_hub_download from drone_navigation.config.experiment_config import ExperimentConfig from drone_navigation.envs.factory import create_env cfg = ExperimentConfig.from_json(hf_hub_download(REPO_ID, "config.json")) cfg.n_envs = 1 cfg.render_mode = None env = create_env(cfg) print(f"Environment : {type(env).__name__}") print(f"Obs space : {env.observation_space}") print(f"Action space: {env.action_space}") for ep in range(n_episodes): obs, _ = env.reset() total_reward = 0.0 done = False steps = 0 while not done: action, _ = model.predict(obs, deterministic=True) obs, reward, terminated, truncated, info = env.step(action.item()) total_reward += reward steps += 1 done = terminated or truncated print( f" Episode {ep + 1:2d}: steps={steps:4d}, " f"reward={total_reward:.1f}, " f"is_target_reached={info.get('is_target_reached', False)}" ) env.close() # ────────────────────────────────────────────────────────────────────────────── # Entry point # ────────────────────────────────────────────────────────────────────────────── if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Inference — Phase I HASP agent (landuse grid, DQN)") parser.add_argument( "--with-env", action="store_true", default=False, help="Run against the full drone-navigation environment instead of a dummy observation.", ) parser.add_argument("--n-episodes", type=int, default=3, help="Number of episodes (only used with --with-env).") parser.add_argument("--device", default="cpu", help="PyTorch device (cpu / cuda).") args = parser.parse_args() model = load_model(device=args.device) if args.with_env: run_with_env(model, n_episodes=args.n_episodes) else: predict_dummy(model)