MSG commited on
Commit
871f869
·
1 Parent(s): 8c6b423

Feat/last sprint (#12)

Browse files

* hf docker build

* gradio sdk deploy for zerogpu mode

* gradio structure wip

* server gpu and tasks

* fix sdk mode

* wip studio fix gradio

* fix studio

* index fix ui

* ui coach

* ui voice

* index ui

* cleaning stuff

* clean

* clean stuff experiment

* ui css app

* fix css

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .cursor/plans/gradio_sdk_deploy_58daaf6e.plan.md +268 -0
  2. .cursor/plans/hf_space_publish_e8a57bab.plan.md +208 -0
  3. .env.example +0 -10
  4. .gitignore +1 -0
  5. Dockerfile +2 -0
  6. README.md +13 -14
  7. USAGE.md +86 -68
  8. app.py +6 -0
  9. apps/gradio-space/src/gradio_space/model_loading.py +3 -0
  10. apps/gradio-space/src/gradio_space/research_helpers.py +2 -0
  11. apps/gradio-space/src/gradio_space/server.py +3 -1
  12. apps/gradio-space/src/gradio_space/spaces_runtime.py +37 -0
  13. apps/gradio-space/src/gradio_space/tabs/echo_coach.py +2 -0
  14. apps/gradio-space/src/gradio_space/tabs/education_pptx.py +3 -0
  15. apps/gradio-space/src/gradio_space/tabs/research_mind.py +4 -0
  16. apps/gradio-space/src/gradio_space/tabs/teacher_voice.py +3 -0
  17. apps/gradio-space/static/studio/index.html +111 -82
  18. apps/gradio-space/static/studio/studio.css +283 -3
  19. models.yaml +0 -6
  20. packages.txt +2 -0
  21. pyproject.toml +0 -8
  22. requirements.txt +32 -0
  23. research/README.md +6 -10
  24. research/USAGE.md +8 -99
  25. research/docs/overview.md +11 -49
  26. research/ensemble/README.md +0 -113
  27. research/ensemble/pyproject.toml +0 -16
  28. research/ensemble/scripts/smoke.sh +0 -35
  29. research/ensemble/src/ensemble/__init__.py +0 -15
  30. research/ensemble/src/ensemble/backends.py +0 -418
  31. research/ensemble/src/ensemble/bridge.py +0 -28
  32. research/ensemble/src/ensemble/checkpoint.py +0 -149
  33. research/ensemble/src/ensemble/config.py +0 -163
  34. research/ensemble/src/ensemble/energy.py +0 -45
  35. research/ensemble/src/ensemble/eval/__init__.py +0 -1
  36. research/ensemble/src/ensemble/eval/jepa_harness.py +0 -266
  37. research/ensemble/src/ensemble/eval/metrics.py +0 -42
  38. research/ensemble/src/ensemble/eval/world_harness.py +0 -174
  39. research/ensemble/src/ensemble/eval_harness.py +0 -309
  40. research/ensemble/src/ensemble/jepa.py +0 -75
  41. research/ensemble/src/ensemble/jepa_ensemble.py +0 -232
  42. research/ensemble/src/ensemble/llm_emb_jepa_ensemble_pluggable.py +0 -507
  43. research/ensemble/src/ensemble/memory.py +0 -46
  44. research/ensemble/src/ensemble/pretrain.py +0 -198
  45. research/ensemble/src/ensemble/world_ensemble.py +0 -228
  46. research/ensemble/src/ensemble/world_model.py +0 -40
  47. research/ensemble/src/ensemble/world_model_ensemble.py +0 -499
  48. research/eval_harness.py +0 -6
  49. research/evals/USAGE.md +2 -14
  50. research/evals/configs/ensemble_jepa_lesson.yaml +0 -24
.cursor/plans/gradio_sdk_deploy_58daaf6e.plan.md ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Gradio SDK Deploy
3
+ overview: "Add Gradio SDK deployment files on main alongside the existing Dockerfile, switch README to `sdk: gradio` for ZeroGPU Spaces, and add `@spaces.GPU` wrappers on LLM entry points so the full Studio + Classic app runs on HF without Docker."
4
+ todos:
5
+ - id: root-gradio-files
6
+ content: Add root app.py, requirements.txt, packages.txt with editable workspace installs and Debian deps
7
+ status: completed
8
+ - id: readme-gradio-sdk
9
+ content: "Fix README YAML frontmatter and switch to sdk: gradio (sdk_version 6.16.0, app_file: app.py)"
10
+ status: completed
11
+ - id: spaces-runtime
12
+ content: Add gradio_space/spaces_runtime.py with gpu_task decorator and is_hf_gradio_runtime()
13
+ status: completed
14
+ - id: zerogpu-decorators
15
+ content: Apply @gpu_task to LLM entry points in model_loading, research_helpers, and tab handlers; skip preload on HF Gradio runtime in server.py
16
+ status: completed
17
+ - id: usage-docs
18
+ content: Update USAGE.md with Gradio SDK + ZeroGPU deploy steps; demote Docker section to later phase
19
+ status: completed
20
+ - id: local-smoke
21
+ content: Validate pip install + python app.py locally before pushing to HF Space
22
+ status: completed
23
+ - id: hf-space-create
24
+ content: Create Gradio Space under build-small-hackathon with ZeroGPU hardware and env vars; verify Studio + Classic smoke tests
25
+ status: cancelled
26
+ isProject: false
27
+ ---
28
+
29
+ # Gradio SDK + ZeroGPU deployment (same branch as Docker)
30
+
31
+ ## Goal
32
+
33
+ Ship the **full app** (Studio at `/`, Classic at `/classic`, all tabs) via **Gradio SDK** on Hugging Face with **ZeroGPU**, while keeping [`Dockerfile`](Dockerfile) on `main` untouched for a later Docker Space phase.
34
+
35
+ ## Same-branch constraint (important)
36
+
37
+ HF reads **one** `sdk:` value from root [`README.md`](README.md). Both deploy paths can live on the same branch as files, but **only one SDK is active per branch at a time**:
38
+
39
+ | Files on `main` | Active when README says |
40
+ |-----------------|-------------------------|
41
+ | `app.py`, `requirements.txt`, `packages.txt` | `sdk: gradio` |
42
+ | [`Dockerfile`](Dockerfile) | `sdk: docker` + `app_port: 7860` |
43
+
44
+ **Phase 1 (now):** set `sdk: gradio` — Gradio Space builds from `app.py`.
45
+ **Phase 2 (later):** flip README to `sdk: docker` for Docker Space, or use a **second HF Space on a second branch** if you need both live at once.
46
+
47
+ ```mermaid
48
+ flowchart TB
49
+ subgraph repo [main branch]
50
+ AppPy[app.py]
51
+ ReqTxt[requirements.txt]
52
+ DockerFile[Dockerfile]
53
+ Shared[apps/gradio-space + libs + skills]
54
+ end
55
+ subgraph phase1 [Phase 1 active]
56
+ ReadmeG[sdk: gradio in README]
57
+ HFGradio[HF Gradio SDK build]
58
+ ZeroGPU[ZeroGPU hardware]
59
+ end
60
+ subgraph phase2 [Phase 2 later]
61
+ ReadmeD[sdk: docker in README]
62
+ HFDocker[HF Docker build]
63
+ GPUBasic[GPU Basic hardware]
64
+ end
65
+ Shared --> AppPy
66
+ Shared --> DockerFile
67
+ ReadmeG --> HFGradio --> ZeroGPU
68
+ ReadmeD --> HFDocker --> GPUBasic
69
+ ```
70
+
71
+ ---
72
+
73
+ ## Phase 1 — Add Gradio SDK root files
74
+
75
+ ### 1. Root [`app.py`](app.py)
76
+
77
+ Thin entry point that reuses the existing server (no UI rewrite):
78
+
79
+ ```python
80
+ from gradio_space.server import main
81
+
82
+ if __name__ == "__main__":
83
+ main()
84
+ ```
85
+
86
+ HF Gradio SDK executes `app.py`; [`server.py`](apps/gradio-space/src/gradio_space/server.py) already calls `server.launch()` on port 7860 with Studio + `/classic`.
87
+
88
+ ### 2. Root [`requirements.txt`](requirements.txt)
89
+
90
+ Pip-install workspace packages via editable paths (HF clones the full repo):
91
+
92
+ ```text
93
+ -e ./libs/inference
94
+ -e ./libs/researchmind
95
+ -e ./libs/agent
96
+ -e ./libs/echocoach[piper,whisper]
97
+ -e ./apps/gradio-space
98
+ # plus transitive deps from libs/*/pyproject.toml (torch, transformers, sentence-transformers, python-pptx, etc.)
99
+ ```
100
+
101
+ Rules (per [HF Spaces dependencies](https://huggingface.co/docs/hub/spaces-dependencies)):
102
+
103
+ - **Do not pin** `gradio`, `spaces`, or `huggingface_hub` — HF preinstalls them.
104
+ - Pin heavy libs that matter for reproducibility: `torch`, `transformers`, `accelerate`, `sentence-transformers`, etc.
105
+ - Keep `llama-cpp-python` for preset parity (HF image has `cmake`; build may be slow).
106
+
107
+ Optional: add [`scripts/sync-requirements.sh`](scripts/sync-requirements.sh) later to regenerate from `pyproject.toml` files — not required for v1.
108
+
109
+ ### 3. Root [`packages.txt`](packages.txt)
110
+
111
+ Debian deps beyond HF defaults (mirror [`Dockerfile`](Dockerfile) apt lines):
112
+
113
+ ```text
114
+ ffmpeg
115
+ libsndfile1
116
+ ```
117
+
118
+ ### 4. Fix + switch README frontmatter
119
+
120
+ Current [`README.md`](README.md) has a blank line after `---` and still declares Docker. Update to:
121
+
122
+ ```yaml
123
+ ---
124
+ title: Lesson Agent
125
+ emoji: 📚
126
+ colorFrom: blue
127
+ colorTo: green
128
+ sdk: gradio
129
+ sdk_version: "6.16.0"
130
+ app_file: app.py
131
+ python_version: "3.12"
132
+ pinned: false
133
+ license: apache-2.0
134
+ ---
135
+ ```
136
+
137
+ Remove `app_port` (Docker-only). Keep [`Dockerfile`](Dockerfile) in repo for phase 2.
138
+
139
+ ---
140
+
141
+ ## Phase 2 — ZeroGPU runtime hooks
142
+
143
+ ZeroGPU requires all CUDA work inside `@spaces.GPU`. The decorator is a **no-op** locally and on dedicated GPU Spaces, so it is safe to apply everywhere.
144
+
145
+ ### New module: [`apps/gradio-space/src/gradio_space/spaces_runtime.py`](apps/gradio-space/src/gradio_space/spaces_runtime.py)
146
+
147
+ ```python
148
+ def gpu_task(*, duration: int = 180, size: str = "large"):
149
+ """Apply @spaces.GPU when the HF spaces runtime is present."""
150
+ ...
151
+
152
+ def is_hf_gradio_runtime() -> bool:
153
+ """True on HF Gradio SDK Spaces (skip startup model preload)."""
154
+ ...
155
+ ```
156
+
157
+ Use `duration=180`–`300` for agent/slide flows; `duration=60` for simple chat.
158
+
159
+ ### Skip startup preload on HF Gradio runtime
160
+
161
+ [`server.py`](apps/gradio-space/src/gradio_space/server.py) currently calls `preload_active_model()` before launch — this fails on ZeroGPU (no GPU at process start):
162
+
163
+ ```69:69:apps/gradio-space/src/gradio_space/server.py
164
+ preload_active_model()
165
+ ```
166
+
167
+ Change to:
168
+
169
+ ```python
170
+ if not is_hf_gradio_runtime():
171
+ preload_active_model()
172
+ ```
173
+
174
+ First user request lazy-loads inside a `@spaces.GPU`-decorated handler.
175
+
176
+ ### Decorate LLM entry points (not every `backend.chat` call)
177
+
178
+ Wrap **top-level handlers** so multi-step agent loops run inside one GPU allocation:
179
+
180
+ | Module | Functions to decorate |
181
+ |--------|----------------------|
182
+ | [`model_loading.py`](apps/gradio-space/src/gradio_space/model_loading.py) | `chat`, `reload_model` |
183
+ | [`research_helpers.py`](apps/gradio-space/src/gradio_space/research_helpers.py) | `run_research_question`, `rag_aware_chat` |
184
+ | [`tabs/education_pptx.py`](apps/gradio-space/src/gradio_space/tabs/education_pptx.py) | `generate_lesson_slides`, `discover_lesson_sources` |
185
+ | [`tabs/research_mind.py`](apps/gradio-space/src/gradio_space/tabs/research_mind.py) | `discover_sources`, `ask_question`, `auto_search_ingest` |
186
+ | [`tabs/echo_coach.py`](apps/gradio-space/src/gradio_space/tabs/echo_coach.py) | `analyze_pitch` |
187
+ | [`tabs/teacher_voice.py`](apps/gradio-space/src/gradio_space/tabs/teacher_voice.py) | text/audio turn handlers |
188
+
189
+ Studio APIs in [`api/studio.py`](apps/gradio-space/src/gradio_space/api/studio.py) call these helpers — decorating the tab/helper layer avoids duplicating decorators on ~20 API wrappers.
190
+
191
+ **Generator caveat:** `generate_lesson_slides` uses `yield` for progress. If ZeroGPU rejects generator handlers, extract GPU work into a plain `@gpu_task` function and keep the outer generator for UI progress only (test on Space Logs after first deploy).
192
+
193
+ **Embeddings (ResearchMind ingest):** sentence-transformers can stay on CPU for v1; only LLM paths need `@spaces.GPU` initially.
194
+
195
+ ---
196
+
197
+ ## Phase 3 — Space configuration
198
+
199
+ Create Space under [build-small-hackathon](https://huggingface.co/build-small-hackathon):
200
+
201
+ | Setting | Value |
202
+ |---------|-------|
203
+ | SDK | **Gradio** (Blank template) |
204
+ | Hardware | **ZeroGPU** (creator needs PRO/Team) |
205
+ | Repo | GitHub `main` (or push to Space git) |
206
+
207
+ **Environment variables** (Settings → Variables):
208
+
209
+ | Variable | Value |
210
+ |----------|-------|
211
+ | `ACTIVE_MODEL` | `minicpm5-1b` |
212
+ | `ALLOW_MODEL_SWITCH` | `false` |
213
+ | `RESEARCHMIND_DATA_DIR` | `/tmp/researchmind` |
214
+
215
+ Default preset in [`models.yaml`](models.yaml) is already `minicpm5-1b` (transformers) — good fit for ZeroGPU.
216
+
217
+ ---
218
+
219
+ ## Phase 4 — Docs and local smoke test
220
+
221
+ Update [`USAGE.md`](USAGE.md):
222
+
223
+ - New **Gradio SDK deployment** section (primary path): `app.py`, `requirements.txt`, ZeroGPU, env vars.
224
+ - Move existing Docker section to **"Docker SDK (later)"** — note README must switch to `sdk: docker` + `app_port: 7860`.
225
+ - Local Gradio SDK smoke test:
226
+
227
+ ```bash
228
+ python -m venv .venv && source .venv/bin/activate
229
+ pip install -r requirements.txt
230
+ ACTIVE_MODEL=minicpm5-1b ALLOW_MODEL_SWITCH=false python app.py
231
+ ```
232
+
233
+ Keep existing `uv run` workflow for day-to-day dev unchanged.
234
+
235
+ Update [`.cursor/plans/hf_space_publish_e8a57bab.plan.md`](.cursor/plans/hf_space_publish_e8a57bab.plan.md) todos to reflect Gradio-first ordering.
236
+
237
+ ---
238
+
239
+ ## Phase 5 — Verify on Space
240
+
241
+ 1. **Logs** — pip install succeeds; app starts on `0.0.0.0:7860`.
242
+ 2. **`/` Studio** — loads static UI.
243
+ 3. **`/classic`** — all tabs render.
244
+ 4. **Smoke flows** — slides generation, research chat, EchoCoach sample clip, teacher voice text turn.
245
+ 5. **ZeroGPU** — first LLM request allocates GPU (may be slow on cold start); watch for "No CUDA GPUs" (means handler is outside `@spaces.GPU`).
246
+
247
+ ---
248
+
249
+ ## Phase 6 — Docker later (no code removal)
250
+
251
+ When ready for Docker Space:
252
+
253
+ 1. Change README to `sdk: docker`, `app_port: 7860` (remove `sdk_version` / `app_file`).
254
+ 2. Create a **second Space** (or reuse after README flip) with **GPU Basic** hardware.
255
+ 3. Existing [`Dockerfile`](Dockerfile) + `uv sync` path unchanged; no `@spaces.GPU` needed on dedicated GPU.
256
+
257
+ Both file sets remain on `main`; only README `sdk:` toggles which build HF runs.
258
+
259
+ ---
260
+
261
+ ## Risk notes
262
+
263
+ | Risk | Mitigation |
264
+ |------|------------|
265
+ | `pip install llama-cpp-python` slow/fails on HF | Accept slow build; default `minicpm5-1b` avoids GGUF at runtime |
266
+ | EchoCoach deps (piper, whisper) heavy | Full scope requested; pin versions; fix from Space Logs if needed |
267
+ | ZeroGPU + generator slide progress | Refactor GPU block to non-generator helper if build succeeds but inference fails |
268
+ | Two live Spaces same branch | Not supported with different SDKs — use README flip or second branch for concurrent Docker + Gradio |
.cursor/plans/hf_space_publish_e8a57bab.plan.md ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: HF Space Publish
3
+ overview: Fix two repo blockers (README Space card YAML and missing `researchmind` in Dockerfile), validate locally with Docker, push to GitHub, then create a Docker Space under build-small-hackathon linked to GitHub with GPU hardware and MiniCPM5-1B env vars.
4
+ todos:
5
+ - id: fix-readme-yaml
6
+ content: "Fix root README.md frontmatter: change `## title:` to `title:`"
7
+ status: completed
8
+ - id: fix-dockerfile-researchmind
9
+ content: Add libs/researchmind COPY lines to Dockerfile
10
+ status: completed
11
+ - id: local-docker-smoke
12
+ content: Run docker build + docker run locally on port 7860 with ACTIVE_MODEL=minicpm5-1b
13
+ status: in_progress
14
+ - id: push-github
15
+ content: Push fixed branch to GitHub repo
16
+ status: pending
17
+ - id: create-space
18
+ content: Create Docker Space under build-small-hackathon, link GitHub, set GPU basic
19
+ status: pending
20
+ - id: configure-env
21
+ content: "Set Space secrets: ACTIVE_MODEL=minicpm5-1b, ALLOW_MODEL_SWITCH=false, RESEARCHMIND_DATA_DIR=/tmp/researchmind"
22
+ status: pending
23
+ - id: verify-live
24
+ content: Check Space Logs, test / and /classic, confirm slide generation works
25
+ status: pending
26
+ isProject: false
27
+ ---
28
+
29
+ # Publish Gradio app to Hugging Face Space
30
+
31
+ ## Current state
32
+
33
+ Your repo is **mostly ready** for a Docker Space:
34
+
35
+ - Root [`Dockerfile`](Dockerfile) exposes port **7860** and runs `python -m gradio_space.app`
36
+ - Root [`README.md`](README.md) has Space metadata (`sdk: docker`, `app_port: 7860`)
37
+ - Default model in [`models.yaml`](models.yaml) is **`minicpm5-1b`** (transformers, `openbmb/MiniCPM5-1B`)
38
+
39
+ Two issues will likely **break the Space build or card** until fixed:
40
+
41
+ ### Blocker 1 — README YAML is malformed
42
+
43
+ The Space card frontmatter must use `title:`, not a markdown heading:
44
+
45
+ ```yaml
46
+ # Current (wrong)
47
+ ## title: Lesson Agent
48
+
49
+ # Required (correct)
50
+ title: Lesson Agent
51
+ ```
52
+
53
+ HF reads YAML from the **root** [`README.md`](README.md) only. Keep [`apps/gradio-space/README.md`](apps/gradio-space/README.md) as dev docs.
54
+
55
+ ### Blocker 2 — Dockerfile missing `researchmind`
56
+
57
+ [`libs/agent`](libs/agent/pyproject.toml) depends on `researchmind`, but the Dockerfile only copies `inference`, `agent`, and `echocoach`. `uv sync` inside the image will fail without:
58
+
59
+ ```dockerfile
60
+ COPY libs/researchmind/pyproject.toml libs/researchmind/README.md libs/researchmind/
61
+ COPY libs/researchmind/src libs/researchmind/src
62
+ ```
63
+
64
+ Add these lines alongside the other `libs/*` COPY blocks in [`Dockerfile`](Dockerfile).
65
+
66
+ ---
67
+
68
+ ## Architecture (what gets deployed)
69
+
70
+ ```mermaid
71
+ flowchart LR
72
+ subgraph hf [HuggingFaceSpace]
73
+ DockerBuild[DockerBuild]
74
+ Container[Container_port7860]
75
+ end
76
+ GitHub[GitHub_repo] --> DockerBuild
77
+ DockerBuild --> Container
78
+ Container --> StudioUI["/ Studio UI"]
79
+ Container --> ClassicUI["/classic Gradio tabs"]
80
+ Container --> HubModel["Hub: openbmb/MiniCPM5-1B"]
81
+ HubModel --> Container
82
+ ```
83
+
84
+ Entrypoint (unchanged):
85
+
86
+ ```44:44:Dockerfile
87
+ CMD ["uv", "run", "--package", "gradio-space", "python", "-m", "gradio_space.app"]
88
+ ```
89
+
90
+ This launches [`gradio_space.server`](apps/gradio-space/src/gradio_space/server.py): Studio at `/`, Classic tabs at `/classic`.
91
+
92
+ ---
93
+
94
+ ## Phase 1 — Fix repo files (before push)
95
+
96
+ | File | Change |
97
+ |------|--------|
98
+ | [`README.md`](README.md) | Fix frontmatter: `title: Lesson Agent` (remove `##`); keep `sdk: docker`, `app_port: 7860` |
99
+ | [`Dockerfile`](Dockerfile) | Add `libs/researchmind` pyproject + src COPY lines |
100
+
101
+ Optional but recommended in README frontmatter (already present except title):
102
+
103
+ ```yaml
104
+ ---
105
+ title: Lesson Agent
106
+ emoji: 📚
107
+ colorFrom: blue
108
+ colorTo: green
109
+ sdk: docker
110
+ app_port: 7860
111
+ pinned: false
112
+ license: apache-2.0
113
+ ---
114
+ ```
115
+
116
+ ---
117
+
118
+ ## Phase 2 — Validate locally with Docker
119
+
120
+ From repo root:
121
+
122
+ ```bash
123
+ docker build -t hackathon-space .
124
+ docker run --rm -p 7860:7860 \
125
+ -e ACTIVE_MODEL=minicpm5-1b \
126
+ -e ALLOW_MODEL_SWITCH=false \
127
+ hackathon-space
128
+ ```
129
+
130
+ Open [http://localhost:7860](http://localhost:7860) (`/` Studio, `/classic` tabs). First model load downloads weights from Hub — expect several minutes on first run.
131
+
132
+ If build fails, check Logs for `researchmind` or `uv sync` errors (confirms Blocker 2 fix).
133
+
134
+ ---
135
+
136
+ ## Phase 3 — Push to GitHub
137
+
138
+ 1. Create a GitHub repo (if not already linked)
139
+ 2. Push `main` with at minimum:
140
+ - `Dockerfile`, `README.md`, `pyproject.toml`, `uv.lock`
141
+ - `apps/gradio-space/`, `libs/`, `skills/`, `models.yaml`, `voice_models.yaml`
142
+
143
+ Do **not** commit `.env`, local `models/*.gguf`, or large artifacts (`.dockerignore` already excludes these).
144
+
145
+ ---
146
+
147
+ ## Phase 4 — Create and link the Space
148
+
149
+ 1. Go to [build-small-hackathon](https://huggingface.co/build-small-hackathon) → **New Space**
150
+ 2. Settings:
151
+ - **Name:** e.g. `lesson-agent` or `small-model-hackathon`
152
+ - **SDK:** **Docker** (not Gradio SDK — monorepo needs root Dockerfile)
153
+ - **Hardware:** **GPU basic** (required for transformers `minicpm5-1b`)
154
+ 3. Under **Repository** → connect your GitHub repo and branch (`main`)
155
+ 4. HF will auto-build from root `Dockerfile` on each push
156
+
157
+ ---
158
+
159
+ ## Phase 5 — Space environment variables
160
+
161
+ In Space **Settings → Variables and secrets** (Repository secrets, not `.env` in git):
162
+
163
+ | Variable | Value | Why |
164
+ |----------|-------|-----|
165
+ | `ACTIVE_MODEL` | `minicpm5-1b` | Pins model for visitors |
166
+ | `ALLOW_MODEL_SWITCH` | `false` | Hides dev model dropdown |
167
+ | `AGENT_OUTPUTS_DIR` | `/tmp/agent_outputs` | Already set in Dockerfile; optional override |
168
+ | `RESEARCHMIND_DATA_DIR` | `/tmp/researchmind` | Ephemeral RAG store on Space (recommended) |
169
+
170
+ No secrets required for the default MiniCPM5 preset unless you switch to a gated model.
171
+
172
+ ---
173
+
174
+ ## Phase 6 — Verify publish
175
+
176
+ 1. Open Space **Logs** — wait for `Running on local URL: 0.0.0.0:7860`
177
+ 2. Open the Space URL
178
+ 3. Smoke test:
179
+ - `/` — Studio loads
180
+ - Generate slides with a simple topic (e.g. "Photosynthesis, grade 8, 5 slides")
181
+ - `/classic` — tabs render
182
+ 4. First inference may be slow while `openbmb/MiniCPM5-1B` downloads
183
+
184
+ ### Optional: faster restarts
185
+
186
+ If cold starts are painful, add a **Storage Bucket** in Space settings so Hub model cache persists across restarts.
187
+
188
+ ---
189
+
190
+ ## Troubleshooting
191
+
192
+ | Symptom | Fix |
193
+ |---------|-----|
194
+ | Space card shows wrong title / no Docker | Fix README YAML (`title:` not `## title:`) |
195
+ | Docker build fails at `uv sync` | Add `researchmind` to Dockerfile |
196
+ | Build OK but app crashes on Research tab | Confirm `researchmind` src is copied |
197
+ | First request very slow | Normal — model download; use Storage Bucket |
198
+ | OOM on GPU | Try smaller batch or switch preset to GGUF on CPU |
199
+
200
+ Full reference: [`USAGE.md`](USAGE.md) sections "Docker smoke test" and "Hugging Face Space deployment".
201
+
202
+ ---
203
+
204
+ ## What you do NOT need
205
+
206
+ - Plain Gradio SDK (`app.py` + `requirements.txt` at root) — wrong fit for this monorepo
207
+ - Committing GGUF files — models download from Hub at runtime via `ACTIVE_MODEL` / `models.yaml`
208
+ - Changing the CMD — current entrypoint already serves Studio + Classic
.env.example CHANGED
@@ -66,14 +66,4 @@ ALLOW_MODEL_SWITCH=false
66
  # For Cohere Transcribe ASR: huggingface-cli login + accept model terms, then:
67
  # ECHOCOACH_ASR_PRESET=cohere-transcribe
68
 
69
- # --- Ensemble research (research/ensemble/) ---
70
- # Base LLM resolution (first match wins): ENSEMBLE_LLM, LLM_PATH, BASE, MODEL_ID, ACTIVE_MODEL
71
- # LLM_PATH=./models/finetuned/minicpm5-1b-lora-merged
72
- # ENSEMBLE_LLM=Qwen/Qwen2.5-0.5B-Instruct
73
- # ENSEMBLE_PRESET=minicpm5-1b
74
- # ENSEMBLE_OUT=./models/ensemble/minicpm5-1b-jepa-pretrain
75
- # ENSEMBLE_QA=./research/data/benchmark-qa.jsonl
76
- # ENSEMBLE_KB=./research/data/benchmark-kb.jsonl
77
- # ENSEMBLE_CKPT=./models/ensemble/jepa-lesson-pretrain
78
-
79
  BASE=openbmb/MiniCPM5-1B
 
66
  # For Cohere Transcribe ASR: huggingface-cli login + accept model terms, then:
67
  # ECHOCOACH_ASR_PRESET=cohere-transcribe
68
 
 
 
 
 
 
 
 
 
 
 
69
  BASE=openbmb/MiniCPM5-1B
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  .venv/
 
2
  __pycache__/
3
  *.py[cod]
4
  .env
 
1
  .venv/
2
+ .venv-gradio/
3
  __pycache__/
4
  *.py[cod]
5
  .env
Dockerfile CHANGED
@@ -20,11 +20,13 @@ COPY apps/gradio-space/pyproject.toml apps/gradio-space/README.md apps/gradio-sp
20
  COPY libs/inference/pyproject.toml libs/inference/README.md libs/inference/
21
  COPY libs/agent/pyproject.toml libs/agent/README.md libs/agent/
22
  COPY libs/echocoach/pyproject.toml libs/echocoach/README.md libs/echocoach/
 
23
  COPY apps/gradio-space/src apps/gradio-space/src
24
  COPY apps/gradio-space/static apps/gradio-space/static
25
  COPY libs/inference/src libs/inference/src
26
  COPY libs/agent/src libs/agent/src
27
  COPY libs/echocoach/src libs/echocoach/src
 
28
  COPY skills skills
29
 
30
  RUN useradd -m -u 1000 user && \
 
20
  COPY libs/inference/pyproject.toml libs/inference/README.md libs/inference/
21
  COPY libs/agent/pyproject.toml libs/agent/README.md libs/agent/
22
  COPY libs/echocoach/pyproject.toml libs/echocoach/README.md libs/echocoach/
23
+ COPY libs/researchmind/pyproject.toml libs/researchmind/README.md libs/researchmind/
24
  COPY apps/gradio-space/src apps/gradio-space/src
25
  COPY apps/gradio-space/static apps/gradio-space/static
26
  COPY libs/inference/src libs/inference/src
27
  COPY libs/agent/src libs/agent/src
28
  COPY libs/echocoach/src libs/echocoach/src
29
+ COPY libs/researchmind/src libs/researchmind/src
30
  COPY skills skills
31
 
32
  RUN useradd -m -u 1000 user && \
README.md CHANGED
@@ -1,13 +1,15 @@
1
  ---
2
-
3
- ## title: Lesson Agent
4
  emoji: 📚
5
  colorFrom: blue
6
  colorTo: green
7
- sdk: docker
8
- app_port: 7860
 
 
9
  pinned: false
10
  license: apache-2.0
 
11
 
12
  # Lesson Agent
13
 
@@ -15,7 +17,7 @@ license: apache-2.0
15
 
16
  A local skill-based agent helps a teacher you know turn a **topic + grade level** into a downloadable **PowerPoint** — powered by a small transformers model (`MiniCPM5-1B` by default), no cloud LLM API.
17
 
18
- See **[USAGE.md](USAGE.md)** for local run, Docker smoke test, and HF Space deployment.
19
 
20
  ## Prerequisites
21
 
@@ -59,7 +61,7 @@ libs/agent/ # Skill agent runner, tools, trace recorder
59
  libs/researchmind/ # Scraper, chunk/embed, MemRAG SQLite store, retrieval
60
  libs/inference/ # Transformers + llama.cpp backends
61
  skills/ # SKILL.md + references/ + scripts/ per task
62
- research/ # Fine-tune, ensemble experiments, agentic evals (optional)
63
  ```
64
 
65
  ### ResearchMind (offline after ingest)
@@ -87,15 +89,12 @@ See [`.env.example`](.env.example) and [`models.yaml`](models.yaml) for model pr
87
 
88
  ## Hugging Face Space deployment
89
 
90
- 1. Create a Space under [build-small-hackathon](https://huggingface.co/build-small-hackathon) with **Docker** SDK.
91
- 2. Link this repository (root `Dockerfile` + root `README.md` YAML above).
92
- 3. Hardware: **GPU basic** recommended for transformers (`minicpm5-1b`).
93
- 4. Optional secrets: `ACTIVE_MODEL`, `N_GPU_LAYERS` (if using GGUF preset).
94
 
95
- ```bash
96
- docker build -t hackathon-space .
97
- docker run --rm -p 7860:7860 -e ACTIVE_MODEL=minicpm5-1b hackathon-space
98
- ```
99
 
100
  ## Hackathon checklist
101
 
 
1
  ---
2
+ title: Lesson Agent
 
3
  emoji: 📚
4
  colorFrom: blue
5
  colorTo: green
6
+ sdk: gradio
7
+ sdk_version: "6.16.0"
8
+ app_file: app.py
9
+ python_version: "3.12"
10
  pinned: false
11
  license: apache-2.0
12
+ ---
13
 
14
  # Lesson Agent
15
 
 
17
 
18
  A local skill-based agent helps a teacher you know turn a **topic + grade level** into a downloadable **PowerPoint** — powered by a small transformers model (`MiniCPM5-1B` by default), no cloud LLM API.
19
 
20
+ See **[USAGE.md](USAGE.md)** for local run, Gradio SDK / ZeroGPU Space deployment, and Docker (later).
21
 
22
  ## Prerequisites
23
 
 
61
  libs/researchmind/ # Scraper, chunk/embed, MemRAG SQLite store, retrieval
62
  libs/inference/ # Transformers + llama.cpp backends
63
  skills/ # SKILL.md + references/ + scripts/ per task
64
+ research/ # Fine-tune and agentic evals (optional)
65
  ```
66
 
67
  ### ResearchMind (offline after ingest)
 
89
 
90
  ## Hugging Face Space deployment
91
 
92
+ 1. Create a Space under [build-small-hackathon](https://huggingface.co/build-small-hackathon) with **Gradio** SDK (Blank template).
93
+ 2. Link this repository — HF builds from root `app.py` + `requirements.txt` (README YAML above).
94
+ 3. Hardware: **ZeroGPU** for burst GPU inference, or **GPU basic** for always-on GPU.
95
+ 4. Set `ACTIVE_MODEL=minicpm5-1b`, `ALLOW_MODEL_SWITCH=false`, `RESEARCHMIND_DATA_DIR=/tmp/researchmind`.
96
 
97
+ A root `Dockerfile` is kept for a later **Docker SDK** deploy (flip README to `sdk: docker`). See [USAGE.md](USAGE.md).
 
 
 
98
 
99
  ## Hackathon checklist
100
 
USAGE.md CHANGED
@@ -1,6 +1,6 @@
1
  # Usage
2
 
3
- How to run the **Lesson Agent** Gradio app locally, test it in Docker, and deploy to a Hugging Face Space for the [Build Small Hackathon](https://huggingface.co/build-small-hackathon).
4
 
5
  The primary UI is the **Lesson slides** tab (topic → local model outline → downloadable `.pptx`). Use **ResearchMind** for corpus Q&A, **TeacherVoice** for spoken back-and-forth tutoring, **EchoCoach** for one-shot pitch analysis, or ground lessons directly from the Lesson tab. The **Chat (debug)** tab tests the underlying model.
6
 
@@ -223,98 +223,121 @@ INFERENCE_BACKEND=transformers MODEL_ID=Qwen/Qwen2.5-3B-Instruct \
223
 
224
  ---
225
 
226
- ## Docker (local prod-like test)
227
 
228
- Run the same container image HF Spaces will build:
229
 
230
  ```bash
231
- docker build -t hackathon-space .
232
- docker run --rm -p 7860:7860 \
233
- -e MODEL_REPO=Qwen/Qwen2.5-3B-Instruct-GGUF \
234
- -e MODEL_FILE=qwen2.5-3b-instruct-q4_k_m.gguf \
235
- -e N_CTX=4096 \
236
- -e N_GPU_LAYERS=0 \
237
- hackathon-space
238
  ```
239
 
240
- Open [http://localhost:7860](http://localhost:7860) — Studio at `/`, Classic tabs at `/classic`. Stop with `Ctrl+C`.
241
-
242
- To use a pre-downloaded local model inside Docker, mount it and set `MODEL_PATH`:
243
 
244
- ```bash
245
- docker run --rm -p 7860:7860 \
246
- -v "$(pwd)/models:/app/models:ro" \
247
- -e MODEL_PATH=/app/models/qwen2.5-3b-instruct-q4_k_m.gguf \
248
- hackathon-space
249
- ```
250
 
251
  ---
252
 
253
- ## Hugging Face Space deployment
254
 
255
- This repo uses the **Docker SDK**. The Space card metadata lives in the YAML frontmatter at the top of [README.md](README.md).
256
 
257
  ### 1. Push code to GitHub
258
 
259
- Make sure `main` (or your deploy branch) contains at minimum:
260
 
261
- - `Dockerfile`
262
- - `README.md` (with `sdk: docker` and `app_port: 7860`)
263
- - `pyproject.toml`, `uv.lock`
264
- - `apps/gradio-space/` and `libs/inference/`
 
 
265
 
266
  ### 2. Create the Space
267
 
268
  1. Go to [build-small-hackathon](https://huggingface.co/build-small-hackathon)
269
  2. **New Space**
270
- 3. Name: e.g. `small-model-hackathon`
271
- 4. SDK: **Docker**
272
- 5. Link your GitHub repo, or push directly to the Space repo
 
273
 
274
  CLI alternative (if you have `hf` installed and org access):
275
 
276
  ```bash
277
  hf repo create build-small-hackathon/<your-space-name> \
278
  --repo-type space \
279
- --space_sdk docker
280
  ```
281
 
282
- ### 3. Configure hardware
 
 
 
 
 
 
 
 
283
 
 
284
 
285
- | Setting | Recommendation |
286
- | -------- | ------------------------------------------------------------ |
287
- | Hardware | **CPU basic** to start (llama.cpp with `N_GPU_LAYERS=0`) |
288
- | Upgrade | GPU Space if you set `N_GPU_LAYERS > 0` for faster inference |
289
 
 
290
 
291
- ### 4. Set Space environment variables
 
292
 
293
- In the Space **Settings Variables and secrets**:
294
 
 
 
 
 
295
 
296
- | Variable | Value |
297
- | ------------------- | --------------------------------- |
298
- | `INFERENCE_BACKEND` | `llama_cpp` |
299
- | `MODEL_REPO` | `Qwen/Qwen2.5-3B-Instruct-GGUF` |
300
- | `MODEL_FILE` | `qwen2.5-3b-instruct-q4_k_m.gguf` |
301
- | `N_CTX` | `4096` |
302
- | `N_GPU_LAYERS` | `0` (or higher on GPU hardware) |
303
 
 
304
 
305
- ### 5. Build and verify
 
 
 
 
 
 
306
 
307
- HF builds from the root `Dockerfile` and runs:
 
 
 
 
 
 
 
 
308
 
309
  ```bash
310
- uv run --package gradio-space python -m gradio_space.app
 
 
 
 
 
311
  ```
312
 
313
- Check the **Logs** tab while the Space builds. Once running, open the Space URL and send a test chat message. The first message may take several minutes on CPU while the GGUF downloads.
314
 
315
- ### 6. Optional: persistent model cache
316
 
317
- If cold starts are too slow, attach a **Storage Bucket** in Space settings so downloaded GGUF files survive restarts.
 
 
 
 
 
318
 
319
  ---
320
 
@@ -323,29 +346,24 @@ If cold starts are too slow, attach a **Storage Bucket** in Space settings so do
323
 
324
  | Symptom | Likely cause | Fix |
325
  | ---------------------------------------- | --------------------------------- | -------------------------------------------------------------------- |
326
- | First chat hangs / slow | GGUF downloading from Hub | Pre-download locally; on Space, wait or use Storage Bucket |
327
- | `Failed to load model` in chat | Wrong `MODEL_REPO` / `MODEL_FILE` | Check env vars match a valid GGUF on Hub |
328
- | Docker build fails on `llama-cpp-python` | Missing build tools | Dockerfile already installs `build-essential` and `cmake` |
329
- | Space build fails | Missing `uv.lock` or README YAML | Ensure `sdk: docker` is in root `README.md` frontmatter |
330
- | `transformers` backend error | Optional deps not installed | Run `uv sync --package inference --extra transformers` |
331
- | Port already in use locally | Another process on 7860 | `PORT=7861 uv run --package gradio-space python -m gradio_space.app` |
 
332
 
333
 
334
  ---
335
 
336
  ## Entrypoint summary
337
 
338
- All three environments use the same command:
339
-
340
- ```bash
341
- uv run --package gradio-space python -m gradio_space.app
342
- ```
343
-
344
-
345
- | Environment | How to run |
346
- | ----------- | ---------------------------------------------------------- |
347
- | Local dev | `uv run --package gradio-space python -m gradio_space.app` |
348
- | Docker | `docker run -p 7860:7860 hackathon-space` |
349
- | HF Space | Built and started automatically from `Dockerfile` `CMD` |
350
 
351
 
 
1
  # Usage
2
 
3
+ How to run the **Lesson Agent** Gradio app locally, deploy to a Hugging Face Space (Gradio SDK + ZeroGPU), and optionally test with Docker later for the [Build Small Hackathon](https://huggingface.co/build-small-hackathon).
4
 
5
  The primary UI is the **Lesson slides** tab (topic → local model outline → downloadable `.pptx`). Use **ResearchMind** for corpus Q&A, **TeacherVoice** for spoken back-and-forth tutoring, **EchoCoach** for one-shot pitch analysis, or ground lessons directly from the Lesson tab. The **Chat (debug)** tab tests the underlying model.
6
 
 
223
 
224
  ---
225
 
226
+ ## Gradio SDK local smoke test (matches HF Space build)
227
 
228
+ Before pushing to Hugging Face, verify the Gradio SDK entry point:
229
 
230
  ```bash
231
+ python -m venv .venv-gradio && source .venv-gradio/bin/activate
232
+ pip install -r requirements.txt
233
+ ACTIVE_MODEL=minicpm5-1b ALLOW_MODEL_SWITCH=false python app.py
 
 
 
 
234
  ```
235
 
236
+ Open [http://localhost:7860](http://localhost:7860) — Studio at `/`, Classic at `/classic`.
 
 
237
 
238
+ Day-to-day development can still use `uv run` (see above); this path mirrors what HF installs from `requirements.txt`.
 
 
 
 
 
239
 
240
  ---
241
 
242
+ ## Hugging Face Space deployment (Gradio SDK + ZeroGPU)
243
 
244
+ The Space card metadata lives in the YAML frontmatter at the top of [README.md](README.md) (`sdk: gradio`, `app_file: app.py`).
245
 
246
  ### 1. Push code to GitHub
247
 
248
+ Make sure `main` contains at minimum:
249
 
250
+ - `app.py`, `requirements.txt`, `packages.txt`
251
+ - `README.md` (with `sdk: gradio`, `sdk_version`, `app_file: app.py`)
252
+ - `models.yaml`, `skills/`
253
+ - `apps/gradio-space/` and all `libs/*` packages
254
+
255
+ The root `Dockerfile` stays in the repo for a later Docker SDK deploy (see below).
256
 
257
  ### 2. Create the Space
258
 
259
  1. Go to [build-small-hackathon](https://huggingface.co/build-small-hackathon)
260
  2. **New Space**
261
+ 3. Name: e.g. `lesson-agent` or `small-model-hackathon`
262
+ 4. SDK: **Gradio** (Blank template)
263
+ 5. Hardware: **ZeroGPU** (creator needs PRO/Team) or **GPU basic**
264
+ 6. Link your GitHub repo, or push directly to the Space git remote
265
 
266
  CLI alternative (if you have `hf` installed and org access):
267
 
268
  ```bash
269
  hf repo create build-small-hackathon/<your-space-name> \
270
  --repo-type space \
271
+ --space_sdk gradio
272
  ```
273
 
274
+ ### 3. Set Space environment variables
275
+
276
+ In the Space **Settings → Variables and secrets**:
277
+
278
+ | Variable | Value |
279
+ | -------- | ----- |
280
+ | `ACTIVE_MODEL` | `minicpm5-1b` |
281
+ | `ALLOW_MODEL_SWITCH` | `false` |
282
+ | `RESEARCHMIND_DATA_DIR` | `/tmp/researchmind` |
283
 
284
+ Default preset in [`models.yaml`](models.yaml) is `minicpm5-1b` (transformers) — suitable for ZeroGPU.
285
 
286
+ ### 4. Build and verify
 
 
 
287
 
288
+ HF installs from `requirements.txt` and runs root `app.py`. Check the **Logs** tab for:
289
 
290
+ - Successful pip install (first build may take several minutes — `llama-cpp-python` compiles)
291
+ - `Running on local URL: 0.0.0.0:7860`
292
 
293
+ Smoke test on the live Space:
294
 
295
+ 1. **`/`** — Studio UI loads
296
+ 2. **`/classic`** — all tabs render
297
+ 3. Generate slides with a simple topic (e.g. "Photosynthesis, grade 8, 5 slides")
298
+ 4. First LLM request may be slow (model download + ZeroGPU queue)
299
 
300
+ ### 5. ZeroGPU notes
 
 
 
 
 
 
301
 
302
+ LLM handlers use `@spaces.GPU` via [`gradio_space/spaces_runtime.py`](apps/gradio-space/src/gradio_space/spaces_runtime.py). If you see **No CUDA GPUs are available**, an inference path is running outside a decorated handler.
303
 
304
+ Startup model preload is skipped on HF Gradio runtime; the first user request loads the model inside a GPU task.
305
+
306
+ ### 6. Optional: persistent model cache
307
+
308
+ Attach a **Storage Bucket** in Space settings so Hub model weights survive restarts.
309
+
310
+ ---
311
 
312
+ ## Docker SDK deployment (later)
313
+
314
+ Both deploy paths live on the same branch. HF reads **one** `sdk:` from README — switch to Docker when you are ready for a dedicated-GPU Space.
315
+
316
+ 1. Change [README.md](README.md) frontmatter to `sdk: docker`, `app_port: 7860` (remove `sdk_version` / `app_file`)
317
+ 2. Create or reconfigure a Space with **Docker** SDK and **GPU basic** hardware
318
+ 3. Set the same env vars (`ACTIVE_MODEL=minicpm5-1b`, etc.)
319
+
320
+ ### Local Docker smoke test
321
 
322
  ```bash
323
+ docker build -t hackathon-space .
324
+ docker run --rm -p 7860:7860 \
325
+ -e ACTIVE_MODEL=minicpm5-1b \
326
+ -e ALLOW_MODEL_SWITCH=false \
327
+ -e RESEARCHMIND_DATA_DIR=/tmp/researchmind \
328
+ hackathon-space
329
  ```
330
 
331
+ Open [http://localhost:7860](http://localhost:7860) Studio at `/`, Classic tabs at `/classic`. Stop with `Ctrl+C`.
332
 
333
+ To use a pre-downloaded local GGUF model inside Docker, mount it and set `MODEL_PATH`:
334
 
335
+ ```bash
336
+ docker run --rm -p 7860:7860 \
337
+ -v "$(pwd)/models:/app/models:ro" \
338
+ -e MODEL_PATH=/app/models/qwen2.5-3b-instruct-q4_k_m.gguf \
339
+ hackathon-space
340
+ ```
341
 
342
  ---
343
 
 
346
 
347
  | Symptom | Likely cause | Fix |
348
  | ---------------------------------------- | --------------------------------- | -------------------------------------------------------------------- |
349
+ | First chat hangs / slow | Model downloading from Hub | Wait on Space; use Storage Bucket for cache |
350
+ | `Failed to load model` in chat | Wrong `ACTIVE_MODEL` preset | Use `minicpm5-1b` or valid key from `models.yaml` |
351
+ | Space build fails on pip install | `llama-cpp-python` compile | Check Logs; default preset avoids GGUF at runtime |
352
+ | Space build fails | Malformed README YAML | Ensure `sdk: gradio` and `app_file: app.py` in README frontmatter |
353
+ | No CUDA GPUs on ZeroGPU | Handler outside `@spaces.GPU` | LLM entry points must use `gpu_task` in `spaces_runtime.py` |
354
+ | Docker build fails on `llama-cpp-python` | Missing build tools | Dockerfile installs `build-essential` and `cmake` |
355
+ | Port already in use locally | Another process on 7860 | `PORT=7861 python app.py` or `uv run ...` |
356
 
357
 
358
  ---
359
 
360
  ## Entrypoint summary
361
 
362
+ | Environment | How to run |
363
+ | ----------- | ---------- |
364
+ | Local dev (uv) | `uv run --package gradio-space python -m gradio_space.app` |
365
+ | Local Gradio SDK smoke | `pip install -r requirements.txt && python app.py` |
366
+ | HF Gradio Space | HF runs root `app.py` automatically |
367
+ | Docker (later) | `docker run -p 7860:7860 hackathon-space` (after README `sdk: docker`) |
 
 
 
 
 
 
368
 
369
 
app.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Hugging Face Gradio SDK entry point (ZeroGPU / Gradio Spaces)."""
2
+
3
+ from gradio_space.server import main
4
+
5
+ if __name__ == "__main__":
6
+ main()
apps/gradio-space/src/gradio_space/model_loading.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from inference.config import get_app_config, get_model_config
2
  from inference.factory import get_backend, reset_backend
3
  from inference.response_clean import strip_reasoning_output
@@ -74,6 +75,7 @@ def warmup(model_key: str | None = None) -> str:
74
  )
75
 
76
 
 
77
  def reload_model(model_key: str) -> str:
78
  """Clear cached backend and reload weights for settings panel."""
79
  global _current_model_key
@@ -120,6 +122,7 @@ def _history_to_messages(history: list) -> list[dict[str, str]]:
120
  return messages
121
 
122
 
 
123
  def chat(message: str, history: list, model_key: str) -> str:
124
  load_error = ensure_model_loaded(model_key)
125
  if load_error:
 
1
+ from gradio_space.spaces_runtime import gpu_task
2
  from inference.config import get_app_config, get_model_config
3
  from inference.factory import get_backend, reset_backend
4
  from inference.response_clean import strip_reasoning_output
 
75
  )
76
 
77
 
78
+ @gpu_task(duration=120)
79
  def reload_model(model_key: str) -> str:
80
  """Clear cached backend and reload weights for settings panel."""
81
  global _current_model_key
 
122
  return messages
123
 
124
 
125
+ @gpu_task(duration=60)
126
  def chat(message: str, history: list, model_key: str) -> str:
127
  load_error = ensure_model_loaded(model_key)
128
  if load_error:
apps/gradio-space/src/gradio_space/research_helpers.py CHANGED
@@ -8,6 +8,7 @@ import gradio as gr
8
  from agent.models import ResearchIngestResult
9
  from agent.runner import AgentRunner
10
  from gradio_space.model_loading import chat, ensure_model_loaded, get_active_model_key
 
11
  from inference.factory import get_backend
12
  from researchmind.ingest import IngestPipeline
13
 
@@ -209,6 +210,7 @@ def rag_scope_hint(session_id: str, doc_ids: list[str] | None) -> str:
209
  return "RAG scope: **entire** indexed corpus (all sessions)."
210
 
211
 
 
212
  def run_research_question(
213
  question: str,
214
  *,
 
8
  from agent.models import ResearchIngestResult
9
  from agent.runner import AgentRunner
10
  from gradio_space.model_loading import chat, ensure_model_loaded, get_active_model_key
11
+ from gradio_space.spaces_runtime import gpu_task
12
  from inference.factory import get_backend
13
  from researchmind.ingest import IngestPipeline
14
 
 
210
  return "RAG scope: **entire** indexed corpus (all sessions)."
211
 
212
 
213
+ @gpu_task(duration=180)
214
  def run_research_question(
215
  question: str,
216
  *,
apps/gradio-space/src/gradio_space/server.py CHANGED
@@ -12,6 +12,7 @@ from gradio import mount_gradio_app
12
  from gradio_space.api.studio import register_studio_apis
13
  from gradio_space.app import build_demo
14
  from gradio_space.model_loading import preload_active_model
 
15
  from gradio_space.tabs.education_pptx import gradio_allowed_paths
16
  from gradio_space.tabs.echo_coach import echo_coach_allowed_paths
17
  from gradio_space.tabs.research_mind import researchmind_allowed_paths
@@ -66,7 +67,8 @@ def create_server() -> gr.Server:
66
 
67
 
68
  def main() -> None:
69
- preload_active_model()
 
70
  server = create_server()
71
  port = int(os.environ.get("PORT", "7860"))
72
  server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
 
12
  from gradio_space.api.studio import register_studio_apis
13
  from gradio_space.app import build_demo
14
  from gradio_space.model_loading import preload_active_model
15
+ from gradio_space.spaces_runtime import is_hf_gradio_runtime
16
  from gradio_space.tabs.education_pptx import gradio_allowed_paths
17
  from gradio_space.tabs.echo_coach import echo_coach_allowed_paths
18
  from gradio_space.tabs.research_mind import researchmind_allowed_paths
 
67
 
68
 
69
  def main() -> None:
70
+ if not is_hf_gradio_runtime():
71
+ preload_active_model()
72
  server = create_server()
73
  port = int(os.environ.get("PORT", "7860"))
74
  server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
apps/gradio-space/src/gradio_space/spaces_runtime.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hugging Face Spaces ZeroGPU helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from collections.abc import Callable
7
+ from typing import ParamSpec, TypeVar
8
+
9
+ P = ParamSpec("P")
10
+ R = TypeVar("R")
11
+
12
+
13
+ def is_hf_gradio_runtime() -> bool:
14
+ """True on Hugging Face Gradio SDK Spaces (skip startup model preload)."""
15
+ try:
16
+ import spaces # noqa: F401
17
+ except ImportError:
18
+ return False
19
+ return bool(os.environ.get("SPACE_ID"))
20
+
21
+
22
+ def gpu_task(
23
+ *,
24
+ duration: int = 180,
25
+ size: str = "large",
26
+ ) -> Callable[[Callable[P, R]], Callable[P, R]]:
27
+ """Apply @spaces.GPU when the HF spaces runtime is present (no-op elsewhere)."""
28
+
29
+ def decorator(fn: Callable[P, R]) -> Callable[P, R]:
30
+ try:
31
+ import spaces
32
+
33
+ return spaces.GPU(duration=duration, size=size)(fn)
34
+ except ImportError:
35
+ return fn
36
+
37
+ return decorator
apps/gradio-space/src/gradio_space/tabs/echo_coach.py CHANGED
@@ -7,6 +7,7 @@ import gradio as gr
7
  from echocoach.config import get_echo_coach_config
8
  from echocoach.pipeline import run_echo_coach
9
  from gradio_space.model_loading import ensure_model_loaded, get_active_model_key
 
10
  from gradio_space.ui.components import (
11
  build_advanced_panel,
12
  build_recording_block,
@@ -64,6 +65,7 @@ def load_sample_pitch() -> tuple[str | None, str]:
64
  )
65
 
66
 
 
67
  def analyze_pitch(
68
  audio_path: str | None,
69
  language: str,
 
7
  from echocoach.config import get_echo_coach_config
8
  from echocoach.pipeline import run_echo_coach
9
  from gradio_space.model_loading import ensure_model_loaded, get_active_model_key
10
+ from gradio_space.spaces_runtime import gpu_task
11
  from gradio_space.ui.components import (
12
  build_advanced_panel,
13
  build_recording_block,
 
65
  )
66
 
67
 
68
+ @gpu_task(duration=180)
69
  def analyze_pitch(
70
  audio_path: str | None,
71
  language: str,
apps/gradio-space/src/gradio_space/tabs/education_pptx.py CHANGED
@@ -16,6 +16,7 @@ from gradio_space.research_helpers import (
16
  resolve_session,
17
  resolve_topic,
18
  )
 
19
  from gradio_space.ui.components import build_advanced_panel, DOC_CHOICE_LIST_CLASSES, WorkspaceWidgets
20
  from inference.factory import get_backend
21
  from researchmind.config import get_config
@@ -158,6 +159,7 @@ def update_source_visibility(source_mode_label: str, search_workflow_label: str)
158
  )
159
 
160
 
 
161
  def discover_lesson_sources(
162
  topic: str,
163
  session_id: str,
@@ -208,6 +210,7 @@ def discover_lesson_sources(
208
  return msg, gr.update(choices=[], value=[]), refresh_sessions(session_id)
209
 
210
 
 
211
  def generate_lesson_slides(
212
  topic: str,
213
  grade: str,
 
16
  resolve_session,
17
  resolve_topic,
18
  )
19
+ from gradio_space.spaces_runtime import gpu_task
20
  from gradio_space.ui.components import build_advanced_panel, DOC_CHOICE_LIST_CLASSES, WorkspaceWidgets
21
  from inference.factory import get_backend
22
  from researchmind.config import get_config
 
159
  )
160
 
161
 
162
+ @gpu_task(duration=120)
163
  def discover_lesson_sources(
164
  topic: str,
165
  session_id: str,
 
210
  return msg, gr.update(choices=[], value=[]), refresh_sessions(session_id)
211
 
212
 
213
+ @gpu_task(duration=300)
214
  def generate_lesson_slides(
215
  topic: str,
216
  grade: str,
apps/gradio-space/src/gradio_space/tabs/research_mind.py CHANGED
@@ -23,6 +23,7 @@ from gradio_space.research_helpers import (
23
  run_research_question,
24
  trace_summary_markdown,
25
  )
 
26
  from gradio_space.ui.components import build_advanced_panel, DOC_CHOICE_LIST_CLASSES, WorkspaceWidgets
27
  from inference.factory import get_backend
28
 
@@ -35,6 +36,7 @@ def _require_topic(topic: str | None) -> str | None:
35
  return None
36
 
37
 
 
38
  def discover_sources(
39
  topic: str,
40
  session_id: str,
@@ -118,6 +120,7 @@ def discover_sources(
118
  )
119
 
120
 
 
121
  def auto_search_ingest(
122
  topic: str,
123
  session_id: str,
@@ -279,6 +282,7 @@ def ingest_selected(
279
  )
280
 
281
 
 
282
  def ask_question(
283
  question: str,
284
  session_id: str,
 
23
  run_research_question,
24
  trace_summary_markdown,
25
  )
26
+ from gradio_space.spaces_runtime import gpu_task
27
  from gradio_space.ui.components import build_advanced_panel, DOC_CHOICE_LIST_CLASSES, WorkspaceWidgets
28
  from inference.factory import get_backend
29
 
 
36
  return None
37
 
38
 
39
+ @gpu_task(duration=120)
40
  def discover_sources(
41
  topic: str,
42
  session_id: str,
 
120
  )
121
 
122
 
123
+ @gpu_task(duration=180)
124
  def auto_search_ingest(
125
  topic: str,
126
  session_id: str,
 
282
  )
283
 
284
 
285
+ @gpu_task(duration=180)
286
  def ask_question(
287
  question: str,
288
  session_id: str,
apps/gradio-space/src/gradio_space/tabs/teacher_voice.py CHANGED
@@ -18,6 +18,7 @@ from gradio_space.research_helpers import (
18
  resolve_topic,
19
  trace_as_dict,
20
  )
 
21
  from gradio_space.tabs.research_mind import (
22
  auto_search_ingest,
23
  discover_sources,
@@ -87,6 +88,7 @@ def _turn_error(history: list | None, message: str) -> tuple:
87
  )
88
 
89
 
 
90
  def send_turn(
91
  audio_path: str | None,
92
  history: list,
@@ -142,6 +144,7 @@ def send_turn(
142
  return _turn_result(result)
143
 
144
 
 
145
  def send_text_turn(
146
  message: str,
147
  history: list,
 
18
  resolve_topic,
19
  trace_as_dict,
20
  )
21
+ from gradio_space.spaces_runtime import gpu_task
22
  from gradio_space.tabs.research_mind import (
23
  auto_search_ingest,
24
  discover_sources,
 
88
  )
89
 
90
 
91
+ @gpu_task(duration=180)
92
  def send_turn(
93
  audio_path: str | None,
94
  history: list,
 
144
  return _turn_result(result)
145
 
146
 
147
+ @gpu_task(duration=180)
148
  def send_text_turn(
149
  message: str,
150
  history: list,
apps/gradio-space/static/studio/index.html CHANGED
@@ -291,83 +291,107 @@
291
  </section>
292
 
293
  <section class="col col-studio">
294
- <h2 class="section-label">Step 3 · Studio Controls</h2>
295
- <div class="card">
296
- <p class="card-title">RAG Scope</p>
297
- <label class="toggle-row">
298
- <span>Cross-Reference Sources</span>
299
- <input id="use-rag" type="checkbox" checked />
300
- </label>
301
- <p class="status-text">Session and documents use workspace defaults above unless overridden per tool.</p>
302
- </div>
303
- <div class="card">
304
- <p class="card-title">Teacher Voice Mode</p>
305
- <div class="mode-cards" id="voice-modes">
306
- <button type="button" class="mode-card" data-mode="explain">Explain</button>
307
- <button type="button" class="mode-card active" data-mode="lesson">Coach</button>
308
- <button type="button" class="mode-card" data-mode="pitch">Practice</button>
309
- </div>
310
- <label class="field voice-topic-wrap" id="voice-topic-wrap">
311
- <span>Focus topic</span>
312
- <input id="voice-topic" type="text" class="input" placeholder="Uses workspace topic when empty" />
313
- </label>
314
- <details class="voice-rag-sources" id="voice-rag-sources">
315
- <summary>ResearchMind sources (optional)</summary>
316
- <p class="status-text">Set focus topic, then discover or ingest sources. Enable RAG above to ground answers in your library.</p>
317
- <div class="ingest-action-row">
318
- <button type="button" id="btn-voice-discover" class="btn btn-secondary">Discover on web</button>
319
- <button type="button" id="btn-voice-auto-ingest" class="btn btn-secondary">Auto-ingest from web</button>
320
  </div>
321
- <div id="voice-url-choices-panel" class="url-choices-panel hidden">
322
- <div id="voice-url-choices-list" class="url-choices-list"></div>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  </div>
324
- <label class="field">
325
- <span>Paste URLs (one per line)</span>
326
- <textarea id="voice-urls-text" class="input" rows="2" placeholder="https://…"></textarea>
327
- </label>
328
- <label class="upload-zone upload-zone-compact">
329
- <input id="voice-ingest-file" type="file" accept=".pdf,.docx" multiple hidden />
330
- <span class="material-symbols-outlined">upload_file</span>
331
- <span>Upload PDF or Doc</span>
332
- </label>
333
- <button type="button" id="btn-voice-ingest" class="btn btn-secondary btn-block">Ingest sources</button>
334
- <p id="voice-ingest-status" class="status-text"></p>
335
- </details>
336
- <div id="voice-chat-messages" class="research-chat-messages voice-chat-messages">
337
- <p class="research-chat-empty">Type a message or record audio, then send.</p>
338
- </div>
339
- <label class="field voice-panel" id="voice-panel">
340
- <span>Ask the teacher</span>
341
- <textarea id="voice-message" class="input" rows="3" placeholder="What is the difference between pretraining and finetuning a small model?"></textarea>
342
- <div class="recording-row">
343
- <button type="button" id="btn-voice-record-start" class="btn btn-secondary">Start mic</button>
344
- <button type="button" id="btn-voice-record-stop" class="btn btn-secondary" disabled>Stop mic</button>
345
- <input id="voice-audio-upload" type="file" accept="audio/*" class="input input-compact" />
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  </div>
347
- <p id="voice-record-status" class="status-text"></p>
348
- <button type="button" id="btn-voice-send" class="btn btn-secondary btn-block">Send text</button>
349
- <button type="button" id="btn-voice-audio-send" class="btn btn-primary btn-block">Send voice turn</button>
350
- </label>
351
- <p id="voice-turn-status" class="status-text"></p>
352
- <div class="voice-replay-row">
353
- <button type="button" id="btn-voice-speak-full" class="btn btn-secondary">Speak full reply</button>
354
- <button type="button" id="btn-voice-speak-quick" class="btn btn-secondary">Speak first sentence</button>
355
- <button type="button" id="btn-voice-clear" class="btn btn-ghost">Clear conversation</button>
356
  </div>
357
- <div id="voice-audio-out" class="voice-audio-out"></div>
358
  </div>
359
- <div class="card coach-panel-wrap">
360
- <h2 class="section-label">EchoCoach Feedback</h2>
361
- <div class="recording-row">
362
- <button type="button" id="btn-coach-record-start" class="btn btn-secondary">Start mic</button>
363
- <button type="button" id="btn-coach-record-stop" class="btn btn-secondary" disabled>Stop mic</button>
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  </div>
365
- <p id="coach-record-status" class="status-text"></p>
366
- <button type="button" id="btn-coach-sample" class="btn btn-ghost btn-block">Load sample clip</button>
367
- <label class="field">
368
- <span>Or upload pitch (WAV)</span>
369
- <input id="coach-audio" type="file" accept="audio/*" />
370
- </label>
371
  <div class="controls-grid coach-presets">
372
  <label class="field">
373
  <span>Language</span>
@@ -378,19 +402,22 @@
378
  <select id="coach-asr" class="input"></select>
379
  </label>
380
  </div>
381
- <label class="toggle-row">
382
  <span>Speak full rewrite (VoiceOut)</span>
383
  <input id="coach-speak-rewrite" type="checkbox" />
384
  </label>
385
- <button type="button" id="btn-analyze" class="btn btn-secondary btn-block">Analyze pitch</button>
386
- <div id="coach-panel"></div>
387
  </div>
388
  </section>
389
 
390
  <section class="col col-debug">
391
- <div class="card card-tall">
392
- <h2 class="section-label">Chat (debug)</h2>
393
- <p class="status-text">Plain chat or corpus-grounded answers — traces appear below when RAG is on.</p>
 
 
 
394
  <label class="toggle-row">
395
  <span>Use ResearchMind RAG</span>
396
  <input id="debug-use-rag" type="checkbox" />
@@ -418,11 +445,13 @@
418
  <div id="debug-chat-messages" class="research-chat-messages debug-chat-messages">
419
  <p class="research-chat-empty">Send a message to test the active local model.</p>
420
  </div>
421
- <label class="field">
422
- <span>Message</span>
423
- <textarea id="debug-message" class="input" rows="3" placeholder="Hello, model…"></textarea>
424
- </label>
425
- <button type="button" id="btn-debug-send" class="btn btn-primary btn-block">Send</button>
 
 
426
  <details class="studio-debug-trace" id="debug-trace-details">
427
  <summary>Debug trace</summary>
428
  <div id="debug-trace-panel"></div>
 
291
  </section>
292
 
293
  <section class="col col-studio">
294
+ <div class="voice-layout view-voice-only">
295
+ <aside class="voice-rail">
296
+ <div class="card voice-rag-card">
297
+ <p class="card-title">RAG Scope</p>
298
+ <label class="toggle-row">
299
+ <span>Cross-Reference Sources</span>
300
+ <input id="use-rag" type="checkbox" checked />
301
+ </label>
302
+ <p class="status-text">Uses workspace session and documents unless overridden below.</p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  </div>
304
+ <div class="card voice-rail-controls">
305
+ <p class="card-title">Mode</p>
306
+ <div class="mode-cards voice-mode-cards" id="voice-modes">
307
+ <button type="button" class="mode-card" data-mode="explain">Explain</button>
308
+ <button type="button" class="mode-card active" data-mode="lesson">Coach</button>
309
+ <button type="button" class="mode-card" data-mode="pitch">Practice</button>
310
+ </div>
311
+ <label class="field voice-topic-wrap" id="voice-topic-wrap">
312
+ <span>Focus topic</span>
313
+ <input id="voice-topic" type="text" class="input" placeholder="Uses workspace topic when empty" />
314
+ </label>
315
+ <details class="voice-rag-sources" id="voice-rag-sources">
316
+ <summary>Add sources (optional)</summary>
317
+ <p class="status-text">Discover or ingest sources to ground answers in your library.</p>
318
+ <div class="ingest-action-row">
319
+ <button type="button" id="btn-voice-discover" class="btn btn-secondary">Discover on web</button>
320
+ <button type="button" id="btn-voice-auto-ingest" class="btn btn-secondary">Auto-ingest</button>
321
+ </div>
322
+ <div id="voice-url-choices-panel" class="url-choices-panel hidden">
323
+ <div id="voice-url-choices-list" class="url-choices-list"></div>
324
+ </div>
325
+ <label class="field">
326
+ <span>Paste URLs (one per line)</span>
327
+ <textarea id="voice-urls-text" class="input" rows="2" placeholder="https://…"></textarea>
328
+ </label>
329
+ <label class="upload-zone upload-zone-compact">
330
+ <input id="voice-ingest-file" type="file" accept=".pdf,.docx" multiple hidden />
331
+ <span class="material-symbols-outlined">upload_file</span>
332
+ <span>Upload PDF or Doc</span>
333
+ </label>
334
+ <button type="button" id="btn-voice-ingest" class="btn btn-secondary btn-block">Ingest sources</button>
335
+ <p id="voice-ingest-status" class="status-text"></p>
336
+ </details>
337
  </div>
338
+ </aside>
339
+ <div class="voice-main">
340
+ <div class="card voice-main-card">
341
+ <div class="voice-card-head">
342
+ <h2 class="section-label">Teacher Voice</h2>
343
+ <p class="voice-card-desc">Talk with the teacher using text or voice — grounded in your sources when RAG is on.</p>
344
+ </div>
345
+ <div id="voice-chat-messages" class="research-chat-messages voice-chat-messages">
346
+ <p class="research-chat-empty">Type a message or record audio, then send.</p>
347
+ </div>
348
+ <div class="voice-compose" id="voice-panel">
349
+ <label class="field">
350
+ <span>Ask the teacher</span>
351
+ <textarea id="voice-message" class="input" rows="2" placeholder="What is the difference between pretraining and finetuning a small model?"></textarea>
352
+ </label>
353
+ <div class="voice-input-toolbar">
354
+ <div class="recording-row voice-recording-row">
355
+ <button type="button" id="btn-voice-record-start" class="btn btn-secondary">Start mic</button>
356
+ <button type="button" id="btn-voice-record-stop" class="btn btn-secondary" disabled>Stop mic</button>
357
+ <input id="voice-audio-upload" type="file" accept="audio/*" class="input input-compact" />
358
+ </div>
359
+ <p id="voice-record-status" class="status-text voice-record-status"></p>
360
+ </div>
361
+ <div class="voice-send-row">
362
+ <button type="button" id="btn-voice-send" class="btn btn-secondary">Send text</button>
363
+ <button type="button" id="btn-voice-audio-send" class="btn btn-primary">Send voice turn</button>
364
+ </div>
365
+ <p id="voice-turn-status" class="status-text"></p>
366
+ <div class="voice-replay-row">
367
+ <button type="button" id="btn-voice-speak-full" class="btn btn-secondary">Speak full reply</button>
368
+ <button type="button" id="btn-voice-speak-quick" class="btn btn-secondary">Speak first sentence</button>
369
+ <button type="button" id="btn-voice-clear" class="btn btn-ghost">Clear conversation</button>
370
+ </div>
371
+ <div id="voice-audio-out" class="voice-audio-out"></div>
372
+ </div>
373
  </div>
 
 
 
 
 
 
 
 
 
374
  </div>
 
375
  </div>
376
+ <div class="card coach-panel-wrap view-coach-only">
377
+ <div class="coach-card-head">
378
+ <h2 class="section-label">EchoCoach · Pitch analysis</h2>
379
+ <p class="coach-card-desc">Record or upload a short pitch for pace, filler highlights, and spoken feedback.</p>
380
+ </div>
381
+ <div class="coach-capture-row">
382
+ <div class="coach-capture-controls">
383
+ <div class="recording-row coach-recording-row">
384
+ <button type="button" id="btn-coach-record-start" class="btn btn-secondary">Start mic</button>
385
+ <button type="button" id="btn-coach-record-stop" class="btn btn-secondary" disabled>Stop mic</button>
386
+ <button type="button" id="btn-coach-sample" class="btn btn-ghost">Load sample</button>
387
+ </div>
388
+ <p id="coach-record-status" class="status-text coach-record-status"></p>
389
+ </div>
390
+ <label class="field coach-upload-field">
391
+ <span>Upload pitch (WAV)</span>
392
+ <input id="coach-audio" type="file" accept="audio/*" />
393
+ </label>
394
  </div>
 
 
 
 
 
 
395
  <div class="controls-grid coach-presets">
396
  <label class="field">
397
  <span>Language</span>
 
402
  <select id="coach-asr" class="input"></select>
403
  </label>
404
  </div>
405
+ <label class="toggle-row coach-voiceout-toggle">
406
  <span>Speak full rewrite (VoiceOut)</span>
407
  <input id="coach-speak-rewrite" type="checkbox" />
408
  </label>
409
+ <button type="button" id="btn-analyze" class="btn btn-primary btn-block coach-analyze-btn">Analyze pitch</button>
410
+ <div id="coach-panel" class="coach-results-panel"></div>
411
  </div>
412
  </section>
413
 
414
  <section class="col col-debug">
415
+ <div class="card card-tall coach-debug-card">
416
+ <div class="coach-card-head">
417
+ <h2 class="section-label">Chat (debug)</h2>
418
+ <p class="coach-card-desc view-coach-only">Plain chat or corpus-grounded answers — traces appear below when RAG is on.</p>
419
+ <p class="status-text view-debug-only">Plain chat or corpus-grounded answers — traces appear below when RAG is on.</p>
420
+ </div>
421
  <label class="toggle-row">
422
  <span>Use ResearchMind RAG</span>
423
  <input id="debug-use-rag" type="checkbox" />
 
445
  <div id="debug-chat-messages" class="research-chat-messages debug-chat-messages">
446
  <p class="research-chat-empty">Send a message to test the active local model.</p>
447
  </div>
448
+ <div class="coach-debug-compose">
449
+ <label class="field">
450
+ <span>Message</span>
451
+ <textarea id="debug-message" class="input" rows="2" placeholder="Hello, model…"></textarea>
452
+ </label>
453
+ <button type="button" id="btn-debug-send" class="btn btn-primary btn-block">Send</button>
454
+ </div>
455
  <details class="studio-debug-trace" id="debug-trace-details">
456
  <summary>Debug trace</summary>
457
  <div id="debug-trace-panel"></div>
apps/gradio-space/static/studio/studio.css CHANGED
@@ -960,11 +960,288 @@ body {
960
 
961
  .workspace[data-view="voice"] .col-research,
962
  .workspace[data-view="voice"] .col-slides { display: none; }
963
- .workspace[data-view="voice"] { grid-template-columns: 1fr; max-width: 520px; margin-left: auto; margin-right: auto; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
964
 
965
  .workspace[data-view="coach"] .col-research,
966
  .workspace[data-view="coach"] .col-slides { display: none; }
967
- .workspace[data-view="coach"] { grid-template-columns: 1fr; max-width: 520px; margin-left: auto; margin-right: auto; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
968
 
969
  @media (max-width: 768px) {
970
  :root { --sidebar-w: 0px; }
@@ -1229,9 +1506,12 @@ body {
1229
  }
1230
 
1231
  .coach-presets {
1232
- margin-top: 0.5rem;
1233
  }
1234
 
 
 
 
1235
  .workspace[data-view="debug"] .col-research,
1236
  .workspace[data-view="debug"] .col-slides,
1237
  .workspace[data-view="debug"] .col-studio { display: none; }
 
960
 
961
  .workspace[data-view="voice"] .col-research,
962
  .workspace[data-view="voice"] .col-slides { display: none; }
963
+
964
+ .workspace[data-view="voice"] .col-debug,
965
+ .workspace[data-view="voice"] .view-coach-only { display: none; }
966
+
967
+ .view-voice-only { display: none; }
968
+
969
+ .workspace[data-view="voice"] {
970
+ grid-template-columns: minmax(0, 1fr);
971
+ max-width: 1280px;
972
+ gap: 1.25rem;
973
+ }
974
+
975
+ .workspace[data-view="voice"] .col-studio {
976
+ grid-column: 1 / -1;
977
+ width: 100%;
978
+ min-width: 0;
979
+ }
980
+
981
+ .workspace[data-view="voice"] .voice-layout {
982
+ display: grid;
983
+ grid-template-columns: minmax(260px, 0.78fr) minmax(0, 1.22fr);
984
+ gap: 1.25rem;
985
+ align-items: start;
986
+ width: 100%;
987
+ }
988
+
989
+ .workspace[data-view="voice"] .voice-rail {
990
+ display: flex;
991
+ flex-direction: column;
992
+ gap: 1rem;
993
+ min-width: 0;
994
+ }
995
+
996
+ .workspace[data-view="voice"] .voice-main {
997
+ min-width: 0;
998
+ }
999
+
1000
+ .workspace[data-view="voice"] .voice-main-card {
1001
+ display: flex;
1002
+ flex-direction: column;
1003
+ }
1004
+
1005
+ .workspace[data-view="voice"] .voice-compose {
1006
+ display: flex;
1007
+ flex-direction: column;
1008
+ gap: 0.5rem;
1009
+ }
1010
+
1011
+ .workspace[data-view="voice"] .voice-compose .field {
1012
+ margin: 0;
1013
+ }
1014
+
1015
+ .workspace[data-view="voice"] .voice-compose textarea {
1016
+ min-height: 3.25rem;
1017
+ resize: vertical;
1018
+ }
1019
+
1020
+ .workspace[data-view="voice"] .voice-rail .voice-mode-cards {
1021
+ flex-direction: row;
1022
+ flex-wrap: wrap;
1023
+ gap: 0.35rem;
1024
+ margin-bottom: 0.75rem;
1025
+ }
1026
+
1027
+ .workspace[data-view="voice"] .voice-rail .voice-mode-cards .mode-card {
1028
+ flex: 1 1 calc(33.333% - 0.35rem);
1029
+ text-align: center;
1030
+ justify-content: center;
1031
+ min-width: 0;
1032
+ padding-left: 0.5rem;
1033
+ padding-right: 0.5rem;
1034
+ }
1035
+
1036
+ .workspace[data-view="voice"] .voice-rail-controls .voice-topic-wrap {
1037
+ margin: 0 0 0.75rem;
1038
+ }
1039
+
1040
+ .workspace[data-view="voice"] .voice-rag-sources {
1041
+ margin: 0;
1042
+ }
1043
+
1044
+ .workspace[data-view="voice"] .voice-rag-sources summary {
1045
+ cursor: pointer;
1046
+ font-weight: 600;
1047
+ font-size: 0.82rem;
1048
+ }
1049
+
1050
+ .workspace[data-view="voice"] .voice-chat-messages {
1051
+ min-height: 160px;
1052
+ max-height: min(260px, 32vh);
1053
+ margin: 0 0 0.75rem;
1054
+ }
1055
+
1056
+ .workspace[data-view="voice"] .voice-input-toolbar {
1057
+ padding: 0.65rem 0.75rem;
1058
+ border: 1px solid var(--outline-variant);
1059
+ border-radius: var(--radius-lg);
1060
+ background: var(--surface-container-low);
1061
+ margin-bottom: 0.65rem;
1062
+ }
1063
+
1064
+ .workspace[data-view="voice"] .voice-recording-row {
1065
+ margin: 0;
1066
+ }
1067
+
1068
+ .workspace[data-view="voice"] .voice-record-status {
1069
+ margin: 0.35rem 0 0;
1070
+ min-height: 1.1rem;
1071
+ }
1072
+
1073
+ .workspace[data-view="voice"] .voice-send-row {
1074
+ display: grid;
1075
+ grid-template-columns: 1fr 1fr;
1076
+ gap: 0.5rem;
1077
+ margin-bottom: 0.35rem;
1078
+ }
1079
+
1080
+ .workspace[data-view="voice"] .voice-card-head {
1081
+ margin-bottom: 0.85rem;
1082
+ }
1083
+
1084
+ .workspace[data-view="voice"] .voice-card-head .section-label {
1085
+ margin-bottom: 0.35rem;
1086
+ }
1087
+
1088
+ .voice-card-desc {
1089
+ margin: 0;
1090
+ font-size: 0.84rem;
1091
+ line-height: 1.45;
1092
+ color: var(--secondary);
1093
+ }
1094
+
1095
+ @media (max-width: 960px) {
1096
+ .workspace[data-view="voice"] .voice-layout {
1097
+ grid-template-columns: 1fr;
1098
+ max-width: 640px;
1099
+ margin-left: auto;
1100
+ margin-right: auto;
1101
+ }
1102
+
1103
+ .workspace[data-view="voice"] .voice-rail .voice-mode-cards {
1104
+ flex-direction: column;
1105
+ }
1106
+
1107
+ .workspace[data-view="voice"] .voice-rail .voice-mode-cards .mode-card {
1108
+ flex: 1 1 auto;
1109
+ text-align: left;
1110
+ justify-content: space-between;
1111
+ }
1112
+
1113
+ .workspace[data-view="voice"] .voice-send-row {
1114
+ grid-template-columns: 1fr;
1115
+ }
1116
+ }
1117
 
1118
  .workspace[data-view="coach"] .col-research,
1119
  .workspace[data-view="coach"] .col-slides { display: none; }
1120
+
1121
+ .workspace[data-view="coach"] .view-voice-only { display: none; }
1122
+
1123
+ .workspace[data-view="slides"] .col-studio,
1124
+ .workspace[data-view="research"] .col-debug { display: none; }
1125
+
1126
+ .workspace[data-view="coach"] {
1127
+ grid-template-columns: minmax(0, 1.05fr) minmax(0, 0.95fr);
1128
+ max-width: 1280px;
1129
+ gap: 1.25rem;
1130
+ align-items: start;
1131
+ }
1132
+
1133
+ .workspace[data-view="coach"] .coach-panel-wrap,
1134
+ .workspace[data-view="coach"] .coach-debug-card {
1135
+ display: flex;
1136
+ flex-direction: column;
1137
+ }
1138
+
1139
+ .workspace[data-view="coach"] .coach-results-panel {
1140
+ flex: 1;
1141
+ min-height: 120px;
1142
+ margin-top: 0.75rem;
1143
+ overflow-y: auto;
1144
+ }
1145
+
1146
+ .workspace[data-view="coach"] .coach-results-panel:not(:empty) {
1147
+ border-top: 1px solid var(--outline-variant);
1148
+ padding-top: 0.75rem;
1149
+ }
1150
+
1151
+ .workspace[data-view="coach"] .debug-chat-messages {
1152
+ min-height: 140px;
1153
+ max-height: min(240px, 30vh);
1154
+ margin-bottom: 0.5rem;
1155
+ }
1156
+
1157
+ .workspace[data-view="coach"] .coach-debug-compose {
1158
+ padding-top: 0;
1159
+ border-top: none;
1160
+ }
1161
+
1162
+ .workspace[data-view="coach"] .coach-debug-compose textarea {
1163
+ min-height: 3.5rem;
1164
+ resize: vertical;
1165
+ }
1166
+
1167
+ .workspace[data-view="coach"] .coach-debug-card .studio-debug-trace {
1168
+ flex-shrink: 0;
1169
+ margin-top: 0.5rem;
1170
+ }
1171
+
1172
+ .workspace[data-view="coach"] .coach-debug-card .toggle-row,
1173
+ .workspace[data-view="coach"] .coach-debug-card .debug-rag-scope {
1174
+ flex-shrink: 0;
1175
+ }
1176
+
1177
+ .view-coach-only { display: none; }
1178
+ .workspace[data-view="coach"] .view-coach-only:not(.coach-panel-wrap) { display: block; }
1179
+ .workspace[data-view="coach"] .view-debug-only { display: none; }
1180
+
1181
+ .coach-card-head {
1182
+ margin-bottom: 0.85rem;
1183
+ }
1184
+
1185
+ .coach-card-head .section-label {
1186
+ margin-bottom: 0.35rem;
1187
+ }
1188
+
1189
+ .coach-card-desc {
1190
+ margin: 0;
1191
+ font-size: 0.84rem;
1192
+ line-height: 1.45;
1193
+ color: var(--secondary);
1194
+ }
1195
+
1196
+ .coach-capture-row {
1197
+ display: grid;
1198
+ grid-template-columns: minmax(0, 1.2fr) minmax(0, 1fr);
1199
+ gap: 0.75rem;
1200
+ align-items: start;
1201
+ margin-bottom: 0.75rem;
1202
+ padding: 0.75rem;
1203
+ border: 1px solid var(--outline-variant);
1204
+ border-radius: var(--radius-lg);
1205
+ background: var(--surface-container-low);
1206
+ }
1207
+
1208
+ .coach-recording-row {
1209
+ margin: 0;
1210
+ }
1211
+
1212
+ .coach-record-status {
1213
+ margin: 0.35rem 0 0;
1214
+ min-height: 1.25rem;
1215
+ }
1216
+
1217
+ .coach-upload-field {
1218
+ margin: 0;
1219
+ }
1220
+
1221
+ .coach-upload-field input[type="file"] {
1222
+ font-size: 0.78rem;
1223
+ }
1224
+
1225
+ .coach-voiceout-toggle {
1226
+ margin: 0.75rem 0;
1227
+ }
1228
+
1229
+ .coach-analyze-btn {
1230
+ margin-top: 0.25rem;
1231
+ }
1232
+
1233
+ @media (max-width: 960px) {
1234
+ .workspace[data-view="coach"] {
1235
+ grid-template-columns: 1fr;
1236
+ max-width: 640px;
1237
+ margin-left: auto;
1238
+ margin-right: auto;
1239
+ }
1240
+
1241
+ .coach-capture-row {
1242
+ grid-template-columns: 1fr;
1243
+ }
1244
+ }
1245
 
1246
  @media (max-width: 768px) {
1247
  :root { --sidebar-w: 0px; }
 
1506
  }
1507
 
1508
  .coach-presets {
1509
+ margin-top: 0;
1510
  }
1511
 
1512
+ .workspace[data-view="debug"] .view-coach-only { display: none; }
1513
+ .workspace[data-view="debug"] .view-debug-only { display: block; }
1514
+
1515
  .workspace[data-view="debug"] .col-research,
1516
  .workspace[data-view="debug"] .col-slides,
1517
  .workspace[data-view="debug"] .col-studio { display: none; }
models.yaml CHANGED
@@ -67,9 +67,3 @@ models:
67
  backend: transformers
68
  model_id: ./models/finetuned/minicpm5-1b-lora-merged
69
  trust_remote_code: true
70
-
71
- jepa-ensemble-lesson:
72
- label: JEPA ensemble (LLM + emb + JEPA) lesson pretrain
73
- backend: transformers
74
- model_id: ./models/ensemble/jepa-lesson-pretrain
75
- trust_remote_code: true
 
67
  backend: transformers
68
  model_id: ./models/finetuned/minicpm5-1b-lora-merged
69
  trust_remote_code: true
 
 
 
 
 
 
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ffmpeg
2
+ libsndfile1
pyproject.toml CHANGED
@@ -7,7 +7,6 @@ requires-python = ">=3.12"
7
  dependencies = [
8
  "agent",
9
  "echocoach",
10
- "ensemble",
11
  "gradio-space",
12
  "inference",
13
  "researchmind",
@@ -23,11 +22,6 @@ finetune = [
23
  "datasets>=3.0.0",
24
  "peft>=0.14.0",
25
  ]
26
- ensemble = [
27
- "accelerate>=1.2.0",
28
- "peft>=0.14.0",
29
- "transformers>=5.7.0",
30
- ]
31
  evals = [
32
  "slm-evals",
33
  ]
@@ -39,14 +33,12 @@ lm-eval = [
39
  members = [
40
  "apps/*",
41
  "libs/*",
42
- "research/ensemble",
43
  "research/evals",
44
  ]
45
 
46
  [tool.uv.sources]
47
  agent = { workspace = true }
48
  echocoach = { workspace = true }
49
- ensemble = { workspace = true }
50
  gradio-space = { workspace = true }
51
  inference = { workspace = true }
52
  researchmind = { workspace = true }
 
7
  dependencies = [
8
  "agent",
9
  "echocoach",
 
10
  "gradio-space",
11
  "inference",
12
  "researchmind",
 
22
  "datasets>=3.0.0",
23
  "peft>=0.14.0",
24
  ]
 
 
 
 
 
25
  evals = [
26
  "slm-evals",
27
  ]
 
33
  members = [
34
  "apps/*",
35
  "libs/*",
 
36
  "research/evals",
37
  ]
38
 
39
  [tool.uv.sources]
40
  agent = { workspace = true }
41
  echocoach = { workspace = true }
 
42
  gradio-space = { workspace = true }
43
  inference = { workspace = true }
44
  researchmind = { workspace = true }
requirements.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Workspace packages (HF clones full repo)
2
+ -e ./libs/inference
3
+ -e ./libs/researchmind
4
+ -e ./libs/agent
5
+ -e ./libs/echocoach[piper,whisper]
6
+ -e ./apps/gradio-space
7
+
8
+ # Pinned runtime deps (do not pin gradio, spaces, or huggingface_hub — HF preinstalls them)
9
+ accelerate==1.13.0
10
+ torch==2.12.0
11
+ torchvision==0.27.0
12
+ transformers==5.10.2
13
+ peft==0.19.1
14
+ llama-cpp-python==0.3.26
15
+ sentence-transformers==5.5.1
16
+ pydantic>=2.0.0
17
+ pyyaml>=6.0.2
18
+ pillow>=10.0.0
19
+ python-pptx>=1.0.0
20
+ python-docx>=1.1.0
21
+ httpx>=0.28.0
22
+ numpy>=2.0.0
23
+ ddgs==9.14.4
24
+ googlesearch-python>=1.3.0
25
+ pypdf>=5.0.0
26
+ trafilatura==2.1.0
27
+ matplotlib==3.11.0
28
+ soundfile>=0.12.0
29
+ sounddevice>=0.5.0
30
+ librosa==0.11.0
31
+ piper-tts==1.4.2
32
+ pywhispercpp==1.5.0
research/README.md CHANGED
@@ -1,27 +1,25 @@
1
  # Research
2
 
3
- Experimental code for **fine-tuning**, **ensemble architectures**, and **agentic benchmarks**. Nothing here is wired into the Gradio Lesson Agent by default — use it to train models, probe JEPA/world-model ideas, and score checkpoints against public benchmarks.
4
 
5
  | Path | Purpose |
6
  | ---- | ------- |
7
  | [`finetune.py`](finetune.py) | LoRA / QLoRA / full fine-tune on chat or instruction data |
8
- | [`ensemble/`](ensemble/) | JEPA + world-model ensemble experiments (uv package `ensemble`) |
9
  | [`evals/`](evals/) | SLM agentic benchmark suite — BFCL, τ-bench, GAIA, SWE-bench (uv package `slm-evals`) |
10
- | [`data/`](data/) | Shared JSONL datasets for finetune and ensemble harnesses |
11
 
12
  ## Quick links
13
 
14
  - **[USAGE.md](USAGE.md)** — install groups, commands, and typical workflows
15
  - **[docs/overview.md](docs/overview.md)** — how the pieces fit together
16
- - **[ensemble/README.md](ensemble/README.md)** — ensemble smoke tests and harnesses
17
  - **[evals/USAGE.md](evals/USAGE.md)** — benchmark CLI, configs, and results
18
  - **[evals/docs/benchmarks.md](evals/docs/benchmarks.md)** — what each benchmark measures
19
 
20
  ## Install (from repo root)
21
 
22
  ```bash
23
- # Everything you need for research scripts
24
- uv sync --group finetune --group ensemble --group evals
25
  ```
26
 
27
  Individual groups:
@@ -29,8 +27,8 @@ Individual groups:
29
  | Group | Command | Enables |
30
  | ----- | ------- | ------- |
31
  | `finetune` | `uv sync --group finetune` | `research/finetune.py` (LoRA, QLoRA, merge) |
32
- | `ensemble` | `uv sync --group ensemble` | `research/ensemble/` package |
33
  | `evals` | `uv sync --group evals` | `research/evals/` package (`slm-benchmark`) |
 
34
 
35
  ## Typical workflow
36
 
@@ -40,9 +38,7 @@ research/data/education-lesson-chat.jsonl
40
 
41
  research/finetune.py ──► models/finetuned/<preset>-lora/
42
 
43
- ──► research/evals/ (BFCL, τ-bench, GAIA, SWE-bench)
44
-
45
- └──► research/ensemble/ (JEPA / world-model ablations)
46
  ```
47
 
48
  See [USAGE.md](USAGE.md) for copy-paste commands.
 
1
  # Research
2
 
3
+ Experimental code for **fine-tuning** and **agentic benchmarks**. Nothing here is wired into the Gradio Lesson Agent by default — use it to train models and score checkpoints against public benchmarks.
4
 
5
  | Path | Purpose |
6
  | ---- | ------- |
7
  | [`finetune.py`](finetune.py) | LoRA / QLoRA / full fine-tune on chat or instruction data |
 
8
  | [`evals/`](evals/) | SLM agentic benchmark suite — BFCL, τ-bench, GAIA, SWE-bench (uv package `slm-evals`) |
9
+ | [`data/`](data/) | Shared JSONL datasets for finetune and evals |
10
 
11
  ## Quick links
12
 
13
  - **[USAGE.md](USAGE.md)** — install groups, commands, and typical workflows
14
  - **[docs/overview.md](docs/overview.md)** — how the pieces fit together
 
15
  - **[evals/USAGE.md](evals/USAGE.md)** — benchmark CLI, configs, and results
16
  - **[evals/docs/benchmarks.md](evals/docs/benchmarks.md)** — what each benchmark measures
17
 
18
  ## Install (from repo root)
19
 
20
  ```bash
21
+ # All research tooling
22
+ uv sync --group finetune --group evals --group lm-eval
23
  ```
24
 
25
  Individual groups:
 
27
  | Group | Command | Enables |
28
  | ----- | ------- | ------- |
29
  | `finetune` | `uv sync --group finetune` | `research/finetune.py` (LoRA, QLoRA, merge) |
 
30
  | `evals` | `uv sync --group evals` | `research/evals/` package (`slm-benchmark`) |
31
+ | `lm-eval` | `uv sync --group lm-eval` | `slm-lm-eval` CLI (GSM8K, ARC, HellaSwag, …) |
32
 
33
  ## Typical workflow
34
 
 
38
 
39
  research/finetune.py ──► models/finetuned/<preset>-lora/
40
 
41
+ ──► research/evals/ (BFCL, τ-bench, GAIA, SWE-bench, lm-eval)
 
 
42
  ```
43
 
44
  See [USAGE.md](USAGE.md) for copy-paste commands.
research/USAGE.md CHANGED
@@ -1,24 +1,23 @@
1
  # Research usage
2
 
3
- How to run fine-tuning, ensemble experiments, and agentic benchmarks under `research/`. All commands assume the **repo root** as the working directory unless noted.
4
 
5
  The Lesson Agent app lives in `apps/gradio-space/` — see root [USAGE.md](../USAGE.md). Research code is optional and isolated here.
6
 
7
  ## Prerequisites
8
 
9
  - [uv](https://docs.astral.sh/uv/) and Python 3.12
10
- - GPU recommended for real-model runs (CPU works for smoke tests and `tiny` backends)
11
  - Hugging Face Hub access for model downloads and some benchmark datasets
12
 
13
  ## Install dependency groups
14
 
15
  ```bash
16
  # All research tooling
17
- uv sync --group finetune --group ensemble --group evals --group lm-eval
18
 
19
  # Or one at a time
20
  uv sync --group finetune
21
- uv sync --group ensemble
22
  uv sync --group evals
23
  uv sync --group lm-eval
24
  ```
@@ -26,7 +25,6 @@ uv sync --group lm-eval
26
  | Group | Package / script | What it adds |
27
  | ----- | ---------------- | ------------ |
28
  | `finetune` | `research/finetune.py` | `peft`, `datasets`, `bitsandbytes` (QLoRA) |
29
- | `ensemble` | `ensemble` workspace member | JEPA / world-model ensemble + harnesses |
30
  | `evals` | `slm-evals` workspace member | `slm-benchmark` CLI |
31
  | `lm-eval` | `slm-evals[lm-eval]` | `slm-lm-eval` CLI (GSM8K, ARC, HellaSwag, …) |
32
 
@@ -94,83 +92,7 @@ Training writes to `<out>/` (default `./models/finetuned/<preset>-<mode>/`):
94
 
95
  ---
96
 
97
- ## 2. Ensemble experiments (`research/ensemble/`)
98
-
99
- JEPA and world-model ensemble prototypes: small LLM + embedding memory + latent predictors + energy-based draft selection. **Not connected to the Gradio app.**
100
-
101
- Install: `uv sync --group ensemble`
102
-
103
- ### Tier 1 — CPU smoke (no Hub download)
104
-
105
- ```bash
106
- uv run --package ensemble python -m ensemble.jepa_ensemble tiny
107
- uv run --package ensemble python -m ensemble.world_ensemble tiny
108
- bash research/ensemble/scripts/smoke.sh
109
- ```
110
-
111
- ### Tier 2 — Real small model
112
-
113
- ```bash
114
- uv run --package ensemble python -m ensemble.jepa_ensemble Qwen/Qwen2.5-0.5B-Instruct
115
- uv run --package ensemble python -m ensemble.world_ensemble Qwen/Qwen2.5-0.5B-Instruct
116
- ```
117
-
118
- ### Pretrain + save (LLM + emb + JEPA)
119
-
120
- ```bash
121
- # Default LLM: ENSEMBLE_LLM → LLM_PATH → BASE → MODEL_ID → ACTIVE_MODEL (models.yaml)
122
- uv run --package ensemble ensemble-pretrain --steps 200
123
-
124
- # Or override
125
- uv run --package ensemble ensemble-pretrain \
126
- --llm Qwen/Qwen2.5-0.5B-Instruct \
127
- --steps 200
128
-
129
- # Benchmark saved ensemble with slm-evals (compare to base HF model)
130
- uv run --package slm-evals slm-benchmark \
131
- --model ./models/ensemble/jepa-lesson-pretrain \
132
- --model-type ensemble \
133
- --benchmarks bfcl tau_bench --max-samples 20
134
- ```
135
-
136
- Checkpoint files: `manifest.json`, `aux.pt`, `llm/` (PEFT adapters), optional `store.pt`.
137
-
138
- ### Tier 3 — Benchmark harnesses
139
-
140
- Uses `research/data/benchmark-qa.jsonl` (questions) and `benchmark-kb.jsonl` (retrieval snippets).
141
-
142
- ```bash
143
- # JEPA track — toy
144
- uv run --package ensemble python -m ensemble.eval.jepa_harness \
145
- --llm tiny --toy --limit 20 --n_drafts 8
146
-
147
- # JEPA track — education QA
148
- uv run --package ensemble python -m ensemble.eval.jepa_harness \
149
- --llm Qwen/Qwen2.5-0.5B-Instruct \
150
- --qa research/data/benchmark-qa.jsonl \
151
- --kb research/data/benchmark-kb.jsonl \
152
- --limit 50 --n_drafts 8
153
-
154
- # World-model track
155
- uv run --package ensemble python -m ensemble.eval.world_harness \
156
- --llm tiny --toy --limit 20 --n_drafts 8
157
- ```
158
-
159
- More detail: [ensemble/README.md](ensemble/README.md), [docs/overview.md](docs/overview.md).
160
-
161
- ### Legacy shims
162
-
163
- Top-level files re-export the package for old scripts:
164
-
165
- - `research/llm_emb_jepa_ensemble_pluggable.py` → `ensemble.jepa_ensemble`
166
- - `research/world_model_ensemble.py` → `ensemble.world_ensemble`
167
- - `research/eval_harness.py` → `ensemble.eval.jepa_harness`
168
-
169
- Prefer `uv run --package ensemble python -m ensemble.<module>`.
170
-
171
- ---
172
-
173
- ## 3. Agentic benchmarks (`research/evals/`)
174
 
175
  Evaluate local HuggingFace checkpoints on BFCL, τ-bench, GAIA, and SWE-bench Verified.
176
 
@@ -192,9 +114,9 @@ Full reference: [evals/USAGE.md](evals/USAGE.md).
192
 
193
  ---
194
 
195
- ## 4. Academic benchmarks (`slm-lm-eval`)
196
 
197
- Standard lm-evaluation-harness tasks (ARC, HellaSwag, GSM8K, …) for base presets, LoRA adapters, merged checkpoints, and ensemble manifests.
198
 
199
  Install: `uv sync --group lm-eval`
200
 
@@ -222,12 +144,6 @@ uv run --package slm-evals slm-lm-eval \
222
  --preset minicpm5-1b-lesson-lora \
223
  --experiment-name minicpm5-1b-lora__v1 \
224
  --compare-to results/lm_eval/minicpm5-1b__baseline/results.json
225
-
226
- # Ensemble checkpoint
227
- uv run --package slm-evals slm-lm-eval \
228
- --config research/evals/configs/lm_eval_smoke.yaml \
229
- --model ./models/ensemble/jepa-lesson-pretrain \
230
- --experiment-name ensemble-jepa__lm-eval
231
  ```
232
 
233
  Post-training hook:
@@ -248,8 +164,8 @@ Full reference: [evals/USAGE.md](evals/USAGE.md#lm-evaluation-harness-slm-lm-eva
248
  | File | Used by | Format |
249
  | ---- | ------- | ------ |
250
  | `education-lesson-chat.jsonl` | `finetune.py` default | Chat messages for lesson agent |
251
- | `benchmark-qa.jsonl` | Ensemble harnesses | `question`, `answer`, `domain` |
252
- | `benchmark-kb.jsonl` | Ensemble harnesses | Retrieval snippets for memory routing |
253
 
254
  ---
255
 
@@ -283,18 +199,12 @@ Full reference: [evals/USAGE.md](evals/USAGE.md#lm-evaluation-harness-slm-lm-eva
283
  --compare-to results/lm_eval/minicpm5-1b__baseline/results.json
284
  ```
285
 
286
- 5. **Optional** — probe ensemble ideas on the same QA/KB files:
287
- ```bash
288
- bash research/ensemble/scripts/smoke.sh
289
- ```
290
-
291
  ### Verification checklist
292
 
293
  - Use the **same** lm-eval YAML (`tasks`, `num_fewshot`, `limit`, `seed`) for baseline and candidate runs.
294
  - Compare lm-eval `results.json` files with `--compare-to`; do not compare `training_results.json` `result_score` to lm-eval accuracy.
295
  - For LoRA checkpoints, prefer `--preset minicpm5-1b-lesson-lora` (base + adapter) over passing the adapter dir alone to `--model`.
296
  - Report mean ± std only after multiple training seeds; single-seed deltas are indicative, not conclusive.
297
- - Ensemble `loglikelihood` tasks score the underlying LLM head; generative tasks (`gsm8k`) use the full JEPA+RAG stack.
298
 
299
  ---
300
 
@@ -302,7 +212,6 @@ Full reference: [evals/USAGE.md](evals/USAGE.md#lm-evaluation-harness-slm-lm-eva
302
 
303
  | Symptom | Fix |
304
  | ------- | --- |
305
- | `No module named 'ensemble'` | `uv sync --group ensemble` |
306
  | `slm-benchmark: command not found` | `uv sync --group evals` |
307
  | `slm-lm-eval: command not found` | `uv sync --group lm-eval` |
308
  | CUDA OOM during finetune | Use `--mode qlora` or reduce batch size in script args |
 
1
  # Research usage
2
 
3
+ How to run fine-tuning and agentic benchmarks under `research/`. All commands assume the **repo root** as the working directory unless noted.
4
 
5
  The Lesson Agent app lives in `apps/gradio-space/` — see root [USAGE.md](../USAGE.md). Research code is optional and isolated here.
6
 
7
  ## Prerequisites
8
 
9
  - [uv](https://docs.astral.sh/uv/) and Python 3.12
10
+ - GPU recommended for real-model runs (CPU works for smoke tests)
11
  - Hugging Face Hub access for model downloads and some benchmark datasets
12
 
13
  ## Install dependency groups
14
 
15
  ```bash
16
  # All research tooling
17
+ uv sync --group finetune --group evals --group lm-eval
18
 
19
  # Or one at a time
20
  uv sync --group finetune
 
21
  uv sync --group evals
22
  uv sync --group lm-eval
23
  ```
 
25
  | Group | Package / script | What it adds |
26
  | ----- | ---------------- | ------------ |
27
  | `finetune` | `research/finetune.py` | `peft`, `datasets`, `bitsandbytes` (QLoRA) |
 
28
  | `evals` | `slm-evals` workspace member | `slm-benchmark` CLI |
29
  | `lm-eval` | `slm-evals[lm-eval]` | `slm-lm-eval` CLI (GSM8K, ARC, HellaSwag, …) |
30
 
 
92
 
93
  ---
94
 
95
+ ## 2. Agentic benchmarks (`research/evals/`)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  Evaluate local HuggingFace checkpoints on BFCL, τ-bench, GAIA, and SWE-bench Verified.
98
 
 
114
 
115
  ---
116
 
117
+ ## 3. Academic benchmarks (`slm-lm-eval`)
118
 
119
+ Standard lm-evaluation-harness tasks (ARC, HellaSwag, GSM8K, …) for base presets, LoRA adapters, and merged checkpoints.
120
 
121
  Install: `uv sync --group lm-eval`
122
 
 
144
  --preset minicpm5-1b-lesson-lora \
145
  --experiment-name minicpm5-1b-lora__v1 \
146
  --compare-to results/lm_eval/minicpm5-1b__baseline/results.json
 
 
 
 
 
 
147
  ```
148
 
149
  Post-training hook:
 
164
  | File | Used by | Format |
165
  | ---- | ------- | ------ |
166
  | `education-lesson-chat.jsonl` | `finetune.py` default | Chat messages for lesson agent |
167
+ | `benchmark-qa.jsonl` | Optional domain QA evals | `question`, `answer`, `domain` |
168
+ | `benchmark-kb.jsonl` | Optional retrieval snippets | KB entries for domain QA |
169
 
170
  ---
171
 
 
199
  --compare-to results/lm_eval/minicpm5-1b__baseline/results.json
200
  ```
201
 
 
 
 
 
 
202
  ### Verification checklist
203
 
204
  - Use the **same** lm-eval YAML (`tasks`, `num_fewshot`, `limit`, `seed`) for baseline and candidate runs.
205
  - Compare lm-eval `results.json` files with `--compare-to`; do not compare `training_results.json` `result_score` to lm-eval accuracy.
206
  - For LoRA checkpoints, prefer `--preset minicpm5-1b-lesson-lora` (base + adapter) over passing the adapter dir alone to `--model`.
207
  - Report mean ± std only after multiple training seeds; single-seed deltas are indicative, not conclusive.
 
208
 
209
  ---
210
 
 
212
 
213
  | Symptom | Fix |
214
  | ------- | --- |
 
215
  | `slm-benchmark: command not found` | `uv sync --group evals` |
216
  | `slm-lm-eval: command not found` | `uv sync --group lm-eval` |
217
  | CUDA OOM during finetune | Use `--mode qlora` or reduce batch size in script args |
research/docs/overview.md CHANGED
@@ -13,13 +13,12 @@ small-model-hackathon/
13
  └── research/ ← experiments (this tree)
14
  ├── finetune.py
15
  ├── data/
16
- ├── ensemble/ ← uv workspace package
17
  └── evals/ ← uv workspace package
18
  ```
19
 
20
- Research code is a **uv workspace sibling** of `apps/*` and `libs/*`. Root `pyproject.toml` declares optional dependency groups (`finetune`, `ensemble`, `evals`) so the Docker Space image does not need to install torch-heavy extras unless you opt in locally.
21
 
22
- ## Three tracks
23
 
24
  ### Fine-tuning
25
 
@@ -27,38 +26,12 @@ Research code is a **uv workspace sibling** of `apps/*` and `libs/*`. Root `pypr
27
 
28
  Outputs land in `models/finetuned/` — you can register a new preset in `models.yaml` pointing at merged weights for the **Well-Tuned** hackathon badge.
29
 
30
- ### Ensemble (JEPA / world model)
31
 
32
- `research/ensemble/` explores a modular stack inspired by LeCun-style architectures:
33
 
34
- ```text
35
- Input ──► Embedder + VectorStore (retrieval memory)
36
-
37
-
38
- JEPA encoder ──► latent state
39
-
40
- ├──► World model (multi-step latent rollout)
41
-
42
- └──► Energy model (scores LLM draft continuations)
43
-
44
-
45
- Small LLM generates N drafts → pick lowest energy
46
- ```
47
-
48
- Two entry ensembles:
49
-
50
- | Module | File | Critic |
51
- | ------ | ---- | ------ |
52
- | JEPA track | `ensemble.jepa_ensemble` | JEPA latent prediction |
53
- | World track | `ensemble.world_ensemble` | Energy model over world-model rollouts |
54
-
55
- `TinyBackend` runs on CPU with random weights for smoke tests. `HFBackend` loads real Hub models via `transformers` + optional `peft` LoRA banks.
56
-
57
- Eval harnesses (`ensemble.eval.jepa_harness`, `ensemble.eval.world_harness`) measure draft-selection accuracy on `research/data/benchmark-qa.jsonl` with optional KB retrieval from `benchmark-kb.jsonl`.
58
-
59
- ### Agentic evals
60
-
61
- `research/evals/` (`slm-evals` package) scores **whole models** on public agent benchmarks — function calling, multi-turn tool use, GAIA tasks, and SWE-bench patches. This complements ensemble harnesses: evals test end-to-end model behavior; ensemble harnesses test internal selection mechanisms on a small custom QA set.
62
 
63
  ## Data flow
64
 
@@ -79,20 +52,12 @@ flowchart LR
79
  tau[tau-bench]
80
  gaia[GAIA]
81
  swe[SWE-bench]
82
- end
83
-
84
- subgraph ens [ensemble]
85
- jepa[JEPA harness]
86
- world[World harness]
87
  end
88
 
89
  lesson --> train
90
  train --> ckpt
91
  ckpt --> evals
92
- qa --> jepa
93
- kb --> jepa
94
- qa --> world
95
- kb --> world
96
  ```
97
 
98
  ## When to use which tool
@@ -101,14 +66,11 @@ flowchart LR
101
  | ---- | ---- |
102
  | Improve lesson slide quality on your data | `finetune.py` + optional eval before/after |
103
  | Compare base vs LoRA on public agent tasks | `slm-benchmark` |
104
- | Prototype latent draft selection | `ensemble` smoke → harness |
105
  | Ship in Gradio Space | `apps/gradio-space` only — wire new weights via `models.yaml` |
106
 
107
- ## Workspace packages
108
-
109
- Both subpackages are listed in root `[tool.uv.workspace] members`:
110
 
111
- - `research/ensemble` import name `ensemble`
112
- - `research/evals` → import name `slm_evals`, CLI `slm-benchmark`
113
 
114
- Run with `uv run --package <name>` from the repo root so uv resolves workspace paths and shared lockfile versions.
 
13
  └── research/ ← experiments (this tree)
14
  ├── finetune.py
15
  ├── data/
 
16
  └── evals/ ← uv workspace package
17
  ```
18
 
19
+ Research code is a **uv workspace sibling** of `apps/*` and `libs/*`. Root `pyproject.toml` declares optional dependency groups (`finetune`, `evals`, `lm-eval`) so the Docker Space image does not need to install torch-heavy extras unless you opt in locally.
20
 
21
+ ## Two tracks
22
 
23
  ### Fine-tuning
24
 
 
26
 
27
  Outputs land in `models/finetuned/` — you can register a new preset in `models.yaml` pointing at merged weights for the **Well-Tuned** hackathon badge.
28
 
29
+ ### Agentic and academic evals
30
 
31
+ `research/evals/` (`slm-evals` package) scores **whole models** on:
32
 
33
+ - **Agentic benchmarks** — BFCL, τ-bench, GAIA, SWE-bench (`slm-benchmark`)
34
+ - **Academic benchmarks** GSM8K, ARC, HellaSwag, etc. via lm-evaluation-harness (`slm-lm-eval`)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  ## Data flow
37
 
 
52
  tau[tau-bench]
53
  gaia[GAIA]
54
  swe[SWE-bench]
55
+ lmeval[lm-eval tasks]
 
 
 
 
56
  end
57
 
58
  lesson --> train
59
  train --> ckpt
60
  ckpt --> evals
 
 
 
 
61
  ```
62
 
63
  ## When to use which tool
 
66
  | ---- | ---- |
67
  | Improve lesson slide quality on your data | `finetune.py` + optional eval before/after |
68
  | Compare base vs LoRA on public agent tasks | `slm-benchmark` |
69
+ | Compare base vs LoRA on academic tasks | `slm-lm-eval` |
70
  | Ship in Gradio Space | `apps/gradio-space` only — wire new weights via `models.yaml` |
71
 
72
+ ## Workspace package
 
 
73
 
74
+ `research/evals` is listed in root `[tool.uv.workspace] members` as import name `slm_evals`, CLI `slm-benchmark` and `slm-lm-eval`.
 
75
 
76
+ Run with `uv run --package slm-evals ...` from the repo root so uv resolves workspace paths and shared lockfile versions.
research/ensemble/README.md DELETED
@@ -1,113 +0,0 @@
1
- # Ensemble research package
2
-
3
- JEPA and world-model ensemble experiments. Stays under `research/` — not wired into the Gradio agent.
4
-
5
- See also: [../USAGE.md](../USAGE.md) · [../docs/overview.md](../docs/overview.md)
6
-
7
- ## Install
8
-
9
- ```bash
10
- uv sync --group ensemble
11
- ```
12
-
13
- ## Tier 1 — Smoke (CPU, no HF download)
14
-
15
- ```bash
16
- uv run --package ensemble python -m ensemble.jepa_ensemble tiny
17
- uv run --package ensemble python -m ensemble.world_ensemble tiny
18
- bash research/ensemble/scripts/smoke.sh
19
- ```
20
-
21
- ## Tier 2 — Micro demo (real small model)
22
-
23
- ```bash
24
- uv run --package ensemble python -m ensemble.jepa_ensemble Qwen/Qwen2.5-0.5B-Instruct
25
- uv run --package ensemble python -m ensemble.world_ensemble Qwen/Qwen2.5-0.5B-Instruct
26
- ```
27
-
28
- ## Pretrain + save (LLM + emb + JEPA)
29
-
30
- Joint training writes a full checkpoint to `models/ensemble/<name>/`:
31
-
32
- ```bash
33
- # CPU smoke (tiny backend, no HF download)
34
- uv run --package ensemble ensemble-pretrain \
35
- --llm tiny --steps 50 --no-kb \
36
- --out models/ensemble/jepa-smoke
37
-
38
- # Uses ACTIVE_MODEL / BASE / LLM_PATH from .env + models.yaml by default
39
- uv run --package ensemble ensemble-pretrain \
40
- --data research/data/education-lesson-chat.jsonl \
41
- --kb research/data/benchmark-kb.jsonl \
42
- --steps 200
43
-
44
- # Override base LLM explicitly
45
- uv run --package ensemble ensemble-pretrain \
46
- --llm Qwen/Qwen2.5-0.5B-Instruct --steps 200
47
- ```
48
-
49
- Checkpoint layout: `manifest.json`, `aux.pt` (emb/jepa/bridge/router), `llm/` (PEFT adapters).
50
-
51
- Benchmark the saved ensemble with **slm-evals** (auto-detects `manifest.json`):
52
-
53
- ```bash
54
- uv run --package slm-evals slm-benchmark \
55
- --model ./models/ensemble/jepa-lesson-pretrain \
56
- --model-type ensemble \
57
- --benchmarks bfcl tau_bench --max-samples 20
58
-
59
- # Or use the template config
60
- uv run --package slm-evals slm-benchmark \
61
- --config research/evals/configs/ensemble_jepa_lesson.yaml
62
- ```
63
-
64
- Compare against a base HF model by running the same config with `model_type: hf` and `model_path: openbmb/MiniCPM5-1B`.
65
-
66
- ## Tier 3 — Benchmark
67
-
68
- ### JEPA ablation ladder
69
-
70
- ```bash
71
- # Toy (no download)
72
- uv run --package ensemble python -m ensemble.eval.jepa_harness \
73
- --llm tiny --toy --limit 20 --n_drafts 8
74
-
75
- # Education QA set
76
- uv run --package ensemble python -m ensemble.eval.jepa_harness \
77
- --llm Qwen/Qwen2.5-0.5B-Instruct \
78
- --qa research/data/benchmark-qa.jsonl \
79
- --kb research/data/benchmark-kb.jsonl \
80
- --limit 50 --n_drafts 8
81
- ```
82
-
83
- ### World-model energy selector
84
-
85
- ```bash
86
- uv run --package ensemble python -m ensemble.eval.world_harness \
87
- --llm tiny --toy --limit 20 --n_drafts 8
88
-
89
- uv run --package ensemble python -m ensemble.eval.world_harness \
90
- --llm Qwen/Qwen2.5-0.5B-Instruct \
91
- --qa research/data/benchmark-qa.jsonl \
92
- --kb research/data/benchmark-kb.jsonl \
93
- --limit 50 --n_drafts 8
94
- ```
95
-
96
- ## Layout
97
-
98
- ```
99
- research/ensemble/
100
- src/ensemble/
101
- backends.py # TinyBackend, HFBackend, TinyLLM, HFLLM
102
- memory.py # Embedder, VectorStore, Router
103
- jepa.py # JEPA latent predictor
104
- bridge.py # LLM hidden -> latent alignment
105
- world_model.py # Latent dynamics + rollout
106
- energy.py # Energy-based critic
107
- jepa_ensemble.py # Ensemble (JEPA track)
108
- world_ensemble.py # WorldEnsemble
109
- eval/
110
- metrics.py
111
- jepa_harness.py
112
- world_harness.py
113
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/pyproject.toml DELETED
@@ -1,16 +0,0 @@
1
- [project]
2
- name = "ensemble"
3
- version = "0.1.0"
4
- description = "JEPA and world-model ensemble research package"
5
- readme = "README.md"
6
- requires-python = ">=3.12"
7
- dependencies = [
8
- "torch>=2.5.0",
9
- ]
10
-
11
- [project.scripts]
12
- ensemble-pretrain = "ensemble.pretrain:main"
13
-
14
- [build-system]
15
- requires = ["uv_build>=0.8.13,<0.9.0"]
16
- build-backend = "uv_build"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/scripts/smoke.sh DELETED
@@ -1,35 +0,0 @@
1
- #!/usr/bin/env bash
2
- set -euo pipefail
3
-
4
- ROOT="$(cd "$(dirname "$0")/../../.." && pwd)"
5
- cd "$ROOT"
6
-
7
- echo "== JEPA ensemble demo (tiny) =="
8
- uv run --package ensemble python -m ensemble.jepa_ensemble tiny
9
-
10
- echo ""
11
- echo "== World ensemble demo (tiny) =="
12
- uv run --package ensemble python -m ensemble.world_ensemble tiny
13
-
14
- echo ""
15
- echo "== JEPA harness (toy) =="
16
- uv run --package ensemble python -m ensemble.eval.jepa_harness \
17
- --llm tiny --toy --limit 10 --n_drafts 4
18
-
19
- echo "== Pretrain smoke + checkpoint roundtrip =="
20
- uv run --package ensemble ensemble-pretrain \
21
- --llm tiny --steps 20 --no-kb \
22
- --out models/ensemble/jepa-smoke
23
- uv run --package ensemble python -c "
24
- from ensemble.checkpoint import load_checkpoint
25
- ens = load_checkpoint('models/ensemble/jepa-smoke')
26
- print('loaded ensemble, adapters:', ens.adapter_names)
27
- "
28
-
29
- echo ""
30
- echo "== World harness (toy) =="
31
- uv run --package ensemble python -m ensemble.eval.world_harness \
32
- --llm tiny --toy --limit 10 --n_drafts 4
33
-
34
- echo ""
35
- echo "All smoke checks passed."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/__init__.py DELETED
@@ -1,15 +0,0 @@
1
- """Research ensemble package: JEPA and world-model tracks."""
2
-
3
- __all__ = ["Ensemble", "WorldEnsemble"]
4
-
5
-
6
- def __getattr__(name: str):
7
- if name == "Ensemble":
8
- from ensemble.jepa_ensemble import Ensemble
9
-
10
- return Ensemble
11
- if name == "WorldEnsemble":
12
- from ensemble.world_ensemble import WorldEnsemble
13
-
14
- return WorldEnsemble
15
- raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/backends.py DELETED
@@ -1,418 +0,0 @@
1
- """LLM backends: toy fallbacks and HuggingFace + LoRA loaders."""
2
-
3
- from __future__ import annotations
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
-
9
-
10
- class LLMBackend(nn.Module):
11
- """Contract for JEPA ensemble backends."""
12
-
13
- vocab_size: int
14
- hidden_size: int
15
-
16
-
17
- class HFBackend(LLMBackend):
18
- """HuggingFace causal LM with PEFT LoRA adapter bank."""
19
-
20
- def __init__(
21
- self,
22
- model_path: str,
23
- *,
24
- load_in_4bit: bool = False,
25
- lora_r: int = 16,
26
- lora_alpha: int = 32,
27
- target_modules=("q_proj", "v_proj"),
28
- device: str | None = None,
29
- torch_dtype=None,
30
- ):
31
- super().__init__()
32
- from peft import LoraConfig, get_peft_model
33
- from transformers import AutoModelForCausalLM, AutoTokenizer
34
-
35
- self.device_ = torch.device(
36
- device or ("cuda" if torch.cuda.is_available() else "cpu")
37
- )
38
-
39
- kwargs = {}
40
- if load_in_4bit:
41
- from transformers import BitsAndBytesConfig
42
-
43
- kwargs["quantization_config"] = BitsAndBytesConfig(
44
- load_in_4bit=True,
45
- bnb_4bit_compute_dtype=torch.bfloat16,
46
- bnb_4bit_quant_type="nf4",
47
- )
48
- if torch_dtype is not None:
49
- kwargs["torch_dtype"] = torch_dtype
50
-
51
- self.tokenizer = AutoTokenizer.from_pretrained(model_path)
52
- if self.tokenizer.pad_token is None:
53
- self.tokenizer.pad_token = self.tokenizer.eos_token
54
- base = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
55
- if not load_in_4bit:
56
- base.to(self.device_)
57
-
58
- for p in base.parameters():
59
- p.requires_grad_(False)
60
-
61
- self._lora_cfg = LoraConfig(
62
- r=lora_r,
63
- lora_alpha=lora_alpha,
64
- lora_dropout=0.05,
65
- target_modules=list(target_modules),
66
- task_type="CAUSAL_LM",
67
- )
68
- self.model = get_peft_model(base, self._lora_cfg, adapter_name="general")
69
- self._adapters = {"general"}
70
-
71
- self.vocab_size = self.model.config.vocab_size
72
- self.hidden_size = self.model.config.hidden_size
73
-
74
- def add_adapter(self, name: str):
75
- if name not in self._adapters:
76
- self.model.add_adapter(name, self._lora_cfg)
77
- self._adapters.add(name)
78
-
79
- def set_adapter(self, name: str):
80
- self.model.set_adapter(name)
81
-
82
- def trainable_parameters(self):
83
- return (p for p in self.model.parameters() if p.requires_grad)
84
-
85
- def forward(self, ids):
86
- out = self.model(
87
- input_ids=ids.to(self.device_), output_hidden_states=True
88
- )
89
- return out.logits, out.hidden_states[-1]
90
-
91
- @torch.no_grad()
92
- def generate(self, ids, n_new=64, temperature=0.8):
93
- gen_kwargs: dict = dict(
94
- input_ids=ids.to(self.device_),
95
- max_new_tokens=n_new,
96
- pad_token_id=self.tokenizer.pad_token_id,
97
- )
98
- if temperature <= 0:
99
- gen_kwargs["do_sample"] = False
100
- else:
101
- gen_kwargs.update(do_sample=True, temperature=temperature)
102
- out = self.model.generate(**gen_kwargs)
103
- return out
104
-
105
- def encode_text(self, text: str):
106
- return self.tokenizer(text, return_tensors="pt").input_ids.to(self.device_)
107
-
108
- def decode(self, ids):
109
- return self.tokenizer.decode(ids[0], skip_special_tokens=True)
110
-
111
- @property
112
- def device(self):
113
- return self.device_
114
-
115
-
116
- class TinyBackend(LLMBackend):
117
- """Toy transformer with LoRA adapters (no transformers dependency)."""
118
-
119
- VOCAB, D_MODEL, N_LAYERS, N_HEADS, SEQ_LEN, LORA_R = 1000, 128, 2, 4, 32, 8
120
-
121
- class _LoRALinear(nn.Module):
122
- def __init__(self, d_in, d_out, r):
123
- super().__init__()
124
- self.base = nn.Linear(d_in, d_out)
125
- self.base.weight.requires_grad_(False)
126
- self.base.bias.requires_grad_(False)
127
- self.adapters, self.active, self.r = nn.ModuleDict(), None, r
128
-
129
- def add_adapter(self, name):
130
- A = nn.Linear(self.base.in_features, self.r, bias=False)
131
- B = nn.Linear(self.r, self.base.out_features, bias=False)
132
- nn.init.zeros_(B.weight)
133
- self.adapters[name] = nn.Sequential(A, B)
134
-
135
- def forward(self, x):
136
- y = self.base(x)
137
- if self.active and self.active in self.adapters:
138
- y = y + self.adapters[self.active](x)
139
- return y
140
-
141
- class _Block(nn.Module):
142
- def __init__(self, D, H, R):
143
- super().__init__()
144
- L = TinyBackend._LoRALinear
145
- self.ln1 = nn.LayerNorm(D)
146
- self.attn = nn.MultiheadAttention(D, H, batch_first=True)
147
- self.ln2 = nn.LayerNorm(D)
148
- self.up, self.down = L(D, 4 * D, R), L(4 * D, D, R)
149
-
150
- def forward(self, x, mask):
151
- h = self.ln1(x)
152
- a, _ = self.attn(h, h, h, attn_mask=mask, need_weights=False)
153
- x = x + a
154
- return x + self.down(F.gelu(self.up(self.ln2(x))))
155
-
156
- def __init__(self):
157
- super().__init__()
158
- D, V = self.D_MODEL, self.VOCAB
159
- self.tok = nn.Embedding(V, D)
160
- self.pos = nn.Embedding(self.SEQ_LEN * 4, D)
161
- self.blocks = nn.ModuleList(
162
- [self._Block(D, self.N_HEADS, self.LORA_R) for _ in range(self.N_LAYERS)]
163
- )
164
- self.ln_f, self.head = nn.LayerNorm(D), nn.Linear(D, V, bias=False)
165
- self.vocab_size, self.hidden_size = V, D
166
- self.add_adapter("general")
167
- self.set_adapter("general")
168
-
169
- def add_adapter(self, name):
170
- for b in self.blocks:
171
- b.up.add_adapter(name)
172
- b.down.add_adapter(name)
173
-
174
- def set_adapter(self, name):
175
- for b in self.blocks:
176
- b.up.active = name
177
- b.down.active = name
178
-
179
- def trainable_parameters(self):
180
- return (p for p in self.parameters() if p.requires_grad)
181
-
182
- def forward(self, ids):
183
- B, T = ids.shape
184
- x = self.tok(ids) + self.pos(torch.arange(T, device=ids.device))
185
- mask = torch.triu(
186
- torch.full((T, T), float("-inf"), device=ids.device), 1
187
- )
188
- for b in self.blocks:
189
- x = b(x, mask)
190
- h = self.ln_f(x)
191
- return self.head(h), h
192
-
193
- @torch.no_grad()
194
- def generate(self, ids, n_new=16, temperature=1.0):
195
- for _ in range(n_new):
196
- logits, _ = self(ids[:, -self.SEQ_LEN :])
197
- if temperature <= 0:
198
- nxt = logits[:, -1].argmax(dim=-1, keepdim=True)
199
- else:
200
- nxt = torch.multinomial(
201
- F.softmax(logits[:, -1] / temperature, -1), 1
202
- )
203
- ids = torch.cat([ids, nxt], dim=1)
204
- return ids
205
-
206
- def encode_text(self, text: str):
207
- vals = [ord(c) % self.vocab_size for c in text[: self.SEQ_LEN]]
208
- if not vals:
209
- vals = [0]
210
- return torch.tensor([vals], dtype=torch.long)
211
-
212
- def decode(self, ids):
213
- return " ".join(str(int(t)) for t in ids[0].tolist())
214
-
215
- @property
216
- def device(self):
217
- return next(self.parameters()).device
218
-
219
-
220
- def make_backend(llm: str, **kw) -> LLMBackend:
221
- """'tiny' -> toy model; anything else -> HF hub id or local path."""
222
- return TinyBackend() if llm == "tiny" else HFBackend(llm, **kw)
223
-
224
-
225
- def load_hf_backend_from_checkpoint(
226
- base_llm: str,
227
- adapter_dir: str | None,
228
- *,
229
- adapter_names: tuple[str, ...] = ("general",),
230
- device: str | None = None,
231
- load_in_4bit: bool = False,
232
- lora_r: int = 16,
233
- lora_alpha: int = 32,
234
- ) -> HFBackend:
235
- """Load a frozen base LM + saved PEFT adapters (ensemble checkpoint llm/)."""
236
- from pathlib import Path
237
-
238
- from peft import LoraConfig, PeftModel, get_peft_model
239
- from transformers import AutoModelForCausalLM, AutoTokenizer
240
-
241
- def _discover_adapter_dirs(root: Path) -> dict[str, Path]:
242
- if (root / "adapter_config.json").is_file():
243
- return {"general": root}
244
- discovered: dict[str, Path] = {}
245
- for child in sorted(root.iterdir()):
246
- if child.is_dir() and (child / "adapter_config.json").is_file():
247
- discovered[child.name] = child
248
- return discovered
249
-
250
- resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
251
- tokenizer = AutoTokenizer.from_pretrained(adapter_dir or base_llm)
252
- if tokenizer.pad_token is None:
253
- tokenizer.pad_token = tokenizer.eos_token
254
-
255
- kwargs: dict = {}
256
- if load_in_4bit:
257
- from transformers import BitsAndBytesConfig
258
-
259
- kwargs["quantization_config"] = BitsAndBytesConfig(
260
- load_in_4bit=True,
261
- bnb_4bit_compute_dtype=torch.bfloat16,
262
- bnb_4bit_quant_type="nf4",
263
- )
264
- elif resolved_device != "cpu":
265
- kwargs["torch_dtype"] = torch.bfloat16
266
-
267
- base = AutoModelForCausalLM.from_pretrained(base_llm, **kwargs)
268
- if not load_in_4bit and resolved_device != "cpu":
269
- base.to(resolved_device)
270
- for p in base.parameters():
271
- p.requires_grad_(False)
272
-
273
- if adapter_dir:
274
- adapter_dirs = _discover_adapter_dirs(Path(adapter_dir))
275
- if not adapter_dirs:
276
- raise ValueError(
277
- f"No PEFT adapters found under {adapter_dir} "
278
- "(expected adapter_config.json or <name>/adapter_config.json)"
279
- )
280
- preferred = [name for name in adapter_names if name in adapter_dirs]
281
- load_order = preferred + [
282
- name for name in adapter_dirs if name not in preferred
283
- ]
284
- first_name = load_order[0]
285
- model = PeftModel.from_pretrained(
286
- base,
287
- str(adapter_dirs[first_name]),
288
- adapter_name=first_name,
289
- is_trainable=False,
290
- )
291
- for name in load_order[1:]:
292
- model.load_adapter(str(adapter_dirs[name]), adapter_name=name)
293
- adapters = set(load_order)
294
- else:
295
- lora_cfg = LoraConfig(
296
- r=lora_r,
297
- lora_alpha=lora_alpha,
298
- lora_dropout=0.05,
299
- target_modules=["q_proj", "v_proj"],
300
- task_type="CAUSAL_LM",
301
- )
302
- model = get_peft_model(base, lora_cfg, adapter_name="general")
303
- adapters = {"general"}
304
-
305
- backend = HFBackend.__new__(HFBackend)
306
- nn.Module.__init__(backend)
307
- backend.device_ = torch.device(resolved_device)
308
- backend.tokenizer = tokenizer
309
- backend.model = model
310
- backend._lora_cfg = None
311
- backend._adapters = adapters
312
- backend.vocab_size = model.config.vocab_size
313
- backend.hidden_size = model.config.hidden_size
314
- if adapter_names:
315
- backend.set_adapter(adapter_names[0])
316
- return backend
317
-
318
-
319
- class TinyLLM(nn.Module):
320
- """Simpler toy LLM for the world-model track (no adapter bank)."""
321
-
322
- VOCAB, D, L, H, T = 1000, 128, 2, 4, 32
323
-
324
- def __init__(self):
325
- super().__init__()
326
- self.tok = nn.Embedding(self.VOCAB, self.D)
327
- self.pos = nn.Embedding(self.T * 4, self.D)
328
- layer = nn.TransformerEncoderLayer(
329
- self.D, self.H, 4 * self.D, batch_first=True, norm_first=True
330
- )
331
- self.blocks = nn.TransformerEncoder(layer, self.L)
332
- self.head = nn.Linear(self.D, self.VOCAB, bias=False)
333
- self.vocab_size, self.hidden_size = self.VOCAB, self.D
334
-
335
- def forward(self, ids):
336
- Tn = ids.size(1)
337
- x = self.tok(ids) + self.pos(torch.arange(Tn, device=ids.device))
338
- mask = torch.triu(
339
- torch.full((Tn, Tn), float("-inf"), device=ids.device), 1
340
- )
341
- h = self.blocks(x, mask=mask)
342
- return self.head(h), h
343
-
344
- @torch.no_grad()
345
- def generate(self, ids, n_new=16, temperature=1.0):
346
- for _ in range(n_new):
347
- logits, _ = self(ids[:, -self.T :])
348
- nxt = torch.multinomial(
349
- F.softmax(logits[:, -1] / temperature, -1), 1
350
- )
351
- ids = torch.cat([ids, nxt], 1)
352
- return ids
353
-
354
- def trainable_parameters(self):
355
- return self.parameters()
356
-
357
- @property
358
- def device(self):
359
- return next(self.parameters()).device
360
-
361
-
362
- class HFLLM(nn.Module):
363
- """Small HF model with single LoRA stack (world-model track)."""
364
-
365
- def __init__(self, path, lora_r=16):
366
- super().__init__()
367
- from peft import LoraConfig, get_peft_model
368
- from transformers import AutoModelForCausalLM, AutoTokenizer
369
-
370
- self.tokenizer = AutoTokenizer.from_pretrained(path)
371
- if self.tokenizer.pad_token is None:
372
- self.tokenizer.pad_token = self.tokenizer.eos_token
373
- base = AutoModelForCausalLM.from_pretrained(
374
- path,
375
- torch_dtype=torch.bfloat16
376
- if torch.cuda.is_available()
377
- else torch.float32,
378
- device_map="auto" if torch.cuda.is_available() else None,
379
- )
380
- for p in base.parameters():
381
- p.requires_grad_(False)
382
- cfg = LoraConfig(
383
- r=lora_r,
384
- lora_alpha=2 * lora_r,
385
- lora_dropout=0.05,
386
- target_modules=["q_proj", "v_proj"],
387
- task_type="CAUSAL_LM",
388
- )
389
- self.model = get_peft_model(base, cfg)
390
- self.vocab_size = self.model.config.vocab_size
391
- self.hidden_size = self.model.config.hidden_size
392
-
393
- def forward(self, ids):
394
- out = self.model(
395
- input_ids=ids.to(self.device), output_hidden_states=True
396
- )
397
- return out.logits, out.hidden_states[-1]
398
-
399
- @torch.no_grad()
400
- def generate(self, ids, n_new=32, temperature=0.8):
401
- return self.model.generate(
402
- input_ids=ids.to(self.device),
403
- max_new_tokens=n_new,
404
- do_sample=True,
405
- temperature=temperature,
406
- pad_token_id=self.tokenizer.pad_token_id,
407
- )
408
-
409
- def trainable_parameters(self):
410
- return (p for p in self.model.parameters() if p.requires_grad)
411
-
412
- @property
413
- def device(self):
414
- return next(self.model.parameters()).device
415
-
416
-
417
- def load_llm(spec: str):
418
- return TinyLLM() if spec == "tiny" else HFLLM(spec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/bridge.py DELETED
@@ -1,28 +0,0 @@
1
- """Bridge: align LLM hidden states with JEPA latent space."""
2
-
3
- from __future__ import annotations
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
-
9
-
10
- class Bridge(nn.Module):
11
- def __init__(self, d_llm_hidden: int, d_latent: int):
12
- super().__init__()
13
- self.proj = nn.Sequential(
14
- nn.Linear(d_llm_hidden, d_latent),
15
- nn.GELU(),
16
- nn.Linear(d_latent, d_latent),
17
- )
18
-
19
- def forward(self, llm_hidden):
20
- return self.proj(llm_hidden.float().mean(dim=1))
21
-
22
- def info_nce(self, z1, z2, tau=0.07):
23
- z1, z2 = F.normalize(z1, dim=-1), F.normalize(z2, dim=-1)
24
- logits = z1 @ z2.t() / tau
25
- labels = torch.arange(z1.size(0), device=z1.device)
26
- return 0.5 * (
27
- F.cross_entropy(logits, labels) + F.cross_entropy(logits.t(), labels)
28
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/checkpoint.py DELETED
@@ -1,149 +0,0 @@
1
- """Save and load JEPA ensemble checkpoints under models/ensemble/."""
2
-
3
- from __future__ import annotations
4
-
5
- import json
6
- from pathlib import Path
7
- from typing import Any
8
-
9
- import torch
10
-
11
- from ensemble.backends import TinyBackend, load_hf_backend_from_checkpoint
12
- from ensemble.jepa_ensemble import Ensemble
13
-
14
- MANIFEST_FILE = "manifest.json"
15
- AUX_FILE = "aux.pt"
16
- STORE_FILE = "store.pt"
17
- LLM_DIR = "llm"
18
- TINY_LLM_FILE = "tiny_llm.pt"
19
-
20
- CHECKPOINT_VERSION = 1
21
-
22
-
23
- def _aux_state_dict(ens: Ensemble) -> dict[str, torch.Tensor]:
24
- return {
25
- "emb": ens.emb.state_dict(),
26
- "jepa": ens.jepa.state_dict(),
27
- "bridge": ens.bridge.state_dict(),
28
- "router": ens.router.state_dict(),
29
- }
30
-
31
-
32
- def _store_payload(ens: Ensemble) -> dict[str, Any]:
33
- return {
34
- "keys": [k for k in ens.store.keys],
35
- "values": [v for v in ens.store.values],
36
- }
37
-
38
-
39
- def save_checkpoint(
40
- ens: Ensemble,
41
- out_dir: str | Path,
42
- *,
43
- base_llm: str,
44
- training_meta: dict[str, Any] | None = None,
45
- ) -> Path:
46
- """Persist ensemble (LLM adapters + emb + JEPA + bridge + router + store)."""
47
- root = Path(out_dir).resolve()
48
- root.mkdir(parents=True, exist_ok=True)
49
-
50
- backend = "tiny" if isinstance(ens.llm, TinyBackend) else "hf"
51
- manifest: dict[str, Any] = {
52
- "version": CHECKPOINT_VERSION,
53
- "track": "jepa",
54
- "backend": backend,
55
- "base_llm": base_llm,
56
- "adapter_names": list(ens.adapter_names),
57
- "d_emb": ens.emb.d_emb,
58
- "d_jepa": ens.jepa.d_latent,
59
- "training": training_meta or {},
60
- }
61
-
62
- torch.save(_aux_state_dict(ens), root / AUX_FILE)
63
- store = _store_payload(ens)
64
- if store["keys"]:
65
- torch.save(store, root / STORE_FILE)
66
-
67
- if backend == "hf":
68
- llm_path = root / LLM_DIR
69
- llm_path.mkdir(exist_ok=True)
70
- ens.llm.model.save_pretrained(llm_path)
71
- ens.llm.tokenizer.save_pretrained(llm_path)
72
- else:
73
- torch.save(ens.llm.state_dict(), root / TINY_LLM_FILE)
74
-
75
- with open(root / MANIFEST_FILE, "w") as f:
76
- json.dump(manifest, f, indent=2)
77
-
78
- return root
79
-
80
-
81
- def is_ensemble_checkpoint(path: str | Path) -> bool:
82
- return (Path(path) / MANIFEST_FILE).is_file()
83
-
84
-
85
- def load_checkpoint(
86
- ckpt_dir: str | Path,
87
- *,
88
- device: str | None = None,
89
- load_in_4bit: bool = False,
90
- ) -> Ensemble:
91
- """Restore a saved JEPA ensemble from models/ensemble/<name>/."""
92
- root = Path(ckpt_dir).resolve()
93
- manifest_path = root / MANIFEST_FILE
94
- if not manifest_path.is_file():
95
- raise FileNotFoundError(
96
- f"Not an ensemble checkpoint (missing {MANIFEST_FILE}): {root}"
97
- )
98
-
99
- with open(manifest_path) as f:
100
- manifest = json.load(f)
101
-
102
- base_llm = manifest["base_llm"]
103
- backend = manifest.get("backend", "hf")
104
- adapter_names = tuple(manifest.get("adapter_names", ["general"]))
105
- d_emb = manifest.get("d_emb", 64)
106
- d_jepa = manifest.get("d_jepa", 64)
107
-
108
- if backend == "tiny":
109
- ens = Ensemble(
110
- llm="tiny",
111
- adapter_names=adapter_names,
112
- d_emb=d_emb,
113
- d_jepa=d_jepa,
114
- )
115
- tiny_state = torch.load(
116
- root / TINY_LLM_FILE, map_location="cpu", weights_only=True
117
- )
118
- ens.llm.load_state_dict(tiny_state)
119
- else:
120
- llm_dir = root / LLM_DIR
121
- llm_backend = load_hf_backend_from_checkpoint(
122
- base_llm,
123
- str(llm_dir) if llm_dir.is_dir() else None,
124
- adapter_names=adapter_names,
125
- device=device,
126
- load_in_4bit=load_in_4bit,
127
- )
128
- ens = Ensemble(
129
- llm=base_llm,
130
- adapter_names=adapter_names,
131
- d_emb=d_emb,
132
- d_jepa=d_jepa,
133
- llm_backend=llm_backend,
134
- )
135
-
136
- aux = torch.load(root / AUX_FILE, map_location="cpu", weights_only=True)
137
- ens.emb.load_state_dict(aux["emb"])
138
- ens.jepa.load_state_dict(aux["jepa"])
139
- ens.bridge.load_state_dict(aux["bridge"])
140
- ens.router.load_state_dict(aux["router"])
141
-
142
- store_path = root / STORE_FILE
143
- if store_path.is_file():
144
- store = torch.load(store_path, map_location="cpu", weights_only=True)
145
- ens.store.keys = list(store["keys"])
146
- ens.store.values = list(store["values"])
147
-
148
- ens.eval()
149
- return ens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/config.py DELETED
@@ -1,163 +0,0 @@
1
- """Resolve base LLM for ensemble from .env and models.yaml (same order as finetune)."""
2
-
3
- from __future__ import annotations
4
-
5
- import os
6
- import sys
7
- from pathlib import Path
8
-
9
- _REPO_ROOT = Path(__file__).resolve().parents[4]
10
- _FALLBACK_PRESET = "minicpm5-1b"
11
-
12
- _ENV_LLM_KEYS = (
13
- "ENSEMBLE_LLM",
14
- "LLM_PATH",
15
- "BASE",
16
- "FINETUNE_MODEL",
17
- "MODEL_ID",
18
- )
19
-
20
-
21
- def repo_root() -> Path:
22
- return _REPO_ROOT
23
-
24
-
25
- def load_dotenv() -> None:
26
- """Load KEY=VALUE pairs from repo .env without overriding existing env vars."""
27
- path = _REPO_ROOT / ".env"
28
- if not path.is_file():
29
- return
30
- for line in path.read_text().splitlines():
31
- line = line.strip()
32
- if not line or line.startswith("#") or "=" not in line:
33
- continue
34
- key, _, value = line.partition("=")
35
- key = key.strip()
36
- value = value.strip().strip('"').strip("'")
37
- if key:
38
- os.environ.setdefault(key, value)
39
-
40
-
41
- def _ensure_inference_on_path() -> None:
42
- libs = _REPO_ROOT / "libs" / "inference" / "src"
43
- if str(libs) not in sys.path:
44
- sys.path.insert(0, str(libs))
45
-
46
-
47
- def _is_ensemble_llm_preset(model) -> bool:
48
- return model.backend == "transformers" and not model.multimodal and bool(
49
- model.model_id
50
- )
51
-
52
-
53
- def _llm_from_local_path(raw: str) -> str | None:
54
- path = Path(raw)
55
- if not path.is_absolute():
56
- path = (_REPO_ROOT / path).resolve()
57
- if path.suffix == ".gguf":
58
- return None
59
- if path.is_dir() and (path / "config.json").is_file():
60
- return str(path)
61
- if path.is_file():
62
- return None
63
- return None
64
-
65
-
66
- def _llm_from_env_paths() -> str | None:
67
- for key in ("LLM_PATH", "MODEL_PATH"):
68
- raw = os.environ.get(key)
69
- if raw:
70
- resolved = _llm_from_local_path(raw)
71
- if resolved:
72
- return resolved
73
- return None
74
-
75
-
76
- def resolve_llm(
77
- *,
78
- llm_arg: str | None = None,
79
- preset_arg: str | None = None,
80
- ) -> tuple[str, str | None]:
81
- """
82
- Return (hub_id_or_local_path, preset_key) for ensemble HF backends.
83
-
84
- Priority when llm_arg is None or ``auto``:
85
- 1. ENSEMBLE_LLM, LLM_PATH (local HF dir), BASE, FINETUNE_MODEL, MODEL_ID
86
- 2. MODEL_PATH if it points at a HuggingFace model directory (not .gguf)
87
- 3. ENSEMBLE_PRESET, FINETUNE_PRESET, or ACTIVE_MODEL from models.yaml
88
- 4. First fine-tunable transformers preset (default minicpm5-1b)
89
- """
90
- if llm_arg and llm_arg not in ("auto",):
91
- return llm_arg, preset_arg
92
-
93
- for env_name in _ENV_LLM_KEYS:
94
- raw = os.environ.get(env_name)
95
- if raw:
96
- local = _llm_from_local_path(raw)
97
- return local or raw, preset_arg
98
-
99
- local = _llm_from_env_paths()
100
- if local:
101
- return local, preset_arg
102
-
103
- _ensure_inference_on_path()
104
- from inference.config import get_app_config, get_model_config
105
-
106
- app_config = get_app_config(reload=True)
107
- preset_key = (
108
- preset_arg
109
- or os.environ.get("ENSEMBLE_PRESET")
110
- or os.environ.get("FINETUNE_PRESET")
111
- or os.environ.get("ACTIVE_MODEL")
112
- )
113
-
114
- if preset_key and preset_key in app_config.models:
115
- model = get_model_config(preset_key)
116
- if not _is_ensemble_llm_preset(model):
117
- preset_key = None
118
-
119
- if preset_key is None:
120
- for candidate in (_FALLBACK_PRESET, *app_config.models):
121
- if candidate not in app_config.models:
122
- continue
123
- model = get_model_config(candidate)
124
- if _is_ensemble_llm_preset(model):
125
- preset_key = candidate
126
- break
127
-
128
- if not preset_key:
129
- raise SystemExit(
130
- "No transformers LLM found for ensemble. Pass --llm, set LLM_PATH/BASE/"
131
- "MODEL_ID in .env, or ACTIVE_MODEL in models.yaml."
132
- )
133
-
134
- model = get_model_config(preset_key)
135
- if not _is_ensemble_llm_preset(model):
136
- raise SystemExit(
137
- f"Preset {preset_key!r} cannot back an ensemble "
138
- f"(backend={model.backend}, multimodal={model.multimodal})."
139
- )
140
- return model.model_id, preset_key
141
-
142
-
143
- def default_ensemble_out(preset_key: str | None) -> str:
144
- label = preset_key or "custom"
145
- return str((_REPO_ROOT / "models" / "ensemble" / f"{label}-jepa-pretrain").resolve())
146
-
147
-
148
- def resolve_llm_cli(
149
- llm: str | None,
150
- *,
151
- toy: bool = False,
152
- preset: str | None = None,
153
- ) -> str:
154
- """CLI helper: explicit tiny, else .env / models.yaml unless --toy without --llm."""
155
- if llm == "tiny":
156
- return "tiny"
157
- if llm is None or llm == "auto":
158
- if toy:
159
- return "tiny"
160
- load_dotenv()
161
- resolved, _ = resolve_llm(preset_arg=preset)
162
- return resolved
163
- return llm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/energy.py DELETED
@@ -1,45 +0,0 @@
1
- """Energy model: score candidate latents against world state."""
2
-
3
- from __future__ import annotations
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
-
9
-
10
- class EnergyModel(nn.Module):
11
- def __init__(self, d_latent: int):
12
- super().__init__()
13
- self.net = nn.Sequential(
14
- nn.Linear(2 * d_latent, 2 * d_latent),
15
- nn.GELU(),
16
- nn.Linear(2 * d_latent, d_latent),
17
- nn.GELU(),
18
- nn.Linear(d_latent, 1),
19
- )
20
- self.d_latent = d_latent
21
-
22
- def energy(self, s, z):
23
- return self.net(torch.cat([s, z], -1)).squeeze(-1)
24
-
25
- def contrastive_loss(self, s, z_pos, z_negs=None, tau=0.5):
26
- B = s.size(0)
27
- s_rep = s.unsqueeze(1).expand(B, B, self.d_latent).reshape(
28
- B * B, self.d_latent
29
- )
30
- z_rep = z_pos.unsqueeze(0).expand(B, B, self.d_latent).reshape(
31
- B * B, self.d_latent
32
- )
33
- E = self.energy(s_rep, z_rep).view(B, B)
34
- if z_negs is not None:
35
- En = self.energy(
36
- s.repeat_interleave(z_negs.size(1), 0),
37
- z_negs.reshape(-1, self.d_latent),
38
- ).view(B, -1)
39
- E = torch.cat([E, En], dim=1)
40
- labels = torch.arange(B, device=s.device)
41
- return F.cross_entropy(-E / tau, labels)
42
-
43
- @torch.no_grad()
44
- def rank(self, s, candidates):
45
- return self.energy(s.expand(candidates.size(0), -1), candidates)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/eval/__init__.py DELETED
@@ -1 +0,0 @@
1
- """Evaluation harnesses for JEPA and world-model ensembles."""
 
 
research/ensemble/src/ensemble/eval/jepa_harness.py DELETED
@@ -1,266 +0,0 @@
1
- """Ablation ladder + JEPA best-of-N benchmark for the ensemble."""
2
-
3
- from __future__ import annotations
4
-
5
- import argparse
6
- import json
7
- import random
8
- import time
9
- from collections import defaultdict
10
-
11
- import torch
12
- import torch.nn.functional as F
13
-
14
- from ensemble.eval.metrics import em_score, f1_score, paired_bootstrap
15
- from ensemble.backends import TinyBackend
16
- from ensemble.checkpoint import load_checkpoint
17
- from ensemble.config import load_dotenv, resolve_llm_cli
18
- from ensemble.jepa_ensemble import Ensemble
19
-
20
-
21
- @torch.no_grad()
22
- def generate_plain(ens, q_ids, n_new):
23
- ens.llm.set_adapter(ens.adapter_names[0])
24
- t0 = time.time()
25
- out = ens.llm.generate(q_ids.to(ens.llm.device), n_new=n_new, temperature=0.7)
26
- return out[:, q_ids.size(1) :], time.time() - t0
27
-
28
-
29
- @torch.no_grad()
30
- def generate_config(
31
- ens, q_ids, n_new, *, use_rag, use_router, use_jepa, n_drafts=1, tau=0.0
32
- ):
33
- q_emb = ens.emb(q_ids.cpu())
34
-
35
- if use_router:
36
- a_idx = ens.router(q_emb).item()
37
- ens.llm.set_adapter(ens.adapter_names[a_idx])
38
- else:
39
- ens.llm.set_adapter(ens.adapter_names[0])
40
-
41
- ctx = q_ids.cpu()
42
- if use_rag:
43
- mems = ens.store.search(q_emb, k=1)
44
- if mems:
45
- ctx = torch.cat([mems[0], ctx], dim=1)
46
-
47
- t0 = time.time()
48
- if not use_jepa:
49
- out = ens.llm.generate(
50
- ctx.to(ens.llm.device), n_new=n_new, temperature=0.7
51
- )
52
- return out[:, ctx.size(1) :], time.time() - t0, None
53
-
54
- z_exp = ens.jepa.predict_next_latent(ctx)
55
- drafts, scores = [], []
56
- for _ in range(n_drafts):
57
- out = ens.llm.generate(
58
- ctx.to(ens.llm.device), n_new=n_new, temperature=0.9
59
- )
60
- new = out[:, ctx.size(1) :].cpu()
61
- drafts.append(new)
62
- scores.append(
63
- F.cosine_similarity(z_exp, ens.jepa.encode(new)).item()
64
- )
65
- best = max(range(n_drafts), key=lambda i: scores[i])
66
- return drafts[best], time.time() - t0, (drafts, scores)
67
-
68
-
69
- def selector_comparison(drafts_scores_gold, decode_fn, rng):
70
- res = defaultdict(list)
71
- for drafts, scores, gold in drafts_scores_gold:
72
- texts = [decode_fn(d) for d in drafts]
73
- ems = [em_score(t, gold) for t in texts]
74
- res["first"].append(ems[0])
75
- res["random"].append(ems[rng.randrange(len(ems))])
76
- res["jepa"].append(ems[max(range(len(ems)), key=lambda i: scores[i])])
77
- res["oracle"].append(max(ems))
78
- return {k: sum(v) / len(v) for k, v in res.items()}, res
79
-
80
-
81
- def load_jsonl(path):
82
- with open(path) as f:
83
- return [json.loads(line) for line in f if line.strip()]
84
-
85
-
86
- def make_toy_data(ens, n_qa=20, vocab=None):
87
- vocab = vocab or ens.llm.vocab_size
88
- qa, kb = [], []
89
- for _ in range(n_qa):
90
- key = torch.randint(0, vocab, (1, 6))
91
- ans = torch.randint(0, vocab, (1, 4))
92
- kb.append(torch.cat([key, ans], dim=1))
93
- qa.append({"q_ids": key, "answer_ids": ans})
94
- return qa, kb
95
-
96
-
97
- def run(args):
98
- torch.manual_seed(args.seed)
99
- rng = random.Random(args.seed)
100
-
101
- if args.ckpt:
102
- ens = load_checkpoint(args.ckpt)
103
- print(f"loaded ensemble checkpoint: {args.ckpt}")
104
- is_text = not isinstance(ens.llm, TinyBackend)
105
- else:
106
- load_dotenv()
107
- args.llm = resolve_llm_cli(
108
- args.llm, toy=args.toy, preset=getattr(args, "preset", None)
109
- )
110
- print(f"Resolved LLM: {args.llm}")
111
- ens = Ensemble(llm=args.llm)
112
- is_text = args.llm != "tiny"
113
-
114
- if args.toy or not is_text:
115
- qa, kb = make_toy_data(ens)
116
- for mem in kb:
117
- ens.memorize_ids(mem)
118
-
119
- def to_ids(item):
120
- return item["q_ids"]
121
-
122
- def gold_text(item):
123
- return " ".join(map(str, item["answer_ids"][0].tolist()))
124
-
125
- def decode(ids):
126
- return " ".join(map(str, ids[0].tolist()))
127
- else:
128
- qa = load_jsonl(args.qa)
129
- if args.kb:
130
- for row in load_jsonl(args.kb):
131
- ens.memorize_text(row["text"])
132
-
133
- def to_ids(item):
134
- return ens.llm.encode_text(
135
- f"Answer briefly.\nQ: {item['question']}\nA:"
136
- )
137
-
138
- def gold_text(item):
139
- return item["answer"]
140
-
141
- def decode(ids):
142
- return ens.llm.decode(ids)
143
-
144
- qa = qa[: args.limit]
145
- print(
146
- f"eval set: {len(qa)} questions | store: {len(ens.store.keys)} memories\n"
147
- )
148
-
149
- configs = {
150
- "C1_base": dict(use_rag=False, use_router=False, use_jepa=False),
151
- "C2_rag": dict(use_rag=True, use_router=False, use_jepa=False),
152
- "C3_rag_router": dict(use_rag=True, use_router=True, use_jepa=False),
153
- "C4_full_jepa": dict(
154
- use_rag=True,
155
- use_router=True,
156
- use_jepa=True,
157
- n_drafts=args.n_drafts,
158
- ),
159
- }
160
-
161
- per_q = {}
162
- summary = {}
163
- jepa_material = []
164
-
165
- for name, cfg in configs.items():
166
- ems, f1s, lats = [], [], []
167
- for item in qa:
168
- ids = to_ids(item)
169
- if name == "C1_base":
170
- out, dt = generate_plain(ens, ids, args.n_new)
171
- extra = None
172
- else:
173
- out, dt, extra = generate_config(ens, ids, args.n_new, **cfg)
174
- pred, gold = decode(out), gold_text(item)
175
- ems.append(em_score(pred, gold))
176
- f1s.append(f1_score(pred, gold))
177
- lats.append(dt)
178
- if name == "C4_full_jepa" and extra is not None:
179
- jepa_material.append((extra[0], extra[1], gold))
180
- per_q[name] = ems
181
- summary[name] = (
182
- sum(ems) / len(ems),
183
- sum(f1s) / len(f1s),
184
- sum(lats) / len(lats),
185
- )
186
-
187
- print(f"{'config':<16}{'EM':>8}{'F1':>8}{'lat(s)':>9}")
188
- for k, (em, f1, lat) in summary.items():
189
- print(f"{k:<16}{em:>8.3f}{f1:>8.3f}{lat:>9.3f}")
190
-
191
- print("\ncomponent contributions (paired bootstrap, P(B>A)):")
192
- ladder = list(configs.keys())
193
- for a, b in zip(ladder, ladder[1:]):
194
- d = summary[b][0] - summary[a][0]
195
- p = paired_bootstrap(per_q[a], per_q[b])
196
- print(f" {b} - {a}: ΔEM={d:+.3f} P(better)={p:.2f}")
197
-
198
- if jepa_material:
199
- sel, sel_per_q = selector_comparison(jepa_material, decode, rng)
200
- print(
201
- f"\nbest-of-N selector comparison (same drafts, N={args.n_drafts}):"
202
- )
203
- for k in ("first", "random", "jepa", "oracle"):
204
- print(f" {k:<8}EM={sel[k]:.3f}")
205
- p = paired_bootstrap(sel_per_q["random"], sel_per_q["jepa"])
206
- verdict = (
207
- "JEPA critic WORKS"
208
- if p > 0.95
209
- else "inconclusive — critic ~ random"
210
- )
211
- print(f" P(jepa > random) = {p:.2f} {verdict}")
212
- print(f" headroom to oracle: {sel['oracle'] - sel['jepa']:.3f}")
213
-
214
- if args.continual:
215
- print(
216
- "\ncontinual test: accuracy on task-A questions "
217
- "before vs after adding adapters B and C"
218
- )
219
- ems_before = per_q["C3_rag_router"]
220
- ens.new_task_adapter("task_B")
221
- ens.new_task_adapter("task_C")
222
- ems_after = []
223
- for item in qa:
224
- out, _, _ = generate_config(
225
- ens,
226
- to_ids(item),
227
- args.n_new,
228
- use_rag=True,
229
- use_router=True,
230
- use_jepa=False,
231
- )
232
- ems_after.append(em_score(decode(out), gold_text(item)))
233
- bt = sum(ems_after) / len(ems_after) - sum(ems_before) / len(
234
- ems_before
235
- )
236
- print(f" backward transfer (≈0 is ideal): {bt:+.3f}")
237
-
238
- return summary
239
-
240
-
241
- def parse_args():
242
- p = argparse.ArgumentParser()
243
- p.add_argument(
244
- "--llm",
245
- default=None,
246
- help="HF id / path, 'tiny', or omit for LLM_PATH / ACTIVE_MODEL from .env",
247
- )
248
- p.add_argument("--preset", default=None, help="models.yaml preset override")
249
- p.add_argument("--qa", default=None, help="jsonl with question/answer")
250
- p.add_argument("--kb", default=None, help="jsonl with text -> vector store")
251
- p.add_argument(
252
- "--ckpt",
253
- default=None,
254
- help="saved ensemble directory (models/ensemble/... with manifest.json)",
255
- )
256
- p.add_argument("--toy", action="store_true", help="synthetic data smoke test")
257
- p.add_argument("--limit", type=int, default=100)
258
- p.add_argument("--n_new", type=int, default=24)
259
- p.add_argument("--n_drafts", type=int, default=8)
260
- p.add_argument("--continual", action="store_true")
261
- p.add_argument("--seed", type=int, default=0)
262
- return p.parse_args()
263
-
264
-
265
- if __name__ == "__main__":
266
- run(parse_args())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/eval/metrics.py DELETED
@@ -1,42 +0,0 @@
1
- """QA metrics and paired bootstrap significance."""
2
-
3
- from __future__ import annotations
4
-
5
- import random
6
- import re
7
- import string
8
- from collections import Counter
9
-
10
-
11
- def normalize(s: str) -> str:
12
- s = s.lower()
13
- s = "".join(c for c in s if c not in string.punctuation)
14
- s = re.sub(r"\b(a|an|the)\b", " ", s)
15
- return " ".join(s.split())
16
-
17
-
18
- def em_score(pred: str, gold: str) -> float:
19
- return float(normalize(gold) in normalize(pred))
20
-
21
-
22
- def f1_score(pred: str, gold: str) -> float:
23
- p, g = normalize(pred).split(), normalize(gold).split()
24
- if not p or not g:
25
- return float(p == g)
26
- common = Counter(p) & Counter(g)
27
- overlap = sum(common.values())
28
- if overlap == 0:
29
- return 0.0
30
- prec, rec = overlap / len(p), overlap / len(g)
31
- return 2 * prec * rec / (prec + rec)
32
-
33
-
34
- def paired_bootstrap(scores_a, scores_b, iters=2000, seed=0):
35
- rng = random.Random(seed)
36
- n, wins = len(scores_a), 0
37
- for _ in range(iters):
38
- idx = [rng.randrange(n) for _ in range(n)]
39
- da = sum(scores_a[i] for i in idx) / n
40
- db = sum(scores_b[i] for i in idx) / n
41
- wins += db > da
42
- return wins / iters
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/eval/world_harness.py DELETED
@@ -1,174 +0,0 @@
1
- """Energy-based draft selector benchmark for the world-model ensemble."""
2
-
3
- from __future__ import annotations
4
-
5
- import argparse
6
- import json
7
- import random
8
- import time
9
- from collections import defaultdict
10
-
11
- import torch
12
-
13
- from ensemble.eval.metrics import em_score, f1_score, paired_bootstrap
14
- from ensemble.world_ensemble import WorldEnsemble
15
-
16
-
17
- @torch.no_grad()
18
- def generate_drafts(ens, q_ids, n_new, n_drafts, use_rag=True):
19
- q_emb = ens.emb(q_ids.cpu())
20
- mems = ens.store.search(q_emb, k=1) if use_rag else []
21
- segments = (mems + [q_ids.cpu()]) if mems else [q_ids.cpu()]
22
- ctx = torch.cat(segments, dim=1)
23
-
24
- s = ens.world_state(segments)
25
- ens.world.rollout(s, horizon=3)
26
-
27
- drafts, energies = [], []
28
- t0 = time.time()
29
- for _ in range(n_drafts):
30
- out = ens.llm.generate(
31
- ctx.to(ens.llm.device), n_new=n_new, temperature=0.9
32
- )
33
- new = out[:, ctx.size(1) :].cpu()
34
- drafts.append(new)
35
- z = ens.jepa.encode(new)
36
- energies.append(ens.energy.rank(s, z).item())
37
- return drafts, energies, time.time() - t0
38
-
39
-
40
- def selector_comparison(drafts_energy_gold, decode_fn, rng):
41
- res = defaultdict(list)
42
- for drafts, energies, gold in drafts_energy_gold:
43
- texts = [decode_fn(d) for d in drafts]
44
- ems = [em_score(t, gold) for t in texts]
45
- res["first"].append(ems[0])
46
- res["random"].append(ems[rng.randrange(len(ems))])
47
- res["energy"].append(
48
- ems[min(range(len(ems)), key=lambda i: energies[i])]
49
- )
50
- res["oracle"].append(max(ems))
51
- return {k: sum(v) / len(v) for k, v in res.items()}, res
52
-
53
-
54
- def load_jsonl(path):
55
- with open(path) as f:
56
- return [json.loads(line) for line in f if line.strip()]
57
-
58
-
59
- def make_toy_data(ens, n_qa=20, vocab=None):
60
- vocab = vocab or ens.llm.vocab_size
61
- qa, kb = [], []
62
- for _ in range(n_qa):
63
- key = torch.randint(0, vocab, (1, 6))
64
- ans = torch.randint(0, vocab, (1, 4))
65
- kb.append(torch.cat([key, ans], dim=1))
66
- qa.append({"q_ids": key, "answer_ids": ans})
67
- return qa, kb
68
-
69
-
70
- def run(args):
71
- from ensemble.config import load_dotenv, resolve_llm_cli
72
-
73
- torch.manual_seed(args.seed)
74
- rng = random.Random(args.seed)
75
-
76
- load_dotenv()
77
- args.llm = resolve_llm_cli(
78
- args.llm, toy=args.toy, preset=getattr(args, "preset", None)
79
- )
80
- print(f"Resolved LLM: {args.llm}")
81
- ens = WorldEnsemble(args.llm)
82
- if args.ckpt:
83
- state = torch.load(args.ckpt, map_location="cpu")
84
- ens.load_state_dict(state, strict=False)
85
- print(f"loaded world ensemble checkpoint: {args.ckpt}")
86
-
87
- is_text = args.llm != "tiny"
88
-
89
- if args.toy or not is_text:
90
- qa, kb = make_toy_data(ens)
91
- for mem in kb:
92
- ens.memorize(mem)
93
-
94
- def to_ids(item):
95
- return item["q_ids"]
96
-
97
- def gold_text(item):
98
- return " ".join(map(str, item["answer_ids"][0].tolist()))
99
-
100
- def decode(ids):
101
- return " ".join(map(str, ids[0].tolist()))
102
- else:
103
- qa = load_jsonl(args.qa)
104
- if args.kb:
105
- for row in load_jsonl(args.kb):
106
- ids = ens.llm.tokenizer(
107
- row["text"], return_tensors="pt"
108
- ).input_ids
109
- ens.memorize(ids)
110
-
111
- def to_ids(item):
112
- return ens.llm.tokenizer(
113
- f"Answer briefly.\nQ: {item['question']}\nA:",
114
- return_tensors="pt",
115
- ).input_ids
116
-
117
- def gold_text(item):
118
- return item["answer"]
119
-
120
- def decode(ids):
121
- return ens.llm.tokenizer.decode(ids[0], skip_special_tokens=True)
122
-
123
- qa = qa[: args.limit]
124
- print(
125
- f"eval set: {len(qa)} questions | store: {len(ens.store.keys)} memories\n"
126
- )
127
-
128
- material = []
129
- lats = []
130
- for item in qa:
131
- drafts, energies, dt = generate_drafts(
132
- ens, to_ids(item), args.n_new, args.n_drafts
133
- )
134
- material.append((drafts, energies, gold_text(item)))
135
- lats.append(dt)
136
-
137
- sel, sel_per_q = selector_comparison(material, decode, rng)
138
- print(f"best-of-N selector comparison (same drafts, N={args.n_drafts}):")
139
- for k in ("first", "random", "energy", "oracle"):
140
- print(f" {k:<8}EM={sel[k]:.3f}")
141
- p = paired_bootstrap(sel_per_q["random"], sel_per_q["energy"])
142
- verdict = (
143
- "Energy critic WORKS"
144
- if p > 0.95
145
- else "inconclusive — critic ~ random"
146
- )
147
- print(f" P(energy > random) = {p:.2f} {verdict}")
148
- print(f" headroom to oracle: {sel['oracle'] - sel['energy']:.3f}")
149
- print(f" mean latency: {sum(lats) / len(lats):.3f}s")
150
-
151
- return sel
152
-
153
-
154
- def parse_args():
155
- p = argparse.ArgumentParser()
156
- p.add_argument(
157
- "--llm",
158
- default=None,
159
- help="HF id / path, 'tiny', or omit for LLM_PATH / ACTIVE_MODEL from .env",
160
- )
161
- p.add_argument("--preset", default=None, help="models.yaml preset override")
162
- p.add_argument("--qa", default=None, help="jsonl with question/answer")
163
- p.add_argument("--kb", default=None, help="jsonl with text -> vector store")
164
- p.add_argument("--ckpt", default=None, help="trained world ensemble .pt")
165
- p.add_argument("--toy", action="store_true", help="synthetic data smoke test")
166
- p.add_argument("--limit", type=int, default=100)
167
- p.add_argument("--n_new", type=int, default=24)
168
- p.add_argument("--n_drafts", type=int, default=8)
169
- p.add_argument("--seed", type=int, default=0)
170
- return p.parse_args()
171
-
172
-
173
- if __name__ == "__main__":
174
- run(parse_args())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/eval_harness.py DELETED
@@ -1,309 +0,0 @@
1
- """
2
- eval_harness.py — Ablation ladder + JEPA best-of-N test for the ensemble
3
- ========================================================================
4
- Companion to `llm_emb_jepa_ensemble_pluggable.py` (must be importable,
5
- i.e. in the same directory).
6
-
7
- What it runs
8
- ------------
9
- 1. ABLATION LADDER on a QA set:
10
- C1 base LLM alone
11
- C2 C1 + RAG (embedding retrieval)
12
- C3 C2 + router/adapters
13
- C4 C3 + JEPA best-of-N critic
14
- (C5 = C4 with a bridge-trained checkpoint — just pass --ckpt)
15
-
16
- 2. BEST-OF-N SELECTOR comparison (the decisive JEPA experiment):
17
- first-sample | random-pick | JEPA-score pick | oracle pick
18
- All on the SAME N drafts per question, so differences are pure selection.
19
-
20
- 3. CONTINUAL FORGETTING test (optional, --continual):
21
- accuracy on task A before vs after training adapters for B and C.
22
-
23
- 4. PAIRED BOOTSTRAP significance between any two configs.
24
-
25
- Usage
26
- -----
27
- # Smoke test, no GPU/deps beyond torch (toy backend, synthetic QA):
28
- python eval_harness.py --llm tiny --toy
29
-
30
- # Real model + your QA file (jsonl: {"question": ..., "answer": ..., "context": optional}):
31
- python eval_harness.py --llm Qwen/Qwen2.5-0.5B-Instruct \
32
- --qa ./domain_qa.jsonl --kb ./knowledge.jsonl --n_drafts 8
33
-
34
- # With a bridge-trained ensemble checkpoint (C5):
35
- python eval_harness.py --llm /models/llama-3.2-1b --qa ./qa.jsonl \
36
- --kb ./kb.jsonl --ckpt ./ensemble_bridge.pt
37
-
38
- QA file: {"question": str, "answer": str, "domain": optional str}
39
- KB file: {"text": str} (each line becomes one memory in the vector store)
40
- """
41
-
42
- import argparse
43
- import json
44
- import random
45
- import re
46
- import string
47
- import time
48
- from collections import Counter, defaultdict
49
-
50
- import torch
51
-
52
- from llm_emb_jepa_ensemble_pluggable import Ensemble # same directory
53
-
54
- # ----------------------------------------------------------------------------
55
- # Metrics: normalized exact match + token F1 (SQuAD-style)
56
- # ----------------------------------------------------------------------------
57
- def normalize(s: str) -> str:
58
- s = s.lower()
59
- s = "".join(c for c in s if c not in string.punctuation)
60
- s = re.sub(r"\b(a|an|the)\b", " ", s)
61
- return " ".join(s.split())
62
-
63
-
64
- def em_score(pred: str, gold: str) -> float:
65
- return float(normalize(gold) in normalize(pred)) # containment EM
66
-
67
-
68
- def f1_score(pred: str, gold: str) -> float:
69
- p, g = normalize(pred).split(), normalize(gold).split()
70
- if not p or not g:
71
- return float(p == g)
72
- common = Counter(p) & Counter(g)
73
- overlap = sum(common.values())
74
- if overlap == 0:
75
- return 0.0
76
- prec, rec = overlap / len(p), overlap / len(g)
77
- return 2 * prec * rec / (prec + rec)
78
-
79
-
80
- # ----------------------------------------------------------------------------
81
- # Paired bootstrap: P(config B beats config A)
82
- # ----------------------------------------------------------------------------
83
- def paired_bootstrap(scores_a, scores_b, iters=2000, seed=0):
84
- rng = random.Random(seed)
85
- n, wins = len(scores_a), 0
86
- for _ in range(iters):
87
- idx = [rng.randrange(n) for _ in range(n)]
88
- da = sum(scores_a[i] for i in idx) / n
89
- db = sum(scores_b[i] for i in idx) / n
90
- wins += db > da
91
- return wins / iters
92
-
93
-
94
- # ----------------------------------------------------------------------------
95
- # Config runners — each returns per-question dicts
96
- # ----------------------------------------------------------------------------
97
- @torch.no_grad()
98
- def generate_plain(ens, q_ids, n_new):
99
- """C1: base adapter, no retrieval, single sample."""
100
- ens.llm.set_adapter(ens.adapter_names[0])
101
- t0 = time.time()
102
- out = ens.llm.generate(q_ids.to(ens.llm.device), n_new=n_new, temperature=0.7)
103
- return out[:, q_ids.size(1):], time.time() - t0
104
-
105
-
106
- @torch.no_grad()
107
- def generate_config(ens, q_ids, n_new, *, use_rag, use_router, use_jepa,
108
- n_drafts=1, tau=0.0):
109
- """Unified runner for C2/C3/C4."""
110
- q_emb = ens.emb(q_ids.cpu())
111
-
112
- if use_router:
113
- a_idx = ens.router(q_emb).item()
114
- ens.llm.set_adapter(ens.adapter_names[a_idx])
115
- else:
116
- ens.llm.set_adapter(ens.adapter_names[0])
117
-
118
- ctx = q_ids.cpu()
119
- if use_rag:
120
- mems = ens.store.search(q_emb, k=1)
121
- if mems:
122
- ctx = torch.cat([mems[0], ctx], dim=1)
123
-
124
- t0 = time.time()
125
- if not use_jepa:
126
- out = ens.llm.generate(ctx.to(ens.llm.device), n_new=n_new, temperature=0.7)
127
- return out[:, ctx.size(1):], time.time() - t0, None
128
-
129
- # JEPA best-of-N: sample drafts, keep the one closest to predicted latent
130
- z_exp = ens.jepa.predict_next_latent(ctx)
131
- drafts, scores = [], []
132
- for _ in range(n_drafts):
133
- out = ens.llm.generate(ctx.to(ens.llm.device), n_new=n_new, temperature=0.9)
134
- new = out[:, ctx.size(1):].cpu()
135
- drafts.append(new)
136
- scores.append(torch.nn.functional.cosine_similarity(
137
- z_exp, ens.jepa.encode(new)).item())
138
- best = max(range(n_drafts), key=lambda i: scores[i])
139
- return drafts[best], time.time() - t0, (drafts, scores)
140
-
141
-
142
- # ----------------------------------------------------------------------------
143
- # Best-of-N selector comparison on shared drafts
144
- # ----------------------------------------------------------------------------
145
- def selector_comparison(drafts_scores_gold, decode_fn, rng):
146
- """drafts_scores_gold: list of (drafts, jepa_scores, gold_answer).
147
- Returns EM for: first | random | jepa | oracle — all on the SAME drafts."""
148
- res = defaultdict(list)
149
- for drafts, scores, gold in drafts_scores_gold:
150
- texts = [decode_fn(d) for d in drafts]
151
- ems = [em_score(t, gold) for t in texts]
152
- res["first"].append(ems[0])
153
- res["random"].append(ems[rng.randrange(len(ems))])
154
- res["jepa"].append(ems[max(range(len(ems)), key=lambda i: scores[i])])
155
- res["oracle"].append(max(ems)) # upper bound of selection
156
- return {k: sum(v) / len(v) for k, v in res.items()}, res
157
-
158
-
159
- # ----------------------------------------------------------------------------
160
- # Data loading
161
- # ----------------------------------------------------------------------------
162
- def load_jsonl(path):
163
- with open(path) as f:
164
- return [json.loads(l) for l in f if l.strip()]
165
-
166
-
167
- def make_toy_data(ens, n_qa=20, vocab=None):
168
- """Synthetic QA for the tiny backend: 'answer' token sequence is planted
169
- in the KB so RAG can genuinely help even with random weights."""
170
- vocab = vocab or ens.llm.vocab_size
171
- qa, kb = [], []
172
- for i in range(n_qa):
173
- key = torch.randint(0, vocab, (1, 6))
174
- ans = torch.randint(0, vocab, (1, 4))
175
- kb.append(torch.cat([key, ans], dim=1)) # memory = key+answer
176
- qa.append({"q_ids": key, "answer_ids": ans})
177
- return qa, kb
178
-
179
-
180
- # ----------------------------------------------------------------------------
181
- # Main evaluation
182
- # ----------------------------------------------------------------------------
183
- def run(args):
184
- torch.manual_seed(args.seed)
185
- rng = random.Random(args.seed)
186
-
187
- ens = Ensemble(llm=args.llm)
188
- if args.ckpt:
189
- state = torch.load(args.ckpt, map_location="cpu")
190
- ens.load_state_dict(state, strict=False)
191
- print(f"loaded ensemble checkpoint: {args.ckpt}")
192
-
193
- is_text = args.llm != "tiny"
194
-
195
- # ---- load data and fill the vector store -------------------------------
196
- if args.toy or not is_text:
197
- qa, kb = make_toy_data(ens)
198
- for mem in kb:
199
- ens.memorize_ids(mem)
200
- def to_ids(item): return item["q_ids"]
201
- def gold_of(item): return item["answer_ids"]
202
- def decode(ids): return " ".join(map(str, ids[0].tolist()))
203
- def gold_text(item): return decode(item["answer_ids"])
204
- else:
205
- qa = load_jsonl(args.qa)
206
- if args.kb:
207
- for row in load_jsonl(args.kb):
208
- ens.memorize_text(row["text"])
209
- def to_ids(item): return ens.llm.encode_text(
210
- f"Answer briefly.\nQ: {item['question']}\nA:")
211
- def gold_text(item): return item["answer"]
212
- def decode(ids): return ens.llm.decode(ids)
213
-
214
- qa = qa[: args.limit]
215
- print(f"eval set: {len(qa)} questions | store: {len(ens.store.keys)} memories\n")
216
-
217
- # ---- ablation ladder ----------------------------------------------------
218
- configs = {
219
- "C1_base": dict(use_rag=False, use_router=False, use_jepa=False),
220
- "C2_rag": dict(use_rag=True, use_router=False, use_jepa=False),
221
- "C3_rag_router": dict(use_rag=True, use_router=True, use_jepa=False),
222
- "C4_full_jepa": dict(use_rag=True, use_router=True, use_jepa=True,
223
- n_drafts=args.n_drafts),
224
- }
225
-
226
- per_q = {} # config -> list of EM scores (for bootstrap)
227
- summary = {}
228
- jepa_material = [] # (drafts, scores, gold) for selector comparison
229
-
230
- for name, cfg in configs.items():
231
- ems, f1s, lats = [], [], []
232
- for item in qa:
233
- ids = to_ids(item)
234
- if name == "C1_base":
235
- out, dt = generate_plain(ens, ids, args.n_new)
236
- extra = None
237
- else:
238
- out, dt, extra = generate_config(ens, ids, args.n_new, **cfg)
239
- pred, gold = decode(out), gold_text(item)
240
- ems.append(em_score(pred, gold))
241
- f1s.append(f1_score(pred, gold))
242
- lats.append(dt)
243
- if name == "C4_full_jepa" and extra is not None:
244
- jepa_material.append((extra[0], extra[1], gold))
245
- per_q[name] = ems
246
- summary[name] = (sum(ems) / len(ems), sum(f1s) / len(f1s),
247
- sum(lats) / len(lats))
248
-
249
- print(f"{'config':<16}{'EM':>8}{'F1':>8}{'lat(s)':>9}")
250
- for k, (em, f1, lat) in summary.items():
251
- print(f"{k:<16}{em:>8.3f}{f1:>8.3f}{lat:>9.3f}")
252
-
253
- # deltas + significance
254
- print("\ncomponent contributions (paired bootstrap, P(B>A)):")
255
- ladder = list(configs.keys())
256
- for a, b in zip(ladder, ladder[1:]):
257
- d = summary[b][0] - summary[a][0]
258
- p = paired_bootstrap(per_q[a], per_q[b])
259
- print(f" {b} - {a}: ΔEM={d:+.3f} P(better)={p:.2f}")
260
-
261
- # ---- decisive JEPA selector experiment ----------------------------------
262
- if jepa_material:
263
- sel, sel_per_q = selector_comparison(jepa_material, decode, rng)
264
- print("\nbest-of-N selector comparison (same drafts, N="
265
- f"{args.n_drafts}):")
266
- for k in ("first", "random", "jepa", "oracle"):
267
- print(f" {k:<8}EM={sel[k]:.3f}")
268
- p = paired_bootstrap(sel_per_q["random"], sel_per_q["jepa"])
269
- print(f" P(jepa > random) = {p:.2f} "
270
- f"{'JEPA critic WORKS' if p > 0.95 else 'inconclusive — critic ~ random'}")
271
- gap = sel["oracle"] - sel["jepa"]
272
- print(f" headroom to oracle: {gap:.3f}")
273
-
274
- # ---- continual forgetting (optional) ------------------------------------
275
- if args.continual:
276
- print("\ncontinual test: accuracy on task-A questions "
277
- "before vs after adding adapters B and C")
278
- ems_before = per_q["C3_rag_router"]
279
- ens.new_task_adapter("task_B")
280
- ens.new_task_adapter("task_C")
281
- ems_after = []
282
- for item in qa:
283
- out, _, _ = generate_config(ens, to_ids(item), args.n_new,
284
- use_rag=True, use_router=True,
285
- use_jepa=False)
286
- ems_after.append(em_score(decode(out), gold_text(item)))
287
- bt = sum(ems_after) / len(ems_after) - sum(ems_before) / len(ems_before)
288
- print(f" backward transfer (≈0 is ideal): {bt:+.3f}")
289
-
290
- return summary
291
-
292
-
293
- def parse_args():
294
- p = argparse.ArgumentParser()
295
- p.add_argument("--llm", default="tiny", help="'tiny' | HF id | local path")
296
- p.add_argument("--qa", default=None, help="jsonl with question/answer")
297
- p.add_argument("--kb", default=None, help="jsonl with text -> vector store")
298
- p.add_argument("--ckpt", default=None, help="bridge-trained ensemble .pt (C5)")
299
- p.add_argument("--toy", action="store_true", help="synthetic data smoke test")
300
- p.add_argument("--limit", type=int, default=100)
301
- p.add_argument("--n_new", type=int, default=24)
302
- p.add_argument("--n_drafts", type=int, default=8)
303
- p.add_argument("--continual", action="store_true")
304
- p.add_argument("--seed", type=int, default=0)
305
- return p.parse_args()
306
-
307
-
308
- if __name__ == "__main__":
309
- run(parse_args())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/jepa.py DELETED
@@ -1,75 +0,0 @@
1
- """JEPA latent predictor with EMA target encoder."""
2
-
3
- from __future__ import annotations
4
-
5
- import copy
6
-
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
-
11
-
12
- class _SegEncoder(nn.Module):
13
- def __init__(self, vocab_size, d):
14
- super().__init__()
15
- self.tok = nn.Embedding(vocab_size, d)
16
- self.enc = nn.GRU(d, d, batch_first=True)
17
- self.out = nn.Linear(d, d)
18
-
19
- def forward(self, ids):
20
- h, _ = self.enc(self.tok(ids))
21
- return self.out(h.mean(dim=1))
22
-
23
-
24
- class JEPA(nn.Module):
25
- def __init__(self, vocab_size: int, d_latent: int = 64, ema_m: float = 0.996):
26
- super().__init__()
27
- self.ctx_enc = _SegEncoder(vocab_size, d_latent)
28
- self.tgt_enc = copy.deepcopy(self.ctx_enc)
29
- for p in self.tgt_enc.parameters():
30
- p.requires_grad_(False)
31
- self.predictor = nn.Sequential(
32
- nn.Linear(d_latent, 2 * d_latent),
33
- nn.GELU(),
34
- nn.Linear(2 * d_latent, d_latent),
35
- )
36
- self.m = ema_m
37
- self.d_latent = d_latent
38
-
39
- @property
40
- def enc(self):
41
- """Alias used by world-model track."""
42
- return self.ctx_enc
43
-
44
- @property
45
- def tgt(self):
46
- return self.tgt_enc
47
-
48
- @property
49
- def pred(self):
50
- return self.predictor
51
-
52
- @torch.no_grad()
53
- def ema_update(self):
54
- for p_t, p_c in zip(self.tgt_enc.parameters(), self.ctx_enc.parameters()):
55
- p_t.mul_(self.m).add_(p_c.detach(), alpha=1 - self.m)
56
-
57
- def ema(self):
58
- """Alias used by world-model track."""
59
- self.ema_update()
60
-
61
- def loss(self, seg_ctx, seg_tgt):
62
- z_hat = self.predictor(self.ctx_enc(seg_ctx))
63
- with torch.no_grad():
64
- z_tgt = self.tgt_enc(seg_tgt)
65
- pred = F.mse_loss(z_hat, z_tgt)
66
- var_reg = F.relu(1.0 - z_hat.std(dim=0)).mean()
67
- return pred + 0.5 * var_reg
68
-
69
- @torch.no_grad()
70
- def predict_next_latent(self, seg_ctx):
71
- return self.predictor(self.ctx_enc(seg_ctx))
72
-
73
- @torch.no_grad()
74
- def encode(self, seg):
75
- return self.tgt_enc(seg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/jepa_ensemble.py DELETED
@@ -1,232 +0,0 @@
1
- """JEPA ensemble: route -> retrieve -> generate -> JEPA-verify."""
2
-
3
- from __future__ import annotations
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
-
9
- from ensemble.backends import HFBackend, make_backend
10
- from ensemble.bridge import Bridge
11
- from ensemble.jepa import JEPA
12
- from ensemble.memory import Embedder, Router, VectorStore
13
-
14
- torch.manual_seed(0)
15
-
16
-
17
- class Ensemble(nn.Module):
18
- def __init__(
19
- self,
20
- llm: str = "tiny",
21
- adapter_names=("general",),
22
- d_emb: int = 64,
23
- d_jepa: int = 64,
24
- llm_backend: HFBackend | None = None,
25
- **backend_kw,
26
- ):
27
- super().__init__()
28
- self.llm = llm_backend if llm_backend is not None else make_backend(llm, **backend_kw)
29
- V, H = self.llm.vocab_size, self.llm.hidden_size
30
-
31
- self.emb = Embedder(V, d_emb)
32
- self.jepa = JEPA(V, d_jepa)
33
- self.bridge = Bridge(H, d_jepa)
34
- self.store = VectorStore()
35
-
36
- self.adapter_names = list(adapter_names)
37
- for n in self.adapter_names:
38
- self.llm.add_adapter(n)
39
- self.llm.set_adapter(self.adapter_names[0])
40
- self.router = Router(d_emb, len(self.adapter_names))
41
-
42
- @torch.no_grad()
43
- def answer_ids(
44
- self,
45
- query_ids,
46
- n_new=32,
47
- tau_consistency=0.0,
48
- max_retries=2,
49
- temperature: float = 0.7,
50
- ):
51
- q_emb = self.emb(query_ids.cpu())
52
- a_idx = self.router(q_emb).item()
53
- self.llm.set_adapter(self.adapter_names[a_idx])
54
-
55
- mems = self.store.search(q_emb, k=1)
56
- ctx = (
57
- torch.cat([mems[0], query_ids.cpu()], dim=1)
58
- if mems
59
- else query_ids.cpu()
60
- )
61
-
62
- z_expected = self.jepa.predict_next_latent(ctx)
63
-
64
- best = None
65
- for attempt in range(max_retries + 1):
66
- temp = temperature if attempt == 0 else max(temperature, 0.8 + 0.3 * attempt)
67
- draft = self.llm.generate(
68
- ctx.to(self.llm.device),
69
- n_new=n_new,
70
- temperature=temp,
71
- )
72
- new_part = draft[:, ctx.size(1) :].cpu()
73
- score = F.cosine_similarity(
74
- z_expected, self.jepa.encode(new_part)
75
- ).item()
76
- if best is None or score > best[1]:
77
- best = (draft, score, attempt)
78
- if score >= tau_consistency:
79
- break
80
- draft, score, attempt = best
81
- return draft, score, self.adapter_names[a_idx], attempt
82
-
83
- def answer_text(self, prompt: str, **kw):
84
- ids = self.llm.encode_text(prompt)
85
- out, score, adapter, retries = self.answer_ids(ids, **kw)
86
- return self.llm.decode(out), score, adapter, retries
87
-
88
- def generate_text(
89
- self,
90
- prompt: str,
91
- *,
92
- max_new_tokens: int = 512,
93
- temperature: float = 0.0,
94
- ) -> str:
95
- """Greedy or sampled generation through the full ensemble stack."""
96
- ids = self.llm.encode_text(prompt)
97
- out, _, _, _ = self.answer_ids(
98
- ids,
99
- n_new=max_new_tokens,
100
- tau_consistency=-1.0,
101
- max_retries=0 if temperature <= 0 else 1,
102
- temperature=temperature,
103
- )
104
- return self.llm.decode(out)
105
-
106
- def memorize_ids(self, ids):
107
- self.store.add(self.emb(ids.cpu()), ids.cpu())
108
-
109
- def memorize_text(self, text: str):
110
- self.memorize_ids(self.llm.encode_text(text))
111
-
112
- def new_task_adapter(self, name: str):
113
- self.adapter_names.append(name)
114
- self.llm.add_adapter(name)
115
- old = self.router
116
- self.router = Router(self.emb.d_emb, len(self.adapter_names))
117
- with torch.no_grad():
118
- self.router.fc.weight[: old.fc.out_features] = old.fc.weight
119
- self.router.fc.bias[: old.fc.out_features] = old.fc.bias
120
-
121
- def train_step(self, seg_a, seg_b, opt, w_bridge=0.1):
122
- logits, hidden = self.llm(seg_a.to(self.llm.device))
123
- lm_loss = F.cross_entropy(
124
- logits[:, :-1].reshape(-1, self.llm.vocab_size).float(),
125
- seg_a[:, 1:].reshape(-1).to(logits.device),
126
- )
127
-
128
- jepa_loss = self.jepa.loss(seg_a.cpu(), seg_b.cpu())
129
-
130
- z_llm = self.bridge(
131
- hidden.cpu() if hidden.device.type != "cpu" else hidden
132
- )
133
- z_jepa = self.jepa.ctx_enc(seg_a.cpu()).detach()
134
- bridge_loss = self.bridge.info_nce(z_llm, z_jepa.to(z_llm.device))
135
-
136
- loss = lm_loss.cpu() + jepa_loss + w_bridge * bridge_loss
137
- opt.zero_grad()
138
- loss.backward()
139
- opt.step()
140
- self.jepa.ema_update()
141
- return {
142
- "lm": lm_loss.item(),
143
- "jepa": jepa_loss.item(),
144
- "bridge": bridge_loss.item(),
145
- }
146
-
147
- def make_optimizer(self, lr_lora=2e-4, lr_aux=1e-3):
148
- return torch.optim.AdamW(
149
- [
150
- {"params": list(self.llm.trainable_parameters()), "lr": lr_lora},
151
- {
152
- "params": list(self.jepa.ctx_enc.parameters())
153
- + list(self.jepa.predictor.parameters()),
154
- "lr": lr_aux,
155
- },
156
- {
157
- "params": list(self.bridge.parameters())
158
- + list(self.emb.parameters())
159
- + list(self.router.parameters()),
160
- "lr": lr_aux,
161
- },
162
- ]
163
- )
164
-
165
-
166
- def segment_pairs_from_texts(backend: HFBackend, texts, seg_len=64):
167
- a_list, b_list = [], []
168
- for t in texts:
169
- ids = backend.tokenizer(t, return_tensors="pt").input_ids[0]
170
- for i in range(0, len(ids) - 2 * seg_len, seg_len):
171
- a_list.append(ids[i : i + seg_len])
172
- b_list.append(ids[i + seg_len : i + 2 * seg_len])
173
- if not a_list:
174
- raise ValueError("texts too short for the chosen seg_len")
175
- return torch.stack(a_list), torch.stack(b_list)
176
-
177
-
178
- def demo_tiny(steps=50):
179
- ens = Ensemble(llm="tiny")
180
- opt = ens.make_optimizer()
181
- for s in range(steps):
182
- seg_a = torch.randint(0, ens.llm.vocab_size, (8, 32))
183
- seg_b = torch.randint(0, ens.llm.vocab_size, (8, 32))
184
- logs = ens.train_step(seg_a, seg_b, opt)
185
- if s % 10 == 0:
186
- print(
187
- f"step {s:3d} | "
188
- + " | ".join(f"{k} {v:.3f}" for k, v in logs.items())
189
- )
190
-
191
- for _ in range(5):
192
- ens.memorize_ids(torch.randint(0, ens.llm.vocab_size, (1, 32)))
193
- ens.new_task_adapter("medical")
194
-
195
- q = torch.randint(0, ens.llm.vocab_size, (1, 8))
196
- out, score, adapter, retries = ens.answer_ids(q, tau_consistency=-1.0)
197
- print(f"\nadapter={adapter} jepa_consistency={score:.3f} retries={retries}")
198
-
199
-
200
- def demo_hf(model_path="Qwen/Qwen2.5-0.5B-Instruct"):
201
- ens = Ensemble(llm=model_path, load_in_4bit=False)
202
- opt = ens.make_optimizer()
203
-
204
- texts = ["Replace this with your real corpus. " * 50]
205
- seg_a, seg_b = segment_pairs_from_texts(ens.llm, texts, seg_len=32)
206
- for s in range(10):
207
- logs = ens.train_step(seg_a[:4], seg_b[:4], opt)
208
- print(f"step {s} | " + " | ".join(f"{k} {v:.3f}" for k, v in logs.items()))
209
-
210
- ens.memorize_text("The project codename is AURORA and it ships in Q3.")
211
- ens.new_task_adapter("project_aurora")
212
-
213
- text, score, adapter, retries = ens.answer_text(
214
- "What is the project codename?", n_new=24, tau_consistency=-1.0
215
- )
216
- print(f"\n[{adapter} | jepa={score:.3f} | retries={retries}]\n{text}")
217
-
218
-
219
- if __name__ == "__main__":
220
- import sys
221
-
222
- from ensemble.config import load_dotenv, resolve_llm
223
-
224
- load_dotenv()
225
- arg = sys.argv[1] if len(sys.argv) > 1 else None
226
- if arg is None or arg == "auto":
227
- arg, preset = resolve_llm()
228
- print(f"Resolved LLM: {arg} (preset {preset})")
229
- if arg == "tiny":
230
- demo_tiny()
231
- else:
232
- demo_hf(arg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/llm_emb_jepa_ensemble_pluggable.py DELETED
@@ -1,507 +0,0 @@
1
- """
2
- LLM + Embedding + JEPA Ensemble — pluggable base-model edition
3
- ==============================================================
4
- Now the LLM is a swappable BACKEND. Three ways to load it:
5
-
6
- # 1. HuggingFace Hub id
7
- ens = Ensemble(llm="Qwen/Qwen2.5-0.5B-Instruct")
8
-
9
- # 2. Local path (e.g. downloaded Llama / converted checkpoint)
10
- ens = Ensemble(llm="/models/llama-3.2-1b")
11
-
12
- # 3. Toy fallback (no transformers needed, runs on CPU in seconds)
13
- ens = Ensemble(llm="tiny")
14
-
15
- Requirements for real models:
16
- pip install torch transformers peft accelerate
17
- (optional 4-bit: pip install bitsandbytes -> load_in_4bit=True)
18
-
19
- Everything else (Embedder, JEPA, Bridge, VectorStore, Router, the
20
- JEPA-critic inference loop, continual-learning hooks) only touches
21
- token ids / hidden states / latents, so it works with ANY backend.
22
- """
23
-
24
- from __future__ import annotations
25
- import copy
26
- import torch
27
- import torch.nn as nn
28
- import torch.nn.functional as F
29
-
30
- torch.manual_seed(0)
31
-
32
- # ----------------------------------------------------------------------------
33
- # 0. Backend interface — everything the ensemble needs from "an LLM"
34
- # ----------------------------------------------------------------------------
35
- class LLMBackend(nn.Module):
36
- """Contract:
37
- vocab_size : int
38
- hidden_size: int
39
- device : torch.device
40
- forward(ids) -> (logits [B,T,V], hidden [B,T,H])
41
- generate(ids, n_new) -> ids [B, T+n_new]
42
- add_adapter(name) / set_adapter(name)
43
- trainable_parameters() -> iterable of params to optimize
44
- encode_text(str) / decode(ids) (real backends only)
45
- """
46
- vocab_size: int
47
- hidden_size: int
48
-
49
-
50
- # ----------------------------------------------------------------------------
51
- # 0a. HuggingFace backend (local path OR hub id) with PEFT LoRA adapters
52
- # ----------------------------------------------------------------------------
53
- class HFBackend(LLMBackend):
54
- def __init__(self, model_path: str, *, load_in_4bit: bool = False,
55
- lora_r: int = 16, lora_alpha: int = 32,
56
- target_modules=("q_proj", "v_proj"),
57
- device: str | None = None, torch_dtype=None):
58
- super().__init__()
59
- from transformers import AutoModelForCausalLM, AutoTokenizer
60
- from peft import LoraConfig, get_peft_model
61
-
62
- self.device_ = torch.device(
63
- device or ("cuda" if torch.cuda.is_available() else "cpu"))
64
-
65
- kwargs = {}
66
- if load_in_4bit:
67
- from transformers import BitsAndBytesConfig
68
- kwargs["quantization_config"] = BitsAndBytesConfig(
69
- load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16,
70
- bnb_4bit_quant_type="nf4")
71
- if torch_dtype is not None:
72
- kwargs["torch_dtype"] = torch_dtype
73
-
74
- # `model_path` may be "Qwen/Qwen2.5-0.5B-Instruct", "meta-llama/...",
75
- # or a local directory like "/models/llama-3.2-1b".
76
- self.tokenizer = AutoTokenizer.from_pretrained(model_path)
77
- if self.tokenizer.pad_token is None:
78
- self.tokenizer.pad_token = self.tokenizer.eos_token
79
- base = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
80
- if not load_in_4bit:
81
- base.to(self.device_)
82
-
83
- # Freeze the base; all learning happens in LoRA adapters.
84
- for p in base.parameters():
85
- p.requires_grad_(False)
86
-
87
- self._lora_cfg = LoraConfig(
88
- r=lora_r, lora_alpha=lora_alpha, lora_dropout=0.05,
89
- target_modules=list(target_modules), task_type="CAUSAL_LM")
90
- self.model = get_peft_model(base, self._lora_cfg, adapter_name="general")
91
- self._adapters = {"general"}
92
-
93
- self.vocab_size = self.model.config.vocab_size
94
- self.hidden_size = self.model.config.hidden_size
95
-
96
- # ---- adapters -----------------------------------------------------------
97
- def add_adapter(self, name: str):
98
- if name not in self._adapters:
99
- self.model.add_adapter(name, self._lora_cfg)
100
- self._adapters.add(name)
101
-
102
- def set_adapter(self, name: str):
103
- self.model.set_adapter(name)
104
-
105
- def trainable_parameters(self):
106
- return (p for p in self.model.parameters() if p.requires_grad)
107
-
108
- # ---- core ops -----------------------------------------------------------
109
- def forward(self, ids):
110
- out = self.model(input_ids=ids.to(self.device_),
111
- output_hidden_states=True)
112
- return out.logits, out.hidden_states[-1] # last layer hidden
113
-
114
- @torch.no_grad()
115
- def generate(self, ids, n_new=64, temperature=0.8):
116
- out = self.model.generate(
117
- input_ids=ids.to(self.device_),
118
- max_new_tokens=n_new, do_sample=True, temperature=temperature,
119
- pad_token_id=self.tokenizer.pad_token_id)
120
- return out
121
-
122
- # ---- text helpers -------------------------------------------------------
123
- def encode_text(self, text: str):
124
- return self.tokenizer(text, return_tensors="pt").input_ids.to(self.device_)
125
-
126
- def decode(self, ids):
127
- return self.tokenizer.decode(ids[0], skip_special_tokens=True)
128
-
129
- @property
130
- def device(self):
131
- return self.device_
132
-
133
-
134
- # ----------------------------------------------------------------------------
135
- # 0b. Tiny fallback backend (no transformers; same toy model as before)
136
- # ----------------------------------------------------------------------------
137
- class TinyBackend(LLMBackend):
138
- VOCAB, D_MODEL, N_LAYERS, N_HEADS, SEQ_LEN, LORA_R = 1000, 128, 2, 4, 32, 8
139
-
140
- class _LoRALinear(nn.Module):
141
- def __init__(self, d_in, d_out, r):
142
- super().__init__()
143
- self.base = nn.Linear(d_in, d_out)
144
- self.base.weight.requires_grad_(False)
145
- self.base.bias.requires_grad_(False)
146
- self.adapters, self.active, self.r = nn.ModuleDict(), None, r
147
-
148
- def add_adapter(self, name):
149
- A = nn.Linear(self.base.in_features, self.r, bias=False)
150
- B = nn.Linear(self.r, self.base.out_features, bias=False)
151
- nn.init.zeros_(B.weight)
152
- self.adapters[name] = nn.Sequential(A, B)
153
-
154
- def forward(self, x):
155
- y = self.base(x)
156
- if self.active and self.active in self.adapters:
157
- y = y + self.adapters[self.active](x)
158
- return y
159
-
160
- class _Block(nn.Module):
161
- def __init__(self, D, H, R):
162
- super().__init__()
163
- L = TinyBackend._LoRALinear
164
- self.ln1 = nn.LayerNorm(D)
165
- self.attn = nn.MultiheadAttention(D, H, batch_first=True)
166
- self.ln2 = nn.LayerNorm(D)
167
- self.up, self.down = L(D, 4 * D, R), L(4 * D, D, R)
168
-
169
- def forward(self, x, mask):
170
- h = self.ln1(x)
171
- a, _ = self.attn(h, h, h, attn_mask=mask, need_weights=False)
172
- x = x + a
173
- return x + self.down(F.gelu(self.up(self.ln2(x))))
174
-
175
- def __init__(self):
176
- super().__init__()
177
- D, V = self.D_MODEL, self.VOCAB
178
- self.tok = nn.Embedding(V, D)
179
- self.pos = nn.Embedding(self.SEQ_LEN * 4, D)
180
- self.blocks = nn.ModuleList(
181
- [self._Block(D, self.N_HEADS, self.LORA_R) for _ in range(self.N_LAYERS)])
182
- self.ln_f, self.head = nn.LayerNorm(D), nn.Linear(D, V, bias=False)
183
- self.vocab_size, self.hidden_size = V, D
184
- self.add_adapter("general")
185
- self.set_adapter("general")
186
-
187
- def add_adapter(self, name):
188
- for b in self.blocks:
189
- b.up.add_adapter(name); b.down.add_adapter(name)
190
-
191
- def set_adapter(self, name):
192
- for b in self.blocks:
193
- b.up.active = name; b.down.active = name
194
-
195
- def trainable_parameters(self):
196
- return (p for p in self.parameters() if p.requires_grad)
197
-
198
- def forward(self, ids):
199
- B, T = ids.shape
200
- x = self.tok(ids) + self.pos(torch.arange(T, device=ids.device))
201
- mask = torch.triu(torch.full((T, T), float("-inf"), device=ids.device), 1)
202
- for b in self.blocks:
203
- x = b(x, mask)
204
- h = self.ln_f(x)
205
- return self.head(h), h
206
-
207
- @torch.no_grad()
208
- def generate(self, ids, n_new=16, temperature=1.0):
209
- for _ in range(n_new):
210
- logits, _ = self(ids[:, -self.SEQ_LEN:])
211
- nxt = torch.multinomial(F.softmax(logits[:, -1] / temperature, -1), 1)
212
- ids = torch.cat([ids, nxt], dim=1)
213
- return ids
214
-
215
- @property
216
- def device(self):
217
- return next(self.parameters()).device
218
-
219
-
220
- def make_backend(llm: str, **kw) -> LLMBackend:
221
- """'tiny' -> toy model; anything else -> HF hub id or local path."""
222
- return TinyBackend() if llm == "tiny" else HFBackend(llm, **kw)
223
-
224
-
225
- # ----------------------------------------------------------------------------
226
- # 1. Embedder — vocab-agnostic (sized from the backend's tokenizer)
227
- # Swap for a real model: pass embed_fn=lambda txt: sbert.encode(...)
228
- # ----------------------------------------------------------------------------
229
- class Embedder(nn.Module):
230
- def __init__(self, vocab_size: int, d_emb: int = 64):
231
- super().__init__()
232
- self.tok = nn.Embedding(vocab_size, d_emb)
233
- self.enc = nn.GRU(d_emb, d_emb, batch_first=True, bidirectional=True)
234
- self.proj = nn.Linear(2 * d_emb, d_emb)
235
- self.d_emb = d_emb
236
-
237
- def forward(self, ids):
238
- h, _ = self.enc(self.tok(ids))
239
- return F.normalize(self.proj(h.mean(dim=1)), dim=-1)
240
-
241
-
242
- # ----------------------------------------------------------------------------
243
- # 2. JEPA — vocab-agnostic latent predictor with EMA target encoder
244
- # ----------------------------------------------------------------------------
245
- class _JEPAEncoder(nn.Module):
246
- def __init__(self, vocab_size, d):
247
- super().__init__()
248
- self.tok = nn.Embedding(vocab_size, d)
249
- self.enc = nn.GRU(d, d, batch_first=True)
250
- self.out = nn.Linear(d, d)
251
-
252
- def forward(self, ids):
253
- h, _ = self.enc(self.tok(ids))
254
- return self.out(h.mean(dim=1))
255
-
256
-
257
- class JEPA(nn.Module):
258
- def __init__(self, vocab_size: int, d_jepa: int = 64, ema_m: float = 0.996):
259
- super().__init__()
260
- self.ctx_enc = _JEPAEncoder(vocab_size, d_jepa)
261
- self.tgt_enc = copy.deepcopy(self.ctx_enc)
262
- for p in self.tgt_enc.parameters():
263
- p.requires_grad_(False)
264
- self.predictor = nn.Sequential(
265
- nn.Linear(d_jepa, 2 * d_jepa), nn.GELU(), nn.Linear(2 * d_jepa, d_jepa))
266
- self.m, self.d_jepa = ema_m, d_jepa
267
-
268
- @torch.no_grad()
269
- def ema_update(self):
270
- for p_t, p_c in zip(self.tgt_enc.parameters(), self.ctx_enc.parameters()):
271
- p_t.mul_(self.m).add_(p_c.detach(), alpha=1 - self.m)
272
-
273
- def loss(self, seg_ctx, seg_tgt):
274
- z_hat = self.predictor(self.ctx_enc(seg_ctx))
275
- with torch.no_grad():
276
- z_tgt = self.tgt_enc(seg_tgt)
277
- pred = F.mse_loss(z_hat, z_tgt)
278
- var_reg = F.relu(1.0 - z_hat.std(dim=0)).mean() # anti-collapse
279
- return pred + 0.5 * var_reg
280
-
281
- @torch.no_grad()
282
- def predict_next_latent(self, seg_ctx):
283
- return self.predictor(self.ctx_enc(seg_ctx))
284
-
285
- @torch.no_grad()
286
- def encode(self, seg):
287
- return self.tgt_enc(seg)
288
-
289
-
290
- # ----------------------------------------------------------------------------
291
- # 3. Bridge — sized from backend.hidden_size at construction
292
- # ----------------------------------------------------------------------------
293
- class Bridge(nn.Module):
294
- def __init__(self, d_llm_hidden: int, d_jepa: int):
295
- super().__init__()
296
- self.proj = nn.Sequential(
297
- nn.Linear(d_llm_hidden, d_jepa), nn.GELU(), nn.Linear(d_jepa, d_jepa))
298
-
299
- def forward(self, llm_hidden): # [B,T,H] -> [B,d_jepa]
300
- return self.proj(llm_hidden.float().mean(dim=1))
301
-
302
- def info_nce(self, z1, z2, tau=0.07):
303
- z1, z2 = F.normalize(z1, dim=-1), F.normalize(z2, dim=-1)
304
- logits = z1 @ z2.t() / tau
305
- labels = torch.arange(z1.size(0), device=z1.device)
306
- return 0.5 * (F.cross_entropy(logits, labels) +
307
- F.cross_entropy(logits.t(), labels))
308
-
309
-
310
- # ----------------------------------------------------------------------------
311
- # 4. Memory + Router
312
- # ----------------------------------------------------------------------------
313
- class VectorStore:
314
- def __init__(self):
315
- self.keys, self.values = [], []
316
-
317
- def add(self, emb, payload):
318
- self.keys.append(emb.squeeze(0).detach().cpu())
319
- self.values.append(payload)
320
-
321
- def search(self, q, k=2):
322
- if not self.keys:
323
- return []
324
- K = torch.stack(self.keys)
325
- sims = (q.detach().cpu() @ K.t()).squeeze(0)
326
- top = sims.topk(min(k, len(self.keys))).indices
327
- return [self.values[i] for i in top]
328
-
329
-
330
- class Router(nn.Module):
331
- def __init__(self, d_emb, n_adapters):
332
- super().__init__()
333
- self.fc = nn.Linear(d_emb, n_adapters)
334
-
335
- def forward(self, emb):
336
- return self.fc(emb).argmax(dim=-1)
337
-
338
-
339
- # ----------------------------------------------------------------------------
340
- # 5. Ensemble — backend-agnostic
341
- # ----------------------------------------------------------------------------
342
- class Ensemble(nn.Module):
343
- def __init__(self, llm: str = "tiny", adapter_names=("general",),
344
- d_emb: int = 64, d_jepa: int = 64, **backend_kw):
345
- super().__init__()
346
- self.llm = make_backend(llm, **backend_kw)
347
- V, H = self.llm.vocab_size, self.llm.hidden_size
348
-
349
- self.emb = Embedder(V, d_emb)
350
- self.jepa = JEPA(V, d_jepa)
351
- self.bridge = Bridge(H, d_jepa)
352
- self.store = VectorStore()
353
-
354
- self.adapter_names = list(adapter_names)
355
- for n in self.adapter_names:
356
- self.llm.add_adapter(n)
357
- self.llm.set_adapter(self.adapter_names[0])
358
- self.router = Router(d_emb, len(self.adapter_names))
359
-
360
- # -------- inference: route -> retrieve -> generate -> JEPA-verify -------
361
- @torch.no_grad()
362
- def answer_ids(self, query_ids, n_new=32, tau_consistency=0.0, max_retries=2):
363
- q_emb = self.emb(query_ids.cpu())
364
- a_idx = self.router(q_emb).item()
365
- self.llm.set_adapter(self.adapter_names[a_idx])
366
-
367
- mems = self.store.search(q_emb, k=1)
368
- ctx = (torch.cat([mems[0], query_ids.cpu()], dim=1)
369
- if mems else query_ids.cpu())
370
-
371
- z_expected = self.jepa.predict_next_latent(ctx)
372
-
373
- best = None
374
- for attempt in range(max_retries + 1):
375
- draft = self.llm.generate(ctx.to(self.llm.device), n_new=n_new,
376
- temperature=0.8 + 0.3 * attempt)
377
- new_part = draft[:, ctx.size(1):].cpu()
378
- score = F.cosine_similarity(
379
- z_expected, self.jepa.encode(new_part)).item()
380
- if best is None or score > best[1]:
381
- best = (draft, score, attempt)
382
- if score >= tau_consistency:
383
- break
384
- draft, score, attempt = best
385
- return draft, score, self.adapter_names[a_idx], attempt
386
-
387
- def answer_text(self, prompt: str, **kw):
388
- """Convenience wrapper for HF backends (uses the real tokenizer)."""
389
- ids = self.llm.encode_text(prompt)
390
- out, score, adapter, retries = self.answer_ids(ids, **kw)
391
- return self.llm.decode(out), score, adapter, retries
392
-
393
- # -------- continual learning hooks ---------------------------------------
394
- def memorize_ids(self, ids):
395
- self.store.add(self.emb(ids.cpu()), ids.cpu())
396
-
397
- def memorize_text(self, text: str):
398
- self.memorize_ids(self.llm.encode_text(text))
399
-
400
- def new_task_adapter(self, name: str):
401
- self.adapter_names.append(name)
402
- self.llm.add_adapter(name)
403
- old = self.router
404
- self.router = Router(self.emb.d_emb, len(self.adapter_names))
405
- with torch.no_grad():
406
- self.router.fc.weight[: old.fc.out_features] = old.fc.weight
407
- self.router.fc.bias[: old.fc.out_features] = old.fc.bias
408
-
409
- # -------- one joint training step (LM + JEPA + Bridge) -------------------
410
- def train_step(self, seg_a, seg_b, opt, w_bridge=0.1):
411
- """seg_a, seg_b: consecutive token-id segments [B, T] (same tokenizer
412
- as the backend!). For HF backends build them with backend.tokenizer."""
413
- logits, hidden = self.llm(seg_a.to(self.llm.device))
414
- lm_loss = F.cross_entropy(
415
- logits[:, :-1].reshape(-1, self.llm.vocab_size).float(),
416
- seg_a[:, 1:].reshape(-1).to(logits.device))
417
-
418
- jepa_loss = self.jepa.loss(seg_a.cpu(), seg_b.cpu())
419
-
420
- z_llm = self.bridge(hidden.cpu() if hidden.device.type != "cpu" else hidden)
421
- z_jepa = self.jepa.ctx_enc(seg_a.cpu()).detach()
422
- bridge_loss = self.bridge.info_nce(z_llm, z_jepa.to(z_llm.device))
423
-
424
- loss = lm_loss.cpu() + jepa_loss + w_bridge * bridge_loss
425
- opt.zero_grad(); loss.backward(); opt.step()
426
- self.jepa.ema_update()
427
- return {"lm": lm_loss.item(), "jepa": jepa_loss.item(),
428
- "bridge": bridge_loss.item()}
429
-
430
- def make_optimizer(self, lr_lora=2e-4, lr_aux=1e-3):
431
- return torch.optim.AdamW([
432
- {"params": list(self.llm.trainable_parameters()), "lr": lr_lora},
433
- {"params": list(self.jepa.ctx_enc.parameters())
434
- + list(self.jepa.predictor.parameters()), "lr": lr_aux},
435
- {"params": list(self.bridge.parameters())
436
- + list(self.emb.parameters())
437
- + list(self.router.parameters()), "lr": lr_aux},
438
- ])
439
-
440
-
441
- # ----------------------------------------------------------------------------
442
- # 6. Helpers: turn raw text into (seg_a, seg_b) pairs with the HF tokenizer
443
- # ----------------------------------------------------------------------------
444
- def segment_pairs_from_texts(backend: HFBackend, texts, seg_len=64):
445
- """Yields consecutive-segment id pairs for the JEPA + LM losses."""
446
- a_list, b_list = [], []
447
- for t in texts:
448
- ids = backend.tokenizer(t, return_tensors="pt").input_ids[0]
449
- for i in range(0, len(ids) - 2 * seg_len, seg_len):
450
- a_list.append(ids[i:i + seg_len])
451
- b_list.append(ids[i + seg_len:i + 2 * seg_len])
452
- if not a_list:
453
- raise ValueError("texts too short for the chosen seg_len")
454
- return torch.stack(a_list), torch.stack(b_list)
455
-
456
-
457
- # ----------------------------------------------------------------------------
458
- # 7. Demos
459
- # ----------------------------------------------------------------------------
460
- def demo_tiny(steps=50):
461
- """No-dependency smoke test."""
462
- ens = Ensemble(llm="tiny")
463
- opt = ens.make_optimizer()
464
- for s in range(steps):
465
- seg_a = torch.randint(0, ens.llm.vocab_size, (8, 32))
466
- seg_b = torch.randint(0, ens.llm.vocab_size, (8, 32))
467
- logs = ens.train_step(seg_a, seg_b, opt)
468
- if s % 10 == 0:
469
- print(f"step {s:3d} | " + " | ".join(f"{k} {v:.3f}" for k, v in logs.items()))
470
-
471
- for _ in range(5):
472
- ens.memorize_ids(torch.randint(0, ens.llm.vocab_size, (1, 32)))
473
- ens.new_task_adapter("medical")
474
-
475
- q = torch.randint(0, ens.llm.vocab_size, (1, 8))
476
- out, score, adapter, retries = ens.answer_ids(q, tau_consistency=-1.0)
477
- print(f"\nadapter={adapter} jepa_consistency={score:.3f} retries={retries}")
478
-
479
-
480
- def demo_hf(model_path="Qwen/Qwen2.5-0.5B-Instruct"):
481
- """Real model from hub id OR local path, e.g. '/models/llama-3.2-1b'.
482
- For gated Llama repos: huggingface-cli login first."""
483
- ens = Ensemble(llm=model_path, load_in_4bit=False) # 4bit needs bitsandbytes
484
- opt = ens.make_optimizer()
485
-
486
- texts = ["Replace this with your real corpus. " * 50]
487
- seg_a, seg_b = segment_pairs_from_texts(ens.llm, texts, seg_len=32)
488
- for s in range(10): # tiny demo run
489
- logs = ens.train_step(seg_a[:4], seg_b[:4], opt)
490
- print(f"step {s} | " + " | ".join(f"{k} {v:.3f}" for k, v in logs.items()))
491
-
492
- ens.memorize_text("The project codename is AURORA and it ships in Q3.")
493
- ens.new_task_adapter("project_aurora")
494
-
495
- text, score, adapter, retries = ens.answer_text(
496
- "What is the project codename?", n_new=24, tau_consistency=-1.0)
497
- print(f"\n[{adapter} | jepa={score:.3f} | retries={retries}]\n{text}")
498
-
499
-
500
- if __name__ == "__main__":
501
- import sys
502
- arg = sys.argv[1] if len(sys.argv) > 1 else "tiny"
503
- if arg == "tiny":
504
- demo_tiny()
505
- else:
506
- demo_hf(arg) # python ensemble.py /models/llama-3.2-1b
507
- # python ensemble.py Qwen/Qwen2.5-0.5B-Instruct
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/memory.py DELETED
@@ -1,46 +0,0 @@
1
- """Retrieval memory: embedder, vector store, and adapter router."""
2
-
3
- from __future__ import annotations
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
-
9
-
10
- class Embedder(nn.Module):
11
- def __init__(self, vocab_size: int, d_emb: int = 64):
12
- super().__init__()
13
- self.tok = nn.Embedding(vocab_size, d_emb)
14
- self.enc = nn.GRU(d_emb, d_emb, batch_first=True, bidirectional=True)
15
- self.proj = nn.Linear(2 * d_emb, d_emb)
16
- self.d_emb = d_emb
17
-
18
- def forward(self, ids):
19
- h, _ = self.enc(self.tok(ids))
20
- return F.normalize(self.proj(h.mean(dim=1)), dim=-1)
21
-
22
-
23
- class VectorStore:
24
- def __init__(self):
25
- self.keys, self.values = [], []
26
-
27
- def add(self, emb, payload):
28
- self.keys.append(emb.squeeze(0).detach().cpu())
29
- self.values.append(payload)
30
-
31
- def search(self, q, k=2):
32
- if not self.keys:
33
- return []
34
- K = torch.stack(self.keys)
35
- sims = (q.detach().cpu() @ K.t()).squeeze(0)
36
- top = sims.topk(min(k, len(self.keys))).indices
37
- return [self.values[i] for i in top]
38
-
39
-
40
- class Router(nn.Module):
41
- def __init__(self, d_emb, n_adapters):
42
- super().__init__()
43
- self.fc = nn.Linear(d_emb, n_adapters)
44
-
45
- def forward(self, emb):
46
- return self.fc(emb).argmax(dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/pretrain.py DELETED
@@ -1,198 +0,0 @@
1
- """Joint pretrain: LLM (LoRA) + embedder + JEPA + bridge, saved to models/ensemble/."""
2
-
3
- from __future__ import annotations
4
-
5
- import argparse
6
- import json
7
- import os
8
- import random
9
- import time
10
- from pathlib import Path
11
-
12
- import torch
13
-
14
- from ensemble.checkpoint import save_checkpoint
15
- from ensemble.config import default_ensemble_out, load_dotenv, resolve_llm
16
- from ensemble.jepa_ensemble import Ensemble, segment_pairs_from_texts
17
-
18
- _REPO_ROOT = Path(__file__).resolve().parents[4]
19
- _DEFAULT_DATA = _REPO_ROOT / "research/data/education-lesson-chat.jsonl"
20
- _DEFAULT_KB = _REPO_ROOT / "research/data/benchmark-kb.jsonl"
21
-
22
-
23
- def _load_jsonl(path: Path) -> list[dict]:
24
- rows = []
25
- with open(path) as f:
26
- for line in f:
27
- line = line.strip()
28
- if line:
29
- rows.append(json.loads(line))
30
- return rows
31
-
32
-
33
- def _chat_to_text(row: dict) -> str:
34
- messages = row.get("messages", [])
35
- parts = [f"{m.get('role', 'user')}: {m.get('content', '')}" for m in messages]
36
- return "\n".join(parts)
37
-
38
-
39
- def _collect_texts(data_path: Path, max_samples: int | None) -> list[str]:
40
- rows = _load_jsonl(data_path)
41
- if max_samples is not None:
42
- rows = rows[:max_samples]
43
- return [_chat_to_text(r) for r in rows if _chat_to_text(r).strip()]
44
-
45
-
46
- def _seed_memory(ens: Ensemble, kb_path: Path | None) -> int:
47
- if kb_path is None or not kb_path.is_file():
48
- return 0
49
- count = 0
50
- for row in _load_jsonl(kb_path):
51
- text = row.get("text", "").strip()
52
- if text:
53
- ens.memorize_text(text)
54
- count += 1
55
- return count
56
-
57
-
58
- def pretrain(args) -> Path:
59
- torch.manual_seed(args.seed)
60
- random.seed(args.seed)
61
-
62
- data_path = Path(args.data).resolve()
63
- out_dir = Path(args.out).resolve()
64
- kb_path = Path(args.kb).resolve() if args.kb else None
65
-
66
- print(f"Loading ensemble backend: {args.llm}")
67
- ens = Ensemble(llm=args.llm, load_in_4bit=args.load_in_4bit)
68
- opt = ens.make_optimizer(lr_lora=args.lr_lora, lr_aux=args.lr_aux)
69
-
70
- texts = _collect_texts(data_path, args.max_samples)
71
- if not texts and args.llm != "tiny":
72
- raise SystemExit(f"No training texts found in {data_path}")
73
-
74
- mem_count = _seed_memory(ens, kb_path)
75
- print(f"Training texts: {len(texts)} | memory snippets: {mem_count}")
76
-
77
- if args.llm == "tiny":
78
- n_pairs = max(args.steps * args.batch_size, args.batch_size)
79
- v = ens.llm.vocab_size
80
- seg_a = torch.randint(0, v, (n_pairs, args.seg_len))
81
- seg_b = torch.randint(0, v, (n_pairs, args.seg_len))
82
- else:
83
- seg_a, seg_b = segment_pairs_from_texts(
84
- ens.llm, texts, seg_len=args.seg_len
85
- )
86
- n_pairs = seg_a.size(0)
87
- batch = min(args.batch_size, n_pairs)
88
- print(f"Segment pairs: {n_pairs} | batch={batch} | steps={args.steps}")
89
-
90
- t0 = time.time()
91
- for step in range(args.steps):
92
- idx = torch.randint(0, n_pairs, (batch,))
93
- logs = ens.train_step(seg_a[idx], seg_b[idx], opt, w_bridge=args.w_bridge)
94
- if step % max(1, args.log_every) == 0 or step == args.steps - 1:
95
- parts = " | ".join(f"{k} {v:.4f}" for k, v in logs.items())
96
- print(f"step {step:4d}/{args.steps} | {parts}")
97
-
98
- elapsed = time.time() - t0
99
- meta = {
100
- "steps": args.steps,
101
- "batch_size": batch,
102
- "seg_len": args.seg_len,
103
- "data": str(data_path),
104
- "kb": str(kb_path) if kb_path else None,
105
- "memory_count": mem_count,
106
- "text_count": len(texts),
107
- "elapsed_s": round(elapsed, 1),
108
- "lr_lora": args.lr_lora,
109
- "lr_aux": args.lr_aux,
110
- "w_bridge": args.w_bridge,
111
- "seed": args.seed,
112
- "preset": getattr(args, "preset", None),
113
- }
114
-
115
- saved = save_checkpoint(
116
- ens,
117
- out_dir,
118
- base_llm=args.llm,
119
- training_meta=meta,
120
- )
121
- print(f"\nSaved ensemble checkpoint → {saved}")
122
- print("Benchmark with slm-evals:")
123
- print(
124
- f" uv run --package slm-evals slm-benchmark "
125
- f"--model {saved} --model-type ensemble "
126
- f"--benchmarks bfcl --max-samples 5"
127
- )
128
- return saved
129
-
130
-
131
- def parse_args():
132
- p = argparse.ArgumentParser(
133
- description="Pretrain JEPA ensemble (LLM+emb+JEPA) and save to models/ensemble/"
134
- )
135
- p.add_argument(
136
- "--llm",
137
- default=None,
138
- help=(
139
- "HF hub id / local path, 'tiny' for CPU smoke, or omit to use "
140
- "LLM_PATH / BASE / MODEL_ID / ACTIVE_MODEL from .env + models.yaml"
141
- ),
142
- )
143
- p.add_argument(
144
- "--preset",
145
- default=None,
146
- help="models.yaml preset key (default: ENSEMBLE_PRESET or ACTIVE_MODEL)",
147
- )
148
- p.add_argument(
149
- "--data",
150
- default=str(_DEFAULT_DATA),
151
- help="Chat JSONL (messages[]) for segment-pair training",
152
- )
153
- p.add_argument(
154
- "--kb",
155
- default=str(_DEFAULT_KB),
156
- help="Optional KB JSONL (text field) loaded into vector store",
157
- )
158
- p.add_argument(
159
- "--out",
160
- default=None,
161
- help="Output dir (default: ENSEMBLE_OUT or models/ensemble/<preset>-jepa-pretrain)",
162
- )
163
- p.add_argument("--steps", type=int, default=100)
164
- p.add_argument("--batch-size", type=int, default=4)
165
- p.add_argument("--seg-len", type=int, default=32)
166
- p.add_argument("--max-samples", type=int, default=None)
167
- p.add_argument("--lr-lora", type=float, default=2e-4)
168
- p.add_argument("--lr-aux", type=float, default=1e-3)
169
- p.add_argument("--w-bridge", type=float, default=0.1)
170
- p.add_argument("--log-every", type=int, default=10)
171
- p.add_argument("--seed", type=int, default=0)
172
- p.add_argument("--load-in-4bit", action="store_true")
173
- p.add_argument("--no-kb", action="store_true", help="Skip loading KB into memory")
174
- return p.parse_args()
175
-
176
-
177
- def main():
178
- load_dotenv()
179
- args = parse_args()
180
- if args.no_kb:
181
- args.kb = None
182
-
183
- preset_key = args.preset
184
- if args.llm is None or args.llm == "auto":
185
- args.llm, preset_key = resolve_llm(preset_arg=args.preset)
186
- elif args.llm != "tiny" and not args.preset:
187
- _, preset_key = resolve_llm(llm_arg=args.llm)
188
-
189
- if not args.out:
190
- args.out = os.environ.get("ENSEMBLE_OUT") or default_ensemble_out(preset_key)
191
-
192
- args.preset = preset_key
193
- print(f"Resolved LLM: {args.llm}" + (f" (preset {preset_key})" if preset_key else ""))
194
- pretrain(args)
195
-
196
-
197
- if __name__ == "__main__":
198
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/world_ensemble.py DELETED
@@ -1,228 +0,0 @@
1
- """World-model ensemble: plan -> generate -> energy-rank."""
2
-
3
- from __future__ import annotations
4
-
5
- import sys
6
- import time
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
-
12
- from ensemble.backends import HFLLM, load_llm
13
- from ensemble.bridge import Bridge
14
- from ensemble.energy import EnergyModel
15
- from ensemble.jepa import JEPA
16
- from ensemble.memory import Embedder, VectorStore
17
- from ensemble.world_model import WorldModel
18
-
19
- torch.manual_seed(0)
20
-
21
- D_LAT = 96
22
- D_EMB = 64
23
-
24
-
25
- class WorldEnsemble(nn.Module):
26
- def __init__(self, llm_spec="tiny"):
27
- super().__init__()
28
- self.llm = load_llm(llm_spec)
29
- V, H = self.llm.vocab_size, self.llm.hidden_size
30
- self.emb = Embedder(V, D_EMB)
31
- self.jepa = JEPA(V, D_LAT)
32
- self.world = WorldModel(D_LAT)
33
- self.energy = EnergyModel(D_LAT)
34
- self.bridge = Bridge(H, D_LAT)
35
- self.store = VectorStore()
36
-
37
- @torch.no_grad()
38
- def world_state(self, segments):
39
- s = self.world.init_state(1, "cpu")
40
- for seg in segments:
41
- z = self.jepa.encode(seg.cpu())
42
- s, _ = self.world.step(s, z)
43
- return s
44
-
45
- @torch.no_grad()
46
- def answer(self, query_ids, n_new=24, n_drafts=6, horizon=3):
47
- q_emb = self.emb(query_ids.cpu())
48
- mems = self.store.search(q_emb, k=1)
49
- segments = (mems + [query_ids.cpu()]) if mems else [query_ids.cpu()]
50
- ctx = torch.cat(segments, dim=1)
51
-
52
- s = self.world_state(segments)
53
- plan, _ = self.world.rollout(s, horizon)
54
-
55
- drafts, lat = [], []
56
- for _ in range(n_drafts):
57
- out = self.llm.generate(
58
- ctx.to(self.llm.device), n_new=n_new, temperature=0.9
59
- )
60
- new = out[:, ctx.size(1) :].cpu()
61
- drafts.append(new)
62
- lat.append(self.jepa.encode(new))
63
- Z = torch.cat(lat, 0)
64
- E = self.energy.rank(s, Z)
65
- best = E.argmin().item()
66
- return {
67
- "output": drafts[best],
68
- "energy": E[best].item(),
69
- "all_energies": E.tolist(),
70
- "plan_alignment": F.cosine_similarity(
71
- plan[:, 0], Z[best : best + 1]
72
- ).item(),
73
- }
74
-
75
- def memorize(self, ids):
76
- self.store.add(self.emb(ids.cpu()), ids.cpu())
77
-
78
- def train_step(
79
- self,
80
- seg_seq,
81
- opt,
82
- w=None,
83
- hard_negs=True,
84
- ):
85
- if w is None:
86
- w = dict(lm=1.0, jepa=1.0, world=1.0, ebm=1.0, bridge=0.1)
87
-
88
- B, T, L = seg_seq.shape
89
- dev = self.llm.device
90
-
91
- flat = seg_seq[:, 0].to(dev)
92
- logits, hidden = self.llm(flat)
93
- lm = F.cross_entropy(
94
- logits[:, :-1].reshape(-1, self.llm.vocab_size).float(),
95
- flat[:, 1:].reshape(-1),
96
- )
97
-
98
- jepa = self.jepa.loss(seg_seq[:, 0], seg_seq[:, 1])
99
-
100
- z_seq = torch.stack(
101
- [self.jepa.enc(seg_seq[:, t]) for t in range(T)], 1
102
- )
103
- world = self.world.sequence_loss(z_seq)
104
-
105
- s = self.world.init_state(B, z_seq.device)
106
- s, _ = self.world.step(s, z_seq[:, 0].detach())
107
- z_pos = z_seq[:, 1].detach()
108
- z_negs = None
109
- if hard_negs:
110
- with torch.no_grad():
111
- gen = self.llm.generate(seg_seq[:, 0].to(dev), n_new=L)
112
- gen_new = gen[:, seg_seq.size(2) :].cpu()
113
- z_negs = self.jepa.encode(gen_new).unsqueeze(1)
114
- ebm = self.energy.contrastive_loss(s, z_pos, z_negs)
115
-
116
- bridge = self.bridge.info_nce(
117
- self.bridge(
118
- hidden.cpu() if hidden.device.type != "cpu" else hidden
119
- ),
120
- self.jepa.enc(seg_seq[:, 0]).detach(),
121
- )
122
-
123
- loss = (
124
- w["lm"] * lm.cpu()
125
- + w["jepa"] * jepa
126
- + w["world"] * world
127
- + w["ebm"] * ebm
128
- + w["bridge"] * bridge
129
- )
130
- opt.zero_grad()
131
- loss.backward()
132
- opt.step()
133
- self.jepa.ema()
134
- return dict(
135
- lm=lm.item(),
136
- jepa=jepa.item(),
137
- world=world.item(),
138
- ebm=ebm.item(),
139
- bridge=bridge.item(),
140
- )
141
-
142
- def make_optimizer(self, lr_lora=2e-4, lr_aux=1e-3):
143
- return torch.optim.AdamW(
144
- [
145
- {"params": list(self.llm.trainable_parameters()), "lr": lr_lora},
146
- {
147
- "params": list(self.jepa.enc.parameters())
148
- + list(self.jepa.pred.parameters()),
149
- "lr": lr_aux,
150
- },
151
- {"params": list(self.world.parameters()), "lr": lr_aux},
152
- {"params": list(self.energy.parameters()), "lr": lr_aux},
153
- {
154
- "params": list(self.bridge.parameters())
155
- + list(self.emb.parameters()),
156
- "lr": lr_aux,
157
- },
158
- ]
159
- )
160
-
161
-
162
- def toy_segment_sequences(B=8, T=4, L=24, vocab=1000):
163
- return torch.randint(0, vocab, (B, T, L))
164
-
165
-
166
- def hf_segment_sequences(llm: HFLLM, texts, T=4, L=64):
167
- seqs = []
168
- for t in texts:
169
- ids = llm.tokenizer(t, return_tensors="pt").input_ids[0]
170
- n = (len(ids) // (T * L)) * T * L
171
- if n:
172
- seqs.append(ids[:n].view(-1, T, L))
173
- if not seqs:
174
- raise ValueError("corpus too short for T*L window")
175
- return torch.cat(seqs, 0)
176
-
177
-
178
- def demo(spec="tiny", steps=60):
179
- ens = WorldEnsemble(spec)
180
- opt = ens.make_optimizer()
181
-
182
- if spec == "tiny":
183
- get_batch = lambda: toy_segment_sequences(vocab=ens.llm.vocab_size)
184
- else:
185
- corpus = ["Replace with your real documents. " * 200]
186
- data = hf_segment_sequences(ens.llm, corpus, T=4, L=32)
187
- get_batch = lambda: data[torch.randperm(len(data))[:4]]
188
- steps = min(steps, 10)
189
-
190
- t0 = time.time()
191
- for s in range(steps):
192
- logs = ens.train_step(
193
- get_batch(), opt, hard_negs=(s > steps // 2)
194
- )
195
- if s % 10 == 0:
196
- print(
197
- f"step {s:3d} | "
198
- + " | ".join(f"{k} {v:.3f}" for k, v in logs.items())
199
- )
200
- print(f"trained {steps} steps in {time.time() - t0:.1f}s")
201
-
202
- for _ in range(4):
203
- if spec == "tiny":
204
- ens.memorize(torch.randint(0, ens.llm.vocab_size, (1, 24)))
205
- q = (
206
- torch.randint(0, ens.llm.vocab_size, (1, 12))
207
- if spec == "tiny"
208
- else ens.llm.tokenizer(
209
- "What is this document about?", return_tensors="pt"
210
- ).input_ids
211
- )
212
- res = ens.answer(q, n_drafts=6, horizon=3)
213
- print(
214
- f"\nselected draft energy={res['energy']:.3f} "
215
- f"(all: {[f'{e:.2f}' for e in res['all_energies']]})"
216
- )
217
- print(f"plan↔output alignment: {res['plan_alignment']:.3f}")
218
-
219
-
220
- if __name__ == "__main__":
221
- from ensemble.config import load_dotenv, resolve_llm
222
-
223
- load_dotenv()
224
- spec = sys.argv[1] if len(sys.argv) > 1 else None
225
- if spec is None or spec == "auto":
226
- spec, preset = resolve_llm()
227
- print(f"Resolved LLM: {spec} (preset {preset})")
228
- demo(spec or "tiny")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/world_model.py DELETED
@@ -1,40 +0,0 @@
1
- """Latent world model: multi-step rollout in JEPA space."""
2
-
3
- from __future__ import annotations
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
-
9
-
10
- class WorldModel(nn.Module):
11
- def __init__(self, d_latent: int):
12
- super().__init__()
13
- self.cell = nn.GRUCell(d_latent, d_latent)
14
- self.head = nn.Linear(d_latent, d_latent)
15
- self.s0 = nn.Parameter(torch.zeros(d_latent))
16
- self.d_latent = d_latent
17
-
18
- def init_state(self, B, device):
19
- return self.s0.unsqueeze(0).expand(B, -1).contiguous().to(device)
20
-
21
- def step(self, s, z):
22
- s = self.cell(z, s)
23
- return s, self.head(s)
24
-
25
- def rollout(self, s, horizon):
26
- preds = []
27
- for _ in range(horizon):
28
- z_hat = self.head(s)
29
- preds.append(z_hat)
30
- s = self.cell(z_hat, s)
31
- return torch.stack(preds, 1), s
32
-
33
- def sequence_loss(self, z_seq):
34
- B, T, _ = z_seq.shape
35
- s = self.init_state(B, z_seq.device)
36
- loss = 0.0
37
- for t in range(T - 1):
38
- s, z_hat = self.step(s, z_seq[:, t])
39
- loss = loss + F.mse_loss(z_hat, z_seq[:, t + 1])
40
- return loss / (T - 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/ensemble/src/ensemble/world_model_ensemble.py DELETED
@@ -1,499 +0,0 @@
1
- """
2
- World-Model Ensemble: EMB + EBM + JEPA + World Model + small LLM (from path)
3
- =============================================================================
4
- A LeCun-style modular agent built around a small language model.
5
-
6
- ARCHITECTURE
7
- ------------
8
- ┌────────────────────────────┐
9
- input tokens ──► EMB ──┤ VectorStore (retrieval/CL) │──► context
10
- │ └────────────────────────────┘ │
11
- │ │
12
- ▼ ▼
13
- JEPA encoder ──► latent state s_t ──► WORLD MODEL ──► ŝ_{t+1..t+H}
14
- │ (GRU dynamics, multi-step rollout) │
15
- │ │
16
- │ ┌────────────────────────────────────┐ │
17
- └──────────► │ ENERGY MODEL E(s_ctx, z_candidate)│ ◄─┘
18
- │ low energy = compatible/plausible │
19
- └────────────────┬───────────────────┘
20
- │ scores drafts / plans
21
-
22
- LLM (small, loaded from path, LoRA bank) ──► N drafts ──► pick argmin E
23
-
24
- ROLES
25
- -----
26
- EMB perception for retrieval + routing (non-parametric memory)
27
- JEPA learns the latent space: predict z(next segment) from z(context)
28
- (EMA target encoder + variance reg, no token reconstruction)
29
- WORLD MODEL deterministic latent dynamics s_{t+1} = f(s_t, z_t):
30
- rolls the conversation/document state forward H steps in
31
- LATENT space — cheap lookahead without decoding tokens
32
- ENERGY E(s, z) ∈ R, trained so true continuations have LOW energy and
33
- negatives (shuffled / model-generated) have HIGH energy.
34
- At inference it is the critic: rank LLM drafts, reject bad plans.
35
- LLM the only token-level generator. Loaded from a local path or HF id;
36
- frozen base + LoRA adapters (continual learning by isolation).
37
-
38
- WHY EBM *and* JEPA? JEPA gives a point prediction ẑ of the future latent;
39
- the EBM gives a *compatibility landscape* E(s, z) — it can say "both A and B
40
- are plausible" where a point predictor must average them. JEPA trains the
41
- representation; the EBM scores hypotheses in it. World model chains JEPA
42
- one-step predictions into multi-step rollouts that the EBM can evaluate.
43
-
44
- USAGE
45
- -----
46
- pip install torch # toy mode
47
- pip install transformers peft accelerate # real LLM mode
48
-
49
- python world_model_ensemble.py tiny # smoke test
50
- python world_model_ensemble.py /models/llama-3.2-1b # local weights
51
- python world_model_ensemble.py Qwen/Qwen2.5-0.5B-Instruct
52
- """
53
-
54
- from __future__ import annotations
55
- import copy
56
- import math
57
- import sys
58
- import time
59
-
60
- import torch
61
- import torch.nn as nn
62
- import torch.nn.functional as F
63
-
64
- torch.manual_seed(0)
65
-
66
- D_LAT = 96 # shared latent dimension (JEPA / world / energy)
67
- D_EMB = 64 # retrieval embedding dim
68
-
69
-
70
- # ============================================================================
71
- # 1. LLM backend — load small model from path / hub, or toy fallback
72
- # (same contract as before: forward -> (logits, hidden), generate, adapters)
73
- # ============================================================================
74
- class TinyLLM(nn.Module):
75
- VOCAB, D, L, H, T = 1000, 128, 2, 4, 32
76
-
77
- def __init__(self):
78
- super().__init__()
79
- self.tok = nn.Embedding(self.VOCAB, self.D)
80
- self.pos = nn.Embedding(self.T * 4, self.D)
81
- layer = nn.TransformerEncoderLayer(self.D, self.H, 4 * self.D,
82
- batch_first=True, norm_first=True)
83
- self.blocks = nn.TransformerEncoder(layer, self.L)
84
- self.head = nn.Linear(self.D, self.VOCAB, bias=False)
85
- self.vocab_size, self.hidden_size = self.VOCAB, self.D
86
-
87
- def forward(self, ids):
88
- Tn = ids.size(1)
89
- x = self.tok(ids) + self.pos(torch.arange(Tn, device=ids.device))
90
- mask = torch.triu(torch.full((Tn, Tn), float("-inf"),
91
- device=ids.device), 1)
92
- h = self.blocks(x, mask=mask)
93
- return self.head(h), h
94
-
95
- @torch.no_grad()
96
- def generate(self, ids, n_new=16, temperature=1.0):
97
- for _ in range(n_new):
98
- logits, _ = self(ids[:, -self.T:])
99
- nxt = torch.multinomial(
100
- F.softmax(logits[:, -1] / temperature, -1), 1)
101
- ids = torch.cat([ids, nxt], 1)
102
- return ids
103
-
104
- def trainable_parameters(self):
105
- return self.parameters()
106
-
107
- @property
108
- def device(self):
109
- return next(self.parameters()).device
110
-
111
-
112
- class HFLLM(nn.Module):
113
- """Small model from a local path or HF id, frozen base + LoRA."""
114
- def __init__(self, path, lora_r=16):
115
- super().__init__()
116
- from transformers import AutoModelForCausalLM, AutoTokenizer
117
- from peft import LoraConfig, get_peft_model
118
- self.tokenizer = AutoTokenizer.from_pretrained(path)
119
- if self.tokenizer.pad_token is None:
120
- self.tokenizer.pad_token = self.tokenizer.eos_token
121
- base = AutoModelForCausalLM.from_pretrained(
122
- path, torch_dtype=torch.bfloat16
123
- if torch.cuda.is_available() else torch.float32,
124
- device_map="auto" if torch.cuda.is_available() else None)
125
- for p in base.parameters():
126
- p.requires_grad_(False)
127
- cfg = LoraConfig(r=lora_r, lora_alpha=2 * lora_r, lora_dropout=0.05,
128
- target_modules=["q_proj", "v_proj"],
129
- task_type="CAUSAL_LM")
130
- self.model = get_peft_model(base, cfg)
131
- self.vocab_size = self.model.config.vocab_size
132
- self.hidden_size = self.model.config.hidden_size
133
-
134
- def forward(self, ids):
135
- out = self.model(input_ids=ids.to(self.device),
136
- output_hidden_states=True)
137
- return out.logits, out.hidden_states[-1]
138
-
139
- @torch.no_grad()
140
- def generate(self, ids, n_new=32, temperature=0.8):
141
- return self.model.generate(
142
- input_ids=ids.to(self.device), max_new_tokens=n_new,
143
- do_sample=True, temperature=temperature,
144
- pad_token_id=self.tokenizer.pad_token_id)
145
-
146
- def trainable_parameters(self):
147
- return (p for p in self.model.parameters() if p.requires_grad)
148
-
149
- @property
150
- def device(self):
151
- return next(self.model.parameters()).device
152
-
153
-
154
- def load_llm(spec: str):
155
- return TinyLLM() if spec == "tiny" else HFLLM(spec)
156
-
157
-
158
- # ============================================================================
159
- # 2. Embedder (retrieval) — vocab-agnostic
160
- # ============================================================================
161
- class Embedder(nn.Module):
162
- def __init__(self, vocab):
163
- super().__init__()
164
- self.tok = nn.Embedding(vocab, D_EMB)
165
- self.gru = nn.GRU(D_EMB, D_EMB, batch_first=True, bidirectional=True)
166
- self.out = nn.Linear(2 * D_EMB, D_EMB)
167
-
168
- def forward(self, ids):
169
- h, _ = self.gru(self.tok(ids))
170
- return F.normalize(self.out(h.mean(1)), dim=-1)
171
-
172
-
173
- class VectorStore:
174
- def __init__(self):
175
- self.K, self.V = [], []
176
-
177
- def add(self, k, v):
178
- self.K.append(k.squeeze(0).detach().cpu()); self.V.append(v)
179
-
180
- def search(self, q, k=1):
181
- if not self.K:
182
- return []
183
- sims = (q.detach().cpu() @ torch.stack(self.K).t()).squeeze(0)
184
- return [self.V[i] for i in sims.topk(min(k, len(self.K))).indices]
185
-
186
-
187
- # ============================================================================
188
- # 3. JEPA — owns the latent space (EMA target encoder, variance-regularized)
189
- # ============================================================================
190
- class SegEncoder(nn.Module):
191
- def __init__(self, vocab):
192
- super().__init__()
193
- self.tok = nn.Embedding(vocab, D_LAT)
194
- self.gru = nn.GRU(D_LAT, D_LAT, batch_first=True)
195
- self.out = nn.Linear(D_LAT, D_LAT)
196
-
197
- def forward(self, ids):
198
- h, _ = self.gru(self.tok(ids))
199
- return self.out(h.mean(1)) # [B, D_LAT]
200
-
201
-
202
- class JEPA(nn.Module):
203
- def __init__(self, vocab, m=0.996):
204
- super().__init__()
205
- self.enc = SegEncoder(vocab) # context/online enc
206
- self.tgt = copy.deepcopy(self.enc) # EMA target
207
- for p in self.tgt.parameters():
208
- p.requires_grad_(False)
209
- self.pred = nn.Sequential(nn.Linear(D_LAT, 2 * D_LAT), nn.GELU(),
210
- nn.Linear(2 * D_LAT, D_LAT))
211
- self.m = m
212
-
213
- @torch.no_grad()
214
- def ema(self):
215
- for pt, pc in zip(self.tgt.parameters(), self.enc.parameters()):
216
- pt.mul_(self.m).add_(pc.detach(), alpha=1 - self.m)
217
-
218
- def loss(self, seg_a, seg_b):
219
- z_hat = self.pred(self.enc(seg_a))
220
- with torch.no_grad():
221
- z_tgt = self.tgt(seg_b)
222
- var = F.relu(1.0 - z_hat.std(0)).mean() # anti-collapse
223
- return F.mse_loss(z_hat, z_tgt) + 0.5 * var
224
-
225
- @torch.no_grad()
226
- def encode(self, seg): # target space
227
- return self.tgt(seg)
228
-
229
-
230
- # ============================================================================
231
- # 4. WORLD MODEL — latent dynamics s_{t+1} = f(s_t, z_t), multi-step rollout
232
- # Trained on SEQUENCES of segments: predict each next latent from state.
233
- # ============================================================================
234
- class WorldModel(nn.Module):
235
- def __init__(self):
236
- super().__init__()
237
- self.cell = nn.GRUCell(D_LAT, D_LAT) # state update
238
- self.head = nn.Linear(D_LAT, D_LAT) # state -> ẑ_{t+1}
239
- self.s0 = nn.Parameter(torch.zeros(D_LAT))
240
-
241
- def init_state(self, B, device):
242
- return self.s0.unsqueeze(0).expand(B, -1).contiguous().to(device)
243
-
244
- def step(self, s, z):
245
- """Consume observed latent z_t, return (new state, prediction ẑ_{t+1})."""
246
- s = self.cell(z, s)
247
- return s, self.head(s)
248
-
249
- def rollout(self, s, horizon):
250
- """Imagine H future latents feeding its own predictions back in."""
251
- preds = []
252
- for _ in range(horizon):
253
- z_hat = self.head(s)
254
- preds.append(z_hat)
255
- s = self.cell(z_hat, s)
256
- return torch.stack(preds, 1), s # [B, H, D_LAT]
257
-
258
- def sequence_loss(self, z_seq):
259
- """z_seq: [B, T, D_LAT] observed segment latents (teacher forcing)."""
260
- B, T, _ = z_seq.shape
261
- s = self.init_state(B, z_seq.device)
262
- loss = 0.0
263
- for t in range(T - 1):
264
- s, z_hat = self.step(s, z_seq[:, t])
265
- loss = loss + F.mse_loss(z_hat, z_seq[:, t + 1])
266
- return loss / (T - 1)
267
-
268
-
269
- # ============================================================================
270
- # 5. ENERGY MODEL — E(state, candidate latent) ∈ R, low = plausible
271
- # Trained with InfoNCE-style contrastive: positives = true next latent,
272
- # negatives = (a) other batch items, (b) LLM-generated drafts (optional).
273
- # ============================================================================
274
- class EnergyModel(nn.Module):
275
- def __init__(self):
276
- super().__init__()
277
- self.net = nn.Sequential(
278
- nn.Linear(2 * D_LAT, 2 * D_LAT), nn.GELU(),
279
- nn.Linear(2 * D_LAT, D_LAT), nn.GELU(),
280
- nn.Linear(D_LAT, 1))
281
-
282
- def energy(self, s, z):
283
- """s: [B, D_LAT] context state; z: [B, D_LAT] candidate. -> [B]"""
284
- return self.net(torch.cat([s, z], -1)).squeeze(-1)
285
-
286
- def contrastive_loss(self, s, z_pos, z_negs=None, tau=0.5):
287
- """Softmax over energies: true continuation must be the argmin.
288
- In-batch negatives: every other item's z_pos is a negative for s_i."""
289
- B = s.size(0)
290
- # pairwise energies: E(s_i, z_j) for all i, j
291
- s_rep = s.unsqueeze(1).expand(B, B, D_LAT).reshape(B * B, D_LAT)
292
- z_rep = z_pos.unsqueeze(0).expand(B, B, D_LAT).reshape(B * B, D_LAT)
293
- E = self.energy(s_rep, z_rep).view(B, B) # [B, B]
294
- if z_negs is not None: # extra hard negatives
295
- En = self.energy(
296
- s.repeat_interleave(z_negs.size(1), 0),
297
- z_negs.reshape(-1, D_LAT)).view(B, -1)
298
- E = torch.cat([E, En], dim=1)
299
- labels = torch.arange(B, device=s.device)
300
- return F.cross_entropy(-E / tau, labels) # low E ⇒ high logit
301
-
302
- @torch.no_grad()
303
- def rank(self, s, candidates):
304
- """candidates: [N, D_LAT]; returns energies [N] (lower = better)."""
305
- return self.energy(s.expand(candidates.size(0), -1), candidates)
306
-
307
-
308
- # ============================================================================
309
- # 6. Bridge — LLM hidden states -> shared latent space (alignment)
310
- # ============================================================================
311
- class Bridge(nn.Module):
312
- def __init__(self, d_hidden):
313
- super().__init__()
314
- self.proj = nn.Sequential(nn.Linear(d_hidden, D_LAT), nn.GELU(),
315
- nn.Linear(D_LAT, D_LAT))
316
-
317
- def forward(self, h): # [B,T,H] -> [B,D_LAT]
318
- return self.proj(h.float().mean(1))
319
-
320
- def info_nce(self, a, b, tau=0.07):
321
- a, b = F.normalize(a, -1), F.normalize(b, -1)
322
- logits = a @ b.t() / tau
323
- y = torch.arange(a.size(0), device=a.device)
324
- return 0.5 * (F.cross_entropy(logits, y) +
325
- F.cross_entropy(logits.t(), y))
326
-
327
-
328
- # ============================================================================
329
- # 7. THE ENSEMBLE — wiring + inference (plan -> generate -> energy-rank)
330
- # ============================================================================
331
- class WorldEnsemble(nn.Module):
332
- def __init__(self, llm_spec="tiny"):
333
- super().__init__()
334
- self.llm = load_llm(llm_spec)
335
- V, H = self.llm.vocab_size, self.llm.hidden_size
336
- self.emb = Embedder(V)
337
- self.jepa = JEPA(V)
338
- self.world = WorldModel()
339
- self.energy = EnergyModel()
340
- self.bridge = Bridge(H)
341
- self.store = VectorStore()
342
-
343
- # ------------------------- inference ---------------------------------
344
- @torch.no_grad()
345
- def world_state(self, segments):
346
- """Fold a list of [1,T] segment tensors into a latent state."""
347
- s = self.world.init_state(1, "cpu")
348
- for seg in segments:
349
- z = self.jepa.encode(seg.cpu())
350
- s, _ = self.world.step(s, z)
351
- return s
352
-
353
- @torch.no_grad()
354
- def answer(self, query_ids, n_new=24, n_drafts=6, horizon=3):
355
- """retrieve -> build world state -> imagine -> generate N -> argmin E."""
356
- q_emb = self.emb(query_ids.cpu())
357
- mems = self.store.search(q_emb, k=1)
358
- segments = (mems + [query_ids.cpu()]) if mems else [query_ids.cpu()]
359
- ctx = torch.cat(segments, dim=1)
360
-
361
- s = self.world_state(segments) # latent context state
362
- plan, _ = self.world.rollout(s, horizon) # imagined future
363
- # (plan is available for planning losses / steering; logged here)
364
-
365
- drafts, lat = [], []
366
- for _ in range(n_drafts):
367
- out = self.llm.generate(ctx.to(self.llm.device), n_new=n_new,
368
- temperature=0.9)
369
- new = out[:, ctx.size(1):].cpu()
370
- drafts.append(new)
371
- lat.append(self.jepa.encode(new))
372
- Z = torch.cat(lat, 0) # [N, D_LAT]
373
- E = self.energy.rank(s, Z) # lower = better
374
- best = E.argmin().item()
375
- return {"output": drafts[best], "energy": E[best].item(),
376
- "all_energies": E.tolist(),
377
- "plan_alignment": F.cosine_similarity(
378
- plan[:, 0], Z[best:best + 1]).item()}
379
-
380
- def memorize(self, ids):
381
- self.store.add(self.emb(ids.cpu()), ids.cpu())
382
-
383
- # ------------------------- training ----------------------------------
384
- def train_step(self, seg_seq, opt, w=dict(lm=1.0, jepa=1.0, world=1.0,
385
- ebm=1.0, bridge=0.1),
386
- hard_negs=True):
387
- """seg_seq: [B, T_seg, L] — B documents, each split into T_seg
388
- consecutive segments of length L (same tokenizer as the LLM)."""
389
- B, T, L = seg_seq.shape
390
- dev = self.llm.device
391
-
392
- # (1) LM loss on the first segment (or all, batched, if budget allows)
393
- flat = seg_seq[:, 0].to(dev)
394
- logits, hidden = self.llm(flat)
395
- lm = F.cross_entropy(
396
- logits[:, :-1].reshape(-1, self.llm.vocab_size).float(),
397
- flat[:, 1:].reshape(-1))
398
-
399
- # (2) JEPA: adjacent segment pairs
400
- jepa = self.jepa.loss(seg_seq[:, 0], seg_seq[:, 1])
401
-
402
- # (3) World model: sequence of latents (online encoder, grads flow)
403
- z_seq = torch.stack([self.jepa.enc(seg_seq[:, t])
404
- for t in range(T)], 1) # [B, T, D_LAT]
405
- world = self.world.sequence_loss(z_seq)
406
-
407
- # (4) Energy: state after t=0 must give low E to true z_1,
408
- # high E to in-batch + (optionally) LLM-generated negatives
409
- s = self.world.init_state(B, z_seq.device)
410
- s, _ = self.world.step(s, z_seq[:, 0].detach())
411
- z_pos = z_seq[:, 1].detach()
412
- z_negs = None
413
- if hard_negs:
414
- with torch.no_grad(): # model drafts as negs
415
- gen = self.llm.generate(seg_seq[:, 0].to(dev), n_new=L)
416
- gen_new = gen[:, seg_seq.size(2):].cpu()
417
- z_negs = self.jepa.encode(gen_new).unsqueeze(1) # [B,1,D]
418
- ebm = self.energy.contrastive_loss(s, z_pos, z_negs)
419
-
420
- # (5) Bridge: align LLM hidden(seg0) with JEPA latent(seg0)
421
- bridge = self.bridge.info_nce(
422
- self.bridge(hidden.cpu() if hidden.device.type != "cpu" else hidden),
423
- self.jepa.enc(seg_seq[:, 0]).detach())
424
-
425
- loss = (w["lm"] * lm.cpu() + w["jepa"] * jepa + w["world"] * world
426
- + w["ebm"] * ebm + w["bridge"] * bridge)
427
- opt.zero_grad(); loss.backward(); opt.step()
428
- self.jepa.ema()
429
- return dict(lm=lm.item(), jepa=jepa.item(), world=world.item(),
430
- ebm=ebm.item(), bridge=bridge.item())
431
-
432
- def make_optimizer(self, lr_lora=2e-4, lr_aux=1e-3):
433
- return torch.optim.AdamW([
434
- {"params": list(self.llm.trainable_parameters()), "lr": lr_lora},
435
- {"params": list(self.jepa.enc.parameters())
436
- + list(self.jepa.pred.parameters()), "lr": lr_aux},
437
- {"params": list(self.world.parameters()), "lr": lr_aux},
438
- {"params": list(self.energy.parameters()), "lr": lr_aux},
439
- {"params": list(self.bridge.parameters())
440
- + list(self.emb.parameters()), "lr": lr_aux}])
441
-
442
-
443
- # ============================================================================
444
- # 8. Data helpers + demo
445
- # ============================================================================
446
- def toy_segment_sequences(B=8, T=4, L=24, vocab=1000):
447
- """Random docs split into T consecutive segments. Replace with real
448
- corpus: tokenize each document, reshape into [T, L] windows."""
449
- return torch.randint(0, vocab, (B, T, L))
450
-
451
-
452
- def hf_segment_sequences(llm: HFLLM, texts, T=4, L=64):
453
- seqs = []
454
- for t in texts:
455
- ids = llm.tokenizer(t, return_tensors="pt").input_ids[0]
456
- n = (len(ids) // (T * L)) * T * L
457
- if n:
458
- seqs.append(ids[:n].view(-1, T, L))
459
- if not seqs:
460
- raise ValueError("corpus too short for T*L window")
461
- return torch.cat(seqs, 0)
462
-
463
-
464
- def demo(spec="tiny", steps=60):
465
- ens = WorldEnsemble(spec)
466
- opt = ens.make_optimizer()
467
-
468
- if spec == "tiny":
469
- get_batch = lambda: toy_segment_sequences(vocab=ens.llm.vocab_size)
470
- else:
471
- corpus = ["Replace with your real documents. " * 200]
472
- data = hf_segment_sequences(ens.llm, corpus, T=4, L=32)
473
- get_batch = lambda: data[torch.randperm(len(data))[:4]]
474
- steps = min(steps, 10)
475
-
476
- t0 = time.time()
477
- for s in range(steps):
478
- logs = ens.train_step(get_batch(), opt,
479
- hard_negs=(s > steps // 2)) # warmup w/o negs
480
- if s % 10 == 0:
481
- print(f"step {s:3d} | " +
482
- " | ".join(f"{k} {v:.3f}" for k, v in logs.items()))
483
- print(f"trained {steps} steps in {time.time()-t0:.1f}s")
484
-
485
- # memory + inference
486
- for _ in range(4):
487
- if spec == "tiny":
488
- ens.memorize(torch.randint(0, ens.llm.vocab_size, (1, 24)))
489
- q = (torch.randint(0, ens.llm.vocab_size, (1, 12)) if spec == "tiny"
490
- else ens.llm.tokenizer("What is this document about?",
491
- return_tensors="pt").input_ids)
492
- res = ens.answer(q, n_drafts=6, horizon=3)
493
- print(f"\nselected draft energy={res['energy']:.3f} "
494
- f"(all: {[f'{e:.2f}' for e in res['all_energies']]})")
495
- print(f"plan↔output alignment: {res['plan_alignment']:.3f}")
496
-
497
-
498
- if __name__ == "__main__":
499
- demo(sys.argv[1] if len(sys.argv) > 1 else "tiny")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
research/eval_harness.py DELETED
@@ -1,6 +0,0 @@
1
- """Deprecated shim — use `ensemble.eval.jepa_harness` instead."""
2
-
3
- from ensemble.eval.jepa_harness import run, parse_args
4
-
5
- if __name__ == "__main__":
6
- run(parse_args())
 
 
 
 
 
 
 
research/evals/USAGE.md CHANGED
@@ -197,12 +197,6 @@ uv run --package slm-evals slm-lm-eval \
197
  --model openbmb/MiniCPM5-1B \
198
  --adapter ./models/finetuned/minicpm5-1b-lora \
199
  --experiment-name minicpm5-1b-lora__manual
200
-
201
- # Ensemble checkpoint (manifest.json auto-detected)
202
- uv run --package slm-evals slm-lm-eval \
203
- --config research/evals/configs/lm_eval_smoke.yaml \
204
- --model ./models/ensemble/jepa-lesson-pretrain \
205
- --experiment-name ensemble-jepa__lm-eval
206
  ```
207
 
208
  ### Compare baseline vs candidate
@@ -259,8 +253,8 @@ slm-lm-eval [OPTIONS]
259
  --list-tasks-all Full lm-eval task list
260
  --profile NAME Shorthand for --config (reasoning, code, smoke, …)
261
  --config PATH YAML config (tasks, seed, limit, …)
262
- --preset KEY models.yaml preset (base, LoRA, merged, ensemble)
263
- --model PATH HF Hub id, merged dir, or ensemble checkpoint
264
  --adapter PATH LoRA adapter (alternative to preset adapter_path)
265
  --tasks NAMES Override task list
266
  --num-fewshot N
@@ -285,12 +279,6 @@ Each run writes to `<output_dir>/<experiment_name>/`:
285
  | `run_meta.json` | Preset, base model, adapter, tasks, seed |
286
  | `comparison.md` | Delta table (when `--compare-to` set) |
287
 
288
- ### Ensemble backend notes
289
-
290
- - **`ensemble-lm`** loads JEPA checkpoints via `manifest.json`.
291
- - **`generate_until`** tasks (e.g. `gsm8k`) use the full ensemble stack (`generate_text`).
292
- - **`loglikelihood`** tasks (e.g. `arc_easy`, `hellaswag`) score the underlying HF LLM head (adapter 0), not the JEPA selector. Use [`jepa_harness`](../ensemble/README.md) to measure selector value on domain QA.
293
-
294
  ### PEFT / LoRA
295
 
296
  lm-eval expects `pretrained=<base>,peft=<adapter>`. The preset resolver handles this for keys like `minicpm5-1b-lesson-lora`. Merged checkpoints use `--preset minicpm5-1b-lesson-merged` or `--model ./models/finetuned/...-merged`.
 
197
  --model openbmb/MiniCPM5-1B \
198
  --adapter ./models/finetuned/minicpm5-1b-lora \
199
  --experiment-name minicpm5-1b-lora__manual
 
 
 
 
 
 
200
  ```
201
 
202
  ### Compare baseline vs candidate
 
253
  --list-tasks-all Full lm-eval task list
254
  --profile NAME Shorthand for --config (reasoning, code, smoke, …)
255
  --config PATH YAML config (tasks, seed, limit, …)
256
+ --preset KEY models.yaml preset (base, LoRA, merged)
257
+ --model PATH HF Hub id or merged checkpoint dir
258
  --adapter PATH LoRA adapter (alternative to preset adapter_path)
259
  --tasks NAMES Override task list
260
  --num-fewshot N
 
279
  | `run_meta.json` | Preset, base model, adapter, tasks, seed |
280
  | `comparison.md` | Delta table (when `--compare-to` set) |
281
 
 
 
 
 
 
 
282
  ### PEFT / LoRA
283
 
284
  lm-eval expects `pretrained=<base>,peft=<adapter>`. The preset resolver handles this for keys like `minicpm5-1b-lesson-lora`. Merged checkpoints use `--preset minicpm5-1b-lesson-merged` or `--model ./models/finetuned/...-merged`.
research/evals/configs/ensemble_jepa_lesson.yaml DELETED
@@ -1,24 +0,0 @@
1
- # JEPA ensemble checkpoint (models/ensemble/jepa-lesson-pretrain)
2
- # Pretrain: uv run --package ensemble ensemble-pretrain --llm Qwen/Qwen2.5-0.5B-Instruct
3
- # Compare baseline: copy this file, set model_path to the base Hub id and model_type: hf
4
-
5
- model_path: "./models/ensemble/jepa-lesson-pretrain"
6
- model_type: "ensemble"
7
- device: "auto"
8
- dtype: "bfloat16"
9
-
10
- max_new_tokens: 512
11
- temperature: 0.0
12
-
13
- experiment_name: "jepa-ensemble-lesson__bfcl-tau__v1"
14
- output_dir: "results"
15
-
16
- benchmarks:
17
- - bfcl
18
- - tau_bench
19
-
20
- max_samples: 20
21
-
22
- benchmark_overrides:
23
- tau_bench:
24
- use_llm_user: false