Feature Extraction
sentence-transformers
Safetensors
English
sparse-encoder
splade
sparse-retrieval
opensearch
elasticsearch
financial-filings
Instructions to use oneryalcin/financial-filings-sparse-encoder-v1 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- sentence-transformers
How to use oneryalcin/financial-filings-sparse-encoder-v1 with sentence-transformers:
from sentence_transformers import SentenceTransformer model = SentenceTransformer("oneryalcin/financial-filings-sparse-encoder-v1") sentences = [ "The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium." ] embeddings = model.encode(sentences) similarities = model.similarity(embeddings, embeddings) print(similarities.shape) # [3, 3] - Notebooks
- Google Colab
- Kaggle
Add financial filings sparse encoder v1
Browse files- README.md +333 -0
- benchmarks/regularization_sweep_500_steps.json +25 -0
- benchmarks/retrieval_proxy_1000.json +109 -0
- benchmarks/triplet_topk_1000.json +48 -0
- config_sentence_transformers.json +14 -0
- document_0_MLMTransformer/config.json +27 -0
- document_0_MLMTransformer/model.safetensors +3 -0
- document_0_MLMTransformer/sentence_bert_config.json +10 -0
- document_0_MLMTransformer/tokenizer.json +0 -0
- document_0_MLMTransformer/tokenizer_config.json +16 -0
- document_1_SpladePooling/config.json +5 -0
- modules.json +8 -0
- query_0_SparseStaticEmbedding/config.json +3 -0
- query_0_SparseStaticEmbedding/model.safetensors +3 -0
- query_0_SparseStaticEmbedding/tokenizer.json +0 -0
- query_0_SparseStaticEmbedding/tokenizer_config.json +16 -0
- router_config.json +21 -0
- scripts/eval_fin_bm25_retrieval_proxy.py +155 -0
- scripts/eval_fin_sparse_retrieval_proxy.py +170 -0
- scripts/eval_fin_sparse_topk.py +144 -0
- scripts/train_fin_sparse_encoder_v2.py +238 -0
README.md
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
library_name: sentence-transformers
|
| 5 |
+
pipeline_tag: feature-extraction
|
| 6 |
+
tags:
|
| 7 |
+
- sentence-transformers
|
| 8 |
+
- sparse-encoder
|
| 9 |
+
- splade
|
| 10 |
+
- sparse-retrieval
|
| 11 |
+
- opensearch
|
| 12 |
+
- elasticsearch
|
| 13 |
+
- financial-filings
|
| 14 |
+
base_model: opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill
|
| 15 |
+
datasets:
|
| 16 |
+
- oneryalcin/financial-filings-sparse-retrieval-training
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
# Financial Filings Sparse Encoder v1
|
| 20 |
+
|
| 21 |
+
This is a Sentence Transformers `SparseEncoder` / SPLADE-style model fine-tuned for financial filing retrieval.
|
| 22 |
+
|
| 23 |
+
The practical recommendation from the experiments below is to index document vectors after **top-128 pruning**. In the current proxy retrieval benchmark, top-128 keeps almost all unpruned quality while reducing each document to about 126 active sparse terms.
|
| 24 |
+
|
| 25 |
+
## TL;DR
|
| 26 |
+
|
| 27 |
+
| setting | value |
|
| 28 |
+
|---|---:|
|
| 29 |
+
| Base model | `opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill` |
|
| 30 |
+
| Dataset | `oneryalcin/financial-filings-sparse-retrieval-training`, `combined` config |
|
| 31 |
+
| Training recipe | `SpladeLoss(SparseMultipleNegativesRankingLoss)` |
|
| 32 |
+
| Final steps | 1500 |
|
| 33 |
+
| Recommended serving | document `top_k=128` |
|
| 34 |
+
| Triplet accuracy, top-128 | 78.0% |
|
| 35 |
+
| Retrieval proxy Recall@10, top-128 | 67.2% |
|
| 36 |
+
| Retrieval proxy nDCG@10, top-128 | 0.521 |
|
| 37 |
+
|
| 38 |
+
This is an experiment report as much as a model card. It describes what was tried, why decisions were made, what worked, and what remains unproven.
|
| 39 |
+
|
| 40 |
+
## Why this exists
|
| 41 |
+
|
| 42 |
+
Financial filing search is not generic semantic similarity. Queries often refer to company events, accounting concepts, risk factors, segment details, and filing-specific language. A learned sparse model is attractive because it can improve ranking while preserving an Elasticsearch/OpenSearch-style sparse retrieval path.
|
| 43 |
+
|
| 44 |
+
The goal here was to train a domain-adapted sparse encoder that can be deployed as weighted sparse terms, then test whether the learned sparse signal is meaningfully better than both the base sparse model and a lexical BM25 baseline.
|
| 45 |
+
|
| 46 |
+
## Starting point
|
| 47 |
+
|
| 48 |
+
This experiment started after the Sentence Transformers `v5.5.0` release and its new `train-sentence-transformers` agent skill. The release made it easier to create a complete sparse-encoder training workflow: base model choice, loss selection, sparse regularization, top-k pruning checks, and model-card packaging.
|
| 49 |
+
|
| 50 |
+
Before training, the important constraints were:
|
| 51 |
+
|
| 52 |
+
- Use a sparse encoder suitable for Elasticsearch/OpenSearch-style retrieval.
|
| 53 |
+
- Keep vectors sparse enough to index in practice.
|
| 54 |
+
- Use the user's existing Hugging Face dataset and local Apple Silicon machine.
|
| 55 |
+
- Avoid overclaiming from pairwise training accuracy alone.
|
| 56 |
+
|
| 57 |
+
## Base model decision
|
| 58 |
+
|
| 59 |
+
Chosen base model:
|
| 60 |
+
|
| 61 |
+
```text
|
| 62 |
+
opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
Reason: it already has separate query/document sparse encoding behavior and is aligned with OpenSearch neural sparse retrieval. Starting here means fine-tuning adapts a serving-compatible sparse model rather than building a new retrieval stack from scratch.
|
| 66 |
+
|
| 67 |
+
## Dataset
|
| 68 |
+
|
| 69 |
+
Dataset:
|
| 70 |
+
|
| 71 |
+
```text
|
| 72 |
+
oneryalcin/financial-filings-sparse-retrieval-training
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
Config:
|
| 76 |
+
|
| 77 |
+
```text
|
| 78 |
+
combined
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
Each usable row is treated as:
|
| 82 |
+
|
| 83 |
+
```text
|
| 84 |
+
query
|
| 85 |
+
positive filing chunk
|
| 86 |
+
first non-empty hard negative filing chunk
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
Training used up to 20,000 requested rows. After dropping rows without a non-empty negative, there were 18,247 usable training triplets.
|
| 90 |
+
|
| 91 |
+
Evaluation used the first 1,000 usable rows from the held-out `test` split.
|
| 92 |
+
|
| 93 |
+
Known dataset fields include `query`, `positive`, `negatives`, `query_type`, `company`, and `doc_type`.
|
| 94 |
+
|
| 95 |
+
## Training recipe
|
| 96 |
+
|
| 97 |
+
Model type:
|
| 98 |
+
|
| 99 |
+
```text
|
| 100 |
+
sentence_transformers.sparse_encoder.SparseEncoder
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
Loss:
|
| 104 |
+
|
| 105 |
+
```text
|
| 106 |
+
SpladeLoss(
|
| 107 |
+
SparseMultipleNegativesRankingLoss,
|
| 108 |
+
query_regularizer_weight=1e-4,
|
| 109 |
+
document_regularizer_weight=1e-2,
|
| 110 |
+
)
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
Important settings:
|
| 114 |
+
|
| 115 |
+
| setting | value |
|
| 116 |
+
|---|---:|
|
| 117 |
+
| max steps | 1500 |
|
| 118 |
+
| batch size | 8 |
|
| 119 |
+
| train rows requested | 20,000 |
|
| 120 |
+
| usable train rows | 18,247 |
|
| 121 |
+
| query regularization | `1e-4` |
|
| 122 |
+
| document regularization | `1e-2` |
|
| 123 |
+
| sampler | no-duplicates batch sampler |
|
| 124 |
+
| platform | local Apple Silicon, MPS where supported |
|
| 125 |
+
| package cutoff | `uv run --exclude-newer=2026-05-13T00:00:00Z` |
|
| 126 |
+
|
| 127 |
+
Why this recipe:
|
| 128 |
+
|
| 129 |
+
- `SparseMultipleNegativesRankingLoss` teaches the model to rank the paired positive above in-batch negatives.
|
| 130 |
+
- `SpladeLoss` adds FLOPS-style sparse regularization so the model does not emit unindexably dense document vectors.
|
| 131 |
+
- A no-duplicates sampler avoids false negatives inside contrastive batches.
|
| 132 |
+
- Document regularization matters more than query regularization for index footprint.
|
| 133 |
+
|
| 134 |
+
## Decision log
|
| 135 |
+
|
| 136 |
+
| decision point | choice | why |
|
| 137 |
+
|---|---|---|
|
| 138 |
+
| Retrieval family | `SparseEncoder` / SPLADE-style sparse retrieval | Best fit for Elasticsearch/OpenSearch sparse retrieval. |
|
| 139 |
+
| Base model | OpenSearch sparse doc v2 distill | Already query/document routed and serving-aligned. |
|
| 140 |
+
| Vocab | Keep base tokenizer/vocab | Changing vocabulary would be a larger pretraining-style project, not a fine-tune. |
|
| 141 |
+
| Loss | `SpladeLoss + SparseMultipleNegativesRankingLoss` | Combines ranking pressure with sparse regularization. |
|
| 142 |
+
| Sampler | no-duplicates | Reduces accidental false negatives in batch. |
|
| 143 |
+
| Doc regularization | increased to `1e-2` | Lower regularization gave high scores but thousands of active doc terms. |
|
| 144 |
+
| Serving vector size | `top_k=128` | Preserved almost all quality with much smaller document vectors. |
|
| 145 |
+
| Final run length | 1500 steps | 500-step runs were promising; longer training improved both triplet and retrieval proxy scores. |
|
| 146 |
+
| Evaluation strategy | triplet + retrieval proxy + base sparse + BM25 | Avoids relying on a single easy metric. |
|
| 147 |
+
|
| 148 |
+
## Regularization sweep
|
| 149 |
+
|
| 150 |
+
The first 500-step sweep showed the main tradeoff. Low document regularization produced good triplet accuracy but extremely dense document vectors.
|
| 151 |
+
|
| 152 |
+
| doc reg | triplet accuracy | positive doc dims | interpretation |
|
| 153 |
+
|---:|---:|---:|---|
|
| 154 |
+
| `8e-5` | 73.8% | 4204.9 | Too dense. |
|
| 155 |
+
| `1.5e-4` | 74.8% | 3214.8 | Best raw 500-step accuracy, still too dense. |
|
| 156 |
+
| `3e-4` | 73.4% | 2101.4 | Still too dense. |
|
| 157 |
+
| `1e-3` | 72.7% | 1340.5 | Better, still high. |
|
| 158 |
+
| `3e-3` | 73.4% | 657.1 | Practical direction. |
|
| 159 |
+
| `1e-2` | 73.2% | 296.7 | Chosen for longer run. |
|
| 160 |
+
| `3e-2` | 72.3% | 160.9 | Compact, lower quality. |
|
| 161 |
+
|
| 162 |
+
The final model uses `document_regularizer_weight=1e-2` and 1500 steps.
|
| 163 |
+
|
| 164 |
+
## Evaluation 1: triplet accuracy
|
| 165 |
+
|
| 166 |
+
For each held-out row:
|
| 167 |
+
|
| 168 |
+
```text
|
| 169 |
+
score(query, positive) > score(query, hard_negative)
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
Metric:
|
| 173 |
+
|
| 174 |
+
```text
|
| 175 |
+
accuracy = fraction of rows where positive scores higher than the paired negative
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
| model | doc pruning | accuracy | mean margin | query dims | positive doc dims |
|
| 179 |
+
|---|---|---:|---:|---:|---:|
|
| 180 |
+
| base sparse | unpruned | 49.6% | 0.451 | 15.2 | 371.1 |
|
| 181 |
+
| fine-tuned sparse | unpruned | 78.1% | 3.004 | 15.2 | 340.9 |
|
| 182 |
+
| fine-tuned sparse | top-128 | 78.0% | 2.992 | 15.2 | 126.7 |
|
| 183 |
+
| fine-tuned sparse | top-64 | 75.9% | 2.988 | 15.2 | 64.0 |
|
| 184 |
+
|
| 185 |
+
Interpretation: the fine-tuned model learned the domain signal strongly. Top-128 preserved almost all pairwise quality.
|
| 186 |
+
|
| 187 |
+
## Evaluation 2: in-memory retrieval proxy
|
| 188 |
+
|
| 189 |
+
A retrieval candidate pool was built from the held-out test split:
|
| 190 |
+
|
| 191 |
+
- 1,000 held-out queries
|
| 192 |
+
- all unique positives from those rows
|
| 193 |
+
- all unique first hard negatives from those rows
|
| 194 |
+
- 1,912 unique candidate chunks total
|
| 195 |
+
|
| 196 |
+
For each query, the known positive chunk is the only labeled relevant document. The model ranks all 1,912 candidate chunks by sparse dot product. BM25 ranks the same candidate corpus with a local lexical implementation.
|
| 197 |
+
|
| 198 |
+
Metrics:
|
| 199 |
+
|
| 200 |
+
- `Recall@1`
|
| 201 |
+
- `Recall@5`
|
| 202 |
+
- `Recall@10`
|
| 203 |
+
- `Recall@20`
|
| 204 |
+
- `MRR@10`
|
| 205 |
+
- `nDCG@10`
|
| 206 |
+
- median rank
|
| 207 |
+
|
| 208 |
+
| model | pruning | Recall@1 | Recall@5 | Recall@10 | Recall@20 | MRR@10 | nDCG@10 | median rank |
|
| 209 |
+
|---|---|---:|---:|---:|---:|---:|---:|---:|
|
| 210 |
+
| fine-tuned sparse | unpruned | 39.0% | 58.9% | 67.5% | 75.0% | 0.479 | 0.526 | 3 |
|
| 211 |
+
| fine-tuned sparse | top-128 | 38.6% | 57.8% | 67.2% | 73.9% | 0.473 | 0.521 | 3 |
|
| 212 |
+
| fine-tuned sparse | top-64 | 35.0% | 55.7% | 64.8% | 72.4% | 0.442 | 0.491 | 4 |
|
| 213 |
+
| base sparse | unpruned | 32.1% | 56.5% | 63.7% | 69.9% | 0.431 | 0.481 | 3 |
|
| 214 |
+
| base sparse | top-128 | 31.4% | 54.4% | 62.7% | 69.7% | 0.422 | 0.472 | 3 |
|
| 215 |
+
| base sparse | top-64 | 29.5% | 52.6% | 59.5% | 66.5% | 0.396 | 0.444 | 4 |
|
| 216 |
+
| BM25 | lexical | 24.0% | 58.2% | 64.1% | 68.6% | 0.397 | 0.457 | 3 |
|
| 217 |
+
|
| 218 |
+
Interpretation:
|
| 219 |
+
|
| 220 |
+
- Top-128 fine-tuned sparse is the best current deployment candidate.
|
| 221 |
+
- BM25 is competitive at Recall@10, but worse at early precision and ranking quality.
|
| 222 |
+
- Fine-tuning mostly improved early ranking: Recall@1 rose from 32.1% for the base sparse model to 39.0% unpruned, and 38.6% at top-128.
|
| 223 |
+
- Top-64 is usable only if index size or latency dominates quality.
|
| 224 |
+
|
| 225 |
+
## Recommended serving configuration
|
| 226 |
+
|
| 227 |
+
Use document top-k pruning:
|
| 228 |
+
|
| 229 |
+
```text
|
| 230 |
+
document_top_k = 128
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
Rationale:
|
| 234 |
+
|
| 235 |
+
| setting | Recall@10 | nDCG@10 | doc active dims |
|
| 236 |
+
|---|---:|---:|---:|
|
| 237 |
+
| unpruned | 67.5% | 0.526 | 319.8 |
|
| 238 |
+
| top-128 | 67.2% | 0.521 | 126.5 |
|
| 239 |
+
| top-64 | 64.8% | 0.491 | 64.0 |
|
| 240 |
+
|
| 241 |
+
Top-128 gives almost the same retrieval quality as unpruned with a much smaller sparse index footprint.
|
| 242 |
+
|
| 243 |
+
## Usage
|
| 244 |
+
|
| 245 |
+
```python
|
| 246 |
+
from sentence_transformers.sparse_encoder import SparseEncoder
|
| 247 |
+
|
| 248 |
+
model = SparseEncoder("oneryalcin/financial-filings-sparse-encoder-v1")
|
| 249 |
+
|
| 250 |
+
query_vectors = model.encode_query([
|
| 251 |
+
"What does the company say about liquidity risk?"
|
| 252 |
+
])
|
| 253 |
+
|
| 254 |
+
document_vectors = model.encode_document([
|
| 255 |
+
"The company discusses liquidity and capital resources in the MD&A section..."
|
| 256 |
+
])
|
| 257 |
+
```
|
| 258 |
+
|
| 259 |
+
For production-style sparse indexing, keep the highest-weighted 128 dimensions per document vector before indexing.
|
| 260 |
+
|
| 261 |
+
## Reproduction
|
| 262 |
+
|
| 263 |
+
The local runs used `uv` with a package-date cutoff:
|
| 264 |
+
|
| 265 |
+
```bash
|
| 266 |
+
uv run --exclude-newer=2026-05-13T00:00:00Z scripts/train_fin_sparse_encoder_v2.py \
|
| 267 |
+
--train-size 20000 \
|
| 268 |
+
--eval-size 1000 \
|
| 269 |
+
--max-steps 1500 \
|
| 270 |
+
--batch-size 8 \
|
| 271 |
+
--eval-batch-size 16 \
|
| 272 |
+
--query-reg 1e-4 \
|
| 273 |
+
--doc-reg 1e-2 \
|
| 274 |
+
--run-name fin-sparse-encoder-doc-v2-reg1e-2-1500
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
Triplet top-k evaluation:
|
| 278 |
+
|
| 279 |
+
```bash
|
| 280 |
+
uv run --exclude-newer=2026-05-13T00:00:00Z scripts/eval_fin_sparse_topk.py \
|
| 281 |
+
--eval-size 1000 \
|
| 282 |
+
--batch-size 16 \
|
| 283 |
+
--topks none,128,64 \
|
| 284 |
+
--models ./
|
| 285 |
+
```
|
| 286 |
+
|
| 287 |
+
Retrieval proxy evaluation:
|
| 288 |
+
|
| 289 |
+
```bash
|
| 290 |
+
uv run --exclude-newer=2026-05-13T00:00:00Z scripts/eval_fin_sparse_retrieval_proxy.py \
|
| 291 |
+
--model ./ \
|
| 292 |
+
--eval-size 1000 \
|
| 293 |
+
--batch-size 16 \
|
| 294 |
+
--topks none,128,64
|
| 295 |
+
```
|
| 296 |
+
|
| 297 |
+
BM25 baseline:
|
| 298 |
+
|
| 299 |
+
```bash
|
| 300 |
+
uv run --exclude-newer=2026-05-13T00:00:00Z scripts/eval_fin_bm25_retrieval_proxy.py \
|
| 301 |
+
--eval-size 1000
|
| 302 |
+
```
|
| 303 |
+
|
| 304 |
+
## Files in this repo
|
| 305 |
+
|
| 306 |
+
| path | purpose |
|
| 307 |
+
|---|---|
|
| 308 |
+
| `scripts/train_fin_sparse_encoder_v2.py` | Training script used for the final run. |
|
| 309 |
+
| `scripts/eval_fin_sparse_topk.py` | Pairwise triplet accuracy with document top-k pruning. |
|
| 310 |
+
| `scripts/eval_fin_sparse_retrieval_proxy.py` | In-memory retrieval proxy over held-out positives and negatives. |
|
| 311 |
+
| `scripts/eval_fin_bm25_retrieval_proxy.py` | Local BM25 baseline over the same held-out candidate pool. |
|
| 312 |
+
| `benchmarks/regularization_sweep_500_steps.json` | 500-step document-regularization sweep. |
|
| 313 |
+
| `benchmarks/triplet_topk_1000.json` | Final triplet/top-k results. |
|
| 314 |
+
| `benchmarks/retrieval_proxy_1000.json` | Final retrieval proxy and baseline results. |
|
| 315 |
+
|
| 316 |
+
## Limitations
|
| 317 |
+
|
| 318 |
+
These results are promising but not final production proof.
|
| 319 |
+
|
| 320 |
+
- The retrieval benchmark is a proxy, not a full production OpenSearch/Elasticsearch benchmark.
|
| 321 |
+
- Each query has only one labeled positive, so other genuinely relevant chunks may be counted as false competitors.
|
| 322 |
+
- The candidate pool has 1,912 chunks, not millions.
|
| 323 |
+
- Evaluation used the first 1,000 usable held-out examples, not the full test split.
|
| 324 |
+
- BM25 is a local baseline, not a tuned OpenSearch BM25 setup.
|
| 325 |
+
- End-to-end index size, latency, shard behavior, and hybrid retrieval have not been measured yet.
|
| 326 |
+
|
| 327 |
+
## Next steps
|
| 328 |
+
|
| 329 |
+
- Run the retrieval proxy on the full held-out test split.
|
| 330 |
+
- Run a real OpenSearch/Elasticsearch index benchmark with top-128 sparse vectors.
|
| 331 |
+
- Test hybrid BM25 + learned sparse retrieval.
|
| 332 |
+
- Add query-type, company, and filing-type slice analysis.
|
| 333 |
+
- Add human-labeled financial filing retrieval judgments.
|
benchmarks/regularization_sweep_500_steps.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"dataset": "oneryalcin/financial-filings-sparse-retrieval-training",
|
| 3 |
+
"config": "combined",
|
| 4 |
+
"split": "test",
|
| 5 |
+
"rows": 1000,
|
| 6 |
+
"base_model": "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill",
|
| 7 |
+
"constant_settings": {
|
| 8 |
+
"steps": 500,
|
| 9 |
+
"train_size_requested": 20000,
|
| 10 |
+
"usable_train_rows": 18247,
|
| 11 |
+
"batch_size": 8,
|
| 12 |
+
"query_regularizer_weight": 0.0001,
|
| 13 |
+
"loss": "SpladeLoss(SparseMultipleNegativesRankingLoss)",
|
| 14 |
+
"sampler": "BatchSamplers.NO_DUPLICATES"
|
| 15 |
+
},
|
| 16 |
+
"results": [
|
| 17 |
+
{"document_regularizer_weight": 0.00008, "accuracy": 0.738, "positive_doc_active_dims": 4204.911, "negative_doc_active_dims": 3784.729, "note": "High score but too dense for practical sparse indexing."},
|
| 18 |
+
{"document_regularizer_weight": 0.00015, "accuracy": 0.748, "positive_doc_active_dims": 3214.797, "negative_doc_active_dims": 2956.645, "note": "Best raw 500-step triplet accuracy but still too dense."},
|
| 19 |
+
{"document_regularizer_weight": 0.0003, "accuracy": 0.734, "positive_doc_active_dims": 2101.448, "negative_doc_active_dims": 1875.179},
|
| 20 |
+
{"document_regularizer_weight": 0.001, "accuracy": 0.727, "positive_doc_active_dims": 1340.540, "negative_doc_active_dims": 1136.954},
|
| 21 |
+
{"document_regularizer_weight": 0.003, "accuracy": 0.734, "positive_doc_active_dims": 657.067, "negative_doc_active_dims": 563.964},
|
| 22 |
+
{"document_regularizer_weight": 0.01, "accuracy": 0.732, "positive_doc_active_dims": 296.654, "negative_doc_active_dims": 261.160, "note": "Chosen for longer run because quality stayed strong while native vectors became much more indexable."},
|
| 23 |
+
{"document_regularizer_weight": 0.03, "accuracy": 0.723, "positive_doc_active_dims": 160.887, "negative_doc_active_dims": 143.986, "note": "Compact but lower quality than 0.01."}
|
| 24 |
+
]
|
| 25 |
+
}
|
benchmarks/retrieval_proxy_1000.json
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"dataset": "oneryalcin/financial-filings-sparse-retrieval-training",
|
| 3 |
+
"config": "combined",
|
| 4 |
+
"split": "test",
|
| 5 |
+
"rows": 1000,
|
| 6 |
+
"corpus_size": 1912,
|
| 7 |
+
"method": "Build an in-memory candidate corpus from all unique positives and first hard negatives in the 1000 held-out rows. For each query, rank all 1912 candidate chunks by sparse dot product or BM25 score. Only the known positive is labeled relevant.",
|
| 8 |
+
"results": [
|
| 9 |
+
{
|
| 10 |
+
"model": "oneryalcin/financial-filings-sparse-encoder-v1",
|
| 11 |
+
"doc_topk": null,
|
| 12 |
+
"doc_active_dims": 319.79864501953125,
|
| 13 |
+
"query_active_dims": 15.227999687194824,
|
| 14 |
+
"recall_at_1": 0.38999998569488525,
|
| 15 |
+
"recall_at_5": 0.5889999866485596,
|
| 16 |
+
"recall_at_10": 0.675000011920929,
|
| 17 |
+
"recall_at_20": 0.75,
|
| 18 |
+
"mrr_at_10": 0.479160338640213,
|
| 19 |
+
"ndcg_at_10": 0.5258780717849731,
|
| 20 |
+
"mean_rank": 51.290000915527344,
|
| 21 |
+
"median_rank": 3.0
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"model": "oneryalcin/financial-filings-sparse-encoder-v1",
|
| 25 |
+
"doc_topk": 128,
|
| 26 |
+
"doc_active_dims": 126.48535919189453,
|
| 27 |
+
"query_active_dims": 15.227999687194824,
|
| 28 |
+
"recall_at_1": 0.38600000739097595,
|
| 29 |
+
"recall_at_5": 0.578000009059906,
|
| 30 |
+
"recall_at_10": 0.671999990940094,
|
| 31 |
+
"recall_at_20": 0.7390000224113464,
|
| 32 |
+
"mrr_at_10": 0.4734043478965759,
|
| 33 |
+
"ndcg_at_10": 0.5206189155578613,
|
| 34 |
+
"mean_rank": 52.779998779296875,
|
| 35 |
+
"median_rank": 3.0
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"model": "oneryalcin/financial-filings-sparse-encoder-v1",
|
| 39 |
+
"doc_topk": 64,
|
| 40 |
+
"doc_active_dims": 63.98692321777344,
|
| 41 |
+
"query_active_dims": 15.227999687194824,
|
| 42 |
+
"recall_at_1": 0.3499999940395355,
|
| 43 |
+
"recall_at_5": 0.5569999814033508,
|
| 44 |
+
"recall_at_10": 0.6480000019073486,
|
| 45 |
+
"recall_at_20": 0.7239999771118164,
|
| 46 |
+
"mrr_at_10": 0.4424420893192291,
|
| 47 |
+
"ndcg_at_10": 0.4913859963417053,
|
| 48 |
+
"mean_rank": 56.25199890136719,
|
| 49 |
+
"median_rank": 4.0
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"model": "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill",
|
| 53 |
+
"doc_topk": null,
|
| 54 |
+
"doc_active_dims": 377.7003173828125,
|
| 55 |
+
"query_active_dims": 15.227999687194824,
|
| 56 |
+
"recall_at_1": 0.32100000977516174,
|
| 57 |
+
"recall_at_5": 0.5649999976158142,
|
| 58 |
+
"recall_at_10": 0.6370000243186951,
|
| 59 |
+
"recall_at_20": 0.6990000009536743,
|
| 60 |
+
"mrr_at_10": 0.4311789870262146,
|
| 61 |
+
"ndcg_at_10": 0.4811413586139679,
|
| 62 |
+
"mean_rank": 98.5719985961914,
|
| 63 |
+
"median_rank": 3.0
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"model": "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill",
|
| 67 |
+
"doc_topk": 128,
|
| 68 |
+
"doc_active_dims": 127.96809387207031,
|
| 69 |
+
"query_active_dims": 15.227999687194824,
|
| 70 |
+
"recall_at_1": 0.3140000104904175,
|
| 71 |
+
"recall_at_5": 0.5440000295639038,
|
| 72 |
+
"recall_at_10": 0.6269999742507935,
|
| 73 |
+
"recall_at_20": 0.6970000267028809,
|
| 74 |
+
"mrr_at_10": 0.4224868714809418,
|
| 75 |
+
"ndcg_at_10": 0.4720706045627594,
|
| 76 |
+
"mean_rank": 95.6719970703125,
|
| 77 |
+
"median_rank": 3.0
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"model": "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill",
|
| 81 |
+
"doc_topk": 64,
|
| 82 |
+
"doc_active_dims": 64.0,
|
| 83 |
+
"query_active_dims": 15.227999687194824,
|
| 84 |
+
"recall_at_1": 0.29499998688697815,
|
| 85 |
+
"recall_at_5": 0.5260000228881836,
|
| 86 |
+
"recall_at_10": 0.5950000286102295,
|
| 87 |
+
"recall_at_20": 0.6650000214576721,
|
| 88 |
+
"mrr_at_10": 0.3963717818260193,
|
| 89 |
+
"ndcg_at_10": 0.44444921612739563,
|
| 90 |
+
"mean_rank": 96.0739974975586,
|
| 91 |
+
"median_rank": 4.0
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"model": "bm25",
|
| 95 |
+
"recall_at_1": 0.24,
|
| 96 |
+
"recall_at_5": 0.582,
|
| 97 |
+
"recall_at_10": 0.641,
|
| 98 |
+
"recall_at_20": 0.686,
|
| 99 |
+
"mrr_at_10": 0.3968968253968254,
|
| 100 |
+
"ndcg_at_10": 0.45737284741632095,
|
| 101 |
+
"mean_rank": 158.692,
|
| 102 |
+
"median_rank": 3,
|
| 103 |
+
"k1": 1.2,
|
| 104 |
+
"b": 0.75,
|
| 105 |
+
"avg_doc_len": 213.8655857740586,
|
| 106 |
+
"median_doc_len": 222
|
| 107 |
+
}
|
| 108 |
+
]
|
| 109 |
+
}
|
benchmarks/triplet_topk_1000.json
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"dataset": "oneryalcin/financial-filings-sparse-retrieval-training",
|
| 3 |
+
"config": "combined",
|
| 4 |
+
"split": "test",
|
| 5 |
+
"rows": 1000,
|
| 6 |
+
"method": "For each held-out row, score(query, positive) and score(query, first hard negative) with sparse dot product. Accuracy is the fraction where positive scores higher.",
|
| 7 |
+
"results": [
|
| 8 |
+
{
|
| 9 |
+
"model": "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill",
|
| 10 |
+
"doc_topk": null,
|
| 11 |
+
"accuracy": 0.4960,
|
| 12 |
+
"mean_margin": 0.4508,
|
| 13 |
+
"query_active_dims": 15.228,
|
| 14 |
+
"positive_doc_active_dims": 371.084,
|
| 15 |
+
"negative_doc_active_dims": 354.758
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"model": "oneryalcin/financial-filings-sparse-encoder-v1",
|
| 19 |
+
"doc_topk": null,
|
| 20 |
+
"accuracy": 0.781000018119812,
|
| 21 |
+
"mean_margin": 3.0043423175811768,
|
| 22 |
+
"median_margin": 2.4458980560302734,
|
| 23 |
+
"query_active_dims": 15.227999687194824,
|
| 24 |
+
"positive_doc_active_dims": 340.94500732421875,
|
| 25 |
+
"negative_doc_active_dims": 298.7919921875
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"model": "oneryalcin/financial-filings-sparse-encoder-v1",
|
| 29 |
+
"doc_topk": 128,
|
| 30 |
+
"accuracy": 0.7799999713897705,
|
| 31 |
+
"mean_margin": 2.9922006130218506,
|
| 32 |
+
"median_margin": 2.4021129608154297,
|
| 33 |
+
"query_active_dims": 15.227999687194824,
|
| 34 |
+
"positive_doc_active_dims": 126.6719970703125,
|
| 35 |
+
"negative_doc_active_dims": 126.36499786376953
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"model": "oneryalcin/financial-filings-sparse-encoder-v1",
|
| 39 |
+
"doc_topk": 64,
|
| 40 |
+
"accuracy": 0.7590000033378601,
|
| 41 |
+
"mean_margin": 2.988410234451294,
|
| 42 |
+
"median_margin": 2.3834714889526367,
|
| 43 |
+
"query_active_dims": 15.227999687194824,
|
| 44 |
+
"positive_doc_active_dims": 63.986000061035156,
|
| 45 |
+
"negative_doc_active_dims": 63.98899841308594
|
| 46 |
+
}
|
| 47 |
+
]
|
| 48 |
+
}
|
config_sentence_transformers.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"__version__": {
|
| 3 |
+
"pytorch": "2.11.0",
|
| 4 |
+
"sentence_transformers": "5.5.0",
|
| 5 |
+
"transformers": "5.8.0"
|
| 6 |
+
},
|
| 7 |
+
"default_prompt_name": null,
|
| 8 |
+
"model_type": "SparseEncoder",
|
| 9 |
+
"prompts": {
|
| 10 |
+
"document": "",
|
| 11 |
+
"query": ""
|
| 12 |
+
},
|
| 13 |
+
"similarity_fn_name": "dot"
|
| 14 |
+
}
|
document_0_MLMTransformer/config.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"activation": "gelu",
|
| 3 |
+
"architectures": [
|
| 4 |
+
"DistilBertForMaskedLM"
|
| 5 |
+
],
|
| 6 |
+
"attention_dropout": 0.1,
|
| 7 |
+
"bos_token_id": null,
|
| 8 |
+
"dim": 768,
|
| 9 |
+
"dropout": 0.1,
|
| 10 |
+
"dtype": "float32",
|
| 11 |
+
"eos_token_id": null,
|
| 12 |
+
"hidden_dim": 3072,
|
| 13 |
+
"initializer_range": 0.02,
|
| 14 |
+
"max_position_embeddings": 512,
|
| 15 |
+
"model_type": "distilbert",
|
| 16 |
+
"n_heads": 12,
|
| 17 |
+
"n_layers": 6,
|
| 18 |
+
"pad_token_id": 0,
|
| 19 |
+
"qa_dropout": 0.1,
|
| 20 |
+
"seq_classif_dropout": 0.2,
|
| 21 |
+
"sinusoidal_pos_embds": false,
|
| 22 |
+
"tie_weights_": true,
|
| 23 |
+
"tie_word_embeddings": true,
|
| 24 |
+
"transformers_version": "5.8.0",
|
| 25 |
+
"use_cache": false,
|
| 26 |
+
"vocab_size": 30522
|
| 27 |
+
}
|
document_0_MLMTransformer/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:74310f44e4ea4067d85d4ab39161a1eed74b96ead822dea488bdbda9d034a304
|
| 3 |
+
size 267954768
|
document_0_MLMTransformer/sentence_bert_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"transformer_task": "fill-mask",
|
| 3 |
+
"modality_config": {
|
| 4 |
+
"text": {
|
| 5 |
+
"method": "forward",
|
| 6 |
+
"method_output_name": "logits"
|
| 7 |
+
}
|
| 8 |
+
},
|
| 9 |
+
"module_output_name": "token_embeddings"
|
| 10 |
+
}
|
document_0_MLMTransformer/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
document_0_MLMTransformer/tokenizer_config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"clean_up_tokenization_spaces": true,
|
| 4 |
+
"cls_token": "[CLS]",
|
| 5 |
+
"do_lower_case": true,
|
| 6 |
+
"is_local": false,
|
| 7 |
+
"local_files_only": false,
|
| 8 |
+
"mask_token": "[MASK]",
|
| 9 |
+
"model_max_length": 384,
|
| 10 |
+
"pad_token": "[PAD]",
|
| 11 |
+
"sep_token": "[SEP]",
|
| 12 |
+
"strip_accents": null,
|
| 13 |
+
"tokenize_chinese_chars": true,
|
| 14 |
+
"tokenizer_class": "DistilBertTokenizer",
|
| 15 |
+
"unk_token": "[UNK]"
|
| 16 |
+
}
|
document_1_SpladePooling/config.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"pooling_strategy": "max",
|
| 3 |
+
"activation_function": "relu",
|
| 4 |
+
"embedding_dimension": 30522
|
| 5 |
+
}
|
modules.json
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"idx": 0,
|
| 4 |
+
"name": "0",
|
| 5 |
+
"path": "",
|
| 6 |
+
"type": "sentence_transformers.base.modules.router.Router"
|
| 7 |
+
}
|
| 8 |
+
]
|
query_0_SparseStaticEmbedding/config.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"frozen": true
|
| 3 |
+
}
|
query_0_SparseStaticEmbedding/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:711ec64837a7962d2ae106996079782b7ee87860089a0b2348bf7cb840f252d3
|
| 3 |
+
size 122168
|
query_0_SparseStaticEmbedding/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
query_0_SparseStaticEmbedding/tokenizer_config.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"backend": "tokenizers",
|
| 3 |
+
"clean_up_tokenization_spaces": true,
|
| 4 |
+
"cls_token": "[CLS]",
|
| 5 |
+
"do_lower_case": true,
|
| 6 |
+
"is_local": false,
|
| 7 |
+
"local_files_only": false,
|
| 8 |
+
"mask_token": "[MASK]",
|
| 9 |
+
"model_max_length": 512,
|
| 10 |
+
"pad_token": "[PAD]",
|
| 11 |
+
"sep_token": "[SEP]",
|
| 12 |
+
"strip_accents": null,
|
| 13 |
+
"tokenize_chinese_chars": true,
|
| 14 |
+
"tokenizer_class": "DistilBertTokenizer",
|
| 15 |
+
"unk_token": "[UNK]"
|
| 16 |
+
}
|
router_config.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"types": {
|
| 3 |
+
"query_0_SparseStaticEmbedding": "sentence_transformers.sparse_encoder.modules.sparse_static_embedding.SparseStaticEmbedding",
|
| 4 |
+
"document_0_MLMTransformer": "sentence_transformers.sparse_encoder.modules.mlm_transformer.MLMTransformer",
|
| 5 |
+
"document_1_SpladePooling": "sentence_transformers.sparse_encoder.modules.splade_pooling.SpladePooling"
|
| 6 |
+
},
|
| 7 |
+
"structure": {
|
| 8 |
+
"query": [
|
| 9 |
+
"query_0_SparseStaticEmbedding"
|
| 10 |
+
],
|
| 11 |
+
"document": [
|
| 12 |
+
"document_0_MLMTransformer",
|
| 13 |
+
"document_1_SpladePooling"
|
| 14 |
+
]
|
| 15 |
+
},
|
| 16 |
+
"parameters": {
|
| 17 |
+
"default_route": "document",
|
| 18 |
+
"allow_empty_key": true,
|
| 19 |
+
"route_mappings": {}
|
| 20 |
+
}
|
| 21 |
+
}
|
scripts/eval_fin_bm25_retrieval_proxy.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# /// script
|
| 3 |
+
# requires-python = ">=3.11"
|
| 4 |
+
# dependencies = [
|
| 5 |
+
# "datasets",
|
| 6 |
+
# ]
|
| 7 |
+
# ///
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import argparse
|
| 11 |
+
import json
|
| 12 |
+
import logging
|
| 13 |
+
import math
|
| 14 |
+
import re
|
| 15 |
+
from collections import Counter, OrderedDict, defaultdict
|
| 16 |
+
|
| 17 |
+
from datasets import load_dataset
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
|
| 21 |
+
TOKEN_RE = re.compile(r"[a-z0-9]+")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def tokenize(text: str) -> list[str]:
|
| 25 |
+
return TOKEN_RE.findall(text.lower())
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def first_negative(row: dict) -> str | None:
|
| 29 |
+
negatives = row.get("negatives")
|
| 30 |
+
if isinstance(negatives, list):
|
| 31 |
+
for negative in negatives:
|
| 32 |
+
if isinstance(negative, str) and negative.strip():
|
| 33 |
+
return negative
|
| 34 |
+
if isinstance(negatives, str) and negatives.strip():
|
| 35 |
+
return negatives
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def load_rows(limit: int) -> list[dict]:
|
| 40 |
+
dataset = load_dataset(
|
| 41 |
+
"oneryalcin/financial-filings-sparse-retrieval-training",
|
| 42 |
+
"combined",
|
| 43 |
+
split="test",
|
| 44 |
+
)
|
| 45 |
+
rows: list[dict] = []
|
| 46 |
+
for row in dataset:
|
| 47 |
+
query = row.get("query")
|
| 48 |
+
positive = row.get("positive")
|
| 49 |
+
negative = first_negative(row)
|
| 50 |
+
if query and positive and negative:
|
| 51 |
+
rows.append({"query": query, "positive": positive, "negative": negative})
|
| 52 |
+
if len(rows) >= limit:
|
| 53 |
+
break
|
| 54 |
+
return rows
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def dedupe_texts(rows: list[dict]) -> tuple[list[str], list[int]]:
|
| 58 |
+
corpus = OrderedDict()
|
| 59 |
+
positive_ids: list[int] = []
|
| 60 |
+
for row in rows:
|
| 61 |
+
for key in ("positive", "negative"):
|
| 62 |
+
text = row[key]
|
| 63 |
+
if text not in corpus:
|
| 64 |
+
corpus[text] = len(corpus)
|
| 65 |
+
positive_ids.append(corpus[row["positive"]])
|
| 66 |
+
return list(corpus.keys()), positive_ids
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def build_bm25(corpus: list[str], k1: float, b: float) -> tuple[dict[str, list[tuple[int, float]]], list[int], float]:
|
| 70 |
+
doc_lens: list[int] = []
|
| 71 |
+
doc_counts: list[Counter[str]] = []
|
| 72 |
+
document_frequency: Counter[str] = Counter()
|
| 73 |
+
|
| 74 |
+
for text in corpus:
|
| 75 |
+
counts = Counter(tokenize(text))
|
| 76 |
+
doc_counts.append(counts)
|
| 77 |
+
doc_lens.append(sum(counts.values()))
|
| 78 |
+
document_frequency.update(counts.keys())
|
| 79 |
+
|
| 80 |
+
corpus_size = len(corpus)
|
| 81 |
+
avg_doc_len = sum(doc_lens) / max(corpus_size, 1)
|
| 82 |
+
inverted: dict[str, list[tuple[int, float]]] = defaultdict(list)
|
| 83 |
+
|
| 84 |
+
for doc_id, counts in enumerate(doc_counts):
|
| 85 |
+
norm = k1 * (1 - b + b * doc_lens[doc_id] / avg_doc_len)
|
| 86 |
+
for term, tf in counts.items():
|
| 87 |
+
df = document_frequency[term]
|
| 88 |
+
idf = math.log(1 + (corpus_size - df + 0.5) / (df + 0.5))
|
| 89 |
+
score = idf * (tf * (k1 + 1)) / (tf + norm)
|
| 90 |
+
inverted[term].append((doc_id, score))
|
| 91 |
+
|
| 92 |
+
return inverted, doc_lens, avg_doc_len
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def rank_query(query: str, inverted: dict[str, list[tuple[int, float]]], corpus_size: int) -> list[float]:
|
| 96 |
+
scores = [0.0] * corpus_size
|
| 97 |
+
for term, qtf in Counter(tokenize(query)).items():
|
| 98 |
+
for doc_id, score in inverted.get(term, ()):
|
| 99 |
+
scores[doc_id] += score * qtf
|
| 100 |
+
return scores
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def retrieval_metrics(all_scores: list[list[float]], positive_ids: list[int]) -> dict:
|
| 104 |
+
ranks: list[int] = []
|
| 105 |
+
for scores, positive_id in zip(all_scores, positive_ids):
|
| 106 |
+
positive_score = scores[positive_id]
|
| 107 |
+
rank = 1 + sum(score > positive_score for score in scores)
|
| 108 |
+
ranks.append(rank)
|
| 109 |
+
|
| 110 |
+
ranks_sorted = sorted(ranks)
|
| 111 |
+
count = len(ranks)
|
| 112 |
+
mrr10 = sum((1.0 / rank) if rank <= 10 else 0.0 for rank in ranks) / count
|
| 113 |
+
ndcg10 = sum((1.0 / math.log2(rank + 1)) if rank <= 10 else 0.0 for rank in ranks) / count
|
| 114 |
+
return {
|
| 115 |
+
"recall_at_1": sum(rank <= 1 for rank in ranks) / count,
|
| 116 |
+
"recall_at_5": sum(rank <= 5 for rank in ranks) / count,
|
| 117 |
+
"recall_at_10": sum(rank <= 10 for rank in ranks) / count,
|
| 118 |
+
"recall_at_20": sum(rank <= 20 for rank in ranks) / count,
|
| 119 |
+
"mrr_at_10": mrr10,
|
| 120 |
+
"ndcg_at_10": ndcg10,
|
| 121 |
+
"mean_rank": sum(ranks) / count,
|
| 122 |
+
"median_rank": ranks_sorted[count // 2],
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def main() -> None:
|
| 127 |
+
parser = argparse.ArgumentParser()
|
| 128 |
+
parser.add_argument("--eval-size", type=int, default=1000)
|
| 129 |
+
parser.add_argument("--k1", type=float, default=1.2)
|
| 130 |
+
parser.add_argument("--b", type=float, default=0.75)
|
| 131 |
+
args = parser.parse_args()
|
| 132 |
+
|
| 133 |
+
rows = load_rows(args.eval_size)
|
| 134 |
+
queries = [row["query"] for row in rows]
|
| 135 |
+
corpus, positive_ids = dedupe_texts(rows)
|
| 136 |
+
logging.info("rows=%d corpus=%d k1=%s b=%s", len(rows), len(corpus), args.k1, args.b)
|
| 137 |
+
|
| 138 |
+
inverted, doc_lens, avg_doc_len = build_bm25(corpus, args.k1, args.b)
|
| 139 |
+
scores = [rank_query(query, inverted, len(corpus)) for query in queries]
|
| 140 |
+
result = {
|
| 141 |
+
"model": "bm25",
|
| 142 |
+
"rows": len(rows),
|
| 143 |
+
"corpus_size": len(corpus),
|
| 144 |
+
"avg_doc_len": avg_doc_len,
|
| 145 |
+
"median_doc_len": sorted(doc_lens)[len(doc_lens) // 2],
|
| 146 |
+
"k1": args.k1,
|
| 147 |
+
"b": args.b,
|
| 148 |
+
**retrieval_metrics(scores, positive_ids),
|
| 149 |
+
}
|
| 150 |
+
logging.info("BM25_RESULT %s", json.dumps(result, sort_keys=True))
|
| 151 |
+
print(json.dumps(result, sort_keys=True))
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
if __name__ == "__main__":
|
| 155 |
+
main()
|
scripts/eval_fin_sparse_retrieval_proxy.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# /// script
|
| 3 |
+
# requires-python = ">=3.11"
|
| 4 |
+
# dependencies = [
|
| 5 |
+
# "datasets",
|
| 6 |
+
# "sentence-transformers==5.5.0",
|
| 7 |
+
# "torch",
|
| 8 |
+
# ]
|
| 9 |
+
# ///
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import json
|
| 14 |
+
import logging
|
| 15 |
+
from collections import OrderedDict
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from datasets import load_dataset
|
| 19 |
+
from sentence_transformers.sparse_encoder import SparseEncoder
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def parse_topks(value: str) -> list[int | None]:
|
| 26 |
+
out: list[int | None] = []
|
| 27 |
+
for item in value.split(","):
|
| 28 |
+
item = item.strip().lower()
|
| 29 |
+
if item in {"none", "all", "null"}:
|
| 30 |
+
out.append(None)
|
| 31 |
+
else:
|
| 32 |
+
out.append(int(item))
|
| 33 |
+
return out
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def first_negative(row: dict) -> str | None:
|
| 37 |
+
negatives = row.get("negatives")
|
| 38 |
+
if isinstance(negatives, list):
|
| 39 |
+
for negative in negatives:
|
| 40 |
+
if isinstance(negative, str) and negative.strip():
|
| 41 |
+
return negative
|
| 42 |
+
if isinstance(negatives, str) and negatives.strip():
|
| 43 |
+
return negatives
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def load_rows(limit: int) -> list[dict]:
|
| 48 |
+
dataset = load_dataset(
|
| 49 |
+
"oneryalcin/financial-filings-sparse-retrieval-training",
|
| 50 |
+
"combined",
|
| 51 |
+
split="test",
|
| 52 |
+
)
|
| 53 |
+
rows: list[dict] = []
|
| 54 |
+
for row in dataset:
|
| 55 |
+
query = row.get("query")
|
| 56 |
+
positive = row.get("positive")
|
| 57 |
+
negative = first_negative(row)
|
| 58 |
+
if query and positive and negative:
|
| 59 |
+
rows.append({"query": query, "positive": positive, "negative": negative})
|
| 60 |
+
if len(rows) >= limit:
|
| 61 |
+
break
|
| 62 |
+
return rows
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def dedupe_texts(rows: list[dict]) -> tuple[list[str], list[int], list[int]]:
|
| 66 |
+
corpus = OrderedDict()
|
| 67 |
+
positive_ids: list[int] = []
|
| 68 |
+
negative_ids: list[int] = []
|
| 69 |
+
for row in rows:
|
| 70 |
+
for key in ("positive", "negative"):
|
| 71 |
+
text = row[key]
|
| 72 |
+
if text not in corpus:
|
| 73 |
+
corpus[text] = len(corpus)
|
| 74 |
+
positive_ids.append(corpus[row["positive"]])
|
| 75 |
+
negative_ids.append(corpus[row["negative"]])
|
| 76 |
+
return list(corpus.keys()), positive_ids, negative_ids
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def densify(encoded) -> torch.Tensor:
|
| 80 |
+
if isinstance(encoded, torch.Tensor):
|
| 81 |
+
tensor = encoded
|
| 82 |
+
if tensor.layout != torch.strided:
|
| 83 |
+
tensor = tensor.to_dense()
|
| 84 |
+
elif hasattr(encoded, "to_dense"):
|
| 85 |
+
tensor = encoded.to_dense()
|
| 86 |
+
else:
|
| 87 |
+
tensor = torch.as_tensor(encoded)
|
| 88 |
+
return tensor.float().cpu()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def active_dims(tensor: torch.Tensor) -> float:
|
| 92 |
+
return (tensor > 0).sum(dim=1).float().mean().item()
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def apply_doc_topk(docs: torch.Tensor, topk: int | None) -> torch.Tensor:
|
| 96 |
+
if topk is None or topk >= docs.shape[1]:
|
| 97 |
+
return docs
|
| 98 |
+
values, indices = torch.topk(docs, k=topk, dim=1)
|
| 99 |
+
pruned = torch.zeros_like(docs)
|
| 100 |
+
pruned.scatter_(1, indices, values)
|
| 101 |
+
return pruned
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def retrieval_metrics(scores: torch.Tensor, positive_ids: list[int]) -> dict:
|
| 105 |
+
positive = torch.tensor(positive_ids, dtype=torch.long)
|
| 106 |
+
positive_scores = scores[torch.arange(scores.shape[0]), positive]
|
| 107 |
+
ranks = (scores > positive_scores.unsqueeze(1)).sum(dim=1) + 1
|
| 108 |
+
ranks_f = ranks.float()
|
| 109 |
+
|
| 110 |
+
mrr10 = torch.where(ranks <= 10, 1.0 / ranks_f, torch.zeros_like(ranks_f)).mean().item()
|
| 111 |
+
ndcg10 = torch.where(
|
| 112 |
+
ranks <= 10,
|
| 113 |
+
1.0 / torch.log2(ranks_f + 1.0),
|
| 114 |
+
torch.zeros_like(ranks_f),
|
| 115 |
+
).mean().item()
|
| 116 |
+
return {
|
| 117 |
+
"recall_at_1": (ranks <= 1).float().mean().item(),
|
| 118 |
+
"recall_at_5": (ranks <= 5).float().mean().item(),
|
| 119 |
+
"recall_at_10": (ranks <= 10).float().mean().item(),
|
| 120 |
+
"recall_at_20": (ranks <= 20).float().mean().item(),
|
| 121 |
+
"mrr_at_10": mrr10,
|
| 122 |
+
"ndcg_at_10": ndcg10,
|
| 123 |
+
"mean_rank": ranks_f.mean().item(),
|
| 124 |
+
"median_rank": ranks_f.median().item(),
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def main() -> None:
|
| 129 |
+
parser = argparse.ArgumentParser()
|
| 130 |
+
parser.add_argument("--model", required=True)
|
| 131 |
+
parser.add_argument("--eval-size", type=int, default=1000)
|
| 132 |
+
parser.add_argument("--batch-size", type=int, default=16)
|
| 133 |
+
parser.add_argument("--topks", default="none,128,64")
|
| 134 |
+
args = parser.parse_args()
|
| 135 |
+
|
| 136 |
+
rows = load_rows(args.eval_size)
|
| 137 |
+
queries = [row["query"] for row in rows]
|
| 138 |
+
corpus, positive_ids, negative_ids = dedupe_texts(rows)
|
| 139 |
+
logging.info("rows=%d corpus=%d topks=%s", len(rows), len(corpus), args.topks)
|
| 140 |
+
|
| 141 |
+
model = SparseEncoder(args.model, device="cpu")
|
| 142 |
+
query_vectors = densify(
|
| 143 |
+
model.encode_query(queries, batch_size=args.batch_size, convert_to_tensor=True)
|
| 144 |
+
)
|
| 145 |
+
doc_vectors = densify(
|
| 146 |
+
model.encode_document(corpus, batch_size=args.batch_size, convert_to_tensor=True)
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
query_active = active_dims(query_vectors)
|
| 150 |
+
doc_active = active_dims(doc_vectors)
|
| 151 |
+
|
| 152 |
+
for topk in parse_topks(args.topks):
|
| 153 |
+
pruned_docs = apply_doc_topk(doc_vectors, topk)
|
| 154 |
+
scores = query_vectors @ pruned_docs.T
|
| 155 |
+
metrics = retrieval_metrics(scores, positive_ids)
|
| 156 |
+
result = {
|
| 157 |
+
"model": args.model,
|
| 158 |
+
"rows": len(rows),
|
| 159 |
+
"corpus_size": len(corpus),
|
| 160 |
+
"doc_topk": topk,
|
| 161 |
+
"query_active_dims": query_active,
|
| 162 |
+
"doc_active_dims": active_dims(pruned_docs),
|
| 163 |
+
**metrics,
|
| 164 |
+
}
|
| 165 |
+
logging.info("RETRIEVAL_RESULT %s", json.dumps(result, sort_keys=True))
|
| 166 |
+
print(json.dumps(result, sort_keys=True))
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
if __name__ == "__main__":
|
| 170 |
+
main()
|
scripts/eval_fin_sparse_topk.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# /// script
|
| 3 |
+
# dependencies = [
|
| 4 |
+
# "sentence-transformers[train]==5.5.0",
|
| 5 |
+
# "datasets",
|
| 6 |
+
# "torch",
|
| 7 |
+
# ]
|
| 8 |
+
# ///
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import json
|
| 13 |
+
import logging
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from datasets import load_dataset
|
| 18 |
+
from sentence_transformers.sparse_encoder import SparseEncoder
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
DATASET = "oneryalcin/financial-filings-sparse-retrieval-training"
|
| 22 |
+
CONFIG = "combined"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
|
| 26 |
+
for noisy in ("httpx", "httpcore", "huggingface_hub", "urllib3", "filelock", "fsspec"):
|
| 27 |
+
logging.getLogger(noisy).setLevel(logging.WARNING)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass(frozen=True)
|
| 31 |
+
class EvalBatch:
|
| 32 |
+
queries: list[str]
|
| 33 |
+
positives: list[str]
|
| 34 |
+
negatives: list[str]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def dense(tensor: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
if tensor.is_sparse:
|
| 39 |
+
return tensor.to_dense()
|
| 40 |
+
return tensor
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def doc_topk(tensor: torch.Tensor, k: int | None) -> torch.Tensor:
|
| 44 |
+
if k is None or k <= 0 or k >= tensor.shape[1]:
|
| 45 |
+
return tensor
|
| 46 |
+
values, indices = torch.topk(tensor, k=k, dim=1)
|
| 47 |
+
out = torch.zeros_like(tensor)
|
| 48 |
+
return out.scatter(1, indices, values)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def active_dims(tensor: torch.Tensor) -> float:
|
| 52 |
+
return (tensor != 0).sum(dim=1).float().mean().item()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def margins(queries: torch.Tensor, positives: torch.Tensor, negatives: torch.Tensor) -> torch.Tensor:
|
| 56 |
+
positive_scores = (queries * positives).sum(dim=1)
|
| 57 |
+
negative_scores = (queries * negatives).sum(dim=1)
|
| 58 |
+
return positive_scores - negative_scores
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def load_eval_rows(limit: int) -> EvalBatch:
|
| 62 |
+
rows = load_dataset(DATASET, CONFIG, split="test", streaming=False)
|
| 63 |
+
queries: list[str] = []
|
| 64 |
+
positives: list[str] = []
|
| 65 |
+
negatives: list[str] = []
|
| 66 |
+
|
| 67 |
+
for row in rows:
|
| 68 |
+
negatives_list = row.get("negatives") or []
|
| 69 |
+
if not negatives_list:
|
| 70 |
+
continue
|
| 71 |
+
queries.append(row["query"])
|
| 72 |
+
positives.append(row["positive"])
|
| 73 |
+
negatives.append(negatives_list[0])
|
| 74 |
+
if len(queries) >= limit:
|
| 75 |
+
break
|
| 76 |
+
|
| 77 |
+
return EvalBatch(queries=queries, positives=positives, negatives=negatives)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def eval_model(model_name: str, batch: EvalBatch, topks: list[int | None], batch_size: int) -> list[dict[str, float | int | str | None]]:
|
| 81 |
+
logging.info("Loading %s", model_name)
|
| 82 |
+
model = SparseEncoder(model_name, trust_remote_code=True, device="cpu")
|
| 83 |
+
|
| 84 |
+
logging.info("Encoding %d triplets", len(batch.queries))
|
| 85 |
+
queries = dense(model.encode_query(batch.queries, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)).cpu()
|
| 86 |
+
positives_raw = dense(model.encode_document(batch.positives, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)).cpu()
|
| 87 |
+
negatives_raw = dense(model.encode_document(batch.negatives, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)).cpu()
|
| 88 |
+
|
| 89 |
+
results: list[dict[str, float | int | str | None]] = []
|
| 90 |
+
for k in topks:
|
| 91 |
+
positives = doc_topk(positives_raw, k)
|
| 92 |
+
negatives = doc_topk(negatives_raw, k)
|
| 93 |
+
margin = margins(queries, positives, negatives)
|
| 94 |
+
results.append(
|
| 95 |
+
{
|
| 96 |
+
"model": model_name,
|
| 97 |
+
"doc_topk": k,
|
| 98 |
+
"accuracy": (margin > 0).float().mean().item(),
|
| 99 |
+
"mean_margin": margin.mean().item(),
|
| 100 |
+
"median_margin": margin.median().item(),
|
| 101 |
+
"query_active_dims": active_dims(queries),
|
| 102 |
+
"positive_doc_active_dims": active_dims(positives),
|
| 103 |
+
"negative_doc_active_dims": active_dims(negatives),
|
| 104 |
+
}
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
return results
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def parse_topks(raw: str) -> list[int | None]:
|
| 111 |
+
out: list[int | None] = []
|
| 112 |
+
for item in raw.split(","):
|
| 113 |
+
item = item.strip().lower()
|
| 114 |
+
if item in {"none", "all", "0"}:
|
| 115 |
+
out.append(None)
|
| 116 |
+
else:
|
| 117 |
+
out.append(int(item))
|
| 118 |
+
return out
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def main() -> None:
|
| 122 |
+
parser = argparse.ArgumentParser()
|
| 123 |
+
parser.add_argument("--models", nargs="+", required=True)
|
| 124 |
+
parser.add_argument("--eval-size", type=int, default=1000)
|
| 125 |
+
parser.add_argument("--batch-size", type=int, default=16)
|
| 126 |
+
parser.add_argument("--topks", default="none,1024,512,256,128,64")
|
| 127 |
+
args = parser.parse_args()
|
| 128 |
+
|
| 129 |
+
eval_batch = load_eval_rows(args.eval_size)
|
| 130 |
+
topks = parse_topks(args.topks)
|
| 131 |
+
logging.info("rows=%d topks=%s", len(eval_batch.queries), topks)
|
| 132 |
+
|
| 133 |
+
all_results: list[dict[str, float | int | str | None]] = []
|
| 134 |
+
for model_name in args.models:
|
| 135 |
+
all_results.extend(eval_model(model_name, eval_batch, topks, args.batch_size))
|
| 136 |
+
|
| 137 |
+
for result in all_results:
|
| 138 |
+
logging.info("TOPK_RESULT %s", json.dumps(result, sort_keys=True))
|
| 139 |
+
|
| 140 |
+
print(json.dumps(all_results, indent=2, sort_keys=True))
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
main()
|
scripts/train_fin_sparse_encoder_v2.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# /// script
|
| 3 |
+
# requires-python = ">=3.10"
|
| 4 |
+
# dependencies = [
|
| 5 |
+
# "sentence-transformers[train]==5.5.0",
|
| 6 |
+
# "datasets>=2.19.0",
|
| 7 |
+
# "accelerate>=0.26.0",
|
| 8 |
+
# ]
|
| 9 |
+
# ///
|
| 10 |
+
"""Financial filings SparseEncoder V2.
|
| 11 |
+
|
| 12 |
+
Design choices:
|
| 13 |
+
- stable Apache/OpenSearch base: doc-v2-distill, not brittle custom v3-gte;
|
| 14 |
+
- clean dataset shape: query, positive, negative only;
|
| 15 |
+
- explicit Router mapping: query -> query route, positive/negative -> document route;
|
| 16 |
+
- stronger document FLOPS regularization than the old run to avoid dense docs;
|
| 17 |
+
- custom CPU evaluator because sparse materialization is not safe on MPS here.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import ast
|
| 24 |
+
import json
|
| 25 |
+
import logging
|
| 26 |
+
import os
|
| 27 |
+
from pathlib import Path
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
from datasets import Dataset, load_dataset
|
| 31 |
+
|
| 32 |
+
from sentence_transformers import (
|
| 33 |
+
SparseEncoder,
|
| 34 |
+
SparseEncoderModelCardData,
|
| 35 |
+
SparseEncoderTrainer,
|
| 36 |
+
SparseEncoderTrainingArguments,
|
| 37 |
+
)
|
| 38 |
+
from sentence_transformers.base.sampler import BatchSamplers
|
| 39 |
+
from sentence_transformers.sparse_encoder.data_collator import SparseEncoderDataCollator
|
| 40 |
+
from sentence_transformers.sparse_encoder.losses import SparseMultipleNegativesRankingLoss, SpladeLoss
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
DATASET = "oneryalcin/financial-filings-sparse-retrieval-training"
|
| 44 |
+
CONFIG = "combined"
|
| 45 |
+
BASE_MODEL = "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill"
|
| 46 |
+
RUN_NAME = "fin-sparse-encoder-doc-v2-clean-router"
|
| 47 |
+
OUTPUT_DIR = Path("models") / RUN_NAME
|
| 48 |
+
LOG_DIR = Path("logs")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def setup_logging(run_name: str) -> None:
|
| 52 |
+
LOG_DIR.mkdir(exist_ok=True)
|
| 53 |
+
logging.basicConfig(
|
| 54 |
+
format="%(asctime)s - %(message)s",
|
| 55 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 56 |
+
level=logging.INFO,
|
| 57 |
+
handlers=[logging.StreamHandler(), logging.FileHandler(LOG_DIR / f"{run_name}.log")],
|
| 58 |
+
force=True,
|
| 59 |
+
)
|
| 60 |
+
for noisy in ("httpx", "httpcore", "huggingface_hub", "urllib3", "filelock", "fsspec"):
|
| 61 |
+
logging.getLogger(noisy).setLevel(logging.WARNING)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def first_negative(value) -> str:
|
| 65 |
+
if isinstance(value, list):
|
| 66 |
+
return value[0] if value else ""
|
| 67 |
+
if isinstance(value, str):
|
| 68 |
+
parsed = ast.literal_eval(value)
|
| 69 |
+
return parsed[0] if parsed else ""
|
| 70 |
+
raise TypeError(f"Unsupported negatives value: {type(value)}")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def clean_triplets(ds, limit: int | None = None) -> Dataset:
|
| 74 |
+
if limit is not None:
|
| 75 |
+
ds = ds.select(range(min(limit, len(ds))))
|
| 76 |
+
queries: list[str] = []
|
| 77 |
+
positives: list[str] = []
|
| 78 |
+
negatives: list[str] = []
|
| 79 |
+
for query, positive, raw_negative in zip(ds["query"], ds["positive"], ds["negatives"], strict=True):
|
| 80 |
+
negative = first_negative(raw_negative)
|
| 81 |
+
if not negative:
|
| 82 |
+
continue
|
| 83 |
+
queries.append(query)
|
| 84 |
+
positives.append(positive)
|
| 85 |
+
negatives.append(negative)
|
| 86 |
+
return Dataset.from_dict(
|
| 87 |
+
{
|
| 88 |
+
"query": queries,
|
| 89 |
+
"positive": positives,
|
| 90 |
+
"negative": negatives,
|
| 91 |
+
}
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def dense(x: torch.Tensor) -> torch.Tensor:
|
| 96 |
+
return x.to_dense() if x.is_sparse else x
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def encode_query(model: SparseEncoder, texts: list[str], batch_size: int) -> torch.Tensor:
|
| 100 |
+
return dense(model.encode_query(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=False))
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def encode_document(model: SparseEncoder, texts: list[str], batch_size: int) -> torch.Tensor:
|
| 104 |
+
return dense(model.encode_document(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=False))
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def domain_triplet_eval(model: SparseEncoder, eval_ds: Dataset, batch_size: int) -> dict[str, float]:
|
| 108 |
+
original_device = str(model.device)
|
| 109 |
+
model.to("cpu")
|
| 110 |
+
queries = list(eval_ds["query"])
|
| 111 |
+
positives = list(eval_ds["positive"])
|
| 112 |
+
negatives = list(eval_ds["negative"])
|
| 113 |
+
with torch.no_grad():
|
| 114 |
+
query_emb = encode_query(model, queries, batch_size)
|
| 115 |
+
pos_emb = encode_document(model, positives, batch_size)
|
| 116 |
+
neg_emb = encode_document(model, negatives, batch_size)
|
| 117 |
+
margins = (query_emb * pos_emb).sum(dim=1) - (query_emb * neg_emb).sum(dim=1)
|
| 118 |
+
result = {
|
| 119 |
+
"accuracy": (margins > 0).float().mean().item(),
|
| 120 |
+
"mean_margin": margins.float().mean().item(),
|
| 121 |
+
"median_margin": margins.float().median().item(),
|
| 122 |
+
"query_active_dims": (query_emb != 0).sum(dim=1).float().mean().item(),
|
| 123 |
+
"positive_active_dims": (pos_emb != 0).sum(dim=1).float().mean().item(),
|
| 124 |
+
"negative_active_dims": (neg_emb != 0).sum(dim=1).float().mean().item(),
|
| 125 |
+
}
|
| 126 |
+
if original_device != "cpu":
|
| 127 |
+
model.to(original_device)
|
| 128 |
+
return result
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def main() -> None:
|
| 132 |
+
parser = argparse.ArgumentParser()
|
| 133 |
+
parser.add_argument("--train-size", type=int, default=10_000)
|
| 134 |
+
parser.add_argument("--eval-size", type=int, default=500)
|
| 135 |
+
parser.add_argument("--max-steps", type=int, default=100)
|
| 136 |
+
parser.add_argument("--batch-size", type=int, default=8)
|
| 137 |
+
parser.add_argument("--eval-batch-size", type=int, default=16)
|
| 138 |
+
parser.add_argument("--max-seq-length", type=int, default=384)
|
| 139 |
+
parser.add_argument("--query-reg", type=float, default=1e-4)
|
| 140 |
+
parser.add_argument("--doc-reg", type=float, default=8e-5)
|
| 141 |
+
parser.add_argument("--run-name", default=RUN_NAME)
|
| 142 |
+
cli = parser.parse_args()
|
| 143 |
+
|
| 144 |
+
output_dir = Path("models") / cli.run_name
|
| 145 |
+
setup_logging(cli.run_name)
|
| 146 |
+
logging.info("Torch: %s", torch.__version__)
|
| 147 |
+
logging.info("MPS available: %s", torch.backends.mps.is_available())
|
| 148 |
+
logging.info("Loading base model: %s", BASE_MODEL)
|
| 149 |
+
model = SparseEncoder(
|
| 150 |
+
BASE_MODEL,
|
| 151 |
+
model_card_data=SparseEncoderModelCardData(
|
| 152 |
+
language="en",
|
| 153 |
+
license="apache-2.0",
|
| 154 |
+
model_name=f"Financial filings sparse encoder V2 ({cli.run_name})",
|
| 155 |
+
),
|
| 156 |
+
)
|
| 157 |
+
model.max_seq_length = cli.max_seq_length
|
| 158 |
+
logging.info("Training device: %s", model.device)
|
| 159 |
+
logging.info("max_seq_length: %s", model.max_seq_length)
|
| 160 |
+
|
| 161 |
+
raw_train = load_dataset(DATASET, CONFIG, split="train")
|
| 162 |
+
raw_test = load_dataset(DATASET, CONFIG, split="test")
|
| 163 |
+
train_dataset = clean_triplets(raw_train, cli.train_size)
|
| 164 |
+
eval_dataset = clean_triplets(raw_test, cli.eval_size)
|
| 165 |
+
logging.info("train rows: %s | eval rows: %s", len(train_dataset), len(eval_dataset))
|
| 166 |
+
|
| 167 |
+
logging.info("Baseline domain eval:")
|
| 168 |
+
baseline = domain_triplet_eval(model, eval_dataset, cli.eval_batch_size)
|
| 169 |
+
logging.info("BASELINE: %s", json.dumps(baseline, sort_keys=True))
|
| 170 |
+
|
| 171 |
+
loss = SpladeLoss(
|
| 172 |
+
model=model,
|
| 173 |
+
loss=SparseMultipleNegativesRankingLoss(model=model),
|
| 174 |
+
query_regularizer_weight=cli.query_reg,
|
| 175 |
+
document_regularizer_weight=cli.doc_reg,
|
| 176 |
+
)
|
| 177 |
+
data_collator = SparseEncoderDataCollator(
|
| 178 |
+
preprocess_fn=model.preprocess,
|
| 179 |
+
router_mapping={
|
| 180 |
+
"query": "query",
|
| 181 |
+
"positive": "document",
|
| 182 |
+
"negative": "document",
|
| 183 |
+
},
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
args = SparseEncoderTrainingArguments(
|
| 187 |
+
output_dir=str(output_dir),
|
| 188 |
+
max_steps=cli.max_steps,
|
| 189 |
+
num_train_epochs=1,
|
| 190 |
+
per_device_train_batch_size=cli.batch_size,
|
| 191 |
+
learning_rate=2e-5,
|
| 192 |
+
weight_decay=0.01,
|
| 193 |
+
warmup_steps=0.1,
|
| 194 |
+
lr_scheduler_type="linear",
|
| 195 |
+
bf16=False,
|
| 196 |
+
fp16=False,
|
| 197 |
+
batch_sampler=BatchSamplers.NO_DUPLICATES,
|
| 198 |
+
eval_strategy="no",
|
| 199 |
+
save_strategy="no",
|
| 200 |
+
logging_steps=10,
|
| 201 |
+
logging_first_step=True,
|
| 202 |
+
dataloader_pin_memory=False,
|
| 203 |
+
report_to="none",
|
| 204 |
+
run_name=cli.run_name,
|
| 205 |
+
seed=12,
|
| 206 |
+
)
|
| 207 |
+
trainer = SparseEncoderTrainer(
|
| 208 |
+
model=model,
|
| 209 |
+
args=args,
|
| 210 |
+
train_dataset=train_dataset,
|
| 211 |
+
loss=loss,
|
| 212 |
+
data_collator=data_collator,
|
| 213 |
+
)
|
| 214 |
+
trainer.train()
|
| 215 |
+
|
| 216 |
+
logging.info("Post-training domain eval:")
|
| 217 |
+
score = domain_triplet_eval(model, eval_dataset, cli.eval_batch_size)
|
| 218 |
+
logging.info("SCORE: %s", json.dumps(score, sort_keys=True))
|
| 219 |
+
delta = score["accuracy"] - baseline["accuracy"]
|
| 220 |
+
verdict = "WIN" if delta >= 0.005 else "MARGINAL" if delta >= 0 else "REGRESSION"
|
| 221 |
+
logging.info(
|
| 222 |
+
"VERDICT: %s | score=%.4f | baseline=%.4f | delta=%+.4f | query_active=%.1f doc_active=%.1f",
|
| 223 |
+
verdict,
|
| 224 |
+
score["accuracy"],
|
| 225 |
+
baseline["accuracy"],
|
| 226 |
+
delta,
|
| 227 |
+
score["query_active_dims"],
|
| 228 |
+
score["positive_active_dims"],
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
final_dir = output_dir / "final"
|
| 232 |
+
model.save_pretrained(str(final_dir))
|
| 233 |
+
logging.info("Saved final model to %s", final_dir)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
if __name__ == "__main__":
|
| 237 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 238 |
+
main()
|