Food Desert
Add alias-based character tag filtering for Stage 3
c6be992
Raw
History Blame
10.4 kB
import gradio as gr
import os
import logging
from PIL import Image
from pathlib import Path
from typing import List
from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words
from psq_rag.llm.rewrite import llm_rewrite_prompt
from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup
from psq_rag.llm.select import llm_select_indices
def _split_prompt_commas(s: str) -> List[str]:
return [p.strip() for p in (s or "").split(",") if p.strip()]
def _norm_for_dedupe(tag: str) -> str:
# your canonical form for lookup/dedupe
return _norm_tag_for_lookup(tag.lower())
def compose_final_prompt(rewritten_prompt: str, selected_tags: List[str]) -> str:
parts = _split_prompt_commas(rewritten_prompt)
parts.extend(selected_tags)
seen = set()
out = []
for p in parts:
key = _norm_for_dedupe(p)
if key in seen:
continue
seen.add(key)
out.append(p)
return ", ".join(out)
# Set up logging
# Minimal prod logging: warnings+ to stderr, no file by default
import os, logging
LOG_LEVEL = os.environ.get("PSQ_LOG_LEVEL", "WARNING").upper()
logging.basicConfig(
level=getattr(logging, LOG_LEVEL, logging.WARNING),
format="%(asctime)s %(levelname)s:%(message)s",
handlers=[logging.StreamHandler()] # no file -> avoids huge logs on Spaces
)
# Quiet down common noisy libs (optional)
for _name in ("gensim", "gradio", "hnswlib", "httpx", "uvicorn"):
logging.getLogger(_name).setLevel(logging.ERROR)
# Turn off Gradio analytics phone-home to avoid those background thread errors (optional)
os.environ["GRADIO_ANALYTICS_ENABLED"] = "0"
MASCOT_DIR = Path(__file__).parent / "mascotimages"
MASCOT_FILE = MASCOT_DIR / "transparentsquirrel.png"
try:
from gradio_client import utils as _gc_utils
_orig_get_type = _gc_utils.get_type
_orig_j2p = _gc_utils._json_schema_to_python_type
_orig_pub = _gc_utils.json_schema_to_python_type
def _get_type_safe(schema):
# Sometimes schema is a bare True/False (JSON Schema boolean form)
if not isinstance(schema, dict):
return "any"
return _orig_get_type(schema)
def _j2p_safe(schema, defs=None):
# Accept non-dict schemas (True/False/None) and treat as "any"
if not isinstance(schema, dict):
return "any"
return _orig_j2p(schema, defs or schema.get("$defs"))
def _pub_safe(schema):
# Public wrapper used by Gradio; keep it resilient too
if not isinstance(schema, dict):
return "any"
return _j2p_safe(schema, schema.get("$defs"))
_gc_utils.get_type = _get_type_safe
_gc_utils._json_schema_to_python_type = _j2p_safe
_gc_utils.json_schema_to_python_type = _pub_safe
except Exception as e:
print("gradio_client hotfix not applied:", e)
# -------------------------------------------------------------------------------
allow_nsfw_tags = False
verbose_retrieval = True
verbose_retrieval_all = False
verbose_retrieval_limit = 20
css = """
.scrollable-content{
max-height: 420px;
overflow-y: scroll; /* always show scrollbar */
overflow-x: hidden;
padding-right: 8px;
padding-bottom: 14px; /* <— add this */
scrollbar-gutter: stable; /* prevent layout shift as it fills */
/* Firefox */
scrollbar-width: auto;
scrollbar-color: rgba(180,180,180,.9) rgba(0,0,0,.15);
}
/* WebKit/Chromium (Chrome/Edge/Safari) */
.scrollable-content::-webkit-scrollbar{ width: 10px; }
.scrollable-content::-webkit-scrollbar-thumb{ background: rgba(180,180,180,.9); border-radius: 8px; }
.scrollable-content::-webkit-scrollbar-track{ background: rgba(0,0,0,.15); }
/* (Optional) make both scroll panes taller so they fill more of the column */
.pane-left .scrollable-content,
.pane-right .scrollable-content {
max-height: 610px; /* was 420px; tweak to taste */
}
"""
def rag_pipeline_ui(user_prompt: str):
logs = []
def log(s): logs.append(s)
try:
log("Start: received prompt")
prompt_in = (user_prompt or "").strip()
if not prompt_in:
return "Error: empty prompt", ""
log("Input:")
log(prompt_in)
log("")
user_tags = extract_user_provided_tags_upto_3_words(prompt_in)
log("Heuristically extracted user tags:")
if user_tags:
log(", ".join(user_tags))
else:
log("(none)")
log("")
log("Step 1: LLM rewrite")
rewritten = llm_rewrite_prompt(prompt_in, log)
log("Rewrite:")
log(rewritten if rewritten else "(empty)")
log("")
rewrite_for_retrieval = rewritten
if user_tags:
# keep them separate in logs, but allow them to help retrieval
rewrite_for_retrieval = (rewrite_for_retrieval + ", " + ", ".join(user_tags)).strip(", ").strip()
log("Step 2: Prompt Squirrel retrieval (hidden)")
try:
rewrite_phrases = [p.strip() for p in (rewrite_for_retrieval or "").split(",") if p.strip()]
retrieval_result = psq_candidates_from_rewrite_phrases(
rewrite_phrases=rewrite_phrases,
allow_nsfw_tags=allow_nsfw_tags,
global_k=300,
verbose=verbose_retrieval,
)
if isinstance(retrieval_result, tuple):
candidates, phrase_reports = retrieval_result
else:
candidates, phrase_reports = retrieval_result, []
log(f"Retrieved {len(candidates)} candidate tags")
if verbose_retrieval:
log(f"Total unique candidates: {len(candidates)}")
limit = None if verbose_retrieval_all else max(1, int(verbose_retrieval_limit))
for report in phrase_reports:
phrase = report.get("normalized") or report.get("phrase") or ""
lookup = report.get("lookup") or ""
tfidf_vocab = report.get("tfidf_vocab")
log(f"Phrase: {phrase} (lookup={lookup}) tfidf_vocab={tfidf_vocab}")
rows = report.get("candidates", [])
shown = rows if limit is None else rows[:limit]
for row in shown:
tag = row.get("tag")
alias_token = row.get("alias_token")
score_fasttext = row.get("score_fasttext")
score_context = row.get("score_context")
score_combined = row.get("score_combined")
count = row.get("count")
alias_part = ""
if alias_token and alias_token != tag:
alias_part = f" [alias_token={alias_token}]"
fasttext_str = (
f"{score_fasttext:.3f}" if isinstance(score_fasttext, (int, float)) else score_fasttext
)
if score_context is None:
context_str = "None"
else:
context_str = (
f"{score_context:.3f}" if isinstance(score_context, (int, float)) else score_context
)
combined_str = (
f"{score_combined:.3f}" if isinstance(score_combined, (int, float)) else score_combined
)
log(
f" {tag}{alias_part} | fasttext={fasttext_str} context={context_str} "
f"combined={combined_str} count={count}"
)
if limit is not None and len(rows) > limit:
log(f" ... ({len(rows) - limit} more)")
except Exception as e:
log(f"Retrieval fallback: {type(e).__name__}: {e}")
candidates = []
log("Step 3: LLM index selection")
# We pass the original 'prompt_in' as the description for the LLM to match against
picked_indices = llm_select_indices(
query_text=prompt_in,
candidates=candidates,
max_pick=0,
log=log
)
selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else []
log("Step 4: Compose final prompt")
final_prompt = compose_final_prompt(rewritten, selected_tags)
log("Done: final prompt ready")
return "\n".join(logs), final_prompt
except Exception as e:
log(f"Error: {type(e).__name__}: {e}")
return "\n".join(logs), ""
with gr.Blocks(css=css) as app:
with gr.Row():
with gr.Column(scale=3, elem_classes=["prompt-col"]):
image_tags = gr.Textbox(
label="Enter Prompt",
placeholder="e.g. fox, outside, detailed background, .",
lines=1
)
with gr.Column(scale=1):
_mascot_pil = Image.open(MASCOT_FILE).convert("RGBA")
mascot_img = gr.Image(
value=_mascot_pil,
show_label=False,
interactive=False,
height=220,
elem_id="mascot"
)
submit_button = gr.Button("Run", variant="primary")
gr.Markdown(
"""
### Prompt Squirrel RAG (pipeline version)
Type a rough prompt. This tool rewrites it and aligns it to an e621-style tag vocabulary using Prompt Squirrel internally,
then returns a cleaned, model-friendly prompt.
""".strip()
)
console = gr.Textbox(
label="Console",
lines=10,
interactive=False,
placeholder="Progress logs will appear here."
)
final_prompt = gr.Textbox(
label="Final Prompt",
lines=3,
interactive=False,
placeholder="Your optimized prompt will appear here."
)
submit_button.click(
rag_pipeline_ui,
inputs=[image_tags],
outputs=[console, final_prompt]
)
image_tags.submit(
rag_pipeline_ui,
inputs=[image_tags],
outputs=[console, final_prompt]
)
if __name__ == "__main__":
app.queue().launch(allowed_paths=[str(MASCOT_DIR)])