Prompt_Squirrel_RAG / scripts /build_t5_rewrite_dataset.py
Food Desert
Roll out T5 rewrite updates, tooling, docs, and artifact ignore rules
34c53b5
Raw
History Blame Contribute Delete
6.58 kB
from __future__ import annotations
import argparse
import gzip
import hashlib
import json
import random
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple
REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_INPUT = REPO_ROOT / "data" / "external" / "caption_emporium" / "furry-e621-safe-llama3.2-11b" / "train.jsonl.gz"
DEFAULT_OUTPUT_DIR = REPO_ROOT / "data" / "external" / "caption_emporium" / "t5_rewrite_splits"
CAPTION_FIELDS = ("caption_short", "caption_medium", "caption_long")
def _iter_jsonl(path: Path) -> Iterable[Dict[str, Any]]:
if path.suffix == ".gz":
with gzip.open(path, "rt", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
yield json.loads(line)
else:
with path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
yield json.loads(line)
def _canonicalize_tag(tag: str) -> str:
t = " ".join(str(tag or "").strip().split()).lower()
return t.replace(" ", "_").replace("\\(", "(").replace("\\)", ")")
def _flatten_tags(raw: Any) -> List[str]:
cats = raw
if isinstance(raw, str):
try:
cats = json.loads(raw)
except json.JSONDecodeError:
return []
if not isinstance(cats, dict):
return []
out = set()
for vals in cats.values():
if not isinstance(vals, list):
continue
for tag in vals:
ct = _canonicalize_tag(str(tag))
if ct:
out.add(ct)
return sorted(out)
def _split_name(sample_id: Any, val_frac: float, test_frac: float) -> str:
key = str(sample_id).encode("utf-8")
digest = hashlib.blake2b(key, digest_size=8).hexdigest()
bucket = int(digest, 16) % 10000
test_cut = int(round(test_frac * 10000))
val_cut = test_cut + int(round(val_frac * 10000))
if bucket < test_cut:
return "test"
if bucket < val_cut:
return "val"
return "train"
def _reservoir_add(
arr: List[Dict[str, Any]],
item: Dict[str, Any],
cap: Optional[int],
seen_count: int,
rng: random.Random,
) -> None:
if cap is None:
arr.append(item)
return
if cap <= 0:
return
if len(arr) < cap:
arr.append(item)
return
j = rng.randint(0, seen_count - 1)
if j < cap:
arr[j] = item
def _write_jsonl(path: Path, rows: List[Dict[str, Any]]) -> None:
with path.open("w", encoding="utf-8") as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=False) + "\n")
def main() -> int:
ap = argparse.ArgumentParser(description="Build T5 rewrite fine-tuning splits from CaptionEmporium JSONL(.gz)")
ap.add_argument("--input", type=Path, default=DEFAULT_INPUT)
ap.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR)
ap.add_argument("--val-frac", type=float, default=0.01)
ap.add_argument("--test-frac", type=float, default=0.01)
ap.add_argument("--max-train", type=int, default=60000, help="Reservoir cap for train split (0 disables)")
ap.add_argument("--max-val", type=int, default=3000, help="Reservoir cap for val split (0 disables)")
ap.add_argument("--max-test", type=int, default=3000, help="Reservoir cap for test split (0 disables)")
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--task-prefix", type=str, default="caption_to_tags:")
args = ap.parse_args()
input_path = args.input if args.input.is_absolute() else (REPO_ROOT / args.input).resolve()
if not input_path.is_file():
raise FileNotFoundError(f"Input dataset not found: {input_path}")
output_dir = args.output_dir if args.output_dir.is_absolute() else (REPO_ROOT / args.output_dir).resolve()
output_dir.mkdir(parents=True, exist_ok=True)
rng = random.Random(args.seed)
split_rows: Dict[str, List[Dict[str, Any]]] = {"train": [], "val": [], "test": []}
split_seen = {"train": 0, "val": 0, "test": 0}
split_caps: Dict[str, Optional[int]] = {
"train": None if args.max_train == 0 else args.max_train,
"val": None if args.max_val == 0 else args.max_val,
"test": None if args.max_test == 0 else args.max_test,
}
rows_total = 0
rows_with_tags = 0
examples_total = 0
prefix = (args.task_prefix or "").strip()
for obj in _iter_jsonl(input_path):
rows_total += 1
sid = obj.get("id", rows_total)
tags = _flatten_tags(obj.get("tags_ground_truth_categorized"))
if not tags:
continue
rows_with_tags += 1
target_text = ", ".join(tags)
split = _split_name(sid, args.val_frac, args.test_frac)
for field in CAPTION_FIELDS:
caption = str(obj.get(field, "") or "").strip()
if not caption:
continue
source_text = f"{prefix} {caption}".strip() if prefix else caption
rec = {
"id": sid,
"caption_field": field,
"source_text": source_text,
"target_text": target_text,
}
split_seen[split] += 1
_reservoir_add(
split_rows[split],
rec,
split_caps[split],
split_seen[split],
rng,
)
examples_total += 1
for name in ("train", "val", "test"):
rng.shuffle(split_rows[name])
_write_jsonl(output_dir / f"{name}.jsonl", split_rows[name])
meta = {
"input_path": str(input_path),
"output_dir": str(output_dir),
"seed": args.seed,
"val_frac": args.val_frac,
"test_frac": args.test_frac,
"max_train": args.max_train,
"max_val": args.max_val,
"max_test": args.max_test,
"task_prefix": prefix,
"rows_total": rows_total,
"rows_with_tags": rows_with_tags,
"examples_total_pre_cap": examples_total,
"examples_written": {k: len(v) for k, v in split_rows.items()},
"examples_seen_by_split_pre_cap": split_seen,
"caption_fields": list(CAPTION_FIELDS),
}
with (output_dir / "meta.json").open("w", encoding="utf-8") as f:
json.dump(meta, f, ensure_ascii=False, indent=2)
print(json.dumps(meta, ensure_ascii=False, indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())