Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import gzip | |
| import hashlib | |
| import json | |
| import random | |
| from pathlib import Path | |
| from typing import Any, Dict, Iterable, List, Optional, Tuple | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| DEFAULT_INPUT = REPO_ROOT / "data" / "external" / "caption_emporium" / "furry-e621-safe-llama3.2-11b" / "train.jsonl.gz" | |
| DEFAULT_OUTPUT_DIR = REPO_ROOT / "data" / "external" / "caption_emporium" / "t5_rewrite_splits" | |
| CAPTION_FIELDS = ("caption_short", "caption_medium", "caption_long") | |
| def _iter_jsonl(path: Path) -> Iterable[Dict[str, Any]]: | |
| if path.suffix == ".gz": | |
| with gzip.open(path, "rt", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| yield json.loads(line) | |
| else: | |
| with path.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| yield json.loads(line) | |
| def _canonicalize_tag(tag: str) -> str: | |
| t = " ".join(str(tag or "").strip().split()).lower() | |
| return t.replace(" ", "_").replace("\\(", "(").replace("\\)", ")") | |
| def _flatten_tags(raw: Any) -> List[str]: | |
| cats = raw | |
| if isinstance(raw, str): | |
| try: | |
| cats = json.loads(raw) | |
| except json.JSONDecodeError: | |
| return [] | |
| if not isinstance(cats, dict): | |
| return [] | |
| out = set() | |
| for vals in cats.values(): | |
| if not isinstance(vals, list): | |
| continue | |
| for tag in vals: | |
| ct = _canonicalize_tag(str(tag)) | |
| if ct: | |
| out.add(ct) | |
| return sorted(out) | |
| def _split_name(sample_id: Any, val_frac: float, test_frac: float) -> str: | |
| key = str(sample_id).encode("utf-8") | |
| digest = hashlib.blake2b(key, digest_size=8).hexdigest() | |
| bucket = int(digest, 16) % 10000 | |
| test_cut = int(round(test_frac * 10000)) | |
| val_cut = test_cut + int(round(val_frac * 10000)) | |
| if bucket < test_cut: | |
| return "test" | |
| if bucket < val_cut: | |
| return "val" | |
| return "train" | |
| def _reservoir_add( | |
| arr: List[Dict[str, Any]], | |
| item: Dict[str, Any], | |
| cap: Optional[int], | |
| seen_count: int, | |
| rng: random.Random, | |
| ) -> None: | |
| if cap is None: | |
| arr.append(item) | |
| return | |
| if cap <= 0: | |
| return | |
| if len(arr) < cap: | |
| arr.append(item) | |
| return | |
| j = rng.randint(0, seen_count - 1) | |
| if j < cap: | |
| arr[j] = item | |
| def _write_jsonl(path: Path, rows: List[Dict[str, Any]]) -> None: | |
| with path.open("w", encoding="utf-8") as f: | |
| for row in rows: | |
| f.write(json.dumps(row, ensure_ascii=False) + "\n") | |
| def main() -> int: | |
| ap = argparse.ArgumentParser(description="Build T5 rewrite fine-tuning splits from CaptionEmporium JSONL(.gz)") | |
| ap.add_argument("--input", type=Path, default=DEFAULT_INPUT) | |
| ap.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR) | |
| ap.add_argument("--val-frac", type=float, default=0.01) | |
| ap.add_argument("--test-frac", type=float, default=0.01) | |
| ap.add_argument("--max-train", type=int, default=60000, help="Reservoir cap for train split (0 disables)") | |
| ap.add_argument("--max-val", type=int, default=3000, help="Reservoir cap for val split (0 disables)") | |
| ap.add_argument("--max-test", type=int, default=3000, help="Reservoir cap for test split (0 disables)") | |
| ap.add_argument("--seed", type=int, default=42) | |
| ap.add_argument("--task-prefix", type=str, default="caption_to_tags:") | |
| args = ap.parse_args() | |
| input_path = args.input if args.input.is_absolute() else (REPO_ROOT / args.input).resolve() | |
| if not input_path.is_file(): | |
| raise FileNotFoundError(f"Input dataset not found: {input_path}") | |
| output_dir = args.output_dir if args.output_dir.is_absolute() else (REPO_ROOT / args.output_dir).resolve() | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| rng = random.Random(args.seed) | |
| split_rows: Dict[str, List[Dict[str, Any]]] = {"train": [], "val": [], "test": []} | |
| split_seen = {"train": 0, "val": 0, "test": 0} | |
| split_caps: Dict[str, Optional[int]] = { | |
| "train": None if args.max_train == 0 else args.max_train, | |
| "val": None if args.max_val == 0 else args.max_val, | |
| "test": None if args.max_test == 0 else args.max_test, | |
| } | |
| rows_total = 0 | |
| rows_with_tags = 0 | |
| examples_total = 0 | |
| prefix = (args.task_prefix or "").strip() | |
| for obj in _iter_jsonl(input_path): | |
| rows_total += 1 | |
| sid = obj.get("id", rows_total) | |
| tags = _flatten_tags(obj.get("tags_ground_truth_categorized")) | |
| if not tags: | |
| continue | |
| rows_with_tags += 1 | |
| target_text = ", ".join(tags) | |
| split = _split_name(sid, args.val_frac, args.test_frac) | |
| for field in CAPTION_FIELDS: | |
| caption = str(obj.get(field, "") or "").strip() | |
| if not caption: | |
| continue | |
| source_text = f"{prefix} {caption}".strip() if prefix else caption | |
| rec = { | |
| "id": sid, | |
| "caption_field": field, | |
| "source_text": source_text, | |
| "target_text": target_text, | |
| } | |
| split_seen[split] += 1 | |
| _reservoir_add( | |
| split_rows[split], | |
| rec, | |
| split_caps[split], | |
| split_seen[split], | |
| rng, | |
| ) | |
| examples_total += 1 | |
| for name in ("train", "val", "test"): | |
| rng.shuffle(split_rows[name]) | |
| _write_jsonl(output_dir / f"{name}.jsonl", split_rows[name]) | |
| meta = { | |
| "input_path": str(input_path), | |
| "output_dir": str(output_dir), | |
| "seed": args.seed, | |
| "val_frac": args.val_frac, | |
| "test_frac": args.test_frac, | |
| "max_train": args.max_train, | |
| "max_val": args.max_val, | |
| "max_test": args.max_test, | |
| "task_prefix": prefix, | |
| "rows_total": rows_total, | |
| "rows_with_tags": rows_with_tags, | |
| "examples_total_pre_cap": examples_total, | |
| "examples_written": {k: len(v) for k, v in split_rows.items()}, | |
| "examples_seen_by_split_pre_cap": split_seen, | |
| "caption_fields": list(CAPTION_FIELDS), | |
| } | |
| with (output_dir / "meta.json").open("w", encoding="utf-8") as f: | |
| json.dump(meta, f, ensure_ascii=False, indent=2) | |
| print(json.dumps(meta, ensure_ascii=False, indent=2)) | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |