import os import logging import time import json import csv import re import base64 import atexit import html from datetime import datetime from functools import lru_cache from PIL import Image from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError def _startup_profile_enabled() -> bool: raw = (os.environ.get("PSQ_STARTUP_PROFILE", "") or "").strip().lower() return raw in {"1", "true", "yes", "on"} _STARTUP_PROFILE_ON = _startup_profile_enabled() _STARTUP_PROFILE_T0 = time.perf_counter() _STARTUP_PROFILE_PATH: Path | None = None _STARTUP_PROFILE_FILE = None _STARTUP_PROFILE_CLIENT_LOAD_MARKED = False if _STARTUP_PROFILE_ON: _startup_profile_path_raw = ( os.environ.get("PSQ_STARTUP_PROFILE_PATH", "") or "" ).strip() if _startup_profile_path_raw: _STARTUP_PROFILE_PATH = Path(_startup_profile_path_raw) else: _STARTUP_PROFILE_PATH = Path("data/runtime_metrics") / ( f"startup_profile_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" ) try: _STARTUP_PROFILE_PATH.parent.mkdir(parents=True, exist_ok=True) _STARTUP_PROFILE_FILE = _STARTUP_PROFILE_PATH.open("a", encoding="utf-8") except Exception: _STARTUP_PROFILE_FILE = None def _startup_profile_mark(event: str, **fields: Any) -> None: if not _STARTUP_PROFILE_ON: return rec: Dict[str, Any] = { "event": str(event), "t_s": round(time.perf_counter() - _STARTUP_PROFILE_T0, 6), } if fields: rec.update(fields) line = json.dumps(rec, ensure_ascii=False) if _STARTUP_PROFILE_FILE is not None: try: _STARTUP_PROFILE_FILE.write(line + "\n") _STARTUP_PROFILE_FILE.flush() except Exception: pass print(f"STARTUP_PROFILE {line}") def _startup_profile_close() -> None: if _STARTUP_PROFILE_FILE is None: return try: _STARTUP_PROFILE_FILE.close() except Exception: pass _startup_profile_mark("module_bootstrap_begin") if _STARTUP_PROFILE_ON and _STARTUP_PROFILE_PATH is not None: _startup_profile_mark("startup_profile_enabled", path=str(_STARTUP_PROFILE_PATH)) atexit.register(_startup_profile_close) import gradio as gr _startup_profile_mark("import.gradio.done") from psq_rag.pipeline.preproc import ( extract_exact_tag_query_phrases, extract_user_provided_tags_upto_3_words, ) _startup_profile_mark("import.psq_rag.pipeline.preproc.done") from psq_rag.llm.rewrite import llm_rewrite_prompt _startup_profile_mark("import.psq_rag.llm.rewrite.done") from psq_rag.llm.rewrite_local_t5 import local_t5_rewrite_prompt _startup_profile_mark("import.psq_rag.llm.rewrite_local_t5.done") from psq_rag.retrieval.psq_retrieval import Candidate, psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup _startup_profile_mark("import.psq_rag.retrieval.psq_retrieval.done") from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags _startup_profile_mark("import.psq_rag.llm.select.done") from psq_rag.retrieval.state import ( expand_tags_via_implications, get_tag_type_name, get_tag_implications, get_tag_counts, get_alias2tags, get_nsfw_tags, ) _startup_profile_mark("import.psq_rag.retrieval.state.done") from psq_rag.ui.group_ranked_display import rank_groups_from_tfidf, _load_enabled_groups _startup_profile_mark("import.psq_rag.ui.group_ranked_display.done") APP_DIR = Path(__file__).parent DOCS_DIR = APP_DIR / "docs" ARCH_DIAGRAM_FILE = DOCS_DIR / "assets" / "architecture_overview.png" ARCH_DIAGRAM_MARKER = "{{ARCHITECTURE_DIAGRAM}}" ARCH_DIAGRAM_INSERT_BEFORE_HEADING = "## What Each Step Does" DEFAULT_STAGE_OPENROUTER_MODEL = "mistralai/mistral-small-24b-instruct-2501" DEFAULT_STAGE_OPENROUTER_FALLBACK_MODEL = "meta-llama/llama-3.1-8b-instruct" DEFAULT_PROMPT_EXAMPLE = ( "A feral fox with blue eyes and orange fur standing and looking at the viewer " "with a calm expression in a dense forest" ) T5_REWRITE_PROMPT_PREFIX = "The image showcases " SUGGESTED_PROMPT_LABEL_READY = "Suggested Prompt (Read-only)" SUGGESTED_PROMPT_LABEL_PHASE_START = "Thinking... Starting" SUGGESTED_PROMPT_LABEL_PHASE_REWRITE = "Thinking... Rewriting" SUGGESTED_PROMPT_LABEL_PHASE_SELECT = "Thinking... Retrieving + Selecting" def _redact_console_error_text(err: Any) -> str: """Redact provider/user identifiers and token-like substrings from console logs.""" text = str(err or "") text = re.sub(r"\buser_(?!id\b)[A-Za-z0-9]{6,}\b", "user_", text) text = re.sub(r"(?i)\b(bearer)\s+[A-Za-z0-9._:-]+\b", r"\1 ", text) text = re.sub(r"\b(sk|or)-[A-Za-z0-9._-]+\b", r"\1-", text) return text _CORPORATE_HARDBLOCK_PATTERNS = [ # Rating-like explicitness markers. re.compile(r"(^|_)(nsfw|explicit|questionable)(_|$)", re.IGNORECASE), # Unambiguous sexual anatomy. re.compile( r"(^|_)(breast|breasts|boob|boobs|nipple|nipples|penis|vagina|pussy|clit|testicle|scrotum|genital|crotch|anus|anal|areola)(_|$)", re.IGNORECASE, ), # Unambiguous sexual activity. re.compile( r"(^|_)(sex|sexual|fucking|fuck|blowjob|handjob|masturbat|penetrat|thrust|orgasm|cum|ejaculat|creampie|nude|naked|topless|bottomless|moan|sexy)(_|$)", re.IGNORECASE, ), # Common kink/fetish markers. re.compile(r"(^|_)(fetish|bdsm|bondage|dominatrix|submission|vore|inflation|watersports)(_|$)", re.IGNORECASE), ] def _split_prompt_commas(s: str) -> List[str]: return [p.strip() for p in (s or "").split(",") if p.strip()] def _norm_for_dedupe(tag: str) -> str: # your canonical form for lookup/dedupe return _norm_tag_for_lookup(tag.lower()) def _env_flag(name: str, default: bool) -> bool: raw = (os.environ.get(name, "") or "").strip().lower() if not raw: return default return raw not in {"0", "false", "no", "off"} def _env_int(name: str, default: int, minimum: int = 1) -> int: raw = (os.environ.get(name, "") or "").strip() if not raw: return max(minimum, int(default)) try: return max(minimum, int(raw)) except Exception: return max(minimum, int(default)) def _get_rewrite_source() -> str: source = (os.environ.get("PSQ_REWRITE_SOURCE", "t5") or "t5").strip().lower() if source in {"llm", "t5"}: return source return "t5" def _rewrite_prompt(prompt_in: str, log) -> str: source = _get_rewrite_source() if source == "llm": return llm_rewrite_prompt(prompt_in, log) t5_model_dir = ( os.environ.get( "PSQ_T5_REWRITE_MODEL_DIR", "models/finetune/t5-rewrite-n30best-20260508/checkpoint-18000", ) or "models/finetune/t5-rewrite-n30best-20260508/checkpoint-18000" ).strip() t5_num_beams = _env_int("PSQ_T5_REWRITE_NUM_BEAMS", 1, minimum=1) t5_max_new_tokens = _env_int("PSQ_T5_REWRITE_MAX_NEW_TOKENS", 128, minimum=8) # Keep T5 input style aligned with caption training distribution. t5_prompt_in = f"{T5_REWRITE_PROMPT_PREFIX}{prompt_in}" out = local_t5_rewrite_prompt( t5_prompt_in, log=log, model_dir=t5_model_dir, num_beams=t5_num_beams, max_new_tokens=t5_max_new_tokens, ) if out: return out # Keep wiring available, but default to no LLM fallback for T5 rewrite. if _env_flag("PSQ_T5_REWRITE_FALLBACK_TO_LLM", False): log("Rewrite source=t5 returned empty; fallback to llm") return llm_rewrite_prompt(prompt_in, log) return "" def compose_final_prompt(rewritten_prompt: str, selected_tags: List[str]) -> str: parts = _split_prompt_commas(rewritten_prompt) parts.extend(selected_tags) seen = set() out = [] for p in parts: key = _norm_for_dedupe(p) if key in seen: continue seen.add(key) out.append(p) return ", ".join(out) def _display_tag_text(tag: str) -> str: return tag.replace("_", " ") def _display_row_label(name: str) -> str: n = (name or "").strip() if not n: return "" if n == "selected_other": return "Selected (Other)" return n.replace("_", " ").title() def _normalize_selection_origin(origin: str) -> str: o = (origin or "").strip().lower() if o in {"rewrite", "selection", "probe", "structural", "classifier", "user", "candidate"}: return o return "selection" def _choice_label_with_source_meta(tag: str, *, origin: str, preselected: bool) -> str: # Keep labels plain to avoid frontend text/value desynchronization. return _display_tag_text(tag) @lru_cache(maxsize=1) def _load_tag_wiki_defs() -> Dict[str, str]: p = Path("data/tag_wiki_defs.json") if not p.exists(): return {} try: with p.open("r", encoding="utf-8") as f: data = json.load(f) out: Dict[str, str] = {} if isinstance(data, dict): for k, v in data.items(): tag = _norm_tag_for_lookup(str(k)) text = " ".join(str(v or "").split()) if tag and text: out[tag] = text return out except Exception: return {} @lru_cache(maxsize=1) def _load_tag_tooltip_overrides() -> Dict[str, str]: """Load optional per-tag tooltip text overrides. File format: data/tag_tooltip_overrides.csv columns: tag, tooltip_override """ p = Path("data/tag_tooltip_overrides.csv") if not p.exists(): return {} try: out: Dict[str, str] = {} with p.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for row in reader: tag = _norm_tag_for_lookup(str(row.get("tag", ""))) text = " ".join(str(row.get("tooltip_override", "")).split()) if tag and text: out[tag] = text return out except Exception: return {} @lru_cache(maxsize=1) def _load_about_docs_markdown() -> str: candidates = [ DOCS_DIR / "space_overview.md", APP_DIR / "PROJECT_SUMMARY.md", ] for p in candidates: if not p.exists(): continue try: raw = p.read_text(encoding="utf-8") except Exception: continue text = raw.strip() if not text: continue # Strip YAML front matter if present. if text.startswith("---"): parts = text.split("---", 2) if len(parts) >= 3: text = parts[2].strip() if text: return text return ( "Documentation is unavailable.\n\n" "Expected file: `docs/space_overview.md`" ) def _tooltip_text_for_tag(tag: str) -> str: t = _norm_tag_for_lookup(tag) parts: List[str] = [] try: count = get_tag_counts().get(t) except Exception: count = None if isinstance(count, int): parts.append(f"Count: {count:,}") d = _load_tag_tooltip_overrides().get(t, "") if not d: d = _load_tag_wiki_defs().get(t, "") if d: parts.append(d) return "\n".join(parts).strip() @lru_cache(maxsize=1) def _load_arch_diagram_data_uri() -> str: if not ARCH_DIAGRAM_FILE.exists(): return "" try: raw = ARCH_DIAGRAM_FILE.read_bytes() except Exception: return "" if not raw: return "" b64 = base64.b64encode(raw).decode("ascii") return f"data:image/png;base64,{b64}" def _split_about_docs_for_diagram(md: str) -> Tuple[str, str, bool]: text = (md or "").strip() if ARCH_DIAGRAM_MARKER in text: before, after = text.rsplit(ARCH_DIAGRAM_MARKER, 1) return before.strip(), after.strip(), True # Backward compatibility if an explicit architecture heading exists in docs. m_arch = re.search(r"(?m)^##\s+Architecture At A Glance\s*$", text) if m_arch: before = text[: m_arch.start()].strip() after = text[m_arch.end() :].strip() return before, after, True # Preferred insertion point: inject diagram right before "What Each Step Does". m_steps = re.search(r"(?m)^##\s+What Each Step Does\s*$", text) if m_steps: before = text[: m_steps.start()].strip() after = text[m_steps.start() :].strip() return before, after, True return text, "", False def _build_arch_diagram_html() -> str: return """

Architecture At A Glance

User prompt Orchestrator: app.py Concurrent evidence producers Query Reformulation local T5 search phrases Lexical Matching canonical/alias 1- and 2-grams Scene Composition Mistral 24B high-level tags Tag Classifier ModernBERT train-time tags Semantic Retrieval FastText/HNSW candidates Context Rescoring TF-IDF/SVD context Candidate Ranking Mistral 24B indices Final Tag Merge scene, classifier auto, implications Editable output ranked rows and suggested prompt
""" def _selection_source_rank(origin: str) -> int: o = _normalize_selection_origin(origin) if o == "structural": return 0 if o == "classifier": return 1 if o == "probe": return 2 # Keep rewrite/user in the same priority band as general selection for row ordering. return 3 def _build_implied_parent_map( direct_tags_ordered: List[str], implied_tags: List[str], ) -> Dict[str, str]: implied_set = {_norm_tag_for_lookup(t) for t in (implied_tags or []) if t} if not implied_set or not direct_tags_ordered: return {} impl = get_tag_implications() parent_by_implied: Dict[str, str] = {} for direct in direct_tags_ordered: d = _norm_tag_for_lookup(direct) if not d: continue queue = [d] seen = {d} while queue: t = queue.pop() for parent in impl.get(t, ()): p = _norm_tag_for_lookup(parent) if not p or p in seen: continue seen.add(p) if p in implied_set and p not in parent_by_implied: parent_by_implied[p] = d queue.append(p) return parent_by_implied def _order_selected_tags_for_row( *, row_selected_tags: List[str], selected_index: Dict[str, int], tag_selection_origins: Dict[str, str], implied_parent_map: Dict[str, str], ) -> List[str]: row_selected_norm = [_norm_tag_for_lookup(t) for t in (row_selected_tags or []) if t] implied_in_row = {t for t in row_selected_norm if t in implied_parent_map} base_tags = [t for t in row_selected_norm if t not in implied_in_row] base_tags.sort( key=lambda t: ( _selection_source_rank(tag_selection_origins.get(t, "selection")), selected_index.get(t, 10**9), t, ) ) children_by_parent: Dict[str, List[str]] = {} for implied in implied_in_row: parent = implied_parent_map.get(implied) if parent: children_by_parent.setdefault(parent, []).append(implied) for parent, children in children_by_parent.items(): children.sort(key=lambda t: (selected_index.get(t, 10**9), t)) ordered: List[str] = [] emitted: Set[str] = set() for tag in base_tags: if tag in emitted: continue ordered.append(tag) emitted.add(tag) for child in children_by_parent.get(tag, []): if child not in emitted: ordered.append(child) emitted.add(child) remaining_implied = [t for t in row_selected_norm if t not in emitted] remaining_implied.sort( key=lambda t: ( _selection_source_rank(tag_selection_origins.get(implied_parent_map.get(t, ""), "selection")), selected_index.get(implied_parent_map.get(t, ""), 10**9), selected_index.get(t, 10**9), t, ) ) for t in remaining_implied: if t not in emitted: ordered.append(t) emitted.add(t) return ordered def _escape_prompt_tag(tag: str) -> str: return ( tag.replace("_", " ") .replace("(", "\\(") .replace(")", "\\)") ) def _ordered_selected_for_prompt(selected: Set[str], row_defs: List[Dict[str, Any]]) -> List[str]: out: List[str] = [] seen: Set[str] = set() for row in row_defs: for tag in row.get("tags", []): if tag in selected and tag not in seen: out.append(tag) seen.add(tag) return out def _compose_toggle_prompt_text(selected_tags: List[str], row_defs: List[Dict[str, Any]]) -> str: selected = {t for t in (selected_tags or []) if t} ordered = _ordered_selected_for_prompt(selected, row_defs or []) return ", ".join(_escape_prompt_tag(t) for t in ordered) def _is_artist_tag(tag: str) -> bool: t = _norm_tag_for_lookup(str(tag)) if not t: return False # Keep a resilient fallback for malformed/missing tag typing metadata. return get_tag_type_name(t) == "artist" or t.startswith("by_") @lru_cache(maxsize=1) def _load_excluded_recommendation_tags() -> Set[str]: out: Set[str] = set() # Existing category-registry driven exclusions. csv_path = Path("data/category_registry.csv") if not csv_path.exists(): csv_path = Path("data/analysis/category_registry.csv") if csv_path.exists(): try: with csv_path.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for row in reader: tag = _norm_tag_for_lookup(str(row.get("tag") or "")) if not tag: continue status = str(row.get("category_status") or "").strip().lower() if status == "excluded": out.add(tag) except Exception: pass # Corporate-safety exclusions (editable runtime list). corp_path = Path("data/corporate_excluded_tags.csv") if corp_path.exists(): try: with corp_path.open("r", encoding="utf-8", newline="") as f: reader = csv.DictReader(f) for row in reader: tag = _norm_tag_for_lookup(str(row.get("tag") or "")) if not tag: continue enabled_raw = str(row.get("enabled", "1")).strip().lower() enabled = enabled_raw not in {"0", "false", "no", "off"} if enabled: out.add(tag) except Exception: pass return out def _is_hardblocked_corporate_tag(tag: str) -> bool: t = _norm_tag_for_lookup(str(tag)) if not t: return False return any(rx.search(t) for rx in _CORPORATE_HARDBLOCK_PATTERNS) def _is_excluded_recommendation_tag(tag: str) -> bool: t = _norm_tag_for_lookup(str(tag)) if not t: return False if _is_hardblocked_corporate_tag(t): return True return t in _load_excluded_recommendation_tags() def _get_min_tag_count() -> int: try: return max(0, int(os.environ.get("PSQ_MIN_TAG_COUNT", "100"))) except Exception: return 100 def _filter_min_count_tags(tags: List[str], min_count: int) -> Tuple[List[str], List[str]]: if min_count <= 0: return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), [] tag_counts = get_tag_counts() keep: List[str] = [] removed: List[str] = [] seen: Set[str] = set() for raw in (tags or []): t = _norm_tag_for_lookup(str(raw)) if not t: continue c = int(tag_counts.get(t, 0) or 0) if c < min_count: removed.append(t) continue if t in seen: continue seen.add(t) keep.append(t) return keep, sorted(set(removed)) def _filter_excluded_recommendation_tags(tags: List[str]) -> Tuple[List[str], List[str]]: excluded = _load_excluded_recommendation_tags() if not excluded: return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), [] keep: List[str] = [] removed: List[str] = [] seen: Set[str] = set() for raw in (tags or []): t = _norm_tag_for_lookup(str(raw)) if not t: continue if t in excluded: removed.append(t) continue if t in seen: continue seen.add(t) keep.append(t) return keep, sorted(set(removed)) def _filter_excluded_candidates(candidates: List[Any]) -> Tuple[List[Any], List[str]]: excluded = _load_excluded_recommendation_tags() if not excluded: return list(candidates or []), [] keep: List[Any] = [] removed: List[str] = [] for c in (candidates or []): tag = _norm_tag_for_lookup(str(getattr(c, "tag", "") or "")) if tag and tag in excluded: removed.append(tag) continue keep.append(c) return keep, sorted(set(removed)) def _resolve_tag_classifier_threshold_path(raw: str) -> Optional[Path]: candidates: List[Path] = [] if raw: p = Path(raw) candidates.append(p if p.is_absolute() else APP_DIR / p) model_dir = classifier_model_dir if classifier_model_dir.is_absolute() else APP_DIR / classifier_model_dir candidates.extend( [ model_dir / "classifier_precision_thresholds_calibrated.csv", model_dir / "precision_thresholds_calibrated.csv", ] ) analysis_dir = APP_DIR / "data" / "analysis" try: latest = sorted( analysis_dir.glob("tag_classifier_precision_thresholds_*_calibrated.csv"), key=lambda p: p.stat().st_mtime, reverse=True, ) candidates.extend(latest) except Exception: pass for p in candidates: try: if p.is_file(): return p except Exception: continue return None def _load_tag_classifier_thresholds(path: Optional[Path], target_precision: float) -> Dict[str, float]: if not path or not path.is_file(): return {} out: Dict[str, float] = {} try: with path.open("r", encoding="utf-8", newline="") as f: for row in csv.DictReader(f): try: row_target = float(row.get("target_precision", "")) threshold = float(row.get("threshold", "")) except (TypeError, ValueError): continue if abs(row_target - float(target_precision)) > 1e-6: continue tag = _norm_tag_for_lookup(str(row.get("tag") or "")) if tag: out[tag] = threshold except Exception: return {} return out @lru_cache(maxsize=1) def _load_tag_classifier_bundle() -> Optional[Dict[str, Any]]: model_dir = classifier_model_dir if classifier_model_dir.is_absolute() else APP_DIR / classifier_model_dir labels_path = model_dir / "labels.json" if not model_dir.is_dir() or not labels_path.is_file(): return None threshold_path = _resolve_tag_classifier_threshold_path(classifier_threshold_path) thresholds = _load_tag_classifier_thresholds( threshold_path, target_precision=classifier_auto_precision_target, ) try: import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer except Exception: return None try: labels_raw = json.loads(labels_path.read_text(encoding="utf-8")) labels = [_norm_tag_for_lookup(str(x)) for x in labels_raw] tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True) device_raw = (os.environ.get("PSQ_TAG_CLASSIFIER_DEVICE", "cpu") or "cpu").strip().lower() if device_raw == "auto": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(device_raw) load_kwargs: Dict[str, Any] = {"local_files_only": True} cpu_dtype = (os.environ.get("PSQ_TAG_CLASSIFIER_CPU_DTYPE", "float16") or "").strip().lower() if device.type == "cpu" and cpu_dtype in {"float16", "fp16", "half"}: # Keep the deployed classifier resident in fp16 on CPU by default. # This reduces RAM pressure on the Space; set CPU_DTYPE=float32 if needed. load_kwargs["dtype"] = torch.float16 model = AutoModelForSequenceClassification.from_pretrained(model_dir, **load_kwargs) if device.type == "cpu" and cpu_dtype in {"float32", "fp32", "full"}: model.float() model.to(device) model.eval() return { "torch": torch, "tokenizer": tokenizer, "model": model, "device": device, "labels": labels, "thresholds": thresholds, "threshold_path": str(threshold_path) if threshold_path else "", "model_dir": str(model_dir), } except Exception: return None def _classifier_tag_allowed(tag: str, min_tag_count: int, *, allow_nsfw: bool) -> bool: t = _norm_tag_for_lookup(tag) if not t or _is_artist_tag(t) or _is_excluded_recommendation_tag(t): return False if min_tag_count > 0: try: if int(get_tag_counts().get(t, 0) or 0) < min_tag_count: return False except Exception: return False if not allow_nsfw: try: if t in get_nsfw_tags(): return False except Exception: pass return True def _run_tag_classifier( prompt_in: str, *, min_tag_count: int, allow_nsfw: bool, log=None, ) -> Dict[str, Any]: t_func0 = time.perf_counter() empty = { "auto_tags": [], "candidate_tags": [], "score_by_tag": {}, "threshold_by_tag": {}, "enabled": False, "elapsed_s": 0.0, } if not classifier_enabled: return empty bundle = _load_tag_classifier_bundle() if not bundle: if log: log("Classifier: unavailable; skipping classifier auto/candidate tags") return empty torch = bundle["torch"] tokenizer = bundle["tokenizer"] model = bundle["model"] device = bundle["device"] labels: List[str] = bundle["labels"] thresholds: Dict[str, float] = bundle["thresholds"] max_len = max(16, int(os.environ.get("PSQ_TAG_CLASSIFIER_MAX_LEN", "160"))) text = " ".join(str(prompt_in or "").strip().split()) if not text: return empty try: with torch.no_grad(): enc = tokenizer( [text], padding=True, truncation=True, max_length=max_len, return_tensors="pt", ) enc = {k: v.to(device) for k, v in enc.items()} probs = torch.sigmoid(model(**enc).logits)[0].detach().cpu().tolist() except Exception as e: if log: log(f"Classifier: failed during inference; skipping ({type(e).__name__}: {_redact_console_error_text(e)})") return empty ranked = sorted( ((labels[i], float(score)) for i, score in enumerate(probs) if i < len(labels)), key=lambda kv: kv[1], reverse=True, ) auto_tags: List[str] = [] candidate_tags: List[str] = [] score_by_tag: Dict[str, float] = {} threshold_by_tag: Dict[str, float] = {} auto_set: Set[str] = set() for tag, score in ranked: if not _classifier_tag_allowed(tag, min_tag_count, allow_nsfw=allow_nsfw): continue threshold = thresholds.get(tag) if threshold is None or score < threshold: continue auto_tags.append(tag) auto_set.add(tag) score_by_tag[tag] = score threshold_by_tag[tag] = float(threshold) for tag, score in ranked: if len(candidate_tags) >= classifier_candidate_top_k: break if tag in auto_set: continue if not _classifier_tag_allowed(tag, min_tag_count, allow_nsfw=allow_nsfw): continue candidate_tags.append(tag) score_by_tag[tag] = score if log: elapsed_s = time.perf_counter() - t_func0 log( "Classifier: " f"auto={len(auto_tags)} target_precision={classifier_auto_precision_target:.2f} " f"candidate_top_k={len(candidate_tags)} " f"device={device} elapsed={elapsed_s:.2f}s" ) if auto_tags: shown = ", ".join(f"{t}:{score_by_tag.get(t, 0.0):.3f}" for t in auto_tags[:20]) log("Classifier auto tags: " + shown + (" ..." if len(auto_tags) > 20 else "")) if candidate_tags: shown = ", ".join(f"{t}:{score_by_tag.get(t, 0.0):.3f}" for t in candidate_tags[:20]) log("Classifier candidate tags: " + shown + (" ..." if len(candidate_tags) > 20 else "")) return { "auto_tags": auto_tags, "candidate_tags": candidate_tags, "score_by_tag": score_by_tag, "threshold_by_tag": threshold_by_tag, "enabled": True, "threshold_path": bundle.get("threshold_path", ""), "elapsed_s": time.perf_counter() - t_func0, } def _dedupe_norm_tags(tags: List[str]) -> List[str]: out: List[str] = [] seen: Set[str] = set() for raw in (tags or []): t = _norm_tag_for_lookup(str(raw)) if not t or t in seen: continue seen.add(t) out.append(t) return out def _collect_visible_tags(row_defs: List[Dict[str, Any]]) -> Set[str]: out: Set[str] = set() for row in (row_defs or []): for t in _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else []): out.add(t) return out def _collect_selected_from_state( selected_tags_state: List[str], row_defs: List[Dict[str, Any]], ) -> List[str]: visible_tags = _collect_visible_tags(row_defs) if not visible_tags: return [] selected: List[str] = [] seen: Set[str] = set() visible_by_norm = {_norm_tag_for_lookup(t): t for t in visible_tags} for raw in (selected_tags_state or []): t = _norm_tag_for_lookup(str(raw)) if not t: continue mapped = t if t in visible_tags else visible_by_norm.get(t) if not mapped or mapped in seen: continue seen.add(mapped) selected.append(mapped) return selected def _collect_selected_from_row_values( row_defs: List[Dict[str, Any]], row_values_state: List[List[str]], ) -> List[str]: selected: List[str] = [] seen: Set[str] = set() values = list(row_values_state or []) for idx, row in enumerate(row_defs or []): row_tags = _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else []) if not row_tags: continue row_tag_set = set(row_tags) row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags} raw_vals = values[idx] if 0 <= idx < len(values) else [] for raw in (raw_vals or []): if raw in row_tag_set: if raw not in seen: seen.add(raw) selected.append(raw) continue raw_norm = _norm_tag_for_lookup(str(raw)) mapped = row_tag_by_norm.get(raw_norm) if mapped and mapped not in seen: seen.add(mapped) selected.append(mapped) return selected def _build_toggle_rows( *, seed_terms: List[str], selected_tags: List[str], retrieved_candidate_tags: List[str], tag_selection_origins: Dict[str, str], implied_parent_map: Dict[str, str], top_groups: int, top_tags_per_group: int, group_rank_top_k: int, ) -> List[Dict[str, Any]]: ranked_rows = rank_groups_from_tfidf( seed_terms=seed_terms, top_groups=max(1, int(top_groups)), top_tags_per_group=max(1, int(top_tags_per_group)), group_rank_top_k=max(1, int(group_rank_top_k)), ) groups_map = _load_enabled_groups() selected_active = list( dict.fromkeys( _norm_tag_for_lookup(t) for t in selected_tags if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t) ) ) selected_index: Dict[str, int] = {t: i for i, t in enumerate(selected_active)} selected_set: Set[str] = set(selected_active) def _is_pipeline_preselected(tag: str) -> bool: # Tags explicitly user-toggled should keep user coloring after row rebuilds. # Only non-user selected tags should render as pipeline-preselected. if tag not in selected_set: return False return _normalize_selection_origin(tag_selection_origins.get(tag, "selection")) != "user" row_defs: List[Dict[str, Any]] = [] enabled_group_tag_sets: Dict[str, Set[str]] = { name: {t for t in tags if not _is_artist_tag(t)} for name, tags in groups_map.items() } tags_in_any_enabled_group: Set[str] = set() for tag_set in enabled_group_tag_sets.values(): tags_in_any_enabled_group.update(tag_set) displayed_group_names = [r.group_name for r in ranked_rows] displayed_group_tag_sets: Dict[str, Set[str]] = { name: enabled_group_tag_sets.get(name, set()) for name in displayed_group_names } tags_in_any_displayed_group: Set[str] = set() for tag_set in displayed_group_tag_sets.values(): tags_in_any_displayed_group.update(tag_set) retrieved_uncategorized_ranked = list( dict.fromkeys( _norm_tag_for_lookup(t) for t in (retrieved_candidate_tags or []) if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t) and _norm_tag_for_lookup(t) not in tags_in_any_enabled_group ) ) retrieved_other_row: Dict[str, Any] | None = None if retrieved_uncategorized_ranked: retrieved_uncategorized_set = set(retrieved_uncategorized_ranked) selected_in_retrieved_other_raw = [ t for t in selected_active if t in retrieved_uncategorized_set ] selected_in_retrieved_other = _order_selected_tags_for_row( row_selected_tags=selected_in_retrieved_other_raw, selected_index=selected_index, tag_selection_origins=tag_selection_origins, implied_parent_map=implied_parent_map, ) merged_retrieved_other = selected_in_retrieved_other + [ t for t in retrieved_uncategorized_ranked if t not in selected_in_retrieved_other ] merged_retrieved_other = _dedupe_norm_tags(merged_retrieved_other) keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_retrieved_other)) merged_retrieved_other = merged_retrieved_other[:keep_n] retrieved_other_meta = { t: { "origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")), "preselected": _is_pipeline_preselected(t), } for t in merged_retrieved_other } retrieved_other_row = { "name": "other_retrieved", "label": "Other (Retrieved)", "tags": merged_retrieved_other, "tag_meta": retrieved_other_meta, } # "Selected (Other)" should contain selected tags not already shown in any displayed row. # Include "Other (Retrieved)" in that displayed-row set to avoid duplicates across those rows. tags_in_displayed_rows = set(tags_in_any_displayed_group) if retrieved_other_row: tags_in_displayed_rows.update(retrieved_other_row.get("tags", [])) selected_other_raw = [t for t in selected_active if t not in tags_in_displayed_rows] selected_other = _order_selected_tags_for_row( row_selected_tags=selected_other_raw, selected_index=selected_index, tag_selection_origins=tag_selection_origins, implied_parent_map=implied_parent_map, ) selected_other = _dedupe_norm_tags(selected_other) selected_other_meta = { t: { "origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")), "preselected": _is_pipeline_preselected(t), } for t in selected_other } row_defs.append( { "name": "selected_other", "label": _display_row_label("selected_other"), "tags": selected_other, "tag_meta": selected_other_meta, } ) for row in ranked_rows: group_name = row.group_name group_tag_set = displayed_group_tag_sets.get(group_name, set()) selected_in_group_raw = [t for t in selected_active if t in group_tag_set] selected_in_group = _order_selected_tags_for_row( row_selected_tags=selected_in_group_raw, selected_index=selected_index, tag_selection_origins=tag_selection_origins, implied_parent_map=implied_parent_map, ) ranked_tags = [ _norm_tag_for_lookup(t) for t, _ in row.tags if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t) ] ranked_tags = _dedupe_norm_tags(ranked_tags) merged = selected_in_group + [t for t in ranked_tags if t not in selected_in_group] merged = _dedupe_norm_tags(merged) keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_group)) merged = merged[:keep_n] tag_meta = { t: { "origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")), "preselected": _is_pipeline_preselected(t), } for t in merged } row_defs.append( { "name": group_name, "label": _display_row_label(group_name), "tags": merged, "tag_meta": tag_meta, } ) # Keep this row at the bottom so category/group rows remain contiguous. if retrieved_other_row: row_defs.append(retrieved_other_row) return row_defs def _build_display_audit_line( row_defs: List[Dict[str, Any]], *, active_selected_tags: List[str], direct_selected_tags: List[str], implied_selected_tags: List[str], ) -> str: active_set = { _norm_tag_for_lookup(t) for t in (active_selected_tags or []) if t and not _is_artist_tag(t) } direct_set = { _norm_tag_for_lookup(t) for t in (direct_selected_tags or []) if t and not _is_artist_tag(t) } implied_set = { _norm_tag_for_lookup(t) for t in (implied_selected_tags or []) if t and not _is_artist_tag(t) } info_by_tag: Dict[str, Dict[str, Any]] = {} for row in row_defs or []: row_name = row.get("name", "") row_label = row.get("label", row_name) for tag in row.get("tags", []): rec = info_by_tag.setdefault(tag, {"rows": [], "sources": set()}) rec["rows"].append(row_label) if row_name == "selected_other": rec["sources"].add("selected_other_row") elif row_name == "other_retrieved": rec["sources"].add("other_retrieved_row") else: rec["sources"].add("ranked_group_row") if tag in active_set: rec["sources"].add("selected_active") if tag in direct_set: rec["sources"].add("selected_direct") if tag in implied_set: rec["sources"].add("selected_implied") payload = { "n_tags": len(info_by_tag), "tags": [ { "tag": tag, "rows": rec["rows"], "sources": sorted(rec["sources"]), } for tag, rec in sorted(info_by_tag.items()) ], } return "Display Tag Audit: " + json.dumps(payload, ensure_ascii=True) def _build_tooltip_payload(row_defs: List[Dict[str, Any]], max_rows: int) -> str: row_defs_ui = (row_defs or [])[: max(0, int(max_rows))] tips: Dict[str, str] = {} rows: List[List[str]] = [] meta_rows: List[List[Dict[str, Any]]] = [] for row in row_defs_ui: tags = _dedupe_norm_tags(row.get("tags", []) if isinstance(row, dict) else []) rows.append(tags) row_meta = row.get("tag_meta", {}) if isinstance(row, dict) and isinstance(row.get("tag_meta", {}), dict) else {} meta_row: List[Dict[str, Any]] = [] for t in tags: raw_meta = row_meta.get(t, {}) if isinstance(row_meta.get(t, {}), dict) else {} meta_row.append( { "origin": _normalize_selection_origin(str(raw_meta.get("origin", "selection"))), "preselected": bool(raw_meta.get("preselected", False)), } ) meta_rows.append(meta_row) for t in tags: if t not in tips: tips[t] = _tooltip_text_for_tag(t) return json.dumps({"rows": rows, "meta_rows": meta_rows, "tips": tips}, ensure_ascii=True) def _build_row_component_updates( row_defs: List[Dict[str, Any]], selected_tags: List[str], max_rows: int, ): selected = {t for t in (selected_tags or []) if t} row_defs_ui = (row_defs or [])[: max(0, int(max_rows))] row_values_state: List[List[str]] = [] header_updates = [] checkbox_updates = [] for idx in range(max_rows): if idx < len(row_defs_ui): row = row_defs_ui[idx] tags = _dedupe_norm_tags(row.get("tags", [])) values = [t for t in tags if t in selected] row_values_state.append(values) visible = bool(tags) header_updates.append(gr.update(value=row.get("label", ""), visible=visible)) tag_meta = row.get("tag_meta", {}) if isinstance(row.get("tag_meta", {}), dict) else {} choices = [] for t in tags: meta = tag_meta.get(t, {}) if isinstance(tag_meta.get(t, {}), dict) else {} origin = _normalize_selection_origin(str(meta.get("origin", "selection"))) preselected = bool(meta.get("preselected", False)) choices.append((_choice_label_with_source_meta(t, origin=origin, preselected=preselected), t)) checkbox_updates.append( gr.update( choices=choices, value=values, visible=visible, ) ) else: header_updates.append(gr.update(value="", visible=False)) checkbox_updates.append(gr.update(choices=[], value=[], visible=False)) prompt_text = _compose_toggle_prompt_text(list(selected), row_defs_ui) return prompt_text, row_values_state, header_updates, checkbox_updates def _on_toggle_row( row_idx: int, changed_values: List[str], selected_tags_state: List[str], rows_dirty_state: bool, row_defs_state: List[Dict[str, Any]], row_values_state: List[List[str]], max_rows: int, ): row_defs = row_defs_state or [] row_defs_ui = row_defs[: max(0, int(max_rows))] prev_values = list(row_values_state or []) selected_from_state = _collect_selected_from_state(selected_tags_state, row_defs_ui) selected_from_rows = _collect_selected_from_row_values(row_defs_ui, prev_values) # Prefer row-value state as source-of-truth (closest to visible UI), with selected-state as fallback. selected: Set[str] = set(selected_from_rows or selected_from_state) row = row_defs_ui[row_idx] if 0 <= row_idx < len(row_defs_ui) else {} row_tags = _dedupe_norm_tags(row.get("tags", [])) row_label = str(row.get("label", "")) row_tag_set = set(row_tags) row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags} # Be tolerant to UI payload forms: canonical tag values, display labels, normalized variants, # and occasional single-string payloads from frontend events. if changed_values is None: changed_iter: List[Any] = [] elif isinstance(changed_values, str): changed_iter = [changed_values] elif isinstance(changed_values, (list, tuple, set)): changed_iter = list(changed_values) else: changed_iter = [changed_values] # Be tolerant to UI payload forms: canonical tag values, display labels, or normalized variants. new_set: Set[str] = set() for raw in changed_iter: if raw in row_tag_set: new_set.add(raw) continue raw_norm = _norm_tag_for_lookup(str(raw)) mapped = row_tag_by_norm.get(raw_norm) if mapped: new_set.add(mapped) prev_row_selected = {t for t in row_tags if t in selected} # Ignore non-user/no-op events (e.g., programmatic value re-sets) deterministically. if new_set == prev_row_selected: prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui) checkbox_updates = [gr.skip() for _ in range(max_rows)] return [sorted(selected), rows_dirty_state, gr.skip(), prev_values, prompt_text, *checkbox_updates] selected.difference_update(row_tag_set) selected.update(new_set) toggled_tags = prev_row_selected ^ new_set new_row_values_state: List[List[str]] = [] affected_rows: Set[int] = {row_idx} for idx, row_item in enumerate(row_defs_ui): tags = _dedupe_norm_tags(row_item.get("tags", [])) values = [t for t in tags if t in selected] new_row_values_state.append(values) if toggled_tags and any(t in toggled_tags for t in tags): affected_rows.add(idx) checkbox_updates = [] for idx in range(max_rows): if idx >= len(row_defs_ui): checkbox_updates.append(gr.skip()) continue if idx in affected_rows: checkbox_updates.append(gr.update(value=new_row_values_state[idx])) else: checkbox_updates.append(gr.skip()) prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui) return [ sorted(selected), True, gr.update(visible=True, interactive=True), new_row_values_state, prompt_text, *checkbox_updates, ] def _build_ui_payload( *, console_text: str, row_defs: List[Dict[str, Any]], selected_tags: List[str], pipeline_status_html: str | None = None, suggested_prompt_text: str | None = None, suggested_prompt_label: str | None = None, ): prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates( row_defs=row_defs, selected_tags=selected_tags, max_rows=display_max_rows_default, ) if suggested_prompt_text is not None: prompt_text = str(suggested_prompt_text) effective_label = suggested_prompt_label or SUGGESTED_PROMPT_LABEL_READY selected_ui: List[str] = [] selected_ui_seen: Set[str] = set() for vals in row_values_state: for t in vals: if t in selected_ui_seen: continue selected_ui_seen.add(t) selected_ui.append(t) tooltip_payload = _build_tooltip_payload(row_defs, display_max_rows_default) return [ console_text, gr.update(visible=bool(row_defs)), tooltip_payload, gr.update(value=pipeline_status_html) if pipeline_status_html is not None else gr.skip(), gr.update(value=prompt_text, label=effective_label), selected_ui, False, gr.update(visible=False, interactive=False), row_defs, row_values_state, *header_updates, *checkbox_updates, ] _PIPELINE_STAGE_DEFS: Tuple[Tuple[str, str, str], ...] = ( ("rewrite", "Query Reformulation", "local T5"), ("exact", "Lexical Matching", "lexical"), ("structural", "Scene Composition", "Mistral 24B"), ("classifier", "Tag Classifier", "ModernBERT"), ("retrieval", "Semantic Retrieval", "FastText/HNSW"), ("reranker", "Candidate Ranking", "Mistral 24B"), ) def _pipeline_status_html(states: Optional[Dict[str, Dict[str, Any]]] = None) -> str: states = states or {} symbol_by_status = { "waiting": "○", "running": "⟳", "done": "✓", "failed": "!", "skipped": "-", } parts = ['
'] for key, name, implementation in _PIPELINE_STAGE_DEFS: raw = states.get(key, {}) or {} status = str(raw.get("status") or "waiting").lower() if status not in symbol_by_status: status = "waiting" outcome_parts: List[str] = [] if isinstance(raw.get("time_s"), (int, float)): outcome_parts.append(f"{float(raw['time_s']):.1f}s") outcome_html = "" if outcome_parts: outcome_html = f'{" · ".join(html.escape(p) for p in outcome_parts)}' parts.append( ''.format(html.escape(status)) + f'{symbol_by_status[status]}' + f'{html.escape(name)}' + f'{html.escape(implementation)}' + outcome_html + "" ) parts.append("
") return "".join(parts) def _parse_console_stage_seconds(console_text: str, stage_name: str) -> Optional[float]: pattern = re.compile(rf"^{re.escape(stage_name)}:\s+([0-9]+(?:\.[0-9]+)?)s\s*$") for line in reversed(str(console_text or "").splitlines()): m = pattern.match(line.strip()) if not m: continue try: return float(m.group(1)) except ValueError: return None return None def _format_user_facing_error(exc: Exception) -> str: msg = str(exc or "").strip() msg_l = msg.lower() if "rewrite: empty output" in msg_l: return ( "Could not rewrite that prompt. Try simpler, neutral wording and remove sensitive phrasing, " "then click Run again." ) if "openrouter_api_key" in msg_l: return "Service configuration is missing. Please contact the app owner." if "timed out" in msg_l: return "The model request timed out. Please try again with a shorter or simpler prompt." if "index selection failed" in msg_l: return "Tag selection failed for this request. Please try again." if "startup preflight failed" in msg_l: return "App startup checks failed. Please contact the app owner." return "Something went wrong while processing the prompt. Please try again." def _prepare_run_ui() -> List[Any]: header_updates = [gr.update(value="", visible=False) for _ in range(display_max_rows_default)] checkbox_updates = [ gr.update(choices=[], value=[], visible=False) for _ in range(display_max_rows_default) ] return [ "Running...", gr.skip(), "{}", _pipeline_status_html({"rewrite": {"status": "running"}}), gr.update( value="", label=SUGGESTED_PROMPT_LABEL_READY, ), [], False, gr.update(visible=False, interactive=False), [], [], *header_updates, *checkbox_updates, ] def _prepare_run_ui_with_rewrite_state() -> List[Any]: return [*_prepare_run_ui(), ""] def _recognized_rewrite_tags(rewritten: str, min_tag_count: int) -> List[str]: pieces = [p.strip() for p in (rewritten or "").split(",") if p.strip()] norm = list(dict.fromkeys(_norm_tag_for_lookup(p) for p in pieces if p)) norm = [t for t in norm if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t)] if not norm: return [] tag_counts = get_tag_counts() out: List[str] = [] for t in norm: c = int(tag_counts.get(t, 0) or 0) if c <= 0: continue if min_tag_count > 0 and c < min_tag_count: continue out.append(t) return out def _rewrite_preview_ui( user_prompt: str, ) -> List[Any]: logs: List[str] = [] def log(s: Any) -> None: logs.append(str(s)) prompt_in = (user_prompt or "").strip() if not prompt_in: return [*_build_ui_payload( console_text="Error: empty prompt", row_defs=[], selected_tags=[], pipeline_status_html=_pipeline_status_html(), suggested_prompt_text='Enter a prompt and click "Run".', ), ""] log("Start: received prompt") log("Input:") log(prompt_in) log("") log("Step 1/2: query reformulation preview") try: t0 = time.perf_counter() rewritten = _rewrite_prompt(prompt_in, log) dt = time.perf_counter() - t0 if not rewritten: raise RuntimeError("Rewrite: empty output") log(f"Rewrite: {dt:.2f}s") log("Rewrite:") log(rewritten) min_tag_count = _get_min_tag_count() recognized = _recognized_rewrite_tags(rewritten, min_tag_count=min_tag_count) t_exact0 = time.perf_counter() exact_query_phrases = extract_exact_tag_query_phrases( prompt_in, get_tag_counts(), get_alias2tags(), min_tag_count=min_tag_count, max_ngram=max(0, retrieval_exact_ngram_max), ) exact_query_phrases, _removed_exact_excluded = _filter_excluded_recommendation_tags(exact_query_phrases) exact_dt = time.perf_counter() - t_exact0 if recognized: log(f"Preview: recognized {len(recognized)} rewrite tags") else: log("Preview: no recognized rewrite tags; continuing") status_html = _pipeline_status_html( { "rewrite": { "status": "done", "time_s": dt, "detail": f"{len(recognized)} tags" if recognized else "no recognized tags", }, "exact": { "status": "done", "time_s": exact_dt, "detail": f"{len(exact_query_phrases)} matches" if exact_query_phrases else "no matches", }, } ) log("Step 2/2: retrieval + selection in progress...") return [*_build_ui_payload( console_text="\n".join(logs), row_defs=[], selected_tags=[], pipeline_status_html=status_html, suggested_prompt_text="", suggested_prompt_label=SUGGESTED_PROMPT_LABEL_READY, ), rewritten] except Exception as e: log(f"Rewrite preview failed: {type(e).__name__}: {_redact_console_error_text(e)}") return [*_build_ui_payload( console_text="\n".join(logs), row_defs=[], selected_tags=[], pipeline_status_html=_pipeline_status_html({"rewrite": {"status": "failed", "detail": "see console"}}), suggested_prompt_text="Rewrite failed; see console details.", suggested_prompt_label=SUGGESTED_PROMPT_LABEL_PHASE_REWRITE, ), ""] def _update_run_button_visibility(prompt_text: str, last_run_prompt: str): curr = (prompt_text or "").strip() last = (last_run_prompt or "").strip() can_run = bool(curr) and curr != last return gr.update(visible=can_run, interactive=can_run) def _mark_run_triggered(prompt_text: str): curr = (prompt_text or "").strip() return gr.update(visible=False, interactive=False), curr def _rebuild_rows_from_selected( selected_tags_state: List[str], row_defs_state: List[Dict[str, Any]], row_values_state: List[List[str]], display_top_groups: float, display_top_tags_per_group: float, display_rank_top_k: float, ): existing_rows = row_defs_state or [] existing_values = list(row_values_state or []) selected_from_state = _collect_selected_from_state(selected_tags_state, existing_rows) selected_from_rows = _collect_selected_from_row_values(existing_rows, existing_values) # Rebuild source-of-truth is current row checkbox values; fall back only when unavailable. selected_seed = selected_from_rows if existing_values else selected_from_state selected_active = list( dict.fromkeys( _norm_tag_for_lookup(t) for t in selected_seed if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t) ) ) selected_active_set: Set[str] = set(selected_active) retrieved_candidate_tags: List[str] = [] tag_selection_origins: Dict[str, str] = {} for row in existing_rows: row_tags = row.get("tags", []) if isinstance(row, dict) else [] row_meta = row.get("tag_meta", {}) if isinstance(row, dict) else {} if not isinstance(row_meta, dict): row_meta = {} for t in row_tags: tn = _norm_tag_for_lookup(t) if not tn or _is_artist_tag(tn) or _is_excluded_recommendation_tag(tn): continue retrieved_candidate_tags.append(tn) if tn not in tag_selection_origins: meta = row_meta.get(t, {}) if isinstance(row_meta.get(t, {}), dict) else {} meta_origin = _normalize_selection_origin(str(meta.get("origin", "selection"))) meta_preselected = bool(meta.get("preselected", False)) # Preserve explicit user toggles across rebuild: # if a currently selected tag was not pipeline-preselected, keep user provenance. if tn in selected_active_set and not meta_preselected: meta_origin = "user" tag_selection_origins[tn] = meta_origin for t in selected_active: tag_selection_origins.setdefault(t, "user") retrieved_candidate_tags.append(t) implied_selected_tags = [t for t in selected_active if tag_selection_origins.get(t) == "implied"] implied_set = set(implied_selected_tags) direct_selected_tags = [t for t in selected_active if t not in implied_set] direct_idx = {t: i for i, t in enumerate(direct_selected_tags)} direct_selected_tags.sort( key=lambda t: ( _selection_source_rank(tag_selection_origins.get(t, "selection")), direct_idx.get(t, 10**9), ) ) implied_parent_map = _build_implied_parent_map( direct_tags_ordered=direct_selected_tags, implied_tags=implied_selected_tags, ) toggle_rows = _build_toggle_rows( seed_terms=list(selected_active), selected_tags=selected_active, retrieved_candidate_tags=list(dict.fromkeys(retrieved_candidate_tags)), tag_selection_origins=tag_selection_origins, implied_parent_map=implied_parent_map, top_groups=max(1, int(display_top_groups)), top_tags_per_group=max(1, int(display_top_tags_per_group)), group_rank_top_k=max(1, int(display_rank_top_k)), ) prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates( row_defs=toggle_rows, selected_tags=selected_active, max_rows=display_max_rows_default, ) tooltip_payload = _build_tooltip_payload(toggle_rows, display_max_rows_default) return [ gr.update(visible=bool(toggle_rows)), tooltip_payload, prompt_text, sorted(selected_active), False, gr.update(visible=False, interactive=False), toggle_rows, row_values_state, *header_updates, *checkbox_updates, ] def _build_selection_query( prompt_in: str, rewritten: str, structural_tags: List[str], probe_tags: List[str], classifier_auto_tags: Optional[List[str]] = None, ) -> str: lines = [f"IMAGE DESCRIPTION: {prompt_in.strip()}"] if rewritten and rewritten.strip(): lines.append(f"REWRITE PHRASES: {rewritten.strip()}") hint_tags = [] if structural_tags: hint_tags.extend(structural_tags) if probe_tags: hint_tags.extend(probe_tags) if classifier_auto_tags: hint_tags.extend(classifier_auto_tags) if hint_tags: # Keep hints as context only; selection still must choose by candidate indices. lines.append( "INFERRED TAG HINTS (context only): " + ", ".join(sorted(set(hint_tags))) ) return "\n".join(lines) # Set up logging # Minimal prod logging: warnings+ to stderr, no file by default import os, logging LOG_LEVEL = os.environ.get("PSQ_LOG_LEVEL", "WARNING").upper() logging.basicConfig( level=getattr(logging, LOG_LEVEL, logging.WARNING), format="%(asctime)s %(levelname)s:%(message)s", handlers=[logging.StreamHandler()] # no file -> avoids huge logs on Spaces ) # Quiet down common noisy libs (optional) for _name in ("gensim", "gradio", "hnswlib", "httpx", "uvicorn"): logging.getLogger(_name).setLevel(logging.ERROR) # Turn off Gradio analytics phone-home to avoid those background thread errors (optional) os.environ["GRADIO_ANALYTICS_ENABLED"] = "0" MASCOT_DIR = Path(__file__).parent / "mascotimages" MASCOT_FILE = MASCOT_DIR / "transparentsquirrel.png" def _load_mascot_image(): """Load mascot image if available; return None when missing/unreadable.""" if not MASCOT_FILE.exists(): logging.warning("Mascot image missing: %s", MASCOT_FILE) return None try: return Image.open(MASCOT_FILE).convert("RGBA") except Exception as e: logging.warning("Failed to load mascot image (%s): %s", MASCOT_FILE, e) return None try: from gradio_client import utils as _gc_utils _orig_get_type = _gc_utils.get_type _orig_j2p = _gc_utils._json_schema_to_python_type _orig_pub = _gc_utils.json_schema_to_python_type def _get_type_safe(schema): # Sometimes schema is a bare True/False (JSON Schema boolean form) if not isinstance(schema, dict): return "any" return _orig_get_type(schema) def _j2p_safe(schema, defs=None): # Accept non-dict schemas (True/False/None) and treat as "any" if not isinstance(schema, dict): return "any" return _orig_j2p(schema, defs or schema.get("$defs")) def _pub_safe(schema): # Public wrapper used by Gradio; keep it resilient too if not isinstance(schema, dict): return "any" return _j2p_safe(schema, schema.get("$defs")) _gc_utils.get_type = _get_type_safe _gc_utils._json_schema_to_python_type = _j2p_safe _gc_utils.json_schema_to_python_type = _pub_safe except Exception as e: print("gradio_client hotfix not applied:", e) # ------------------------------------------------------------------------------- allow_nsfw_tags = False def _is_production_runtime() -> bool: """Best-effort detection for deployed runtime (HF Spaces or explicit env).""" if os.environ.get("PSQ_PRODUCTION", "").strip().lower() in {"1", "true", "yes"}: return True if os.environ.get("SPACE_ID"): return True if os.environ.get("HF_SPACE_ID"): return True if os.environ.get("SYSTEM") == "spaces": return True return False verbose_retrieval_default = "0" if _is_production_runtime() else "1" verbose_retrieval = os.environ.get("PSQ_VERBOSE_RETRIEVAL", verbose_retrieval_default).strip().lower() in {"1", "true", "yes"} verbose_retrieval_all = False verbose_retrieval_limit = 20 display_top_groups_default = int(os.environ.get("PSQ_DISPLAY_TOP_GROUPS", "10")) display_top_tags_per_group_default = int(os.environ.get("PSQ_DISPLAY_TOP_TAGS_PER_GROUP", "7")) display_rank_top_k_default = int(os.environ.get("PSQ_DISPLAY_GROUP_RANK_TOP_K", "7")) display_max_rows_default = int(os.environ.get("PSQ_DISPLAY_MAX_ROWS", "14")) retrieval_global_k = int(os.environ.get("PSQ_RETRIEVAL_GLOBAL_K", "300")) retrieval_per_phrase_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_K", "10")) retrieval_per_phrase_final_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_FINAL_K", "1")) retrieval_exact_ngram_max = int(os.environ.get("PSQ_RETRIEVAL_EXACT_NGRAM_MAX", "2")) selection_mode = os.environ.get("PSQ_SELECTION_MODE", "chunked_map_union").strip() selection_chunk_size = int(os.environ.get("PSQ_SELECTION_CHUNK_SIZE", "60")) selection_per_phrase_k = int(os.environ.get("PSQ_SELECTION_PER_PHRASE_K", "2")) selection_candidate_cap = int(os.environ.get("PSQ_SELECTION_CANDIDATE_CAP", "0")) stage1_rewrite_timeout_s = float(os.environ.get("PSQ_TIMEOUT_REWRITE_S", "45")) stage1_struct_timeout_s = float(os.environ.get("PSQ_TIMEOUT_STRUCT_S", "45")) stage1_classifier_timeout_s = float(os.environ.get("PSQ_TIMEOUT_CLASSIFIER_S", "60")) stage3_select_timeout_s = float(os.environ.get("PSQ_TIMEOUT_SELECT_S", "50")) stage3_select_retry_timeout_s = float(os.environ.get("PSQ_TIMEOUT_SELECT_RETRY_S", "20")) stage3_fast_retry_count = max(0, int(os.environ.get("PSQ_STAGE3_FAST_RETRY_COUNT", "1"))) classifier_enabled = os.environ.get("PSQ_TAG_CLASSIFIER_ENABLED", "1").strip().lower() not in {"0", "false", "no", "off"} classifier_model_dir = Path(os.environ.get("PSQ_TAG_CLASSIFIER_MODEL_DIR", "models/finetune/tag-classifier-modernbert-full")) classifier_threshold_path = os.environ.get("PSQ_TAG_CLASSIFIER_THRESHOLDS", "").strip() classifier_auto_precision_target = float(os.environ.get("PSQ_TAG_CLASSIFIER_AUTO_PRECISION", "0.95")) classifier_candidate_top_k = max(0, int(os.environ.get("PSQ_TAG_CLASSIFIER_CANDIDATE_TOP_K", "20"))) timing_log_path = Path(os.environ.get("PSQ_TIMING_LOG_PATH", "data/runtime_metrics/ui_pipeline_timings.jsonl")) def _startup_preflight_errors() -> List[str]: errs: List[str] = [] if not os.getenv("OPENROUTER_API_KEY"): errs.append("OPENROUTER_API_KEY is missing. Set it in Space Secrets or environment variables.") return errs STARTUP_PREFLIGHT_ERRORS = _startup_preflight_errors() if STARTUP_PREFLIGHT_ERRORS: for _err in STARTUP_PREFLIGHT_ERRORS: logging.error("Startup preflight error: %s", _err) _startup_profile_mark( "startup_preflight.done", error_count=len(STARTUP_PREFLIGHT_ERRORS), ) css = """ .scrollable-content{ max-height: 420px; overflow-y: scroll; /* always show scrollbar */ overflow-x: hidden; padding-right: 8px; padding-bottom: 14px; /* <— add this */ scrollbar-gutter: stable; /* prevent layout shift as it fills */ /* Firefox */ scrollbar-width: auto; scrollbar-color: rgba(180,180,180,.9) rgba(0,0,0,.15); } /* WebKit/Chromium (Chrome/Edge/Safari) */ .scrollable-content::-webkit-scrollbar{ width: 10px; } .scrollable-content::-webkit-scrollbar-thumb{ background: rgba(180,180,180,.9); border-radius: 8px; } .scrollable-content::-webkit-scrollbar-track{ background: rgba(0,0,0,.15); } /* (Optional) make both scroll panes taller so they fill more of the column */ .pane-left .scrollable-content, .pane-right .scrollable-content { max-height: 610px; /* was 420px; tweak to taste */ } /* Console: force internal scrolling so full logs are always reachable */ #psq-console, #psq-console .scroll-hide, #psq-console textarea { overflow-y: auto !important; overflow-x: hidden !important; max-height: 420px !important; min-height: 240px !important; } #psq-console textarea { white-space: pre-wrap !important; word-break: break-word !important; font-family: ui-monospace, SFMono-Regular, Menlo, Consolas, "Liberation Mono", monospace !important; } .lego-tags .gr-checkboxgroup, .lego-tags .wrap { display: flex !important; flex-wrap: wrap !important; gap: 10px !important; } .lego-tags label { margin: 0 !important; padding: 0 !important; position: relative !important; } /* Hide native checkbox visuals completely */ .lego-tags input[type="checkbox"] { appearance: none !important; -webkit-appearance: none !important; -moz-appearance: none !important; position: absolute !important; width: 1px !important; height: 1px !important; opacity: 0 !important; pointer-events: none !important; display: none !important; } /* Brick button skin (works for both +span and ~span structures) */ .lego-tags input[type="checkbox"] + span, .lego-tags input[type="checkbox"] ~ span { --on-bg1: #ffd166; --on-bg2: #f39c4a; --on-border: #b86e21; --on-text: #2e1706; position: relative !important; display: inline-flex !important; align-items: center !important; min-height: 40px !important; padding: 10px 15px 9px 22px !important; border: 1px solid #9aa6b8 !important; border-radius: 10px !important; background: linear-gradient(180deg, #dfe5ee 0%, #bec8d6 100%) !important; color: #364254 !important; font-size: 0.97rem !important; font-weight: 800 !important; line-height: 1.15 !important; cursor: pointer !important; user-select: none !important; letter-spacing: 0.01em !important; box-shadow: 0 3px 0 rgba(0,0,0,0.16), inset 0 1px 0 rgba(255,255,255,0.55) !important; transition: transform 0.08s ease, box-shadow 0.08s ease, filter 0.08s ease !important; } .lego-tags input[type="checkbox"] + span::before, .lego-tags input[type="checkbox"] ~ span::before { content: "" !important; position: absolute !important; top: 5px !important; left: 8px !important; width: 8px !important; height: 8px !important; border-radius: 50% !important; background: rgba(255,255,255,0.58) !important; box-shadow: 22px 0 0 rgba(255,255,255,0.58) !important; pointer-events: none !important; } /* Unselected cue: show "+" on the left. */ .lego-tags input[type="checkbox"] + span::after, .lego-tags input[type="checkbox"] ~ span::after { content: "+" !important; position: absolute !important; left: 6px !important; top: 50% !important; transform: translateY(-52%) !important; font-size: 1rem !important; font-weight: 900 !important; color: #4b5563 !important; opacity: 0.95 !important; pointer-events: none !important; } /* Bright color cycle used only when selected */ .lego-tags label:nth-child(8n+1) span { --on-bg1: #ffd166; --on-bg2: #f39c4a; --on-border: #b86e21; --on-text: #2e1706; } .lego-tags label:nth-child(8n+2) span { --on-bg1: #6ee7ff; --on-bg2: #1fb7ff; --on-border: #157cb3; --on-text: #07263c; } .lego-tags label:nth-child(8n+3) span { --on-bg1: #9dff8f; --on-bg2: #45c96f; --on-border: #2a8b4b; --on-text: #0d2917; } .lego-tags label:nth-child(8n+4) span { --on-bg1: #ff8fab; --on-bg2: #ff5c7a; --on-border: #b83956; --on-text: #3f0f1d; } .lego-tags label:nth-child(8n+5) span { --on-bg1: #d0a8ff; --on-bg2: #a46cff; --on-border: #7147b3; --on-text: #25143f; } .lego-tags label:nth-child(8n+6) span { --on-bg1: #ffe27a; --on-bg2: #f7bf39; --on-border: #ad7f1f; --on-text: #332407; } .lego-tags label:nth-child(8n+7) span { --on-bg1: #8effd5; --on-bg2: #2ed6b5; --on-border: #1e947d; --on-text: #0d2a25; } .lego-tags label:nth-child(8n+8) span { --on-bg1: #ffb47e; --on-bg2: #ff8753; --on-border: #b95b2d; --on-text: #391a0a; } /* Source-driven selected colors (applies when tags are preselected by the pipeline). */ .lego-tags label[data-psq-preselected="1"][data-psq-origin="rewrite"] span { --on-bg1: #77f0d7; --on-bg2: #26b9a3; --on-border: #187869; --on-text: #062923; } .lego-tags label[data-psq-preselected="1"][data-psq-origin="selection"] span { --on-bg1: #ffd98a; --on-bg2: #f0a93c; --on-border: #a66f1f; --on-text: #382206; } .lego-tags label[data-psq-preselected="1"][data-psq-origin="probe"] span { --on-bg1: #d8b4ff; --on-bg2: #9a6cff; --on-border: #6745b0; --on-text: #24143b; } .lego-tags label[data-psq-preselected="1"][data-psq-origin="structural"] span { --on-bg1: #a6f79a; --on-bg2: #53c368; --on-border: #2f8442; --on-text: #102d17; } .lego-tags label[data-psq-preselected="1"][data-psq-origin="classifier"] span { --on-bg1: #bed5ff; --on-bg2: #79a7ff; --on-border: #426ad3; --on-text: #10204f; } .lego-tags label[data-psq-preselected="1"][data-psq-origin="implied"] span { --on-bg1: #d7dde8; --on-bg2: #a8b3c4; --on-border: #6f7e95; --on-text: #1d2633; } /* User-selected tags (not initially selected by the pipeline). */ .lego-tags label[data-psq-preselected="0"] span { --on-bg1: #9ec5ff; --on-bg2: #4f86ff; --on-border: #2f5fbf; --on-text: #0b1f42; } .lego-tags label:hover span { filter: brightness(1.02) !important; transform: translateY(1px) !important; } /* ON state: brighter + visibly recessed */ .lego-tags input[type="checkbox"]:checked + span, .lego-tags input[type="checkbox"]:checked ~ span, .lego-tags label:has(input[type="checkbox"]:checked) span { background: linear-gradient(180deg, var(--on-bg1) 0%, var(--on-bg2) 100%) !important; color: var(--on-text) !important; border-color: var(--on-border) !important; filter: saturate(1.2) brightness(1.12) !important; transform: translateY(-2px) !important; box-shadow: inset 0 3px 6px rgba(0,0,0,0.20), inset 0 -1px 0 rgba(255,255,255,0.36), 0 6px 0 rgba(0,0,0,0.32) !important; } .lego-tags input[type="checkbox"]:checked + span::after, .lego-tags input[type="checkbox"]:checked ~ span::after, .lego-tags label:has(input[type="checkbox"]:checked) span::after { content: "" !important; } .source-legend { display: flex; flex-wrap: wrap; align-items: center; gap: 8px; margin: 4px 0 10px 0; } .source-legend .legend-title { font-size: 0.92rem; font-weight: 900; color: #334155; margin-right: 4px; } .source-legend .chip { display: inline-flex; align-items: center; border-radius: 10px; border: 1px solid #6c7788; padding: 6px 12px; font-size: 0.85rem; font-weight: 800; color: #111827; background: #f3f6fb; } .source-legend .chip.rewrite { background: #26b9a3; color: #062923; border-color: #187869; } .source-legend .chip.selection { background: #f0a93c; color: #382206; border-color: #a66f1f; } .source-legend .chip.probe { background: #9a6cff; color: #ffffff; border-color: #6745b0; } .source-legend .chip.structural { background: #53c368; color: #102d17; border-color: #2f8442; } .source-legend .chip.classifier { background: #79a7ff; color: #10204f; border-color: #426ad3; } .source-legend .chip.implied { background: #a8b3c4; color: #1d2633; border-color: #6f7e95; } .source-legend .chip.user { background: #4f86ff; color: #ffffff; border-color: #2f5fbf; } .source-legend .chip.unselected { background: #c7ced8; color: #2d3440; border-color: #7d8897; } .row-heading p { margin: 8px 0 0 0 !important; font-size: 1.18rem !important; font-weight: 850 !important; line-height: 1.2 !important; } .row-instruction { text-align: center; margin: 8px 0 12px 0; } .row-instruction p { margin: 0 !important; font-size: 1.02rem !important; font-style: italic !important; font-weight: 800 !important; color: #1d4ed8 !important; } .about-docs { margin-top: 4px; } .about-docs > p { line-height: 1.42 !important; } .about-docs img { max-width: 100% !important; height: auto !important; border: 1px solid #d2d7e0; border-radius: 10px; background: #ffffff; } .arch-diagram-wrap { margin: 6px 0 10px 0; } .arch-diagram-wrap h2 { margin: 0 0 8px 0 !important; } .arch-flow { display: block; width: 100%; max-width: 1100px; height: auto; margin: 0 auto; border: 1px solid #d2d7e0; border-radius: 8px; background: #ffffff; } .top-instruction { text-align: center; margin: 2px 0 6px 0; } .top-instruction p { margin: 0 !important; font-size: 1.02rem !important; font-style: italic !important; font-weight: 800 !important; color: #1d4ed8 !important; } .run-hint { margin-top: 6px; text-align: center; } .run-hint p { margin: 0 !important; font-size: 0.9rem !important; font-style: italic !important; color: #475569 !important; } .mascot-runtime-hint { margin-top: 6px; text-align: right; } .mascot-runtime-hint p { margin: 0 !important; font-size: 0.9rem !important; font-style: italic !important; color: #475569 !important; } .system-description { margin: 2px 0 8px 0 !important; text-align: center; } .system-description p { margin: 0 !important; font-size: 1.08rem !important; font-weight: 500 !important; font-style: italic !important; color: var(--body-text-color, #111827) !important; opacity: 1.0 !important; } .prompt-card { background: transparent !important; border: none !important; box-shadow: none !important; padding: 0 !important; } /* Mascot image: keep container transparent so PNG alpha shows app background. */ #mascot, #mascot .image-container, #mascot .image-frame, #mascot .wrap { background: transparent !important; border: none !important; box-shadow: none !important; } #mascot .image-container { background-image: none !important; } /* Mascot is decorative: hide image action controls (fullscreen/download/share). */ #mascot button[aria-label*="full" i], #mascot button[aria-label*="download" i], #mascot button[aria-label*="share" i], #mascot a[aria-label*="full" i], #mascot a[aria-label*="download" i], #mascot a[aria-label*="share" i], #mascot .tools, #mascot .tool-icons { display: none !important; } .suggested-prompt-box { margin-top: 2px !important; } .pipeline-status { margin: 0 0 6px 0 !important; min-height: 68px; } .pipeline-status-strip { display: flex; flex-wrap: wrap; align-items: center; gap: 6px; overflow: visible; white-space: normal; min-height: 60px; max-height: 76px; padding: 6px 4px 4px 4px; } .pipeline-stage { display: inline-flex; align-items: center; gap: 5px; box-sizing: border-box; flex: 1 1 calc(33.333% - 6px); min-width: 220px; padding: 4px 8px; border: 1px solid #94a3b8; border-radius: 999px; background: #ffffff; color: #0f172a; font-size: 0.82rem; line-height: 1.1; } @media (max-width: 820px) { .pipeline-stage { flex-basis: calc(50% - 6px); } } @media (max-width: 560px) { .pipeline-stage { flex-basis: 100%; } } .pipeline-name { font-weight: 700; color: #020617 !important; } .pipeline-impl { color: #1f2937 !important; font-weight: 600; } .pipeline-impl::before, .pipeline-outcome::before { content: "/"; color: #475569; margin-right: 5px; font-weight: 700; } .pipeline-outcome { color: #334155 !important; font-weight: 600; } .pipeline-icon { min-width: 0.9em; text-align: center; font-weight: 800; } .pipeline-running { border-color: #1d4ed8; background: #dbeafe; color: #0f172a; } .pipeline-done { border-color: #15803d; background: #dcfce7; color: #0f172a; } .pipeline-failed { border-color: #b91c1c; background: #fee2e2; color: #0f172a; } .pipeline-skipped { border-color: #64748b; background: #f1f5f9; color: #0f172a; } .pipeline-running .pipeline-icon { color: #1d4ed8; } .pipeline-done .pipeline-icon { color: #15803d; } .pipeline-failed .pipeline-icon { color: #b91c1c; } .pipeline-skipped .pipeline-icon { color: #475569; } .suggested-prompt-card { margin-top: 10px !important; } .psq-hidden { display: none !important; } """ client_js = """ () => { const PROMPT_DRAFT_KEY = `psq_prompt_draft_v3:${window.location.hostname}:${window.location.pathname}`; const getPromptInput = () => document.querySelector(".enter-prompt-box textarea, .enter-prompt-box input"); const readPromptDraft = () => { try { return sessionStorage.getItem(PROMPT_DRAFT_KEY) || ""; } catch (_) { return ""; } }; const writePromptDraft = (value) => { const text = (value || "").trim(); try { if (text) { sessionStorage.setItem(PROMPT_DRAFT_KEY, value || ""); } else { sessionStorage.removeItem(PROMPT_DRAFT_KEY); } } catch (_) { // Ignore storage failures. } }; const setNativeInputValue = (el, value) => { const proto = Object.getPrototypeOf(el); const desc = proto ? Object.getOwnPropertyDescriptor(proto, "value") : null; if (desc && typeof desc.set === "function") { desc.set.call(el, value); } else { el.value = value; } }; const restorePromptDraft = () => { const el = getPromptInput(); if (!el) return; const current = el.value || ""; if (current.trim().length > 0) { writePromptDraft(current); return; } const saved = readPromptDraft(); if (!saved) return; setNativeInputValue(el, saved); el.dispatchEvent(new Event("input", { bubbles: true, composed: true })); el.dispatchEvent(new Event("change", { bubbles: true, composed: true })); }; const bindPromptDraftHandlers = () => { if (document.body && document.body.dataset.psqDraftHandlersBound === "1") return; if (!document.body) return; document.body.dataset.psqDraftHandlersBound = "1"; const onFieldEvent = (evt) => { const t = evt && evt.target; if (!(t instanceof HTMLTextAreaElement || t instanceof HTMLInputElement)) return; if (!t.closest(".enter-prompt-box")) return; writePromptDraft(t.value || ""); }; // Capture phase so we persist even during early remount races. document.addEventListener("input", onFieldEvent, true); document.addEventListener("change", onFieldEvent, true); }; const schedulePromptRestore = () => { const delaysMs = [0, 120, 350, 900]; delaysMs.forEach((ms) => { window.setTimeout(() => restorePromptDraft(), ms); }); }; const readTooltipMapRaw = () => { const el = document.querySelector("#psq-tooltip-map textarea, #psq-tooltip-map input"); if (!el) return ""; return (el.value || "").trim(); }; const readTooltipMap = () => { const raw = readTooltipMapRaw(); if (!raw) return { rows: [], meta_rows: [], tips: {} }; try { const obj = JSON.parse(raw); if (!obj || typeof obj !== "object") return { rows: [], meta_rows: [], tips: {} }; const rows = Array.isArray(obj.rows) ? obj.rows : []; const meta_rows = Array.isArray(obj.meta_rows) ? obj.meta_rows : []; const tips = (obj.tips && typeof obj.tips === "object") ? obj.tips : {}; return { rows, meta_rows, tips }; } catch (_) { return { rows: [], meta_rows: [], tips: {} }; } }; const applyTooltips = () => { const payload = readTooltipMap(); const rowTags = Array.isArray(payload.rows) ? payload.rows : []; const rowMeta = Array.isArray(payload.meta_rows) ? payload.meta_rows : []; const tipMap = (payload.tips && typeof payload.tips === "object") ? payload.tips : {}; const validOrigins = new Set(["rewrite", "selection", "probe", "structural", "classifier", "implied", "user"]); const rowEls = document.querySelectorAll(".lego-tags"); rowEls.forEach((rowEl, rowIdx) => { const tags = Array.isArray(rowTags[rowIdx]) ? rowTags[rowIdx] : []; const metas = Array.isArray(rowMeta[rowIdx]) ? rowMeta[rowIdx] : []; const labels = rowEl.querySelectorAll("label"); labels.forEach((label, tagIdx) => { const span = label.querySelector("span"); const tag = (tagIdx < tags.length) ? tags[tagIdx] : ""; const tip = tag && Object.prototype.hasOwnProperty.call(tipMap, tag) ? (tipMap[tag] || "") : ""; const meta = (tagIdx < metas.length && metas[tagIdx] && typeof metas[tagIdx] === "object") ? metas[tagIdx] : {}; const originRaw = String(meta.origin || "selection").trim().toLowerCase(); const origin = validOrigins.has(originRaw) ? originRaw : "selection"; const preselected = !!meta.preselected; label.setAttribute("data-psq-origin", origin); label.setAttribute("data-psq-preselected", preselected ? "1" : "0"); if (span) { span.setAttribute("data-psq-origin", origin); span.setAttribute("data-psq-preselected", preselected ? "1" : "0"); } if (tip) { label.title = tip; if (span) span.title = tip; } else { label.removeAttribute("title"); if (span) span.removeAttribute("title"); } }); }); }; bindPromptDraftHandlers(); schedulePromptRestore(); applyTooltips(); window.addEventListener("pageshow", () => { schedulePromptRestore(); applyTooltips(); }); let lastTooltipPayload = readTooltipMapRaw(); let lastTooltipLabelCount = document.querySelectorAll(".lego-tags label").length; window.setInterval(() => { const current = readTooltipMapRaw(); const labelCount = document.querySelectorAll(".lego-tags label").length; if (current === lastTooltipPayload && labelCount === lastTooltipLabelCount) return; lastTooltipPayload = current; lastTooltipLabelCount = labelCount; applyTooltips(); }, 250); } """ client_startup_profile_js = """ () => { try { const nav = performance.getEntriesByType("navigation")[0] || null; const paints = performance.getEntriesByType("paint") || []; let fpMs = null; let fcpMs = null; for (const p of paints) { if (p && p.name === "first-paint" && fpMs === null) fpMs = p.startTime; if (p && p.name === "first-contentful-paint" && fcpMs === null) fcpMs = p.startTime; } const round = (v) => (typeof v === "number" && Number.isFinite(v) ? Math.round(v * 1000) / 1000 : null); const payload = { href_path: window.location.pathname || "", nav_type: nav ? String(nav.type || "") : "", ready_state: document.readyState || "", now_ms: round(performance.now()), response_end_ms: nav ? round(nav.responseEnd) : null, dom_interactive_ms: nav ? round(nav.domInteractive) : null, dom_content_loaded_end_ms: nav ? round(nav.domContentLoadedEventEnd) : null, load_event_end_ms: nav ? round(nav.loadEventEnd) : null, fp_ms: round(fpMs), fcp_ms: round(fcpMs), transfer_size: nav && typeof nav.transferSize === "number" ? nav.transferSize : null, encoded_body_size: nav && typeof nav.encodedBodySize === "number" ? nav.encodedBodySize : null, decoded_body_size: nav && typeof nav.decodedBodySize === "number" ? nav.decodedBodySize : null }; return [JSON.stringify(payload)]; } catch (_) { return [""]; } } """ client_enable_input_profile_js = """ () => { try { const nav = performance.getEntriesByType("navigation")[0] || null; const round = (v) => (typeof v === "number" && Number.isFinite(v) ? Math.round(v * 1000) / 1000 : null); const payload = { now_ms: round(performance.now()), ready_state: document.readyState || "", nav_type: nav ? String(nav.type || "") : "", response_end_ms: nav ? round(nav.responseEnd) : null, dom_content_loaded_end_ms: nav ? round(nav.domContentLoadedEventEnd) : null, load_event_end_ms: nav ? round(nav.loadEventEnd) : null }; return [JSON.stringify(payload)]; } catch (_) { return [""]; } } """ def _log_client_startup_profile(payload_raw: str) -> None: payload = (payload_raw or "").strip() if not payload: return rec: Dict[str, Any] try: obj = json.loads(payload) if not isinstance(obj, dict): return keep = { "href_path", "nav_type", "ready_state", "now_ms", "response_end_ms", "dom_interactive_ms", "dom_content_loaded_end_ms", "load_event_end_ms", "fp_ms", "fcp_ms", "transfer_size", "encoded_body_size", "decoded_body_size", } rec = {k: obj.get(k) for k in keep} except Exception: rec = {"parse_error": True} rec["event"] = "client_startup_profile" rec["server_t_s"] = round(time.perf_counter() - _STARTUP_PROFILE_T0, 6) print(f"STARTUP_CLIENT_PROFILE {json.dumps(rec, ensure_ascii=False)}") def _enable_prompt_input(client_payload_raw: str = ""): global _STARTUP_PROFILE_CLIENT_LOAD_MARKED if _STARTUP_PROFILE_ON and not _STARTUP_PROFILE_CLIENT_LOAD_MARKED: # First successful app.load callback from browser connection. _startup_profile_mark("ui.client_load_event") _STARTUP_PROFILE_CLIENT_LOAD_MARKED = True payload = (client_payload_raw or "").strip() if payload: rec: Dict[str, Any] try: obj = json.loads(payload) if isinstance(obj, dict): rec = { "now_ms": obj.get("now_ms"), "ready_state": obj.get("ready_state"), "nav_type": obj.get("nav_type"), "response_end_ms": obj.get("response_end_ms"), "dom_content_loaded_end_ms": obj.get("dom_content_loaded_end_ms"), "load_event_end_ms": obj.get("load_event_end_ms"), } else: rec = {"parse_error": True} except Exception: rec = {"parse_error": True} rec["event"] = "enable_prompt_input" rec["server_t_s"] = round(time.perf_counter() - _STARTUP_PROFILE_T0, 6) print(f"STARTUP_CLIENT_ENABLE {json.dumps(rec, ensure_ascii=False)}") return gr.update( interactive=True, placeholder=DEFAULT_PROMPT_EXAMPLE, ) def rag_pipeline_ui( user_prompt: str, display_top_groups: float, display_top_tags_per_group: float, display_rank_top_k: float, rewrite_override: str = "", console_seed: str = "", ): logs = [] if console_seed: logs.extend(str(console_seed).splitlines()) upstream_throttle_hits = 0 credit_or_quota_hits = 0 def _classify_runtime_line(s: str) -> None: nonlocal upstream_throttle_hits, credit_or_quota_hits t = str(s or "") t_l = t.lower() if ( "upstream throttle detected" in t_l or "rate-limited upstream" in t_l or "too many requests" in t_l or "error code: 429" in t_l ): upstream_throttle_hits += 1 if ( "credit/quota error detected" in t_l or "insufficient credits" in t_l or "payment required" in t_l or "quota exceeded" in t_l or "error code: 402" in t_l ): credit_or_quota_hits += 1 def log(s): text = str(s) logs.append(text) _classify_runtime_line(text) try: stage_timings = {} status_states: Dict[str, Dict[str, Any]] = {} preview_rewrite_s = _parse_console_stage_seconds(console_seed, "Rewrite") if preview_rewrite_s is not None: status_states["rewrite"] = {"status": "done", "time_s": preview_rewrite_s} def _progress_payload() -> List[Any]: return _build_ui_payload( console_text="\n".join(logs), row_defs=[], selected_tags=[], pipeline_status_html=_pipeline_status_html(status_states), suggested_prompt_text="", suggested_prompt_label=SUGGESTED_PROMPT_LABEL_READY, ) def _record_timing(stage: str, dt_s: float): stage_timings[stage] = float(dt_s) def _emit_timing_summary(total_s: float): summary_order = [ "preprocess", "rewrite", "structural", "classifier", "retrieval", "selection", "implication_expansion", "prompt_composition", "group_display", ] lines = [] for k in summary_order: if k in stage_timings: lines.append(f"{k}={stage_timings[k]:.2f}s") slowest = max(stage_timings.items(), key=lambda kv: kv[1])[0] if stage_timings else "n/a" log("Timing Summary: " + ", ".join(lines)) log(f"Timing Slowest Stage: {slowest}") log(f"Timing Total: {total_s:.2f}s") def _append_timing_jsonl(total_s: float): try: timing_log_path.parent.mkdir(parents=True, exist_ok=True) rec = { "timestamp_utc": datetime.utcnow().isoformat(timespec="seconds") + "Z", "stages_s": stage_timings, "total_s": float(total_s), "provider_diagnostics": { "upstream_throttle_hits": int(upstream_throttle_hits), "credit_or_quota_hits": int(credit_or_quota_hits), }, "config": { "timeout_rewrite_s": stage1_rewrite_timeout_s, "timeout_struct_s": stage1_struct_timeout_s, "timeout_classifier_s": stage1_classifier_timeout_s, "timeout_select_s": stage3_select_timeout_s, }, } with timing_log_path.open("a", encoding="utf-8") as f: f.write(json.dumps(rec, ensure_ascii=True) + "\n") log(f"Timing Log: wrote {timing_log_path}") except Exception as e: log(f"Timing Log: failed ({type(e).__name__}: {_redact_console_error_text(e)})") def _future_with_timeout( fut, timeout_s: float, stage_name: str, fallback, *, strict: bool = False, ): t0 = time.perf_counter() try: out = fut.result(timeout=max(1.0, float(timeout_s))) dt = time.perf_counter() - t0 dt_for_stage = dt if ( stage_name == "Classifier inference" and isinstance(out, dict) and isinstance(out.get("elapsed_s"), (int, float)) and float(out.get("elapsed_s") or 0.0) > 0 ): dt_for_stage = float(out.get("elapsed_s") or 0.0) log(f"{stage_name}: {dt_for_stage:.2f}s (parallel wait {dt:.2f}s)") else: log(f"{stage_name}: {dt:.2f}s") stage_key = { "Rewrite": "rewrite", "Structural inference": "structural", "Classifier inference": "classifier", "Index selection": "selection", }.get(stage_name) if stage_key: _record_timing(stage_key, dt_for_stage) return out except FutureTimeoutError: fut.cancel() msg = f"{stage_name}: timed out after {timeout_s:.0f}s" if strict: raise RuntimeError(msg) log(f"{msg}; using fallback") return fallback except Exception as e: msg = f"{stage_name}: failed ({type(e).__name__}: {_redact_console_error_text(e)})" if strict: raise RuntimeError(msg) log(f"{msg}; using fallback") return fallback t_total0 = time.perf_counter() log("Start: received prompt") if STARTUP_PREFLIGHT_ERRORS: log("Startup preflight failed:") for e in STARTUP_PREFLIGHT_ERRORS: log(f"- {_redact_console_error_text(e)}") status_states["rewrite"] = {"status": "failed", "detail": "startup check"} yield _build_ui_payload( console_text="\n".join(logs), row_defs=[], selected_tags=[], pipeline_status_html=_pipeline_status_html(status_states), suggested_prompt_text="Error: startup preflight failed. Check console details.", ) return prompt_in = (user_prompt or "").strip() if not prompt_in: yield _build_ui_payload( console_text="Error: empty prompt", row_defs=[], selected_tags=[], pipeline_status_html=_pipeline_status_html(status_states), suggested_prompt_text='Enter a prompt and click "Run".', ) return log("Input:") log(prompt_in) log("") default_or_model = ( os.environ.get("PSQ_DEFAULT_OPENROUTER_MODEL", "") or DEFAULT_STAGE_OPENROUTER_MODEL ).strip() or DEFAULT_STAGE_OPENROUTER_MODEL default_fallback_model = ( os.environ.get("PSQ_DEFAULT_OPENROUTER_FALLBACK_MODEL", "") or DEFAULT_STAGE_OPENROUTER_FALLBACK_MODEL ).strip() or DEFAULT_STAGE_OPENROUTER_FALLBACK_MODEL struct_model_cfg = ( os.environ.get("PSQ_STRUCT_OPENROUTER_MODEL", "") or "" ).strip() or default_or_model select_model_cfg = ( os.environ.get("PSQ_SELECT_OPENROUTER_MODEL", "") or "" ).strip() or default_or_model struct_fallback_model_cfg = ( os.environ.get("PSQ_STRUCT_OPENROUTER_FALLBACK_MODEL", "") or "" ).strip() or default_fallback_model select_fallback_model_cfg = ( os.environ.get("PSQ_SELECT_OPENROUTER_FALLBACK_MODEL", "") or "" ).strip() or default_fallback_model if struct_fallback_model_cfg == struct_model_cfg: struct_fallback_model_cfg = "" if select_fallback_model_cfg == select_model_cfg: select_fallback_model_cfg = "" rewrite_source_cfg = _get_rewrite_source() t5_model_dir_cfg = ( os.environ.get( "PSQ_T5_REWRITE_MODEL_DIR", "models/finetune/t5-rewrite-n30best-20260508/checkpoint-18000", ) or "models/finetune/t5-rewrite-n30best-20260508/checkpoint-18000" ).strip() t5_num_beams_cfg = _env_int("PSQ_T5_REWRITE_NUM_BEAMS", 1, minimum=1) t5_max_new_tokens_cfg = _env_int("PSQ_T5_REWRITE_MAX_NEW_TOKENS", 128, minimum=8) log( "Runtime config: " f"retrieval_global_k={retrieval_global_k} " f"retrieval_per_phrase_k={retrieval_per_phrase_k} " f"retrieval_per_phrase_final_k={retrieval_per_phrase_final_k} " f"retrieval_exact_ngram_max={retrieval_exact_ngram_max} " f"selection_mode={selection_mode} " f"selection_chunk_size={selection_chunk_size} " f"selection_per_phrase_k={selection_per_phrase_k} " f"rewrite_source={rewrite_source_cfg} " f"min_tag_count={_get_min_tag_count()} " f"select_timeout_s={stage3_select_timeout_s:.0f} " f"select_retry_timeout_s={stage3_select_retry_timeout_s:.0f} " f"select_fast_retries=disabled(configured={stage3_fast_retry_count}) " f"struct_model={struct_model_cfg} " f"classifier_enabled={classifier_enabled} " f"classifier_auto_precision={classifier_auto_precision_target:.2f} " f"classifier_candidate_top_k={classifier_candidate_top_k} " f"select_model={select_model_cfg}" ) if rewrite_source_cfg == "t5": log( "Rewrite config: " f"t5_model_dir={t5_model_dir_cfg} " f"t5_num_beams={t5_num_beams_cfg} " f"t5_max_new_tokens={t5_max_new_tokens_cfg}" ) log("") t0 = time.perf_counter() min_tag_count = _get_min_tag_count() user_tags_raw = extract_user_provided_tags_upto_3_words(prompt_in) user_tags, removed_user_low = _filter_min_count_tags(user_tags_raw, min_tag_count) user_tags, removed_user_excluded = _filter_excluded_recommendation_tags(user_tags) exact_query_phrases = extract_exact_tag_query_phrases( prompt_in, get_tag_counts(), get_alias2tags(), min_tag_count=min_tag_count, max_ngram=max(0, retrieval_exact_ngram_max), ) exact_query_phrases, removed_exact_excluded = _filter_excluded_recommendation_tags(exact_query_phrases) dt = time.perf_counter()-t0 _record_timing("preprocess", dt) log(f"Preprocess (user tag extraction): {dt:.2f}s") log("Heuristically extracted user tags:") if user_tags: log(", ".join(user_tags)) else: log("(none)") if removed_user_low: log( f"Filtered {len(removed_user_low)} low-frequency user tags " f"(<{min_tag_count}): {', '.join(removed_user_low)}" ) if removed_user_excluded: log( f"Filtered {len(removed_user_excluded)} excluded user tags: " f"{', '.join(removed_user_excluded)}" ) if retrieval_exact_ngram_max > 0: log(f"Exact caption tag query phrases (1-{retrieval_exact_ngram_max} grams):") else: log("Exact caption tag query phrases: disabled") if exact_query_phrases: shown = ", ".join(exact_query_phrases[:40]) log(shown + (" ..." if len(exact_query_phrases) > 40 else "")) else: log("(none)") if removed_exact_excluded: log( f"Filtered {len(removed_exact_excluded)} excluded exact query phrases: " f"{', '.join(removed_exact_excluded)}" ) log("") rewrite_prefilled = (rewrite_override or "").strip() if rewrite_prefilled: log("Step 1: structural + classifier inference (rewrite already prepared)") status_states["rewrite"] = status_states.get("rewrite") or {"status": "done", "detail": "prepared"} else: log("Step 1: rewrite + structural + classifier inference (concurrent)") status_states["rewrite"] = {"status": "running"} status_states["exact"] = { "status": "done", "time_s": stage_timings.get("preprocess"), "detail": f"{len(exact_query_phrases)} matches" if exact_query_phrases else "no matches", } status_states["structural"] = {"status": "running"} status_states["classifier"] = {"status": "running"} status_states["retrieval"] = {"status": "waiting"} status_states["reranker"] = {"status": "waiting"} status_states["rows"] = {"status": "waiting"} yield _progress_payload() ex = ThreadPoolExecutor(max_workers=3) try: fut_rewrite = ( ex.submit(_rewrite_prompt, prompt_in, log) if not rewrite_prefilled else None ) fut_struct = ex.submit( llm_infer_structural_tags, prompt_in, log=log, model_override=struct_model_cfg, fallback_model_override=(struct_fallback_model_cfg or None), ) fut_classifier = ex.submit( _run_tag_classifier, prompt_in, min_tag_count=min_tag_count, allow_nsfw=allow_nsfw_tags, log=log, ) if rewrite_prefilled: rewritten = rewrite_prefilled _record_timing("rewrite", 0.0) else: rewritten = _future_with_timeout( fut_rewrite, stage1_rewrite_timeout_s, "Rewrite", "", strict=True, ) structural_tags = _future_with_timeout( fut_struct, stage1_struct_timeout_s, "Structural inference", [] ) classifier_result = _future_with_timeout( fut_classifier, stage1_classifier_timeout_s, "Classifier inference", { "auto_tags": [], "candidate_tags": [], "score_by_tag": {}, "threshold_by_tag": {}, "enabled": False, }, ) probe_tags = [] finally: ex.shutdown(wait=False, cancel_futures=True) classifier_auto_tags = list(classifier_result.get("auto_tags") or []) classifier_candidate_tags = list(classifier_result.get("candidate_tags") or []) classifier_score_by_tag = dict(classifier_result.get("score_by_tag") or {}) structural_tags, removed_struct_low = _filter_min_count_tags(structural_tags, min_tag_count) structural_tags, removed_struct_excluded = _filter_excluded_recommendation_tags(structural_tags) if removed_struct_low: log( f"Filtered {len(removed_struct_low)} low-frequency structural tags " f"(<{min_tag_count}): {', '.join(removed_struct_low)}" ) if removed_struct_excluded: log( f"Filtered {len(removed_struct_excluded)} excluded structural tags: " f"{', '.join(removed_struct_excluded)}" ) if not rewritten: raise RuntimeError("Rewrite: empty output") log("Rewrite:") log(rewritten if rewritten else "(empty)") log("") if "rewrite" not in status_states or status_states.get("rewrite", {}).get("status") == "running": status_states["rewrite"] = { "status": "done", "time_s": stage_timings.get("rewrite"), } status_states["structural"] = { "status": "done", "time_s": stage_timings.get("structural"), "detail": f"{len(structural_tags)} tags" if structural_tags else "no tags", } if classifier_result.get("enabled"): status_states["classifier"] = { "status": "done", "time_s": stage_timings.get("classifier"), "detail": f"{len(classifier_auto_tags)} auto, {len(classifier_candidate_tags)} candidates", } else: status_states["classifier"] = {"status": "skipped", "detail": "not loaded"} status_states["retrieval"] = {"status": "running"} yield _progress_payload() rewrite_for_retrieval = rewritten retrieval_query_hints = list(dict.fromkeys((user_tags or []) + (exact_query_phrases or []))) if retrieval_query_hints: # keep them separate in logs, but allow them to help retrieval rewrite_for_retrieval = (rewrite_for_retrieval + ", " + ", ".join(retrieval_query_hints)).strip(", ").strip() log("Step 2: Prompt Squirrel retrieval (hidden)") try: t0 = time.perf_counter() retrieval_context_tags = list( dict.fromkeys( (structural_tags or []) + (probe_tags or []) + (classifier_auto_tags or []) ) ) rewrite_phrases = [p.strip() for p in (rewrite_for_retrieval or "").split(",") if p.strip()] retrieval_result = psq_candidates_from_rewrite_phrases( rewrite_phrases=rewrite_phrases, allow_nsfw_tags=allow_nsfw_tags, context_tags=retrieval_context_tags, global_k=max(1, retrieval_global_k), per_phrase_k=max(1, retrieval_per_phrase_k), per_phrase_final_k=max(1, retrieval_per_phrase_final_k), min_tag_count=max(0, min_tag_count), verbose=verbose_retrieval, ) if isinstance(retrieval_result, tuple): candidates, phrase_reports = retrieval_result else: candidates, phrase_reports = retrieval_result, [] candidates, removed_candidate_excluded = _filter_excluded_candidates(candidates) if removed_candidate_excluded: log( f"Filtered {len(removed_candidate_excluded)} excluded retrieved tags: " f"{', '.join(removed_candidate_excluded[:25])}" + (" ..." if len(removed_candidate_excluded) > 25 else "") ) if selection_candidate_cap > 0 and len(candidates) > selection_candidate_cap: candidates = candidates[:selection_candidate_cap] log(f"Selection candidate cap applied: {selection_candidate_cap}") if classifier_candidate_tags: existing_candidate_tags = { _norm_tag_for_lookup(str(getattr(c, "tag", "") or "")) for c in (candidates or []) } tag_counts = get_tag_counts() added_classifier_candidates = [] for tag in classifier_candidate_tags: t = _norm_tag_for_lookup(tag) if not t or t in existing_candidate_tags: continue score = float(classifier_score_by_tag.get(t, 0.0) or 0.0) candidates.append( Candidate( tag=t, score_combined=score, score_fasttext=None, score_context=None, count=tag_counts.get(t), sources=["classifier_candidate"], ) ) existing_candidate_tags.add(t) added_classifier_candidates.append(t) if added_classifier_candidates: log( f"Classifier injected {len(added_classifier_candidates)} reranker candidates: " + ", ".join(added_classifier_candidates[:25]) ) dt = time.perf_counter()-t0 _record_timing("retrieval", dt) log(f"Retrieval: {dt:.2f}s") log(f"Retrieved {len(candidates)} candidate tags") if verbose_retrieval: log(f"Total unique candidates: {len(candidates)}") limit = None if verbose_retrieval_all else max(1, int(verbose_retrieval_limit)) for report in phrase_reports: phrase = report.get("normalized") or report.get("phrase") or "" lookup = report.get("lookup") or "" tfidf_vocab = report.get("tfidf_vocab") log(f"Phrase: {phrase} (lookup={lookup}) tfidf_vocab={tfidf_vocab}") rows = report.get("candidates", []) shown = rows if limit is None else rows[:limit] for row in shown: tag = row.get("tag") alias_token = row.get("alias_token") score_fasttext = row.get("score_fasttext") score_context = row.get("score_context") score_combined = row.get("score_combined") count = row.get("count") alias_part = "" if alias_token and alias_token != tag: alias_part = f" [alias_token={alias_token}]" fasttext_str = ( f"{score_fasttext:.3f}" if isinstance(score_fasttext, (int, float)) else score_fasttext ) if score_context is None: context_str = "None" else: context_str = ( f"{score_context:.3f}" if isinstance(score_context, (int, float)) else score_context ) combined_str = ( f"{score_combined:.3f}" if isinstance(score_combined, (int, float)) else score_combined ) log( f" {tag}{alias_part} | fasttext={fasttext_str} context={context_str} " f"combined={combined_str} count={count}" ) if limit is not None and len(rows) > limit: log(f" ... ({len(rows) - limit} more)") except Exception as e: log(f"Retrieval fallback: {type(e).__name__}: {_redact_console_error_text(e)}") candidates = [] retrieved_candidate_tags = list( dict.fromkeys( _norm_tag_for_lookup(c.tag) for c in (candidates or []) if getattr(c, "tag", None) ) ) status_states["retrieval"] = { "status": "done", "time_s": stage_timings.get("retrieval"), "detail": f"{len(candidates or [])} candidates", } status_states["reranker"] = {"status": "running"} yield _progress_payload() log("Step 3: LLM index selection (uses rewrite + structural context)") selection_query = _build_selection_query( prompt_in=prompt_in, rewritten=rewritten, structural_tags=structural_tags, probe_tags=probe_tags, classifier_auto_tags=classifier_auto_tags, ) select_model_override = select_model_cfg.strip() or None select_fallback_model_override = select_fallback_model_cfg.strip() or None picked_indices = None last_stage3_error: Exception | None = None # Same-model retries are disabled; failover happens inside Stage3 calls. stage3_attempts = 1 for attempt_i in range(stage3_attempts): timeout_s = stage3_select_timeout_s if attempt_i == 0 else stage3_select_retry_timeout_s if attempt_i > 0: log( f"Index selection: fast retry {attempt_i}/{stage3_fast_retry_count} " f"(timeout={timeout_s:.0f}s)" ) ex = ThreadPoolExecutor(max_workers=1) try: fut_sel = ex.submit( llm_select_indices, query_text=selection_query, candidates=candidates, max_pick=0, log=log, mode=selection_mode, chunk_size=max(1, selection_chunk_size), per_phrase_k=max(1, selection_per_phrase_k), model_override=select_model_override, fallback_model_override=select_fallback_model_override, ) picked_indices = _future_with_timeout( fut_sel, timeout_s, "Index selection", [], strict=True, ) last_stage3_error = None break except Exception as e: last_stage3_error = e log(f"Index selection attempt {attempt_i + 1} failed: {_redact_console_error_text(e)}") finally: ex.shutdown(wait=False, cancel_futures=True) if picked_indices is None: raise RuntimeError( f"Index selection failed after {stage3_attempts} attempt(s): {last_stage3_error}" ) selection_selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else [] selection_selected_tags, removed_stage3_low = _filter_min_count_tags(selection_selected_tags, min_tag_count) if removed_stage3_low: log( f" Filtered {len(removed_stage3_low)} low-frequency stage3 tags " f"(<{min_tag_count}): {', '.join(removed_stage3_low)}" ) selected_tags = list(selection_selected_tags) if structural_tags: # Add structural tags that aren't already selected existing = {t for t in selected_tags} new_structural = [t for t in structural_tags if t not in existing] selected_tags.extend(new_structural) log(f" Added {len(new_structural)} structural tags: {', '.join(new_structural)}") else: log(" No structural tags inferred") if probe_tags: existing = {t for t in selected_tags} new_probe = [t for t in probe_tags if t not in existing] selected_tags.extend(new_probe) log(f" Added {len(new_probe)} probe tags: {', '.join(new_probe)}") if classifier_auto_tags: existing = {t for t in selected_tags} new_classifier = [t for t in classifier_auto_tags if t not in existing] selected_tags.extend(new_classifier) if new_classifier: log( f" Added {len(new_classifier)} high-confidence classifier tags: " + ", ".join( f"{t}:{classifier_score_by_tag.get(t, 0.0):.3f}" for t in new_classifier[:25] ) ) else: log(" High-confidence classifier tags were already selected") else: log(" No high-confidence classifier tags") selected_tags, removed_excluded_direct = _filter_excluded_recommendation_tags(selected_tags) if removed_excluded_direct: log(f" Removed {len(removed_excluded_direct)} excluded tags: {', '.join(removed_excluded_direct)}") direct_selected_tags = list(dict.fromkeys(selected_tags)) status_states["reranker"] = { "status": "done", "time_s": stage_timings.get("selection"), "detail": f"{len(selection_selected_tags)} selected", } status_states["rows"] = {"status": "running"} yield _progress_payload() log("Step 3c: Expand via tag implications") t0 = time.perf_counter() tag_set = set(selected_tags) expanded, implied_only = expand_tags_via_implications(tag_set) dt = time.perf_counter()-t0 _record_timing("implication_expansion", dt) log(f"Implication expansion: {dt:.2f}s") implied_selected_tags = sorted(implied_only) if implied_only else [] if implied_only: implied_added = sorted(implied_only) implied_added, removed_implied_low = _filter_min_count_tags(implied_added, min_tag_count) implied_selected_tags = list(implied_added) if implied_added: selected_tags.extend(implied_added) log(f" Added {len(implied_added)} implied tags: {', '.join(implied_added)}") if removed_implied_low: log( f" Filtered {len(removed_implied_low)} low-frequency implied tags " f"(<{min_tag_count}): {', '.join(removed_implied_low)}" ) else: log(" No additional implied tags") selected_tags, removed_excluded_implied = _filter_excluded_recommendation_tags(selected_tags) implied_selected_tags = [ t for t in implied_selected_tags if not _is_excluded_recommendation_tag(t) ] if removed_excluded_implied: log( f" Removed {len(removed_excluded_implied)} excluded tags after implications: " f"{', '.join(removed_excluded_implied)}" ) log("Step 4: Compose final prompt") t0 = time.perf_counter() final_prompt = compose_final_prompt(rewritten, selected_tags) dt = time.perf_counter()-t0 _record_timing("prompt_composition", dt) log(f"Prompt composition: {dt:.2f}s") log("Step 5: Build ranked group/category display") t0 = time.perf_counter() seed_terms = [] seed_terms.extend(user_tags) seed_terms.extend([p.strip() for p in (rewritten or "").split(",") if p.strip()]) seed_terms.extend(structural_tags or []) seed_terms.extend(probe_tags or []) seed_terms.extend(classifier_auto_tags or []) seed_terms.extend(classifier_candidate_tags or []) seed_terms.extend(selected_tags) seed_terms = list(dict.fromkeys(seed_terms)) active_selected_tags = list(dict.fromkeys(selected_tags)) structural_set = {_norm_tag_for_lookup(t) for t in (structural_tags or []) if t} probe_set = {_norm_tag_for_lookup(t) for t in (probe_tags or []) if t} classifier_auto_set = {_norm_tag_for_lookup(t) for t in (classifier_auto_tags or []) if t} implied_set = {_norm_tag_for_lookup(t) for t in (implied_selected_tags or []) if t} rewrite_set = { _norm_tag_for_lookup(t) for t in (list(user_tags or []) + [p.strip() for p in (rewritten or "").split(",") if p.strip()]) if t } selection_set = {_norm_tag_for_lookup(t) for t in (selection_selected_tags or []) if t} tag_selection_origins: Dict[str, str] = {} for tag in active_selected_tags: tag_norm = _norm_tag_for_lookup(tag) if tag_norm in structural_set: origin = "structural" elif tag_norm in classifier_auto_set: origin = "classifier" elif tag_norm in probe_set: origin = "probe" elif tag_norm in rewrite_set: origin = "rewrite" elif tag_norm in selection_set: origin = "selection" elif tag_norm in implied_set: origin = "implied" else: # Unknown/fallback tags use selection color. origin = "selection" tag_selection_origins[tag] = origin if tag_norm and tag_norm != tag: tag_selection_origins[tag_norm] = origin direct_tags_for_implied = list( dict.fromkeys(_norm_tag_for_lookup(t) for t in (direct_selected_tags or []) if t) ) direct_tags_for_implied_idx = {t: i for i, t in enumerate(direct_tags_for_implied)} direct_tags_for_implied.sort( key=lambda t: ( _selection_source_rank(tag_selection_origins.get(t, "selection")), direct_tags_for_implied_idx.get(t, 10**9), ) ) implied_parent_map = _build_implied_parent_map( direct_tags_ordered=direct_tags_for_implied, implied_tags=implied_selected_tags, ) toggle_rows = _build_toggle_rows( seed_terms=seed_terms, selected_tags=active_selected_tags, retrieved_candidate_tags=retrieved_candidate_tags, tag_selection_origins=tag_selection_origins, implied_parent_map=implied_parent_map, top_groups=max(1, int(display_top_groups)), top_tags_per_group=max(1, int(display_top_tags_per_group)), group_rank_top_k=max(1, int(display_rank_top_k)), ) dt = time.perf_counter()-t0 _record_timing("group_display", dt) log(f"Ranked group display: {dt:.2f}s ({len(toggle_rows)} rows)") log( _build_display_audit_line( toggle_rows, active_selected_tags=active_selected_tags, direct_selected_tags=direct_selected_tags, implied_selected_tags=implied_selected_tags, ) ) for idx, row in enumerate(toggle_rows[: max(0, int(display_max_rows_default))]): tags_preview = ", ".join(row.get("tags", [])) log(f"UI Row {idx}: {row.get('label', '')} :: {tags_preview}") total_dt = time.perf_counter()-t_total0 _emit_timing_summary(total_dt) if upstream_throttle_hits or credit_or_quota_hits: log( "Provider Diagnostics: " f"upstream_throttle_hits={upstream_throttle_hits} " f"credit_or_quota_hits={credit_or_quota_hits}" ) if not active_selected_tags and upstream_throttle_hits > 0: log( "Provider Diagnostics: output likely degraded by upstream throttling " "(selection context may be incomplete)." ) _append_timing_jsonl(total_dt) log("Done: final prompt ready") status_states["rows"] = { "status": "done", "time_s": stage_timings.get("group_display"), "detail": f"{len(toggle_rows)} rows", } yield _build_ui_payload( console_text="\n".join(logs), row_defs=toggle_rows, selected_tags=active_selected_tags, pipeline_status_html=_pipeline_status_html(status_states), ) return except Exception as e: log(f"Error: {type(e).__name__}: {_redact_console_error_text(e)}") for key, state in list(status_states.items()): if state.get("status") == "running": status_states[key] = {"status": "failed", "detail": "see console"} yield _build_ui_payload( console_text="\n".join(logs), row_defs=[], selected_tags=[], pipeline_status_html=_pipeline_status_html(status_states), suggested_prompt_text=_format_user_facing_error(e), ) return _startup_profile_mark("ui.blocks_build_begin") with gr.Blocks(css=css, js=client_js) as app: with gr.Row(): with gr.Column(scale=3, elem_classes=["prompt-col"]): gr.Markdown( "Retrieval-augmented system for mapping unstructured input to a controlled " "vocabulary (demo: image tags).
" "See 'How Prompt Squirrel Works'.", elem_classes=["run-hint", "system-description"], ) with gr.Group(elem_classes=["prompt-card"]): image_tags = gr.Textbox( label="Describe your image here, then click 'Run'.", placeholder=DEFAULT_PROMPT_EXAMPLE, lines=1, interactive=False, elem_classes=["enter-prompt-box"], ) with gr.Group(elem_classes=["prompt-card", "suggested-prompt-card"]): pipeline_status = gr.HTML( value=_pipeline_status_html(), elem_classes=["pipeline-status"], ) suggested_prompt = gr.Textbox( label=SUGGESTED_PROMPT_LABEL_READY, lines=2, interactive=False, show_copy_button=True, placeholder='Suggested prompt will appear here after you click "Run".', elem_classes=["suggested-prompt-box"], ) with gr.Column(scale=1): _mascot_pil = _load_mascot_image() if _mascot_pil is not None: mascot_img = gr.Image( value=_mascot_pil, show_label=False, interactive=False, height=240, elem_id="mascot" ) else: mascot_img = gr.Markdown("`(mascot image unavailable)`") submit_button = gr.Button("Run", variant="primary", visible=False, interactive=False) gr.Markdown("Typical runtime: up to ~20 seconds.", elem_classes=["mascot-runtime-hint"]) last_run_prompt_state = gr.State("") selected_tags_state = gr.State([]) rows_dirty_state = gr.State(False) row_defs_state = gr.State([]) row_values_state = gr.State([]) rewrite_state = gr.State("") toggle_instruction = gr.Markdown( "Click tag buttons to add or remove tags from the suggested prompt.", elem_classes=["row-instruction"], visible=False, ) row_headers: List[gr.Markdown] = [] row_checkboxes: List[gr.CheckboxGroup] = [] for _ in range(display_max_rows_default): with gr.Row(): with gr.Column(scale=2, min_width=170): row_headers.append(gr.Markdown(value="", visible=False, elem_classes=["row-heading"])) with gr.Column(scale=10): row_checkboxes.append( gr.CheckboxGroup( choices=[], value=[], visible=False, interactive=True, container=False, elem_classes=["lego-tags"], ) ) with gr.Row(): with gr.Column(scale=10): gr.HTML( """
Legend: Query Reformulation Candidate Ranking Scene Composition Tag Classifier Implied User-toggled Unselected
""" ) with gr.Column(scale=2, min_width=180): rebuild_rows_button = gr.Button( "Rebuild Rows", variant="primary", visible=False, interactive=False, ) with gr.Accordion("Display Settings", open=False): with gr.Row(): display_top_groups = gr.Number( value=display_top_groups_default, precision=0, label="Rows (Top Groups/Categories)", minimum=1, ) display_top_tags_per_group = gr.Number( value=display_top_tags_per_group_default, precision=0, label="Top Tags Shown Per Row", minimum=1, ) display_rank_top_k = gr.Number( value=display_rank_top_k_default, precision=0, label="Top Tags Used for Row Ranking", minimum=1, ) with gr.Accordion("Console", open=False): console = gr.Textbox( label="Console", lines=10, interactive=False, placeholder="Progress logs will appear here.", elem_id="psq-console", ) with gr.Accordion("How Prompt Squirrel Works", open=False): _about_md = _load_about_docs_markdown() _about_before, _about_after, _has_arch_slot = _split_about_docs_for_diagram(_about_md) if _has_arch_slot: if _about_before: gr.Markdown( _about_before, elem_id="about-docs", elem_classes=["about-docs"], ) gr.HTML( _build_arch_diagram_html(), elem_classes=["about-docs"], ) if _about_after: gr.Markdown( _about_after, elem_classes=["about-docs"], ) else: gr.Markdown( _about_md, elem_id="about-docs", elem_classes=["about-docs"], ) tooltip_map_payload = gr.Textbox( value="{}", visible=True, interactive=False, container=False, elem_id="psq-tooltip-map", elem_classes=["psq-hidden"], ) client_startup_profile_payload = gr.Textbox( value="", visible=True, interactive=False, container=False, elem_id="psq-client-startup-profile", elem_classes=["psq-hidden"], ) client_enable_input_payload = gr.Textbox( value="", visible=True, interactive=False, container=False, elem_id="psq-client-enable-input-profile", elem_classes=["psq-hidden"], ) run_outputs = [ console, toggle_instruction, tooltip_map_payload, pipeline_status, suggested_prompt, selected_tags_state, rows_dirty_state, rebuild_rows_button, row_defs_state, row_values_state, *row_headers, *row_checkboxes, ] run_outputs_with_rewrite = [*run_outputs, rewrite_state] image_tags.change( _update_run_button_visibility, inputs=[image_tags, last_run_prompt_state], outputs=[submit_button], queue=False, show_progress="hidden", ) app.load( _enable_prompt_input, inputs=[client_enable_input_payload], outputs=[image_tags], js=client_enable_input_profile_js, queue=False, show_progress="hidden", ) app.load( _log_client_startup_profile, inputs=[client_startup_profile_payload], outputs=[], js=client_startup_profile_js, queue=False, show_progress="hidden", ) submit_button.click( _mark_run_triggered, inputs=[image_tags], outputs=[submit_button, last_run_prompt_state], queue=False, show_progress="hidden", ).then( _prepare_run_ui_with_rewrite_state, inputs=[], outputs=run_outputs_with_rewrite, queue=False, show_progress="hidden", ).then( _rewrite_preview_ui, inputs=[image_tags], outputs=run_outputs_with_rewrite, show_progress="hidden", ).then( rag_pipeline_ui, inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k, rewrite_state, console], outputs=run_outputs, show_progress="minimal", show_progress_on=[mascot_img], ) image_tags.submit( _mark_run_triggered, inputs=[image_tags], outputs=[submit_button, last_run_prompt_state], queue=False, show_progress="hidden", ).then( _prepare_run_ui_with_rewrite_state, inputs=[], outputs=run_outputs_with_rewrite, queue=False, show_progress="hidden", ).then( _rewrite_preview_ui, inputs=[image_tags], outputs=run_outputs_with_rewrite, show_progress="hidden", ).then( rag_pipeline_ui, inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k, rewrite_state, console], outputs=run_outputs, show_progress="minimal", show_progress_on=[mascot_img], ) for idx, row_cb in enumerate(row_checkboxes): row_cb.change( fn=lambda changed_values, selected_state, rows_dirty, row_defs, row_values, i=idx: _on_toggle_row( i, changed_values, selected_state, rows_dirty, row_defs, row_values, display_max_rows_default, ), inputs=[row_cb, selected_tags_state, rows_dirty_state, row_defs_state, row_values_state], outputs=[selected_tags_state, rows_dirty_state, rebuild_rows_button, row_values_state, suggested_prompt, *row_checkboxes], queue=False, show_progress="hidden", ) rebuild_rows_button.click( _rebuild_rows_from_selected, inputs=[selected_tags_state, row_defs_state, row_values_state, display_top_groups, display_top_tags_per_group, display_rank_top_k], outputs=[ toggle_instruction, tooltip_map_payload, suggested_prompt, selected_tags_state, rows_dirty_state, rebuild_rows_button, row_defs_state, row_values_state, *row_headers, *row_checkboxes, ], queue=False, show_progress="hidden", ) _startup_profile_mark("ui.blocks_build_end") if __name__ == "__main__": _startup_profile_mark("launch.begin", ssr_mode=False) app.queue().launch( allowed_paths=[str(MASCOT_DIR), str(DOCS_DIR)], ssr_mode=False, pwa=False, )