financial-filings-sparse-encoder-v1 / scripts /train_fin_sparse_encoder_v2.py
oneryalcin's picture
Add financial filings sparse encoder v1
bf3d3f8 verified
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "sentence-transformers[train]==5.5.0",
# "datasets>=2.19.0",
# "accelerate>=0.26.0",
# ]
# ///
"""Financial filings SparseEncoder V2.
Design choices:
- stable Apache/OpenSearch base: doc-v2-distill, not brittle custom v3-gte;
- clean dataset shape: query, positive, negative only;
- explicit Router mapping: query -> query route, positive/negative -> document route;
- stronger document FLOPS regularization than the old run to avoid dense docs;
- custom CPU evaluator because sparse materialization is not safe on MPS here.
"""
from __future__ import annotations
import argparse
import ast
import json
import logging
import os
from pathlib import Path
import torch
from datasets import Dataset, load_dataset
from sentence_transformers import (
SparseEncoder,
SparseEncoderModelCardData,
SparseEncoderTrainer,
SparseEncoderTrainingArguments,
)
from sentence_transformers.base.sampler import BatchSamplers
from sentence_transformers.sparse_encoder.data_collator import SparseEncoderDataCollator
from sentence_transformers.sparse_encoder.losses import SparseMultipleNegativesRankingLoss, SpladeLoss
DATASET = "oneryalcin/financial-filings-sparse-retrieval-training"
CONFIG = "combined"
BASE_MODEL = "opensearch-project/opensearch-neural-sparse-encoding-doc-v2-distill"
RUN_NAME = "fin-sparse-encoder-doc-v2-clean-router"
OUTPUT_DIR = Path("models") / RUN_NAME
LOG_DIR = Path("logs")
def setup_logging(run_name: str) -> None:
LOG_DIR.mkdir(exist_ok=True)
logging.basicConfig(
format="%(asctime)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO,
handlers=[logging.StreamHandler(), logging.FileHandler(LOG_DIR / f"{run_name}.log")],
force=True,
)
for noisy in ("httpx", "httpcore", "huggingface_hub", "urllib3", "filelock", "fsspec"):
logging.getLogger(noisy).setLevel(logging.WARNING)
def first_negative(value) -> str:
if isinstance(value, list):
return value[0] if value else ""
if isinstance(value, str):
parsed = ast.literal_eval(value)
return parsed[0] if parsed else ""
raise TypeError(f"Unsupported negatives value: {type(value)}")
def clean_triplets(ds, limit: int | None = None) -> Dataset:
if limit is not None:
ds = ds.select(range(min(limit, len(ds))))
queries: list[str] = []
positives: list[str] = []
negatives: list[str] = []
for query, positive, raw_negative in zip(ds["query"], ds["positive"], ds["negatives"], strict=True):
negative = first_negative(raw_negative)
if not negative:
continue
queries.append(query)
positives.append(positive)
negatives.append(negative)
return Dataset.from_dict(
{
"query": queries,
"positive": positives,
"negative": negatives,
}
)
def dense(x: torch.Tensor) -> torch.Tensor:
return x.to_dense() if x.is_sparse else x
def encode_query(model: SparseEncoder, texts: list[str], batch_size: int) -> torch.Tensor:
return dense(model.encode_query(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=False))
def encode_document(model: SparseEncoder, texts: list[str], batch_size: int) -> torch.Tensor:
return dense(model.encode_document(texts, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=False))
def domain_triplet_eval(model: SparseEncoder, eval_ds: Dataset, batch_size: int) -> dict[str, float]:
original_device = str(model.device)
model.to("cpu")
queries = list(eval_ds["query"])
positives = list(eval_ds["positive"])
negatives = list(eval_ds["negative"])
with torch.no_grad():
query_emb = encode_query(model, queries, batch_size)
pos_emb = encode_document(model, positives, batch_size)
neg_emb = encode_document(model, negatives, batch_size)
margins = (query_emb * pos_emb).sum(dim=1) - (query_emb * neg_emb).sum(dim=1)
result = {
"accuracy": (margins > 0).float().mean().item(),
"mean_margin": margins.float().mean().item(),
"median_margin": margins.float().median().item(),
"query_active_dims": (query_emb != 0).sum(dim=1).float().mean().item(),
"positive_active_dims": (pos_emb != 0).sum(dim=1).float().mean().item(),
"negative_active_dims": (neg_emb != 0).sum(dim=1).float().mean().item(),
}
if original_device != "cpu":
model.to(original_device)
return result
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--train-size", type=int, default=10_000)
parser.add_argument("--eval-size", type=int, default=500)
parser.add_argument("--max-steps", type=int, default=100)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--eval-batch-size", type=int, default=16)
parser.add_argument("--max-seq-length", type=int, default=384)
parser.add_argument("--query-reg", type=float, default=1e-4)
parser.add_argument("--doc-reg", type=float, default=8e-5)
parser.add_argument("--run-name", default=RUN_NAME)
cli = parser.parse_args()
output_dir = Path("models") / cli.run_name
setup_logging(cli.run_name)
logging.info("Torch: %s", torch.__version__)
logging.info("MPS available: %s", torch.backends.mps.is_available())
logging.info("Loading base model: %s", BASE_MODEL)
model = SparseEncoder(
BASE_MODEL,
model_card_data=SparseEncoderModelCardData(
language="en",
license="apache-2.0",
model_name=f"Financial filings sparse encoder V2 ({cli.run_name})",
),
)
model.max_seq_length = cli.max_seq_length
logging.info("Training device: %s", model.device)
logging.info("max_seq_length: %s", model.max_seq_length)
raw_train = load_dataset(DATASET, CONFIG, split="train")
raw_test = load_dataset(DATASET, CONFIG, split="test")
train_dataset = clean_triplets(raw_train, cli.train_size)
eval_dataset = clean_triplets(raw_test, cli.eval_size)
logging.info("train rows: %s | eval rows: %s", len(train_dataset), len(eval_dataset))
logging.info("Baseline domain eval:")
baseline = domain_triplet_eval(model, eval_dataset, cli.eval_batch_size)
logging.info("BASELINE: %s", json.dumps(baseline, sort_keys=True))
loss = SpladeLoss(
model=model,
loss=SparseMultipleNegativesRankingLoss(model=model),
query_regularizer_weight=cli.query_reg,
document_regularizer_weight=cli.doc_reg,
)
data_collator = SparseEncoderDataCollator(
preprocess_fn=model.preprocess,
router_mapping={
"query": "query",
"positive": "document",
"negative": "document",
},
)
args = SparseEncoderTrainingArguments(
output_dir=str(output_dir),
max_steps=cli.max_steps,
num_train_epochs=1,
per_device_train_batch_size=cli.batch_size,
learning_rate=2e-5,
weight_decay=0.01,
warmup_steps=0.1,
lr_scheduler_type="linear",
bf16=False,
fp16=False,
batch_sampler=BatchSamplers.NO_DUPLICATES,
eval_strategy="no",
save_strategy="no",
logging_steps=10,
logging_first_step=True,
dataloader_pin_memory=False,
report_to="none",
run_name=cli.run_name,
seed=12,
)
trainer = SparseEncoderTrainer(
model=model,
args=args,
train_dataset=train_dataset,
loss=loss,
data_collator=data_collator,
)
trainer.train()
logging.info("Post-training domain eval:")
score = domain_triplet_eval(model, eval_dataset, cli.eval_batch_size)
logging.info("SCORE: %s", json.dumps(score, sort_keys=True))
delta = score["accuracy"] - baseline["accuracy"]
verdict = "WIN" if delta >= 0.005 else "MARGINAL" if delta >= 0 else "REGRESSION"
logging.info(
"VERDICT: %s | score=%.4f | baseline=%.4f | delta=%+.4f | query_active=%.1f doc_active=%.1f",
verdict,
score["accuracy"],
baseline["accuracy"],
delta,
score["query_active_dims"],
score["positive_active_dims"],
)
final_dir = output_dir / "final"
model.save_pretrained(str(final_dir))
logging.info("Saved final model to %s", final_dir)
if __name__ == "__main__":
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
main()