financial-filings-sparse-encoder-v1 / scripts /eval_fin_sparse_retrieval_proxy.py
oneryalcin's picture
Add financial filings sparse encoder v1
bf3d3f8 verified
#!/usr/bin/env python
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "datasets",
# "sentence-transformers==5.5.0",
# "torch",
# ]
# ///
from __future__ import annotations
import argparse
import json
import logging
from collections import OrderedDict
import torch
from datasets import load_dataset
from sentence_transformers.sparse_encoder import SparseEncoder
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s")
def parse_topks(value: str) -> list[int | None]:
out: list[int | None] = []
for item in value.split(","):
item = item.strip().lower()
if item in {"none", "all", "null"}:
out.append(None)
else:
out.append(int(item))
return out
def first_negative(row: dict) -> str | None:
negatives = row.get("negatives")
if isinstance(negatives, list):
for negative in negatives:
if isinstance(negative, str) and negative.strip():
return negative
if isinstance(negatives, str) and negatives.strip():
return negatives
return None
def load_rows(limit: int) -> list[dict]:
dataset = load_dataset(
"oneryalcin/financial-filings-sparse-retrieval-training",
"combined",
split="test",
)
rows: list[dict] = []
for row in dataset:
query = row.get("query")
positive = row.get("positive")
negative = first_negative(row)
if query and positive and negative:
rows.append({"query": query, "positive": positive, "negative": negative})
if len(rows) >= limit:
break
return rows
def dedupe_texts(rows: list[dict]) -> tuple[list[str], list[int], list[int]]:
corpus = OrderedDict()
positive_ids: list[int] = []
negative_ids: list[int] = []
for row in rows:
for key in ("positive", "negative"):
text = row[key]
if text not in corpus:
corpus[text] = len(corpus)
positive_ids.append(corpus[row["positive"]])
negative_ids.append(corpus[row["negative"]])
return list(corpus.keys()), positive_ids, negative_ids
def densify(encoded) -> torch.Tensor:
if isinstance(encoded, torch.Tensor):
tensor = encoded
if tensor.layout != torch.strided:
tensor = tensor.to_dense()
elif hasattr(encoded, "to_dense"):
tensor = encoded.to_dense()
else:
tensor = torch.as_tensor(encoded)
return tensor.float().cpu()
def active_dims(tensor: torch.Tensor) -> float:
return (tensor > 0).sum(dim=1).float().mean().item()
def apply_doc_topk(docs: torch.Tensor, topk: int | None) -> torch.Tensor:
if topk is None or topk >= docs.shape[1]:
return docs
values, indices = torch.topk(docs, k=topk, dim=1)
pruned = torch.zeros_like(docs)
pruned.scatter_(1, indices, values)
return pruned
def retrieval_metrics(scores: torch.Tensor, positive_ids: list[int]) -> dict:
positive = torch.tensor(positive_ids, dtype=torch.long)
positive_scores = scores[torch.arange(scores.shape[0]), positive]
ranks = (scores > positive_scores.unsqueeze(1)).sum(dim=1) + 1
ranks_f = ranks.float()
mrr10 = torch.where(ranks <= 10, 1.0 / ranks_f, torch.zeros_like(ranks_f)).mean().item()
ndcg10 = torch.where(
ranks <= 10,
1.0 / torch.log2(ranks_f + 1.0),
torch.zeros_like(ranks_f),
).mean().item()
return {
"recall_at_1": (ranks <= 1).float().mean().item(),
"recall_at_5": (ranks <= 5).float().mean().item(),
"recall_at_10": (ranks <= 10).float().mean().item(),
"recall_at_20": (ranks <= 20).float().mean().item(),
"mrr_at_10": mrr10,
"ndcg_at_10": ndcg10,
"mean_rank": ranks_f.mean().item(),
"median_rank": ranks_f.median().item(),
}
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--eval-size", type=int, default=1000)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--topks", default="none,128,64")
args = parser.parse_args()
rows = load_rows(args.eval_size)
queries = [row["query"] for row in rows]
corpus, positive_ids, negative_ids = dedupe_texts(rows)
logging.info("rows=%d corpus=%d topks=%s", len(rows), len(corpus), args.topks)
model = SparseEncoder(args.model, device="cpu")
query_vectors = densify(
model.encode_query(queries, batch_size=args.batch_size, convert_to_tensor=True)
)
doc_vectors = densify(
model.encode_document(corpus, batch_size=args.batch_size, convert_to_tensor=True)
)
query_active = active_dims(query_vectors)
doc_active = active_dims(doc_vectors)
for topk in parse_topks(args.topks):
pruned_docs = apply_doc_topk(doc_vectors, topk)
scores = query_vectors @ pruned_docs.T
metrics = retrieval_metrics(scores, positive_ids)
result = {
"model": args.model,
"rows": len(rows),
"corpus_size": len(corpus),
"doc_topk": topk,
"query_active_dims": query_active,
"doc_active_dims": active_dims(pruned_docs),
**metrics,
}
logging.info("RETRIEVAL_RESULT %s", json.dumps(result, sort_keys=True))
print(json.dumps(result, sort_keys=True))
if __name__ == "__main__":
main()