#!/usr/bin/env python # /// script # requires-python = ">=3.11" # dependencies = [ # "datasets", # ] # /// from __future__ import annotations import argparse import json from collections import Counter from pathlib import Path from typing import Any from datasets import load_dataset DEFAULT_DATASET = "oneryalcin/financial-filings-sparse-retrieval-training" DEFAULT_CONFIG = "combined" def non_empty_text(value: Any) -> str | None: if isinstance(value, str) and value.strip(): return value.strip() return None def first_non_empty_negative(row: dict[str, Any]) -> str | None: negatives = row.get("negatives") if isinstance(negatives, list): for negative in negatives: text = non_empty_text(negative) if text: return text return non_empty_text(negatives) def metadata(row: dict[str, Any]) -> dict[str, str]: out: dict[str, str] = {} for key in ("company", "ticker", "doc_type", "form", "filing_year", "section", "query_type"): value = non_empty_text(row.get(key)) if value: out[key] = value return out def metadata_prefix(meta: dict[str, str]) -> str: labels = { "company": "Company", "ticker": "Ticker", "doc_type": "Document type", "form": "Form", "filing_year": "Filing year", "section": "Section", } lines = [f"{label}: {meta[key]}" for key, label in labels.items() if key in meta] if not lines: return "" return "\n".join(lines) + "\n\n" def clean_row(row: dict[str, Any], add_metadata_prefix: bool) -> tuple[dict[str, Any] | None, str | None]: query = non_empty_text(row.get("query")) positive = non_empty_text(row.get("positive")) negative = first_non_empty_negative(row) if not query: return None, "missing_query" if not positive: return None, "missing_positive" if not negative: return None, "missing_negative" meta = metadata(row) prefix = metadata_prefix(meta) if add_metadata_prefix else "" return { "query": query, "positive": prefix + positive, "negative": prefix + negative, "metadata": meta, }, None def write_jsonl(path: Path, rows: list[dict[str, Any]]) -> None: path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", encoding="utf-8") as handle: for row in rows: handle.write(json.dumps(row, ensure_ascii=False) + "\n") def main() -> None: parser = argparse.ArgumentParser( description=( "Prepare financial-filings sparse retrieval triplets. " "Default behavior matches the v1 fine-tune: drop rows missing query/positive/non-empty negative, " "then select the first non-empty negative." ) ) parser.add_argument("--dataset", default=DEFAULT_DATASET) parser.add_argument("--config", default=DEFAULT_CONFIG) parser.add_argument("--split", default="train") parser.add_argument("--limit", type=int, default=None) parser.add_argument("--output", type=Path, default=None, help="Optional JSONL output path.") parser.add_argument( "--add-metadata-prefix", action="store_true", help=( "Opt-in experimental v2 transform: prepend available metadata to positive/negative documents. " "This was not used for the v1 model." ), ) args = parser.parse_args() dataset = load_dataset(args.dataset, args.config, split=args.split) cleaned: list[dict[str, Any]] = [] skipped: Counter[str] = Counter() query_types: Counter[str] = Counter() doc_types: Counter[str] = Counter() companies: Counter[str] = Counter() for row in dataset: clean, reason = clean_row(dict(row), args.add_metadata_prefix) if clean is None: skipped[reason or "unknown"] += 1 continue cleaned.append(clean) meta = clean["metadata"] query_types.update([meta.get("query_type", "")]) doc_types.update([meta.get("doc_type", "")]) companies.update([meta.get("company", "")]) if args.limit is not None and len(cleaned) >= args.limit: break if args.output: write_jsonl(args.output, cleaned) summary = { "dataset": args.dataset, "config": args.config, "split": args.split, "limit": args.limit, "add_metadata_prefix": args.add_metadata_prefix, "usable_rows": len(cleaned), "skipped_rows": dict(skipped), "top_query_types": query_types.most_common(20), "top_doc_types": doc_types.most_common(20), "top_companies": companies.most_common(20), "output": str(args.output) if args.output else None, } print(json.dumps(summary, indent=2, sort_keys=True)) if __name__ == "__main__": main()