Spaces:
Running
Running
Food Desert commited on
Commit ·
29b12cd
1
Parent(s): c3191c6
Add exact n-gram retrieval query hints
Browse files- .gitignore +8 -0
- app.py +37 -8
- docs/rewrite_contract.md +5 -0
- psq_rag/pipeline/preproc.py +74 -10
- scripts/test_exact_tag_query_phrases.py +104 -0
.gitignore
CHANGED
|
@@ -29,6 +29,10 @@ data/analysis/openrouter_concurrency_*.json
|
|
| 29 |
data/analysis/pipeline_call_count_probe*.json
|
| 30 |
data/analysis/rewrite_only_compare_*.json
|
| 31 |
data/analysis/rewrite_ablation_*.json
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
data/analysis/t5_sweep_two_stage_*.json
|
| 33 |
data/analysis/t5_sweep_two_stage_*.csv
|
| 34 |
data/analysis/tmp_ckpt_compare_*.json
|
|
@@ -46,3 +50,7 @@ data/eval_results/tmp_llm_rewrite_diag*.jsonl
|
|
| 46 |
data/eval_results/eval_caption_cogvlm_n30_llm_heur_*_20260509.jsonl
|
| 47 |
data/eval_results/eval_caption_cogvlm_n30_t5_heur_*_20260509.jsonl
|
| 48 |
data/eval_results/eval_caption_cogvlm_n1_seed42_20260509_005007.jsonl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
data/analysis/pipeline_call_count_probe*.json
|
| 30 |
data/analysis/rewrite_only_compare_*.json
|
| 31 |
data/analysis/rewrite_ablation_*.json
|
| 32 |
+
data/analysis/retrieval_ngram_recovery_*.json
|
| 33 |
+
data/analysis/retrieval_ngram_recovery_*.csv
|
| 34 |
+
data/analysis/t5_tag_frequency_profile_*.json
|
| 35 |
+
data/analysis/t5_tag_frequency_profile_*.csv
|
| 36 |
data/analysis/t5_sweep_two_stage_*.json
|
| 37 |
data/analysis/t5_sweep_two_stage_*.csv
|
| 38 |
data/analysis/tmp_ckpt_compare_*.json
|
|
|
|
| 50 |
data/eval_results/eval_caption_cogvlm_n30_llm_heur_*_20260509.jsonl
|
| 51 |
data/eval_results/eval_caption_cogvlm_n30_t5_heur_*_20260509.jsonl
|
| 52 |
data/eval_results/eval_caption_cogvlm_n1_seed42_20260509_005007.jsonl
|
| 53 |
+
|
| 54 |
+
# Temporary local profiling helpers
|
| 55 |
+
scripts/profile_retrieval_ngram_recovery.py
|
| 56 |
+
scripts/profile_t5_tag_frequency.py
|
app.py
CHANGED
|
@@ -78,7 +78,10 @@ if _STARTUP_PROFILE_ON and _STARTUP_PROFILE_PATH is not None:
|
|
| 78 |
import gradio as gr
|
| 79 |
_startup_profile_mark("import.gradio.done")
|
| 80 |
|
| 81 |
-
from psq_rag.pipeline.preproc import
|
|
|
|
|
|
|
|
|
|
| 82 |
_startup_profile_mark("import.psq_rag.pipeline.preproc.done")
|
| 83 |
from psq_rag.llm.rewrite import llm_rewrite_prompt
|
| 84 |
_startup_profile_mark("import.psq_rag.llm.rewrite.done")
|
|
@@ -93,6 +96,7 @@ from psq_rag.retrieval.state import (
|
|
| 93 |
get_tag_type_name,
|
| 94 |
get_tag_implications,
|
| 95 |
get_tag_counts,
|
|
|
|
| 96 |
)
|
| 97 |
_startup_profile_mark("import.psq_rag.retrieval.state.done")
|
| 98 |
from psq_rag.ui.group_ranked_display import rank_groups_from_tfidf, _load_enabled_groups
|
|
@@ -1474,9 +1478,10 @@ display_top_groups_default = int(os.environ.get("PSQ_DISPLAY_TOP_GROUPS", "10"))
|
|
| 1474 |
display_top_tags_per_group_default = int(os.environ.get("PSQ_DISPLAY_TOP_TAGS_PER_GROUP", "7"))
|
| 1475 |
display_rank_top_k_default = int(os.environ.get("PSQ_DISPLAY_GROUP_RANK_TOP_K", "7"))
|
| 1476 |
display_max_rows_default = int(os.environ.get("PSQ_DISPLAY_MAX_ROWS", "14"))
|
| 1477 |
-
retrieval_global_k = int(os.environ.get("PSQ_RETRIEVAL_GLOBAL_K", "300"))
|
| 1478 |
-
retrieval_per_phrase_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_K", "10"))
|
| 1479 |
retrieval_per_phrase_final_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_FINAL_K", "1"))
|
|
|
|
| 1480 |
selection_mode = os.environ.get("PSQ_SELECTION_MODE", "chunked_map_union").strip()
|
| 1481 |
selection_chunk_size = int(os.environ.get("PSQ_SELECTION_CHUNK_SIZE", "60"))
|
| 1482 |
selection_per_phrase_k = int(os.environ.get("PSQ_SELECTION_PER_PHRASE_K", "2"))
|
|
@@ -2360,6 +2365,7 @@ def rag_pipeline_ui(
|
|
| 2360 |
f"retrieval_global_k={retrieval_global_k} "
|
| 2361 |
f"retrieval_per_phrase_k={retrieval_per_phrase_k} "
|
| 2362 |
f"retrieval_per_phrase_final_k={retrieval_per_phrase_final_k} "
|
|
|
|
| 2363 |
f"selection_mode={selection_mode} "
|
| 2364 |
f"selection_chunk_size={selection_chunk_size} "
|
| 2365 |
f"selection_per_phrase_k={selection_per_phrase_k} "
|
|
@@ -2386,6 +2392,14 @@ def rag_pipeline_ui(
|
|
| 2386 |
user_tags_raw = extract_user_provided_tags_upto_3_words(prompt_in)
|
| 2387 |
user_tags, removed_user_low = _filter_min_count_tags(user_tags_raw, min_tag_count)
|
| 2388 |
user_tags, removed_user_excluded = _filter_excluded_recommendation_tags(user_tags)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2389 |
dt = time.perf_counter()-t0
|
| 2390 |
_record_timing("preprocess", dt)
|
| 2391 |
log(f"Preprocess (user tag extraction): {dt:.2f}s")
|
|
@@ -2404,6 +2418,20 @@ def rag_pipeline_ui(
|
|
| 2404 |
f"Filtered {len(removed_user_excluded)} excluded user tags: "
|
| 2405 |
f"{', '.join(removed_user_excluded)}"
|
| 2406 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2407 |
log("")
|
| 2408 |
|
| 2409 |
rewrite_prefilled = (rewrite_override or "").strip()
|
|
@@ -2489,11 +2517,12 @@ def rag_pipeline_ui(
|
|
| 2489 |
log("Rewrite:")
|
| 2490 |
log(rewritten if rewritten else "(empty)")
|
| 2491 |
log("")
|
| 2492 |
-
|
| 2493 |
-
rewrite_for_retrieval = rewritten
|
| 2494 |
-
|
| 2495 |
-
|
| 2496 |
-
|
|
|
|
| 2497 |
|
| 2498 |
|
| 2499 |
log("Step 2: Prompt Squirrel retrieval (hidden)")
|
|
|
|
| 78 |
import gradio as gr
|
| 79 |
_startup_profile_mark("import.gradio.done")
|
| 80 |
|
| 81 |
+
from psq_rag.pipeline.preproc import (
|
| 82 |
+
extract_exact_tag_query_phrases,
|
| 83 |
+
extract_user_provided_tags_upto_3_words,
|
| 84 |
+
)
|
| 85 |
_startup_profile_mark("import.psq_rag.pipeline.preproc.done")
|
| 86 |
from psq_rag.llm.rewrite import llm_rewrite_prompt
|
| 87 |
_startup_profile_mark("import.psq_rag.llm.rewrite.done")
|
|
|
|
| 96 |
get_tag_type_name,
|
| 97 |
get_tag_implications,
|
| 98 |
get_tag_counts,
|
| 99 |
+
get_alias2tags,
|
| 100 |
)
|
| 101 |
_startup_profile_mark("import.psq_rag.retrieval.state.done")
|
| 102 |
from psq_rag.ui.group_ranked_display import rank_groups_from_tfidf, _load_enabled_groups
|
|
|
|
| 1478 |
display_top_tags_per_group_default = int(os.environ.get("PSQ_DISPLAY_TOP_TAGS_PER_GROUP", "7"))
|
| 1479 |
display_rank_top_k_default = int(os.environ.get("PSQ_DISPLAY_GROUP_RANK_TOP_K", "7"))
|
| 1480 |
display_max_rows_default = int(os.environ.get("PSQ_DISPLAY_MAX_ROWS", "14"))
|
| 1481 |
+
retrieval_global_k = int(os.environ.get("PSQ_RETRIEVAL_GLOBAL_K", "300"))
|
| 1482 |
+
retrieval_per_phrase_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_K", "10"))
|
| 1483 |
retrieval_per_phrase_final_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_FINAL_K", "1"))
|
| 1484 |
+
retrieval_exact_ngram_max = int(os.environ.get("PSQ_RETRIEVAL_EXACT_NGRAM_MAX", "2"))
|
| 1485 |
selection_mode = os.environ.get("PSQ_SELECTION_MODE", "chunked_map_union").strip()
|
| 1486 |
selection_chunk_size = int(os.environ.get("PSQ_SELECTION_CHUNK_SIZE", "60"))
|
| 1487 |
selection_per_phrase_k = int(os.environ.get("PSQ_SELECTION_PER_PHRASE_K", "2"))
|
|
|
|
| 2365 |
f"retrieval_global_k={retrieval_global_k} "
|
| 2366 |
f"retrieval_per_phrase_k={retrieval_per_phrase_k} "
|
| 2367 |
f"retrieval_per_phrase_final_k={retrieval_per_phrase_final_k} "
|
| 2368 |
+
f"retrieval_exact_ngram_max={retrieval_exact_ngram_max} "
|
| 2369 |
f"selection_mode={selection_mode} "
|
| 2370 |
f"selection_chunk_size={selection_chunk_size} "
|
| 2371 |
f"selection_per_phrase_k={selection_per_phrase_k} "
|
|
|
|
| 2392 |
user_tags_raw = extract_user_provided_tags_upto_3_words(prompt_in)
|
| 2393 |
user_tags, removed_user_low = _filter_min_count_tags(user_tags_raw, min_tag_count)
|
| 2394 |
user_tags, removed_user_excluded = _filter_excluded_recommendation_tags(user_tags)
|
| 2395 |
+
exact_query_phrases = extract_exact_tag_query_phrases(
|
| 2396 |
+
prompt_in,
|
| 2397 |
+
get_tag_counts(),
|
| 2398 |
+
get_alias2tags(),
|
| 2399 |
+
min_tag_count=min_tag_count,
|
| 2400 |
+
max_ngram=max(0, retrieval_exact_ngram_max),
|
| 2401 |
+
)
|
| 2402 |
+
exact_query_phrases, removed_exact_excluded = _filter_excluded_recommendation_tags(exact_query_phrases)
|
| 2403 |
dt = time.perf_counter()-t0
|
| 2404 |
_record_timing("preprocess", dt)
|
| 2405 |
log(f"Preprocess (user tag extraction): {dt:.2f}s")
|
|
|
|
| 2418 |
f"Filtered {len(removed_user_excluded)} excluded user tags: "
|
| 2419 |
f"{', '.join(removed_user_excluded)}"
|
| 2420 |
)
|
| 2421 |
+
if retrieval_exact_ngram_max > 0:
|
| 2422 |
+
log(f"Exact caption tag query phrases (1-{retrieval_exact_ngram_max} grams):")
|
| 2423 |
+
else:
|
| 2424 |
+
log("Exact caption tag query phrases: disabled")
|
| 2425 |
+
if exact_query_phrases:
|
| 2426 |
+
shown = ", ".join(exact_query_phrases[:40])
|
| 2427 |
+
log(shown + (" ..." if len(exact_query_phrases) > 40 else ""))
|
| 2428 |
+
else:
|
| 2429 |
+
log("(none)")
|
| 2430 |
+
if removed_exact_excluded:
|
| 2431 |
+
log(
|
| 2432 |
+
f"Filtered {len(removed_exact_excluded)} excluded exact query phrases: "
|
| 2433 |
+
f"{', '.join(removed_exact_excluded)}"
|
| 2434 |
+
)
|
| 2435 |
log("")
|
| 2436 |
|
| 2437 |
rewrite_prefilled = (rewrite_override or "").strip()
|
|
|
|
| 2517 |
log("Rewrite:")
|
| 2518 |
log(rewritten if rewritten else "(empty)")
|
| 2519 |
log("")
|
| 2520 |
+
|
| 2521 |
+
rewrite_for_retrieval = rewritten
|
| 2522 |
+
retrieval_query_hints = list(dict.fromkeys((user_tags or []) + (exact_query_phrases or [])))
|
| 2523 |
+
if retrieval_query_hints:
|
| 2524 |
+
# keep them separate in logs, but allow them to help retrieval
|
| 2525 |
+
rewrite_for_retrieval = (rewrite_for_retrieval + ", " + ", ".join(retrieval_query_hints)).strip(", ").strip()
|
| 2526 |
|
| 2527 |
|
| 2528 |
log("Step 2: Prompt Squirrel retrieval (hidden)")
|
docs/rewrite_contract.md
CHANGED
|
@@ -76,6 +76,11 @@ Outside Stage 1 itself, `app.py` also computes heuristic short phrases via:
|
|
| 76 |
- split on `.` and `,`
|
| 77 |
- keep segments with <= 3 tokens
|
| 78 |
- case-insensitive dedupe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
These heuristic terms are later appended to retrieval input only if rewrite succeeds.
|
| 81 |
|
|
|
|
| 76 |
- split on `.` and `,`
|
| 77 |
- keep segments with <= 3 tokens
|
| 78 |
- case-insensitive dedupe
|
| 79 |
+
- `extract_exact_tag_query_phrases()`
|
| 80 |
+
- scan prompt text for exact 1- to N-gram canonical tag or alias matches
|
| 81 |
+
- app default N is 2 (`PSQ_RETRIEVAL_EXACT_NGRAM_MAX`)
|
| 82 |
+
- matches must resolve to at least one canonical tag that clears `PSQ_MIN_TAG_COUNT`
|
| 83 |
+
- longest matches suppress their own component unigrams
|
| 84 |
|
| 85 |
These heuristic terms are later appended to retrieval input only if rewrite succeeds.
|
| 86 |
|
psq_rag/pipeline/preproc.py
CHANGED
|
@@ -1,6 +1,11 @@
|
|
| 1 |
-
import re
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
Heuristic:
|
| 6 |
- split on '.' and ','
|
|
@@ -27,10 +32,69 @@ def extract_user_provided_tags_upto_3_words(prompt_in: str) -> list[str]:
|
|
| 27 |
if key not in seen:
|
| 28 |
seen.add(key)
|
| 29 |
out.append(item)
|
| 30 |
-
|
| 31 |
-
return out
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Mapping, Sequence
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
_TOKEN_RE = re.compile(r"[a-z0-9]+(?:'[a-z0-9]+)?")
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def extract_user_provided_tags_upto_3_words(prompt_in: str) -> list[str]:
|
| 9 |
"""
|
| 10 |
Heuristic:
|
| 11 |
- split on '.' and ','
|
|
|
|
| 32 |
if key not in seen:
|
| 33 |
seen.add(key)
|
| 34 |
out.append(item)
|
| 35 |
+
|
| 36 |
+
return out
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def extract_exact_tag_query_phrases(
|
| 40 |
+
prompt_in: str,
|
| 41 |
+
tag_counts: Mapping[str, int],
|
| 42 |
+
alias2tags: Mapping[str, Sequence[str]],
|
| 43 |
+
*,
|
| 44 |
+
min_tag_count: int = 0,
|
| 45 |
+
max_ngram: int = 2,
|
| 46 |
+
) -> list[str]:
|
| 47 |
+
"""Extract exact canonical/alias n-gram matches as retrieval query phrases.
|
| 48 |
+
|
| 49 |
+
The output is conservative: every emitted phrase either is a canonical tag or
|
| 50 |
+
resolves through the alias map to at least one canonical tag that clears the
|
| 51 |
+
count floor. Longest matches win, so a matched 2-gram suppresses its own
|
| 52 |
+
component 1-grams.
|
| 53 |
+
"""
|
| 54 |
+
if not prompt_in or max_ngram <= 0:
|
| 55 |
+
return []
|
| 56 |
+
|
| 57 |
+
text = prompt_in.strip()
|
| 58 |
+
prefix = "caption_to_tags:"
|
| 59 |
+
if text.lower().startswith(prefix):
|
| 60 |
+
text = text[len(prefix):].strip()
|
| 61 |
+
|
| 62 |
+
tokens = _TOKEN_RE.findall(text.lower())
|
| 63 |
+
if not tokens:
|
| 64 |
+
return []
|
| 65 |
+
|
| 66 |
+
def _count_ok(tag: str) -> bool:
|
| 67 |
+
if min_tag_count <= 0:
|
| 68 |
+
return True
|
| 69 |
+
return int(tag_counts.get(tag, 0) or 0) >= min_tag_count
|
| 70 |
+
|
| 71 |
+
def _resolves(lookup: str) -> bool:
|
| 72 |
+
if lookup in tag_counts:
|
| 73 |
+
return _count_ok(lookup)
|
| 74 |
+
return any(_count_ok(tag) for tag in alias2tags.get(lookup, ()))
|
| 75 |
+
|
| 76 |
+
matches: list[tuple[int, int, str]] = []
|
| 77 |
+
max_n = min(max(1, int(max_ngram)), len(tokens))
|
| 78 |
+
for n in range(max_n, 0, -1):
|
| 79 |
+
for start in range(0, len(tokens) - n + 1):
|
| 80 |
+
lookup = "_".join(tokens[start:start + n])
|
| 81 |
+
if _resolves(lookup):
|
| 82 |
+
matches.append((start, start + n, lookup))
|
| 83 |
+
|
| 84 |
+
used: set[int] = set()
|
| 85 |
+
selected: list[tuple[int, str]] = []
|
| 86 |
+
seen: set[str] = set()
|
| 87 |
+
for start, end, lookup in matches:
|
| 88 |
+
span = set(range(start, end))
|
| 89 |
+
if span & used or lookup in seen:
|
| 90 |
+
continue
|
| 91 |
+
used.update(span)
|
| 92 |
+
seen.add(lookup)
|
| 93 |
+
selected.append((start, lookup))
|
| 94 |
+
|
| 95 |
+
selected.sort(key=lambda row: row[0])
|
| 96 |
+
return [lookup for _, lookup in selected]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
print("preproc.py imports ok")
|
scripts/test_exact_tag_query_phrases.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
repo_root = Path(__file__).resolve().parents[1]
|
| 5 |
+
sys.path.insert(0, str(repo_root))
|
| 6 |
+
|
| 7 |
+
from psq_rag.pipeline.preproc import extract_exact_tag_query_phrases
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def assert_equal(actual, expected, message):
|
| 11 |
+
if actual != expected:
|
| 12 |
+
raise AssertionError(f"{message}: expected {expected!r}, got {actual!r}")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def assert_in(item, values, message):
|
| 16 |
+
if item not in values:
|
| 17 |
+
raise AssertionError(f"{message}: {item!r} not in {values!r}")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_longest_match_suppresses_component_unigrams():
|
| 21 |
+
tag_counts = {
|
| 22 |
+
"red": 1000,
|
| 23 |
+
"fox": 1000,
|
| 24 |
+
"red_fox": 300,
|
| 25 |
+
"burrito": 164,
|
| 26 |
+
}
|
| 27 |
+
phrases = extract_exact_tag_query_phrases(
|
| 28 |
+
"A red fox eating a giant burrito",
|
| 29 |
+
tag_counts,
|
| 30 |
+
{},
|
| 31 |
+
min_tag_count=100,
|
| 32 |
+
max_ngram=2,
|
| 33 |
+
)
|
| 34 |
+
assert_equal(phrases, ["red_fox", "burrito"], "2-gram should suppress its component 1-grams")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_alias_resolution_uses_target_count_floor():
|
| 38 |
+
tag_counts = {
|
| 39 |
+
"hotdog": 150,
|
| 40 |
+
"low_count_tag": 99,
|
| 41 |
+
}
|
| 42 |
+
alias2tags = {
|
| 43 |
+
"hot_dog": ["hotdog"],
|
| 44 |
+
"rare_alias": ["low_count_tag"],
|
| 45 |
+
}
|
| 46 |
+
phrases = extract_exact_tag_query_phrases(
|
| 47 |
+
"A hot dog and rare alias",
|
| 48 |
+
tag_counts,
|
| 49 |
+
alias2tags,
|
| 50 |
+
min_tag_count=100,
|
| 51 |
+
max_ngram=2,
|
| 52 |
+
)
|
| 53 |
+
assert_equal(phrases, ["hot_dog"], "alias phrase should emit only when a target clears min count")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def test_caption_prefix_is_ignored():
|
| 57 |
+
tag_counts = {"caption": 1000, "red_fox": 300}
|
| 58 |
+
phrases = extract_exact_tag_query_phrases(
|
| 59 |
+
"caption_to_tags: red fox",
|
| 60 |
+
tag_counts,
|
| 61 |
+
{},
|
| 62 |
+
min_tag_count=100,
|
| 63 |
+
max_ngram=2,
|
| 64 |
+
)
|
| 65 |
+
assert_equal(phrases, ["red_fox"], "task prefix should not contribute tag query phrases")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def test_real_assets_find_burrito_and_retrieve_it():
|
| 69 |
+
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases
|
| 70 |
+
from psq_rag.retrieval.state import get_alias2tags, get_tag_counts
|
| 71 |
+
|
| 72 |
+
tag_counts = get_tag_counts()
|
| 73 |
+
phrases = extract_exact_tag_query_phrases(
|
| 74 |
+
"A red fox eating a giant burrito",
|
| 75 |
+
tag_counts,
|
| 76 |
+
get_alias2tags(),
|
| 77 |
+
min_tag_count=100,
|
| 78 |
+
max_ngram=2,
|
| 79 |
+
)
|
| 80 |
+
assert_in("red_fox", phrases, "real asset extraction should find red_fox")
|
| 81 |
+
assert_in("burrito", phrases, "real asset extraction should find burrito")
|
| 82 |
+
|
| 83 |
+
candidates = psq_candidates_from_rewrite_phrases(
|
| 84 |
+
rewrite_phrases=phrases,
|
| 85 |
+
allow_nsfw_tags=False,
|
| 86 |
+
min_tag_count=100,
|
| 87 |
+
per_phrase_k=10,
|
| 88 |
+
per_phrase_final_k=1,
|
| 89 |
+
global_k=300,
|
| 90 |
+
)
|
| 91 |
+
tags = {candidate.tag for candidate in candidates}
|
| 92 |
+
assert_in("burrito", tags, "exact burrito query phrase should retrieve burrito")
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def main():
|
| 96 |
+
test_longest_match_suppresses_component_unigrams()
|
| 97 |
+
test_alias_resolution_uses_target_count_floor()
|
| 98 |
+
test_caption_prefix_is_ignored()
|
| 99 |
+
test_real_assets_find_burrito_and_retrieve_it()
|
| 100 |
+
print("exact tag query phrase tests: PASS")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
main()
|