import gradio as gr import os import logging from PIL import Image from pathlib import Path from typing import List from concurrent.futures import ThreadPoolExecutor from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words from psq_rag.llm.rewrite import llm_rewrite_prompt from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags, llm_infer_probe_tags from psq_rag.retrieval.state import expand_tags_via_implications def _split_prompt_commas(s: str) -> List[str]: return [p.strip() for p in (s or "").split(",") if p.strip()] def _norm_for_dedupe(tag: str) -> str: # your canonical form for lookup/dedupe return _norm_tag_for_lookup(tag.lower()) def compose_final_prompt(rewritten_prompt: str, selected_tags: List[str]) -> str: parts = _split_prompt_commas(rewritten_prompt) parts.extend(selected_tags) seen = set() out = [] for p in parts: key = _norm_for_dedupe(p) if key in seen: continue seen.add(key) out.append(p) return ", ".join(out) def _build_selection_query( prompt_in: str, rewritten: str, structural_tags: List[str], probe_tags: List[str], ) -> str: lines = [f"IMAGE DESCRIPTION: {prompt_in.strip()}"] if rewritten and rewritten.strip(): lines.append(f"REWRITE PHRASES: {rewritten.strip()}") hint_tags = [] if structural_tags: hint_tags.extend(structural_tags) if probe_tags: hint_tags.extend(probe_tags) if hint_tags: # Keep hints as context only; selection still must choose by candidate indices. lines.append( "INFERRED TAG HINTS (context only): " + ", ".join(sorted(set(hint_tags))) ) return "\n".join(lines) # Set up logging # Minimal prod logging: warnings+ to stderr, no file by default import os, logging LOG_LEVEL = os.environ.get("PSQ_LOG_LEVEL", "WARNING").upper() logging.basicConfig( level=getattr(logging, LOG_LEVEL, logging.WARNING), format="%(asctime)s %(levelname)s:%(message)s", handlers=[logging.StreamHandler()] # no file -> avoids huge logs on Spaces ) # Quiet down common noisy libs (optional) for _name in ("gensim", "gradio", "hnswlib", "httpx", "uvicorn"): logging.getLogger(_name).setLevel(logging.ERROR) # Turn off Gradio analytics phone-home to avoid those background thread errors (optional) os.environ["GRADIO_ANALYTICS_ENABLED"] = "0" MASCOT_DIR = Path(__file__).parent / "mascotimages" MASCOT_FILE = MASCOT_DIR / "transparentsquirrel.png" try: from gradio_client import utils as _gc_utils _orig_get_type = _gc_utils.get_type _orig_j2p = _gc_utils._json_schema_to_python_type _orig_pub = _gc_utils.json_schema_to_python_type def _get_type_safe(schema): # Sometimes schema is a bare True/False (JSON Schema boolean form) if not isinstance(schema, dict): return "any" return _orig_get_type(schema) def _j2p_safe(schema, defs=None): # Accept non-dict schemas (True/False/None) and treat as "any" if not isinstance(schema, dict): return "any" return _orig_j2p(schema, defs or schema.get("$defs")) def _pub_safe(schema): # Public wrapper used by Gradio; keep it resilient too if not isinstance(schema, dict): return "any" return _j2p_safe(schema, schema.get("$defs")) _gc_utils.get_type = _get_type_safe _gc_utils._json_schema_to_python_type = _j2p_safe _gc_utils.json_schema_to_python_type = _pub_safe except Exception as e: print("gradio_client hotfix not applied:", e) # ------------------------------------------------------------------------------- allow_nsfw_tags = False verbose_retrieval = True verbose_retrieval_all = False verbose_retrieval_limit = 20 enable_probe_tags = os.environ.get("PSQ_ENABLE_PROBE", "1").strip() not in {"0", "false", "False"} css = """ .scrollable-content{ max-height: 420px; overflow-y: scroll; /* always show scrollbar */ overflow-x: hidden; padding-right: 8px; padding-bottom: 14px; /* <— add this */ scrollbar-gutter: stable; /* prevent layout shift as it fills */ /* Firefox */ scrollbar-width: auto; scrollbar-color: rgba(180,180,180,.9) rgba(0,0,0,.15); } /* WebKit/Chromium (Chrome/Edge/Safari) */ .scrollable-content::-webkit-scrollbar{ width: 10px; } .scrollable-content::-webkit-scrollbar-thumb{ background: rgba(180,180,180,.9); border-radius: 8px; } .scrollable-content::-webkit-scrollbar-track{ background: rgba(0,0,0,.15); } /* (Optional) make both scroll panes taller so they fill more of the column */ .pane-left .scrollable-content, .pane-right .scrollable-content { max-height: 610px; /* was 420px; tweak to taste */ } """ def rag_pipeline_ui(user_prompt: str): logs = [] def log(s): logs.append(s) try: log("Start: received prompt") prompt_in = (user_prompt or "").strip() if not prompt_in: return "Error: empty prompt", "" log("Input:") log(prompt_in) log("") user_tags = extract_user_provided_tags_upto_3_words(prompt_in) log("Heuristically extracted user tags:") if user_tags: log(", ".join(user_tags)) else: log("(none)") log("") log("Step 1: LLM rewrite + structural inference + probe (concurrent)") max_workers = 3 if enable_probe_tags else 2 with ThreadPoolExecutor(max_workers=max_workers) as ex: fut_rewrite = ex.submit(llm_rewrite_prompt, prompt_in, log) fut_struct = ex.submit(llm_infer_structural_tags, prompt_in, log=log) fut_probe = ex.submit(llm_infer_probe_tags, prompt_in, log=log) if enable_probe_tags else None rewritten = fut_rewrite.result() structural_tags = fut_struct.result() probe_tags = fut_probe.result() if fut_probe else [] log("Rewrite:") log(rewritten if rewritten else "(empty)") log("") rewrite_for_retrieval = rewritten if user_tags: # keep them separate in logs, but allow them to help retrieval rewrite_for_retrieval = (rewrite_for_retrieval + ", " + ", ".join(user_tags)).strip(", ").strip() log("Step 2: Prompt Squirrel retrieval (hidden)") try: retrieval_context_tags = list(dict.fromkeys((structural_tags or []) + (probe_tags or []))) rewrite_phrases = [p.strip() for p in (rewrite_for_retrieval or "").split(",") if p.strip()] retrieval_result = psq_candidates_from_rewrite_phrases( rewrite_phrases=rewrite_phrases, allow_nsfw_tags=allow_nsfw_tags, context_tags=retrieval_context_tags, global_k=300, verbose=verbose_retrieval, ) if isinstance(retrieval_result, tuple): candidates, phrase_reports = retrieval_result else: candidates, phrase_reports = retrieval_result, [] log(f"Retrieved {len(candidates)} candidate tags") if verbose_retrieval: log(f"Total unique candidates: {len(candidates)}") limit = None if verbose_retrieval_all else max(1, int(verbose_retrieval_limit)) for report in phrase_reports: phrase = report.get("normalized") or report.get("phrase") or "" lookup = report.get("lookup") or "" tfidf_vocab = report.get("tfidf_vocab") log(f"Phrase: {phrase} (lookup={lookup}) tfidf_vocab={tfidf_vocab}") rows = report.get("candidates", []) shown = rows if limit is None else rows[:limit] for row in shown: tag = row.get("tag") alias_token = row.get("alias_token") score_fasttext = row.get("score_fasttext") score_context = row.get("score_context") score_combined = row.get("score_combined") count = row.get("count") alias_part = "" if alias_token and alias_token != tag: alias_part = f" [alias_token={alias_token}]" fasttext_str = ( f"{score_fasttext:.3f}" if isinstance(score_fasttext, (int, float)) else score_fasttext ) if score_context is None: context_str = "None" else: context_str = ( f"{score_context:.3f}" if isinstance(score_context, (int, float)) else score_context ) combined_str = ( f"{score_combined:.3f}" if isinstance(score_combined, (int, float)) else score_combined ) log( f" {tag}{alias_part} | fasttext={fasttext_str} context={context_str} " f"combined={combined_str} count={count}" ) if limit is not None and len(rows) > limit: log(f" ... ({len(rows) - limit} more)") except Exception as e: log(f"Retrieval fallback: {type(e).__name__}: {e}") candidates = [] log("Step 3: LLM index selection (uses rewrite + structural/probe context)") selection_query = _build_selection_query( prompt_in=prompt_in, rewritten=rewritten, structural_tags=structural_tags, probe_tags=probe_tags, ) picked_indices = llm_select_indices( query_text=selection_query, candidates=candidates, max_pick=0, log=log, ) selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else [] if structural_tags: # Add structural tags that aren't already selected existing = {t for t in selected_tags} new_structural = [t for t in structural_tags if t not in existing] selected_tags.extend(new_structural) log(f" Added {len(new_structural)} structural tags: {', '.join(new_structural)}") else: log(" No structural tags inferred") if probe_tags: existing = {t for t in selected_tags} new_probe = [t for t in probe_tags if t not in existing] selected_tags.extend(new_probe) log(f" Added {len(new_probe)} probe tags: {', '.join(new_probe)}") elif enable_probe_tags: log(" No probe tags inferred") log("Step 3c: Expand via tag implications") tag_set = set(selected_tags) expanded, implied_only = expand_tags_via_implications(tag_set) if implied_only: selected_tags.extend(sorted(implied_only)) log(f" Added {len(implied_only)} implied tags: {', '.join(sorted(implied_only))}") else: log(" No additional implied tags") log("Step 4: Compose final prompt") final_prompt = compose_final_prompt(rewritten, selected_tags) log("Done: final prompt ready") return "\n".join(logs), final_prompt except Exception as e: log(f"Error: {type(e).__name__}: {e}") return "\n".join(logs), "" with gr.Blocks(css=css) as app: with gr.Row(): with gr.Column(scale=3, elem_classes=["prompt-col"]): image_tags = gr.Textbox( label="Enter Prompt", placeholder="e.g. fox, outside, detailed background, .", lines=1 ) with gr.Column(scale=1): _mascot_pil = Image.open(MASCOT_FILE).convert("RGBA") mascot_img = gr.Image( value=_mascot_pil, show_label=False, interactive=False, height=220, elem_id="mascot" ) submit_button = gr.Button("Run", variant="primary") gr.Markdown( """ ### Prompt Squirrel RAG (pipeline version) Type a rough prompt. This tool rewrites it and aligns it to an e621-style tag vocabulary using Prompt Squirrel internally, then returns a cleaned, model-friendly prompt. """.strip() ) console = gr.Textbox( label="Console", lines=10, interactive=False, placeholder="Progress logs will appear here." ) final_prompt = gr.Textbox( label="Final Prompt", lines=3, interactive=False, placeholder="Your optimized prompt will appear here." ) submit_button.click( rag_pipeline_ui, inputs=[image_tags], outputs=[console, final_prompt] ) image_tags.submit( rag_pipeline_ui, inputs=[image_tags], outputs=[console, final_prompt] ) if __name__ == "__main__": app.queue().launch(allowed_paths=[str(MASCOT_DIR)])