from __future__ import annotations import csv import logging import pathlib from typing import Any, Dict, List, Optional, Set, Tuple import joblib import numpy as np try: import hnswlib except Exception: hnswlib = None # allow import on environments without hnswlib during partial tests TFIDF_PATH = pathlib.Path("tf_idf_files_420.joblib") NSFW_CSV_PATH = pathlib.Path("word_rating_probabilities.csv") NSFW_THRESHOLD = 0.95 HNSW_TAG_PATH = pathlib.Path("tfidf_hnsw_tags.bin") FASTTEXT_MODEL_PATH = pathlib.Path("e621FastTextModel010Replacement_small.bin") TAG_ALIASES_PATH = pathlib.Path("fluffyrock_3m.csv") TAG_IMPLICATIONS_PATH = pathlib.Path("tag_implications-2023-07-20.csv") _tfidf_components: Optional[Dict[str, Any]] = None _nsfw_tags: Optional[Set[str]] = None _artist_set: Optional[Set[str]] = None _fasttext_model: Optional[Any] = None _tag_counts: Optional[Dict[str, int]] = None _tfidf_tag_vectors: Optional[Dict[str, Any]] = None _alias_to_tags: Optional[Dict[str, List[str]]] = None _tag_to_aliases: Optional[Dict[str, List[str]]] = None _tag_type_id: Optional[Dict[str, int]] = None _tag_implications: Optional[Dict[str, List[str]]] = None _hnsw_tag_index: Optional["hnswlib.Index"] = None _hnsw_tag_count: int = 0 # Tag type names inferred from e621 wiki documentation. # Numeric IDs come from fluffyrock_3m.csv column 1; mapping is heuristic but # matches observed usage on e621. TAG_TYPE_ID_TO_NAME: Dict[int, str] = { 0: "general", # Default tag type: visible attributes, actions, objects, etc. 1: "artist", # Artist tags (e.g. by_name, artist_name) 2: "contributor", # Contributor tags (rare / possibly unused in this dataset) 3: "copyright", # Series, franchise, or IP (e.g. pokemon, winnie_the_pooh) 4: "character", # Named characters (e.g. pikachu, pinkie_pie_(mlp)) 5: "species", # Species tags (e.g. canine, domestic_cat) 6: "invalid", # Invalid / disallowed / disambiguation-only tags 7: "meta", # Meta / presentation / file / style-related tags } def _l2_normalize_rows(mat: np.ndarray) -> np.ndarray: mat = np.asarray(mat, dtype=np.float32) norms = np.linalg.norm(mat, axis=1, keepdims=True) norms[norms == 0.0] = 1.0 return mat / norms def _clean_tag_ascii(tag: str) -> str: return "".join(char for char in tag if ord(char) < 128) def clean_tag(tag: str) -> str: """Normalize tags consistently with legacy alias parsing.""" return _clean_tag_ascii(tag) def build_aliases_dict(csv_path: str, reverse: bool = False) -> Dict[str, List[str]]: """Build tag/alias mappings from the aliases CSV.""" aliases_dict: Dict[str, List[str]] = {} with open(csv_path, "r", newline="", encoding="utf-8") as csvfile: reader = csv.reader(csvfile) for row in reader: tag = clean_tag(row[0]) alias_list = [] if row[3] == "null" else [clean_tag(alias) for alias in row[3].split(",")] if reverse: for alias in alias_list: aliases_dict.setdefault(alias, []).append(tag) else: aliases_dict[tag] = alias_list return aliases_dict def get_tfidf_components() -> Dict[str, Any]: global _tfidf_components if _tfidf_components is not None: return _tfidf_components if not TFIDF_PATH.is_file(): raise FileNotFoundError(f"TF-IDF joblib not found: {TFIDF_PATH}") model_components = joblib.load(TFIDF_PATH) if "tag_to_row_index" in model_components and "row_to_tag" not in model_components: model_components["row_to_tag"] = { idx: tag for tag, idx in model_components["tag_to_row_index"].items() } idf = model_components.get("idf") if isinstance(idf, dict): t2c = model_components["tag_to_column_index"] n_cols = max(t2c.values()) + 1 idf_by_col = np.ones(n_cols, dtype=np.float32) for term, col in t2c.items(): idf_by_col[col] = float(idf.get(term, 1.0)) model_components["idf"] = idf_by_col _tfidf_components = model_components return model_components def get_nsfw_tags() -> Set[str]: global _nsfw_tags if _nsfw_tags is not None: return _nsfw_tags if not NSFW_CSV_PATH.is_file(): raise FileNotFoundError(f"NSFW tag CSV not found: {NSFW_CSV_PATH}") tags: Set[str] = set() with NSFW_CSV_PATH.open("r", newline="", encoding="utf-8") as csvfile: reader = csv.reader(csvfile) next(reader, None) for row in reader: if not row: continue word = row[0] try: probability_sum = float(row[1]) except (IndexError, ValueError): continue if probability_sum >= NSFW_THRESHOLD: tags.add(word) _nsfw_tags = tags return _nsfw_tags def get_artist_set() -> Set[str]: global _artist_set if _artist_set is not None: return _artist_set path = pathlib.Path("fluffyrock_3m.csv") if not path.is_file(): _artist_set = set() return _artist_set artists: Set[str] = set() with path.open("r", newline="", encoding="utf-8") as csvfile: reader = csv.reader(csvfile) for row in reader: if not row: continue tag_name = row[0] if tag_name.startswith("by_"): artists.add(tag_name[3:]) _artist_set = artists return _artist_set def get_fasttext_model() -> Any: global _fasttext_model if _fasttext_model is not None: return _fasttext_model if not FASTTEXT_MODEL_PATH.is_file(): raise FileNotFoundError(f"FastText model not found: {FASTTEXT_MODEL_PATH}") import compress_fasttext _fasttext_model = compress_fasttext.models.CompressedFastTextKeyedVectors.load( str(FASTTEXT_MODEL_PATH) ) return _fasttext_model def get_tag_type_ids() -> Dict[str, int]: """Return canonical tag -> type_id (int) from fluffyrock_3m.csv. Reads row[1] as int when possible. Missing/invalid values are skipped. """ global _tag_type_id if _tag_type_id is not None: return _tag_type_id if not TAG_ALIASES_PATH.is_file(): raise FileNotFoundError(f"Tag CSV not found: {TAG_ALIASES_PATH}") m: Dict[str, int] = {} with TAG_ALIASES_PATH.open("r", newline="", encoding="utf-8") as csvfile: reader = csv.reader(csvfile) for row in reader: if not row: continue tag = clean_tag(row[0]) if len(row) < 2: continue try: type_id = int(row[1]) except ValueError: continue m[tag] = type_id _tag_type_id = m return _tag_type_id def get_tag_type_name(tag: str) -> Optional[str]: """Return heuristic type name for a tag (e.g. 'artist', 'character'), or None.""" tid = get_tag_type_ids().get(clean_tag(tag)) if tid is None: return None return TAG_TYPE_ID_TO_NAME.get(tid, f"type_{tid}") def get_tag_counts() -> Dict[str, int]: global _tag_counts if _tag_counts is not None: return _tag_counts if not TAG_ALIASES_PATH.is_file(): raise FileNotFoundError(f"Tag count CSV not found: {TAG_ALIASES_PATH}") tag_counts: Dict[str, int] = {} with TAG_ALIASES_PATH.open("r", newline="", encoding="utf-8") as csvfile: reader = csv.reader(csvfile) for row in reader: if not row: continue key = row[0] value = int(row[2]) if row[2].isdigit() else None if value is not None: tag_counts[key] = value _tag_counts = tag_counts return _tag_counts def get_alias2tags() -> Dict[str, List[str]]: """Return alias -> [canonical tags] mapping.""" global _alias_to_tags if _alias_to_tags is not None: return _alias_to_tags if not TAG_ALIASES_PATH.is_file(): raise FileNotFoundError(f"Tag alias CSV not found: {TAG_ALIASES_PATH}") _alias_to_tags = build_aliases_dict(str(TAG_ALIASES_PATH), reverse=True) return _alias_to_tags def get_tag2aliases() -> Dict[str, List[str]]: """Return canonical tag -> [aliases] mapping.""" global _tag_to_aliases if _tag_to_aliases is not None: return _tag_to_aliases if not TAG_ALIASES_PATH.is_file(): raise FileNotFoundError(f"Tag alias CSV not found: {TAG_ALIASES_PATH}") _tag_to_aliases = build_aliases_dict(str(TAG_ALIASES_PATH), reverse=False) return _tag_to_aliases def get_tag_implications() -> Dict[str, List[str]]: """Return antecedent_tag -> [consequent_tags] from the implications CSV. Only active implications where both tags exist in the tag database are kept. """ global _tag_implications if _tag_implications is not None: return _tag_implications if not TAG_IMPLICATIONS_PATH.is_file(): logging.warning("Tag implications CSV not found: %s", TAG_IMPLICATIONS_PATH) _tag_implications = {} return _tag_implications known_tags = set(get_tag_type_ids().keys()) impl: Dict[str, List[str]] = {} with TAG_IMPLICATIONS_PATH.open("r", newline="", encoding="utf-8") as csvfile: reader = csv.reader(csvfile) next(reader, None) # skip header for row in reader: if len(row) < 5 or row[4] != "active": continue antecedent = clean_tag(row[1]) consequent = clean_tag(row[2]) if antecedent in known_tags and consequent in known_tags: impl.setdefault(antecedent, []).append(consequent) _tag_implications = impl logging.info("Loaded %d tag implications", sum(len(v) for v in impl.values())) return _tag_implications def expand_tags_via_implications(tags: Set[str]) -> Tuple[Set[str], Set[str]]: """Walk the implication graph upward from each tag, collecting ancestors. Returns (all_tags, implied_only) where: - all_tags = original tags + implied ancestors - implied_only = tags that were added (not in the original set) """ impl = get_tag_implications() expanded = set(tags) queue = list(tags) while queue: tag = queue.pop() for parent in impl.get(tag, ()): if parent not in expanded: expanded.add(parent) queue.append(parent) implied_only = expanded - tags return expanded, implied_only def get_leaf_tags(tags: Set[str]) -> Set[str]: """Return only leaf tags — those not implied by any other tag in the set. For example, given {fox, canine, canid, mammal}, returns {fox} because canine/canid/mammal are all reachable from fox via implications. """ impl = get_tag_implications() # For each tag, compute what it implies; mark those as non-leaves non_leaves: Set[str] = set() for tag in tags: visited: Set[str] = set() queue = [tag] while queue: t = queue.pop() for parent in impl.get(t, ()): if parent not in visited: visited.add(parent) if parent in tags: non_leaves.add(parent) queue.append(parent) return tags - non_leaves def get_tfidf_tag_vectors() -> Dict[str, Any]: global _tfidf_tag_vectors if _tfidf_tag_vectors is not None: return _tfidf_tag_vectors components = get_tfidf_components() reduced_matrix = components.get("reduced_matrix") if reduced_matrix is None: raise KeyError("TF-IDF components missing reduced_matrix") row_to_tag = components.get("row_to_tag") if row_to_tag is None and "tag_to_row_index" in components: row_to_tag = {idx: tag for tag, idx in components["tag_to_row_index"].items()} if row_to_tag is None: raise KeyError("TF-IDF components missing row_to_tag mapping") tag_to_row_index = components.get("tag_to_row_index") if tag_to_row_index is None: tag_to_row_index = {tag: idx for idx, tag in row_to_tag.items()} reduced_matrix_norm = _l2_normalize_rows(reduced_matrix).astype(np.float32) _tfidf_tag_vectors = { "reduced_matrix": reduced_matrix, "reduced_matrix_norm": reduced_matrix_norm, "row_to_tag": row_to_tag, "tag_to_row_index": tag_to_row_index, } return _tfidf_tag_vectors def _build_or_load_index(path: pathlib.Path, rows: list[int], rm: np.ndarray, dim: int) -> "hnswlib.Index": idx = hnswlib.Index(space="cosine", dim=dim) need_build = True if path.exists(): try: idx.load_index(str(path), max_elements=max(1, len(rows))) if getattr(idx, "get_current_count", None) and idx.get_current_count() == len(rows) and len(rows) > 0: need_build = False else: logging.debug( "Rebuilding %s: saved_count!=rows_len (%s vs %s)", path.name, idx.get_current_count(), len(rows), ) except Exception as e: logging.debug("Reload %s failed, rebuilding: %s", path.name, e) if need_build: try: if path.exists(): path.unlink() except Exception: pass idx.init_index(max_elements=max(1, len(rows)), ef_construction=200, M=16) if rows: idx.add_items(rm[rows], ids=np.asarray(rows, dtype=np.int32)) idx.save_index(str(path)) idx.set_ef(200) return idx def _ensure_hnsw_indexes() -> None: global _hnsw_tag_index, _hnsw_tag_count if hnswlib is None: return if _hnsw_tag_index is not None: return components = get_tfidf_components() reduced_matrix = components["reduced_matrix"] rm = _l2_normalize_rows(reduced_matrix).astype(np.float32) n_items, dim = rm.shape tag_rows = list(range(n_items)) _hnsw_tag_index = _build_or_load_index(HNSW_TAG_PATH, tag_rows, rm, dim) _hnsw_tag_count = len(tag_rows) def get_hnsw_tag_index() -> Tuple[Optional["hnswlib.Index"], int]: _ensure_hnsw_indexes() return _hnsw_tag_index, _hnsw_tag_count