Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import logging | |
| import time | |
| import json | |
| import csv | |
| import base64 | |
| from datetime import datetime | |
| from functools import lru_cache | |
| from PIL import Image | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Set, Tuple | |
| from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError | |
| from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words | |
| from psq_rag.llm.rewrite import llm_rewrite_prompt | |
| from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup | |
| from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags, llm_infer_probe_tags | |
| from psq_rag.retrieval.state import ( | |
| expand_tags_via_implications, | |
| get_tag_type_name, | |
| get_tag_implications, | |
| get_tag_counts, | |
| ) | |
| from psq_rag.ui.group_ranked_display import rank_groups_from_tfidf, _load_enabled_groups | |
| def _split_prompt_commas(s: str) -> List[str]: | |
| return [p.strip() for p in (s or "").split(",") if p.strip()] | |
| def _norm_for_dedupe(tag: str) -> str: | |
| # your canonical form for lookup/dedupe | |
| return _norm_tag_for_lookup(tag.lower()) | |
| def compose_final_prompt(rewritten_prompt: str, selected_tags: List[str]) -> str: | |
| parts = _split_prompt_commas(rewritten_prompt) | |
| parts.extend(selected_tags) | |
| seen = set() | |
| out = [] | |
| for p in parts: | |
| key = _norm_for_dedupe(p) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| out.append(p) | |
| return ", ".join(out) | |
| def _display_tag_text(tag: str) -> str: | |
| return tag.replace("_", " ") | |
| def _display_row_label(name: str) -> str: | |
| n = (name or "").strip() | |
| if not n: | |
| return "" | |
| if n == "selected_other": | |
| return "Selected (Other)" | |
| return n.replace("_", " ").title() | |
| def _normalize_selection_origin(origin: str) -> str: | |
| o = (origin or "").strip().lower() | |
| if o in {"rewrite", "selection", "probe", "structural", "user", "candidate"}: | |
| return o | |
| return "selection" | |
| def _load_tag_wiki_defs() -> Dict[str, str]: | |
| p = Path("data/tag_wiki_defs.json") | |
| if not p.exists(): | |
| return {} | |
| try: | |
| with p.open("r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| out: Dict[str, str] = {} | |
| if isinstance(data, dict): | |
| for k, v in data.items(): | |
| tag = _norm_tag_for_lookup(str(k)) | |
| text = " ".join(str(v or "").split()) | |
| if tag and text: | |
| out[tag] = text | |
| return out | |
| except Exception: | |
| return {} | |
| def _tooltip_text_for_tag(tag: str) -> str: | |
| t = _norm_tag_for_lookup(tag) | |
| parts: List[str] = [] | |
| try: | |
| count = get_tag_counts().get(t) | |
| except Exception: | |
| count = None | |
| if isinstance(count, int): | |
| parts.append(f"Count: {count:,}") | |
| d = _load_tag_wiki_defs().get(t, "") | |
| if d: | |
| parts.append(d) | |
| return "\n".join(parts).strip() | |
| def _choice_label_with_source_meta(tag: str, *, origin: str, preselected: bool) -> str: | |
| # Marker is stripped client-side and converted into data attributes for CSS-driven colors/tooltips. | |
| origin_norm = _normalize_selection_origin(origin) | |
| pre = "1" if preselected else "0" | |
| tooltip = _tooltip_text_for_tag(tag) | |
| tip_b64 = "" | |
| if tooltip: | |
| tip_b64 = base64.urlsafe_b64encode(tooltip.encode("utf-8")).decode("ascii") | |
| return f"{_display_tag_text(tag)} [[psq:{origin_norm}:{pre}:{tip_b64}]]" | |
| def _selection_source_rank(origin: str) -> int: | |
| o = _normalize_selection_origin(origin) | |
| if o == "structural": | |
| return 0 | |
| if o == "probe": | |
| return 1 | |
| # Keep rewrite/user in the same priority band as general selection for row ordering. | |
| return 2 | |
| def _build_implied_parent_map( | |
| direct_tags_ordered: List[str], | |
| implied_tags: List[str], | |
| ) -> Dict[str, str]: | |
| implied_set = {_norm_tag_for_lookup(t) for t in (implied_tags or []) if t} | |
| if not implied_set or not direct_tags_ordered: | |
| return {} | |
| impl = get_tag_implications() | |
| parent_by_implied: Dict[str, str] = {} | |
| for direct in direct_tags_ordered: | |
| d = _norm_tag_for_lookup(direct) | |
| if not d: | |
| continue | |
| queue = [d] | |
| seen = {d} | |
| while queue: | |
| t = queue.pop() | |
| for parent in impl.get(t, ()): | |
| p = _norm_tag_for_lookup(parent) | |
| if not p or p in seen: | |
| continue | |
| seen.add(p) | |
| if p in implied_set and p not in parent_by_implied: | |
| parent_by_implied[p] = d | |
| queue.append(p) | |
| return parent_by_implied | |
| def _order_selected_tags_for_row( | |
| *, | |
| row_selected_tags: List[str], | |
| selected_index: Dict[str, int], | |
| tag_selection_origins: Dict[str, str], | |
| implied_parent_map: Dict[str, str], | |
| ) -> List[str]: | |
| row_selected_norm = [_norm_tag_for_lookup(t) for t in (row_selected_tags or []) if t] | |
| implied_in_row = {t for t in row_selected_norm if t in implied_parent_map} | |
| base_tags = [t for t in row_selected_norm if t not in implied_in_row] | |
| base_tags.sort( | |
| key=lambda t: ( | |
| _selection_source_rank(tag_selection_origins.get(t, "selection")), | |
| selected_index.get(t, 10**9), | |
| t, | |
| ) | |
| ) | |
| children_by_parent: Dict[str, List[str]] = {} | |
| for implied in implied_in_row: | |
| parent = implied_parent_map.get(implied) | |
| if parent: | |
| children_by_parent.setdefault(parent, []).append(implied) | |
| for parent, children in children_by_parent.items(): | |
| children.sort(key=lambda t: (selected_index.get(t, 10**9), t)) | |
| ordered: List[str] = [] | |
| emitted: Set[str] = set() | |
| for tag in base_tags: | |
| if tag in emitted: | |
| continue | |
| ordered.append(tag) | |
| emitted.add(tag) | |
| for child in children_by_parent.get(tag, []): | |
| if child not in emitted: | |
| ordered.append(child) | |
| emitted.add(child) | |
| remaining_implied = [t for t in row_selected_norm if t not in emitted] | |
| remaining_implied.sort( | |
| key=lambda t: ( | |
| _selection_source_rank(tag_selection_origins.get(implied_parent_map.get(t, ""), "selection")), | |
| selected_index.get(implied_parent_map.get(t, ""), 10**9), | |
| selected_index.get(t, 10**9), | |
| t, | |
| ) | |
| ) | |
| for t in remaining_implied: | |
| if t not in emitted: | |
| ordered.append(t) | |
| emitted.add(t) | |
| return ordered | |
| def _escape_prompt_tag(tag: str) -> str: | |
| return ( | |
| tag.replace("_", " ") | |
| .replace("(", "\\(") | |
| .replace(")", "\\)") | |
| ) | |
| def _ordered_selected_for_prompt(selected: Set[str], row_defs: List[Dict[str, Any]]) -> List[str]: | |
| out: List[str] = [] | |
| seen: Set[str] = set() | |
| for row in row_defs: | |
| for tag in row.get("tags", []): | |
| if tag in selected and tag not in seen: | |
| out.append(tag) | |
| seen.add(tag) | |
| # Fallback for any selected tags not present in current rows. | |
| for tag in sorted(selected): | |
| if tag not in seen: | |
| out.append(tag) | |
| seen.add(tag) | |
| return out | |
| def _compose_toggle_prompt_text(selected_tags: List[str], row_defs: List[Dict[str, Any]]) -> str: | |
| selected = {t for t in (selected_tags or []) if t} | |
| ordered = _ordered_selected_for_prompt(selected, row_defs or []) | |
| return ", ".join(_escape_prompt_tag(t) for t in ordered) | |
| def _is_artist_tag(tag: str) -> bool: | |
| t = _norm_tag_for_lookup(str(tag)) | |
| if not t: | |
| return False | |
| # Keep a resilient fallback for malformed/missing tag typing metadata. | |
| return get_tag_type_name(t) == "artist" or t.startswith("by_") | |
| def _load_excluded_recommendation_tags() -> Set[str]: | |
| csv_path = Path("data/analysis/category_registry.csv") | |
| out: Set[str] = set() | |
| if not csv_path.exists(): | |
| return out | |
| try: | |
| with csv_path.open("r", encoding="utf-8", newline="") as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| tag = _norm_tag_for_lookup(str(row.get("tag") or "")) | |
| if not tag: | |
| continue | |
| status = str(row.get("category_status") or "").strip().lower() | |
| if status == "excluded": | |
| out.add(tag) | |
| except Exception: | |
| return set() | |
| return out | |
| def _is_excluded_recommendation_tag(tag: str) -> bool: | |
| t = _norm_tag_for_lookup(str(tag)) | |
| if not t: | |
| return False | |
| return t in _load_excluded_recommendation_tags() | |
| def _filter_excluded_recommendation_tags(tags: List[str]) -> Tuple[List[str], List[str]]: | |
| excluded = _load_excluded_recommendation_tags() | |
| if not excluded: | |
| return list(dict.fromkeys(_norm_tag_for_lookup(t) for t in (tags or []) if t)), [] | |
| keep: List[str] = [] | |
| removed: List[str] = [] | |
| seen: Set[str] = set() | |
| for raw in (tags or []): | |
| t = _norm_tag_for_lookup(str(raw)) | |
| if not t: | |
| continue | |
| if t in excluded: | |
| removed.append(t) | |
| continue | |
| if t in seen: | |
| continue | |
| seen.add(t) | |
| keep.append(t) | |
| return keep, sorted(set(removed)) | |
| def _build_toggle_rows( | |
| *, | |
| seed_terms: List[str], | |
| selected_tags: List[str], | |
| retrieved_candidate_tags: List[str], | |
| tag_selection_origins: Dict[str, str], | |
| implied_parent_map: Dict[str, str], | |
| top_groups: int, | |
| top_tags_per_group: int, | |
| group_rank_top_k: int, | |
| ) -> List[Dict[str, Any]]: | |
| ranked_rows = rank_groups_from_tfidf( | |
| seed_terms=seed_terms, | |
| top_groups=max(1, int(top_groups)), | |
| top_tags_per_group=max(1, int(top_tags_per_group)), | |
| group_rank_top_k=max(1, int(group_rank_top_k)), | |
| ) | |
| groups_map = _load_enabled_groups() | |
| selected_active = list( | |
| dict.fromkeys( | |
| _norm_tag_for_lookup(t) | |
| for t in selected_tags | |
| if t and not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t) | |
| ) | |
| ) | |
| selected_index: Dict[str, int] = {t: i for i, t in enumerate(selected_active)} | |
| row_defs: List[Dict[str, Any]] = [] | |
| enabled_group_tag_sets: Dict[str, Set[str]] = { | |
| name: {t for t in tags if not _is_artist_tag(t)} | |
| for name, tags in groups_map.items() | |
| } | |
| tags_in_any_enabled_group: Set[str] = set() | |
| for tag_set in enabled_group_tag_sets.values(): | |
| tags_in_any_enabled_group.update(tag_set) | |
| displayed_group_names = [r.group_name for r in ranked_rows] | |
| displayed_group_tag_sets: Dict[str, Set[str]] = { | |
| name: enabled_group_tag_sets.get(name, set()) | |
| for name in displayed_group_names | |
| } | |
| tags_in_any_displayed_group: Set[str] = set() | |
| for tag_set in displayed_group_tag_sets.values(): | |
| tags_in_any_displayed_group.update(tag_set) | |
| retrieved_uncategorized_ranked = list( | |
| dict.fromkeys( | |
| _norm_tag_for_lookup(t) | |
| for t in (retrieved_candidate_tags or []) | |
| if t | |
| and not _is_artist_tag(t) | |
| and not _is_excluded_recommendation_tag(t) | |
| and _norm_tag_for_lookup(t) not in tags_in_any_enabled_group | |
| ) | |
| ) | |
| retrieved_other_row: Dict[str, Any] | None = None | |
| if retrieved_uncategorized_ranked: | |
| retrieved_uncategorized_set = set(retrieved_uncategorized_ranked) | |
| selected_in_retrieved_other_raw = [ | |
| t for t in selected_active if t in retrieved_uncategorized_set | |
| ] | |
| selected_in_retrieved_other = _order_selected_tags_for_row( | |
| row_selected_tags=selected_in_retrieved_other_raw, | |
| selected_index=selected_index, | |
| tag_selection_origins=tag_selection_origins, | |
| implied_parent_map=implied_parent_map, | |
| ) | |
| merged_retrieved_other = selected_in_retrieved_other + [ | |
| t for t in retrieved_uncategorized_ranked if t not in selected_in_retrieved_other | |
| ] | |
| keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_retrieved_other)) | |
| merged_retrieved_other = merged_retrieved_other[:keep_n] | |
| retrieved_other_meta = { | |
| t: { | |
| "origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")), | |
| "preselected": t in selected_active, | |
| } | |
| for t in merged_retrieved_other | |
| } | |
| retrieved_other_row = { | |
| "name": "other_retrieved", | |
| "label": "Other (Retrieved)", | |
| "tags": merged_retrieved_other, | |
| "tag_meta": retrieved_other_meta, | |
| } | |
| # "Selected (Other)" should contain selected tags not already shown in any displayed row. | |
| # Include "Other (Retrieved)" in that displayed-row set to avoid duplicates across those rows. | |
| tags_in_displayed_rows = set(tags_in_any_displayed_group) | |
| if retrieved_other_row: | |
| tags_in_displayed_rows.update(retrieved_other_row.get("tags", [])) | |
| selected_other_raw = [t for t in selected_active if t not in tags_in_displayed_rows] | |
| selected_other = _order_selected_tags_for_row( | |
| row_selected_tags=selected_other_raw, | |
| selected_index=selected_index, | |
| tag_selection_origins=tag_selection_origins, | |
| implied_parent_map=implied_parent_map, | |
| ) | |
| selected_other_meta = { | |
| t: { | |
| "origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")), | |
| "preselected": True, | |
| } | |
| for t in selected_other | |
| } | |
| row_defs.append( | |
| { | |
| "name": "selected_other", | |
| "label": _display_row_label("selected_other"), | |
| "tags": selected_other, | |
| "tag_meta": selected_other_meta, | |
| } | |
| ) | |
| for row in ranked_rows: | |
| group_name = row.group_name | |
| group_tag_set = displayed_group_tag_sets.get(group_name, set()) | |
| selected_in_group_raw = [t for t in selected_active if t in group_tag_set] | |
| selected_in_group = _order_selected_tags_for_row( | |
| row_selected_tags=selected_in_group_raw, | |
| selected_index=selected_index, | |
| tag_selection_origins=tag_selection_origins, | |
| implied_parent_map=implied_parent_map, | |
| ) | |
| ranked_tags = [ | |
| t | |
| for t, _ in row.tags | |
| if not _is_artist_tag(t) and not _is_excluded_recommendation_tag(t) | |
| ] | |
| merged = selected_in_group + [t for t in ranked_tags if t not in selected_in_group] | |
| keep_n = max(max(1, int(top_tags_per_group)), len(selected_in_group)) | |
| merged = merged[:keep_n] | |
| tag_meta = { | |
| t: { | |
| "origin": _normalize_selection_origin(tag_selection_origins.get(t, "selection")), | |
| "preselected": t in selected_active, | |
| } | |
| for t in merged | |
| } | |
| row_defs.append( | |
| { | |
| "name": group_name, | |
| "label": _display_row_label(group_name), | |
| "tags": merged, | |
| "tag_meta": tag_meta, | |
| } | |
| ) | |
| # Keep this row at the bottom so category/group rows remain contiguous. | |
| if retrieved_other_row: | |
| row_defs.append(retrieved_other_row) | |
| return row_defs | |
| def _build_display_audit_line( | |
| row_defs: List[Dict[str, Any]], | |
| *, | |
| active_selected_tags: List[str], | |
| direct_selected_tags: List[str], | |
| implied_selected_tags: List[str], | |
| ) -> str: | |
| active_set = { | |
| _norm_tag_for_lookup(t) | |
| for t in (active_selected_tags or []) | |
| if t and not _is_artist_tag(t) | |
| } | |
| direct_set = { | |
| _norm_tag_for_lookup(t) | |
| for t in (direct_selected_tags or []) | |
| if t and not _is_artist_tag(t) | |
| } | |
| implied_set = { | |
| _norm_tag_for_lookup(t) | |
| for t in (implied_selected_tags or []) | |
| if t and not _is_artist_tag(t) | |
| } | |
| info_by_tag: Dict[str, Dict[str, Any]] = {} | |
| for row in row_defs or []: | |
| row_name = row.get("name", "") | |
| row_label = row.get("label", row_name) | |
| for tag in row.get("tags", []): | |
| rec = info_by_tag.setdefault(tag, {"rows": [], "sources": set()}) | |
| rec["rows"].append(row_label) | |
| if row_name == "selected_other": | |
| rec["sources"].add("selected_other_row") | |
| elif row_name == "other_retrieved": | |
| rec["sources"].add("other_retrieved_row") | |
| else: | |
| rec["sources"].add("ranked_group_row") | |
| if tag in active_set: | |
| rec["sources"].add("selected_active") | |
| if tag in direct_set: | |
| rec["sources"].add("selected_direct") | |
| if tag in implied_set: | |
| rec["sources"].add("selected_implied") | |
| payload = { | |
| "n_tags": len(info_by_tag), | |
| "tags": [ | |
| { | |
| "tag": tag, | |
| "rows": rec["rows"], | |
| "sources": sorted(rec["sources"]), | |
| } | |
| for tag, rec in sorted(info_by_tag.items()) | |
| ], | |
| } | |
| return "Display Tag Audit: " + json.dumps(payload, ensure_ascii=True) | |
| def _build_row_component_updates( | |
| row_defs: List[Dict[str, Any]], | |
| selected_tags: List[str], | |
| max_rows: int, | |
| ): | |
| selected = {t for t in (selected_tags or []) if t} | |
| row_values_state: List[List[str]] = [] | |
| header_updates = [] | |
| checkbox_updates = [] | |
| for idx in range(max_rows): | |
| if idx < len(row_defs): | |
| row = row_defs[idx] | |
| tags = list(dict.fromkeys(row.get("tags", []))) | |
| values = [t for t in tags if t in selected] | |
| row_values_state.append(values) | |
| visible = bool(tags) | |
| header_updates.append(gr.update(value=row.get("label", ""), visible=visible)) | |
| tag_meta = row.get("tag_meta", {}) if isinstance(row.get("tag_meta", {}), dict) else {} | |
| choices = [] | |
| for t in tags: | |
| meta = tag_meta.get(t, {}) if isinstance(tag_meta.get(t, {}), dict) else {} | |
| origin = _normalize_selection_origin(str(meta.get("origin", "selection"))) | |
| preselected = bool(meta.get("preselected", False)) | |
| choices.append((_choice_label_with_source_meta(t, origin=origin, preselected=preselected), t)) | |
| checkbox_updates.append( | |
| gr.update( | |
| choices=choices, | |
| value=values, | |
| visible=visible, | |
| ) | |
| ) | |
| else: | |
| header_updates.append(gr.update(value="", visible=False)) | |
| checkbox_updates.append(gr.update(choices=[], value=[], visible=False)) | |
| prompt_text = _compose_toggle_prompt_text(list(selected), row_defs) | |
| return prompt_text, row_values_state, header_updates, checkbox_updates | |
| def _on_toggle_row( | |
| row_idx: int, | |
| changed_values: List[str], | |
| selected_tags_state: List[str], | |
| row_defs_state: List[Dict[str, Any]], | |
| row_values_state: List[List[str]], | |
| max_rows: int, | |
| ): | |
| row_defs = row_defs_state or [] | |
| row_defs_ui = row_defs[: max(0, int(max_rows))] | |
| selected = set(selected_tags_state or []) | |
| row = row_defs_ui[row_idx] if 0 <= row_idx < len(row_defs_ui) else {} | |
| row_tags = list(dict.fromkeys(row.get("tags", []))) | |
| row_tag_set = set(row_tags) | |
| row_tag_by_norm = {_norm_tag_for_lookup(t): t for t in row_tags} | |
| # Be tolerant to UI payload forms: canonical tag values, display labels, or normalized variants. | |
| new_set: Set[str] = set() | |
| for raw in (changed_values or []): | |
| if raw in row_tag_set: | |
| new_set.add(raw) | |
| continue | |
| raw_norm = _norm_tag_for_lookup(str(raw)) | |
| mapped = row_tag_by_norm.get(raw_norm) | |
| if mapped: | |
| new_set.add(mapped) | |
| prev_values = list(row_values_state or []) | |
| prev_row_values = prev_values[row_idx] if 0 <= row_idx < len(prev_values) else [] | |
| prev_row_selected = set() | |
| for raw in (prev_row_values or []): | |
| if raw in row_tag_set: | |
| prev_row_selected.add(raw) | |
| continue | |
| raw_norm = _norm_tag_for_lookup(str(raw)) | |
| mapped = row_tag_by_norm.get(raw_norm) | |
| if mapped: | |
| prev_row_selected.add(mapped) | |
| # Ignore non-user/no-op events (e.g., programmatic value re-sets) deterministically. | |
| if new_set == prev_row_selected: | |
| prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui) | |
| checkbox_updates = [gr.skip() for _ in range(max_rows)] | |
| return [sorted(selected), prev_values, prompt_text, *checkbox_updates] | |
| selected.difference_update(row_tag_set) | |
| selected.update(new_set) | |
| toggled_tags = prev_row_selected ^ new_set | |
| new_row_values_state: List[List[str]] = [] | |
| affected_rows: Set[int] = {row_idx} | |
| for idx, row_item in enumerate(row_defs_ui): | |
| tags = list(dict.fromkeys(row_item.get("tags", []))) | |
| values = [t for t in tags if t in selected] | |
| new_row_values_state.append(values) | |
| if toggled_tags and any(t in toggled_tags for t in tags): | |
| affected_rows.add(idx) | |
| checkbox_updates = [] | |
| for idx in range(max_rows): | |
| if idx >= len(row_defs_ui): | |
| checkbox_updates.append(gr.skip()) | |
| continue | |
| if idx in affected_rows: | |
| checkbox_updates.append(gr.update(value=new_row_values_state[idx])) | |
| else: | |
| checkbox_updates.append(gr.skip()) | |
| prompt_text = _compose_toggle_prompt_text(sorted(selected), row_defs_ui) | |
| return [sorted(selected), new_row_values_state, prompt_text, *checkbox_updates] | |
| def _build_ui_payload( | |
| *, | |
| console_text: str, | |
| row_defs: List[Dict[str, Any]], | |
| selected_tags: List[str], | |
| ): | |
| prompt_text, row_values_state, header_updates, checkbox_updates = _build_row_component_updates( | |
| row_defs=row_defs, | |
| selected_tags=selected_tags, | |
| max_rows=display_max_rows_default, | |
| ) | |
| return [ | |
| console_text, | |
| gr.update(visible=bool(row_defs)), | |
| prompt_text, | |
| sorted(set(selected_tags or [])), | |
| row_defs, | |
| row_values_state, | |
| *header_updates, | |
| *checkbox_updates, | |
| ] | |
| def _prepare_run_ui() -> List[Any]: | |
| header_updates = [gr.update(value="", visible=False) for _ in range(display_max_rows_default)] | |
| checkbox_updates = [ | |
| gr.update(choices=[], value=[], visible=False) | |
| for _ in range(display_max_rows_default) | |
| ] | |
| return [ | |
| "Running...", | |
| gr.skip(), | |
| "Running... usually completes in about 20 seconds.", | |
| [], | |
| [], | |
| [], | |
| *header_updates, | |
| *checkbox_updates, | |
| ] | |
| def _build_selection_query( | |
| prompt_in: str, | |
| rewritten: str, | |
| structural_tags: List[str], | |
| probe_tags: List[str], | |
| ) -> str: | |
| lines = [f"IMAGE DESCRIPTION: {prompt_in.strip()}"] | |
| if rewritten and rewritten.strip(): | |
| lines.append(f"REWRITE PHRASES: {rewritten.strip()}") | |
| hint_tags = [] | |
| if structural_tags: | |
| hint_tags.extend(structural_tags) | |
| if probe_tags: | |
| hint_tags.extend(probe_tags) | |
| if hint_tags: | |
| # Keep hints as context only; selection still must choose by candidate indices. | |
| lines.append( | |
| "INFERRED TAG HINTS (context only): " + ", ".join(sorted(set(hint_tags))) | |
| ) | |
| return "\n".join(lines) | |
| # Set up logging | |
| # Minimal prod logging: warnings+ to stderr, no file by default | |
| import os, logging | |
| LOG_LEVEL = os.environ.get("PSQ_LOG_LEVEL", "WARNING").upper() | |
| logging.basicConfig( | |
| level=getattr(logging, LOG_LEVEL, logging.WARNING), | |
| format="%(asctime)s %(levelname)s:%(message)s", | |
| handlers=[logging.StreamHandler()] # no file -> avoids huge logs on Spaces | |
| ) | |
| # Quiet down common noisy libs (optional) | |
| for _name in ("gensim", "gradio", "hnswlib", "httpx", "uvicorn"): | |
| logging.getLogger(_name).setLevel(logging.ERROR) | |
| # Turn off Gradio analytics phone-home to avoid those background thread errors (optional) | |
| os.environ["GRADIO_ANALYTICS_ENABLED"] = "0" | |
| MASCOT_DIR = Path(__file__).parent / "mascotimages" | |
| MASCOT_FILE = MASCOT_DIR / "transparentsquirrel.png" | |
| def _load_mascot_image(): | |
| """Load mascot image if available; return None when missing/unreadable.""" | |
| if not MASCOT_FILE.exists(): | |
| logging.warning("Mascot image missing: %s", MASCOT_FILE) | |
| return None | |
| try: | |
| return Image.open(MASCOT_FILE).convert("RGBA") | |
| except Exception as e: | |
| logging.warning("Failed to load mascot image (%s): %s", MASCOT_FILE, e) | |
| return None | |
| try: | |
| from gradio_client import utils as _gc_utils | |
| _orig_get_type = _gc_utils.get_type | |
| _orig_j2p = _gc_utils._json_schema_to_python_type | |
| _orig_pub = _gc_utils.json_schema_to_python_type | |
| def _get_type_safe(schema): | |
| # Sometimes schema is a bare True/False (JSON Schema boolean form) | |
| if not isinstance(schema, dict): | |
| return "any" | |
| return _orig_get_type(schema) | |
| def _j2p_safe(schema, defs=None): | |
| # Accept non-dict schemas (True/False/None) and treat as "any" | |
| if not isinstance(schema, dict): | |
| return "any" | |
| return _orig_j2p(schema, defs or schema.get("$defs")) | |
| def _pub_safe(schema): | |
| # Public wrapper used by Gradio; keep it resilient too | |
| if not isinstance(schema, dict): | |
| return "any" | |
| return _j2p_safe(schema, schema.get("$defs")) | |
| _gc_utils.get_type = _get_type_safe | |
| _gc_utils._json_schema_to_python_type = _j2p_safe | |
| _gc_utils.json_schema_to_python_type = _pub_safe | |
| except Exception as e: | |
| print("gradio_client hotfix not applied:", e) | |
| # ------------------------------------------------------------------------------- | |
| allow_nsfw_tags = False | |
| def _is_production_runtime() -> bool: | |
| """Best-effort detection for deployed runtime (HF Spaces or explicit env).""" | |
| if os.environ.get("PSQ_PRODUCTION", "").strip().lower() in {"1", "true", "yes"}: | |
| return True | |
| if os.environ.get("SPACE_ID"): | |
| return True | |
| if os.environ.get("HF_SPACE_ID"): | |
| return True | |
| if os.environ.get("SYSTEM") == "spaces": | |
| return True | |
| return False | |
| verbose_retrieval_default = "0" if _is_production_runtime() else "1" | |
| verbose_retrieval = os.environ.get("PSQ_VERBOSE_RETRIEVAL", verbose_retrieval_default).strip().lower() in {"1", "true", "yes"} | |
| verbose_retrieval_all = False | |
| verbose_retrieval_limit = 20 | |
| enable_probe_tags = os.environ.get("PSQ_ENABLE_PROBE", "1").strip() not in {"0", "false", "False"} | |
| display_top_groups_default = int(os.environ.get("PSQ_DISPLAY_TOP_GROUPS", "10")) | |
| display_top_tags_per_group_default = int(os.environ.get("PSQ_DISPLAY_TOP_TAGS_PER_GROUP", "5")) | |
| display_rank_top_k_default = int(os.environ.get("PSQ_DISPLAY_GROUP_RANK_TOP_K", "5")) | |
| display_max_rows_default = int(os.environ.get("PSQ_DISPLAY_MAX_ROWS", "14")) | |
| retrieval_global_k = int(os.environ.get("PSQ_RETRIEVAL_GLOBAL_K", "300")) | |
| retrieval_per_phrase_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_K", "10")) | |
| retrieval_per_phrase_final_k = int(os.environ.get("PSQ_RETRIEVAL_PER_PHRASE_FINAL_K", "1")) | |
| selection_mode = os.environ.get("PSQ_SELECTION_MODE", "chunked_map_union").strip() | |
| selection_chunk_size = int(os.environ.get("PSQ_SELECTION_CHUNK_SIZE", "60")) | |
| selection_per_phrase_k = int(os.environ.get("PSQ_SELECTION_PER_PHRASE_K", "2")) | |
| selection_candidate_cap = int(os.environ.get("PSQ_SELECTION_CANDIDATE_CAP", "0")) | |
| stage1_rewrite_timeout_s = float(os.environ.get("PSQ_TIMEOUT_REWRITE_S", "45")) | |
| stage1_struct_timeout_s = float(os.environ.get("PSQ_TIMEOUT_STRUCT_S", "45")) | |
| stage1_probe_timeout_s = float(os.environ.get("PSQ_TIMEOUT_PROBE_S", "45")) | |
| stage3_select_timeout_s = float(os.environ.get("PSQ_TIMEOUT_SELECT_S", "45")) | |
| timing_log_path = Path(os.environ.get("PSQ_TIMING_LOG_PATH", "data/runtime_metrics/ui_pipeline_timings.jsonl")) | |
| css = """ | |
| .scrollable-content{ | |
| max-height: 420px; | |
| overflow-y: scroll; /* always show scrollbar */ | |
| overflow-x: hidden; | |
| padding-right: 8px; | |
| padding-bottom: 14px; /* <— add this */ | |
| scrollbar-gutter: stable; /* prevent layout shift as it fills */ | |
| /* Firefox */ | |
| scrollbar-width: auto; | |
| scrollbar-color: rgba(180,180,180,.9) rgba(0,0,0,.15); | |
| } | |
| /* WebKit/Chromium (Chrome/Edge/Safari) */ | |
| .scrollable-content::-webkit-scrollbar{ width: 10px; } | |
| .scrollable-content::-webkit-scrollbar-thumb{ background: rgba(180,180,180,.9); border-radius: 8px; } | |
| .scrollable-content::-webkit-scrollbar-track{ background: rgba(0,0,0,.15); } | |
| /* (Optional) make both scroll panes taller so they fill more of the column */ | |
| .pane-left .scrollable-content, | |
| .pane-right .scrollable-content { | |
| max-height: 610px; /* was 420px; tweak to taste */ | |
| } | |
| .lego-tags .gr-checkboxgroup, | |
| .lego-tags .wrap { | |
| display: flex !important; | |
| flex-wrap: wrap !important; | |
| gap: 10px !important; | |
| } | |
| .lego-tags label { | |
| margin: 0 !important; | |
| padding: 0 !important; | |
| position: relative !important; | |
| } | |
| /* Hide native checkbox visuals completely */ | |
| .lego-tags input[type="checkbox"] { | |
| appearance: none !important; | |
| -webkit-appearance: none !important; | |
| -moz-appearance: none !important; | |
| position: absolute !important; | |
| width: 1px !important; | |
| height: 1px !important; | |
| opacity: 0 !important; | |
| pointer-events: none !important; | |
| display: none !important; | |
| } | |
| /* Brick button skin (works for both +span and ~span structures) */ | |
| .lego-tags input[type="checkbox"] + span, | |
| .lego-tags input[type="checkbox"] ~ span { | |
| --on-bg1: #ffd166; | |
| --on-bg2: #f39c4a; | |
| --on-border: #b86e21; | |
| --on-text: #2e1706; | |
| position: relative !important; | |
| display: inline-flex !important; | |
| align-items: center !important; | |
| min-height: 40px !important; | |
| padding: 10px 15px 9px 22px !important; | |
| border: 1px solid #9aa6b8 !important; | |
| border-radius: 10px !important; | |
| background: linear-gradient(180deg, #dfe5ee 0%, #bec8d6 100%) !important; | |
| color: #364254 !important; | |
| font-size: 0.97rem !important; | |
| font-weight: 800 !important; | |
| line-height: 1.15 !important; | |
| cursor: pointer !important; | |
| user-select: none !important; | |
| letter-spacing: 0.01em !important; | |
| box-shadow: 0 3px 0 rgba(0,0,0,0.16), inset 0 1px 0 rgba(255,255,255,0.55) !important; | |
| transition: transform 0.08s ease, box-shadow 0.08s ease, filter 0.08s ease !important; | |
| } | |
| .lego-tags input[type="checkbox"] + span::before, | |
| .lego-tags input[type="checkbox"] ~ span::before { | |
| content: "" !important; | |
| position: absolute !important; | |
| top: 5px !important; | |
| left: 8px !important; | |
| width: 8px !important; | |
| height: 8px !important; | |
| border-radius: 50% !important; | |
| background: rgba(255,255,255,0.58) !important; | |
| box-shadow: 22px 0 0 rgba(255,255,255,0.58) !important; | |
| pointer-events: none !important; | |
| } | |
| /* Unselected cue: show "+" on the left. */ | |
| .lego-tags input[type="checkbox"] + span::after, | |
| .lego-tags input[type="checkbox"] ~ span::after { | |
| content: "+" !important; | |
| position: absolute !important; | |
| left: 6px !important; | |
| top: 50% !important; | |
| transform: translateY(-52%) !important; | |
| font-size: 1rem !important; | |
| font-weight: 900 !important; | |
| color: #4b5563 !important; | |
| opacity: 0.95 !important; | |
| pointer-events: none !important; | |
| } | |
| /* Bright color cycle used only when selected */ | |
| .lego-tags label:nth-child(8n+1) span { --on-bg1: #ffd166; --on-bg2: #f39c4a; --on-border: #b86e21; --on-text: #2e1706; } | |
| .lego-tags label:nth-child(8n+2) span { --on-bg1: #6ee7ff; --on-bg2: #1fb7ff; --on-border: #157cb3; --on-text: #07263c; } | |
| .lego-tags label:nth-child(8n+3) span { --on-bg1: #9dff8f; --on-bg2: #45c96f; --on-border: #2a8b4b; --on-text: #0d2917; } | |
| .lego-tags label:nth-child(8n+4) span { --on-bg1: #ff8fab; --on-bg2: #ff5c7a; --on-border: #b83956; --on-text: #3f0f1d; } | |
| .lego-tags label:nth-child(8n+5) span { --on-bg1: #d0a8ff; --on-bg2: #a46cff; --on-border: #7147b3; --on-text: #25143f; } | |
| .lego-tags label:nth-child(8n+6) span { --on-bg1: #ffe27a; --on-bg2: #f7bf39; --on-border: #ad7f1f; --on-text: #332407; } | |
| .lego-tags label:nth-child(8n+7) span { --on-bg1: #8effd5; --on-bg2: #2ed6b5; --on-border: #1e947d; --on-text: #0d2a25; } | |
| .lego-tags label:nth-child(8n+8) span { --on-bg1: #ffb47e; --on-bg2: #ff8753; --on-border: #b95b2d; --on-text: #391a0a; } | |
| /* Source-driven selected colors (applies when tags are preselected by the pipeline). */ | |
| .lego-tags label[data-psq-preselected="1"][data-psq-origin="rewrite"] span { | |
| --on-bg1: #77f0d7; | |
| --on-bg2: #26b9a3; | |
| --on-border: #187869; | |
| --on-text: #062923; | |
| } | |
| .lego-tags label[data-psq-preselected="1"][data-psq-origin="selection"] span { | |
| --on-bg1: #ffd98a; | |
| --on-bg2: #f0a93c; | |
| --on-border: #a66f1f; | |
| --on-text: #382206; | |
| } | |
| .lego-tags label[data-psq-preselected="1"][data-psq-origin="probe"] span { | |
| --on-bg1: #d8b4ff; | |
| --on-bg2: #9a6cff; | |
| --on-border: #6745b0; | |
| --on-text: #24143b; | |
| } | |
| .lego-tags label[data-psq-preselected="1"][data-psq-origin="structural"] span { | |
| --on-bg1: #a6f79a; | |
| --on-bg2: #53c368; | |
| --on-border: #2f8442; | |
| --on-text: #102d17; | |
| } | |
| .lego-tags label[data-psq-preselected="1"][data-psq-origin="implied"] span { | |
| --on-bg1: #d7dde8; | |
| --on-bg2: #a8b3c4; | |
| --on-border: #6f7e95; | |
| --on-text: #1d2633; | |
| } | |
| /* User-selected tags (not initially selected by the pipeline). */ | |
| .lego-tags label[data-psq-preselected="0"] span { | |
| --on-bg1: #9ec5ff; | |
| --on-bg2: #4f86ff; | |
| --on-border: #2f5fbf; | |
| --on-text: #0b1f42; | |
| } | |
| .lego-tags label:hover span { | |
| filter: brightness(1.02) !important; | |
| transform: translateY(1px) !important; | |
| } | |
| /* ON state: brighter + visibly recessed */ | |
| .lego-tags input[type="checkbox"]:checked + span, | |
| .lego-tags input[type="checkbox"]:checked ~ span, | |
| .lego-tags label:has(input[type="checkbox"]:checked) span { | |
| background: linear-gradient(180deg, var(--on-bg1) 0%, var(--on-bg2) 100%) !important; | |
| color: var(--on-text) !important; | |
| border-color: var(--on-border) !important; | |
| filter: saturate(1.2) brightness(1.12) !important; | |
| transform: translateY(-2px) !important; | |
| box-shadow: | |
| inset 0 3px 6px rgba(0,0,0,0.20), | |
| inset 0 -1px 0 rgba(255,255,255,0.36), | |
| 0 6px 0 rgba(0,0,0,0.32) !important; | |
| } | |
| .lego-tags input[type="checkbox"]:checked + span::after, | |
| .lego-tags input[type="checkbox"]:checked ~ span::after, | |
| .lego-tags label:has(input[type="checkbox"]:checked) span::after { | |
| content: "" !important; | |
| } | |
| .source-legend { | |
| display: flex; | |
| flex-wrap: wrap; | |
| align-items: center; | |
| gap: 8px; | |
| margin: 4px 0 10px 0; | |
| } | |
| .source-legend .legend-title { | |
| font-size: 0.92rem; | |
| font-weight: 900; | |
| color: #334155; | |
| margin-right: 4px; | |
| } | |
| .source-legend .chip { | |
| display: inline-flex; | |
| align-items: center; | |
| border-radius: 10px; | |
| border: 1px solid #6c7788; | |
| padding: 6px 12px; | |
| font-size: 0.85rem; | |
| font-weight: 800; | |
| color: #111827; | |
| background: #f3f6fb; | |
| } | |
| .source-legend .chip.rewrite { background: #26b9a3; color: #062923; border-color: #187869; } | |
| .source-legend .chip.selection { background: #f0a93c; color: #382206; border-color: #a66f1f; } | |
| .source-legend .chip.probe { background: #9a6cff; color: #ffffff; border-color: #6745b0; } | |
| .source-legend .chip.structural { background: #53c368; color: #102d17; border-color: #2f8442; } | |
| .source-legend .chip.implied { background: #a8b3c4; color: #1d2633; border-color: #6f7e95; } | |
| .source-legend .chip.user { background: #4f86ff; color: #ffffff; border-color: #2f5fbf; } | |
| .source-legend .chip.unselected { background: #c7ced8; color: #2d3440; border-color: #7d8897; } | |
| .row-heading p { | |
| margin: 8px 0 0 0 !important; | |
| font-size: 1.18rem !important; | |
| font-weight: 850 !important; | |
| line-height: 1.2 !important; | |
| } | |
| .row-instruction { | |
| text-align: center; | |
| margin: 8px 0 12px 0; | |
| } | |
| .row-instruction p { | |
| margin: 0 !important; | |
| font-size: 1.02rem !important; | |
| font-style: italic !important; | |
| font-weight: 800 !important; | |
| color: #1d4ed8 !important; | |
| } | |
| .top-instruction { | |
| text-align: center; | |
| margin: 2px 0 6px 0; | |
| } | |
| .top-instruction p { | |
| margin: 0 !important; | |
| font-size: 1.02rem !important; | |
| font-style: italic !important; | |
| font-weight: 800 !important; | |
| color: #1d4ed8 !important; | |
| } | |
| .run-hint { | |
| margin-top: 6px; | |
| text-align: center; | |
| } | |
| .run-hint p { | |
| margin: 0 !important; | |
| font-size: 0.9rem !important; | |
| font-style: italic !important; | |
| color: #475569 !important; | |
| } | |
| .prompt-card { | |
| background: transparent !important; | |
| border: none !important; | |
| box-shadow: none !important; | |
| padding: 0 !important; | |
| } | |
| .suggested-prompt-box { | |
| margin-top: 2px !important; | |
| } | |
| .suggested-prompt-card { | |
| margin-top: 10px !important; | |
| } | |
| """ | |
| client_js = """ | |
| () => { | |
| const markerRe = /\\s*\\[\\[psq:([a-z_]+):(0|1):([A-Za-z0-9_\\-=]*)\\]\\]\\s*$/; | |
| const decodeTip = (b64) => { | |
| if (!b64) return ""; | |
| try { | |
| const binary = atob((b64 || "").replace(/-/g, "+").replace(/_/g, "/")); | |
| const bytes = Uint8Array.from(binary, (c) => c.charCodeAt(0)); | |
| return new TextDecoder("utf-8").decode(bytes); | |
| } catch (_) { | |
| return ""; | |
| } | |
| }; | |
| const applyTagMeta = () => { | |
| const labels = document.querySelectorAll(".lego-tags label"); | |
| labels.forEach((label) => { | |
| const span = label.querySelector("span"); | |
| if (!span) return; | |
| const text = span.textContent || ""; | |
| const match = text.match(markerRe); | |
| if (!match) return; | |
| label.dataset.psqOrigin = match[1]; | |
| label.dataset.psqPreselected = match[2]; | |
| const tip = decodeTip(match[3] || ""); | |
| if (tip) { | |
| label.title = tip; | |
| span.title = tip; | |
| } else { | |
| label.removeAttribute("title"); | |
| span.removeAttribute("title"); | |
| } | |
| span.textContent = text.replace(markerRe, ""); | |
| }); | |
| }; | |
| applyTagMeta(); | |
| const observer = new MutationObserver(() => applyTagMeta()); | |
| observer.observe(document.body, { childList: true, subtree: true, characterData: true }); | |
| } | |
| """ | |
| def rag_pipeline_ui( | |
| user_prompt: str, | |
| display_top_groups: float, | |
| display_top_tags_per_group: float, | |
| display_rank_top_k: float, | |
| ): | |
| logs = [] | |
| def log(s): logs.append(s) | |
| try: | |
| stage_timings = {} | |
| def _record_timing(stage: str, dt_s: float): | |
| stage_timings[stage] = float(dt_s) | |
| def _emit_timing_summary(total_s: float): | |
| summary_order = [ | |
| "preprocess", | |
| "rewrite", | |
| "structural", | |
| "probe", | |
| "retrieval", | |
| "selection", | |
| "implication_expansion", | |
| "prompt_composition", | |
| "group_display", | |
| ] | |
| lines = [] | |
| for k in summary_order: | |
| if k in stage_timings: | |
| lines.append(f"{k}={stage_timings[k]:.2f}s") | |
| slowest = max(stage_timings.items(), key=lambda kv: kv[1])[0] if stage_timings else "n/a" | |
| log("Timing Summary: " + ", ".join(lines)) | |
| log(f"Timing Slowest Stage: {slowest}") | |
| log(f"Timing Total: {total_s:.2f}s") | |
| def _append_timing_jsonl(total_s: float): | |
| try: | |
| timing_log_path.parent.mkdir(parents=True, exist_ok=True) | |
| rec = { | |
| "timestamp_utc": datetime.utcnow().isoformat(timespec="seconds") + "Z", | |
| "stages_s": stage_timings, | |
| "total_s": float(total_s), | |
| "config": { | |
| "timeout_rewrite_s": stage1_rewrite_timeout_s, | |
| "timeout_struct_s": stage1_struct_timeout_s, | |
| "timeout_probe_s": stage1_probe_timeout_s, | |
| "timeout_select_s": stage3_select_timeout_s, | |
| }, | |
| } | |
| with timing_log_path.open("a", encoding="utf-8") as f: | |
| f.write(json.dumps(rec, ensure_ascii=True) + "\n") | |
| log(f"Timing Log: wrote {timing_log_path}") | |
| except Exception as e: | |
| log(f"Timing Log: failed ({type(e).__name__}: {e})") | |
| def _future_with_timeout(fut, timeout_s: float, stage_name: str, fallback): | |
| t0 = time.perf_counter() | |
| try: | |
| out = fut.result(timeout=max(1.0, float(timeout_s))) | |
| dt = time.perf_counter() - t0 | |
| log(f"{stage_name}: {dt:.2f}s") | |
| stage_key = { | |
| "Rewrite": "rewrite", | |
| "Structural inference": "structural", | |
| "Probe inference": "probe", | |
| "Index selection": "selection", | |
| }.get(stage_name) | |
| if stage_key: | |
| _record_timing(stage_key, dt) | |
| return out | |
| except FutureTimeoutError: | |
| fut.cancel() | |
| log(f"{stage_name}: timed out after {timeout_s:.0f}s; using fallback") | |
| return fallback | |
| except Exception as e: | |
| log(f"{stage_name}: failed ({type(e).__name__}: {e}); using fallback") | |
| return fallback | |
| t_total0 = time.perf_counter() | |
| log("Start: received prompt") | |
| prompt_in = (user_prompt or "").strip() | |
| if not prompt_in: | |
| return _build_ui_payload( | |
| console_text="Error: empty prompt", | |
| row_defs=[], | |
| selected_tags=[], | |
| ) | |
| log("Input:") | |
| log(prompt_in) | |
| log("") | |
| log( | |
| "Runtime config: " | |
| f"retrieval_global_k={retrieval_global_k} " | |
| f"retrieval_per_phrase_k={retrieval_per_phrase_k} " | |
| f"retrieval_per_phrase_final_k={retrieval_per_phrase_final_k} " | |
| f"selection_mode={selection_mode} " | |
| f"selection_chunk_size={selection_chunk_size} " | |
| f"selection_per_phrase_k={selection_per_phrase_k}" | |
| ) | |
| log("") | |
| t0 = time.perf_counter() | |
| user_tags = extract_user_provided_tags_upto_3_words(prompt_in) | |
| dt = time.perf_counter()-t0 | |
| _record_timing("preprocess", dt) | |
| log(f"Preprocess (user tag extraction): {dt:.2f}s") | |
| log("Heuristically extracted user tags:") | |
| if user_tags: | |
| log(", ".join(user_tags)) | |
| else: | |
| log("(none)") | |
| log("") | |
| log("Step 1: LLM rewrite + structural inference + probe (concurrent)") | |
| max_workers = 3 if enable_probe_tags else 2 | |
| with ThreadPoolExecutor(max_workers=max_workers) as ex: | |
| fut_rewrite = ex.submit(llm_rewrite_prompt, prompt_in, log) | |
| fut_struct = ex.submit(llm_infer_structural_tags, prompt_in, log=log) | |
| fut_probe = ex.submit(llm_infer_probe_tags, prompt_in, log=log) if enable_probe_tags else None | |
| rewritten = _future_with_timeout( | |
| fut_rewrite, stage1_rewrite_timeout_s, "Rewrite", prompt_in | |
| ) | |
| structural_tags = _future_with_timeout( | |
| fut_struct, stage1_struct_timeout_s, "Structural inference", [] | |
| ) | |
| probe_tags = ( | |
| _future_with_timeout(fut_probe, stage1_probe_timeout_s, "Probe inference", []) | |
| if fut_probe else [] | |
| ) | |
| log("Rewrite:") | |
| log(rewritten if rewritten else "(empty)") | |
| log("") | |
| rewrite_for_retrieval = rewritten | |
| if user_tags: | |
| # keep them separate in logs, but allow them to help retrieval | |
| rewrite_for_retrieval = (rewrite_for_retrieval + ", " + ", ".join(user_tags)).strip(", ").strip() | |
| log("Step 2: Prompt Squirrel retrieval (hidden)") | |
| try: | |
| t0 = time.perf_counter() | |
| retrieval_context_tags = list(dict.fromkeys((structural_tags or []) + (probe_tags or []))) | |
| rewrite_phrases = [p.strip() for p in (rewrite_for_retrieval or "").split(",") if p.strip()] | |
| retrieval_result = psq_candidates_from_rewrite_phrases( | |
| rewrite_phrases=rewrite_phrases, | |
| allow_nsfw_tags=allow_nsfw_tags, | |
| context_tags=retrieval_context_tags, | |
| global_k=max(1, retrieval_global_k), | |
| per_phrase_k=max(1, retrieval_per_phrase_k), | |
| per_phrase_final_k=max(1, retrieval_per_phrase_final_k), | |
| verbose=verbose_retrieval, | |
| ) | |
| if isinstance(retrieval_result, tuple): | |
| candidates, phrase_reports = retrieval_result | |
| else: | |
| candidates, phrase_reports = retrieval_result, [] | |
| if selection_candidate_cap > 0 and len(candidates) > selection_candidate_cap: | |
| candidates = candidates[:selection_candidate_cap] | |
| log(f"Selection candidate cap applied: {selection_candidate_cap}") | |
| dt = time.perf_counter()-t0 | |
| _record_timing("retrieval", dt) | |
| log(f"Retrieval: {dt:.2f}s") | |
| log(f"Retrieved {len(candidates)} candidate tags") | |
| if verbose_retrieval: | |
| log(f"Total unique candidates: {len(candidates)}") | |
| limit = None if verbose_retrieval_all else max(1, int(verbose_retrieval_limit)) | |
| for report in phrase_reports: | |
| phrase = report.get("normalized") or report.get("phrase") or "" | |
| lookup = report.get("lookup") or "" | |
| tfidf_vocab = report.get("tfidf_vocab") | |
| log(f"Phrase: {phrase} (lookup={lookup}) tfidf_vocab={tfidf_vocab}") | |
| rows = report.get("candidates", []) | |
| shown = rows if limit is None else rows[:limit] | |
| for row in shown: | |
| tag = row.get("tag") | |
| alias_token = row.get("alias_token") | |
| score_fasttext = row.get("score_fasttext") | |
| score_context = row.get("score_context") | |
| score_combined = row.get("score_combined") | |
| count = row.get("count") | |
| alias_part = "" | |
| if alias_token and alias_token != tag: | |
| alias_part = f" [alias_token={alias_token}]" | |
| fasttext_str = ( | |
| f"{score_fasttext:.3f}" if isinstance(score_fasttext, (int, float)) else score_fasttext | |
| ) | |
| if score_context is None: | |
| context_str = "None" | |
| else: | |
| context_str = ( | |
| f"{score_context:.3f}" if isinstance(score_context, (int, float)) else score_context | |
| ) | |
| combined_str = ( | |
| f"{score_combined:.3f}" if isinstance(score_combined, (int, float)) else score_combined | |
| ) | |
| log( | |
| f" {tag}{alias_part} | fasttext={fasttext_str} context={context_str} " | |
| f"combined={combined_str} count={count}" | |
| ) | |
| if limit is not None and len(rows) > limit: | |
| log(f" ... ({len(rows) - limit} more)") | |
| except Exception as e: | |
| log(f"Retrieval fallback: {type(e).__name__}: {e}") | |
| candidates = [] | |
| retrieved_candidate_tags = list( | |
| dict.fromkeys( | |
| _norm_tag_for_lookup(c.tag) | |
| for c in (candidates or []) | |
| if getattr(c, "tag", None) | |
| ) | |
| ) | |
| log("Step 3: LLM index selection (uses rewrite + structural/probe context)") | |
| selection_query = _build_selection_query( | |
| prompt_in=prompt_in, | |
| rewritten=rewritten, | |
| structural_tags=structural_tags, | |
| probe_tags=probe_tags, | |
| ) | |
| with ThreadPoolExecutor(max_workers=1) as ex: | |
| fut_sel = ex.submit( | |
| llm_select_indices, | |
| query_text=selection_query, | |
| candidates=candidates, | |
| max_pick=0, | |
| log=log, | |
| mode=selection_mode, | |
| chunk_size=max(1, selection_chunk_size), | |
| per_phrase_k=max(1, selection_per_phrase_k), | |
| ) | |
| picked_indices = _future_with_timeout( | |
| fut_sel, stage3_select_timeout_s, "Index selection", [] | |
| ) | |
| selection_selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else [] | |
| selected_tags = list(selection_selected_tags) | |
| if structural_tags: | |
| # Add structural tags that aren't already selected | |
| existing = {t for t in selected_tags} | |
| new_structural = [t for t in structural_tags if t not in existing] | |
| selected_tags.extend(new_structural) | |
| log(f" Added {len(new_structural)} structural tags: {', '.join(new_structural)}") | |
| else: | |
| log(" No structural tags inferred") | |
| if probe_tags: | |
| existing = {t for t in selected_tags} | |
| new_probe = [t for t in probe_tags if t not in existing] | |
| selected_tags.extend(new_probe) | |
| log(f" Added {len(new_probe)} probe tags: {', '.join(new_probe)}") | |
| elif enable_probe_tags: | |
| log(" No probe tags inferred") | |
| selected_tags, removed_excluded_direct = _filter_excluded_recommendation_tags(selected_tags) | |
| if removed_excluded_direct: | |
| log(f" Removed {len(removed_excluded_direct)} excluded tags: {', '.join(removed_excluded_direct)}") | |
| direct_selected_tags = list(dict.fromkeys(selected_tags)) | |
| log("Step 3c: Expand via tag implications") | |
| t0 = time.perf_counter() | |
| tag_set = set(selected_tags) | |
| expanded, implied_only = expand_tags_via_implications(tag_set) | |
| dt = time.perf_counter()-t0 | |
| _record_timing("implication_expansion", dt) | |
| log(f"Implication expansion: {dt:.2f}s") | |
| implied_selected_tags = sorted(implied_only) if implied_only else [] | |
| if implied_only: | |
| selected_tags.extend(sorted(implied_only)) | |
| log(f" Added {len(implied_only)} implied tags: {', '.join(sorted(implied_only))}") | |
| else: | |
| log(" No additional implied tags") | |
| selected_tags, removed_excluded_implied = _filter_excluded_recommendation_tags(selected_tags) | |
| implied_selected_tags = [ | |
| t for t in implied_selected_tags if not _is_excluded_recommendation_tag(t) | |
| ] | |
| if removed_excluded_implied: | |
| log( | |
| f" Removed {len(removed_excluded_implied)} excluded tags after implications: " | |
| f"{', '.join(removed_excluded_implied)}" | |
| ) | |
| log("Step 4: Compose final prompt") | |
| t0 = time.perf_counter() | |
| final_prompt = compose_final_prompt(rewritten, selected_tags) | |
| dt = time.perf_counter()-t0 | |
| _record_timing("prompt_composition", dt) | |
| log(f"Prompt composition: {dt:.2f}s") | |
| log("Step 5: Build ranked group/category display") | |
| t0 = time.perf_counter() | |
| seed_terms = [] | |
| seed_terms.extend(user_tags) | |
| seed_terms.extend([p.strip() for p in (rewritten or "").split(",") if p.strip()]) | |
| seed_terms.extend(structural_tags or []) | |
| seed_terms.extend(probe_tags or []) | |
| seed_terms.extend(selected_tags) | |
| seed_terms = list(dict.fromkeys(seed_terms)) | |
| active_selected_tags = list(dict.fromkeys(selected_tags)) | |
| structural_set = {_norm_tag_for_lookup(t) for t in (structural_tags or []) if t} | |
| probe_set = {_norm_tag_for_lookup(t) for t in (probe_tags or []) if t} | |
| implied_set = {_norm_tag_for_lookup(t) for t in (implied_selected_tags or []) if t} | |
| rewrite_set = { | |
| _norm_tag_for_lookup(t) | |
| for t in (list(user_tags or []) + [p.strip() for p in (rewritten or "").split(",") if p.strip()]) | |
| if t | |
| } | |
| selection_set = {_norm_tag_for_lookup(t) for t in (selection_selected_tags or []) if t} | |
| tag_selection_origins: Dict[str, str] = {} | |
| for tag in active_selected_tags: | |
| tag_norm = _norm_tag_for_lookup(tag) | |
| if tag_norm in structural_set: | |
| origin = "structural" | |
| elif tag_norm in probe_set: | |
| origin = "probe" | |
| elif tag_norm in rewrite_set: | |
| origin = "rewrite" | |
| elif tag_norm in selection_set: | |
| origin = "selection" | |
| elif tag_norm in implied_set: | |
| origin = "implied" | |
| else: | |
| # Unknown/fallback tags use selection color. | |
| origin = "selection" | |
| tag_selection_origins[tag] = origin | |
| if tag_norm and tag_norm != tag: | |
| tag_selection_origins[tag_norm] = origin | |
| direct_tags_for_implied = list( | |
| dict.fromkeys(_norm_tag_for_lookup(t) for t in (direct_selected_tags or []) if t) | |
| ) | |
| direct_tags_for_implied_idx = {t: i for i, t in enumerate(direct_tags_for_implied)} | |
| direct_tags_for_implied.sort( | |
| key=lambda t: ( | |
| _selection_source_rank(tag_selection_origins.get(t, "selection")), | |
| direct_tags_for_implied_idx.get(t, 10**9), | |
| ) | |
| ) | |
| implied_parent_map = _build_implied_parent_map( | |
| direct_tags_ordered=direct_tags_for_implied, | |
| implied_tags=implied_selected_tags, | |
| ) | |
| toggle_rows = _build_toggle_rows( | |
| seed_terms=seed_terms, | |
| selected_tags=active_selected_tags, | |
| retrieved_candidate_tags=retrieved_candidate_tags, | |
| tag_selection_origins=tag_selection_origins, | |
| implied_parent_map=implied_parent_map, | |
| top_groups=max(1, int(display_top_groups)), | |
| top_tags_per_group=max(1, int(display_top_tags_per_group)), | |
| group_rank_top_k=max(1, int(display_rank_top_k)), | |
| ) | |
| dt = time.perf_counter()-t0 | |
| _record_timing("group_display", dt) | |
| log(f"Ranked group display: {dt:.2f}s ({len(toggle_rows)} rows)") | |
| log( | |
| _build_display_audit_line( | |
| toggle_rows, | |
| active_selected_tags=active_selected_tags, | |
| direct_selected_tags=direct_selected_tags, | |
| implied_selected_tags=implied_selected_tags, | |
| ) | |
| ) | |
| total_dt = time.perf_counter()-t_total0 | |
| _emit_timing_summary(total_dt) | |
| _append_timing_jsonl(total_dt) | |
| log("Done: final prompt ready") | |
| return _build_ui_payload( | |
| console_text="\n".join(logs), | |
| row_defs=toggle_rows, | |
| selected_tags=active_selected_tags, | |
| ) | |
| except Exception as e: | |
| log(f"Error: {type(e).__name__}: {e}") | |
| return _build_ui_payload( | |
| console_text="\n".join(logs), | |
| row_defs=[], | |
| selected_tags=[], | |
| ) | |
| with gr.Blocks(css=css, js=client_js) as app: | |
| with gr.Row(): | |
| with gr.Column(scale=3, elem_classes=["prompt-col"]): | |
| gr.Markdown( | |
| 'Describe your image under "Enter Prompt" and click "Run". ' | |
| 'Prompt Squirrel will translate it into image board tags.', | |
| elem_classes=["top-instruction"], | |
| ) | |
| with gr.Group(elem_classes=["prompt-card"]): | |
| image_tags = gr.Textbox( | |
| label="Enter Prompt", | |
| placeholder="e.g. fox, outside, detailed background, .", | |
| lines=1, | |
| elem_classes=["enter-prompt-box"], | |
| ) | |
| with gr.Group(elem_classes=["prompt-card", "suggested-prompt-card"]): | |
| suggested_prompt = gr.Textbox( | |
| label="Suggested Prompt (Read-only)", | |
| lines=2, | |
| interactive=False, | |
| show_copy_button=True, | |
| placeholder='Suggested prompt will appear here after you click "Run".', | |
| elem_classes=["suggested-prompt-box"], | |
| ) | |
| with gr.Column(scale=1): | |
| _mascot_pil = _load_mascot_image() | |
| if _mascot_pil is not None: | |
| mascot_img = gr.Image( | |
| value=_mascot_pil, | |
| show_label=False, | |
| interactive=False, | |
| height=240, | |
| elem_id="mascot" | |
| ) | |
| else: | |
| mascot_img = gr.Markdown("`(mascot image unavailable)`") | |
| submit_button = gr.Button("Run", variant="primary") | |
| gr.Markdown("Typical runtime: up to ~20 seconds.", elem_classes=["run-hint"]) | |
| selected_tags_state = gr.State([]) | |
| row_defs_state = gr.State([]) | |
| row_values_state = gr.State([]) | |
| toggle_instruction = gr.Markdown( | |
| "Click tag buttons to add or remove tags from the suggested prompt.", | |
| elem_classes=["row-instruction"], | |
| visible=False, | |
| ) | |
| row_headers: List[gr.Markdown] = [] | |
| row_checkboxes: List[gr.CheckboxGroup] = [] | |
| for _ in range(display_max_rows_default): | |
| with gr.Row(): | |
| with gr.Column(scale=2, min_width=170): | |
| row_headers.append(gr.Markdown(value="", visible=False, elem_classes=["row-heading"])) | |
| with gr.Column(scale=10): | |
| row_checkboxes.append( | |
| gr.CheckboxGroup( | |
| choices=[], | |
| value=[], | |
| visible=False, | |
| interactive=True, | |
| container=False, | |
| elem_classes=["lego-tags"], | |
| ) | |
| ) | |
| gr.HTML( | |
| """ | |
| <div class="source-legend"> | |
| <span class="legend-title">Legend:</span> | |
| <span class="chip rewrite">Rewrite phrase</span> | |
| <span class="chip selection">General selection</span> | |
| <span class="chip probe">Probe query</span> | |
| <span class="chip structural">Structural query</span> | |
| <span class="chip implied">Implied</span> | |
| <span class="chip user">User-toggled</span> | |
| <span class="chip unselected">Unselected</span> | |
| </div> | |
| """ | |
| ) | |
| with gr.Accordion("Display Settings", open=False): | |
| with gr.Row(): | |
| display_top_groups = gr.Number( | |
| value=display_top_groups_default, | |
| precision=0, | |
| label="Rows (Top Groups/Categories)", | |
| minimum=1, | |
| ) | |
| display_top_tags_per_group = gr.Number( | |
| value=display_top_tags_per_group_default, | |
| precision=0, | |
| label="Top Tags Shown Per Row", | |
| minimum=1, | |
| ) | |
| display_rank_top_k = gr.Number( | |
| value=display_rank_top_k_default, | |
| precision=0, | |
| label="Top Tags Used for Row Ranking", | |
| minimum=1, | |
| ) | |
| with gr.Accordion("Console", open=False): | |
| console = gr.Textbox( | |
| label="Console", | |
| lines=10, | |
| interactive=False, | |
| placeholder="Progress logs will appear here." | |
| ) | |
| run_outputs = [ | |
| console, | |
| toggle_instruction, | |
| suggested_prompt, | |
| selected_tags_state, | |
| row_defs_state, | |
| row_values_state, | |
| *row_headers, | |
| *row_checkboxes, | |
| ] | |
| submit_button.click( | |
| _prepare_run_ui, | |
| inputs=[], | |
| outputs=run_outputs, | |
| queue=False, | |
| show_progress="hidden", | |
| ).then( | |
| rag_pipeline_ui, | |
| inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k], | |
| outputs=run_outputs, | |
| ) | |
| image_tags.submit( | |
| _prepare_run_ui, | |
| inputs=[], | |
| outputs=run_outputs, | |
| queue=False, | |
| show_progress="hidden", | |
| ).then( | |
| rag_pipeline_ui, | |
| inputs=[image_tags, display_top_groups, display_top_tags_per_group, display_rank_top_k], | |
| outputs=run_outputs, | |
| ) | |
| for idx, row_cb in enumerate(row_checkboxes): | |
| row_cb.select( | |
| fn=lambda changed_values, selected_state, row_defs, row_values, i=idx: _on_toggle_row( | |
| i, | |
| changed_values, | |
| selected_state, | |
| row_defs, | |
| row_values, | |
| display_max_rows_default, | |
| ), | |
| inputs=[row_cb, selected_tags_state, row_defs_state, row_values_state], | |
| outputs=[selected_tags_state, row_values_state, suggested_prompt, *row_checkboxes], | |
| queue=False, | |
| show_progress="hidden", | |
| ) | |
| if __name__ == "__main__": | |
| app.queue().launch(allowed_paths=[str(MASCOT_DIR)]) | |