from __future__ import annotations import json import logging import math import os import pathlib import re from collections import Counter, OrderedDict from dataclasses import dataclass from itertools import islice from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union import numpy as np import joblib from scipy.sparse import csr_matrix from .state import ( get_fasttext_model, get_tag_counts, get_hnsw_artist_index, get_hnsw_tag_index, get_nsfw_tags, get_tfidf_components, get_tfidf_tag_vectors, get_alias2tags, ) @dataclass(frozen=True) class Candidate: tag: str score_combined: float score_fasttext: Optional[float] score_context: Optional[float] count: Optional[int] sources: List[str] def _norm_tag_for_lookup(s: str) -> str: # convert "name with spaces" -> "name_with_spaces" and unescape parens return s.replace(' ', '_').replace('\\(', '(').replace('\\)', ')') special_tags = ["score:0", "score:1", "score:2", "score:3", "score:4", "score:5", "score:6", "score:7", "score:8", "score:9", "rating:s", "rating:q", "rating:e"] def remove_special_tags(original_string): tags = [tag.strip() for tag in original_string.split(",")] remaining_tags = [tag for tag in tags if tag not in special_tags] removed_tags = [tag for tag in tags if tag in special_tags] return ", ".join(remaining_tags), removed_tags def construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index): cols, data = [], [] for term, w in pseudo_doc_terms.items(): j = term_to_column_index.get(term) if j is None: continue cols.append(j) data.append(w * idf[j]) n_cols = len(idf) indptr = [0, len(cols)] return csr_matrix((data, cols, indptr), shape=(1, n_cols), dtype=np.float32) def _ensure_dual_hnsw_indexes(): """ Build/load two HNSW indexes over the SVD-reduced TF-IDF matrix. """ get_hnsw_tag_index() get_hnsw_artist_index() return def _hnsw_query(idx, vec: np.ndarray, k: int): """ Query a given HNSW index with a (1, D) or (D,) vector in SVD space. Returns (indices, sims) with cosine similarity scores. """ q = np.asarray(vec, dtype=np.float32).reshape(-1) q_norm = np.linalg.norm(q) if q_norm > 0: q = q / q_norm labels, dists = idx.knn_query(q, k=k) inds = labels[0] sims = 1.0 - dists[0] # cosine distance -> similarity return inds, sims def _ann_tags_topk(vec: np.ndarray, k: int): idx, n_items = get_hnsw_tag_index() if idx is None: return (np.array([], dtype=int), np.array([], dtype=float)) k = min(k, n_items if n_items else 0) return _hnsw_query(idx, vec, k) if k else (np.array([], dtype=int), np.array([], dtype=float)) def _ann_artists_topk(vec: np.ndarray, k: int): idx, n_items = get_hnsw_artist_index() if idx is None: return (np.array([], dtype=int), np.array([], dtype=float)) k = min(k, n_items if n_items else 0) return _hnsw_query(idx, vec, k) if k else (np.array([], dtype=int), np.array([], dtype=float)) def get_tfidf_reduced_similar_tags(pseudo_doc_terms, allow_nsfw_tags): tf_idf_components = get_tfidf_components() idf = tf_idf_components["idf"] term_to_column_index = tf_idf_components["tag_to_column_index"] row_to_tag = tf_idf_components["row_to_tag"] svd = tf_idf_components["svd_model"] # 1) Build the pseudo TF-IDF, reduce to SVD space (unchanged) pseudo_tfidf_vector = construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index) reduced_pseudo_vector = svd.transform(pseudo_tfidf_vector) # shape (1, D) # 2) ANN: only fetch nearest non-artist candidates (no full-matrix cosine) K = 2000 # tune for speed/recall top_inds, top_sims = _ann_tags_topk(reduced_pseudo_vector, k=K) # 3) Build similarity dict from those candidates tag_similarity_dict = {} for i, sim in zip(top_inds, top_sims): tag = row_to_tag.get(int(i)) if tag is not None: tag_similarity_dict[tag] = float(sim) if not allow_nsfw_tags: nsfw_tags = get_nsfw_tags() tag_similarity_dict = {t: s for t, s in tag_similarity_dict.items() if t not in nsfw_tags} # 4) Sort & escape like before sorted_tag_similarity_dict = OrderedDict(sorted(tag_similarity_dict.items(), key=lambda x: x[1], reverse=True)) transformed_sorted_tag_similarity_dict = OrderedDict( (key.replace('_', ' ').replace('(', '\\(').replace(')', '\\)'), val) for key, val in sorted_tag_similarity_dict.items() ) return transformed_sorted_tag_similarity_dict def psq_candidates_from_terms(terms: Sequence[str], *, allow_nsfw_tags: bool, k: int = 300): cand_dict = get_tfidf_reduced_similar_tags(dict(Counter(terms)), allow_nsfw_tags) candidates = list(islice(cand_dict.items(), k)) tag_counts = get_tag_counts() return [ Candidate( tag=tag, score_combined=float(score), score_fasttext=None, score_context=None, count=tag_counts.get(tag), sources=[], ) for tag, score in candidates ] def psq_candidates_from_rewrite_phrases( rewrite_phrases: Sequence[str], *, allow_nsfw_tags: bool, context_weight: float = 0.5, per_phrase_k: int = 50, per_phrase_final_k: int = 10, global_k: int = 300, verbose: bool = False, ) -> Union[List[Candidate], Tuple[List[Candidate], List[Dict[str, Any]]]]: head_stopwords = { "and", "or", "the", "a", "an", "of", "to", "in", "on", "at", "with", "for", "from", "by", "as", "is", "are", "was", "were", "be", "been", "being", "down", "up", "over", "under", } def _normalize_phrase(phrase: str) -> str: lowered = (phrase or "").lower().strip().replace("_", " ") return " ".join(lowered.split()) norm_phrases = [_normalize_phrase(p) for p in rewrite_phrases] deduped_phrases = list(dict.fromkeys(p for p in norm_phrases if p)) if not deduped_phrases: return ([], []) if verbose else [] head_phrases: List[str] = [] for phrase in deduped_phrases: parts = phrase.split() if len(parts) >= 2: head = parts[-1] if len(head) >= 3 and head.lower() not in head_stopwords: head_phrases.append(head) final_phrases = list(dict.fromkeys(deduped_phrases + head_phrases)) fasttext_model = get_fasttext_model() tag_counts = get_tag_counts() nsfw_tags = get_nsfw_tags() if not allow_nsfw_tags else set() alias2tags = get_alias2tags() tfidf_components = get_tfidf_components() tfidf_vocab = tfidf_components.get("tag_to_column_index", {}) idf = tfidf_components["idf"] term_to_column_index = tfidf_components["tag_to_column_index"] svd = tfidf_components["svd_model"] pseudo_doc_terms = Counter() oov_terms: List[str] = [] for phrase in final_phrases: lookup = phrase.replace(" ", "_") if lookup in term_to_column_index: pseudo_doc_terms[lookup] += 1 elif verbose: oov_terms.append(lookup) pseudo_tfidf_vector = construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index) reduced_query_vector = svd.transform(pseudo_tfidf_vector).reshape(-1) query_norm = np.linalg.norm(reduced_query_vector) if query_norm > 0: reduced_query_vector = reduced_query_vector / query_norm query_has_context = True else: query_has_context = False tag_vectors = get_tfidf_tag_vectors() if query_has_context else None tag_to_row_index = tag_vectors["tag_to_row_index"] if tag_vectors else {} phrase_candidate_maps: List[Tuple[str, Dict[str, float]]] = [] phrase_required_tags: Dict[str, Set[str]] = {} phrase_best_tokens: Dict[str, Dict[str, str]] = {} phrase_context_imputed: Dict[str, Dict[str, bool]] = {} phrase_reports: List[Dict[str, Any]] = [] for phrase in final_phrases: lookup = phrase.replace(" ", "_") def _project_to_canonicals(token: str) -> List[str]: if token in tag_counts or token in tag_to_row_index: return [token] if token in alias2tags: return alias2tags[token] return [] try: neighbors = fasttext_model.most_similar(lookup, topn=per_phrase_k) except KeyError: neighbors = [] per_phrase_candidates: Dict[str, float] = {} per_phrase_best_token: Dict[str, str] = {} for token, sim in neighbors: for canonical_tag in _project_to_canonicals(token): if not allow_nsfw_tags and canonical_tag in nsfw_tags: continue prev = per_phrase_candidates.get(canonical_tag) if prev is None or sim > prev: per_phrase_candidates[canonical_tag] = float(sim) per_phrase_best_token[canonical_tag] = token projected_lookup = _project_to_canonicals(lookup) required_tags = set(projected_lookup) if not allow_nsfw_tags: required_tags = {tag for tag in required_tags if tag not in nsfw_tags} for canonical_tag in projected_lookup: if not allow_nsfw_tags and canonical_tag in nsfw_tags: continue prev = per_phrase_candidates.get(canonical_tag) if prev is None or 1.0 > prev: per_phrase_candidates[canonical_tag] = 1.0 per_phrase_best_token[canonical_tag] = lookup phrase_candidate_maps.append((phrase, per_phrase_candidates)) phrase_required_tags[phrase] = required_tags phrase_best_tokens[phrase] = per_phrase_best_token if verbose: in_vocab = bool(tfidf_vocab and lookup in tfidf_vocab) rows = [] for canonical_tag, sim in sorted(per_phrase_candidates.items(), key=lambda x: x[1], reverse=True): if not allow_nsfw_tags and canonical_tag in nsfw_tags: continue alias_token = per_phrase_best_token.get(canonical_tag, canonical_tag) rows.append( { "tag": canonical_tag, "alias_token": alias_token, "score_fasttext": float(sim), "score_context": None, "score_combined": float(sim), "context_imputed": False, "count": tag_counts.get(canonical_tag), } ) phrase_reports.append( { "phrase": phrase, "normalized": phrase, "lookup": lookup, "tfidf_vocab": in_vocab, "oov_terms": oov_terms, "candidates": rows, } ) all_candidate_tags: Set[str] = set() for _, per_phrase_candidates in phrase_candidate_maps: all_candidate_tags.update(per_phrase_candidates.keys()) score_context_by_tag: Dict[str, Optional[float]] = {} if query_has_context: reduced_matrix_norm = tag_vectors["reduced_matrix_norm"] for tag in all_candidate_tags: row = tag_to_row_index.get(tag) if row is None: score_context_by_tag[tag] = None continue score_context_by_tag[tag] = float(np.dot(reduced_query_vector, reduced_matrix_norm[row])) else: for tag in all_candidate_tags: score_context_by_tag[tag] = None merged_by_tag: Dict[str, Candidate] = {} per_phrase_scored: Dict[str, List[Tuple[str, float, Optional[float], float]]] = {} for phrase, per_phrase_candidates in phrase_candidate_maps: context_imputed_by_tag: Dict[str, bool] = {} default_context_for_phrase = None if query_has_context: context_scores = [ score_context_by_tag.get(tag) for tag in per_phrase_candidates.keys() ] context_scores = [score for score in context_scores if score is not None] if context_scores: context_scores.sort() index = int(math.floor(0.10 * (len(context_scores) - 1))) default_context_for_phrase = float(context_scores[index]) else: default_context_for_phrase = 0.0 scored_rows: List[Tuple[str, float, Optional[float], float]] = [] for tag, score_fasttext in per_phrase_candidates.items(): if not allow_nsfw_tags and tag in nsfw_tags: continue score_context = score_context_by_tag.get(tag) context_imputed = False if score_context is None and query_has_context: # Impute missing context with the per-phrase 10th percentile. score_context = default_context_for_phrase context_imputed = True if score_context is None: score_combined = float(score_fasttext) else: score_combined = (1.0 - context_weight) * float(score_fasttext) + context_weight * score_context scored_rows.append((tag, float(score_fasttext), score_context, score_combined)) context_imputed_by_tag[tag] = context_imputed scored_rows.sort(key=lambda x: x[3], reverse=True) required_tags = phrase_required_tags.get(phrase, set()) if required_tags: scored_by_tag = {row[0]: row for row in scored_rows} top_rows = scored_rows[:per_phrase_final_k] top_tags = {row[0] for row in top_rows} for required_tag in required_tags: if required_tag in top_tags: continue required_row = scored_by_tag.get(required_tag) if required_row is None: score_fasttext = per_phrase_candidates.get(required_tag) score_context = score_context_by_tag.get(required_tag) if score_fasttext is None: score_fasttext = 1.0 context_imputed = False if score_context is None and query_has_context: score_context = default_context_for_phrase context_imputed = True if score_context is None: score_combined = float(score_fasttext) else: score_combined = (1.0 - context_weight) * float(score_fasttext) + context_weight * score_context required_row = (required_tag, float(score_fasttext), score_context, score_combined) context_imputed_by_tag[required_tag] = context_imputed if len(top_rows) >= per_phrase_final_k: drop_index = None for idx in range(len(top_rows) - 1, -1, -1): if top_rows[idx][0] not in required_tags: drop_index = idx break if drop_index is None: drop_index = -1 top_rows.pop(drop_index) top_rows.append(required_row) top_tags.add(required_tag) # Deterministic must-include for exact phrase matches; re-sort top-N by combined score. top_rows.sort(key=lambda x: x[3], reverse=True) scored_rows = top_rows else: scored_rows = scored_rows[:per_phrase_final_k] per_phrase_scored[phrase] = scored_rows phrase_context_imputed[phrase] = context_imputed_by_tag for tag, score_fasttext, score_context, score_combined in scored_rows: existing = merged_by_tag.get(tag) if existing is None: merged_by_tag[tag] = Candidate( tag=tag, score_combined=score_combined, score_fasttext=score_fasttext, score_context=score_context, count=tag_counts.get(tag), sources=[phrase], ) else: if phrase not in existing.sources: existing.sources.append(phrase) existing_fasttext = ( existing.score_fasttext if existing.score_fasttext is not None else float("-inf") ) incoming_fasttext = score_fasttext if score_fasttext is not None else float("-inf") max_fasttext = max(existing_fasttext, incoming_fasttext) existing_context = existing.score_context if existing_context is None: max_context = score_context elif score_context is None: max_context = existing_context else: max_context = max(existing_context, score_context) max_combined = max(existing.score_combined, score_combined) merged_by_tag[tag] = Candidate( tag=tag, score_combined=max_combined, score_fasttext=max_fasttext if max_fasttext != float("-inf") else None, score_context=max_context, count=existing.count, sources=existing.sources, ) if verbose: for report in phrase_reports: phrase = report["phrase"] rows = [] for tag, score_fasttext, score_context, score_combined in per_phrase_scored.get(phrase, []): alias_token = phrase_best_tokens.get(phrase, {}).get(tag, tag) context_imputed = phrase_context_imputed.get(phrase, {}).get(tag, False) rows.append( { "tag": tag, "alias_token": alias_token, "score_fasttext": score_fasttext, "score_context": score_context, "score_combined": score_combined, "context_imputed": context_imputed, "count": tag_counts.get(tag), } ) report["candidates"] = rows merged_candidates = list(merged_by_tag.values()) merged_candidates.sort(key=lambda c: c.score_combined, reverse=True) merged_candidates = merged_candidates[:global_k] return (merged_candidates, phrase_reports) if verbose else merged_candidates def psq_candidates_from_prompt(prompt: str, *, allow_nsfw_tags: bool, k: int = 300): """Return Stage 2 candidates from a raw prompt.""" from ..parsing.prompt_grammar import build_tag_offsets_dicts, extract_tags, parser p = (prompt or "").lower() p, removed_special = remove_special_tags(p) parsed = parser.parse(p) tags_with_pos = extract_tags(parsed) tag_data = build_tag_offsets_dicts(tags_with_pos) # These are TF-IDF terms as your pipeline already expects terms = [item["tf_idf_matrix_tag"] for item in tag_data] + removed_special return psq_candidates_from_terms(terms, allow_nsfw_tags=allow_nsfw_tags, k=k) if __name__ == "__main__": print("psq_retrieval.py imports ok")