"""PG-MAP demo Space — Gradio app for SD 1.5 / SDXL / SD3.5-medium. Single-file Gradio app deployed at https://huggingface.co/spaces/sophialan/pg-map-demo. Pipeline is loaded lazily on first use (per backbone) so the Space spins up without immediately downloading 7 GB of weights. Backbones are unloaded when the user switches via the dropdown to keep VRAM under the A10G small-tier 24 GB budget. """ from __future__ import annotations import gc import os from dataclasses import replace import gradio as gr import torch # Lazy loading: only instantiate when the user picks a backbone. _PIPE = {"backbone": None, "obj": None} def _load(backbone: str): """Load (or swap) the appropriate PG-MAP pipeline.""" if _PIPE["backbone"] == backbone and _PIPE["obj"] is not None: return _PIPE["obj"] # Free the previous backbone first if _PIPE["obj"] is not None: del _PIPE["obj"] _PIPE["obj"] = None gc.collect() torch.cuda.empty_cache() from diffusers import DiffusionPipeline spec = { "SD 1.5 (512²)": ("stable-diffusion-v1-5/stable-diffusion-v1-5", "sophialan/pg-map-sd15", {}), "SDXL (1024²)": ("stabilityai/stable-diffusion-xl-base-1.0","sophialan/pg-map-sdxl", {"variant": "fp16"}), "SD3.5-medium (1024²)": ("stabilityai/stable-diffusion-3.5-medium", "sophialan/pg-map-sd3", {}), }[backbone] model_id, custom_pipe, extra = spec pipe = DiffusionPipeline.from_pretrained( model_id, custom_pipeline=custom_pipe, torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False, **extra, ).to("cuda" if torch.cuda.is_available() else "cpu") _PIPE["backbone"] = backbone _PIPE["obj"] = pipe return pipe def generate(prompt, backbone, seed, steps, guidance, lambda_reward, eta_z, K_inner, enable_pgmap, progress=gr.Progress()): """Run a single generation.""" if not prompt or not prompt.strip(): return None, "Please enter a prompt." if not torch.cuda.is_available(): return None, ("This Space is running on the free **CPU-basic** tier, which cannot run " "CUDA generation. To generate images, duplicate this Space (or open " "**Settings → Hardware**) and select a GPU runtime — A10G or larger " "(SD 1.5 fits on T4).") progress(0.0, desc=f"Loading {backbone}…") try: pipe = _load(backbone) except Exception as e: return None, f"Pipeline load failed: {e!r}" progress(0.2, desc="Running PG-MAP…" if enable_pgmap else "Running baseline…") g = torch.Generator(device="cuda").manual_seed(int(seed)) if not enable_pgmap: out = pipe( prompt=prompt, num_inference_steps=int(steps), guidance_scale=float(guidance), generator=g, ) return out.images[0], f"Vanilla {backbone} baseline (no PG-MAP)." # PG-MAP path from pgmap_config import ( sd15_defaults, sdxl_defaults, ) presets = { "SD 1.5 (512²)": sd15_defaults, "SDXL (1024²)": sdxl_defaults, "SD3.5-medium (1024²)": sdxl_defaults, # SD3.5 reads K_inner / eta_z from config too } cfg = presets[backbone]() cfg = replace(cfg, num_steps=int(steps), seed=int(seed), guidance_scale=float(guidance)) cfg.refinement.K = int(K_inner) cfg.refinement.eta_z = float(eta_z) cfg.reward.lambda_reward = float(lambda_reward) # For SD3.5 default to UG-FM (z-only, data-side); to switch to full PG-MAP-FM, set optimize_c. if backbone.startswith("SD3.5"): cfg.optimize_c = False cfg.optimize_z = True out = pipe( prompt=prompt, pg_map_config=cfg, num_inference_steps=int(steps), guidance_scale=float(guidance), ) return out.images[0], f"PG-MAP on {backbone}: λ={lambda_reward}, η_z={eta_z}, K={K_inner}" DESCRIPTION = """\ # PG-MAP Demo  ·  NeurIPS 2026 (under review) **Inference-time alignment for diffusion + flow-matching** — re-optimize the conditioning $c$ and the latent $z_t$ at every denoising step under a trajectory-level Gibbs-MAP / proximal energy objective. No training required. 🔗 Code: [github.com/sophialanlan/PG-MAP](https://github.com/sophialanlan/PG-MAP) · Paper: [arXiv:2606.22958](https://arxiv.org/abs/2606.22958) · HF Pipelines: [sd15](https://huggingface.co/sophialan/pg-map-sd15) · [sdxl](https://huggingface.co/sophialan/pg-map-sdxl) · [sd3](https://huggingface.co/sophialan/pg-map-sd3) Pick a backbone, write a prompt, hit **Generate**. Toggle PG-MAP off to compare against the static baseline at the same seed. Default hyperparameters match the paper table; the sliders expose the productive ranges. """ EXAMPLES = [ ["a phoenix rising from ashes, vivid orange and red feathers, dramatic lighting"], ["a tea cup with a tiny galaxy swirling inside"], ["a cinematic photo of a red panda astronaut in a white space suit"], ["an old sailboat sailing through a thunderstorm with massive lightning bolts overhead"], ["a swordsman mid-leap slashing through a glowing magical barrier"], ] def build_app(): with gr.Blocks(title="PG-MAP Demo · NeurIPS 2026 (under review)", theme=gr.themes.Soft()) as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=2): prompt = gr.Textbox(label="Prompt", lines=2, max_lines=4, placeholder="a phoenix rising from ashes…") with gr.Row(): backbone = gr.Dropdown( ["SD 1.5 (512²)", "SDXL (1024²)", "SD3.5-medium (1024²)"], value="SDXL (1024²)", label="Backbone", ) enable = gr.Checkbox(value=True, label="Enable PG-MAP (uncheck = vanilla baseline)") with gr.Accordion("Generation settings", open=False): seed = gr.Number(value=42, label="Seed", precision=0) steps = gr.Slider(8, 60, value=30, step=1, label="Denoising steps") guidance = gr.Slider(1.0, 15.0, value=5.0, step=0.5, label="CFG scale") with gr.Accordion("PG-MAP hyperparameters", open=False): lambda_reward = gr.Slider(0.0, 0.5, value=0.10, step=0.01, label="λ (reward weight)") eta_z = gr.Slider(0.0, 0.5, value=0.005, step=0.001, label="η_z (latent step size)") K_inner = gr.Slider(1, 6, value=2, step=1, label="K (inner gradient steps per denoising step)") btn = gr.Button("Generate", variant="primary") gr.Examples(EXAMPLES, inputs=prompt) with gr.Column(scale=3): out_img = gr.Image(label="Output", height=512) out_status = gr.Markdown() btn.click( generate, inputs=[prompt, backbone, seed, steps, guidance, lambda_reward, eta_z, K_inner, enable], outputs=[out_img, out_status], ) return demo if __name__ == "__main__": build_app().queue().launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))