from __future__ import annotations import argparse import gzip import hashlib import json import random from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple REPO_ROOT = Path(__file__).resolve().parents[1] DEFAULT_INPUT = REPO_ROOT / "data" / "external" / "caption_emporium" / "furry-e621-safe-llama3.2-11b" / "train.jsonl.gz" DEFAULT_OUTPUT_DIR = REPO_ROOT / "data" / "external" / "caption_emporium" / "t5_rewrite_splits" CAPTION_FIELDS = ("caption_short", "caption_medium", "caption_long") def _iter_jsonl(path: Path) -> Iterable[Dict[str, Any]]: if path.suffix == ".gz": with gzip.open(path, "rt", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue yield json.loads(line) else: with path.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue yield json.loads(line) def _canonicalize_tag(tag: str) -> str: t = " ".join(str(tag or "").strip().split()).lower() return t.replace(" ", "_").replace("\\(", "(").replace("\\)", ")") def _flatten_tags(raw: Any) -> List[str]: cats = raw if isinstance(raw, str): try: cats = json.loads(raw) except json.JSONDecodeError: return [] if not isinstance(cats, dict): return [] out = set() for vals in cats.values(): if not isinstance(vals, list): continue for tag in vals: ct = _canonicalize_tag(str(tag)) if ct: out.add(ct) return sorted(out) def _split_name(sample_id: Any, val_frac: float, test_frac: float) -> str: key = str(sample_id).encode("utf-8") digest = hashlib.blake2b(key, digest_size=8).hexdigest() bucket = int(digest, 16) % 10000 test_cut = int(round(test_frac * 10000)) val_cut = test_cut + int(round(val_frac * 10000)) if bucket < test_cut: return "test" if bucket < val_cut: return "val" return "train" def _reservoir_add( arr: List[Dict[str, Any]], item: Dict[str, Any], cap: Optional[int], seen_count: int, rng: random.Random, ) -> None: if cap is None: arr.append(item) return if cap <= 0: return if len(arr) < cap: arr.append(item) return j = rng.randint(0, seen_count - 1) if j < cap: arr[j] = item def _write_jsonl(path: Path, rows: List[Dict[str, Any]]) -> None: with path.open("w", encoding="utf-8") as f: for row in rows: f.write(json.dumps(row, ensure_ascii=False) + "\n") def main() -> int: ap = argparse.ArgumentParser(description="Build T5 rewrite fine-tuning splits from CaptionEmporium JSONL(.gz)") ap.add_argument("--input", type=Path, default=DEFAULT_INPUT) ap.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR) ap.add_argument("--val-frac", type=float, default=0.01) ap.add_argument("--test-frac", type=float, default=0.01) ap.add_argument("--max-train", type=int, default=60000, help="Reservoir cap for train split (0 disables)") ap.add_argument("--max-val", type=int, default=3000, help="Reservoir cap for val split (0 disables)") ap.add_argument("--max-test", type=int, default=3000, help="Reservoir cap for test split (0 disables)") ap.add_argument("--seed", type=int, default=42) ap.add_argument("--task-prefix", type=str, default="caption_to_tags:") args = ap.parse_args() input_path = args.input if args.input.is_absolute() else (REPO_ROOT / args.input).resolve() if not input_path.is_file(): raise FileNotFoundError(f"Input dataset not found: {input_path}") output_dir = args.output_dir if args.output_dir.is_absolute() else (REPO_ROOT / args.output_dir).resolve() output_dir.mkdir(parents=True, exist_ok=True) rng = random.Random(args.seed) split_rows: Dict[str, List[Dict[str, Any]]] = {"train": [], "val": [], "test": []} split_seen = {"train": 0, "val": 0, "test": 0} split_caps: Dict[str, Optional[int]] = { "train": None if args.max_train == 0 else args.max_train, "val": None if args.max_val == 0 else args.max_val, "test": None if args.max_test == 0 else args.max_test, } rows_total = 0 rows_with_tags = 0 examples_total = 0 prefix = (args.task_prefix or "").strip() for obj in _iter_jsonl(input_path): rows_total += 1 sid = obj.get("id", rows_total) tags = _flatten_tags(obj.get("tags_ground_truth_categorized")) if not tags: continue rows_with_tags += 1 target_text = ", ".join(tags) split = _split_name(sid, args.val_frac, args.test_frac) for field in CAPTION_FIELDS: caption = str(obj.get(field, "") or "").strip() if not caption: continue source_text = f"{prefix} {caption}".strip() if prefix else caption rec = { "id": sid, "caption_field": field, "source_text": source_text, "target_text": target_text, } split_seen[split] += 1 _reservoir_add( split_rows[split], rec, split_caps[split], split_seen[split], rng, ) examples_total += 1 for name in ("train", "val", "test"): rng.shuffle(split_rows[name]) _write_jsonl(output_dir / f"{name}.jsonl", split_rows[name]) meta = { "input_path": str(input_path), "output_dir": str(output_dir), "seed": args.seed, "val_frac": args.val_frac, "test_frac": args.test_frac, "max_train": args.max_train, "max_val": args.max_val, "max_test": args.max_test, "task_prefix": prefix, "rows_total": rows_total, "rows_with_tags": rows_with_tags, "examples_total_pre_cap": examples_total, "examples_written": {k: len(v) for k, v in split_rows.items()}, "examples_seen_by_split_pre_cap": split_seen, "caption_fields": list(CAPTION_FIELDS), } with (output_dir / "meta.json").open("w", encoding="utf-8") as f: json.dump(meta, f, ensure_ascii=False, indent=2) print(json.dumps(meta, ensure_ascii=False, indent=2)) return 0 if __name__ == "__main__": raise SystemExit(main())