Food Desert commited on
Commit
a48a025
·
1 Parent(s): 334af6b

Add synchronized lego-style tag toggles and prompt builder UI

Browse files
Files changed (1) hide show
  1. app.py +299 -16
app.py CHANGED
@@ -6,7 +6,7 @@ import json
6
  from datetime import datetime
7
  from PIL import Image
8
  from pathlib import Path
9
- from typing import List
10
  from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
11
 
12
  from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words
@@ -14,7 +14,7 @@ from psq_rag.llm.rewrite import llm_rewrite_prompt
14
  from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup
15
  from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags, llm_infer_probe_tags
16
  from psq_rag.retrieval.state import expand_tags_via_implications
17
- from psq_rag.ui.group_ranked_display import render_group_rankings_markdown
18
 
19
 
20
  def _split_prompt_commas(s: str) -> List[str]:
@@ -40,6 +40,181 @@ def compose_final_prompt(rewritten_prompt: str, selected_tags: List[str]) -> str
40
  return ", ".join(out)
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def _build_selection_query(
44
  prompt_in: str,
45
  rewritten: str,
@@ -152,6 +327,7 @@ enable_probe_tags = os.environ.get("PSQ_ENABLE_PROBE", "1").strip() not in {"0",
152
  display_top_groups_default = int(os.environ.get("PSQ_DISPLAY_TOP_GROUPS", "10"))
153
  display_top_tags_per_group_default = int(os.environ.get("PSQ_DISPLAY_TOP_TAGS_PER_GROUP", "5"))
154
  display_rank_top_k_default = int(os.environ.get("PSQ_DISPLAY_GROUP_RANK_TOP_K", "5"))
 
155
  retrieval_global_k = int(os.environ.get("PSQ_RETRIEVAL_GLOBAL_K", "300"))
156
  retrieval_per_phrase_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_K", "10"))
157
  retrieval_per_phrase_final_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_FINAL_K", "1"))
@@ -189,6 +365,42 @@ css = """
189
  .pane-right .scrollable-content {
190
  max-height: 610px; /* was 420px; tweak to taste */
191
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  """
193
 
194
 
@@ -275,7 +487,12 @@ def rag_pipeline_ui(
275
  log("Start: received prompt")
276
  prompt_in = (user_prompt or "").strip()
277
  if not prompt_in:
278
- return "Error: empty prompt", "", ""
 
 
 
 
 
279
 
280
  log("Input:")
281
  log(prompt_in)
@@ -439,6 +656,8 @@ def rag_pipeline_ui(
439
  elif enable_probe_tags:
440
  log(" No probe tags inferred")
441
 
 
 
442
  log("Step 3c: Expand via tag implications")
443
  t0 = time.perf_counter()
444
  tag_set = set(selected_tags)
@@ -469,25 +688,36 @@ def rag_pipeline_ui(
469
  seed_terms.extend(selected_tags)
470
  seed_terms = list(dict.fromkeys(seed_terms))
471
 
472
- groups_md = render_group_rankings_markdown(
473
  seed_terms=seed_terms,
 
474
  top_groups=max(1, int(display_top_groups)),
475
  top_tags_per_group=max(1, int(display_top_tags_per_group)),
476
  group_rank_top_k=max(1, int(display_rank_top_k)),
477
  )
478
  dt = time.perf_counter()-t0
479
  _record_timing("group_display", dt)
480
- log(f"Ranked group display: {dt:.2f}s")
481
 
482
  total_dt = time.perf_counter()-t_total0
483
  _emit_timing_summary(total_dt)
484
  _append_timing_jsonl(total_dt)
485
  log("Done: final prompt ready")
486
- return "\n".join(logs), final_prompt, groups_md
 
 
 
 
 
487
 
488
  except Exception as e:
489
  log(f"Error: {type(e).__name__}: {e}")
490
- return "\n".join(logs), "", ""
 
 
 
 
 
491
 
492
 
493
 
@@ -529,11 +759,44 @@ then returns a cleaned, model-friendly prompt.
529
  placeholder="Progress logs will appear here."
530
  )
531
 
532
- final_prompt = gr.Textbox(
533
- label="Final Prompt",
534
  lines=3,
535
  interactive=False,
536
- placeholder="Your optimized prompt will appear here."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
  )
538
 
539
  with gr.Accordion("Display Settings", open=False):
@@ -557,22 +820,42 @@ then returns a cleaned, model-friendly prompt.
557
  minimum=1,
558
  )
559
 
560
- group_rankings_md = gr.Markdown(
561
- label="Ranked Group/Category Tag Suggestions",
562
- value="",
563
- )
 
 
 
 
 
 
564
 
565
  submit_button.click(
566
  rag_pipeline_ui,
567
  inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k],
568
- outputs=[console, final_prompt, group_rankings_md]
569
  )
570
 
571
  image_tags.submit(
572
  rag_pipeline_ui,
573
  inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k],
574
- outputs=[console, final_prompt, group_rankings_md]
575
  )
576
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  if __name__ == "__main__":
578
  app.queue().launch(allowed_paths=[str(MASCOT_DIR)])
 
6
  from datetime import datetime
7
  from PIL import Image
8
  from pathlib import Path
9
+ from typing import Any, Dict, List, Set
10
  from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
11
 
12
  from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words
 
14
  from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup
15
  from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags, llm_infer_probe_tags
16
  from psq_rag.retrieval.state import expand_tags_via_implications
17
+ from psq_rag.ui.group_ranked_display import rank_groups_from_tfidf, _load_enabled_groups
18
 
19
 
20
  def _split_prompt_commas(s: str) -> List[str]:
 
40
  return ", ".join(out)
41
 
42
 
43
+ def _display_tag_text(tag: str) -> str:
44
+ return tag.replace("_", " ")
45
+
46
+
47
+ def _escape_prompt_tag(tag: str) -> str:
48
+ return (
49
+ tag.replace("_", " ")
50
+ .replace("(", "\\(")
51
+ .replace(")", "\\)")
52
+ )
53
+
54
+
55
+ def _ordered_selected_for_prompt(selected: Set[str], row_defs: List[Dict[str, Any]]) -> List[str]:
56
+ out: List[str] = []
57
+ seen: Set[str] = set()
58
+ for row in row_defs:
59
+ for tag in row.get("tags", []):
60
+ if tag in selected and tag not in seen:
61
+ out.append(tag)
62
+ seen.add(tag)
63
+ # Fallback for any selected tags not present in current rows.
64
+ for tag in sorted(selected):
65
+ if tag not in seen:
66
+ out.append(tag)
67
+ seen.add(tag)
68
+ return out
69
+
70
+
71
+ def _compose_toggle_prompt_text(selected_tags: List[str], row_defs: List[Dict[str, Any]]) -> str:
72
+ selected = {t for t in (selected_tags or []) if t}
73
+ ordered = _ordered_selected_for_prompt(selected, row_defs or [])
74
+ return ", ".join(_escape_prompt_tag(t) for t in ordered)
75
+
76
+
77
+ def _build_toggle_rows(
78
+ *,
79
+ seed_terms: List[str],
80
+ llm_selected_tags: List[str],
81
+ top_groups: int,
82
+ top_tags_per_group: int,
83
+ group_rank_top_k: int,
84
+ ) -> List[Dict[str, Any]]:
85
+ ranked_rows = rank_groups_from_tfidf(
86
+ seed_terms=seed_terms,
87
+ top_groups=max(1, int(top_groups)),
88
+ top_tags_per_group=max(1, int(top_tags_per_group)),
89
+ group_rank_top_k=max(1, int(group_rank_top_k)),
90
+ )
91
+ groups_map = _load_enabled_groups()
92
+ llm_selected = list(dict.fromkeys(_norm_tag_for_lookup(t) for t in llm_selected_tags if t))
93
+
94
+ row_defs: List[Dict[str, Any]] = []
95
+ displayed_group_names = [r.group_name for r in ranked_rows]
96
+ displayed_group_tag_sets: Dict[str, Set[str]] = {
97
+ name: set(groups_map.get(name, [])) for name in displayed_group_names
98
+ }
99
+ tags_in_any_displayed_group: Set[str] = set()
100
+ for tag_set in displayed_group_tag_sets.values():
101
+ tags_in_any_displayed_group.update(tag_set)
102
+
103
+ llm_other = [t for t in llm_selected if t not in tags_in_any_displayed_group]
104
+ row_defs.append(
105
+ {
106
+ "name": "llm_selected_other",
107
+ "label": "LLM Selected (Other)",
108
+ "tags": llm_other,
109
+ }
110
+ )
111
+
112
+ for row in ranked_rows:
113
+ group_name = row.group_name
114
+ group_tag_set = displayed_group_tag_sets.get(group_name, set())
115
+ selected_in_group = [t for t in llm_selected if t in group_tag_set]
116
+ ranked_tags = [t for t, _ in row.tags]
117
+ merged = selected_in_group + [t for t in ranked_tags if t not in selected_in_group]
118
+ keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_group))
119
+ merged = merged[:keep_n]
120
+ row_defs.append(
121
+ {
122
+ "name": group_name,
123
+ "label": f"{group_name} (E={row.expected_count:.2f})",
124
+ "tags": merged,
125
+ }
126
+ )
127
+
128
+ return row_defs
129
+
130
+
131
+ def _build_row_component_updates(
132
+ row_defs: List[Dict[str, Any]],
133
+ selected_tags: List[str],
134
+ max_rows: int,
135
+ ):
136
+ selected = {t for t in (selected_tags or []) if t}
137
+ row_values_state: List[List[str]] = []
138
+ header_updates = []
139
+ checkbox_updates = []
140
+
141
+ for idx in range(max_rows):
142
+ if idx < len(row_defs):
143
+ row = row_defs[idx]
144
+ tags = list(dict.fromkeys(row.get("tags", [])))
145
+ values = [t for t in tags if t in selected]
146
+ row_values_state.append(values)
147
+ visible = bool(tags)
148
+ header_updates.append(gr.update(value=f"**{row.get('label', '')}**", visible=visible))
149
+ choices = [(_display_tag_text(t), t) for t in tags]
150
+ checkbox_updates.append(
151
+ gr.update(
152
+ choices=choices,
153
+ value=values,
154
+ visible=visible,
155
+ )
156
+ )
157
+ else:
158
+ header_updates.append(gr.update(value="", visible=False))
159
+ checkbox_updates.append(gr.update(choices=[], value=[], visible=False))
160
+
161
+ prompt_text = _compose_toggle_prompt_text(list(selected), row_defs)
162
+ return prompt_text, row_values_state, header_updates, checkbox_updates
163
+
164
+
165
+ def _on_toggle_row(
166
+ row_idx: int,
167
+ changed_values: List[str],
168
+ selected_tags_state: List[str],
169
+ row_defs_state: List[Dict[str, Any]],
170
+ row_values_state: List[List[str]],
171
+ max_rows: int,
172
+ ):
173
+ row_defs = row_defs_state or []
174
+ selected = set(selected_tags_state or [])
175
+ prev_values = list(row_values_state or [])
176
+
177
+ while len(prev_values) < len(row_defs):
178
+ prev_values.append([])
179
+
180
+ prev_set = set(prev_values[row_idx]) if row_idx < len(prev_values) else set()
181
+ new_set = set(changed_values or [])
182
+ selected.update(new_set - prev_set)
183
+ selected.difference_update(prev_set - new_set)
184
+
185
+ prompt_text, new_row_values_state, _header_updates, checkbox_updates = _build_row_component_updates(
186
+ row_defs=row_defs,
187
+ selected_tags=list(selected),
188
+ max_rows=max_rows,
189
+ )
190
+
191
+ return [sorted(selected), new_row_values_state, prompt_text, *checkbox_updates]
192
+
193
+
194
+ def _build_ui_payload(
195
+ *,
196
+ console_text: str,
197
+ legacy_prompt_text: str,
198
+ row_defs: List[Dict[str, Any]],
199
+ selected_tags: List[str],
200
+ ):
201
+ prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates(
202
+ row_defs=row_defs,
203
+ selected_tags=selected_tags,
204
+ max_rows=display_max_rows_default,
205
+ )
206
+ return [
207
+ console_text,
208
+ legacy_prompt_text,
209
+ prompt_text,
210
+ sorted(set(selected_tags or [])),
211
+ row_defs,
212
+ row_values_state,
213
+ *header_updates,
214
+ *checkbox_updates,
215
+ ]
216
+
217
+
218
  def _build_selection_query(
219
  prompt_in: str,
220
  rewritten: str,
 
327
  display_top_groups_default = int(os.environ.get("PSQ_DISPLAY_TOP_GROUPS", "10"))
328
  display_top_tags_per_group_default = int(os.environ.get("PSQ_DISPLAY_TOP_TAGS_PER_GROUP", "5"))
329
  display_rank_top_k_default = int(os.environ.get("PSQ_DISPLAY_GROUP_RANK_TOP_K", "5"))
330
+ display_max_rows_default = int(os.environ.get("PSQ_DISPLAY_MAX_ROWS", "14"))
331
  retrieval_global_k = int(os.environ.get("PSQ_RETRIEVAL_GLOBAL_K", "300"))
332
  retrieval_per_phrase_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_K", "10"))
333
  retrieval_per_phrase_final_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_FINAL_K", "1"))
 
365
  .pane-right .scrollable-content {
366
  max-height: 610px; /* was 420px; tweak to taste */
367
  }
368
+
369
+ .lego-tags .gr-checkboxgroup {
370
+ display: flex;
371
+ flex-wrap: wrap;
372
+ gap: 8px;
373
+ }
374
+
375
+ .lego-tags .gr-checkboxgroup label {
376
+ margin: 0;
377
+ padding: 0;
378
+ }
379
+
380
+ .lego-tags .gr-checkboxgroup input[type="checkbox"] {
381
+ display: none;
382
+ }
383
+
384
+ .lego-tags .gr-checkboxgroup span {
385
+ display: inline-block;
386
+ padding: 7px 12px;
387
+ border: 1px solid #8a8a8a;
388
+ border-radius: 10px;
389
+ background: #f4f4f4;
390
+ color: #222;
391
+ font-size: 0.95rem;
392
+ line-height: 1.2;
393
+ cursor: pointer;
394
+ user-select: none;
395
+ box-shadow: 0 1px 0 rgba(0,0,0,0.12), inset 0 1px 0 rgba(255,255,255,0.7);
396
+ }
397
+
398
+ .lego-tags .gr-checkboxgroup input[type="checkbox"]:checked + span {
399
+ background: #ffd86a;
400
+ border-color: #c49a00;
401
+ box-shadow: 0 2px 0 #a98000, inset 0 1px 0 rgba(255,255,255,0.65);
402
+ transform: translateY(1px);
403
+ }
404
  """
405
 
406
 
 
487
  log("Start: received prompt")
488
  prompt_in = (user_prompt or "").strip()
489
  if not prompt_in:
490
+ return _build_ui_payload(
491
+ console_text="Error: empty prompt",
492
+ legacy_prompt_text="",
493
+ row_defs=[],
494
+ selected_tags=[],
495
+ )
496
 
497
  log("Input:")
498
  log(prompt_in)
 
656
  elif enable_probe_tags:
657
  log(" No probe tags inferred")
658
 
659
+ llm_selected_tags = list(dict.fromkeys(selected_tags))
660
+
661
  log("Step 3c: Expand via tag implications")
662
  t0 = time.perf_counter()
663
  tag_set = set(selected_tags)
 
688
  seed_terms.extend(selected_tags)
689
  seed_terms = list(dict.fromkeys(seed_terms))
690
 
691
+ toggle_rows = _build_toggle_rows(
692
  seed_terms=seed_terms,
693
+ llm_selected_tags=llm_selected_tags,
694
  top_groups=max(1, int(display_top_groups)),
695
  top_tags_per_group=max(1, int(display_top_tags_per_group)),
696
  group_rank_top_k=max(1, int(display_rank_top_k)),
697
  )
698
  dt = time.perf_counter()-t0
699
  _record_timing("group_display", dt)
700
+ log(f"Ranked group display: {dt:.2f}s ({len(toggle_rows)} rows)")
701
 
702
  total_dt = time.perf_counter()-t_total0
703
  _emit_timing_summary(total_dt)
704
  _append_timing_jsonl(total_dt)
705
  log("Done: final prompt ready")
706
+ return _build_ui_payload(
707
+ console_text="\n".join(logs),
708
+ legacy_prompt_text=final_prompt,
709
+ row_defs=toggle_rows,
710
+ selected_tags=llm_selected_tags,
711
+ )
712
 
713
  except Exception as e:
714
  log(f"Error: {type(e).__name__}: {e}")
715
+ return _build_ui_payload(
716
+ console_text="\n".join(logs),
717
+ legacy_prompt_text="",
718
+ row_defs=[],
719
+ selected_tags=[],
720
+ )
721
 
722
 
723
 
 
759
  placeholder="Progress logs will appear here."
760
  )
761
 
762
+ suggested_prompt = gr.Textbox(
763
+ label="Suggested Prompt (From Toggled Tags)",
764
  lines=3,
765
  interactive=False,
766
+ show_copy_button=True,
767
+ placeholder="Comma-separated tags selected in the rows below."
768
+ )
769
+
770
+ with gr.Accordion("Legacy Pipeline Prompt (for reference)", open=False):
771
+ legacy_final_prompt = gr.Textbox(
772
+ label="Legacy Final Prompt",
773
+ lines=3,
774
+ interactive=False,
775
+ show_copy_button=True,
776
+ )
777
+
778
+ selected_tags_state = gr.State([])
779
+ row_defs_state = gr.State([])
780
+ row_values_state = gr.State([])
781
+
782
+ gr.Markdown("### Toggle Tag Rows")
783
+ row_headers: List[gr.Markdown] = []
784
+ row_checkboxes: List[gr.CheckboxGroup] = []
785
+ for _ in range(display_max_rows_default):
786
+ row_headers.append(gr.Markdown(value="", visible=False))
787
+ row_checkboxes.append(
788
+ gr.CheckboxGroup(
789
+ choices=[],
790
+ value=[],
791
+ visible=False,
792
+ interactive=True,
793
+ container=False,
794
+ elem_classes=["lego-tags"],
795
+ )
796
+ )
797
+
798
+ gr.Markdown(
799
+ "Toggling a tag in any row toggles it everywhere else that tag appears."
800
  )
801
 
802
  with gr.Accordion("Display Settings", open=False):
 
820
  minimum=1,
821
  )
822
 
823
+ run_outputs = [
824
+ console,
825
+ legacy_final_prompt,
826
+ suggested_prompt,
827
+ selected_tags_state,
828
+ row_defs_state,
829
+ row_values_state,
830
+ *row_headers,
831
+ *row_checkboxes,
832
+ ]
833
 
834
  submit_button.click(
835
  rag_pipeline_ui,
836
  inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k],
837
+ outputs=run_outputs
838
  )
839
 
840
  image_tags.submit(
841
  rag_pipeline_ui,
842
  inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k],
843
+ outputs=run_outputs
844
  )
845
 
846
+ for idx, row_cb in enumerate(row_checkboxes):
847
+ row_cb.change(
848
+ fn=lambda changed_values, selected_state, row_defs, row_values, i=idx: _on_toggle_row(
849
+ i,
850
+ changed_values,
851
+ selected_state,
852
+ row_defs,
853
+ row_values,
854
+ display_max_rows_default,
855
+ ),
856
+ inputs=[row_cb, selected_tags_state, row_defs_state, row_values_state],
857
+ outputs=[selected_tags_state, row_values_state, suggested_prompt, *row_checkboxes],
858
+ )
859
+
860
  if __name__ == "__main__":
861
  app.queue().launch(allowed_paths=[str(MASCOT_DIR)])