Food Desert
Add alias-based character tag filtering for Stage 3
c6be992
Raw
History Blame
13 kB
from __future__ import annotations
import csv
import logging
import pathlib
from typing import Any, Dict, List, Optional, Set, Tuple
import joblib
import numpy as np
try:
import hnswlib
except Exception:
hnswlib = None # allow import on environments without hnswlib during partial tests
TFIDF_PATH = pathlib.Path("tf_idf_files_420.joblib")
NSFW_CSV_PATH = pathlib.Path("word_rating_probabilities.csv")
NSFW_THRESHOLD = 0.95
HNSW_ART_PATH = pathlib.Path("tfidf_hnsw_artists.bin")
HNSW_TAG_PATH = pathlib.Path("tfidf_hnsw_tags.bin")
FASTTEXT_MODEL_PATH = pathlib.Path("e621FastTextModel010Replacement_small.bin")
TAG_ALIASES_PATH = pathlib.Path("fluffyrock_3m.csv")
_tfidf_components: Optional[Dict[str, Any]] = None
_nsfw_tags: Optional[Set[str]] = None
_artist_set: Optional[Set[str]] = None
_fasttext_model: Optional[Any] = None
_tag_counts: Optional[Dict[str, int]] = None
_tfidf_tag_vectors: Optional[Dict[str, Any]] = None
_alias_to_tags: Optional[Dict[str, List[str]]] = None
_tag_to_aliases: Optional[Dict[str, List[str]]] = None
_tag_type_id: Optional[Dict[str, int]] = None
_hnsw_tag_index: Optional["hnswlib.Index"] = None
_hnsw_artist_index: Optional["hnswlib.Index"] = None
_hnsw_tag_count: int = 0
_hnsw_artist_count: int = 0
# Tag type names inferred from e621 wiki documentation.
# Numeric IDs come from fluffyrock_3m.csv column 1; mapping is heuristic but
# matches observed usage on e621.
TAG_TYPE_ID_TO_NAME: Dict[int, str] = {
0: "general", # Default tag type: visible attributes, actions, objects, etc.
1: "artist", # Artist tags (e.g. by_name, artist_name)
2: "contributor", # Contributor tags (rare / possibly unused in this dataset)
3: "copyright", # Series, franchise, or IP (e.g. pokemon, winnie_the_pooh)
4: "character", # Named characters (e.g. pikachu, pinkie_pie_(mlp))
5: "species", # Species tags (e.g. canine, domestic_cat)
6: "invalid", # Invalid / disallowed / disambiguation-only tags
7: "meta", # Meta / presentation / file / style-related tags
}
def _l2_normalize_rows(mat: np.ndarray) -> np.ndarray:
mat = np.asarray(mat, dtype=np.float32)
norms = np.linalg.norm(mat, axis=1, keepdims=True)
norms[norms == 0.0] = 1.0
return mat / norms
def _clean_tag_ascii(tag: str) -> str:
return "".join(char for char in tag if ord(char) < 128)
def clean_tag(tag: str) -> str:
"""Normalize tags consistently with legacy alias parsing."""
return _clean_tag_ascii(tag)
def build_aliases_dict(csv_path: str, reverse: bool = False) -> Dict[str, List[str]]:
"""Build tag/alias mappings from the aliases CSV."""
aliases_dict: Dict[str, List[str]] = {}
with open(csv_path, "r", newline="", encoding="utf-8") as csvfile:
reader = csv.reader(csvfile)
for row in reader:
tag = clean_tag(row[0])
alias_list = [] if row[3] == "null" else [clean_tag(alias) for alias in row[3].split(",")]
if reverse:
for alias in alias_list:
aliases_dict.setdefault(alias, []).append(tag)
else:
aliases_dict[tag] = alias_list
return aliases_dict
def get_tfidf_components() -> Dict[str, Any]:
global _tfidf_components
if _tfidf_components is not None:
return _tfidf_components
if not TFIDF_PATH.is_file():
raise FileNotFoundError(f"TF-IDF joblib not found: {TFIDF_PATH}")
model_components = joblib.load(TFIDF_PATH)
if "tag_to_row_index" in model_components and "row_to_tag" not in model_components:
model_components["row_to_tag"] = {
idx: tag for tag, idx in model_components["tag_to_row_index"].items()
}
idf = model_components.get("idf")
if isinstance(idf, dict):
t2c = model_components["tag_to_column_index"]
n_cols = max(t2c.values()) + 1
idf_by_col = np.ones(n_cols, dtype=np.float32)
for term, col in t2c.items():
idf_by_col[col] = float(idf.get(term, 1.0))
model_components["idf"] = idf_by_col
_tfidf_components = model_components
return model_components
def get_nsfw_tags() -> Set[str]:
global _nsfw_tags
if _nsfw_tags is not None:
return _nsfw_tags
if not NSFW_CSV_PATH.is_file():
raise FileNotFoundError(f"NSFW tag CSV not found: {NSFW_CSV_PATH}")
tags: Set[str] = set()
with NSFW_CSV_PATH.open("r", newline="", encoding="utf-8") as csvfile:
reader = csv.reader(csvfile)
next(reader, None)
for row in reader:
if not row:
continue
word = row[0]
try:
probability_sum = float(row[1])
except (IndexError, ValueError):
continue
if probability_sum >= NSFW_THRESHOLD:
tags.add(word)
_nsfw_tags = tags
return _nsfw_tags
def get_artist_set() -> Set[str]:
global _artist_set
if _artist_set is not None:
return _artist_set
path = pathlib.Path("fluffyrock_3m.csv")
if not path.is_file():
_artist_set = set()
return _artist_set
artists: Set[str] = set()
with path.open("r", newline="", encoding="utf-8") as csvfile:
reader = csv.reader(csvfile)
for row in reader:
if not row:
continue
tag_name = row[0]
if tag_name.startswith("by_"):
artists.add(tag_name[3:])
_artist_set = artists
return _artist_set
def is_artist(name: str) -> bool:
return name in get_artist_set()
def get_fasttext_model() -> Any:
global _fasttext_model
if _fasttext_model is not None:
return _fasttext_model
if not FASTTEXT_MODEL_PATH.is_file():
raise FileNotFoundError(f"FastText model not found: {FASTTEXT_MODEL_PATH}")
import compress_fasttext
_fasttext_model = compress_fasttext.models.CompressedFastTextKeyedVectors.load(
str(FASTTEXT_MODEL_PATH)
)
return _fasttext_model
def get_tag_type_ids() -> Dict[str, int]:
"""Return canonical tag -> type_id (int) from fluffyrock_3m.csv.
Reads row[1] as int when possible. Missing/invalid values are skipped.
"""
global _tag_type_id
if _tag_type_id is not None:
return _tag_type_id
if not TAG_ALIASES_PATH.is_file():
raise FileNotFoundError(f"Tag CSV not found: {TAG_ALIASES_PATH}")
m: Dict[str, int] = {}
with TAG_ALIASES_PATH.open("r", newline="", encoding="utf-8") as csvfile:
reader = csv.reader(csvfile)
for row in reader:
if not row:
continue
tag = clean_tag(row[0])
if len(row) < 2:
continue
try:
type_id = int(row[1])
except ValueError:
continue
m[tag] = type_id
_tag_type_id = m
return _tag_type_id
def get_tag_type_name(tag: str) -> Optional[str]:
"""Return heuristic type name for a tag (e.g. 'artist', 'character'), or None."""
tid = get_tag_type_ids().get(clean_tag(tag))
if tid is None:
return None
return TAG_TYPE_ID_TO_NAME.get(tid, f"type_{tid}")
def get_tag_counts() -> Dict[str, int]:
global _tag_counts
if _tag_counts is not None:
return _tag_counts
if not TAG_ALIASES_PATH.is_file():
raise FileNotFoundError(f"Tag count CSV not found: {TAG_ALIASES_PATH}")
tag_counts: Dict[str, int] = {}
with TAG_ALIASES_PATH.open("r", newline="", encoding="utf-8") as csvfile:
reader = csv.reader(csvfile)
for row in reader:
if not row:
continue
key = row[0]
value = int(row[2]) if row[2].isdigit() else None
if value is not None:
tag_counts[key] = value
_tag_counts = tag_counts
return _tag_counts
def get_alias2tags() -> Dict[str, List[str]]:
"""Return alias -> [canonical tags] mapping."""
global _alias_to_tags
if _alias_to_tags is not None:
return _alias_to_tags
if not TAG_ALIASES_PATH.is_file():
raise FileNotFoundError(f"Tag alias CSV not found: {TAG_ALIASES_PATH}")
_alias_to_tags = build_aliases_dict(str(TAG_ALIASES_PATH), reverse=True)
return _alias_to_tags
def get_tag2aliases() -> Dict[str, List[str]]:
"""Return canonical tag -> [aliases] mapping."""
global _tag_to_aliases
if _tag_to_aliases is not None:
return _tag_to_aliases
if not TAG_ALIASES_PATH.is_file():
raise FileNotFoundError(f"Tag alias CSV not found: {TAG_ALIASES_PATH}")
_tag_to_aliases = build_aliases_dict(str(TAG_ALIASES_PATH), reverse=False)
return _tag_to_aliases
def get_tfidf_tag_vectors() -> Dict[str, Any]:
global _tfidf_tag_vectors
if _tfidf_tag_vectors is not None:
return _tfidf_tag_vectors
components = get_tfidf_components()
reduced_matrix = components.get("reduced_matrix")
if reduced_matrix is None:
raise KeyError("TF-IDF components missing reduced_matrix")
row_to_tag = components.get("row_to_tag")
if row_to_tag is None and "tag_to_row_index" in components:
row_to_tag = {idx: tag for tag, idx in components["tag_to_row_index"].items()}
if row_to_tag is None:
raise KeyError("TF-IDF components missing row_to_tag mapping")
tag_to_row_index = components.get("tag_to_row_index")
if tag_to_row_index is None:
tag_to_row_index = {tag: idx for idx, tag in row_to_tag.items()}
reduced_matrix_norm = _l2_normalize_rows(reduced_matrix).astype(np.float32)
_tfidf_tag_vectors = {
"reduced_matrix": reduced_matrix,
"reduced_matrix_norm": reduced_matrix_norm,
"row_to_tag": row_to_tag,
"tag_to_row_index": tag_to_row_index,
}
return _tfidf_tag_vectors
def retrieval_assets_status() -> Dict[str, bool]:
return {
"tfidf": TFIDF_PATH.is_file(),
"nsfw_csv": NSFW_CSV_PATH.is_file(),
"fasttext_model": FASTTEXT_MODEL_PATH.is_file(),
"tag_aliases_csv": TAG_ALIASES_PATH.is_file(),
"hnsw_tags": HNSW_TAG_PATH.is_file(),
"hnsw_artists": HNSW_ART_PATH.is_file(),
}
def _build_or_load_index(path: pathlib.Path, rows: list[int], rm: np.ndarray, dim: int) -> "hnswlib.Index":
idx = hnswlib.Index(space="cosine", dim=dim)
need_build = True
if path.exists():
try:
idx.load_index(str(path), max_elements=max(1, len(rows)))
if getattr(idx, "get_current_count", None) and idx.get_current_count() == len(rows) and len(rows) > 0:
need_build = False
else:
logging.debug(
"Rebuilding %s: saved_count!=rows_len (%s vs %s)",
path.name,
idx.get_current_count(),
len(rows),
)
except Exception as e:
logging.debug("Reload %s failed, rebuilding: %s", path.name, e)
if need_build:
try:
if path.exists():
path.unlink()
except Exception:
pass
idx.init_index(max_elements=max(1, len(rows)), ef_construction=200, M=16)
if rows:
idx.add_items(rm[rows], ids=np.asarray(rows, dtype=np.int32))
idx.save_index(str(path))
idx.set_ef(200)
return idx
def _ensure_hnsw_indexes(need_artists: bool) -> None:
global _hnsw_tag_index, _hnsw_artist_index, _hnsw_tag_count, _hnsw_artist_count
if hnswlib is None:
return
if _hnsw_tag_index is not None and (not need_artists or _hnsw_artist_index is not None):
return
components = get_tfidf_components()
reduced_matrix = components["reduced_matrix"]
row_to_tag = components["row_to_tag"]
rm = _l2_normalize_rows(reduced_matrix).astype(np.float32)
n_items, dim = rm.shape
artist_set = get_artist_set() if need_artists else set()
artist_rows: list[int] = []
tag_rows: list[int] = []
for i in range(n_items):
tag = row_to_tag.get(i, "")
base = tag[3:] if tag.startswith("by_") else tag
if tag in {"by_unknown_artist", "by_conditional_dnp"}:
tag_rows.append(i)
continue
if artist_set and is_artist(base):
artist_rows.append(i)
else:
tag_rows.append(i)
_hnsw_tag_index = _build_or_load_index(HNSW_TAG_PATH, tag_rows, rm, dim)
_hnsw_tag_count = len(tag_rows)
if need_artists:
_hnsw_artist_index = _build_or_load_index(HNSW_ART_PATH, artist_rows, rm, dim)
_hnsw_artist_count = len(artist_rows)
def get_hnsw_tag_index() -> Tuple[Optional["hnswlib.Index"], int]:
_ensure_hnsw_indexes(need_artists=False)
return _hnsw_tag_index, _hnsw_tag_count
def get_hnsw_artist_index() -> Tuple[Optional["hnswlib.Index"], int]:
_ensure_hnsw_indexes(need_artists=True)
return _hnsw_artist_index, _hnsw_artist_count