import gradio as gr import os import logging import time import json import csv from datetime import datetime from functools import lru_cache from PIL import Image from pathlib import Path from typing import Any, Dict, List, Set, Tuple from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words from psq_rag.llm.rewrite import llm_rewrite_prompt from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags, llm_infer_probe_tags from psq_rag.retrieval.state import ( expand_tags_via_implications, get_tag_type_name, get_tag_implications, get_tag_counts, ) from psq_rag.ui.group_ranked_display import rank_groups_from_tfidf, _load_enabled_groups def _split_prompt_commas(s: str) -> List[str]: return [p.strip() for p in (s or "").split(",") if p.strip()] def _norm_for_dedupe(tag: str) -> str: # your canonical form for lookup/dedupe return _norm_tag_for_lookup(tag.lower()) def compose_final_prompt(rewritten_prompt: str, selected_tags: List[str]) -> str: parts = _split_prompt_commas(rewritten_prompt) parts.extend(selected_tags) seen = set() out = [] for p in parts: key = _norm_for_dedupe(p) if key in seen: continue seen.add(key) out.append(p) return ", ".join(out) def _display_tag_text(tag: str) -> str: return tag.replace("_", " ") def _display_row_label(name: str) -> str: n = (name or "").strip() if not n: return "" if n == "selected_other": return "Selected (Other)" return n.replace("_", " ").title() def _normalize_selection_origin(origin: str) -> str: o = (origin or "").strip().lower() if o in {"rewrite", "selection", "probe", "structural", "user", "candidate"}: return o return "selection" def _choice_label_with_source_meta(tag: str, *, origin: str, preselected: bool) -> str: # Keep labels plain to avoid frontend text/value desynchronization. return _display_tag_text(tag) def _selection_source_rank(origin: str) -> int: o = _normalize_selection_origin(origin) if o == "structural": return 0 if o == "probe": return 1 # Keep rewrite/user in the same priority band as general selection for row ordering. return 2 def _build_implied_parent_map( direct_tags_ordered: List[str], implied_tags: List[str], ) -> Dict[str, str]: implied_set = {_norm_tag_for_lookup(t) for t in (implied_tags or []) if t} if not implied_set or not direct_tags_ordered: return {} impl = get_tag_implications() parent_by_implied: Dict[str, str] = {} for direct in direct_tags_ordered: d = _norm_tag_for_lookup(direct) if not d: continue queue = [d] seen = {d} while queue: t = queue.pop() for parent in impl.get(t, ()): p = _norm_tag_for_lookup(parent) if not p or p in seen: continue seen.add(p) if p in implied_set and p not in parent_by_implied: parent_by_implied[p] = d queue.append(p) return parent_by_implied def _order_selected_tags_for_row( *, row_selected_tags: List[str], selected_index: Dict[str, int], tag_selection_origins: Dict[str, str], implied_parent_map: Dict[str, str], ) -> List[str]: row_selected_norm = [_norm_tag_for_lookup(t) for t in (row_selected_tags or []) if t] implied_in_row = {t for t in row_selected_norm if t in implied_parent_map} base_tags = [t for t in row_selected_norm if t not in implied_in_row] base_tags.sort( key=lambda t: ( _selection_source_rank(tag_selection_origins.get(t, "selection")), selected_index.get(t, 10**9), t, ) ) children_by_parent: Dict[str, List[str]] = {} for implied in implied_in_row: parent = implied_parent_map.get(implied) if parent: children_by_parent.setdefault(parent, []).append(implied) for parent, children in children_by_parent.items(): children.sort(key=lambda t: (selected_index.get(t, 10**9), t)) ordered: List[str] = [] emitted: Set[str] = set() for tag in base_tags: if tag in emitted: continue ordered.append(tag) emitted.add(tag) for child in children_by_parent.get(tag, []): if child not in emitted: ordered.append(child) emitted.add(child) remaining_implied = [t for t in row_selected_norm if t not in emitted] remaining_implied.sort( key=lambda t: ( _selection_source_rank(tag_selection_origins.get(implied_parent_map.get(t, ""), "selection")), selected_index.get(implied_parent_map.get(t, ""), 10**9), selected_index.get(t, 10**9), t, ) ) for t in remaining_implied: if t not in emitted: ordered.append(t) emitted.add(t) return ordered def _escape_prompt_tag(tag: str) -> str: return ( tag.replace("_", " ") .replace("(", "\\(") .replace(")", "\\)") ) def _ordered_selected_for_prompt(selected: Set[str], row_defs: List[Dict[str, Any]]) -> List[str]: out: List[str] = [] seen: Set[str] = set() for row in row_defs: for tag in row.get("tags", []): if tag in selected and tag not in seen: out.append(tag) seen.add(tag) return out def _compose_toggle_prompt_text(selected_tags: List[str], row_defs: List[Dict[str, Any]]) -> str: selected = {t for t in (selected_tags or []) if t} ordered = _ordered_selected_for_prompt(selected, row_defs or []) return ", ".join(_escape_prompt_tag(t) for t in ordered) def _is_artist_tag(tag: str) -> bool: t = _norm_tag_for_lookup(str(tag)) if not t: return False # Keep a resilient fallback for malformed/missing tag typing metadata. return get_tag_type_name(t) == "artist" or t.startswith("by_") @lru_cache(maxsize=1) def _load_excluded_recommendation_tags() -> Set[str]: csv_path = Path("data/category_registry.csv") if not csv_path.exists(): csv_path = Path("data/analysis/category_registry.csv") out: Set[str] = set() if not csv_path.exists(): return out try: with csv_path.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for row in reader: tag = _norm_tag_for_lookup(str(row.get("tag") or "")) if not tag: continue status = str(row.get("category_status") or "").strip().lower() if status == "excluded": out.add(tag) except Exception: return set() return out def _is_excluded_recommendation_tag(tag: str) -> bool: t = _norm_tag_for_lookup(str(tag)) if not t: return False return t in _load_excluded_recommendation_tags() def _get_min_tag_count() -> int: try: return max(0, int(os.environ.get("PSQ_MIN_TAG_COUNT", "100"))) except Exception: return 100 def _filter_min_count_tags(tags: List[str], min_count: int) -> Tuple[List[str], List[str]]: if min_count <= 0: return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), [] tag_counts = get_tag_counts() keep: List[str] = [] removed: List[str] = [] seen: Set[str] = set() for raw in (tags or []): t = _norm_tag_for_lookup(str(raw)) if not t: continue c = int(tag_counts.get(t, 0) or 0) if c < min_count: removed.append(t) continue if t in seen: continue seen.add(t) keep.append(t) return keep, sorted(set(removed)) def _filter_excluded_recommendation_tags(tags: List[str]) -> Tuple[List[str], List[str]]: excluded = _load_excluded_recommendation_tags() if not excluded: return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), [] keep: List[str] = [] removed: List[str] = [] seen: Set[str] = set() for raw in (tags or []): t = _norm_tag_for_lookup(str(raw)) if not t: continue if t in excluded: removed.append(t) continue if t in seen: continue seen.add(t) keep.append(t) return keep, sorted(set(removed)) def _dedupe_norm_tags(tags: List[str]) -> List[str]: out: List[str] = [] seen: Set[str] = set() for raw in (tags or []): t = _norm_tag_for_lookup(str(raw)) if not t or t in seen: continue seen.add(t) out.append(t) return out def _collect_visible_tags(row_defs: List[Dict[str, Any]]) -> Set[str]: out: Set[str] = set() for row in (row_defs or []): for t in _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else []): out.add(t) return out def _collect_selected_from_state( selected_tags_state: List[str], row_defs: List[Dict[str, Any]], ) -> List[str]: visible_tags = _collect_visible_tags(row_defs) if not visible_tags: return [] selected: List[str] = [] seen: Set[str] = set() visible_by_norm = {_norm_tag_for_lookup(t): t for t in visible_tags} for raw in (selected_tags_state or []): t = _norm_tag_for_lookup(str(raw)) if not t: continue mapped = t if t in visible_tags else visible_by_norm.get(t) if not mapped or mapped in seen: continue seen.add(mapped) selected.append(mapped) return selected def _collect_selected_from_row_values( row_defs: List[Dict[str, Any]], row_values_state: List[List[str]], ) -> List[str]: selected: List[str] = [] seen: Set[str] = set() values = list(row_values_state or []) for idx, row in enumerate(row_defs or []): row_tags = _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else []) if not row_tags: continue row_tag_set = set(row_tags) row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags} raw_vals = values[idx] if 0 <= idx < len(values) else [] for raw in (raw_vals or []): if raw in row_tag_set: if raw not in seen: seen.add(raw) selected.append(raw) continue raw_norm = _norm_tag_for_lookup(str(raw)) mapped = row_tag_by_norm.get(raw_norm) if mapped and mapped not in seen: seen.add(mapped) selected.append(mapped) return selected def _build_toggle_rows( *, seed_terms: List[str], selected_tags: List[str], retrieved_candidate_tags: List[str], tag_selection_origins: Dict[str, str], implied_parent_map: Dict[str, str], top_groups: int, top_tags_per_group: int, group_rank_top_k: int, ) -> List[Dict[str, Any]]: ranked_rows = rank_groups_from_tfidf( seed_terms=seed_terms, top_groups=max(1, int(top_groups)), top_tags_per_group=max(1, int(top_tags_per_group)), group_rank_top_k=max(1, int(group_rank_top_k)), ) groups_map = _load_enabled_groups() selected_active = list( dict.fromkeys( _norm_tag_for_lookup(t) for t in selected_tags if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t) ) ) selected_index: Dict[str, int] = {t: i for i, t in enumerate(selected_active)} row_defs: List[Dict[str, Any]] = [] enabled_group_tag_sets: Dict[str, Set[str]] = { name: {t for t in tags if not _is_artist_tag(t)} for name, tags in groups_map.items() } tags_in_any_enabled_group: Set[str] = set() for tag_set in enabled_group_tag_sets.values(): tags_in_any_enabled_group.update(tag_set) displayed_group_names = [r.group_name for r in ranked_rows] displayed_group_tag_sets: Dict[str, Set[str]] = { name: enabled_group_tag_sets.get(name, set()) for name in displayed_group_names } tags_in_any_displayed_group: Set[str] = set() for tag_set in displayed_group_tag_sets.values(): tags_in_any_displayed_group.update(tag_set) retrieved_uncategorized_ranked = list( dict.fromkeys( _norm_tag_for_lookup(t) for t in (retrieved_candidate_tags or []) if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t) and _norm_tag_for_lookup(t) not in tags_in_any_enabled_group ) ) retrieved_other_row: Dict[str, Any] | None = None if retrieved_uncategorized_ranked: retrieved_uncategorized_set = set(retrieved_uncategorized_ranked) selected_in_retrieved_other_raw = [ t for t in selected_active if t in retrieved_uncategorized_set ] selected_in_retrieved_other = _order_selected_tags_for_row( row_selected_tags=selected_in_retrieved_other_raw, selected_index=selected_index, tag_selection_origins=tag_selection_origins, implied_parent_map=implied_parent_map, ) merged_retrieved_other = selected_in_retrieved_other + [ t for t in retrieved_uncategorized_ranked if t not in selected_in_retrieved_other ] merged_retrieved_other = _dedupe_norm_tags(merged_retrieved_other) keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_retrieved_other)) merged_retrieved_other = merged_retrieved_other[:keep_n] retrieved_other_meta = { t: { "origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")), "preselected": t in selected_active, } for t in merged_retrieved_other } retrieved_other_row = { "name": "other_retrieved", "label": "Other (Retrieved)", "tags": merged_retrieved_other, "tag_meta": retrieved_other_meta, } # "Selected (Other)" should contain selected tags not already shown in any displayed row. # Include "Other (Retrieved)" in that displayed-row set to avoid duplicates across those rows. tags_in_displayed_rows = set(tags_in_any_displayed_group) if retrieved_other_row: tags_in_displayed_rows.update(retrieved_other_row.get("tags", [])) selected_other_raw = [t for t in selected_active if t not in tags_in_displayed_rows] selected_other = _order_selected_tags_for_row( row_selected_tags=selected_other_raw, selected_index=selected_index, tag_selection_origins=tag_selection_origins, implied_parent_map=implied_parent_map, ) selected_other = _dedupe_norm_tags(selected_other) selected_other_meta = { t: { "origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")), "preselected": True, } for t in selected_other } row_defs.append( { "name": "selected_other", "label": _display_row_label("selected_other"), "tags": selected_other, "tag_meta": selected_other_meta, } ) for row in ranked_rows: group_name = row.group_name group_tag_set = displayed_group_tag_sets.get(group_name, set()) selected_in_group_raw = [t for t in selected_active if t in group_tag_set] selected_in_group = _order_selected_tags_for_row( row_selected_tags=selected_in_group_raw, selected_index=selected_index, tag_selection_origins=tag_selection_origins, implied_parent_map=implied_parent_map, ) ranked_tags = [ _norm_tag_for_lookup(t) for t, _ in row.tags if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t) ] ranked_tags = _dedupe_norm_tags(ranked_tags) merged = selected_in_group + [t for t in ranked_tags if t not in selected_in_group] merged = _dedupe_norm_tags(merged) keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_group)) merged = merged[:keep_n] tag_meta = { t: { "origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")), "preselected": t in selected_active, } for t in merged } row_defs.append( { "name": group_name, "label": _display_row_label(group_name), "tags": merged, "tag_meta": tag_meta, } ) # Keep this row at the bottom so category/group rows remain contiguous. if retrieved_other_row: row_defs.append(retrieved_other_row) return row_defs def _build_display_audit_line( row_defs: List[Dict[str, Any]], *, active_selected_tags: List[str], direct_selected_tags: List[str], implied_selected_tags: List[str], ) -> str: active_set = { _norm_tag_for_lookup(t) for t in (active_selected_tags or []) if t and not _is_artist_tag(t) } direct_set = { _norm_tag_for_lookup(t) for t in (direct_selected_tags or []) if t and not _is_artist_tag(t) } implied_set = { _norm_tag_for_lookup(t) for t in (implied_selected_tags or []) if t and not _is_artist_tag(t) } info_by_tag: Dict[str, Dict[str, Any]] = {} for row in row_defs or []: row_name = row.get("name", "") row_label = row.get("label", row_name) for tag in row.get("tags", []): rec = info_by_tag.setdefault(tag, {"rows": [], "sources": set()}) rec["rows"].append(row_label) if row_name == "selected_other": rec["sources"].add("selected_other_row") elif row_name == "other_retrieved": rec["sources"].add("other_retrieved_row") else: rec["sources"].add("ranked_group_row") if tag in active_set: rec["sources"].add("selected_active") if tag in direct_set: rec["sources"].add("selected_direct") if tag in implied_set: rec["sources"].add("selected_implied") payload = { "n_tags": len(info_by_tag), "tags": [ { "tag": tag, "rows": rec["rows"], "sources": sorted(rec["sources"]), } for tag, rec in sorted(info_by_tag.items()) ], } return "Display Tag Audit: " + json.dumps(payload, ensure_ascii=True) def _build_row_component_updates( row_defs: List[Dict[str, Any]], selected_tags: List[str], max_rows: int, ): selected = {t for t in (selected_tags or []) if t} row_defs_ui = (row_defs or [])[: max(0, int(max_rows))] row_values_state: List[List[str]] = [] header_updates = [] checkbox_updates = [] for idx in range(max_rows): if idx < len(row_defs_ui): row = row_defs_ui[idx] tags = _dedupe_norm_tags(row.get("tags", [])) values = [t for t in tags if t in selected] row_values_state.append(values) visible = bool(tags) header_updates.append(gr.update(value=row.get("label", ""), visible=visible)) tag_meta = row.get("tag_meta", {}) if isinstance(row.get("tag_meta", {}), dict) else {} choices = [] for t in tags: meta = tag_meta.get(t, {}) if isinstance(tag_meta.get(t, {}), dict) else {} origin = _normalize_selection_origin(str(meta.get("origin", "selection"))) preselected = bool(meta.get("preselected", False)) choices.append((_choice_label_with_source_meta(t, origin=origin, preselected=preselected), t)) checkbox_updates.append( gr.update( choices=choices, value=values, visible=visible, ) ) else: header_updates.append(gr.update(value="", visible=False)) checkbox_updates.append(gr.update(choices=[], value=[], visible=False)) prompt_text = _compose_toggle_prompt_text(list(selected), row_defs_ui) return prompt_text, row_values_state, header_updates, checkbox_updates def _on_toggle_row( row_idx: int, changed_values: List[str], selected_tags_state: List[str], rows_dirty_state: bool, row_defs_state: List[Dict[str, Any]], row_values_state: List[List[str]], max_rows: int, ): row_defs = row_defs_state or [] row_defs_ui = row_defs[: max(0, int(max_rows))] prev_values = list(row_values_state or []) selected_from_state = _collect_selected_from_state(selected_tags_state, row_defs_ui) selected_from_rows = _collect_selected_from_row_values(row_defs_ui, prev_values) # Prefer row-value state as source-of-truth (closest to visible UI), with selected-state as fallback. selected: Set[str] = set(selected_from_rows or selected_from_state) row = row_defs_ui[row_idx] if 0 <= row_idx < len(row_defs_ui) else {} row_tags = _dedupe_norm_tags(row.get("tags", [])) row_label = str(row.get("label", "")) row_tag_set = set(row_tags) row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags} # Be tolerant to UI payload forms: canonical tag values, display labels, normalized variants, # and occasional single-string payloads from frontend events. if changed_values is None: changed_iter: List[Any] = [] elif isinstance(changed_values, str): changed_iter = [changed_values] elif isinstance(changed_values, (list, tuple, set)): changed_iter = list(changed_values) else: changed_iter = [changed_values] # Be tolerant to UI payload forms: canonical tag values, display labels, or normalized variants. new_set: Set[str] = set() for raw in changed_iter: if raw in row_tag_set: new_set.add(raw) continue raw_norm = _norm_tag_for_lookup(str(raw)) mapped = row_tag_by_norm.get(raw_norm) if mapped: new_set.add(mapped) prev_row_selected = {t for t in row_tags if t in selected} # Ignore non-user/no-op events (e.g., programmatic value re-sets) deterministically. if new_set == prev_row_selected: prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui) checkbox_updates = [gr.skip() for _ in range(max_rows)] return [sorted(selected), rows_dirty_state, gr.skip(), prev_values, prompt_text, *checkbox_updates] selected.difference_update(row_tag_set) selected.update(new_set) toggled_tags = prev_row_selected ^ new_set new_row_values_state: List[List[str]] = [] affected_rows: Set[int] = {row_idx} for idx, row_item in enumerate(row_defs_ui): tags = _dedupe_norm_tags(row_item.get("tags", [])) values = [t for t in tags if t in selected] new_row_values_state.append(values) if toggled_tags and any(t in toggled_tags for t in tags): affected_rows.add(idx) checkbox_updates = [] for idx in range(max_rows): if idx >= len(row_defs_ui): checkbox_updates.append(gr.skip()) continue if idx in affected_rows: checkbox_updates.append(gr.update(value=new_row_values_state[idx])) else: checkbox_updates.append(gr.skip()) prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui) return [ sorted(selected), True, gr.update(visible=True, interactive=True), new_row_values_state, prompt_text, *checkbox_updates, ] def _build_ui_payload( *, console_text: str, row_defs: List[Dict[str, Any]], selected_tags: List[str], ): prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates( row_defs=row_defs, selected_tags=selected_tags, max_rows=display_max_rows_default, ) selected_ui: List[str] = [] selected_ui_seen: Set[str] = set() for vals in row_values_state: for t in vals: if t in selected_ui_seen: continue selected_ui_seen.add(t) selected_ui.append(t) return [ console_text, gr.update(visible=bool(row_defs)), prompt_text, selected_ui, False, gr.update(visible=False, interactive=False), row_defs, row_values_state, *header_updates, *checkbox_updates, ] def _prepare_run_ui() -> List[Any]: header_updates = [gr.update(value="", visible=False) for _ in range(display_max_rows_default)] checkbox_updates = [ gr.update(choices=[], value=[], visible=False) for _ in range(display_max_rows_default) ] return [ "Running...", gr.skip(), "Running... usually completes in about 20 seconds.", [], False, gr.update(visible=False, interactive=False), [], [], *header_updates, *checkbox_updates, ] def _rebuild_rows_from_selected( selected_tags_state: List[str], row_defs_state: List[Dict[str, Any]], row_values_state: List[List[str]], display_top_groups: float, display_top_tags_per_group: float, display_rank_top_k: float, ): existing_rows = row_defs_state or [] existing_values = list(row_values_state or []) selected_from_state = _collect_selected_from_state(selected_tags_state, existing_rows) selected_from_rows = _collect_selected_from_row_values(existing_rows, existing_values) # Rebuild source-of-truth is current row checkbox values; fall back only when unavailable. selected_seed = selected_from_rows if existing_values else selected_from_state selected_active = list( dict.fromkeys( _norm_tag_for_lookup(t) for t in selected_seed if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t) ) ) retrieved_candidate_tags: List[str] = [] tag_selection_origins: Dict[str, str] = {} for row in existing_rows: row_tags = row.get("tags", []) if isinstance(row, dict) else [] row_meta = row.get("tag_meta", {}) if isinstance(row, dict) else {} if not isinstance(row_meta, dict): row_meta = {} for t in row_tags: tn = _norm_tag_for_lookup(t) if not tn or _is_artist_tag(tn) or _is_excluded_recommendation_tag(tn): continue retrieved_candidate_tags.append(tn) if tn not in tag_selection_origins: meta = row_meta.get(t, {}) if isinstance(row_meta.get(t, {}), dict) else {} tag_selection_origins[tn] = _normalize_selection_origin(str(meta.get("origin", "selection"))) for t in selected_active: tag_selection_origins.setdefault(t, "user") retrieved_candidate_tags.append(t) implied_selected_tags = [t for t in selected_active if tag_selection_origins.get(t) == "implied"] implied_set = set(implied_selected_tags) direct_selected_tags = [t for t in selected_active if t not in implied_set] direct_idx = {t: i for i, t in enumerate(direct_selected_tags)} direct_selected_tags.sort( key=lambda t: ( _selection_source_rank(tag_selection_origins.get(t, "selection")), direct_idx.get(t, 10**9), ) ) implied_parent_map = _build_implied_parent_map( direct_tags_ordered=direct_selected_tags, implied_tags=implied_selected_tags, ) toggle_rows = _build_toggle_rows( seed_terms=list(selected_active), selected_tags=selected_active, retrieved_candidate_tags=list(dict.fromkeys(retrieved_candidate_tags)), tag_selection_origins=tag_selection_origins, implied_parent_map=implied_parent_map, top_groups=max(1, int(display_top_groups)), top_tags_per_group=max(1, int(display_top_tags_per_group)), group_rank_top_k=max(1, int(display_rank_top_k)), ) prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates( row_defs=toggle_rows, selected_tags=selected_active, max_rows=display_max_rows_default, ) return [ gr.update(visible=bool(toggle_rows)), prompt_text, sorted(selected_active), False, gr.update(visible=False, interactive=False), toggle_rows, row_values_state, *header_updates, *checkbox_updates, ] def _build_selection_query( prompt_in: str, rewritten: str, structural_tags: List[str], probe_tags: List[str], ) -> str: lines = [f"IMAGE DESCRIPTION: {prompt_in.strip()}"] if rewritten and rewritten.strip(): lines.append(f"REWRITE PHRASES: {rewritten.strip()}") hint_tags = [] if structural_tags: hint_tags.extend(structural_tags) if probe_tags: hint_tags.extend(probe_tags) if hint_tags: # Keep hints as context only; selection still must choose by candidate indices. lines.append( "INFERRED TAG HINTS (context only): " + ", ".join(sorted(set(hint_tags))) ) return "\n".join(lines) # Set up logging # Minimal prod logging: warnings+ to stderr, no file by default import os, logging LOG_LEVEL = os.environ.get("PSQ_LOG_LEVEL", "WARNING").upper() logging.basicConfig( level=getattr(logging, LOG_LEVEL, logging.WARNING), format="%(asctime)s %(levelname)s:%(message)s", handlers=[logging.StreamHandler()] # no file -> avoids huge logs on Spaces ) # Quiet down common noisy libs (optional) for _name in ("gensim", "gradio", "hnswlib", "httpx", "uvicorn"): logging.getLogger(_name).setLevel(logging.ERROR) # Turn off Gradio analytics phone-home to avoid those background thread errors (optional) os.environ["GRADIO_ANALYTICS_ENABLED"] = "0" MASCOT_DIR = Path(__file__).parent / "mascotimages" MASCOT_FILE = MASCOT_DIR / "transparentsquirrel.png" def _load_mascot_image(): """Load mascot image if available; return None when missing/unreadable.""" if not MASCOT_FILE.exists(): logging.warning("Mascot image missing: %s", MASCOT_FILE) return None try: return Image.open(MASCOT_FILE).convert("RGBA") except Exception as e: logging.warning("Failed to load mascot image (%s): %s", MASCOT_FILE, e) return None try: from gradio_client import utils as _gc_utils _orig_get_type = _gc_utils.get_type _orig_j2p = _gc_utils._json_schema_to_python_type _orig_pub = _gc_utils.json_schema_to_python_type def _get_type_safe(schema): # Sometimes schema is a bare True/False (JSON Schema boolean form) if not isinstance(schema, dict): return "any" return _orig_get_type(schema) def _j2p_safe(schema, defs=None): # Accept non-dict schemas (True/False/None) and treat as "any" if not isinstance(schema, dict): return "any" return _orig_j2p(schema, defs or schema.get("$defs")) def _pub_safe(schema): # Public wrapper used by Gradio; keep it resilient too if not isinstance(schema, dict): return "any" return _j2p_safe(schema, schema.get("$defs")) _gc_utils.get_type = _get_type_safe _gc_utils._json_schema_to_python_type = _j2p_safe _gc_utils.json_schema_to_python_type = _pub_safe except Exception as e: print("gradio_client hotfix not applied:", e) # ------------------------------------------------------------------------------- allow_nsfw_tags = False def _is_production_runtime() -> bool: """Best-effort detection for deployed runtime (HF Spaces or explicit env).""" if os.environ.get("PSQ_PRODUCTION", "").strip().lower() in {"1", "true", "yes"}: return True if os.environ.get("SPACE_ID"): return True if os.environ.get("HF_SPACE_ID"): return True if os.environ.get("SYSTEM") == "spaces": return True return False verbose_retrieval_default = "0" if _is_production_runtime() else "1" verbose_retrieval = os.environ.get("PSQ_VERBOSE_RETRIEVAL", verbose_retrieval_default).strip().lower() in {"1", "true", "yes"} verbose_retrieval_all = False verbose_retrieval_limit = 20 enable_probe_tags = os.environ.get("PSQ_ENABLE_PROBE", "1").strip() not in {"0", "false", "False"} display_top_groups_default = int(os.environ.get("PSQ_DISPLAY_TOP_GROUPS", "10")) display_top_tags_per_group_default = int(os.environ.get("PSQ_DISPLAY_TOP_TAGS_PER_GROUP", "5")) display_rank_top_k_default = int(os.environ.get("PSQ_DISPLAY_GROUP_RANK_TOP_K", "5")) display_max_rows_default = int(os.environ.get("PSQ_DISPLAY_MAX_ROWS", "14")) retrieval_global_k = int(os.environ.get("PSQ_RETRIEVAL_GLOBAL_K", "300")) retrieval_per_phrase_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_K", "10")) retrieval_per_phrase_final_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_FINAL_K", "1")) selection_mode = os.environ.get("PSQ_SELECTION_MODE", "chunked_map_union").strip() selection_chunk_size = int(os.environ.get("PSQ_SELECTION_CHUNK_SIZE", "60")) selection_per_phrase_k = int(os.environ.get("PSQ_SELECTION_PER_PHRASE_K", "2")) selection_candidate_cap = int(os.environ.get("PSQ_SELECTION_CANDIDATE_CAP", "0")) stage1_rewrite_timeout_s = float(os.environ.get("PSQ_TIMEOUT_REWRITE_S", "45")) stage1_struct_timeout_s = float(os.environ.get("PSQ_TIMEOUT_STRUCT_S", "45")) stage1_probe_timeout_s = float(os.environ.get("PSQ_TIMEOUT_PROBE_S", "45")) stage3_select_timeout_s = float(os.environ.get("PSQ_TIMEOUT_SELECT_S", "50")) stage3_select_retry_timeout_s = float(os.environ.get("PSQ_TIMEOUT_SELECT_RETRY_S", "20")) stage3_fast_retry_count = max(0, int(os.environ.get("PSQ_STAGE3_FAST_RETRY_COUNT", "1"))) timing_log_path = Path(os.environ.get("PSQ_TIMING_LOG_PATH", "data/runtime_metrics/ui_pipeline_timings.jsonl")) def _startup_preflight_errors() -> List[str]: errs: List[str] = [] if not os.getenv("OPENROUTER_API_KEY"): errs.append("OPENROUTER_API_KEY is missing. Set it in Space Secrets or environment variables.") return errs STARTUP_PREFLIGHT_ERRORS = _startup_preflight_errors() if STARTUP_PREFLIGHT_ERRORS: for _err in STARTUP_PREFLIGHT_ERRORS: logging.error("Startup preflight error: %s", _err) css = """ .scrollable-content{ max-height: 420px; overflow-y: scroll; /* always show scrollbar */ overflow-x: hidden; padding-right: 8px; padding-bottom: 14px; /* <— add this */ scrollbar-gutter: stable; /* prevent layout shift as it fills */ /* Firefox */ scrollbar-width: auto; scrollbar-color: rgba(180,180,180,.9) rgba(0,0,0,.15); } /* WebKit/Chromium (Chrome/Edge/Safari) */ .scrollable-content::-webkit-scrollbar{ width: 10px; } .scrollable-content::-webkit-scrollbar-thumb{ background: rgba(180,180,180,.9); border-radius: 8px; } .scrollable-content::-webkit-scrollbar-track{ background: rgba(0,0,0,.15); } /* (Optional) make both scroll panes taller so they fill more of the column */ .pane-left .scrollable-content, .pane-right .scrollable-content { max-height: 610px; /* was 420px; tweak to taste */ } .lego-tags .gr-checkboxgroup, .lego-tags .wrap { display: flex !important; flex-wrap: wrap !important; gap: 10px !important; } .lego-tags label { margin: 0 !important; padding: 0 !important; position: relative !important; } /* Hide native checkbox visuals completely */ .lego-tags input[type="checkbox"] { appearance: none !important; -webkit-appearance: none !important; -moz-appearance: none !important; position: absolute !important; width: 1px !important; height: 1px !important; opacity: 0 !important; pointer-events: none !important; display: none !important; } /* Brick button skin (works for both +span and ~span structures) */ .lego-tags input[type="checkbox"] + span, .lego-tags input[type="checkbox"] ~ span { --on-bg1: #ffd166; --on-bg2: #f39c4a; --on-border: #b86e21; --on-text: #2e1706; position: relative !important; display: inline-flex !important; align-items: center !important; min-height: 40px !important; padding: 10px 15px 9px 22px !important; border: 1px solid #9aa6b8 !important; border-radius: 10px !important; background: linear-gradient(180deg, #dfe5ee 0%, #bec8d6 100%) !important; color: #364254 !important; font-size: 0.97rem !important; font-weight: 800 !important; line-height: 1.15 !important; cursor: pointer !important; user-select: none !important; letter-spacing: 0.01em !important; box-shadow: 0 3px 0 rgba(0,0,0,0.16), inset 0 1px 0 rgba(255,255,255,0.55) !important; transition: transform 0.08s ease, box-shadow 0.08s ease, filter 0.08s ease !important; } .lego-tags input[type="checkbox"] + span::before, .lego-tags input[type="checkbox"] ~ span::before { content: "" !important; position: absolute !important; top: 5px !important; left: 8px !important; width: 8px !important; height: 8px !important; border-radius: 50% !important; background: rgba(255,255,255,0.58) !important; box-shadow: 22px 0 0 rgba(255,255,255,0.58) !important; pointer-events: none !important; } /* Unselected cue: show "+" on the left. */ .lego-tags input[type="checkbox"] + span::after, .lego-tags input[type="checkbox"] ~ span::after { content: "+" !important; position: absolute !important; left: 6px !important; top: 50% !important; transform: translateY(-52%) !important; font-size: 1rem !important; font-weight: 900 !important; color: #4b5563 !important; opacity: 0.95 !important; pointer-events: none !important; } /* Bright color cycle used only when selected */ .lego-tags label:nth-child(8n+1) span { --on-bg1: #ffd166; --on-bg2: #f39c4a; --on-border: #b86e21; --on-text: #2e1706; } .lego-tags label:nth-child(8n+2) span { --on-bg1: #6ee7ff; --on-bg2: #1fb7ff; --on-border: #157cb3; --on-text: #07263c; } .lego-tags label:nth-child(8n+3) span { --on-bg1: #9dff8f; --on-bg2: #45c96f; --on-border: #2a8b4b; --on-text: #0d2917; } .lego-tags label:nth-child(8n+4) span { --on-bg1: #ff8fab; --on-bg2: #ff5c7a; --on-border: #b83956; --on-text: #3f0f1d; } .lego-tags label:nth-child(8n+5) span { --on-bg1: #d0a8ff; --on-bg2: #a46cff; --on-border: #7147b3; --on-text: #25143f; } .lego-tags label:nth-child(8n+6) span { --on-bg1: #ffe27a; --on-bg2: #f7bf39; --on-border: #ad7f1f; --on-text: #332407; } .lego-tags label:nth-child(8n+7) span { --on-bg1: #8effd5; --on-bg2: #2ed6b5; --on-border: #1e947d; --on-text: #0d2a25; } .lego-tags label:nth-child(8n+8) span { --on-bg1: #ffb47e; --on-bg2: #ff8753; --on-border: #b95b2d; --on-text: #391a0a; } /* Source-driven selected colors (applies when tags are preselected by the pipeline). */ .lego-tags label[data-psq-preselected="1"][data-psq-origin="rewrite"] span { --on-bg1: #77f0d7; --on-bg2: #26b9a3; --on-border: #187869; --on-text: #062923; } .lego-tags label[data-psq-preselected="1"][data-psq-origin="selection"] span { --on-bg1: #ffd98a; --on-bg2: #f0a93c; --on-border: #a66f1f; --on-text: #382206; } .lego-tags label[data-psq-preselected="1"][data-psq-origin="probe"] span { --on-bg1: #d8b4ff; --on-bg2: #9a6cff; --on-border: #6745b0; --on-text: #24143b; } .lego-tags label[data-psq-preselected="1"][data-psq-origin="structural"] span { --on-bg1: #a6f79a; --on-bg2: #53c368; --on-border: #2f8442; --on-text: #102d17; } .lego-tags label[data-psq-preselected="1"][data-psq-origin="implied"] span { --on-bg1: #d7dde8; --on-bg2: #a8b3c4; --on-border: #6f7e95; --on-text: #1d2633; } /* User-selected tags (not initially selected by the pipeline). */ .lego-tags label[data-psq-preselected="0"] span { --on-bg1: #9ec5ff; --on-bg2: #4f86ff; --on-border: #2f5fbf; --on-text: #0b1f42; } .lego-tags label:hover span { filter: brightness(1.02) !important; transform: translateY(1px) !important; } /* ON state: brighter + visibly recessed */ .lego-tags input[type="checkbox"]:checked + span, .lego-tags input[type="checkbox"]:checked ~ span, .lego-tags label:has(input[type="checkbox"]:checked) span { background: linear-gradient(180deg, var(--on-bg1) 0%, var(--on-bg2) 100%) !important; color: var(--on-text) !important; border-color: var(--on-border) !important; filter: saturate(1.2) brightness(1.12) !important; transform: translateY(-2px) !important; box-shadow: inset 0 3px 6px rgba(0,0,0,0.20), inset 0 -1px 0 rgba(255,255,255,0.36), 0 6px 0 rgba(0,0,0,0.32) !important; } .lego-tags input[type="checkbox"]:checked + span::after, .lego-tags input[type="checkbox"]:checked ~ span::after, .lego-tags label:has(input[type="checkbox"]:checked) span::after { content: "" !important; } .source-legend { display: flex; flex-wrap: wrap; align-items: center; gap: 8px; margin: 4px 0 10px 0; } .source-legend .legend-title { font-size: 0.92rem; font-weight: 900; color: #334155; margin-right: 4px; } .source-legend .chip { display: inline-flex; align-items: center; border-radius: 10px; border: 1px solid #6c7788; padding: 6px 12px; font-size: 0.85rem; font-weight: 800; color: #111827; background: #f3f6fb; } .source-legend .chip.rewrite { background: #26b9a3; color: #062923; border-color: #187869; } .source-legend .chip.selection { background: #f0a93c; color: #382206; border-color: #a66f1f; } .source-legend .chip.probe { background: #9a6cff; color: #ffffff; border-color: #6745b0; } .source-legend .chip.structural { background: #53c368; color: #102d17; border-color: #2f8442; } .source-legend .chip.implied { background: #a8b3c4; color: #1d2633; border-color: #6f7e95; } .source-legend .chip.user { background: #4f86ff; color: #ffffff; border-color: #2f5fbf; } .source-legend .chip.unselected { background: #c7ced8; color: #2d3440; border-color: #7d8897; } .row-heading p { margin: 8px 0 0 0 !important; font-size: 1.18rem !important; font-weight: 850 !important; line-height: 1.2 !important; } .row-instruction { text-align: center; margin: 8px 0 12px 0; } .row-instruction p { margin: 0 !important; font-size: 1.02rem !important; font-style: italic !important; font-weight: 800 !important; color: #1d4ed8 !important; } .top-instruction { text-align: center; margin: 2px 0 6px 0; } .top-instruction p { margin: 0 !important; font-size: 1.02rem !important; font-style: italic !important; font-weight: 800 !important; color: #1d4ed8 !important; } .run-hint { margin-top: 6px; text-align: center; } .run-hint p { margin: 0 !important; font-size: 0.9rem !important; font-style: italic !important; color: #475569 !important; } .prompt-card { background: transparent !important; border: none !important; box-shadow: none !important; padding: 0 !important; } .suggested-prompt-box { margin-top: 2px !important; } .suggested-prompt-card { margin-top: 10px !important; } """ client_js = """ () => {} """ def rag_pipeline_ui( user_prompt: str, display_top_groups: float, display_top_tags_per_group: float, display_rank_top_k: float, ): logs = [] def log(s): logs.append(s) try: stage_timings = {} def _record_timing(stage: str, dt_s: float): stage_timings[stage] = float(dt_s) def _emit_timing_summary(total_s: float): summary_order = [ "preprocess", "rewrite", "structural", "probe", "retrieval", "selection", "implication_expansion", "prompt_composition", "group_display", ] lines = [] for k in summary_order: if k in stage_timings: lines.append(f"{k}={stage_timings[k]:.2f}s") slowest = max(stage_timings.items(), key=lambda kv: kv[1])[0] if stage_timings else "n/a" log("Timing Summary: " + ", ".join(lines)) log(f"Timing Slowest Stage: {slowest}") log(f"Timing Total: {total_s:.2f}s") def _append_timing_jsonl(total_s: float): try: timing_log_path.parent.mkdir(parents=True, exist_ok=True) rec = { "timestamp_utc": datetime.utcnow().isoformat(timespec="seconds") + "Z", "stages_s": stage_timings, "total_s": float(total_s), "config": { "timeout_rewrite_s": stage1_rewrite_timeout_s, "timeout_struct_s": stage1_struct_timeout_s, "timeout_probe_s": stage1_probe_timeout_s, "timeout_select_s": stage3_select_timeout_s, }, } with timing_log_path.open("a", encoding="utf-8") as f: f.write(json.dumps(rec, ensure_ascii=True) + "\n") log(f"Timing Log: wrote {timing_log_path}") except Exception as e: log(f"Timing Log: failed ({type(e).__name__}: {e})") def _future_with_timeout( fut, timeout_s: float, stage_name: str, fallback, *, strict: bool = False, ): t0 = time.perf_counter() try: out = fut.result(timeout=max(1.0, float(timeout_s))) dt = time.perf_counter() - t0 log(f"{stage_name}: {dt:.2f}s") stage_key = { "Rewrite": "rewrite", "Structural inference": "structural", "Probe inference": "probe", "Index selection": "selection", }.get(stage_name) if stage_key: _record_timing(stage_key, dt) return out except FutureTimeoutError: fut.cancel() msg = f"{stage_name}: timed out after {timeout_s:.0f}s" if strict: raise RuntimeError(msg) log(f"{msg}; using fallback") return fallback except Exception as e: msg = f"{stage_name}: failed ({type(e).__name__}: {e})" if strict: raise RuntimeError(msg) log(f"{msg}; using fallback") return fallback t_total0 = time.perf_counter() log("Start: received prompt") if STARTUP_PREFLIGHT_ERRORS: log("Startup preflight failed:") for e in STARTUP_PREFLIGHT_ERRORS: log(f"- {e}") return _build_ui_payload( console_text="\n".join(logs), row_defs=[], selected_tags=[], ) prompt_in = (user_prompt or "").strip() if not prompt_in: return _build_ui_payload( console_text="Error: empty prompt", row_defs=[], selected_tags=[], ) log("Input:") log(prompt_in) log("") log( "Runtime config: " f"retrieval_global_k={retrieval_global_k} " f"retrieval_per_phrase_k={retrieval_per_phrase_k} " f"retrieval_per_phrase_final_k={retrieval_per_phrase_final_k} " f"selection_mode={selection_mode} " f"selection_chunk_size={selection_chunk_size} " f"selection_per_phrase_k={selection_per_phrase_k} " f"min_tag_count={_get_min_tag_count()} " f"select_timeout_s={stage3_select_timeout_s:.0f} " f"select_retry_timeout_s={stage3_select_retry_timeout_s:.0f} " f"select_fast_retries={stage3_fast_retry_count}" ) log("") t0 = time.perf_counter() min_tag_count = _get_min_tag_count() user_tags_raw = extract_user_provided_tags_upto_3_words(prompt_in) user_tags, removed_user_low = _filter_min_count_tags(user_tags_raw, min_tag_count) dt = time.perf_counter()-t0 _record_timing("preprocess", dt) log(f"Preprocess (user tag extraction): {dt:.2f}s") log("Heuristically extracted user tags:") if user_tags: log(", ".join(user_tags)) else: log("(none)") if removed_user_low: log( f"Filtered {len(removed_user_low)} low-frequency user tags " f"(<{min_tag_count}): {', '.join(removed_user_low)}" ) log("") log("Step 1: LLM rewrite + structural inference + probe (concurrent)") max_workers = 3 if enable_probe_tags else 2 ex = ThreadPoolExecutor(max_workers=max_workers) try: fut_rewrite = ex.submit(llm_rewrite_prompt, prompt_in, log) fut_struct = ex.submit(llm_infer_structural_tags, prompt_in, log=log) fut_probe = ex.submit(llm_infer_probe_tags, prompt_in, log=log) if enable_probe_tags else None rewritten = _future_with_timeout( fut_rewrite, stage1_rewrite_timeout_s, "Rewrite", "", strict=True, ) structural_tags = _future_with_timeout( fut_struct, stage1_struct_timeout_s, "Structural inference", [] ) probe_tags = ( _future_with_timeout(fut_probe, stage1_probe_timeout_s, "Probe inference", []) if fut_probe else [] ) finally: ex.shutdown(wait=False, cancel_futures=True) structural_tags, removed_struct_low = _filter_min_count_tags(structural_tags, min_tag_count) probe_tags, removed_probe_low = _filter_min_count_tags(probe_tags, min_tag_count) if removed_struct_low: log( f"Filtered {len(removed_struct_low)} low-frequency structural tags " f"(<{min_tag_count}): {', '.join(removed_struct_low)}" ) if removed_probe_low: log( f"Filtered {len(removed_probe_low)} low-frequency probe tags " f"(<{min_tag_count}): {', '.join(removed_probe_low)}" ) if not rewritten: raise RuntimeError("Rewrite: empty output") log("Rewrite:") log(rewritten if rewritten else "(empty)") log("") rewrite_for_retrieval = rewritten if user_tags: # keep them separate in logs, but allow them to help retrieval rewrite_for_retrieval = (rewrite_for_retrieval + ", " + ", ".join(user_tags)).strip(", ").strip() log("Step 2: Prompt Squirrel retrieval (hidden)") try: t0 = time.perf_counter() retrieval_context_tags = list(dict.fromkeys((structural_tags or []) + (probe_tags or []))) rewrite_phrases = [p.strip() for p in (rewrite_for_retrieval or "").split(",") if p.strip()] retrieval_result = psq_candidates_from_rewrite_phrases( rewrite_phrases=rewrite_phrases, allow_nsfw_tags=allow_nsfw_tags, context_tags=retrieval_context_tags, global_k=max(1, retrieval_global_k), per_phrase_k=max(1, retrieval_per_phrase_k), per_phrase_final_k=max(1, retrieval_per_phrase_final_k), min_tag_count=max(0, min_tag_count), verbose=verbose_retrieval, ) if isinstance(retrieval_result, tuple): candidates, phrase_reports = retrieval_result else: candidates, phrase_reports = retrieval_result, [] if selection_candidate_cap > 0 and len(candidates) > selection_candidate_cap: candidates = candidates[:selection_candidate_cap] log(f"Selection candidate cap applied: {selection_candidate_cap}") dt = time.perf_counter()-t0 _record_timing("retrieval", dt) log(f"Retrieval: {dt:.2f}s") log(f"Retrieved {len(candidates)} candidate tags") if verbose_retrieval: log(f"Total unique candidates: {len(candidates)}") limit = None if verbose_retrieval_all else max(1, int(verbose_retrieval_limit)) for report in phrase_reports: phrase = report.get("normalized") or report.get("phrase") or "" lookup = report.get("lookup") or "" tfidf_vocab = report.get("tfidf_vocab") log(f"Phrase: {phrase} (lookup={lookup}) tfidf_vocab={tfidf_vocab}") rows = report.get("candidates", []) shown = rows if limit is None else rows[:limit] for row in shown: tag = row.get("tag") alias_token = row.get("alias_token") score_fasttext = row.get("score_fasttext") score_context = row.get("score_context") score_combined = row.get("score_combined") count = row.get("count") alias_part = "" if alias_token and alias_token != tag: alias_part = f" [alias_token={alias_token}]" fasttext_str = ( f"{score_fasttext:.3f}" if isinstance(score_fasttext, (int, float)) else score_fasttext ) if score_context is None: context_str = "None" else: context_str = ( f"{score_context:.3f}" if isinstance(score_context, (int, float)) else score_context ) combined_str = ( f"{score_combined:.3f}" if isinstance(score_combined, (int, float)) else score_combined ) log( f" {tag}{alias_part} | fasttext={fasttext_str} context={context_str} " f"combined={combined_str} count={count}" ) if limit is not None and len(rows) > limit: log(f" ... ({len(rows) - limit} more)") except Exception as e: log(f"Retrieval fallback: {type(e).__name__}: {e}") candidates = [] retrieved_candidate_tags = list( dict.fromkeys( _norm_tag_for_lookup(c.tag) for c in (candidates or []) if getattr(c, "tag", None) ) ) log("Step 3: LLM index selection (uses rewrite + structural/probe context)") selection_query = _build_selection_query( prompt_in=prompt_in, rewritten=rewritten, structural_tags=structural_tags, probe_tags=probe_tags, ) picked_indices = None last_stage3_error: Exception | None = None stage3_attempts = 1 + int(stage3_fast_retry_count) for attempt_i in range(stage3_attempts): timeout_s = stage3_select_timeout_s if attempt_i == 0 else stage3_select_retry_timeout_s if attempt_i > 0: log( f"Index selection: fast retry {attempt_i}/{stage3_fast_retry_count} " f"(timeout={timeout_s:.0f}s)" ) ex = ThreadPoolExecutor(max_workers=1) try: fut_sel = ex.submit( llm_select_indices, query_text=selection_query, candidates=candidates, max_pick=0, log=log, mode=selection_mode, chunk_size=max(1, selection_chunk_size), per_phrase_k=max(1, selection_per_phrase_k), ) picked_indices = _future_with_timeout( fut_sel, timeout_s, "Index selection", [], strict=True, ) last_stage3_error = None break except Exception as e: last_stage3_error = e log(f"Index selection attempt {attempt_i + 1} failed: {e}") finally: ex.shutdown(wait=False, cancel_futures=True) if picked_indices is None: raise RuntimeError( f"Index selection failed after {stage3_attempts} attempt(s): {last_stage3_error}" ) selection_selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else [] selection_selected_tags, removed_stage3_low = _filter_min_count_tags(selection_selected_tags, min_tag_count) if removed_stage3_low: log( f" Filtered {len(removed_stage3_low)} low-frequency stage3 tags " f"(<{min_tag_count}): {', '.join(removed_stage3_low)}" ) selected_tags = list(selection_selected_tags) if structural_tags: # Add structural tags that aren't already selected existing = {t for t in selected_tags} new_structural = [t for t in structural_tags if t not in existing] selected_tags.extend(new_structural) log(f" Added {len(new_structural)} structural tags: {', '.join(new_structural)}") else: log(" No structural tags inferred") if probe_tags: existing = {t for t in selected_tags} new_probe = [t for t in probe_tags if t not in existing] selected_tags.extend(new_probe) log(f" Added {len(new_probe)} probe tags: {', '.join(new_probe)}") elif enable_probe_tags: log(" No probe tags inferred") selected_tags, removed_excluded_direct = _filter_excluded_recommendation_tags(selected_tags) if removed_excluded_direct: log(f" Removed {len(removed_excluded_direct)} excluded tags: {', '.join(removed_excluded_direct)}") direct_selected_tags = list(dict.fromkeys(selected_tags)) log("Step 3c: Expand via tag implications") t0 = time.perf_counter() tag_set = set(selected_tags) expanded, implied_only = expand_tags_via_implications(tag_set) dt = time.perf_counter()-t0 _record_timing("implication_expansion", dt) log(f"Implication expansion: {dt:.2f}s") implied_selected_tags = sorted(implied_only) if implied_only else [] if implied_only: implied_added = sorted(implied_only) implied_added, removed_implied_low = _filter_min_count_tags(implied_added, min_tag_count) implied_selected_tags = list(implied_added) if implied_added: selected_tags.extend(implied_added) log(f" Added {len(implied_added)} implied tags: {', '.join(implied_added)}") if removed_implied_low: log( f" Filtered {len(removed_implied_low)} low-frequency implied tags " f"(<{min_tag_count}): {', '.join(removed_implied_low)}" ) else: log(" No additional implied tags") selected_tags, removed_excluded_implied = _filter_excluded_recommendation_tags(selected_tags) implied_selected_tags = [ t for t in implied_selected_tags if not _is_excluded_recommendation_tag(t) ] if removed_excluded_implied: log( f" Removed {len(removed_excluded_implied)} excluded tags after implications: " f"{', '.join(removed_excluded_implied)}" ) log("Step 4: Compose final prompt") t0 = time.perf_counter() final_prompt = compose_final_prompt(rewritten, selected_tags) dt = time.perf_counter()-t0 _record_timing("prompt_composition", dt) log(f"Prompt composition: {dt:.2f}s") log("Step 5: Build ranked group/category display") t0 = time.perf_counter() seed_terms = [] seed_terms.extend(user_tags) seed_terms.extend([p.strip() for p in (rewritten or "").split(",") if p.strip()]) seed_terms.extend(structural_tags or []) seed_terms.extend(probe_tags or []) seed_terms.extend(selected_tags) seed_terms = list(dict.fromkeys(seed_terms)) active_selected_tags = list(dict.fromkeys(selected_tags)) structural_set = {_norm_tag_for_lookup(t) for t in (structural_tags or []) if t} probe_set = {_norm_tag_for_lookup(t) for t in (probe_tags or []) if t} implied_set = {_norm_tag_for_lookup(t) for t in (implied_selected_tags or []) if t} rewrite_set = { _norm_tag_for_lookup(t) for t in (list(user_tags or []) + [p.strip() for p in (rewritten or "").split(",") if p.strip()]) if t } selection_set = {_norm_tag_for_lookup(t) for t in (selection_selected_tags or []) if t} tag_selection_origins: Dict[str, str] = {} for tag in active_selected_tags: tag_norm = _norm_tag_for_lookup(tag) if tag_norm in structural_set: origin = "structural" elif tag_norm in probe_set: origin = "probe" elif tag_norm in rewrite_set: origin = "rewrite" elif tag_norm in selection_set: origin = "selection" elif tag_norm in implied_set: origin = "implied" else: # Unknown/fallback tags use selection color. origin = "selection" tag_selection_origins[tag] = origin if tag_norm and tag_norm != tag: tag_selection_origins[tag_norm] = origin direct_tags_for_implied = list( dict.fromkeys(_norm_tag_for_lookup(t) for t in (direct_selected_tags or []) if t) ) direct_tags_for_implied_idx = {t: i for i, t in enumerate(direct_tags_for_implied)} direct_tags_for_implied.sort( key=lambda t: ( _selection_source_rank(tag_selection_origins.get(t, "selection")), direct_tags_for_implied_idx.get(t, 10**9), ) ) implied_parent_map = _build_implied_parent_map( direct_tags_ordered=direct_tags_for_implied, implied_tags=implied_selected_tags, ) toggle_rows = _build_toggle_rows( seed_terms=seed_terms, selected_tags=active_selected_tags, retrieved_candidate_tags=retrieved_candidate_tags, tag_selection_origins=tag_selection_origins, implied_parent_map=implied_parent_map, top_groups=max(1, int(display_top_groups)), top_tags_per_group=max(1, int(display_top_tags_per_group)), group_rank_top_k=max(1, int(display_rank_top_k)), ) dt = time.perf_counter()-t0 _record_timing("group_display", dt) log(f"Ranked group display: {dt:.2f}s ({len(toggle_rows)} rows)") log( _build_display_audit_line( toggle_rows, active_selected_tags=active_selected_tags, direct_selected_tags=direct_selected_tags, implied_selected_tags=implied_selected_tags, ) ) for idx, row in enumerate(toggle_rows[: max(0, int(display_max_rows_default))]): tags_preview = ", ".join(row.get("tags", [])) log(f"UI Row {idx}: {row.get('label', '')} :: {tags_preview}") total_dt = time.perf_counter()-t_total0 _emit_timing_summary(total_dt) _append_timing_jsonl(total_dt) log("Done: final prompt ready") return _build_ui_payload( console_text="\n".join(logs), row_defs=toggle_rows, selected_tags=active_selected_tags, ) except Exception as e: log(f"Error: {type(e).__name__}: {e}") return _build_ui_payload( console_text="\n".join(logs), row_defs=[], selected_tags=[], ) with gr.Blocks(css=css, js=client_js) as app: with gr.Row(): with gr.Column(scale=3, elem_classes=["prompt-col"]): gr.Markdown( 'Describe your image under "Enter Prompt" and click "Run". ' 'Prompt Squirrel will translate it into image board tags.', elem_classes=["top-instruction"], ) with gr.Group(elem_classes=["prompt-card"]): image_tags = gr.Textbox( label="Enter Prompt", placeholder="e.g. fox, outside, detailed background, .", lines=1, elem_classes=["enter-prompt-box"], ) with gr.Group(elem_classes=["prompt-card", "suggested-prompt-card"]): suggested_prompt = gr.Textbox( label="Suggested Prompt (Read-only)", lines=2, interactive=False, show_copy_button=True, placeholder='Suggested prompt will appear here after you click "Run".', elem_classes=["suggested-prompt-box"], ) with gr.Column(scale=1): _mascot_pil = _load_mascot_image() if _mascot_pil is not None: mascot_img = gr.Image( value=_mascot_pil, show_label=False, interactive=False, height=240, elem_id="mascot" ) else: mascot_img = gr.Markdown("`(mascot image unavailable)`") submit_button = gr.Button("Run", variant="primary") gr.Markdown("Typical runtime: up to ~20 seconds.", elem_classes=["run-hint"]) selected_tags_state = gr.State([]) rows_dirty_state = gr.State(False) row_defs_state = gr.State([]) row_values_state = gr.State([]) toggle_instruction = gr.Markdown( "Click tag buttons to add or remove tags from the suggested prompt.", elem_classes=["row-instruction"], visible=False, ) row_headers: List[gr.Markdown] = [] row_checkboxes: List[gr.CheckboxGroup] = [] for _ in range(display_max_rows_default): with gr.Row(): with gr.Column(scale=2, min_width=170): row_headers.append(gr.Markdown(value="", visible=False, elem_classes=["row-heading"])) with gr.Column(scale=10): row_checkboxes.append( gr.CheckboxGroup( choices=[], value=[], visible=False, interactive=True, container=False, elem_classes=["lego-tags"], ) ) with gr.Row(): with gr.Column(scale=10): gr.HTML( """