Spaces:
Running
Running
Food Desert commited on
Commit ·
827e786
1
Parent(s): 6e50f4d
Fix UI tag-button desync and add regression smoke coverage
Browse files- app.py +321 -185
- scripts/smoke_ui_state.py +196 -0
app.py
CHANGED
|
@@ -4,7 +4,6 @@ import logging
|
|
| 4 |
import time
|
| 5 |
import json
|
| 6 |
import csv
|
| 7 |
-
import base64
|
| 8 |
from datetime import datetime
|
| 9 |
from functools import lru_cache
|
| 10 |
from PIL import Image
|
|
@@ -68,53 +67,9 @@ def _normalize_selection_origin(origin: str) -> str:
|
|
| 68 |
return "selection"
|
| 69 |
|
| 70 |
|
| 71 |
-
@lru_cache(maxsize=1)
|
| 72 |
-
def _load_tag_wiki_defs() -> Dict[str, str]:
|
| 73 |
-
p = Path("data/tag_wiki_defs.json")
|
| 74 |
-
if not p.exists():
|
| 75 |
-
return {}
|
| 76 |
-
try:
|
| 77 |
-
with p.open("r", encoding="utf-8") as f:
|
| 78 |
-
data = json.load(f)
|
| 79 |
-
out: Dict[str, str] = {}
|
| 80 |
-
if isinstance(data, dict):
|
| 81 |
-
for k, v in data.items():
|
| 82 |
-
tag = _norm_tag_for_lookup(str(k))
|
| 83 |
-
text = " ".join(str(v or "").split())
|
| 84 |
-
if tag and text:
|
| 85 |
-
out[tag] = text
|
| 86 |
-
return out
|
| 87 |
-
except Exception:
|
| 88 |
-
return {}
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
def _tooltip_text_for_tag(tag: str) -> str:
|
| 92 |
-
t = _norm_tag_for_lookup(tag)
|
| 93 |
-
parts: List[str] = []
|
| 94 |
-
|
| 95 |
-
try:
|
| 96 |
-
count = get_tag_counts().get(t)
|
| 97 |
-
except Exception:
|
| 98 |
-
count = None
|
| 99 |
-
if isinstance(count, int):
|
| 100 |
-
parts.append(f"Count: {count:,}")
|
| 101 |
-
|
| 102 |
-
d = _load_tag_wiki_defs().get(t, "")
|
| 103 |
-
if d:
|
| 104 |
-
parts.append(d)
|
| 105 |
-
|
| 106 |
-
return "\n".join(parts).strip()
|
| 107 |
-
|
| 108 |
-
|
| 109 |
def _choice_label_with_source_meta(tag: str, *, origin: str, preselected: bool) -> str:
|
| 110 |
-
#
|
| 111 |
-
|
| 112 |
-
pre = "1" if preselected else "0"
|
| 113 |
-
tooltip = _tooltip_text_for_tag(tag)
|
| 114 |
-
tip_b64 = ""
|
| 115 |
-
if tooltip:
|
| 116 |
-
tip_b64 = base64.urlsafe_b64encode(tooltip.encode("utf-8")).decode("ascii")
|
| 117 |
-
return f"{_display_tag_text(tag)} [[psq:{origin_norm}:{pre}:{tip_b64}]]"
|
| 118 |
|
| 119 |
|
| 120 |
def _selection_source_rank(origin: str) -> int:
|
|
@@ -219,20 +174,15 @@ def _escape_prompt_tag(tag: str) -> str:
|
|
| 219 |
)
|
| 220 |
|
| 221 |
|
| 222 |
-
def _ordered_selected_for_prompt(selected: Set[str], row_defs: List[Dict[str, Any]]) -> List[str]:
|
| 223 |
-
out: List[str] = []
|
| 224 |
-
seen: Set[str] = set()
|
| 225 |
-
for row in row_defs:
|
| 226 |
-
for tag in row.get("tags", []):
|
| 227 |
-
if tag in selected and tag not in seen:
|
| 228 |
-
out.append(tag)
|
| 229 |
-
seen.add(tag)
|
| 230 |
-
|
| 231 |
-
for tag in sorted(selected):
|
| 232 |
-
if tag not in seen:
|
| 233 |
-
out.append(tag)
|
| 234 |
-
seen.add(tag)
|
| 235 |
-
return out
|
| 236 |
|
| 237 |
|
| 238 |
def _compose_toggle_prompt_text(selected_tags: List[str], row_defs: List[Dict[str, Any]]) -> str:
|
|
@@ -308,7 +258,7 @@ def _filter_min_count_tags(tags: List[str], min_count: int) -> Tuple[List[str],
|
|
| 308 |
return keep, sorted(set(removed))
|
| 309 |
|
| 310 |
|
| 311 |
-
def _filter_excluded_recommendation_tags(tags: List[str]) -> Tuple[List[str], List[str]]:
|
| 312 |
excluded = _load_excluded_recommendation_tags()
|
| 313 |
if not excluded:
|
| 314 |
return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), []
|
|
@@ -327,9 +277,79 @@ def _filter_excluded_recommendation_tags(tags: List[str]) -> Tuple[List[str], Li
|
|
| 327 |
continue
|
| 328 |
seen.add(t)
|
| 329 |
keep.append(t)
|
| 330 |
-
return keep, sorted(set(removed))
|
| 331 |
-
|
| 332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
def _build_toggle_rows(
|
| 334 |
*,
|
| 335 |
seed_terms: List[str],
|
|
@@ -400,6 +420,7 @@ def _build_toggle_rows(
|
|
| 400 |
merged_retrieved_other = selected_in_retrieved_other + [
|
| 401 |
t for t in retrieved_uncategorized_ranked if t not in selected_in_retrieved_other
|
| 402 |
]
|
|
|
|
| 403 |
keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_retrieved_other))
|
| 404 |
merged_retrieved_other = merged_retrieved_other[:keep_n]
|
| 405 |
retrieved_other_meta = {
|
|
@@ -428,6 +449,7 @@ def _build_toggle_rows(
|
|
| 428 |
tag_selection_origins=tag_selection_origins,
|
| 429 |
implied_parent_map=implied_parent_map,
|
| 430 |
)
|
|
|
|
| 431 |
selected_other_meta = {
|
| 432 |
t: {
|
| 433 |
"origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")),
|
|
@@ -454,12 +476,14 @@ def _build_toggle_rows(
|
|
| 454 |
tag_selection_origins=tag_selection_origins,
|
| 455 |
implied_parent_map=implied_parent_map,
|
| 456 |
)
|
| 457 |
-
ranked_tags = [
|
| 458 |
-
t
|
| 459 |
-
for t, _ in row.tags
|
| 460 |
-
if not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)
|
| 461 |
-
]
|
|
|
|
| 462 |
merged = selected_in_group + [t for t in ranked_tags if t not in selected_in_group]
|
|
|
|
| 463 |
keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_group))
|
| 464 |
merged = merged[:keep_n]
|
| 465 |
tag_meta = {
|
|
@@ -545,17 +569,18 @@ def _build_display_audit_line(
|
|
| 545 |
def _build_row_component_updates(
|
| 546 |
row_defs: List[Dict[str, Any]],
|
| 547 |
selected_tags: List[str],
|
| 548 |
-
max_rows: int,
|
| 549 |
-
):
|
| 550 |
-
selected = {t for t in (selected_tags or []) if t}
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
|
|
|
| 559 |
values = [t for t in tags if t in selected]
|
| 560 |
row_values_state.append(values)
|
| 561 |
visible = bool(tags)
|
|
@@ -574,33 +599,51 @@ def _build_row_component_updates(
|
|
| 574 |
visible=visible,
|
| 575 |
)
|
| 576 |
)
|
| 577 |
-
else:
|
| 578 |
-
header_updates.append(gr.update(value="", visible=False))
|
| 579 |
-
checkbox_updates.append(gr.update(choices=[], value=[], visible=False))
|
| 580 |
-
|
| 581 |
-
prompt_text = _compose_toggle_prompt_text(list(selected),
|
| 582 |
-
return prompt_text, row_values_state, header_updates, checkbox_updates
|
| 583 |
|
| 584 |
|
| 585 |
def _on_toggle_row(
|
| 586 |
row_idx: int,
|
| 587 |
changed_values: List[str],
|
| 588 |
selected_tags_state: List[str],
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
|
|
|
| 592 |
):
|
| 593 |
row_defs = row_defs_state or []
|
| 594 |
row_defs_ui = row_defs[: max(0, int(max_rows))]
|
| 595 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
row = row_defs_ui[row_idx] if 0 <= row_idx < len(row_defs_ui) else {}
|
| 597 |
-
row_tags =
|
|
|
|
| 598 |
row_tag_set = set(row_tags)
|
| 599 |
row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags}
|
| 600 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
# Be tolerant to UI payload forms: canonical tag values, display labels, or normalized variants.
|
| 602 |
new_set: Set[str] = set()
|
| 603 |
-
for raw in
|
| 604 |
if raw in row_tag_set:
|
| 605 |
new_set.add(raw)
|
| 606 |
continue
|
|
@@ -609,23 +652,13 @@ def _on_toggle_row(
|
|
| 609 |
if mapped:
|
| 610 |
new_set.add(mapped)
|
| 611 |
|
| 612 |
-
|
| 613 |
-
prev_row_values = prev_values[row_idx] if 0 <= row_idx < len(prev_values) else []
|
| 614 |
-
prev_row_selected = set()
|
| 615 |
-
for raw in (prev_row_values or []):
|
| 616 |
-
if raw in row_tag_set:
|
| 617 |
-
prev_row_selected.add(raw)
|
| 618 |
-
continue
|
| 619 |
-
raw_norm = _norm_tag_for_lookup(str(raw))
|
| 620 |
-
mapped = row_tag_by_norm.get(raw_norm)
|
| 621 |
-
if mapped:
|
| 622 |
-
prev_row_selected.add(mapped)
|
| 623 |
|
| 624 |
# Ignore non-user/no-op events (e.g., programmatic value re-sets) deterministically.
|
| 625 |
if new_set == prev_row_selected:
|
| 626 |
prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui)
|
| 627 |
checkbox_updates = [gr.skip() for _ in range(max_rows)]
|
| 628 |
-
return [sorted(selected), prev_values, prompt_text, *checkbox_updates]
|
| 629 |
|
| 630 |
selected.difference_update(row_tag_set)
|
| 631 |
selected.update(new_set)
|
|
@@ -634,7 +667,7 @@ def _on_toggle_row(
|
|
| 634 |
new_row_values_state: List[List[str]] = []
|
| 635 |
affected_rows: Set[int] = {row_idx}
|
| 636 |
for idx, row_item in enumerate(row_defs_ui):
|
| 637 |
-
tags =
|
| 638 |
values = [t for t in tags if t in selected]
|
| 639 |
new_row_values_state.append(values)
|
| 640 |
if toggled_tags and any(t in toggled_tags for t in tags):
|
|
@@ -651,7 +684,14 @@ def _on_toggle_row(
|
|
| 651 |
checkbox_updates.append(gr.skip())
|
| 652 |
|
| 653 |
prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui)
|
| 654 |
-
return [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 655 |
|
| 656 |
|
| 657 |
def _build_ui_payload(
|
|
@@ -660,20 +700,30 @@ def _build_ui_payload(
|
|
| 660 |
row_defs: List[Dict[str, Any]],
|
| 661 |
selected_tags: List[str],
|
| 662 |
):
|
| 663 |
-
prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates(
|
| 664 |
-
row_defs=row_defs,
|
| 665 |
-
selected_tags=selected_tags,
|
| 666 |
-
max_rows=display_max_rows_default,
|
| 667 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
return [
|
| 669 |
console_text,
|
| 670 |
gr.update(visible=bool(row_defs)),
|
| 671 |
prompt_text,
|
| 672 |
-
|
|
|
|
|
|
|
| 673 |
row_defs,
|
| 674 |
row_values_state,
|
| 675 |
*header_updates,
|
| 676 |
-
*checkbox_updates,
|
| 677 |
]
|
| 678 |
|
| 679 |
|
|
@@ -688,6 +738,8 @@ def _prepare_run_ui() -> List[Any]:
|
|
| 688 |
gr.skip(),
|
| 689 |
"Running... usually completes in about 20 seconds.",
|
| 690 |
[],
|
|
|
|
|
|
|
| 691 |
[],
|
| 692 |
[],
|
| 693 |
*header_updates,
|
|
@@ -695,6 +747,93 @@ def _prepare_run_ui() -> List[Any]:
|
|
| 695 |
]
|
| 696 |
|
| 697 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 698 |
def _build_selection_query(
|
| 699 |
prompt_in: str,
|
| 700 |
rewritten: str,
|
|
@@ -1116,55 +1255,18 @@ css = """
|
|
| 1116 |
"""
|
| 1117 |
|
| 1118 |
client_js = """
|
| 1119 |
-
() => {
|
| 1120 |
-
const markerRe = /\\s*\\[\\[psq:([a-z_]+):(0|1):([A-Za-z0-9_\\-=]*)\\]\\]\\s*$/;
|
| 1121 |
-
const decodeTip = (b64) => {
|
| 1122 |
-
if (!b64) return "";
|
| 1123 |
-
try {
|
| 1124 |
-
const binary = atob((b64 || "").replace(/-/g, "+").replace(/_/g, "/"));
|
| 1125 |
-
const bytes = Uint8Array.from(binary, (c) => c.charCodeAt(0));
|
| 1126 |
-
return new TextDecoder("utf-8").decode(bytes);
|
| 1127 |
-
} catch (_) {
|
| 1128 |
-
return "";
|
| 1129 |
-
}
|
| 1130 |
-
};
|
| 1131 |
-
const applyTagMeta = () => {
|
| 1132 |
-
const labels = document.querySelectorAll(".lego-tags label");
|
| 1133 |
-
labels.forEach((label) => {
|
| 1134 |
-
const span = label.querySelector("span");
|
| 1135 |
-
if (!span) return;
|
| 1136 |
-
const text = span.textContent || "";
|
| 1137 |
-
const match = text.match(markerRe);
|
| 1138 |
-
if (!match) return;
|
| 1139 |
-
label.dataset.psqOrigin = match[1];
|
| 1140 |
-
label.dataset.psqPreselected = match[2];
|
| 1141 |
-
const tip = decodeTip(match[3] || "");
|
| 1142 |
-
if (tip) {
|
| 1143 |
-
label.title = tip;
|
| 1144 |
-
span.title = tip;
|
| 1145 |
-
} else {
|
| 1146 |
-
label.removeAttribute("title");
|
| 1147 |
-
span.removeAttribute("title");
|
| 1148 |
-
}
|
| 1149 |
-
span.textContent = text.replace(markerRe, "");
|
| 1150 |
-
});
|
| 1151 |
-
};
|
| 1152 |
-
|
| 1153 |
-
applyTagMeta();
|
| 1154 |
-
const observer = new MutationObserver(() => applyTagMeta());
|
| 1155 |
-
observer.observe(document.body, { childList: true, subtree: true, characterData: true });
|
| 1156 |
-
}
|
| 1157 |
"""
|
| 1158 |
|
| 1159 |
|
| 1160 |
def rag_pipeline_ui(
|
| 1161 |
-
user_prompt: str,
|
| 1162 |
-
display_top_groups: float,
|
| 1163 |
-
display_top_tags_per_group: float,
|
| 1164 |
-
display_rank_top_k: float,
|
| 1165 |
-
):
|
| 1166 |
-
logs = []
|
| 1167 |
-
def log(s): logs.append(s)
|
| 1168 |
|
| 1169 |
try:
|
| 1170 |
stage_timings = {}
|
|
@@ -1618,9 +1720,9 @@ def rag_pipeline_ui(
|
|
| 1618 |
top_tags_per_group=max(1, int(display_top_tags_per_group)),
|
| 1619 |
group_rank_top_k=max(1, int(display_rank_top_k)),
|
| 1620 |
)
|
| 1621 |
-
dt = time.perf_counter()-t0
|
| 1622 |
-
_record_timing("group_display", dt)
|
| 1623 |
-
log(f"Ranked group display: {dt:.2f}s ({len(toggle_rows)} rows)")
|
| 1624 |
log(
|
| 1625 |
_build_display_audit_line(
|
| 1626 |
toggle_rows,
|
|
@@ -1629,6 +1731,9 @@ def rag_pipeline_ui(
|
|
| 1629 |
implied_selected_tags=implied_selected_tags,
|
| 1630 |
)
|
| 1631 |
)
|
|
|
|
|
|
|
|
|
|
| 1632 |
|
| 1633 |
total_dt = time.perf_counter()-t_total0
|
| 1634 |
_emit_timing_summary(total_dt)
|
|
@@ -1690,6 +1795,7 @@ with gr.Blocks(css=css, js=client_js) as app:
|
|
| 1690 |
gr.Markdown("Typical runtime: up to ~20 seconds.", elem_classes=["run-hint"])
|
| 1691 |
|
| 1692 |
selected_tags_state = gr.State([])
|
|
|
|
| 1693 |
row_defs_state = gr.State([])
|
| 1694 |
row_values_state = gr.State([])
|
| 1695 |
|
|
@@ -1716,20 +1822,29 @@ with gr.Blocks(css=css, js=client_js) as app:
|
|
| 1716 |
)
|
| 1717 |
)
|
| 1718 |
|
| 1719 |
-
gr.
|
| 1720 |
-
|
| 1721 |
-
|
| 1722 |
-
|
| 1723 |
-
|
| 1724 |
-
|
| 1725 |
-
|
| 1726 |
-
|
| 1727 |
-
|
| 1728 |
-
|
| 1729 |
-
|
| 1730 |
-
|
| 1731 |
-
|
| 1732 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1733 |
|
| 1734 |
with gr.Accordion("Display Settings", open=False):
|
| 1735 |
with gr.Row():
|
|
@@ -1765,11 +1880,13 @@ with gr.Blocks(css=css, js=client_js) as app:
|
|
| 1765 |
toggle_instruction,
|
| 1766 |
suggested_prompt,
|
| 1767 |
selected_tags_state,
|
|
|
|
|
|
|
| 1768 |
row_defs_state,
|
| 1769 |
row_values_state,
|
| 1770 |
-
*row_headers,
|
| 1771 |
-
*row_checkboxes,
|
| 1772 |
-
]
|
| 1773 |
|
| 1774 |
submit_button.click(
|
| 1775 |
_prepare_run_ui,
|
|
@@ -1796,17 +1913,36 @@ with gr.Blocks(css=css, js=client_js) as app:
|
|
| 1796 |
)
|
| 1797 |
|
| 1798 |
for idx, row_cb in enumerate(row_checkboxes):
|
| 1799 |
-
row_cb.
|
| 1800 |
-
fn=lambda changed_values, selected_state, row_defs, row_values, i=idx: _on_toggle_row(
|
| 1801 |
i,
|
| 1802 |
changed_values,
|
| 1803 |
selected_state,
|
|
|
|
| 1804 |
row_defs,
|
| 1805 |
-
row_values,
|
| 1806 |
-
display_max_rows_default,
|
| 1807 |
),
|
| 1808 |
-
inputs=[row_cb, selected_tags_state, row_defs_state, row_values_state],
|
| 1809 |
-
outputs=[selected_tags_state, row_values_state, suggested_prompt, *row_checkboxes],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1810 |
queue=False,
|
| 1811 |
show_progress="hidden",
|
| 1812 |
)
|
|
|
|
| 4 |
import time
|
| 5 |
import json
|
| 6 |
import csv
|
|
|
|
| 7 |
from datetime import datetime
|
| 8 |
from functools import lru_cache
|
| 9 |
from PIL import Image
|
|
|
|
| 67 |
return "selection"
|
| 68 |
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
def _choice_label_with_source_meta(tag: str, *, origin: str, preselected: bool) -> str:
|
| 71 |
+
# Keep labels plain to avoid frontend text/value desynchronization.
|
| 72 |
+
return _display_tag_text(tag)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
|
| 75 |
def _selection_source_rank(origin: str) -> int:
|
|
|
|
| 174 |
)
|
| 175 |
|
| 176 |
|
| 177 |
+
def _ordered_selected_for_prompt(selected: Set[str], row_defs: List[Dict[str, Any]]) -> List[str]:
|
| 178 |
+
out: List[str] = []
|
| 179 |
+
seen: Set[str] = set()
|
| 180 |
+
for row in row_defs:
|
| 181 |
+
for tag in row.get("tags", []):
|
| 182 |
+
if tag in selected and tag not in seen:
|
| 183 |
+
out.append(tag)
|
| 184 |
+
seen.add(tag)
|
| 185 |
+
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
|
| 188 |
def _compose_toggle_prompt_text(selected_tags: List[str], row_defs: List[Dict[str, Any]]) -> str:
|
|
|
|
| 258 |
return keep, sorted(set(removed))
|
| 259 |
|
| 260 |
|
| 261 |
+
def _filter_excluded_recommendation_tags(tags: List[str]) -> Tuple[List[str], List[str]]:
|
| 262 |
excluded = _load_excluded_recommendation_tags()
|
| 263 |
if not excluded:
|
| 264 |
return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), []
|
|
|
|
| 277 |
continue
|
| 278 |
seen.add(t)
|
| 279 |
keep.append(t)
|
| 280 |
+
return keep, sorted(set(removed))
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def _dedupe_norm_tags(tags: List[str]) -> List[str]:
|
| 284 |
+
out: List[str] = []
|
| 285 |
+
seen: Set[str] = set()
|
| 286 |
+
for raw in (tags or []):
|
| 287 |
+
t = _norm_tag_for_lookup(str(raw))
|
| 288 |
+
if not t or t in seen:
|
| 289 |
+
continue
|
| 290 |
+
seen.add(t)
|
| 291 |
+
out.append(t)
|
| 292 |
+
return out
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _collect_visible_tags(row_defs: List[Dict[str, Any]]) -> Set[str]:
|
| 296 |
+
out: Set[str] = set()
|
| 297 |
+
for row in (row_defs or []):
|
| 298 |
+
for t in _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else []):
|
| 299 |
+
out.add(t)
|
| 300 |
+
return out
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def _collect_selected_from_state(
|
| 304 |
+
selected_tags_state: List[str],
|
| 305 |
+
row_defs: List[Dict[str, Any]],
|
| 306 |
+
) -> List[str]:
|
| 307 |
+
visible_tags = _collect_visible_tags(row_defs)
|
| 308 |
+
if not visible_tags:
|
| 309 |
+
return []
|
| 310 |
+
selected: List[str] = []
|
| 311 |
+
seen: Set[str] = set()
|
| 312 |
+
visible_by_norm = {_norm_tag_for_lookup(t): t for t in visible_tags}
|
| 313 |
+
for raw in (selected_tags_state or []):
|
| 314 |
+
t = _norm_tag_for_lookup(str(raw))
|
| 315 |
+
if not t:
|
| 316 |
+
continue
|
| 317 |
+
mapped = t if t in visible_tags else visible_by_norm.get(t)
|
| 318 |
+
if not mapped or mapped in seen:
|
| 319 |
+
continue
|
| 320 |
+
seen.add(mapped)
|
| 321 |
+
selected.append(mapped)
|
| 322 |
+
return selected
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def _collect_selected_from_row_values(
|
| 326 |
+
row_defs: List[Dict[str, Any]],
|
| 327 |
+
row_values_state: List[List[str]],
|
| 328 |
+
) -> List[str]:
|
| 329 |
+
selected: List[str] = []
|
| 330 |
+
seen: Set[str] = set()
|
| 331 |
+
values = list(row_values_state or [])
|
| 332 |
+
for idx, row in enumerate(row_defs or []):
|
| 333 |
+
row_tags = _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else [])
|
| 334 |
+
if not row_tags:
|
| 335 |
+
continue
|
| 336 |
+
row_tag_set = set(row_tags)
|
| 337 |
+
row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags}
|
| 338 |
+
raw_vals = values[idx] if 0 <= idx < len(values) else []
|
| 339 |
+
for raw in (raw_vals or []):
|
| 340 |
+
if raw in row_tag_set:
|
| 341 |
+
if raw not in seen:
|
| 342 |
+
seen.add(raw)
|
| 343 |
+
selected.append(raw)
|
| 344 |
+
continue
|
| 345 |
+
raw_norm = _norm_tag_for_lookup(str(raw))
|
| 346 |
+
mapped = row_tag_by_norm.get(raw_norm)
|
| 347 |
+
if mapped and mapped not in seen:
|
| 348 |
+
seen.add(mapped)
|
| 349 |
+
selected.append(mapped)
|
| 350 |
+
return selected
|
| 351 |
+
|
| 352 |
+
|
| 353 |
def _build_toggle_rows(
|
| 354 |
*,
|
| 355 |
seed_terms: List[str],
|
|
|
|
| 420 |
merged_retrieved_other = selected_in_retrieved_other + [
|
| 421 |
t for t in retrieved_uncategorized_ranked if t not in selected_in_retrieved_other
|
| 422 |
]
|
| 423 |
+
merged_retrieved_other = _dedupe_norm_tags(merged_retrieved_other)
|
| 424 |
keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_retrieved_other))
|
| 425 |
merged_retrieved_other = merged_retrieved_other[:keep_n]
|
| 426 |
retrieved_other_meta = {
|
|
|
|
| 449 |
tag_selection_origins=tag_selection_origins,
|
| 450 |
implied_parent_map=implied_parent_map,
|
| 451 |
)
|
| 452 |
+
selected_other = _dedupe_norm_tags(selected_other)
|
| 453 |
selected_other_meta = {
|
| 454 |
t: {
|
| 455 |
"origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")),
|
|
|
|
| 476 |
tag_selection_origins=tag_selection_origins,
|
| 477 |
implied_parent_map=implied_parent_map,
|
| 478 |
)
|
| 479 |
+
ranked_tags = [
|
| 480 |
+
_norm_tag_for_lookup(t)
|
| 481 |
+
for t, _ in row.tags
|
| 482 |
+
if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)
|
| 483 |
+
]
|
| 484 |
+
ranked_tags = _dedupe_norm_tags(ranked_tags)
|
| 485 |
merged = selected_in_group + [t for t in ranked_tags if t not in selected_in_group]
|
| 486 |
+
merged = _dedupe_norm_tags(merged)
|
| 487 |
keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_group))
|
| 488 |
merged = merged[:keep_n]
|
| 489 |
tag_meta = {
|
|
|
|
| 569 |
def _build_row_component_updates(
|
| 570 |
row_defs: List[Dict[str, Any]],
|
| 571 |
selected_tags: List[str],
|
| 572 |
+
max_rows: int,
|
| 573 |
+
):
|
| 574 |
+
selected = {t for t in (selected_tags or []) if t}
|
| 575 |
+
row_defs_ui = (row_defs or [])[: max(0, int(max_rows))]
|
| 576 |
+
row_values_state: List[List[str]] = []
|
| 577 |
+
header_updates = []
|
| 578 |
+
checkbox_updates = []
|
| 579 |
+
|
| 580 |
+
for idx in range(max_rows):
|
| 581 |
+
if idx < len(row_defs_ui):
|
| 582 |
+
row = row_defs_ui[idx]
|
| 583 |
+
tags = _dedupe_norm_tags(row.get("tags", []))
|
| 584 |
values = [t for t in tags if t in selected]
|
| 585 |
row_values_state.append(values)
|
| 586 |
visible = bool(tags)
|
|
|
|
| 599 |
visible=visible,
|
| 600 |
)
|
| 601 |
)
|
| 602 |
+
else:
|
| 603 |
+
header_updates.append(gr.update(value="", visible=False))
|
| 604 |
+
checkbox_updates.append(gr.update(choices=[], value=[], visible=False))
|
| 605 |
+
|
| 606 |
+
prompt_text = _compose_toggle_prompt_text(list(selected), row_defs_ui)
|
| 607 |
+
return prompt_text, row_values_state, header_updates, checkbox_updates
|
| 608 |
|
| 609 |
|
| 610 |
def _on_toggle_row(
|
| 611 |
row_idx: int,
|
| 612 |
changed_values: List[str],
|
| 613 |
selected_tags_state: List[str],
|
| 614 |
+
rows_dirty_state: bool,
|
| 615 |
+
row_defs_state: List[Dict[str, Any]],
|
| 616 |
+
row_values_state: List[List[str]],
|
| 617 |
+
max_rows: int,
|
| 618 |
):
|
| 619 |
row_defs = row_defs_state or []
|
| 620 |
row_defs_ui = row_defs[: max(0, int(max_rows))]
|
| 621 |
+
prev_values = list(row_values_state or [])
|
| 622 |
+
selected_from_state = _collect_selected_from_state(selected_tags_state, row_defs_ui)
|
| 623 |
+
selected_from_rows = _collect_selected_from_row_values(row_defs_ui, prev_values)
|
| 624 |
+
# Prefer row-value state as source-of-truth (closest to visible UI), with selected-state as fallback.
|
| 625 |
+
selected: Set[str] = set(selected_from_rows or selected_from_state)
|
| 626 |
+
|
| 627 |
row = row_defs_ui[row_idx] if 0 <= row_idx < len(row_defs_ui) else {}
|
| 628 |
+
row_tags = _dedupe_norm_tags(row.get("tags", []))
|
| 629 |
+
row_label = str(row.get("label", ""))
|
| 630 |
row_tag_set = set(row_tags)
|
| 631 |
row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags}
|
| 632 |
|
| 633 |
+
# Be tolerant to UI payload forms: canonical tag values, display labels, normalized variants,
|
| 634 |
+
# and occasional single-string payloads from frontend events.
|
| 635 |
+
if changed_values is None:
|
| 636 |
+
changed_iter: List[Any] = []
|
| 637 |
+
elif isinstance(changed_values, str):
|
| 638 |
+
changed_iter = [changed_values]
|
| 639 |
+
elif isinstance(changed_values, (list, tuple, set)):
|
| 640 |
+
changed_iter = list(changed_values)
|
| 641 |
+
else:
|
| 642 |
+
changed_iter = [changed_values]
|
| 643 |
+
|
| 644 |
# Be tolerant to UI payload forms: canonical tag values, display labels, or normalized variants.
|
| 645 |
new_set: Set[str] = set()
|
| 646 |
+
for raw in changed_iter:
|
| 647 |
if raw in row_tag_set:
|
| 648 |
new_set.add(raw)
|
| 649 |
continue
|
|
|
|
| 652 |
if mapped:
|
| 653 |
new_set.add(mapped)
|
| 654 |
|
| 655 |
+
prev_row_selected = {t for t in row_tags if t in selected}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
|
| 657 |
# Ignore non-user/no-op events (e.g., programmatic value re-sets) deterministically.
|
| 658 |
if new_set == prev_row_selected:
|
| 659 |
prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui)
|
| 660 |
checkbox_updates = [gr.skip() for _ in range(max_rows)]
|
| 661 |
+
return [sorted(selected), rows_dirty_state, gr.skip(), prev_values, prompt_text, *checkbox_updates]
|
| 662 |
|
| 663 |
selected.difference_update(row_tag_set)
|
| 664 |
selected.update(new_set)
|
|
|
|
| 667 |
new_row_values_state: List[List[str]] = []
|
| 668 |
affected_rows: Set[int] = {row_idx}
|
| 669 |
for idx, row_item in enumerate(row_defs_ui):
|
| 670 |
+
tags = _dedupe_norm_tags(row_item.get("tags", []))
|
| 671 |
values = [t for t in tags if t in selected]
|
| 672 |
new_row_values_state.append(values)
|
| 673 |
if toggled_tags and any(t in toggled_tags for t in tags):
|
|
|
|
| 684 |
checkbox_updates.append(gr.skip())
|
| 685 |
|
| 686 |
prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui)
|
| 687 |
+
return [
|
| 688 |
+
sorted(selected),
|
| 689 |
+
True,
|
| 690 |
+
gr.update(visible=True, interactive=True),
|
| 691 |
+
new_row_values_state,
|
| 692 |
+
prompt_text,
|
| 693 |
+
*checkbox_updates,
|
| 694 |
+
]
|
| 695 |
|
| 696 |
|
| 697 |
def _build_ui_payload(
|
|
|
|
| 700 |
row_defs: List[Dict[str, Any]],
|
| 701 |
selected_tags: List[str],
|
| 702 |
):
|
| 703 |
+
prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates(
|
| 704 |
+
row_defs=row_defs,
|
| 705 |
+
selected_tags=selected_tags,
|
| 706 |
+
max_rows=display_max_rows_default,
|
| 707 |
+
)
|
| 708 |
+
selected_ui: List[str] = []
|
| 709 |
+
selected_ui_seen: Set[str] = set()
|
| 710 |
+
for vals in row_values_state:
|
| 711 |
+
for t in vals:
|
| 712 |
+
if t in selected_ui_seen:
|
| 713 |
+
continue
|
| 714 |
+
selected_ui_seen.add(t)
|
| 715 |
+
selected_ui.append(t)
|
| 716 |
return [
|
| 717 |
console_text,
|
| 718 |
gr.update(visible=bool(row_defs)),
|
| 719 |
prompt_text,
|
| 720 |
+
selected_ui,
|
| 721 |
+
False,
|
| 722 |
+
gr.update(visible=False, interactive=False),
|
| 723 |
row_defs,
|
| 724 |
row_values_state,
|
| 725 |
*header_updates,
|
| 726 |
+
*checkbox_updates,
|
| 727 |
]
|
| 728 |
|
| 729 |
|
|
|
|
| 738 |
gr.skip(),
|
| 739 |
"Running... usually completes in about 20 seconds.",
|
| 740 |
[],
|
| 741 |
+
False,
|
| 742 |
+
gr.update(visible=False, interactive=False),
|
| 743 |
[],
|
| 744 |
[],
|
| 745 |
*header_updates,
|
|
|
|
| 747 |
]
|
| 748 |
|
| 749 |
|
| 750 |
+
def _rebuild_rows_from_selected(
|
| 751 |
+
selected_tags_state: List[str],
|
| 752 |
+
row_defs_state: List[Dict[str, Any]],
|
| 753 |
+
row_values_state: List[List[str]],
|
| 754 |
+
display_top_groups: float,
|
| 755 |
+
display_top_tags_per_group: float,
|
| 756 |
+
display_rank_top_k: float,
|
| 757 |
+
):
|
| 758 |
+
existing_rows = row_defs_state or []
|
| 759 |
+
existing_values = list(row_values_state or [])
|
| 760 |
+
selected_from_state = _collect_selected_from_state(selected_tags_state, existing_rows)
|
| 761 |
+
selected_from_rows = _collect_selected_from_row_values(existing_rows, existing_values)
|
| 762 |
+
# Rebuild source-of-truth is current row checkbox values; fall back only when unavailable.
|
| 763 |
+
selected_seed = selected_from_rows if existing_values else selected_from_state
|
| 764 |
+
selected_active = list(
|
| 765 |
+
dict.fromkeys(
|
| 766 |
+
_norm_tag_for_lookup(t)
|
| 767 |
+
for t in selected_seed
|
| 768 |
+
if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)
|
| 769 |
+
)
|
| 770 |
+
)
|
| 771 |
+
|
| 772 |
+
retrieved_candidate_tags: List[str] = []
|
| 773 |
+
tag_selection_origins: Dict[str, str] = {}
|
| 774 |
+
for row in existing_rows:
|
| 775 |
+
row_tags = row.get("tags", []) if isinstance(row, dict) else []
|
| 776 |
+
row_meta = row.get("tag_meta", {}) if isinstance(row, dict) else {}
|
| 777 |
+
if not isinstance(row_meta, dict):
|
| 778 |
+
row_meta = {}
|
| 779 |
+
for t in row_tags:
|
| 780 |
+
tn = _norm_tag_for_lookup(t)
|
| 781 |
+
if not tn or _is_artist_tag(tn) or _is_excluded_recommendation_tag(tn):
|
| 782 |
+
continue
|
| 783 |
+
retrieved_candidate_tags.append(tn)
|
| 784 |
+
if tn not in tag_selection_origins:
|
| 785 |
+
meta = row_meta.get(t, {}) if isinstance(row_meta.get(t, {}), dict) else {}
|
| 786 |
+
tag_selection_origins[tn] = _normalize_selection_origin(str(meta.get("origin", "selection")))
|
| 787 |
+
|
| 788 |
+
for t in selected_active:
|
| 789 |
+
tag_selection_origins.setdefault(t, "user")
|
| 790 |
+
retrieved_candidate_tags.append(t)
|
| 791 |
+
|
| 792 |
+
implied_selected_tags = [t for t in selected_active if tag_selection_origins.get(t) == "implied"]
|
| 793 |
+
implied_set = set(implied_selected_tags)
|
| 794 |
+
direct_selected_tags = [t for t in selected_active if t not in implied_set]
|
| 795 |
+
direct_idx = {t: i for i, t in enumerate(direct_selected_tags)}
|
| 796 |
+
direct_selected_tags.sort(
|
| 797 |
+
key=lambda t: (
|
| 798 |
+
_selection_source_rank(tag_selection_origins.get(t, "selection")),
|
| 799 |
+
direct_idx.get(t, 10**9),
|
| 800 |
+
)
|
| 801 |
+
)
|
| 802 |
+
implied_parent_map = _build_implied_parent_map(
|
| 803 |
+
direct_tags_ordered=direct_selected_tags,
|
| 804 |
+
implied_tags=implied_selected_tags,
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
toggle_rows = _build_toggle_rows(
|
| 808 |
+
seed_terms=list(selected_active),
|
| 809 |
+
selected_tags=selected_active,
|
| 810 |
+
retrieved_candidate_tags=list(dict.fromkeys(retrieved_candidate_tags)),
|
| 811 |
+
tag_selection_origins=tag_selection_origins,
|
| 812 |
+
implied_parent_map=implied_parent_map,
|
| 813 |
+
top_groups=max(1, int(display_top_groups)),
|
| 814 |
+
top_tags_per_group=max(1, int(display_top_tags_per_group)),
|
| 815 |
+
group_rank_top_k=max(1, int(display_rank_top_k)),
|
| 816 |
+
)
|
| 817 |
+
|
| 818 |
+
prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates(
|
| 819 |
+
row_defs=toggle_rows,
|
| 820 |
+
selected_tags=selected_active,
|
| 821 |
+
max_rows=display_max_rows_default,
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
return [
|
| 825 |
+
gr.update(visible=bool(toggle_rows)),
|
| 826 |
+
prompt_text,
|
| 827 |
+
sorted(selected_active),
|
| 828 |
+
False,
|
| 829 |
+
gr.update(visible=False, interactive=False),
|
| 830 |
+
toggle_rows,
|
| 831 |
+
row_values_state,
|
| 832 |
+
*header_updates,
|
| 833 |
+
*checkbox_updates,
|
| 834 |
+
]
|
| 835 |
+
|
| 836 |
+
|
| 837 |
def _build_selection_query(
|
| 838 |
prompt_in: str,
|
| 839 |
rewritten: str,
|
|
|
|
| 1255 |
"""
|
| 1256 |
|
| 1257 |
client_js = """
|
| 1258 |
+
() => {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1259 |
"""
|
| 1260 |
|
| 1261 |
|
| 1262 |
def rag_pipeline_ui(
|
| 1263 |
+
user_prompt: str,
|
| 1264 |
+
display_top_groups: float,
|
| 1265 |
+
display_top_tags_per_group: float,
|
| 1266 |
+
display_rank_top_k: float,
|
| 1267 |
+
):
|
| 1268 |
+
logs = []
|
| 1269 |
+
def log(s): logs.append(s)
|
| 1270 |
|
| 1271 |
try:
|
| 1272 |
stage_timings = {}
|
|
|
|
| 1720 |
top_tags_per_group=max(1, int(display_top_tags_per_group)),
|
| 1721 |
group_rank_top_k=max(1, int(display_rank_top_k)),
|
| 1722 |
)
|
| 1723 |
+
dt = time.perf_counter()-t0
|
| 1724 |
+
_record_timing("group_display", dt)
|
| 1725 |
+
log(f"Ranked group display: {dt:.2f}s ({len(toggle_rows)} rows)")
|
| 1726 |
log(
|
| 1727 |
_build_display_audit_line(
|
| 1728 |
toggle_rows,
|
|
|
|
| 1731 |
implied_selected_tags=implied_selected_tags,
|
| 1732 |
)
|
| 1733 |
)
|
| 1734 |
+
for idx, row in enumerate(toggle_rows[: max(0, int(display_max_rows_default))]):
|
| 1735 |
+
tags_preview = ", ".join(row.get("tags", []))
|
| 1736 |
+
log(f"UI Row {idx}: {row.get('label', '')} :: {tags_preview}")
|
| 1737 |
|
| 1738 |
total_dt = time.perf_counter()-t_total0
|
| 1739 |
_emit_timing_summary(total_dt)
|
|
|
|
| 1795 |
gr.Markdown("Typical runtime: up to ~20 seconds.", elem_classes=["run-hint"])
|
| 1796 |
|
| 1797 |
selected_tags_state = gr.State([])
|
| 1798 |
+
rows_dirty_state = gr.State(False)
|
| 1799 |
row_defs_state = gr.State([])
|
| 1800 |
row_values_state = gr.State([])
|
| 1801 |
|
|
|
|
| 1822 |
)
|
| 1823 |
)
|
| 1824 |
|
| 1825 |
+
with gr.Row():
|
| 1826 |
+
with gr.Column(scale=10):
|
| 1827 |
+
gr.HTML(
|
| 1828 |
+
"""
|
| 1829 |
+
<div class="source-legend">
|
| 1830 |
+
<span class="legend-title">Legend:</span>
|
| 1831 |
+
<span class="chip rewrite">Rewrite phrase</span>
|
| 1832 |
+
<span class="chip selection">General selection</span>
|
| 1833 |
+
<span class="chip probe">Probe query</span>
|
| 1834 |
+
<span class="chip structural">Structural query</span>
|
| 1835 |
+
<span class="chip implied">Implied</span>
|
| 1836 |
+
<span class="chip user">User-toggled</span>
|
| 1837 |
+
<span class="chip unselected">Unselected</span>
|
| 1838 |
+
</div>
|
| 1839 |
+
"""
|
| 1840 |
+
)
|
| 1841 |
+
with gr.Column(scale=2, min_width=180):
|
| 1842 |
+
rebuild_rows_button = gr.Button(
|
| 1843 |
+
"Rebuild Rows",
|
| 1844 |
+
variant="primary",
|
| 1845 |
+
visible=False,
|
| 1846 |
+
interactive=False,
|
| 1847 |
+
)
|
| 1848 |
|
| 1849 |
with gr.Accordion("Display Settings", open=False):
|
| 1850 |
with gr.Row():
|
|
|
|
| 1880 |
toggle_instruction,
|
| 1881 |
suggested_prompt,
|
| 1882 |
selected_tags_state,
|
| 1883 |
+
rows_dirty_state,
|
| 1884 |
+
rebuild_rows_button,
|
| 1885 |
row_defs_state,
|
| 1886 |
row_values_state,
|
| 1887 |
+
*row_headers,
|
| 1888 |
+
*row_checkboxes,
|
| 1889 |
+
]
|
| 1890 |
|
| 1891 |
submit_button.click(
|
| 1892 |
_prepare_run_ui,
|
|
|
|
| 1913 |
)
|
| 1914 |
|
| 1915 |
for idx, row_cb in enumerate(row_checkboxes):
|
| 1916 |
+
row_cb.change(
|
| 1917 |
+
fn=lambda changed_values, selected_state, rows_dirty, row_defs, row_values, i=idx: _on_toggle_row(
|
| 1918 |
i,
|
| 1919 |
changed_values,
|
| 1920 |
selected_state,
|
| 1921 |
+
rows_dirty,
|
| 1922 |
row_defs,
|
| 1923 |
+
row_values,
|
| 1924 |
+
display_max_rows_default,
|
| 1925 |
),
|
| 1926 |
+
inputs=[row_cb, selected_tags_state, rows_dirty_state, row_defs_state, row_values_state],
|
| 1927 |
+
outputs=[selected_tags_state, rows_dirty_state, rebuild_rows_button, row_values_state, suggested_prompt, *row_checkboxes],
|
| 1928 |
+
queue=False,
|
| 1929 |
+
show_progress="hidden",
|
| 1930 |
+
)
|
| 1931 |
+
|
| 1932 |
+
rebuild_rows_button.click(
|
| 1933 |
+
_rebuild_rows_from_selected,
|
| 1934 |
+
inputs=[selected_tags_state, row_defs_state, row_values_state, display_top_groups, display_top_tags_per_group, display_rank_top_k],
|
| 1935 |
+
outputs=[
|
| 1936 |
+
toggle_instruction,
|
| 1937 |
+
suggested_prompt,
|
| 1938 |
+
selected_tags_state,
|
| 1939 |
+
rows_dirty_state,
|
| 1940 |
+
rebuild_rows_button,
|
| 1941 |
+
row_defs_state,
|
| 1942 |
+
row_values_state,
|
| 1943 |
+
*row_headers,
|
| 1944 |
+
*row_checkboxes,
|
| 1945 |
+
],
|
| 1946 |
queue=False,
|
| 1947 |
show_progress="hidden",
|
| 1948 |
)
|
scripts/smoke_ui_state.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
| 5 |
+
import app
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _assert(cond: bool, msg: str) -> None:
|
| 9 |
+
if not cond:
|
| 10 |
+
raise AssertionError(msg)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_prompt_uses_visible_rows_only() -> None:
|
| 14 |
+
# If selected state contains stale hidden tags, prompt should still reflect visible-row selections only.
|
| 15 |
+
row_defs = [
|
| 16 |
+
{"name": "r1", "label": "R1", "tags": ["solo", "female"], "tag_meta": {}},
|
| 17 |
+
{"name": "r2", "label": "R2", "tags": ["cub"], "tag_meta": {}},
|
| 18 |
+
]
|
| 19 |
+
payload = app._build_ui_payload(
|
| 20 |
+
console_text="x",
|
| 21 |
+
row_defs=row_defs,
|
| 22 |
+
selected_tags=["solo", "rosalina_(mario)"],
|
| 23 |
+
)
|
| 24 |
+
prompt_text = payload[2]
|
| 25 |
+
selected_state = payload[3]
|
| 26 |
+
_assert("rosalina \\(mario\\)" not in prompt_text, "stale hidden tag leaked into prompt")
|
| 27 |
+
_assert("solo" in prompt_text, "visible selected tag missing from prompt")
|
| 28 |
+
_assert("rosalina_(mario)" not in selected_state, "stale hidden tag leaked into selected state")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def test_row_deduping() -> None:
|
| 32 |
+
row_defs = [
|
| 33 |
+
{
|
| 34 |
+
"name": "other_retrieved",
|
| 35 |
+
"label": "Other (Retrieved)",
|
| 36 |
+
"tags": ["cub", "expressions", "invalid_tag", "cub", "expressions"],
|
| 37 |
+
"tag_meta": {},
|
| 38 |
+
}
|
| 39 |
+
]
|
| 40 |
+
prompt_text, row_values_state, _, checkbox_updates = app._build_row_component_updates(
|
| 41 |
+
row_defs=row_defs,
|
| 42 |
+
selected_tags=["cub", "expressions"],
|
| 43 |
+
max_rows=app.display_max_rows_default,
|
| 44 |
+
)
|
| 45 |
+
_assert(prompt_text == "cub, expressions", "prompt should be deduped and ordered from row")
|
| 46 |
+
_assert(row_values_state[0] == ["cub", "expressions"], "row selected values should be deduped")
|
| 47 |
+
first_choices = checkbox_updates[0]["choices"]
|
| 48 |
+
first_values = [v for _, v in first_choices]
|
| 49 |
+
_assert(first_values == ["cub", "expressions", "invalid_tag"], "row choices should be deduped")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def test_rebuild_ignores_stale_selected_state() -> None:
|
| 53 |
+
row_defs = [
|
| 54 |
+
{"name": "selected_other", "label": "Selected (Other)", "tags": ["solo", "female", "anthro"], "tag_meta": {}},
|
| 55 |
+
{"name": "other_retrieved", "label": "Other (Retrieved)", "tags": ["cub", "expressions"], "tag_meta": {}},
|
| 56 |
+
]
|
| 57 |
+
# Simulate UI state where user has deselected anthro, but stale selected state still contains it.
|
| 58 |
+
selected_state = ["solo", "female", "anthro", "cub"]
|
| 59 |
+
row_values_state = [["solo", "female"], ["cub"]]
|
| 60 |
+
out = app._rebuild_rows_from_selected(
|
| 61 |
+
selected_state,
|
| 62 |
+
row_defs,
|
| 63 |
+
row_values_state,
|
| 64 |
+
app.display_top_groups_default,
|
| 65 |
+
app.display_top_tags_per_group_default,
|
| 66 |
+
app.display_rank_top_k_default,
|
| 67 |
+
)
|
| 68 |
+
prompt = out[1]
|
| 69 |
+
selected_after = out[2]
|
| 70 |
+
_assert("anthro" not in selected_after, "rebuild should not resurrect stale deselected tags")
|
| 71 |
+
_assert("anthro" not in prompt, "prompt should not include stale deselected tags")
|
| 72 |
+
_assert("solo" in prompt and "female" in prompt and "cub" in prompt, "rebuild should retain current row selections")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def test_toggle_then_rebuild_does_not_resurrect_removed_tag() -> None:
|
| 76 |
+
row_defs = [
|
| 77 |
+
{"name": "selected_other", "label": "Selected (Other)", "tags": ["solo", "anthro", "female"], "tag_meta": {}},
|
| 78 |
+
{"name": "other_retrieved", "label": "Other (Retrieved)", "tags": ["cub", "expressions"], "tag_meta": {}},
|
| 79 |
+
]
|
| 80 |
+
selected_state = ["solo", "anthro", "female", "cub"]
|
| 81 |
+
row_values_state = [["solo", "anthro", "female"], ["cub"]]
|
| 82 |
+
|
| 83 |
+
# User unchecks anthro in row 0.
|
| 84 |
+
toggle_out = app._on_toggle_row(
|
| 85 |
+
0,
|
| 86 |
+
["solo", "female"],
|
| 87 |
+
selected_state,
|
| 88 |
+
False,
|
| 89 |
+
row_defs,
|
| 90 |
+
row_values_state,
|
| 91 |
+
app.display_max_rows_default,
|
| 92 |
+
)
|
| 93 |
+
selected_after_toggle = toggle_out[0]
|
| 94 |
+
row_values_after_toggle = toggle_out[3]
|
| 95 |
+
_assert("anthro" not in selected_after_toggle, "toggle should remove anthro from selected state")
|
| 96 |
+
|
| 97 |
+
# Rebuild from current row values must preserve the user-toggle result.
|
| 98 |
+
rebuild_out = app._rebuild_rows_from_selected(
|
| 99 |
+
selected_after_toggle,
|
| 100 |
+
row_defs,
|
| 101 |
+
row_values_after_toggle,
|
| 102 |
+
app.display_top_groups_default,
|
| 103 |
+
app.display_top_tags_per_group_default,
|
| 104 |
+
app.display_rank_top_k_default,
|
| 105 |
+
)
|
| 106 |
+
prompt_after_rebuild = rebuild_out[1]
|
| 107 |
+
selected_after_rebuild = rebuild_out[2]
|
| 108 |
+
_assert("anthro" not in selected_after_rebuild, "rebuild should not resurrect deselected anthro")
|
| 109 |
+
_assert("anthro" not in prompt_after_rebuild, "prompt should not contain deselected anthro after rebuild")
|
| 110 |
+
_assert("solo" in prompt_after_rebuild and "female" in prompt_after_rebuild, "kept selections should remain")
|
| 111 |
+
_assert("cub" in prompt_after_rebuild, "other retrieved selection should remain")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def test_toggle_does_not_cross_activate_unrelated_row_tag() -> None:
|
| 115 |
+
row_defs = [
|
| 116 |
+
{"name": "organization", "label": "Organization", "tags": ["pinup", "close-up"], "tag_meta": {}},
|
| 117 |
+
{"name": "color_markings", "label": "Color Markings", "tags": ["shoulder_markings", "black_markings"], "tag_meta": {}},
|
| 118 |
+
]
|
| 119 |
+
selected_state = []
|
| 120 |
+
row_values_state = [[], []]
|
| 121 |
+
|
| 122 |
+
# User enables close-up in organization row.
|
| 123 |
+
out = app._on_toggle_row(
|
| 124 |
+
0,
|
| 125 |
+
["close-up"],
|
| 126 |
+
selected_state,
|
| 127 |
+
False,
|
| 128 |
+
row_defs,
|
| 129 |
+
row_values_state,
|
| 130 |
+
app.display_max_rows_default,
|
| 131 |
+
)
|
| 132 |
+
selected_after = out[0]
|
| 133 |
+
row_values_after = out[3]
|
| 134 |
+
_assert("close-up" in selected_after, "close-up should be selected")
|
| 135 |
+
_assert("shoulder_markings" not in selected_after, "unrelated row tag should not be auto-selected")
|
| 136 |
+
_assert(row_values_after[0] == ["close-up"], "organization row values should include close-up only")
|
| 137 |
+
_assert(row_values_after[1] == [], "color markings row should remain unselected")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def test_shared_tag_mirrors_without_unrelated_cross_toggle() -> None:
|
| 141 |
+
row_defs = [
|
| 142 |
+
{"name": "objects_props", "label": "Objects Props", "tags": ["holding_face", "holding_clothing"], "tag_meta": {}},
|
| 143 |
+
{"name": "expression_detail", "label": "Expression Detail", "tags": ["open_mouth", "closed_smile"], "tag_meta": {}},
|
| 144 |
+
{"name": "pose_action_detail", "label": "Pose Action Detail", "tags": ["holding_face", "walking"], "tag_meta": {}},
|
| 145 |
+
]
|
| 146 |
+
selected_state = []
|
| 147 |
+
row_values_state = [[], [], []]
|
| 148 |
+
|
| 149 |
+
# Enable open_mouth; should not affect holding_face rows.
|
| 150 |
+
out1 = app._on_toggle_row(
|
| 151 |
+
1,
|
| 152 |
+
["open_mouth"],
|
| 153 |
+
selected_state,
|
| 154 |
+
False,
|
| 155 |
+
row_defs,
|
| 156 |
+
row_values_state,
|
| 157 |
+
app.display_max_rows_default,
|
| 158 |
+
)
|
| 159 |
+
sel1 = out1[0]
|
| 160 |
+
vals1 = out1[3]
|
| 161 |
+
_assert("open_mouth" in sel1, "open_mouth should be selected")
|
| 162 |
+
_assert("holding_face" not in sel1, "holding_face must remain unselected")
|
| 163 |
+
_assert(vals1[0] == [], "objects props row should remain unselected")
|
| 164 |
+
_assert(vals1[1] == ["open_mouth"], "expression row should select open_mouth")
|
| 165 |
+
_assert(vals1[2] == [], "pose row should remain unselected")
|
| 166 |
+
|
| 167 |
+
# Enable holding_face in objects row; should mirror only to pose row, not expression row.
|
| 168 |
+
out2 = app._on_toggle_row(
|
| 169 |
+
0,
|
| 170 |
+
["holding_face"],
|
| 171 |
+
sel1,
|
| 172 |
+
True,
|
| 173 |
+
row_defs,
|
| 174 |
+
vals1,
|
| 175 |
+
app.display_max_rows_default,
|
| 176 |
+
)
|
| 177 |
+
sel2 = out2[0]
|
| 178 |
+
vals2 = out2[3]
|
| 179 |
+
_assert("holding_face" in sel2 and "open_mouth" in sel2, "both explicitly selected tags should be present")
|
| 180 |
+
_assert(vals2[0] == ["holding_face"], "objects row should select holding_face")
|
| 181 |
+
_assert(vals2[1] == ["open_mouth"], "expression row should keep open_mouth only")
|
| 182 |
+
_assert(vals2[2] == ["holding_face"], "pose row should mirror holding_face")
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def main() -> None:
|
| 186 |
+
test_prompt_uses_visible_rows_only()
|
| 187 |
+
test_row_deduping()
|
| 188 |
+
test_rebuild_ignores_stale_selected_state()
|
| 189 |
+
test_toggle_then_rebuild_does_not_resurrect_removed_tag()
|
| 190 |
+
test_toggle_does_not_cross_activate_unrelated_row_tag()
|
| 191 |
+
test_shared_tag_mirrors_without_unrelated_cross_toggle()
|
| 192 |
+
print("ui state smoke: ok")
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
main()
|