import re from typing import Mapping, Sequence _TOKEN_RE = re.compile(r"[a-z0-9]+(?:'[a-z0-9]+)?") def extract_user_provided_tags_upto_3_words(prompt_in: str) -> list[str]: """ Heuristic: - split on '.' and ',' - strip leading/trailing whitespace - split on whitespace - keep items with <= 3 tokens """ if not prompt_in: return [] parts = re.split(r"[.,]+", prompt_in) out: list[str] = [] seen = set() for raw in parts: item = raw.strip() if not item: continue tokens = item.split() if len(tokens) <= 3: key = item.lower() if key not in seen: seen.add(key) out.append(item) return out def extract_exact_tag_query_phrases( prompt_in: str, tag_counts: Mapping[str, int], alias2tags: Mapping[str, Sequence[str]], *, min_tag_count: int = 0, max_ngram: int = 2, ) -> list[str]: """Extract exact canonical/alias n-gram matches as retrieval query phrases. The output is conservative: every emitted phrase either is a canonical tag or resolves through the alias map to at least one canonical tag that clears the count floor. Longest matches win, so a matched 2-gram suppresses its own component 1-grams. """ if not prompt_in or max_ngram <= 0: return [] text = prompt_in.strip() prefix = "caption_to_tags:" if text.lower().startswith(prefix): text = text[len(prefix):].strip() tokens = _TOKEN_RE.findall(text.lower()) if not tokens: return [] def _count_ok(tag: str) -> bool: if min_tag_count <= 0: return True return int(tag_counts.get(tag, 0) or 0) >= min_tag_count def _resolves(lookup: str) -> bool: if lookup in tag_counts: return _count_ok(lookup) return any(_count_ok(tag) for tag in alias2tags.get(lookup, ())) matches: list[tuple[int, int, str]] = [] max_n = min(max(1, int(max_ngram)), len(tokens)) for n in range(max_n, 0, -1): for start in range(0, len(tokens) - n + 1): lookup = "_".join(tokens[start:start + n]) if _resolves(lookup): matches.append((start, start + n, lookup)) used: set[int] = set() selected: list[tuple[int, str]] = [] seen: set[str] = set() for start, end, lookup in matches: span = set(range(start, end)) if span & used or lookup in seen: continue used.update(span) seen.add(lookup) selected.append((start, lookup)) selected.sort(key=lambda row: row[0]) return [lookup for _, lookup in selected] if __name__ == "__main__": print("preproc.py imports ok")