#!/usr/bin/env python3 # /// script # requires-python = ">=3.10" # dependencies = [ # "sentence-transformers[train]==5.5.0", # "datasets>=2.19.0", # "accelerate>=0.26.0", # ] # /// """Financial filings SparseEncoder V2. Design choices: - stable Apache/OpenSearch base: doc-v2-distill, not brittle custom v3-gte; - clean dataset shape: query, positive, negative only; - explicit Router mapping: query -> query route, positive/negative -> document route; - stronger document FLOPS regularization than the old run to avoid dense docs; - custom CPU evaluator because sparse materialization is not safe on MPS here. """ from __future__ import annotations import argparse import ast import json import logging import os from pathlib import Path import torch from datasets import Dataset, load_dataset from sentence_transformers import ( SparseEncoder, SparseEncoderModelCardData, SparseEncoderTrainer, SparseEncoderTrainingArguments, ) from sentence_transformers.base.sampler import BatchSamplers from sentence_transformers.sparse_encoder.data_collator import SparseEncoderDataCollator from sentence_transformers.sparse_encoder.losses import SparseMultipleNegativesRankingLoss, SpladeLoss DATASET = "oneryalcin/financial-filings-sparse-retrieval-training" CONFIG = "combined" BASE_MODEL = "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill" RUN_NAME = "fin-sparse-encoder-doc-v2-clean-router" OUTPUT_DIR = Path("models") / RUN_NAME LOG_DIR = Path("logs") def setup_logging(run_name: str) -> None: LOG_DIR.mkdir(exist_ok=True) logging.basicConfig( format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[logging.StreamHandler(), logging.FileHandler(LOG_DIR / f"{run_name}.log")], force=True, ) for noisy in ("httpx", "httpcore", "huggingface_hub", "urllib3", "filelock", "fsspec"): logging.getLogger(noisy).setLevel(logging.WARNING) def first_negative(value) -> str: if isinstance(value, list): return value[0] if value else "" if isinstance(value, str): parsed = ast.literal_eval(value) return parsed[0] if parsed else "" raise TypeError(f"Unsupported negatives value: {type(value)}") def clean_triplets(ds, limit: int | None = None) -> Dataset: if limit is not None: ds = ds.select(range(min(limit, len(ds)))) queries: list[str] = [] positives: list[str] = [] negatives: list[str] = [] for query, positive, raw_negative in zip(ds["query"], ds["positive"], ds["negatives"], strict=True): negative = first_negative(raw_negative) if not negative: continue queries.append(query) positives.append(positive) negatives.append(negative) return Dataset.from_dict( { "query": queries, "positive": positives, "negative": negatives, } ) def dense(x: torch.Tensor) -> torch.Tensor: return x.to_dense() if x.is_sparse else x def encode_query(model: SparseEncoder, texts: list[str], batch_size: int) -> torch.Tensor: return dense(model.encode_query(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=False)) def encode_document(model: SparseEncoder, texts: list[str], batch_size: int) -> torch.Tensor: return dense(model.encode_document(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=False)) def domain_triplet_eval(model: SparseEncoder, eval_ds: Dataset, batch_size: int) -> dict[str, float]: original_device = str(model.device) model.to("cpu") queries = list(eval_ds["query"]) positives = list(eval_ds["positive"]) negatives = list(eval_ds["negative"]) with torch.no_grad(): query_emb = encode_query(model, queries, batch_size) pos_emb = encode_document(model, positives, batch_size) neg_emb = encode_document(model, negatives, batch_size) margins = (query_emb * pos_emb).sum(dim=1) - (query_emb * neg_emb).sum(dim=1) result = { "accuracy": (margins > 0).float().mean().item(), "mean_margin": margins.float().mean().item(), "median_margin": margins.float().median().item(), "query_active_dims": (query_emb != 0).sum(dim=1).float().mean().item(), "positive_active_dims": (pos_emb != 0).sum(dim=1).float().mean().item(), "negative_active_dims": (neg_emb != 0).sum(dim=1).float().mean().item(), } if original_device != "cpu": model.to(original_device) return result def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--train-size", type=int, default=10_000) parser.add_argument("--eval-size", type=int, default=500) parser.add_argument("--max-steps", type=int, default=100) parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--eval-batch-size", type=int, default=16) parser.add_argument("--max-seq-length", type=int, default=384) parser.add_argument("--query-reg", type=float, default=1e-4) parser.add_argument("--doc-reg", type=float, default=8e-5) parser.add_argument("--run-name", default=RUN_NAME) cli = parser.parse_args() output_dir = Path("models") / cli.run_name setup_logging(cli.run_name) logging.info("Torch: %s", torch.__version__) logging.info("MPS available: %s", torch.backends.mps.is_available()) logging.info("Loading base model: %s", BASE_MODEL) model = SparseEncoder( BASE_MODEL, model_card_data=SparseEncoderModelCardData( language="en", license="apache-2.0", model_name=f"Financial filings sparse encoder V2 ({cli.run_name})", ), ) model.max_seq_length = cli.max_seq_length logging.info("Training device: %s", model.device) logging.info("max_seq_length: %s", model.max_seq_length) raw_train = load_dataset(DATASET, CONFIG, split="train") raw_test = load_dataset(DATASET, CONFIG, split="test") train_dataset = clean_triplets(raw_train, cli.train_size) eval_dataset = clean_triplets(raw_test, cli.eval_size) logging.info("train rows: %s | eval rows: %s", len(train_dataset), len(eval_dataset)) logging.info("Baseline domain eval:") baseline = domain_triplet_eval(model, eval_dataset, cli.eval_batch_size) logging.info("BASELINE: %s", json.dumps(baseline, sort_keys=True)) loss = SpladeLoss( model=model, loss=SparseMultipleNegativesRankingLoss(model=model), query_regularizer_weight=cli.query_reg, document_regularizer_weight=cli.doc_reg, ) data_collator = SparseEncoderDataCollator( preprocess_fn=model.preprocess, router_mapping={ "query": "query", "positive": "document", "negative": "document", }, ) args = SparseEncoderTrainingArguments( output_dir=str(output_dir), max_steps=cli.max_steps, num_train_epochs=1, per_device_train_batch_size=cli.batch_size, learning_rate=2e-5, weight_decay=0.01, warmup_steps=0.1, lr_scheduler_type="linear", bf16=False, fp16=False, batch_sampler=BatchSamplers.NO_DUPLICATES, eval_strategy="no", save_strategy="no", logging_steps=10, logging_first_step=True, dataloader_pin_memory=False, report_to="none", run_name=cli.run_name, seed=12, ) trainer = SparseEncoderTrainer( model=model, args=args, train_dataset=train_dataset, loss=loss, data_collator=data_collator, ) trainer.train() logging.info("Post-training domain eval:") score = domain_triplet_eval(model, eval_dataset, cli.eval_batch_size) logging.info("SCORE: %s", json.dumps(score, sort_keys=True)) delta = score["accuracy"] - baseline["accuracy"] verdict = "WIN" if delta >= 0.005 else "MARGINAL" if delta >= 0 else "REGRESSION" logging.info( "VERDICT: %s | score=%.4f | baseline=%.4f | delta=%+.4f | query_active=%.1f doc_active=%.1f", verdict, score["accuracy"], baseline["accuracy"], delta, score["query_active_dims"], score["positive_active_dims"], ) final_dir = output_dir / "final" model.save_pretrained(str(final_dir)) logging.info("Saved final model to %s", final_dir) if __name__ == "__main__": os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") main()