from __future__ import annotations import csv import logging import pathlib from typing import Any, Dict, List, Optional, Set, Tuple import joblib import numpy as np try: import hnswlib except Exception: hnswlib = None # allow import on environments without hnswlib during partial tests TFIDF_PATH = pathlib.Path("tf_idf_files_420.joblib") NSFW_CSV_PATH = pathlib.Path("word_rating_probabilities.csv") NSFW_THRESHOLD = 0.95 HNSW_ART_PATH = pathlib.Path("tfidf_hnsw_artists.bin") HNSW_TAG_PATH = pathlib.Path("tfidf_hnsw_tags.bin") FASTTEXT_MODEL_PATH = pathlib.Path("e621FastTextModel010Replacement_small.bin") TAG_ALIASES_PATH = pathlib.Path("fluffyrock_3m.csv") _tfidf_components: Optional[Dict[str, Any]] = None _nsfw_tags: Optional[Set[str]] = None _artist_set: Optional[Set[str]] = None _fasttext_model: Optional[Any] = None _tag_counts: Optional[Dict[str, int]] = None _tfidf_tag_vectors: Optional[Dict[str, Any]] = None _alias_to_tags: Optional[Dict[str, List[str]]] = None _tag_to_aliases: Optional[Dict[str, List[str]]] = None _tag_type_id: Optional[Dict[str, int]] = None _hnsw_tag_index: Optional["hnswlib.Index"] = None _hnsw_artist_index: Optional["hnswlib.Index"] = None _hnsw_tag_count: int = 0 _hnsw_artist_count: int = 0 # Tag type names inferred from e621 wiki documentation. # Numeric IDs come from fluffyrock_3m.csv column 1; mapping is heuristic but # matches observed usage on e621. TAG_TYPE_ID_TO_NAME: Dict[int, str] = { 0: "general", # Default tag type: visible attributes, actions, objects, etc. 1: "artist", # Artist tags (e.g. by_name, artist_name) 2: "contributor", # Contributor tags (rare / possibly unused in this dataset) 3: "copyright", # Series, franchise, or IP (e.g. pokemon, winnie_the_pooh) 4: "character", # Named characters (e.g. pikachu, pinkie_pie_(mlp)) 5: "species", # Species tags (e.g. canine, domestic_cat) 6: "invalid", # Invalid / disallowed / disambiguation-only tags 7: "meta", # Meta / presentation / file / style-related tags } def _l2_normalize_rows(mat: np.ndarray) -> np.ndarray: mat = np.asarray(mat, dtype=np.float32) norms = np.linalg.norm(mat, axis=1, keepdims=True) norms[norms == 0.0] = 1.0 return mat / norms def _clean_tag_ascii(tag: str) -> str: return "".join(char for char in tag if ord(char) < 128) def clean_tag(tag: str) -> str: """Normalize tags consistently with legacy alias parsing.""" return _clean_tag_ascii(tag) def build_aliases_dict(csv_path: str, reverse: bool = False) -> Dict[str, List[str]]: """Build tag/alias mappings from the aliases CSV.""" aliases_dict: Dict[str, List[str]] = {} with open(csv_path, "r", newline="", encoding="utf-8") as csvfile: reader = csv.reader(csvfile) for row in reader: tag = clean_tag(row[0]) alias_list = [] if row[3] == "null" else [clean_tag(alias) for alias in row[3].split(",")] if reverse: for alias in alias_list: aliases_dict.setdefault(alias, []).append(tag) else: aliases_dict[tag] = alias_list return aliases_dict def get_tfidf_components() -> Dict[str, Any]: global _tfidf_components if _tfidf_components is not None: return _tfidf_components if not TFIDF_PATH.is_file(): raise FileNotFoundError(f"TF-IDF joblib not found: {TFIDF_PATH}") model_components = joblib.load(TFIDF_PATH) if "tag_to_row_index" in model_components and "row_to_tag" not in model_components: model_components["row_to_tag"] = { idx: tag for tag, idx in model_components["tag_to_row_index"].items() } idf = model_components.get("idf") if isinstance(idf, dict): t2c = model_components["tag_to_column_index"] n_cols = max(t2c.values()) + 1 idf_by_col = np.ones(n_cols, dtype=np.float32) for term, col in t2c.items(): idf_by_col[col] = float(idf.get(term, 1.0)) model_components["idf"] = idf_by_col _tfidf_components = model_components return model_components def get_nsfw_tags() -> Set[str]: global _nsfw_tags if _nsfw_tags is not None: return _nsfw_tags if not NSFW_CSV_PATH.is_file(): raise FileNotFoundError(f"NSFW tag CSV not found: {NSFW_CSV_PATH}") tags: Set[str] = set() with NSFW_CSV_PATH.open("r", newline="", encoding="utf-8") as csvfile: reader = csv.reader(csvfile) next(reader, None) for row in reader: if not row: continue word = row[0] try: probability_sum = float(row[1]) except (IndexError, ValueError): continue if probability_sum >= NSFW_THRESHOLD: tags.add(word) _nsfw_tags = tags return _nsfw_tags def get_artist_set() -> Set[str]: global _artist_set if _artist_set is not None: return _artist_set path = pathlib.Path("fluffyrock_3m.csv") if not path.is_file(): _artist_set = set() return _artist_set artists: Set[str] = set() with path.open("r", newline="", encoding="utf-8") as csvfile: reader = csv.reader(csvfile) for row in reader: if not row: continue tag_name = row[0] if tag_name.startswith("by_"): artists.add(tag_name[3:]) _artist_set = artists return _artist_set def is_artist(name: str) -> bool: return name in get_artist_set() def get_fasttext_model() -> Any: global _fasttext_model if _fasttext_model is not None: return _fasttext_model if not FASTTEXT_MODEL_PATH.is_file(): raise FileNotFoundError(f"FastText model not found: {FASTTEXT_MODEL_PATH}") import compress_fasttext _fasttext_model = compress_fasttext.models.CompressedFastTextKeyedVectors.load( str(FASTTEXT_MODEL_PATH) ) return _fasttext_model def get_tag_type_ids() -> Dict[str, int]: """Return canonical tag -> type_id (int) from fluffyrock_3m.csv. Reads row[1] as int when possible. Missing/invalid values are skipped. """ global _tag_type_id if _tag_type_id is not None: return _tag_type_id if not TAG_ALIASES_PATH.is_file(): raise FileNotFoundError(f"Tag CSV not found: {TAG_ALIASES_PATH}") m: Dict[str, int] = {} with TAG_ALIASES_PATH.open("r", newline="", encoding="utf-8") as csvfile: reader = csv.reader(csvfile) for row in reader: if not row: continue tag = clean_tag(row[0]) if len(row) < 2: continue try: type_id = int(row[1]) except ValueError: continue m[tag] = type_id _tag_type_id = m return _tag_type_id def get_tag_type_name(tag: str) -> Optional[str]: """Return heuristic type name for a tag (e.g. 'artist', 'character'), or None.""" tid = get_tag_type_ids().get(clean_tag(tag)) if tid is None: return None return TAG_TYPE_ID_TO_NAME.get(tid, f"type_{tid}") def get_tag_counts() -> Dict[str, int]: global _tag_counts if _tag_counts is not None: return _tag_counts if not TAG_ALIASES_PATH.is_file(): raise FileNotFoundError(f"Tag count CSV not found: {TAG_ALIASES_PATH}") tag_counts: Dict[str, int] = {} with TAG_ALIASES_PATH.open("r", newline="", encoding="utf-8") as csvfile: reader = csv.reader(csvfile) for row in reader: if not row: continue key = row[0] value = int(row[2]) if row[2].isdigit() else None if value is not None: tag_counts[key] = value _tag_counts = tag_counts return _tag_counts def get_alias2tags() -> Dict[str, List[str]]: """Return alias -> [canonical tags] mapping.""" global _alias_to_tags if _alias_to_tags is not None: return _alias_to_tags if not TAG_ALIASES_PATH.is_file(): raise FileNotFoundError(f"Tag alias CSV not found: {TAG_ALIASES_PATH}") _alias_to_tags = build_aliases_dict(str(TAG_ALIASES_PATH), reverse=True) return _alias_to_tags def get_tag2aliases() -> Dict[str, List[str]]: """Return canonical tag -> [aliases] mapping.""" global _tag_to_aliases if _tag_to_aliases is not None: return _tag_to_aliases if not TAG_ALIASES_PATH.is_file(): raise FileNotFoundError(f"Tag alias CSV not found: {TAG_ALIASES_PATH}") _tag_to_aliases = build_aliases_dict(str(TAG_ALIASES_PATH), reverse=False) return _tag_to_aliases def get_tfidf_tag_vectors() -> Dict[str, Any]: global _tfidf_tag_vectors if _tfidf_tag_vectors is not None: return _tfidf_tag_vectors components = get_tfidf_components() reduced_matrix = components.get("reduced_matrix") if reduced_matrix is None: raise KeyError("TF-IDF components missing reduced_matrix") row_to_tag = components.get("row_to_tag") if row_to_tag is None and "tag_to_row_index" in components: row_to_tag = {idx: tag for tag, idx in components["tag_to_row_index"].items()} if row_to_tag is None: raise KeyError("TF-IDF components missing row_to_tag mapping") tag_to_row_index = components.get("tag_to_row_index") if tag_to_row_index is None: tag_to_row_index = {tag: idx for idx, tag in row_to_tag.items()} reduced_matrix_norm = _l2_normalize_rows(reduced_matrix).astype(np.float32) _tfidf_tag_vectors = { "reduced_matrix": reduced_matrix, "reduced_matrix_norm": reduced_matrix_norm, "row_to_tag": row_to_tag, "tag_to_row_index": tag_to_row_index, } return _tfidf_tag_vectors def retrieval_assets_status() -> Dict[str, bool]: return { "tfidf": TFIDF_PATH.is_file(), "nsfw_csv": NSFW_CSV_PATH.is_file(), "fasttext_model": FASTTEXT_MODEL_PATH.is_file(), "tag_aliases_csv": TAG_ALIASES_PATH.is_file(), "hnsw_tags": HNSW_TAG_PATH.is_file(), "hnsw_artists": HNSW_ART_PATH.is_file(), } def _build_or_load_index(path: pathlib.Path, rows: list[int], rm: np.ndarray, dim: int) -> "hnswlib.Index": idx = hnswlib.Index(space="cosine", dim=dim) need_build = True if path.exists(): try: idx.load_index(str(path), max_elements=max(1, len(rows))) if getattr(idx, "get_current_count", None) and idx.get_current_count() == len(rows) and len(rows) > 0: need_build = False else: logging.debug( "Rebuilding %s: saved_count!=rows_len (%s vs %s)", path.name, idx.get_current_count(), len(rows), ) except Exception as e: logging.debug("Reload %s failed, rebuilding: %s", path.name, e) if need_build: try: if path.exists(): path.unlink() except Exception: pass idx.init_index(max_elements=max(1, len(rows)), ef_construction=200, M=16) if rows: idx.add_items(rm[rows], ids=np.asarray(rows, dtype=np.int32)) idx.save_index(str(path)) idx.set_ef(200) return idx def _ensure_hnsw_indexes(need_artists: bool) -> None: global _hnsw_tag_index, _hnsw_artist_index, _hnsw_tag_count, _hnsw_artist_count if hnswlib is None: return if _hnsw_tag_index is not None and (not need_artists or _hnsw_artist_index is not None): return components = get_tfidf_components() reduced_matrix = components["reduced_matrix"] row_to_tag = components["row_to_tag"] rm = _l2_normalize_rows(reduced_matrix).astype(np.float32) n_items, dim = rm.shape artist_set = get_artist_set() if need_artists else set() artist_rows: list[int] = [] tag_rows: list[int] = [] for i in range(n_items): tag = row_to_tag.get(i, "") base = tag[3:] if tag.startswith("by_") else tag if tag in {"by_unknown_artist", "by_conditional_dnp"}: tag_rows.append(i) continue if artist_set and is_artist(base): artist_rows.append(i) else: tag_rows.append(i) _hnsw_tag_index = _build_or_load_index(HNSW_TAG_PATH, tag_rows, rm, dim) _hnsw_tag_count = len(tag_rows) if need_artists: _hnsw_artist_index = _build_or_load_index(HNSW_ART_PATH, artist_rows, rm, dim) _hnsw_artist_count = len(artist_rows) def get_hnsw_tag_index() -> Tuple[Optional["hnswlib.Index"], int]: _ensure_hnsw_indexes(need_artists=False) return _hnsw_tag_index, _hnsw_tag_count def get_hnsw_artist_index() -> Tuple[Optional["hnswlib.Index"], int]: _ensure_hnsw_indexes(need_artists=True) return _hnsw_artist_index, _hnsw_artist_count