#!/usr/bin/env python # /// script # requires-python = ">=3.11" # dependencies = [ # "datasets", # "sentence-transformers==5.5.0", # "torch", # ] # /// from __future__ import annotations import argparse import json import logging from collections import OrderedDict import torch from datasets import load_dataset from sentence_transformers.sparse_encoder import SparseEncoder logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s") def parse_topks(value: str) -> list[int | None]: out: list[int | None] = [] for item in value.split(","): item = item.strip().lower() if item in {"none", "all", "null"}: out.append(None) else: out.append(int(item)) return out def first_negative(row: dict) -> str | None: negatives = row.get("negatives") if isinstance(negatives, list): for negative in negatives: if isinstance(negative, str) and negative.strip(): return negative if isinstance(negatives, str) and negatives.strip(): return negatives return None def load_rows(limit: int) -> list[dict]: dataset = load_dataset( "oneryalcin/financial-filings-sparse-retrieval-training", "combined", split="test", ) rows: list[dict] = [] for row in dataset: query = row.get("query") positive = row.get("positive") negative = first_negative(row) if query and positive and negative: rows.append({"query": query, "positive": positive, "negative": negative}) if len(rows) >= limit: break return rows def dedupe_texts(rows: list[dict]) -> tuple[list[str], list[int], list[int]]: corpus = OrderedDict() positive_ids: list[int] = [] negative_ids: list[int] = [] for row in rows: for key in ("positive", "negative"): text = row[key] if text not in corpus: corpus[text] = len(corpus) positive_ids.append(corpus[row["positive"]]) negative_ids.append(corpus[row["negative"]]) return list(corpus.keys()), positive_ids, negative_ids def densify(encoded) -> torch.Tensor: if isinstance(encoded, torch.Tensor): tensor = encoded if tensor.layout != torch.strided: tensor = tensor.to_dense() elif hasattr(encoded, "to_dense"): tensor = encoded.to_dense() else: tensor = torch.as_tensor(encoded) return tensor.float().cpu() def active_dims(tensor: torch.Tensor) -> float: return (tensor > 0).sum(dim=1).float().mean().item() def apply_doc_topk(docs: torch.Tensor, topk: int | None) -> torch.Tensor: if topk is None or topk >= docs.shape[1]: return docs values, indices = torch.topk(docs, k=topk, dim=1) pruned = torch.zeros_like(docs) pruned.scatter_(1, indices, values) return pruned def retrieval_metrics(scores: torch.Tensor, positive_ids: list[int]) -> dict: positive = torch.tensor(positive_ids, dtype=torch.long) positive_scores = scores[torch.arange(scores.shape[0]), positive] ranks = (scores > positive_scores.unsqueeze(1)).sum(dim=1) + 1 ranks_f = ranks.float() mrr10 = torch.where(ranks <= 10, 1.0 / ranks_f, torch.zeros_like(ranks_f)).mean().item() ndcg10 = torch.where( ranks <= 10, 1.0 / torch.log2(ranks_f + 1.0), torch.zeros_like(ranks_f), ).mean().item() return { "recall_at_1": (ranks <= 1).float().mean().item(), "recall_at_5": (ranks <= 5).float().mean().item(), "recall_at_10": (ranks <= 10).float().mean().item(), "recall_at_20": (ranks <= 20).float().mean().item(), "mrr_at_10": mrr10, "ndcg_at_10": ndcg10, "mean_rank": ranks_f.mean().item(), "median_rank": ranks_f.median().item(), } def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--model", required=True) parser.add_argument("--eval-size", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--topks", default="none,128,64") args = parser.parse_args() rows = load_rows(args.eval_size) queries = [row["query"] for row in rows] corpus, positive_ids, negative_ids = dedupe_texts(rows) logging.info("rows=%d corpus=%d topks=%s", len(rows), len(corpus), args.topks) model = SparseEncoder(args.model, device="cpu") query_vectors = densify( model.encode_query(queries, batch_size=args.batch_size, convert_to_tensor=True) ) doc_vectors = densify( model.encode_document(corpus, batch_size=args.batch_size, convert_to_tensor=True) ) query_active = active_dims(query_vectors) doc_active = active_dims(doc_vectors) for topk in parse_topks(args.topks): pruned_docs = apply_doc_topk(doc_vectors, topk) scores = query_vectors @ pruned_docs.T metrics = retrieval_metrics(scores, positive_ids) result = { "model": args.model, "rows": len(rows), "corpus_size": len(corpus), "doc_topk": topk, "query_active_dims": query_active, "doc_active_dims": active_dims(pruned_docs), **metrics, } logging.info("RETRIEVAL_RESULT %s", json.dumps(result, sort_keys=True)) print(json.dumps(result, sort_keys=True)) if __name__ == "__main__": main()