from __future__ import annotations import json import logging import math import os import pathlib import re from collections import Counter, OrderedDict from dataclasses import dataclass from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union import numpy as np import joblib from scipy.sparse import csr_matrix from .state import ( get_fasttext_model, get_tag_counts, get_hnsw_tag_index, get_nsfw_tags, get_tfidf_components, get_tfidf_tag_vectors, get_alias2tags, ) @dataclass(frozen=True) class Candidate: tag: str score_combined: float score_fasttext: Optional[float] score_context: Optional[float] count: Optional[int] sources: List[str] def _norm_tag_for_lookup(s: str) -> str: # convert "name with spaces" -> "name_with_spaces" and unescape parens return s.replace(' ', '_').replace('\\(', '(').replace('\\)', ')') def construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index): cols, data = [], [] for term, w in pseudo_doc_terms.items(): j = term_to_column_index.get(term) if j is None: continue cols.append(j) data.append(w * idf[j]) n_cols = len(idf) indptr = [0, len(cols)] return csr_matrix((data, cols, indptr), shape=(1, n_cols), dtype=np.float32) def _hnsw_query(idx, vec: np.ndarray, k: int): """ Query a given HNSW index with a (1, D) or (D,) vector in SVD space. Returns (indices, sims) with cosine similarity scores. """ q = np.asarray(vec, dtype=np.float32).reshape(-1) q_norm = np.linalg.norm(q) if q_norm > 0: q = q / q_norm labels, dists = idx.knn_query(q, k=k) inds = labels[0] sims = 1.0 - dists[0] # cosine distance -> similarity return inds, sims def _ann_tags_topk(vec: np.ndarray, k: int): idx, n_items = get_hnsw_tag_index() if idx is None: return (np.array([], dtype=int), np.array([], dtype=float)) k = min(k, n_items if n_items else 0) return _hnsw_query(idx, vec, k) if k else (np.array([], dtype=int), np.array([], dtype=float)) def get_tfidf_reduced_similar_tags(pseudo_doc_terms, allow_nsfw_tags): tf_idf_components = get_tfidf_components() idf = tf_idf_components["idf"] term_to_column_index = tf_idf_components["tag_to_column_index"] row_to_tag = tf_idf_components["row_to_tag"] svd = tf_idf_components["svd_model"] # 1) Build the pseudo TF-IDF, reduce to SVD space (unchanged) pseudo_tfidf_vector = construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index) reduced_pseudo_vector = svd.transform(pseudo_tfidf_vector) # shape (1, D) # 2) ANN: only fetch nearest non-artist candidates (no full-matrix cosine) K = 2000 # tune for speed/recall top_inds, top_sims = _ann_tags_topk(reduced_pseudo_vector, k=K) # 3) Build similarity dict from those candidates tag_similarity_dict = {} for i, sim in zip(top_inds, top_sims): tag = row_to_tag.get(int(i)) if tag is not None: tag_similarity_dict[tag] = float(sim) if not allow_nsfw_tags: nsfw_tags = get_nsfw_tags() tag_similarity_dict = {t: s for t, s in tag_similarity_dict.items() if t not in nsfw_tags} # 4) Sort & escape like before sorted_tag_similarity_dict = OrderedDict(sorted(tag_similarity_dict.items(), key=lambda x: x[1], reverse=True)) transformed_sorted_tag_similarity_dict = OrderedDict( (key.replace('_', ' ').replace('(', '\\(').replace(')', '\\)'), val) for key, val in sorted_tag_similarity_dict.items() ) return transformed_sorted_tag_similarity_dict def psq_candidates_from_rewrite_phrases( rewrite_phrases: Sequence[str], *, allow_nsfw_tags: bool, context_tags: Optional[Sequence[str]] = None, context_tag_weight: float = 1.0, context_weight: float = 0.5, per_phrase_k: int = 50, per_phrase_final_k: int = 1, global_k: int = 300, min_tag_count: int = 0, return_phrase_ranks: bool = False, verbose: bool = False, ) -> Union[List[Candidate], Tuple[List[Candidate], List[Dict[str, Any]]]]: head_stopwords = { "and", "or", "the", "a", "an", "of", "to", "in", "on", "at", "with", "for", "from", "by", "as", "is", "are", "was", "were", "be", "been", "being", "down", "up", "over", "under", } def _normalize_phrase(phrase: str) -> str: lowered = (phrase or "").lower().strip().replace("_", " ") return " ".join(lowered.split()) norm_phrases = [_normalize_phrase(p) for p in rewrite_phrases] deduped_phrases = list(dict.fromkeys(p for p in norm_phrases if p)) if not deduped_phrases: return ([], []) if verbose else [] head_phrases: List[str] = [] for phrase in deduped_phrases: parts = phrase.split() if len(parts) >= 2: head = parts[-1] if len(head) >= 3 and head.lower() not in head_stopwords: head_phrases.append(head) final_phrases = list(dict.fromkeys(deduped_phrases + head_phrases)) fasttext_model = get_fasttext_model() skip_fasttext_for_exact_alias = os.environ.get( "PSQ_SKIP_FASTTEXT_FOR_EXACT_ALIAS", "1", ).strip().lower() in {"1", "true", "yes"} tag_counts = get_tag_counts() def _count_ok(tag: str) -> bool: if min_tag_count <= 0: return True return int(tag_counts.get(tag, 0) or 0) >= min_tag_count nsfw_tags = get_nsfw_tags() if not allow_nsfw_tags else set() alias2tags = get_alias2tags() tfidf_components = get_tfidf_components() tfidf_vocab = tfidf_components.get("tag_to_column_index", {}) idf = tfidf_components["idf"] term_to_column_index = tfidf_components["tag_to_column_index"] svd = tfidf_components["svd_model"] pseudo_doc_terms = Counter() oov_terms: List[str] = [] for phrase in final_phrases: lookup = phrase.replace(" ", "_") if lookup in term_to_column_index: pseudo_doc_terms[lookup] += 1 elif verbose: oov_terms.append(lookup) # Optional auxiliary retrieval context from reliable side-channel tags # (e.g., structural/probe outputs). These only affect the TF-IDF/SVD # context vector, not FastText neighbor generation. if context_tags and context_tag_weight > 0: for t in context_tags: if not t: continue lookup = str(t).strip().lower().replace(" ", "_") if not lookup: continue if lookup in term_to_column_index: pseudo_doc_terms[lookup] += float(context_tag_weight) elif verbose: oov_terms.append(lookup) pseudo_tfidf_vector = construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index) reduced_query_vector = svd.transform(pseudo_tfidf_vector).reshape(-1) query_norm = np.linalg.norm(reduced_query_vector) if query_norm > 0: reduced_query_vector = reduced_query_vector / query_norm query_has_context = True else: query_has_context = False tag_vectors = get_tfidf_tag_vectors() if query_has_context else None tag_to_row_index = tag_vectors["tag_to_row_index"] if tag_vectors else {} phrase_candidate_maps: List[Tuple[str, Dict[str, float]]] = [] phrase_required_tags: Dict[str, Set[str]] = {} phrase_best_tokens: Dict[str, Dict[str, str]] = {} phrase_context_imputed: Dict[str, Dict[str, bool]] = {} phrase_reports: List[Dict[str, Any]] = [] phrase_rank_by_tag: Dict[str, int] = {} for phrase in final_phrases: lookup = phrase.replace(" ", "_") def _project_to_canonicals(token: str) -> List[str]: if (token in tag_counts or token in tag_to_row_index) and _count_ok(token): return [token] if token in alias2tags: return [t for t in alias2tags[token] if _count_ok(t)] return [] projected_lookup = _project_to_canonicals(lookup) if skip_fasttext_for_exact_alias and projected_lookup: neighbors = [] else: try: neighbors = fasttext_model.most_similar(lookup, topn=per_phrase_k) except KeyError: neighbors = [] per_phrase_candidates: Dict[str, float] = {} per_phrase_best_token: Dict[str, str] = {} for token, sim in neighbors: for canonical_tag in _project_to_canonicals(token): if not allow_nsfw_tags and canonical_tag in nsfw_tags: continue if not _count_ok(canonical_tag): continue prev = per_phrase_candidates.get(canonical_tag) if prev is None or sim > prev: per_phrase_candidates[canonical_tag] = float(sim) per_phrase_best_token[canonical_tag] = token required_tags = set(projected_lookup) if not allow_nsfw_tags: required_tags = {tag for tag in required_tags if tag not in nsfw_tags} for canonical_tag in projected_lookup: if not allow_nsfw_tags and canonical_tag in nsfw_tags: continue if not _count_ok(canonical_tag): continue prev = per_phrase_candidates.get(canonical_tag) if prev is None or 1.0 > prev: per_phrase_candidates[canonical_tag] = 1.0 per_phrase_best_token[canonical_tag] = lookup phrase_candidate_maps.append((phrase, per_phrase_candidates)) phrase_required_tags[phrase] = required_tags phrase_best_tokens[phrase] = per_phrase_best_token if verbose: in_vocab = bool(tfidf_vocab and lookup in tfidf_vocab) rows = [] for canonical_tag, sim in sorted(per_phrase_candidates.items(), key=lambda x: x[1], reverse=True): if not allow_nsfw_tags and canonical_tag in nsfw_tags: continue alias_token = per_phrase_best_token.get(canonical_tag, canonical_tag) rows.append( { "tag": canonical_tag, "alias_token": alias_token, "score_fasttext": float(sim), "score_context": None, "score_combined": float(sim), "context_imputed": False, "count": tag_counts.get(canonical_tag), } ) phrase_reports.append( { "phrase": phrase, "normalized": phrase, "lookup": lookup, "tfidf_vocab": in_vocab, "oov_terms": oov_terms, "candidates": rows, } ) all_candidate_tags: Set[str] = set() for _, per_phrase_candidates in phrase_candidate_maps: all_candidate_tags.update(per_phrase_candidates.keys()) score_context_by_tag: Dict[str, Optional[float]] = {} if query_has_context: reduced_matrix_norm = tag_vectors["reduced_matrix_norm"] for tag in all_candidate_tags: row = tag_to_row_index.get(tag) if row is None: score_context_by_tag[tag] = None continue score_context_by_tag[tag] = float(np.dot(reduced_query_vector, reduced_matrix_norm[row])) else: for tag in all_candidate_tags: score_context_by_tag[tag] = None merged_by_tag: Dict[str, Candidate] = {} per_phrase_scored: Dict[str, List[Tuple[str, float, Optional[float], float]]] = {} for phrase, per_phrase_candidates in phrase_candidate_maps: context_imputed_by_tag: Dict[str, bool] = {} default_context_for_phrase = None if query_has_context: context_scores = [ score_context_by_tag.get(tag) for tag in per_phrase_candidates.keys() ] context_scores = [score for score in context_scores if score is not None] if context_scores: context_scores.sort() index = int(math.floor(0.10 * (len(context_scores) - 1))) default_context_for_phrase = float(context_scores[index]) else: default_context_for_phrase = 0.0 scored_rows: List[Tuple[str, float, Optional[float], float]] = [] for tag, score_fasttext in per_phrase_candidates.items(): if not allow_nsfw_tags and tag in nsfw_tags: continue score_context = score_context_by_tag.get(tag) context_imputed = False if score_context is None and query_has_context: # Impute missing context with the per-phrase 10th percentile. score_context = default_context_for_phrase context_imputed = True if score_context is None: score_combined = float(score_fasttext) else: score_combined = (1.0 - context_weight) * float(score_fasttext) + context_weight * score_context scored_rows.append((tag, float(score_fasttext), score_context, score_combined)) context_imputed_by_tag[tag] = context_imputed scored_rows.sort(key=lambda x: x[3], reverse=True) required_tags = phrase_required_tags.get(phrase, set()) if required_tags: scored_by_tag = {row[0]: row for row in scored_rows} top_rows = scored_rows[:per_phrase_final_k] top_tags = {row[0] for row in top_rows} for required_tag in required_tags: if required_tag in top_tags: continue required_row = scored_by_tag.get(required_tag) if required_row is None: score_fasttext = per_phrase_candidates.get(required_tag) score_context = score_context_by_tag.get(required_tag) if score_fasttext is None: score_fasttext = 1.0 context_imputed = False if score_context is None and query_has_context: score_context = default_context_for_phrase context_imputed = True if score_context is None: score_combined = float(score_fasttext) else: score_combined = (1.0 - context_weight) * float(score_fasttext) + context_weight * score_context required_row = (required_tag, float(score_fasttext), score_context, score_combined) context_imputed_by_tag[required_tag] = context_imputed if len(top_rows) >= per_phrase_final_k: drop_index = None for idx in range(len(top_rows) - 1, -1, -1): if top_rows[idx][0] not in required_tags: drop_index = idx break if drop_index is None: drop_index = -1 top_rows.pop(drop_index) top_rows.append(required_row) top_tags.add(required_tag) # Deterministic must-include for exact phrase matches; re-sort top-N by combined score. top_rows.sort(key=lambda x: x[3], reverse=True) scored_rows = top_rows else: scored_rows = scored_rows[:per_phrase_final_k] per_phrase_scored[phrase] = scored_rows phrase_context_imputed[phrase] = context_imputed_by_tag if return_phrase_ranks: for rank, (tag, _score_fasttext, _score_context, _score_combined) in enumerate(scored_rows, start=1): prev = phrase_rank_by_tag.get(tag) if prev is None or rank < prev: phrase_rank_by_tag[tag] = rank for tag, score_fasttext, score_context, score_combined in scored_rows: existing = merged_by_tag.get(tag) if existing is None: merged_by_tag[tag] = Candidate( tag=tag, score_combined=score_combined, score_fasttext=score_fasttext, score_context=score_context, count=tag_counts.get(tag), sources=[phrase], ) else: if phrase not in existing.sources: existing.sources.append(phrase) existing_fasttext = ( existing.score_fasttext if existing.score_fasttext is not None else float("-inf") ) incoming_fasttext = score_fasttext if score_fasttext is not None else float("-inf") max_fasttext = max(existing_fasttext, incoming_fasttext) existing_context = existing.score_context if existing_context is None: max_context = score_context elif score_context is None: max_context = existing_context else: max_context = max(existing_context, score_context) max_combined = max(existing.score_combined, score_combined) merged_by_tag[tag] = Candidate( tag=tag, score_combined=max_combined, score_fasttext=max_fasttext if max_fasttext != float("-inf") else None, score_context=max_context, count=existing.count, sources=existing.sources, ) if verbose: for report in phrase_reports: phrase = report["phrase"] rows = [] for tag, score_fasttext, score_context, score_combined in per_phrase_scored.get(phrase, []): alias_token = phrase_best_tokens.get(phrase, {}).get(tag, tag) context_imputed = phrase_context_imputed.get(phrase, {}).get(tag, False) rows.append( { "tag": tag, "alias_token": alias_token, "score_fasttext": score_fasttext, "score_context": score_context, "score_combined": score_combined, "context_imputed": context_imputed, "count": tag_counts.get(tag), } ) report["candidates"] = rows merged_candidates = list(merged_by_tag.values()) merged_candidates.sort(key=lambda c: c.score_combined, reverse=True) merged_candidates = merged_candidates[:global_k] if return_phrase_ranks: if verbose: return (merged_candidates, phrase_reports, phrase_rank_by_tag) return (merged_candidates, phrase_rank_by_tag) return (merged_candidates, phrase_reports) if verbose else merged_candidates if __name__ == "__main__": print("psq_retrieval.py imports ok")