#!/usr/bin/env python # /// script # requires-python = ">=3.11" # dependencies = [ # "datasets", # ] # /// from __future__ import annotations import argparse import json import logging import math import re from collections import Counter, OrderedDict, defaultdict from datasets import load_dataset logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s") TOKEN_RE = re.compile(r"[a-z0-9]+") def tokenize(text: str) -> list[str]: return TOKEN_RE.findall(text.lower()) def first_negative(row: dict) -> str | None: negatives = row.get("negatives") if isinstance(negatives, list): for negative in negatives: if isinstance(negative, str) and negative.strip(): return negative if isinstance(negatives, str) and negatives.strip(): return negatives return None def load_rows(limit: int) -> list[dict]: dataset = load_dataset( "oneryalcin/financial-filings-sparse-retrieval-training", "combined", split="test", ) rows: list[dict] = [] for row in dataset: query = row.get("query") positive = row.get("positive") negative = first_negative(row) if query and positive and negative: rows.append({"query": query, "positive": positive, "negative": negative}) if len(rows) >= limit: break return rows def dedupe_texts(rows: list[dict]) -> tuple[list[str], list[int]]: corpus = OrderedDict() positive_ids: list[int] = [] for row in rows: for key in ("positive", "negative"): text = row[key] if text not in corpus: corpus[text] = len(corpus) positive_ids.append(corpus[row["positive"]]) return list(corpus.keys()), positive_ids def build_bm25(corpus: list[str], k1: float, b: float) -> tuple[dict[str, list[tuple[int, float]]], list[int], float]: doc_lens: list[int] = [] doc_counts: list[Counter[str]] = [] document_frequency: Counter[str] = Counter() for text in corpus: counts = Counter(tokenize(text)) doc_counts.append(counts) doc_lens.append(sum(counts.values())) document_frequency.update(counts.keys()) corpus_size = len(corpus) avg_doc_len = sum(doc_lens) / max(corpus_size, 1) inverted: dict[str, list[tuple[int, float]]] = defaultdict(list) for doc_id, counts in enumerate(doc_counts): norm = k1 * (1 - b + b * doc_lens[doc_id] / avg_doc_len) for term, tf in counts.items(): df = document_frequency[term] idf = math.log(1 + (corpus_size - df + 0.5) / (df + 0.5)) score = idf * (tf * (k1 + 1)) / (tf + norm) inverted[term].append((doc_id, score)) return inverted, doc_lens, avg_doc_len def rank_query(query: str, inverted: dict[str, list[tuple[int, float]]], corpus_size: int) -> list[float]: scores = [0.0] * corpus_size for term, qtf in Counter(tokenize(query)).items(): for doc_id, score in inverted.get(term, ()): scores[doc_id] += score * qtf return scores def retrieval_metrics(all_scores: list[list[float]], positive_ids: list[int]) -> dict: ranks: list[int] = [] for scores, positive_id in zip(all_scores, positive_ids): positive_score = scores[positive_id] rank = 1 + sum(score > positive_score for score in scores) ranks.append(rank) ranks_sorted = sorted(ranks) count = len(ranks) mrr10 = sum((1.0 / rank) if rank <= 10 else 0.0 for rank in ranks) / count ndcg10 = sum((1.0 / math.log2(rank + 1)) if rank <= 10 else 0.0 for rank in ranks) / count return { "recall_at_1": sum(rank <= 1 for rank in ranks) / count, "recall_at_5": sum(rank <= 5 for rank in ranks) / count, "recall_at_10": sum(rank <= 10 for rank in ranks) / count, "recall_at_20": sum(rank <= 20 for rank in ranks) / count, "mrr_at_10": mrr10, "ndcg_at_10": ndcg10, "mean_rank": sum(ranks) / count, "median_rank": ranks_sorted[count // 2], } def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--eval-size", type=int, default=1000) parser.add_argument("--k1", type=float, default=1.2) parser.add_argument("--b", type=float, default=0.75) args = parser.parse_args() rows = load_rows(args.eval_size) queries = [row["query"] for row in rows] corpus, positive_ids = dedupe_texts(rows) logging.info("rows=%d corpus=%d k1=%s b=%s", len(rows), len(corpus), args.k1, args.b) inverted, doc_lens, avg_doc_len = build_bm25(corpus, args.k1, args.b) scores = [rank_query(query, inverted, len(corpus)) for query in queries] result = { "model": "bm25", "rows": len(rows), "corpus_size": len(corpus), "avg_doc_len": avg_doc_len, "median_doc_len": sorted(doc_lens)[len(doc_lens) // 2], "k1": args.k1, "b": args.b, **retrieval_metrics(scores, positive_ids), } logging.info("BM25_RESULT %s", json.dumps(result, sort_keys=True)) print(json.dumps(result, sort_keys=True)) if __name__ == "__main__": main()