"""
Robot Policy Evaluation Harness
Interactive HuggingFace Space — Bayesian + SPARC + STL on real robot data.
Based on: Kress-Gazit et al. (TRI/Cornell) arXiv:2409.09491
"""
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from scipy.fft import rfft, rfftfreq
from scipy import stats
from datasets import load_dataset
import gradio as gr
import io, json
# ── constants ────────────────────────────────────────────────────────────────
PALETTE = ["#60A5FA", "#FB923C", "#F87171", "#34D399", "#A78BFA"]
BG = "#0F172A"
CARD = "#1E293B"
BORDER = "#334155"
TEXT = "#F1F5F9"
SUBTEXT = "#94A3B8"
ACCENT = "#38BDF8"
PLOTLY_LAYOUT = dict(
paper_bgcolor=CARD, plot_bgcolor=CARD,
font=dict(color=TEXT, family="Inter, sans-serif"),
margin=dict(l=40, r=20, t=50, b=40),
legend=dict(bgcolor="rgba(0,0,0,0)", bordercolor=BORDER),
xaxis=dict(gridcolor=BORDER, zerolinecolor=BORDER),
yaxis=dict(gridcolor=BORDER, zerolinecolor=BORDER),
)
FS = 50 # Hz — ALOHA dataset sampling rate
# ── signal extraction ─────────────────────────────────────────────────────────
def extract_episode(states, actions):
states = np.array(states, dtype=float)
actions = np.array(actions, dtype=float)
dq = np.diff(states, axis=0)
speed = np.linalg.norm(dq, axis=1) * FS
# align dims before subtracting — state and action may differ (e.g. Franka 13 vs 15)
min_dim = min(states.shape[1], actions.shape[1])
effort = np.linalg.norm(actions[:, :min_dim] - states[:, :min_dim], axis=1)
raw_z = states[:, 1]
z = (raw_z - raw_z.min()) / (raw_z.max() - raw_z.min() + 1e-9)
return speed, effort, z
# ── SPARC ─────────────────────────────────────────────────────────────────────
def sparc(speed, fs=FS, padlevel=4, fc=10.0, amp_th=0.05):
speed = np.asarray(speed, dtype=float)
if speed.max() == 0:
return 0.0
nfft = int(pow(2, np.ceil(np.log2(len(speed))) + padlevel))
freqs = rfftfreq(nfft, d=1.0 / fs)
Mf = np.abs(rfft(speed, n=nfft)); Mf /= Mf.max()
inx = np.where((freqs <= fc) & (Mf >= amp_th))[0]
fc_used = freqs[inx[-1]] if len(inx) else fc
inx = np.where(freqs <= fc_used)[0]
dMf = np.diff(Mf[inx]) / (freqs[1] - freqs[0])
_trapz = getattr(np, "trapezoid", None) or getattr(np, "trapz")
return float(-np.sqrt(_trapz(dMf**2 + 1, freqs[inx[:-1]])))
# ── STL manual robustness ─────────────────────────────────────────────────────
def stl_robustness(effort, z, threshold):
"""min over t of: if effort>threshold then z>0.3"""
high = effort > threshold
if high.any():
return float((z[high] - 0.3).min())
return float((threshold - effort).min())
# ── Bayesian helpers ──────────────────────────────────────────────────────────
def posterior(s, n, N=200_000):
return np.random.beta(1 + s, 1 + n - s, N)
def ci_lower(s, n, N=100_000):
return float(np.percentile(np.random.beta(1 + s, 1 + n - s, N), 5))
# ── load ALOHA demo data ──────────────────────────────────────────────────────
_cache = {}
def load_aloha():
if "aloha" in _cache:
return _cache["aloha"]
raw = load_dataset("lerobot/aloha_static_cups_open", split="train")
episodes = {}
for row in raw:
ei = row["episode_index"]
if ei not in episodes:
episodes[ei] = {"states": [], "actions": []}
episodes[ei]["states"].append(row["observation.state"])
episodes[ei]["actions"].append(row["action"])
ep_ids = sorted(episodes.keys())
extracted = {}
for ei in ep_ids:
speed, effort, z = extract_episode(episodes[ei]["states"], episodes[ei]["actions"])
extracted[ei] = {"speed": speed, "effort": effort, "z": z}
_cache["aloha"] = (ep_ids, extracted)
return ep_ids, extracted
# ── core analysis ─────────────────────────────────────────────────────────────
def run_analysis(policy_data):
"""
policy_data: dict of name → {trials, speeds, efforts, zs}
Returns dict of figures + report text.
"""
names = list(policy_data.keys())
colors = {n: PALETTE[i] for i, n in enumerate(names)}
all_efforts = np.concatenate([e for d in policy_data.values() for e in d["efforts"]])
effort_thresh = float(np.percentile(all_efforts, 75))
# ── compute metrics ───────────────────────────────────────────────────────
metrics = {}
for name, d in policy_data.items():
s = sum(d["trials"]); n = len(d["trials"])
sc = [sparc(sp) for sp in d["speeds"]]
st = [stl_robustness(ef, z, effort_thresh)
for ef, z in zip(d["efforts"], d["zs"])]
metrics[name] = {
"s": s, "n": n,
"sparc": sc,
"stl": st,
"ci_lo": ci_lower(s, n),
"safe": sum(1 for x in st if x >= 0) / len(st),
}
# ── fig 1: bayesian posteriors ────────────────────────────────────────────
x = np.linspace(0, 1, 400)
fig1 = go.Figure()
for name in names:
m = metrics[name]
a, b = 1 + m["s"], 1 + m["n"] - m["s"]
y = stats.beta.pdf(x, a, b)
fig1.add_trace(go.Scatter(
x=x, y=y, mode="lines", name=f"Policy {name} ({m['s']}/{m['n']})",
line=dict(color=colors[name], width=2.5),
fill="tozeroy", fillcolor=colors[name].replace(")", ",0.1)").replace("rgb", "rgba"),
))
fig1.add_vline(x=m["ci_lo"], line_color=colors[name],
line_dash="dot", line_width=1.5)
fig1.update_layout(**PLOTLY_LAYOUT,
title="① Bayesian Posteriors (dotted = 95 % CI lower bound)",
xaxis_title="Success probability p", yaxis_title="Density")
# pairwise matrix
mat = np.zeros((len(names), len(names)))
samps = {n: posterior(metrics[n]["s"], metrics[n]["n"]) for n in names}
for i, a in enumerate(names):
for j, b in enumerate(names):
mat[i, j] = (samps[a] > samps[b]).mean()
fig1b = go.Figure(go.Heatmap(
z=mat, x=[f"Policy {n}" for n in names], y=[f"Policy {n}" for n in names],
colorscale="RdYlGn", zmin=0, zmax=1,
text=[[f"{mat[i,j]:.2f}" for j in range(len(names))] for i in range(len(names))],
texttemplate="%{text}", textfont=dict(size=14),
))
fig1b.update_layout(**PLOTLY_LAYOUT, title="P(row beats col)")
# ── fig 2: SPARC ──────────────────────────────────────────────────────────
fig2 = go.Figure()
for name in names:
sc = metrics[name]["sparc"]
fig2.add_trace(go.Box(
y=sc, name=f"Policy {name}",
marker_color=colors[name], line_color=colors[name],
boxmean=True, fillcolor=colors[name].replace(")", ",0.3)").replace("rgb", "rgba"),
))
fig2.add_hline(y=np.mean([v for m in metrics.values() for v in m["sparc"]]),
line_dash="dot", line_color=SUBTEXT, annotation_text="global mean")
fig2.update_layout(**PLOTLY_LAYOUT,
title="② SPARC Smoothness (less negative = smoother)",
yaxis_title="SPARC score")
# sample speed profiles
fig2b = go.Figure()
for name in names:
sp = policy_data[name]["speeds"][0]
t = np.arange(len(sp)) / FS
fig2b.add_trace(go.Scatter(x=t, y=sp, mode="lines",
name=f"Policy {name}", line=dict(color=colors[name], width=1.8)))
fig2b.update_layout(**PLOTLY_LAYOUT,
title="Joint-Space Speed Profile (first episode per policy)",
xaxis_title="Time (s)", yaxis_title="Speed (rad/s)")
# ── fig 3: STL ────────────────────────────────────────────────────────────
fig3 = go.Figure()
for name in names:
st = metrics[name]["stl"]
fig3.add_trace(go.Scatter(
x=[f"Policy {name}"] * len(st),
y=st, mode="markers",
name=f"Policy {name}",
marker=dict(color=colors[name], size=9, opacity=0.7,
line=dict(color="white", width=0.5)),
))
fig3.add_trace(go.Scatter(
x=[f"Policy {name}", f"Policy {name}"],
y=[np.mean(st), np.mean(st)],
mode="lines", line=dict(color=colors[name], width=4),
showlegend=False,
))
fig3.add_hline(y=0, line_dash="dash", line_color="white", line_width=1.5,
annotation_text="violation boundary")
fig3.update_layout(**PLOTLY_LAYOUT,
title="③ STL Safety Robustness (positive = constraint satisfied)",
yaxis_title="Robustness score")
# violation bar
viols = [sum(1 for x in metrics[n]["stl"] if x < 0) for n in names]
totals = [metrics[n]["n"] for n in names]
fig3b = go.Figure(go.Bar(
x=[f"Policy {n}" for n in names], y=viols,
marker_color=[colors[n] for n in names],
text=[f"{v}/{t}" for v, t in zip(viols, totals)],
textposition="outside",
))
fig3b.update_layout(**PLOTLY_LAYOUT,
title="Constraint Violations per Policy",
yaxis_title="# violations")
# ── fig 4: composite radar + bar ──────────────────────────────────────────
def normalize(vals):
lo, hi = min(vals), max(vals)
return [(v - lo) / (hi - lo + 1e-9) for v in vals]
sparc_norm = normalize([-metrics[n]["ci_lo"] for n in names]) # invert (less neg = better)
sparc_norm = [1 - v for v in normalize([-metrics[n]["ci_lo"] for n in names])]
sparc_norm = normalize([-np.mean(metrics[n]["sparc"]) for n in names])
sparc_norm = [1 - v for v in sparc_norm]
composite = {}
for i, name in enumerate(names):
m = metrics[name]
composite[name] = (
0.40 * m["ci_lo"] +
0.20 * sparc_norm[i] +
0.25 * m["safe"] +
0.15 * (m["s"] / m["n"])
)
cats = ["Success
(CI lb)", "Smoothness", "Safety
(STL)", "Success
rate"]
fig4 = go.Figure()
for i, name in enumerate(names):
m = metrics[name]
vals = [m["ci_lo"], sparc_norm[i], m["safe"], m["s"] / m["n"]]
vals += vals[:1]
theta = cats + [cats[0]]
fig4.add_trace(go.Scatterpolar(
r=vals, theta=theta, fill="toself", name=f"Policy {name}",
line=dict(color=colors[name], width=2),
fillcolor=colors[name].replace(")", ",0.15)").replace("rgb", "rgba"),
))
fig4.update_layout(
paper_bgcolor=CARD, font=dict(color=TEXT),
polar=dict(
bgcolor=CARD,
radialaxis=dict(visible=True, range=[0, 1], gridcolor=BORDER, color=SUBTEXT),
angularaxis=dict(gridcolor=BORDER, color=TEXT),
),
title="④ Composite Radar",
legend=dict(bgcolor="rgba(0,0,0,0)"),
margin=dict(l=60, r=60, t=60, b=40),
)
cv = [composite[n] for n in names]
fig4b = go.Figure(go.Bar(
x=[f"Policy {n}" for n in names], y=cv,
marker_color=[colors[n] for n in names],
text=[f"{v:.3f}" for v in cv],
textposition="outside",
))
winner = names[int(np.argmax(cv))]
fig4b.update_layout(**PLOTLY_LAYOUT,
title=f"④ Final Ranking (winner: Policy {winner})",
yaxis_title="Composite score",
yaxis_range=[0, max(cv) * 1.3])
# ── scorecard text ────────────────────────────────────────────────────────
rows = ["| Metric | " + " | ".join(f"Policy {n}" for n in names) + " |",
"|" + "---|" * (len(names) + 1)]
defs = [
("Episodes", lambda n: str(metrics[n]["n"])),
("Successes", lambda n: f"{metrics[n]['s']}/{metrics[n]['n']} ({metrics[n]['s']/metrics[n]['n']:.0%})"),
("95% CI lower", lambda n: f"{metrics[n]['ci_lo']:.1%}"),
("Mean SPARC", lambda n: f"{np.mean(metrics[n]['sparc']):.3f}"),
("Safe fraction", lambda n: f"{metrics[n]['safe']:.0%}"),
("Composite", lambda n: f"**{composite[n]:.3f}**"),
]
for label, fn in defs:
rows.append("| " + label + " | " + " | ".join(fn(n) for n in names) + " |")
rows.append(f"\n🏆 **Recommended policy: {winner}**")
rows.append(f"\nEffort threshold used for STL: `{effort_thresh:.4f}`")
return fig1, fig1b, fig2, fig2b, fig3, fig3b, fig4, fig4b, "\n".join(rows)
# ── demo analysis (ALOHA) ─────────────────────────────────────────────────────
def run_demo(n_A, n_B, n_C, sr_A, sr_B, sr_C, progress=gr.Progress()):
progress(0, desc="Loading ALOHA dataset from HuggingFace…")
ep_ids, extracted = load_aloha()
total = n_A + n_B + n_C
if total > len(ep_ids):
n_A = min(n_A, len(ep_ids) // 3)
n_B = min(n_B, len(ep_ids) // 3)
n_C = len(ep_ids) - n_A - n_B
progress(0.3, desc="Extracting signals…")
ids_A = ep_ids[:n_A]
ids_B = ep_ids[n_A:n_A + n_B]
ids_C = ep_ids[n_A + n_B:n_A + n_B + n_C]
def make_policy(eids, sr):
n = len(eids)
ns = int(round(sr * n))
t = [1]*ns + [0]*(n - ns); np.random.shuffle(t)
return {
"trials": t,
"speeds": [extracted[ei]["speed"] for ei in eids],
"efforts": [extracted[ei]["effort"] for ei in eids],
"zs": [extracted[ei]["z"] for ei in eids],
}
policy_data = {
"A": make_policy(ids_A, sr_A / 100),
"B": make_policy(ids_B, sr_B / 100),
"C": make_policy(ids_C, sr_C / 100),
}
progress(0.6, desc="Running Bayesian + SPARC + STL analysis…")
results = run_analysis(policy_data)
progress(1.0, desc="Done!")
return results
# ── upload analysis ────────────────────────────────────────────────────────────
def run_upload(file):
if file is None:
return [None]*8 + ["⚠️ Please upload a CSV file."]
df = pd.read_csv(file.name)
required = {"episode_id", "success"}
state_cols = [c for c in df.columns if c.startswith("state_")]
action_cols = [c for c in df.columns if c.startswith("action_")]
if not required.issubset(df.columns):
return [None]*8 + [f"⚠️ CSV must have columns: episode_id, success, state_0…state_N, action_0…action_N\nFound: {list(df.columns)}"]
if not state_cols:
return [None]*8 + ["⚠️ No state columns found (expected state_0, state_1, …)"]
# Group by episode
policy_data = {"A": {"trials": [], "speeds": [], "efforts": [], "zs": []}}
for ei, grp in df.groupby("episode_id"):
states = grp[state_cols].values
actions = grp[action_cols].values if action_cols else states
speed, effort, z = extract_episode(states, actions)
policy_data["A"]["trials"].append(int(grp["success"].iloc[-1]))
policy_data["A"]["speeds"].append(speed)
policy_data["A"]["efforts"].append(effort)
policy_data["A"]["zs"].append(z)
# If policy_name column exists, split into multiple policies
if "policy_name" in df.columns:
policy_data = {}
for pname, pdf in df.groupby("policy_name"):
pd_ = {"trials": [], "speeds": [], "efforts": [], "zs": []}
for ei, grp in pdf.groupby("episode_id"):
states = grp[state_cols].values
actions = grp[action_cols].values if action_cols else states
speed, effort, z = extract_episode(states, actions)
pd_["trials"].append(int(grp["success"].iloc[-1]))
pd_["speeds"].append(speed)
pd_["efforts"].append(effort)
pd_["zs"].append(z)
policy_data[str(pname)] = pd_
return run_analysis(policy_data)
# ── CSV template + sample downloads ───────────────────────────────────────────
def make_template():
rows = []
for ep in range(3):
for frame in range(20):
row = {"episode_id": ep, "policy_name": ["A","B","C"][ep],
"success": int(frame == 19)}
for i in range(7):
row[f"state_{i}"] = round(np.random.randn() * 0.5, 4)
row[f"action_{i}"] = round(np.random.randn() * 0.5, 4)
rows.append(row)
df = pd.DataFrame(rows)
path = "/tmp/robot_eval_template.csv"
df.to_csv(path, index=False)
return path
SAMPLE_DATASETS = {
"ALOHA bimanual — cup opening (14-DOF)":
("lerobot/aloha_static_cups_open", "observation.state", "action", 20),
"Push-T real robot — tabletop push (8-DOF)":
("lerobot/columbia_cairlab_pusht_real", "observation.state", "action", 20),
"Franka Panda — free-play manipulation (13-DOF)":
("lerobot/nyu_franka_play_dataset", "observation.state", "action", 20),
"Unitree H1 humanoid — warehouse (19-DOF / 40-DOF action)":
("lerobot/unitreeh1_warehouse", "observation.state", "action", 12),
}
def download_sample(choice, progress=gr.Progress()):
if not choice:
return None
progress(0.1, desc=f"Loading {choice}…")
hf, sc, ac, max_eps = SAMPLE_DATASETS[choice]
ds = load_dataset(hf, split="train")
df_raw = ds.to_pandas()
ep_ids = sorted(df_raw["episode_index"].unique())[:max_eps]
rows = []
policy_name = choice.split("—")[0].strip()
progress(0.4, desc="Extracting episodes…")
for ei in ep_ids:
grp = df_raw[df_raw["episode_index"] == ei].reset_index(drop=True)
success = int(grp["next.reward"].max() > 0) if "next.reward" in grp.columns else 1
states = np.vstack(grp[sc].values)
actions = np.vstack(grp[ac].values)
for fi, (s, a) in enumerate(zip(states, actions)):
row = {"episode_id": int(ei), "policy_name": policy_name,
"frame_id": fi, "success": success}
for i, v in enumerate(s): row[f"state_{i}"] = round(float(v), 6)
for i, v in enumerate(a): row[f"action_{i}"] = round(float(v), 6)
rows.append(row)
path = f"/tmp/sample_{hf.split('/')[-1]}.csv"
pd.DataFrame(rows).to_csv(path, index=False)
progress(1.0, desc="Ready!")
return path
# ── UI ────────────────────────────────────────────────────────────────────────
CSS = f"""
:root {{
--bg: {BG}; --card: {CARD}; --border: {BORDER};
--text: {TEXT}; --sub: {SUBTEXT}; --accent: {ACCENT};
}}
body, .gradio-container {{ background: var(--bg) !important; color: var(--text) !important; }}
.gr-box, .gr-panel {{ background: var(--card) !important; border-color: var(--border) !important; }}
.gr-button-primary {{ background: var(--accent) !important; color: #0F172A !important; font-weight: 700; }}
.gr-button {{ border-color: var(--border) !important; color: var(--text) !important; }}
footer {{ display: none !important; }}
h1, h2, h3 {{ color: var(--text) !important; }}
label {{ color: var(--sub) !important; }}
.tab-nav button {{ color: var(--sub) !important; }}
.tab-nav button.selected {{ color: var(--accent) !important; border-color: var(--accent) !important; }}
"""
HEADER = """
Bayesian statistics · SPARC smoothness · STL safety constraints
Based on Kress-Gazit et al. (TRI/Cornell) · arXiv:2409.09491