Food Desert
Refine UI guidance, tag toggle cues, and hover tooltip metadata
3aa1163
Raw
History Blame
61.5 kB
import gradio as gr
import os
import logging
import time
import json
import csv
import base64
from datetime import datetime
from functools import lru_cache
from PIL import Image
from pathlib import Path
from typing import Any, Dict, List, Set, Tuple
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words
from psq_rag.llm.rewrite import llm_rewrite_prompt
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup
from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags, llm_infer_probe_tags
from psq_rag.retrieval.state import (
expand_tags_via_implications,
get_tag_type_name,
get_tag_implications,
get_tag_counts,
)
from psq_rag.ui.group_ranked_display import rank_groups_from_tfidf, _load_enabled_groups
def _split_prompt_commas(s: str) -> List[str]:
return [p.strip() for p in (s or "").split(",") if p.strip()]
def _norm_for_dedupe(tag: str) -> str:
# your canonical form for lookup/dedupe
return _norm_tag_for_lookup(tag.lower())
def compose_final_prompt(rewritten_prompt: str, selected_tags: List[str]) -> str:
parts = _split_prompt_commas(rewritten_prompt)
parts.extend(selected_tags)
seen = set()
out = []
for p in parts:
key = _norm_for_dedupe(p)
if key in seen:
continue
seen.add(key)
out.append(p)
return ", ".join(out)
def _display_tag_text(tag: str) -> str:
return tag.replace("_", " ")
def _display_row_label(name: str) -> str:
n = (name or "").strip()
if not n:
return ""
if n == "selected_other":
return "Selected (Other)"
return n.replace("_", " ").title()
def _normalize_selection_origin(origin: str) -> str:
o = (origin or "").strip().lower()
if o in {"rewrite", "selection", "probe", "structural", "user", "candidate"}:
return o
return "selection"
@lru_cache(maxsize=1)
def _load_tag_wiki_defs() -> Dict[str, str]:
p = Path("data/tag_wiki_defs.json")
if not p.exists():
return {}
try:
with p.open("r", encoding="utf-8") as f:
data = json.load(f)
out: Dict[str, str] = {}
if isinstance(data, dict):
for k, v in data.items():
tag = _norm_tag_for_lookup(str(k))
text = " ".join(str(v or "").split())
if tag and text:
out[tag] = text
return out
except Exception:
return {}
def _tooltip_text_for_tag(tag: str) -> str:
t = _norm_tag_for_lookup(tag)
parts: List[str] = []
try:
count = get_tag_counts().get(t)
except Exception:
count = None
if isinstance(count, int):
parts.append(f"Count: {count:,}")
d = _load_tag_wiki_defs().get(t, "")
if d:
parts.append(d)
return "\n".join(parts).strip()
def _choice_label_with_source_meta(tag: str, *, origin: str, preselected: bool) -> str:
# Marker is stripped client-side and converted into data attributes for CSS-driven colors/tooltips.
origin_norm = _normalize_selection_origin(origin)
pre = "1" if preselected else "0"
tooltip = _tooltip_text_for_tag(tag)
tip_b64 = ""
if tooltip:
tip_b64 = base64.urlsafe_b64encode(tooltip.encode("utf-8")).decode("ascii")
return f"{_display_tag_text(tag)} [[psq:{origin_norm}:{pre}:{tip_b64}]]"
def _selection_source_rank(origin: str) -> int:
o = _normalize_selection_origin(origin)
if o == "structural":
return 0
if o == "probe":
return 1
# Keep rewrite/user in the same priority band as general selection for row ordering.
return 2
def _build_implied_parent_map(
direct_tags_ordered: List[str],
implied_tags: List[str],
) -> Dict[str, str]:
implied_set = {_norm_tag_for_lookup(t) for t in (implied_tags or []) if t}
if not implied_set or not direct_tags_ordered:
return {}
impl = get_tag_implications()
parent_by_implied: Dict[str, str] = {}
for direct in direct_tags_ordered:
d = _norm_tag_for_lookup(direct)
if not d:
continue
queue = [d]
seen = {d}
while queue:
t = queue.pop()
for parent in impl.get(t, ()):
p = _norm_tag_for_lookup(parent)
if not p or p in seen:
continue
seen.add(p)
if p in implied_set and p not in parent_by_implied:
parent_by_implied[p] = d
queue.append(p)
return parent_by_implied
def _order_selected_tags_for_row(
*,
row_selected_tags: List[str],
selected_index: Dict[str, int],
tag_selection_origins: Dict[str, str],
implied_parent_map: Dict[str, str],
) -> List[str]:
row_selected_norm = [_norm_tag_for_lookup(t) for t in (row_selected_tags or []) if t]
implied_in_row = {t for t in row_selected_norm if t in implied_parent_map}
base_tags = [t for t in row_selected_norm if t not in implied_in_row]
base_tags.sort(
key=lambda t: (
_selection_source_rank(tag_selection_origins.get(t, "selection")),
selected_index.get(t, 10**9),
t,
)
)
children_by_parent: Dict[str, List[str]] = {}
for implied in implied_in_row:
parent = implied_parent_map.get(implied)
if parent:
children_by_parent.setdefault(parent, []).append(implied)
for parent, children in children_by_parent.items():
children.sort(key=lambda t: (selected_index.get(t, 10**9), t))
ordered: List[str] = []
emitted: Set[str] = set()
for tag in base_tags:
if tag in emitted:
continue
ordered.append(tag)
emitted.add(tag)
for child in children_by_parent.get(tag, []):
if child not in emitted:
ordered.append(child)
emitted.add(child)
remaining_implied = [t for t in row_selected_norm if t not in emitted]
remaining_implied.sort(
key=lambda t: (
_selection_source_rank(tag_selection_origins.get(implied_parent_map.get(t, ""), "selection")),
selected_index.get(implied_parent_map.get(t, ""), 10**9),
selected_index.get(t, 10**9),
t,
)
)
for t in remaining_implied:
if t not in emitted:
ordered.append(t)
emitted.add(t)
return ordered
def _escape_prompt_tag(tag: str) -> str:
return (
tag.replace("_", " ")
.replace("(", "\\(")
.replace(")", "\\)")
)
def _ordered_selected_for_prompt(selected: Set[str], row_defs: List[Dict[str, Any]]) -> List[str]:
out: List[str] = []
seen: Set[str] = set()
for row in row_defs:
for tag in row.get("tags", []):
if tag in selected and tag not in seen:
out.append(tag)
seen.add(tag)
# Fallback for any selected tags not present in current rows.
for tag in sorted(selected):
if tag not in seen:
out.append(tag)
seen.add(tag)
return out
def _compose_toggle_prompt_text(selected_tags: List[str], row_defs: List[Dict[str, Any]]) -> str:
selected = {t for t in (selected_tags or []) if t}
ordered = _ordered_selected_for_prompt(selected, row_defs or [])
return ", ".join(_escape_prompt_tag(t) for t in ordered)
def _is_artist_tag(tag: str) -> bool:
t = _norm_tag_for_lookup(str(tag))
if not t:
return False
# Keep a resilient fallback for malformed/missing tag typing metadata.
return get_tag_type_name(t) == "artist" or t.startswith("by_")
@lru_cache(maxsize=1)
def _load_excluded_recommendation_tags() -> Set[str]:
csv_path = Path("data/analysis/category_registry.csv")
out: Set[str] = set()
if not csv_path.exists():
return out
try:
with csv_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
tag = _norm_tag_for_lookup(str(row.get("tag") or ""))
if not tag:
continue
status = str(row.get("category_status") or "").strip().lower()
if status == "excluded":
out.add(tag)
except Exception:
return set()
return out
def _is_excluded_recommendation_tag(tag: str) -> bool:
t = _norm_tag_for_lookup(str(tag))
if not t:
return False
return t in _load_excluded_recommendation_tags()
def _filter_excluded_recommendation_tags(tags: List[str]) -> Tuple[List[str], List[str]]:
excluded = _load_excluded_recommendation_tags()
if not excluded:
return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), []
keep: List[str] = []
removed: List[str] = []
seen: Set[str] = set()
for raw in (tags or []):
t = _norm_tag_for_lookup(str(raw))
if not t:
continue
if t in excluded:
removed.append(t)
continue
if t in seen:
continue
seen.add(t)
keep.append(t)
return keep, sorted(set(removed))
def _build_toggle_rows(
*,
seed_terms: List[str],
selected_tags: List[str],
retrieved_candidate_tags: List[str],
tag_selection_origins: Dict[str, str],
implied_parent_map: Dict[str, str],
top_groups: int,
top_tags_per_group: int,
group_rank_top_k: int,
) -> List[Dict[str, Any]]:
ranked_rows = rank_groups_from_tfidf(
seed_terms=seed_terms,
top_groups=max(1, int(top_groups)),
top_tags_per_group=max(1, int(top_tags_per_group)),
group_rank_top_k=max(1, int(group_rank_top_k)),
)
groups_map = _load_enabled_groups()
selected_active = list(
dict.fromkeys(
_norm_tag_for_lookup(t)
for t in selected_tags
if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)
)
)
selected_index: Dict[str, int] = {t: i for i, t in enumerate(selected_active)}
row_defs: List[Dict[str, Any]] = []
enabled_group_tag_sets: Dict[str, Set[str]] = {
name: {t for t in tags if not _is_artist_tag(t)}
for name, tags in groups_map.items()
}
tags_in_any_enabled_group: Set[str] = set()
for tag_set in enabled_group_tag_sets.values():
tags_in_any_enabled_group.update(tag_set)
displayed_group_names = [r.group_name for r in ranked_rows]
displayed_group_tag_sets: Dict[str, Set[str]] = {
name: enabled_group_tag_sets.get(name, set())
for name in displayed_group_names
}
tags_in_any_displayed_group: Set[str] = set()
for tag_set in displayed_group_tag_sets.values():
tags_in_any_displayed_group.update(tag_set)
retrieved_uncategorized_ranked = list(
dict.fromkeys(
_norm_tag_for_lookup(t)
for t in (retrieved_candidate_tags or [])
if t
and not _is_artist_tag(t)
and not _is_excluded_recommendation_tag(t)
and _norm_tag_for_lookup(t) not in tags_in_any_enabled_group
)
)
retrieved_other_row: Dict[str, Any] | None = None
if retrieved_uncategorized_ranked:
retrieved_uncategorized_set = set(retrieved_uncategorized_ranked)
selected_in_retrieved_other_raw = [
t for t in selected_active if t in retrieved_uncategorized_set
]
selected_in_retrieved_other = _order_selected_tags_for_row(
row_selected_tags=selected_in_retrieved_other_raw,
selected_index=selected_index,
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
)
merged_retrieved_other = selected_in_retrieved_other + [
t for t in retrieved_uncategorized_ranked if t not in selected_in_retrieved_other
]
keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_retrieved_other))
merged_retrieved_other = merged_retrieved_other[:keep_n]
retrieved_other_meta = {
t: {
"origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")),
"preselected": t in selected_active,
}
for t in merged_retrieved_other
}
retrieved_other_row = {
"name": "other_retrieved",
"label": "Other (Retrieved)",
"tags": merged_retrieved_other,
"tag_meta": retrieved_other_meta,
}
# "Selected (Other)" should contain selected tags not already shown in any displayed row.
# Include "Other (Retrieved)" in that displayed-row set to avoid duplicates across those rows.
tags_in_displayed_rows = set(tags_in_any_displayed_group)
if retrieved_other_row:
tags_in_displayed_rows.update(retrieved_other_row.get("tags", []))
selected_other_raw = [t for t in selected_active if t not in tags_in_displayed_rows]
selected_other = _order_selected_tags_for_row(
row_selected_tags=selected_other_raw,
selected_index=selected_index,
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
)
selected_other_meta = {
t: {
"origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")),
"preselected": True,
}
for t in selected_other
}
row_defs.append(
{
"name": "selected_other",
"label": _display_row_label("selected_other"),
"tags": selected_other,
"tag_meta": selected_other_meta,
}
)
for row in ranked_rows:
group_name = row.group_name
group_tag_set = displayed_group_tag_sets.get(group_name, set())
selected_in_group_raw = [t for t in selected_active if t in group_tag_set]
selected_in_group = _order_selected_tags_for_row(
row_selected_tags=selected_in_group_raw,
selected_index=selected_index,
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
)
ranked_tags = [
t
for t, _ in row.tags
if not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)
]
merged = selected_in_group + [t for t in ranked_tags if t not in selected_in_group]
keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_group))
merged = merged[:keep_n]
tag_meta = {
t: {
"origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")),
"preselected": t in selected_active,
}
for t in merged
}
row_defs.append(
{
"name": group_name,
"label": _display_row_label(group_name),
"tags": merged,
"tag_meta": tag_meta,
}
)
# Keep this row at the bottom so category/group rows remain contiguous.
if retrieved_other_row:
row_defs.append(retrieved_other_row)
return row_defs
def _build_display_audit_line(
row_defs: List[Dict[str, Any]],
*,
active_selected_tags: List[str],
direct_selected_tags: List[str],
implied_selected_tags: List[str],
) -> str:
active_set = {
_norm_tag_for_lookup(t)
for t in (active_selected_tags or [])
if t and not _is_artist_tag(t)
}
direct_set = {
_norm_tag_for_lookup(t)
for t in (direct_selected_tags or [])
if t and not _is_artist_tag(t)
}
implied_set = {
_norm_tag_for_lookup(t)
for t in (implied_selected_tags or [])
if t and not _is_artist_tag(t)
}
info_by_tag: Dict[str, Dict[str, Any]] = {}
for row in row_defs or []:
row_name = row.get("name", "")
row_label = row.get("label", row_name)
for tag in row.get("tags", []):
rec = info_by_tag.setdefault(tag, {"rows": [], "sources": set()})
rec["rows"].append(row_label)
if row_name == "selected_other":
rec["sources"].add("selected_other_row")
elif row_name == "other_retrieved":
rec["sources"].add("other_retrieved_row")
else:
rec["sources"].add("ranked_group_row")
if tag in active_set:
rec["sources"].add("selected_active")
if tag in direct_set:
rec["sources"].add("selected_direct")
if tag in implied_set:
rec["sources"].add("selected_implied")
payload = {
"n_tags": len(info_by_tag),
"tags": [
{
"tag": tag,
"rows": rec["rows"],
"sources": sorted(rec["sources"]),
}
for tag, rec in sorted(info_by_tag.items())
],
}
return "Display Tag Audit: " + json.dumps(payload, ensure_ascii=True)
def _build_row_component_updates(
row_defs: List[Dict[str, Any]],
selected_tags: List[str],
max_rows: int,
):
selected = {t for t in (selected_tags or []) if t}
row_values_state: List[List[str]] = []
header_updates = []
checkbox_updates = []
for idx in range(max_rows):
if idx < len(row_defs):
row = row_defs[idx]
tags = list(dict.fromkeys(row.get("tags", [])))
values = [t for t in tags if t in selected]
row_values_state.append(values)
visible = bool(tags)
header_updates.append(gr.update(value=row.get("label", ""), visible=visible))
tag_meta = row.get("tag_meta", {}) if isinstance(row.get("tag_meta", {}), dict) else {}
choices = []
for t in tags:
meta = tag_meta.get(t, {}) if isinstance(tag_meta.get(t, {}), dict) else {}
origin = _normalize_selection_origin(str(meta.get("origin", "selection")))
preselected = bool(meta.get("preselected", False))
choices.append((_choice_label_with_source_meta(t, origin=origin, preselected=preselected), t))
checkbox_updates.append(
gr.update(
choices=choices,
value=values,
visible=visible,
)
)
else:
header_updates.append(gr.update(value="", visible=False))
checkbox_updates.append(gr.update(choices=[], value=[], visible=False))
prompt_text = _compose_toggle_prompt_text(list(selected), row_defs)
return prompt_text, row_values_state, header_updates, checkbox_updates
def _on_toggle_row(
row_idx: int,
changed_values: List[str],
selected_tags_state: List[str],
row_defs_state: List[Dict[str, Any]],
row_values_state: List[List[str]],
max_rows: int,
):
row_defs = row_defs_state or []
row_defs_ui = row_defs[: max(0, int(max_rows))]
selected = set(selected_tags_state or [])
row = row_defs_ui[row_idx] if 0 <= row_idx < len(row_defs_ui) else {}
row_tags = list(dict.fromkeys(row.get("tags", [])))
row_tag_set = set(row_tags)
row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags}
# Be tolerant to UI payload forms: canonical tag values, display labels, or normalized variants.
new_set: Set[str] = set()
for raw in (changed_values or []):
if raw in row_tag_set:
new_set.add(raw)
continue
raw_norm = _norm_tag_for_lookup(str(raw))
mapped = row_tag_by_norm.get(raw_norm)
if mapped:
new_set.add(mapped)
prev_values = list(row_values_state or [])
prev_row_values = prev_values[row_idx] if 0 <= row_idx < len(prev_values) else []
prev_row_selected = set()
for raw in (prev_row_values or []):
if raw in row_tag_set:
prev_row_selected.add(raw)
continue
raw_norm = _norm_tag_for_lookup(str(raw))
mapped = row_tag_by_norm.get(raw_norm)
if mapped:
prev_row_selected.add(mapped)
# Ignore non-user/no-op events (e.g., programmatic value re-sets) deterministically.
if new_set == prev_row_selected:
prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui)
checkbox_updates = [gr.skip() for _ in range(max_rows)]
return [sorted(selected), prev_values, prompt_text, *checkbox_updates]
selected.difference_update(row_tag_set)
selected.update(new_set)
toggled_tags = prev_row_selected ^ new_set
new_row_values_state: List[List[str]] = []
affected_rows: Set[int] = {row_idx}
for idx, row_item in enumerate(row_defs_ui):
tags = list(dict.fromkeys(row_item.get("tags", [])))
values = [t for t in tags if t in selected]
new_row_values_state.append(values)
if toggled_tags and any(t in toggled_tags for t in tags):
affected_rows.add(idx)
checkbox_updates = []
for idx in range(max_rows):
if idx >= len(row_defs_ui):
checkbox_updates.append(gr.skip())
continue
if idx in affected_rows:
checkbox_updates.append(gr.update(value=new_row_values_state[idx]))
else:
checkbox_updates.append(gr.skip())
prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui)
return [sorted(selected), new_row_values_state, prompt_text, *checkbox_updates]
def _build_ui_payload(
*,
console_text: str,
row_defs: List[Dict[str, Any]],
selected_tags: List[str],
):
prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates(
row_defs=row_defs,
selected_tags=selected_tags,
max_rows=display_max_rows_default,
)
return [
console_text,
gr.update(visible=bool(row_defs)),
prompt_text,
sorted(set(selected_tags or [])),
row_defs,
row_values_state,
*header_updates,
*checkbox_updates,
]
def _prepare_run_ui() -> List[Any]:
header_updates = [gr.update(value="", visible=False) for _ in range(display_max_rows_default)]
checkbox_updates = [
gr.update(choices=[], value=[], visible=False)
for _ in range(display_max_rows_default)
]
return [
"Running...",
gr.skip(),
"Running... usually completes in about 20 seconds.",
[],
[],
[],
*header_updates,
*checkbox_updates,
]
def _build_selection_query(
prompt_in: str,
rewritten: str,
structural_tags: List[str],
probe_tags: List[str],
) -> str:
lines = [f"IMAGE DESCRIPTION: {prompt_in.strip()}"]
if rewritten and rewritten.strip():
lines.append(f"REWRITE PHRASES: {rewritten.strip()}")
hint_tags = []
if structural_tags:
hint_tags.extend(structural_tags)
if probe_tags:
hint_tags.extend(probe_tags)
if hint_tags:
# Keep hints as context only; selection still must choose by candidate indices.
lines.append(
"INFERRED TAG HINTS (context only): " + ", ".join(sorted(set(hint_tags)))
)
return "\n".join(lines)
# Set up logging
# Minimal prod logging: warnings+ to stderr, no file by default
import os, logging
LOG_LEVEL = os.environ.get("PSQ_LOG_LEVEL", "WARNING").upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.WARNING),
format="%(asctime)s %(levelname)s:%(message)s",
handlers=[logging.StreamHandler()] # no file -> avoids huge logs on Spaces
)
# Quiet down common noisy libs (optional)
for _name in ("gensim", "gradio", "hnswlib", "httpx", "uvicorn"):
logging.getLogger(_name).setLevel(logging.ERROR)
# Turn off Gradio analytics phone-home to avoid those background thread errors (optional)
os.environ["GRADIO_ANALYTICS_ENABLED"] = "0"
MASCOT_DIR = Path(__file__).parent / "mascotimages"
MASCOT_FILE = MASCOT_DIR / "transparentsquirrel.png"
def _load_mascot_image():
"""Load mascot image if available; return None when missing/unreadable."""
if not MASCOT_FILE.exists():
logging.warning("Mascot image missing: %s", MASCOT_FILE)
return None
try:
return Image.open(MASCOT_FILE).convert("RGBA")
except Exception as e:
logging.warning("Failed to load mascot image (%s): %s", MASCOT_FILE, e)
return None
try:
from gradio_client import utils as _gc_utils
_orig_get_type = _gc_utils.get_type
_orig_j2p = _gc_utils._json_schema_to_python_type
_orig_pub = _gc_utils.json_schema_to_python_type
def _get_type_safe(schema):
# Sometimes schema is a bare True/False (JSON Schema boolean form)
if not isinstance(schema, dict):
return "any"
return _orig_get_type(schema)
def _j2p_safe(schema, defs=None):
# Accept non-dict schemas (True/False/None) and treat as "any"
if not isinstance(schema, dict):
return "any"
return _orig_j2p(schema, defs or schema.get("$defs"))
def _pub_safe(schema):
# Public wrapper used by Gradio; keep it resilient too
if not isinstance(schema, dict):
return "any"
return _j2p_safe(schema, schema.get("$defs"))
_gc_utils.get_type = _get_type_safe
_gc_utils._json_schema_to_python_type = _j2p_safe
_gc_utils.json_schema_to_python_type = _pub_safe
except Exception as e:
print("gradio_client hotfix not applied:", e)
# -------------------------------------------------------------------------------
allow_nsfw_tags = False
def _is_production_runtime() -> bool:
"""Best-effort detection for deployed runtime (HF Spaces or explicit env)."""
if os.environ.get("PSQ_PRODUCTION", "").strip().lower() in {"1", "true", "yes"}:
return True
if os.environ.get("SPACE_ID"):
return True
if os.environ.get("HF_SPACE_ID"):
return True
if os.environ.get("SYSTEM") == "spaces":
return True
return False
verbose_retrieval_default = "0" if _is_production_runtime() else "1"
verbose_retrieval = os.environ.get("PSQ_VERBOSE_RETRIEVAL", verbose_retrieval_default).strip().lower() in {"1", "true", "yes"}
verbose_retrieval_all = False
verbose_retrieval_limit = 20
enable_probe_tags = os.environ.get("PSQ_ENABLE_PROBE", "1").strip() not in {"0", "false", "False"}
display_top_groups_default = int(os.environ.get("PSQ_DISPLAY_TOP_GROUPS", "10"))
display_top_tags_per_group_default = int(os.environ.get("PSQ_DISPLAY_TOP_TAGS_PER_GROUP", "5"))
display_rank_top_k_default = int(os.environ.get("PSQ_DISPLAY_GROUP_RANK_TOP_K", "5"))
display_max_rows_default = int(os.environ.get("PSQ_DISPLAY_MAX_ROWS", "14"))
retrieval_global_k = int(os.environ.get("PSQ_RETRIEVAL_GLOBAL_K", "300"))
retrieval_per_phrase_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_K", "10"))
retrieval_per_phrase_final_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_FINAL_K", "1"))
selection_mode = os.environ.get("PSQ_SELECTION_MODE", "chunked_map_union").strip()
selection_chunk_size = int(os.environ.get("PSQ_SELECTION_CHUNK_SIZE", "60"))
selection_per_phrase_k = int(os.environ.get("PSQ_SELECTION_PER_PHRASE_K", "2"))
selection_candidate_cap = int(os.environ.get("PSQ_SELECTION_CANDIDATE_CAP", "0"))
stage1_rewrite_timeout_s = float(os.environ.get("PSQ_TIMEOUT_REWRITE_S", "45"))
stage1_struct_timeout_s = float(os.environ.get("PSQ_TIMEOUT_STRUCT_S", "45"))
stage1_probe_timeout_s = float(os.environ.get("PSQ_TIMEOUT_PROBE_S", "45"))
stage3_select_timeout_s = float(os.environ.get("PSQ_TIMEOUT_SELECT_S", "45"))
timing_log_path = Path(os.environ.get("PSQ_TIMING_LOG_PATH", "data/runtime_metrics/ui_pipeline_timings.jsonl"))
css = """
.scrollable-content{
max-height: 420px;
overflow-y: scroll; /* always show scrollbar */
overflow-x: hidden;
padding-right: 8px;
padding-bottom: 14px; /* <— add this */
scrollbar-gutter: stable; /* prevent layout shift as it fills */
/* Firefox */
scrollbar-width: auto;
scrollbar-color: rgba(180,180,180,.9) rgba(0,0,0,.15);
}
/* WebKit/Chromium (Chrome/Edge/Safari) */
.scrollable-content::-webkit-scrollbar{ width: 10px; }
.scrollable-content::-webkit-scrollbar-thumb{ background: rgba(180,180,180,.9); border-radius: 8px; }
.scrollable-content::-webkit-scrollbar-track{ background: rgba(0,0,0,.15); }
/* (Optional) make both scroll panes taller so they fill more of the column */
.pane-left .scrollable-content,
.pane-right .scrollable-content {
max-height: 610px; /* was 420px; tweak to taste */
}
.lego-tags .gr-checkboxgroup,
.lego-tags .wrap {
display: flex !important;
flex-wrap: wrap !important;
gap: 10px !important;
}
.lego-tags label {
margin: 0 !important;
padding: 0 !important;
position: relative !important;
}
/* Hide native checkbox visuals completely */
.lego-tags input[type="checkbox"] {
appearance: none !important;
-webkit-appearance: none !important;
-moz-appearance: none !important;
position: absolute !important;
width: 1px !important;
height: 1px !important;
opacity: 0 !important;
pointer-events: none !important;
display: none !important;
}
/* Brick button skin (works for both +span and ~span structures) */
.lego-tags input[type="checkbox"] + span,
.lego-tags input[type="checkbox"] ~ span {
--on-bg1: #ffd166;
--on-bg2: #f39c4a;
--on-border: #b86e21;
--on-text: #2e1706;
position: relative !important;
display: inline-flex !important;
align-items: center !important;
min-height: 40px !important;
padding: 10px 15px 9px 22px !important;
border: 1px solid #9aa6b8 !important;
border-radius: 10px !important;
background: linear-gradient(180deg, #dfe5ee 0%, #bec8d6 100%) !important;
color: #364254 !important;
font-size: 0.97rem !important;
font-weight: 800 !important;
line-height: 1.15 !important;
cursor: pointer !important;
user-select: none !important;
letter-spacing: 0.01em !important;
box-shadow: 0 3px 0 rgba(0,0,0,0.16), inset 0 1px 0 rgba(255,255,255,0.55) !important;
transition: transform 0.08s ease, box-shadow 0.08s ease, filter 0.08s ease !important;
}
.lego-tags input[type="checkbox"] + span::before,
.lego-tags input[type="checkbox"] ~ span::before {
content: "" !important;
position: absolute !important;
top: 5px !important;
left: 8px !important;
width: 8px !important;
height: 8px !important;
border-radius: 50% !important;
background: rgba(255,255,255,0.58) !important;
box-shadow: 22px 0 0 rgba(255,255,255,0.58) !important;
pointer-events: none !important;
}
/* Unselected cue: show "+" on the left. */
.lego-tags input[type="checkbox"] + span::after,
.lego-tags input[type="checkbox"] ~ span::after {
content: "+" !important;
position: absolute !important;
left: 6px !important;
top: 50% !important;
transform: translateY(-52%) !important;
font-size: 1rem !important;
font-weight: 900 !important;
color: #4b5563 !important;
opacity: 0.95 !important;
pointer-events: none !important;
}
/* Bright color cycle used only when selected */
.lego-tags label:nth-child(8n+1) span { --on-bg1: #ffd166; --on-bg2: #f39c4a; --on-border: #b86e21; --on-text: #2e1706; }
.lego-tags label:nth-child(8n+2) span { --on-bg1: #6ee7ff; --on-bg2: #1fb7ff; --on-border: #157cb3; --on-text: #07263c; }
.lego-tags label:nth-child(8n+3) span { --on-bg1: #9dff8f; --on-bg2: #45c96f; --on-border: #2a8b4b; --on-text: #0d2917; }
.lego-tags label:nth-child(8n+4) span { --on-bg1: #ff8fab; --on-bg2: #ff5c7a; --on-border: #b83956; --on-text: #3f0f1d; }
.lego-tags label:nth-child(8n+5) span { --on-bg1: #d0a8ff; --on-bg2: #a46cff; --on-border: #7147b3; --on-text: #25143f; }
.lego-tags label:nth-child(8n+6) span { --on-bg1: #ffe27a; --on-bg2: #f7bf39; --on-border: #ad7f1f; --on-text: #332407; }
.lego-tags label:nth-child(8n+7) span { --on-bg1: #8effd5; --on-bg2: #2ed6b5; --on-border: #1e947d; --on-text: #0d2a25; }
.lego-tags label:nth-child(8n+8) span { --on-bg1: #ffb47e; --on-bg2: #ff8753; --on-border: #b95b2d; --on-text: #391a0a; }
/* Source-driven selected colors (applies when tags are preselected by the pipeline). */
.lego-tags label[data-psq-preselected="1"][data-psq-origin="rewrite"] span {
--on-bg1: #77f0d7;
--on-bg2: #26b9a3;
--on-border: #187869;
--on-text: #062923;
}
.lego-tags label[data-psq-preselected="1"][data-psq-origin="selection"] span {
--on-bg1: #ffd98a;
--on-bg2: #f0a93c;
--on-border: #a66f1f;
--on-text: #382206;
}
.lego-tags label[data-psq-preselected="1"][data-psq-origin="probe"] span {
--on-bg1: #d8b4ff;
--on-bg2: #9a6cff;
--on-border: #6745b0;
--on-text: #24143b;
}
.lego-tags label[data-psq-preselected="1"][data-psq-origin="structural"] span {
--on-bg1: #a6f79a;
--on-bg2: #53c368;
--on-border: #2f8442;
--on-text: #102d17;
}
.lego-tags label[data-psq-preselected="1"][data-psq-origin="implied"] span {
--on-bg1: #d7dde8;
--on-bg2: #a8b3c4;
--on-border: #6f7e95;
--on-text: #1d2633;
}
/* User-selected tags (not initially selected by the pipeline). */
.lego-tags label[data-psq-preselected="0"] span {
--on-bg1: #9ec5ff;
--on-bg2: #4f86ff;
--on-border: #2f5fbf;
--on-text: #0b1f42;
}
.lego-tags label:hover span {
filter: brightness(1.02) !important;
transform: translateY(1px) !important;
}
/* ON state: brighter + visibly recessed */
.lego-tags input[type="checkbox"]:checked + span,
.lego-tags input[type="checkbox"]:checked ~ span,
.lego-tags label:has(input[type="checkbox"]:checked) span {
background: linear-gradient(180deg, var(--on-bg1) 0%, var(--on-bg2) 100%) !important;
color: var(--on-text) !important;
border-color: var(--on-border) !important;
filter: saturate(1.2) brightness(1.12) !important;
transform: translateY(-2px) !important;
box-shadow:
inset 0 3px 6px rgba(0,0,0,0.20),
inset 0 -1px 0 rgba(255,255,255,0.36),
0 6px 0 rgba(0,0,0,0.32) !important;
}
.lego-tags input[type="checkbox"]:checked + span::after,
.lego-tags input[type="checkbox"]:checked ~ span::after,
.lego-tags label:has(input[type="checkbox"]:checked) span::after {
content: "" !important;
}
.source-legend {
display: flex;
flex-wrap: wrap;
align-items: center;
gap: 8px;
margin: 4px 0 10px 0;
}
.source-legend .legend-title {
font-size: 0.92rem;
font-weight: 900;
color: #334155;
margin-right: 4px;
}
.source-legend .chip {
display: inline-flex;
align-items: center;
border-radius: 10px;
border: 1px solid #6c7788;
padding: 6px 12px;
font-size: 0.85rem;
font-weight: 800;
color: #111827;
background: #f3f6fb;
}
.source-legend .chip.rewrite { background: #26b9a3; color: #062923; border-color: #187869; }
.source-legend .chip.selection { background: #f0a93c; color: #382206; border-color: #a66f1f; }
.source-legend .chip.probe { background: #9a6cff; color: #ffffff; border-color: #6745b0; }
.source-legend .chip.structural { background: #53c368; color: #102d17; border-color: #2f8442; }
.source-legend .chip.implied { background: #a8b3c4; color: #1d2633; border-color: #6f7e95; }
.source-legend .chip.user { background: #4f86ff; color: #ffffff; border-color: #2f5fbf; }
.source-legend .chip.unselected { background: #c7ced8; color: #2d3440; border-color: #7d8897; }
.row-heading p {
margin: 8px 0 0 0 !important;
font-size: 1.18rem !important;
font-weight: 850 !important;
line-height: 1.2 !important;
}
.row-instruction {
text-align: center;
margin: 8px 0 12px 0;
}
.row-instruction p {
margin: 0 !important;
font-size: 1.02rem !important;
font-style: italic !important;
font-weight: 800 !important;
color: #1d4ed8 !important;
}
.top-instruction {
text-align: center;
margin: 2px 0 6px 0;
}
.top-instruction p {
margin: 0 !important;
font-size: 1.02rem !important;
font-style: italic !important;
font-weight: 800 !important;
color: #1d4ed8 !important;
}
.run-hint {
margin-top: 6px;
text-align: center;
}
.run-hint p {
margin: 0 !important;
font-size: 0.9rem !important;
font-style: italic !important;
color: #475569 !important;
}
.prompt-card {
background: transparent !important;
border: none !important;
box-shadow: none !important;
padding: 0 !important;
}
.suggested-prompt-box {
margin-top: 2px !important;
}
.suggested-prompt-card {
margin-top: 10px !important;
}
"""
client_js = """
() => {
const markerRe = /\\s*\\[\\[psq:([a-z_]+):(0|1):([A-Za-z0-9_\\-=]*)\\]\\]\\s*$/;
const decodeTip = (b64) => {
if (!b64) return "";
try {
const binary = atob((b64 || "").replace(/-/g, "+").replace(/_/g, "/"));
const bytes = Uint8Array.from(binary, (c) => c.charCodeAt(0));
return new TextDecoder("utf-8").decode(bytes);
} catch (_) {
return "";
}
};
const applyTagMeta = () => {
const labels = document.querySelectorAll(".lego-tags label");
labels.forEach((label) => {
const span = label.querySelector("span");
if (!span) return;
const text = span.textContent || "";
const match = text.match(markerRe);
if (!match) return;
label.dataset.psqOrigin = match[1];
label.dataset.psqPreselected = match[2];
const tip = decodeTip(match[3] || "");
if (tip) {
label.title = tip;
span.title = tip;
} else {
label.removeAttribute("title");
span.removeAttribute("title");
}
span.textContent = text.replace(markerRe, "");
});
};
applyTagMeta();
const observer = new MutationObserver(() => applyTagMeta());
observer.observe(document.body, { childList: true, subtree: true, characterData: true });
}
"""
def rag_pipeline_ui(
user_prompt: str,
display_top_groups: float,
display_top_tags_per_group: float,
display_rank_top_k: float,
):
logs = []
def log(s): logs.append(s)
try:
stage_timings = {}
def _record_timing(stage: str, dt_s: float):
stage_timings[stage] = float(dt_s)
def _emit_timing_summary(total_s: float):
summary_order = [
"preprocess",
"rewrite",
"structural",
"probe",
"retrieval",
"selection",
"implication_expansion",
"prompt_composition",
"group_display",
]
lines = []
for k in summary_order:
if k in stage_timings:
lines.append(f"{k}={stage_timings[k]:.2f}s")
slowest = max(stage_timings.items(), key=lambda kv: kv[1])[0] if stage_timings else "n/a"
log("Timing Summary: " + ", ".join(lines))
log(f"Timing Slowest Stage: {slowest}")
log(f"Timing Total: {total_s:.2f}s")
def _append_timing_jsonl(total_s: float):
try:
timing_log_path.parent.mkdir(parents=True, exist_ok=True)
rec = {
"timestamp_utc": datetime.utcnow().isoformat(timespec="seconds") + "Z",
"stages_s": stage_timings,
"total_s": float(total_s),
"config": {
"timeout_rewrite_s": stage1_rewrite_timeout_s,
"timeout_struct_s": stage1_struct_timeout_s,
"timeout_probe_s": stage1_probe_timeout_s,
"timeout_select_s": stage3_select_timeout_s,
},
}
with timing_log_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(rec, ensure_ascii=True) + "\n")
log(f"Timing Log: wrote {timing_log_path}")
except Exception as e:
log(f"Timing Log: failed ({type(e).__name__}: {e})")
def _future_with_timeout(fut, timeout_s: float, stage_name: str, fallback):
t0 = time.perf_counter()
try:
out = fut.result(timeout=max(1.0, float(timeout_s)))
dt = time.perf_counter() - t0
log(f"{stage_name}: {dt:.2f}s")
stage_key = {
"Rewrite": "rewrite",
"Structural inference": "structural",
"Probe inference": "probe",
"Index selection": "selection",
}.get(stage_name)
if stage_key:
_record_timing(stage_key, dt)
return out
except FutureTimeoutError:
fut.cancel()
log(f"{stage_name}: timed out after {timeout_s:.0f}s; using fallback")
return fallback
except Exception as e:
log(f"{stage_name}: failed ({type(e).__name__}: {e}); using fallback")
return fallback
t_total0 = time.perf_counter()
log("Start: received prompt")
prompt_in = (user_prompt or "").strip()
if not prompt_in:
return _build_ui_payload(
console_text="Error: empty prompt",
row_defs=[],
selected_tags=[],
)
log("Input:")
log(prompt_in)
log("")
log(
"Runtime config: "
f"retrieval_global_k={retrieval_global_k} "
f"retrieval_per_phrase_k={retrieval_per_phrase_k} "
f"retrieval_per_phrase_final_k={retrieval_per_phrase_final_k} "
f"selection_mode={selection_mode} "
f"selection_chunk_size={selection_chunk_size} "
f"selection_per_phrase_k={selection_per_phrase_k}"
)
log("")
t0 = time.perf_counter()
user_tags = extract_user_provided_tags_upto_3_words(prompt_in)
dt = time.perf_counter()-t0
_record_timing("preprocess", dt)
log(f"Preprocess (user tag extraction): {dt:.2f}s")
log("Heuristically extracted user tags:")
if user_tags:
log(", ".join(user_tags))
else:
log("(none)")
log("")
log("Step 1: LLM rewrite + structural inference + probe (concurrent)")
max_workers = 3 if enable_probe_tags else 2
with ThreadPoolExecutor(max_workers=max_workers) as ex:
fut_rewrite = ex.submit(llm_rewrite_prompt, prompt_in, log)
fut_struct = ex.submit(llm_infer_structural_tags, prompt_in, log=log)
fut_probe = ex.submit(llm_infer_probe_tags, prompt_in, log=log) if enable_probe_tags else None
rewritten = _future_with_timeout(
fut_rewrite, stage1_rewrite_timeout_s, "Rewrite", prompt_in
)
structural_tags = _future_with_timeout(
fut_struct, stage1_struct_timeout_s, "Structural inference", []
)
probe_tags = (
_future_with_timeout(fut_probe, stage1_probe_timeout_s, "Probe inference", [])
if fut_probe else []
)
log("Rewrite:")
log(rewritten if rewritten else "(empty)")
log("")
rewrite_for_retrieval = rewritten
if user_tags:
# keep them separate in logs, but allow them to help retrieval
rewrite_for_retrieval = (rewrite_for_retrieval + ", " + ", ".join(user_tags)).strip(", ").strip()
log("Step 2: Prompt Squirrel retrieval (hidden)")
try:
t0 = time.perf_counter()
retrieval_context_tags = list(dict.fromkeys((structural_tags or []) + (probe_tags or [])))
rewrite_phrases = [p.strip() for p in (rewrite_for_retrieval or "").split(",") if p.strip()]
retrieval_result = psq_candidates_from_rewrite_phrases(
rewrite_phrases=rewrite_phrases,
allow_nsfw_tags=allow_nsfw_tags,
context_tags=retrieval_context_tags,
global_k=max(1, retrieval_global_k),
per_phrase_k=max(1, retrieval_per_phrase_k),
per_phrase_final_k=max(1, retrieval_per_phrase_final_k),
verbose=verbose_retrieval,
)
if isinstance(retrieval_result, tuple):
candidates, phrase_reports = retrieval_result
else:
candidates, phrase_reports = retrieval_result, []
if selection_candidate_cap > 0 and len(candidates) > selection_candidate_cap:
candidates = candidates[:selection_candidate_cap]
log(f"Selection candidate cap applied: {selection_candidate_cap}")
dt = time.perf_counter()-t0
_record_timing("retrieval", dt)
log(f"Retrieval: {dt:.2f}s")
log(f"Retrieved {len(candidates)} candidate tags")
if verbose_retrieval:
log(f"Total unique candidates: {len(candidates)}")
limit = None if verbose_retrieval_all else max(1, int(verbose_retrieval_limit))
for report in phrase_reports:
phrase = report.get("normalized") or report.get("phrase") or ""
lookup = report.get("lookup") or ""
tfidf_vocab = report.get("tfidf_vocab")
log(f"Phrase: {phrase} (lookup={lookup}) tfidf_vocab={tfidf_vocab}")
rows = report.get("candidates", [])
shown = rows if limit is None else rows[:limit]
for row in shown:
tag = row.get("tag")
alias_token = row.get("alias_token")
score_fasttext = row.get("score_fasttext")
score_context = row.get("score_context")
score_combined = row.get("score_combined")
count = row.get("count")
alias_part = ""
if alias_token and alias_token != tag:
alias_part = f" [alias_token={alias_token}]"
fasttext_str = (
f"{score_fasttext:.3f}" if isinstance(score_fasttext, (int, float)) else score_fasttext
)
if score_context is None:
context_str = "None"
else:
context_str = (
f"{score_context:.3f}" if isinstance(score_context, (int, float)) else score_context
)
combined_str = (
f"{score_combined:.3f}" if isinstance(score_combined, (int, float)) else score_combined
)
log(
f" {tag}{alias_part} | fasttext={fasttext_str} context={context_str} "
f"combined={combined_str} count={count}"
)
if limit is not None and len(rows) > limit:
log(f" ... ({len(rows) - limit} more)")
except Exception as e:
log(f"Retrieval fallback: {type(e).__name__}: {e}")
candidates = []
retrieved_candidate_tags = list(
dict.fromkeys(
_norm_tag_for_lookup(c.tag)
for c in (candidates or [])
if getattr(c, "tag", None)
)
)
log("Step 3: LLM index selection (uses rewrite + structural/probe context)")
selection_query = _build_selection_query(
prompt_in=prompt_in,
rewritten=rewritten,
structural_tags=structural_tags,
probe_tags=probe_tags,
)
with ThreadPoolExecutor(max_workers=1) as ex:
fut_sel = ex.submit(
llm_select_indices,
query_text=selection_query,
candidates=candidates,
max_pick=0,
log=log,
mode=selection_mode,
chunk_size=max(1, selection_chunk_size),
per_phrase_k=max(1, selection_per_phrase_k),
)
picked_indices = _future_with_timeout(
fut_sel, stage3_select_timeout_s, "Index selection", []
)
selection_selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else []
selected_tags = list(selection_selected_tags)
if structural_tags:
# Add structural tags that aren't already selected
existing = {t for t in selected_tags}
new_structural = [t for t in structural_tags if t not in existing]
selected_tags.extend(new_structural)
log(f" Added {len(new_structural)} structural tags: {', '.join(new_structural)}")
else:
log(" No structural tags inferred")
if probe_tags:
existing = {t for t in selected_tags}
new_probe = [t for t in probe_tags if t not in existing]
selected_tags.extend(new_probe)
log(f" Added {len(new_probe)} probe tags: {', '.join(new_probe)}")
elif enable_probe_tags:
log(" No probe tags inferred")
selected_tags, removed_excluded_direct = _filter_excluded_recommendation_tags(selected_tags)
if removed_excluded_direct:
log(f" Removed {len(removed_excluded_direct)} excluded tags: {', '.join(removed_excluded_direct)}")
direct_selected_tags = list(dict.fromkeys(selected_tags))
log("Step 3c: Expand via tag implications")
t0 = time.perf_counter()
tag_set = set(selected_tags)
expanded, implied_only = expand_tags_via_implications(tag_set)
dt = time.perf_counter()-t0
_record_timing("implication_expansion", dt)
log(f"Implication expansion: {dt:.2f}s")
implied_selected_tags = sorted(implied_only) if implied_only else []
if implied_only:
selected_tags.extend(sorted(implied_only))
log(f" Added {len(implied_only)} implied tags: {', '.join(sorted(implied_only))}")
else:
log(" No additional implied tags")
selected_tags, removed_excluded_implied = _filter_excluded_recommendation_tags(selected_tags)
implied_selected_tags = [
t for t in implied_selected_tags if not _is_excluded_recommendation_tag(t)
]
if removed_excluded_implied:
log(
f" Removed {len(removed_excluded_implied)} excluded tags after implications: "
f"{', '.join(removed_excluded_implied)}"
)
log("Step 4: Compose final prompt")
t0 = time.perf_counter()
final_prompt = compose_final_prompt(rewritten, selected_tags)
dt = time.perf_counter()-t0
_record_timing("prompt_composition", dt)
log(f"Prompt composition: {dt:.2f}s")
log("Step 5: Build ranked group/category display")
t0 = time.perf_counter()
seed_terms = []
seed_terms.extend(user_tags)
seed_terms.extend([p.strip() for p in (rewritten or "").split(",") if p.strip()])
seed_terms.extend(structural_tags or [])
seed_terms.extend(probe_tags or [])
seed_terms.extend(selected_tags)
seed_terms = list(dict.fromkeys(seed_terms))
active_selected_tags = list(dict.fromkeys(selected_tags))
structural_set = {_norm_tag_for_lookup(t) for t in (structural_tags or []) if t}
probe_set = {_norm_tag_for_lookup(t) for t in (probe_tags or []) if t}
implied_set = {_norm_tag_for_lookup(t) for t in (implied_selected_tags or []) if t}
rewrite_set = {
_norm_tag_for_lookup(t)
for t in (list(user_tags or []) + [p.strip() for p in (rewritten or "").split(",") if p.strip()])
if t
}
selection_set = {_norm_tag_for_lookup(t) for t in (selection_selected_tags or []) if t}
tag_selection_origins: Dict[str, str] = {}
for tag in active_selected_tags:
tag_norm = _norm_tag_for_lookup(tag)
if tag_norm in structural_set:
origin = "structural"
elif tag_norm in probe_set:
origin = "probe"
elif tag_norm in rewrite_set:
origin = "rewrite"
elif tag_norm in selection_set:
origin = "selection"
elif tag_norm in implied_set:
origin = "implied"
else:
# Unknown/fallback tags use selection color.
origin = "selection"
tag_selection_origins[tag] = origin
if tag_norm and tag_norm != tag:
tag_selection_origins[tag_norm] = origin
direct_tags_for_implied = list(
dict.fromkeys(_norm_tag_for_lookup(t) for t in (direct_selected_tags or []) if t)
)
direct_tags_for_implied_idx = {t: i for i, t in enumerate(direct_tags_for_implied)}
direct_tags_for_implied.sort(
key=lambda t: (
_selection_source_rank(tag_selection_origins.get(t, "selection")),
direct_tags_for_implied_idx.get(t, 10**9),
)
)
implied_parent_map = _build_implied_parent_map(
direct_tags_ordered=direct_tags_for_implied,
implied_tags=implied_selected_tags,
)
toggle_rows = _build_toggle_rows(
seed_terms=seed_terms,
selected_tags=active_selected_tags,
retrieved_candidate_tags=retrieved_candidate_tags,
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
top_groups=max(1, int(display_top_groups)),
top_tags_per_group=max(1, int(display_top_tags_per_group)),
group_rank_top_k=max(1, int(display_rank_top_k)),
)
dt = time.perf_counter()-t0
_record_timing("group_display", dt)
log(f"Ranked group display: {dt:.2f}s ({len(toggle_rows)} rows)")
log(
_build_display_audit_line(
toggle_rows,
active_selected_tags=active_selected_tags,
direct_selected_tags=direct_selected_tags,
implied_selected_tags=implied_selected_tags,
)
)
total_dt = time.perf_counter()-t_total0
_emit_timing_summary(total_dt)
_append_timing_jsonl(total_dt)
log("Done: final prompt ready")
return _build_ui_payload(
console_text="\n".join(logs),
row_defs=toggle_rows,
selected_tags=active_selected_tags,
)
except Exception as e:
log(f"Error: {type(e).__name__}: {e}")
return _build_ui_payload(
console_text="\n".join(logs),
row_defs=[],
selected_tags=[],
)
with gr.Blocks(css=css, js=client_js) as app:
with gr.Row():
with gr.Column(scale=3, elem_classes=["prompt-col"]):
gr.Markdown(
'Describe your image under "Enter Prompt" and click "Run". '
'Prompt Squirrel will translate it into image board tags.',
elem_classes=["top-instruction"],
)
with gr.Group(elem_classes=["prompt-card"]):
image_tags = gr.Textbox(
label="Enter Prompt",
placeholder="e.g. fox, outside, detailed background, .",
lines=1,
elem_classes=["enter-prompt-box"],
)
with gr.Group(elem_classes=["prompt-card", "suggested-prompt-card"]):
suggested_prompt = gr.Textbox(
label="Suggested Prompt (Read-only)",
lines=2,
interactive=False,
show_copy_button=True,
placeholder='Suggested prompt will appear here after you click "Run".',
elem_classes=["suggested-prompt-box"],
)
with gr.Column(scale=1):
_mascot_pil = _load_mascot_image()
if _mascot_pil is not None:
mascot_img = gr.Image(
value=_mascot_pil,
show_label=False,
interactive=False,
height=240,
elem_id="mascot"
)
else:
mascot_img = gr.Markdown("`(mascot image unavailable)`")
submit_button = gr.Button("Run", variant="primary")
gr.Markdown("Typical runtime: up to ~20 seconds.", elem_classes=["run-hint"])
selected_tags_state = gr.State([])
row_defs_state = gr.State([])
row_values_state = gr.State([])
toggle_instruction = gr.Markdown(
"Click tag buttons to add or remove tags from the suggested prompt.",
elem_classes=["row-instruction"],
visible=False,
)
row_headers: List[gr.Markdown] = []
row_checkboxes: List[gr.CheckboxGroup] = []
for _ in range(display_max_rows_default):
with gr.Row():
with gr.Column(scale=2, min_width=170):
row_headers.append(gr.Markdown(value="", visible=False, elem_classes=["row-heading"]))
with gr.Column(scale=10):
row_checkboxes.append(
gr.CheckboxGroup(
choices=[],
value=[],
visible=False,
interactive=True,
container=False,
elem_classes=["lego-tags"],
)
)
gr.HTML(
"""
<div class="source-legend">
<span class="legend-title">Legend:</span>
<span class="chip rewrite">Rewrite phrase</span>
<span class="chip selection">General selection</span>
<span class="chip probe">Probe query</span>
<span class="chip structural">Structural query</span>
<span class="chip implied">Implied</span>
<span class="chip user">User-toggled</span>
<span class="chip unselected">Unselected</span>
</div>
"""
)
with gr.Accordion("Display Settings", open=False):
with gr.Row():
display_top_groups = gr.Number(
value=display_top_groups_default,
precision=0,
label="Rows (Top Groups/Categories)",
minimum=1,
)
display_top_tags_per_group = gr.Number(
value=display_top_tags_per_group_default,
precision=0,
label="Top Tags Shown Per Row",
minimum=1,
)
display_rank_top_k = gr.Number(
value=display_rank_top_k_default,
precision=0,
label="Top Tags Used for Row Ranking",
minimum=1,
)
with gr.Accordion("Console", open=False):
console = gr.Textbox(
label="Console",
lines=10,
interactive=False,
placeholder="Progress logs will appear here."
)
run_outputs = [
console,
toggle_instruction,
suggested_prompt,
selected_tags_state,
row_defs_state,
row_values_state,
*row_headers,
*row_checkboxes,
]
submit_button.click(
_prepare_run_ui,
inputs=[],
outputs=run_outputs,
queue=False,
show_progress="hidden",
).then(
rag_pipeline_ui,
inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k],
outputs=run_outputs,
)
image_tags.submit(
_prepare_run_ui,
inputs=[],
outputs=run_outputs,
queue=False,
show_progress="hidden",
).then(
rag_pipeline_ui,
inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k],
outputs=run_outputs,
)
for idx, row_cb in enumerate(row_checkboxes):
row_cb.select(
fn=lambda changed_values, selected_state, row_defs, row_values, i=idx: _on_toggle_row(
i,
changed_values,
selected_state,
row_defs,
row_values,
display_max_rows_default,
),
inputs=[row_cb, selected_tags_state, row_defs_state, row_values_state],
outputs=[selected_tags_state, row_values_state, suggested_prompt, *row_checkboxes],
queue=False,
show_progress="hidden",
)
if __name__ == "__main__":
app.queue().launch(allowed_paths=[str(MASCOT_DIR)])