Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import logging | |
| from PIL import Image | |
| from pathlib import Path | |
| from typing import List | |
| from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words | |
| from psq_rag.llm.rewrite import llm_rewrite_prompt | |
| from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup | |
| from psq_rag.llm.select import llm_select_indices, llm_infer_structural_tags | |
| from psq_rag.retrieval.state import expand_tags_via_implications | |
| def _split_prompt_commas(s: str) -> List[str]: | |
| return [p.strip() for p in (s or "").split(",") if p.strip()] | |
| def _norm_for_dedupe(tag: str) -> str: | |
| # your canonical form for lookup/dedupe | |
| return _norm_tag_for_lookup(tag.lower()) | |
| def compose_final_prompt(rewritten_prompt: str, selected_tags: List[str]) -> str: | |
| parts = _split_prompt_commas(rewritten_prompt) | |
| parts.extend(selected_tags) | |
| seen = set() | |
| out = [] | |
| for p in parts: | |
| key = _norm_for_dedupe(p) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| out.append(p) | |
| return ", ".join(out) | |
| # Set up logging | |
| # Minimal prod logging: warnings+ to stderr, no file by default | |
| import os, logging | |
| LOG_LEVEL = os.environ.get("PSQ_LOG_LEVEL", "WARNING").upper() | |
| logging.basicConfig( | |
| level=getattr(logging, LOG_LEVEL, logging.WARNING), | |
| format="%(asctime)s %(levelname)s:%(message)s", | |
| handlers=[logging.StreamHandler()] # no file -> avoids huge logs on Spaces | |
| ) | |
| # Quiet down common noisy libs (optional) | |
| for _name in ("gensim", "gradio", "hnswlib", "httpx", "uvicorn"): | |
| logging.getLogger(_name).setLevel(logging.ERROR) | |
| # Turn off Gradio analytics phone-home to avoid those background thread errors (optional) | |
| os.environ["GRADIO_ANALYTICS_ENABLED"] = "0" | |
| MASCOT_DIR = Path(__file__).parent / "mascotimages" | |
| MASCOT_FILE = MASCOT_DIR / "transparentsquirrel.png" | |
| try: | |
| from gradio_client import utils as _gc_utils | |
| _orig_get_type = _gc_utils.get_type | |
| _orig_j2p = _gc_utils._json_schema_to_python_type | |
| _orig_pub = _gc_utils.json_schema_to_python_type | |
| def _get_type_safe(schema): | |
| # Sometimes schema is a bare True/False (JSON Schema boolean form) | |
| if not isinstance(schema, dict): | |
| return "any" | |
| return _orig_get_type(schema) | |
| def _j2p_safe(schema, defs=None): | |
| # Accept non-dict schemas (True/False/None) and treat as "any" | |
| if not isinstance(schema, dict): | |
| return "any" | |
| return _orig_j2p(schema, defs or schema.get("$defs")) | |
| def _pub_safe(schema): | |
| # Public wrapper used by Gradio; keep it resilient too | |
| if not isinstance(schema, dict): | |
| return "any" | |
| return _j2p_safe(schema, schema.get("$defs")) | |
| _gc_utils.get_type = _get_type_safe | |
| _gc_utils._json_schema_to_python_type = _j2p_safe | |
| _gc_utils.json_schema_to_python_type = _pub_safe | |
| except Exception as e: | |
| print("gradio_client hotfix not applied:", e) | |
| # ------------------------------------------------------------------------------- | |
| allow_nsfw_tags = False | |
| verbose_retrieval = True | |
| verbose_retrieval_all = False | |
| verbose_retrieval_limit = 20 | |
| css = """ | |
| .scrollable-content{ | |
| max-height: 420px; | |
| overflow-y: scroll; /* always show scrollbar */ | |
| overflow-x: hidden; | |
| padding-right: 8px; | |
| padding-bottom: 14px; /* <— add this */ | |
| scrollbar-gutter: stable; /* prevent layout shift as it fills */ | |
| /* Firefox */ | |
| scrollbar-width: auto; | |
| scrollbar-color: rgba(180,180,180,.9) rgba(0,0,0,.15); | |
| } | |
| /* WebKit/Chromium (Chrome/Edge/Safari) */ | |
| .scrollable-content::-webkit-scrollbar{ width: 10px; } | |
| .scrollable-content::-webkit-scrollbar-thumb{ background: rgba(180,180,180,.9); border-radius: 8px; } | |
| .scrollable-content::-webkit-scrollbar-track{ background: rgba(0,0,0,.15); } | |
| /* (Optional) make both scroll panes taller so they fill more of the column */ | |
| .pane-left .scrollable-content, | |
| .pane-right .scrollable-content { | |
| max-height: 610px; /* was 420px; tweak to taste */ | |
| } | |
| """ | |
| def rag_pipeline_ui(user_prompt: str): | |
| logs = [] | |
| def log(s): logs.append(s) | |
| try: | |
| log("Start: received prompt") | |
| prompt_in = (user_prompt or "").strip() | |
| if not prompt_in: | |
| return "Error: empty prompt", "" | |
| log("Input:") | |
| log(prompt_in) | |
| log("") | |
| user_tags = extract_user_provided_tags_upto_3_words(prompt_in) | |
| log("Heuristically extracted user tags:") | |
| if user_tags: | |
| log(", ".join(user_tags)) | |
| else: | |
| log("(none)") | |
| log("") | |
| log("Step 1: LLM rewrite") | |
| rewritten = llm_rewrite_prompt(prompt_in, log) | |
| log("Rewrite:") | |
| log(rewritten if rewritten else "(empty)") | |
| log("") | |
| rewrite_for_retrieval = rewritten | |
| if user_tags: | |
| # keep them separate in logs, but allow them to help retrieval | |
| rewrite_for_retrieval = (rewrite_for_retrieval + ", " + ", ".join(user_tags)).strip(", ").strip() | |
| log("Step 2: Prompt Squirrel retrieval (hidden)") | |
| try: | |
| rewrite_phrases = [p.strip() for p in (rewrite_for_retrieval or "").split(",") if p.strip()] | |
| retrieval_result = psq_candidates_from_rewrite_phrases( | |
| rewrite_phrases=rewrite_phrases, | |
| allow_nsfw_tags=allow_nsfw_tags, | |
| global_k=300, | |
| verbose=verbose_retrieval, | |
| ) | |
| if isinstance(retrieval_result, tuple): | |
| candidates, phrase_reports = retrieval_result | |
| else: | |
| candidates, phrase_reports = retrieval_result, [] | |
| log(f"Retrieved {len(candidates)} candidate tags") | |
| if verbose_retrieval: | |
| log(f"Total unique candidates: {len(candidates)}") | |
| limit = None if verbose_retrieval_all else max(1, int(verbose_retrieval_limit)) | |
| for report in phrase_reports: | |
| phrase = report.get("normalized") or report.get("phrase") or "" | |
| lookup = report.get("lookup") or "" | |
| tfidf_vocab = report.get("tfidf_vocab") | |
| log(f"Phrase: {phrase} (lookup={lookup}) tfidf_vocab={tfidf_vocab}") | |
| rows = report.get("candidates", []) | |
| shown = rows if limit is None else rows[:limit] | |
| for row in shown: | |
| tag = row.get("tag") | |
| alias_token = row.get("alias_token") | |
| score_fasttext = row.get("score_fasttext") | |
| score_context = row.get("score_context") | |
| score_combined = row.get("score_combined") | |
| count = row.get("count") | |
| alias_part = "" | |
| if alias_token and alias_token != tag: | |
| alias_part = f" [alias_token={alias_token}]" | |
| fasttext_str = ( | |
| f"{score_fasttext:.3f}" if isinstance(score_fasttext, (int, float)) else score_fasttext | |
| ) | |
| if score_context is None: | |
| context_str = "None" | |
| else: | |
| context_str = ( | |
| f"{score_context:.3f}" if isinstance(score_context, (int, float)) else score_context | |
| ) | |
| combined_str = ( | |
| f"{score_combined:.3f}" if isinstance(score_combined, (int, float)) else score_combined | |
| ) | |
| log( | |
| f" {tag}{alias_part} | fasttext={fasttext_str} context={context_str} " | |
| f"combined={combined_str} count={count}" | |
| ) | |
| if limit is not None and len(rows) > limit: | |
| log(f" ... ({len(rows) - limit} more)") | |
| except Exception as e: | |
| log(f"Retrieval fallback: {type(e).__name__}: {e}") | |
| candidates = [] | |
| log("Step 3: LLM index selection") | |
| # We pass the original 'prompt_in' as the description for the LLM to match against | |
| picked_indices = llm_select_indices( | |
| query_text=prompt_in, | |
| candidates=candidates, | |
| max_pick=0, | |
| log=log | |
| ) | |
| selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else [] | |
| log("Step 3b: Structural tag inference (solo/duo/gender/body plan)") | |
| structural_tags = llm_infer_structural_tags(prompt_in, log=log) | |
| if structural_tags: | |
| # Add structural tags that aren't already selected | |
| existing = {t for t in selected_tags} | |
| new_structural = [t for t in structural_tags if t not in existing] | |
| selected_tags.extend(new_structural) | |
| log(f" Added {len(new_structural)} structural tags: {', '.join(new_structural)}") | |
| else: | |
| log(" No structural tags inferred") | |
| log("Step 3c: Expand via tag implications") | |
| tag_set = set(selected_tags) | |
| expanded, implied_only = expand_tags_via_implications(tag_set) | |
| if implied_only: | |
| selected_tags.extend(sorted(implied_only)) | |
| log(f" Added {len(implied_only)} implied tags: {', '.join(sorted(implied_only))}") | |
| else: | |
| log(" No additional implied tags") | |
| log("Step 4: Compose final prompt") | |
| final_prompt = compose_final_prompt(rewritten, selected_tags) | |
| log("Done: final prompt ready") | |
| return "\n".join(logs), final_prompt | |
| except Exception as e: | |
| log(f"Error: {type(e).__name__}: {e}") | |
| return "\n".join(logs), "" | |
| with gr.Blocks(css=css) as app: | |
| with gr.Row(): | |
| with gr.Column(scale=3, elem_classes=["prompt-col"]): | |
| image_tags = gr.Textbox( | |
| label="Enter Prompt", | |
| placeholder="e.g. fox, outside, detailed background, .", | |
| lines=1 | |
| ) | |
| with gr.Column(scale=1): | |
| _mascot_pil = Image.open(MASCOT_FILE).convert("RGBA") | |
| mascot_img = gr.Image( | |
| value=_mascot_pil, | |
| show_label=False, | |
| interactive=False, | |
| height=220, | |
| elem_id="mascot" | |
| ) | |
| submit_button = gr.Button("Run", variant="primary") | |
| gr.Markdown( | |
| """ | |
| ### Prompt Squirrel RAG (pipeline version) | |
| Type a rough prompt. This tool rewrites it and aligns it to an e621-style tag vocabulary using Prompt Squirrel internally, | |
| then returns a cleaned, model-friendly prompt. | |
| """.strip() | |
| ) | |
| console = gr.Textbox( | |
| label="Console", | |
| lines=10, | |
| interactive=False, | |
| placeholder="Progress logs will appear here." | |
| ) | |
| final_prompt = gr.Textbox( | |
| label="Final Prompt", | |
| lines=3, | |
| interactive=False, | |
| placeholder="Your optimized prompt will appear here." | |
| ) | |
| submit_button.click( | |
| rag_pipeline_ui, | |
| inputs=[image_tags], | |
| outputs=[console, final_prompt] | |
| ) | |
| image_tags.submit( | |
| rag_pipeline_ui, | |
| inputs=[image_tags], | |
| outputs=[console, final_prompt] | |
| ) | |
| if __name__ == "__main__": | |
| app.queue().launch(allowed_paths=[str(MASCOT_DIR)]) | |