# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """Asset Harvester Gradio demo — single-image upload to 3D Gaussian splat.""" from __future__ import annotations import gc import logging import os import random import tempfile import threading import uuid from functools import partial import gradio as gr import imageio import numpy as np import torch import torchvision.transforms as T from diffusers.schedulers import DPMSolverMultistepScheduler from huggingface_hub import snapshot_download from PIL import Image class _SpacesStub: @staticmethod def GPU(*args, **kwargs): def decorator(fn): return fn if args and callable(args[0]): return args[0] return decorator try: import spaces _HAS_SPACES = True except ImportError: _HAS_SPACES = False spaces = _SpacesStub() # type: ignore[assignment] if os.getenv("SPACE_ID") is None: _HAS_SPACES = False spaces = _SpacesStub() # type: ignore[assignment] logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) HF_CHECKPOINT_REPO = "nvidia/asset-harvester" CHECKPOINTS_DIR = "/app/checkpoints" if os.path.isdir("/app") else os.path.join(os.getcwd(), "checkpoints") MV_CKPT = "AH_multiview_diffusion.safetensors" TOKENGS_CKPT = "AH_tokengs_lifting.safetensors" AHC_CKPT = "AH_camera_estimator.safetensors" SEG_CKPT = "AH_object_seg_jit.pt" DEFAULT_NUM_STEPS = 30 DEFAULT_CFG_SCALE = 2.0 IMAGE_SIZE = 512 GRAY_VALUE = 128 SEG_INPUT_SIZE = (384, 384) MIN_MASK_AREA_FRAC = 0.01 MAX_MASK_AREA_FRAC = 0.95 MIN_UPLOAD_SIDE = 256 _MODELS_LOCK = threading.Lock() _MODELS: dict = {} _CKPT_PATHS: dict[str, str] = {} _SESSION_MVDATA: dict[str, object] = {} def _load_seg_estimator_class(): """Load Mask2FormerSegmentationEstimator directly from its source file, bypassing `asset_harvester.ncore_parser.__init__` which pulls in the private `ncore` module.""" import importlib.util import asset_harvester pkg_root = os.path.dirname(asset_harvester.__file__) source = os.path.join(pkg_root, "ncore_parser", "image_segmentation.py") spec = importlib.util.spec_from_file_location("_ah_image_segmentation", source) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module.Mask2FormerSegmentationEstimator def _download_checkpoints() -> None: if _CKPT_PATHS: return hf_token = os.getenv("HF_TOKEN") local_dir = snapshot_download( repo_id=HF_CHECKPOINT_REPO, allow_patterns=[MV_CKPT, TOKENGS_CKPT, AHC_CKPT, SEG_CKPT], local_dir=CHECKPOINTS_DIR, token=hf_token, ) for key, filename in (("mv", MV_CKPT), ("tokengs", TOKENGS_CKPT), ("ahc", AHC_CKPT), ("seg", SEG_CKPT)): path = os.path.join(local_dir, filename) if not os.path.isfile(path): raise FileNotFoundError(f"Missing {filename} in {local_dir}") _CKPT_PATHS[key] = path logger.info("Checkpoints ready in %s", local_dir) def _load_models(device: str) -> dict: with _MODELS_LOCK: if _MODELS: return _MODELS from asset_harvester.camera_estimator.inference import AHCEstimator from asset_harvester.multiview_diffusion.pipelines import SparseViewDiTPipeline from asset_harvester.multiview_diffusion.utils.model_builder import get_models from asset_harvester.tokengs.lifting_inference import TokengsLiftingRunner Mask2FormerSegmentationEstimator = _load_seg_estimator_class() _download_checkpoints() dtype = torch.bfloat16 if device.startswith("cuda") else torch.float32 logger.info("Loading MVD (+ VAE, c-radio)...") vae, cradio_model, cradio_image_processor, transformer = get_models( _CKPT_PATHS["mv"], device=device, dtype=dtype, ) scheduler = DPMSolverMultistepScheduler( num_train_timesteps=1000, beta_schedule="scaled_linear", prediction_type="flow_prediction", flow_shift=1.0, use_flow_sigmas=True, ) pipeline = SparseViewDiTPipeline( vae=vae, text_encoder=None, tokenizer=None, scheduler=scheduler, transformer=transformer, image_encoder=cradio_model, image_processor=cradio_image_processor, ).to(dtype) logger.info("Loading AHC (shared c-radio)...") ahc = AHCEstimator( checkpoint_path=_CKPT_PATHS["ahc"], device=device, cradio_model=cradio_model, cradio_image_processor=cradio_image_processor, ) logger.info("Loading segmentation JIT...") seg = Mask2FormerSegmentationEstimator( model_path=_CKPT_PATHS["seg"], device=device, input_size=SEG_INPUT_SIZE, ) logger.info("Loading TokenGS lifting...") lifter = TokengsLiftingRunner( _CKPT_PATHS["tokengs"], bbox_size=0.8, dtype=dtype, render_img_size=IMAGE_SIZE, ) _MODELS.update(pipeline=pipeline, ahc=ahc, seg=seg, lifter=lifter, dtype=dtype, device=device) return _MODELS def _segment(seg, image_pil: Image.Image) -> np.ndarray: """Return a uint8 binary mask at the native image resolution.""" _, instance_seg = seg.predict(image_pil) if len(instance_seg["classes"]) == 0: return np.zeros((image_pil.height, image_pil.width), dtype=np.uint8) mh, mw = SEG_INPUT_SIZE unpacked = np.unpackbits(instance_seg["instance_masks"]).reshape( len(instance_seg["classes"]), mh, mw, ) areas = unpacked.sum(axis=(1, 2)) biggest = unpacked[int(np.argmax(areas))].astype(np.uint8) * 255 mask_pil = Image.fromarray(biggest, mode="L").resize( (image_pil.width, image_pil.height), Image.NEAREST, ) return np.array(mask_pil) def _recenter_and_pad(image_pil: Image.Image, mask_np: np.ndarray) -> tuple[Image.Image, Image.Image]: """Translate image+mask so the mask centroid lands at frame center, square-pad, resize to 512. Image padding uses GRAY_VALUE (matches AHC's apply_mask background). Mask padding uses 0. Raises ValueError on degenerate masks. """ H, W = mask_np.shape ys, xs = np.where(mask_np > 0) if ys.size == 0: raise ValueError("No object detected in the input image. Try a cleaner photo with a single subject.") area_frac = ys.size / (H * W) if area_frac < MIN_MASK_AREA_FRAC: raise ValueError(f"Detected object is too small ({area_frac * 100:.1f}% of image).") if area_frac > MAX_MASK_AREA_FRAC: raise ValueError( f"Detected object fills nearly the whole image ({area_frac * 100:.1f}%); provide a wider-angle photo." ) y0, y1 = int(ys.min()), int(ys.max()) x0, x1 = int(xs.min()), int(xs.max()) if y0 == 0 and y1 == H - 1 and x0 == 0 and x1 == W - 1: raise ValueError("Object touches all four edges; provide an image showing the full object.") cy = float(ys.mean()) cx = float(xs.mean()) side_y = int(np.ceil(2 * max(cy, H - cy))) side_x = int(np.ceil(2 * max(cx, W - cx))) side = max(side_y, side_x, H, W) paste_y = side // 2 - int(round(cy)) paste_x = side // 2 - int(round(cx)) canvas_img = np.full((side, side, 3), GRAY_VALUE, dtype=np.uint8) canvas_msk = np.zeros((side, side), dtype=np.uint8) img_np = np.array(image_pil.convert("RGB")) canvas_img[paste_y : paste_y + H, paste_x : paste_x + W] = img_np canvas_msk[paste_y : paste_y + H, paste_x : paste_x + W] = mask_np out_img = Image.fromarray(canvas_img).resize((IMAGE_SIZE, IMAGE_SIZE), Image.BILINEAR) out_msk = Image.fromarray(canvas_msk).resize((IMAGE_SIZE, IMAGE_SIZE), Image.NEAREST) return out_img, out_msk def _run_image_guard(image_pil: Image.Image, device: str, dtype: torch.dtype) -> None: from asset_harvester.utils.image_guard import ImageGuard guard = ImageGuard(device=device, dtype=dtype) try: guard.load() result = guard.check_image(image_pil) finally: guard.unload() gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() if not result.passed: raise gr.Error(f"Image rejected by safety check (label={result.label}, score={result.score:.2f}).") def _build_mvdata(image_pil: Image.Image, mask_pil: Image.Image, ahc): from asset_harvester.multiview_diffusion.data.nre_preproc import MVData tmp = tempfile.mkdtemp(prefix="ah_upload_") frame_p = os.path.join(tmp, "frame_0.jpg") mask_p = os.path.join(tmp, "mask_0.png") image_pil.save(frame_p, quality=95) mask_pil.save(mask_p) cam_data = ahc.run([(frame_p, mask_p)]) return MVData( clip_id="upload", obj_id="0", frames=[np.array(image_pil)], cam_poses=np.array(cam_data["cam_poses"], dtype=np.float32), dists=np.array(cam_data["dists"], dtype=np.float32), fov=np.array(cam_data["fov"], dtype=np.float32), npct="vehicle", lwh=np.array(cam_data["lwh"], dtype=np.float32), masks=[np.array(mask_pil)], auto_label=None, ) def _encode_mp4(frames_np, path: str, fps: int = 24) -> None: imageio.v2.mimwrite(path, frames_np, fps=fps, macro_block_size=1) @spaces.GPU(duration=60) def run_segmentation(image_pil, is_example: bool = False, progress=gr.Progress()): """First stage: safety check + segmentation + recentering + camera estimation. Returns (mask_preview, state) where state is handed to `run_3d`. Progress shown only on the mask image output. """ if image_pil is None: raise gr.Error("Please upload an image.") if min(image_pil.size) < MIN_UPLOAD_SIDE: raise gr.Error(f"Image too small ({image_pil.size[0]}x{image_pil.size[1]}); min {MIN_UPLOAD_SIDE}px per side.") device = "cuda" if torch.cuda.is_available() else "cpu" progress(0.1, desc="Loading models…") models = _load_models(device) dtype = models["dtype"] image_pil = image_pil.convert("RGB") if is_example: progress(0.3, desc="Skipping safety check (curated example)…") else: progress(0.3, desc="Running safety check…") _run_image_guard(image_pil, device, dtype) progress(0.6, desc="Segmenting object…") mask_np = _segment(models["seg"], image_pil) progress(0.8, desc="Recentering and estimating camera…") try: centered_img, centered_mask = _recenter_and_pad(image_pil, mask_np) except ValueError as e: raise gr.Error(str(e)) rgb = np.array(image_pil) fg = (mask_np > 0).astype(np.uint8)[:, :, None] mask_preview = Image.fromarray(np.where(fg, rgb, np.full_like(rgb, GRAY_VALUE)).astype(np.uint8)) mvdata = _build_mvdata(centered_img, centered_mask, models["ahc"]) uid = str(uuid.uuid4()) _SESSION_MVDATA[uid] = mvdata progress(1.0, desc="Done") return mask_preview, uid @spaces.GPU(duration=180) def run_3d(state, progress=gr.Progress()): """Second stage: multiview diffusion + TokenGS lifting. Returns (orbit_mp4_path, ply_path) matching outputs=[video_out, ply_out]. """ if not state or state not in _SESSION_MVDATA: raise gr.Error("Segmentation must run first.") device = "cuda" if torch.cuda.is_available() else "cpu" models = _load_models(device) pipeline = models["pipeline"] lifter = models["lifter"] mvdata = _SESSION_MVDATA.pop(state) from asset_harvester.multiview_diffusion.data.inference_utils import build_eval_cams from asset_harvester.multiview_diffusion.data.nre_preproc import preproc progress(0.05, desc="Preparing multiview conditioning…") transform = T.Compose( [T.Resize(IMAGE_SIZE), T.ToTensor(), T.Normalize([0.5], [0.5])] ) inference_preproc = partial( preproc, image_transform=transform, resolution=IMAGE_SIZE, conditioning_mode="n", eval_mode=True, eval_cam_sampler=build_eval_cams, ) data_dict = inference_preproc(mvdata) max_length = data_dict.n_target + min(4, len(data_dict.x) - data_dict.n_target) for attr in ("x", "c2w_relatives", "x_white_background", "dists", "fovs", "plucker_image", "relative_brightness"): if hasattr(data_dict, attr): setattr(data_dict, attr, getattr(data_dict, attr)[:max_length]) if hasattr(data_dict, "intrinsics") and data_dict.intrinsics.shape[0] > max_length: data_dict.intrinsics = data_dict.intrinsics[:max_length] progress(0.15, desc="Generating multiview images…") with torch.no_grad(): output = pipeline( data_dict=data_dict, num_inference_steps=DEFAULT_NUM_STEPS, guidance_scale=DEFAULT_CFG_SCALE, flow_shift=1.0, output_type="pil", ) images_np = [np.array(img) for img in output["images"]] progress(0.55, desc="Lifting to 3D Gaussian splat…") output_dir = tempfile.mkdtemp(prefix="ah_out_") offload_ok = False try: if torch.cuda.is_available(): for name in ("vae", "transformer", "image_encoder"): m = getattr(pipeline, name, None) if m is not None: m.to("cpu") pipeline.to("cpu") offload_ok = True gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() fov = float(data_dict.fovs[0].item()) dist = float(data_dict.dists[0].item()) lwh = data_dict.lwh if hasattr(data_dict, "lwh") and data_dict.lwh is not None else [1.0, 1.0, 1.0] with torch.no_grad(): gaussians = lifter.run_lifting(images_np, fov, dist, lwh) progress(0.85, desc="Rendering orbit views of the lifted splat…") with torch.no_grad(): rendered = lifter.render_orbit_views(gaussians, fov, dist, lwh) rendered_np = [im.permute(1, 2, 0).cpu().numpy() for im in rendered] orbit_mp4 = os.path.join(output_dir, "lifting.mp4") _encode_mp4(rendered_np, orbit_mp4) progress(0.95, desc="Saving Gaussian splat…") ply_path = os.path.join(output_dir, "gaussians.ply") lifter.save_ply(gaussians, ply_path) finally: if offload_ok and torch.cuda.is_available(): for name in ("vae", "transformer", "image_encoder"): m = getattr(pipeline, name, None) if m is not None: m.to(device) pipeline.to(device) gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() progress(1.0, desc="Done") return orbit_mp4, ply_path HEADER_MD = """ ## Image to 3D Asset with [Asset Harvester](https://github.com/NVIDIA/asset-harvester) [**Paper**](https://arxiv.org/abs/2604.18468) | [**Project Page**](https://research.nvidia.com/labs/sil/projects/asset-harvester/) | [**Code**](https://github.com/NVIDIA/asset-harvester) | [**Model**](https://huggingface.co/nvidia/asset-harvester) | [**Data**](https://huggingface.co/datasets/nvidia/PhysicalAI-Autonomous-Vehicles-NCore) **Upload a single image of one object — a vehicle, pedestrian, cyclist, or other road object — to generate a 3D Gaussian splat asset. The assumed inputs are images cropped and rectified from AV datasets, like the example images below. However, you can also challenge the model with internet photos.** The inference pipeline consists of: - **Object Segmentation** — isolates the object from the background. - **Camera Estimation** — predicts the viewing direction, distance, field of view, and object dimensions. - **Multiview Diffusion** — generates 16 novel orbit views. - **3D Lifting** — reconstructs the generated views into a 3D Gaussian splat (downloadable PLY). """ def build_ui(): theme = gr.themes.Default(primary_hue="green", neutral_hue="slate") app_css = """ /* Base typography */ .gradio-container { font-size: 20px !important; } .gradio-container .prose p, .gradio-container .prose li, .gradio-container .md p, .gradio-container .md li { font-size: 1.2rem !important; line-height: 1.6 !important; } .gradio-container .prose h2, .gradio-container .md h2 { font-size: 2rem !important; } .gradio-container .block-label, .gradio-container button { font-size: 1.1rem !important; } /* Fluid media — images/videos fill their column, keep aspect ratio */ .gradio-container .image-container img, .gradio-container .video-container video { max-width: 100% !important; max-height: 100% !important; width: auto !important; height: auto !important; object-fit: contain !important; } .gradio-container .image-container, .gradio-container .video-container { display: flex !important; align-items: center !important; justify-content: center !important; } /* Narrow viewports: let columns wrap instead of cramming */ @media (max-width: 1024px) { .gradio-container .prose h2, .gradio-container .md h2 { font-size: 1.7rem !important; } .gradio-container .prose p, .gradio-container .prose li, .gradio-container .md p, .gradio-container .md li { font-size: 1.1rem !important; } } @media (max-width: 720px) { .gradio-container { font-size: 18px !important; } /* Force columns in the main Row to take full width, stack vertically */ .gradio-container .grid-wrap { grid-template-columns: 1fr !important; } } """ with gr.Blocks(title="Asset Harvester", css=app_css) as demo: gr.Markdown(HEADER_MD) image_in = gr.Image( label="Image Prompt", type="pil", height=360, sources=["upload", "clipboard"], render=False, ) examples_dir = os.path.join(os.path.dirname(__file__), "examples") all_examples = [ [os.path.join(examples_dir, f)] for f in sorted(os.listdir(examples_dir)) if f.lower().endswith((".jpeg", ".jpg", ".png")) ] with gr.Row(): with gr.Column(scale=2, min_width=200): examples_ds = gr.Dataset( components=[image_in], samples=all_examples, samples_per_page=18, label="Example images", ) with gr.Column(scale=4, min_width=360): image_in.render() gr.Markdown( "**Notes:**\n\n" "* **For best results, please upload clear, object-centric images " "where the camera is level with the object, similar to rectified " "ego-viewpoint images in our AV setting.**\n" "* The uploaded images are screened with " "[Llama Guard 3 Vision](https://huggingface.co/meta-llama/Llama-Guard-3-11B-Vision) " "to filter out harmful content." ) run_btn = gr.Button("Generate 3D Asset", variant="primary") gr.Markdown( "
" "Disclaimer: Asset Harvester is trained for the AV domain, " "and its performance is not guaranteed on arbitrary images." "
" ) with gr.Column(scale=5, min_width=400): mask_out = gr.Image( label="Object Segmentation", type="pil", height=400, ) video_out = gr.Video( label="3D Gaussian Splat — Orbit Render", height=400, autoplay=True, loop=True, ) ply_out = gr.DownloadButton( label="Download PLY", ) stage_state = gr.State() is_example = gr.State(False) def _pick_example(sample): return sample[0] if isinstance(sample, (list, tuple)) else sample examples_ds.click( _pick_example, inputs=examples_ds, outputs=image_in ).then(lambda: True, outputs=is_example) image_in.input(lambda: False, outputs=is_example) image_in.clear(lambda: False, outputs=is_example) def _shuffled_examples(): shuffled = all_examples.copy() random.shuffle(shuffled) return gr.update(samples=shuffled) demo.load(_shuffled_examples, inputs=None, outputs=examples_ds) run_btn.click( fn=run_segmentation, inputs=[image_in, is_example], outputs=[mask_out, stage_state], show_progress="full", show_progress_on=[mask_out], concurrency_id="seg", concurrency_limit=2, ).then( fn=run_3d, inputs=[stage_state], outputs=[video_out, ply_out], show_progress="full", concurrency_id="gpu3d", concurrency_limit=1, ) demo.queue(default_concurrency_limit=1, max_size=30) return demo, theme def _prefetch_all(device: str) -> None: """Warm checkpoints and load the main pipeline models into memory at startup. Image guard (Llama Guard 3 Vision) weights are prefetched to disk cache only — they are load/unloaded per-call because the model is large (~22 GB on GPU). """ logger.info("Prefetching asset-harvester checkpoints...") _download_checkpoints() logger.info("Loading pipeline / AHC / segmentation / TokenGS into memory...") _load_models(device) logger.info("Prefetching Llama Guard 3 Vision weights to disk cache...") try: snapshot_download( repo_id="meta-llama/Llama-Guard-3-11B-Vision", allow_patterns=["*.json", "*.safetensors", "*.txt", "*.model", "tokenizer*"], token=os.getenv("HF_TOKEN"), ) logger.info("Image guard weights cached.") except Exception as e: logger.warning( "Could not prefetch Llama Guard weights (will download on first safety check): %s", e, ) logger.info("Startup prefetch complete.") if os.getenv("AH_PREFETCH", "1") == "1": _startup_device = "cuda" if torch.cuda.is_available() else "cpu" _prefetch_all(_startup_device) demo, _theme = build_ui() if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, max_threads=40, theme=_theme)