"""Gradio web demo for satellite change detection. Upload before/after satellite image pairs, select a model, and view the predicted change mask, overlay, and change-area statistics. Auto-detects available checkpoints — no manual path entry needed. Usage: python app.py """ import logging import os from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import gradio as gr import numpy as np import torch import yaml from huggingface_hub import hf_hub_download from data.dataset import IMAGENET_MEAN, IMAGENET_STD from inference import sliding_window_inference from models import get_model from utils.visualization import overlay_changes logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Globals # --------------------------------------------------------------------------- _cached_model: Optional[torch.nn.Module] = None _cached_model_key: Optional[str] = None _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") _config: Optional[Dict[str, Any]] = None _hf_model_repo_id: Optional[str] = os.getenv("HF_MODEL_REPO") _hf_model_revision: Optional[str] = os.getenv("HF_MODEL_REVISION") # Search these directories for checkpoint files _CHECKPOINT_SEARCH_DIRS = [ Path("checkpoints"), Path("/home/user/app/checkpoints"), Path("/kaggle/working/checkpoints"), Path("/content/drive/MyDrive/change-detection/checkpoints"), ] # Map model names to expected checkpoint filenames _MODEL_CHECKPOINT_NAMES = { "siamese_cnn": "siamese_cnn_best.pth", "unet_pp": "unet_pp_best.pth", "changeformer": "changeformer_best.pth", } def _download_checkpoint_from_hf(model_name: str) -> Optional[Path]: """Download checkpoint from Hugging Face Hub if configured. Uses env var ``HF_MODEL_REPO`` as the source model repository and downloads to ``./checkpoints`` cache. Args: model_name: One of the supported model keys. Returns: Local path to downloaded checkpoint, or ``None`` if unavailable. """ if not _hf_model_repo_id: return None filename = _MODEL_CHECKPOINT_NAMES.get(model_name) if filename is None: return None try: local_path = hf_hub_download( repo_id=_hf_model_repo_id, filename=filename, revision=_hf_model_revision, local_dir="checkpoints", local_dir_use_symlinks=False, ) logger.info("Downloaded %s from %s", filename, _hf_model_repo_id) return Path(local_path) except Exception as exc: # pragma: no cover - best-effort fallback logger.warning("Could not download %s from HF Hub: %s", filename, exc) return None # --------------------------------------------------------------------------- # Config / model loading # --------------------------------------------------------------------------- def _load_config() -> Dict[str, Any]: """Load and cache the project config. Returns: Full config dict. """ global _config if _config is None: config_path = Path("configs/config.yaml") with open(config_path, "r") as fh: _config = yaml.safe_load(fh) return _config def _find_checkpoint(model_name: str) -> Optional[Path]: """Auto-detect the checkpoint file for a given model. Searches multiple directories for the expected checkpoint filename. Args: model_name: One of ``siamese_cnn``, ``unet_pp``, ``changeformer``. Returns: Path to the checkpoint if found, ``None`` otherwise. """ filename = _MODEL_CHECKPOINT_NAMES.get(model_name) if filename is None: return None for search_dir in _CHECKPOINT_SEARCH_DIRS: candidate = search_dir / filename if candidate.exists(): return candidate downloaded = _download_checkpoint_from_hf(model_name) if downloaded is not None and downloaded.exists(): return downloaded return None def _get_available_models() -> List[str]: """Return a list of model names that have checkpoints available. Returns: List of model name strings with detected checkpoints. """ available = [] for model_name in _MODEL_CHECKPOINT_NAMES: if _find_checkpoint(model_name) is not None: available.append(model_name) return available def _load_model(model_name: str) -> torch.nn.Module: """Load a model using auto-detected checkpoint. Args: model_name: Architecture name. Returns: Model in eval mode on the current device. Raises: FileNotFoundError: If no checkpoint is found. """ global _cached_model, _cached_model_key if _cached_model is not None and _cached_model_key == model_name: return _cached_model ckpt_path = _find_checkpoint(model_name) if ckpt_path is None: raise FileNotFoundError( f"No checkpoint found for '{model_name}'. " f"Expected '{_MODEL_CHECKPOINT_NAMES[model_name]}' in one of: " f"{[str(d) for d in _CHECKPOINT_SEARCH_DIRS]}" ) config = _load_config() model = get_model(model_name, config).to(_device) ckpt = torch.load(ckpt_path, map_location=_device) model.load_state_dict(ckpt["model_state_dict"]) model.eval() _cached_model = model _cached_model_key = model_name logger.info("Loaded %s from %s", model_name, ckpt_path) return model # --------------------------------------------------------------------------- # Preprocessing # --------------------------------------------------------------------------- def _numpy_to_tensor( img: np.ndarray, patch_size: int = 256, ) -> Tuple[torch.Tensor, Tuple[int, int]]: """Convert a uint8 RGB numpy image to a normalised, padded tensor. Args: img: Input image ``[H, W, 3]``, uint8, RGB. patch_size: Pad to a multiple of this value. Returns: Tuple of ``(tensor [1, 3, H_pad, W_pad], (orig_h, orig_w))``. """ orig_h, orig_w = img.shape[:2] pad_h = (patch_size - orig_h % patch_size) % patch_size pad_w = (patch_size - orig_w % patch_size) % patch_size if pad_h > 0 or pad_w > 0: img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect") img_f = img.astype(np.float32) / 255.0 mean = np.array(IMAGENET_MEAN, dtype=np.float32) std = np.array(IMAGENET_STD, dtype=np.float32) img_f = (img_f - mean) / std tensor = torch.from_numpy(img_f).permute(2, 0, 1).unsqueeze(0).float() return tensor, (orig_h, orig_w) # --------------------------------------------------------------------------- # Prediction # --------------------------------------------------------------------------- def predict( before_image: Optional[np.ndarray], after_image: Optional[np.ndarray], model_name: str, threshold: float, ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], str]: """Run change detection and return visualisations + summary text. Args: before_image: Before image as numpy ``[H, W, 3]`` RGB uint8. after_image: After image as numpy ``[H, W, 3]`` RGB uint8. model_name: Architecture name. threshold: Binarisation threshold for predictions. Returns: Tuple of ``(change_mask, overlay_image, summary_text)``. """ if before_image is None or after_image is None: return None, None, "Please upload both **before** and **after** images." config = _load_config() patch_size: int = config.get("dataset", {}).get("patch_size", 256) # Load model (auto-detects checkpoint) try: model = _load_model(model_name) except FileNotFoundError as exc: return None, None, f"**Error:** {exc}" # Preprocess tensor_a, (orig_h, orig_w) = _numpy_to_tensor(before_image, patch_size) tensor_b, _ = _numpy_to_tensor(after_image, patch_size) # Tiled inference prob_map = sliding_window_inference(model, tensor_a, tensor_b, patch_size, _device) prob_map = prob_map[:, :, :orig_h, :orig_w] prob_np = prob_map.squeeze().numpy() # Binary change mask binary_mask = (prob_np > threshold).astype(np.uint8) * 255 # Overlay on after image pred_tensor = (prob_map.squeeze(0) >= threshold).float() img_b_tensor = tensor_b.squeeze()[:, :orig_h, :orig_w] overlay_rgb = overlay_changes( img_after=img_b_tensor, mask_pred=pred_tensor, alpha=0.4, color=(255, 0, 0), ) # Change statistics total_pixels = orig_h * orig_w changed_pixels = int(binary_mask.sum() // 255) pct_changed = (changed_pixels / total_pixels) * 100.0 ckpt_path = _find_checkpoint(model_name) summary = ( f"### Change Detection Results\n\n" f"| Metric | Value |\n" f"|---|---|\n" f"| **Model** | {model_name} |\n" f"| **Image size** | {orig_w} x {orig_h} |\n" f"| **Total pixels** | {total_pixels:,} |\n" f"| **Changed pixels** | {changed_pixels:,} |\n" f"| **Area changed** | {pct_changed:.2f}% |\n" f"| **Threshold** | {threshold} |\n" f"| **Checkpoint** | {ckpt_path.name if ckpt_path else 'N/A'} |\n" f"| **Device** | {_device} |" ) return binary_mask, overlay_rgb, summary # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- def build_demo() -> gr.Blocks: """Construct the Gradio Blocks interface. Returns: A ``gr.Blocks`` application ready to ``.launch()``. """ available = _get_available_models() all_models = list(_MODEL_CHECKPOINT_NAMES.keys()) # Show which models are available status_lines = [] for m in all_models: ckpt = _find_checkpoint(m) if ckpt: status_lines.append(f"- **{m}**: {ckpt.name}") else: status_lines.append(f"- **{m}**: not found") model_status = "\n".join(status_lines) default_model = available[0] if available else "changeformer" with gr.Blocks(title="Military Base Change Detection") as demo: gr.Markdown( "# Military Base Change Detection\n" "Upload **before** and **after** satellite images to detect " "construction, infrastructure changes, and runway development.\n\n" "**Available models:**\n" + model_status ) # ---- Inputs --------------------------------------------------- with gr.Row(): with gr.Column(scale=1): before_img = gr.Image( label="Before Image (older)", type="numpy", sources=["upload", "clipboard"], ) with gr.Column(scale=1): after_img = gr.Image( label="After Image (newer)", type="numpy", sources=["upload", "clipboard"], ) # ---- Controls ------------------------------------------------- with gr.Row(): model_dropdown = gr.Dropdown( choices=available if available else all_models, value=default_model, label="Model Architecture", ) threshold_slider = gr.Slider( minimum=0.1, maximum=0.9, value=0.5, step=0.05, label="Detection Threshold", ) detect_btn = gr.Button("Detect Changes", variant="primary", size="lg") # ---- Outputs -------------------------------------------------- with gr.Row(): with gr.Column(scale=1): change_mask_out = gr.Image(label="Change Mask") with gr.Column(scale=1): overlay_out = gr.Image(label="Overlay (changes in red)") summary_out = gr.Markdown(label="Summary") # ---- Wiring --------------------------------------------------- detect_btn.click( fn=predict, inputs=[before_img, after_img, model_dropdown, threshold_slider], outputs=[change_mask_out, overlay_out, summary_out], ) return demo # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- def main() -> None: """Launch the Gradio demo server.""" logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) config = _load_config() gradio_cfg = config.get("gradio", {}) demo = build_demo() in_hf_space = os.getenv("SPACE_ID") is not None demo.launch( server_name="0.0.0.0" if in_hf_space else "127.0.0.1", server_port=gradio_cfg.get("server_port", 7860), share=False if in_hf_space else gradio_cfg.get("share", False), ) if __name__ == "__main__": main()