import gradio as gr import os import logging from PIL import Image from pathlib import Path from typing import List from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words from psq_rag.llm.rewrite import llm_rewrite_prompt from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags from psq_rag.retrieval.state import expand_tags_via_implications def _split_prompt_commas(s: str) -> List[str]: return [p.strip() for p in (s or "").split(",") if p.strip()] def _norm_for_dedupe(tag: str) -> str: # your canonical form for lookup/dedupe return _norm_tag_for_lookup(tag.lower()) def compose_final_prompt(rewritten_prompt: str, selected_tags: List[str]) -> str: parts = _split_prompt_commas(rewritten_prompt) parts.extend(selected_tags) seen = set() out = [] for p in parts: key = _norm_for_dedupe(p) if key in seen: continue seen.add(key) out.append(p) return ", ".join(out) # Set up logging # Minimal prod logging: warnings+ to stderr, no file by default import os, logging LOG_LEVEL = os.environ.get("PSQ_LOG_LEVEL", "WARNING").upper() logging.basicConfig( level=getattr(logging, LOG_LEVEL, logging.WARNING), format="%(asctime)s %(levelname)s:%(message)s", handlers=[logging.StreamHandler()] # no file -> avoids huge logs on Spaces ) # Quiet down common noisy libs (optional) for _name in ("gensim", "gradio", "hnswlib", "httpx", "uvicorn"): logging.getLogger(_name).setLevel(logging.ERROR) # Turn off Gradio analytics phone-home to avoid those background thread errors (optional) os.environ["GRADIO_ANALYTICS_ENABLED"] = "0" MASCOT_DIR = Path(__file__).parent / "mascotimages" MASCOT_FILE = MASCOT_DIR / "transparentsquirrel.png" try: from gradio_client import utils as _gc_utils _orig_get_type = _gc_utils.get_type _orig_j2p = _gc_utils._json_schema_to_python_type _orig_pub = _gc_utils.json_schema_to_python_type def _get_type_safe(schema): # Sometimes schema is a bare True/False (JSON Schema boolean form) if not isinstance(schema, dict): return "any" return _orig_get_type(schema) def _j2p_safe(schema, defs=None): # Accept non-dict schemas (True/False/None) and treat as "any" if not isinstance(schema, dict): return "any" return _orig_j2p(schema, defs or schema.get("$defs")) def _pub_safe(schema): # Public wrapper used by Gradio; keep it resilient too if not isinstance(schema, dict): return "any" return _j2p_safe(schema, schema.get("$defs")) _gc_utils.get_type = _get_type_safe _gc_utils._json_schema_to_python_type = _j2p_safe _gc_utils.json_schema_to_python_type = _pub_safe except Exception as e: print("gradio_client hotfix not applied:", e) # ------------------------------------------------------------------------------- allow_nsfw_tags = False verbose_retrieval = True verbose_retrieval_all = False verbose_retrieval_limit = 20 css = """ .scrollable-content{ max-height: 420px; overflow-y: scroll; /* always show scrollbar */ overflow-x: hidden; padding-right: 8px; padding-bottom: 14px; /* <— add this */ scrollbar-gutter: stable; /* prevent layout shift as it fills */ /* Firefox */ scrollbar-width: auto; scrollbar-color: rgba(180,180,180,.9) rgba(0,0,0,.15); } /* WebKit/Chromium (Chrome/Edge/Safari) */ .scrollable-content::-webkit-scrollbar{ width: 10px; } .scrollable-content::-webkit-scrollbar-thumb{ background: rgba(180,180,180,.9); border-radius: 8px; } .scrollable-content::-webkit-scrollbar-track{ background: rgba(0,0,0,.15); } /* (Optional) make both scroll panes taller so they fill more of the column */ .pane-left .scrollable-content, .pane-right .scrollable-content { max-height: 610px; /* was 420px; tweak to taste */ } """ def rag_pipeline_ui(user_prompt: str): logs = [] def log(s): logs.append(s) try: log("Start: received prompt") prompt_in = (user_prompt or "").strip() if not prompt_in: return "Error: empty prompt", "" log("Input:") log(prompt_in) log("") user_tags = extract_user_provided_tags_upto_3_words(prompt_in) log("Heuristically extracted user tags:") if user_tags: log(", ".join(user_tags)) else: log("(none)") log("") log("Step 1: LLM rewrite") rewritten = llm_rewrite_prompt(prompt_in, log) log("Rewrite:") log(rewritten if rewritten else "(empty)") log("") rewrite_for_retrieval = rewritten if user_tags: # keep them separate in logs, but allow them to help retrieval rewrite_for_retrieval = (rewrite_for_retrieval + ", " + ", ".join(user_tags)).strip(", ").strip() log("Step 2: Prompt Squirrel retrieval (hidden)") try: rewrite_phrases = [p.strip() for p in (rewrite_for_retrieval or "").split(",") if p.strip()] retrieval_result = psq_candidates_from_rewrite_phrases( rewrite_phrases=rewrite_phrases, allow_nsfw_tags=allow_nsfw_tags, global_k=300, verbose=verbose_retrieval, ) if isinstance(retrieval_result, tuple): candidates, phrase_reports = retrieval_result else: candidates, phrase_reports = retrieval_result, [] log(f"Retrieved {len(candidates)} candidate tags") if verbose_retrieval: log(f"Total unique candidates: {len(candidates)}") limit = None if verbose_retrieval_all else max(1, int(verbose_retrieval_limit)) for report in phrase_reports: phrase = report.get("normalized") or report.get("phrase") or "" lookup = report.get("lookup") or "" tfidf_vocab = report.get("tfidf_vocab") log(f"Phrase: {phrase} (lookup={lookup}) tfidf_vocab={tfidf_vocab}") rows = report.get("candidates", []) shown = rows if limit is None else rows[:limit] for row in shown: tag = row.get("tag") alias_token = row.get("alias_token") score_fasttext = row.get("score_fasttext") score_context = row.get("score_context") score_combined = row.get("score_combined") count = row.get("count") alias_part = "" if alias_token and alias_token != tag: alias_part = f" [alias_token={alias_token}]" fasttext_str = ( f"{score_fasttext:.3f}" if isinstance(score_fasttext, (int, float)) else score_fasttext ) if score_context is None: context_str = "None" else: context_str = ( f"{score_context:.3f}" if isinstance(score_context, (int, float)) else score_context ) combined_str = ( f"{score_combined:.3f}" if isinstance(score_combined, (int, float)) else score_combined ) log( f" {tag}{alias_part} | fasttext={fasttext_str} context={context_str} " f"combined={combined_str} count={count}" ) if limit is not None and len(rows) > limit: log(f" ... ({len(rows) - limit} more)") except Exception as e: log(f"Retrieval fallback: {type(e).__name__}: {e}") candidates = [] log("Step 3: LLM index selection") # We pass the original 'prompt_in' as the description for the LLM to match against picked_indices = llm_select_indices( query_text=prompt_in, candidates=candidates, max_pick=0, log=log ) selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else [] log("Step 3b: Structural tag inference (solo/duo/gender/body plan)") structural_tags = llm_infer_structural_tags(prompt_in, log=log) if structural_tags: # Add structural tags that aren't already selected existing = {t for t in selected_tags} new_structural = [t for t in structural_tags if t not in existing] selected_tags.extend(new_structural) log(f" Added {len(new_structural)} structural tags: {', '.join(new_structural)}") else: log(" No structural tags inferred") log("Step 3c: Expand via tag implications") tag_set = set(selected_tags) expanded, implied_only = expand_tags_via_implications(tag_set) if implied_only: selected_tags.extend(sorted(implied_only)) log(f" Added {len(implied_only)} implied tags: {', '.join(sorted(implied_only))}") else: log(" No additional implied tags") log("Step 4: Compose final prompt") final_prompt = compose_final_prompt(rewritten, selected_tags) log("Done: final prompt ready") return "\n".join(logs), final_prompt except Exception as e: log(f"Error: {type(e).__name__}: {e}") return "\n".join(logs), "" with gr.Blocks(css=css) as app: with gr.Row(): with gr.Column(scale=3, elem_classes=["prompt-col"]): image_tags = gr.Textbox( label="Enter Prompt", placeholder="e.g. fox, outside, detailed background, .", lines=1 ) with gr.Column(scale=1): _mascot_pil = Image.open(MASCOT_FILE).convert("RGBA") mascot_img = gr.Image( value=_mascot_pil, show_label=False, interactive=False, height=220, elem_id="mascot" ) submit_button = gr.Button("Run", variant="primary") gr.Markdown( """ ### Prompt Squirrel RAG (pipeline version) Type a rough prompt. This tool rewrites it and aligns it to an e621-style tag vocabulary using Prompt Squirrel internally, then returns a cleaned, model-friendly prompt. """.strip() ) console = gr.Textbox( label="Console", lines=10, interactive=False, placeholder="Progress logs will appear here." ) final_prompt = gr.Textbox( label="Final Prompt", lines=3, interactive=False, placeholder="Your optimized prompt will appear here." ) submit_button.click( rag_pipeline_ui, inputs=[image_tags], outputs=[console, final_prompt] ) image_tags.submit( rag_pipeline_ui, inputs=[image_tags], outputs=[console, final_prompt] ) if __name__ == "__main__": app.queue().launch(allowed_paths=[str(MASCOT_DIR)])