Spaces:
Running
Running
Food Desert commited on
Commit ·
a48a025
1
Parent(s): 334af6b
Add synchronized lego-style tag toggles and prompt builder UI
Browse files
app.py
CHANGED
|
@@ -6,7 +6,7 @@ import json
|
|
| 6 |
from datetime import datetime
|
| 7 |
from PIL import Image
|
| 8 |
from pathlib import Path
|
| 9 |
-
from typing import List
|
| 10 |
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
|
| 11 |
|
| 12 |
from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words
|
|
@@ -14,7 +14,7 @@ from psq_rag.llm.rewrite import llm_rewrite_prompt
|
|
| 14 |
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup
|
| 15 |
from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags, llm_infer_probe_tags
|
| 16 |
from psq_rag.retrieval.state import expand_tags_via_implications
|
| 17 |
-
from psq_rag.ui.group_ranked_display import
|
| 18 |
|
| 19 |
|
| 20 |
def _split_prompt_commas(s: str) -> List[str]:
|
|
@@ -40,6 +40,181 @@ def compose_final_prompt(rewritten_prompt: str, selected_tags: List[str]) -> str
|
|
| 40 |
return ", ".join(out)
|
| 41 |
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
def _build_selection_query(
|
| 44 |
prompt_in: str,
|
| 45 |
rewritten: str,
|
|
@@ -152,6 +327,7 @@ enable_probe_tags = os.environ.get("PSQ_ENABLE_PROBE", "1").strip() not in {"0",
|
|
| 152 |
display_top_groups_default = int(os.environ.get("PSQ_DISPLAY_TOP_GROUPS", "10"))
|
| 153 |
display_top_tags_per_group_default = int(os.environ.get("PSQ_DISPLAY_TOP_TAGS_PER_GROUP", "5"))
|
| 154 |
display_rank_top_k_default = int(os.environ.get("PSQ_DISPLAY_GROUP_RANK_TOP_K", "5"))
|
|
|
|
| 155 |
retrieval_global_k = int(os.environ.get("PSQ_RETRIEVAL_GLOBAL_K", "300"))
|
| 156 |
retrieval_per_phrase_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_K", "10"))
|
| 157 |
retrieval_per_phrase_final_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_FINAL_K", "1"))
|
|
@@ -189,6 +365,42 @@ css = """
|
|
| 189 |
.pane-right .scrollable-content {
|
| 190 |
max-height: 610px; /* was 420px; tweak to taste */
|
| 191 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
"""
|
| 193 |
|
| 194 |
|
|
@@ -275,7 +487,12 @@ def rag_pipeline_ui(
|
|
| 275 |
log("Start: received prompt")
|
| 276 |
prompt_in = (user_prompt or "").strip()
|
| 277 |
if not prompt_in:
|
| 278 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
log("Input:")
|
| 281 |
log(prompt_in)
|
|
@@ -439,6 +656,8 @@ def rag_pipeline_ui(
|
|
| 439 |
elif enable_probe_tags:
|
| 440 |
log(" No probe tags inferred")
|
| 441 |
|
|
|
|
|
|
|
| 442 |
log("Step 3c: Expand via tag implications")
|
| 443 |
t0 = time.perf_counter()
|
| 444 |
tag_set = set(selected_tags)
|
|
@@ -469,25 +688,36 @@ def rag_pipeline_ui(
|
|
| 469 |
seed_terms.extend(selected_tags)
|
| 470 |
seed_terms = list(dict.fromkeys(seed_terms))
|
| 471 |
|
| 472 |
-
|
| 473 |
seed_terms=seed_terms,
|
|
|
|
| 474 |
top_groups=max(1, int(display_top_groups)),
|
| 475 |
top_tags_per_group=max(1, int(display_top_tags_per_group)),
|
| 476 |
group_rank_top_k=max(1, int(display_rank_top_k)),
|
| 477 |
)
|
| 478 |
dt = time.perf_counter()-t0
|
| 479 |
_record_timing("group_display", dt)
|
| 480 |
-
log(f"Ranked group display: {dt:.2f}s")
|
| 481 |
|
| 482 |
total_dt = time.perf_counter()-t_total0
|
| 483 |
_emit_timing_summary(total_dt)
|
| 484 |
_append_timing_jsonl(total_dt)
|
| 485 |
log("Done: final prompt ready")
|
| 486 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
|
| 488 |
except Exception as e:
|
| 489 |
log(f"Error: {type(e).__name__}: {e}")
|
| 490 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
|
| 492 |
|
| 493 |
|
|
@@ -529,11 +759,44 @@ then returns a cleaned, model-friendly prompt.
|
|
| 529 |
placeholder="Progress logs will appear here."
|
| 530 |
)
|
| 531 |
|
| 532 |
-
|
| 533 |
-
label="
|
| 534 |
lines=3,
|
| 535 |
interactive=False,
|
| 536 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
)
|
| 538 |
|
| 539 |
with gr.Accordion("Display Settings", open=False):
|
|
@@ -557,22 +820,42 @@ then returns a cleaned, model-friendly prompt.
|
|
| 557 |
minimum=1,
|
| 558 |
)
|
| 559 |
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
|
| 565 |
submit_button.click(
|
| 566 |
rag_pipeline_ui,
|
| 567 |
inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k],
|
| 568 |
-
outputs=
|
| 569 |
)
|
| 570 |
|
| 571 |
image_tags.submit(
|
| 572 |
rag_pipeline_ui,
|
| 573 |
inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k],
|
| 574 |
-
outputs=
|
| 575 |
)
|
| 576 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
if __name__ == "__main__":
|
| 578 |
app.queue().launch(allowed_paths=[str(MASCOT_DIR)])
|
|
|
|
| 6 |
from datetime import datetime
|
| 7 |
from PIL import Image
|
| 8 |
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, List, Set
|
| 10 |
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
|
| 11 |
|
| 12 |
from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words
|
|
|
|
| 14 |
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup
|
| 15 |
from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags, llm_infer_probe_tags
|
| 16 |
from psq_rag.retrieval.state import expand_tags_via_implications
|
| 17 |
+
from psq_rag.ui.group_ranked_display import rank_groups_from_tfidf, _load_enabled_groups
|
| 18 |
|
| 19 |
|
| 20 |
def _split_prompt_commas(s: str) -> List[str]:
|
|
|
|
| 40 |
return ", ".join(out)
|
| 41 |
|
| 42 |
|
| 43 |
+
def _display_tag_text(tag: str) -> str:
|
| 44 |
+
return tag.replace("_", " ")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _escape_prompt_tag(tag: str) -> str:
|
| 48 |
+
return (
|
| 49 |
+
tag.replace("_", " ")
|
| 50 |
+
.replace("(", "\\(")
|
| 51 |
+
.replace(")", "\\)")
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _ordered_selected_for_prompt(selected: Set[str], row_defs: List[Dict[str, Any]]) -> List[str]:
|
| 56 |
+
out: List[str] = []
|
| 57 |
+
seen: Set[str] = set()
|
| 58 |
+
for row in row_defs:
|
| 59 |
+
for tag in row.get("tags", []):
|
| 60 |
+
if tag in selected and tag not in seen:
|
| 61 |
+
out.append(tag)
|
| 62 |
+
seen.add(tag)
|
| 63 |
+
# Fallback for any selected tags not present in current rows.
|
| 64 |
+
for tag in sorted(selected):
|
| 65 |
+
if tag not in seen:
|
| 66 |
+
out.append(tag)
|
| 67 |
+
seen.add(tag)
|
| 68 |
+
return out
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _compose_toggle_prompt_text(selected_tags: List[str], row_defs: List[Dict[str, Any]]) -> str:
|
| 72 |
+
selected = {t for t in (selected_tags or []) if t}
|
| 73 |
+
ordered = _ordered_selected_for_prompt(selected, row_defs or [])
|
| 74 |
+
return ", ".join(_escape_prompt_tag(t) for t in ordered)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _build_toggle_rows(
|
| 78 |
+
*,
|
| 79 |
+
seed_terms: List[str],
|
| 80 |
+
llm_selected_tags: List[str],
|
| 81 |
+
top_groups: int,
|
| 82 |
+
top_tags_per_group: int,
|
| 83 |
+
group_rank_top_k: int,
|
| 84 |
+
) -> List[Dict[str, Any]]:
|
| 85 |
+
ranked_rows = rank_groups_from_tfidf(
|
| 86 |
+
seed_terms=seed_terms,
|
| 87 |
+
top_groups=max(1, int(top_groups)),
|
| 88 |
+
top_tags_per_group=max(1, int(top_tags_per_group)),
|
| 89 |
+
group_rank_top_k=max(1, int(group_rank_top_k)),
|
| 90 |
+
)
|
| 91 |
+
groups_map = _load_enabled_groups()
|
| 92 |
+
llm_selected = list(dict.fromkeys(_norm_tag_for_lookup(t) for t in llm_selected_tags if t))
|
| 93 |
+
|
| 94 |
+
row_defs: List[Dict[str, Any]] = []
|
| 95 |
+
displayed_group_names = [r.group_name for r in ranked_rows]
|
| 96 |
+
displayed_group_tag_sets: Dict[str, Set[str]] = {
|
| 97 |
+
name: set(groups_map.get(name, [])) for name in displayed_group_names
|
| 98 |
+
}
|
| 99 |
+
tags_in_any_displayed_group: Set[str] = set()
|
| 100 |
+
for tag_set in displayed_group_tag_sets.values():
|
| 101 |
+
tags_in_any_displayed_group.update(tag_set)
|
| 102 |
+
|
| 103 |
+
llm_other = [t for t in llm_selected if t not in tags_in_any_displayed_group]
|
| 104 |
+
row_defs.append(
|
| 105 |
+
{
|
| 106 |
+
"name": "llm_selected_other",
|
| 107 |
+
"label": "LLM Selected (Other)",
|
| 108 |
+
"tags": llm_other,
|
| 109 |
+
}
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
for row in ranked_rows:
|
| 113 |
+
group_name = row.group_name
|
| 114 |
+
group_tag_set = displayed_group_tag_sets.get(group_name, set())
|
| 115 |
+
selected_in_group = [t for t in llm_selected if t in group_tag_set]
|
| 116 |
+
ranked_tags = [t for t, _ in row.tags]
|
| 117 |
+
merged = selected_in_group + [t for t in ranked_tags if t not in selected_in_group]
|
| 118 |
+
keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_group))
|
| 119 |
+
merged = merged[:keep_n]
|
| 120 |
+
row_defs.append(
|
| 121 |
+
{
|
| 122 |
+
"name": group_name,
|
| 123 |
+
"label": f"{group_name} (E={row.expected_count:.2f})",
|
| 124 |
+
"tags": merged,
|
| 125 |
+
}
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
return row_defs
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def _build_row_component_updates(
|
| 132 |
+
row_defs: List[Dict[str, Any]],
|
| 133 |
+
selected_tags: List[str],
|
| 134 |
+
max_rows: int,
|
| 135 |
+
):
|
| 136 |
+
selected = {t for t in (selected_tags or []) if t}
|
| 137 |
+
row_values_state: List[List[str]] = []
|
| 138 |
+
header_updates = []
|
| 139 |
+
checkbox_updates = []
|
| 140 |
+
|
| 141 |
+
for idx in range(max_rows):
|
| 142 |
+
if idx < len(row_defs):
|
| 143 |
+
row = row_defs[idx]
|
| 144 |
+
tags = list(dict.fromkeys(row.get("tags", [])))
|
| 145 |
+
values = [t for t in tags if t in selected]
|
| 146 |
+
row_values_state.append(values)
|
| 147 |
+
visible = bool(tags)
|
| 148 |
+
header_updates.append(gr.update(value=f"**{row.get('label', '')}**", visible=visible))
|
| 149 |
+
choices = [(_display_tag_text(t), t) for t in tags]
|
| 150 |
+
checkbox_updates.append(
|
| 151 |
+
gr.update(
|
| 152 |
+
choices=choices,
|
| 153 |
+
value=values,
|
| 154 |
+
visible=visible,
|
| 155 |
+
)
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
header_updates.append(gr.update(value="", visible=False))
|
| 159 |
+
checkbox_updates.append(gr.update(choices=[], value=[], visible=False))
|
| 160 |
+
|
| 161 |
+
prompt_text = _compose_toggle_prompt_text(list(selected), row_defs)
|
| 162 |
+
return prompt_text, row_values_state, header_updates, checkbox_updates
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _on_toggle_row(
|
| 166 |
+
row_idx: int,
|
| 167 |
+
changed_values: List[str],
|
| 168 |
+
selected_tags_state: List[str],
|
| 169 |
+
row_defs_state: List[Dict[str, Any]],
|
| 170 |
+
row_values_state: List[List[str]],
|
| 171 |
+
max_rows: int,
|
| 172 |
+
):
|
| 173 |
+
row_defs = row_defs_state or []
|
| 174 |
+
selected = set(selected_tags_state or [])
|
| 175 |
+
prev_values = list(row_values_state or [])
|
| 176 |
+
|
| 177 |
+
while len(prev_values) < len(row_defs):
|
| 178 |
+
prev_values.append([])
|
| 179 |
+
|
| 180 |
+
prev_set = set(prev_values[row_idx]) if row_idx < len(prev_values) else set()
|
| 181 |
+
new_set = set(changed_values or [])
|
| 182 |
+
selected.update(new_set - prev_set)
|
| 183 |
+
selected.difference_update(prev_set - new_set)
|
| 184 |
+
|
| 185 |
+
prompt_text, new_row_values_state, _header_updates, checkbox_updates = _build_row_component_updates(
|
| 186 |
+
row_defs=row_defs,
|
| 187 |
+
selected_tags=list(selected),
|
| 188 |
+
max_rows=max_rows,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
return [sorted(selected), new_row_values_state, prompt_text, *checkbox_updates]
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _build_ui_payload(
|
| 195 |
+
*,
|
| 196 |
+
console_text: str,
|
| 197 |
+
legacy_prompt_text: str,
|
| 198 |
+
row_defs: List[Dict[str, Any]],
|
| 199 |
+
selected_tags: List[str],
|
| 200 |
+
):
|
| 201 |
+
prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates(
|
| 202 |
+
row_defs=row_defs,
|
| 203 |
+
selected_tags=selected_tags,
|
| 204 |
+
max_rows=display_max_rows_default,
|
| 205 |
+
)
|
| 206 |
+
return [
|
| 207 |
+
console_text,
|
| 208 |
+
legacy_prompt_text,
|
| 209 |
+
prompt_text,
|
| 210 |
+
sorted(set(selected_tags or [])),
|
| 211 |
+
row_defs,
|
| 212 |
+
row_values_state,
|
| 213 |
+
*header_updates,
|
| 214 |
+
*checkbox_updates,
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
|
| 218 |
def _build_selection_query(
|
| 219 |
prompt_in: str,
|
| 220 |
rewritten: str,
|
|
|
|
| 327 |
display_top_groups_default = int(os.environ.get("PSQ_DISPLAY_TOP_GROUPS", "10"))
|
| 328 |
display_top_tags_per_group_default = int(os.environ.get("PSQ_DISPLAY_TOP_TAGS_PER_GROUP", "5"))
|
| 329 |
display_rank_top_k_default = int(os.environ.get("PSQ_DISPLAY_GROUP_RANK_TOP_K", "5"))
|
| 330 |
+
display_max_rows_default = int(os.environ.get("PSQ_DISPLAY_MAX_ROWS", "14"))
|
| 331 |
retrieval_global_k = int(os.environ.get("PSQ_RETRIEVAL_GLOBAL_K", "300"))
|
| 332 |
retrieval_per_phrase_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_K", "10"))
|
| 333 |
retrieval_per_phrase_final_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_FINAL_K", "1"))
|
|
|
|
| 365 |
.pane-right .scrollable-content {
|
| 366 |
max-height: 610px; /* was 420px; tweak to taste */
|
| 367 |
}
|
| 368 |
+
|
| 369 |
+
.lego-tags .gr-checkboxgroup {
|
| 370 |
+
display: flex;
|
| 371 |
+
flex-wrap: wrap;
|
| 372 |
+
gap: 8px;
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
.lego-tags .gr-checkboxgroup label {
|
| 376 |
+
margin: 0;
|
| 377 |
+
padding: 0;
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
.lego-tags .gr-checkboxgroup input[type="checkbox"] {
|
| 381 |
+
display: none;
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
.lego-tags .gr-checkboxgroup span {
|
| 385 |
+
display: inline-block;
|
| 386 |
+
padding: 7px 12px;
|
| 387 |
+
border: 1px solid #8a8a8a;
|
| 388 |
+
border-radius: 10px;
|
| 389 |
+
background: #f4f4f4;
|
| 390 |
+
color: #222;
|
| 391 |
+
font-size: 0.95rem;
|
| 392 |
+
line-height: 1.2;
|
| 393 |
+
cursor: pointer;
|
| 394 |
+
user-select: none;
|
| 395 |
+
box-shadow: 0 1px 0 rgba(0,0,0,0.12), inset 0 1px 0 rgba(255,255,255,0.7);
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
.lego-tags .gr-checkboxgroup input[type="checkbox"]:checked + span {
|
| 399 |
+
background: #ffd86a;
|
| 400 |
+
border-color: #c49a00;
|
| 401 |
+
box-shadow: 0 2px 0 #a98000, inset 0 1px 0 rgba(255,255,255,0.65);
|
| 402 |
+
transform: translateY(1px);
|
| 403 |
+
}
|
| 404 |
"""
|
| 405 |
|
| 406 |
|
|
|
|
| 487 |
log("Start: received prompt")
|
| 488 |
prompt_in = (user_prompt or "").strip()
|
| 489 |
if not prompt_in:
|
| 490 |
+
return _build_ui_payload(
|
| 491 |
+
console_text="Error: empty prompt",
|
| 492 |
+
legacy_prompt_text="",
|
| 493 |
+
row_defs=[],
|
| 494 |
+
selected_tags=[],
|
| 495 |
+
)
|
| 496 |
|
| 497 |
log("Input:")
|
| 498 |
log(prompt_in)
|
|
|
|
| 656 |
elif enable_probe_tags:
|
| 657 |
log(" No probe tags inferred")
|
| 658 |
|
| 659 |
+
llm_selected_tags = list(dict.fromkeys(selected_tags))
|
| 660 |
+
|
| 661 |
log("Step 3c: Expand via tag implications")
|
| 662 |
t0 = time.perf_counter()
|
| 663 |
tag_set = set(selected_tags)
|
|
|
|
| 688 |
seed_terms.extend(selected_tags)
|
| 689 |
seed_terms = list(dict.fromkeys(seed_terms))
|
| 690 |
|
| 691 |
+
toggle_rows = _build_toggle_rows(
|
| 692 |
seed_terms=seed_terms,
|
| 693 |
+
llm_selected_tags=llm_selected_tags,
|
| 694 |
top_groups=max(1, int(display_top_groups)),
|
| 695 |
top_tags_per_group=max(1, int(display_top_tags_per_group)),
|
| 696 |
group_rank_top_k=max(1, int(display_rank_top_k)),
|
| 697 |
)
|
| 698 |
dt = time.perf_counter()-t0
|
| 699 |
_record_timing("group_display", dt)
|
| 700 |
+
log(f"Ranked group display: {dt:.2f}s ({len(toggle_rows)} rows)")
|
| 701 |
|
| 702 |
total_dt = time.perf_counter()-t_total0
|
| 703 |
_emit_timing_summary(total_dt)
|
| 704 |
_append_timing_jsonl(total_dt)
|
| 705 |
log("Done: final prompt ready")
|
| 706 |
+
return _build_ui_payload(
|
| 707 |
+
console_text="\n".join(logs),
|
| 708 |
+
legacy_prompt_text=final_prompt,
|
| 709 |
+
row_defs=toggle_rows,
|
| 710 |
+
selected_tags=llm_selected_tags,
|
| 711 |
+
)
|
| 712 |
|
| 713 |
except Exception as e:
|
| 714 |
log(f"Error: {type(e).__name__}: {e}")
|
| 715 |
+
return _build_ui_payload(
|
| 716 |
+
console_text="\n".join(logs),
|
| 717 |
+
legacy_prompt_text="",
|
| 718 |
+
row_defs=[],
|
| 719 |
+
selected_tags=[],
|
| 720 |
+
)
|
| 721 |
|
| 722 |
|
| 723 |
|
|
|
|
| 759 |
placeholder="Progress logs will appear here."
|
| 760 |
)
|
| 761 |
|
| 762 |
+
suggested_prompt = gr.Textbox(
|
| 763 |
+
label="Suggested Prompt (From Toggled Tags)",
|
| 764 |
lines=3,
|
| 765 |
interactive=False,
|
| 766 |
+
show_copy_button=True,
|
| 767 |
+
placeholder="Comma-separated tags selected in the rows below."
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
with gr.Accordion("Legacy Pipeline Prompt (for reference)", open=False):
|
| 771 |
+
legacy_final_prompt = gr.Textbox(
|
| 772 |
+
label="Legacy Final Prompt",
|
| 773 |
+
lines=3,
|
| 774 |
+
interactive=False,
|
| 775 |
+
show_copy_button=True,
|
| 776 |
+
)
|
| 777 |
+
|
| 778 |
+
selected_tags_state = gr.State([])
|
| 779 |
+
row_defs_state = gr.State([])
|
| 780 |
+
row_values_state = gr.State([])
|
| 781 |
+
|
| 782 |
+
gr.Markdown("### Toggle Tag Rows")
|
| 783 |
+
row_headers: List[gr.Markdown] = []
|
| 784 |
+
row_checkboxes: List[gr.CheckboxGroup] = []
|
| 785 |
+
for _ in range(display_max_rows_default):
|
| 786 |
+
row_headers.append(gr.Markdown(value="", visible=False))
|
| 787 |
+
row_checkboxes.append(
|
| 788 |
+
gr.CheckboxGroup(
|
| 789 |
+
choices=[],
|
| 790 |
+
value=[],
|
| 791 |
+
visible=False,
|
| 792 |
+
interactive=True,
|
| 793 |
+
container=False,
|
| 794 |
+
elem_classes=["lego-tags"],
|
| 795 |
+
)
|
| 796 |
+
)
|
| 797 |
+
|
| 798 |
+
gr.Markdown(
|
| 799 |
+
"Toggling a tag in any row toggles it everywhere else that tag appears."
|
| 800 |
)
|
| 801 |
|
| 802 |
with gr.Accordion("Display Settings", open=False):
|
|
|
|
| 820 |
minimum=1,
|
| 821 |
)
|
| 822 |
|
| 823 |
+
run_outputs = [
|
| 824 |
+
console,
|
| 825 |
+
legacy_final_prompt,
|
| 826 |
+
suggested_prompt,
|
| 827 |
+
selected_tags_state,
|
| 828 |
+
row_defs_state,
|
| 829 |
+
row_values_state,
|
| 830 |
+
*row_headers,
|
| 831 |
+
*row_checkboxes,
|
| 832 |
+
]
|
| 833 |
|
| 834 |
submit_button.click(
|
| 835 |
rag_pipeline_ui,
|
| 836 |
inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k],
|
| 837 |
+
outputs=run_outputs
|
| 838 |
)
|
| 839 |
|
| 840 |
image_tags.submit(
|
| 841 |
rag_pipeline_ui,
|
| 842 |
inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k],
|
| 843 |
+
outputs=run_outputs
|
| 844 |
)
|
| 845 |
|
| 846 |
+
for idx, row_cb in enumerate(row_checkboxes):
|
| 847 |
+
row_cb.change(
|
| 848 |
+
fn=lambda changed_values, selected_state, row_defs, row_values, i=idx: _on_toggle_row(
|
| 849 |
+
i,
|
| 850 |
+
changed_values,
|
| 851 |
+
selected_state,
|
| 852 |
+
row_defs,
|
| 853 |
+
row_values,
|
| 854 |
+
display_max_rows_default,
|
| 855 |
+
),
|
| 856 |
+
inputs=[row_cb, selected_tags_state, row_defs_state, row_values_state],
|
| 857 |
+
outputs=[selected_tags_state, row_values_state, suggested_prompt, *row_checkboxes],
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
if __name__ == "__main__":
|
| 861 |
app.queue().launch(allowed_paths=[str(MASCOT_DIR)])
|