#!/usr/bin/env python """ Gradio UI for single-event MLPF inference. Launch with: python app.py [--device cpu] The UI lets you: 1. Load an event from a parquet file (pick file + event index), **or** paste hit / track / particle data in CSV format. 2. (Optionally) load pre-trained model checkpoints. 3. Run inference → view predicted particles and the hit→cluster mapping. """ import argparse import os import shutil import traceback import gradio as gr import pandas as pd import numpy as np import plotly.graph_objects as go from huggingface_hub import hf_hub_download # --------------------------------------------------------------------------- # Auto-download demo files from Hugging Face Hub if they are not present # --------------------------------------------------------------------------- _HF_REPO_ID = "gregorkrzmanc/hitpf_demo_files" _DEMO_FILES = [ "model_clustering.ckpt", "model_e_pid.ckpt", "test_data.parquet", ] def _ensure_demo_files(dest_dir: str = ".") -> None: """Download demo files from Hugging Face Hub if they don't already exist.""" for fname in _DEMO_FILES: dest = os.path.join(dest_dir, fname) if not os.path.isfile(dest): try: print(f"Downloading {fname} from HF Hub ({_HF_REPO_ID}) …") downloaded = hf_hub_download( repo_id=_HF_REPO_ID, filename=fname, repo_type="dataset", ) shutil.copy(downloaded, dest) print(f" → saved to {dest}") except Exception as exc: print(f" ⚠️ Could not download {fname}: {exc}") _ensure_demo_files() # --------------------------------------------------------------------------- # Global state – filled lazily # --------------------------------------------------------------------------- _MODEL = None _ARGS = None _DEVICE = "cpu" def _set_device(device: str): global _DEVICE _DEVICE = device # --------------------------------------------------------------------------- # Model loading # --------------------------------------------------------------------------- def load_model_ui(clustering_ckpt: str, energy_pid_ckpt: str, device: str): """Load model from checkpoint paths (called by the UI button).""" global _MODEL, _ARGS, _DEVICE _DEVICE = device or "cpu" if not clustering_ckpt or not os.path.isfile(clustering_ckpt): return "⚠️ Please provide a valid path to the clustering checkpoint." energy_pid = energy_pid_ckpt if (energy_pid_ckpt and os.path.isfile(energy_pid_ckpt)) else None try: from src.inference import load_model _MODEL, _ARGS = load_model( clustering_ckpt=clustering_ckpt, energy_pid_ckpt=energy_pid, device=_DEVICE, ) msg = f"✅ Model loaded on **{_DEVICE}**" if energy_pid: msg += " (clustering + energy/PID correction)" else: msg += " (clustering only — no energy/PID correction)" return msg except Exception: return f"❌ Failed to load model:\n```\n{traceback.format_exc()}\n```" # --------------------------------------------------------------------------- # Event loading helpers # --------------------------------------------------------------------------- def _count_events_in_parquet(parquet_path: str) -> str: """Return a short info string about the parquet file.""" if not parquet_path or not os.path.isfile(parquet_path): return "No file selected" try: from src.inference import load_event_from_parquet from src.data.fileio import _read_parquet table = _read_parquet(parquet_path) n = len(table["X_track"]) return f"File has **{n}** events (indices 0–{n-1})" except Exception as e: return f"Error reading file: {e}" def _load_event_into_csv(parquet_path: str, event_index: int): """Load an event from a parquet file and return CSV strings for the text fields.""" if not parquet_path or not os.path.isfile(parquet_path): return "", "", "", "", "", "⚠️ Please provide a valid parquet file path." try: from src.inference import load_event_from_parquet event = load_event_from_parquet(parquet_path, int(event_index)) hits_arr = np.asarray(event.get("X_hit", [])) tracks_arr = np.asarray(event.get("X_track", [])) particles_arr = np.asarray(event.get("X_gen", [])) pandora_arr = np.asarray(event.get("X_pandora", [])) def _arr_to_csv(arr): if arr.ndim != 2: return "" return "\n".join(",".join(str(v) for v in row) for row in arr) def _1d_to_csv(arr): if len(arr) == 0: return "" return ",".join(str(int(v)) for v in arr) pfo_calohit = np.asarray(event.get("pfo_calohit", []), dtype=np.int64) pfo_track = np.asarray(event.get("pfo_track", []), dtype=np.int64) calohit_csv = _1d_to_csv(pfo_calohit) track_csv = _1d_to_csv(pfo_track) if calohit_csv and track_csv: pfo_links_csv = calohit_csv + "\n" + track_csv elif calohit_csv: pfo_links_csv = calohit_csv elif track_csv: pfo_links_csv = "\n" + track_csv else: pfo_links_csv = "" return ( _arr_to_csv(hits_arr), _arr_to_csv(tracks_arr), _arr_to_csv(particles_arr), _arr_to_csv(pandora_arr), pfo_links_csv, f"✅ Loaded event **{int(event_index)}**: " f"{hits_arr.shape[0] if hits_arr.ndim == 2 else 0} hits, " f"{tracks_arr.shape[0] if tracks_arr.ndim == 2 else 0} tracks, " f"{particles_arr.shape[0] if particles_arr.ndim == 2 else 0} MC particles, " f"{pandora_arr.shape[0] if pandora_arr.ndim == 2 else 0} Pandora PFOs", ) except Exception as e: return "", "", "", "", "", f"❌ Error loading event: {e}" def _build_cluster_plot(hit_cluster_df: pd.DataFrame) -> go.Figure: """Build an interactive 3D scatter plot of hits colored by cluster ID.""" if hit_cluster_df.empty: fig = go.Figure() fig.update_layout(title="No hit data available", height=600) return fig df = hit_cluster_df.copy() # Drop rows with NaN/Inf coordinates for col in ("x", "y", "z", "hit_energy"): df[col] = pd.to_numeric(df[col], errors="coerce") df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["x", "y", "z", "hit_energy"]) if df.empty: fig = go.Figure() fig.update_layout(title="No valid hit data (all NaN/Inf)", height=600) return fig # Normalize hit energies for marker sizes energies = df["hit_energy"].values.astype(float) e_min, e_max = float(energies.min()), float(energies.max()) if e_max > e_min: norm_e = (energies - e_min) / (e_max - e_min) else: norm_e = np.ones_like(energies) * 0.5 # midpoint when all equal marker_sizes = 3 + norm_e * 12 # min size 3, max size 15 # Build per-hit hover text (avoids mixed-type customdata serialization issues) df["_hover"] = ( "" + df["hit_type"].astype(str) + " hit #" + df["hit_index"].astype(int).astype(str) + "
" + "Cluster: " + df["cluster_id"].astype(int).astype(str) + "
" + "Energy: " + df["hit_energy"].map(lambda v: f"{v:.4f}") + "
" + "x: " + df["x"].map(lambda v: f"{v:.2f}") + ", y: " + df["y"].map(lambda v: f"{v:.2f}") + ", z: " + df["z"].map(lambda v: f"{v:.2f}") ) cluster_ids = df["cluster_id"].values unique_clusters = sorted(set(int(c) for c in cluster_ids)) fig = go.Figure() for cid in unique_clusters: mask = cluster_ids == cid subset = df[mask] sizes = marker_sizes[mask].tolist() label = "noise" if cid == 0 else f"cluster {cid}" fig.add_trace(go.Scatter3d( x=subset["x"].tolist(), y=subset["y"].tolist(), z=subset["z"].tolist(), mode="markers", name=label, marker=dict(size=sizes, opacity=0.8), hovertext=subset["_hover"].tolist(), hoverinfo="text", )) fig.update_layout( title="Hit → Cluster 3D Map", scene=dict(xaxis_title="x", yaxis_title="y", zaxis_title="z"), legend_title="Cluster", height=600, margin=dict(l=0, r=0, t=40, b=0), ) return fig def _build_pandora_cluster_plot(hit_cluster_df: pd.DataFrame) -> go.Figure: """Build an interactive 3D scatter plot of hits colored by Pandora cluster ID.""" if hit_cluster_df.empty or "pandora_cluster_id" not in hit_cluster_df.columns: fig = go.Figure() fig.update_layout(title="No Pandora cluster data available", height=600) return fig df = hit_cluster_df.copy() # Only keep rows that have valid Pandora assignments (pandora_cluster_id >= 0) for col in ("x", "y", "z", "hit_energy"): df[col] = pd.to_numeric(df[col], errors="coerce") df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=["x", "y", "z", "hit_energy"]) if df.empty: fig = go.Figure() fig.update_layout(title="No valid hit data for Pandora plot (all NaN/Inf)", height=600) return fig # Normalize hit energies for marker sizes energies = df["hit_energy"].values.astype(float) e_min, e_max = float(energies.min()), float(energies.max()) if e_max > e_min: norm_e = (energies - e_min) / (e_max - e_min) else: norm_e = np.ones_like(energies) * 0.5 marker_sizes = 3 + norm_e * 12 # Build per-hit hover text df["_hover"] = ( "" + df["hit_type"].astype(str) + " hit #" + df["hit_index"].astype(int).astype(str) + "
" + "Pandora cluster: " + df["pandora_cluster_id"].astype(int).astype(str) + "
" + "Energy: " + df["hit_energy"].map(lambda v: f"{v:.4f}") + "
" + "x: " + df["x"].map(lambda v: f"{v:.2f}") + ", y: " + df["y"].map(lambda v: f"{v:.2f}") + ", z: " + df["z"].map(lambda v: f"{v:.2f}") ) pandora_ids = df["pandora_cluster_id"].values unique_clusters = sorted(set(int(c) for c in pandora_ids)) fig = go.Figure() for cid in unique_clusters: mask = pandora_ids == cid subset = df[mask] sizes = marker_sizes[mask].tolist() label = "unassigned" if cid == -1 else f"PFO {cid}" fig.add_trace(go.Scatter3d( x=subset["x"].tolist(), y=subset["y"].tolist(), z=subset["z"].tolist(), mode="markers", name=label, marker=dict(size=sizes, opacity=0.8), hovertext=subset["_hover"].tolist(), hoverinfo="text", )) fig.update_layout( title="Hit → Pandora Cluster 3D Map", scene=dict(xaxis_title="x", yaxis_title="y", zaxis_title="z"), legend_title="Pandora PFO", height=600, margin=dict(l=0, r=0, t=40, b=0), ) return fig def _build_clustering_space_plot(hit_cluster_df: pd.DataFrame) -> go.Figure: """Build an interactive 3D scatter plot of hits in the learned clustering space.""" if hit_cluster_df.empty or "cluster_x" not in hit_cluster_df.columns: fig = go.Figure() fig.update_layout(title="No clustering-space data available", height=600) return fig df = hit_cluster_df.copy() # Drop rows with NaN/Inf coordinates for col in ("cluster_x", "cluster_y", "cluster_z", "hit_energy"): df[col] = pd.to_numeric(df[col], errors="coerce") df = df.replace([np.inf, -np.inf], np.nan).dropna( subset=["cluster_x", "cluster_y", "cluster_z", "hit_energy"] ) if df.empty: fig = go.Figure() fig.update_layout(title="No valid clustering-space data (all NaN/Inf)", height=600) return fig # Normalize hit energies for marker sizes energies = df["hit_energy"].values.astype(float) e_min, e_max = float(energies.min()), float(energies.max()) if e_max > e_min: norm_e = (energies - e_min) / (e_max - e_min) else: norm_e = np.ones_like(energies) * 0.5 marker_sizes = 3 + norm_e * 12 # Build per-hit hover text df["_hover"] = ( "" + df["hit_type"].astype(str) + " hit #" + df["hit_index"].astype(int).astype(str) + "
" + "Cluster: " + df["cluster_id"].astype(int).astype(str) + "
" + "Energy: " + df["hit_energy"].map(lambda v: f"{v:.4f}") + "
" + "cluster_x: " + df["cluster_x"].map(lambda v: f"{v:.4f}") + ", cluster_y: " + df["cluster_y"].map(lambda v: f"{v:.4f}") + ", cluster_z: " + df["cluster_z"].map(lambda v: f"{v:.4f}") ) cluster_ids = df["cluster_id"].values unique_clusters = sorted(set(int(c) for c in cluster_ids)) fig = go.Figure() for cid in unique_clusters: mask = cluster_ids == cid subset = df[mask] sizes = marker_sizes[mask].tolist() label = "noise" if cid == 0 else f"cluster {cid}" fig.add_trace(go.Scatter3d( x=subset["cluster_x"].tolist(), y=subset["cluster_y"].tolist(), z=subset["cluster_z"].tolist(), mode="markers", name=label, marker=dict(size=sizes, opacity=0.8), hovertext=subset["_hover"].tolist(), hoverinfo="text", )) fig.update_layout( title="Clustering Space 3D Map (GATr regressed coordinates)", scene=dict( xaxis_title="cluster_x", yaxis_title="cluster_y", zaxis_title="cluster_z", ), legend_title="Cluster", height=600, margin=dict(l=0, r=0, t=40, b=0), ) return fig # --------------------------------------------------------------------------- # Main inference entry point for the UI # --------------------------------------------------------------------------- def _compute_inv_mass(df, e_col, px_col, py_col, pz_col): """Compute the invariant mass of a system of particles in GeV. Returns the scalar invariant mass m = sqrt(max((ΣE)²−(Σpx)²−(Σpy)²−(Σpz)², 0)), or *None* when *df* is empty or the required columns are absent. """ if df.empty: return None for col in (e_col, px_col, py_col, pz_col): if col not in df.columns: return None E = float(df[e_col].sum()) px = float(df[px_col].sum()) py = float(df[py_col].sum()) pz = float(df[pz_col].sum()) m2 = E ** 2 - px ** 2 - py ** 2 - pz ** 2 return float(np.sqrt(max(m2, 0.0))) def _fmt_mass(val): """Format an invariant-mass value (float or None) as a GeV string.""" return f"{val:.4f} GeV" if val is not None else "N/A" def run_inference_ui( parquet_path: str, event_index: int, csv_hits: str, csv_tracks: str, csv_particles: str, csv_pandora: str, csv_pfo_links: str = "", ): """Run inference on a single event, return predicted particles, 3D plots, MC particles and Pandora particles. Returns ------- particles_df : pandas.DataFrame cluster_fig : plotly.graph_objects.Figure clustering_space_fig : plotly.graph_objects.Figure pandora_cluster_fig : plotly.graph_objects.Figure mc_particles_df : pandas.DataFrame pandora_particles_df : pandas.DataFrame inv_mass_summary : str """ global _MODEL, _ARGS, _DEVICE empty_fig = go.Figure() if _MODEL is None: return ( pd.DataFrame({"error": ["Model not loaded. Please load a model first."]}), empty_fig, empty_fig, empty_fig, pd.DataFrame(), pd.DataFrame(), "", ) try: from src.inference import load_event_from_parquet, run_single_event_inference # Decide input source use_parquet = parquet_path and os.path.isfile(parquet_path) use_csv = bool(csv_hits and csv_hits.strip()) if not use_parquet and not use_csv: return ( pd.DataFrame({"error": ["Provide a parquet file or paste CSV hit data."]}), empty_fig, empty_fig, empty_fig, pd.DataFrame(), pd.DataFrame(), "", ) if use_csv: event = _parse_csv_event(csv_hits, csv_tracks, csv_particles, csv_pandora, csv_pfo_links) elif use_parquet: event = load_event_from_parquet(parquet_path, int(event_index)) particles_df, hit_cluster_df, mc_particles_df, pandora_particles_df = run_single_event_inference( event, _MODEL, _ARGS, device=_DEVICE, ) if particles_df.empty: particles_df = pd.DataFrame({"info": ["Event produced no clusters (empty graph)."]}) cluster_fig = _build_cluster_plot(hit_cluster_df) clustering_space_fig = _build_clustering_space_plot(hit_cluster_df) pandora_cluster_fig = _build_pandora_cluster_plot(hit_cluster_df) # Compute invariant masses [GeV] m_true = _compute_inv_mass(mc_particles_df, "energy", "px", "py", "pz") # HitPF uses corrected_energy when available, otherwise energy_sum_hits hitpf_e_col = "corrected_energy" if "corrected_energy" in particles_df.columns else "energy_sum_hits" m_reco_hitpf = _compute_inv_mass(particles_df, hitpf_e_col, "px", "py", "pz") m_reco_pandora = _compute_inv_mass(pandora_particles_df, "energy", "px", "py", "pz") inv_mass_summary = ( f"**Invariant mass (sum of all particle 4-vectors)**\n\n" f"| Algorithm | m [GeV] |\n" f"|---|---|\n" f"| m_true (MC truth) | {_fmt_mass(m_true)} |\n" f"| m_reco (HitPF) | {_fmt_mass(m_reco_hitpf)} |\n" f"| m_reco (Pandora) | {_fmt_mass(m_reco_pandora)} |" ) return particles_df, cluster_fig, clustering_space_fig, pandora_cluster_fig, mc_particles_df, pandora_particles_df, inv_mass_summary except Exception: err = traceback.format_exc() return ( pd.DataFrame({"error": [err]}), empty_fig, empty_fig, empty_fig, pd.DataFrame(), pd.DataFrame(), "", ) def _parse_csv_event(csv_hits: str, csv_tracks: str, csv_particles: str, csv_pandora: str = "", csv_pfo_links: str = ""): """Parse user-provided CSV text into the dict-of-arrays format expected by ``create_graph``. Expected CSV columns for hits (X_hit) — 11 columns: 0: hit_x — hit position x [mm] 1: hit_y — hit position y [mm] 2: hit_z — hit position z [mm] 3: hit_px — hit momentum px [GeV] (0 for calo hits) 4: hit_py — hit momentum py [GeV] (0 for calo hits) 5: hit_energy — hit energy deposit [GeV] 6: hit_x_calo — hit position x at calorimeter surface [mm] (used as 3D position by the model) 7: hit_y_calo — hit position y at calorimeter surface [mm] 8: hit_z_calo — hit position z at calorimeter surface [mm] 9: (unused) — reserved column (set to 0) 10: hit_type — hit sub-detector type: 1 = ECAL, 2 = HCAL, 3 = muon system Expected CSV columns for tracks (X_track) — 25 columns (padded with zeros if fewer are provided; minimum 17): 0: elemtype — element type (always 1 for tracks) 1–4: (unused) — reserved columns (set to 0) 5: p — track momentum magnitude |p| [GeV] 6: px_IP — track px at interaction point [GeV] 7: py_IP — track py at interaction point [GeV] 8: pz_IP — track pz at interaction point [GeV] 9–11: (unused) — reserved columns (set to 0) 12: ref_x_calo — track reference-point x at calorimeter [mm] 13: ref_y_calo — track reference-point y at calorimeter [mm] 14: ref_z_calo — track reference-point z at calorimeter [mm] 15: chi2 — track-fit chi-squared 16: ndf — track-fit number of degrees of freedom 17–21: (unused) — reserved columns (set to 0) 22: px_calo — track momentum x component at calorimeter [GeV] 23: py_calo — track momentum y component at calorimeter [GeV] 24: pz_calo — track momentum z component at calorimeter [GeV] Expected CSV columns for particles / MC truth (X_gen) — 18 columns: 0: pid — PDG particle ID (e.g. 211, 22, 11, 13) 1: gen_status — generator status code 2: isDecayedInCalo — 1 if decayed in calorimeter, else 0 3: isDecayedInTracker — 1 if decayed in tracker, else 0 4: theta — polar angle [rad] 5: phi — azimuthal angle [rad] 6: (unused) — reserved (set to 0) 7: (unused) — reserved (set to 0) 8: energy — true particle energy [GeV] 9: (unused) — reserved (set to 0) 10: mass — particle mass [GeV] 11: momentum — momentum magnitude |p| [GeV] 12: px — momentum x component [GeV] 13: py — momentum y component [GeV] 14: pz — momentum z component [GeV] 15: vx — production vertex x [mm] 16: vy — production vertex y [mm] 17: vz — production vertex z [mm] PFO links (csv_pfo_links) — two lines of comma-separated integers: Line 1: pfo_calohit — one PFO index per calorimeter hit (-1 = unassigned) Line 2: pfo_track — one PFO index per track (-1 = unassigned) """ import io import awkward as ak def _read(text, min_cols=1): if not text or not text.strip(): return np.zeros((0, min_cols), dtype=np.float64) df = pd.read_csv(io.StringIO(text), header=None) return df.values.astype(np.float64) hits_arr = _read(csv_hits, 11) tracks_arr = _read(csv_tracks, 25) particles_arr = _read(csv_particles, 18) pandora_arr = _read(csv_pandora, 9) # Pad tracks to 25 columns if needed if tracks_arr.shape[1] < 25 and tracks_arr.shape[0] > 0: pad = np.zeros((tracks_arr.shape[0], 25 - tracks_arr.shape[1])) tracks_arr = np.concatenate([tracks_arr, pad], axis=1) # Build ygen_hit / ygen_track (particle link per hit — use -1 for unknown) ygen_hit = np.full(len(hits_arr), -1, dtype=np.int64) ygen_track = np.full(len(tracks_arr), -1, dtype=np.int64) # Parse PFO link arrays (hit → Pandora cluster mapping) pfo_calohit = np.array([], dtype=np.int64) pfo_track = np.array([], dtype=np.int64) if csv_pfo_links and csv_pfo_links.strip(): lines = csv_pfo_links.strip().split("\n") if len(lines) >= 1 and lines[0].strip(): pfo_calohit = np.array( [int(v) for v in lines[0].strip().split(",")], dtype=np.int64 ) if len(lines) >= 2 and lines[1].strip(): pfo_track = np.array( [int(v) for v in lines[1].strip().split(",")], dtype=np.int64 ) event = { "X_hit": hits_arr, "X_track": tracks_arr, "X_gen": particles_arr, "X_pandora": pandora_arr, "ygen_hit": ygen_hit, "ygen_track": ygen_track, "pfo_calohit": pfo_calohit, "pfo_track": pfo_track, } return event # --------------------------------------------------------------------------- # Build the Gradio interface # --------------------------------------------------------------------------- def build_app(): with gr.Blocks(title="HitPF — Single-event MLPF Inference") as demo: gr.Markdown( "# HitPF — Single-event MLPF Inference\n" "Run the GATr-based particle-flow reconstruction on a single event.\n\n" "**Steps:** 1) Load model checkpoints 2) Select an event 3) Run inference" ) # ---- Model loading ---- with gr.Accordion("1 · Load Model", open=True): with gr.Row(): clustering_ckpt = gr.Textbox( label="Clustering checkpoint (.ckpt)", value="model_clustering.ckpt", placeholder="/path/to/clustering.ckpt", ) energy_pid_ckpt = gr.Textbox( label="Energy / PID checkpoint (.ckpt) — optional", value="model_e_pid.ckpt", placeholder="/path/to/energy_pid.ckpt", ) device_dd = gr.Dropdown( choices=["cpu", "cuda:0", "cuda:1"], value="cpu", label="Device", ) load_btn = gr.Button("Load model") load_status = gr.Markdown("") load_btn.click( fn=load_model_ui, inputs=[clustering_ckpt, energy_pid_ckpt, device_dd], outputs=load_status, ) # ---- Event selection ---- with gr.Accordion("2 · Select Event", open=True): gr.Markdown("**Option A** — from a parquet file:") with gr.Row(): parquet_path = gr.Textbox( label="Parquet file path", value="test_data.parquet", placeholder="/path/to/events.parquet", ) event_idx = gr.Number(label="Event index", value=0, precision=0) parquet_info = gr.Markdown("") parquet_path.change( fn=_count_events_in_parquet, inputs=parquet_path, outputs=parquet_info, ) load_event_btn = gr.Button("Load event from parquet") load_event_status = gr.Markdown("") gr.Markdown( "---\n**Option B** — paste CSV data (one row per hit/track/particle, " "no header, comma-separated):\n" ) csv_hits = gr.Textbox( label="Hits CSV (11 columns)", lines=4, placeholder=( "Example (one ECAL hit, one HCAL hit):\n" "0,0,0,0,0,1.23,1800.5,200.3,100.1,0,1\n" "0,0,0,0,0,0.45,1900.2,-50.1,300.7,0,2" ), ) csv_tracks = gr.Textbox( label="Tracks CSV (25 columns; leave empty if none)", lines=3, placeholder=( "Example (one track with p≈5 GeV):\n" "1,0,0,0,0,5.0,3.0,2.0,3.3,0,0,0,1800.0,150.0,90.0,12.5,8,0,0,0,0,0,2.9,1.9,3.2" ), ) csv_particles = gr.Textbox( label="Particles (MC truth) CSV (18 columns; optional)", lines=3, placeholder=( "Example (one pion, one photon):\n" "211,1,0,0,1.2,0.5,0,0,5.2,0,0.1396,5.198,3.1,2.0,3.3,0,0,0\n" "22,1,0,0,0.8,2.1,0,0,1.5,0,0,1.5,0.5,-0.3,1.38,0,0,0" ), ) csv_pandora = gr.Textbox( label="Pandora PFOs CSV (9 columns; optional)", lines=3, placeholder=( "Columns: pid, px, py, pz, ref_x, ref_y, ref_z, energy, momentum\n" "Example (one charged pion PFO):\n" "211,3.0,2.0,3.3,1800.0,150.0,90.0,5.2,5.198" ), ) csv_pfo_links = gr.Textbox( label="Hit → Pandora Cluster links (optional; loaded from parquet)", lines=2, placeholder=( "Line 1: PFO index per calo hit (comma-separated, -1 = unassigned)\n" "Line 2: PFO index per track (comma-separated, -1 = unassigned)" ), ) load_event_btn.click( fn=_load_event_into_csv, inputs=[parquet_path, event_idx], outputs=[csv_hits, csv_tracks, csv_particles, csv_pandora, csv_pfo_links, load_event_status], ) # ---- Run inference ---- with gr.Accordion("3 · Results", open=True): run_btn = gr.Button("▶ Run Inference", variant="primary") inv_mass_output = gr.Markdown("") gr.Markdown("### Predicted Particles (HitPF)") particles_table = gr.Dataframe(label="Predicted particles") gr.Markdown("### MC Truth Particles") mc_particles_table = gr.Dataframe(label="MC truth particles (for comparison)") gr.Markdown("### Pandora Particles") pandora_particles_table = gr.Dataframe(label="Pandora PFO particles (for comparison)") with gr.Row(): with gr.Column(): gr.Markdown("### Hit → HitPF Cluster 3D Map") cluster_plot = gr.Plot(label="Hit-cluster 3D scatter (color = HitPF cluster, size = energy)") with gr.Column(): gr.Markdown("### Hit → Pandora Cluster 3D Map") pandora_cluster_plot = gr.Plot(label="Hit-cluster 3D scatter (color = Pandora PFO, size = energy)") gr.Markdown("### Clustering Space 3D Map") clustering_space_plot = gr.Plot(label="Clustering space 3D scatter (GATr regressed coordinates)") run_btn.click( fn=run_inference_ui, inputs=[parquet_path, event_idx, csv_hits, csv_tracks, csv_particles, csv_pandora, csv_pfo_links], outputs=[particles_table, cluster_plot, clustering_space_plot, pandora_cluster_plot, mc_particles_table, pandora_particles_table, inv_mass_output], ) return demo # --------------------------------------------------------------------------- if __name__ == "__main__": ap = argparse.ArgumentParser(description="HitPF Gradio UI") ap.add_argument("--device", default="cpu", help="Default device (cpu / cuda:0 / …)") ap.add_argument("--share", action="store_true", help="Create a public Gradio link") cli_args = ap.parse_args() _set_device(cli_args.device) demo = build_app() demo.launch(share=cli_args.share)