#!/usr/bin/env python3 # /// script # dependencies = [ # "sentence-transformers[train]==5.5.0", # "datasets", # "torch", # ] # /// from __future__ import annotations import argparse import json import logging from dataclasses import dataclass import torch from datasets import load_dataset from sentence_transformers.sparse_encoder import SparseEncoder DATASET = "oneryalcin/financial-filings-sparse-retrieval-training" CONFIG = "combined" logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s") for noisy in ("httpx", "httpcore", "huggingface_hub", "urllib3", "filelock", "fsspec"): logging.getLogger(noisy).setLevel(logging.WARNING) @dataclass(frozen=True) class EvalBatch: queries: list[str] positives: list[str] negatives: list[str] def dense(tensor: torch.Tensor) -> torch.Tensor: if tensor.is_sparse: return tensor.to_dense() return tensor def doc_topk(tensor: torch.Tensor, k: int | None) -> torch.Tensor: if k is None or k <= 0 or k >= tensor.shape[1]: return tensor values, indices = torch.topk(tensor, k=k, dim=1) out = torch.zeros_like(tensor) return out.scatter(1, indices, values) def active_dims(tensor: torch.Tensor) -> float: return (tensor != 0).sum(dim=1).float().mean().item() def margins(queries: torch.Tensor, positives: torch.Tensor, negatives: torch.Tensor) -> torch.Tensor: positive_scores = (queries * positives).sum(dim=1) negative_scores = (queries * negatives).sum(dim=1) return positive_scores - negative_scores def load_eval_rows(limit: int) -> EvalBatch: rows = load_dataset(DATASET, CONFIG, split="test", streaming=False) queries: list[str] = [] positives: list[str] = [] negatives: list[str] = [] for row in rows: negatives_list = row.get("negatives") or [] if not negatives_list: continue queries.append(row["query"]) positives.append(row["positive"]) negatives.append(negatives_list[0]) if len(queries) >= limit: break return EvalBatch(queries=queries, positives=positives, negatives=negatives) def eval_model(model_name: str, batch: EvalBatch, topks: list[int | None], batch_size: int) -> list[dict[str, float | int | str | None]]: logging.info("Loading %s", model_name) model = SparseEncoder(model_name, trust_remote_code=True, device="cpu") logging.info("Encoding %d triplets", len(batch.queries)) queries = dense(model.encode_query(batch.queries, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)).cpu() positives_raw = dense(model.encode_document(batch.positives, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)).cpu() negatives_raw = dense(model.encode_document(batch.negatives, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)).cpu() results: list[dict[str, float | int | str | None]] = [] for k in topks: positives = doc_topk(positives_raw, k) negatives = doc_topk(negatives_raw, k) margin = margins(queries, positives, negatives) results.append( { "model": model_name, "doc_topk": k, "accuracy": (margin > 0).float().mean().item(), "mean_margin": margin.mean().item(), "median_margin": margin.median().item(), "query_active_dims": active_dims(queries), "positive_doc_active_dims": active_dims(positives), "negative_doc_active_dims": active_dims(negatives), } ) return results def parse_topks(raw: str) -> list[int | None]: out: list[int | None] = [] for item in raw.split(","): item = item.strip().lower() if item in {"none", "all", "0"}: out.append(None) else: out.append(int(item)) return out def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--models", nargs="+", required=True) parser.add_argument("--eval-size", type=int, default=1000) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--topks", default="none,1024,512,256,128,64") args = parser.parse_args() eval_batch = load_eval_rows(args.eval_size) topks = parse_topks(args.topks) logging.info("rows=%d topks=%s", len(eval_batch.queries), topks) all_results: list[dict[str, float | int | str | None]] = [] for model_name in args.models: all_results.extend(eval_model(model_name, eval_batch, topks, args.batch_size)) for result in all_results: logging.info("TOPK_RESULT %s", json.dumps(result, sort_keys=True)) print(json.dumps(all_results, indent=2, sort_keys=True)) if __name__ == "__main__": main()