Food Desert commited on
Commit
827e786
·
1 Parent(s): 6e50f4d

Fix UI tag-button desync and add regression smoke coverage

Browse files
Files changed (2) hide show
  1. app.py +321 -185
  2. scripts/smoke_ui_state.py +196 -0
app.py CHANGED
@@ -4,7 +4,6 @@ import logging
4
  import time
5
  import json
6
  import csv
7
- import base64
8
  from datetime import datetime
9
  from functools import lru_cache
10
  from PIL import Image
@@ -68,53 +67,9 @@ def _normalize_selection_origin(origin: str) -> str:
68
  return "selection"
69
 
70
 
71
- @lru_cache(maxsize=1)
72
- def _load_tag_wiki_defs() -> Dict[str, str]:
73
- p = Path("data/tag_wiki_defs.json")
74
- if not p.exists():
75
- return {}
76
- try:
77
- with p.open("r", encoding="utf-8") as f:
78
- data = json.load(f)
79
- out: Dict[str, str] = {}
80
- if isinstance(data, dict):
81
- for k, v in data.items():
82
- tag = _norm_tag_for_lookup(str(k))
83
- text = " ".join(str(v or "").split())
84
- if tag and text:
85
- out[tag] = text
86
- return out
87
- except Exception:
88
- return {}
89
-
90
-
91
- def _tooltip_text_for_tag(tag: str) -> str:
92
- t = _norm_tag_for_lookup(tag)
93
- parts: List[str] = []
94
-
95
- try:
96
- count = get_tag_counts().get(t)
97
- except Exception:
98
- count = None
99
- if isinstance(count, int):
100
- parts.append(f"Count: {count:,}")
101
-
102
- d = _load_tag_wiki_defs().get(t, "")
103
- if d:
104
- parts.append(d)
105
-
106
- return "\n".join(parts).strip()
107
-
108
-
109
  def _choice_label_with_source_meta(tag: str, *, origin: str, preselected: bool) -> str:
110
- # Marker is stripped client-side and converted into data attributes for CSS-driven colors/tooltips.
111
- origin_norm = _normalize_selection_origin(origin)
112
- pre = "1" if preselected else "0"
113
- tooltip = _tooltip_text_for_tag(tag)
114
- tip_b64 = ""
115
- if tooltip:
116
- tip_b64 = base64.urlsafe_b64encode(tooltip.encode("utf-8")).decode("ascii")
117
- return f"{_display_tag_text(tag)} [[psq:{origin_norm}:{pre}:{tip_b64}]]"
118
 
119
 
120
  def _selection_source_rank(origin: str) -> int:
@@ -219,20 +174,15 @@ def _escape_prompt_tag(tag: str) -> str:
219
  )
220
 
221
 
222
- def _ordered_selected_for_prompt(selected: Set[str], row_defs: List[Dict[str, Any]]) -> List[str]:
223
- out: List[str] = []
224
- seen: Set[str] = set()
225
- for row in row_defs:
226
- for tag in row.get("tags", []):
227
- if tag in selected and tag not in seen:
228
- out.append(tag)
229
- seen.add(tag)
230
- # Fallback for any selected tags not present in current rows.
231
- for tag in sorted(selected):
232
- if tag not in seen:
233
- out.append(tag)
234
- seen.add(tag)
235
- return out
236
 
237
 
238
  def _compose_toggle_prompt_text(selected_tags: List[str], row_defs: List[Dict[str, Any]]) -> str:
@@ -308,7 +258,7 @@ def _filter_min_count_tags(tags: List[str], min_count: int) -> Tuple[List[str],
308
  return keep, sorted(set(removed))
309
 
310
 
311
- def _filter_excluded_recommendation_tags(tags: List[str]) -> Tuple[List[str], List[str]]:
312
  excluded = _load_excluded_recommendation_tags()
313
  if not excluded:
314
  return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), []
@@ -327,9 +277,79 @@ def _filter_excluded_recommendation_tags(tags: List[str]) -> Tuple[List[str], Li
327
  continue
328
  seen.add(t)
329
  keep.append(t)
330
- return keep, sorted(set(removed))
331
-
332
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  def _build_toggle_rows(
334
  *,
335
  seed_terms: List[str],
@@ -400,6 +420,7 @@ def _build_toggle_rows(
400
  merged_retrieved_other = selected_in_retrieved_other + [
401
  t for t in retrieved_uncategorized_ranked if t not in selected_in_retrieved_other
402
  ]
 
403
  keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_retrieved_other))
404
  merged_retrieved_other = merged_retrieved_other[:keep_n]
405
  retrieved_other_meta = {
@@ -428,6 +449,7 @@ def _build_toggle_rows(
428
  tag_selection_origins=tag_selection_origins,
429
  implied_parent_map=implied_parent_map,
430
  )
 
431
  selected_other_meta = {
432
  t: {
433
  "origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")),
@@ -454,12 +476,14 @@ def _build_toggle_rows(
454
  tag_selection_origins=tag_selection_origins,
455
  implied_parent_map=implied_parent_map,
456
  )
457
- ranked_tags = [
458
- t
459
- for t, _ in row.tags
460
- if not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)
461
- ]
 
462
  merged = selected_in_group + [t for t in ranked_tags if t not in selected_in_group]
 
463
  keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_group))
464
  merged = merged[:keep_n]
465
  tag_meta = {
@@ -545,17 +569,18 @@ def _build_display_audit_line(
545
  def _build_row_component_updates(
546
  row_defs: List[Dict[str, Any]],
547
  selected_tags: List[str],
548
- max_rows: int,
549
- ):
550
- selected = {t for t in (selected_tags or []) if t}
551
- row_values_state: List[List[str]] = []
552
- header_updates = []
553
- checkbox_updates = []
554
-
555
- for idx in range(max_rows):
556
- if idx < len(row_defs):
557
- row = row_defs[idx]
558
- tags = list(dict.fromkeys(row.get("tags", [])))
 
559
  values = [t for t in tags if t in selected]
560
  row_values_state.append(values)
561
  visible = bool(tags)
@@ -574,33 +599,51 @@ def _build_row_component_updates(
574
  visible=visible,
575
  )
576
  )
577
- else:
578
- header_updates.append(gr.update(value="", visible=False))
579
- checkbox_updates.append(gr.update(choices=[], value=[], visible=False))
580
-
581
- prompt_text = _compose_toggle_prompt_text(list(selected), row_defs)
582
- return prompt_text, row_values_state, header_updates, checkbox_updates
583
 
584
 
585
  def _on_toggle_row(
586
  row_idx: int,
587
  changed_values: List[str],
588
  selected_tags_state: List[str],
589
- row_defs_state: List[Dict[str, Any]],
590
- row_values_state: List[List[str]],
591
- max_rows: int,
 
592
  ):
593
  row_defs = row_defs_state or []
594
  row_defs_ui = row_defs[: max(0, int(max_rows))]
595
- selected = set(selected_tags_state or [])
 
 
 
 
 
596
  row = row_defs_ui[row_idx] if 0 <= row_idx < len(row_defs_ui) else {}
597
- row_tags = list(dict.fromkeys(row.get("tags", [])))
 
598
  row_tag_set = set(row_tags)
599
  row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags}
600
 
 
 
 
 
 
 
 
 
 
 
 
601
  # Be tolerant to UI payload forms: canonical tag values, display labels, or normalized variants.
602
  new_set: Set[str] = set()
603
- for raw in (changed_values or []):
604
  if raw in row_tag_set:
605
  new_set.add(raw)
606
  continue
@@ -609,23 +652,13 @@ def _on_toggle_row(
609
  if mapped:
610
  new_set.add(mapped)
611
 
612
- prev_values = list(row_values_state or [])
613
- prev_row_values = prev_values[row_idx] if 0 <= row_idx < len(prev_values) else []
614
- prev_row_selected = set()
615
- for raw in (prev_row_values or []):
616
- if raw in row_tag_set:
617
- prev_row_selected.add(raw)
618
- continue
619
- raw_norm = _norm_tag_for_lookup(str(raw))
620
- mapped = row_tag_by_norm.get(raw_norm)
621
- if mapped:
622
- prev_row_selected.add(mapped)
623
 
624
  # Ignore non-user/no-op events (e.g., programmatic value re-sets) deterministically.
625
  if new_set == prev_row_selected:
626
  prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui)
627
  checkbox_updates = [gr.skip() for _ in range(max_rows)]
628
- return [sorted(selected), prev_values, prompt_text, *checkbox_updates]
629
 
630
  selected.difference_update(row_tag_set)
631
  selected.update(new_set)
@@ -634,7 +667,7 @@ def _on_toggle_row(
634
  new_row_values_state: List[List[str]] = []
635
  affected_rows: Set[int] = {row_idx}
636
  for idx, row_item in enumerate(row_defs_ui):
637
- tags = list(dict.fromkeys(row_item.get("tags", [])))
638
  values = [t for t in tags if t in selected]
639
  new_row_values_state.append(values)
640
  if toggled_tags and any(t in toggled_tags for t in tags):
@@ -651,7 +684,14 @@ def _on_toggle_row(
651
  checkbox_updates.append(gr.skip())
652
 
653
  prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui)
654
- return [sorted(selected), new_row_values_state, prompt_text, *checkbox_updates]
 
 
 
 
 
 
 
655
 
656
 
657
  def _build_ui_payload(
@@ -660,20 +700,30 @@ def _build_ui_payload(
660
  row_defs: List[Dict[str, Any]],
661
  selected_tags: List[str],
662
  ):
663
- prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates(
664
- row_defs=row_defs,
665
- selected_tags=selected_tags,
666
- max_rows=display_max_rows_default,
667
- )
 
 
 
 
 
 
 
 
668
  return [
669
  console_text,
670
  gr.update(visible=bool(row_defs)),
671
  prompt_text,
672
- sorted(set(selected_tags or [])),
 
 
673
  row_defs,
674
  row_values_state,
675
  *header_updates,
676
- *checkbox_updates,
677
  ]
678
 
679
 
@@ -688,6 +738,8 @@ def _prepare_run_ui() -> List[Any]:
688
  gr.skip(),
689
  "Running... usually completes in about 20 seconds.",
690
  [],
 
 
691
  [],
692
  [],
693
  *header_updates,
@@ -695,6 +747,93 @@ def _prepare_run_ui() -> List[Any]:
695
  ]
696
 
697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
698
  def _build_selection_query(
699
  prompt_in: str,
700
  rewritten: str,
@@ -1116,55 +1255,18 @@ css = """
1116
  """
1117
 
1118
  client_js = """
1119
- () => {
1120
- const markerRe = /\\s*\\[\\[psq:([a-z_]+):(0|1):([A-Za-z0-9_\\-=]*)\\]\\]\\s*$/;
1121
- const decodeTip = (b64) => {
1122
- if (!b64) return "";
1123
- try {
1124
- const binary = atob((b64 || "").replace(/-/g, "+").replace(/_/g, "/"));
1125
- const bytes = Uint8Array.from(binary, (c) => c.charCodeAt(0));
1126
- return new TextDecoder("utf-8").decode(bytes);
1127
- } catch (_) {
1128
- return "";
1129
- }
1130
- };
1131
- const applyTagMeta = () => {
1132
- const labels = document.querySelectorAll(".lego-tags label");
1133
- labels.forEach((label) => {
1134
- const span = label.querySelector("span");
1135
- if (!span) return;
1136
- const text = span.textContent || "";
1137
- const match = text.match(markerRe);
1138
- if (!match) return;
1139
- label.dataset.psqOrigin = match[1];
1140
- label.dataset.psqPreselected = match[2];
1141
- const tip = decodeTip(match[3] || "");
1142
- if (tip) {
1143
- label.title = tip;
1144
- span.title = tip;
1145
- } else {
1146
- label.removeAttribute("title");
1147
- span.removeAttribute("title");
1148
- }
1149
- span.textContent = text.replace(markerRe, "");
1150
- });
1151
- };
1152
-
1153
- applyTagMeta();
1154
- const observer = new MutationObserver(() => applyTagMeta());
1155
- observer.observe(document.body, { childList: true, subtree: true, characterData: true });
1156
- }
1157
  """
1158
 
1159
 
1160
  def rag_pipeline_ui(
1161
- user_prompt: str,
1162
- display_top_groups: float,
1163
- display_top_tags_per_group: float,
1164
- display_rank_top_k: float,
1165
- ):
1166
- logs = []
1167
- def log(s): logs.append(s)
1168
 
1169
  try:
1170
  stage_timings = {}
@@ -1618,9 +1720,9 @@ def rag_pipeline_ui(
1618
  top_tags_per_group=max(1, int(display_top_tags_per_group)),
1619
  group_rank_top_k=max(1, int(display_rank_top_k)),
1620
  )
1621
- dt = time.perf_counter()-t0
1622
- _record_timing("group_display", dt)
1623
- log(f"Ranked group display: {dt:.2f}s ({len(toggle_rows)} rows)")
1624
  log(
1625
  _build_display_audit_line(
1626
  toggle_rows,
@@ -1629,6 +1731,9 @@ def rag_pipeline_ui(
1629
  implied_selected_tags=implied_selected_tags,
1630
  )
1631
  )
 
 
 
1632
 
1633
  total_dt = time.perf_counter()-t_total0
1634
  _emit_timing_summary(total_dt)
@@ -1690,6 +1795,7 @@ with gr.Blocks(css=css, js=client_js) as app:
1690
  gr.Markdown("Typical runtime: up to ~20 seconds.", elem_classes=["run-hint"])
1691
 
1692
  selected_tags_state = gr.State([])
 
1693
  row_defs_state = gr.State([])
1694
  row_values_state = gr.State([])
1695
 
@@ -1716,20 +1822,29 @@ with gr.Blocks(css=css, js=client_js) as app:
1716
  )
1717
  )
1718
 
1719
- gr.HTML(
1720
- """
1721
- <div class="source-legend">
1722
- <span class="legend-title">Legend:</span>
1723
- <span class="chip rewrite">Rewrite phrase</span>
1724
- <span class="chip selection">General selection</span>
1725
- <span class="chip probe">Probe query</span>
1726
- <span class="chip structural">Structural query</span>
1727
- <span class="chip implied">Implied</span>
1728
- <span class="chip user">User-toggled</span>
1729
- <span class="chip unselected">Unselected</span>
1730
- </div>
1731
- """
1732
- )
 
 
 
 
 
 
 
 
 
1733
 
1734
  with gr.Accordion("Display Settings", open=False):
1735
  with gr.Row():
@@ -1765,11 +1880,13 @@ with gr.Blocks(css=css, js=client_js) as app:
1765
  toggle_instruction,
1766
  suggested_prompt,
1767
  selected_tags_state,
 
 
1768
  row_defs_state,
1769
  row_values_state,
1770
- *row_headers,
1771
- *row_checkboxes,
1772
- ]
1773
 
1774
  submit_button.click(
1775
  _prepare_run_ui,
@@ -1796,17 +1913,36 @@ with gr.Blocks(css=css, js=client_js) as app:
1796
  )
1797
 
1798
  for idx, row_cb in enumerate(row_checkboxes):
1799
- row_cb.select(
1800
- fn=lambda changed_values, selected_state, row_defs, row_values, i=idx: _on_toggle_row(
1801
  i,
1802
  changed_values,
1803
  selected_state,
 
1804
  row_defs,
1805
- row_values,
1806
- display_max_rows_default,
1807
  ),
1808
- inputs=[row_cb, selected_tags_state, row_defs_state, row_values_state],
1809
- outputs=[selected_tags_state, row_values_state, suggested_prompt, *row_checkboxes],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1810
  queue=False,
1811
  show_progress="hidden",
1812
  )
 
4
  import time
5
  import json
6
  import csv
 
7
  from datetime import datetime
8
  from functools import lru_cache
9
  from PIL import Image
 
67
  return "selection"
68
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def _choice_label_with_source_meta(tag: str, *, origin: str, preselected: bool) -> str:
71
+ # Keep labels plain to avoid frontend text/value desynchronization.
72
+ return _display_tag_text(tag)
 
 
 
 
 
 
73
 
74
 
75
  def _selection_source_rank(origin: str) -> int:
 
174
  )
175
 
176
 
177
+ def _ordered_selected_for_prompt(selected: Set[str], row_defs: List[Dict[str, Any]]) -> List[str]:
178
+ out: List[str] = []
179
+ seen: Set[str] = set()
180
+ for row in row_defs:
181
+ for tag in row.get("tags", []):
182
+ if tag in selected and tag not in seen:
183
+ out.append(tag)
184
+ seen.add(tag)
185
+ return out
 
 
 
 
 
186
 
187
 
188
  def _compose_toggle_prompt_text(selected_tags: List[str], row_defs: List[Dict[str, Any]]) -> str:
 
258
  return keep, sorted(set(removed))
259
 
260
 
261
+ def _filter_excluded_recommendation_tags(tags: List[str]) -> Tuple[List[str], List[str]]:
262
  excluded = _load_excluded_recommendation_tags()
263
  if not excluded:
264
  return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), []
 
277
  continue
278
  seen.add(t)
279
  keep.append(t)
280
+ return keep, sorted(set(removed))
281
+
282
+
283
+ def _dedupe_norm_tags(tags: List[str]) -> List[str]:
284
+ out: List[str] = []
285
+ seen: Set[str] = set()
286
+ for raw in (tags or []):
287
+ t = _norm_tag_for_lookup(str(raw))
288
+ if not t or t in seen:
289
+ continue
290
+ seen.add(t)
291
+ out.append(t)
292
+ return out
293
+
294
+
295
+ def _collect_visible_tags(row_defs: List[Dict[str, Any]]) -> Set[str]:
296
+ out: Set[str] = set()
297
+ for row in (row_defs or []):
298
+ for t in _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else []):
299
+ out.add(t)
300
+ return out
301
+
302
+
303
+ def _collect_selected_from_state(
304
+ selected_tags_state: List[str],
305
+ row_defs: List[Dict[str, Any]],
306
+ ) -> List[str]:
307
+ visible_tags = _collect_visible_tags(row_defs)
308
+ if not visible_tags:
309
+ return []
310
+ selected: List[str] = []
311
+ seen: Set[str] = set()
312
+ visible_by_norm = {_norm_tag_for_lookup(t): t for t in visible_tags}
313
+ for raw in (selected_tags_state or []):
314
+ t = _norm_tag_for_lookup(str(raw))
315
+ if not t:
316
+ continue
317
+ mapped = t if t in visible_tags else visible_by_norm.get(t)
318
+ if not mapped or mapped in seen:
319
+ continue
320
+ seen.add(mapped)
321
+ selected.append(mapped)
322
+ return selected
323
+
324
+
325
+ def _collect_selected_from_row_values(
326
+ row_defs: List[Dict[str, Any]],
327
+ row_values_state: List[List[str]],
328
+ ) -> List[str]:
329
+ selected: List[str] = []
330
+ seen: Set[str] = set()
331
+ values = list(row_values_state or [])
332
+ for idx, row in enumerate(row_defs or []):
333
+ row_tags = _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else [])
334
+ if not row_tags:
335
+ continue
336
+ row_tag_set = set(row_tags)
337
+ row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags}
338
+ raw_vals = values[idx] if 0 <= idx < len(values) else []
339
+ for raw in (raw_vals or []):
340
+ if raw in row_tag_set:
341
+ if raw not in seen:
342
+ seen.add(raw)
343
+ selected.append(raw)
344
+ continue
345
+ raw_norm = _norm_tag_for_lookup(str(raw))
346
+ mapped = row_tag_by_norm.get(raw_norm)
347
+ if mapped and mapped not in seen:
348
+ seen.add(mapped)
349
+ selected.append(mapped)
350
+ return selected
351
+
352
+
353
  def _build_toggle_rows(
354
  *,
355
  seed_terms: List[str],
 
420
  merged_retrieved_other = selected_in_retrieved_other + [
421
  t for t in retrieved_uncategorized_ranked if t not in selected_in_retrieved_other
422
  ]
423
+ merged_retrieved_other = _dedupe_norm_tags(merged_retrieved_other)
424
  keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_retrieved_other))
425
  merged_retrieved_other = merged_retrieved_other[:keep_n]
426
  retrieved_other_meta = {
 
449
  tag_selection_origins=tag_selection_origins,
450
  implied_parent_map=implied_parent_map,
451
  )
452
+ selected_other = _dedupe_norm_tags(selected_other)
453
  selected_other_meta = {
454
  t: {
455
  "origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")),
 
476
  tag_selection_origins=tag_selection_origins,
477
  implied_parent_map=implied_parent_map,
478
  )
479
+ ranked_tags = [
480
+ _norm_tag_for_lookup(t)
481
+ for t, _ in row.tags
482
+ if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)
483
+ ]
484
+ ranked_tags = _dedupe_norm_tags(ranked_tags)
485
  merged = selected_in_group + [t for t in ranked_tags if t not in selected_in_group]
486
+ merged = _dedupe_norm_tags(merged)
487
  keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_group))
488
  merged = merged[:keep_n]
489
  tag_meta = {
 
569
  def _build_row_component_updates(
570
  row_defs: List[Dict[str, Any]],
571
  selected_tags: List[str],
572
+ max_rows: int,
573
+ ):
574
+ selected = {t for t in (selected_tags or []) if t}
575
+ row_defs_ui = (row_defs or [])[: max(0, int(max_rows))]
576
+ row_values_state: List[List[str]] = []
577
+ header_updates = []
578
+ checkbox_updates = []
579
+
580
+ for idx in range(max_rows):
581
+ if idx < len(row_defs_ui):
582
+ row = row_defs_ui[idx]
583
+ tags = _dedupe_norm_tags(row.get("tags", []))
584
  values = [t for t in tags if t in selected]
585
  row_values_state.append(values)
586
  visible = bool(tags)
 
599
  visible=visible,
600
  )
601
  )
602
+ else:
603
+ header_updates.append(gr.update(value="", visible=False))
604
+ checkbox_updates.append(gr.update(choices=[], value=[], visible=False))
605
+
606
+ prompt_text = _compose_toggle_prompt_text(list(selected), row_defs_ui)
607
+ return prompt_text, row_values_state, header_updates, checkbox_updates
608
 
609
 
610
  def _on_toggle_row(
611
  row_idx: int,
612
  changed_values: List[str],
613
  selected_tags_state: List[str],
614
+ rows_dirty_state: bool,
615
+ row_defs_state: List[Dict[str, Any]],
616
+ row_values_state: List[List[str]],
617
+ max_rows: int,
618
  ):
619
  row_defs = row_defs_state or []
620
  row_defs_ui = row_defs[: max(0, int(max_rows))]
621
+ prev_values = list(row_values_state or [])
622
+ selected_from_state = _collect_selected_from_state(selected_tags_state, row_defs_ui)
623
+ selected_from_rows = _collect_selected_from_row_values(row_defs_ui, prev_values)
624
+ # Prefer row-value state as source-of-truth (closest to visible UI), with selected-state as fallback.
625
+ selected: Set[str] = set(selected_from_rows or selected_from_state)
626
+
627
  row = row_defs_ui[row_idx] if 0 <= row_idx < len(row_defs_ui) else {}
628
+ row_tags = _dedupe_norm_tags(row.get("tags", []))
629
+ row_label = str(row.get("label", ""))
630
  row_tag_set = set(row_tags)
631
  row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags}
632
 
633
+ # Be tolerant to UI payload forms: canonical tag values, display labels, normalized variants,
634
+ # and occasional single-string payloads from frontend events.
635
+ if changed_values is None:
636
+ changed_iter: List[Any] = []
637
+ elif isinstance(changed_values, str):
638
+ changed_iter = [changed_values]
639
+ elif isinstance(changed_values, (list, tuple, set)):
640
+ changed_iter = list(changed_values)
641
+ else:
642
+ changed_iter = [changed_values]
643
+
644
  # Be tolerant to UI payload forms: canonical tag values, display labels, or normalized variants.
645
  new_set: Set[str] = set()
646
+ for raw in changed_iter:
647
  if raw in row_tag_set:
648
  new_set.add(raw)
649
  continue
 
652
  if mapped:
653
  new_set.add(mapped)
654
 
655
+ prev_row_selected = {t for t in row_tags if t in selected}
 
 
 
 
 
 
 
 
 
 
656
 
657
  # Ignore non-user/no-op events (e.g., programmatic value re-sets) deterministically.
658
  if new_set == prev_row_selected:
659
  prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui)
660
  checkbox_updates = [gr.skip() for _ in range(max_rows)]
661
+ return [sorted(selected), rows_dirty_state, gr.skip(), prev_values, prompt_text, *checkbox_updates]
662
 
663
  selected.difference_update(row_tag_set)
664
  selected.update(new_set)
 
667
  new_row_values_state: List[List[str]] = []
668
  affected_rows: Set[int] = {row_idx}
669
  for idx, row_item in enumerate(row_defs_ui):
670
+ tags = _dedupe_norm_tags(row_item.get("tags", []))
671
  values = [t for t in tags if t in selected]
672
  new_row_values_state.append(values)
673
  if toggled_tags and any(t in toggled_tags for t in tags):
 
684
  checkbox_updates.append(gr.skip())
685
 
686
  prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui)
687
+ return [
688
+ sorted(selected),
689
+ True,
690
+ gr.update(visible=True, interactive=True),
691
+ new_row_values_state,
692
+ prompt_text,
693
+ *checkbox_updates,
694
+ ]
695
 
696
 
697
  def _build_ui_payload(
 
700
  row_defs: List[Dict[str, Any]],
701
  selected_tags: List[str],
702
  ):
703
+ prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates(
704
+ row_defs=row_defs,
705
+ selected_tags=selected_tags,
706
+ max_rows=display_max_rows_default,
707
+ )
708
+ selected_ui: List[str] = []
709
+ selected_ui_seen: Set[str] = set()
710
+ for vals in row_values_state:
711
+ for t in vals:
712
+ if t in selected_ui_seen:
713
+ continue
714
+ selected_ui_seen.add(t)
715
+ selected_ui.append(t)
716
  return [
717
  console_text,
718
  gr.update(visible=bool(row_defs)),
719
  prompt_text,
720
+ selected_ui,
721
+ False,
722
+ gr.update(visible=False, interactive=False),
723
  row_defs,
724
  row_values_state,
725
  *header_updates,
726
+ *checkbox_updates,
727
  ]
728
 
729
 
 
738
  gr.skip(),
739
  "Running... usually completes in about 20 seconds.",
740
  [],
741
+ False,
742
+ gr.update(visible=False, interactive=False),
743
  [],
744
  [],
745
  *header_updates,
 
747
  ]
748
 
749
 
750
+ def _rebuild_rows_from_selected(
751
+ selected_tags_state: List[str],
752
+ row_defs_state: List[Dict[str, Any]],
753
+ row_values_state: List[List[str]],
754
+ display_top_groups: float,
755
+ display_top_tags_per_group: float,
756
+ display_rank_top_k: float,
757
+ ):
758
+ existing_rows = row_defs_state or []
759
+ existing_values = list(row_values_state or [])
760
+ selected_from_state = _collect_selected_from_state(selected_tags_state, existing_rows)
761
+ selected_from_rows = _collect_selected_from_row_values(existing_rows, existing_values)
762
+ # Rebuild source-of-truth is current row checkbox values; fall back only when unavailable.
763
+ selected_seed = selected_from_rows if existing_values else selected_from_state
764
+ selected_active = list(
765
+ dict.fromkeys(
766
+ _norm_tag_for_lookup(t)
767
+ for t in selected_seed
768
+ if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)
769
+ )
770
+ )
771
+
772
+ retrieved_candidate_tags: List[str] = []
773
+ tag_selection_origins: Dict[str, str] = {}
774
+ for row in existing_rows:
775
+ row_tags = row.get("tags", []) if isinstance(row, dict) else []
776
+ row_meta = row.get("tag_meta", {}) if isinstance(row, dict) else {}
777
+ if not isinstance(row_meta, dict):
778
+ row_meta = {}
779
+ for t in row_tags:
780
+ tn = _norm_tag_for_lookup(t)
781
+ if not tn or _is_artist_tag(tn) or _is_excluded_recommendation_tag(tn):
782
+ continue
783
+ retrieved_candidate_tags.append(tn)
784
+ if tn not in tag_selection_origins:
785
+ meta = row_meta.get(t, {}) if isinstance(row_meta.get(t, {}), dict) else {}
786
+ tag_selection_origins[tn] = _normalize_selection_origin(str(meta.get("origin", "selection")))
787
+
788
+ for t in selected_active:
789
+ tag_selection_origins.setdefault(t, "user")
790
+ retrieved_candidate_tags.append(t)
791
+
792
+ implied_selected_tags = [t for t in selected_active if tag_selection_origins.get(t) == "implied"]
793
+ implied_set = set(implied_selected_tags)
794
+ direct_selected_tags = [t for t in selected_active if t not in implied_set]
795
+ direct_idx = {t: i for i, t in enumerate(direct_selected_tags)}
796
+ direct_selected_tags.sort(
797
+ key=lambda t: (
798
+ _selection_source_rank(tag_selection_origins.get(t, "selection")),
799
+ direct_idx.get(t, 10**9),
800
+ )
801
+ )
802
+ implied_parent_map = _build_implied_parent_map(
803
+ direct_tags_ordered=direct_selected_tags,
804
+ implied_tags=implied_selected_tags,
805
+ )
806
+
807
+ toggle_rows = _build_toggle_rows(
808
+ seed_terms=list(selected_active),
809
+ selected_tags=selected_active,
810
+ retrieved_candidate_tags=list(dict.fromkeys(retrieved_candidate_tags)),
811
+ tag_selection_origins=tag_selection_origins,
812
+ implied_parent_map=implied_parent_map,
813
+ top_groups=max(1, int(display_top_groups)),
814
+ top_tags_per_group=max(1, int(display_top_tags_per_group)),
815
+ group_rank_top_k=max(1, int(display_rank_top_k)),
816
+ )
817
+
818
+ prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates(
819
+ row_defs=toggle_rows,
820
+ selected_tags=selected_active,
821
+ max_rows=display_max_rows_default,
822
+ )
823
+
824
+ return [
825
+ gr.update(visible=bool(toggle_rows)),
826
+ prompt_text,
827
+ sorted(selected_active),
828
+ False,
829
+ gr.update(visible=False, interactive=False),
830
+ toggle_rows,
831
+ row_values_state,
832
+ *header_updates,
833
+ *checkbox_updates,
834
+ ]
835
+
836
+
837
  def _build_selection_query(
838
  prompt_in: str,
839
  rewritten: str,
 
1255
  """
1256
 
1257
  client_js = """
1258
+ () => {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1259
  """
1260
 
1261
 
1262
  def rag_pipeline_ui(
1263
+ user_prompt: str,
1264
+ display_top_groups: float,
1265
+ display_top_tags_per_group: float,
1266
+ display_rank_top_k: float,
1267
+ ):
1268
+ logs = []
1269
+ def log(s): logs.append(s)
1270
 
1271
  try:
1272
  stage_timings = {}
 
1720
  top_tags_per_group=max(1, int(display_top_tags_per_group)),
1721
  group_rank_top_k=max(1, int(display_rank_top_k)),
1722
  )
1723
+ dt = time.perf_counter()-t0
1724
+ _record_timing("group_display", dt)
1725
+ log(f"Ranked group display: {dt:.2f}s ({len(toggle_rows)} rows)")
1726
  log(
1727
  _build_display_audit_line(
1728
  toggle_rows,
 
1731
  implied_selected_tags=implied_selected_tags,
1732
  )
1733
  )
1734
+ for idx, row in enumerate(toggle_rows[: max(0, int(display_max_rows_default))]):
1735
+ tags_preview = ", ".join(row.get("tags", []))
1736
+ log(f"UI Row {idx}: {row.get('label', '')} :: {tags_preview}")
1737
 
1738
  total_dt = time.perf_counter()-t_total0
1739
  _emit_timing_summary(total_dt)
 
1795
  gr.Markdown("Typical runtime: up to ~20 seconds.", elem_classes=["run-hint"])
1796
 
1797
  selected_tags_state = gr.State([])
1798
+ rows_dirty_state = gr.State(False)
1799
  row_defs_state = gr.State([])
1800
  row_values_state = gr.State([])
1801
 
 
1822
  )
1823
  )
1824
 
1825
+ with gr.Row():
1826
+ with gr.Column(scale=10):
1827
+ gr.HTML(
1828
+ """
1829
+ <div class="source-legend">
1830
+ <span class="legend-title">Legend:</span>
1831
+ <span class="chip rewrite">Rewrite phrase</span>
1832
+ <span class="chip selection">General selection</span>
1833
+ <span class="chip probe">Probe query</span>
1834
+ <span class="chip structural">Structural query</span>
1835
+ <span class="chip implied">Implied</span>
1836
+ <span class="chip user">User-toggled</span>
1837
+ <span class="chip unselected">Unselected</span>
1838
+ </div>
1839
+ """
1840
+ )
1841
+ with gr.Column(scale=2, min_width=180):
1842
+ rebuild_rows_button = gr.Button(
1843
+ "Rebuild Rows",
1844
+ variant="primary",
1845
+ visible=False,
1846
+ interactive=False,
1847
+ )
1848
 
1849
  with gr.Accordion("Display Settings", open=False):
1850
  with gr.Row():
 
1880
  toggle_instruction,
1881
  suggested_prompt,
1882
  selected_tags_state,
1883
+ rows_dirty_state,
1884
+ rebuild_rows_button,
1885
  row_defs_state,
1886
  row_values_state,
1887
+ *row_headers,
1888
+ *row_checkboxes,
1889
+ ]
1890
 
1891
  submit_button.click(
1892
  _prepare_run_ui,
 
1913
  )
1914
 
1915
  for idx, row_cb in enumerate(row_checkboxes):
1916
+ row_cb.change(
1917
+ fn=lambda changed_values, selected_state, rows_dirty, row_defs, row_values, i=idx: _on_toggle_row(
1918
  i,
1919
  changed_values,
1920
  selected_state,
1921
+ rows_dirty,
1922
  row_defs,
1923
+ row_values,
1924
+ display_max_rows_default,
1925
  ),
1926
+ inputs=[row_cb, selected_tags_state, rows_dirty_state, row_defs_state, row_values_state],
1927
+ outputs=[selected_tags_state, rows_dirty_state, rebuild_rows_button, row_values_state, suggested_prompt, *row_checkboxes],
1928
+ queue=False,
1929
+ show_progress="hidden",
1930
+ )
1931
+
1932
+ rebuild_rows_button.click(
1933
+ _rebuild_rows_from_selected,
1934
+ inputs=[selected_tags_state, row_defs_state, row_values_state, display_top_groups, display_top_tags_per_group, display_rank_top_k],
1935
+ outputs=[
1936
+ toggle_instruction,
1937
+ suggested_prompt,
1938
+ selected_tags_state,
1939
+ rows_dirty_state,
1940
+ rebuild_rows_button,
1941
+ row_defs_state,
1942
+ row_values_state,
1943
+ *row_headers,
1944
+ *row_checkboxes,
1945
+ ],
1946
  queue=False,
1947
  show_progress="hidden",
1948
  )
scripts/smoke_ui_state.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
5
+ import app
6
+
7
+
8
+ def _assert(cond: bool, msg: str) -> None:
9
+ if not cond:
10
+ raise AssertionError(msg)
11
+
12
+
13
+ def test_prompt_uses_visible_rows_only() -> None:
14
+ # If selected state contains stale hidden tags, prompt should still reflect visible-row selections only.
15
+ row_defs = [
16
+ {"name": "r1", "label": "R1", "tags": ["solo", "female"], "tag_meta": {}},
17
+ {"name": "r2", "label": "R2", "tags": ["cub"], "tag_meta": {}},
18
+ ]
19
+ payload = app._build_ui_payload(
20
+ console_text="x",
21
+ row_defs=row_defs,
22
+ selected_tags=["solo", "rosalina_(mario)"],
23
+ )
24
+ prompt_text = payload[2]
25
+ selected_state = payload[3]
26
+ _assert("rosalina \\(mario\\)" not in prompt_text, "stale hidden tag leaked into prompt")
27
+ _assert("solo" in prompt_text, "visible selected tag missing from prompt")
28
+ _assert("rosalina_(mario)" not in selected_state, "stale hidden tag leaked into selected state")
29
+
30
+
31
+ def test_row_deduping() -> None:
32
+ row_defs = [
33
+ {
34
+ "name": "other_retrieved",
35
+ "label": "Other (Retrieved)",
36
+ "tags": ["cub", "expressions", "invalid_tag", "cub", "expressions"],
37
+ "tag_meta": {},
38
+ }
39
+ ]
40
+ prompt_text, row_values_state, _, checkbox_updates = app._build_row_component_updates(
41
+ row_defs=row_defs,
42
+ selected_tags=["cub", "expressions"],
43
+ max_rows=app.display_max_rows_default,
44
+ )
45
+ _assert(prompt_text == "cub, expressions", "prompt should be deduped and ordered from row")
46
+ _assert(row_values_state[0] == ["cub", "expressions"], "row selected values should be deduped")
47
+ first_choices = checkbox_updates[0]["choices"]
48
+ first_values = [v for _, v in first_choices]
49
+ _assert(first_values == ["cub", "expressions", "invalid_tag"], "row choices should be deduped")
50
+
51
+
52
+ def test_rebuild_ignores_stale_selected_state() -> None:
53
+ row_defs = [
54
+ {"name": "selected_other", "label": "Selected (Other)", "tags": ["solo", "female", "anthro"], "tag_meta": {}},
55
+ {"name": "other_retrieved", "label": "Other (Retrieved)", "tags": ["cub", "expressions"], "tag_meta": {}},
56
+ ]
57
+ # Simulate UI state where user has deselected anthro, but stale selected state still contains it.
58
+ selected_state = ["solo", "female", "anthro", "cub"]
59
+ row_values_state = [["solo", "female"], ["cub"]]
60
+ out = app._rebuild_rows_from_selected(
61
+ selected_state,
62
+ row_defs,
63
+ row_values_state,
64
+ app.display_top_groups_default,
65
+ app.display_top_tags_per_group_default,
66
+ app.display_rank_top_k_default,
67
+ )
68
+ prompt = out[1]
69
+ selected_after = out[2]
70
+ _assert("anthro" not in selected_after, "rebuild should not resurrect stale deselected tags")
71
+ _assert("anthro" not in prompt, "prompt should not include stale deselected tags")
72
+ _assert("solo" in prompt and "female" in prompt and "cub" in prompt, "rebuild should retain current row selections")
73
+
74
+
75
+ def test_toggle_then_rebuild_does_not_resurrect_removed_tag() -> None:
76
+ row_defs = [
77
+ {"name": "selected_other", "label": "Selected (Other)", "tags": ["solo", "anthro", "female"], "tag_meta": {}},
78
+ {"name": "other_retrieved", "label": "Other (Retrieved)", "tags": ["cub", "expressions"], "tag_meta": {}},
79
+ ]
80
+ selected_state = ["solo", "anthro", "female", "cub"]
81
+ row_values_state = [["solo", "anthro", "female"], ["cub"]]
82
+
83
+ # User unchecks anthro in row 0.
84
+ toggle_out = app._on_toggle_row(
85
+ 0,
86
+ ["solo", "female"],
87
+ selected_state,
88
+ False,
89
+ row_defs,
90
+ row_values_state,
91
+ app.display_max_rows_default,
92
+ )
93
+ selected_after_toggle = toggle_out[0]
94
+ row_values_after_toggle = toggle_out[3]
95
+ _assert("anthro" not in selected_after_toggle, "toggle should remove anthro from selected state")
96
+
97
+ # Rebuild from current row values must preserve the user-toggle result.
98
+ rebuild_out = app._rebuild_rows_from_selected(
99
+ selected_after_toggle,
100
+ row_defs,
101
+ row_values_after_toggle,
102
+ app.display_top_groups_default,
103
+ app.display_top_tags_per_group_default,
104
+ app.display_rank_top_k_default,
105
+ )
106
+ prompt_after_rebuild = rebuild_out[1]
107
+ selected_after_rebuild = rebuild_out[2]
108
+ _assert("anthro" not in selected_after_rebuild, "rebuild should not resurrect deselected anthro")
109
+ _assert("anthro" not in prompt_after_rebuild, "prompt should not contain deselected anthro after rebuild")
110
+ _assert("solo" in prompt_after_rebuild and "female" in prompt_after_rebuild, "kept selections should remain")
111
+ _assert("cub" in prompt_after_rebuild, "other retrieved selection should remain")
112
+
113
+
114
+ def test_toggle_does_not_cross_activate_unrelated_row_tag() -> None:
115
+ row_defs = [
116
+ {"name": "organization", "label": "Organization", "tags": ["pinup", "close-up"], "tag_meta": {}},
117
+ {"name": "color_markings", "label": "Color Markings", "tags": ["shoulder_markings", "black_markings"], "tag_meta": {}},
118
+ ]
119
+ selected_state = []
120
+ row_values_state = [[], []]
121
+
122
+ # User enables close-up in organization row.
123
+ out = app._on_toggle_row(
124
+ 0,
125
+ ["close-up"],
126
+ selected_state,
127
+ False,
128
+ row_defs,
129
+ row_values_state,
130
+ app.display_max_rows_default,
131
+ )
132
+ selected_after = out[0]
133
+ row_values_after = out[3]
134
+ _assert("close-up" in selected_after, "close-up should be selected")
135
+ _assert("shoulder_markings" not in selected_after, "unrelated row tag should not be auto-selected")
136
+ _assert(row_values_after[0] == ["close-up"], "organization row values should include close-up only")
137
+ _assert(row_values_after[1] == [], "color markings row should remain unselected")
138
+
139
+
140
+ def test_shared_tag_mirrors_without_unrelated_cross_toggle() -> None:
141
+ row_defs = [
142
+ {"name": "objects_props", "label": "Objects Props", "tags": ["holding_face", "holding_clothing"], "tag_meta": {}},
143
+ {"name": "expression_detail", "label": "Expression Detail", "tags": ["open_mouth", "closed_smile"], "tag_meta": {}},
144
+ {"name": "pose_action_detail", "label": "Pose Action Detail", "tags": ["holding_face", "walking"], "tag_meta": {}},
145
+ ]
146
+ selected_state = []
147
+ row_values_state = [[], [], []]
148
+
149
+ # Enable open_mouth; should not affect holding_face rows.
150
+ out1 = app._on_toggle_row(
151
+ 1,
152
+ ["open_mouth"],
153
+ selected_state,
154
+ False,
155
+ row_defs,
156
+ row_values_state,
157
+ app.display_max_rows_default,
158
+ )
159
+ sel1 = out1[0]
160
+ vals1 = out1[3]
161
+ _assert("open_mouth" in sel1, "open_mouth should be selected")
162
+ _assert("holding_face" not in sel1, "holding_face must remain unselected")
163
+ _assert(vals1[0] == [], "objects props row should remain unselected")
164
+ _assert(vals1[1] == ["open_mouth"], "expression row should select open_mouth")
165
+ _assert(vals1[2] == [], "pose row should remain unselected")
166
+
167
+ # Enable holding_face in objects row; should mirror only to pose row, not expression row.
168
+ out2 = app._on_toggle_row(
169
+ 0,
170
+ ["holding_face"],
171
+ sel1,
172
+ True,
173
+ row_defs,
174
+ vals1,
175
+ app.display_max_rows_default,
176
+ )
177
+ sel2 = out2[0]
178
+ vals2 = out2[3]
179
+ _assert("holding_face" in sel2 and "open_mouth" in sel2, "both explicitly selected tags should be present")
180
+ _assert(vals2[0] == ["holding_face"], "objects row should select holding_face")
181
+ _assert(vals2[1] == ["open_mouth"], "expression row should keep open_mouth only")
182
+ _assert(vals2[2] == ["holding_face"], "pose row should mirror holding_face")
183
+
184
+
185
+ def main() -> None:
186
+ test_prompt_uses_visible_rows_only()
187
+ test_row_deduping()
188
+ test_rebuild_ignores_stale_selected_state()
189
+ test_toggle_then_rebuild_does_not_resurrect_removed_tag()
190
+ test_toggle_does_not_cross_activate_unrelated_row_tag()
191
+ test_shared_tag_mirrors_without_unrelated_cross_toggle()
192
+ print("ui state smoke: ok")
193
+
194
+
195
+ if __name__ == "__main__":
196
+ main()