Shubham-Rasal Claude Sonnet 4.6 commited on
Commit
654b826
Β·
1 Parent(s): 317e86c

Rewrite PhAIL loader: static.json-driven episode discovery

Browse files

Path structure is sample/inference/<batch>/<episode>/<signal>.parquet
not sample/inference/<model>/... β€” model name lives in static.json.

New approach:
- Find all static.json files under sample/inference/
- Load each for model name + outcome (success label)
- Load robot_state.q.parquet (7-DOF joint positions) as states
- Load robot_commands.pose.parquet as actions (fallback to states)
- Build policy_data keyed by actual model name from metadata

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +55 -82
app.py CHANGED
@@ -90,42 +90,15 @@ PHAIL_POLICIES = {
90
  "smolvla": "SmolVLA",
91
  }
92
 
93
- def _pick_joint_cols(ep):
94
- """Return (state_cols, action_cols) from a PhAIL episode dataframe."""
95
- cols = ep.columns.tolist()
96
- # Print once for debugging (visible in Space logs)
97
- print(f"[PhAIL] columns ({len(cols)}): {cols[:20]}")
98
-
99
- keywords_state = ["joint_pos", "q_pos", "position", "state", "obs"]
100
- keywords_action = ["joint_cmd", "q_cmd", "command", "action", "target"]
101
-
102
- def match(kws):
103
- return [c for c in cols if any(k in c.lower() for k in kws)]
104
-
105
- sc = match(keywords_state)
106
- ac = match(keywords_action)
107
-
108
- # Avoid overlap: if both lists share columns, prefer more specific
109
- if sc and ac and set(sc) & set(ac):
110
- ac = [c for c in ac if c not in sc]
111
-
112
- # Absolute fallback: split numeric columns in half
113
- if not sc:
114
- num = ep.select_dtypes(include=[np.number]).columns.tolist()
115
- mid = max(1, len(num) // 2)
116
- sc = num[:mid]
117
- ac = num[mid:] if not ac else ac
118
-
119
- if not ac:
120
- ac = sc # use state as action proxy (effort will be ~0)
121
-
122
- return sc, ac
123
-
124
-
125
  def load_phail_sample(progress=None):
126
  """
127
- Download the 20-episode stratified sample from phail-anon/phail-v1.0.
128
- Returns policy_data dict ready for run_analysis().
 
 
 
 
 
129
  """
130
  if "phail" in _cache:
131
  return _cache["phail"]
@@ -134,82 +107,82 @@ def load_phail_sample(progress=None):
134
  if progress is not None:
135
  progress(frac, desc=desc)
136
 
137
- _prog(0.05, "Listing PhAIL sample files on HuggingFace Hub…")
138
 
139
  all_files = list(list_repo_files("phail-anon/phail-v1.0", repo_type="dataset"))
140
- sample_parquets = sorted([f for f in all_files
141
- if f.startswith("sample/inference/") and f.endswith(".parquet")])
142
 
143
- print(f"[PhAIL] found {len(sample_parquets)} sample parquets")
144
- if not sample_parquets:
145
- raise ValueError("No parquet files found under sample/inference/ in phail-anon/phail-v1.0. "
146
- f"All files: {all_files[:30]}")
 
 
 
147
 
148
- _prog(0.15, f"Found {len(sample_parquets)} episodes β€” downloading…")
149
 
150
- policy_data = {label: {"trials": [], "speeds": [], "efforts": [], "zs": []}
151
- for label in PHAIL_POLICIES.values()}
152
 
153
- col_schema_logged = False
 
154
 
155
- for i, fpath in enumerate(sample_parquets):
156
- _prog(0.15 + 0.7 * (i / len(sample_parquets)),
157
- f"Episode {i+1}/{len(sample_parquets)}…")
158
 
159
- parts = fpath.split("/")
160
- # path: sample/inference/<model>/batch_X/episode_Y/something.parquet
161
- model_key = parts[2] if len(parts) > 2 else None
162
- label = PHAIL_POLICIES.get(model_key)
163
- if label is None:
164
- print(f"[PhAIL] skip {fpath} β€” model_key={model_key!r} not in PHAIL_POLICIES")
 
 
165
  continue
166
 
 
 
 
 
 
 
 
 
167
  try:
168
- local = hf_hub_download(repo_id="phail-anon/phail-v1.0",
169
- filename=fpath, repo_type="dataset")
170
- ep = pd.read_parquet(local)
171
  except Exception as exc:
172
- print(f"[PhAIL] failed to load {fpath}: {exc}")
173
  continue
174
 
175
- if not col_schema_logged:
176
- print(f"[PhAIL] full column list: {ep.columns.tolist()}")
177
- col_schema_logged = True
 
 
 
 
 
178
 
179
- sc, ac = _pick_joint_cols(ep)
180
- states = ep[sc].values.astype(float)
181
- actions = ep[ac].values.astype(float)
182
 
183
  if len(states) < 4:
184
- print(f"[PhAIL] skip {fpath} β€” only {len(states)} rows")
185
  continue
186
 
187
  speed, effort, z = extract_episode(states, actions)
188
 
189
- # Load success label from static.json in the same episode directory
190
- success = 0
191
- try:
192
- static_path = "/".join(parts[:-1]) + "/static.json"
193
- meta_local = hf_hub_download(repo_id="phail-anon/phail-v1.0",
194
- filename=static_path, repo_type="dataset")
195
- with open(meta_local) as f:
196
- meta = json.load(f)
197
- outcome = meta.get("eval", {}).get("outcome", "")
198
- success = 1 if outcome == "Success" else 0
199
- except Exception as exc:
200
- print(f"[PhAIL] no static.json for {fpath}: {exc}")
201
-
202
  policy_data[label]["trials"].append(success)
203
  policy_data[label]["speeds"].append(speed)
204
  policy_data[label]["efforts"].append(effort)
205
  policy_data[label]["zs"].append(z)
206
 
207
- policy_data = {k: v for k, v in policy_data.items() if v["trials"]}
208
- print(f"[PhAIL] loaded policies: { {k: len(v['trials']) for k, v in policy_data.items()} }")
209
 
210
  if not policy_data:
211
- raise ValueError(
212
- "PhAIL: no episodes loaded. Check Space logs for column names and path structure.")
213
 
214
  _prog(0.95, "Finalising…")
215
  _cache["phail"] = policy_data
 
90
  "smolvla": "SmolVLA",
91
  }
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def load_phail_sample(progress=None):
94
  """
95
+ Load PhAIL sample β€” path structure:
96
+ sample/inference/<batch_id>/<episode_id>/<signal>.parquet
97
+ sample/inference/<batch_id>/<episode_id>/static.json ← model name + outcome
98
+
99
+ Per-episode signals used:
100
+ robot_state.q.parquet β€” 7-DOF Franka joint positions (states)
101
+ robot_commands.pose.parquet β€” commanded EE pose (actions proxy)
102
  """
103
  if "phail" in _cache:
104
  return _cache["phail"]
 
107
  if progress is not None:
108
  progress(frac, desc=desc)
109
 
110
+ _prog(0.05, "Listing PhAIL sample files…")
111
 
112
  all_files = list(list_repo_files("phail-anon/phail-v1.0", repo_type="dataset"))
 
 
113
 
114
+ # Collect episode dirs that have static.json
115
+ static_files = sorted([f for f in all_files
116
+ if f.startswith("sample/inference/") and f.endswith("/static.json")])
117
+ print(f"[PhAIL] found {len(static_files)} episodes (static.json)")
118
+
119
+ if not static_files:
120
+ raise ValueError("No static.json files found under sample/inference/")
121
 
122
+ _prog(0.1, f"Found {len(static_files)} episodes β€” loading…")
123
 
124
+ policy_data = {} # built dynamically from actual model names in static.json
 
125
 
126
+ for i, sf in enumerate(static_files):
127
+ _prog(0.1 + 0.8 * (i / len(static_files)), f"Episode {i+1}/{len(static_files)}…")
128
 
129
+ ep_dir = sf[: -len("/static.json")] # e.g. sample/inference/000.../000.../
 
 
130
 
131
+ # ── load metadata ──────────────────────────────────────────────────────
132
+ try:
133
+ meta_local = hf_hub_download(repo_id="phail-anon/phail-v1.0",
134
+ filename=sf, repo_type="dataset")
135
+ with open(meta_local) as f:
136
+ meta = json.load(f)
137
+ except Exception as exc:
138
+ print(f"[PhAIL] skip {sf}: {exc}")
139
  continue
140
 
141
+ model = meta.get("model", meta.get("source", "unknown"))
142
+ outcome = meta.get("eval", {}).get("outcome", meta.get("outcome", ""))
143
+ success = 1 if outcome == "Success" else 0
144
+ # Map model key β†’ display label
145
+ label = PHAIL_POLICIES.get(model, model)
146
+
147
+ # ── load joint positions (states) ──────────────────────────────────────
148
+ q_path = ep_dir + "/robot_state.q.parquet"
149
  try:
150
+ q_local = hf_hub_download(repo_id="phail-anon/phail-v1.0",
151
+ filename=q_path, repo_type="dataset")
152
+ q_df = pd.read_parquet(q_local)
153
  except Exception as exc:
154
+ print(f"[PhAIL] no robot_state.q for {ep_dir}: {exc}")
155
  continue
156
 
157
+ # ── load commands (actions proxy) ──────────────────────────────────────
158
+ cmd_path = ep_dir + "/robot_commands.pose.parquet"
159
+ try:
160
+ cmd_local = hf_hub_download(repo_id="phail-anon/phail-v1.0",
161
+ filename=cmd_path, repo_type="dataset")
162
+ cmd_df = pd.read_parquet(cmd_local)
163
+ except Exception:
164
+ cmd_df = q_df # fall back to state = action (effort β‰ˆ 0)
165
 
166
+ states = q_df.select_dtypes(include=[np.number]).values.astype(float)
167
+ actions = cmd_df.select_dtypes(include=[np.number]).values.astype(float)
 
168
 
169
  if len(states) < 4:
170
+ print(f"[PhAIL] skip {ep_dir} β€” only {len(states)} rows")
171
  continue
172
 
173
  speed, effort, z = extract_episode(states, actions)
174
 
175
+ if label not in policy_data:
176
+ policy_data[label] = {"trials": [], "speeds": [], "efforts": [], "zs": []}
 
 
 
 
 
 
 
 
 
 
 
177
  policy_data[label]["trials"].append(success)
178
  policy_data[label]["speeds"].append(speed)
179
  policy_data[label]["efforts"].append(effort)
180
  policy_data[label]["zs"].append(z)
181
 
182
+ print(f"[PhAIL] loaded: { {k: len(v['trials']) for k, v in policy_data.items()} }")
 
183
 
184
  if not policy_data:
185
+ raise ValueError("PhAIL: no episodes loaded β€” check logs for path/schema details.")
 
186
 
187
  _prog(0.95, "Finalising…")
188
  _cache["phail"] = policy_data