Food Desert
Add exact n-gram retrieval query hints
29b12cd
Raw
History Blame
116 kB
import os
import logging
import time
import json
import csv
import re
import base64
import atexit
from datetime import datetime
from functools import lru_cache
from PIL import Image
from pathlib import Path
from typing import Any, Dict, List, Set, Tuple
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
def _startup_profile_enabled() -> bool:
raw = (os.environ.get("PSQ_STARTUP_PROFILE", "") or "").strip().lower()
return raw in {"1", "true", "yes", "on"}
_STARTUP_PROFILE_ON = _startup_profile_enabled()
_STARTUP_PROFILE_T0 = time.perf_counter()
_STARTUP_PROFILE_PATH: Path | None = None
_STARTUP_PROFILE_FILE = None
_STARTUP_PROFILE_CLIENT_LOAD_MARKED = False
if _STARTUP_PROFILE_ON:
_startup_profile_path_raw = (
os.environ.get("PSQ_STARTUP_PROFILE_PATH", "") or ""
).strip()
if _startup_profile_path_raw:
_STARTUP_PROFILE_PATH = Path(_startup_profile_path_raw)
else:
_STARTUP_PROFILE_PATH = Path("data/runtime_metrics") / (
f"startup_profile_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl"
)
try:
_STARTUP_PROFILE_PATH.parent.mkdir(parents=True, exist_ok=True)
_STARTUP_PROFILE_FILE = _STARTUP_PROFILE_PATH.open("a", encoding="utf-8")
except Exception:
_STARTUP_PROFILE_FILE = None
def _startup_profile_mark(event: str, **fields: Any) -> None:
if not _STARTUP_PROFILE_ON:
return
rec: Dict[str, Any] = {
"event": str(event),
"t_s": round(time.perf_counter() - _STARTUP_PROFILE_T0, 6),
}
if fields:
rec.update(fields)
line = json.dumps(rec, ensure_ascii=False)
if _STARTUP_PROFILE_FILE is not None:
try:
_STARTUP_PROFILE_FILE.write(line + "\n")
_STARTUP_PROFILE_FILE.flush()
except Exception:
pass
print(f"STARTUP_PROFILE {line}")
def _startup_profile_close() -> None:
if _STARTUP_PROFILE_FILE is None:
return
try:
_STARTUP_PROFILE_FILE.close()
except Exception:
pass
_startup_profile_mark("module_bootstrap_begin")
if _STARTUP_PROFILE_ON and _STARTUP_PROFILE_PATH is not None:
_startup_profile_mark("startup_profile_enabled", path=str(_STARTUP_PROFILE_PATH))
atexit.register(_startup_profile_close)
import gradio as gr
_startup_profile_mark("import.gradio.done")
from psq_rag.pipeline.preproc import (
extract_exact_tag_query_phrases,
extract_user_provided_tags_upto_3_words,
)
_startup_profile_mark("import.psq_rag.pipeline.preproc.done")
from psq_rag.llm.rewrite import llm_rewrite_prompt
_startup_profile_mark("import.psq_rag.llm.rewrite.done")
from psq_rag.llm.rewrite_local_t5 import local_t5_rewrite_prompt
_startup_profile_mark("import.psq_rag.llm.rewrite_local_t5.done")
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup
_startup_profile_mark("import.psq_rag.retrieval.psq_retrieval.done")
from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags, llm_infer_probe_tags
_startup_profile_mark("import.psq_rag.llm.select.done")
from psq_rag.retrieval.state import (
expand_tags_via_implications,
get_tag_type_name,
get_tag_implications,
get_tag_counts,
get_alias2tags,
)
_startup_profile_mark("import.psq_rag.retrieval.state.done")
from psq_rag.ui.group_ranked_display import rank_groups_from_tfidf, _load_enabled_groups
_startup_profile_mark("import.psq_rag.ui.group_ranked_display.done")
APP_DIR = Path(__file__).parent
DOCS_DIR = APP_DIR / "docs"
ARCH_DIAGRAM_FILE = DOCS_DIR / "assets" / "architecture_overview.png"
ARCH_DIAGRAM_MARKER = "{{ARCHITECTURE_DIAGRAM}}"
ARCH_DIAGRAM_INSERT_BEFORE_HEADING = "## What Each Step Does"
DEFAULT_STAGE_OPENROUTER_MODEL = "mistralai/mistral-small-24b-instruct-2501"
DEFAULT_STAGE_OPENROUTER_FALLBACK_MODEL = "meta-llama/llama-3.1-8b-instruct"
DEFAULT_PROMPT_EXAMPLE = (
"A feral fox with blue eyes and orange fur standing and looking at the viewer "
"with a calm expression in a dense forest"
)
T5_REWRITE_PROMPT_PREFIX = "The image showcases "
SUGGESTED_PROMPT_LABEL_READY = "Suggested Prompt (Read-only)"
SUGGESTED_PROMPT_LABEL_PHASE_START = "Thinking... Starting"
SUGGESTED_PROMPT_LABEL_PHASE_REWRITE = "Thinking... Rewriting"
SUGGESTED_PROMPT_LABEL_PHASE_SELECT = "Thinking... Retrieving + Selecting"
def _redact_console_error_text(err: Any) -> str:
"""Redact provider/user identifiers and token-like substrings from console logs."""
text = str(err or "")
text = re.sub(r"\buser_(?!id\b)[A-Za-z0-9]{6,}\b", "user_<redacted>", text)
text = re.sub(r"(?i)\b(bearer)\s+[A-Za-z0-9._:-]+\b", r"\1 <redacted>", text)
text = re.sub(r"\b(sk|or)-[A-Za-z0-9._-]+\b", r"\1-<redacted>", text)
return text
_CORPORATE_HARDBLOCK_PATTERNS = [
# Rating-like explicitness markers.
re.compile(r"(^|_)(nsfw|explicit|questionable)(_|$)", re.IGNORECASE),
# Unambiguous sexual anatomy.
re.compile(
r"(^|_)(breast|breasts|boob|boobs|nipple|nipples|penis|vagina|pussy|clit|testicle|scrotum|genital|crotch|anus|anal|areola)(_|$)",
re.IGNORECASE,
),
# Unambiguous sexual activity.
re.compile(
r"(^|_)(sex|sexual|fucking|fuck|blowjob|handjob|masturbat|penetrat|thrust|orgasm|cum|ejaculat|creampie|nude|naked|topless|bottomless|moan|sexy)(_|$)",
re.IGNORECASE,
),
# Common kink/fetish markers.
re.compile(r"(^|_)(fetish|bdsm|bondage|dominatrix|submission|vore|inflation|watersports)(_|$)", re.IGNORECASE),
]
def _split_prompt_commas(s: str) -> List[str]:
return [p.strip() for p in (s or "").split(",") if p.strip()]
def _norm_for_dedupe(tag: str) -> str:
# your canonical form for lookup/dedupe
return _norm_tag_for_lookup(tag.lower())
def _env_flag(name: str, default: bool) -> bool:
raw = (os.environ.get(name, "") or "").strip().lower()
if not raw:
return default
return raw not in {"0", "false", "no", "off"}
def _env_int(name: str, default: int, minimum: int = 1) -> int:
raw = (os.environ.get(name, "") or "").strip()
if not raw:
return max(minimum, int(default))
try:
return max(minimum, int(raw))
except Exception:
return max(minimum, int(default))
def _get_rewrite_source() -> str:
source = (os.environ.get("PSQ_REWRITE_SOURCE", "t5") or "t5").strip().lower()
if source in {"llm", "t5"}:
return source
return "t5"
def _rewrite_prompt(prompt_in: str, log) -> str:
source = _get_rewrite_source()
if source == "llm":
return llm_rewrite_prompt(prompt_in, log)
t5_model_dir = (
os.environ.get(
"PSQ_T5_REWRITE_MODEL_DIR",
"models/finetune/t5-rewrite-n30best-20260508/checkpoint-18000",
)
or "models/finetune/t5-rewrite-n30best-20260508/checkpoint-18000"
).strip()
t5_num_beams = _env_int("PSQ_T5_REWRITE_NUM_BEAMS", 4, minimum=1)
t5_max_new_tokens = _env_int("PSQ_T5_REWRITE_MAX_NEW_TOKENS", 128, minimum=8)
# Keep T5 input style aligned with caption training distribution.
t5_prompt_in = f"{T5_REWRITE_PROMPT_PREFIX}{prompt_in}"
out = local_t5_rewrite_prompt(
t5_prompt_in,
log=log,
model_dir=t5_model_dir,
num_beams=t5_num_beams,
max_new_tokens=t5_max_new_tokens,
)
if out:
return out
# Keep wiring available, but default to no LLM fallback for T5 rewrite.
if _env_flag("PSQ_T5_REWRITE_FALLBACK_TO_LLM", False):
log("Rewrite source=t5 returned empty; fallback to llm")
return llm_rewrite_prompt(prompt_in, log)
return ""
def compose_final_prompt(rewritten_prompt: str, selected_tags: List[str]) -> str:
parts = _split_prompt_commas(rewritten_prompt)
parts.extend(selected_tags)
seen = set()
out = []
for p in parts:
key = _norm_for_dedupe(p)
if key in seen:
continue
seen.add(key)
out.append(p)
return ", ".join(out)
def _display_tag_text(tag: str) -> str:
return tag.replace("_", " ")
def _display_row_label(name: str) -> str:
n = (name or "").strip()
if not n:
return ""
if n == "selected_other":
return "Selected (Other)"
return n.replace("_", " ").title()
def _normalize_selection_origin(origin: str) -> str:
o = (origin or "").strip().lower()
if o in {"rewrite", "selection", "probe", "structural", "user", "candidate"}:
return o
return "selection"
def _choice_label_with_source_meta(tag: str, *, origin: str, preselected: bool) -> str:
# Keep labels plain to avoid frontend text/value desynchronization.
return _display_tag_text(tag)
@lru_cache(maxsize=1)
def _load_tag_wiki_defs() -> Dict[str, str]:
p = Path("data/tag_wiki_defs.json")
if not p.exists():
return {}
try:
with p.open("r", encoding="utf-8") as f:
data = json.load(f)
out: Dict[str, str] = {}
if isinstance(data, dict):
for k, v in data.items():
tag = _norm_tag_for_lookup(str(k))
text = " ".join(str(v or "").split())
if tag and text:
out[tag] = text
return out
except Exception:
return {}
@lru_cache(maxsize=1)
def _load_tag_tooltip_overrides() -> Dict[str, str]:
"""Load optional per-tag tooltip text overrides.
File format:
data/tag_tooltip_overrides.csv
columns: tag, tooltip_override
"""
p = Path("data/tag_tooltip_overrides.csv")
if not p.exists():
return {}
try:
out: Dict[str, str] = {}
with p.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
tag = _norm_tag_for_lookup(str(row.get("tag", "")))
text = " ".join(str(row.get("tooltip_override", "")).split())
if tag and text:
out[tag] = text
return out
except Exception:
return {}
@lru_cache(maxsize=1)
def _load_about_docs_markdown() -> str:
candidates = [
DOCS_DIR / "space_overview.md",
APP_DIR / "PROJECT_SUMMARY.md",
]
for p in candidates:
if not p.exists():
continue
try:
raw = p.read_text(encoding="utf-8")
except Exception:
continue
text = raw.strip()
if not text:
continue
# Strip YAML front matter if present.
if text.startswith("---"):
parts = text.split("---", 2)
if len(parts) >= 3:
text = parts[2].strip()
if text:
return text
return (
"Documentation is unavailable.\n\n"
"Expected file: `docs/space_overview.md`"
)
def _tooltip_text_for_tag(tag: str) -> str:
t = _norm_tag_for_lookup(tag)
parts: List[str] = []
try:
count = get_tag_counts().get(t)
except Exception:
count = None
if isinstance(count, int):
parts.append(f"Count: {count:,}")
d = _load_tag_tooltip_overrides().get(t, "")
if not d:
d = _load_tag_wiki_defs().get(t, "")
if d:
parts.append(d)
return "\n".join(parts).strip()
@lru_cache(maxsize=1)
def _load_arch_diagram_data_uri() -> str:
if not ARCH_DIAGRAM_FILE.exists():
return ""
try:
raw = ARCH_DIAGRAM_FILE.read_bytes()
except Exception:
return ""
if not raw:
return ""
b64 = base64.b64encode(raw).decode("ascii")
return f"data:image/png;base64,{b64}"
def _split_about_docs_for_diagram(md: str) -> Tuple[str, str, bool]:
text = (md or "").strip()
if ARCH_DIAGRAM_MARKER in text:
before, after = text.rsplit(ARCH_DIAGRAM_MARKER, 1)
return before.strip(), after.strip(), True
# Backward compatibility if an explicit architecture heading exists in docs.
m_arch = re.search(r"(?m)^##\s+Architecture At A Glance\s*$", text)
if m_arch:
before = text[: m_arch.start()].strip()
after = text[m_arch.end() :].strip()
return before, after, True
# Preferred insertion point: inject diagram right before "What Each Step Does".
m_steps = re.search(r"(?m)^##\s+What Each Step Does\s*$", text)
if m_steps:
before = text[: m_steps.start()].strip()
after = text[m_steps.start() :].strip()
return before, after, True
return text, "", False
def _build_arch_diagram_html() -> str:
uri = _load_arch_diagram_data_uri()
if not uri:
return "<p><code>(architecture diagram unavailable)</code></p>"
return f"""
<div class="arch-diagram-wrap">
<h2>Architecture At A Glance</h2>
<img src="{uri}" alt="Architecture diagram" />
</div>
"""
def _selection_source_rank(origin: str) -> int:
o = _normalize_selection_origin(origin)
if o == "structural":
return 0
if o == "probe":
return 1
# Keep rewrite/user in the same priority band as general selection for row ordering.
return 2
def _build_implied_parent_map(
direct_tags_ordered: List[str],
implied_tags: List[str],
) -> Dict[str, str]:
implied_set = {_norm_tag_for_lookup(t) for t in (implied_tags or []) if t}
if not implied_set or not direct_tags_ordered:
return {}
impl = get_tag_implications()
parent_by_implied: Dict[str, str] = {}
for direct in direct_tags_ordered:
d = _norm_tag_for_lookup(direct)
if not d:
continue
queue = [d]
seen = {d}
while queue:
t = queue.pop()
for parent in impl.get(t, ()):
p = _norm_tag_for_lookup(parent)
if not p or p in seen:
continue
seen.add(p)
if p in implied_set and p not in parent_by_implied:
parent_by_implied[p] = d
queue.append(p)
return parent_by_implied
def _order_selected_tags_for_row(
*,
row_selected_tags: List[str],
selected_index: Dict[str, int],
tag_selection_origins: Dict[str, str],
implied_parent_map: Dict[str, str],
) -> List[str]:
row_selected_norm = [_norm_tag_for_lookup(t) for t in (row_selected_tags or []) if t]
implied_in_row = {t for t in row_selected_norm if t in implied_parent_map}
base_tags = [t for t in row_selected_norm if t not in implied_in_row]
base_tags.sort(
key=lambda t: (
_selection_source_rank(tag_selection_origins.get(t, "selection")),
selected_index.get(t, 10**9),
t,
)
)
children_by_parent: Dict[str, List[str]] = {}
for implied in implied_in_row:
parent = implied_parent_map.get(implied)
if parent:
children_by_parent.setdefault(parent, []).append(implied)
for parent, children in children_by_parent.items():
children.sort(key=lambda t: (selected_index.get(t, 10**9), t))
ordered: List[str] = []
emitted: Set[str] = set()
for tag in base_tags:
if tag in emitted:
continue
ordered.append(tag)
emitted.add(tag)
for child in children_by_parent.get(tag, []):
if child not in emitted:
ordered.append(child)
emitted.add(child)
remaining_implied = [t for t in row_selected_norm if t not in emitted]
remaining_implied.sort(
key=lambda t: (
_selection_source_rank(tag_selection_origins.get(implied_parent_map.get(t, ""), "selection")),
selected_index.get(implied_parent_map.get(t, ""), 10**9),
selected_index.get(t, 10**9),
t,
)
)
for t in remaining_implied:
if t not in emitted:
ordered.append(t)
emitted.add(t)
return ordered
def _escape_prompt_tag(tag: str) -> str:
return (
tag.replace("_", " ")
.replace("(", "\\(")
.replace(")", "\\)")
)
def _ordered_selected_for_prompt(selected: Set[str], row_defs: List[Dict[str, Any]]) -> List[str]:
out: List[str] = []
seen: Set[str] = set()
for row in row_defs:
for tag in row.get("tags", []):
if tag in selected and tag not in seen:
out.append(tag)
seen.add(tag)
return out
def _compose_toggle_prompt_text(selected_tags: List[str], row_defs: List[Dict[str, Any]]) -> str:
selected = {t for t in (selected_tags or []) if t}
ordered = _ordered_selected_for_prompt(selected, row_defs or [])
return ", ".join(_escape_prompt_tag(t) for t in ordered)
def _is_artist_tag(tag: str) -> bool:
t = _norm_tag_for_lookup(str(tag))
if not t:
return False
# Keep a resilient fallback for malformed/missing tag typing metadata.
return get_tag_type_name(t) == "artist" or t.startswith("by_")
@lru_cache(maxsize=1)
def _load_excluded_recommendation_tags() -> Set[str]:
out: Set[str] = set()
# Existing category-registry driven exclusions.
csv_path = Path("data/category_registry.csv")
if not csv_path.exists():
csv_path = Path("data/analysis/category_registry.csv")
if csv_path.exists():
try:
with csv_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
tag = _norm_tag_for_lookup(str(row.get("tag") or ""))
if not tag:
continue
status = str(row.get("category_status") or "").strip().lower()
if status == "excluded":
out.add(tag)
except Exception:
pass
# Corporate-safety exclusions (editable runtime list).
corp_path = Path("data/corporate_excluded_tags.csv")
if corp_path.exists():
try:
with corp_path.open("r", encoding="utf-8", newline="") as f:
reader = csv.DictReader(f)
for row in reader:
tag = _norm_tag_for_lookup(str(row.get("tag") or ""))
if not tag:
continue
enabled_raw = str(row.get("enabled", "1")).strip().lower()
enabled = enabled_raw not in {"0", "false", "no", "off"}
if enabled:
out.add(tag)
except Exception:
pass
return out
def _is_hardblocked_corporate_tag(tag: str) -> bool:
t = _norm_tag_for_lookup(str(tag))
if not t:
return False
return any(rx.search(t) for rx in _CORPORATE_HARDBLOCK_PATTERNS)
def _is_excluded_recommendation_tag(tag: str) -> bool:
t = _norm_tag_for_lookup(str(tag))
if not t:
return False
if _is_hardblocked_corporate_tag(t):
return True
return t in _load_excluded_recommendation_tags()
def _get_min_tag_count() -> int:
try:
return max(0, int(os.environ.get("PSQ_MIN_TAG_COUNT", "100")))
except Exception:
return 100
def _filter_min_count_tags(tags: List[str], min_count: int) -> Tuple[List[str], List[str]]:
if min_count <= 0:
return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), []
tag_counts = get_tag_counts()
keep: List[str] = []
removed: List[str] = []
seen: Set[str] = set()
for raw in (tags or []):
t = _norm_tag_for_lookup(str(raw))
if not t:
continue
c = int(tag_counts.get(t, 0) or 0)
if c < min_count:
removed.append(t)
continue
if t in seen:
continue
seen.add(t)
keep.append(t)
return keep, sorted(set(removed))
def _filter_excluded_recommendation_tags(tags: List[str]) -> Tuple[List[str], List[str]]:
excluded = _load_excluded_recommendation_tags()
if not excluded:
return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), []
keep: List[str] = []
removed: List[str] = []
seen: Set[str] = set()
for raw in (tags or []):
t = _norm_tag_for_lookup(str(raw))
if not t:
continue
if t in excluded:
removed.append(t)
continue
if t in seen:
continue
seen.add(t)
keep.append(t)
return keep, sorted(set(removed))
def _filter_excluded_candidates(candidates: List[Any]) -> Tuple[List[Any], List[str]]:
excluded = _load_excluded_recommendation_tags()
if not excluded:
return list(candidates or []), []
keep: List[Any] = []
removed: List[str] = []
for c in (candidates or []):
tag = _norm_tag_for_lookup(str(getattr(c, "tag", "") or ""))
if tag and tag in excluded:
removed.append(tag)
continue
keep.append(c)
return keep, sorted(set(removed))
def _dedupe_norm_tags(tags: List[str]) -> List[str]:
out: List[str] = []
seen: Set[str] = set()
for raw in (tags or []):
t = _norm_tag_for_lookup(str(raw))
if not t or t in seen:
continue
seen.add(t)
out.append(t)
return out
def _collect_visible_tags(row_defs: List[Dict[str, Any]]) -> Set[str]:
out: Set[str] = set()
for row in (row_defs or []):
for t in _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else []):
out.add(t)
return out
def _collect_selected_from_state(
selected_tags_state: List[str],
row_defs: List[Dict[str, Any]],
) -> List[str]:
visible_tags = _collect_visible_tags(row_defs)
if not visible_tags:
return []
selected: List[str] = []
seen: Set[str] = set()
visible_by_norm = {_norm_tag_for_lookup(t): t for t in visible_tags}
for raw in (selected_tags_state or []):
t = _norm_tag_for_lookup(str(raw))
if not t:
continue
mapped = t if t in visible_tags else visible_by_norm.get(t)
if not mapped or mapped in seen:
continue
seen.add(mapped)
selected.append(mapped)
return selected
def _collect_selected_from_row_values(
row_defs: List[Dict[str, Any]],
row_values_state: List[List[str]],
) -> List[str]:
selected: List[str] = []
seen: Set[str] = set()
values = list(row_values_state or [])
for idx, row in enumerate(row_defs or []):
row_tags = _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else [])
if not row_tags:
continue
row_tag_set = set(row_tags)
row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags}
raw_vals = values[idx] if 0 <= idx < len(values) else []
for raw in (raw_vals or []):
if raw in row_tag_set:
if raw not in seen:
seen.add(raw)
selected.append(raw)
continue
raw_norm = _norm_tag_for_lookup(str(raw))
mapped = row_tag_by_norm.get(raw_norm)
if mapped and mapped not in seen:
seen.add(mapped)
selected.append(mapped)
return selected
def _build_toggle_rows(
*,
seed_terms: List[str],
selected_tags: List[str],
retrieved_candidate_tags: List[str],
tag_selection_origins: Dict[str, str],
implied_parent_map: Dict[str, str],
top_groups: int,
top_tags_per_group: int,
group_rank_top_k: int,
) -> List[Dict[str, Any]]:
ranked_rows = rank_groups_from_tfidf(
seed_terms=seed_terms,
top_groups=max(1, int(top_groups)),
top_tags_per_group=max(1, int(top_tags_per_group)),
group_rank_top_k=max(1, int(group_rank_top_k)),
)
groups_map = _load_enabled_groups()
selected_active = list(
dict.fromkeys(
_norm_tag_for_lookup(t)
for t in selected_tags
if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)
)
)
selected_index: Dict[str, int] = {t: i for i, t in enumerate(selected_active)}
selected_set: Set[str] = set(selected_active)
def _is_pipeline_preselected(tag: str) -> bool:
# Tags explicitly user-toggled should keep user coloring after row rebuilds.
# Only non-user selected tags should render as pipeline-preselected.
if tag not in selected_set:
return False
return _normalize_selection_origin(tag_selection_origins.get(tag, "selection")) != "user"
row_defs: List[Dict[str, Any]] = []
enabled_group_tag_sets: Dict[str, Set[str]] = {
name: {t for t in tags if not _is_artist_tag(t)}
for name, tags in groups_map.items()
}
tags_in_any_enabled_group: Set[str] = set()
for tag_set in enabled_group_tag_sets.values():
tags_in_any_enabled_group.update(tag_set)
displayed_group_names = [r.group_name for r in ranked_rows]
displayed_group_tag_sets: Dict[str, Set[str]] = {
name: enabled_group_tag_sets.get(name, set())
for name in displayed_group_names
}
tags_in_any_displayed_group: Set[str] = set()
for tag_set in displayed_group_tag_sets.values():
tags_in_any_displayed_group.update(tag_set)
retrieved_uncategorized_ranked = list(
dict.fromkeys(
_norm_tag_for_lookup(t)
for t in (retrieved_candidate_tags or [])
if t
and not _is_artist_tag(t)
and not _is_excluded_recommendation_tag(t)
and _norm_tag_for_lookup(t) not in tags_in_any_enabled_group
)
)
retrieved_other_row: Dict[str, Any] | None = None
if retrieved_uncategorized_ranked:
retrieved_uncategorized_set = set(retrieved_uncategorized_ranked)
selected_in_retrieved_other_raw = [
t for t in selected_active if t in retrieved_uncategorized_set
]
selected_in_retrieved_other = _order_selected_tags_for_row(
row_selected_tags=selected_in_retrieved_other_raw,
selected_index=selected_index,
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
)
merged_retrieved_other = selected_in_retrieved_other + [
t for t in retrieved_uncategorized_ranked if t not in selected_in_retrieved_other
]
merged_retrieved_other = _dedupe_norm_tags(merged_retrieved_other)
keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_retrieved_other))
merged_retrieved_other = merged_retrieved_other[:keep_n]
retrieved_other_meta = {
t: {
"origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")),
"preselected": _is_pipeline_preselected(t),
}
for t in merged_retrieved_other
}
retrieved_other_row = {
"name": "other_retrieved",
"label": "Other (Retrieved)",
"tags": merged_retrieved_other,
"tag_meta": retrieved_other_meta,
}
# "Selected (Other)" should contain selected tags not already shown in any displayed row.
# Include "Other (Retrieved)" in that displayed-row set to avoid duplicates across those rows.
tags_in_displayed_rows = set(tags_in_any_displayed_group)
if retrieved_other_row:
tags_in_displayed_rows.update(retrieved_other_row.get("tags", []))
selected_other_raw = [t for t in selected_active if t not in tags_in_displayed_rows]
selected_other = _order_selected_tags_for_row(
row_selected_tags=selected_other_raw,
selected_index=selected_index,
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
)
selected_other = _dedupe_norm_tags(selected_other)
selected_other_meta = {
t: {
"origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")),
"preselected": _is_pipeline_preselected(t),
}
for t in selected_other
}
row_defs.append(
{
"name": "selected_other",
"label": _display_row_label("selected_other"),
"tags": selected_other,
"tag_meta": selected_other_meta,
}
)
for row in ranked_rows:
group_name = row.group_name
group_tag_set = displayed_group_tag_sets.get(group_name, set())
selected_in_group_raw = [t for t in selected_active if t in group_tag_set]
selected_in_group = _order_selected_tags_for_row(
row_selected_tags=selected_in_group_raw,
selected_index=selected_index,
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
)
ranked_tags = [
_norm_tag_for_lookup(t)
for t, _ in row.tags
if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)
]
ranked_tags = _dedupe_norm_tags(ranked_tags)
merged = selected_in_group + [t for t in ranked_tags if t not in selected_in_group]
merged = _dedupe_norm_tags(merged)
keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_group))
merged = merged[:keep_n]
tag_meta = {
t: {
"origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")),
"preselected": _is_pipeline_preselected(t),
}
for t in merged
}
row_defs.append(
{
"name": group_name,
"label": _display_row_label(group_name),
"tags": merged,
"tag_meta": tag_meta,
}
)
# Keep this row at the bottom so category/group rows remain contiguous.
if retrieved_other_row:
row_defs.append(retrieved_other_row)
return row_defs
def _build_display_audit_line(
row_defs: List[Dict[str, Any]],
*,
active_selected_tags: List[str],
direct_selected_tags: List[str],
implied_selected_tags: List[str],
) -> str:
active_set = {
_norm_tag_for_lookup(t)
for t in (active_selected_tags or [])
if t and not _is_artist_tag(t)
}
direct_set = {
_norm_tag_for_lookup(t)
for t in (direct_selected_tags or [])
if t and not _is_artist_tag(t)
}
implied_set = {
_norm_tag_for_lookup(t)
for t in (implied_selected_tags or [])
if t and not _is_artist_tag(t)
}
info_by_tag: Dict[str, Dict[str, Any]] = {}
for row in row_defs or []:
row_name = row.get("name", "")
row_label = row.get("label", row_name)
for tag in row.get("tags", []):
rec = info_by_tag.setdefault(tag, {"rows": [], "sources": set()})
rec["rows"].append(row_label)
if row_name == "selected_other":
rec["sources"].add("selected_other_row")
elif row_name == "other_retrieved":
rec["sources"].add("other_retrieved_row")
else:
rec["sources"].add("ranked_group_row")
if tag in active_set:
rec["sources"].add("selected_active")
if tag in direct_set:
rec["sources"].add("selected_direct")
if tag in implied_set:
rec["sources"].add("selected_implied")
payload = {
"n_tags": len(info_by_tag),
"tags": [
{
"tag": tag,
"rows": rec["rows"],
"sources": sorted(rec["sources"]),
}
for tag, rec in sorted(info_by_tag.items())
],
}
return "Display Tag Audit: " + json.dumps(payload, ensure_ascii=True)
def _build_tooltip_payload(row_defs: List[Dict[str, Any]], max_rows: int) -> str:
row_defs_ui = (row_defs or [])[: max(0, int(max_rows))]
tips: Dict[str, str] = {}
rows: List[List[str]] = []
meta_rows: List[List[Dict[str, Any]]] = []
for row in row_defs_ui:
tags = _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else [])
rows.append(tags)
row_meta = row.get("tag_meta", {}) if isinstance(row, dict) and isinstance(row.get("tag_meta", {}), dict) else {}
meta_row: List[Dict[str, Any]] = []
for t in tags:
raw_meta = row_meta.get(t, {}) if isinstance(row_meta.get(t, {}), dict) else {}
meta_row.append(
{
"origin": _normalize_selection_origin(str(raw_meta.get("origin", "selection"))),
"preselected": bool(raw_meta.get("preselected", False)),
}
)
meta_rows.append(meta_row)
for t in tags:
if t not in tips:
tips[t] = _tooltip_text_for_tag(t)
return json.dumps({"rows": rows, "meta_rows": meta_rows, "tips": tips}, ensure_ascii=True)
def _build_row_component_updates(
row_defs: List[Dict[str, Any]],
selected_tags: List[str],
max_rows: int,
):
selected = {t for t in (selected_tags or []) if t}
row_defs_ui = (row_defs or [])[: max(0, int(max_rows))]
row_values_state: List[List[str]] = []
header_updates = []
checkbox_updates = []
for idx in range(max_rows):
if idx < len(row_defs_ui):
row = row_defs_ui[idx]
tags = _dedupe_norm_tags(row.get("tags", []))
values = [t for t in tags if t in selected]
row_values_state.append(values)
visible = bool(tags)
header_updates.append(gr.update(value=row.get("label", ""), visible=visible))
tag_meta = row.get("tag_meta", {}) if isinstance(row.get("tag_meta", {}), dict) else {}
choices = []
for t in tags:
meta = tag_meta.get(t, {}) if isinstance(tag_meta.get(t, {}), dict) else {}
origin = _normalize_selection_origin(str(meta.get("origin", "selection")))
preselected = bool(meta.get("preselected", False))
choices.append((_choice_label_with_source_meta(t, origin=origin, preselected=preselected), t))
checkbox_updates.append(
gr.update(
choices=choices,
value=values,
visible=visible,
)
)
else:
header_updates.append(gr.update(value="", visible=False))
checkbox_updates.append(gr.update(choices=[], value=[], visible=False))
prompt_text = _compose_toggle_prompt_text(list(selected), row_defs_ui)
return prompt_text, row_values_state, header_updates, checkbox_updates
def _on_toggle_row(
row_idx: int,
changed_values: List[str],
selected_tags_state: List[str],
rows_dirty_state: bool,
row_defs_state: List[Dict[str, Any]],
row_values_state: List[List[str]],
max_rows: int,
):
row_defs = row_defs_state or []
row_defs_ui = row_defs[: max(0, int(max_rows))]
prev_values = list(row_values_state or [])
selected_from_state = _collect_selected_from_state(selected_tags_state, row_defs_ui)
selected_from_rows = _collect_selected_from_row_values(row_defs_ui, prev_values)
# Prefer row-value state as source-of-truth (closest to visible UI), with selected-state as fallback.
selected: Set[str] = set(selected_from_rows or selected_from_state)
row = row_defs_ui[row_idx] if 0 <= row_idx < len(row_defs_ui) else {}
row_tags = _dedupe_norm_tags(row.get("tags", []))
row_label = str(row.get("label", ""))
row_tag_set = set(row_tags)
row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags}
# Be tolerant to UI payload forms: canonical tag values, display labels, normalized variants,
# and occasional single-string payloads from frontend events.
if changed_values is None:
changed_iter: List[Any] = []
elif isinstance(changed_values, str):
changed_iter = [changed_values]
elif isinstance(changed_values, (list, tuple, set)):
changed_iter = list(changed_values)
else:
changed_iter = [changed_values]
# Be tolerant to UI payload forms: canonical tag values, display labels, or normalized variants.
new_set: Set[str] = set()
for raw in changed_iter:
if raw in row_tag_set:
new_set.add(raw)
continue
raw_norm = _norm_tag_for_lookup(str(raw))
mapped = row_tag_by_norm.get(raw_norm)
if mapped:
new_set.add(mapped)
prev_row_selected = {t for t in row_tags if t in selected}
# Ignore non-user/no-op events (e.g., programmatic value re-sets) deterministically.
if new_set == prev_row_selected:
prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui)
checkbox_updates = [gr.skip() for _ in range(max_rows)]
return [sorted(selected), rows_dirty_state, gr.skip(), prev_values, prompt_text, *checkbox_updates]
selected.difference_update(row_tag_set)
selected.update(new_set)
toggled_tags = prev_row_selected ^ new_set
new_row_values_state: List[List[str]] = []
affected_rows: Set[int] = {row_idx}
for idx, row_item in enumerate(row_defs_ui):
tags = _dedupe_norm_tags(row_item.get("tags", []))
values = [t for t in tags if t in selected]
new_row_values_state.append(values)
if toggled_tags and any(t in toggled_tags for t in tags):
affected_rows.add(idx)
checkbox_updates = []
for idx in range(max_rows):
if idx >= len(row_defs_ui):
checkbox_updates.append(gr.skip())
continue
if idx in affected_rows:
checkbox_updates.append(gr.update(value=new_row_values_state[idx]))
else:
checkbox_updates.append(gr.skip())
prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui)
return [
sorted(selected),
True,
gr.update(visible=True, interactive=True),
new_row_values_state,
prompt_text,
*checkbox_updates,
]
def _build_ui_payload(
*,
console_text: str,
row_defs: List[Dict[str, Any]],
selected_tags: List[str],
suggested_prompt_text: str | None = None,
suggested_prompt_label: str | None = None,
):
prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates(
row_defs=row_defs,
selected_tags=selected_tags,
max_rows=display_max_rows_default,
)
if suggested_prompt_text is not None:
prompt_text = str(suggested_prompt_text)
effective_label = suggested_prompt_label or SUGGESTED_PROMPT_LABEL_READY
selected_ui: List[str] = []
selected_ui_seen: Set[str] = set()
for vals in row_values_state:
for t in vals:
if t in selected_ui_seen:
continue
selected_ui_seen.add(t)
selected_ui.append(t)
tooltip_payload = _build_tooltip_payload(row_defs, display_max_rows_default)
return [
console_text,
gr.update(visible=bool(row_defs)),
tooltip_payload,
gr.update(value=prompt_text, label=effective_label),
selected_ui,
False,
gr.update(visible=False, interactive=False),
row_defs,
row_values_state,
*header_updates,
*checkbox_updates,
]
def _format_user_facing_error(exc: Exception) -> str:
msg = str(exc or "").strip()
msg_l = msg.lower()
if "rewrite: empty output" in msg_l:
return (
"Could not rewrite that prompt. Try simpler, neutral wording and remove sensitive phrasing, "
"then click Run again."
)
if "openrouter_api_key" in msg_l:
return "Service configuration is missing. Please contact the app owner."
if "timed out" in msg_l:
return "The model request timed out. Please try again with a shorter or simpler prompt."
if "index selection failed" in msg_l:
return "Tag selection failed for this request. Please try again."
if "startup preflight failed" in msg_l:
return "App startup checks failed. Please contact the app owner."
return "Something went wrong while processing the prompt. Please try again."
def _prepare_run_ui() -> List[Any]:
header_updates = [gr.update(value="", visible=False) for _ in range(display_max_rows_default)]
checkbox_updates = [
gr.update(choices=[], value=[], visible=False)
for _ in range(display_max_rows_default)
]
return [
"Running...",
gr.skip(),
"{}",
gr.update(
value="Starting... preparing query processing.",
label=SUGGESTED_PROMPT_LABEL_PHASE_START,
),
[],
False,
gr.update(visible=False, interactive=False),
[],
[],
*header_updates,
*checkbox_updates,
]
def _prepare_run_ui_with_rewrite_state() -> List[Any]:
return [*_prepare_run_ui(), ""]
def _recognized_rewrite_tags(rewritten: str, min_tag_count: int) -> List[str]:
pieces = [p.strip() for p in (rewritten or "").split(",") if p.strip()]
norm = list(dict.fromkeys(_norm_tag_for_lookup(p) for p in pieces if p))
norm = [t for t in norm if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)]
if not norm:
return []
tag_counts = get_tag_counts()
out: List[str] = []
for t in norm:
c = int(tag_counts.get(t, 0) or 0)
if c <= 0:
continue
if min_tag_count > 0 and c < min_tag_count:
continue
out.append(t)
return out
def _rewrite_preview_ui(
user_prompt: str,
) -> List[Any]:
logs: List[str] = []
def log(s: Any) -> None:
logs.append(str(s))
prompt_in = (user_prompt or "").strip()
if not prompt_in:
return [*_build_ui_payload(
console_text="Error: empty prompt",
row_defs=[],
selected_tags=[],
suggested_prompt_text='Enter a prompt and click "Run".',
), ""]
log("Start: received prompt")
log("Input:")
log(prompt_in)
log("")
log("Step 1/2: query reformulation preview")
try:
t0 = time.perf_counter()
rewritten = _rewrite_prompt(prompt_in, log)
dt = time.perf_counter() - t0
if not rewritten:
raise RuntimeError("Rewrite: empty output")
log(f"Rewrite: {dt:.2f}s")
log("Rewrite:")
log(rewritten)
min_tag_count = _get_min_tag_count()
recognized = _recognized_rewrite_tags(rewritten, min_tag_count=min_tag_count)
if recognized:
preview_text = (
"Rewrite preview tags:\n" +
", ".join(_escape_prompt_tag(t) for t in recognized)
)
log(f"Preview: recognized {len(recognized)} rewrite tags")
else:
preview_text = "(No recognized rewrite tags yet; continuing...)"
log("Preview: no recognized rewrite tags; continuing")
log("Step 2/2: retrieval + selection in progress...")
return [*_build_ui_payload(
console_text="\n".join(logs),
row_defs=[],
selected_tags=[],
suggested_prompt_text=preview_text,
suggested_prompt_label=SUGGESTED_PROMPT_LABEL_PHASE_SELECT,
), rewritten]
except Exception as e:
log(f"Rewrite preview failed: {type(e).__name__}: {_redact_console_error_text(e)}")
return [*_build_ui_payload(
console_text="\n".join(logs),
row_defs=[],
selected_tags=[],
suggested_prompt_text="Rewrite failed; see console details.",
suggested_prompt_label=SUGGESTED_PROMPT_LABEL_PHASE_REWRITE,
), ""]
def _update_run_button_visibility(prompt_text: str, last_run_prompt: str):
curr = (prompt_text or "").strip()
last = (last_run_prompt or "").strip()
can_run = bool(curr) and curr != last
return gr.update(visible=can_run, interactive=can_run)
def _mark_run_triggered(prompt_text: str):
curr = (prompt_text or "").strip()
return gr.update(visible=False, interactive=False), curr
def _rebuild_rows_from_selected(
selected_tags_state: List[str],
row_defs_state: List[Dict[str, Any]],
row_values_state: List[List[str]],
display_top_groups: float,
display_top_tags_per_group: float,
display_rank_top_k: float,
):
existing_rows = row_defs_state or []
existing_values = list(row_values_state or [])
selected_from_state = _collect_selected_from_state(selected_tags_state, existing_rows)
selected_from_rows = _collect_selected_from_row_values(existing_rows, existing_values)
# Rebuild source-of-truth is current row checkbox values; fall back only when unavailable.
selected_seed = selected_from_rows if existing_values else selected_from_state
selected_active = list(
dict.fromkeys(
_norm_tag_for_lookup(t)
for t in selected_seed
if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)
)
)
selected_active_set: Set[str] = set(selected_active)
retrieved_candidate_tags: List[str] = []
tag_selection_origins: Dict[str, str] = {}
for row in existing_rows:
row_tags = row.get("tags", []) if isinstance(row, dict) else []
row_meta = row.get("tag_meta", {}) if isinstance(row, dict) else {}
if not isinstance(row_meta, dict):
row_meta = {}
for t in row_tags:
tn = _norm_tag_for_lookup(t)
if not tn or _is_artist_tag(tn) or _is_excluded_recommendation_tag(tn):
continue
retrieved_candidate_tags.append(tn)
if tn not in tag_selection_origins:
meta = row_meta.get(t, {}) if isinstance(row_meta.get(t, {}), dict) else {}
meta_origin = _normalize_selection_origin(str(meta.get("origin", "selection")))
meta_preselected = bool(meta.get("preselected", False))
# Preserve explicit user toggles across rebuild:
# if a currently selected tag was not pipeline-preselected, keep user provenance.
if tn in selected_active_set and not meta_preselected:
meta_origin = "user"
tag_selection_origins[tn] = meta_origin
for t in selected_active:
tag_selection_origins.setdefault(t, "user")
retrieved_candidate_tags.append(t)
implied_selected_tags = [t for t in selected_active if tag_selection_origins.get(t) == "implied"]
implied_set = set(implied_selected_tags)
direct_selected_tags = [t for t in selected_active if t not in implied_set]
direct_idx = {t: i for i, t in enumerate(direct_selected_tags)}
direct_selected_tags.sort(
key=lambda t: (
_selection_source_rank(tag_selection_origins.get(t, "selection")),
direct_idx.get(t, 10**9),
)
)
implied_parent_map = _build_implied_parent_map(
direct_tags_ordered=direct_selected_tags,
implied_tags=implied_selected_tags,
)
toggle_rows = _build_toggle_rows(
seed_terms=list(selected_active),
selected_tags=selected_active,
retrieved_candidate_tags=list(dict.fromkeys(retrieved_candidate_tags)),
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
top_groups=max(1, int(display_top_groups)),
top_tags_per_group=max(1, int(display_top_tags_per_group)),
group_rank_top_k=max(1, int(display_rank_top_k)),
)
prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates(
row_defs=toggle_rows,
selected_tags=selected_active,
max_rows=display_max_rows_default,
)
tooltip_payload = _build_tooltip_payload(toggle_rows, display_max_rows_default)
return [
gr.update(visible=bool(toggle_rows)),
tooltip_payload,
prompt_text,
sorted(selected_active),
False,
gr.update(visible=False, interactive=False),
toggle_rows,
row_values_state,
*header_updates,
*checkbox_updates,
]
def _build_selection_query(
prompt_in: str,
rewritten: str,
structural_tags: List[str],
probe_tags: List[str],
) -> str:
lines = [f"IMAGE DESCRIPTION: {prompt_in.strip()}"]
if rewritten and rewritten.strip():
lines.append(f"REWRITE PHRASES: {rewritten.strip()}")
hint_tags = []
if structural_tags:
hint_tags.extend(structural_tags)
if probe_tags:
hint_tags.extend(probe_tags)
if hint_tags:
# Keep hints as context only; selection still must choose by candidate indices.
lines.append(
"INFERRED TAG HINTS (context only): " + ", ".join(sorted(set(hint_tags)))
)
return "\n".join(lines)
# Set up logging
# Minimal prod logging: warnings+ to stderr, no file by default
import os, logging
LOG_LEVEL = os.environ.get("PSQ_LOG_LEVEL", "WARNING").upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.WARNING),
format="%(asctime)s %(levelname)s:%(message)s",
handlers=[logging.StreamHandler()] # no file -> avoids huge logs on Spaces
)
# Quiet down common noisy libs (optional)
for _name in ("gensim", "gradio", "hnswlib", "httpx", "uvicorn"):
logging.getLogger(_name).setLevel(logging.ERROR)
# Turn off Gradio analytics phone-home to avoid those background thread errors (optional)
os.environ["GRADIO_ANALYTICS_ENABLED"] = "0"
MASCOT_DIR = Path(__file__).parent / "mascotimages"
MASCOT_FILE = MASCOT_DIR / "transparentsquirrel.png"
def _load_mascot_image():
"""Load mascot image if available; return None when missing/unreadable."""
if not MASCOT_FILE.exists():
logging.warning("Mascot image missing: %s", MASCOT_FILE)
return None
try:
return Image.open(MASCOT_FILE).convert("RGBA")
except Exception as e:
logging.warning("Failed to load mascot image (%s): %s", MASCOT_FILE, e)
return None
try:
from gradio_client import utils as _gc_utils
_orig_get_type = _gc_utils.get_type
_orig_j2p = _gc_utils._json_schema_to_python_type
_orig_pub = _gc_utils.json_schema_to_python_type
def _get_type_safe(schema):
# Sometimes schema is a bare True/False (JSON Schema boolean form)
if not isinstance(schema, dict):
return "any"
return _orig_get_type(schema)
def _j2p_safe(schema, defs=None):
# Accept non-dict schemas (True/False/None) and treat as "any"
if not isinstance(schema, dict):
return "any"
return _orig_j2p(schema, defs or schema.get("$defs"))
def _pub_safe(schema):
# Public wrapper used by Gradio; keep it resilient too
if not isinstance(schema, dict):
return "any"
return _j2p_safe(schema, schema.get("$defs"))
_gc_utils.get_type = _get_type_safe
_gc_utils._json_schema_to_python_type = _j2p_safe
_gc_utils.json_schema_to_python_type = _pub_safe
except Exception as e:
print("gradio_client hotfix not applied:", e)
# -------------------------------------------------------------------------------
allow_nsfw_tags = False
def _is_production_runtime() -> bool:
"""Best-effort detection for deployed runtime (HF Spaces or explicit env)."""
if os.environ.get("PSQ_PRODUCTION", "").strip().lower() in {"1", "true", "yes"}:
return True
if os.environ.get("SPACE_ID"):
return True
if os.environ.get("HF_SPACE_ID"):
return True
if os.environ.get("SYSTEM") == "spaces":
return True
return False
verbose_retrieval_default = "0" if _is_production_runtime() else "1"
verbose_retrieval = os.environ.get("PSQ_VERBOSE_RETRIEVAL", verbose_retrieval_default).strip().lower() in {"1", "true", "yes"}
verbose_retrieval_all = False
verbose_retrieval_limit = 20
enable_probe_tags = os.environ.get("PSQ_ENABLE_PROBE", "1").strip() not in {"0", "false", "False"}
display_top_groups_default = int(os.environ.get("PSQ_DISPLAY_TOP_GROUPS", "10"))
display_top_tags_per_group_default = int(os.environ.get("PSQ_DISPLAY_TOP_TAGS_PER_GROUP", "7"))
display_rank_top_k_default = int(os.environ.get("PSQ_DISPLAY_GROUP_RANK_TOP_K", "7"))
display_max_rows_default = int(os.environ.get("PSQ_DISPLAY_MAX_ROWS", "14"))
retrieval_global_k = int(os.environ.get("PSQ_RETRIEVAL_GLOBAL_K", "300"))
retrieval_per_phrase_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_K", "10"))
retrieval_per_phrase_final_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_FINAL_K", "1"))
retrieval_exact_ngram_max = int(os.environ.get("PSQ_RETRIEVAL_EXACT_NGRAM_MAX", "2"))
selection_mode = os.environ.get("PSQ_SELECTION_MODE", "chunked_map_union").strip()
selection_chunk_size = int(os.environ.get("PSQ_SELECTION_CHUNK_SIZE", "60"))
selection_per_phrase_k = int(os.environ.get("PSQ_SELECTION_PER_PHRASE_K", "2"))
selection_candidate_cap = int(os.environ.get("PSQ_SELECTION_CANDIDATE_CAP", "0"))
stage1_rewrite_timeout_s = float(os.environ.get("PSQ_TIMEOUT_REWRITE_S", "45"))
stage1_struct_timeout_s = float(os.environ.get("PSQ_TIMEOUT_STRUCT_S", "45"))
stage1_probe_timeout_s = float(os.environ.get("PSQ_TIMEOUT_PROBE_S", "45"))
stage3_select_timeout_s = float(os.environ.get("PSQ_TIMEOUT_SELECT_S", "50"))
stage3_select_retry_timeout_s = float(os.environ.get("PSQ_TIMEOUT_SELECT_RETRY_S", "20"))
stage3_fast_retry_count = max(0, int(os.environ.get("PSQ_STAGE3_FAST_RETRY_COUNT", "1")))
timing_log_path = Path(os.environ.get("PSQ_TIMING_LOG_PATH", "data/runtime_metrics/ui_pipeline_timings.jsonl"))
def _startup_preflight_errors() -> List[str]:
errs: List[str] = []
if not os.getenv("OPENROUTER_API_KEY"):
errs.append("OPENROUTER_API_KEY is missing. Set it in Space Secrets or environment variables.")
return errs
STARTUP_PREFLIGHT_ERRORS = _startup_preflight_errors()
if STARTUP_PREFLIGHT_ERRORS:
for _err in STARTUP_PREFLIGHT_ERRORS:
logging.error("Startup preflight error: %s", _err)
_startup_profile_mark(
"startup_preflight.done",
error_count=len(STARTUP_PREFLIGHT_ERRORS),
)
css = """
.scrollable-content{
max-height: 420px;
overflow-y: scroll; /* always show scrollbar */
overflow-x: hidden;
padding-right: 8px;
padding-bottom: 14px; /* <— add this */
scrollbar-gutter: stable; /* prevent layout shift as it fills */
/* Firefox */
scrollbar-width: auto;
scrollbar-color: rgba(180,180,180,.9) rgba(0,0,0,.15);
}
/* WebKit/Chromium (Chrome/Edge/Safari) */
.scrollable-content::-webkit-scrollbar{ width: 10px; }
.scrollable-content::-webkit-scrollbar-thumb{ background: rgba(180,180,180,.9); border-radius: 8px; }
.scrollable-content::-webkit-scrollbar-track{ background: rgba(0,0,0,.15); }
/* (Optional) make both scroll panes taller so they fill more of the column */
.pane-left .scrollable-content,
.pane-right .scrollable-content {
max-height: 610px; /* was 420px; tweak to taste */
}
/* Console: force internal scrolling so full logs are always reachable */
#psq-console,
#psq-console .scroll-hide,
#psq-console textarea {
overflow-y: auto !important;
overflow-x: hidden !important;
max-height: 420px !important;
min-height: 240px !important;
}
#psq-console textarea {
white-space: pre-wrap !important;
word-break: break-word !important;
font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, "Liberation Mono", monospace !important;
}
.lego-tags .gr-checkboxgroup,
.lego-tags .wrap {
display: flex !important;
flex-wrap: wrap !important;
gap: 10px !important;
}
.lego-tags label {
margin: 0 !important;
padding: 0 !important;
position: relative !important;
}
/* Hide native checkbox visuals completely */
.lego-tags input[type="checkbox"] {
appearance: none !important;
-webkit-appearance: none !important;
-moz-appearance: none !important;
position: absolute !important;
width: 1px !important;
height: 1px !important;
opacity: 0 !important;
pointer-events: none !important;
display: none !important;
}
/* Brick button skin (works for both +span and ~span structures) */
.lego-tags input[type="checkbox"] + span,
.lego-tags input[type="checkbox"] ~ span {
--on-bg1: #ffd166;
--on-bg2: #f39c4a;
--on-border: #b86e21;
--on-text: #2e1706;
position: relative !important;
display: inline-flex !important;
align-items: center !important;
min-height: 40px !important;
padding: 10px 15px 9px 22px !important;
border: 1px solid #9aa6b8 !important;
border-radius: 10px !important;
background: linear-gradient(180deg, #dfe5ee 0%, #bec8d6 100%) !important;
color: #364254 !important;
font-size: 0.97rem !important;
font-weight: 800 !important;
line-height: 1.15 !important;
cursor: pointer !important;
user-select: none !important;
letter-spacing: 0.01em !important;
box-shadow: 0 3px 0 rgba(0,0,0,0.16), inset 0 1px 0 rgba(255,255,255,0.55) !important;
transition: transform 0.08s ease, box-shadow 0.08s ease, filter 0.08s ease !important;
}
.lego-tags input[type="checkbox"] + span::before,
.lego-tags input[type="checkbox"] ~ span::before {
content: "" !important;
position: absolute !important;
top: 5px !important;
left: 8px !important;
width: 8px !important;
height: 8px !important;
border-radius: 50% !important;
background: rgba(255,255,255,0.58) !important;
box-shadow: 22px 0 0 rgba(255,255,255,0.58) !important;
pointer-events: none !important;
}
/* Unselected cue: show "+" on the left. */
.lego-tags input[type="checkbox"] + span::after,
.lego-tags input[type="checkbox"] ~ span::after {
content: "+" !important;
position: absolute !important;
left: 6px !important;
top: 50% !important;
transform: translateY(-52%) !important;
font-size: 1rem !important;
font-weight: 900 !important;
color: #4b5563 !important;
opacity: 0.95 !important;
pointer-events: none !important;
}
/* Bright color cycle used only when selected */
.lego-tags label:nth-child(8n+1) span { --on-bg1: #ffd166; --on-bg2: #f39c4a; --on-border: #b86e21; --on-text: #2e1706; }
.lego-tags label:nth-child(8n+2) span { --on-bg1: #6ee7ff; --on-bg2: #1fb7ff; --on-border: #157cb3; --on-text: #07263c; }
.lego-tags label:nth-child(8n+3) span { --on-bg1: #9dff8f; --on-bg2: #45c96f; --on-border: #2a8b4b; --on-text: #0d2917; }
.lego-tags label:nth-child(8n+4) span { --on-bg1: #ff8fab; --on-bg2: #ff5c7a; --on-border: #b83956; --on-text: #3f0f1d; }
.lego-tags label:nth-child(8n+5) span { --on-bg1: #d0a8ff; --on-bg2: #a46cff; --on-border: #7147b3; --on-text: #25143f; }
.lego-tags label:nth-child(8n+6) span { --on-bg1: #ffe27a; --on-bg2: #f7bf39; --on-border: #ad7f1f; --on-text: #332407; }
.lego-tags label:nth-child(8n+7) span { --on-bg1: #8effd5; --on-bg2: #2ed6b5; --on-border: #1e947d; --on-text: #0d2a25; }
.lego-tags label:nth-child(8n+8) span { --on-bg1: #ffb47e; --on-bg2: #ff8753; --on-border: #b95b2d; --on-text: #391a0a; }
/* Source-driven selected colors (applies when tags are preselected by the pipeline). */
.lego-tags label[data-psq-preselected="1"][data-psq-origin="rewrite"] span {
--on-bg1: #77f0d7;
--on-bg2: #26b9a3;
--on-border: #187869;
--on-text: #062923;
}
.lego-tags label[data-psq-preselected="1"][data-psq-origin="selection"] span {
--on-bg1: #ffd98a;
--on-bg2: #f0a93c;
--on-border: #a66f1f;
--on-text: #382206;
}
.lego-tags label[data-psq-preselected="1"][data-psq-origin="probe"] span {
--on-bg1: #d8b4ff;
--on-bg2: #9a6cff;
--on-border: #6745b0;
--on-text: #24143b;
}
.lego-tags label[data-psq-preselected="1"][data-psq-origin="structural"] span {
--on-bg1: #a6f79a;
--on-bg2: #53c368;
--on-border: #2f8442;
--on-text: #102d17;
}
.lego-tags label[data-psq-preselected="1"][data-psq-origin="implied"] span {
--on-bg1: #d7dde8;
--on-bg2: #a8b3c4;
--on-border: #6f7e95;
--on-text: #1d2633;
}
/* User-selected tags (not initially selected by the pipeline). */
.lego-tags label[data-psq-preselected="0"] span {
--on-bg1: #9ec5ff;
--on-bg2: #4f86ff;
--on-border: #2f5fbf;
--on-text: #0b1f42;
}
.lego-tags label:hover span {
filter: brightness(1.02) !important;
transform: translateY(1px) !important;
}
/* ON state: brighter + visibly recessed */
.lego-tags input[type="checkbox"]:checked + span,
.lego-tags input[type="checkbox"]:checked ~ span,
.lego-tags label:has(input[type="checkbox"]:checked) span {
background: linear-gradient(180deg, var(--on-bg1) 0%, var(--on-bg2) 100%) !important;
color: var(--on-text) !important;
border-color: var(--on-border) !important;
filter: saturate(1.2) brightness(1.12) !important;
transform: translateY(-2px) !important;
box-shadow:
inset 0 3px 6px rgba(0,0,0,0.20),
inset 0 -1px 0 rgba(255,255,255,0.36),
0 6px 0 rgba(0,0,0,0.32) !important;
}
.lego-tags input[type="checkbox"]:checked + span::after,
.lego-tags input[type="checkbox"]:checked ~ span::after,
.lego-tags label:has(input[type="checkbox"]:checked) span::after {
content: "" !important;
}
.source-legend {
display: flex;
flex-wrap: wrap;
align-items: center;
gap: 8px;
margin: 4px 0 10px 0;
}
.source-legend .legend-title {
font-size: 0.92rem;
font-weight: 900;
color: #334155;
margin-right: 4px;
}
.source-legend .chip {
display: inline-flex;
align-items: center;
border-radius: 10px;
border: 1px solid #6c7788;
padding: 6px 12px;
font-size: 0.85rem;
font-weight: 800;
color: #111827;
background: #f3f6fb;
}
.source-legend .chip.rewrite { background: #26b9a3; color: #062923; border-color: #187869; }
.source-legend .chip.selection { background: #f0a93c; color: #382206; border-color: #a66f1f; }
.source-legend .chip.probe { background: #9a6cff; color: #ffffff; border-color: #6745b0; }
.source-legend .chip.structural { background: #53c368; color: #102d17; border-color: #2f8442; }
.source-legend .chip.implied { background: #a8b3c4; color: #1d2633; border-color: #6f7e95; }
.source-legend .chip.user { background: #4f86ff; color: #ffffff; border-color: #2f5fbf; }
.source-legend .chip.unselected { background: #c7ced8; color: #2d3440; border-color: #7d8897; }
.row-heading p {
margin: 8px 0 0 0 !important;
font-size: 1.18rem !important;
font-weight: 850 !important;
line-height: 1.2 !important;
}
.row-instruction {
text-align: center;
margin: 8px 0 12px 0;
}
.row-instruction p {
margin: 0 !important;
font-size: 1.02rem !important;
font-style: italic !important;
font-weight: 800 !important;
color: #1d4ed8 !important;
}
.about-docs {
margin-top: 4px;
}
.about-docs > p {
line-height: 1.42 !important;
}
.about-docs img {
max-width: 100% !important;
height: auto !important;
border: 1px solid #d2d7e0;
border-radius: 10px;
background: #ffffff;
}
.arch-diagram-wrap {
margin: 6px 0 10px 0;
}
.arch-diagram-wrap h2 {
margin: 0 0 8px 0 !important;
}
.top-instruction {
text-align: center;
margin: 2px 0 6px 0;
}
.top-instruction p {
margin: 0 !important;
font-size: 1.02rem !important;
font-style: italic !important;
font-weight: 800 !important;
color: #1d4ed8 !important;
}
.run-hint {
margin-top: 6px;
text-align: center;
}
.run-hint p {
margin: 0 !important;
font-size: 0.9rem !important;
font-style: italic !important;
color: #475569 !important;
}
.mascot-runtime-hint {
margin-top: 6px;
text-align: right;
}
.mascot-runtime-hint p {
margin: 0 !important;
font-size: 0.9rem !important;
font-style: italic !important;
color: #475569 !important;
}
.system-description {
margin: 2px 0 8px 0 !important;
text-align: center;
}
.system-description p {
margin: 0 !important;
font-size: 1.08rem !important;
font-weight: 500 !important;
font-style: italic !important;
color: var(--body-text-color, #111827) !important;
opacity: 1.0 !important;
}
.prompt-card {
background: transparent !important;
border: none !important;
box-shadow: none !important;
padding: 0 !important;
}
/* Mascot image: keep container transparent so PNG alpha shows app background. */
#mascot,
#mascot .image-container,
#mascot .image-frame,
#mascot .wrap {
background: transparent !important;
border: none !important;
box-shadow: none !important;
}
#mascot .image-container {
background-image: none !important;
}
/* Mascot is decorative: hide image action controls (fullscreen/download/share). */
#mascot button[aria-label*="full" i],
#mascot button[aria-label*="download" i],
#mascot button[aria-label*="share" i],
#mascot a[aria-label*="full" i],
#mascot a[aria-label*="download" i],
#mascot a[aria-label*="share" i],
#mascot .tools,
#mascot .tool-icons {
display: none !important;
}
.suggested-prompt-box {
margin-top: 2px !important;
}
.suggested-prompt-card {
margin-top: 10px !important;
}
.psq-hidden {
display: none !important;
}
"""
client_js = """
() => {
const PROMPT_DRAFT_KEY = `psq_prompt_draft_v3:${window.location.hostname}:${window.location.pathname}`;
const getPromptInput = () =>
document.querySelector(".enter-prompt-box textarea, .enter-prompt-box input");
const readPromptDraft = () => {
try {
return sessionStorage.getItem(PROMPT_DRAFT_KEY) || "";
} catch (_) {
return "";
}
};
const writePromptDraft = (value) => {
const text = (value || "").trim();
try {
if (text) {
sessionStorage.setItem(PROMPT_DRAFT_KEY, value || "");
} else {
sessionStorage.removeItem(PROMPT_DRAFT_KEY);
}
} catch (_) {
// Ignore storage failures.
}
};
const setNativeInputValue = (el, value) => {
const proto = Object.getPrototypeOf(el);
const desc = proto ? Object.getOwnPropertyDescriptor(proto, "value") : null;
if (desc && typeof desc.set === "function") {
desc.set.call(el, value);
} else {
el.value = value;
}
};
const restorePromptDraft = () => {
const el = getPromptInput();
if (!el) return;
const current = el.value || "";
if (current.trim().length > 0) {
writePromptDraft(current);
return;
}
const saved = readPromptDraft();
if (!saved) return;
setNativeInputValue(el, saved);
el.dispatchEvent(new Event("input", { bubbles: true, composed: true }));
el.dispatchEvent(new Event("change", { bubbles: true, composed: true }));
};
const bindPromptDraftHandlers = () => {
if (document.body && document.body.dataset.psqDraftHandlersBound === "1") return;
if (!document.body) return;
document.body.dataset.psqDraftHandlersBound = "1";
const onFieldEvent = (evt) => {
const t = evt && evt.target;
if (!(t instanceof HTMLTextAreaElement || t instanceof HTMLInputElement)) return;
if (!t.closest(".enter-prompt-box")) return;
writePromptDraft(t.value || "");
};
// Capture phase so we persist even during early remount races.
document.addEventListener("input", onFieldEvent, true);
document.addEventListener("change", onFieldEvent, true);
};
const schedulePromptRestore = () => {
const delaysMs = [0, 120, 350, 900];
delaysMs.forEach((ms) => {
window.setTimeout(() => restorePromptDraft(), ms);
});
};
const readTooltipMapRaw = () => {
const el = document.querySelector("#psq-tooltip-map textarea, #psq-tooltip-map input");
if (!el) return "";
return (el.value || "").trim();
};
const readTooltipMap = () => {
const raw = readTooltipMapRaw();
if (!raw) return { rows: [], meta_rows: [], tips: {} };
try {
const obj = JSON.parse(raw);
if (!obj || typeof obj !== "object") return { rows: [], meta_rows: [], tips: {} };
const rows = Array.isArray(obj.rows) ? obj.rows : [];
const meta_rows = Array.isArray(obj.meta_rows) ? obj.meta_rows : [];
const tips = (obj.tips && typeof obj.tips === "object") ? obj.tips : {};
return { rows, meta_rows, tips };
} catch (_) {
return { rows: [], meta_rows: [], tips: {} };
}
};
const applyTooltips = () => {
const payload = readTooltipMap();
const rowTags = Array.isArray(payload.rows) ? payload.rows : [];
const rowMeta = Array.isArray(payload.meta_rows) ? payload.meta_rows : [];
const tipMap = (payload.tips && typeof payload.tips === "object") ? payload.tips : {};
const validOrigins = new Set(["rewrite", "selection", "probe", "structural", "implied", "user"]);
const rowEls = document.querySelectorAll(".lego-tags");
rowEls.forEach((rowEl, rowIdx) => {
const tags = Array.isArray(rowTags[rowIdx]) ? rowTags[rowIdx] : [];
const metas = Array.isArray(rowMeta[rowIdx]) ? rowMeta[rowIdx] : [];
const labels = rowEl.querySelectorAll("label");
labels.forEach((label, tagIdx) => {
const span = label.querySelector("span");
const tag = (tagIdx < tags.length) ? tags[tagIdx] : "";
const tip = tag && Object.prototype.hasOwnProperty.call(tipMap, tag) ? (tipMap[tag] || "") : "";
const meta = (tagIdx < metas.length && metas[tagIdx] && typeof metas[tagIdx] === "object")
? metas[tagIdx]
: {};
const originRaw = String(meta.origin || "selection").trim().toLowerCase();
const origin = validOrigins.has(originRaw) ? originRaw : "selection";
const preselected = !!meta.preselected;
label.setAttribute("data-psq-origin", origin);
label.setAttribute("data-psq-preselected", preselected ? "1" : "0");
if (span) {
span.setAttribute("data-psq-origin", origin);
span.setAttribute("data-psq-preselected", preselected ? "1" : "0");
}
if (tip) {
label.title = tip;
if (span) span.title = tip;
} else {
label.removeAttribute("title");
if (span) span.removeAttribute("title");
}
});
});
};
bindPromptDraftHandlers();
schedulePromptRestore();
applyTooltips();
window.addEventListener("pageshow", () => {
schedulePromptRestore();
applyTooltips();
});
let lastTooltipPayload = readTooltipMapRaw();
let lastTooltipLabelCount = document.querySelectorAll(".lego-tags label").length;
window.setInterval(() => {
const current = readTooltipMapRaw();
const labelCount = document.querySelectorAll(".lego-tags label").length;
if (current === lastTooltipPayload && labelCount === lastTooltipLabelCount) return;
lastTooltipPayload = current;
lastTooltipLabelCount = labelCount;
applyTooltips();
}, 250);
}
"""
client_startup_profile_js = """
() => {
try {
const nav = performance.getEntriesByType("navigation")[0] || null;
const paints = performance.getEntriesByType("paint") || [];
let fpMs = null;
let fcpMs = null;
for (const p of paints) {
if (p && p.name === "first-paint" && fpMs === null) fpMs = p.startTime;
if (p && p.name === "first-contentful-paint" && fcpMs === null) fcpMs = p.startTime;
}
const round = (v) => (typeof v === "number" && Number.isFinite(v) ? Math.round(v * 1000) / 1000 : null);
const payload = {
href_path: window.location.pathname || "",
nav_type: nav ? String(nav.type || "") : "",
ready_state: document.readyState || "",
now_ms: round(performance.now()),
response_end_ms: nav ? round(nav.responseEnd) : null,
dom_interactive_ms: nav ? round(nav.domInteractive) : null,
dom_content_loaded_end_ms: nav ? round(nav.domContentLoadedEventEnd) : null,
load_event_end_ms: nav ? round(nav.loadEventEnd) : null,
fp_ms: round(fpMs),
fcp_ms: round(fcpMs),
transfer_size: nav && typeof nav.transferSize === "number" ? nav.transferSize : null,
encoded_body_size: nav && typeof nav.encodedBodySize === "number" ? nav.encodedBodySize : null,
decoded_body_size: nav && typeof nav.decodedBodySize === "number" ? nav.decodedBodySize : null
};
return [JSON.stringify(payload)];
} catch (_) {
return [""];
}
}
"""
client_enable_input_profile_js = """
() => {
try {
const nav = performance.getEntriesByType("navigation")[0] || null;
const round = (v) => (typeof v === "number" && Number.isFinite(v) ? Math.round(v * 1000) / 1000 : null);
const payload = {
now_ms: round(performance.now()),
ready_state: document.readyState || "",
nav_type: nav ? String(nav.type || "") : "",
response_end_ms: nav ? round(nav.responseEnd) : null,
dom_content_loaded_end_ms: nav ? round(nav.domContentLoadedEventEnd) : null,
load_event_end_ms: nav ? round(nav.loadEventEnd) : null
};
return [JSON.stringify(payload)];
} catch (_) {
return [""];
}
}
"""
def _log_client_startup_profile(payload_raw: str) -> None:
payload = (payload_raw or "").strip()
if not payload:
return
rec: Dict[str, Any]
try:
obj = json.loads(payload)
if not isinstance(obj, dict):
return
keep = {
"href_path",
"nav_type",
"ready_state",
"now_ms",
"response_end_ms",
"dom_interactive_ms",
"dom_content_loaded_end_ms",
"load_event_end_ms",
"fp_ms",
"fcp_ms",
"transfer_size",
"encoded_body_size",
"decoded_body_size",
}
rec = {k: obj.get(k) for k in keep}
except Exception:
rec = {"parse_error": True}
rec["event"] = "client_startup_profile"
rec["server_t_s"] = round(time.perf_counter() - _STARTUP_PROFILE_T0, 6)
print(f"STARTUP_CLIENT_PROFILE {json.dumps(rec, ensure_ascii=False)}")
def _enable_prompt_input(client_payload_raw: str = ""):
global _STARTUP_PROFILE_CLIENT_LOAD_MARKED
if _STARTUP_PROFILE_ON and not _STARTUP_PROFILE_CLIENT_LOAD_MARKED:
# First successful app.load callback from browser connection.
_startup_profile_mark("ui.client_load_event")
_STARTUP_PROFILE_CLIENT_LOAD_MARKED = True
payload = (client_payload_raw or "").strip()
if payload:
rec: Dict[str, Any]
try:
obj = json.loads(payload)
if isinstance(obj, dict):
rec = {
"now_ms": obj.get("now_ms"),
"ready_state": obj.get("ready_state"),
"nav_type": obj.get("nav_type"),
"response_end_ms": obj.get("response_end_ms"),
"dom_content_loaded_end_ms": obj.get("dom_content_loaded_end_ms"),
"load_event_end_ms": obj.get("load_event_end_ms"),
}
else:
rec = {"parse_error": True}
except Exception:
rec = {"parse_error": True}
rec["event"] = "enable_prompt_input"
rec["server_t_s"] = round(time.perf_counter() - _STARTUP_PROFILE_T0, 6)
print(f"STARTUP_CLIENT_ENABLE {json.dumps(rec, ensure_ascii=False)}")
return gr.update(
interactive=True,
placeholder=DEFAULT_PROMPT_EXAMPLE,
)
def rag_pipeline_ui(
user_prompt: str,
display_top_groups: float,
display_top_tags_per_group: float,
display_rank_top_k: float,
rewrite_override: str = "",
console_seed: str = "",
):
logs = []
if console_seed:
logs.extend(str(console_seed).splitlines())
upstream_throttle_hits = 0
credit_or_quota_hits = 0
def _classify_runtime_line(s: str) -> None:
nonlocal upstream_throttle_hits, credit_or_quota_hits
t = str(s or "")
t_l = t.lower()
if (
"upstream throttle detected" in t_l
or "rate-limited upstream" in t_l
or "too many requests" in t_l
or "error code: 429" in t_l
):
upstream_throttle_hits += 1
if (
"credit/quota error detected" in t_l
or "insufficient credits" in t_l
or "payment required" in t_l
or "quota exceeded" in t_l
or "error code: 402" in t_l
):
credit_or_quota_hits += 1
def log(s):
text = str(s)
logs.append(text)
_classify_runtime_line(text)
try:
stage_timings = {}
def _record_timing(stage: str, dt_s: float):
stage_timings[stage] = float(dt_s)
def _emit_timing_summary(total_s: float):
summary_order = [
"preprocess",
"rewrite",
"structural",
"probe",
"retrieval",
"selection",
"implication_expansion",
"prompt_composition",
"group_display",
]
lines = []
for k in summary_order:
if k in stage_timings:
lines.append(f"{k}={stage_timings[k]:.2f}s")
slowest = max(stage_timings.items(), key=lambda kv: kv[1])[0] if stage_timings else "n/a"
log("Timing Summary: " + ", ".join(lines))
log(f"Timing Slowest Stage: {slowest}")
log(f"Timing Total: {total_s:.2f}s")
def _append_timing_jsonl(total_s: float):
try:
timing_log_path.parent.mkdir(parents=True, exist_ok=True)
rec = {
"timestamp_utc": datetime.utcnow().isoformat(timespec="seconds") + "Z",
"stages_s": stage_timings,
"total_s": float(total_s),
"provider_diagnostics": {
"upstream_throttle_hits": int(upstream_throttle_hits),
"credit_or_quota_hits": int(credit_or_quota_hits),
},
"config": {
"timeout_rewrite_s": stage1_rewrite_timeout_s,
"timeout_struct_s": stage1_struct_timeout_s,
"timeout_probe_s": stage1_probe_timeout_s,
"timeout_select_s": stage3_select_timeout_s,
},
}
with timing_log_path.open("a", encoding="utf-8") as f:
f.write(json.dumps(rec, ensure_ascii=True) + "\n")
log(f"Timing Log: wrote {timing_log_path}")
except Exception as e:
log(f"Timing Log: failed ({type(e).__name__}: {_redact_console_error_text(e)})")
def _future_with_timeout(
fut,
timeout_s: float,
stage_name: str,
fallback,
*,
strict: bool = False,
):
t0 = time.perf_counter()
try:
out = fut.result(timeout=max(1.0, float(timeout_s)))
dt = time.perf_counter() - t0
log(f"{stage_name}: {dt:.2f}s")
stage_key = {
"Rewrite": "rewrite",
"Structural inference": "structural",
"Probe inference": "probe",
"Index selection": "selection",
}.get(stage_name)
if stage_key:
_record_timing(stage_key, dt)
return out
except FutureTimeoutError:
fut.cancel()
msg = f"{stage_name}: timed out after {timeout_s:.0f}s"
if strict:
raise RuntimeError(msg)
log(f"{msg}; using fallback")
return fallback
except Exception as e:
msg = f"{stage_name}: failed ({type(e).__name__}: {_redact_console_error_text(e)})"
if strict:
raise RuntimeError(msg)
log(f"{msg}; using fallback")
return fallback
t_total0 = time.perf_counter()
log("Start: received prompt")
if STARTUP_PREFLIGHT_ERRORS:
log("Startup preflight failed:")
for e in STARTUP_PREFLIGHT_ERRORS:
log(f"- {_redact_console_error_text(e)}")
return _build_ui_payload(
console_text="\n".join(logs),
row_defs=[],
selected_tags=[],
suggested_prompt_text="Error: startup preflight failed. Check console details.",
)
prompt_in = (user_prompt or "").strip()
if not prompt_in:
return _build_ui_payload(
console_text="Error: empty prompt",
row_defs=[],
selected_tags=[],
suggested_prompt_text='Enter a prompt and click "Run".',
)
log("Input:")
log(prompt_in)
log("")
default_or_model = (
os.environ.get("PSQ_DEFAULT_OPENROUTER_MODEL", "")
or DEFAULT_STAGE_OPENROUTER_MODEL
).strip() or DEFAULT_STAGE_OPENROUTER_MODEL
default_fallback_model = (
os.environ.get("PSQ_DEFAULT_OPENROUTER_FALLBACK_MODEL", "")
or DEFAULT_STAGE_OPENROUTER_FALLBACK_MODEL
).strip() or DEFAULT_STAGE_OPENROUTER_FALLBACK_MODEL
struct_model_cfg = (
os.environ.get("PSQ_STRUCT_OPENROUTER_MODEL", "") or ""
).strip() or default_or_model
probe_model_cfg = (
os.environ.get("PSQ_PROBE_OPENROUTER_MODEL", "") or ""
).strip() or default_or_model
select_model_cfg = (
os.environ.get("PSQ_SELECT_OPENROUTER_MODEL", "") or ""
).strip() or default_or_model
struct_fallback_model_cfg = (
os.environ.get("PSQ_STRUCT_OPENROUTER_FALLBACK_MODEL", "") or ""
).strip() or default_fallback_model
probe_fallback_model_cfg = (
os.environ.get("PSQ_PROBE_OPENROUTER_FALLBACK_MODEL", "") or ""
).strip() or default_fallback_model
select_fallback_model_cfg = (
os.environ.get("PSQ_SELECT_OPENROUTER_FALLBACK_MODEL", "") or ""
).strip() or default_fallback_model
if struct_fallback_model_cfg == struct_model_cfg:
struct_fallback_model_cfg = ""
if probe_fallback_model_cfg == probe_model_cfg:
probe_fallback_model_cfg = ""
if select_fallback_model_cfg == select_model_cfg:
select_fallback_model_cfg = ""
rewrite_source_cfg = _get_rewrite_source()
t5_model_dir_cfg = (
os.environ.get(
"PSQ_T5_REWRITE_MODEL_DIR",
"models/finetune/t5-rewrite-n30best-20260508/checkpoint-18000",
)
or "models/finetune/t5-rewrite-n30best-20260508/checkpoint-18000"
).strip()
t5_num_beams_cfg = _env_int("PSQ_T5_REWRITE_NUM_BEAMS", 4, minimum=1)
t5_max_new_tokens_cfg = _env_int("PSQ_T5_REWRITE_MAX_NEW_TOKENS", 128, minimum=8)
log(
"Runtime config: "
f"retrieval_global_k={retrieval_global_k} "
f"retrieval_per_phrase_k={retrieval_per_phrase_k} "
f"retrieval_per_phrase_final_k={retrieval_per_phrase_final_k} "
f"retrieval_exact_ngram_max={retrieval_exact_ngram_max} "
f"selection_mode={selection_mode} "
f"selection_chunk_size={selection_chunk_size} "
f"selection_per_phrase_k={selection_per_phrase_k} "
f"rewrite_source={rewrite_source_cfg} "
f"min_tag_count={_get_min_tag_count()} "
f"select_timeout_s={stage3_select_timeout_s:.0f} "
f"select_retry_timeout_s={stage3_select_retry_timeout_s:.0f} "
f"select_fast_retries=disabled(configured={stage3_fast_retry_count}) "
f"struct_model={struct_model_cfg} "
f"probe_model={probe_model_cfg} "
f"select_model={select_model_cfg}"
)
if rewrite_source_cfg == "t5":
log(
"Rewrite config: "
f"t5_model_dir={t5_model_dir_cfg} "
f"t5_num_beams={t5_num_beams_cfg} "
f"t5_max_new_tokens={t5_max_new_tokens_cfg}"
)
log("")
t0 = time.perf_counter()
min_tag_count = _get_min_tag_count()
user_tags_raw = extract_user_provided_tags_upto_3_words(prompt_in)
user_tags, removed_user_low = _filter_min_count_tags(user_tags_raw, min_tag_count)
user_tags, removed_user_excluded = _filter_excluded_recommendation_tags(user_tags)
exact_query_phrases = extract_exact_tag_query_phrases(
prompt_in,
get_tag_counts(),
get_alias2tags(),
min_tag_count=min_tag_count,
max_ngram=max(0, retrieval_exact_ngram_max),
)
exact_query_phrases, removed_exact_excluded = _filter_excluded_recommendation_tags(exact_query_phrases)
dt = time.perf_counter()-t0
_record_timing("preprocess", dt)
log(f"Preprocess (user tag extraction): {dt:.2f}s")
log("Heuristically extracted user tags:")
if user_tags:
log(", ".join(user_tags))
else:
log("(none)")
if removed_user_low:
log(
f"Filtered {len(removed_user_low)} low-frequency user tags "
f"(<{min_tag_count}): {', '.join(removed_user_low)}"
)
if removed_user_excluded:
log(
f"Filtered {len(removed_user_excluded)} excluded user tags: "
f"{', '.join(removed_user_excluded)}"
)
if retrieval_exact_ngram_max > 0:
log(f"Exact caption tag query phrases (1-{retrieval_exact_ngram_max} grams):")
else:
log("Exact caption tag query phrases: disabled")
if exact_query_phrases:
shown = ", ".join(exact_query_phrases[:40])
log(shown + (" ..." if len(exact_query_phrases) > 40 else ""))
else:
log("(none)")
if removed_exact_excluded:
log(
f"Filtered {len(removed_exact_excluded)} excluded exact query phrases: "
f"{', '.join(removed_exact_excluded)}"
)
log("")
rewrite_prefilled = (rewrite_override or "").strip()
if rewrite_prefilled:
log("Step 1: structural inference + probe (rewrite already prepared)")
else:
log("Step 1: rewrite + structural inference + probe (concurrent)")
max_workers = 3 if enable_probe_tags else 2
ex = ThreadPoolExecutor(max_workers=max_workers)
try:
fut_rewrite = (
ex.submit(_rewrite_prompt, prompt_in, log)
if not rewrite_prefilled else None
)
fut_struct = ex.submit(
llm_infer_structural_tags,
prompt_in,
log=log,
model_override=struct_model_cfg,
fallback_model_override=(struct_fallback_model_cfg or None),
)
fut_probe = (
ex.submit(
llm_infer_probe_tags,
prompt_in,
log=log,
model_override=probe_model_cfg,
fallback_model_override=(probe_fallback_model_cfg or None),
)
if enable_probe_tags
else None
)
if rewrite_prefilled:
rewritten = rewrite_prefilled
_record_timing("rewrite", 0.0)
else:
rewritten = _future_with_timeout(
fut_rewrite,
stage1_rewrite_timeout_s,
"Rewrite",
"",
strict=True,
)
structural_tags = _future_with_timeout(
fut_struct, stage1_struct_timeout_s, "Structural inference", []
)
probe_tags = (
_future_with_timeout(fut_probe, stage1_probe_timeout_s, "Probe inference", [])
if fut_probe else []
)
finally:
ex.shutdown(wait=False, cancel_futures=True)
structural_tags, removed_struct_low = _filter_min_count_tags(structural_tags, min_tag_count)
probe_tags, removed_probe_low = _filter_min_count_tags(probe_tags, min_tag_count)
structural_tags, removed_struct_excluded = _filter_excluded_recommendation_tags(structural_tags)
probe_tags, removed_probe_excluded = _filter_excluded_recommendation_tags(probe_tags)
if removed_struct_low:
log(
f"Filtered {len(removed_struct_low)} low-frequency structural tags "
f"(<{min_tag_count}): {', '.join(removed_struct_low)}"
)
if removed_probe_low:
log(
f"Filtered {len(removed_probe_low)} low-frequency probe tags "
f"(<{min_tag_count}): {', '.join(removed_probe_low)}"
)
if removed_struct_excluded:
log(
f"Filtered {len(removed_struct_excluded)} excluded structural tags: "
f"{', '.join(removed_struct_excluded)}"
)
if removed_probe_excluded:
log(
f"Filtered {len(removed_probe_excluded)} excluded probe tags: "
f"{', '.join(removed_probe_excluded)}"
)
if not rewritten:
raise RuntimeError("Rewrite: empty output")
log("Rewrite:")
log(rewritten if rewritten else "(empty)")
log("")
rewrite_for_retrieval = rewritten
retrieval_query_hints = list(dict.fromkeys((user_tags or []) + (exact_query_phrases or [])))
if retrieval_query_hints:
# keep them separate in logs, but allow them to help retrieval
rewrite_for_retrieval = (rewrite_for_retrieval + ", " + ", ".join(retrieval_query_hints)).strip(", ").strip()
log("Step 2: Prompt Squirrel retrieval (hidden)")
try:
t0 = time.perf_counter()
retrieval_context_tags = list(dict.fromkeys((structural_tags or []) + (probe_tags or [])))
rewrite_phrases = [p.strip() for p in (rewrite_for_retrieval or "").split(",") if p.strip()]
retrieval_result = psq_candidates_from_rewrite_phrases(
rewrite_phrases=rewrite_phrases,
allow_nsfw_tags=allow_nsfw_tags,
context_tags=retrieval_context_tags,
global_k=max(1, retrieval_global_k),
per_phrase_k=max(1, retrieval_per_phrase_k),
per_phrase_final_k=max(1, retrieval_per_phrase_final_k),
min_tag_count=max(0, min_tag_count),
verbose=verbose_retrieval,
)
if isinstance(retrieval_result, tuple):
candidates, phrase_reports = retrieval_result
else:
candidates, phrase_reports = retrieval_result, []
candidates, removed_candidate_excluded = _filter_excluded_candidates(candidates)
if removed_candidate_excluded:
log(
f"Filtered {len(removed_candidate_excluded)} excluded retrieved tags: "
f"{', '.join(removed_candidate_excluded[:25])}"
+ (" ..." if len(removed_candidate_excluded) > 25 else "")
)
if selection_candidate_cap > 0 and len(candidates) > selection_candidate_cap:
candidates = candidates[:selection_candidate_cap]
log(f"Selection candidate cap applied: {selection_candidate_cap}")
dt = time.perf_counter()-t0
_record_timing("retrieval", dt)
log(f"Retrieval: {dt:.2f}s")
log(f"Retrieved {len(candidates)} candidate tags")
if verbose_retrieval:
log(f"Total unique candidates: {len(candidates)}")
limit = None if verbose_retrieval_all else max(1, int(verbose_retrieval_limit))
for report in phrase_reports:
phrase = report.get("normalized") or report.get("phrase") or ""
lookup = report.get("lookup") or ""
tfidf_vocab = report.get("tfidf_vocab")
log(f"Phrase: {phrase} (lookup={lookup}) tfidf_vocab={tfidf_vocab}")
rows = report.get("candidates", [])
shown = rows if limit is None else rows[:limit]
for row in shown:
tag = row.get("tag")
alias_token = row.get("alias_token")
score_fasttext = row.get("score_fasttext")
score_context = row.get("score_context")
score_combined = row.get("score_combined")
count = row.get("count")
alias_part = ""
if alias_token and alias_token != tag:
alias_part = f" [alias_token={alias_token}]"
fasttext_str = (
f"{score_fasttext:.3f}" if isinstance(score_fasttext, (int, float)) else score_fasttext
)
if score_context is None:
context_str = "None"
else:
context_str = (
f"{score_context:.3f}" if isinstance(score_context, (int, float)) else score_context
)
combined_str = (
f"{score_combined:.3f}" if isinstance(score_combined, (int, float)) else score_combined
)
log(
f" {tag}{alias_part} | fasttext={fasttext_str} context={context_str} "
f"combined={combined_str} count={count}"
)
if limit is not None and len(rows) > limit:
log(f" ... ({len(rows) - limit} more)")
except Exception as e:
log(f"Retrieval fallback: {type(e).__name__}: {_redact_console_error_text(e)}")
candidates = []
retrieved_candidate_tags = list(
dict.fromkeys(
_norm_tag_for_lookup(c.tag)
for c in (candidates or [])
if getattr(c, "tag", None)
)
)
log("Step 3: LLM index selection (uses rewrite + structural/probe context)")
selection_query = _build_selection_query(
prompt_in=prompt_in,
rewritten=rewritten,
structural_tags=structural_tags,
probe_tags=probe_tags,
)
select_model_override = select_model_cfg.strip() or None
select_fallback_model_override = select_fallback_model_cfg.strip() or None
picked_indices = None
last_stage3_error: Exception | None = None
# Same-model retries are disabled; failover happens inside Stage3 calls.
stage3_attempts = 1
for attempt_i in range(stage3_attempts):
timeout_s = stage3_select_timeout_s if attempt_i == 0 else stage3_select_retry_timeout_s
if attempt_i > 0:
log(
f"Index selection: fast retry {attempt_i}/{stage3_fast_retry_count} "
f"(timeout={timeout_s:.0f}s)"
)
ex = ThreadPoolExecutor(max_workers=1)
try:
fut_sel = ex.submit(
llm_select_indices,
query_text=selection_query,
candidates=candidates,
max_pick=0,
log=log,
mode=selection_mode,
chunk_size=max(1, selection_chunk_size),
per_phrase_k=max(1, selection_per_phrase_k),
model_override=select_model_override,
fallback_model_override=select_fallback_model_override,
)
picked_indices = _future_with_timeout(
fut_sel,
timeout_s,
"Index selection",
[],
strict=True,
)
last_stage3_error = None
break
except Exception as e:
last_stage3_error = e
log(f"Index selection attempt {attempt_i + 1} failed: {_redact_console_error_text(e)}")
finally:
ex.shutdown(wait=False, cancel_futures=True)
if picked_indices is None:
raise RuntimeError(
f"Index selection failed after {stage3_attempts} attempt(s): {last_stage3_error}"
)
selection_selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else []
selection_selected_tags, removed_stage3_low = _filter_min_count_tags(selection_selected_tags, min_tag_count)
if removed_stage3_low:
log(
f" Filtered {len(removed_stage3_low)} low-frequency stage3 tags "
f"(<{min_tag_count}): {', '.join(removed_stage3_low)}"
)
selected_tags = list(selection_selected_tags)
if structural_tags:
# Add structural tags that aren't already selected
existing = {t for t in selected_tags}
new_structural = [t for t in structural_tags if t not in existing]
selected_tags.extend(new_structural)
log(f" Added {len(new_structural)} structural tags: {', '.join(new_structural)}")
else:
log(" No structural tags inferred")
if probe_tags:
existing = {t for t in selected_tags}
new_probe = [t for t in probe_tags if t not in existing]
selected_tags.extend(new_probe)
log(f" Added {len(new_probe)} probe tags: {', '.join(new_probe)}")
elif enable_probe_tags:
log(" No probe tags inferred")
selected_tags, removed_excluded_direct = _filter_excluded_recommendation_tags(selected_tags)
if removed_excluded_direct:
log(f" Removed {len(removed_excluded_direct)} excluded tags: {', '.join(removed_excluded_direct)}")
direct_selected_tags = list(dict.fromkeys(selected_tags))
log("Step 3c: Expand via tag implications")
t0 = time.perf_counter()
tag_set = set(selected_tags)
expanded, implied_only = expand_tags_via_implications(tag_set)
dt = time.perf_counter()-t0
_record_timing("implication_expansion", dt)
log(f"Implication expansion: {dt:.2f}s")
implied_selected_tags = sorted(implied_only) if implied_only else []
if implied_only:
implied_added = sorted(implied_only)
implied_added, removed_implied_low = _filter_min_count_tags(implied_added, min_tag_count)
implied_selected_tags = list(implied_added)
if implied_added:
selected_tags.extend(implied_added)
log(f" Added {len(implied_added)} implied tags: {', '.join(implied_added)}")
if removed_implied_low:
log(
f" Filtered {len(removed_implied_low)} low-frequency implied tags "
f"(<{min_tag_count}): {', '.join(removed_implied_low)}"
)
else:
log(" No additional implied tags")
selected_tags, removed_excluded_implied = _filter_excluded_recommendation_tags(selected_tags)
implied_selected_tags = [
t for t in implied_selected_tags if not _is_excluded_recommendation_tag(t)
]
if removed_excluded_implied:
log(
f" Removed {len(removed_excluded_implied)} excluded tags after implications: "
f"{', '.join(removed_excluded_implied)}"
)
log("Step 4: Compose final prompt")
t0 = time.perf_counter()
final_prompt = compose_final_prompt(rewritten, selected_tags)
dt = time.perf_counter()-t0
_record_timing("prompt_composition", dt)
log(f"Prompt composition: {dt:.2f}s")
log("Step 5: Build ranked group/category display")
t0 = time.perf_counter()
seed_terms = []
seed_terms.extend(user_tags)
seed_terms.extend([p.strip() for p in (rewritten or "").split(",") if p.strip()])
seed_terms.extend(structural_tags or [])
seed_terms.extend(probe_tags or [])
seed_terms.extend(selected_tags)
seed_terms = list(dict.fromkeys(seed_terms))
active_selected_tags = list(dict.fromkeys(selected_tags))
structural_set = {_norm_tag_for_lookup(t) for t in (structural_tags or []) if t}
probe_set = {_norm_tag_for_lookup(t) for t in (probe_tags or []) if t}
implied_set = {_norm_tag_for_lookup(t) for t in (implied_selected_tags or []) if t}
rewrite_set = {
_norm_tag_for_lookup(t)
for t in (list(user_tags or []) + [p.strip() for p in (rewritten or "").split(",") if p.strip()])
if t
}
selection_set = {_norm_tag_for_lookup(t) for t in (selection_selected_tags or []) if t}
tag_selection_origins: Dict[str, str] = {}
for tag in active_selected_tags:
tag_norm = _norm_tag_for_lookup(tag)
if tag_norm in structural_set:
origin = "structural"
elif tag_norm in probe_set:
origin = "probe"
elif tag_norm in rewrite_set:
origin = "rewrite"
elif tag_norm in selection_set:
origin = "selection"
elif tag_norm in implied_set:
origin = "implied"
else:
# Unknown/fallback tags use selection color.
origin = "selection"
tag_selection_origins[tag] = origin
if tag_norm and tag_norm != tag:
tag_selection_origins[tag_norm] = origin
direct_tags_for_implied = list(
dict.fromkeys(_norm_tag_for_lookup(t) for t in (direct_selected_tags or []) if t)
)
direct_tags_for_implied_idx = {t: i for i, t in enumerate(direct_tags_for_implied)}
direct_tags_for_implied.sort(
key=lambda t: (
_selection_source_rank(tag_selection_origins.get(t, "selection")),
direct_tags_for_implied_idx.get(t, 10**9),
)
)
implied_parent_map = _build_implied_parent_map(
direct_tags_ordered=direct_tags_for_implied,
implied_tags=implied_selected_tags,
)
toggle_rows = _build_toggle_rows(
seed_terms=seed_terms,
selected_tags=active_selected_tags,
retrieved_candidate_tags=retrieved_candidate_tags,
tag_selection_origins=tag_selection_origins,
implied_parent_map=implied_parent_map,
top_groups=max(1, int(display_top_groups)),
top_tags_per_group=max(1, int(display_top_tags_per_group)),
group_rank_top_k=max(1, int(display_rank_top_k)),
)
dt = time.perf_counter()-t0
_record_timing("group_display", dt)
log(f"Ranked group display: {dt:.2f}s ({len(toggle_rows)} rows)")
log(
_build_display_audit_line(
toggle_rows,
active_selected_tags=active_selected_tags,
direct_selected_tags=direct_selected_tags,
implied_selected_tags=implied_selected_tags,
)
)
for idx, row in enumerate(toggle_rows[: max(0, int(display_max_rows_default))]):
tags_preview = ", ".join(row.get("tags", []))
log(f"UI Row {idx}: {row.get('label', '')} :: {tags_preview}")
total_dt = time.perf_counter()-t_total0
_emit_timing_summary(total_dt)
if upstream_throttle_hits or credit_or_quota_hits:
log(
"Provider Diagnostics: "
f"upstream_throttle_hits={upstream_throttle_hits} "
f"credit_or_quota_hits={credit_or_quota_hits}"
)
if not active_selected_tags and upstream_throttle_hits > 0:
log(
"Provider Diagnostics: output likely degraded by upstream throttling "
"(selection context may be incomplete)."
)
_append_timing_jsonl(total_dt)
log("Done: final prompt ready")
return _build_ui_payload(
console_text="\n".join(logs),
row_defs=toggle_rows,
selected_tags=active_selected_tags,
)
except Exception as e:
log(f"Error: {type(e).__name__}: {_redact_console_error_text(e)}")
return _build_ui_payload(
console_text="\n".join(logs),
row_defs=[],
selected_tags=[],
suggested_prompt_text=_format_user_facing_error(e),
)
_startup_profile_mark("ui.blocks_build_begin")
with gr.Blocks(css=css, js=client_js) as app:
with gr.Row():
with gr.Column(scale=3, elem_classes=["prompt-col"]):
gr.Markdown(
"Retrieval-augmented system for mapping unstructured input to a controlled "
"vocabulary (demo: image tags).<br>"
"See 'How Prompt Squirrel Works'.",
elem_classes=["run-hint", "system-description"],
)
with gr.Group(elem_classes=["prompt-card"]):
image_tags = gr.Textbox(
label="Describe your image here, then click 'Run'.",
placeholder=DEFAULT_PROMPT_EXAMPLE,
lines=1,
interactive=False,
elem_classes=["enter-prompt-box"],
)
with gr.Group(elem_classes=["prompt-card", "suggested-prompt-card"]):
suggested_prompt = gr.Textbox(
label=SUGGESTED_PROMPT_LABEL_READY,
lines=2,
interactive=False,
show_copy_button=True,
placeholder='Suggested prompt will appear here after you click "Run".',
elem_classes=["suggested-prompt-box"],
)
with gr.Column(scale=1):
_mascot_pil = _load_mascot_image()
if _mascot_pil is not None:
mascot_img = gr.Image(
value=_mascot_pil,
show_label=False,
interactive=False,
height=240,
elem_id="mascot"
)
else:
mascot_img = gr.Markdown("`(mascot image unavailable)`")
submit_button = gr.Button("Run", variant="primary", visible=False, interactive=False)
gr.Markdown("Typical runtime: up to ~20 seconds.", elem_classes=["mascot-runtime-hint"])
last_run_prompt_state = gr.State("")
selected_tags_state = gr.State([])
rows_dirty_state = gr.State(False)
row_defs_state = gr.State([])
row_values_state = gr.State([])
rewrite_state = gr.State("")
toggle_instruction = gr.Markdown(
"Click tag buttons to add or remove tags from the suggested prompt.",
elem_classes=["row-instruction"],
visible=False,
)
row_headers: List[gr.Markdown] = []
row_checkboxes: List[gr.CheckboxGroup] = []
for _ in range(display_max_rows_default):
with gr.Row():
with gr.Column(scale=2, min_width=170):
row_headers.append(gr.Markdown(value="", visible=False, elem_classes=["row-heading"]))
with gr.Column(scale=10):
row_checkboxes.append(
gr.CheckboxGroup(
choices=[],
value=[],
visible=False,
interactive=True,
container=False,
elem_classes=["lego-tags"],
)
)
with gr.Row():
with gr.Column(scale=10):
gr.HTML(
"""
<div class="source-legend">
<span class="legend-title">Legend:</span>
<span class="chip rewrite">Rewrite phrase</span>
<span class="chip selection">General selection</span>
<span class="chip probe">Probe query</span>
<span class="chip structural">Structural query</span>
<span class="chip implied">Implied</span>
<span class="chip user">User-toggled</span>
<span class="chip unselected">Unselected</span>
</div>
"""
)
with gr.Column(scale=2, min_width=180):
rebuild_rows_button = gr.Button(
"Rebuild Rows",
variant="primary",
visible=False,
interactive=False,
)
with gr.Accordion("Display Settings", open=False):
with gr.Row():
display_top_groups = gr.Number(
value=display_top_groups_default,
precision=0,
label="Rows (Top Groups/Categories)",
minimum=1,
)
display_top_tags_per_group = gr.Number(
value=display_top_tags_per_group_default,
precision=0,
label="Top Tags Shown Per Row",
minimum=1,
)
display_rank_top_k = gr.Number(
value=display_rank_top_k_default,
precision=0,
label="Top Tags Used for Row Ranking",
minimum=1,
)
with gr.Accordion("Console", open=False):
console = gr.Textbox(
label="Console",
lines=10,
interactive=False,
placeholder="Progress logs will appear here.",
elem_id="psq-console",
)
with gr.Accordion("How Prompt Squirrel Works", open=False):
_about_md = _load_about_docs_markdown()
_about_before, _about_after, _has_arch_slot = _split_about_docs_for_diagram(_about_md)
if _has_arch_slot:
if _about_before:
gr.Markdown(
_about_before,
elem_id="about-docs",
elem_classes=["about-docs"],
)
gr.HTML(
_build_arch_diagram_html(),
elem_classes=["about-docs"],
)
if _about_after:
gr.Markdown(
_about_after,
elem_classes=["about-docs"],
)
else:
gr.Markdown(
_about_md,
elem_id="about-docs",
elem_classes=["about-docs"],
)
tooltip_map_payload = gr.Textbox(
value="{}",
visible=True,
interactive=False,
container=False,
elem_id="psq-tooltip-map",
elem_classes=["psq-hidden"],
)
client_startup_profile_payload = gr.Textbox(
value="",
visible=True,
interactive=False,
container=False,
elem_id="psq-client-startup-profile",
elem_classes=["psq-hidden"],
)
client_enable_input_payload = gr.Textbox(
value="",
visible=True,
interactive=False,
container=False,
elem_id="psq-client-enable-input-profile",
elem_classes=["psq-hidden"],
)
run_outputs = [
console,
toggle_instruction,
tooltip_map_payload,
suggested_prompt,
selected_tags_state,
rows_dirty_state,
rebuild_rows_button,
row_defs_state,
row_values_state,
*row_headers,
*row_checkboxes,
]
run_outputs_with_rewrite = [*run_outputs, rewrite_state]
image_tags.change(
_update_run_button_visibility,
inputs=[image_tags, last_run_prompt_state],
outputs=[submit_button],
queue=False,
show_progress="hidden",
)
app.load(
_enable_prompt_input,
inputs=[client_enable_input_payload],
outputs=[image_tags],
js=client_enable_input_profile_js,
queue=False,
show_progress="hidden",
)
app.load(
_log_client_startup_profile,
inputs=[client_startup_profile_payload],
outputs=[],
js=client_startup_profile_js,
queue=False,
show_progress="hidden",
)
submit_button.click(
_mark_run_triggered,
inputs=[image_tags],
outputs=[submit_button, last_run_prompt_state],
queue=False,
show_progress="hidden",
).then(
_prepare_run_ui_with_rewrite_state,
inputs=[],
outputs=run_outputs_with_rewrite,
queue=False,
show_progress="hidden",
).then(
_rewrite_preview_ui,
inputs=[image_tags],
outputs=run_outputs_with_rewrite,
show_progress="hidden",
).then(
rag_pipeline_ui,
inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k, rewrite_state, console],
outputs=run_outputs,
show_progress="minimal",
show_progress_on=[mascot_img],
)
image_tags.submit(
_mark_run_triggered,
inputs=[image_tags],
outputs=[submit_button, last_run_prompt_state],
queue=False,
show_progress="hidden",
).then(
_prepare_run_ui_with_rewrite_state,
inputs=[],
outputs=run_outputs_with_rewrite,
queue=False,
show_progress="hidden",
).then(
_rewrite_preview_ui,
inputs=[image_tags],
outputs=run_outputs_with_rewrite,
show_progress="hidden",
).then(
rag_pipeline_ui,
inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k, rewrite_state, console],
outputs=run_outputs,
show_progress="minimal",
show_progress_on=[mascot_img],
)
for idx, row_cb in enumerate(row_checkboxes):
row_cb.change(
fn=lambda changed_values, selected_state, rows_dirty, row_defs, row_values, i=idx: _on_toggle_row(
i,
changed_values,
selected_state,
rows_dirty,
row_defs,
row_values,
display_max_rows_default,
),
inputs=[row_cb, selected_tags_state, rows_dirty_state, row_defs_state, row_values_state],
outputs=[selected_tags_state, rows_dirty_state, rebuild_rows_button, row_values_state, suggested_prompt, *row_checkboxes],
queue=False,
show_progress="hidden",
)
rebuild_rows_button.click(
_rebuild_rows_from_selected,
inputs=[selected_tags_state, row_defs_state, row_values_state, display_top_groups, display_top_tags_per_group, display_rank_top_k],
outputs=[
toggle_instruction,
tooltip_map_payload,
suggested_prompt,
selected_tags_state,
rows_dirty_state,
rebuild_rows_button,
row_defs_state,
row_values_state,
*row_headers,
*row_checkboxes,
],
queue=False,
show_progress="hidden",
)
_startup_profile_mark("ui.blocks_build_end")
if __name__ == "__main__":
_startup_profile_mark("launch.begin", ssr_mode=False)
app.queue().launch(
allowed_paths=[str(MASCOT_DIR), str(DOCS_DIR)],
ssr_mode=False,
pwa=False,
)