"""Text cleaning, synthetic corruption, dataset loading, and tokenization.""" from __future__ import annotations import csv import json import random import re from pathlib import Path from typing import Callable, Iterable import torch from .config import TrainConfig from .constants import * from .io_utils import * from .runtime import require_package def normalise_text(text: str) -> str: text = str(text or "") text = text.replace("\u2018", "'").replace("\u2019", "'") text = text.replace("\u201c", '"').replace("\u201d", '"') text = re.sub(r"\s+", " ", text).strip() return text def strip_instruction(text: str) -> str: text = normalise_text(text) lowered = text.lower() prefixes = [ "fix grammar:", "fix the grammar:", "correct grammar:", "correct spelling:", "fix spelling:", "rewrite:", "rewrite this:", "paraphrase:", "improve:", "improve this:", "make this sound better:", "clarify:", "punctuate:", "capitalize:", ] for prefix in prefixes: if lowered.startswith(prefix): return text[len(prefix):].strip(" -:") return text def valid_pair(noisy: str, clean: str) -> bool: noisy = normalise_text(noisy) clean = normalise_text(clean) if not noisy or not clean: return False if noisy == clean and len(clean) < 15: return False if len(noisy) < 4 or len(clean) < 4: return False if len(noisy) > 1400 or len(clean) > 1400: return False return True def add_pair(rows: list[tuple[str, str]], noisy: str, clean: str) -> None: noisy = strip_instruction(noisy) clean = strip_instruction(clean) if valid_pair(noisy, clean): rows.append((noisy, clean)) def dedupe_pairs(rows: Iterable[tuple[str, str]], limit: int | None = None) -> list[tuple[str, str]]: seen = set() out = [] for noisy, clean in rows: noisy = normalise_text(noisy) clean = normalise_text(clean) if not valid_pair(noisy, clean): continue key = (noisy.lower(), clean.lower()) if key in seen: continue seen.add(key) out.append((noisy, clean)) if limit and len(out) >= limit: break return out def read_pairs_jsonl(path: Path) -> list[tuple[str, str]]: rows = [] with path.open("r", encoding="utf-8") as fh: for line in fh: line = line.strip() if not line: continue item = json.loads(line) src = item.get("input") or item.get("noisy") or item.get("src") or item.get("bad") tgt = item.get("target") or item.get("clean") or item.get("tgt") or item.get("good") add_pair(rows, src, tgt) return dedupe_pairs(rows) def read_pairs_file(path: Path) -> list[tuple[str, str]]: path = Path(path) if not path.exists(): raise FileNotFoundError(path) if path.suffix.lower() in {".jsonl", ".ndjson"}: return read_pairs_jsonl(path) if path.suffix.lower() == ".json": raw = json.loads(path.read_text(encoding="utf-8")) rows = [] for item in raw if isinstance(raw, list) else []: if isinstance(item, dict): src = item.get("input") or item.get("noisy") or item.get("src") or item.get("bad") tgt = item.get("target") or item.get("clean") or item.get("tgt") or item.get("good") add_pair(rows, src, tgt) elif isinstance(item, (list, tuple)) and len(item) == 2: add_pair(rows, item[0], item[1]) return dedupe_pairs(rows) if path.suffix.lower() in {".csv", ".tsv"}: rows = [] delimiter = "\t" if path.suffix.lower() == ".tsv" else "," with path.open("r", encoding="utf-8", newline="") as fh: reader = csv.DictReader(fh, delimiter=delimiter) for item in reader: src = item.get("input") or item.get("noisy") or item.get("src") or item.get("bad") tgt = item.get("target") or item.get("clean") or item.get("tgt") or item.get("good") add_pair(rows, src, tgt) return dedupe_pairs(rows) raise ValueError(f"Unsupported data file: {path}") def keyboard_typo(word: str) -> str: if len(word) < 2: return word i = random.randrange(len(word)) ch = word[i].lower() if ch not in KEYBOARD_ADJ: return word repl = random.choice(KEYBOARD_ADJ[ch]) if word[i].isupper(): repl = repl.upper() return word[:i] + repl + word[i + 1:] def corrupt_word(word: str) -> str: if len(word) <= 2: return word lower = word.lower() if lower in COMMON_TYPOS and random.random() < 0.65: repl = random.choice(COMMON_TYPOS[lower]) return repl.capitalize() if word[0].isupper() else repl mode = random.choice(["swap", "drop", "double", "keyboard", "dyslexia"]) if mode == "swap" and len(word) > 3: i = random.randint(0, len(word) - 2) chars = list(word) chars[i], chars[i + 1] = chars[i + 1], chars[i] return "".join(chars) if mode == "drop" and len(word) > 4: i = random.randint(1, len(word) - 2) return word[:i] + word[i + 1:] if mode == "double": i = random.randrange(len(word)) return word[:i] + word[i] + word[i:] if mode == "keyboard": return keyboard_typo(word) if mode == "dyslexia": chars = list(word) for i, ch in enumerate(chars): repl = LETTER_SWAPS.get(ch.lower()) if repl and random.random() < 0.55: chars[i] = repl.upper() if ch.isupper() else repl return "".join(chars) return word def maybe_homophone(token: str) -> str: lower = re.sub(r"[^a-z']", "", token.lower()) choices = [b for a, b in HOMOPHONES if a == lower] if not choices: return token repl = random.choice(choices) return repl.capitalize() if token[:1].isupper() else repl def corrupt_sentence(sentence: str, intensity: float = 0.35) -> str: pieces = [] for raw in sentence.split(): prefix = "" suffix = "" word = raw while word and not word[0].isalnum(): prefix += word[0] word = word[1:] while word and not word[-1].isalnum(): suffix = word[-1] + suffix word = word[:-1] if word: if random.random() < 0.10: word = maybe_homophone(word) if random.random() < intensity: word = corrupt_word(word) if random.random() < 0.08: word = word.lower() if random.random() < 0.75 else word.upper() if random.random() < 0.22: suffix = suffix.replace(",", "").replace(".", "") pieces.append(prefix + word + suffix) text = " ".join(pieces) if random.random() < 0.30: text = text.lower() if random.random() < 0.10: text = text.rstrip(".!?") if random.random() < 0.12: text = re.sub(r"([.!?])\s+", " ", text, count=1) if random.random() < 0.08: text = re.sub(r"\s+", " ", text.replace(" ", "", 1)) return normalise_text(text) def sentence_split(text: str) -> list[str]: text = normalise_text(text) return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()] def synthetic_pairs(clean_texts: Iterable[str], max_pairs: int) -> list[tuple[str, str]]: rows = [] clean = [normalise_text(t) for t in clean_texts if 10 < len(normalise_text(t)) < 900] random.shuffle(clean) for text in clean: for intensity in (0.18, 0.35, 0.55): add_pair(rows, corrupt_sentence(text, intensity), text) if len(rows) >= max_pairs: return dedupe_pairs(rows, max_pairs) # Punctuation and capitalization restoration. if any(c in text for c in ".,!?;:"): no_punct = re.sub(r"[.,!?;:]", "", text) add_pair(rows, no_punct, text) if any(c.isupper() for c in text): add_pair(rows, text.lower(), text) # Paragraph flow: merge sentence boundaries and ask the model to restore # a smoother version. sents = sentence_split(text) if len(sents) >= 2: merged = " ".join(s.rstrip(".!?") for s in sents) add_pair(rows, merged, text) # Identity preservation teaches the model not to rewrite good text # unnecessarily. if random.random() < 0.30: add_pair(rows, text, text) if len(rows) >= max_pairs: break return dedupe_pairs(rows, max_pairs) def try_load_dataset(log_fn: Callable[[str], None], *args, **kwargs): datasets = require_package("datasets") try: return datasets.load_dataset(*args, **kwargs) except Exception as exc: log_fn(f" skipped {args}: {exc}") return None def load_builtin_pairs(max_pairs: int, include_c4: bool, log_fn: Callable[[str], None]) -> list[tuple[str, str]]: rows: list[tuple[str, str]] = [] def remaining() -> int: return max(0, max_pairs - len(rows)) def quota(frac: float, floor: int = 500) -> int: return min(max(floor, int(max_pairs * frac)), remaining()) # Seed examples make smoke/offline setup useful and strengthen the exact # everyday writing style this app is for. for clean in SEED_CLEAN_SENTENCES: add_pair(rows, corrupt_sentence(clean, 0.45), clean) add_pair(rows, clean.lower(), clean) add_pair(rows, clean, clean) log_fn("Loading JFLEG grammar correction...") for split in ("validation", "test"): ds = try_load_dataset(log_fn, "jfleg", split=split) if ds is None: continue start = len(rows) target = start + quota(0.08) for item in ds: src = item.get("sentence", "") for correction in item.get("corrections", []) or []: add_pair(rows, src, correction) if len(rows) >= target: break log_fn("Loading Grammarly CoEdIT correction/rewrite tasks...") ds = try_load_dataset(log_fn, "grammarly/coedit", split="train") if ds is not None: start = len(rows) target = start + quota(0.32) for item in ds: src = item.get("src") or item.get("input") or item.get("source") or "" tgt = item.get("tgt") or item.get("target") or item.get("output") or "" add_pair(rows, src, tgt) if len(rows) >= target: break log_fn("Loading W&I/LOCNESS learner-English correction if available...") ds = try_load_dataset(log_fn, "wi_locness", "wi", split="train") if ds is not None: start = len(rows) target = start + quota(0.12) for item in ds: for edit in item.get("edits", []) or []: orig = edit.get("orig") or "" corrections = edit.get("cor") or [] if corrections: add_pair(rows, orig, corrections[0]) if len(rows) >= target: break log_fn("Loading ASSET simplification/rewrite examples...") ds = try_load_dataset(log_fn, "asset", "simplification", split="validation") if ds is not None: start = len(rows) target = start + quota(0.08) for item in ds: src = item.get("original") or "" for tgt in item.get("simplifications", []) or []: add_pair(rows, src, tgt) if len(rows) >= target: break log_fn("Loading WikiSplit sentence-flow examples...") ds = try_load_dataset(log_fn, "wiki_split", split="train") if ds is not None: start = len(rows) target = start + quota(0.08) for item in ds: src = item.get("complex_sentence") or "" s1 = item.get("simple_sentence_1") or "" s2 = item.get("simple_sentence_2") or "" tgt = normalise_text(f"{s1} {s2}") if s2 else s1 add_pair(rows, src, tgt) if len(rows) >= target: break log_fn("Loading MRPC paraphrase pairs...") ds = try_load_dataset(log_fn, "glue", "mrpc", split="train") if ds is not None: start = len(rows) target = start + quota(0.06) for item in ds: if int(item.get("label", 0)) == 1: s1 = item.get("sentence1") or "" s2 = item.get("sentence2") or "" add_pair(rows, s1, s2) add_pair(rows, s2, s1) if len(rows) >= target: break if include_c4: log_fn("Loading optional C4-200M GEC stream...") ds = try_load_dataset(log_fn, "liweili/c4_200m", split="train", streaming=True) if ds is not None: start = len(rows) target = start + quota(0.15) for item in ds: src = item.get("input") or item.get("src") or "" tgt = item.get("output") or item.get("tgt") or "" add_pair(rows, src, tgt) if len(rows) >= target: break base = dedupe_pairs(rows) clean_pool = [clean for _, clean in base] log_fn("Generating synthetic typo, dyslexia-like, punctuation, and preservation pairs...") synth_target = max(1000, max_pairs - len(base)) rows = base + synthetic_pairs(clean_pool + SEED_CLEAN_SENTENCES, synth_target) random.shuffle(rows) return dedupe_pairs(rows, max_pairs) def save_pairs(rows: list[tuple[str, str]], path: Path) -> None: path.parent.mkdir(parents=True, exist_ok=True) tmp = path.with_suffix(".tmp") with tmp.open("w", encoding="utf-8") as fh: for noisy, clean in rows: fh.write(json.dumps({"input": noisy, "target": clean}, ensure_ascii=False) + "\n") tmp.replace(path) def load_prepared_pairs(path: Path = PAIRS_PATH) -> list[tuple[str, str]]: if not path.exists(): return [] return read_pairs_jsonl(path) def train_tokenizer(rows: list[tuple[str, str]], vocab_size: int, path: Path = TOKENIZER_PATH): require_package("tokenizers") from tokenizers import Tokenizer from tokenizers.decoders import ByteLevel as ByteLevelDecoder from tokenizers.models import BPE from tokenizers.pre_tokenizers import ByteLevel from tokenizers.trainers import BpeTrainer tok = Tokenizer(BPE(unk_token="[UNK]")) tok.pre_tokenizer = ByteLevel(add_prefix_space=False) tok.decoder = ByteLevelDecoder() trainer = BpeTrainer( vocab_size=vocab_size, min_frequency=2, special_tokens=SPECIAL_TOKENS, show_progress=True, ) texts = [] for noisy, clean in rows: texts.append(noisy) texts.append(clean) tok.train_from_iterator(texts, trainer=trainer) path.parent.mkdir(parents=True, exist_ok=True) tok.save(str(path)) return tok def load_tokenizer(path: Path = TOKENIZER_PATH): require_package("tokenizers") from tokenizers import Tokenizer from tokenizers.decoders import ByteLevel as ByteLevelDecoder tok = Tokenizer.from_file(str(path)) if tok.decoder is None: tok.decoder = ByteLevelDecoder() return tok def encode_pair(tok, noisy: str, clean: str, max_len: int) -> tuple[list[int], list[int], list[int]]: src_tokens = tok.encode(normalise_text(noisy)).ids[: max_len - 2] tgt_tokens = tok.encode(normalise_text(clean)).ids[: max_len - 2] src = [BOS_ID] + src_tokens + [EOS_ID] tgt_in = [BOS_ID] + tgt_tokens tgt_out = tgt_tokens + [EOS_ID] return src, tgt_in, tgt_out def pad_batch(seqs: list[list[int]], pad_value: int) -> torch.Tensor: max_len = max(len(s) for s in seqs) out = torch.full((len(seqs), max_len), pad_value, dtype=torch.long) for i, seq in enumerate(seqs): out[i, : len(seq)] = torch.tensor(seq, dtype=torch.long) return out def collate_pairs(rows: list[tuple[str, str]], tok, cfg: TrainConfig) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: srcs, tins, touts = [], [], [] for noisy, clean in rows: src, tin, tout = encode_pair(tok, noisy, clean, cfg.max_len) srcs.append(src) tins.append(tin) touts.append(tout) return pad_batch(srcs, PAD_ID), pad_batch(tins, PAD_ID), pad_batch(touts, -100) def make_batches(rows: list[tuple[str, str]], batch_size: int, shuffle_batches: bool) -> list[list[int]]: idx = list(range(len(rows))) idx.sort(key=lambda i: len(rows[i][0]) + len(rows[i][1])) batches = [idx[i : i + batch_size] for i in range(0, len(idx), batch_size)] if shuffle_batches: # Shuffle locally similar lengths to keep padding low without losing # all stochasticity. chunks = [batches[i : i + 256] for i in range(0, len(batches), 256)] for chunk in chunks: random.shuffle(chunk) batches = [b for chunk in chunks for b in chunk] random.shuffle(batches) return batches def split_train_val(rows: list[tuple[str, str]], val_split: float, seed: int): random.Random(seed).shuffle(rows) val_n = max(1, int(len(rows) * val_split)) return rows[val_n:], rows[:val_n]