#!/usr/bin/env python3 """Run the Embedl-quantized SAM-3D-Body backbone from the `.pt2` graph. Loads `embedl_sam3dbody_int8.pt2` with `torch.export.load`, runs it on a person crop, reports the feature-map statistics + latency, and saves a PCA visualization of the 16x16 patch features (the classic DINOv3 "what the backbone sees" image). pip install torch pillow numpy python infer_pt2.py --image sample_input.png --save-pca features_pca.png """ from __future__ import annotations import argparse import time from pathlib import Path import numpy as np import torch from PIL import Image HERE = Path(__file__).resolve().parent PT2 = HERE / "embedl_sam3dbody_int8.pt2" INPUT_SIZE = 512 # DINOv3 patch_size 16 -> 16x16 tokens IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) def preprocess(path: Path) -> torch.Tensor: """Person crop -> (1, 3, 256, 256) ImageNet-normalized float32 tensor.""" img = Image.open(path).convert("RGB").resize((INPUT_SIZE, INPUT_SIZE), Image.BILINEAR) arr = np.asarray(img, dtype=np.float32) / 255.0 arr = (arr - IMAGENET_MEAN) / IMAGENET_STD return torch.from_numpy(arr.transpose(2, 0, 1)[None]) # (1, 3, H, W) def feature_pca_rgb(feats: np.ndarray, out_hw: int = 256) -> Image.Image: """(C, H, W) feature map -> RGB image via PCA of the patch descriptors.""" c, h, w = feats.shape x = feats.reshape(c, h * w).T # (H*W, C) x = x - x.mean(0, keepdims=True) # top-3 principal components via SVD _, _, vt = np.linalg.svd(x, full_matrices=False) comps = x @ vt[:3].T # (H*W, 3) lo, hi = np.percentile(comps, 2, axis=0), np.percentile(comps, 98, axis=0) comps = np.clip((comps - lo) / (hi - lo + 1e-8), 0, 1) rgb = (comps.reshape(h, w, 3) * 255).astype(np.uint8) return Image.fromarray(rgb).resize((out_hw, out_hw), Image.NEAREST) def main() -> None: ap = argparse.ArgumentParser(description=__doc__) ap.add_argument("--image", default=str(HERE / "sample_input.png")) ap.add_argument("--pt2", default=str(PT2)) ap.add_argument("--save-pca", default=None, help="path to write the PCA feature image") ap.add_argument("--iters", type=int, default=20) args = ap.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" print(f"device: {device}") print(f"loading {Path(args.pt2).name} (this can take 10-30 s) ...") # A loaded ExportedProgram's module is already inference-captured; some torch # versions raise NotImplementedError on .eval(), so don't call it. module = torch.export.load(args.pt2).module().to(device) x = preprocess(Path(args.image)).to(device) print(f"input: {tuple(x.shape)} {x.dtype} range [{x.min():.3f}, {x.max():.3f}]") with torch.no_grad(): feats = module(x) feats = feats.float().cpu() print(f"features: {tuple(feats.shape)} mean={feats.mean():.4f} std={feats.std():.4f}") # latency times = [] with torch.no_grad(): for i in range(args.iters): if device == "cuda": torch.cuda.synchronize() t0 = time.perf_counter() module(x) if device == "cuda": torch.cuda.synchronize() times.append((time.perf_counter() - t0) * 1000) if len(times) > 1: s = sorted(times[1:]) print(f"latency: median {s[len(s) // 2]:.1f} ms over {len(times) - 1} runs (excl. warmup)") if args.save_pca: img = feature_pca_rgb(feats[0].numpy()) img.save(args.save_pca) print(f"wrote PCA feature visualization -> {args.save_pca}") if __name__ == "__main__": main()