Food Desert commited on
Commit
29b12cd
·
1 Parent(s): c3191c6

Add exact n-gram retrieval query hints

Browse files
.gitignore CHANGED
@@ -29,6 +29,10 @@ data/analysis/openrouter_concurrency_*.json
29
  data/analysis/pipeline_call_count_probe*.json
30
  data/analysis/rewrite_only_compare_*.json
31
  data/analysis/rewrite_ablation_*.json
 
 
 
 
32
  data/analysis/t5_sweep_two_stage_*.json
33
  data/analysis/t5_sweep_two_stage_*.csv
34
  data/analysis/tmp_ckpt_compare_*.json
@@ -46,3 +50,7 @@ data/eval_results/tmp_llm_rewrite_diag*.jsonl
46
  data/eval_results/eval_caption_cogvlm_n30_llm_heur_*_20260509.jsonl
47
  data/eval_results/eval_caption_cogvlm_n30_t5_heur_*_20260509.jsonl
48
  data/eval_results/eval_caption_cogvlm_n1_seed42_20260509_005007.jsonl
 
 
 
 
 
29
  data/analysis/pipeline_call_count_probe*.json
30
  data/analysis/rewrite_only_compare_*.json
31
  data/analysis/rewrite_ablation_*.json
32
+ data/analysis/retrieval_ngram_recovery_*.json
33
+ data/analysis/retrieval_ngram_recovery_*.csv
34
+ data/analysis/t5_tag_frequency_profile_*.json
35
+ data/analysis/t5_tag_frequency_profile_*.csv
36
  data/analysis/t5_sweep_two_stage_*.json
37
  data/analysis/t5_sweep_two_stage_*.csv
38
  data/analysis/tmp_ckpt_compare_*.json
 
50
  data/eval_results/eval_caption_cogvlm_n30_llm_heur_*_20260509.jsonl
51
  data/eval_results/eval_caption_cogvlm_n30_t5_heur_*_20260509.jsonl
52
  data/eval_results/eval_caption_cogvlm_n1_seed42_20260509_005007.jsonl
53
+
54
+ # Temporary local profiling helpers
55
+ scripts/profile_retrieval_ngram_recovery.py
56
+ scripts/profile_t5_tag_frequency.py
app.py CHANGED
@@ -78,7 +78,10 @@ if _STARTUP_PROFILE_ON and _STARTUP_PROFILE_PATH is not None:
78
  import gradio as gr
79
  _startup_profile_mark("import.gradio.done")
80
 
81
- from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words
 
 
 
82
  _startup_profile_mark("import.psq_rag.pipeline.preproc.done")
83
  from psq_rag.llm.rewrite import llm_rewrite_prompt
84
  _startup_profile_mark("import.psq_rag.llm.rewrite.done")
@@ -93,6 +96,7 @@ from psq_rag.retrieval.state import (
93
  get_tag_type_name,
94
  get_tag_implications,
95
  get_tag_counts,
 
96
  )
97
  _startup_profile_mark("import.psq_rag.retrieval.state.done")
98
  from psq_rag.ui.group_ranked_display import rank_groups_from_tfidf, _load_enabled_groups
@@ -1474,9 +1478,10 @@ display_top_groups_default = int(os.environ.get("PSQ_DISPLAY_TOP_GROUPS", "10"))
1474
  display_top_tags_per_group_default = int(os.environ.get("PSQ_DISPLAY_TOP_TAGS_PER_GROUP", "7"))
1475
  display_rank_top_k_default = int(os.environ.get("PSQ_DISPLAY_GROUP_RANK_TOP_K", "7"))
1476
  display_max_rows_default = int(os.environ.get("PSQ_DISPLAY_MAX_ROWS", "14"))
1477
- retrieval_global_k = int(os.environ.get("PSQ_RETRIEVAL_GLOBAL_K", "300"))
1478
- retrieval_per_phrase_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_K", "10"))
1479
  retrieval_per_phrase_final_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_FINAL_K", "1"))
 
1480
  selection_mode = os.environ.get("PSQ_SELECTION_MODE", "chunked_map_union").strip()
1481
  selection_chunk_size = int(os.environ.get("PSQ_SELECTION_CHUNK_SIZE", "60"))
1482
  selection_per_phrase_k = int(os.environ.get("PSQ_SELECTION_PER_PHRASE_K", "2"))
@@ -2360,6 +2365,7 @@ def rag_pipeline_ui(
2360
  f"retrieval_global_k={retrieval_global_k} "
2361
  f"retrieval_per_phrase_k={retrieval_per_phrase_k} "
2362
  f"retrieval_per_phrase_final_k={retrieval_per_phrase_final_k} "
 
2363
  f"selection_mode={selection_mode} "
2364
  f"selection_chunk_size={selection_chunk_size} "
2365
  f"selection_per_phrase_k={selection_per_phrase_k} "
@@ -2386,6 +2392,14 @@ def rag_pipeline_ui(
2386
  user_tags_raw = extract_user_provided_tags_upto_3_words(prompt_in)
2387
  user_tags, removed_user_low = _filter_min_count_tags(user_tags_raw, min_tag_count)
2388
  user_tags, removed_user_excluded = _filter_excluded_recommendation_tags(user_tags)
 
 
 
 
 
 
 
 
2389
  dt = time.perf_counter()-t0
2390
  _record_timing("preprocess", dt)
2391
  log(f"Preprocess (user tag extraction): {dt:.2f}s")
@@ -2404,6 +2418,20 @@ def rag_pipeline_ui(
2404
  f"Filtered {len(removed_user_excluded)} excluded user tags: "
2405
  f"{', '.join(removed_user_excluded)}"
2406
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2407
  log("")
2408
 
2409
  rewrite_prefilled = (rewrite_override or "").strip()
@@ -2489,11 +2517,12 @@ def rag_pipeline_ui(
2489
  log("Rewrite:")
2490
  log(rewritten if rewritten else "(empty)")
2491
  log("")
2492
-
2493
- rewrite_for_retrieval = rewritten
2494
- if user_tags:
2495
- # keep them separate in logs, but allow them to help retrieval
2496
- rewrite_for_retrieval = (rewrite_for_retrieval + ", " + ", ".join(user_tags)).strip(", ").strip()
 
2497
 
2498
 
2499
  log("Step 2: Prompt Squirrel retrieval (hidden)")
 
78
  import gradio as gr
79
  _startup_profile_mark("import.gradio.done")
80
 
81
+ from psq_rag.pipeline.preproc import (
82
+ extract_exact_tag_query_phrases,
83
+ extract_user_provided_tags_upto_3_words,
84
+ )
85
  _startup_profile_mark("import.psq_rag.pipeline.preproc.done")
86
  from psq_rag.llm.rewrite import llm_rewrite_prompt
87
  _startup_profile_mark("import.psq_rag.llm.rewrite.done")
 
96
  get_tag_type_name,
97
  get_tag_implications,
98
  get_tag_counts,
99
+ get_alias2tags,
100
  )
101
  _startup_profile_mark("import.psq_rag.retrieval.state.done")
102
  from psq_rag.ui.group_ranked_display import rank_groups_from_tfidf, _load_enabled_groups
 
1478
  display_top_tags_per_group_default = int(os.environ.get("PSQ_DISPLAY_TOP_TAGS_PER_GROUP", "7"))
1479
  display_rank_top_k_default = int(os.environ.get("PSQ_DISPLAY_GROUP_RANK_TOP_K", "7"))
1480
  display_max_rows_default = int(os.environ.get("PSQ_DISPLAY_MAX_ROWS", "14"))
1481
+ retrieval_global_k = int(os.environ.get("PSQ_RETRIEVAL_GLOBAL_K", "300"))
1482
+ retrieval_per_phrase_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_K", "10"))
1483
  retrieval_per_phrase_final_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_FINAL_K", "1"))
1484
+ retrieval_exact_ngram_max = int(os.environ.get("PSQ_RETRIEVAL_EXACT_NGRAM_MAX", "2"))
1485
  selection_mode = os.environ.get("PSQ_SELECTION_MODE", "chunked_map_union").strip()
1486
  selection_chunk_size = int(os.environ.get("PSQ_SELECTION_CHUNK_SIZE", "60"))
1487
  selection_per_phrase_k = int(os.environ.get("PSQ_SELECTION_PER_PHRASE_K", "2"))
 
2365
  f"retrieval_global_k={retrieval_global_k} "
2366
  f"retrieval_per_phrase_k={retrieval_per_phrase_k} "
2367
  f"retrieval_per_phrase_final_k={retrieval_per_phrase_final_k} "
2368
+ f"retrieval_exact_ngram_max={retrieval_exact_ngram_max} "
2369
  f"selection_mode={selection_mode} "
2370
  f"selection_chunk_size={selection_chunk_size} "
2371
  f"selection_per_phrase_k={selection_per_phrase_k} "
 
2392
  user_tags_raw = extract_user_provided_tags_upto_3_words(prompt_in)
2393
  user_tags, removed_user_low = _filter_min_count_tags(user_tags_raw, min_tag_count)
2394
  user_tags, removed_user_excluded = _filter_excluded_recommendation_tags(user_tags)
2395
+ exact_query_phrases = extract_exact_tag_query_phrases(
2396
+ prompt_in,
2397
+ get_tag_counts(),
2398
+ get_alias2tags(),
2399
+ min_tag_count=min_tag_count,
2400
+ max_ngram=max(0, retrieval_exact_ngram_max),
2401
+ )
2402
+ exact_query_phrases, removed_exact_excluded = _filter_excluded_recommendation_tags(exact_query_phrases)
2403
  dt = time.perf_counter()-t0
2404
  _record_timing("preprocess", dt)
2405
  log(f"Preprocess (user tag extraction): {dt:.2f}s")
 
2418
  f"Filtered {len(removed_user_excluded)} excluded user tags: "
2419
  f"{', '.join(removed_user_excluded)}"
2420
  )
2421
+ if retrieval_exact_ngram_max > 0:
2422
+ log(f"Exact caption tag query phrases (1-{retrieval_exact_ngram_max} grams):")
2423
+ else:
2424
+ log("Exact caption tag query phrases: disabled")
2425
+ if exact_query_phrases:
2426
+ shown = ", ".join(exact_query_phrases[:40])
2427
+ log(shown + (" ..." if len(exact_query_phrases) > 40 else ""))
2428
+ else:
2429
+ log("(none)")
2430
+ if removed_exact_excluded:
2431
+ log(
2432
+ f"Filtered {len(removed_exact_excluded)} excluded exact query phrases: "
2433
+ f"{', '.join(removed_exact_excluded)}"
2434
+ )
2435
  log("")
2436
 
2437
  rewrite_prefilled = (rewrite_override or "").strip()
 
2517
  log("Rewrite:")
2518
  log(rewritten if rewritten else "(empty)")
2519
  log("")
2520
+
2521
+ rewrite_for_retrieval = rewritten
2522
+ retrieval_query_hints = list(dict.fromkeys((user_tags or []) + (exact_query_phrases or [])))
2523
+ if retrieval_query_hints:
2524
+ # keep them separate in logs, but allow them to help retrieval
2525
+ rewrite_for_retrieval = (rewrite_for_retrieval + ", " + ", ".join(retrieval_query_hints)).strip(", ").strip()
2526
 
2527
 
2528
  log("Step 2: Prompt Squirrel retrieval (hidden)")
docs/rewrite_contract.md CHANGED
@@ -76,6 +76,11 @@ Outside Stage 1 itself, `app.py` also computes heuristic short phrases via:
76
  - split on `.` and `,`
77
  - keep segments with <= 3 tokens
78
  - case-insensitive dedupe
 
 
 
 
 
79
 
80
  These heuristic terms are later appended to retrieval input only if rewrite succeeds.
81
 
 
76
  - split on `.` and `,`
77
  - keep segments with <= 3 tokens
78
  - case-insensitive dedupe
79
+ - `extract_exact_tag_query_phrases()`
80
+ - scan prompt text for exact 1- to N-gram canonical tag or alias matches
81
+ - app default N is 2 (`PSQ_RETRIEVAL_EXACT_NGRAM_MAX`)
82
+ - matches must resolve to at least one canonical tag that clears `PSQ_MIN_TAG_COUNT`
83
+ - longest matches suppress their own component unigrams
84
 
85
  These heuristic terms are later appended to retrieval input only if rewrite succeeds.
86
 
psq_rag/pipeline/preproc.py CHANGED
@@ -1,6 +1,11 @@
1
- import re
2
-
3
- def extract_user_provided_tags_upto_3_words(prompt_in: str) -> list[str]:
 
 
 
 
 
4
  """
5
  Heuristic:
6
  - split on '.' and ','
@@ -27,10 +32,69 @@ def extract_user_provided_tags_upto_3_words(prompt_in: str) -> list[str]:
27
  if key not in seen:
28
  seen.add(key)
29
  out.append(item)
30
-
31
- return out
32
-
33
-
34
- if __name__ == "__main__":
35
- print("preproc.py imports ok")
36
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Mapping, Sequence
3
+
4
+
5
+ _TOKEN_RE = re.compile(r"[a-z0-9]+(?:'[a-z0-9]+)?")
6
+
7
+
8
+ def extract_user_provided_tags_upto_3_words(prompt_in: str) -> list[str]:
9
  """
10
  Heuristic:
11
  - split on '.' and ','
 
32
  if key not in seen:
33
  seen.add(key)
34
  out.append(item)
35
+
36
+ return out
37
+
38
+
39
+ def extract_exact_tag_query_phrases(
40
+ prompt_in: str,
41
+ tag_counts: Mapping[str, int],
42
+ alias2tags: Mapping[str, Sequence[str]],
43
+ *,
44
+ min_tag_count: int = 0,
45
+ max_ngram: int = 2,
46
+ ) -> list[str]:
47
+ """Extract exact canonical/alias n-gram matches as retrieval query phrases.
48
+
49
+ The output is conservative: every emitted phrase either is a canonical tag or
50
+ resolves through the alias map to at least one canonical tag that clears the
51
+ count floor. Longest matches win, so a matched 2-gram suppresses its own
52
+ component 1-grams.
53
+ """
54
+ if not prompt_in or max_ngram <= 0:
55
+ return []
56
+
57
+ text = prompt_in.strip()
58
+ prefix = "caption_to_tags:"
59
+ if text.lower().startswith(prefix):
60
+ text = text[len(prefix):].strip()
61
+
62
+ tokens = _TOKEN_RE.findall(text.lower())
63
+ if not tokens:
64
+ return []
65
+
66
+ def _count_ok(tag: str) -> bool:
67
+ if min_tag_count <= 0:
68
+ return True
69
+ return int(tag_counts.get(tag, 0) or 0) >= min_tag_count
70
+
71
+ def _resolves(lookup: str) -> bool:
72
+ if lookup in tag_counts:
73
+ return _count_ok(lookup)
74
+ return any(_count_ok(tag) for tag in alias2tags.get(lookup, ()))
75
+
76
+ matches: list[tuple[int, int, str]] = []
77
+ max_n = min(max(1, int(max_ngram)), len(tokens))
78
+ for n in range(max_n, 0, -1):
79
+ for start in range(0, len(tokens) - n + 1):
80
+ lookup = "_".join(tokens[start:start + n])
81
+ if _resolves(lookup):
82
+ matches.append((start, start + n, lookup))
83
+
84
+ used: set[int] = set()
85
+ selected: list[tuple[int, str]] = []
86
+ seen: set[str] = set()
87
+ for start, end, lookup in matches:
88
+ span = set(range(start, end))
89
+ if span & used or lookup in seen:
90
+ continue
91
+ used.update(span)
92
+ seen.add(lookup)
93
+ selected.append((start, lookup))
94
+
95
+ selected.sort(key=lambda row: row[0])
96
+ return [lookup for _, lookup in selected]
97
+
98
+
99
+ if __name__ == "__main__":
100
+ print("preproc.py imports ok")
scripts/test_exact_tag_query_phrases.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import sys
3
+
4
+ repo_root = Path(__file__).resolve().parents[1]
5
+ sys.path.insert(0, str(repo_root))
6
+
7
+ from psq_rag.pipeline.preproc import extract_exact_tag_query_phrases
8
+
9
+
10
+ def assert_equal(actual, expected, message):
11
+ if actual != expected:
12
+ raise AssertionError(f"{message}: expected {expected!r}, got {actual!r}")
13
+
14
+
15
+ def assert_in(item, values, message):
16
+ if item not in values:
17
+ raise AssertionError(f"{message}: {item!r} not in {values!r}")
18
+
19
+
20
+ def test_longest_match_suppresses_component_unigrams():
21
+ tag_counts = {
22
+ "red": 1000,
23
+ "fox": 1000,
24
+ "red_fox": 300,
25
+ "burrito": 164,
26
+ }
27
+ phrases = extract_exact_tag_query_phrases(
28
+ "A red fox eating a giant burrito",
29
+ tag_counts,
30
+ {},
31
+ min_tag_count=100,
32
+ max_ngram=2,
33
+ )
34
+ assert_equal(phrases, ["red_fox", "burrito"], "2-gram should suppress its component 1-grams")
35
+
36
+
37
+ def test_alias_resolution_uses_target_count_floor():
38
+ tag_counts = {
39
+ "hotdog": 150,
40
+ "low_count_tag": 99,
41
+ }
42
+ alias2tags = {
43
+ "hot_dog": ["hotdog"],
44
+ "rare_alias": ["low_count_tag"],
45
+ }
46
+ phrases = extract_exact_tag_query_phrases(
47
+ "A hot dog and rare alias",
48
+ tag_counts,
49
+ alias2tags,
50
+ min_tag_count=100,
51
+ max_ngram=2,
52
+ )
53
+ assert_equal(phrases, ["hot_dog"], "alias phrase should emit only when a target clears min count")
54
+
55
+
56
+ def test_caption_prefix_is_ignored():
57
+ tag_counts = {"caption": 1000, "red_fox": 300}
58
+ phrases = extract_exact_tag_query_phrases(
59
+ "caption_to_tags: red fox",
60
+ tag_counts,
61
+ {},
62
+ min_tag_count=100,
63
+ max_ngram=2,
64
+ )
65
+ assert_equal(phrases, ["red_fox"], "task prefix should not contribute tag query phrases")
66
+
67
+
68
+ def test_real_assets_find_burrito_and_retrieve_it():
69
+ from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
70
+ from psq_rag.retrieval.state import get_alias2tags, get_tag_counts
71
+
72
+ tag_counts = get_tag_counts()
73
+ phrases = extract_exact_tag_query_phrases(
74
+ "A red fox eating a giant burrito",
75
+ tag_counts,
76
+ get_alias2tags(),
77
+ min_tag_count=100,
78
+ max_ngram=2,
79
+ )
80
+ assert_in("red_fox", phrases, "real asset extraction should find red_fox")
81
+ assert_in("burrito", phrases, "real asset extraction should find burrito")
82
+
83
+ candidates = psq_candidates_from_rewrite_phrases(
84
+ rewrite_phrases=phrases,
85
+ allow_nsfw_tags=False,
86
+ min_tag_count=100,
87
+ per_phrase_k=10,
88
+ per_phrase_final_k=1,
89
+ global_k=300,
90
+ )
91
+ tags = {candidate.tag for candidate in candidates}
92
+ assert_in("burrito", tags, "exact burrito query phrase should retrieve burrito")
93
+
94
+
95
+ def main():
96
+ test_longest_match_suppresses_component_unigrams()
97
+ test_alias_resolution_uses_target_count_floor()
98
+ test_caption_prefix_is_ignored()
99
+ test_real_assets_find_burrito_and_retrieve_it()
100
+ print("exact tag query phrase tests: PASS")
101
+
102
+
103
+ if __name__ == "__main__":
104
+ main()