Spaces:
Running on Zero
Running on Zero
MSG commited on
Commit ·
871f869
1
Parent(s): 8c6b423
Feat/last sprint (#12)
Browse files* hf docker build
* gradio sdk deploy for zerogpu mode
* gradio structure wip
* server gpu and tasks
* fix sdk mode
* wip studio fix gradio
* fix studio
* index fix ui
* ui coach
* ui voice
* index ui
* cleaning stuff
* clean
* clean stuff experiment
* ui css app
* fix css
This view is limited to 50 files because it contains too many changes. See raw diff
- .cursor/plans/gradio_sdk_deploy_58daaf6e.plan.md +268 -0
- .cursor/plans/hf_space_publish_e8a57bab.plan.md +208 -0
- .env.example +0 -10
- .gitignore +1 -0
- Dockerfile +2 -0
- README.md +13 -14
- USAGE.md +86 -68
- app.py +6 -0
- apps/gradio-space/src/gradio_space/model_loading.py +3 -0
- apps/gradio-space/src/gradio_space/research_helpers.py +2 -0
- apps/gradio-space/src/gradio_space/server.py +3 -1
- apps/gradio-space/src/gradio_space/spaces_runtime.py +37 -0
- apps/gradio-space/src/gradio_space/tabs/echo_coach.py +2 -0
- apps/gradio-space/src/gradio_space/tabs/education_pptx.py +3 -0
- apps/gradio-space/src/gradio_space/tabs/research_mind.py +4 -0
- apps/gradio-space/src/gradio_space/tabs/teacher_voice.py +3 -0
- apps/gradio-space/static/studio/index.html +111 -82
- apps/gradio-space/static/studio/studio.css +283 -3
- models.yaml +0 -6
- packages.txt +2 -0
- pyproject.toml +0 -8
- requirements.txt +32 -0
- research/README.md +6 -10
- research/USAGE.md +8 -99
- research/docs/overview.md +11 -49
- research/ensemble/README.md +0 -113
- research/ensemble/pyproject.toml +0 -16
- research/ensemble/scripts/smoke.sh +0 -35
- research/ensemble/src/ensemble/__init__.py +0 -15
- research/ensemble/src/ensemble/backends.py +0 -418
- research/ensemble/src/ensemble/bridge.py +0 -28
- research/ensemble/src/ensemble/checkpoint.py +0 -149
- research/ensemble/src/ensemble/config.py +0 -163
- research/ensemble/src/ensemble/energy.py +0 -45
- research/ensemble/src/ensemble/eval/__init__.py +0 -1
- research/ensemble/src/ensemble/eval/jepa_harness.py +0 -266
- research/ensemble/src/ensemble/eval/metrics.py +0 -42
- research/ensemble/src/ensemble/eval/world_harness.py +0 -174
- research/ensemble/src/ensemble/eval_harness.py +0 -309
- research/ensemble/src/ensemble/jepa.py +0 -75
- research/ensemble/src/ensemble/jepa_ensemble.py +0 -232
- research/ensemble/src/ensemble/llm_emb_jepa_ensemble_pluggable.py +0 -507
- research/ensemble/src/ensemble/memory.py +0 -46
- research/ensemble/src/ensemble/pretrain.py +0 -198
- research/ensemble/src/ensemble/world_ensemble.py +0 -228
- research/ensemble/src/ensemble/world_model.py +0 -40
- research/ensemble/src/ensemble/world_model_ensemble.py +0 -499
- research/eval_harness.py +0 -6
- research/evals/USAGE.md +2 -14
- research/evals/configs/ensemble_jepa_lesson.yaml +0 -24
.cursor/plans/gradio_sdk_deploy_58daaf6e.plan.md
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: Gradio SDK Deploy
|
| 3 |
+
overview: "Add Gradio SDK deployment files on main alongside the existing Dockerfile, switch README to `sdk: gradio` for ZeroGPU Spaces, and add `@spaces.GPU` wrappers on LLM entry points so the full Studio + Classic app runs on HF without Docker."
|
| 4 |
+
todos:
|
| 5 |
+
- id: root-gradio-files
|
| 6 |
+
content: Add root app.py, requirements.txt, packages.txt with editable workspace installs and Debian deps
|
| 7 |
+
status: completed
|
| 8 |
+
- id: readme-gradio-sdk
|
| 9 |
+
content: "Fix README YAML frontmatter and switch to sdk: gradio (sdk_version 6.16.0, app_file: app.py)"
|
| 10 |
+
status: completed
|
| 11 |
+
- id: spaces-runtime
|
| 12 |
+
content: Add gradio_space/spaces_runtime.py with gpu_task decorator and is_hf_gradio_runtime()
|
| 13 |
+
status: completed
|
| 14 |
+
- id: zerogpu-decorators
|
| 15 |
+
content: Apply @gpu_task to LLM entry points in model_loading, research_helpers, and tab handlers; skip preload on HF Gradio runtime in server.py
|
| 16 |
+
status: completed
|
| 17 |
+
- id: usage-docs
|
| 18 |
+
content: Update USAGE.md with Gradio SDK + ZeroGPU deploy steps; demote Docker section to later phase
|
| 19 |
+
status: completed
|
| 20 |
+
- id: local-smoke
|
| 21 |
+
content: Validate pip install + python app.py locally before pushing to HF Space
|
| 22 |
+
status: completed
|
| 23 |
+
- id: hf-space-create
|
| 24 |
+
content: Create Gradio Space under build-small-hackathon with ZeroGPU hardware and env vars; verify Studio + Classic smoke tests
|
| 25 |
+
status: cancelled
|
| 26 |
+
isProject: false
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
# Gradio SDK + ZeroGPU deployment (same branch as Docker)
|
| 30 |
+
|
| 31 |
+
## Goal
|
| 32 |
+
|
| 33 |
+
Ship the **full app** (Studio at `/`, Classic at `/classic`, all tabs) via **Gradio SDK** on Hugging Face with **ZeroGPU**, while keeping [`Dockerfile`](Dockerfile) on `main` untouched for a later Docker Space phase.
|
| 34 |
+
|
| 35 |
+
## Same-branch constraint (important)
|
| 36 |
+
|
| 37 |
+
HF reads **one** `sdk:` value from root [`README.md`](README.md). Both deploy paths can live on the same branch as files, but **only one SDK is active per branch at a time**:
|
| 38 |
+
|
| 39 |
+
| Files on `main` | Active when README says |
|
| 40 |
+
|-----------------|-------------------------|
|
| 41 |
+
| `app.py`, `requirements.txt`, `packages.txt` | `sdk: gradio` |
|
| 42 |
+
| [`Dockerfile`](Dockerfile) | `sdk: docker` + `app_port: 7860` |
|
| 43 |
+
|
| 44 |
+
**Phase 1 (now):** set `sdk: gradio` — Gradio Space builds from `app.py`.
|
| 45 |
+
**Phase 2 (later):** flip README to `sdk: docker` for Docker Space, or use a **second HF Space on a second branch** if you need both live at once.
|
| 46 |
+
|
| 47 |
+
```mermaid
|
| 48 |
+
flowchart TB
|
| 49 |
+
subgraph repo [main branch]
|
| 50 |
+
AppPy[app.py]
|
| 51 |
+
ReqTxt[requirements.txt]
|
| 52 |
+
DockerFile[Dockerfile]
|
| 53 |
+
Shared[apps/gradio-space + libs + skills]
|
| 54 |
+
end
|
| 55 |
+
subgraph phase1 [Phase 1 active]
|
| 56 |
+
ReadmeG[sdk: gradio in README]
|
| 57 |
+
HFGradio[HF Gradio SDK build]
|
| 58 |
+
ZeroGPU[ZeroGPU hardware]
|
| 59 |
+
end
|
| 60 |
+
subgraph phase2 [Phase 2 later]
|
| 61 |
+
ReadmeD[sdk: docker in README]
|
| 62 |
+
HFDocker[HF Docker build]
|
| 63 |
+
GPUBasic[GPU Basic hardware]
|
| 64 |
+
end
|
| 65 |
+
Shared --> AppPy
|
| 66 |
+
Shared --> DockerFile
|
| 67 |
+
ReadmeG --> HFGradio --> ZeroGPU
|
| 68 |
+
ReadmeD --> HFDocker --> GPUBasic
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
---
|
| 72 |
+
|
| 73 |
+
## Phase 1 — Add Gradio SDK root files
|
| 74 |
+
|
| 75 |
+
### 1. Root [`app.py`](app.py)
|
| 76 |
+
|
| 77 |
+
Thin entry point that reuses the existing server (no UI rewrite):
|
| 78 |
+
|
| 79 |
+
```python
|
| 80 |
+
from gradio_space.server import main
|
| 81 |
+
|
| 82 |
+
if __name__ == "__main__":
|
| 83 |
+
main()
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
HF Gradio SDK executes `app.py`; [`server.py`](apps/gradio-space/src/gradio_space/server.py) already calls `server.launch()` on port 7860 with Studio + `/classic`.
|
| 87 |
+
|
| 88 |
+
### 2. Root [`requirements.txt`](requirements.txt)
|
| 89 |
+
|
| 90 |
+
Pip-install workspace packages via editable paths (HF clones the full repo):
|
| 91 |
+
|
| 92 |
+
```text
|
| 93 |
+
-e ./libs/inference
|
| 94 |
+
-e ./libs/researchmind
|
| 95 |
+
-e ./libs/agent
|
| 96 |
+
-e ./libs/echocoach[piper,whisper]
|
| 97 |
+
-e ./apps/gradio-space
|
| 98 |
+
# plus transitive deps from libs/*/pyproject.toml (torch, transformers, sentence-transformers, python-pptx, etc.)
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
Rules (per [HF Spaces dependencies](https://huggingface.co/docs/hub/spaces-dependencies)):
|
| 102 |
+
|
| 103 |
+
- **Do not pin** `gradio`, `spaces`, or `huggingface_hub` — HF preinstalls them.
|
| 104 |
+
- Pin heavy libs that matter for reproducibility: `torch`, `transformers`, `accelerate`, `sentence-transformers`, etc.
|
| 105 |
+
- Keep `llama-cpp-python` for preset parity (HF image has `cmake`; build may be slow).
|
| 106 |
+
|
| 107 |
+
Optional: add [`scripts/sync-requirements.sh`](scripts/sync-requirements.sh) later to regenerate from `pyproject.toml` files — not required for v1.
|
| 108 |
+
|
| 109 |
+
### 3. Root [`packages.txt`](packages.txt)
|
| 110 |
+
|
| 111 |
+
Debian deps beyond HF defaults (mirror [`Dockerfile`](Dockerfile) apt lines):
|
| 112 |
+
|
| 113 |
+
```text
|
| 114 |
+
ffmpeg
|
| 115 |
+
libsndfile1
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
### 4. Fix + switch README frontmatter
|
| 119 |
+
|
| 120 |
+
Current [`README.md`](README.md) has a blank line after `---` and still declares Docker. Update to:
|
| 121 |
+
|
| 122 |
+
```yaml
|
| 123 |
+
---
|
| 124 |
+
title: Lesson Agent
|
| 125 |
+
emoji: 📚
|
| 126 |
+
colorFrom: blue
|
| 127 |
+
colorTo: green
|
| 128 |
+
sdk: gradio
|
| 129 |
+
sdk_version: "6.16.0"
|
| 130 |
+
app_file: app.py
|
| 131 |
+
python_version: "3.12"
|
| 132 |
+
pinned: false
|
| 133 |
+
license: apache-2.0
|
| 134 |
+
---
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
Remove `app_port` (Docker-only). Keep [`Dockerfile`](Dockerfile) in repo for phase 2.
|
| 138 |
+
|
| 139 |
+
---
|
| 140 |
+
|
| 141 |
+
## Phase 2 — ZeroGPU runtime hooks
|
| 142 |
+
|
| 143 |
+
ZeroGPU requires all CUDA work inside `@spaces.GPU`. The decorator is a **no-op** locally and on dedicated GPU Spaces, so it is safe to apply everywhere.
|
| 144 |
+
|
| 145 |
+
### New module: [`apps/gradio-space/src/gradio_space/spaces_runtime.py`](apps/gradio-space/src/gradio_space/spaces_runtime.py)
|
| 146 |
+
|
| 147 |
+
```python
|
| 148 |
+
def gpu_task(*, duration: int = 180, size: str = "large"):
|
| 149 |
+
"""Apply @spaces.GPU when the HF spaces runtime is present."""
|
| 150 |
+
...
|
| 151 |
+
|
| 152 |
+
def is_hf_gradio_runtime() -> bool:
|
| 153 |
+
"""True on HF Gradio SDK Spaces (skip startup model preload)."""
|
| 154 |
+
...
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
Use `duration=180`–`300` for agent/slide flows; `duration=60` for simple chat.
|
| 158 |
+
|
| 159 |
+
### Skip startup preload on HF Gradio runtime
|
| 160 |
+
|
| 161 |
+
[`server.py`](apps/gradio-space/src/gradio_space/server.py) currently calls `preload_active_model()` before launch — this fails on ZeroGPU (no GPU at process start):
|
| 162 |
+
|
| 163 |
+
```69:69:apps/gradio-space/src/gradio_space/server.py
|
| 164 |
+
preload_active_model()
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
Change to:
|
| 168 |
+
|
| 169 |
+
```python
|
| 170 |
+
if not is_hf_gradio_runtime():
|
| 171 |
+
preload_active_model()
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
First user request lazy-loads inside a `@spaces.GPU`-decorated handler.
|
| 175 |
+
|
| 176 |
+
### Decorate LLM entry points (not every `backend.chat` call)
|
| 177 |
+
|
| 178 |
+
Wrap **top-level handlers** so multi-step agent loops run inside one GPU allocation:
|
| 179 |
+
|
| 180 |
+
| Module | Functions to decorate |
|
| 181 |
+
|--------|----------------------|
|
| 182 |
+
| [`model_loading.py`](apps/gradio-space/src/gradio_space/model_loading.py) | `chat`, `reload_model` |
|
| 183 |
+
| [`research_helpers.py`](apps/gradio-space/src/gradio_space/research_helpers.py) | `run_research_question`, `rag_aware_chat` |
|
| 184 |
+
| [`tabs/education_pptx.py`](apps/gradio-space/src/gradio_space/tabs/education_pptx.py) | `generate_lesson_slides`, `discover_lesson_sources` |
|
| 185 |
+
| [`tabs/research_mind.py`](apps/gradio-space/src/gradio_space/tabs/research_mind.py) | `discover_sources`, `ask_question`, `auto_search_ingest` |
|
| 186 |
+
| [`tabs/echo_coach.py`](apps/gradio-space/src/gradio_space/tabs/echo_coach.py) | `analyze_pitch` |
|
| 187 |
+
| [`tabs/teacher_voice.py`](apps/gradio-space/src/gradio_space/tabs/teacher_voice.py) | text/audio turn handlers |
|
| 188 |
+
|
| 189 |
+
Studio APIs in [`api/studio.py`](apps/gradio-space/src/gradio_space/api/studio.py) call these helpers — decorating the tab/helper layer avoids duplicating decorators on ~20 API wrappers.
|
| 190 |
+
|
| 191 |
+
**Generator caveat:** `generate_lesson_slides` uses `yield` for progress. If ZeroGPU rejects generator handlers, extract GPU work into a plain `@gpu_task` function and keep the outer generator for UI progress only (test on Space Logs after first deploy).
|
| 192 |
+
|
| 193 |
+
**Embeddings (ResearchMind ingest):** sentence-transformers can stay on CPU for v1; only LLM paths need `@spaces.GPU` initially.
|
| 194 |
+
|
| 195 |
+
---
|
| 196 |
+
|
| 197 |
+
## Phase 3 — Space configuration
|
| 198 |
+
|
| 199 |
+
Create Space under [build-small-hackathon](https://huggingface.co/build-small-hackathon):
|
| 200 |
+
|
| 201 |
+
| Setting | Value |
|
| 202 |
+
|---------|-------|
|
| 203 |
+
| SDK | **Gradio** (Blank template) |
|
| 204 |
+
| Hardware | **ZeroGPU** (creator needs PRO/Team) |
|
| 205 |
+
| Repo | GitHub `main` (or push to Space git) |
|
| 206 |
+
|
| 207 |
+
**Environment variables** (Settings → Variables):
|
| 208 |
+
|
| 209 |
+
| Variable | Value |
|
| 210 |
+
|----------|-------|
|
| 211 |
+
| `ACTIVE_MODEL` | `minicpm5-1b` |
|
| 212 |
+
| `ALLOW_MODEL_SWITCH` | `false` |
|
| 213 |
+
| `RESEARCHMIND_DATA_DIR` | `/tmp/researchmind` |
|
| 214 |
+
|
| 215 |
+
Default preset in [`models.yaml`](models.yaml) is already `minicpm5-1b` (transformers) — good fit for ZeroGPU.
|
| 216 |
+
|
| 217 |
+
---
|
| 218 |
+
|
| 219 |
+
## Phase 4 — Docs and local smoke test
|
| 220 |
+
|
| 221 |
+
Update [`USAGE.md`](USAGE.md):
|
| 222 |
+
|
| 223 |
+
- New **Gradio SDK deployment** section (primary path): `app.py`, `requirements.txt`, ZeroGPU, env vars.
|
| 224 |
+
- Move existing Docker section to **"Docker SDK (later)"** — note README must switch to `sdk: docker` + `app_port: 7860`.
|
| 225 |
+
- Local Gradio SDK smoke test:
|
| 226 |
+
|
| 227 |
+
```bash
|
| 228 |
+
python -m venv .venv && source .venv/bin/activate
|
| 229 |
+
pip install -r requirements.txt
|
| 230 |
+
ACTIVE_MODEL=minicpm5-1b ALLOW_MODEL_SWITCH=false python app.py
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
Keep existing `uv run` workflow for day-to-day dev unchanged.
|
| 234 |
+
|
| 235 |
+
Update [`.cursor/plans/hf_space_publish_e8a57bab.plan.md`](.cursor/plans/hf_space_publish_e8a57bab.plan.md) todos to reflect Gradio-first ordering.
|
| 236 |
+
|
| 237 |
+
---
|
| 238 |
+
|
| 239 |
+
## Phase 5 — Verify on Space
|
| 240 |
+
|
| 241 |
+
1. **Logs** — pip install succeeds; app starts on `0.0.0.0:7860`.
|
| 242 |
+
2. **`/` Studio** — loads static UI.
|
| 243 |
+
3. **`/classic`** — all tabs render.
|
| 244 |
+
4. **Smoke flows** — slides generation, research chat, EchoCoach sample clip, teacher voice text turn.
|
| 245 |
+
5. **ZeroGPU** — first LLM request allocates GPU (may be slow on cold start); watch for "No CUDA GPUs" (means handler is outside `@spaces.GPU`).
|
| 246 |
+
|
| 247 |
+
---
|
| 248 |
+
|
| 249 |
+
## Phase 6 — Docker later (no code removal)
|
| 250 |
+
|
| 251 |
+
When ready for Docker Space:
|
| 252 |
+
|
| 253 |
+
1. Change README to `sdk: docker`, `app_port: 7860` (remove `sdk_version` / `app_file`).
|
| 254 |
+
2. Create a **second Space** (or reuse after README flip) with **GPU Basic** hardware.
|
| 255 |
+
3. Existing [`Dockerfile`](Dockerfile) + `uv sync` path unchanged; no `@spaces.GPU` needed on dedicated GPU.
|
| 256 |
+
|
| 257 |
+
Both file sets remain on `main`; only README `sdk:` toggles which build HF runs.
|
| 258 |
+
|
| 259 |
+
---
|
| 260 |
+
|
| 261 |
+
## Risk notes
|
| 262 |
+
|
| 263 |
+
| Risk | Mitigation |
|
| 264 |
+
|------|------------|
|
| 265 |
+
| `pip install llama-cpp-python` slow/fails on HF | Accept slow build; default `minicpm5-1b` avoids GGUF at runtime |
|
| 266 |
+
| EchoCoach deps (piper, whisper) heavy | Full scope requested; pin versions; fix from Space Logs if needed |
|
| 267 |
+
| ZeroGPU + generator slide progress | Refactor GPU block to non-generator helper if build succeeds but inference fails |
|
| 268 |
+
| Two live Spaces same branch | Not supported with different SDKs — use README flip or second branch for concurrent Docker + Gradio |
|
.cursor/plans/hf_space_publish_e8a57bab.plan.md
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: HF Space Publish
|
| 3 |
+
overview: Fix two repo blockers (README Space card YAML and missing `researchmind` in Dockerfile), validate locally with Docker, push to GitHub, then create a Docker Space under build-small-hackathon linked to GitHub with GPU hardware and MiniCPM5-1B env vars.
|
| 4 |
+
todos:
|
| 5 |
+
- id: fix-readme-yaml
|
| 6 |
+
content: "Fix root README.md frontmatter: change `## title:` to `title:`"
|
| 7 |
+
status: completed
|
| 8 |
+
- id: fix-dockerfile-researchmind
|
| 9 |
+
content: Add libs/researchmind COPY lines to Dockerfile
|
| 10 |
+
status: completed
|
| 11 |
+
- id: local-docker-smoke
|
| 12 |
+
content: Run docker build + docker run locally on port 7860 with ACTIVE_MODEL=minicpm5-1b
|
| 13 |
+
status: in_progress
|
| 14 |
+
- id: push-github
|
| 15 |
+
content: Push fixed branch to GitHub repo
|
| 16 |
+
status: pending
|
| 17 |
+
- id: create-space
|
| 18 |
+
content: Create Docker Space under build-small-hackathon, link GitHub, set GPU basic
|
| 19 |
+
status: pending
|
| 20 |
+
- id: configure-env
|
| 21 |
+
content: "Set Space secrets: ACTIVE_MODEL=minicpm5-1b, ALLOW_MODEL_SWITCH=false, RESEARCHMIND_DATA_DIR=/tmp/researchmind"
|
| 22 |
+
status: pending
|
| 23 |
+
- id: verify-live
|
| 24 |
+
content: Check Space Logs, test / and /classic, confirm slide generation works
|
| 25 |
+
status: pending
|
| 26 |
+
isProject: false
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
# Publish Gradio app to Hugging Face Space
|
| 30 |
+
|
| 31 |
+
## Current state
|
| 32 |
+
|
| 33 |
+
Your repo is **mostly ready** for a Docker Space:
|
| 34 |
+
|
| 35 |
+
- Root [`Dockerfile`](Dockerfile) exposes port **7860** and runs `python -m gradio_space.app`
|
| 36 |
+
- Root [`README.md`](README.md) has Space metadata (`sdk: docker`, `app_port: 7860`)
|
| 37 |
+
- Default model in [`models.yaml`](models.yaml) is **`minicpm5-1b`** (transformers, `openbmb/MiniCPM5-1B`)
|
| 38 |
+
|
| 39 |
+
Two issues will likely **break the Space build or card** until fixed:
|
| 40 |
+
|
| 41 |
+
### Blocker 1 — README YAML is malformed
|
| 42 |
+
|
| 43 |
+
The Space card frontmatter must use `title:`, not a markdown heading:
|
| 44 |
+
|
| 45 |
+
```yaml
|
| 46 |
+
# Current (wrong)
|
| 47 |
+
## title: Lesson Agent
|
| 48 |
+
|
| 49 |
+
# Required (correct)
|
| 50 |
+
title: Lesson Agent
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
HF reads YAML from the **root** [`README.md`](README.md) only. Keep [`apps/gradio-space/README.md`](apps/gradio-space/README.md) as dev docs.
|
| 54 |
+
|
| 55 |
+
### Blocker 2 — Dockerfile missing `researchmind`
|
| 56 |
+
|
| 57 |
+
[`libs/agent`](libs/agent/pyproject.toml) depends on `researchmind`, but the Dockerfile only copies `inference`, `agent`, and `echocoach`. `uv sync` inside the image will fail without:
|
| 58 |
+
|
| 59 |
+
```dockerfile
|
| 60 |
+
COPY libs/researchmind/pyproject.toml libs/researchmind/README.md libs/researchmind/
|
| 61 |
+
COPY libs/researchmind/src libs/researchmind/src
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
Add these lines alongside the other `libs/*` COPY blocks in [`Dockerfile`](Dockerfile).
|
| 65 |
+
|
| 66 |
+
---
|
| 67 |
+
|
| 68 |
+
## Architecture (what gets deployed)
|
| 69 |
+
|
| 70 |
+
```mermaid
|
| 71 |
+
flowchart LR
|
| 72 |
+
subgraph hf [HuggingFaceSpace]
|
| 73 |
+
DockerBuild[DockerBuild]
|
| 74 |
+
Container[Container_port7860]
|
| 75 |
+
end
|
| 76 |
+
GitHub[GitHub_repo] --> DockerBuild
|
| 77 |
+
DockerBuild --> Container
|
| 78 |
+
Container --> StudioUI["/ Studio UI"]
|
| 79 |
+
Container --> ClassicUI["/classic Gradio tabs"]
|
| 80 |
+
Container --> HubModel["Hub: openbmb/MiniCPM5-1B"]
|
| 81 |
+
HubModel --> Container
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
Entrypoint (unchanged):
|
| 85 |
+
|
| 86 |
+
```44:44:Dockerfile
|
| 87 |
+
CMD ["uv", "run", "--package", "gradio-space", "python", "-m", "gradio_space.app"]
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
This launches [`gradio_space.server`](apps/gradio-space/src/gradio_space/server.py): Studio at `/`, Classic tabs at `/classic`.
|
| 91 |
+
|
| 92 |
+
---
|
| 93 |
+
|
| 94 |
+
## Phase 1 — Fix repo files (before push)
|
| 95 |
+
|
| 96 |
+
| File | Change |
|
| 97 |
+
|------|--------|
|
| 98 |
+
| [`README.md`](README.md) | Fix frontmatter: `title: Lesson Agent` (remove `##`); keep `sdk: docker`, `app_port: 7860` |
|
| 99 |
+
| [`Dockerfile`](Dockerfile) | Add `libs/researchmind` pyproject + src COPY lines |
|
| 100 |
+
|
| 101 |
+
Optional but recommended in README frontmatter (already present except title):
|
| 102 |
+
|
| 103 |
+
```yaml
|
| 104 |
+
---
|
| 105 |
+
title: Lesson Agent
|
| 106 |
+
emoji: 📚
|
| 107 |
+
colorFrom: blue
|
| 108 |
+
colorTo: green
|
| 109 |
+
sdk: docker
|
| 110 |
+
app_port: 7860
|
| 111 |
+
pinned: false
|
| 112 |
+
license: apache-2.0
|
| 113 |
+
---
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
---
|
| 117 |
+
|
| 118 |
+
## Phase 2 — Validate locally with Docker
|
| 119 |
+
|
| 120 |
+
From repo root:
|
| 121 |
+
|
| 122 |
+
```bash
|
| 123 |
+
docker build -t hackathon-space .
|
| 124 |
+
docker run --rm -p 7860:7860 \
|
| 125 |
+
-e ACTIVE_MODEL=minicpm5-1b \
|
| 126 |
+
-e ALLOW_MODEL_SWITCH=false \
|
| 127 |
+
hackathon-space
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
Open [http://localhost:7860](http://localhost:7860) (`/` Studio, `/classic` tabs). First model load downloads weights from Hub — expect several minutes on first run.
|
| 131 |
+
|
| 132 |
+
If build fails, check Logs for `researchmind` or `uv sync` errors (confirms Blocker 2 fix).
|
| 133 |
+
|
| 134 |
+
---
|
| 135 |
+
|
| 136 |
+
## Phase 3 — Push to GitHub
|
| 137 |
+
|
| 138 |
+
1. Create a GitHub repo (if not already linked)
|
| 139 |
+
2. Push `main` with at minimum:
|
| 140 |
+
- `Dockerfile`, `README.md`, `pyproject.toml`, `uv.lock`
|
| 141 |
+
- `apps/gradio-space/`, `libs/`, `skills/`, `models.yaml`, `voice_models.yaml`
|
| 142 |
+
|
| 143 |
+
Do **not** commit `.env`, local `models/*.gguf`, or large artifacts (`.dockerignore` already excludes these).
|
| 144 |
+
|
| 145 |
+
---
|
| 146 |
+
|
| 147 |
+
## Phase 4 — Create and link the Space
|
| 148 |
+
|
| 149 |
+
1. Go to [build-small-hackathon](https://huggingface.co/build-small-hackathon) → **New Space**
|
| 150 |
+
2. Settings:
|
| 151 |
+
- **Name:** e.g. `lesson-agent` or `small-model-hackathon`
|
| 152 |
+
- **SDK:** **Docker** (not Gradio SDK — monorepo needs root Dockerfile)
|
| 153 |
+
- **Hardware:** **GPU basic** (required for transformers `minicpm5-1b`)
|
| 154 |
+
3. Under **Repository** → connect your GitHub repo and branch (`main`)
|
| 155 |
+
4. HF will auto-build from root `Dockerfile` on each push
|
| 156 |
+
|
| 157 |
+
---
|
| 158 |
+
|
| 159 |
+
## Phase 5 — Space environment variables
|
| 160 |
+
|
| 161 |
+
In Space **Settings → Variables and secrets** (Repository secrets, not `.env` in git):
|
| 162 |
+
|
| 163 |
+
| Variable | Value | Why |
|
| 164 |
+
|----------|-------|-----|
|
| 165 |
+
| `ACTIVE_MODEL` | `minicpm5-1b` | Pins model for visitors |
|
| 166 |
+
| `ALLOW_MODEL_SWITCH` | `false` | Hides dev model dropdown |
|
| 167 |
+
| `AGENT_OUTPUTS_DIR` | `/tmp/agent_outputs` | Already set in Dockerfile; optional override |
|
| 168 |
+
| `RESEARCHMIND_DATA_DIR` | `/tmp/researchmind` | Ephemeral RAG store on Space (recommended) |
|
| 169 |
+
|
| 170 |
+
No secrets required for the default MiniCPM5 preset unless you switch to a gated model.
|
| 171 |
+
|
| 172 |
+
---
|
| 173 |
+
|
| 174 |
+
## Phase 6 — Verify publish
|
| 175 |
+
|
| 176 |
+
1. Open Space **Logs** — wait for `Running on local URL: 0.0.0.0:7860`
|
| 177 |
+
2. Open the Space URL
|
| 178 |
+
3. Smoke test:
|
| 179 |
+
- `/` — Studio loads
|
| 180 |
+
- Generate slides with a simple topic (e.g. "Photosynthesis, grade 8, 5 slides")
|
| 181 |
+
- `/classic` — tabs render
|
| 182 |
+
4. First inference may be slow while `openbmb/MiniCPM5-1B` downloads
|
| 183 |
+
|
| 184 |
+
### Optional: faster restarts
|
| 185 |
+
|
| 186 |
+
If cold starts are painful, add a **Storage Bucket** in Space settings so Hub model cache persists across restarts.
|
| 187 |
+
|
| 188 |
+
---
|
| 189 |
+
|
| 190 |
+
## Troubleshooting
|
| 191 |
+
|
| 192 |
+
| Symptom | Fix |
|
| 193 |
+
|---------|-----|
|
| 194 |
+
| Space card shows wrong title / no Docker | Fix README YAML (`title:` not `## title:`) |
|
| 195 |
+
| Docker build fails at `uv sync` | Add `researchmind` to Dockerfile |
|
| 196 |
+
| Build OK but app crashes on Research tab | Confirm `researchmind` src is copied |
|
| 197 |
+
| First request very slow | Normal — model download; use Storage Bucket |
|
| 198 |
+
| OOM on GPU | Try smaller batch or switch preset to GGUF on CPU |
|
| 199 |
+
|
| 200 |
+
Full reference: [`USAGE.md`](USAGE.md) sections "Docker smoke test" and "Hugging Face Space deployment".
|
| 201 |
+
|
| 202 |
+
---
|
| 203 |
+
|
| 204 |
+
## What you do NOT need
|
| 205 |
+
|
| 206 |
+
- Plain Gradio SDK (`app.py` + `requirements.txt` at root) — wrong fit for this monorepo
|
| 207 |
+
- Committing GGUF files — models download from Hub at runtime via `ACTIVE_MODEL` / `models.yaml`
|
| 208 |
+
- Changing the CMD — current entrypoint already serves Studio + Classic
|
.env.example
CHANGED
|
@@ -66,14 +66,4 @@ ALLOW_MODEL_SWITCH=false
|
|
| 66 |
# For Cohere Transcribe ASR: huggingface-cli login + accept model terms, then:
|
| 67 |
# ECHOCOACH_ASR_PRESET=cohere-transcribe
|
| 68 |
|
| 69 |
-
# --- Ensemble research (research/ensemble/) ---
|
| 70 |
-
# Base LLM resolution (first match wins): ENSEMBLE_LLM, LLM_PATH, BASE, MODEL_ID, ACTIVE_MODEL
|
| 71 |
-
# LLM_PATH=./models/finetuned/minicpm5-1b-lora-merged
|
| 72 |
-
# ENSEMBLE_LLM=Qwen/Qwen2.5-0.5B-Instruct
|
| 73 |
-
# ENSEMBLE_PRESET=minicpm5-1b
|
| 74 |
-
# ENSEMBLE_OUT=./models/ensemble/minicpm5-1b-jepa-pretrain
|
| 75 |
-
# ENSEMBLE_QA=./research/data/benchmark-qa.jsonl
|
| 76 |
-
# ENSEMBLE_KB=./research/data/benchmark-kb.jsonl
|
| 77 |
-
# ENSEMBLE_CKPT=./models/ensemble/jepa-lesson-pretrain
|
| 78 |
-
|
| 79 |
BASE=openbmb/MiniCPM5-1B
|
|
|
|
| 66 |
# For Cohere Transcribe ASR: huggingface-cli login + accept model terms, then:
|
| 67 |
# ECHOCOACH_ASR_PRESET=cohere-transcribe
|
| 68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
BASE=openbmb/MiniCPM5-1B
|
.gitignore
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
.venv/
|
|
|
|
| 2 |
__pycache__/
|
| 3 |
*.py[cod]
|
| 4 |
.env
|
|
|
|
| 1 |
.venv/
|
| 2 |
+
.venv-gradio/
|
| 3 |
__pycache__/
|
| 4 |
*.py[cod]
|
| 5 |
.env
|
Dockerfile
CHANGED
|
@@ -20,11 +20,13 @@ COPY apps/gradio-space/pyproject.toml apps/gradio-space/README.md apps/gradio-sp
|
|
| 20 |
COPY libs/inference/pyproject.toml libs/inference/README.md libs/inference/
|
| 21 |
COPY libs/agent/pyproject.toml libs/agent/README.md libs/agent/
|
| 22 |
COPY libs/echocoach/pyproject.toml libs/echocoach/README.md libs/echocoach/
|
|
|
|
| 23 |
COPY apps/gradio-space/src apps/gradio-space/src
|
| 24 |
COPY apps/gradio-space/static apps/gradio-space/static
|
| 25 |
COPY libs/inference/src libs/inference/src
|
| 26 |
COPY libs/agent/src libs/agent/src
|
| 27 |
COPY libs/echocoach/src libs/echocoach/src
|
|
|
|
| 28 |
COPY skills skills
|
| 29 |
|
| 30 |
RUN useradd -m -u 1000 user && \
|
|
|
|
| 20 |
COPY libs/inference/pyproject.toml libs/inference/README.md libs/inference/
|
| 21 |
COPY libs/agent/pyproject.toml libs/agent/README.md libs/agent/
|
| 22 |
COPY libs/echocoach/pyproject.toml libs/echocoach/README.md libs/echocoach/
|
| 23 |
+
COPY libs/researchmind/pyproject.toml libs/researchmind/README.md libs/researchmind/
|
| 24 |
COPY apps/gradio-space/src apps/gradio-space/src
|
| 25 |
COPY apps/gradio-space/static apps/gradio-space/static
|
| 26 |
COPY libs/inference/src libs/inference/src
|
| 27 |
COPY libs/agent/src libs/agent/src
|
| 28 |
COPY libs/echocoach/src libs/echocoach/src
|
| 29 |
+
COPY libs/researchmind/src libs/researchmind/src
|
| 30 |
COPY skills skills
|
| 31 |
|
| 32 |
RUN useradd -m -u 1000 user && \
|
README.md
CHANGED
|
@@ -1,13 +1,15 @@
|
|
| 1 |
---
|
| 2 |
-
|
| 3 |
-
## title: Lesson Agent
|
| 4 |
emoji: 📚
|
| 5 |
colorFrom: blue
|
| 6 |
colorTo: green
|
| 7 |
-
sdk:
|
| 8 |
-
|
|
|
|
|
|
|
| 9 |
pinned: false
|
| 10 |
license: apache-2.0
|
|
|
|
| 11 |
|
| 12 |
# Lesson Agent
|
| 13 |
|
|
@@ -15,7 +17,7 @@ license: apache-2.0
|
|
| 15 |
|
| 16 |
A local skill-based agent helps a teacher you know turn a **topic + grade level** into a downloadable **PowerPoint** — powered by a small transformers model (`MiniCPM5-1B` by default), no cloud LLM API.
|
| 17 |
|
| 18 |
-
See **[USAGE.md](USAGE.md)** for local run,
|
| 19 |
|
| 20 |
## Prerequisites
|
| 21 |
|
|
@@ -59,7 +61,7 @@ libs/agent/ # Skill agent runner, tools, trace recorder
|
|
| 59 |
libs/researchmind/ # Scraper, chunk/embed, MemRAG SQLite store, retrieval
|
| 60 |
libs/inference/ # Transformers + llama.cpp backends
|
| 61 |
skills/ # SKILL.md + references/ + scripts/ per task
|
| 62 |
-
research/ # Fine-tune
|
| 63 |
```
|
| 64 |
|
| 65 |
### ResearchMind (offline after ingest)
|
|
@@ -87,15 +89,12 @@ See [`.env.example`](.env.example) and [`models.yaml`](models.yaml) for model pr
|
|
| 87 |
|
| 88 |
## Hugging Face Space deployment
|
| 89 |
|
| 90 |
-
1. Create a Space under [build-small-hackathon](https://huggingface.co/build-small-hackathon) with **
|
| 91 |
-
2. Link this repository
|
| 92 |
-
3. Hardware: **GPU basic**
|
| 93 |
-
4.
|
| 94 |
|
| 95 |
-
```
|
| 96 |
-
docker build -t hackathon-space .
|
| 97 |
-
docker run --rm -p 7860:7860 -e ACTIVE_MODEL=minicpm5-1b hackathon-space
|
| 98 |
-
```
|
| 99 |
|
| 100 |
## Hackathon checklist
|
| 101 |
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Lesson Agent
|
|
|
|
| 3 |
emoji: 📚
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: green
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: "6.16.0"
|
| 8 |
+
app_file: app.py
|
| 9 |
+
python_version: "3.12"
|
| 10 |
pinned: false
|
| 11 |
license: apache-2.0
|
| 12 |
+
---
|
| 13 |
|
| 14 |
# Lesson Agent
|
| 15 |
|
|
|
|
| 17 |
|
| 18 |
A local skill-based agent helps a teacher you know turn a **topic + grade level** into a downloadable **PowerPoint** — powered by a small transformers model (`MiniCPM5-1B` by default), no cloud LLM API.
|
| 19 |
|
| 20 |
+
See **[USAGE.md](USAGE.md)** for local run, Gradio SDK / ZeroGPU Space deployment, and Docker (later).
|
| 21 |
|
| 22 |
## Prerequisites
|
| 23 |
|
|
|
|
| 61 |
libs/researchmind/ # Scraper, chunk/embed, MemRAG SQLite store, retrieval
|
| 62 |
libs/inference/ # Transformers + llama.cpp backends
|
| 63 |
skills/ # SKILL.md + references/ + scripts/ per task
|
| 64 |
+
research/ # Fine-tune and agentic evals (optional)
|
| 65 |
```
|
| 66 |
|
| 67 |
### ResearchMind (offline after ingest)
|
|
|
|
| 89 |
|
| 90 |
## Hugging Face Space deployment
|
| 91 |
|
| 92 |
+
1. Create a Space under [build-small-hackathon](https://huggingface.co/build-small-hackathon) with **Gradio** SDK (Blank template).
|
| 93 |
+
2. Link this repository — HF builds from root `app.py` + `requirements.txt` (README YAML above).
|
| 94 |
+
3. Hardware: **ZeroGPU** for burst GPU inference, or **GPU basic** for always-on GPU.
|
| 95 |
+
4. Set `ACTIVE_MODEL=minicpm5-1b`, `ALLOW_MODEL_SWITCH=false`, `RESEARCHMIND_DATA_DIR=/tmp/researchmind`.
|
| 96 |
|
| 97 |
+
A root `Dockerfile` is kept for a later **Docker SDK** deploy (flip README to `sdk: docker`). See [USAGE.md](USAGE.md).
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
## Hackathon checklist
|
| 100 |
|
USAGE.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
# Usage
|
| 2 |
|
| 3 |
-
How to run the **Lesson Agent** Gradio app locally,
|
| 4 |
|
| 5 |
The primary UI is the **Lesson slides** tab (topic → local model outline → downloadable `.pptx`). Use **ResearchMind** for corpus Q&A, **TeacherVoice** for spoken back-and-forth tutoring, **EchoCoach** for one-shot pitch analysis, or ground lessons directly from the Lesson tab. The **Chat (debug)** tab tests the underlying model.
|
| 6 |
|
|
@@ -223,98 +223,121 @@ INFERENCE_BACKEND=transformers MODEL_ID=Qwen/Qwen2.5-3B-Instruct \
|
|
| 223 |
|
| 224 |
---
|
| 225 |
|
| 226 |
-
##
|
| 227 |
|
| 228 |
-
|
| 229 |
|
| 230 |
```bash
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
-e MODEL_FILE=qwen2.5-3b-instruct-q4_k_m.gguf \
|
| 235 |
-
-e N_CTX=4096 \
|
| 236 |
-
-e N_GPU_LAYERS=0 \
|
| 237 |
-
hackathon-space
|
| 238 |
```
|
| 239 |
|
| 240 |
-
Open [http://localhost:7860](http://localhost:7860) — Studio at `/`, Classic
|
| 241 |
-
|
| 242 |
-
To use a pre-downloaded local model inside Docker, mount it and set `MODEL_PATH`:
|
| 243 |
|
| 244 |
-
```
|
| 245 |
-
docker run --rm -p 7860:7860 \
|
| 246 |
-
-v "$(pwd)/models:/app/models:ro" \
|
| 247 |
-
-e MODEL_PATH=/app/models/qwen2.5-3b-instruct-q4_k_m.gguf \
|
| 248 |
-
hackathon-space
|
| 249 |
-
```
|
| 250 |
|
| 251 |
---
|
| 252 |
|
| 253 |
-
## Hugging Face Space deployment
|
| 254 |
|
| 255 |
-
|
| 256 |
|
| 257 |
### 1. Push code to GitHub
|
| 258 |
|
| 259 |
-
Make sure `main`
|
| 260 |
|
| 261 |
-
- `
|
| 262 |
-
- `README.md` (with `sdk:
|
| 263 |
-
- `
|
| 264 |
-
- `apps/gradio-space/` and `libs/
|
|
|
|
|
|
|
| 265 |
|
| 266 |
### 2. Create the Space
|
| 267 |
|
| 268 |
1. Go to [build-small-hackathon](https://huggingface.co/build-small-hackathon)
|
| 269 |
2. **New Space**
|
| 270 |
-
3. Name: e.g. `small-model-hackathon`
|
| 271 |
-
4. SDK: **
|
| 272 |
-
5.
|
|
|
|
| 273 |
|
| 274 |
CLI alternative (if you have `hf` installed and org access):
|
| 275 |
|
| 276 |
```bash
|
| 277 |
hf repo create build-small-hackathon/<your-space-name> \
|
| 278 |
--repo-type space \
|
| 279 |
-
--space_sdk
|
| 280 |
```
|
| 281 |
|
| 282 |
-
### 3.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
|
|
|
|
| 284 |
|
| 285 |
-
|
| 286 |
-
| -------- | ------------------------------------------------------------ |
|
| 287 |
-
| Hardware | **CPU basic** to start (llama.cpp with `N_GPU_LAYERS=0`) |
|
| 288 |
-
| Upgrade | GPU Space if you set `N_GPU_LAYERS > 0` for faster inference |
|
| 289 |
|
|
|
|
| 290 |
|
| 291 |
-
|
|
|
|
| 292 |
|
| 293 |
-
|
| 294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
-
|
| 297 |
-
| ------------------- | --------------------------------- |
|
| 298 |
-
| `INFERENCE_BACKEND` | `llama_cpp` |
|
| 299 |
-
| `MODEL_REPO` | `Qwen/Qwen2.5-3B-Instruct-GGUF` |
|
| 300 |
-
| `MODEL_FILE` | `qwen2.5-3b-instruct-q4_k_m.gguf` |
|
| 301 |
-
| `N_CTX` | `4096` |
|
| 302 |
-
| `N_GPU_LAYERS` | `0` (or higher on GPU hardware) |
|
| 303 |
|
|
|
|
| 304 |
|
| 305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
|
| 309 |
```bash
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
```
|
| 312 |
|
| 313 |
-
|
| 314 |
|
| 315 |
-
|
| 316 |
|
| 317 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
|
| 319 |
---
|
| 320 |
|
|
@@ -323,29 +346,24 @@ If cold starts are too slow, attach a **Storage Bucket** in Space settings so do
|
|
| 323 |
|
| 324 |
| Symptom | Likely cause | Fix |
|
| 325 |
| ---------------------------------------- | --------------------------------- | -------------------------------------------------------------------- |
|
| 326 |
-
| First chat hangs / slow |
|
| 327 |
-
| `Failed to load model` in chat | Wrong `
|
| 328 |
-
|
|
| 329 |
-
| Space build fails |
|
| 330 |
-
|
|
| 331 |
-
|
|
|
|
|
| 332 |
|
| 333 |
|
| 334 |
---
|
| 335 |
|
| 336 |
## Entrypoint summary
|
| 337 |
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
``
|
| 341 |
-
|
| 342 |
-
``
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
| Environment | How to run |
|
| 346 |
-
| ----------- | ---------------------------------------------------------- |
|
| 347 |
-
| Local dev | `uv run --package gradio-space python -m gradio_space.app` |
|
| 348 |
-
| Docker | `docker run -p 7860:7860 hackathon-space` |
|
| 349 |
-
| HF Space | Built and started automatically from `Dockerfile` `CMD` |
|
| 350 |
|
| 351 |
|
|
|
|
| 1 |
# Usage
|
| 2 |
|
| 3 |
+
How to run the **Lesson Agent** Gradio app locally, deploy to a Hugging Face Space (Gradio SDK + ZeroGPU), and optionally test with Docker later for the [Build Small Hackathon](https://huggingface.co/build-small-hackathon).
|
| 4 |
|
| 5 |
The primary UI is the **Lesson slides** tab (topic → local model outline → downloadable `.pptx`). Use **ResearchMind** for corpus Q&A, **TeacherVoice** for spoken back-and-forth tutoring, **EchoCoach** for one-shot pitch analysis, or ground lessons directly from the Lesson tab. The **Chat (debug)** tab tests the underlying model.
|
| 6 |
|
|
|
|
| 223 |
|
| 224 |
---
|
| 225 |
|
| 226 |
+
## Gradio SDK local smoke test (matches HF Space build)
|
| 227 |
|
| 228 |
+
Before pushing to Hugging Face, verify the Gradio SDK entry point:
|
| 229 |
|
| 230 |
```bash
|
| 231 |
+
python -m venv .venv-gradio && source .venv-gradio/bin/activate
|
| 232 |
+
pip install -r requirements.txt
|
| 233 |
+
ACTIVE_MODEL=minicpm5-1b ALLOW_MODEL_SWITCH=false python app.py
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
```
|
| 235 |
|
| 236 |
+
Open [http://localhost:7860](http://localhost:7860) — Studio at `/`, Classic at `/classic`.
|
|
|
|
|
|
|
| 237 |
|
| 238 |
+
Day-to-day development can still use `uv run` (see above); this path mirrors what HF installs from `requirements.txt`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
---
|
| 241 |
|
| 242 |
+
## Hugging Face Space deployment (Gradio SDK + ZeroGPU)
|
| 243 |
|
| 244 |
+
The Space card metadata lives in the YAML frontmatter at the top of [README.md](README.md) (`sdk: gradio`, `app_file: app.py`).
|
| 245 |
|
| 246 |
### 1. Push code to GitHub
|
| 247 |
|
| 248 |
+
Make sure `main` contains at minimum:
|
| 249 |
|
| 250 |
+
- `app.py`, `requirements.txt`, `packages.txt`
|
| 251 |
+
- `README.md` (with `sdk: gradio`, `sdk_version`, `app_file: app.py`)
|
| 252 |
+
- `models.yaml`, `skills/`
|
| 253 |
+
- `apps/gradio-space/` and all `libs/*` packages
|
| 254 |
+
|
| 255 |
+
The root `Dockerfile` stays in the repo for a later Docker SDK deploy (see below).
|
| 256 |
|
| 257 |
### 2. Create the Space
|
| 258 |
|
| 259 |
1. Go to [build-small-hackathon](https://huggingface.co/build-small-hackathon)
|
| 260 |
2. **New Space**
|
| 261 |
+
3. Name: e.g. `lesson-agent` or `small-model-hackathon`
|
| 262 |
+
4. SDK: **Gradio** (Blank template)
|
| 263 |
+
5. Hardware: **ZeroGPU** (creator needs PRO/Team) or **GPU basic**
|
| 264 |
+
6. Link your GitHub repo, or push directly to the Space git remote
|
| 265 |
|
| 266 |
CLI alternative (if you have `hf` installed and org access):
|
| 267 |
|
| 268 |
```bash
|
| 269 |
hf repo create build-small-hackathon/<your-space-name> \
|
| 270 |
--repo-type space \
|
| 271 |
+
--space_sdk gradio
|
| 272 |
```
|
| 273 |
|
| 274 |
+
### 3. Set Space environment variables
|
| 275 |
+
|
| 276 |
+
In the Space **Settings → Variables and secrets**:
|
| 277 |
+
|
| 278 |
+
| Variable | Value |
|
| 279 |
+
| -------- | ----- |
|
| 280 |
+
| `ACTIVE_MODEL` | `minicpm5-1b` |
|
| 281 |
+
| `ALLOW_MODEL_SWITCH` | `false` |
|
| 282 |
+
| `RESEARCHMIND_DATA_DIR` | `/tmp/researchmind` |
|
| 283 |
|
| 284 |
+
Default preset in [`models.yaml`](models.yaml) is `minicpm5-1b` (transformers) — suitable for ZeroGPU.
|
| 285 |
|
| 286 |
+
### 4. Build and verify
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
+
HF installs from `requirements.txt` and runs root `app.py`. Check the **Logs** tab for:
|
| 289 |
|
| 290 |
+
- Successful pip install (first build may take several minutes — `llama-cpp-python` compiles)
|
| 291 |
+
- `Running on local URL: 0.0.0.0:7860`
|
| 292 |
|
| 293 |
+
Smoke test on the live Space:
|
| 294 |
|
| 295 |
+
1. **`/`** — Studio UI loads
|
| 296 |
+
2. **`/classic`** — all tabs render
|
| 297 |
+
3. Generate slides with a simple topic (e.g. "Photosynthesis, grade 8, 5 slides")
|
| 298 |
+
4. First LLM request may be slow (model download + ZeroGPU queue)
|
| 299 |
|
| 300 |
+
### 5. ZeroGPU notes
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
+
LLM handlers use `@spaces.GPU` via [`gradio_space/spaces_runtime.py`](apps/gradio-space/src/gradio_space/spaces_runtime.py). If you see **No CUDA GPUs are available**, an inference path is running outside a decorated handler.
|
| 303 |
|
| 304 |
+
Startup model preload is skipped on HF Gradio runtime; the first user request loads the model inside a GPU task.
|
| 305 |
+
|
| 306 |
+
### 6. Optional: persistent model cache
|
| 307 |
+
|
| 308 |
+
Attach a **Storage Bucket** in Space settings so Hub model weights survive restarts.
|
| 309 |
+
|
| 310 |
+
---
|
| 311 |
|
| 312 |
+
## Docker SDK deployment (later)
|
| 313 |
+
|
| 314 |
+
Both deploy paths live on the same branch. HF reads **one** `sdk:` from README — switch to Docker when you are ready for a dedicated-GPU Space.
|
| 315 |
+
|
| 316 |
+
1. Change [README.md](README.md) frontmatter to `sdk: docker`, `app_port: 7860` (remove `sdk_version` / `app_file`)
|
| 317 |
+
2. Create or reconfigure a Space with **Docker** SDK and **GPU basic** hardware
|
| 318 |
+
3. Set the same env vars (`ACTIVE_MODEL=minicpm5-1b`, etc.)
|
| 319 |
+
|
| 320 |
+
### Local Docker smoke test
|
| 321 |
|
| 322 |
```bash
|
| 323 |
+
docker build -t hackathon-space .
|
| 324 |
+
docker run --rm -p 7860:7860 \
|
| 325 |
+
-e ACTIVE_MODEL=minicpm5-1b \
|
| 326 |
+
-e ALLOW_MODEL_SWITCH=false \
|
| 327 |
+
-e RESEARCHMIND_DATA_DIR=/tmp/researchmind \
|
| 328 |
+
hackathon-space
|
| 329 |
```
|
| 330 |
|
| 331 |
+
Open [http://localhost:7860](http://localhost:7860) — Studio at `/`, Classic tabs at `/classic`. Stop with `Ctrl+C`.
|
| 332 |
|
| 333 |
+
To use a pre-downloaded local GGUF model inside Docker, mount it and set `MODEL_PATH`:
|
| 334 |
|
| 335 |
+
```bash
|
| 336 |
+
docker run --rm -p 7860:7860 \
|
| 337 |
+
-v "$(pwd)/models:/app/models:ro" \
|
| 338 |
+
-e MODEL_PATH=/app/models/qwen2.5-3b-instruct-q4_k_m.gguf \
|
| 339 |
+
hackathon-space
|
| 340 |
+
```
|
| 341 |
|
| 342 |
---
|
| 343 |
|
|
|
|
| 346 |
|
| 347 |
| Symptom | Likely cause | Fix |
|
| 348 |
| ---------------------------------------- | --------------------------------- | -------------------------------------------------------------------- |
|
| 349 |
+
| First chat hangs / slow | Model downloading from Hub | Wait on Space; use Storage Bucket for cache |
|
| 350 |
+
| `Failed to load model` in chat | Wrong `ACTIVE_MODEL` preset | Use `minicpm5-1b` or valid key from `models.yaml` |
|
| 351 |
+
| Space build fails on pip install | `llama-cpp-python` compile | Check Logs; default preset avoids GGUF at runtime |
|
| 352 |
+
| Space build fails | Malformed README YAML | Ensure `sdk: gradio` and `app_file: app.py` in README frontmatter |
|
| 353 |
+
| No CUDA GPUs on ZeroGPU | Handler outside `@spaces.GPU` | LLM entry points must use `gpu_task` in `spaces_runtime.py` |
|
| 354 |
+
| Docker build fails on `llama-cpp-python` | Missing build tools | Dockerfile installs `build-essential` and `cmake` |
|
| 355 |
+
| Port already in use locally | Another process on 7860 | `PORT=7861 python app.py` or `uv run ...` |
|
| 356 |
|
| 357 |
|
| 358 |
---
|
| 359 |
|
| 360 |
## Entrypoint summary
|
| 361 |
|
| 362 |
+
| Environment | How to run |
|
| 363 |
+
| ----------- | ---------- |
|
| 364 |
+
| Local dev (uv) | `uv run --package gradio-space python -m gradio_space.app` |
|
| 365 |
+
| Local Gradio SDK smoke | `pip install -r requirements.txt && python app.py` |
|
| 366 |
+
| HF Gradio Space | HF runs root `app.py` automatically |
|
| 367 |
+
| Docker (later) | `docker run -p 7860:7860 hackathon-space` (after README `sdk: docker`) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
|
app.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hugging Face Gradio SDK entry point (ZeroGPU / Gradio Spaces)."""
|
| 2 |
+
|
| 3 |
+
from gradio_space.server import main
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
main()
|
apps/gradio-space/src/gradio_space/model_loading.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from inference.config import get_app_config, get_model_config
|
| 2 |
from inference.factory import get_backend, reset_backend
|
| 3 |
from inference.response_clean import strip_reasoning_output
|
|
@@ -74,6 +75,7 @@ def warmup(model_key: str | None = None) -> str:
|
|
| 74 |
)
|
| 75 |
|
| 76 |
|
|
|
|
| 77 |
def reload_model(model_key: str) -> str:
|
| 78 |
"""Clear cached backend and reload weights for settings panel."""
|
| 79 |
global _current_model_key
|
|
@@ -120,6 +122,7 @@ def _history_to_messages(history: list) -> list[dict[str, str]]:
|
|
| 120 |
return messages
|
| 121 |
|
| 122 |
|
|
|
|
| 123 |
def chat(message: str, history: list, model_key: str) -> str:
|
| 124 |
load_error = ensure_model_loaded(model_key)
|
| 125 |
if load_error:
|
|
|
|
| 1 |
+
from gradio_space.spaces_runtime import gpu_task
|
| 2 |
from inference.config import get_app_config, get_model_config
|
| 3 |
from inference.factory import get_backend, reset_backend
|
| 4 |
from inference.response_clean import strip_reasoning_output
|
|
|
|
| 75 |
)
|
| 76 |
|
| 77 |
|
| 78 |
+
@gpu_task(duration=120)
|
| 79 |
def reload_model(model_key: str) -> str:
|
| 80 |
"""Clear cached backend and reload weights for settings panel."""
|
| 81 |
global _current_model_key
|
|
|
|
| 122 |
return messages
|
| 123 |
|
| 124 |
|
| 125 |
+
@gpu_task(duration=60)
|
| 126 |
def chat(message: str, history: list, model_key: str) -> str:
|
| 127 |
load_error = ensure_model_loaded(model_key)
|
| 128 |
if load_error:
|
apps/gradio-space/src/gradio_space/research_helpers.py
CHANGED
|
@@ -8,6 +8,7 @@ import gradio as gr
|
|
| 8 |
from agent.models import ResearchIngestResult
|
| 9 |
from agent.runner import AgentRunner
|
| 10 |
from gradio_space.model_loading import chat, ensure_model_loaded, get_active_model_key
|
|
|
|
| 11 |
from inference.factory import get_backend
|
| 12 |
from researchmind.ingest import IngestPipeline
|
| 13 |
|
|
@@ -209,6 +210,7 @@ def rag_scope_hint(session_id: str, doc_ids: list[str] | None) -> str:
|
|
| 209 |
return "RAG scope: **entire** indexed corpus (all sessions)."
|
| 210 |
|
| 211 |
|
|
|
|
| 212 |
def run_research_question(
|
| 213 |
question: str,
|
| 214 |
*,
|
|
|
|
| 8 |
from agent.models import ResearchIngestResult
|
| 9 |
from agent.runner import AgentRunner
|
| 10 |
from gradio_space.model_loading import chat, ensure_model_loaded, get_active_model_key
|
| 11 |
+
from gradio_space.spaces_runtime import gpu_task
|
| 12 |
from inference.factory import get_backend
|
| 13 |
from researchmind.ingest import IngestPipeline
|
| 14 |
|
|
|
|
| 210 |
return "RAG scope: **entire** indexed corpus (all sessions)."
|
| 211 |
|
| 212 |
|
| 213 |
+
@gpu_task(duration=180)
|
| 214 |
def run_research_question(
|
| 215 |
question: str,
|
| 216 |
*,
|
apps/gradio-space/src/gradio_space/server.py
CHANGED
|
@@ -12,6 +12,7 @@ from gradio import mount_gradio_app
|
|
| 12 |
from gradio_space.api.studio import register_studio_apis
|
| 13 |
from gradio_space.app import build_demo
|
| 14 |
from gradio_space.model_loading import preload_active_model
|
|
|
|
| 15 |
from gradio_space.tabs.education_pptx import gradio_allowed_paths
|
| 16 |
from gradio_space.tabs.echo_coach import echo_coach_allowed_paths
|
| 17 |
from gradio_space.tabs.research_mind import researchmind_allowed_paths
|
|
@@ -66,7 +67,8 @@ def create_server() -> gr.Server:
|
|
| 66 |
|
| 67 |
|
| 68 |
def main() -> None:
|
| 69 |
-
|
|
|
|
| 70 |
server = create_server()
|
| 71 |
port = int(os.environ.get("PORT", "7860"))
|
| 72 |
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
|
|
|
| 12 |
from gradio_space.api.studio import register_studio_apis
|
| 13 |
from gradio_space.app import build_demo
|
| 14 |
from gradio_space.model_loading import preload_active_model
|
| 15 |
+
from gradio_space.spaces_runtime import is_hf_gradio_runtime
|
| 16 |
from gradio_space.tabs.education_pptx import gradio_allowed_paths
|
| 17 |
from gradio_space.tabs.echo_coach import echo_coach_allowed_paths
|
| 18 |
from gradio_space.tabs.research_mind import researchmind_allowed_paths
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
def main() -> None:
|
| 70 |
+
if not is_hf_gradio_runtime():
|
| 71 |
+
preload_active_model()
|
| 72 |
server = create_server()
|
| 73 |
port = int(os.environ.get("PORT", "7860"))
|
| 74 |
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
apps/gradio-space/src/gradio_space/spaces_runtime.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hugging Face Spaces ZeroGPU helpers."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
from collections.abc import Callable
|
| 7 |
+
from typing import ParamSpec, TypeVar
|
| 8 |
+
|
| 9 |
+
P = ParamSpec("P")
|
| 10 |
+
R = TypeVar("R")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def is_hf_gradio_runtime() -> bool:
|
| 14 |
+
"""True on Hugging Face Gradio SDK Spaces (skip startup model preload)."""
|
| 15 |
+
try:
|
| 16 |
+
import spaces # noqa: F401
|
| 17 |
+
except ImportError:
|
| 18 |
+
return False
|
| 19 |
+
return bool(os.environ.get("SPACE_ID"))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def gpu_task(
|
| 23 |
+
*,
|
| 24 |
+
duration: int = 180,
|
| 25 |
+
size: str = "large",
|
| 26 |
+
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
| 27 |
+
"""Apply @spaces.GPU when the HF spaces runtime is present (no-op elsewhere)."""
|
| 28 |
+
|
| 29 |
+
def decorator(fn: Callable[P, R]) -> Callable[P, R]:
|
| 30 |
+
try:
|
| 31 |
+
import spaces
|
| 32 |
+
|
| 33 |
+
return spaces.GPU(duration=duration, size=size)(fn)
|
| 34 |
+
except ImportError:
|
| 35 |
+
return fn
|
| 36 |
+
|
| 37 |
+
return decorator
|
apps/gradio-space/src/gradio_space/tabs/echo_coach.py
CHANGED
|
@@ -7,6 +7,7 @@ import gradio as gr
|
|
| 7 |
from echocoach.config import get_echo_coach_config
|
| 8 |
from echocoach.pipeline import run_echo_coach
|
| 9 |
from gradio_space.model_loading import ensure_model_loaded, get_active_model_key
|
|
|
|
| 10 |
from gradio_space.ui.components import (
|
| 11 |
build_advanced_panel,
|
| 12 |
build_recording_block,
|
|
@@ -64,6 +65,7 @@ def load_sample_pitch() -> tuple[str | None, str]:
|
|
| 64 |
)
|
| 65 |
|
| 66 |
|
|
|
|
| 67 |
def analyze_pitch(
|
| 68 |
audio_path: str | None,
|
| 69 |
language: str,
|
|
|
|
| 7 |
from echocoach.config import get_echo_coach_config
|
| 8 |
from echocoach.pipeline import run_echo_coach
|
| 9 |
from gradio_space.model_loading import ensure_model_loaded, get_active_model_key
|
| 10 |
+
from gradio_space.spaces_runtime import gpu_task
|
| 11 |
from gradio_space.ui.components import (
|
| 12 |
build_advanced_panel,
|
| 13 |
build_recording_block,
|
|
|
|
| 65 |
)
|
| 66 |
|
| 67 |
|
| 68 |
+
@gpu_task(duration=180)
|
| 69 |
def analyze_pitch(
|
| 70 |
audio_path: str | None,
|
| 71 |
language: str,
|
apps/gradio-space/src/gradio_space/tabs/education_pptx.py
CHANGED
|
@@ -16,6 +16,7 @@ from gradio_space.research_helpers import (
|
|
| 16 |
resolve_session,
|
| 17 |
resolve_topic,
|
| 18 |
)
|
|
|
|
| 19 |
from gradio_space.ui.components import build_advanced_panel, DOC_CHOICE_LIST_CLASSES, WorkspaceWidgets
|
| 20 |
from inference.factory import get_backend
|
| 21 |
from researchmind.config import get_config
|
|
@@ -158,6 +159,7 @@ def update_source_visibility(source_mode_label: str, search_workflow_label: str)
|
|
| 158 |
)
|
| 159 |
|
| 160 |
|
|
|
|
| 161 |
def discover_lesson_sources(
|
| 162 |
topic: str,
|
| 163 |
session_id: str,
|
|
@@ -208,6 +210,7 @@ def discover_lesson_sources(
|
|
| 208 |
return msg, gr.update(choices=[], value=[]), refresh_sessions(session_id)
|
| 209 |
|
| 210 |
|
|
|
|
| 211 |
def generate_lesson_slides(
|
| 212 |
topic: str,
|
| 213 |
grade: str,
|
|
|
|
| 16 |
resolve_session,
|
| 17 |
resolve_topic,
|
| 18 |
)
|
| 19 |
+
from gradio_space.spaces_runtime import gpu_task
|
| 20 |
from gradio_space.ui.components import build_advanced_panel, DOC_CHOICE_LIST_CLASSES, WorkspaceWidgets
|
| 21 |
from inference.factory import get_backend
|
| 22 |
from researchmind.config import get_config
|
|
|
|
| 159 |
)
|
| 160 |
|
| 161 |
|
| 162 |
+
@gpu_task(duration=120)
|
| 163 |
def discover_lesson_sources(
|
| 164 |
topic: str,
|
| 165 |
session_id: str,
|
|
|
|
| 210 |
return msg, gr.update(choices=[], value=[]), refresh_sessions(session_id)
|
| 211 |
|
| 212 |
|
| 213 |
+
@gpu_task(duration=300)
|
| 214 |
def generate_lesson_slides(
|
| 215 |
topic: str,
|
| 216 |
grade: str,
|
apps/gradio-space/src/gradio_space/tabs/research_mind.py
CHANGED
|
@@ -23,6 +23,7 @@ from gradio_space.research_helpers import (
|
|
| 23 |
run_research_question,
|
| 24 |
trace_summary_markdown,
|
| 25 |
)
|
|
|
|
| 26 |
from gradio_space.ui.components import build_advanced_panel, DOC_CHOICE_LIST_CLASSES, WorkspaceWidgets
|
| 27 |
from inference.factory import get_backend
|
| 28 |
|
|
@@ -35,6 +36,7 @@ def _require_topic(topic: str | None) -> str | None:
|
|
| 35 |
return None
|
| 36 |
|
| 37 |
|
|
|
|
| 38 |
def discover_sources(
|
| 39 |
topic: str,
|
| 40 |
session_id: str,
|
|
@@ -118,6 +120,7 @@ def discover_sources(
|
|
| 118 |
)
|
| 119 |
|
| 120 |
|
|
|
|
| 121 |
def auto_search_ingest(
|
| 122 |
topic: str,
|
| 123 |
session_id: str,
|
|
@@ -279,6 +282,7 @@ def ingest_selected(
|
|
| 279 |
)
|
| 280 |
|
| 281 |
|
|
|
|
| 282 |
def ask_question(
|
| 283 |
question: str,
|
| 284 |
session_id: str,
|
|
|
|
| 23 |
run_research_question,
|
| 24 |
trace_summary_markdown,
|
| 25 |
)
|
| 26 |
+
from gradio_space.spaces_runtime import gpu_task
|
| 27 |
from gradio_space.ui.components import build_advanced_panel, DOC_CHOICE_LIST_CLASSES, WorkspaceWidgets
|
| 28 |
from inference.factory import get_backend
|
| 29 |
|
|
|
|
| 36 |
return None
|
| 37 |
|
| 38 |
|
| 39 |
+
@gpu_task(duration=120)
|
| 40 |
def discover_sources(
|
| 41 |
topic: str,
|
| 42 |
session_id: str,
|
|
|
|
| 120 |
)
|
| 121 |
|
| 122 |
|
| 123 |
+
@gpu_task(duration=180)
|
| 124 |
def auto_search_ingest(
|
| 125 |
topic: str,
|
| 126 |
session_id: str,
|
|
|
|
| 282 |
)
|
| 283 |
|
| 284 |
|
| 285 |
+
@gpu_task(duration=180)
|
| 286 |
def ask_question(
|
| 287 |
question: str,
|
| 288 |
session_id: str,
|
apps/gradio-space/src/gradio_space/tabs/teacher_voice.py
CHANGED
|
@@ -18,6 +18,7 @@ from gradio_space.research_helpers import (
|
|
| 18 |
resolve_topic,
|
| 19 |
trace_as_dict,
|
| 20 |
)
|
|
|
|
| 21 |
from gradio_space.tabs.research_mind import (
|
| 22 |
auto_search_ingest,
|
| 23 |
discover_sources,
|
|
@@ -87,6 +88,7 @@ def _turn_error(history: list | None, message: str) -> tuple:
|
|
| 87 |
)
|
| 88 |
|
| 89 |
|
|
|
|
| 90 |
def send_turn(
|
| 91 |
audio_path: str | None,
|
| 92 |
history: list,
|
|
@@ -142,6 +144,7 @@ def send_turn(
|
|
| 142 |
return _turn_result(result)
|
| 143 |
|
| 144 |
|
|
|
|
| 145 |
def send_text_turn(
|
| 146 |
message: str,
|
| 147 |
history: list,
|
|
|
|
| 18 |
resolve_topic,
|
| 19 |
trace_as_dict,
|
| 20 |
)
|
| 21 |
+
from gradio_space.spaces_runtime import gpu_task
|
| 22 |
from gradio_space.tabs.research_mind import (
|
| 23 |
auto_search_ingest,
|
| 24 |
discover_sources,
|
|
|
|
| 88 |
)
|
| 89 |
|
| 90 |
|
| 91 |
+
@gpu_task(duration=180)
|
| 92 |
def send_turn(
|
| 93 |
audio_path: str | None,
|
| 94 |
history: list,
|
|
|
|
| 144 |
return _turn_result(result)
|
| 145 |
|
| 146 |
|
| 147 |
+
@gpu_task(duration=180)
|
| 148 |
def send_text_turn(
|
| 149 |
message: str,
|
| 150 |
history: list,
|
apps/gradio-space/static/studio/index.html
CHANGED
|
@@ -291,83 +291,107 @@
|
|
| 291 |
</section>
|
| 292 |
|
| 293 |
<section class="col col-studio">
|
| 294 |
-
<
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
<div class="card">
|
| 304 |
-
<p class="card-title">Teacher Voice Mode</p>
|
| 305 |
-
<div class="mode-cards" id="voice-modes">
|
| 306 |
-
<button type="button" class="mode-card" data-mode="explain">Explain</button>
|
| 307 |
-
<button type="button" class="mode-card active" data-mode="lesson">Coach</button>
|
| 308 |
-
<button type="button" class="mode-card" data-mode="pitch">Practice</button>
|
| 309 |
-
</div>
|
| 310 |
-
<label class="field voice-topic-wrap" id="voice-topic-wrap">
|
| 311 |
-
<span>Focus topic</span>
|
| 312 |
-
<input id="voice-topic" type="text" class="input" placeholder="Uses workspace topic when empty" />
|
| 313 |
-
</label>
|
| 314 |
-
<details class="voice-rag-sources" id="voice-rag-sources">
|
| 315 |
-
<summary>ResearchMind sources (optional)</summary>
|
| 316 |
-
<p class="status-text">Set focus topic, then discover or ingest sources. Enable RAG above to ground answers in your library.</p>
|
| 317 |
-
<div class="ingest-action-row">
|
| 318 |
-
<button type="button" id="btn-voice-discover" class="btn btn-secondary">Discover on web</button>
|
| 319 |
-
<button type="button" id="btn-voice-auto-ingest" class="btn btn-secondary">Auto-ingest from web</button>
|
| 320 |
</div>
|
| 321 |
-
<div
|
| 322 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
</div>
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
<
|
| 331 |
-
<
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
</div>
|
| 347 |
-
<p id="voice-record-status" class="status-text"></p>
|
| 348 |
-
<button type="button" id="btn-voice-send" class="btn btn-secondary btn-block">Send text</button>
|
| 349 |
-
<button type="button" id="btn-voice-audio-send" class="btn btn-primary btn-block">Send voice turn</button>
|
| 350 |
-
</label>
|
| 351 |
-
<p id="voice-turn-status" class="status-text"></p>
|
| 352 |
-
<div class="voice-replay-row">
|
| 353 |
-
<button type="button" id="btn-voice-speak-full" class="btn btn-secondary">Speak full reply</button>
|
| 354 |
-
<button type="button" id="btn-voice-speak-quick" class="btn btn-secondary">Speak first sentence</button>
|
| 355 |
-
<button type="button" id="btn-voice-clear" class="btn btn-ghost">Clear conversation</button>
|
| 356 |
</div>
|
| 357 |
-
<div id="voice-audio-out" class="voice-audio-out"></div>
|
| 358 |
</div>
|
| 359 |
-
<div class="card coach-panel-wrap">
|
| 360 |
-
<
|
| 361 |
-
|
| 362 |
-
<
|
| 363 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
</div>
|
| 365 |
-
<p id="coach-record-status" class="status-text"></p>
|
| 366 |
-
<button type="button" id="btn-coach-sample" class="btn btn-ghost btn-block">Load sample clip</button>
|
| 367 |
-
<label class="field">
|
| 368 |
-
<span>Or upload pitch (WAV)</span>
|
| 369 |
-
<input id="coach-audio" type="file" accept="audio/*" />
|
| 370 |
-
</label>
|
| 371 |
<div class="controls-grid coach-presets">
|
| 372 |
<label class="field">
|
| 373 |
<span>Language</span>
|
|
@@ -378,19 +402,22 @@
|
|
| 378 |
<select id="coach-asr" class="input"></select>
|
| 379 |
</label>
|
| 380 |
</div>
|
| 381 |
-
<label class="toggle-row">
|
| 382 |
<span>Speak full rewrite (VoiceOut)</span>
|
| 383 |
<input id="coach-speak-rewrite" type="checkbox" />
|
| 384 |
</label>
|
| 385 |
-
<button type="button" id="btn-analyze" class="btn btn-
|
| 386 |
-
<div id="coach-panel"></div>
|
| 387 |
</div>
|
| 388 |
</section>
|
| 389 |
|
| 390 |
<section class="col col-debug">
|
| 391 |
-
<div class="card card-tall">
|
| 392 |
-
<
|
| 393 |
-
|
|
|
|
|
|
|
|
|
|
| 394 |
<label class="toggle-row">
|
| 395 |
<span>Use ResearchMind RAG</span>
|
| 396 |
<input id="debug-use-rag" type="checkbox" />
|
|
@@ -418,11 +445,13 @@
|
|
| 418 |
<div id="debug-chat-messages" class="research-chat-messages debug-chat-messages">
|
| 419 |
<p class="research-chat-empty">Send a message to test the active local model.</p>
|
| 420 |
</div>
|
| 421 |
-
<
|
| 422 |
-
<
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
|
|
|
|
|
|
| 426 |
<details class="studio-debug-trace" id="debug-trace-details">
|
| 427 |
<summary>Debug trace</summary>
|
| 428 |
<div id="debug-trace-panel"></div>
|
|
|
|
| 291 |
</section>
|
| 292 |
|
| 293 |
<section class="col col-studio">
|
| 294 |
+
<div class="voice-layout view-voice-only">
|
| 295 |
+
<aside class="voice-rail">
|
| 296 |
+
<div class="card voice-rag-card">
|
| 297 |
+
<p class="card-title">RAG Scope</p>
|
| 298 |
+
<label class="toggle-row">
|
| 299 |
+
<span>Cross-Reference Sources</span>
|
| 300 |
+
<input id="use-rag" type="checkbox" checked />
|
| 301 |
+
</label>
|
| 302 |
+
<p class="status-text">Uses workspace session and documents unless overridden below.</p>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
</div>
|
| 304 |
+
<div class="card voice-rail-controls">
|
| 305 |
+
<p class="card-title">Mode</p>
|
| 306 |
+
<div class="mode-cards voice-mode-cards" id="voice-modes">
|
| 307 |
+
<button type="button" class="mode-card" data-mode="explain">Explain</button>
|
| 308 |
+
<button type="button" class="mode-card active" data-mode="lesson">Coach</button>
|
| 309 |
+
<button type="button" class="mode-card" data-mode="pitch">Practice</button>
|
| 310 |
+
</div>
|
| 311 |
+
<label class="field voice-topic-wrap" id="voice-topic-wrap">
|
| 312 |
+
<span>Focus topic</span>
|
| 313 |
+
<input id="voice-topic" type="text" class="input" placeholder="Uses workspace topic when empty" />
|
| 314 |
+
</label>
|
| 315 |
+
<details class="voice-rag-sources" id="voice-rag-sources">
|
| 316 |
+
<summary>Add sources (optional)</summary>
|
| 317 |
+
<p class="status-text">Discover or ingest sources to ground answers in your library.</p>
|
| 318 |
+
<div class="ingest-action-row">
|
| 319 |
+
<button type="button" id="btn-voice-discover" class="btn btn-secondary">Discover on web</button>
|
| 320 |
+
<button type="button" id="btn-voice-auto-ingest" class="btn btn-secondary">Auto-ingest</button>
|
| 321 |
+
</div>
|
| 322 |
+
<div id="voice-url-choices-panel" class="url-choices-panel hidden">
|
| 323 |
+
<div id="voice-url-choices-list" class="url-choices-list"></div>
|
| 324 |
+
</div>
|
| 325 |
+
<label class="field">
|
| 326 |
+
<span>Paste URLs (one per line)</span>
|
| 327 |
+
<textarea id="voice-urls-text" class="input" rows="2" placeholder="https://…"></textarea>
|
| 328 |
+
</label>
|
| 329 |
+
<label class="upload-zone upload-zone-compact">
|
| 330 |
+
<input id="voice-ingest-file" type="file" accept=".pdf,.docx" multiple hidden />
|
| 331 |
+
<span class="material-symbols-outlined">upload_file</span>
|
| 332 |
+
<span>Upload PDF or Doc</span>
|
| 333 |
+
</label>
|
| 334 |
+
<button type="button" id="btn-voice-ingest" class="btn btn-secondary btn-block">Ingest sources</button>
|
| 335 |
+
<p id="voice-ingest-status" class="status-text"></p>
|
| 336 |
+
</details>
|
| 337 |
</div>
|
| 338 |
+
</aside>
|
| 339 |
+
<div class="voice-main">
|
| 340 |
+
<div class="card voice-main-card">
|
| 341 |
+
<div class="voice-card-head">
|
| 342 |
+
<h2 class="section-label">Teacher Voice</h2>
|
| 343 |
+
<p class="voice-card-desc">Talk with the teacher using text or voice — grounded in your sources when RAG is on.</p>
|
| 344 |
+
</div>
|
| 345 |
+
<div id="voice-chat-messages" class="research-chat-messages voice-chat-messages">
|
| 346 |
+
<p class="research-chat-empty">Type a message or record audio, then send.</p>
|
| 347 |
+
</div>
|
| 348 |
+
<div class="voice-compose" id="voice-panel">
|
| 349 |
+
<label class="field">
|
| 350 |
+
<span>Ask the teacher</span>
|
| 351 |
+
<textarea id="voice-message" class="input" rows="2" placeholder="What is the difference between pretraining and finetuning a small model?"></textarea>
|
| 352 |
+
</label>
|
| 353 |
+
<div class="voice-input-toolbar">
|
| 354 |
+
<div class="recording-row voice-recording-row">
|
| 355 |
+
<button type="button" id="btn-voice-record-start" class="btn btn-secondary">Start mic</button>
|
| 356 |
+
<button type="button" id="btn-voice-record-stop" class="btn btn-secondary" disabled>Stop mic</button>
|
| 357 |
+
<input id="voice-audio-upload" type="file" accept="audio/*" class="input input-compact" />
|
| 358 |
+
</div>
|
| 359 |
+
<p id="voice-record-status" class="status-text voice-record-status"></p>
|
| 360 |
+
</div>
|
| 361 |
+
<div class="voice-send-row">
|
| 362 |
+
<button type="button" id="btn-voice-send" class="btn btn-secondary">Send text</button>
|
| 363 |
+
<button type="button" id="btn-voice-audio-send" class="btn btn-primary">Send voice turn</button>
|
| 364 |
+
</div>
|
| 365 |
+
<p id="voice-turn-status" class="status-text"></p>
|
| 366 |
+
<div class="voice-replay-row">
|
| 367 |
+
<button type="button" id="btn-voice-speak-full" class="btn btn-secondary">Speak full reply</button>
|
| 368 |
+
<button type="button" id="btn-voice-speak-quick" class="btn btn-secondary">Speak first sentence</button>
|
| 369 |
+
<button type="button" id="btn-voice-clear" class="btn btn-ghost">Clear conversation</button>
|
| 370 |
+
</div>
|
| 371 |
+
<div id="voice-audio-out" class="voice-audio-out"></div>
|
| 372 |
+
</div>
|
| 373 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 374 |
</div>
|
|
|
|
| 375 |
</div>
|
| 376 |
+
<div class="card coach-panel-wrap view-coach-only">
|
| 377 |
+
<div class="coach-card-head">
|
| 378 |
+
<h2 class="section-label">EchoCoach · Pitch analysis</h2>
|
| 379 |
+
<p class="coach-card-desc">Record or upload a short pitch for pace, filler highlights, and spoken feedback.</p>
|
| 380 |
+
</div>
|
| 381 |
+
<div class="coach-capture-row">
|
| 382 |
+
<div class="coach-capture-controls">
|
| 383 |
+
<div class="recording-row coach-recording-row">
|
| 384 |
+
<button type="button" id="btn-coach-record-start" class="btn btn-secondary">Start mic</button>
|
| 385 |
+
<button type="button" id="btn-coach-record-stop" class="btn btn-secondary" disabled>Stop mic</button>
|
| 386 |
+
<button type="button" id="btn-coach-sample" class="btn btn-ghost">Load sample</button>
|
| 387 |
+
</div>
|
| 388 |
+
<p id="coach-record-status" class="status-text coach-record-status"></p>
|
| 389 |
+
</div>
|
| 390 |
+
<label class="field coach-upload-field">
|
| 391 |
+
<span>Upload pitch (WAV)</span>
|
| 392 |
+
<input id="coach-audio" type="file" accept="audio/*" />
|
| 393 |
+
</label>
|
| 394 |
</div>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
<div class="controls-grid coach-presets">
|
| 396 |
<label class="field">
|
| 397 |
<span>Language</span>
|
|
|
|
| 402 |
<select id="coach-asr" class="input"></select>
|
| 403 |
</label>
|
| 404 |
</div>
|
| 405 |
+
<label class="toggle-row coach-voiceout-toggle">
|
| 406 |
<span>Speak full rewrite (VoiceOut)</span>
|
| 407 |
<input id="coach-speak-rewrite" type="checkbox" />
|
| 408 |
</label>
|
| 409 |
+
<button type="button" id="btn-analyze" class="btn btn-primary btn-block coach-analyze-btn">Analyze pitch</button>
|
| 410 |
+
<div id="coach-panel" class="coach-results-panel"></div>
|
| 411 |
</div>
|
| 412 |
</section>
|
| 413 |
|
| 414 |
<section class="col col-debug">
|
| 415 |
+
<div class="card card-tall coach-debug-card">
|
| 416 |
+
<div class="coach-card-head">
|
| 417 |
+
<h2 class="section-label">Chat (debug)</h2>
|
| 418 |
+
<p class="coach-card-desc view-coach-only">Plain chat or corpus-grounded answers — traces appear below when RAG is on.</p>
|
| 419 |
+
<p class="status-text view-debug-only">Plain chat or corpus-grounded answers — traces appear below when RAG is on.</p>
|
| 420 |
+
</div>
|
| 421 |
<label class="toggle-row">
|
| 422 |
<span>Use ResearchMind RAG</span>
|
| 423 |
<input id="debug-use-rag" type="checkbox" />
|
|
|
|
| 445 |
<div id="debug-chat-messages" class="research-chat-messages debug-chat-messages">
|
| 446 |
<p class="research-chat-empty">Send a message to test the active local model.</p>
|
| 447 |
</div>
|
| 448 |
+
<div class="coach-debug-compose">
|
| 449 |
+
<label class="field">
|
| 450 |
+
<span>Message</span>
|
| 451 |
+
<textarea id="debug-message" class="input" rows="2" placeholder="Hello, model…"></textarea>
|
| 452 |
+
</label>
|
| 453 |
+
<button type="button" id="btn-debug-send" class="btn btn-primary btn-block">Send</button>
|
| 454 |
+
</div>
|
| 455 |
<details class="studio-debug-trace" id="debug-trace-details">
|
| 456 |
<summary>Debug trace</summary>
|
| 457 |
<div id="debug-trace-panel"></div>
|
apps/gradio-space/static/studio/studio.css
CHANGED
|
@@ -960,11 +960,288 @@ body {
|
|
| 960 |
|
| 961 |
.workspace[data-view="voice"] .col-research,
|
| 962 |
.workspace[data-view="voice"] .col-slides { display: none; }
|
| 963 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 964 |
|
| 965 |
.workspace[data-view="coach"] .col-research,
|
| 966 |
.workspace[data-view="coach"] .col-slides { display: none; }
|
| 967 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 968 |
|
| 969 |
@media (max-width: 768px) {
|
| 970 |
:root { --sidebar-w: 0px; }
|
|
@@ -1229,9 +1506,12 @@ body {
|
|
| 1229 |
}
|
| 1230 |
|
| 1231 |
.coach-presets {
|
| 1232 |
-
margin-top: 0
|
| 1233 |
}
|
| 1234 |
|
|
|
|
|
|
|
|
|
|
| 1235 |
.workspace[data-view="debug"] .col-research,
|
| 1236 |
.workspace[data-view="debug"] .col-slides,
|
| 1237 |
.workspace[data-view="debug"] .col-studio { display: none; }
|
|
|
|
| 960 |
|
| 961 |
.workspace[data-view="voice"] .col-research,
|
| 962 |
.workspace[data-view="voice"] .col-slides { display: none; }
|
| 963 |
+
|
| 964 |
+
.workspace[data-view="voice"] .col-debug,
|
| 965 |
+
.workspace[data-view="voice"] .view-coach-only { display: none; }
|
| 966 |
+
|
| 967 |
+
.view-voice-only { display: none; }
|
| 968 |
+
|
| 969 |
+
.workspace[data-view="voice"] {
|
| 970 |
+
grid-template-columns: minmax(0, 1fr);
|
| 971 |
+
max-width: 1280px;
|
| 972 |
+
gap: 1.25rem;
|
| 973 |
+
}
|
| 974 |
+
|
| 975 |
+
.workspace[data-view="voice"] .col-studio {
|
| 976 |
+
grid-column: 1 / -1;
|
| 977 |
+
width: 100%;
|
| 978 |
+
min-width: 0;
|
| 979 |
+
}
|
| 980 |
+
|
| 981 |
+
.workspace[data-view="voice"] .voice-layout {
|
| 982 |
+
display: grid;
|
| 983 |
+
grid-template-columns: minmax(260px, 0.78fr) minmax(0, 1.22fr);
|
| 984 |
+
gap: 1.25rem;
|
| 985 |
+
align-items: start;
|
| 986 |
+
width: 100%;
|
| 987 |
+
}
|
| 988 |
+
|
| 989 |
+
.workspace[data-view="voice"] .voice-rail {
|
| 990 |
+
display: flex;
|
| 991 |
+
flex-direction: column;
|
| 992 |
+
gap: 1rem;
|
| 993 |
+
min-width: 0;
|
| 994 |
+
}
|
| 995 |
+
|
| 996 |
+
.workspace[data-view="voice"] .voice-main {
|
| 997 |
+
min-width: 0;
|
| 998 |
+
}
|
| 999 |
+
|
| 1000 |
+
.workspace[data-view="voice"] .voice-main-card {
|
| 1001 |
+
display: flex;
|
| 1002 |
+
flex-direction: column;
|
| 1003 |
+
}
|
| 1004 |
+
|
| 1005 |
+
.workspace[data-view="voice"] .voice-compose {
|
| 1006 |
+
display: flex;
|
| 1007 |
+
flex-direction: column;
|
| 1008 |
+
gap: 0.5rem;
|
| 1009 |
+
}
|
| 1010 |
+
|
| 1011 |
+
.workspace[data-view="voice"] .voice-compose .field {
|
| 1012 |
+
margin: 0;
|
| 1013 |
+
}
|
| 1014 |
+
|
| 1015 |
+
.workspace[data-view="voice"] .voice-compose textarea {
|
| 1016 |
+
min-height: 3.25rem;
|
| 1017 |
+
resize: vertical;
|
| 1018 |
+
}
|
| 1019 |
+
|
| 1020 |
+
.workspace[data-view="voice"] .voice-rail .voice-mode-cards {
|
| 1021 |
+
flex-direction: row;
|
| 1022 |
+
flex-wrap: wrap;
|
| 1023 |
+
gap: 0.35rem;
|
| 1024 |
+
margin-bottom: 0.75rem;
|
| 1025 |
+
}
|
| 1026 |
+
|
| 1027 |
+
.workspace[data-view="voice"] .voice-rail .voice-mode-cards .mode-card {
|
| 1028 |
+
flex: 1 1 calc(33.333% - 0.35rem);
|
| 1029 |
+
text-align: center;
|
| 1030 |
+
justify-content: center;
|
| 1031 |
+
min-width: 0;
|
| 1032 |
+
padding-left: 0.5rem;
|
| 1033 |
+
padding-right: 0.5rem;
|
| 1034 |
+
}
|
| 1035 |
+
|
| 1036 |
+
.workspace[data-view="voice"] .voice-rail-controls .voice-topic-wrap {
|
| 1037 |
+
margin: 0 0 0.75rem;
|
| 1038 |
+
}
|
| 1039 |
+
|
| 1040 |
+
.workspace[data-view="voice"] .voice-rag-sources {
|
| 1041 |
+
margin: 0;
|
| 1042 |
+
}
|
| 1043 |
+
|
| 1044 |
+
.workspace[data-view="voice"] .voice-rag-sources summary {
|
| 1045 |
+
cursor: pointer;
|
| 1046 |
+
font-weight: 600;
|
| 1047 |
+
font-size: 0.82rem;
|
| 1048 |
+
}
|
| 1049 |
+
|
| 1050 |
+
.workspace[data-view="voice"] .voice-chat-messages {
|
| 1051 |
+
min-height: 160px;
|
| 1052 |
+
max-height: min(260px, 32vh);
|
| 1053 |
+
margin: 0 0 0.75rem;
|
| 1054 |
+
}
|
| 1055 |
+
|
| 1056 |
+
.workspace[data-view="voice"] .voice-input-toolbar {
|
| 1057 |
+
padding: 0.65rem 0.75rem;
|
| 1058 |
+
border: 1px solid var(--outline-variant);
|
| 1059 |
+
border-radius: var(--radius-lg);
|
| 1060 |
+
background: var(--surface-container-low);
|
| 1061 |
+
margin-bottom: 0.65rem;
|
| 1062 |
+
}
|
| 1063 |
+
|
| 1064 |
+
.workspace[data-view="voice"] .voice-recording-row {
|
| 1065 |
+
margin: 0;
|
| 1066 |
+
}
|
| 1067 |
+
|
| 1068 |
+
.workspace[data-view="voice"] .voice-record-status {
|
| 1069 |
+
margin: 0.35rem 0 0;
|
| 1070 |
+
min-height: 1.1rem;
|
| 1071 |
+
}
|
| 1072 |
+
|
| 1073 |
+
.workspace[data-view="voice"] .voice-send-row {
|
| 1074 |
+
display: grid;
|
| 1075 |
+
grid-template-columns: 1fr 1fr;
|
| 1076 |
+
gap: 0.5rem;
|
| 1077 |
+
margin-bottom: 0.35rem;
|
| 1078 |
+
}
|
| 1079 |
+
|
| 1080 |
+
.workspace[data-view="voice"] .voice-card-head {
|
| 1081 |
+
margin-bottom: 0.85rem;
|
| 1082 |
+
}
|
| 1083 |
+
|
| 1084 |
+
.workspace[data-view="voice"] .voice-card-head .section-label {
|
| 1085 |
+
margin-bottom: 0.35rem;
|
| 1086 |
+
}
|
| 1087 |
+
|
| 1088 |
+
.voice-card-desc {
|
| 1089 |
+
margin: 0;
|
| 1090 |
+
font-size: 0.84rem;
|
| 1091 |
+
line-height: 1.45;
|
| 1092 |
+
color: var(--secondary);
|
| 1093 |
+
}
|
| 1094 |
+
|
| 1095 |
+
@media (max-width: 960px) {
|
| 1096 |
+
.workspace[data-view="voice"] .voice-layout {
|
| 1097 |
+
grid-template-columns: 1fr;
|
| 1098 |
+
max-width: 640px;
|
| 1099 |
+
margin-left: auto;
|
| 1100 |
+
margin-right: auto;
|
| 1101 |
+
}
|
| 1102 |
+
|
| 1103 |
+
.workspace[data-view="voice"] .voice-rail .voice-mode-cards {
|
| 1104 |
+
flex-direction: column;
|
| 1105 |
+
}
|
| 1106 |
+
|
| 1107 |
+
.workspace[data-view="voice"] .voice-rail .voice-mode-cards .mode-card {
|
| 1108 |
+
flex: 1 1 auto;
|
| 1109 |
+
text-align: left;
|
| 1110 |
+
justify-content: space-between;
|
| 1111 |
+
}
|
| 1112 |
+
|
| 1113 |
+
.workspace[data-view="voice"] .voice-send-row {
|
| 1114 |
+
grid-template-columns: 1fr;
|
| 1115 |
+
}
|
| 1116 |
+
}
|
| 1117 |
|
| 1118 |
.workspace[data-view="coach"] .col-research,
|
| 1119 |
.workspace[data-view="coach"] .col-slides { display: none; }
|
| 1120 |
+
|
| 1121 |
+
.workspace[data-view="coach"] .view-voice-only { display: none; }
|
| 1122 |
+
|
| 1123 |
+
.workspace[data-view="slides"] .col-studio,
|
| 1124 |
+
.workspace[data-view="research"] .col-debug { display: none; }
|
| 1125 |
+
|
| 1126 |
+
.workspace[data-view="coach"] {
|
| 1127 |
+
grid-template-columns: minmax(0, 1.05fr) minmax(0, 0.95fr);
|
| 1128 |
+
max-width: 1280px;
|
| 1129 |
+
gap: 1.25rem;
|
| 1130 |
+
align-items: start;
|
| 1131 |
+
}
|
| 1132 |
+
|
| 1133 |
+
.workspace[data-view="coach"] .coach-panel-wrap,
|
| 1134 |
+
.workspace[data-view="coach"] .coach-debug-card {
|
| 1135 |
+
display: flex;
|
| 1136 |
+
flex-direction: column;
|
| 1137 |
+
}
|
| 1138 |
+
|
| 1139 |
+
.workspace[data-view="coach"] .coach-results-panel {
|
| 1140 |
+
flex: 1;
|
| 1141 |
+
min-height: 120px;
|
| 1142 |
+
margin-top: 0.75rem;
|
| 1143 |
+
overflow-y: auto;
|
| 1144 |
+
}
|
| 1145 |
+
|
| 1146 |
+
.workspace[data-view="coach"] .coach-results-panel:not(:empty) {
|
| 1147 |
+
border-top: 1px solid var(--outline-variant);
|
| 1148 |
+
padding-top: 0.75rem;
|
| 1149 |
+
}
|
| 1150 |
+
|
| 1151 |
+
.workspace[data-view="coach"] .debug-chat-messages {
|
| 1152 |
+
min-height: 140px;
|
| 1153 |
+
max-height: min(240px, 30vh);
|
| 1154 |
+
margin-bottom: 0.5rem;
|
| 1155 |
+
}
|
| 1156 |
+
|
| 1157 |
+
.workspace[data-view="coach"] .coach-debug-compose {
|
| 1158 |
+
padding-top: 0;
|
| 1159 |
+
border-top: none;
|
| 1160 |
+
}
|
| 1161 |
+
|
| 1162 |
+
.workspace[data-view="coach"] .coach-debug-compose textarea {
|
| 1163 |
+
min-height: 3.5rem;
|
| 1164 |
+
resize: vertical;
|
| 1165 |
+
}
|
| 1166 |
+
|
| 1167 |
+
.workspace[data-view="coach"] .coach-debug-card .studio-debug-trace {
|
| 1168 |
+
flex-shrink: 0;
|
| 1169 |
+
margin-top: 0.5rem;
|
| 1170 |
+
}
|
| 1171 |
+
|
| 1172 |
+
.workspace[data-view="coach"] .coach-debug-card .toggle-row,
|
| 1173 |
+
.workspace[data-view="coach"] .coach-debug-card .debug-rag-scope {
|
| 1174 |
+
flex-shrink: 0;
|
| 1175 |
+
}
|
| 1176 |
+
|
| 1177 |
+
.view-coach-only { display: none; }
|
| 1178 |
+
.workspace[data-view="coach"] .view-coach-only:not(.coach-panel-wrap) { display: block; }
|
| 1179 |
+
.workspace[data-view="coach"] .view-debug-only { display: none; }
|
| 1180 |
+
|
| 1181 |
+
.coach-card-head {
|
| 1182 |
+
margin-bottom: 0.85rem;
|
| 1183 |
+
}
|
| 1184 |
+
|
| 1185 |
+
.coach-card-head .section-label {
|
| 1186 |
+
margin-bottom: 0.35rem;
|
| 1187 |
+
}
|
| 1188 |
+
|
| 1189 |
+
.coach-card-desc {
|
| 1190 |
+
margin: 0;
|
| 1191 |
+
font-size: 0.84rem;
|
| 1192 |
+
line-height: 1.45;
|
| 1193 |
+
color: var(--secondary);
|
| 1194 |
+
}
|
| 1195 |
+
|
| 1196 |
+
.coach-capture-row {
|
| 1197 |
+
display: grid;
|
| 1198 |
+
grid-template-columns: minmax(0, 1.2fr) minmax(0, 1fr);
|
| 1199 |
+
gap: 0.75rem;
|
| 1200 |
+
align-items: start;
|
| 1201 |
+
margin-bottom: 0.75rem;
|
| 1202 |
+
padding: 0.75rem;
|
| 1203 |
+
border: 1px solid var(--outline-variant);
|
| 1204 |
+
border-radius: var(--radius-lg);
|
| 1205 |
+
background: var(--surface-container-low);
|
| 1206 |
+
}
|
| 1207 |
+
|
| 1208 |
+
.coach-recording-row {
|
| 1209 |
+
margin: 0;
|
| 1210 |
+
}
|
| 1211 |
+
|
| 1212 |
+
.coach-record-status {
|
| 1213 |
+
margin: 0.35rem 0 0;
|
| 1214 |
+
min-height: 1.25rem;
|
| 1215 |
+
}
|
| 1216 |
+
|
| 1217 |
+
.coach-upload-field {
|
| 1218 |
+
margin: 0;
|
| 1219 |
+
}
|
| 1220 |
+
|
| 1221 |
+
.coach-upload-field input[type="file"] {
|
| 1222 |
+
font-size: 0.78rem;
|
| 1223 |
+
}
|
| 1224 |
+
|
| 1225 |
+
.coach-voiceout-toggle {
|
| 1226 |
+
margin: 0.75rem 0;
|
| 1227 |
+
}
|
| 1228 |
+
|
| 1229 |
+
.coach-analyze-btn {
|
| 1230 |
+
margin-top: 0.25rem;
|
| 1231 |
+
}
|
| 1232 |
+
|
| 1233 |
+
@media (max-width: 960px) {
|
| 1234 |
+
.workspace[data-view="coach"] {
|
| 1235 |
+
grid-template-columns: 1fr;
|
| 1236 |
+
max-width: 640px;
|
| 1237 |
+
margin-left: auto;
|
| 1238 |
+
margin-right: auto;
|
| 1239 |
+
}
|
| 1240 |
+
|
| 1241 |
+
.coach-capture-row {
|
| 1242 |
+
grid-template-columns: 1fr;
|
| 1243 |
+
}
|
| 1244 |
+
}
|
| 1245 |
|
| 1246 |
@media (max-width: 768px) {
|
| 1247 |
:root { --sidebar-w: 0px; }
|
|
|
|
| 1506 |
}
|
| 1507 |
|
| 1508 |
.coach-presets {
|
| 1509 |
+
margin-top: 0;
|
| 1510 |
}
|
| 1511 |
|
| 1512 |
+
.workspace[data-view="debug"] .view-coach-only { display: none; }
|
| 1513 |
+
.workspace[data-view="debug"] .view-debug-only { display: block; }
|
| 1514 |
+
|
| 1515 |
.workspace[data-view="debug"] .col-research,
|
| 1516 |
.workspace[data-view="debug"] .col-slides,
|
| 1517 |
.workspace[data-view="debug"] .col-studio { display: none; }
|
models.yaml
CHANGED
|
@@ -67,9 +67,3 @@ models:
|
|
| 67 |
backend: transformers
|
| 68 |
model_id: ./models/finetuned/minicpm5-1b-lora-merged
|
| 69 |
trust_remote_code: true
|
| 70 |
-
|
| 71 |
-
jepa-ensemble-lesson:
|
| 72 |
-
label: JEPA ensemble (LLM + emb + JEPA) lesson pretrain
|
| 73 |
-
backend: transformers
|
| 74 |
-
model_id: ./models/ensemble/jepa-lesson-pretrain
|
| 75 |
-
trust_remote_code: true
|
|
|
|
| 67 |
backend: transformers
|
| 68 |
model_id: ./models/finetuned/minicpm5-1b-lora-merged
|
| 69 |
trust_remote_code: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
packages.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ffmpeg
|
| 2 |
+
libsndfile1
|
pyproject.toml
CHANGED
|
@@ -7,7 +7,6 @@ requires-python = ">=3.12"
|
|
| 7 |
dependencies = [
|
| 8 |
"agent",
|
| 9 |
"echocoach",
|
| 10 |
-
"ensemble",
|
| 11 |
"gradio-space",
|
| 12 |
"inference",
|
| 13 |
"researchmind",
|
|
@@ -23,11 +22,6 @@ finetune = [
|
|
| 23 |
"datasets>=3.0.0",
|
| 24 |
"peft>=0.14.0",
|
| 25 |
]
|
| 26 |
-
ensemble = [
|
| 27 |
-
"accelerate>=1.2.0",
|
| 28 |
-
"peft>=0.14.0",
|
| 29 |
-
"transformers>=5.7.0",
|
| 30 |
-
]
|
| 31 |
evals = [
|
| 32 |
"slm-evals",
|
| 33 |
]
|
|
@@ -39,14 +33,12 @@ lm-eval = [
|
|
| 39 |
members = [
|
| 40 |
"apps/*",
|
| 41 |
"libs/*",
|
| 42 |
-
"research/ensemble",
|
| 43 |
"research/evals",
|
| 44 |
]
|
| 45 |
|
| 46 |
[tool.uv.sources]
|
| 47 |
agent = { workspace = true }
|
| 48 |
echocoach = { workspace = true }
|
| 49 |
-
ensemble = { workspace = true }
|
| 50 |
gradio-space = { workspace = true }
|
| 51 |
inference = { workspace = true }
|
| 52 |
researchmind = { workspace = true }
|
|
|
|
| 7 |
dependencies = [
|
| 8 |
"agent",
|
| 9 |
"echocoach",
|
|
|
|
| 10 |
"gradio-space",
|
| 11 |
"inference",
|
| 12 |
"researchmind",
|
|
|
|
| 22 |
"datasets>=3.0.0",
|
| 23 |
"peft>=0.14.0",
|
| 24 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
evals = [
|
| 26 |
"slm-evals",
|
| 27 |
]
|
|
|
|
| 33 |
members = [
|
| 34 |
"apps/*",
|
| 35 |
"libs/*",
|
|
|
|
| 36 |
"research/evals",
|
| 37 |
]
|
| 38 |
|
| 39 |
[tool.uv.sources]
|
| 40 |
agent = { workspace = true }
|
| 41 |
echocoach = { workspace = true }
|
|
|
|
| 42 |
gradio-space = { workspace = true }
|
| 43 |
inference = { workspace = true }
|
| 44 |
researchmind = { workspace = true }
|
requirements.txt
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Workspace packages (HF clones full repo)
|
| 2 |
+
-e ./libs/inference
|
| 3 |
+
-e ./libs/researchmind
|
| 4 |
+
-e ./libs/agent
|
| 5 |
+
-e ./libs/echocoach[piper,whisper]
|
| 6 |
+
-e ./apps/gradio-space
|
| 7 |
+
|
| 8 |
+
# Pinned runtime deps (do not pin gradio, spaces, or huggingface_hub — HF preinstalls them)
|
| 9 |
+
accelerate==1.13.0
|
| 10 |
+
torch==2.12.0
|
| 11 |
+
torchvision==0.27.0
|
| 12 |
+
transformers==5.10.2
|
| 13 |
+
peft==0.19.1
|
| 14 |
+
llama-cpp-python==0.3.26
|
| 15 |
+
sentence-transformers==5.5.1
|
| 16 |
+
pydantic>=2.0.0
|
| 17 |
+
pyyaml>=6.0.2
|
| 18 |
+
pillow>=10.0.0
|
| 19 |
+
python-pptx>=1.0.0
|
| 20 |
+
python-docx>=1.1.0
|
| 21 |
+
httpx>=0.28.0
|
| 22 |
+
numpy>=2.0.0
|
| 23 |
+
ddgs==9.14.4
|
| 24 |
+
googlesearch-python>=1.3.0
|
| 25 |
+
pypdf>=5.0.0
|
| 26 |
+
trafilatura==2.1.0
|
| 27 |
+
matplotlib==3.11.0
|
| 28 |
+
soundfile>=0.12.0
|
| 29 |
+
sounddevice>=0.5.0
|
| 30 |
+
librosa==0.11.0
|
| 31 |
+
piper-tts==1.4.2
|
| 32 |
+
pywhispercpp==1.5.0
|
research/README.md
CHANGED
|
@@ -1,27 +1,25 @@
|
|
| 1 |
# Research
|
| 2 |
|
| 3 |
-
Experimental code for **fine-tuning**
|
| 4 |
|
| 5 |
| Path | Purpose |
|
| 6 |
| ---- | ------- |
|
| 7 |
| [`finetune.py`](finetune.py) | LoRA / QLoRA / full fine-tune on chat or instruction data |
|
| 8 |
-
| [`ensemble/`](ensemble/) | JEPA + world-model ensemble experiments (uv package `ensemble`) |
|
| 9 |
| [`evals/`](evals/) | SLM agentic benchmark suite — BFCL, τ-bench, GAIA, SWE-bench (uv package `slm-evals`) |
|
| 10 |
-
| [`data/`](data/) | Shared JSONL datasets for finetune and
|
| 11 |
|
| 12 |
## Quick links
|
| 13 |
|
| 14 |
- **[USAGE.md](USAGE.md)** — install groups, commands, and typical workflows
|
| 15 |
- **[docs/overview.md](docs/overview.md)** — how the pieces fit together
|
| 16 |
-
- **[ensemble/README.md](ensemble/README.md)** — ensemble smoke tests and harnesses
|
| 17 |
- **[evals/USAGE.md](evals/USAGE.md)** — benchmark CLI, configs, and results
|
| 18 |
- **[evals/docs/benchmarks.md](evals/docs/benchmarks.md)** — what each benchmark measures
|
| 19 |
|
| 20 |
## Install (from repo root)
|
| 21 |
|
| 22 |
```bash
|
| 23 |
-
#
|
| 24 |
-
uv sync --group finetune --group
|
| 25 |
```
|
| 26 |
|
| 27 |
Individual groups:
|
|
@@ -29,8 +27,8 @@ Individual groups:
|
|
| 29 |
| Group | Command | Enables |
|
| 30 |
| ----- | ------- | ------- |
|
| 31 |
| `finetune` | `uv sync --group finetune` | `research/finetune.py` (LoRA, QLoRA, merge) |
|
| 32 |
-
| `ensemble` | `uv sync --group ensemble` | `research/ensemble/` package |
|
| 33 |
| `evals` | `uv sync --group evals` | `research/evals/` package (`slm-benchmark`) |
|
|
|
|
| 34 |
|
| 35 |
## Typical workflow
|
| 36 |
|
|
@@ -40,9 +38,7 @@ research/data/education-lesson-chat.jsonl
|
|
| 40 |
▼
|
| 41 |
research/finetune.py ──► models/finetuned/<preset>-lora/
|
| 42 |
│
|
| 43 |
-
|
| 44 |
-
│
|
| 45 |
-
└──► research/ensemble/ (JEPA / world-model ablations)
|
| 46 |
```
|
| 47 |
|
| 48 |
See [USAGE.md](USAGE.md) for copy-paste commands.
|
|
|
|
| 1 |
# Research
|
| 2 |
|
| 3 |
+
Experimental code for **fine-tuning** and **agentic benchmarks**. Nothing here is wired into the Gradio Lesson Agent by default — use it to train models and score checkpoints against public benchmarks.
|
| 4 |
|
| 5 |
| Path | Purpose |
|
| 6 |
| ---- | ------- |
|
| 7 |
| [`finetune.py`](finetune.py) | LoRA / QLoRA / full fine-tune on chat or instruction data |
|
|
|
|
| 8 |
| [`evals/`](evals/) | SLM agentic benchmark suite — BFCL, τ-bench, GAIA, SWE-bench (uv package `slm-evals`) |
|
| 9 |
+
| [`data/`](data/) | Shared JSONL datasets for finetune and evals |
|
| 10 |
|
| 11 |
## Quick links
|
| 12 |
|
| 13 |
- **[USAGE.md](USAGE.md)** — install groups, commands, and typical workflows
|
| 14 |
- **[docs/overview.md](docs/overview.md)** — how the pieces fit together
|
|
|
|
| 15 |
- **[evals/USAGE.md](evals/USAGE.md)** — benchmark CLI, configs, and results
|
| 16 |
- **[evals/docs/benchmarks.md](evals/docs/benchmarks.md)** — what each benchmark measures
|
| 17 |
|
| 18 |
## Install (from repo root)
|
| 19 |
|
| 20 |
```bash
|
| 21 |
+
# All research tooling
|
| 22 |
+
uv sync --group finetune --group evals --group lm-eval
|
| 23 |
```
|
| 24 |
|
| 25 |
Individual groups:
|
|
|
|
| 27 |
| Group | Command | Enables |
|
| 28 |
| ----- | ------- | ------- |
|
| 29 |
| `finetune` | `uv sync --group finetune` | `research/finetune.py` (LoRA, QLoRA, merge) |
|
|
|
|
| 30 |
| `evals` | `uv sync --group evals` | `research/evals/` package (`slm-benchmark`) |
|
| 31 |
+
| `lm-eval` | `uv sync --group lm-eval` | `slm-lm-eval` CLI (GSM8K, ARC, HellaSwag, …) |
|
| 32 |
|
| 33 |
## Typical workflow
|
| 34 |
|
|
|
|
| 38 |
▼
|
| 39 |
research/finetune.py ──► models/finetuned/<preset>-lora/
|
| 40 |
│
|
| 41 |
+
└──► research/evals/ (BFCL, τ-bench, GAIA, SWE-bench, lm-eval)
|
|
|
|
|
|
|
| 42 |
```
|
| 43 |
|
| 44 |
See [USAGE.md](USAGE.md) for copy-paste commands.
|
research/USAGE.md
CHANGED
|
@@ -1,24 +1,23 @@
|
|
| 1 |
# Research usage
|
| 2 |
|
| 3 |
-
How to run fine-tuning
|
| 4 |
|
| 5 |
The Lesson Agent app lives in `apps/gradio-space/` — see root [USAGE.md](../USAGE.md). Research code is optional and isolated here.
|
| 6 |
|
| 7 |
## Prerequisites
|
| 8 |
|
| 9 |
- [uv](https://docs.astral.sh/uv/) and Python 3.12
|
| 10 |
-
- GPU recommended for real-model runs (CPU works for smoke tests
|
| 11 |
- Hugging Face Hub access for model downloads and some benchmark datasets
|
| 12 |
|
| 13 |
## Install dependency groups
|
| 14 |
|
| 15 |
```bash
|
| 16 |
# All research tooling
|
| 17 |
-
uv sync --group finetune --group
|
| 18 |
|
| 19 |
# Or one at a time
|
| 20 |
uv sync --group finetune
|
| 21 |
-
uv sync --group ensemble
|
| 22 |
uv sync --group evals
|
| 23 |
uv sync --group lm-eval
|
| 24 |
```
|
|
@@ -26,7 +25,6 @@ uv sync --group lm-eval
|
|
| 26 |
| Group | Package / script | What it adds |
|
| 27 |
| ----- | ---------------- | ------------ |
|
| 28 |
| `finetune` | `research/finetune.py` | `peft`, `datasets`, `bitsandbytes` (QLoRA) |
|
| 29 |
-
| `ensemble` | `ensemble` workspace member | JEPA / world-model ensemble + harnesses |
|
| 30 |
| `evals` | `slm-evals` workspace member | `slm-benchmark` CLI |
|
| 31 |
| `lm-eval` | `slm-evals[lm-eval]` | `slm-lm-eval` CLI (GSM8K, ARC, HellaSwag, …) |
|
| 32 |
|
|
@@ -94,83 +92,7 @@ Training writes to `<out>/` (default `./models/finetuned/<preset>-<mode>/`):
|
|
| 94 |
|
| 95 |
---
|
| 96 |
|
| 97 |
-
## 2.
|
| 98 |
-
|
| 99 |
-
JEPA and world-model ensemble prototypes: small LLM + embedding memory + latent predictors + energy-based draft selection. **Not connected to the Gradio app.**
|
| 100 |
-
|
| 101 |
-
Install: `uv sync --group ensemble`
|
| 102 |
-
|
| 103 |
-
### Tier 1 — CPU smoke (no Hub download)
|
| 104 |
-
|
| 105 |
-
```bash
|
| 106 |
-
uv run --package ensemble python -m ensemble.jepa_ensemble tiny
|
| 107 |
-
uv run --package ensemble python -m ensemble.world_ensemble tiny
|
| 108 |
-
bash research/ensemble/scripts/smoke.sh
|
| 109 |
-
```
|
| 110 |
-
|
| 111 |
-
### Tier 2 — Real small model
|
| 112 |
-
|
| 113 |
-
```bash
|
| 114 |
-
uv run --package ensemble python -m ensemble.jepa_ensemble Qwen/Qwen2.5-0.5B-Instruct
|
| 115 |
-
uv run --package ensemble python -m ensemble.world_ensemble Qwen/Qwen2.5-0.5B-Instruct
|
| 116 |
-
```
|
| 117 |
-
|
| 118 |
-
### Pretrain + save (LLM + emb + JEPA)
|
| 119 |
-
|
| 120 |
-
```bash
|
| 121 |
-
# Default LLM: ENSEMBLE_LLM → LLM_PATH → BASE → MODEL_ID → ACTIVE_MODEL (models.yaml)
|
| 122 |
-
uv run --package ensemble ensemble-pretrain --steps 200
|
| 123 |
-
|
| 124 |
-
# Or override
|
| 125 |
-
uv run --package ensemble ensemble-pretrain \
|
| 126 |
-
--llm Qwen/Qwen2.5-0.5B-Instruct \
|
| 127 |
-
--steps 200
|
| 128 |
-
|
| 129 |
-
# Benchmark saved ensemble with slm-evals (compare to base HF model)
|
| 130 |
-
uv run --package slm-evals slm-benchmark \
|
| 131 |
-
--model ./models/ensemble/jepa-lesson-pretrain \
|
| 132 |
-
--model-type ensemble \
|
| 133 |
-
--benchmarks bfcl tau_bench --max-samples 20
|
| 134 |
-
```
|
| 135 |
-
|
| 136 |
-
Checkpoint files: `manifest.json`, `aux.pt`, `llm/` (PEFT adapters), optional `store.pt`.
|
| 137 |
-
|
| 138 |
-
### Tier 3 — Benchmark harnesses
|
| 139 |
-
|
| 140 |
-
Uses `research/data/benchmark-qa.jsonl` (questions) and `benchmark-kb.jsonl` (retrieval snippets).
|
| 141 |
-
|
| 142 |
-
```bash
|
| 143 |
-
# JEPA track — toy
|
| 144 |
-
uv run --package ensemble python -m ensemble.eval.jepa_harness \
|
| 145 |
-
--llm tiny --toy --limit 20 --n_drafts 8
|
| 146 |
-
|
| 147 |
-
# JEPA track — education QA
|
| 148 |
-
uv run --package ensemble python -m ensemble.eval.jepa_harness \
|
| 149 |
-
--llm Qwen/Qwen2.5-0.5B-Instruct \
|
| 150 |
-
--qa research/data/benchmark-qa.jsonl \
|
| 151 |
-
--kb research/data/benchmark-kb.jsonl \
|
| 152 |
-
--limit 50 --n_drafts 8
|
| 153 |
-
|
| 154 |
-
# World-model track
|
| 155 |
-
uv run --package ensemble python -m ensemble.eval.world_harness \
|
| 156 |
-
--llm tiny --toy --limit 20 --n_drafts 8
|
| 157 |
-
```
|
| 158 |
-
|
| 159 |
-
More detail: [ensemble/README.md](ensemble/README.md), [docs/overview.md](docs/overview.md).
|
| 160 |
-
|
| 161 |
-
### Legacy shims
|
| 162 |
-
|
| 163 |
-
Top-level files re-export the package for old scripts:
|
| 164 |
-
|
| 165 |
-
- `research/llm_emb_jepa_ensemble_pluggable.py` → `ensemble.jepa_ensemble`
|
| 166 |
-
- `research/world_model_ensemble.py` → `ensemble.world_ensemble`
|
| 167 |
-
- `research/eval_harness.py` → `ensemble.eval.jepa_harness`
|
| 168 |
-
|
| 169 |
-
Prefer `uv run --package ensemble python -m ensemble.<module>`.
|
| 170 |
-
|
| 171 |
-
---
|
| 172 |
-
|
| 173 |
-
## 3. Agentic benchmarks (`research/evals/`)
|
| 174 |
|
| 175 |
Evaluate local HuggingFace checkpoints on BFCL, τ-bench, GAIA, and SWE-bench Verified.
|
| 176 |
|
|
@@ -192,9 +114,9 @@ Full reference: [evals/USAGE.md](evals/USAGE.md).
|
|
| 192 |
|
| 193 |
---
|
| 194 |
|
| 195 |
-
##
|
| 196 |
|
| 197 |
-
Standard lm-evaluation-harness tasks (ARC, HellaSwag, GSM8K, …) for base presets, LoRA adapters,
|
| 198 |
|
| 199 |
Install: `uv sync --group lm-eval`
|
| 200 |
|
|
@@ -222,12 +144,6 @@ uv run --package slm-evals slm-lm-eval \
|
|
| 222 |
--preset minicpm5-1b-lesson-lora \
|
| 223 |
--experiment-name minicpm5-1b-lora__v1 \
|
| 224 |
--compare-to results/lm_eval/minicpm5-1b__baseline/results.json
|
| 225 |
-
|
| 226 |
-
# Ensemble checkpoint
|
| 227 |
-
uv run --package slm-evals slm-lm-eval \
|
| 228 |
-
--config research/evals/configs/lm_eval_smoke.yaml \
|
| 229 |
-
--model ./models/ensemble/jepa-lesson-pretrain \
|
| 230 |
-
--experiment-name ensemble-jepa__lm-eval
|
| 231 |
```
|
| 232 |
|
| 233 |
Post-training hook:
|
|
@@ -248,8 +164,8 @@ Full reference: [evals/USAGE.md](evals/USAGE.md#lm-evaluation-harness-slm-lm-eva
|
|
| 248 |
| File | Used by | Format |
|
| 249 |
| ---- | ------- | ------ |
|
| 250 |
| `education-lesson-chat.jsonl` | `finetune.py` default | Chat messages for lesson agent |
|
| 251 |
-
| `benchmark-qa.jsonl` |
|
| 252 |
-
| `benchmark-kb.jsonl` |
|
| 253 |
|
| 254 |
---
|
| 255 |
|
|
@@ -283,18 +199,12 @@ Full reference: [evals/USAGE.md](evals/USAGE.md#lm-evaluation-harness-slm-lm-eva
|
|
| 283 |
--compare-to results/lm_eval/minicpm5-1b__baseline/results.json
|
| 284 |
```
|
| 285 |
|
| 286 |
-
5. **Optional** — probe ensemble ideas on the same QA/KB files:
|
| 287 |
-
```bash
|
| 288 |
-
bash research/ensemble/scripts/smoke.sh
|
| 289 |
-
```
|
| 290 |
-
|
| 291 |
### Verification checklist
|
| 292 |
|
| 293 |
- Use the **same** lm-eval YAML (`tasks`, `num_fewshot`, `limit`, `seed`) for baseline and candidate runs.
|
| 294 |
- Compare lm-eval `results.json` files with `--compare-to`; do not compare `training_results.json` `result_score` to lm-eval accuracy.
|
| 295 |
- For LoRA checkpoints, prefer `--preset minicpm5-1b-lesson-lora` (base + adapter) over passing the adapter dir alone to `--model`.
|
| 296 |
- Report mean ± std only after multiple training seeds; single-seed deltas are indicative, not conclusive.
|
| 297 |
-
- Ensemble `loglikelihood` tasks score the underlying LLM head; generative tasks (`gsm8k`) use the full JEPA+RAG stack.
|
| 298 |
|
| 299 |
---
|
| 300 |
|
|
@@ -302,7 +212,6 @@ Full reference: [evals/USAGE.md](evals/USAGE.md#lm-evaluation-harness-slm-lm-eva
|
|
| 302 |
|
| 303 |
| Symptom | Fix |
|
| 304 |
| ------- | --- |
|
| 305 |
-
| `No module named 'ensemble'` | `uv sync --group ensemble` |
|
| 306 |
| `slm-benchmark: command not found` | `uv sync --group evals` |
|
| 307 |
| `slm-lm-eval: command not found` | `uv sync --group lm-eval` |
|
| 308 |
| CUDA OOM during finetune | Use `--mode qlora` or reduce batch size in script args |
|
|
|
|
| 1 |
# Research usage
|
| 2 |
|
| 3 |
+
How to run fine-tuning and agentic benchmarks under `research/`. All commands assume the **repo root** as the working directory unless noted.
|
| 4 |
|
| 5 |
The Lesson Agent app lives in `apps/gradio-space/` — see root [USAGE.md](../USAGE.md). Research code is optional and isolated here.
|
| 6 |
|
| 7 |
## Prerequisites
|
| 8 |
|
| 9 |
- [uv](https://docs.astral.sh/uv/) and Python 3.12
|
| 10 |
+
- GPU recommended for real-model runs (CPU works for smoke tests)
|
| 11 |
- Hugging Face Hub access for model downloads and some benchmark datasets
|
| 12 |
|
| 13 |
## Install dependency groups
|
| 14 |
|
| 15 |
```bash
|
| 16 |
# All research tooling
|
| 17 |
+
uv sync --group finetune --group evals --group lm-eval
|
| 18 |
|
| 19 |
# Or one at a time
|
| 20 |
uv sync --group finetune
|
|
|
|
| 21 |
uv sync --group evals
|
| 22 |
uv sync --group lm-eval
|
| 23 |
```
|
|
|
|
| 25 |
| Group | Package / script | What it adds |
|
| 26 |
| ----- | ---------------- | ------------ |
|
| 27 |
| `finetune` | `research/finetune.py` | `peft`, `datasets`, `bitsandbytes` (QLoRA) |
|
|
|
|
| 28 |
| `evals` | `slm-evals` workspace member | `slm-benchmark` CLI |
|
| 29 |
| `lm-eval` | `slm-evals[lm-eval]` | `slm-lm-eval` CLI (GSM8K, ARC, HellaSwag, …) |
|
| 30 |
|
|
|
|
| 92 |
|
| 93 |
---
|
| 94 |
|
| 95 |
+
## 2. Agentic benchmarks (`research/evals/`)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
Evaluate local HuggingFace checkpoints on BFCL, τ-bench, GAIA, and SWE-bench Verified.
|
| 98 |
|
|
|
|
| 114 |
|
| 115 |
---
|
| 116 |
|
| 117 |
+
## 3. Academic benchmarks (`slm-lm-eval`)
|
| 118 |
|
| 119 |
+
Standard lm-evaluation-harness tasks (ARC, HellaSwag, GSM8K, …) for base presets, LoRA adapters, and merged checkpoints.
|
| 120 |
|
| 121 |
Install: `uv sync --group lm-eval`
|
| 122 |
|
|
|
|
| 144 |
--preset minicpm5-1b-lesson-lora \
|
| 145 |
--experiment-name minicpm5-1b-lora__v1 \
|
| 146 |
--compare-to results/lm_eval/minicpm5-1b__baseline/results.json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
```
|
| 148 |
|
| 149 |
Post-training hook:
|
|
|
|
| 164 |
| File | Used by | Format |
|
| 165 |
| ---- | ------- | ------ |
|
| 166 |
| `education-lesson-chat.jsonl` | `finetune.py` default | Chat messages for lesson agent |
|
| 167 |
+
| `benchmark-qa.jsonl` | Optional domain QA evals | `question`, `answer`, `domain` |
|
| 168 |
+
| `benchmark-kb.jsonl` | Optional retrieval snippets | KB entries for domain QA |
|
| 169 |
|
| 170 |
---
|
| 171 |
|
|
|
|
| 199 |
--compare-to results/lm_eval/minicpm5-1b__baseline/results.json
|
| 200 |
```
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
### Verification checklist
|
| 203 |
|
| 204 |
- Use the **same** lm-eval YAML (`tasks`, `num_fewshot`, `limit`, `seed`) for baseline and candidate runs.
|
| 205 |
- Compare lm-eval `results.json` files with `--compare-to`; do not compare `training_results.json` `result_score` to lm-eval accuracy.
|
| 206 |
- For LoRA checkpoints, prefer `--preset minicpm5-1b-lesson-lora` (base + adapter) over passing the adapter dir alone to `--model`.
|
| 207 |
- Report mean ± std only after multiple training seeds; single-seed deltas are indicative, not conclusive.
|
|
|
|
| 208 |
|
| 209 |
---
|
| 210 |
|
|
|
|
| 212 |
|
| 213 |
| Symptom | Fix |
|
| 214 |
| ------- | --- |
|
|
|
|
| 215 |
| `slm-benchmark: command not found` | `uv sync --group evals` |
|
| 216 |
| `slm-lm-eval: command not found` | `uv sync --group lm-eval` |
|
| 217 |
| CUDA OOM during finetune | Use `--mode qlora` or reduce batch size in script args |
|
research/docs/overview.md
CHANGED
|
@@ -13,13 +13,12 @@ small-model-hackathon/
|
|
| 13 |
└── research/ ← experiments (this tree)
|
| 14 |
├── finetune.py
|
| 15 |
├── data/
|
| 16 |
-
├── ensemble/ ← uv workspace package
|
| 17 |
└── evals/ ← uv workspace package
|
| 18 |
```
|
| 19 |
|
| 20 |
-
Research code is a **uv workspace sibling** of `apps/*` and `libs/*`. Root `pyproject.toml` declares optional dependency groups (`finetune`, `
|
| 21 |
|
| 22 |
-
##
|
| 23 |
|
| 24 |
### Fine-tuning
|
| 25 |
|
|
@@ -27,38 +26,12 @@ Research code is a **uv workspace sibling** of `apps/*` and `libs/*`. Root `pypr
|
|
| 27 |
|
| 28 |
Outputs land in `models/finetuned/` — you can register a new preset in `models.yaml` pointing at merged weights for the **Well-Tuned** hackathon badge.
|
| 29 |
|
| 30 |
-
###
|
| 31 |
|
| 32 |
-
`research/
|
| 33 |
|
| 34 |
-
``
|
| 35 |
-
|
| 36 |
-
│
|
| 37 |
-
▼
|
| 38 |
-
JEPA encoder ──► latent state
|
| 39 |
-
│
|
| 40 |
-
├──► World model (multi-step latent rollout)
|
| 41 |
-
│
|
| 42 |
-
└──► Energy model (scores LLM draft continuations)
|
| 43 |
-
│
|
| 44 |
-
▼
|
| 45 |
-
Small LLM generates N drafts → pick lowest energy
|
| 46 |
-
```
|
| 47 |
-
|
| 48 |
-
Two entry ensembles:
|
| 49 |
-
|
| 50 |
-
| Module | File | Critic |
|
| 51 |
-
| ------ | ---- | ------ |
|
| 52 |
-
| JEPA track | `ensemble.jepa_ensemble` | JEPA latent prediction |
|
| 53 |
-
| World track | `ensemble.world_ensemble` | Energy model over world-model rollouts |
|
| 54 |
-
|
| 55 |
-
`TinyBackend` runs on CPU with random weights for smoke tests. `HFBackend` loads real Hub models via `transformers` + optional `peft` LoRA banks.
|
| 56 |
-
|
| 57 |
-
Eval harnesses (`ensemble.eval.jepa_harness`, `ensemble.eval.world_harness`) measure draft-selection accuracy on `research/data/benchmark-qa.jsonl` with optional KB retrieval from `benchmark-kb.jsonl`.
|
| 58 |
-
|
| 59 |
-
### Agentic evals
|
| 60 |
-
|
| 61 |
-
`research/evals/` (`slm-evals` package) scores **whole models** on public agent benchmarks — function calling, multi-turn tool use, GAIA tasks, and SWE-bench patches. This complements ensemble harnesses: evals test end-to-end model behavior; ensemble harnesses test internal selection mechanisms on a small custom QA set.
|
| 62 |
|
| 63 |
## Data flow
|
| 64 |
|
|
@@ -79,20 +52,12 @@ flowchart LR
|
|
| 79 |
tau[tau-bench]
|
| 80 |
gaia[GAIA]
|
| 81 |
swe[SWE-bench]
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
subgraph ens [ensemble]
|
| 85 |
-
jepa[JEPA harness]
|
| 86 |
-
world[World harness]
|
| 87 |
end
|
| 88 |
|
| 89 |
lesson --> train
|
| 90 |
train --> ckpt
|
| 91 |
ckpt --> evals
|
| 92 |
-
qa --> jepa
|
| 93 |
-
kb --> jepa
|
| 94 |
-
qa --> world
|
| 95 |
-
kb --> world
|
| 96 |
```
|
| 97 |
|
| 98 |
## When to use which tool
|
|
@@ -101,14 +66,11 @@ flowchart LR
|
|
| 101 |
| ---- | ---- |
|
| 102 |
| Improve lesson slide quality on your data | `finetune.py` + optional eval before/after |
|
| 103 |
| Compare base vs LoRA on public agent tasks | `slm-benchmark` |
|
| 104 |
-
|
|
| 105 |
| Ship in Gradio Space | `apps/gradio-space` only — wire new weights via `models.yaml` |
|
| 106 |
|
| 107 |
-
## Workspace
|
| 108 |
-
|
| 109 |
-
Both subpackages are listed in root `[tool.uv.workspace] members`:
|
| 110 |
|
| 111 |
-
|
| 112 |
-
- `research/evals` → import name `slm_evals`, CLI `slm-benchmark`
|
| 113 |
|
| 114 |
-
Run with `uv run --package
|
|
|
|
| 13 |
└── research/ ← experiments (this tree)
|
| 14 |
├── finetune.py
|
| 15 |
├── data/
|
|
|
|
| 16 |
└── evals/ ← uv workspace package
|
| 17 |
```
|
| 18 |
|
| 19 |
+
Research code is a **uv workspace sibling** of `apps/*` and `libs/*`. Root `pyproject.toml` declares optional dependency groups (`finetune`, `evals`, `lm-eval`) so the Docker Space image does not need to install torch-heavy extras unless you opt in locally.
|
| 20 |
|
| 21 |
+
## Two tracks
|
| 22 |
|
| 23 |
### Fine-tuning
|
| 24 |
|
|
|
|
| 26 |
|
| 27 |
Outputs land in `models/finetuned/` — you can register a new preset in `models.yaml` pointing at merged weights for the **Well-Tuned** hackathon badge.
|
| 28 |
|
| 29 |
+
### Agentic and academic evals
|
| 30 |
|
| 31 |
+
`research/evals/` (`slm-evals` package) scores **whole models** on:
|
| 32 |
|
| 33 |
+
- **Agentic benchmarks** — BFCL, τ-bench, GAIA, SWE-bench (`slm-benchmark`)
|
| 34 |
+
- **Academic benchmarks** — GSM8K, ARC, HellaSwag, etc. via lm-evaluation-harness (`slm-lm-eval`)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
## Data flow
|
| 37 |
|
|
|
|
| 52 |
tau[tau-bench]
|
| 53 |
gaia[GAIA]
|
| 54 |
swe[SWE-bench]
|
| 55 |
+
lmeval[lm-eval tasks]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
end
|
| 57 |
|
| 58 |
lesson --> train
|
| 59 |
train --> ckpt
|
| 60 |
ckpt --> evals
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
```
|
| 62 |
|
| 63 |
## When to use which tool
|
|
|
|
| 66 |
| ---- | ---- |
|
| 67 |
| Improve lesson slide quality on your data | `finetune.py` + optional eval before/after |
|
| 68 |
| Compare base vs LoRA on public agent tasks | `slm-benchmark` |
|
| 69 |
+
| Compare base vs LoRA on academic tasks | `slm-lm-eval` |
|
| 70 |
| Ship in Gradio Space | `apps/gradio-space` only — wire new weights via `models.yaml` |
|
| 71 |
|
| 72 |
+
## Workspace package
|
|
|
|
|
|
|
| 73 |
|
| 74 |
+
`research/evals` is listed in root `[tool.uv.workspace] members` as import name `slm_evals`, CLI `slm-benchmark` and `slm-lm-eval`.
|
|
|
|
| 75 |
|
| 76 |
+
Run with `uv run --package slm-evals ...` from the repo root so uv resolves workspace paths and shared lockfile versions.
|
research/ensemble/README.md
DELETED
|
@@ -1,113 +0,0 @@
|
|
| 1 |
-
# Ensemble research package
|
| 2 |
-
|
| 3 |
-
JEPA and world-model ensemble experiments. Stays under `research/` — not wired into the Gradio agent.
|
| 4 |
-
|
| 5 |
-
See also: [../USAGE.md](../USAGE.md) · [../docs/overview.md](../docs/overview.md)
|
| 6 |
-
|
| 7 |
-
## Install
|
| 8 |
-
|
| 9 |
-
```bash
|
| 10 |
-
uv sync --group ensemble
|
| 11 |
-
```
|
| 12 |
-
|
| 13 |
-
## Tier 1 — Smoke (CPU, no HF download)
|
| 14 |
-
|
| 15 |
-
```bash
|
| 16 |
-
uv run --package ensemble python -m ensemble.jepa_ensemble tiny
|
| 17 |
-
uv run --package ensemble python -m ensemble.world_ensemble tiny
|
| 18 |
-
bash research/ensemble/scripts/smoke.sh
|
| 19 |
-
```
|
| 20 |
-
|
| 21 |
-
## Tier 2 — Micro demo (real small model)
|
| 22 |
-
|
| 23 |
-
```bash
|
| 24 |
-
uv run --package ensemble python -m ensemble.jepa_ensemble Qwen/Qwen2.5-0.5B-Instruct
|
| 25 |
-
uv run --package ensemble python -m ensemble.world_ensemble Qwen/Qwen2.5-0.5B-Instruct
|
| 26 |
-
```
|
| 27 |
-
|
| 28 |
-
## Pretrain + save (LLM + emb + JEPA)
|
| 29 |
-
|
| 30 |
-
Joint training writes a full checkpoint to `models/ensemble/<name>/`:
|
| 31 |
-
|
| 32 |
-
```bash
|
| 33 |
-
# CPU smoke (tiny backend, no HF download)
|
| 34 |
-
uv run --package ensemble ensemble-pretrain \
|
| 35 |
-
--llm tiny --steps 50 --no-kb \
|
| 36 |
-
--out models/ensemble/jepa-smoke
|
| 37 |
-
|
| 38 |
-
# Uses ACTIVE_MODEL / BASE / LLM_PATH from .env + models.yaml by default
|
| 39 |
-
uv run --package ensemble ensemble-pretrain \
|
| 40 |
-
--data research/data/education-lesson-chat.jsonl \
|
| 41 |
-
--kb research/data/benchmark-kb.jsonl \
|
| 42 |
-
--steps 200
|
| 43 |
-
|
| 44 |
-
# Override base LLM explicitly
|
| 45 |
-
uv run --package ensemble ensemble-pretrain \
|
| 46 |
-
--llm Qwen/Qwen2.5-0.5B-Instruct --steps 200
|
| 47 |
-
```
|
| 48 |
-
|
| 49 |
-
Checkpoint layout: `manifest.json`, `aux.pt` (emb/jepa/bridge/router), `llm/` (PEFT adapters).
|
| 50 |
-
|
| 51 |
-
Benchmark the saved ensemble with **slm-evals** (auto-detects `manifest.json`):
|
| 52 |
-
|
| 53 |
-
```bash
|
| 54 |
-
uv run --package slm-evals slm-benchmark \
|
| 55 |
-
--model ./models/ensemble/jepa-lesson-pretrain \
|
| 56 |
-
--model-type ensemble \
|
| 57 |
-
--benchmarks bfcl tau_bench --max-samples 20
|
| 58 |
-
|
| 59 |
-
# Or use the template config
|
| 60 |
-
uv run --package slm-evals slm-benchmark \
|
| 61 |
-
--config research/evals/configs/ensemble_jepa_lesson.yaml
|
| 62 |
-
```
|
| 63 |
-
|
| 64 |
-
Compare against a base HF model by running the same config with `model_type: hf` and `model_path: openbmb/MiniCPM5-1B`.
|
| 65 |
-
|
| 66 |
-
## Tier 3 — Benchmark
|
| 67 |
-
|
| 68 |
-
### JEPA ablation ladder
|
| 69 |
-
|
| 70 |
-
```bash
|
| 71 |
-
# Toy (no download)
|
| 72 |
-
uv run --package ensemble python -m ensemble.eval.jepa_harness \
|
| 73 |
-
--llm tiny --toy --limit 20 --n_drafts 8
|
| 74 |
-
|
| 75 |
-
# Education QA set
|
| 76 |
-
uv run --package ensemble python -m ensemble.eval.jepa_harness \
|
| 77 |
-
--llm Qwen/Qwen2.5-0.5B-Instruct \
|
| 78 |
-
--qa research/data/benchmark-qa.jsonl \
|
| 79 |
-
--kb research/data/benchmark-kb.jsonl \
|
| 80 |
-
--limit 50 --n_drafts 8
|
| 81 |
-
```
|
| 82 |
-
|
| 83 |
-
### World-model energy selector
|
| 84 |
-
|
| 85 |
-
```bash
|
| 86 |
-
uv run --package ensemble python -m ensemble.eval.world_harness \
|
| 87 |
-
--llm tiny --toy --limit 20 --n_drafts 8
|
| 88 |
-
|
| 89 |
-
uv run --package ensemble python -m ensemble.eval.world_harness \
|
| 90 |
-
--llm Qwen/Qwen2.5-0.5B-Instruct \
|
| 91 |
-
--qa research/data/benchmark-qa.jsonl \
|
| 92 |
-
--kb research/data/benchmark-kb.jsonl \
|
| 93 |
-
--limit 50 --n_drafts 8
|
| 94 |
-
```
|
| 95 |
-
|
| 96 |
-
## Layout
|
| 97 |
-
|
| 98 |
-
```
|
| 99 |
-
research/ensemble/
|
| 100 |
-
src/ensemble/
|
| 101 |
-
backends.py # TinyBackend, HFBackend, TinyLLM, HFLLM
|
| 102 |
-
memory.py # Embedder, VectorStore, Router
|
| 103 |
-
jepa.py # JEPA latent predictor
|
| 104 |
-
bridge.py # LLM hidden -> latent alignment
|
| 105 |
-
world_model.py # Latent dynamics + rollout
|
| 106 |
-
energy.py # Energy-based critic
|
| 107 |
-
jepa_ensemble.py # Ensemble (JEPA track)
|
| 108 |
-
world_ensemble.py # WorldEnsemble
|
| 109 |
-
eval/
|
| 110 |
-
metrics.py
|
| 111 |
-
jepa_harness.py
|
| 112 |
-
world_harness.py
|
| 113 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/pyproject.toml
DELETED
|
@@ -1,16 +0,0 @@
|
|
| 1 |
-
[project]
|
| 2 |
-
name = "ensemble"
|
| 3 |
-
version = "0.1.0"
|
| 4 |
-
description = "JEPA and world-model ensemble research package"
|
| 5 |
-
readme = "README.md"
|
| 6 |
-
requires-python = ">=3.12"
|
| 7 |
-
dependencies = [
|
| 8 |
-
"torch>=2.5.0",
|
| 9 |
-
]
|
| 10 |
-
|
| 11 |
-
[project.scripts]
|
| 12 |
-
ensemble-pretrain = "ensemble.pretrain:main"
|
| 13 |
-
|
| 14 |
-
[build-system]
|
| 15 |
-
requires = ["uv_build>=0.8.13,<0.9.0"]
|
| 16 |
-
build-backend = "uv_build"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/scripts/smoke.sh
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env bash
|
| 2 |
-
set -euo pipefail
|
| 3 |
-
|
| 4 |
-
ROOT="$(cd "$(dirname "$0")/../../.." && pwd)"
|
| 5 |
-
cd "$ROOT"
|
| 6 |
-
|
| 7 |
-
echo "== JEPA ensemble demo (tiny) =="
|
| 8 |
-
uv run --package ensemble python -m ensemble.jepa_ensemble tiny
|
| 9 |
-
|
| 10 |
-
echo ""
|
| 11 |
-
echo "== World ensemble demo (tiny) =="
|
| 12 |
-
uv run --package ensemble python -m ensemble.world_ensemble tiny
|
| 13 |
-
|
| 14 |
-
echo ""
|
| 15 |
-
echo "== JEPA harness (toy) =="
|
| 16 |
-
uv run --package ensemble python -m ensemble.eval.jepa_harness \
|
| 17 |
-
--llm tiny --toy --limit 10 --n_drafts 4
|
| 18 |
-
|
| 19 |
-
echo "== Pretrain smoke + checkpoint roundtrip =="
|
| 20 |
-
uv run --package ensemble ensemble-pretrain \
|
| 21 |
-
--llm tiny --steps 20 --no-kb \
|
| 22 |
-
--out models/ensemble/jepa-smoke
|
| 23 |
-
uv run --package ensemble python -c "
|
| 24 |
-
from ensemble.checkpoint import load_checkpoint
|
| 25 |
-
ens = load_checkpoint('models/ensemble/jepa-smoke')
|
| 26 |
-
print('loaded ensemble, adapters:', ens.adapter_names)
|
| 27 |
-
"
|
| 28 |
-
|
| 29 |
-
echo ""
|
| 30 |
-
echo "== World harness (toy) =="
|
| 31 |
-
uv run --package ensemble python -m ensemble.eval.world_harness \
|
| 32 |
-
--llm tiny --toy --limit 10 --n_drafts 4
|
| 33 |
-
|
| 34 |
-
echo ""
|
| 35 |
-
echo "All smoke checks passed."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/__init__.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
"""Research ensemble package: JEPA and world-model tracks."""
|
| 2 |
-
|
| 3 |
-
__all__ = ["Ensemble", "WorldEnsemble"]
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
def __getattr__(name: str):
|
| 7 |
-
if name == "Ensemble":
|
| 8 |
-
from ensemble.jepa_ensemble import Ensemble
|
| 9 |
-
|
| 10 |
-
return Ensemble
|
| 11 |
-
if name == "WorldEnsemble":
|
| 12 |
-
from ensemble.world_ensemble import WorldEnsemble
|
| 13 |
-
|
| 14 |
-
return WorldEnsemble
|
| 15 |
-
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/backends.py
DELETED
|
@@ -1,418 +0,0 @@
|
|
| 1 |
-
"""LLM backends: toy fallbacks and HuggingFace + LoRA loaders."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class LLMBackend(nn.Module):
|
| 11 |
-
"""Contract for JEPA ensemble backends."""
|
| 12 |
-
|
| 13 |
-
vocab_size: int
|
| 14 |
-
hidden_size: int
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class HFBackend(LLMBackend):
|
| 18 |
-
"""HuggingFace causal LM with PEFT LoRA adapter bank."""
|
| 19 |
-
|
| 20 |
-
def __init__(
|
| 21 |
-
self,
|
| 22 |
-
model_path: str,
|
| 23 |
-
*,
|
| 24 |
-
load_in_4bit: bool = False,
|
| 25 |
-
lora_r: int = 16,
|
| 26 |
-
lora_alpha: int = 32,
|
| 27 |
-
target_modules=("q_proj", "v_proj"),
|
| 28 |
-
device: str | None = None,
|
| 29 |
-
torch_dtype=None,
|
| 30 |
-
):
|
| 31 |
-
super().__init__()
|
| 32 |
-
from peft import LoraConfig, get_peft_model
|
| 33 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 34 |
-
|
| 35 |
-
self.device_ = torch.device(
|
| 36 |
-
device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 37 |
-
)
|
| 38 |
-
|
| 39 |
-
kwargs = {}
|
| 40 |
-
if load_in_4bit:
|
| 41 |
-
from transformers import BitsAndBytesConfig
|
| 42 |
-
|
| 43 |
-
kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 44 |
-
load_in_4bit=True,
|
| 45 |
-
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 46 |
-
bnb_4bit_quant_type="nf4",
|
| 47 |
-
)
|
| 48 |
-
if torch_dtype is not None:
|
| 49 |
-
kwargs["torch_dtype"] = torch_dtype
|
| 50 |
-
|
| 51 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 52 |
-
if self.tokenizer.pad_token is None:
|
| 53 |
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 54 |
-
base = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
|
| 55 |
-
if not load_in_4bit:
|
| 56 |
-
base.to(self.device_)
|
| 57 |
-
|
| 58 |
-
for p in base.parameters():
|
| 59 |
-
p.requires_grad_(False)
|
| 60 |
-
|
| 61 |
-
self._lora_cfg = LoraConfig(
|
| 62 |
-
r=lora_r,
|
| 63 |
-
lora_alpha=lora_alpha,
|
| 64 |
-
lora_dropout=0.05,
|
| 65 |
-
target_modules=list(target_modules),
|
| 66 |
-
task_type="CAUSAL_LM",
|
| 67 |
-
)
|
| 68 |
-
self.model = get_peft_model(base, self._lora_cfg, adapter_name="general")
|
| 69 |
-
self._adapters = {"general"}
|
| 70 |
-
|
| 71 |
-
self.vocab_size = self.model.config.vocab_size
|
| 72 |
-
self.hidden_size = self.model.config.hidden_size
|
| 73 |
-
|
| 74 |
-
def add_adapter(self, name: str):
|
| 75 |
-
if name not in self._adapters:
|
| 76 |
-
self.model.add_adapter(name, self._lora_cfg)
|
| 77 |
-
self._adapters.add(name)
|
| 78 |
-
|
| 79 |
-
def set_adapter(self, name: str):
|
| 80 |
-
self.model.set_adapter(name)
|
| 81 |
-
|
| 82 |
-
def trainable_parameters(self):
|
| 83 |
-
return (p for p in self.model.parameters() if p.requires_grad)
|
| 84 |
-
|
| 85 |
-
def forward(self, ids):
|
| 86 |
-
out = self.model(
|
| 87 |
-
input_ids=ids.to(self.device_), output_hidden_states=True
|
| 88 |
-
)
|
| 89 |
-
return out.logits, out.hidden_states[-1]
|
| 90 |
-
|
| 91 |
-
@torch.no_grad()
|
| 92 |
-
def generate(self, ids, n_new=64, temperature=0.8):
|
| 93 |
-
gen_kwargs: dict = dict(
|
| 94 |
-
input_ids=ids.to(self.device_),
|
| 95 |
-
max_new_tokens=n_new,
|
| 96 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
| 97 |
-
)
|
| 98 |
-
if temperature <= 0:
|
| 99 |
-
gen_kwargs["do_sample"] = False
|
| 100 |
-
else:
|
| 101 |
-
gen_kwargs.update(do_sample=True, temperature=temperature)
|
| 102 |
-
out = self.model.generate(**gen_kwargs)
|
| 103 |
-
return out
|
| 104 |
-
|
| 105 |
-
def encode_text(self, text: str):
|
| 106 |
-
return self.tokenizer(text, return_tensors="pt").input_ids.to(self.device_)
|
| 107 |
-
|
| 108 |
-
def decode(self, ids):
|
| 109 |
-
return self.tokenizer.decode(ids[0], skip_special_tokens=True)
|
| 110 |
-
|
| 111 |
-
@property
|
| 112 |
-
def device(self):
|
| 113 |
-
return self.device_
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
class TinyBackend(LLMBackend):
|
| 117 |
-
"""Toy transformer with LoRA adapters (no transformers dependency)."""
|
| 118 |
-
|
| 119 |
-
VOCAB, D_MODEL, N_LAYERS, N_HEADS, SEQ_LEN, LORA_R = 1000, 128, 2, 4, 32, 8
|
| 120 |
-
|
| 121 |
-
class _LoRALinear(nn.Module):
|
| 122 |
-
def __init__(self, d_in, d_out, r):
|
| 123 |
-
super().__init__()
|
| 124 |
-
self.base = nn.Linear(d_in, d_out)
|
| 125 |
-
self.base.weight.requires_grad_(False)
|
| 126 |
-
self.base.bias.requires_grad_(False)
|
| 127 |
-
self.adapters, self.active, self.r = nn.ModuleDict(), None, r
|
| 128 |
-
|
| 129 |
-
def add_adapter(self, name):
|
| 130 |
-
A = nn.Linear(self.base.in_features, self.r, bias=False)
|
| 131 |
-
B = nn.Linear(self.r, self.base.out_features, bias=False)
|
| 132 |
-
nn.init.zeros_(B.weight)
|
| 133 |
-
self.adapters[name] = nn.Sequential(A, B)
|
| 134 |
-
|
| 135 |
-
def forward(self, x):
|
| 136 |
-
y = self.base(x)
|
| 137 |
-
if self.active and self.active in self.adapters:
|
| 138 |
-
y = y + self.adapters[self.active](x)
|
| 139 |
-
return y
|
| 140 |
-
|
| 141 |
-
class _Block(nn.Module):
|
| 142 |
-
def __init__(self, D, H, R):
|
| 143 |
-
super().__init__()
|
| 144 |
-
L = TinyBackend._LoRALinear
|
| 145 |
-
self.ln1 = nn.LayerNorm(D)
|
| 146 |
-
self.attn = nn.MultiheadAttention(D, H, batch_first=True)
|
| 147 |
-
self.ln2 = nn.LayerNorm(D)
|
| 148 |
-
self.up, self.down = L(D, 4 * D, R), L(4 * D, D, R)
|
| 149 |
-
|
| 150 |
-
def forward(self, x, mask):
|
| 151 |
-
h = self.ln1(x)
|
| 152 |
-
a, _ = self.attn(h, h, h, attn_mask=mask, need_weights=False)
|
| 153 |
-
x = x + a
|
| 154 |
-
return x + self.down(F.gelu(self.up(self.ln2(x))))
|
| 155 |
-
|
| 156 |
-
def __init__(self):
|
| 157 |
-
super().__init__()
|
| 158 |
-
D, V = self.D_MODEL, self.VOCAB
|
| 159 |
-
self.tok = nn.Embedding(V, D)
|
| 160 |
-
self.pos = nn.Embedding(self.SEQ_LEN * 4, D)
|
| 161 |
-
self.blocks = nn.ModuleList(
|
| 162 |
-
[self._Block(D, self.N_HEADS, self.LORA_R) for _ in range(self.N_LAYERS)]
|
| 163 |
-
)
|
| 164 |
-
self.ln_f, self.head = nn.LayerNorm(D), nn.Linear(D, V, bias=False)
|
| 165 |
-
self.vocab_size, self.hidden_size = V, D
|
| 166 |
-
self.add_adapter("general")
|
| 167 |
-
self.set_adapter("general")
|
| 168 |
-
|
| 169 |
-
def add_adapter(self, name):
|
| 170 |
-
for b in self.blocks:
|
| 171 |
-
b.up.add_adapter(name)
|
| 172 |
-
b.down.add_adapter(name)
|
| 173 |
-
|
| 174 |
-
def set_adapter(self, name):
|
| 175 |
-
for b in self.blocks:
|
| 176 |
-
b.up.active = name
|
| 177 |
-
b.down.active = name
|
| 178 |
-
|
| 179 |
-
def trainable_parameters(self):
|
| 180 |
-
return (p for p in self.parameters() if p.requires_grad)
|
| 181 |
-
|
| 182 |
-
def forward(self, ids):
|
| 183 |
-
B, T = ids.shape
|
| 184 |
-
x = self.tok(ids) + self.pos(torch.arange(T, device=ids.device))
|
| 185 |
-
mask = torch.triu(
|
| 186 |
-
torch.full((T, T), float("-inf"), device=ids.device), 1
|
| 187 |
-
)
|
| 188 |
-
for b in self.blocks:
|
| 189 |
-
x = b(x, mask)
|
| 190 |
-
h = self.ln_f(x)
|
| 191 |
-
return self.head(h), h
|
| 192 |
-
|
| 193 |
-
@torch.no_grad()
|
| 194 |
-
def generate(self, ids, n_new=16, temperature=1.0):
|
| 195 |
-
for _ in range(n_new):
|
| 196 |
-
logits, _ = self(ids[:, -self.SEQ_LEN :])
|
| 197 |
-
if temperature <= 0:
|
| 198 |
-
nxt = logits[:, -1].argmax(dim=-1, keepdim=True)
|
| 199 |
-
else:
|
| 200 |
-
nxt = torch.multinomial(
|
| 201 |
-
F.softmax(logits[:, -1] / temperature, -1), 1
|
| 202 |
-
)
|
| 203 |
-
ids = torch.cat([ids, nxt], dim=1)
|
| 204 |
-
return ids
|
| 205 |
-
|
| 206 |
-
def encode_text(self, text: str):
|
| 207 |
-
vals = [ord(c) % self.vocab_size for c in text[: self.SEQ_LEN]]
|
| 208 |
-
if not vals:
|
| 209 |
-
vals = [0]
|
| 210 |
-
return torch.tensor([vals], dtype=torch.long)
|
| 211 |
-
|
| 212 |
-
def decode(self, ids):
|
| 213 |
-
return " ".join(str(int(t)) for t in ids[0].tolist())
|
| 214 |
-
|
| 215 |
-
@property
|
| 216 |
-
def device(self):
|
| 217 |
-
return next(self.parameters()).device
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
def make_backend(llm: str, **kw) -> LLMBackend:
|
| 221 |
-
"""'tiny' -> toy model; anything else -> HF hub id or local path."""
|
| 222 |
-
return TinyBackend() if llm == "tiny" else HFBackend(llm, **kw)
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
def load_hf_backend_from_checkpoint(
|
| 226 |
-
base_llm: str,
|
| 227 |
-
adapter_dir: str | None,
|
| 228 |
-
*,
|
| 229 |
-
adapter_names: tuple[str, ...] = ("general",),
|
| 230 |
-
device: str | None = None,
|
| 231 |
-
load_in_4bit: bool = False,
|
| 232 |
-
lora_r: int = 16,
|
| 233 |
-
lora_alpha: int = 32,
|
| 234 |
-
) -> HFBackend:
|
| 235 |
-
"""Load a frozen base LM + saved PEFT adapters (ensemble checkpoint llm/)."""
|
| 236 |
-
from pathlib import Path
|
| 237 |
-
|
| 238 |
-
from peft import LoraConfig, PeftModel, get_peft_model
|
| 239 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 240 |
-
|
| 241 |
-
def _discover_adapter_dirs(root: Path) -> dict[str, Path]:
|
| 242 |
-
if (root / "adapter_config.json").is_file():
|
| 243 |
-
return {"general": root}
|
| 244 |
-
discovered: dict[str, Path] = {}
|
| 245 |
-
for child in sorted(root.iterdir()):
|
| 246 |
-
if child.is_dir() and (child / "adapter_config.json").is_file():
|
| 247 |
-
discovered[child.name] = child
|
| 248 |
-
return discovered
|
| 249 |
-
|
| 250 |
-
resolved_device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 251 |
-
tokenizer = AutoTokenizer.from_pretrained(adapter_dir or base_llm)
|
| 252 |
-
if tokenizer.pad_token is None:
|
| 253 |
-
tokenizer.pad_token = tokenizer.eos_token
|
| 254 |
-
|
| 255 |
-
kwargs: dict = {}
|
| 256 |
-
if load_in_4bit:
|
| 257 |
-
from transformers import BitsAndBytesConfig
|
| 258 |
-
|
| 259 |
-
kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 260 |
-
load_in_4bit=True,
|
| 261 |
-
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 262 |
-
bnb_4bit_quant_type="nf4",
|
| 263 |
-
)
|
| 264 |
-
elif resolved_device != "cpu":
|
| 265 |
-
kwargs["torch_dtype"] = torch.bfloat16
|
| 266 |
-
|
| 267 |
-
base = AutoModelForCausalLM.from_pretrained(base_llm, **kwargs)
|
| 268 |
-
if not load_in_4bit and resolved_device != "cpu":
|
| 269 |
-
base.to(resolved_device)
|
| 270 |
-
for p in base.parameters():
|
| 271 |
-
p.requires_grad_(False)
|
| 272 |
-
|
| 273 |
-
if adapter_dir:
|
| 274 |
-
adapter_dirs = _discover_adapter_dirs(Path(adapter_dir))
|
| 275 |
-
if not adapter_dirs:
|
| 276 |
-
raise ValueError(
|
| 277 |
-
f"No PEFT adapters found under {adapter_dir} "
|
| 278 |
-
"(expected adapter_config.json or <name>/adapter_config.json)"
|
| 279 |
-
)
|
| 280 |
-
preferred = [name for name in adapter_names if name in adapter_dirs]
|
| 281 |
-
load_order = preferred + [
|
| 282 |
-
name for name in adapter_dirs if name not in preferred
|
| 283 |
-
]
|
| 284 |
-
first_name = load_order[0]
|
| 285 |
-
model = PeftModel.from_pretrained(
|
| 286 |
-
base,
|
| 287 |
-
str(adapter_dirs[first_name]),
|
| 288 |
-
adapter_name=first_name,
|
| 289 |
-
is_trainable=False,
|
| 290 |
-
)
|
| 291 |
-
for name in load_order[1:]:
|
| 292 |
-
model.load_adapter(str(adapter_dirs[name]), adapter_name=name)
|
| 293 |
-
adapters = set(load_order)
|
| 294 |
-
else:
|
| 295 |
-
lora_cfg = LoraConfig(
|
| 296 |
-
r=lora_r,
|
| 297 |
-
lora_alpha=lora_alpha,
|
| 298 |
-
lora_dropout=0.05,
|
| 299 |
-
target_modules=["q_proj", "v_proj"],
|
| 300 |
-
task_type="CAUSAL_LM",
|
| 301 |
-
)
|
| 302 |
-
model = get_peft_model(base, lora_cfg, adapter_name="general")
|
| 303 |
-
adapters = {"general"}
|
| 304 |
-
|
| 305 |
-
backend = HFBackend.__new__(HFBackend)
|
| 306 |
-
nn.Module.__init__(backend)
|
| 307 |
-
backend.device_ = torch.device(resolved_device)
|
| 308 |
-
backend.tokenizer = tokenizer
|
| 309 |
-
backend.model = model
|
| 310 |
-
backend._lora_cfg = None
|
| 311 |
-
backend._adapters = adapters
|
| 312 |
-
backend.vocab_size = model.config.vocab_size
|
| 313 |
-
backend.hidden_size = model.config.hidden_size
|
| 314 |
-
if adapter_names:
|
| 315 |
-
backend.set_adapter(adapter_names[0])
|
| 316 |
-
return backend
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
class TinyLLM(nn.Module):
|
| 320 |
-
"""Simpler toy LLM for the world-model track (no adapter bank)."""
|
| 321 |
-
|
| 322 |
-
VOCAB, D, L, H, T = 1000, 128, 2, 4, 32
|
| 323 |
-
|
| 324 |
-
def __init__(self):
|
| 325 |
-
super().__init__()
|
| 326 |
-
self.tok = nn.Embedding(self.VOCAB, self.D)
|
| 327 |
-
self.pos = nn.Embedding(self.T * 4, self.D)
|
| 328 |
-
layer = nn.TransformerEncoderLayer(
|
| 329 |
-
self.D, self.H, 4 * self.D, batch_first=True, norm_first=True
|
| 330 |
-
)
|
| 331 |
-
self.blocks = nn.TransformerEncoder(layer, self.L)
|
| 332 |
-
self.head = nn.Linear(self.D, self.VOCAB, bias=False)
|
| 333 |
-
self.vocab_size, self.hidden_size = self.VOCAB, self.D
|
| 334 |
-
|
| 335 |
-
def forward(self, ids):
|
| 336 |
-
Tn = ids.size(1)
|
| 337 |
-
x = self.tok(ids) + self.pos(torch.arange(Tn, device=ids.device))
|
| 338 |
-
mask = torch.triu(
|
| 339 |
-
torch.full((Tn, Tn), float("-inf"), device=ids.device), 1
|
| 340 |
-
)
|
| 341 |
-
h = self.blocks(x, mask=mask)
|
| 342 |
-
return self.head(h), h
|
| 343 |
-
|
| 344 |
-
@torch.no_grad()
|
| 345 |
-
def generate(self, ids, n_new=16, temperature=1.0):
|
| 346 |
-
for _ in range(n_new):
|
| 347 |
-
logits, _ = self(ids[:, -self.T :])
|
| 348 |
-
nxt = torch.multinomial(
|
| 349 |
-
F.softmax(logits[:, -1] / temperature, -1), 1
|
| 350 |
-
)
|
| 351 |
-
ids = torch.cat([ids, nxt], 1)
|
| 352 |
-
return ids
|
| 353 |
-
|
| 354 |
-
def trainable_parameters(self):
|
| 355 |
-
return self.parameters()
|
| 356 |
-
|
| 357 |
-
@property
|
| 358 |
-
def device(self):
|
| 359 |
-
return next(self.parameters()).device
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
class HFLLM(nn.Module):
|
| 363 |
-
"""Small HF model with single LoRA stack (world-model track)."""
|
| 364 |
-
|
| 365 |
-
def __init__(self, path, lora_r=16):
|
| 366 |
-
super().__init__()
|
| 367 |
-
from peft import LoraConfig, get_peft_model
|
| 368 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 369 |
-
|
| 370 |
-
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
| 371 |
-
if self.tokenizer.pad_token is None:
|
| 372 |
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 373 |
-
base = AutoModelForCausalLM.from_pretrained(
|
| 374 |
-
path,
|
| 375 |
-
torch_dtype=torch.bfloat16
|
| 376 |
-
if torch.cuda.is_available()
|
| 377 |
-
else torch.float32,
|
| 378 |
-
device_map="auto" if torch.cuda.is_available() else None,
|
| 379 |
-
)
|
| 380 |
-
for p in base.parameters():
|
| 381 |
-
p.requires_grad_(False)
|
| 382 |
-
cfg = LoraConfig(
|
| 383 |
-
r=lora_r,
|
| 384 |
-
lora_alpha=2 * lora_r,
|
| 385 |
-
lora_dropout=0.05,
|
| 386 |
-
target_modules=["q_proj", "v_proj"],
|
| 387 |
-
task_type="CAUSAL_LM",
|
| 388 |
-
)
|
| 389 |
-
self.model = get_peft_model(base, cfg)
|
| 390 |
-
self.vocab_size = self.model.config.vocab_size
|
| 391 |
-
self.hidden_size = self.model.config.hidden_size
|
| 392 |
-
|
| 393 |
-
def forward(self, ids):
|
| 394 |
-
out = self.model(
|
| 395 |
-
input_ids=ids.to(self.device), output_hidden_states=True
|
| 396 |
-
)
|
| 397 |
-
return out.logits, out.hidden_states[-1]
|
| 398 |
-
|
| 399 |
-
@torch.no_grad()
|
| 400 |
-
def generate(self, ids, n_new=32, temperature=0.8):
|
| 401 |
-
return self.model.generate(
|
| 402 |
-
input_ids=ids.to(self.device),
|
| 403 |
-
max_new_tokens=n_new,
|
| 404 |
-
do_sample=True,
|
| 405 |
-
temperature=temperature,
|
| 406 |
-
pad_token_id=self.tokenizer.pad_token_id,
|
| 407 |
-
)
|
| 408 |
-
|
| 409 |
-
def trainable_parameters(self):
|
| 410 |
-
return (p for p in self.model.parameters() if p.requires_grad)
|
| 411 |
-
|
| 412 |
-
@property
|
| 413 |
-
def device(self):
|
| 414 |
-
return next(self.model.parameters()).device
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
def load_llm(spec: str):
|
| 418 |
-
return TinyLLM() if spec == "tiny" else HFLLM(spec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/bridge.py
DELETED
|
@@ -1,28 +0,0 @@
|
|
| 1 |
-
"""Bridge: align LLM hidden states with JEPA latent space."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class Bridge(nn.Module):
|
| 11 |
-
def __init__(self, d_llm_hidden: int, d_latent: int):
|
| 12 |
-
super().__init__()
|
| 13 |
-
self.proj = nn.Sequential(
|
| 14 |
-
nn.Linear(d_llm_hidden, d_latent),
|
| 15 |
-
nn.GELU(),
|
| 16 |
-
nn.Linear(d_latent, d_latent),
|
| 17 |
-
)
|
| 18 |
-
|
| 19 |
-
def forward(self, llm_hidden):
|
| 20 |
-
return self.proj(llm_hidden.float().mean(dim=1))
|
| 21 |
-
|
| 22 |
-
def info_nce(self, z1, z2, tau=0.07):
|
| 23 |
-
z1, z2 = F.normalize(z1, dim=-1), F.normalize(z2, dim=-1)
|
| 24 |
-
logits = z1 @ z2.t() / tau
|
| 25 |
-
labels = torch.arange(z1.size(0), device=z1.device)
|
| 26 |
-
return 0.5 * (
|
| 27 |
-
F.cross_entropy(logits, labels) + F.cross_entropy(logits.t(), labels)
|
| 28 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/checkpoint.py
DELETED
|
@@ -1,149 +0,0 @@
|
|
| 1 |
-
"""Save and load JEPA ensemble checkpoints under models/ensemble/."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import json
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
from typing import Any
|
| 8 |
-
|
| 9 |
-
import torch
|
| 10 |
-
|
| 11 |
-
from ensemble.backends import TinyBackend, load_hf_backend_from_checkpoint
|
| 12 |
-
from ensemble.jepa_ensemble import Ensemble
|
| 13 |
-
|
| 14 |
-
MANIFEST_FILE = "manifest.json"
|
| 15 |
-
AUX_FILE = "aux.pt"
|
| 16 |
-
STORE_FILE = "store.pt"
|
| 17 |
-
LLM_DIR = "llm"
|
| 18 |
-
TINY_LLM_FILE = "tiny_llm.pt"
|
| 19 |
-
|
| 20 |
-
CHECKPOINT_VERSION = 1
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def _aux_state_dict(ens: Ensemble) -> dict[str, torch.Tensor]:
|
| 24 |
-
return {
|
| 25 |
-
"emb": ens.emb.state_dict(),
|
| 26 |
-
"jepa": ens.jepa.state_dict(),
|
| 27 |
-
"bridge": ens.bridge.state_dict(),
|
| 28 |
-
"router": ens.router.state_dict(),
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def _store_payload(ens: Ensemble) -> dict[str, Any]:
|
| 33 |
-
return {
|
| 34 |
-
"keys": [k for k in ens.store.keys],
|
| 35 |
-
"values": [v for v in ens.store.values],
|
| 36 |
-
}
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def save_checkpoint(
|
| 40 |
-
ens: Ensemble,
|
| 41 |
-
out_dir: str | Path,
|
| 42 |
-
*,
|
| 43 |
-
base_llm: str,
|
| 44 |
-
training_meta: dict[str, Any] | None = None,
|
| 45 |
-
) -> Path:
|
| 46 |
-
"""Persist ensemble (LLM adapters + emb + JEPA + bridge + router + store)."""
|
| 47 |
-
root = Path(out_dir).resolve()
|
| 48 |
-
root.mkdir(parents=True, exist_ok=True)
|
| 49 |
-
|
| 50 |
-
backend = "tiny" if isinstance(ens.llm, TinyBackend) else "hf"
|
| 51 |
-
manifest: dict[str, Any] = {
|
| 52 |
-
"version": CHECKPOINT_VERSION,
|
| 53 |
-
"track": "jepa",
|
| 54 |
-
"backend": backend,
|
| 55 |
-
"base_llm": base_llm,
|
| 56 |
-
"adapter_names": list(ens.adapter_names),
|
| 57 |
-
"d_emb": ens.emb.d_emb,
|
| 58 |
-
"d_jepa": ens.jepa.d_latent,
|
| 59 |
-
"training": training_meta or {},
|
| 60 |
-
}
|
| 61 |
-
|
| 62 |
-
torch.save(_aux_state_dict(ens), root / AUX_FILE)
|
| 63 |
-
store = _store_payload(ens)
|
| 64 |
-
if store["keys"]:
|
| 65 |
-
torch.save(store, root / STORE_FILE)
|
| 66 |
-
|
| 67 |
-
if backend == "hf":
|
| 68 |
-
llm_path = root / LLM_DIR
|
| 69 |
-
llm_path.mkdir(exist_ok=True)
|
| 70 |
-
ens.llm.model.save_pretrained(llm_path)
|
| 71 |
-
ens.llm.tokenizer.save_pretrained(llm_path)
|
| 72 |
-
else:
|
| 73 |
-
torch.save(ens.llm.state_dict(), root / TINY_LLM_FILE)
|
| 74 |
-
|
| 75 |
-
with open(root / MANIFEST_FILE, "w") as f:
|
| 76 |
-
json.dump(manifest, f, indent=2)
|
| 77 |
-
|
| 78 |
-
return root
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def is_ensemble_checkpoint(path: str | Path) -> bool:
|
| 82 |
-
return (Path(path) / MANIFEST_FILE).is_file()
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
def load_checkpoint(
|
| 86 |
-
ckpt_dir: str | Path,
|
| 87 |
-
*,
|
| 88 |
-
device: str | None = None,
|
| 89 |
-
load_in_4bit: bool = False,
|
| 90 |
-
) -> Ensemble:
|
| 91 |
-
"""Restore a saved JEPA ensemble from models/ensemble/<name>/."""
|
| 92 |
-
root = Path(ckpt_dir).resolve()
|
| 93 |
-
manifest_path = root / MANIFEST_FILE
|
| 94 |
-
if not manifest_path.is_file():
|
| 95 |
-
raise FileNotFoundError(
|
| 96 |
-
f"Not an ensemble checkpoint (missing {MANIFEST_FILE}): {root}"
|
| 97 |
-
)
|
| 98 |
-
|
| 99 |
-
with open(manifest_path) as f:
|
| 100 |
-
manifest = json.load(f)
|
| 101 |
-
|
| 102 |
-
base_llm = manifest["base_llm"]
|
| 103 |
-
backend = manifest.get("backend", "hf")
|
| 104 |
-
adapter_names = tuple(manifest.get("adapter_names", ["general"]))
|
| 105 |
-
d_emb = manifest.get("d_emb", 64)
|
| 106 |
-
d_jepa = manifest.get("d_jepa", 64)
|
| 107 |
-
|
| 108 |
-
if backend == "tiny":
|
| 109 |
-
ens = Ensemble(
|
| 110 |
-
llm="tiny",
|
| 111 |
-
adapter_names=adapter_names,
|
| 112 |
-
d_emb=d_emb,
|
| 113 |
-
d_jepa=d_jepa,
|
| 114 |
-
)
|
| 115 |
-
tiny_state = torch.load(
|
| 116 |
-
root / TINY_LLM_FILE, map_location="cpu", weights_only=True
|
| 117 |
-
)
|
| 118 |
-
ens.llm.load_state_dict(tiny_state)
|
| 119 |
-
else:
|
| 120 |
-
llm_dir = root / LLM_DIR
|
| 121 |
-
llm_backend = load_hf_backend_from_checkpoint(
|
| 122 |
-
base_llm,
|
| 123 |
-
str(llm_dir) if llm_dir.is_dir() else None,
|
| 124 |
-
adapter_names=adapter_names,
|
| 125 |
-
device=device,
|
| 126 |
-
load_in_4bit=load_in_4bit,
|
| 127 |
-
)
|
| 128 |
-
ens = Ensemble(
|
| 129 |
-
llm=base_llm,
|
| 130 |
-
adapter_names=adapter_names,
|
| 131 |
-
d_emb=d_emb,
|
| 132 |
-
d_jepa=d_jepa,
|
| 133 |
-
llm_backend=llm_backend,
|
| 134 |
-
)
|
| 135 |
-
|
| 136 |
-
aux = torch.load(root / AUX_FILE, map_location="cpu", weights_only=True)
|
| 137 |
-
ens.emb.load_state_dict(aux["emb"])
|
| 138 |
-
ens.jepa.load_state_dict(aux["jepa"])
|
| 139 |
-
ens.bridge.load_state_dict(aux["bridge"])
|
| 140 |
-
ens.router.load_state_dict(aux["router"])
|
| 141 |
-
|
| 142 |
-
store_path = root / STORE_FILE
|
| 143 |
-
if store_path.is_file():
|
| 144 |
-
store = torch.load(store_path, map_location="cpu", weights_only=True)
|
| 145 |
-
ens.store.keys = list(store["keys"])
|
| 146 |
-
ens.store.values = list(store["values"])
|
| 147 |
-
|
| 148 |
-
ens.eval()
|
| 149 |
-
return ens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/config.py
DELETED
|
@@ -1,163 +0,0 @@
|
|
| 1 |
-
"""Resolve base LLM for ensemble from .env and models.yaml (same order as finetune)."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import os
|
| 6 |
-
import sys
|
| 7 |
-
from pathlib import Path
|
| 8 |
-
|
| 9 |
-
_REPO_ROOT = Path(__file__).resolve().parents[4]
|
| 10 |
-
_FALLBACK_PRESET = "minicpm5-1b"
|
| 11 |
-
|
| 12 |
-
_ENV_LLM_KEYS = (
|
| 13 |
-
"ENSEMBLE_LLM",
|
| 14 |
-
"LLM_PATH",
|
| 15 |
-
"BASE",
|
| 16 |
-
"FINETUNE_MODEL",
|
| 17 |
-
"MODEL_ID",
|
| 18 |
-
)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def repo_root() -> Path:
|
| 22 |
-
return _REPO_ROOT
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
def load_dotenv() -> None:
|
| 26 |
-
"""Load KEY=VALUE pairs from repo .env without overriding existing env vars."""
|
| 27 |
-
path = _REPO_ROOT / ".env"
|
| 28 |
-
if not path.is_file():
|
| 29 |
-
return
|
| 30 |
-
for line in path.read_text().splitlines():
|
| 31 |
-
line = line.strip()
|
| 32 |
-
if not line or line.startswith("#") or "=" not in line:
|
| 33 |
-
continue
|
| 34 |
-
key, _, value = line.partition("=")
|
| 35 |
-
key = key.strip()
|
| 36 |
-
value = value.strip().strip('"').strip("'")
|
| 37 |
-
if key:
|
| 38 |
-
os.environ.setdefault(key, value)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def _ensure_inference_on_path() -> None:
|
| 42 |
-
libs = _REPO_ROOT / "libs" / "inference" / "src"
|
| 43 |
-
if str(libs) not in sys.path:
|
| 44 |
-
sys.path.insert(0, str(libs))
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
def _is_ensemble_llm_preset(model) -> bool:
|
| 48 |
-
return model.backend == "transformers" and not model.multimodal and bool(
|
| 49 |
-
model.model_id
|
| 50 |
-
)
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
def _llm_from_local_path(raw: str) -> str | None:
|
| 54 |
-
path = Path(raw)
|
| 55 |
-
if not path.is_absolute():
|
| 56 |
-
path = (_REPO_ROOT / path).resolve()
|
| 57 |
-
if path.suffix == ".gguf":
|
| 58 |
-
return None
|
| 59 |
-
if path.is_dir() and (path / "config.json").is_file():
|
| 60 |
-
return str(path)
|
| 61 |
-
if path.is_file():
|
| 62 |
-
return None
|
| 63 |
-
return None
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
def _llm_from_env_paths() -> str | None:
|
| 67 |
-
for key in ("LLM_PATH", "MODEL_PATH"):
|
| 68 |
-
raw = os.environ.get(key)
|
| 69 |
-
if raw:
|
| 70 |
-
resolved = _llm_from_local_path(raw)
|
| 71 |
-
if resolved:
|
| 72 |
-
return resolved
|
| 73 |
-
return None
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def resolve_llm(
|
| 77 |
-
*,
|
| 78 |
-
llm_arg: str | None = None,
|
| 79 |
-
preset_arg: str | None = None,
|
| 80 |
-
) -> tuple[str, str | None]:
|
| 81 |
-
"""
|
| 82 |
-
Return (hub_id_or_local_path, preset_key) for ensemble HF backends.
|
| 83 |
-
|
| 84 |
-
Priority when llm_arg is None or ``auto``:
|
| 85 |
-
1. ENSEMBLE_LLM, LLM_PATH (local HF dir), BASE, FINETUNE_MODEL, MODEL_ID
|
| 86 |
-
2. MODEL_PATH if it points at a HuggingFace model directory (not .gguf)
|
| 87 |
-
3. ENSEMBLE_PRESET, FINETUNE_PRESET, or ACTIVE_MODEL from models.yaml
|
| 88 |
-
4. First fine-tunable transformers preset (default minicpm5-1b)
|
| 89 |
-
"""
|
| 90 |
-
if llm_arg and llm_arg not in ("auto",):
|
| 91 |
-
return llm_arg, preset_arg
|
| 92 |
-
|
| 93 |
-
for env_name in _ENV_LLM_KEYS:
|
| 94 |
-
raw = os.environ.get(env_name)
|
| 95 |
-
if raw:
|
| 96 |
-
local = _llm_from_local_path(raw)
|
| 97 |
-
return local or raw, preset_arg
|
| 98 |
-
|
| 99 |
-
local = _llm_from_env_paths()
|
| 100 |
-
if local:
|
| 101 |
-
return local, preset_arg
|
| 102 |
-
|
| 103 |
-
_ensure_inference_on_path()
|
| 104 |
-
from inference.config import get_app_config, get_model_config
|
| 105 |
-
|
| 106 |
-
app_config = get_app_config(reload=True)
|
| 107 |
-
preset_key = (
|
| 108 |
-
preset_arg
|
| 109 |
-
or os.environ.get("ENSEMBLE_PRESET")
|
| 110 |
-
or os.environ.get("FINETUNE_PRESET")
|
| 111 |
-
or os.environ.get("ACTIVE_MODEL")
|
| 112 |
-
)
|
| 113 |
-
|
| 114 |
-
if preset_key and preset_key in app_config.models:
|
| 115 |
-
model = get_model_config(preset_key)
|
| 116 |
-
if not _is_ensemble_llm_preset(model):
|
| 117 |
-
preset_key = None
|
| 118 |
-
|
| 119 |
-
if preset_key is None:
|
| 120 |
-
for candidate in (_FALLBACK_PRESET, *app_config.models):
|
| 121 |
-
if candidate not in app_config.models:
|
| 122 |
-
continue
|
| 123 |
-
model = get_model_config(candidate)
|
| 124 |
-
if _is_ensemble_llm_preset(model):
|
| 125 |
-
preset_key = candidate
|
| 126 |
-
break
|
| 127 |
-
|
| 128 |
-
if not preset_key:
|
| 129 |
-
raise SystemExit(
|
| 130 |
-
"No transformers LLM found for ensemble. Pass --llm, set LLM_PATH/BASE/"
|
| 131 |
-
"MODEL_ID in .env, or ACTIVE_MODEL in models.yaml."
|
| 132 |
-
)
|
| 133 |
-
|
| 134 |
-
model = get_model_config(preset_key)
|
| 135 |
-
if not _is_ensemble_llm_preset(model):
|
| 136 |
-
raise SystemExit(
|
| 137 |
-
f"Preset {preset_key!r} cannot back an ensemble "
|
| 138 |
-
f"(backend={model.backend}, multimodal={model.multimodal})."
|
| 139 |
-
)
|
| 140 |
-
return model.model_id, preset_key
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
def default_ensemble_out(preset_key: str | None) -> str:
|
| 144 |
-
label = preset_key or "custom"
|
| 145 |
-
return str((_REPO_ROOT / "models" / "ensemble" / f"{label}-jepa-pretrain").resolve())
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
def resolve_llm_cli(
|
| 149 |
-
llm: str | None,
|
| 150 |
-
*,
|
| 151 |
-
toy: bool = False,
|
| 152 |
-
preset: str | None = None,
|
| 153 |
-
) -> str:
|
| 154 |
-
"""CLI helper: explicit tiny, else .env / models.yaml unless --toy without --llm."""
|
| 155 |
-
if llm == "tiny":
|
| 156 |
-
return "tiny"
|
| 157 |
-
if llm is None or llm == "auto":
|
| 158 |
-
if toy:
|
| 159 |
-
return "tiny"
|
| 160 |
-
load_dotenv()
|
| 161 |
-
resolved, _ = resolve_llm(preset_arg=preset)
|
| 162 |
-
return resolved
|
| 163 |
-
return llm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/energy.py
DELETED
|
@@ -1,45 +0,0 @@
|
|
| 1 |
-
"""Energy model: score candidate latents against world state."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class EnergyModel(nn.Module):
|
| 11 |
-
def __init__(self, d_latent: int):
|
| 12 |
-
super().__init__()
|
| 13 |
-
self.net = nn.Sequential(
|
| 14 |
-
nn.Linear(2 * d_latent, 2 * d_latent),
|
| 15 |
-
nn.GELU(),
|
| 16 |
-
nn.Linear(2 * d_latent, d_latent),
|
| 17 |
-
nn.GELU(),
|
| 18 |
-
nn.Linear(d_latent, 1),
|
| 19 |
-
)
|
| 20 |
-
self.d_latent = d_latent
|
| 21 |
-
|
| 22 |
-
def energy(self, s, z):
|
| 23 |
-
return self.net(torch.cat([s, z], -1)).squeeze(-1)
|
| 24 |
-
|
| 25 |
-
def contrastive_loss(self, s, z_pos, z_negs=None, tau=0.5):
|
| 26 |
-
B = s.size(0)
|
| 27 |
-
s_rep = s.unsqueeze(1).expand(B, B, self.d_latent).reshape(
|
| 28 |
-
B * B, self.d_latent
|
| 29 |
-
)
|
| 30 |
-
z_rep = z_pos.unsqueeze(0).expand(B, B, self.d_latent).reshape(
|
| 31 |
-
B * B, self.d_latent
|
| 32 |
-
)
|
| 33 |
-
E = self.energy(s_rep, z_rep).view(B, B)
|
| 34 |
-
if z_negs is not None:
|
| 35 |
-
En = self.energy(
|
| 36 |
-
s.repeat_interleave(z_negs.size(1), 0),
|
| 37 |
-
z_negs.reshape(-1, self.d_latent),
|
| 38 |
-
).view(B, -1)
|
| 39 |
-
E = torch.cat([E, En], dim=1)
|
| 40 |
-
labels = torch.arange(B, device=s.device)
|
| 41 |
-
return F.cross_entropy(-E / tau, labels)
|
| 42 |
-
|
| 43 |
-
@torch.no_grad()
|
| 44 |
-
def rank(self, s, candidates):
|
| 45 |
-
return self.energy(s.expand(candidates.size(0), -1), candidates)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/eval/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
"""Evaluation harnesses for JEPA and world-model ensembles."""
|
|
|
|
|
|
research/ensemble/src/ensemble/eval/jepa_harness.py
DELETED
|
@@ -1,266 +0,0 @@
|
|
| 1 |
-
"""Ablation ladder + JEPA best-of-N benchmark for the ensemble."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import argparse
|
| 6 |
-
import json
|
| 7 |
-
import random
|
| 8 |
-
import time
|
| 9 |
-
from collections import defaultdict
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
import torch.nn.functional as F
|
| 13 |
-
|
| 14 |
-
from ensemble.eval.metrics import em_score, f1_score, paired_bootstrap
|
| 15 |
-
from ensemble.backends import TinyBackend
|
| 16 |
-
from ensemble.checkpoint import load_checkpoint
|
| 17 |
-
from ensemble.config import load_dotenv, resolve_llm_cli
|
| 18 |
-
from ensemble.jepa_ensemble import Ensemble
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
@torch.no_grad()
|
| 22 |
-
def generate_plain(ens, q_ids, n_new):
|
| 23 |
-
ens.llm.set_adapter(ens.adapter_names[0])
|
| 24 |
-
t0 = time.time()
|
| 25 |
-
out = ens.llm.generate(q_ids.to(ens.llm.device), n_new=n_new, temperature=0.7)
|
| 26 |
-
return out[:, q_ids.size(1) :], time.time() - t0
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
@torch.no_grad()
|
| 30 |
-
def generate_config(
|
| 31 |
-
ens, q_ids, n_new, *, use_rag, use_router, use_jepa, n_drafts=1, tau=0.0
|
| 32 |
-
):
|
| 33 |
-
q_emb = ens.emb(q_ids.cpu())
|
| 34 |
-
|
| 35 |
-
if use_router:
|
| 36 |
-
a_idx = ens.router(q_emb).item()
|
| 37 |
-
ens.llm.set_adapter(ens.adapter_names[a_idx])
|
| 38 |
-
else:
|
| 39 |
-
ens.llm.set_adapter(ens.adapter_names[0])
|
| 40 |
-
|
| 41 |
-
ctx = q_ids.cpu()
|
| 42 |
-
if use_rag:
|
| 43 |
-
mems = ens.store.search(q_emb, k=1)
|
| 44 |
-
if mems:
|
| 45 |
-
ctx = torch.cat([mems[0], ctx], dim=1)
|
| 46 |
-
|
| 47 |
-
t0 = time.time()
|
| 48 |
-
if not use_jepa:
|
| 49 |
-
out = ens.llm.generate(
|
| 50 |
-
ctx.to(ens.llm.device), n_new=n_new, temperature=0.7
|
| 51 |
-
)
|
| 52 |
-
return out[:, ctx.size(1) :], time.time() - t0, None
|
| 53 |
-
|
| 54 |
-
z_exp = ens.jepa.predict_next_latent(ctx)
|
| 55 |
-
drafts, scores = [], []
|
| 56 |
-
for _ in range(n_drafts):
|
| 57 |
-
out = ens.llm.generate(
|
| 58 |
-
ctx.to(ens.llm.device), n_new=n_new, temperature=0.9
|
| 59 |
-
)
|
| 60 |
-
new = out[:, ctx.size(1) :].cpu()
|
| 61 |
-
drafts.append(new)
|
| 62 |
-
scores.append(
|
| 63 |
-
F.cosine_similarity(z_exp, ens.jepa.encode(new)).item()
|
| 64 |
-
)
|
| 65 |
-
best = max(range(n_drafts), key=lambda i: scores[i])
|
| 66 |
-
return drafts[best], time.time() - t0, (drafts, scores)
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def selector_comparison(drafts_scores_gold, decode_fn, rng):
|
| 70 |
-
res = defaultdict(list)
|
| 71 |
-
for drafts, scores, gold in drafts_scores_gold:
|
| 72 |
-
texts = [decode_fn(d) for d in drafts]
|
| 73 |
-
ems = [em_score(t, gold) for t in texts]
|
| 74 |
-
res["first"].append(ems[0])
|
| 75 |
-
res["random"].append(ems[rng.randrange(len(ems))])
|
| 76 |
-
res["jepa"].append(ems[max(range(len(ems)), key=lambda i: scores[i])])
|
| 77 |
-
res["oracle"].append(max(ems))
|
| 78 |
-
return {k: sum(v) / len(v) for k, v in res.items()}, res
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def load_jsonl(path):
|
| 82 |
-
with open(path) as f:
|
| 83 |
-
return [json.loads(line) for line in f if line.strip()]
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
def make_toy_data(ens, n_qa=20, vocab=None):
|
| 87 |
-
vocab = vocab or ens.llm.vocab_size
|
| 88 |
-
qa, kb = [], []
|
| 89 |
-
for _ in range(n_qa):
|
| 90 |
-
key = torch.randint(0, vocab, (1, 6))
|
| 91 |
-
ans = torch.randint(0, vocab, (1, 4))
|
| 92 |
-
kb.append(torch.cat([key, ans], dim=1))
|
| 93 |
-
qa.append({"q_ids": key, "answer_ids": ans})
|
| 94 |
-
return qa, kb
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def run(args):
|
| 98 |
-
torch.manual_seed(args.seed)
|
| 99 |
-
rng = random.Random(args.seed)
|
| 100 |
-
|
| 101 |
-
if args.ckpt:
|
| 102 |
-
ens = load_checkpoint(args.ckpt)
|
| 103 |
-
print(f"loaded ensemble checkpoint: {args.ckpt}")
|
| 104 |
-
is_text = not isinstance(ens.llm, TinyBackend)
|
| 105 |
-
else:
|
| 106 |
-
load_dotenv()
|
| 107 |
-
args.llm = resolve_llm_cli(
|
| 108 |
-
args.llm, toy=args.toy, preset=getattr(args, "preset", None)
|
| 109 |
-
)
|
| 110 |
-
print(f"Resolved LLM: {args.llm}")
|
| 111 |
-
ens = Ensemble(llm=args.llm)
|
| 112 |
-
is_text = args.llm != "tiny"
|
| 113 |
-
|
| 114 |
-
if args.toy or not is_text:
|
| 115 |
-
qa, kb = make_toy_data(ens)
|
| 116 |
-
for mem in kb:
|
| 117 |
-
ens.memorize_ids(mem)
|
| 118 |
-
|
| 119 |
-
def to_ids(item):
|
| 120 |
-
return item["q_ids"]
|
| 121 |
-
|
| 122 |
-
def gold_text(item):
|
| 123 |
-
return " ".join(map(str, item["answer_ids"][0].tolist()))
|
| 124 |
-
|
| 125 |
-
def decode(ids):
|
| 126 |
-
return " ".join(map(str, ids[0].tolist()))
|
| 127 |
-
else:
|
| 128 |
-
qa = load_jsonl(args.qa)
|
| 129 |
-
if args.kb:
|
| 130 |
-
for row in load_jsonl(args.kb):
|
| 131 |
-
ens.memorize_text(row["text"])
|
| 132 |
-
|
| 133 |
-
def to_ids(item):
|
| 134 |
-
return ens.llm.encode_text(
|
| 135 |
-
f"Answer briefly.\nQ: {item['question']}\nA:"
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
def gold_text(item):
|
| 139 |
-
return item["answer"]
|
| 140 |
-
|
| 141 |
-
def decode(ids):
|
| 142 |
-
return ens.llm.decode(ids)
|
| 143 |
-
|
| 144 |
-
qa = qa[: args.limit]
|
| 145 |
-
print(
|
| 146 |
-
f"eval set: {len(qa)} questions | store: {len(ens.store.keys)} memories\n"
|
| 147 |
-
)
|
| 148 |
-
|
| 149 |
-
configs = {
|
| 150 |
-
"C1_base": dict(use_rag=False, use_router=False, use_jepa=False),
|
| 151 |
-
"C2_rag": dict(use_rag=True, use_router=False, use_jepa=False),
|
| 152 |
-
"C3_rag_router": dict(use_rag=True, use_router=True, use_jepa=False),
|
| 153 |
-
"C4_full_jepa": dict(
|
| 154 |
-
use_rag=True,
|
| 155 |
-
use_router=True,
|
| 156 |
-
use_jepa=True,
|
| 157 |
-
n_drafts=args.n_drafts,
|
| 158 |
-
),
|
| 159 |
-
}
|
| 160 |
-
|
| 161 |
-
per_q = {}
|
| 162 |
-
summary = {}
|
| 163 |
-
jepa_material = []
|
| 164 |
-
|
| 165 |
-
for name, cfg in configs.items():
|
| 166 |
-
ems, f1s, lats = [], [], []
|
| 167 |
-
for item in qa:
|
| 168 |
-
ids = to_ids(item)
|
| 169 |
-
if name == "C1_base":
|
| 170 |
-
out, dt = generate_plain(ens, ids, args.n_new)
|
| 171 |
-
extra = None
|
| 172 |
-
else:
|
| 173 |
-
out, dt, extra = generate_config(ens, ids, args.n_new, **cfg)
|
| 174 |
-
pred, gold = decode(out), gold_text(item)
|
| 175 |
-
ems.append(em_score(pred, gold))
|
| 176 |
-
f1s.append(f1_score(pred, gold))
|
| 177 |
-
lats.append(dt)
|
| 178 |
-
if name == "C4_full_jepa" and extra is not None:
|
| 179 |
-
jepa_material.append((extra[0], extra[1], gold))
|
| 180 |
-
per_q[name] = ems
|
| 181 |
-
summary[name] = (
|
| 182 |
-
sum(ems) / len(ems),
|
| 183 |
-
sum(f1s) / len(f1s),
|
| 184 |
-
sum(lats) / len(lats),
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
print(f"{'config':<16}{'EM':>8}{'F1':>8}{'lat(s)':>9}")
|
| 188 |
-
for k, (em, f1, lat) in summary.items():
|
| 189 |
-
print(f"{k:<16}{em:>8.3f}{f1:>8.3f}{lat:>9.3f}")
|
| 190 |
-
|
| 191 |
-
print("\ncomponent contributions (paired bootstrap, P(B>A)):")
|
| 192 |
-
ladder = list(configs.keys())
|
| 193 |
-
for a, b in zip(ladder, ladder[1:]):
|
| 194 |
-
d = summary[b][0] - summary[a][0]
|
| 195 |
-
p = paired_bootstrap(per_q[a], per_q[b])
|
| 196 |
-
print(f" {b} - {a}: ΔEM={d:+.3f} P(better)={p:.2f}")
|
| 197 |
-
|
| 198 |
-
if jepa_material:
|
| 199 |
-
sel, sel_per_q = selector_comparison(jepa_material, decode, rng)
|
| 200 |
-
print(
|
| 201 |
-
f"\nbest-of-N selector comparison (same drafts, N={args.n_drafts}):"
|
| 202 |
-
)
|
| 203 |
-
for k in ("first", "random", "jepa", "oracle"):
|
| 204 |
-
print(f" {k:<8}EM={sel[k]:.3f}")
|
| 205 |
-
p = paired_bootstrap(sel_per_q["random"], sel_per_q["jepa"])
|
| 206 |
-
verdict = (
|
| 207 |
-
"JEPA critic WORKS"
|
| 208 |
-
if p > 0.95
|
| 209 |
-
else "inconclusive — critic ~ random"
|
| 210 |
-
)
|
| 211 |
-
print(f" P(jepa > random) = {p:.2f} {verdict}")
|
| 212 |
-
print(f" headroom to oracle: {sel['oracle'] - sel['jepa']:.3f}")
|
| 213 |
-
|
| 214 |
-
if args.continual:
|
| 215 |
-
print(
|
| 216 |
-
"\ncontinual test: accuracy on task-A questions "
|
| 217 |
-
"before vs after adding adapters B and C"
|
| 218 |
-
)
|
| 219 |
-
ems_before = per_q["C3_rag_router"]
|
| 220 |
-
ens.new_task_adapter("task_B")
|
| 221 |
-
ens.new_task_adapter("task_C")
|
| 222 |
-
ems_after = []
|
| 223 |
-
for item in qa:
|
| 224 |
-
out, _, _ = generate_config(
|
| 225 |
-
ens,
|
| 226 |
-
to_ids(item),
|
| 227 |
-
args.n_new,
|
| 228 |
-
use_rag=True,
|
| 229 |
-
use_router=True,
|
| 230 |
-
use_jepa=False,
|
| 231 |
-
)
|
| 232 |
-
ems_after.append(em_score(decode(out), gold_text(item)))
|
| 233 |
-
bt = sum(ems_after) / len(ems_after) - sum(ems_before) / len(
|
| 234 |
-
ems_before
|
| 235 |
-
)
|
| 236 |
-
print(f" backward transfer (≈0 is ideal): {bt:+.3f}")
|
| 237 |
-
|
| 238 |
-
return summary
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
def parse_args():
|
| 242 |
-
p = argparse.ArgumentParser()
|
| 243 |
-
p.add_argument(
|
| 244 |
-
"--llm",
|
| 245 |
-
default=None,
|
| 246 |
-
help="HF id / path, 'tiny', or omit for LLM_PATH / ACTIVE_MODEL from .env",
|
| 247 |
-
)
|
| 248 |
-
p.add_argument("--preset", default=None, help="models.yaml preset override")
|
| 249 |
-
p.add_argument("--qa", default=None, help="jsonl with question/answer")
|
| 250 |
-
p.add_argument("--kb", default=None, help="jsonl with text -> vector store")
|
| 251 |
-
p.add_argument(
|
| 252 |
-
"--ckpt",
|
| 253 |
-
default=None,
|
| 254 |
-
help="saved ensemble directory (models/ensemble/... with manifest.json)",
|
| 255 |
-
)
|
| 256 |
-
p.add_argument("--toy", action="store_true", help="synthetic data smoke test")
|
| 257 |
-
p.add_argument("--limit", type=int, default=100)
|
| 258 |
-
p.add_argument("--n_new", type=int, default=24)
|
| 259 |
-
p.add_argument("--n_drafts", type=int, default=8)
|
| 260 |
-
p.add_argument("--continual", action="store_true")
|
| 261 |
-
p.add_argument("--seed", type=int, default=0)
|
| 262 |
-
return p.parse_args()
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
if __name__ == "__main__":
|
| 266 |
-
run(parse_args())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/eval/metrics.py
DELETED
|
@@ -1,42 +0,0 @@
|
|
| 1 |
-
"""QA metrics and paired bootstrap significance."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import random
|
| 6 |
-
import re
|
| 7 |
-
import string
|
| 8 |
-
from collections import Counter
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def normalize(s: str) -> str:
|
| 12 |
-
s = s.lower()
|
| 13 |
-
s = "".join(c for c in s if c not in string.punctuation)
|
| 14 |
-
s = re.sub(r"\b(a|an|the)\b", " ", s)
|
| 15 |
-
return " ".join(s.split())
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
def em_score(pred: str, gold: str) -> float:
|
| 19 |
-
return float(normalize(gold) in normalize(pred))
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
def f1_score(pred: str, gold: str) -> float:
|
| 23 |
-
p, g = normalize(pred).split(), normalize(gold).split()
|
| 24 |
-
if not p or not g:
|
| 25 |
-
return float(p == g)
|
| 26 |
-
common = Counter(p) & Counter(g)
|
| 27 |
-
overlap = sum(common.values())
|
| 28 |
-
if overlap == 0:
|
| 29 |
-
return 0.0
|
| 30 |
-
prec, rec = overlap / len(p), overlap / len(g)
|
| 31 |
-
return 2 * prec * rec / (prec + rec)
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
def paired_bootstrap(scores_a, scores_b, iters=2000, seed=0):
|
| 35 |
-
rng = random.Random(seed)
|
| 36 |
-
n, wins = len(scores_a), 0
|
| 37 |
-
for _ in range(iters):
|
| 38 |
-
idx = [rng.randrange(n) for _ in range(n)]
|
| 39 |
-
da = sum(scores_a[i] for i in idx) / n
|
| 40 |
-
db = sum(scores_b[i] for i in idx) / n
|
| 41 |
-
wins += db > da
|
| 42 |
-
return wins / iters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/eval/world_harness.py
DELETED
|
@@ -1,174 +0,0 @@
|
|
| 1 |
-
"""Energy-based draft selector benchmark for the world-model ensemble."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import argparse
|
| 6 |
-
import json
|
| 7 |
-
import random
|
| 8 |
-
import time
|
| 9 |
-
from collections import defaultdict
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
|
| 13 |
-
from ensemble.eval.metrics import em_score, f1_score, paired_bootstrap
|
| 14 |
-
from ensemble.world_ensemble import WorldEnsemble
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
@torch.no_grad()
|
| 18 |
-
def generate_drafts(ens, q_ids, n_new, n_drafts, use_rag=True):
|
| 19 |
-
q_emb = ens.emb(q_ids.cpu())
|
| 20 |
-
mems = ens.store.search(q_emb, k=1) if use_rag else []
|
| 21 |
-
segments = (mems + [q_ids.cpu()]) if mems else [q_ids.cpu()]
|
| 22 |
-
ctx = torch.cat(segments, dim=1)
|
| 23 |
-
|
| 24 |
-
s = ens.world_state(segments)
|
| 25 |
-
ens.world.rollout(s, horizon=3)
|
| 26 |
-
|
| 27 |
-
drafts, energies = [], []
|
| 28 |
-
t0 = time.time()
|
| 29 |
-
for _ in range(n_drafts):
|
| 30 |
-
out = ens.llm.generate(
|
| 31 |
-
ctx.to(ens.llm.device), n_new=n_new, temperature=0.9
|
| 32 |
-
)
|
| 33 |
-
new = out[:, ctx.size(1) :].cpu()
|
| 34 |
-
drafts.append(new)
|
| 35 |
-
z = ens.jepa.encode(new)
|
| 36 |
-
energies.append(ens.energy.rank(s, z).item())
|
| 37 |
-
return drafts, energies, time.time() - t0
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
def selector_comparison(drafts_energy_gold, decode_fn, rng):
|
| 41 |
-
res = defaultdict(list)
|
| 42 |
-
for drafts, energies, gold in drafts_energy_gold:
|
| 43 |
-
texts = [decode_fn(d) for d in drafts]
|
| 44 |
-
ems = [em_score(t, gold) for t in texts]
|
| 45 |
-
res["first"].append(ems[0])
|
| 46 |
-
res["random"].append(ems[rng.randrange(len(ems))])
|
| 47 |
-
res["energy"].append(
|
| 48 |
-
ems[min(range(len(ems)), key=lambda i: energies[i])]
|
| 49 |
-
)
|
| 50 |
-
res["oracle"].append(max(ems))
|
| 51 |
-
return {k: sum(v) / len(v) for k, v in res.items()}, res
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def load_jsonl(path):
|
| 55 |
-
with open(path) as f:
|
| 56 |
-
return [json.loads(line) for line in f if line.strip()]
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def make_toy_data(ens, n_qa=20, vocab=None):
|
| 60 |
-
vocab = vocab or ens.llm.vocab_size
|
| 61 |
-
qa, kb = [], []
|
| 62 |
-
for _ in range(n_qa):
|
| 63 |
-
key = torch.randint(0, vocab, (1, 6))
|
| 64 |
-
ans = torch.randint(0, vocab, (1, 4))
|
| 65 |
-
kb.append(torch.cat([key, ans], dim=1))
|
| 66 |
-
qa.append({"q_ids": key, "answer_ids": ans})
|
| 67 |
-
return qa, kb
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
def run(args):
|
| 71 |
-
from ensemble.config import load_dotenv, resolve_llm_cli
|
| 72 |
-
|
| 73 |
-
torch.manual_seed(args.seed)
|
| 74 |
-
rng = random.Random(args.seed)
|
| 75 |
-
|
| 76 |
-
load_dotenv()
|
| 77 |
-
args.llm = resolve_llm_cli(
|
| 78 |
-
args.llm, toy=args.toy, preset=getattr(args, "preset", None)
|
| 79 |
-
)
|
| 80 |
-
print(f"Resolved LLM: {args.llm}")
|
| 81 |
-
ens = WorldEnsemble(args.llm)
|
| 82 |
-
if args.ckpt:
|
| 83 |
-
state = torch.load(args.ckpt, map_location="cpu")
|
| 84 |
-
ens.load_state_dict(state, strict=False)
|
| 85 |
-
print(f"loaded world ensemble checkpoint: {args.ckpt}")
|
| 86 |
-
|
| 87 |
-
is_text = args.llm != "tiny"
|
| 88 |
-
|
| 89 |
-
if args.toy or not is_text:
|
| 90 |
-
qa, kb = make_toy_data(ens)
|
| 91 |
-
for mem in kb:
|
| 92 |
-
ens.memorize(mem)
|
| 93 |
-
|
| 94 |
-
def to_ids(item):
|
| 95 |
-
return item["q_ids"]
|
| 96 |
-
|
| 97 |
-
def gold_text(item):
|
| 98 |
-
return " ".join(map(str, item["answer_ids"][0].tolist()))
|
| 99 |
-
|
| 100 |
-
def decode(ids):
|
| 101 |
-
return " ".join(map(str, ids[0].tolist()))
|
| 102 |
-
else:
|
| 103 |
-
qa = load_jsonl(args.qa)
|
| 104 |
-
if args.kb:
|
| 105 |
-
for row in load_jsonl(args.kb):
|
| 106 |
-
ids = ens.llm.tokenizer(
|
| 107 |
-
row["text"], return_tensors="pt"
|
| 108 |
-
).input_ids
|
| 109 |
-
ens.memorize(ids)
|
| 110 |
-
|
| 111 |
-
def to_ids(item):
|
| 112 |
-
return ens.llm.tokenizer(
|
| 113 |
-
f"Answer briefly.\nQ: {item['question']}\nA:",
|
| 114 |
-
return_tensors="pt",
|
| 115 |
-
).input_ids
|
| 116 |
-
|
| 117 |
-
def gold_text(item):
|
| 118 |
-
return item["answer"]
|
| 119 |
-
|
| 120 |
-
def decode(ids):
|
| 121 |
-
return ens.llm.tokenizer.decode(ids[0], skip_special_tokens=True)
|
| 122 |
-
|
| 123 |
-
qa = qa[: args.limit]
|
| 124 |
-
print(
|
| 125 |
-
f"eval set: {len(qa)} questions | store: {len(ens.store.keys)} memories\n"
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
material = []
|
| 129 |
-
lats = []
|
| 130 |
-
for item in qa:
|
| 131 |
-
drafts, energies, dt = generate_drafts(
|
| 132 |
-
ens, to_ids(item), args.n_new, args.n_drafts
|
| 133 |
-
)
|
| 134 |
-
material.append((drafts, energies, gold_text(item)))
|
| 135 |
-
lats.append(dt)
|
| 136 |
-
|
| 137 |
-
sel, sel_per_q = selector_comparison(material, decode, rng)
|
| 138 |
-
print(f"best-of-N selector comparison (same drafts, N={args.n_drafts}):")
|
| 139 |
-
for k in ("first", "random", "energy", "oracle"):
|
| 140 |
-
print(f" {k:<8}EM={sel[k]:.3f}")
|
| 141 |
-
p = paired_bootstrap(sel_per_q["random"], sel_per_q["energy"])
|
| 142 |
-
verdict = (
|
| 143 |
-
"Energy critic WORKS"
|
| 144 |
-
if p > 0.95
|
| 145 |
-
else "inconclusive — critic ~ random"
|
| 146 |
-
)
|
| 147 |
-
print(f" P(energy > random) = {p:.2f} {verdict}")
|
| 148 |
-
print(f" headroom to oracle: {sel['oracle'] - sel['energy']:.3f}")
|
| 149 |
-
print(f" mean latency: {sum(lats) / len(lats):.3f}s")
|
| 150 |
-
|
| 151 |
-
return sel
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
def parse_args():
|
| 155 |
-
p = argparse.ArgumentParser()
|
| 156 |
-
p.add_argument(
|
| 157 |
-
"--llm",
|
| 158 |
-
default=None,
|
| 159 |
-
help="HF id / path, 'tiny', or omit for LLM_PATH / ACTIVE_MODEL from .env",
|
| 160 |
-
)
|
| 161 |
-
p.add_argument("--preset", default=None, help="models.yaml preset override")
|
| 162 |
-
p.add_argument("--qa", default=None, help="jsonl with question/answer")
|
| 163 |
-
p.add_argument("--kb", default=None, help="jsonl with text -> vector store")
|
| 164 |
-
p.add_argument("--ckpt", default=None, help="trained world ensemble .pt")
|
| 165 |
-
p.add_argument("--toy", action="store_true", help="synthetic data smoke test")
|
| 166 |
-
p.add_argument("--limit", type=int, default=100)
|
| 167 |
-
p.add_argument("--n_new", type=int, default=24)
|
| 168 |
-
p.add_argument("--n_drafts", type=int, default=8)
|
| 169 |
-
p.add_argument("--seed", type=int, default=0)
|
| 170 |
-
return p.parse_args()
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
if __name__ == "__main__":
|
| 174 |
-
run(parse_args())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/eval_harness.py
DELETED
|
@@ -1,309 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
eval_harness.py — Ablation ladder + JEPA best-of-N test for the ensemble
|
| 3 |
-
========================================================================
|
| 4 |
-
Companion to `llm_emb_jepa_ensemble_pluggable.py` (must be importable,
|
| 5 |
-
i.e. in the same directory).
|
| 6 |
-
|
| 7 |
-
What it runs
|
| 8 |
-
------------
|
| 9 |
-
1. ABLATION LADDER on a QA set:
|
| 10 |
-
C1 base LLM alone
|
| 11 |
-
C2 C1 + RAG (embedding retrieval)
|
| 12 |
-
C3 C2 + router/adapters
|
| 13 |
-
C4 C3 + JEPA best-of-N critic
|
| 14 |
-
(C5 = C4 with a bridge-trained checkpoint — just pass --ckpt)
|
| 15 |
-
|
| 16 |
-
2. BEST-OF-N SELECTOR comparison (the decisive JEPA experiment):
|
| 17 |
-
first-sample | random-pick | JEPA-score pick | oracle pick
|
| 18 |
-
All on the SAME N drafts per question, so differences are pure selection.
|
| 19 |
-
|
| 20 |
-
3. CONTINUAL FORGETTING test (optional, --continual):
|
| 21 |
-
accuracy on task A before vs after training adapters for B and C.
|
| 22 |
-
|
| 23 |
-
4. PAIRED BOOTSTRAP significance between any two configs.
|
| 24 |
-
|
| 25 |
-
Usage
|
| 26 |
-
-----
|
| 27 |
-
# Smoke test, no GPU/deps beyond torch (toy backend, synthetic QA):
|
| 28 |
-
python eval_harness.py --llm tiny --toy
|
| 29 |
-
|
| 30 |
-
# Real model + your QA file (jsonl: {"question": ..., "answer": ..., "context": optional}):
|
| 31 |
-
python eval_harness.py --llm Qwen/Qwen2.5-0.5B-Instruct \
|
| 32 |
-
--qa ./domain_qa.jsonl --kb ./knowledge.jsonl --n_drafts 8
|
| 33 |
-
|
| 34 |
-
# With a bridge-trained ensemble checkpoint (C5):
|
| 35 |
-
python eval_harness.py --llm /models/llama-3.2-1b --qa ./qa.jsonl \
|
| 36 |
-
--kb ./kb.jsonl --ckpt ./ensemble_bridge.pt
|
| 37 |
-
|
| 38 |
-
QA file: {"question": str, "answer": str, "domain": optional str}
|
| 39 |
-
KB file: {"text": str} (each line becomes one memory in the vector store)
|
| 40 |
-
"""
|
| 41 |
-
|
| 42 |
-
import argparse
|
| 43 |
-
import json
|
| 44 |
-
import random
|
| 45 |
-
import re
|
| 46 |
-
import string
|
| 47 |
-
import time
|
| 48 |
-
from collections import Counter, defaultdict
|
| 49 |
-
|
| 50 |
-
import torch
|
| 51 |
-
|
| 52 |
-
from llm_emb_jepa_ensemble_pluggable import Ensemble # same directory
|
| 53 |
-
|
| 54 |
-
# ----------------------------------------------------------------------------
|
| 55 |
-
# Metrics: normalized exact match + token F1 (SQuAD-style)
|
| 56 |
-
# ----------------------------------------------------------------------------
|
| 57 |
-
def normalize(s: str) -> str:
|
| 58 |
-
s = s.lower()
|
| 59 |
-
s = "".join(c for c in s if c not in string.punctuation)
|
| 60 |
-
s = re.sub(r"\b(a|an|the)\b", " ", s)
|
| 61 |
-
return " ".join(s.split())
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
def em_score(pred: str, gold: str) -> float:
|
| 65 |
-
return float(normalize(gold) in normalize(pred)) # containment EM
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def f1_score(pred: str, gold: str) -> float:
|
| 69 |
-
p, g = normalize(pred).split(), normalize(gold).split()
|
| 70 |
-
if not p or not g:
|
| 71 |
-
return float(p == g)
|
| 72 |
-
common = Counter(p) & Counter(g)
|
| 73 |
-
overlap = sum(common.values())
|
| 74 |
-
if overlap == 0:
|
| 75 |
-
return 0.0
|
| 76 |
-
prec, rec = overlap / len(p), overlap / len(g)
|
| 77 |
-
return 2 * prec * rec / (prec + rec)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
# ----------------------------------------------------------------------------
|
| 81 |
-
# Paired bootstrap: P(config B beats config A)
|
| 82 |
-
# ----------------------------------------------------------------------------
|
| 83 |
-
def paired_bootstrap(scores_a, scores_b, iters=2000, seed=0):
|
| 84 |
-
rng = random.Random(seed)
|
| 85 |
-
n, wins = len(scores_a), 0
|
| 86 |
-
for _ in range(iters):
|
| 87 |
-
idx = [rng.randrange(n) for _ in range(n)]
|
| 88 |
-
da = sum(scores_a[i] for i in idx) / n
|
| 89 |
-
db = sum(scores_b[i] for i in idx) / n
|
| 90 |
-
wins += db > da
|
| 91 |
-
return wins / iters
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
# ----------------------------------------------------------------------------
|
| 95 |
-
# Config runners — each returns per-question dicts
|
| 96 |
-
# ----------------------------------------------------------------------------
|
| 97 |
-
@torch.no_grad()
|
| 98 |
-
def generate_plain(ens, q_ids, n_new):
|
| 99 |
-
"""C1: base adapter, no retrieval, single sample."""
|
| 100 |
-
ens.llm.set_adapter(ens.adapter_names[0])
|
| 101 |
-
t0 = time.time()
|
| 102 |
-
out = ens.llm.generate(q_ids.to(ens.llm.device), n_new=n_new, temperature=0.7)
|
| 103 |
-
return out[:, q_ids.size(1):], time.time() - t0
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
@torch.no_grad()
|
| 107 |
-
def generate_config(ens, q_ids, n_new, *, use_rag, use_router, use_jepa,
|
| 108 |
-
n_drafts=1, tau=0.0):
|
| 109 |
-
"""Unified runner for C2/C3/C4."""
|
| 110 |
-
q_emb = ens.emb(q_ids.cpu())
|
| 111 |
-
|
| 112 |
-
if use_router:
|
| 113 |
-
a_idx = ens.router(q_emb).item()
|
| 114 |
-
ens.llm.set_adapter(ens.adapter_names[a_idx])
|
| 115 |
-
else:
|
| 116 |
-
ens.llm.set_adapter(ens.adapter_names[0])
|
| 117 |
-
|
| 118 |
-
ctx = q_ids.cpu()
|
| 119 |
-
if use_rag:
|
| 120 |
-
mems = ens.store.search(q_emb, k=1)
|
| 121 |
-
if mems:
|
| 122 |
-
ctx = torch.cat([mems[0], ctx], dim=1)
|
| 123 |
-
|
| 124 |
-
t0 = time.time()
|
| 125 |
-
if not use_jepa:
|
| 126 |
-
out = ens.llm.generate(ctx.to(ens.llm.device), n_new=n_new, temperature=0.7)
|
| 127 |
-
return out[:, ctx.size(1):], time.time() - t0, None
|
| 128 |
-
|
| 129 |
-
# JEPA best-of-N: sample drafts, keep the one closest to predicted latent
|
| 130 |
-
z_exp = ens.jepa.predict_next_latent(ctx)
|
| 131 |
-
drafts, scores = [], []
|
| 132 |
-
for _ in range(n_drafts):
|
| 133 |
-
out = ens.llm.generate(ctx.to(ens.llm.device), n_new=n_new, temperature=0.9)
|
| 134 |
-
new = out[:, ctx.size(1):].cpu()
|
| 135 |
-
drafts.append(new)
|
| 136 |
-
scores.append(torch.nn.functional.cosine_similarity(
|
| 137 |
-
z_exp, ens.jepa.encode(new)).item())
|
| 138 |
-
best = max(range(n_drafts), key=lambda i: scores[i])
|
| 139 |
-
return drafts[best], time.time() - t0, (drafts, scores)
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
# ----------------------------------------------------------------------------
|
| 143 |
-
# Best-of-N selector comparison on shared drafts
|
| 144 |
-
# ----------------------------------------------------------------------------
|
| 145 |
-
def selector_comparison(drafts_scores_gold, decode_fn, rng):
|
| 146 |
-
"""drafts_scores_gold: list of (drafts, jepa_scores, gold_answer).
|
| 147 |
-
Returns EM for: first | random | jepa | oracle — all on the SAME drafts."""
|
| 148 |
-
res = defaultdict(list)
|
| 149 |
-
for drafts, scores, gold in drafts_scores_gold:
|
| 150 |
-
texts = [decode_fn(d) for d in drafts]
|
| 151 |
-
ems = [em_score(t, gold) for t in texts]
|
| 152 |
-
res["first"].append(ems[0])
|
| 153 |
-
res["random"].append(ems[rng.randrange(len(ems))])
|
| 154 |
-
res["jepa"].append(ems[max(range(len(ems)), key=lambda i: scores[i])])
|
| 155 |
-
res["oracle"].append(max(ems)) # upper bound of selection
|
| 156 |
-
return {k: sum(v) / len(v) for k, v in res.items()}, res
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
# ----------------------------------------------------------------------------
|
| 160 |
-
# Data loading
|
| 161 |
-
# ----------------------------------------------------------------------------
|
| 162 |
-
def load_jsonl(path):
|
| 163 |
-
with open(path) as f:
|
| 164 |
-
return [json.loads(l) for l in f if l.strip()]
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
def make_toy_data(ens, n_qa=20, vocab=None):
|
| 168 |
-
"""Synthetic QA for the tiny backend: 'answer' token sequence is planted
|
| 169 |
-
in the KB so RAG can genuinely help even with random weights."""
|
| 170 |
-
vocab = vocab or ens.llm.vocab_size
|
| 171 |
-
qa, kb = [], []
|
| 172 |
-
for i in range(n_qa):
|
| 173 |
-
key = torch.randint(0, vocab, (1, 6))
|
| 174 |
-
ans = torch.randint(0, vocab, (1, 4))
|
| 175 |
-
kb.append(torch.cat([key, ans], dim=1)) # memory = key+answer
|
| 176 |
-
qa.append({"q_ids": key, "answer_ids": ans})
|
| 177 |
-
return qa, kb
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
# ----------------------------------------------------------------------------
|
| 181 |
-
# Main evaluation
|
| 182 |
-
# ----------------------------------------------------------------------------
|
| 183 |
-
def run(args):
|
| 184 |
-
torch.manual_seed(args.seed)
|
| 185 |
-
rng = random.Random(args.seed)
|
| 186 |
-
|
| 187 |
-
ens = Ensemble(llm=args.llm)
|
| 188 |
-
if args.ckpt:
|
| 189 |
-
state = torch.load(args.ckpt, map_location="cpu")
|
| 190 |
-
ens.load_state_dict(state, strict=False)
|
| 191 |
-
print(f"loaded ensemble checkpoint: {args.ckpt}")
|
| 192 |
-
|
| 193 |
-
is_text = args.llm != "tiny"
|
| 194 |
-
|
| 195 |
-
# ---- load data and fill the vector store -------------------------------
|
| 196 |
-
if args.toy or not is_text:
|
| 197 |
-
qa, kb = make_toy_data(ens)
|
| 198 |
-
for mem in kb:
|
| 199 |
-
ens.memorize_ids(mem)
|
| 200 |
-
def to_ids(item): return item["q_ids"]
|
| 201 |
-
def gold_of(item): return item["answer_ids"]
|
| 202 |
-
def decode(ids): return " ".join(map(str, ids[0].tolist()))
|
| 203 |
-
def gold_text(item): return decode(item["answer_ids"])
|
| 204 |
-
else:
|
| 205 |
-
qa = load_jsonl(args.qa)
|
| 206 |
-
if args.kb:
|
| 207 |
-
for row in load_jsonl(args.kb):
|
| 208 |
-
ens.memorize_text(row["text"])
|
| 209 |
-
def to_ids(item): return ens.llm.encode_text(
|
| 210 |
-
f"Answer briefly.\nQ: {item['question']}\nA:")
|
| 211 |
-
def gold_text(item): return item["answer"]
|
| 212 |
-
def decode(ids): return ens.llm.decode(ids)
|
| 213 |
-
|
| 214 |
-
qa = qa[: args.limit]
|
| 215 |
-
print(f"eval set: {len(qa)} questions | store: {len(ens.store.keys)} memories\n")
|
| 216 |
-
|
| 217 |
-
# ---- ablation ladder ----------------------------------------------------
|
| 218 |
-
configs = {
|
| 219 |
-
"C1_base": dict(use_rag=False, use_router=False, use_jepa=False),
|
| 220 |
-
"C2_rag": dict(use_rag=True, use_router=False, use_jepa=False),
|
| 221 |
-
"C3_rag_router": dict(use_rag=True, use_router=True, use_jepa=False),
|
| 222 |
-
"C4_full_jepa": dict(use_rag=True, use_router=True, use_jepa=True,
|
| 223 |
-
n_drafts=args.n_drafts),
|
| 224 |
-
}
|
| 225 |
-
|
| 226 |
-
per_q = {} # config -> list of EM scores (for bootstrap)
|
| 227 |
-
summary = {}
|
| 228 |
-
jepa_material = [] # (drafts, scores, gold) for selector comparison
|
| 229 |
-
|
| 230 |
-
for name, cfg in configs.items():
|
| 231 |
-
ems, f1s, lats = [], [], []
|
| 232 |
-
for item in qa:
|
| 233 |
-
ids = to_ids(item)
|
| 234 |
-
if name == "C1_base":
|
| 235 |
-
out, dt = generate_plain(ens, ids, args.n_new)
|
| 236 |
-
extra = None
|
| 237 |
-
else:
|
| 238 |
-
out, dt, extra = generate_config(ens, ids, args.n_new, **cfg)
|
| 239 |
-
pred, gold = decode(out), gold_text(item)
|
| 240 |
-
ems.append(em_score(pred, gold))
|
| 241 |
-
f1s.append(f1_score(pred, gold))
|
| 242 |
-
lats.append(dt)
|
| 243 |
-
if name == "C4_full_jepa" and extra is not None:
|
| 244 |
-
jepa_material.append((extra[0], extra[1], gold))
|
| 245 |
-
per_q[name] = ems
|
| 246 |
-
summary[name] = (sum(ems) / len(ems), sum(f1s) / len(f1s),
|
| 247 |
-
sum(lats) / len(lats))
|
| 248 |
-
|
| 249 |
-
print(f"{'config':<16}{'EM':>8}{'F1':>8}{'lat(s)':>9}")
|
| 250 |
-
for k, (em, f1, lat) in summary.items():
|
| 251 |
-
print(f"{k:<16}{em:>8.3f}{f1:>8.3f}{lat:>9.3f}")
|
| 252 |
-
|
| 253 |
-
# deltas + significance
|
| 254 |
-
print("\ncomponent contributions (paired bootstrap, P(B>A)):")
|
| 255 |
-
ladder = list(configs.keys())
|
| 256 |
-
for a, b in zip(ladder, ladder[1:]):
|
| 257 |
-
d = summary[b][0] - summary[a][0]
|
| 258 |
-
p = paired_bootstrap(per_q[a], per_q[b])
|
| 259 |
-
print(f" {b} - {a}: ΔEM={d:+.3f} P(better)={p:.2f}")
|
| 260 |
-
|
| 261 |
-
# ---- decisive JEPA selector experiment ----------------------------------
|
| 262 |
-
if jepa_material:
|
| 263 |
-
sel, sel_per_q = selector_comparison(jepa_material, decode, rng)
|
| 264 |
-
print("\nbest-of-N selector comparison (same drafts, N="
|
| 265 |
-
f"{args.n_drafts}):")
|
| 266 |
-
for k in ("first", "random", "jepa", "oracle"):
|
| 267 |
-
print(f" {k:<8}EM={sel[k]:.3f}")
|
| 268 |
-
p = paired_bootstrap(sel_per_q["random"], sel_per_q["jepa"])
|
| 269 |
-
print(f" P(jepa > random) = {p:.2f} "
|
| 270 |
-
f"{'JEPA critic WORKS' if p > 0.95 else 'inconclusive — critic ~ random'}")
|
| 271 |
-
gap = sel["oracle"] - sel["jepa"]
|
| 272 |
-
print(f" headroom to oracle: {gap:.3f}")
|
| 273 |
-
|
| 274 |
-
# ---- continual forgetting (optional) ------------------------------------
|
| 275 |
-
if args.continual:
|
| 276 |
-
print("\ncontinual test: accuracy on task-A questions "
|
| 277 |
-
"before vs after adding adapters B and C")
|
| 278 |
-
ems_before = per_q["C3_rag_router"]
|
| 279 |
-
ens.new_task_adapter("task_B")
|
| 280 |
-
ens.new_task_adapter("task_C")
|
| 281 |
-
ems_after = []
|
| 282 |
-
for item in qa:
|
| 283 |
-
out, _, _ = generate_config(ens, to_ids(item), args.n_new,
|
| 284 |
-
use_rag=True, use_router=True,
|
| 285 |
-
use_jepa=False)
|
| 286 |
-
ems_after.append(em_score(decode(out), gold_text(item)))
|
| 287 |
-
bt = sum(ems_after) / len(ems_after) - sum(ems_before) / len(ems_before)
|
| 288 |
-
print(f" backward transfer (≈0 is ideal): {bt:+.3f}")
|
| 289 |
-
|
| 290 |
-
return summary
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
def parse_args():
|
| 294 |
-
p = argparse.ArgumentParser()
|
| 295 |
-
p.add_argument("--llm", default="tiny", help="'tiny' | HF id | local path")
|
| 296 |
-
p.add_argument("--qa", default=None, help="jsonl with question/answer")
|
| 297 |
-
p.add_argument("--kb", default=None, help="jsonl with text -> vector store")
|
| 298 |
-
p.add_argument("--ckpt", default=None, help="bridge-trained ensemble .pt (C5)")
|
| 299 |
-
p.add_argument("--toy", action="store_true", help="synthetic data smoke test")
|
| 300 |
-
p.add_argument("--limit", type=int, default=100)
|
| 301 |
-
p.add_argument("--n_new", type=int, default=24)
|
| 302 |
-
p.add_argument("--n_drafts", type=int, default=8)
|
| 303 |
-
p.add_argument("--continual", action="store_true")
|
| 304 |
-
p.add_argument("--seed", type=int, default=0)
|
| 305 |
-
return p.parse_args()
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
if __name__ == "__main__":
|
| 309 |
-
run(parse_args())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/jepa.py
DELETED
|
@@ -1,75 +0,0 @@
|
|
| 1 |
-
"""JEPA latent predictor with EMA target encoder."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import copy
|
| 6 |
-
|
| 7 |
-
import torch
|
| 8 |
-
import torch.nn as nn
|
| 9 |
-
import torch.nn.functional as F
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
class _SegEncoder(nn.Module):
|
| 13 |
-
def __init__(self, vocab_size, d):
|
| 14 |
-
super().__init__()
|
| 15 |
-
self.tok = nn.Embedding(vocab_size, d)
|
| 16 |
-
self.enc = nn.GRU(d, d, batch_first=True)
|
| 17 |
-
self.out = nn.Linear(d, d)
|
| 18 |
-
|
| 19 |
-
def forward(self, ids):
|
| 20 |
-
h, _ = self.enc(self.tok(ids))
|
| 21 |
-
return self.out(h.mean(dim=1))
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
class JEPA(nn.Module):
|
| 25 |
-
def __init__(self, vocab_size: int, d_latent: int = 64, ema_m: float = 0.996):
|
| 26 |
-
super().__init__()
|
| 27 |
-
self.ctx_enc = _SegEncoder(vocab_size, d_latent)
|
| 28 |
-
self.tgt_enc = copy.deepcopy(self.ctx_enc)
|
| 29 |
-
for p in self.tgt_enc.parameters():
|
| 30 |
-
p.requires_grad_(False)
|
| 31 |
-
self.predictor = nn.Sequential(
|
| 32 |
-
nn.Linear(d_latent, 2 * d_latent),
|
| 33 |
-
nn.GELU(),
|
| 34 |
-
nn.Linear(2 * d_latent, d_latent),
|
| 35 |
-
)
|
| 36 |
-
self.m = ema_m
|
| 37 |
-
self.d_latent = d_latent
|
| 38 |
-
|
| 39 |
-
@property
|
| 40 |
-
def enc(self):
|
| 41 |
-
"""Alias used by world-model track."""
|
| 42 |
-
return self.ctx_enc
|
| 43 |
-
|
| 44 |
-
@property
|
| 45 |
-
def tgt(self):
|
| 46 |
-
return self.tgt_enc
|
| 47 |
-
|
| 48 |
-
@property
|
| 49 |
-
def pred(self):
|
| 50 |
-
return self.predictor
|
| 51 |
-
|
| 52 |
-
@torch.no_grad()
|
| 53 |
-
def ema_update(self):
|
| 54 |
-
for p_t, p_c in zip(self.tgt_enc.parameters(), self.ctx_enc.parameters()):
|
| 55 |
-
p_t.mul_(self.m).add_(p_c.detach(), alpha=1 - self.m)
|
| 56 |
-
|
| 57 |
-
def ema(self):
|
| 58 |
-
"""Alias used by world-model track."""
|
| 59 |
-
self.ema_update()
|
| 60 |
-
|
| 61 |
-
def loss(self, seg_ctx, seg_tgt):
|
| 62 |
-
z_hat = self.predictor(self.ctx_enc(seg_ctx))
|
| 63 |
-
with torch.no_grad():
|
| 64 |
-
z_tgt = self.tgt_enc(seg_tgt)
|
| 65 |
-
pred = F.mse_loss(z_hat, z_tgt)
|
| 66 |
-
var_reg = F.relu(1.0 - z_hat.std(dim=0)).mean()
|
| 67 |
-
return pred + 0.5 * var_reg
|
| 68 |
-
|
| 69 |
-
@torch.no_grad()
|
| 70 |
-
def predict_next_latent(self, seg_ctx):
|
| 71 |
-
return self.predictor(self.ctx_enc(seg_ctx))
|
| 72 |
-
|
| 73 |
-
@torch.no_grad()
|
| 74 |
-
def encode(self, seg):
|
| 75 |
-
return self.tgt_enc(seg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/jepa_ensemble.py
DELETED
|
@@ -1,232 +0,0 @@
|
|
| 1 |
-
"""JEPA ensemble: route -> retrieve -> generate -> JEPA-verify."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
|
| 9 |
-
from ensemble.backends import HFBackend, make_backend
|
| 10 |
-
from ensemble.bridge import Bridge
|
| 11 |
-
from ensemble.jepa import JEPA
|
| 12 |
-
from ensemble.memory import Embedder, Router, VectorStore
|
| 13 |
-
|
| 14 |
-
torch.manual_seed(0)
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class Ensemble(nn.Module):
|
| 18 |
-
def __init__(
|
| 19 |
-
self,
|
| 20 |
-
llm: str = "tiny",
|
| 21 |
-
adapter_names=("general",),
|
| 22 |
-
d_emb: int = 64,
|
| 23 |
-
d_jepa: int = 64,
|
| 24 |
-
llm_backend: HFBackend | None = None,
|
| 25 |
-
**backend_kw,
|
| 26 |
-
):
|
| 27 |
-
super().__init__()
|
| 28 |
-
self.llm = llm_backend if llm_backend is not None else make_backend(llm, **backend_kw)
|
| 29 |
-
V, H = self.llm.vocab_size, self.llm.hidden_size
|
| 30 |
-
|
| 31 |
-
self.emb = Embedder(V, d_emb)
|
| 32 |
-
self.jepa = JEPA(V, d_jepa)
|
| 33 |
-
self.bridge = Bridge(H, d_jepa)
|
| 34 |
-
self.store = VectorStore()
|
| 35 |
-
|
| 36 |
-
self.adapter_names = list(adapter_names)
|
| 37 |
-
for n in self.adapter_names:
|
| 38 |
-
self.llm.add_adapter(n)
|
| 39 |
-
self.llm.set_adapter(self.adapter_names[0])
|
| 40 |
-
self.router = Router(d_emb, len(self.adapter_names))
|
| 41 |
-
|
| 42 |
-
@torch.no_grad()
|
| 43 |
-
def answer_ids(
|
| 44 |
-
self,
|
| 45 |
-
query_ids,
|
| 46 |
-
n_new=32,
|
| 47 |
-
tau_consistency=0.0,
|
| 48 |
-
max_retries=2,
|
| 49 |
-
temperature: float = 0.7,
|
| 50 |
-
):
|
| 51 |
-
q_emb = self.emb(query_ids.cpu())
|
| 52 |
-
a_idx = self.router(q_emb).item()
|
| 53 |
-
self.llm.set_adapter(self.adapter_names[a_idx])
|
| 54 |
-
|
| 55 |
-
mems = self.store.search(q_emb, k=1)
|
| 56 |
-
ctx = (
|
| 57 |
-
torch.cat([mems[0], query_ids.cpu()], dim=1)
|
| 58 |
-
if mems
|
| 59 |
-
else query_ids.cpu()
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
z_expected = self.jepa.predict_next_latent(ctx)
|
| 63 |
-
|
| 64 |
-
best = None
|
| 65 |
-
for attempt in range(max_retries + 1):
|
| 66 |
-
temp = temperature if attempt == 0 else max(temperature, 0.8 + 0.3 * attempt)
|
| 67 |
-
draft = self.llm.generate(
|
| 68 |
-
ctx.to(self.llm.device),
|
| 69 |
-
n_new=n_new,
|
| 70 |
-
temperature=temp,
|
| 71 |
-
)
|
| 72 |
-
new_part = draft[:, ctx.size(1) :].cpu()
|
| 73 |
-
score = F.cosine_similarity(
|
| 74 |
-
z_expected, self.jepa.encode(new_part)
|
| 75 |
-
).item()
|
| 76 |
-
if best is None or score > best[1]:
|
| 77 |
-
best = (draft, score, attempt)
|
| 78 |
-
if score >= tau_consistency:
|
| 79 |
-
break
|
| 80 |
-
draft, score, attempt = best
|
| 81 |
-
return draft, score, self.adapter_names[a_idx], attempt
|
| 82 |
-
|
| 83 |
-
def answer_text(self, prompt: str, **kw):
|
| 84 |
-
ids = self.llm.encode_text(prompt)
|
| 85 |
-
out, score, adapter, retries = self.answer_ids(ids, **kw)
|
| 86 |
-
return self.llm.decode(out), score, adapter, retries
|
| 87 |
-
|
| 88 |
-
def generate_text(
|
| 89 |
-
self,
|
| 90 |
-
prompt: str,
|
| 91 |
-
*,
|
| 92 |
-
max_new_tokens: int = 512,
|
| 93 |
-
temperature: float = 0.0,
|
| 94 |
-
) -> str:
|
| 95 |
-
"""Greedy or sampled generation through the full ensemble stack."""
|
| 96 |
-
ids = self.llm.encode_text(prompt)
|
| 97 |
-
out, _, _, _ = self.answer_ids(
|
| 98 |
-
ids,
|
| 99 |
-
n_new=max_new_tokens,
|
| 100 |
-
tau_consistency=-1.0,
|
| 101 |
-
max_retries=0 if temperature <= 0 else 1,
|
| 102 |
-
temperature=temperature,
|
| 103 |
-
)
|
| 104 |
-
return self.llm.decode(out)
|
| 105 |
-
|
| 106 |
-
def memorize_ids(self, ids):
|
| 107 |
-
self.store.add(self.emb(ids.cpu()), ids.cpu())
|
| 108 |
-
|
| 109 |
-
def memorize_text(self, text: str):
|
| 110 |
-
self.memorize_ids(self.llm.encode_text(text))
|
| 111 |
-
|
| 112 |
-
def new_task_adapter(self, name: str):
|
| 113 |
-
self.adapter_names.append(name)
|
| 114 |
-
self.llm.add_adapter(name)
|
| 115 |
-
old = self.router
|
| 116 |
-
self.router = Router(self.emb.d_emb, len(self.adapter_names))
|
| 117 |
-
with torch.no_grad():
|
| 118 |
-
self.router.fc.weight[: old.fc.out_features] = old.fc.weight
|
| 119 |
-
self.router.fc.bias[: old.fc.out_features] = old.fc.bias
|
| 120 |
-
|
| 121 |
-
def train_step(self, seg_a, seg_b, opt, w_bridge=0.1):
|
| 122 |
-
logits, hidden = self.llm(seg_a.to(self.llm.device))
|
| 123 |
-
lm_loss = F.cross_entropy(
|
| 124 |
-
logits[:, :-1].reshape(-1, self.llm.vocab_size).float(),
|
| 125 |
-
seg_a[:, 1:].reshape(-1).to(logits.device),
|
| 126 |
-
)
|
| 127 |
-
|
| 128 |
-
jepa_loss = self.jepa.loss(seg_a.cpu(), seg_b.cpu())
|
| 129 |
-
|
| 130 |
-
z_llm = self.bridge(
|
| 131 |
-
hidden.cpu() if hidden.device.type != "cpu" else hidden
|
| 132 |
-
)
|
| 133 |
-
z_jepa = self.jepa.ctx_enc(seg_a.cpu()).detach()
|
| 134 |
-
bridge_loss = self.bridge.info_nce(z_llm, z_jepa.to(z_llm.device))
|
| 135 |
-
|
| 136 |
-
loss = lm_loss.cpu() + jepa_loss + w_bridge * bridge_loss
|
| 137 |
-
opt.zero_grad()
|
| 138 |
-
loss.backward()
|
| 139 |
-
opt.step()
|
| 140 |
-
self.jepa.ema_update()
|
| 141 |
-
return {
|
| 142 |
-
"lm": lm_loss.item(),
|
| 143 |
-
"jepa": jepa_loss.item(),
|
| 144 |
-
"bridge": bridge_loss.item(),
|
| 145 |
-
}
|
| 146 |
-
|
| 147 |
-
def make_optimizer(self, lr_lora=2e-4, lr_aux=1e-3):
|
| 148 |
-
return torch.optim.AdamW(
|
| 149 |
-
[
|
| 150 |
-
{"params": list(self.llm.trainable_parameters()), "lr": lr_lora},
|
| 151 |
-
{
|
| 152 |
-
"params": list(self.jepa.ctx_enc.parameters())
|
| 153 |
-
+ list(self.jepa.predictor.parameters()),
|
| 154 |
-
"lr": lr_aux,
|
| 155 |
-
},
|
| 156 |
-
{
|
| 157 |
-
"params": list(self.bridge.parameters())
|
| 158 |
-
+ list(self.emb.parameters())
|
| 159 |
-
+ list(self.router.parameters()),
|
| 160 |
-
"lr": lr_aux,
|
| 161 |
-
},
|
| 162 |
-
]
|
| 163 |
-
)
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
def segment_pairs_from_texts(backend: HFBackend, texts, seg_len=64):
|
| 167 |
-
a_list, b_list = [], []
|
| 168 |
-
for t in texts:
|
| 169 |
-
ids = backend.tokenizer(t, return_tensors="pt").input_ids[0]
|
| 170 |
-
for i in range(0, len(ids) - 2 * seg_len, seg_len):
|
| 171 |
-
a_list.append(ids[i : i + seg_len])
|
| 172 |
-
b_list.append(ids[i + seg_len : i + 2 * seg_len])
|
| 173 |
-
if not a_list:
|
| 174 |
-
raise ValueError("texts too short for the chosen seg_len")
|
| 175 |
-
return torch.stack(a_list), torch.stack(b_list)
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
def demo_tiny(steps=50):
|
| 179 |
-
ens = Ensemble(llm="tiny")
|
| 180 |
-
opt = ens.make_optimizer()
|
| 181 |
-
for s in range(steps):
|
| 182 |
-
seg_a = torch.randint(0, ens.llm.vocab_size, (8, 32))
|
| 183 |
-
seg_b = torch.randint(0, ens.llm.vocab_size, (8, 32))
|
| 184 |
-
logs = ens.train_step(seg_a, seg_b, opt)
|
| 185 |
-
if s % 10 == 0:
|
| 186 |
-
print(
|
| 187 |
-
f"step {s:3d} | "
|
| 188 |
-
+ " | ".join(f"{k} {v:.3f}" for k, v in logs.items())
|
| 189 |
-
)
|
| 190 |
-
|
| 191 |
-
for _ in range(5):
|
| 192 |
-
ens.memorize_ids(torch.randint(0, ens.llm.vocab_size, (1, 32)))
|
| 193 |
-
ens.new_task_adapter("medical")
|
| 194 |
-
|
| 195 |
-
q = torch.randint(0, ens.llm.vocab_size, (1, 8))
|
| 196 |
-
out, score, adapter, retries = ens.answer_ids(q, tau_consistency=-1.0)
|
| 197 |
-
print(f"\nadapter={adapter} jepa_consistency={score:.3f} retries={retries}")
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
def demo_hf(model_path="Qwen/Qwen2.5-0.5B-Instruct"):
|
| 201 |
-
ens = Ensemble(llm=model_path, load_in_4bit=False)
|
| 202 |
-
opt = ens.make_optimizer()
|
| 203 |
-
|
| 204 |
-
texts = ["Replace this with your real corpus. " * 50]
|
| 205 |
-
seg_a, seg_b = segment_pairs_from_texts(ens.llm, texts, seg_len=32)
|
| 206 |
-
for s in range(10):
|
| 207 |
-
logs = ens.train_step(seg_a[:4], seg_b[:4], opt)
|
| 208 |
-
print(f"step {s} | " + " | ".join(f"{k} {v:.3f}" for k, v in logs.items()))
|
| 209 |
-
|
| 210 |
-
ens.memorize_text("The project codename is AURORA and it ships in Q3.")
|
| 211 |
-
ens.new_task_adapter("project_aurora")
|
| 212 |
-
|
| 213 |
-
text, score, adapter, retries = ens.answer_text(
|
| 214 |
-
"What is the project codename?", n_new=24, tau_consistency=-1.0
|
| 215 |
-
)
|
| 216 |
-
print(f"\n[{adapter} | jepa={score:.3f} | retries={retries}]\n{text}")
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
if __name__ == "__main__":
|
| 220 |
-
import sys
|
| 221 |
-
|
| 222 |
-
from ensemble.config import load_dotenv, resolve_llm
|
| 223 |
-
|
| 224 |
-
load_dotenv()
|
| 225 |
-
arg = sys.argv[1] if len(sys.argv) > 1 else None
|
| 226 |
-
if arg is None or arg == "auto":
|
| 227 |
-
arg, preset = resolve_llm()
|
| 228 |
-
print(f"Resolved LLM: {arg} (preset {preset})")
|
| 229 |
-
if arg == "tiny":
|
| 230 |
-
demo_tiny()
|
| 231 |
-
else:
|
| 232 |
-
demo_hf(arg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/llm_emb_jepa_ensemble_pluggable.py
DELETED
|
@@ -1,507 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
LLM + Embedding + JEPA Ensemble — pluggable base-model edition
|
| 3 |
-
==============================================================
|
| 4 |
-
Now the LLM is a swappable BACKEND. Three ways to load it:
|
| 5 |
-
|
| 6 |
-
# 1. HuggingFace Hub id
|
| 7 |
-
ens = Ensemble(llm="Qwen/Qwen2.5-0.5B-Instruct")
|
| 8 |
-
|
| 9 |
-
# 2. Local path (e.g. downloaded Llama / converted checkpoint)
|
| 10 |
-
ens = Ensemble(llm="/models/llama-3.2-1b")
|
| 11 |
-
|
| 12 |
-
# 3. Toy fallback (no transformers needed, runs on CPU in seconds)
|
| 13 |
-
ens = Ensemble(llm="tiny")
|
| 14 |
-
|
| 15 |
-
Requirements for real models:
|
| 16 |
-
pip install torch transformers peft accelerate
|
| 17 |
-
(optional 4-bit: pip install bitsandbytes -> load_in_4bit=True)
|
| 18 |
-
|
| 19 |
-
Everything else (Embedder, JEPA, Bridge, VectorStore, Router, the
|
| 20 |
-
JEPA-critic inference loop, continual-learning hooks) only touches
|
| 21 |
-
token ids / hidden states / latents, so it works with ANY backend.
|
| 22 |
-
"""
|
| 23 |
-
|
| 24 |
-
from __future__ import annotations
|
| 25 |
-
import copy
|
| 26 |
-
import torch
|
| 27 |
-
import torch.nn as nn
|
| 28 |
-
import torch.nn.functional as F
|
| 29 |
-
|
| 30 |
-
torch.manual_seed(0)
|
| 31 |
-
|
| 32 |
-
# ----------------------------------------------------------------------------
|
| 33 |
-
# 0. Backend interface — everything the ensemble needs from "an LLM"
|
| 34 |
-
# ----------------------------------------------------------------------------
|
| 35 |
-
class LLMBackend(nn.Module):
|
| 36 |
-
"""Contract:
|
| 37 |
-
vocab_size : int
|
| 38 |
-
hidden_size: int
|
| 39 |
-
device : torch.device
|
| 40 |
-
forward(ids) -> (logits [B,T,V], hidden [B,T,H])
|
| 41 |
-
generate(ids, n_new) -> ids [B, T+n_new]
|
| 42 |
-
add_adapter(name) / set_adapter(name)
|
| 43 |
-
trainable_parameters() -> iterable of params to optimize
|
| 44 |
-
encode_text(str) / decode(ids) (real backends only)
|
| 45 |
-
"""
|
| 46 |
-
vocab_size: int
|
| 47 |
-
hidden_size: int
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
# ----------------------------------------------------------------------------
|
| 51 |
-
# 0a. HuggingFace backend (local path OR hub id) with PEFT LoRA adapters
|
| 52 |
-
# ----------------------------------------------------------------------------
|
| 53 |
-
class HFBackend(LLMBackend):
|
| 54 |
-
def __init__(self, model_path: str, *, load_in_4bit: bool = False,
|
| 55 |
-
lora_r: int = 16, lora_alpha: int = 32,
|
| 56 |
-
target_modules=("q_proj", "v_proj"),
|
| 57 |
-
device: str | None = None, torch_dtype=None):
|
| 58 |
-
super().__init__()
|
| 59 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 60 |
-
from peft import LoraConfig, get_peft_model
|
| 61 |
-
|
| 62 |
-
self.device_ = torch.device(
|
| 63 |
-
device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 64 |
-
|
| 65 |
-
kwargs = {}
|
| 66 |
-
if load_in_4bit:
|
| 67 |
-
from transformers import BitsAndBytesConfig
|
| 68 |
-
kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 69 |
-
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16,
|
| 70 |
-
bnb_4bit_quant_type="nf4")
|
| 71 |
-
if torch_dtype is not None:
|
| 72 |
-
kwargs["torch_dtype"] = torch_dtype
|
| 73 |
-
|
| 74 |
-
# `model_path` may be "Qwen/Qwen2.5-0.5B-Instruct", "meta-llama/...",
|
| 75 |
-
# or a local directory like "/models/llama-3.2-1b".
|
| 76 |
-
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 77 |
-
if self.tokenizer.pad_token is None:
|
| 78 |
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 79 |
-
base = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
|
| 80 |
-
if not load_in_4bit:
|
| 81 |
-
base.to(self.device_)
|
| 82 |
-
|
| 83 |
-
# Freeze the base; all learning happens in LoRA adapters.
|
| 84 |
-
for p in base.parameters():
|
| 85 |
-
p.requires_grad_(False)
|
| 86 |
-
|
| 87 |
-
self._lora_cfg = LoraConfig(
|
| 88 |
-
r=lora_r, lora_alpha=lora_alpha, lora_dropout=0.05,
|
| 89 |
-
target_modules=list(target_modules), task_type="CAUSAL_LM")
|
| 90 |
-
self.model = get_peft_model(base, self._lora_cfg, adapter_name="general")
|
| 91 |
-
self._adapters = {"general"}
|
| 92 |
-
|
| 93 |
-
self.vocab_size = self.model.config.vocab_size
|
| 94 |
-
self.hidden_size = self.model.config.hidden_size
|
| 95 |
-
|
| 96 |
-
# ---- adapters -----------------------------------------------------------
|
| 97 |
-
def add_adapter(self, name: str):
|
| 98 |
-
if name not in self._adapters:
|
| 99 |
-
self.model.add_adapter(name, self._lora_cfg)
|
| 100 |
-
self._adapters.add(name)
|
| 101 |
-
|
| 102 |
-
def set_adapter(self, name: str):
|
| 103 |
-
self.model.set_adapter(name)
|
| 104 |
-
|
| 105 |
-
def trainable_parameters(self):
|
| 106 |
-
return (p for p in self.model.parameters() if p.requires_grad)
|
| 107 |
-
|
| 108 |
-
# ---- core ops -----------------------------------------------------------
|
| 109 |
-
def forward(self, ids):
|
| 110 |
-
out = self.model(input_ids=ids.to(self.device_),
|
| 111 |
-
output_hidden_states=True)
|
| 112 |
-
return out.logits, out.hidden_states[-1] # last layer hidden
|
| 113 |
-
|
| 114 |
-
@torch.no_grad()
|
| 115 |
-
def generate(self, ids, n_new=64, temperature=0.8):
|
| 116 |
-
out = self.model.generate(
|
| 117 |
-
input_ids=ids.to(self.device_),
|
| 118 |
-
max_new_tokens=n_new, do_sample=True, temperature=temperature,
|
| 119 |
-
pad_token_id=self.tokenizer.pad_token_id)
|
| 120 |
-
return out
|
| 121 |
-
|
| 122 |
-
# ---- text helpers -------------------------------------------------------
|
| 123 |
-
def encode_text(self, text: str):
|
| 124 |
-
return self.tokenizer(text, return_tensors="pt").input_ids.to(self.device_)
|
| 125 |
-
|
| 126 |
-
def decode(self, ids):
|
| 127 |
-
return self.tokenizer.decode(ids[0], skip_special_tokens=True)
|
| 128 |
-
|
| 129 |
-
@property
|
| 130 |
-
def device(self):
|
| 131 |
-
return self.device_
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
# ----------------------------------------------------------------------------
|
| 135 |
-
# 0b. Tiny fallback backend (no transformers; same toy model as before)
|
| 136 |
-
# ----------------------------------------------------------------------------
|
| 137 |
-
class TinyBackend(LLMBackend):
|
| 138 |
-
VOCAB, D_MODEL, N_LAYERS, N_HEADS, SEQ_LEN, LORA_R = 1000, 128, 2, 4, 32, 8
|
| 139 |
-
|
| 140 |
-
class _LoRALinear(nn.Module):
|
| 141 |
-
def __init__(self, d_in, d_out, r):
|
| 142 |
-
super().__init__()
|
| 143 |
-
self.base = nn.Linear(d_in, d_out)
|
| 144 |
-
self.base.weight.requires_grad_(False)
|
| 145 |
-
self.base.bias.requires_grad_(False)
|
| 146 |
-
self.adapters, self.active, self.r = nn.ModuleDict(), None, r
|
| 147 |
-
|
| 148 |
-
def add_adapter(self, name):
|
| 149 |
-
A = nn.Linear(self.base.in_features, self.r, bias=False)
|
| 150 |
-
B = nn.Linear(self.r, self.base.out_features, bias=False)
|
| 151 |
-
nn.init.zeros_(B.weight)
|
| 152 |
-
self.adapters[name] = nn.Sequential(A, B)
|
| 153 |
-
|
| 154 |
-
def forward(self, x):
|
| 155 |
-
y = self.base(x)
|
| 156 |
-
if self.active and self.active in self.adapters:
|
| 157 |
-
y = y + self.adapters[self.active](x)
|
| 158 |
-
return y
|
| 159 |
-
|
| 160 |
-
class _Block(nn.Module):
|
| 161 |
-
def __init__(self, D, H, R):
|
| 162 |
-
super().__init__()
|
| 163 |
-
L = TinyBackend._LoRALinear
|
| 164 |
-
self.ln1 = nn.LayerNorm(D)
|
| 165 |
-
self.attn = nn.MultiheadAttention(D, H, batch_first=True)
|
| 166 |
-
self.ln2 = nn.LayerNorm(D)
|
| 167 |
-
self.up, self.down = L(D, 4 * D, R), L(4 * D, D, R)
|
| 168 |
-
|
| 169 |
-
def forward(self, x, mask):
|
| 170 |
-
h = self.ln1(x)
|
| 171 |
-
a, _ = self.attn(h, h, h, attn_mask=mask, need_weights=False)
|
| 172 |
-
x = x + a
|
| 173 |
-
return x + self.down(F.gelu(self.up(self.ln2(x))))
|
| 174 |
-
|
| 175 |
-
def __init__(self):
|
| 176 |
-
super().__init__()
|
| 177 |
-
D, V = self.D_MODEL, self.VOCAB
|
| 178 |
-
self.tok = nn.Embedding(V, D)
|
| 179 |
-
self.pos = nn.Embedding(self.SEQ_LEN * 4, D)
|
| 180 |
-
self.blocks = nn.ModuleList(
|
| 181 |
-
[self._Block(D, self.N_HEADS, self.LORA_R) for _ in range(self.N_LAYERS)])
|
| 182 |
-
self.ln_f, self.head = nn.LayerNorm(D), nn.Linear(D, V, bias=False)
|
| 183 |
-
self.vocab_size, self.hidden_size = V, D
|
| 184 |
-
self.add_adapter("general")
|
| 185 |
-
self.set_adapter("general")
|
| 186 |
-
|
| 187 |
-
def add_adapter(self, name):
|
| 188 |
-
for b in self.blocks:
|
| 189 |
-
b.up.add_adapter(name); b.down.add_adapter(name)
|
| 190 |
-
|
| 191 |
-
def set_adapter(self, name):
|
| 192 |
-
for b in self.blocks:
|
| 193 |
-
b.up.active = name; b.down.active = name
|
| 194 |
-
|
| 195 |
-
def trainable_parameters(self):
|
| 196 |
-
return (p for p in self.parameters() if p.requires_grad)
|
| 197 |
-
|
| 198 |
-
def forward(self, ids):
|
| 199 |
-
B, T = ids.shape
|
| 200 |
-
x = self.tok(ids) + self.pos(torch.arange(T, device=ids.device))
|
| 201 |
-
mask = torch.triu(torch.full((T, T), float("-inf"), device=ids.device), 1)
|
| 202 |
-
for b in self.blocks:
|
| 203 |
-
x = b(x, mask)
|
| 204 |
-
h = self.ln_f(x)
|
| 205 |
-
return self.head(h), h
|
| 206 |
-
|
| 207 |
-
@torch.no_grad()
|
| 208 |
-
def generate(self, ids, n_new=16, temperature=1.0):
|
| 209 |
-
for _ in range(n_new):
|
| 210 |
-
logits, _ = self(ids[:, -self.SEQ_LEN:])
|
| 211 |
-
nxt = torch.multinomial(F.softmax(logits[:, -1] / temperature, -1), 1)
|
| 212 |
-
ids = torch.cat([ids, nxt], dim=1)
|
| 213 |
-
return ids
|
| 214 |
-
|
| 215 |
-
@property
|
| 216 |
-
def device(self):
|
| 217 |
-
return next(self.parameters()).device
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
def make_backend(llm: str, **kw) -> LLMBackend:
|
| 221 |
-
"""'tiny' -> toy model; anything else -> HF hub id or local path."""
|
| 222 |
-
return TinyBackend() if llm == "tiny" else HFBackend(llm, **kw)
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
# ----------------------------------------------------------------------------
|
| 226 |
-
# 1. Embedder — vocab-agnostic (sized from the backend's tokenizer)
|
| 227 |
-
# Swap for a real model: pass embed_fn=lambda txt: sbert.encode(...)
|
| 228 |
-
# ----------------------------------------------------------------------------
|
| 229 |
-
class Embedder(nn.Module):
|
| 230 |
-
def __init__(self, vocab_size: int, d_emb: int = 64):
|
| 231 |
-
super().__init__()
|
| 232 |
-
self.tok = nn.Embedding(vocab_size, d_emb)
|
| 233 |
-
self.enc = nn.GRU(d_emb, d_emb, batch_first=True, bidirectional=True)
|
| 234 |
-
self.proj = nn.Linear(2 * d_emb, d_emb)
|
| 235 |
-
self.d_emb = d_emb
|
| 236 |
-
|
| 237 |
-
def forward(self, ids):
|
| 238 |
-
h, _ = self.enc(self.tok(ids))
|
| 239 |
-
return F.normalize(self.proj(h.mean(dim=1)), dim=-1)
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
# ----------------------------------------------------------------------------
|
| 243 |
-
# 2. JEPA — vocab-agnostic latent predictor with EMA target encoder
|
| 244 |
-
# ----------------------------------------------------------------------------
|
| 245 |
-
class _JEPAEncoder(nn.Module):
|
| 246 |
-
def __init__(self, vocab_size, d):
|
| 247 |
-
super().__init__()
|
| 248 |
-
self.tok = nn.Embedding(vocab_size, d)
|
| 249 |
-
self.enc = nn.GRU(d, d, batch_first=True)
|
| 250 |
-
self.out = nn.Linear(d, d)
|
| 251 |
-
|
| 252 |
-
def forward(self, ids):
|
| 253 |
-
h, _ = self.enc(self.tok(ids))
|
| 254 |
-
return self.out(h.mean(dim=1))
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
class JEPA(nn.Module):
|
| 258 |
-
def __init__(self, vocab_size: int, d_jepa: int = 64, ema_m: float = 0.996):
|
| 259 |
-
super().__init__()
|
| 260 |
-
self.ctx_enc = _JEPAEncoder(vocab_size, d_jepa)
|
| 261 |
-
self.tgt_enc = copy.deepcopy(self.ctx_enc)
|
| 262 |
-
for p in self.tgt_enc.parameters():
|
| 263 |
-
p.requires_grad_(False)
|
| 264 |
-
self.predictor = nn.Sequential(
|
| 265 |
-
nn.Linear(d_jepa, 2 * d_jepa), nn.GELU(), nn.Linear(2 * d_jepa, d_jepa))
|
| 266 |
-
self.m, self.d_jepa = ema_m, d_jepa
|
| 267 |
-
|
| 268 |
-
@torch.no_grad()
|
| 269 |
-
def ema_update(self):
|
| 270 |
-
for p_t, p_c in zip(self.tgt_enc.parameters(), self.ctx_enc.parameters()):
|
| 271 |
-
p_t.mul_(self.m).add_(p_c.detach(), alpha=1 - self.m)
|
| 272 |
-
|
| 273 |
-
def loss(self, seg_ctx, seg_tgt):
|
| 274 |
-
z_hat = self.predictor(self.ctx_enc(seg_ctx))
|
| 275 |
-
with torch.no_grad():
|
| 276 |
-
z_tgt = self.tgt_enc(seg_tgt)
|
| 277 |
-
pred = F.mse_loss(z_hat, z_tgt)
|
| 278 |
-
var_reg = F.relu(1.0 - z_hat.std(dim=0)).mean() # anti-collapse
|
| 279 |
-
return pred + 0.5 * var_reg
|
| 280 |
-
|
| 281 |
-
@torch.no_grad()
|
| 282 |
-
def predict_next_latent(self, seg_ctx):
|
| 283 |
-
return self.predictor(self.ctx_enc(seg_ctx))
|
| 284 |
-
|
| 285 |
-
@torch.no_grad()
|
| 286 |
-
def encode(self, seg):
|
| 287 |
-
return self.tgt_enc(seg)
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
# ----------------------------------------------------------------------------
|
| 291 |
-
# 3. Bridge — sized from backend.hidden_size at construction
|
| 292 |
-
# ----------------------------------------------------------------------------
|
| 293 |
-
class Bridge(nn.Module):
|
| 294 |
-
def __init__(self, d_llm_hidden: int, d_jepa: int):
|
| 295 |
-
super().__init__()
|
| 296 |
-
self.proj = nn.Sequential(
|
| 297 |
-
nn.Linear(d_llm_hidden, d_jepa), nn.GELU(), nn.Linear(d_jepa, d_jepa))
|
| 298 |
-
|
| 299 |
-
def forward(self, llm_hidden): # [B,T,H] -> [B,d_jepa]
|
| 300 |
-
return self.proj(llm_hidden.float().mean(dim=1))
|
| 301 |
-
|
| 302 |
-
def info_nce(self, z1, z2, tau=0.07):
|
| 303 |
-
z1, z2 = F.normalize(z1, dim=-1), F.normalize(z2, dim=-1)
|
| 304 |
-
logits = z1 @ z2.t() / tau
|
| 305 |
-
labels = torch.arange(z1.size(0), device=z1.device)
|
| 306 |
-
return 0.5 * (F.cross_entropy(logits, labels) +
|
| 307 |
-
F.cross_entropy(logits.t(), labels))
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
# ----------------------------------------------------------------------------
|
| 311 |
-
# 4. Memory + Router
|
| 312 |
-
# ----------------------------------------------------------------------------
|
| 313 |
-
class VectorStore:
|
| 314 |
-
def __init__(self):
|
| 315 |
-
self.keys, self.values = [], []
|
| 316 |
-
|
| 317 |
-
def add(self, emb, payload):
|
| 318 |
-
self.keys.append(emb.squeeze(0).detach().cpu())
|
| 319 |
-
self.values.append(payload)
|
| 320 |
-
|
| 321 |
-
def search(self, q, k=2):
|
| 322 |
-
if not self.keys:
|
| 323 |
-
return []
|
| 324 |
-
K = torch.stack(self.keys)
|
| 325 |
-
sims = (q.detach().cpu() @ K.t()).squeeze(0)
|
| 326 |
-
top = sims.topk(min(k, len(self.keys))).indices
|
| 327 |
-
return [self.values[i] for i in top]
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
class Router(nn.Module):
|
| 331 |
-
def __init__(self, d_emb, n_adapters):
|
| 332 |
-
super().__init__()
|
| 333 |
-
self.fc = nn.Linear(d_emb, n_adapters)
|
| 334 |
-
|
| 335 |
-
def forward(self, emb):
|
| 336 |
-
return self.fc(emb).argmax(dim=-1)
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
# ----------------------------------------------------------------------------
|
| 340 |
-
# 5. Ensemble — backend-agnostic
|
| 341 |
-
# ----------------------------------------------------------------------------
|
| 342 |
-
class Ensemble(nn.Module):
|
| 343 |
-
def __init__(self, llm: str = "tiny", adapter_names=("general",),
|
| 344 |
-
d_emb: int = 64, d_jepa: int = 64, **backend_kw):
|
| 345 |
-
super().__init__()
|
| 346 |
-
self.llm = make_backend(llm, **backend_kw)
|
| 347 |
-
V, H = self.llm.vocab_size, self.llm.hidden_size
|
| 348 |
-
|
| 349 |
-
self.emb = Embedder(V, d_emb)
|
| 350 |
-
self.jepa = JEPA(V, d_jepa)
|
| 351 |
-
self.bridge = Bridge(H, d_jepa)
|
| 352 |
-
self.store = VectorStore()
|
| 353 |
-
|
| 354 |
-
self.adapter_names = list(adapter_names)
|
| 355 |
-
for n in self.adapter_names:
|
| 356 |
-
self.llm.add_adapter(n)
|
| 357 |
-
self.llm.set_adapter(self.adapter_names[0])
|
| 358 |
-
self.router = Router(d_emb, len(self.adapter_names))
|
| 359 |
-
|
| 360 |
-
# -------- inference: route -> retrieve -> generate -> JEPA-verify -------
|
| 361 |
-
@torch.no_grad()
|
| 362 |
-
def answer_ids(self, query_ids, n_new=32, tau_consistency=0.0, max_retries=2):
|
| 363 |
-
q_emb = self.emb(query_ids.cpu())
|
| 364 |
-
a_idx = self.router(q_emb).item()
|
| 365 |
-
self.llm.set_adapter(self.adapter_names[a_idx])
|
| 366 |
-
|
| 367 |
-
mems = self.store.search(q_emb, k=1)
|
| 368 |
-
ctx = (torch.cat([mems[0], query_ids.cpu()], dim=1)
|
| 369 |
-
if mems else query_ids.cpu())
|
| 370 |
-
|
| 371 |
-
z_expected = self.jepa.predict_next_latent(ctx)
|
| 372 |
-
|
| 373 |
-
best = None
|
| 374 |
-
for attempt in range(max_retries + 1):
|
| 375 |
-
draft = self.llm.generate(ctx.to(self.llm.device), n_new=n_new,
|
| 376 |
-
temperature=0.8 + 0.3 * attempt)
|
| 377 |
-
new_part = draft[:, ctx.size(1):].cpu()
|
| 378 |
-
score = F.cosine_similarity(
|
| 379 |
-
z_expected, self.jepa.encode(new_part)).item()
|
| 380 |
-
if best is None or score > best[1]:
|
| 381 |
-
best = (draft, score, attempt)
|
| 382 |
-
if score >= tau_consistency:
|
| 383 |
-
break
|
| 384 |
-
draft, score, attempt = best
|
| 385 |
-
return draft, score, self.adapter_names[a_idx], attempt
|
| 386 |
-
|
| 387 |
-
def answer_text(self, prompt: str, **kw):
|
| 388 |
-
"""Convenience wrapper for HF backends (uses the real tokenizer)."""
|
| 389 |
-
ids = self.llm.encode_text(prompt)
|
| 390 |
-
out, score, adapter, retries = self.answer_ids(ids, **kw)
|
| 391 |
-
return self.llm.decode(out), score, adapter, retries
|
| 392 |
-
|
| 393 |
-
# -------- continual learning hooks ---------------------------------------
|
| 394 |
-
def memorize_ids(self, ids):
|
| 395 |
-
self.store.add(self.emb(ids.cpu()), ids.cpu())
|
| 396 |
-
|
| 397 |
-
def memorize_text(self, text: str):
|
| 398 |
-
self.memorize_ids(self.llm.encode_text(text))
|
| 399 |
-
|
| 400 |
-
def new_task_adapter(self, name: str):
|
| 401 |
-
self.adapter_names.append(name)
|
| 402 |
-
self.llm.add_adapter(name)
|
| 403 |
-
old = self.router
|
| 404 |
-
self.router = Router(self.emb.d_emb, len(self.adapter_names))
|
| 405 |
-
with torch.no_grad():
|
| 406 |
-
self.router.fc.weight[: old.fc.out_features] = old.fc.weight
|
| 407 |
-
self.router.fc.bias[: old.fc.out_features] = old.fc.bias
|
| 408 |
-
|
| 409 |
-
# -------- one joint training step (LM + JEPA + Bridge) -------------------
|
| 410 |
-
def train_step(self, seg_a, seg_b, opt, w_bridge=0.1):
|
| 411 |
-
"""seg_a, seg_b: consecutive token-id segments [B, T] (same tokenizer
|
| 412 |
-
as the backend!). For HF backends build them with backend.tokenizer."""
|
| 413 |
-
logits, hidden = self.llm(seg_a.to(self.llm.device))
|
| 414 |
-
lm_loss = F.cross_entropy(
|
| 415 |
-
logits[:, :-1].reshape(-1, self.llm.vocab_size).float(),
|
| 416 |
-
seg_a[:, 1:].reshape(-1).to(logits.device))
|
| 417 |
-
|
| 418 |
-
jepa_loss = self.jepa.loss(seg_a.cpu(), seg_b.cpu())
|
| 419 |
-
|
| 420 |
-
z_llm = self.bridge(hidden.cpu() if hidden.device.type != "cpu" else hidden)
|
| 421 |
-
z_jepa = self.jepa.ctx_enc(seg_a.cpu()).detach()
|
| 422 |
-
bridge_loss = self.bridge.info_nce(z_llm, z_jepa.to(z_llm.device))
|
| 423 |
-
|
| 424 |
-
loss = lm_loss.cpu() + jepa_loss + w_bridge * bridge_loss
|
| 425 |
-
opt.zero_grad(); loss.backward(); opt.step()
|
| 426 |
-
self.jepa.ema_update()
|
| 427 |
-
return {"lm": lm_loss.item(), "jepa": jepa_loss.item(),
|
| 428 |
-
"bridge": bridge_loss.item()}
|
| 429 |
-
|
| 430 |
-
def make_optimizer(self, lr_lora=2e-4, lr_aux=1e-3):
|
| 431 |
-
return torch.optim.AdamW([
|
| 432 |
-
{"params": list(self.llm.trainable_parameters()), "lr": lr_lora},
|
| 433 |
-
{"params": list(self.jepa.ctx_enc.parameters())
|
| 434 |
-
+ list(self.jepa.predictor.parameters()), "lr": lr_aux},
|
| 435 |
-
{"params": list(self.bridge.parameters())
|
| 436 |
-
+ list(self.emb.parameters())
|
| 437 |
-
+ list(self.router.parameters()), "lr": lr_aux},
|
| 438 |
-
])
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
# ----------------------------------------------------------------------------
|
| 442 |
-
# 6. Helpers: turn raw text into (seg_a, seg_b) pairs with the HF tokenizer
|
| 443 |
-
# ----------------------------------------------------------------------------
|
| 444 |
-
def segment_pairs_from_texts(backend: HFBackend, texts, seg_len=64):
|
| 445 |
-
"""Yields consecutive-segment id pairs for the JEPA + LM losses."""
|
| 446 |
-
a_list, b_list = [], []
|
| 447 |
-
for t in texts:
|
| 448 |
-
ids = backend.tokenizer(t, return_tensors="pt").input_ids[0]
|
| 449 |
-
for i in range(0, len(ids) - 2 * seg_len, seg_len):
|
| 450 |
-
a_list.append(ids[i:i + seg_len])
|
| 451 |
-
b_list.append(ids[i + seg_len:i + 2 * seg_len])
|
| 452 |
-
if not a_list:
|
| 453 |
-
raise ValueError("texts too short for the chosen seg_len")
|
| 454 |
-
return torch.stack(a_list), torch.stack(b_list)
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
# ----------------------------------------------------------------------------
|
| 458 |
-
# 7. Demos
|
| 459 |
-
# ----------------------------------------------------------------------------
|
| 460 |
-
def demo_tiny(steps=50):
|
| 461 |
-
"""No-dependency smoke test."""
|
| 462 |
-
ens = Ensemble(llm="tiny")
|
| 463 |
-
opt = ens.make_optimizer()
|
| 464 |
-
for s in range(steps):
|
| 465 |
-
seg_a = torch.randint(0, ens.llm.vocab_size, (8, 32))
|
| 466 |
-
seg_b = torch.randint(0, ens.llm.vocab_size, (8, 32))
|
| 467 |
-
logs = ens.train_step(seg_a, seg_b, opt)
|
| 468 |
-
if s % 10 == 0:
|
| 469 |
-
print(f"step {s:3d} | " + " | ".join(f"{k} {v:.3f}" for k, v in logs.items()))
|
| 470 |
-
|
| 471 |
-
for _ in range(5):
|
| 472 |
-
ens.memorize_ids(torch.randint(0, ens.llm.vocab_size, (1, 32)))
|
| 473 |
-
ens.new_task_adapter("medical")
|
| 474 |
-
|
| 475 |
-
q = torch.randint(0, ens.llm.vocab_size, (1, 8))
|
| 476 |
-
out, score, adapter, retries = ens.answer_ids(q, tau_consistency=-1.0)
|
| 477 |
-
print(f"\nadapter={adapter} jepa_consistency={score:.3f} retries={retries}")
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
def demo_hf(model_path="Qwen/Qwen2.5-0.5B-Instruct"):
|
| 481 |
-
"""Real model from hub id OR local path, e.g. '/models/llama-3.2-1b'.
|
| 482 |
-
For gated Llama repos: huggingface-cli login first."""
|
| 483 |
-
ens = Ensemble(llm=model_path, load_in_4bit=False) # 4bit needs bitsandbytes
|
| 484 |
-
opt = ens.make_optimizer()
|
| 485 |
-
|
| 486 |
-
texts = ["Replace this with your real corpus. " * 50]
|
| 487 |
-
seg_a, seg_b = segment_pairs_from_texts(ens.llm, texts, seg_len=32)
|
| 488 |
-
for s in range(10): # tiny demo run
|
| 489 |
-
logs = ens.train_step(seg_a[:4], seg_b[:4], opt)
|
| 490 |
-
print(f"step {s} | " + " | ".join(f"{k} {v:.3f}" for k, v in logs.items()))
|
| 491 |
-
|
| 492 |
-
ens.memorize_text("The project codename is AURORA and it ships in Q3.")
|
| 493 |
-
ens.new_task_adapter("project_aurora")
|
| 494 |
-
|
| 495 |
-
text, score, adapter, retries = ens.answer_text(
|
| 496 |
-
"What is the project codename?", n_new=24, tau_consistency=-1.0)
|
| 497 |
-
print(f"\n[{adapter} | jepa={score:.3f} | retries={retries}]\n{text}")
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
if __name__ == "__main__":
|
| 501 |
-
import sys
|
| 502 |
-
arg = sys.argv[1] if len(sys.argv) > 1 else "tiny"
|
| 503 |
-
if arg == "tiny":
|
| 504 |
-
demo_tiny()
|
| 505 |
-
else:
|
| 506 |
-
demo_hf(arg) # python ensemble.py /models/llama-3.2-1b
|
| 507 |
-
# python ensemble.py Qwen/Qwen2.5-0.5B-Instruct
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/memory.py
DELETED
|
@@ -1,46 +0,0 @@
|
|
| 1 |
-
"""Retrieval memory: embedder, vector store, and adapter router."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class Embedder(nn.Module):
|
| 11 |
-
def __init__(self, vocab_size: int, d_emb: int = 64):
|
| 12 |
-
super().__init__()
|
| 13 |
-
self.tok = nn.Embedding(vocab_size, d_emb)
|
| 14 |
-
self.enc = nn.GRU(d_emb, d_emb, batch_first=True, bidirectional=True)
|
| 15 |
-
self.proj = nn.Linear(2 * d_emb, d_emb)
|
| 16 |
-
self.d_emb = d_emb
|
| 17 |
-
|
| 18 |
-
def forward(self, ids):
|
| 19 |
-
h, _ = self.enc(self.tok(ids))
|
| 20 |
-
return F.normalize(self.proj(h.mean(dim=1)), dim=-1)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
class VectorStore:
|
| 24 |
-
def __init__(self):
|
| 25 |
-
self.keys, self.values = [], []
|
| 26 |
-
|
| 27 |
-
def add(self, emb, payload):
|
| 28 |
-
self.keys.append(emb.squeeze(0).detach().cpu())
|
| 29 |
-
self.values.append(payload)
|
| 30 |
-
|
| 31 |
-
def search(self, q, k=2):
|
| 32 |
-
if not self.keys:
|
| 33 |
-
return []
|
| 34 |
-
K = torch.stack(self.keys)
|
| 35 |
-
sims = (q.detach().cpu() @ K.t()).squeeze(0)
|
| 36 |
-
top = sims.topk(min(k, len(self.keys))).indices
|
| 37 |
-
return [self.values[i] for i in top]
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
class Router(nn.Module):
|
| 41 |
-
def __init__(self, d_emb, n_adapters):
|
| 42 |
-
super().__init__()
|
| 43 |
-
self.fc = nn.Linear(d_emb, n_adapters)
|
| 44 |
-
|
| 45 |
-
def forward(self, emb):
|
| 46 |
-
return self.fc(emb).argmax(dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/pretrain.py
DELETED
|
@@ -1,198 +0,0 @@
|
|
| 1 |
-
"""Joint pretrain: LLM (LoRA) + embedder + JEPA + bridge, saved to models/ensemble/."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import argparse
|
| 6 |
-
import json
|
| 7 |
-
import os
|
| 8 |
-
import random
|
| 9 |
-
import time
|
| 10 |
-
from pathlib import Path
|
| 11 |
-
|
| 12 |
-
import torch
|
| 13 |
-
|
| 14 |
-
from ensemble.checkpoint import save_checkpoint
|
| 15 |
-
from ensemble.config import default_ensemble_out, load_dotenv, resolve_llm
|
| 16 |
-
from ensemble.jepa_ensemble import Ensemble, segment_pairs_from_texts
|
| 17 |
-
|
| 18 |
-
_REPO_ROOT = Path(__file__).resolve().parents[4]
|
| 19 |
-
_DEFAULT_DATA = _REPO_ROOT / "research/data/education-lesson-chat.jsonl"
|
| 20 |
-
_DEFAULT_KB = _REPO_ROOT / "research/data/benchmark-kb.jsonl"
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def _load_jsonl(path: Path) -> list[dict]:
|
| 24 |
-
rows = []
|
| 25 |
-
with open(path) as f:
|
| 26 |
-
for line in f:
|
| 27 |
-
line = line.strip()
|
| 28 |
-
if line:
|
| 29 |
-
rows.append(json.loads(line))
|
| 30 |
-
return rows
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def _chat_to_text(row: dict) -> str:
|
| 34 |
-
messages = row.get("messages", [])
|
| 35 |
-
parts = [f"{m.get('role', 'user')}: {m.get('content', '')}" for m in messages]
|
| 36 |
-
return "\n".join(parts)
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def _collect_texts(data_path: Path, max_samples: int | None) -> list[str]:
|
| 40 |
-
rows = _load_jsonl(data_path)
|
| 41 |
-
if max_samples is not None:
|
| 42 |
-
rows = rows[:max_samples]
|
| 43 |
-
return [_chat_to_text(r) for r in rows if _chat_to_text(r).strip()]
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def _seed_memory(ens: Ensemble, kb_path: Path | None) -> int:
|
| 47 |
-
if kb_path is None or not kb_path.is_file():
|
| 48 |
-
return 0
|
| 49 |
-
count = 0
|
| 50 |
-
for row in _load_jsonl(kb_path):
|
| 51 |
-
text = row.get("text", "").strip()
|
| 52 |
-
if text:
|
| 53 |
-
ens.memorize_text(text)
|
| 54 |
-
count += 1
|
| 55 |
-
return count
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
def pretrain(args) -> Path:
|
| 59 |
-
torch.manual_seed(args.seed)
|
| 60 |
-
random.seed(args.seed)
|
| 61 |
-
|
| 62 |
-
data_path = Path(args.data).resolve()
|
| 63 |
-
out_dir = Path(args.out).resolve()
|
| 64 |
-
kb_path = Path(args.kb).resolve() if args.kb else None
|
| 65 |
-
|
| 66 |
-
print(f"Loading ensemble backend: {args.llm}")
|
| 67 |
-
ens = Ensemble(llm=args.llm, load_in_4bit=args.load_in_4bit)
|
| 68 |
-
opt = ens.make_optimizer(lr_lora=args.lr_lora, lr_aux=args.lr_aux)
|
| 69 |
-
|
| 70 |
-
texts = _collect_texts(data_path, args.max_samples)
|
| 71 |
-
if not texts and args.llm != "tiny":
|
| 72 |
-
raise SystemExit(f"No training texts found in {data_path}")
|
| 73 |
-
|
| 74 |
-
mem_count = _seed_memory(ens, kb_path)
|
| 75 |
-
print(f"Training texts: {len(texts)} | memory snippets: {mem_count}")
|
| 76 |
-
|
| 77 |
-
if args.llm == "tiny":
|
| 78 |
-
n_pairs = max(args.steps * args.batch_size, args.batch_size)
|
| 79 |
-
v = ens.llm.vocab_size
|
| 80 |
-
seg_a = torch.randint(0, v, (n_pairs, args.seg_len))
|
| 81 |
-
seg_b = torch.randint(0, v, (n_pairs, args.seg_len))
|
| 82 |
-
else:
|
| 83 |
-
seg_a, seg_b = segment_pairs_from_texts(
|
| 84 |
-
ens.llm, texts, seg_len=args.seg_len
|
| 85 |
-
)
|
| 86 |
-
n_pairs = seg_a.size(0)
|
| 87 |
-
batch = min(args.batch_size, n_pairs)
|
| 88 |
-
print(f"Segment pairs: {n_pairs} | batch={batch} | steps={args.steps}")
|
| 89 |
-
|
| 90 |
-
t0 = time.time()
|
| 91 |
-
for step in range(args.steps):
|
| 92 |
-
idx = torch.randint(0, n_pairs, (batch,))
|
| 93 |
-
logs = ens.train_step(seg_a[idx], seg_b[idx], opt, w_bridge=args.w_bridge)
|
| 94 |
-
if step % max(1, args.log_every) == 0 or step == args.steps - 1:
|
| 95 |
-
parts = " | ".join(f"{k} {v:.4f}" for k, v in logs.items())
|
| 96 |
-
print(f"step {step:4d}/{args.steps} | {parts}")
|
| 97 |
-
|
| 98 |
-
elapsed = time.time() - t0
|
| 99 |
-
meta = {
|
| 100 |
-
"steps": args.steps,
|
| 101 |
-
"batch_size": batch,
|
| 102 |
-
"seg_len": args.seg_len,
|
| 103 |
-
"data": str(data_path),
|
| 104 |
-
"kb": str(kb_path) if kb_path else None,
|
| 105 |
-
"memory_count": mem_count,
|
| 106 |
-
"text_count": len(texts),
|
| 107 |
-
"elapsed_s": round(elapsed, 1),
|
| 108 |
-
"lr_lora": args.lr_lora,
|
| 109 |
-
"lr_aux": args.lr_aux,
|
| 110 |
-
"w_bridge": args.w_bridge,
|
| 111 |
-
"seed": args.seed,
|
| 112 |
-
"preset": getattr(args, "preset", None),
|
| 113 |
-
}
|
| 114 |
-
|
| 115 |
-
saved = save_checkpoint(
|
| 116 |
-
ens,
|
| 117 |
-
out_dir,
|
| 118 |
-
base_llm=args.llm,
|
| 119 |
-
training_meta=meta,
|
| 120 |
-
)
|
| 121 |
-
print(f"\nSaved ensemble checkpoint → {saved}")
|
| 122 |
-
print("Benchmark with slm-evals:")
|
| 123 |
-
print(
|
| 124 |
-
f" uv run --package slm-evals slm-benchmark "
|
| 125 |
-
f"--model {saved} --model-type ensemble "
|
| 126 |
-
f"--benchmarks bfcl --max-samples 5"
|
| 127 |
-
)
|
| 128 |
-
return saved
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def parse_args():
|
| 132 |
-
p = argparse.ArgumentParser(
|
| 133 |
-
description="Pretrain JEPA ensemble (LLM+emb+JEPA) and save to models/ensemble/"
|
| 134 |
-
)
|
| 135 |
-
p.add_argument(
|
| 136 |
-
"--llm",
|
| 137 |
-
default=None,
|
| 138 |
-
help=(
|
| 139 |
-
"HF hub id / local path, 'tiny' for CPU smoke, or omit to use "
|
| 140 |
-
"LLM_PATH / BASE / MODEL_ID / ACTIVE_MODEL from .env + models.yaml"
|
| 141 |
-
),
|
| 142 |
-
)
|
| 143 |
-
p.add_argument(
|
| 144 |
-
"--preset",
|
| 145 |
-
default=None,
|
| 146 |
-
help="models.yaml preset key (default: ENSEMBLE_PRESET or ACTIVE_MODEL)",
|
| 147 |
-
)
|
| 148 |
-
p.add_argument(
|
| 149 |
-
"--data",
|
| 150 |
-
default=str(_DEFAULT_DATA),
|
| 151 |
-
help="Chat JSONL (messages[]) for segment-pair training",
|
| 152 |
-
)
|
| 153 |
-
p.add_argument(
|
| 154 |
-
"--kb",
|
| 155 |
-
default=str(_DEFAULT_KB),
|
| 156 |
-
help="Optional KB JSONL (text field) loaded into vector store",
|
| 157 |
-
)
|
| 158 |
-
p.add_argument(
|
| 159 |
-
"--out",
|
| 160 |
-
default=None,
|
| 161 |
-
help="Output dir (default: ENSEMBLE_OUT or models/ensemble/<preset>-jepa-pretrain)",
|
| 162 |
-
)
|
| 163 |
-
p.add_argument("--steps", type=int, default=100)
|
| 164 |
-
p.add_argument("--batch-size", type=int, default=4)
|
| 165 |
-
p.add_argument("--seg-len", type=int, default=32)
|
| 166 |
-
p.add_argument("--max-samples", type=int, default=None)
|
| 167 |
-
p.add_argument("--lr-lora", type=float, default=2e-4)
|
| 168 |
-
p.add_argument("--lr-aux", type=float, default=1e-3)
|
| 169 |
-
p.add_argument("--w-bridge", type=float, default=0.1)
|
| 170 |
-
p.add_argument("--log-every", type=int, default=10)
|
| 171 |
-
p.add_argument("--seed", type=int, default=0)
|
| 172 |
-
p.add_argument("--load-in-4bit", action="store_true")
|
| 173 |
-
p.add_argument("--no-kb", action="store_true", help="Skip loading KB into memory")
|
| 174 |
-
return p.parse_args()
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
def main():
|
| 178 |
-
load_dotenv()
|
| 179 |
-
args = parse_args()
|
| 180 |
-
if args.no_kb:
|
| 181 |
-
args.kb = None
|
| 182 |
-
|
| 183 |
-
preset_key = args.preset
|
| 184 |
-
if args.llm is None or args.llm == "auto":
|
| 185 |
-
args.llm, preset_key = resolve_llm(preset_arg=args.preset)
|
| 186 |
-
elif args.llm != "tiny" and not args.preset:
|
| 187 |
-
_, preset_key = resolve_llm(llm_arg=args.llm)
|
| 188 |
-
|
| 189 |
-
if not args.out:
|
| 190 |
-
args.out = os.environ.get("ENSEMBLE_OUT") or default_ensemble_out(preset_key)
|
| 191 |
-
|
| 192 |
-
args.preset = preset_key
|
| 193 |
-
print(f"Resolved LLM: {args.llm}" + (f" (preset {preset_key})" if preset_key else ""))
|
| 194 |
-
pretrain(args)
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
if __name__ == "__main__":
|
| 198 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/world_ensemble.py
DELETED
|
@@ -1,228 +0,0 @@
|
|
| 1 |
-
"""World-model ensemble: plan -> generate -> energy-rank."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import sys
|
| 6 |
-
import time
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn as nn
|
| 10 |
-
import torch.nn.functional as F
|
| 11 |
-
|
| 12 |
-
from ensemble.backends import HFLLM, load_llm
|
| 13 |
-
from ensemble.bridge import Bridge
|
| 14 |
-
from ensemble.energy import EnergyModel
|
| 15 |
-
from ensemble.jepa import JEPA
|
| 16 |
-
from ensemble.memory import Embedder, VectorStore
|
| 17 |
-
from ensemble.world_model import WorldModel
|
| 18 |
-
|
| 19 |
-
torch.manual_seed(0)
|
| 20 |
-
|
| 21 |
-
D_LAT = 96
|
| 22 |
-
D_EMB = 64
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
class WorldEnsemble(nn.Module):
|
| 26 |
-
def __init__(self, llm_spec="tiny"):
|
| 27 |
-
super().__init__()
|
| 28 |
-
self.llm = load_llm(llm_spec)
|
| 29 |
-
V, H = self.llm.vocab_size, self.llm.hidden_size
|
| 30 |
-
self.emb = Embedder(V, D_EMB)
|
| 31 |
-
self.jepa = JEPA(V, D_LAT)
|
| 32 |
-
self.world = WorldModel(D_LAT)
|
| 33 |
-
self.energy = EnergyModel(D_LAT)
|
| 34 |
-
self.bridge = Bridge(H, D_LAT)
|
| 35 |
-
self.store = VectorStore()
|
| 36 |
-
|
| 37 |
-
@torch.no_grad()
|
| 38 |
-
def world_state(self, segments):
|
| 39 |
-
s = self.world.init_state(1, "cpu")
|
| 40 |
-
for seg in segments:
|
| 41 |
-
z = self.jepa.encode(seg.cpu())
|
| 42 |
-
s, _ = self.world.step(s, z)
|
| 43 |
-
return s
|
| 44 |
-
|
| 45 |
-
@torch.no_grad()
|
| 46 |
-
def answer(self, query_ids, n_new=24, n_drafts=6, horizon=3):
|
| 47 |
-
q_emb = self.emb(query_ids.cpu())
|
| 48 |
-
mems = self.store.search(q_emb, k=1)
|
| 49 |
-
segments = (mems + [query_ids.cpu()]) if mems else [query_ids.cpu()]
|
| 50 |
-
ctx = torch.cat(segments, dim=1)
|
| 51 |
-
|
| 52 |
-
s = self.world_state(segments)
|
| 53 |
-
plan, _ = self.world.rollout(s, horizon)
|
| 54 |
-
|
| 55 |
-
drafts, lat = [], []
|
| 56 |
-
for _ in range(n_drafts):
|
| 57 |
-
out = self.llm.generate(
|
| 58 |
-
ctx.to(self.llm.device), n_new=n_new, temperature=0.9
|
| 59 |
-
)
|
| 60 |
-
new = out[:, ctx.size(1) :].cpu()
|
| 61 |
-
drafts.append(new)
|
| 62 |
-
lat.append(self.jepa.encode(new))
|
| 63 |
-
Z = torch.cat(lat, 0)
|
| 64 |
-
E = self.energy.rank(s, Z)
|
| 65 |
-
best = E.argmin().item()
|
| 66 |
-
return {
|
| 67 |
-
"output": drafts[best],
|
| 68 |
-
"energy": E[best].item(),
|
| 69 |
-
"all_energies": E.tolist(),
|
| 70 |
-
"plan_alignment": F.cosine_similarity(
|
| 71 |
-
plan[:, 0], Z[best : best + 1]
|
| 72 |
-
).item(),
|
| 73 |
-
}
|
| 74 |
-
|
| 75 |
-
def memorize(self, ids):
|
| 76 |
-
self.store.add(self.emb(ids.cpu()), ids.cpu())
|
| 77 |
-
|
| 78 |
-
def train_step(
|
| 79 |
-
self,
|
| 80 |
-
seg_seq,
|
| 81 |
-
opt,
|
| 82 |
-
w=None,
|
| 83 |
-
hard_negs=True,
|
| 84 |
-
):
|
| 85 |
-
if w is None:
|
| 86 |
-
w = dict(lm=1.0, jepa=1.0, world=1.0, ebm=1.0, bridge=0.1)
|
| 87 |
-
|
| 88 |
-
B, T, L = seg_seq.shape
|
| 89 |
-
dev = self.llm.device
|
| 90 |
-
|
| 91 |
-
flat = seg_seq[:, 0].to(dev)
|
| 92 |
-
logits, hidden = self.llm(flat)
|
| 93 |
-
lm = F.cross_entropy(
|
| 94 |
-
logits[:, :-1].reshape(-1, self.llm.vocab_size).float(),
|
| 95 |
-
flat[:, 1:].reshape(-1),
|
| 96 |
-
)
|
| 97 |
-
|
| 98 |
-
jepa = self.jepa.loss(seg_seq[:, 0], seg_seq[:, 1])
|
| 99 |
-
|
| 100 |
-
z_seq = torch.stack(
|
| 101 |
-
[self.jepa.enc(seg_seq[:, t]) for t in range(T)], 1
|
| 102 |
-
)
|
| 103 |
-
world = self.world.sequence_loss(z_seq)
|
| 104 |
-
|
| 105 |
-
s = self.world.init_state(B, z_seq.device)
|
| 106 |
-
s, _ = self.world.step(s, z_seq[:, 0].detach())
|
| 107 |
-
z_pos = z_seq[:, 1].detach()
|
| 108 |
-
z_negs = None
|
| 109 |
-
if hard_negs:
|
| 110 |
-
with torch.no_grad():
|
| 111 |
-
gen = self.llm.generate(seg_seq[:, 0].to(dev), n_new=L)
|
| 112 |
-
gen_new = gen[:, seg_seq.size(2) :].cpu()
|
| 113 |
-
z_negs = self.jepa.encode(gen_new).unsqueeze(1)
|
| 114 |
-
ebm = self.energy.contrastive_loss(s, z_pos, z_negs)
|
| 115 |
-
|
| 116 |
-
bridge = self.bridge.info_nce(
|
| 117 |
-
self.bridge(
|
| 118 |
-
hidden.cpu() if hidden.device.type != "cpu" else hidden
|
| 119 |
-
),
|
| 120 |
-
self.jepa.enc(seg_seq[:, 0]).detach(),
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
-
loss = (
|
| 124 |
-
w["lm"] * lm.cpu()
|
| 125 |
-
+ w["jepa"] * jepa
|
| 126 |
-
+ w["world"] * world
|
| 127 |
-
+ w["ebm"] * ebm
|
| 128 |
-
+ w["bridge"] * bridge
|
| 129 |
-
)
|
| 130 |
-
opt.zero_grad()
|
| 131 |
-
loss.backward()
|
| 132 |
-
opt.step()
|
| 133 |
-
self.jepa.ema()
|
| 134 |
-
return dict(
|
| 135 |
-
lm=lm.item(),
|
| 136 |
-
jepa=jepa.item(),
|
| 137 |
-
world=world.item(),
|
| 138 |
-
ebm=ebm.item(),
|
| 139 |
-
bridge=bridge.item(),
|
| 140 |
-
)
|
| 141 |
-
|
| 142 |
-
def make_optimizer(self, lr_lora=2e-4, lr_aux=1e-3):
|
| 143 |
-
return torch.optim.AdamW(
|
| 144 |
-
[
|
| 145 |
-
{"params": list(self.llm.trainable_parameters()), "lr": lr_lora},
|
| 146 |
-
{
|
| 147 |
-
"params": list(self.jepa.enc.parameters())
|
| 148 |
-
+ list(self.jepa.pred.parameters()),
|
| 149 |
-
"lr": lr_aux,
|
| 150 |
-
},
|
| 151 |
-
{"params": list(self.world.parameters()), "lr": lr_aux},
|
| 152 |
-
{"params": list(self.energy.parameters()), "lr": lr_aux},
|
| 153 |
-
{
|
| 154 |
-
"params": list(self.bridge.parameters())
|
| 155 |
-
+ list(self.emb.parameters()),
|
| 156 |
-
"lr": lr_aux,
|
| 157 |
-
},
|
| 158 |
-
]
|
| 159 |
-
)
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
def toy_segment_sequences(B=8, T=4, L=24, vocab=1000):
|
| 163 |
-
return torch.randint(0, vocab, (B, T, L))
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
def hf_segment_sequences(llm: HFLLM, texts, T=4, L=64):
|
| 167 |
-
seqs = []
|
| 168 |
-
for t in texts:
|
| 169 |
-
ids = llm.tokenizer(t, return_tensors="pt").input_ids[0]
|
| 170 |
-
n = (len(ids) // (T * L)) * T * L
|
| 171 |
-
if n:
|
| 172 |
-
seqs.append(ids[:n].view(-1, T, L))
|
| 173 |
-
if not seqs:
|
| 174 |
-
raise ValueError("corpus too short for T*L window")
|
| 175 |
-
return torch.cat(seqs, 0)
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
def demo(spec="tiny", steps=60):
|
| 179 |
-
ens = WorldEnsemble(spec)
|
| 180 |
-
opt = ens.make_optimizer()
|
| 181 |
-
|
| 182 |
-
if spec == "tiny":
|
| 183 |
-
get_batch = lambda: toy_segment_sequences(vocab=ens.llm.vocab_size)
|
| 184 |
-
else:
|
| 185 |
-
corpus = ["Replace with your real documents. " * 200]
|
| 186 |
-
data = hf_segment_sequences(ens.llm, corpus, T=4, L=32)
|
| 187 |
-
get_batch = lambda: data[torch.randperm(len(data))[:4]]
|
| 188 |
-
steps = min(steps, 10)
|
| 189 |
-
|
| 190 |
-
t0 = time.time()
|
| 191 |
-
for s in range(steps):
|
| 192 |
-
logs = ens.train_step(
|
| 193 |
-
get_batch(), opt, hard_negs=(s > steps // 2)
|
| 194 |
-
)
|
| 195 |
-
if s % 10 == 0:
|
| 196 |
-
print(
|
| 197 |
-
f"step {s:3d} | "
|
| 198 |
-
+ " | ".join(f"{k} {v:.3f}" for k, v in logs.items())
|
| 199 |
-
)
|
| 200 |
-
print(f"trained {steps} steps in {time.time() - t0:.1f}s")
|
| 201 |
-
|
| 202 |
-
for _ in range(4):
|
| 203 |
-
if spec == "tiny":
|
| 204 |
-
ens.memorize(torch.randint(0, ens.llm.vocab_size, (1, 24)))
|
| 205 |
-
q = (
|
| 206 |
-
torch.randint(0, ens.llm.vocab_size, (1, 12))
|
| 207 |
-
if spec == "tiny"
|
| 208 |
-
else ens.llm.tokenizer(
|
| 209 |
-
"What is this document about?", return_tensors="pt"
|
| 210 |
-
).input_ids
|
| 211 |
-
)
|
| 212 |
-
res = ens.answer(q, n_drafts=6, horizon=3)
|
| 213 |
-
print(
|
| 214 |
-
f"\nselected draft energy={res['energy']:.3f} "
|
| 215 |
-
f"(all: {[f'{e:.2f}' for e in res['all_energies']]})"
|
| 216 |
-
)
|
| 217 |
-
print(f"plan↔output alignment: {res['plan_alignment']:.3f}")
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
if __name__ == "__main__":
|
| 221 |
-
from ensemble.config import load_dotenv, resolve_llm
|
| 222 |
-
|
| 223 |
-
load_dotenv()
|
| 224 |
-
spec = sys.argv[1] if len(sys.argv) > 1 else None
|
| 225 |
-
if spec is None or spec == "auto":
|
| 226 |
-
spec, preset = resolve_llm()
|
| 227 |
-
print(f"Resolved LLM: {spec} (preset {preset})")
|
| 228 |
-
demo(spec or "tiny")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/world_model.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
"""Latent world model: multi-step rollout in JEPA space."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class WorldModel(nn.Module):
|
| 11 |
-
def __init__(self, d_latent: int):
|
| 12 |
-
super().__init__()
|
| 13 |
-
self.cell = nn.GRUCell(d_latent, d_latent)
|
| 14 |
-
self.head = nn.Linear(d_latent, d_latent)
|
| 15 |
-
self.s0 = nn.Parameter(torch.zeros(d_latent))
|
| 16 |
-
self.d_latent = d_latent
|
| 17 |
-
|
| 18 |
-
def init_state(self, B, device):
|
| 19 |
-
return self.s0.unsqueeze(0).expand(B, -1).contiguous().to(device)
|
| 20 |
-
|
| 21 |
-
def step(self, s, z):
|
| 22 |
-
s = self.cell(z, s)
|
| 23 |
-
return s, self.head(s)
|
| 24 |
-
|
| 25 |
-
def rollout(self, s, horizon):
|
| 26 |
-
preds = []
|
| 27 |
-
for _ in range(horizon):
|
| 28 |
-
z_hat = self.head(s)
|
| 29 |
-
preds.append(z_hat)
|
| 30 |
-
s = self.cell(z_hat, s)
|
| 31 |
-
return torch.stack(preds, 1), s
|
| 32 |
-
|
| 33 |
-
def sequence_loss(self, z_seq):
|
| 34 |
-
B, T, _ = z_seq.shape
|
| 35 |
-
s = self.init_state(B, z_seq.device)
|
| 36 |
-
loss = 0.0
|
| 37 |
-
for t in range(T - 1):
|
| 38 |
-
s, z_hat = self.step(s, z_seq[:, t])
|
| 39 |
-
loss = loss + F.mse_loss(z_hat, z_seq[:, t + 1])
|
| 40 |
-
return loss / (T - 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/ensemble/src/ensemble/world_model_ensemble.py
DELETED
|
@@ -1,499 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
World-Model Ensemble: EMB + EBM + JEPA + World Model + small LLM (from path)
|
| 3 |
-
=============================================================================
|
| 4 |
-
A LeCun-style modular agent built around a small language model.
|
| 5 |
-
|
| 6 |
-
ARCHITECTURE
|
| 7 |
-
------------
|
| 8 |
-
┌────────────────────────────┐
|
| 9 |
-
input tokens ──► EMB ──┤ VectorStore (retrieval/CL) │──► context
|
| 10 |
-
│ └────────────────────────────┘ │
|
| 11 |
-
│ │
|
| 12 |
-
▼ ▼
|
| 13 |
-
JEPA encoder ──► latent state s_t ──► WORLD MODEL ──► ŝ_{t+1..t+H}
|
| 14 |
-
│ (GRU dynamics, multi-step rollout) │
|
| 15 |
-
│ │
|
| 16 |
-
│ ┌────────────────────────────────────┐ │
|
| 17 |
-
└──────────► │ ENERGY MODEL E(s_ctx, z_candidate)│ ◄─┘
|
| 18 |
-
│ low energy = compatible/plausible │
|
| 19 |
-
└────────────────┬───────────────────┘
|
| 20 |
-
│ scores drafts / plans
|
| 21 |
-
▼
|
| 22 |
-
LLM (small, loaded from path, LoRA bank) ──► N drafts ──► pick argmin E
|
| 23 |
-
|
| 24 |
-
ROLES
|
| 25 |
-
-----
|
| 26 |
-
EMB perception for retrieval + routing (non-parametric memory)
|
| 27 |
-
JEPA learns the latent space: predict z(next segment) from z(context)
|
| 28 |
-
(EMA target encoder + variance reg, no token reconstruction)
|
| 29 |
-
WORLD MODEL deterministic latent dynamics s_{t+1} = f(s_t, z_t):
|
| 30 |
-
rolls the conversation/document state forward H steps in
|
| 31 |
-
LATENT space — cheap lookahead without decoding tokens
|
| 32 |
-
ENERGY E(s, z) ∈ R, trained so true continuations have LOW energy and
|
| 33 |
-
negatives (shuffled / model-generated) have HIGH energy.
|
| 34 |
-
At inference it is the critic: rank LLM drafts, reject bad plans.
|
| 35 |
-
LLM the only token-level generator. Loaded from a local path or HF id;
|
| 36 |
-
frozen base + LoRA adapters (continual learning by isolation).
|
| 37 |
-
|
| 38 |
-
WHY EBM *and* JEPA? JEPA gives a point prediction ẑ of the future latent;
|
| 39 |
-
the EBM gives a *compatibility landscape* E(s, z) — it can say "both A and B
|
| 40 |
-
are plausible" where a point predictor must average them. JEPA trains the
|
| 41 |
-
representation; the EBM scores hypotheses in it. World model chains JEPA
|
| 42 |
-
one-step predictions into multi-step rollouts that the EBM can evaluate.
|
| 43 |
-
|
| 44 |
-
USAGE
|
| 45 |
-
-----
|
| 46 |
-
pip install torch # toy mode
|
| 47 |
-
pip install transformers peft accelerate # real LLM mode
|
| 48 |
-
|
| 49 |
-
python world_model_ensemble.py tiny # smoke test
|
| 50 |
-
python world_model_ensemble.py /models/llama-3.2-1b # local weights
|
| 51 |
-
python world_model_ensemble.py Qwen/Qwen2.5-0.5B-Instruct
|
| 52 |
-
"""
|
| 53 |
-
|
| 54 |
-
from __future__ import annotations
|
| 55 |
-
import copy
|
| 56 |
-
import math
|
| 57 |
-
import sys
|
| 58 |
-
import time
|
| 59 |
-
|
| 60 |
-
import torch
|
| 61 |
-
import torch.nn as nn
|
| 62 |
-
import torch.nn.functional as F
|
| 63 |
-
|
| 64 |
-
torch.manual_seed(0)
|
| 65 |
-
|
| 66 |
-
D_LAT = 96 # shared latent dimension (JEPA / world / energy)
|
| 67 |
-
D_EMB = 64 # retrieval embedding dim
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
# ============================================================================
|
| 71 |
-
# 1. LLM backend — load small model from path / hub, or toy fallback
|
| 72 |
-
# (same contract as before: forward -> (logits, hidden), generate, adapters)
|
| 73 |
-
# ============================================================================
|
| 74 |
-
class TinyLLM(nn.Module):
|
| 75 |
-
VOCAB, D, L, H, T = 1000, 128, 2, 4, 32
|
| 76 |
-
|
| 77 |
-
def __init__(self):
|
| 78 |
-
super().__init__()
|
| 79 |
-
self.tok = nn.Embedding(self.VOCAB, self.D)
|
| 80 |
-
self.pos = nn.Embedding(self.T * 4, self.D)
|
| 81 |
-
layer = nn.TransformerEncoderLayer(self.D, self.H, 4 * self.D,
|
| 82 |
-
batch_first=True, norm_first=True)
|
| 83 |
-
self.blocks = nn.TransformerEncoder(layer, self.L)
|
| 84 |
-
self.head = nn.Linear(self.D, self.VOCAB, bias=False)
|
| 85 |
-
self.vocab_size, self.hidden_size = self.VOCAB, self.D
|
| 86 |
-
|
| 87 |
-
def forward(self, ids):
|
| 88 |
-
Tn = ids.size(1)
|
| 89 |
-
x = self.tok(ids) + self.pos(torch.arange(Tn, device=ids.device))
|
| 90 |
-
mask = torch.triu(torch.full((Tn, Tn), float("-inf"),
|
| 91 |
-
device=ids.device), 1)
|
| 92 |
-
h = self.blocks(x, mask=mask)
|
| 93 |
-
return self.head(h), h
|
| 94 |
-
|
| 95 |
-
@torch.no_grad()
|
| 96 |
-
def generate(self, ids, n_new=16, temperature=1.0):
|
| 97 |
-
for _ in range(n_new):
|
| 98 |
-
logits, _ = self(ids[:, -self.T:])
|
| 99 |
-
nxt = torch.multinomial(
|
| 100 |
-
F.softmax(logits[:, -1] / temperature, -1), 1)
|
| 101 |
-
ids = torch.cat([ids, nxt], 1)
|
| 102 |
-
return ids
|
| 103 |
-
|
| 104 |
-
def trainable_parameters(self):
|
| 105 |
-
return self.parameters()
|
| 106 |
-
|
| 107 |
-
@property
|
| 108 |
-
def device(self):
|
| 109 |
-
return next(self.parameters()).device
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
class HFLLM(nn.Module):
|
| 113 |
-
"""Small model from a local path or HF id, frozen base + LoRA."""
|
| 114 |
-
def __init__(self, path, lora_r=16):
|
| 115 |
-
super().__init__()
|
| 116 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 117 |
-
from peft import LoraConfig, get_peft_model
|
| 118 |
-
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
| 119 |
-
if self.tokenizer.pad_token is None:
|
| 120 |
-
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 121 |
-
base = AutoModelForCausalLM.from_pretrained(
|
| 122 |
-
path, torch_dtype=torch.bfloat16
|
| 123 |
-
if torch.cuda.is_available() else torch.float32,
|
| 124 |
-
device_map="auto" if torch.cuda.is_available() else None)
|
| 125 |
-
for p in base.parameters():
|
| 126 |
-
p.requires_grad_(False)
|
| 127 |
-
cfg = LoraConfig(r=lora_r, lora_alpha=2 * lora_r, lora_dropout=0.05,
|
| 128 |
-
target_modules=["q_proj", "v_proj"],
|
| 129 |
-
task_type="CAUSAL_LM")
|
| 130 |
-
self.model = get_peft_model(base, cfg)
|
| 131 |
-
self.vocab_size = self.model.config.vocab_size
|
| 132 |
-
self.hidden_size = self.model.config.hidden_size
|
| 133 |
-
|
| 134 |
-
def forward(self, ids):
|
| 135 |
-
out = self.model(input_ids=ids.to(self.device),
|
| 136 |
-
output_hidden_states=True)
|
| 137 |
-
return out.logits, out.hidden_states[-1]
|
| 138 |
-
|
| 139 |
-
@torch.no_grad()
|
| 140 |
-
def generate(self, ids, n_new=32, temperature=0.8):
|
| 141 |
-
return self.model.generate(
|
| 142 |
-
input_ids=ids.to(self.device), max_new_tokens=n_new,
|
| 143 |
-
do_sample=True, temperature=temperature,
|
| 144 |
-
pad_token_id=self.tokenizer.pad_token_id)
|
| 145 |
-
|
| 146 |
-
def trainable_parameters(self):
|
| 147 |
-
return (p for p in self.model.parameters() if p.requires_grad)
|
| 148 |
-
|
| 149 |
-
@property
|
| 150 |
-
def device(self):
|
| 151 |
-
return next(self.model.parameters()).device
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
def load_llm(spec: str):
|
| 155 |
-
return TinyLLM() if spec == "tiny" else HFLLM(spec)
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
# ============================================================================
|
| 159 |
-
# 2. Embedder (retrieval) — vocab-agnostic
|
| 160 |
-
# ============================================================================
|
| 161 |
-
class Embedder(nn.Module):
|
| 162 |
-
def __init__(self, vocab):
|
| 163 |
-
super().__init__()
|
| 164 |
-
self.tok = nn.Embedding(vocab, D_EMB)
|
| 165 |
-
self.gru = nn.GRU(D_EMB, D_EMB, batch_first=True, bidirectional=True)
|
| 166 |
-
self.out = nn.Linear(2 * D_EMB, D_EMB)
|
| 167 |
-
|
| 168 |
-
def forward(self, ids):
|
| 169 |
-
h, _ = self.gru(self.tok(ids))
|
| 170 |
-
return F.normalize(self.out(h.mean(1)), dim=-1)
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
class VectorStore:
|
| 174 |
-
def __init__(self):
|
| 175 |
-
self.K, self.V = [], []
|
| 176 |
-
|
| 177 |
-
def add(self, k, v):
|
| 178 |
-
self.K.append(k.squeeze(0).detach().cpu()); self.V.append(v)
|
| 179 |
-
|
| 180 |
-
def search(self, q, k=1):
|
| 181 |
-
if not self.K:
|
| 182 |
-
return []
|
| 183 |
-
sims = (q.detach().cpu() @ torch.stack(self.K).t()).squeeze(0)
|
| 184 |
-
return [self.V[i] for i in sims.topk(min(k, len(self.K))).indices]
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
# ============================================================================
|
| 188 |
-
# 3. JEPA — owns the latent space (EMA target encoder, variance-regularized)
|
| 189 |
-
# ============================================================================
|
| 190 |
-
class SegEncoder(nn.Module):
|
| 191 |
-
def __init__(self, vocab):
|
| 192 |
-
super().__init__()
|
| 193 |
-
self.tok = nn.Embedding(vocab, D_LAT)
|
| 194 |
-
self.gru = nn.GRU(D_LAT, D_LAT, batch_first=True)
|
| 195 |
-
self.out = nn.Linear(D_LAT, D_LAT)
|
| 196 |
-
|
| 197 |
-
def forward(self, ids):
|
| 198 |
-
h, _ = self.gru(self.tok(ids))
|
| 199 |
-
return self.out(h.mean(1)) # [B, D_LAT]
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
class JEPA(nn.Module):
|
| 203 |
-
def __init__(self, vocab, m=0.996):
|
| 204 |
-
super().__init__()
|
| 205 |
-
self.enc = SegEncoder(vocab) # context/online enc
|
| 206 |
-
self.tgt = copy.deepcopy(self.enc) # EMA target
|
| 207 |
-
for p in self.tgt.parameters():
|
| 208 |
-
p.requires_grad_(False)
|
| 209 |
-
self.pred = nn.Sequential(nn.Linear(D_LAT, 2 * D_LAT), nn.GELU(),
|
| 210 |
-
nn.Linear(2 * D_LAT, D_LAT))
|
| 211 |
-
self.m = m
|
| 212 |
-
|
| 213 |
-
@torch.no_grad()
|
| 214 |
-
def ema(self):
|
| 215 |
-
for pt, pc in zip(self.tgt.parameters(), self.enc.parameters()):
|
| 216 |
-
pt.mul_(self.m).add_(pc.detach(), alpha=1 - self.m)
|
| 217 |
-
|
| 218 |
-
def loss(self, seg_a, seg_b):
|
| 219 |
-
z_hat = self.pred(self.enc(seg_a))
|
| 220 |
-
with torch.no_grad():
|
| 221 |
-
z_tgt = self.tgt(seg_b)
|
| 222 |
-
var = F.relu(1.0 - z_hat.std(0)).mean() # anti-collapse
|
| 223 |
-
return F.mse_loss(z_hat, z_tgt) + 0.5 * var
|
| 224 |
-
|
| 225 |
-
@torch.no_grad()
|
| 226 |
-
def encode(self, seg): # target space
|
| 227 |
-
return self.tgt(seg)
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
# ============================================================================
|
| 231 |
-
# 4. WORLD MODEL — latent dynamics s_{t+1} = f(s_t, z_t), multi-step rollout
|
| 232 |
-
# Trained on SEQUENCES of segments: predict each next latent from state.
|
| 233 |
-
# ============================================================================
|
| 234 |
-
class WorldModel(nn.Module):
|
| 235 |
-
def __init__(self):
|
| 236 |
-
super().__init__()
|
| 237 |
-
self.cell = nn.GRUCell(D_LAT, D_LAT) # state update
|
| 238 |
-
self.head = nn.Linear(D_LAT, D_LAT) # state -> ẑ_{t+1}
|
| 239 |
-
self.s0 = nn.Parameter(torch.zeros(D_LAT))
|
| 240 |
-
|
| 241 |
-
def init_state(self, B, device):
|
| 242 |
-
return self.s0.unsqueeze(0).expand(B, -1).contiguous().to(device)
|
| 243 |
-
|
| 244 |
-
def step(self, s, z):
|
| 245 |
-
"""Consume observed latent z_t, return (new state, prediction ẑ_{t+1})."""
|
| 246 |
-
s = self.cell(z, s)
|
| 247 |
-
return s, self.head(s)
|
| 248 |
-
|
| 249 |
-
def rollout(self, s, horizon):
|
| 250 |
-
"""Imagine H future latents feeding its own predictions back in."""
|
| 251 |
-
preds = []
|
| 252 |
-
for _ in range(horizon):
|
| 253 |
-
z_hat = self.head(s)
|
| 254 |
-
preds.append(z_hat)
|
| 255 |
-
s = self.cell(z_hat, s)
|
| 256 |
-
return torch.stack(preds, 1), s # [B, H, D_LAT]
|
| 257 |
-
|
| 258 |
-
def sequence_loss(self, z_seq):
|
| 259 |
-
"""z_seq: [B, T, D_LAT] observed segment latents (teacher forcing)."""
|
| 260 |
-
B, T, _ = z_seq.shape
|
| 261 |
-
s = self.init_state(B, z_seq.device)
|
| 262 |
-
loss = 0.0
|
| 263 |
-
for t in range(T - 1):
|
| 264 |
-
s, z_hat = self.step(s, z_seq[:, t])
|
| 265 |
-
loss = loss + F.mse_loss(z_hat, z_seq[:, t + 1])
|
| 266 |
-
return loss / (T - 1)
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
# ============================================================================
|
| 270 |
-
# 5. ENERGY MODEL — E(state, candidate latent) ∈ R, low = plausible
|
| 271 |
-
# Trained with InfoNCE-style contrastive: positives = true next latent,
|
| 272 |
-
# negatives = (a) other batch items, (b) LLM-generated drafts (optional).
|
| 273 |
-
# ============================================================================
|
| 274 |
-
class EnergyModel(nn.Module):
|
| 275 |
-
def __init__(self):
|
| 276 |
-
super().__init__()
|
| 277 |
-
self.net = nn.Sequential(
|
| 278 |
-
nn.Linear(2 * D_LAT, 2 * D_LAT), nn.GELU(),
|
| 279 |
-
nn.Linear(2 * D_LAT, D_LAT), nn.GELU(),
|
| 280 |
-
nn.Linear(D_LAT, 1))
|
| 281 |
-
|
| 282 |
-
def energy(self, s, z):
|
| 283 |
-
"""s: [B, D_LAT] context state; z: [B, D_LAT] candidate. -> [B]"""
|
| 284 |
-
return self.net(torch.cat([s, z], -1)).squeeze(-1)
|
| 285 |
-
|
| 286 |
-
def contrastive_loss(self, s, z_pos, z_negs=None, tau=0.5):
|
| 287 |
-
"""Softmax over energies: true continuation must be the argmin.
|
| 288 |
-
In-batch negatives: every other item's z_pos is a negative for s_i."""
|
| 289 |
-
B = s.size(0)
|
| 290 |
-
# pairwise energies: E(s_i, z_j) for all i, j
|
| 291 |
-
s_rep = s.unsqueeze(1).expand(B, B, D_LAT).reshape(B * B, D_LAT)
|
| 292 |
-
z_rep = z_pos.unsqueeze(0).expand(B, B, D_LAT).reshape(B * B, D_LAT)
|
| 293 |
-
E = self.energy(s_rep, z_rep).view(B, B) # [B, B]
|
| 294 |
-
if z_negs is not None: # extra hard negatives
|
| 295 |
-
En = self.energy(
|
| 296 |
-
s.repeat_interleave(z_negs.size(1), 0),
|
| 297 |
-
z_negs.reshape(-1, D_LAT)).view(B, -1)
|
| 298 |
-
E = torch.cat([E, En], dim=1)
|
| 299 |
-
labels = torch.arange(B, device=s.device)
|
| 300 |
-
return F.cross_entropy(-E / tau, labels) # low E ⇒ high logit
|
| 301 |
-
|
| 302 |
-
@torch.no_grad()
|
| 303 |
-
def rank(self, s, candidates):
|
| 304 |
-
"""candidates: [N, D_LAT]; returns energies [N] (lower = better)."""
|
| 305 |
-
return self.energy(s.expand(candidates.size(0), -1), candidates)
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
# ============================================================================
|
| 309 |
-
# 6. Bridge — LLM hidden states -> shared latent space (alignment)
|
| 310 |
-
# ============================================================================
|
| 311 |
-
class Bridge(nn.Module):
|
| 312 |
-
def __init__(self, d_hidden):
|
| 313 |
-
super().__init__()
|
| 314 |
-
self.proj = nn.Sequential(nn.Linear(d_hidden, D_LAT), nn.GELU(),
|
| 315 |
-
nn.Linear(D_LAT, D_LAT))
|
| 316 |
-
|
| 317 |
-
def forward(self, h): # [B,T,H] -> [B,D_LAT]
|
| 318 |
-
return self.proj(h.float().mean(1))
|
| 319 |
-
|
| 320 |
-
def info_nce(self, a, b, tau=0.07):
|
| 321 |
-
a, b = F.normalize(a, -1), F.normalize(b, -1)
|
| 322 |
-
logits = a @ b.t() / tau
|
| 323 |
-
y = torch.arange(a.size(0), device=a.device)
|
| 324 |
-
return 0.5 * (F.cross_entropy(logits, y) +
|
| 325 |
-
F.cross_entropy(logits.t(), y))
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
# ============================================================================
|
| 329 |
-
# 7. THE ENSEMBLE — wiring + inference (plan -> generate -> energy-rank)
|
| 330 |
-
# ============================================================================
|
| 331 |
-
class WorldEnsemble(nn.Module):
|
| 332 |
-
def __init__(self, llm_spec="tiny"):
|
| 333 |
-
super().__init__()
|
| 334 |
-
self.llm = load_llm(llm_spec)
|
| 335 |
-
V, H = self.llm.vocab_size, self.llm.hidden_size
|
| 336 |
-
self.emb = Embedder(V)
|
| 337 |
-
self.jepa = JEPA(V)
|
| 338 |
-
self.world = WorldModel()
|
| 339 |
-
self.energy = EnergyModel()
|
| 340 |
-
self.bridge = Bridge(H)
|
| 341 |
-
self.store = VectorStore()
|
| 342 |
-
|
| 343 |
-
# ------------------------- inference ---------------------------------
|
| 344 |
-
@torch.no_grad()
|
| 345 |
-
def world_state(self, segments):
|
| 346 |
-
"""Fold a list of [1,T] segment tensors into a latent state."""
|
| 347 |
-
s = self.world.init_state(1, "cpu")
|
| 348 |
-
for seg in segments:
|
| 349 |
-
z = self.jepa.encode(seg.cpu())
|
| 350 |
-
s, _ = self.world.step(s, z)
|
| 351 |
-
return s
|
| 352 |
-
|
| 353 |
-
@torch.no_grad()
|
| 354 |
-
def answer(self, query_ids, n_new=24, n_drafts=6, horizon=3):
|
| 355 |
-
"""retrieve -> build world state -> imagine -> generate N -> argmin E."""
|
| 356 |
-
q_emb = self.emb(query_ids.cpu())
|
| 357 |
-
mems = self.store.search(q_emb, k=1)
|
| 358 |
-
segments = (mems + [query_ids.cpu()]) if mems else [query_ids.cpu()]
|
| 359 |
-
ctx = torch.cat(segments, dim=1)
|
| 360 |
-
|
| 361 |
-
s = self.world_state(segments) # latent context state
|
| 362 |
-
plan, _ = self.world.rollout(s, horizon) # imagined future
|
| 363 |
-
# (plan is available for planning losses / steering; logged here)
|
| 364 |
-
|
| 365 |
-
drafts, lat = [], []
|
| 366 |
-
for _ in range(n_drafts):
|
| 367 |
-
out = self.llm.generate(ctx.to(self.llm.device), n_new=n_new,
|
| 368 |
-
temperature=0.9)
|
| 369 |
-
new = out[:, ctx.size(1):].cpu()
|
| 370 |
-
drafts.append(new)
|
| 371 |
-
lat.append(self.jepa.encode(new))
|
| 372 |
-
Z = torch.cat(lat, 0) # [N, D_LAT]
|
| 373 |
-
E = self.energy.rank(s, Z) # lower = better
|
| 374 |
-
best = E.argmin().item()
|
| 375 |
-
return {"output": drafts[best], "energy": E[best].item(),
|
| 376 |
-
"all_energies": E.tolist(),
|
| 377 |
-
"plan_alignment": F.cosine_similarity(
|
| 378 |
-
plan[:, 0], Z[best:best + 1]).item()}
|
| 379 |
-
|
| 380 |
-
def memorize(self, ids):
|
| 381 |
-
self.store.add(self.emb(ids.cpu()), ids.cpu())
|
| 382 |
-
|
| 383 |
-
# ------------------------- training ----------------------------------
|
| 384 |
-
def train_step(self, seg_seq, opt, w=dict(lm=1.0, jepa=1.0, world=1.0,
|
| 385 |
-
ebm=1.0, bridge=0.1),
|
| 386 |
-
hard_negs=True):
|
| 387 |
-
"""seg_seq: [B, T_seg, L] — B documents, each split into T_seg
|
| 388 |
-
consecutive segments of length L (same tokenizer as the LLM)."""
|
| 389 |
-
B, T, L = seg_seq.shape
|
| 390 |
-
dev = self.llm.device
|
| 391 |
-
|
| 392 |
-
# (1) LM loss on the first segment (or all, batched, if budget allows)
|
| 393 |
-
flat = seg_seq[:, 0].to(dev)
|
| 394 |
-
logits, hidden = self.llm(flat)
|
| 395 |
-
lm = F.cross_entropy(
|
| 396 |
-
logits[:, :-1].reshape(-1, self.llm.vocab_size).float(),
|
| 397 |
-
flat[:, 1:].reshape(-1))
|
| 398 |
-
|
| 399 |
-
# (2) JEPA: adjacent segment pairs
|
| 400 |
-
jepa = self.jepa.loss(seg_seq[:, 0], seg_seq[:, 1])
|
| 401 |
-
|
| 402 |
-
# (3) World model: sequence of latents (online encoder, grads flow)
|
| 403 |
-
z_seq = torch.stack([self.jepa.enc(seg_seq[:, t])
|
| 404 |
-
for t in range(T)], 1) # [B, T, D_LAT]
|
| 405 |
-
world = self.world.sequence_loss(z_seq)
|
| 406 |
-
|
| 407 |
-
# (4) Energy: state after t=0 must give low E to true z_1,
|
| 408 |
-
# high E to in-batch + (optionally) LLM-generated negatives
|
| 409 |
-
s = self.world.init_state(B, z_seq.device)
|
| 410 |
-
s, _ = self.world.step(s, z_seq[:, 0].detach())
|
| 411 |
-
z_pos = z_seq[:, 1].detach()
|
| 412 |
-
z_negs = None
|
| 413 |
-
if hard_negs:
|
| 414 |
-
with torch.no_grad(): # model drafts as negs
|
| 415 |
-
gen = self.llm.generate(seg_seq[:, 0].to(dev), n_new=L)
|
| 416 |
-
gen_new = gen[:, seg_seq.size(2):].cpu()
|
| 417 |
-
z_negs = self.jepa.encode(gen_new).unsqueeze(1) # [B,1,D]
|
| 418 |
-
ebm = self.energy.contrastive_loss(s, z_pos, z_negs)
|
| 419 |
-
|
| 420 |
-
# (5) Bridge: align LLM hidden(seg0) with JEPA latent(seg0)
|
| 421 |
-
bridge = self.bridge.info_nce(
|
| 422 |
-
self.bridge(hidden.cpu() if hidden.device.type != "cpu" else hidden),
|
| 423 |
-
self.jepa.enc(seg_seq[:, 0]).detach())
|
| 424 |
-
|
| 425 |
-
loss = (w["lm"] * lm.cpu() + w["jepa"] * jepa + w["world"] * world
|
| 426 |
-
+ w["ebm"] * ebm + w["bridge"] * bridge)
|
| 427 |
-
opt.zero_grad(); loss.backward(); opt.step()
|
| 428 |
-
self.jepa.ema()
|
| 429 |
-
return dict(lm=lm.item(), jepa=jepa.item(), world=world.item(),
|
| 430 |
-
ebm=ebm.item(), bridge=bridge.item())
|
| 431 |
-
|
| 432 |
-
def make_optimizer(self, lr_lora=2e-4, lr_aux=1e-3):
|
| 433 |
-
return torch.optim.AdamW([
|
| 434 |
-
{"params": list(self.llm.trainable_parameters()), "lr": lr_lora},
|
| 435 |
-
{"params": list(self.jepa.enc.parameters())
|
| 436 |
-
+ list(self.jepa.pred.parameters()), "lr": lr_aux},
|
| 437 |
-
{"params": list(self.world.parameters()), "lr": lr_aux},
|
| 438 |
-
{"params": list(self.energy.parameters()), "lr": lr_aux},
|
| 439 |
-
{"params": list(self.bridge.parameters())
|
| 440 |
-
+ list(self.emb.parameters()), "lr": lr_aux}])
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
# ============================================================================
|
| 444 |
-
# 8. Data helpers + demo
|
| 445 |
-
# ============================================================================
|
| 446 |
-
def toy_segment_sequences(B=8, T=4, L=24, vocab=1000):
|
| 447 |
-
"""Random docs split into T consecutive segments. Replace with real
|
| 448 |
-
corpus: tokenize each document, reshape into [T, L] windows."""
|
| 449 |
-
return torch.randint(0, vocab, (B, T, L))
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
def hf_segment_sequences(llm: HFLLM, texts, T=4, L=64):
|
| 453 |
-
seqs = []
|
| 454 |
-
for t in texts:
|
| 455 |
-
ids = llm.tokenizer(t, return_tensors="pt").input_ids[0]
|
| 456 |
-
n = (len(ids) // (T * L)) * T * L
|
| 457 |
-
if n:
|
| 458 |
-
seqs.append(ids[:n].view(-1, T, L))
|
| 459 |
-
if not seqs:
|
| 460 |
-
raise ValueError("corpus too short for T*L window")
|
| 461 |
-
return torch.cat(seqs, 0)
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
def demo(spec="tiny", steps=60):
|
| 465 |
-
ens = WorldEnsemble(spec)
|
| 466 |
-
opt = ens.make_optimizer()
|
| 467 |
-
|
| 468 |
-
if spec == "tiny":
|
| 469 |
-
get_batch = lambda: toy_segment_sequences(vocab=ens.llm.vocab_size)
|
| 470 |
-
else:
|
| 471 |
-
corpus = ["Replace with your real documents. " * 200]
|
| 472 |
-
data = hf_segment_sequences(ens.llm, corpus, T=4, L=32)
|
| 473 |
-
get_batch = lambda: data[torch.randperm(len(data))[:4]]
|
| 474 |
-
steps = min(steps, 10)
|
| 475 |
-
|
| 476 |
-
t0 = time.time()
|
| 477 |
-
for s in range(steps):
|
| 478 |
-
logs = ens.train_step(get_batch(), opt,
|
| 479 |
-
hard_negs=(s > steps // 2)) # warmup w/o negs
|
| 480 |
-
if s % 10 == 0:
|
| 481 |
-
print(f"step {s:3d} | " +
|
| 482 |
-
" | ".join(f"{k} {v:.3f}" for k, v in logs.items()))
|
| 483 |
-
print(f"trained {steps} steps in {time.time()-t0:.1f}s")
|
| 484 |
-
|
| 485 |
-
# memory + inference
|
| 486 |
-
for _ in range(4):
|
| 487 |
-
if spec == "tiny":
|
| 488 |
-
ens.memorize(torch.randint(0, ens.llm.vocab_size, (1, 24)))
|
| 489 |
-
q = (torch.randint(0, ens.llm.vocab_size, (1, 12)) if spec == "tiny"
|
| 490 |
-
else ens.llm.tokenizer("What is this document about?",
|
| 491 |
-
return_tensors="pt").input_ids)
|
| 492 |
-
res = ens.answer(q, n_drafts=6, horizon=3)
|
| 493 |
-
print(f"\nselected draft energy={res['energy']:.3f} "
|
| 494 |
-
f"(all: {[f'{e:.2f}' for e in res['all_energies']]})")
|
| 495 |
-
print(f"plan↔output alignment: {res['plan_alignment']:.3f}")
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
if __name__ == "__main__":
|
| 499 |
-
demo(sys.argv[1] if len(sys.argv) > 1 else "tiny")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/eval_harness.py
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
"""Deprecated shim — use `ensemble.eval.jepa_harness` instead."""
|
| 2 |
-
|
| 3 |
-
from ensemble.eval.jepa_harness import run, parse_args
|
| 4 |
-
|
| 5 |
-
if __name__ == "__main__":
|
| 6 |
-
run(parse_args())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
research/evals/USAGE.md
CHANGED
|
@@ -197,12 +197,6 @@ uv run --package slm-evals slm-lm-eval \
|
|
| 197 |
--model openbmb/MiniCPM5-1B \
|
| 198 |
--adapter ./models/finetuned/minicpm5-1b-lora \
|
| 199 |
--experiment-name minicpm5-1b-lora__manual
|
| 200 |
-
|
| 201 |
-
# Ensemble checkpoint (manifest.json auto-detected)
|
| 202 |
-
uv run --package slm-evals slm-lm-eval \
|
| 203 |
-
--config research/evals/configs/lm_eval_smoke.yaml \
|
| 204 |
-
--model ./models/ensemble/jepa-lesson-pretrain \
|
| 205 |
-
--experiment-name ensemble-jepa__lm-eval
|
| 206 |
```
|
| 207 |
|
| 208 |
### Compare baseline vs candidate
|
|
@@ -259,8 +253,8 @@ slm-lm-eval [OPTIONS]
|
|
| 259 |
--list-tasks-all Full lm-eval task list
|
| 260 |
--profile NAME Shorthand for --config (reasoning, code, smoke, …)
|
| 261 |
--config PATH YAML config (tasks, seed, limit, …)
|
| 262 |
-
--preset KEY models.yaml preset (base, LoRA, merged
|
| 263 |
-
--model PATH HF Hub id
|
| 264 |
--adapter PATH LoRA adapter (alternative to preset adapter_path)
|
| 265 |
--tasks NAMES Override task list
|
| 266 |
--num-fewshot N
|
|
@@ -285,12 +279,6 @@ Each run writes to `<output_dir>/<experiment_name>/`:
|
|
| 285 |
| `run_meta.json` | Preset, base model, adapter, tasks, seed |
|
| 286 |
| `comparison.md` | Delta table (when `--compare-to` set) |
|
| 287 |
|
| 288 |
-
### Ensemble backend notes
|
| 289 |
-
|
| 290 |
-
- **`ensemble-lm`** loads JEPA checkpoints via `manifest.json`.
|
| 291 |
-
- **`generate_until`** tasks (e.g. `gsm8k`) use the full ensemble stack (`generate_text`).
|
| 292 |
-
- **`loglikelihood`** tasks (e.g. `arc_easy`, `hellaswag`) score the underlying HF LLM head (adapter 0), not the JEPA selector. Use [`jepa_harness`](../ensemble/README.md) to measure selector value on domain QA.
|
| 293 |
-
|
| 294 |
### PEFT / LoRA
|
| 295 |
|
| 296 |
lm-eval expects `pretrained=<base>,peft=<adapter>`. The preset resolver handles this for keys like `minicpm5-1b-lesson-lora`. Merged checkpoints use `--preset minicpm5-1b-lesson-merged` or `--model ./models/finetuned/...-merged`.
|
|
|
|
| 197 |
--model openbmb/MiniCPM5-1B \
|
| 198 |
--adapter ./models/finetuned/minicpm5-1b-lora \
|
| 199 |
--experiment-name minicpm5-1b-lora__manual
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
```
|
| 201 |
|
| 202 |
### Compare baseline vs candidate
|
|
|
|
| 253 |
--list-tasks-all Full lm-eval task list
|
| 254 |
--profile NAME Shorthand for --config (reasoning, code, smoke, …)
|
| 255 |
--config PATH YAML config (tasks, seed, limit, …)
|
| 256 |
+
--preset KEY models.yaml preset (base, LoRA, merged)
|
| 257 |
+
--model PATH HF Hub id or merged checkpoint dir
|
| 258 |
--adapter PATH LoRA adapter (alternative to preset adapter_path)
|
| 259 |
--tasks NAMES Override task list
|
| 260 |
--num-fewshot N
|
|
|
|
| 279 |
| `run_meta.json` | Preset, base model, adapter, tasks, seed |
|
| 280 |
| `comparison.md` | Delta table (when `--compare-to` set) |
|
| 281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
### PEFT / LoRA
|
| 283 |
|
| 284 |
lm-eval expects `pretrained=<base>,peft=<adapter>`. The preset resolver handles this for keys like `minicpm5-1b-lesson-lora`. Merged checkpoints use `--preset minicpm5-1b-lesson-merged` or `--model ./models/finetuned/...-merged`.
|
research/evals/configs/ensemble_jepa_lesson.yaml
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
# JEPA ensemble checkpoint (models/ensemble/jepa-lesson-pretrain)
|
| 2 |
-
# Pretrain: uv run --package ensemble ensemble-pretrain --llm Qwen/Qwen2.5-0.5B-Instruct
|
| 3 |
-
# Compare baseline: copy this file, set model_path to the base Hub id and model_type: hf
|
| 4 |
-
|
| 5 |
-
model_path: "./models/ensemble/jepa-lesson-pretrain"
|
| 6 |
-
model_type: "ensemble"
|
| 7 |
-
device: "auto"
|
| 8 |
-
dtype: "bfloat16"
|
| 9 |
-
|
| 10 |
-
max_new_tokens: 512
|
| 11 |
-
temperature: 0.0
|
| 12 |
-
|
| 13 |
-
experiment_name: "jepa-ensemble-lesson__bfcl-tau__v1"
|
| 14 |
-
output_dir: "results"
|
| 15 |
-
|
| 16 |
-
benchmarks:
|
| 17 |
-
- bfcl
|
| 18 |
-
- tau_bench
|
| 19 |
-
|
| 20 |
-
max_samples: 20
|
| 21 |
-
|
| 22 |
-
benchmark_overrides:
|
| 23 |
-
tau_bench:
|
| 24 |
-
use_llm_user: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|