Spaces:
Running
Running
| from __future__ import annotations | |
| import csv | |
| import logging | |
| import pathlib | |
| from typing import Any, Dict, List, Optional, Set, Tuple | |
| import joblib | |
| import numpy as np | |
| try: | |
| import hnswlib | |
| except Exception: | |
| hnswlib = None # allow import on environments without hnswlib during partial tests | |
| TFIDF_PATH = pathlib.Path("tf_idf_files_420.joblib") | |
| NSFW_CSV_PATH = pathlib.Path("word_rating_probabilities.csv") | |
| NSFW_THRESHOLD = 0.95 | |
| HNSW_ART_PATH = pathlib.Path("tfidf_hnsw_artists.bin") | |
| HNSW_TAG_PATH = pathlib.Path("tfidf_hnsw_tags.bin") | |
| FASTTEXT_MODEL_PATH = pathlib.Path("e621FastTextModel010Replacement_small.bin") | |
| TAG_ALIASES_PATH = pathlib.Path("fluffyrock_3m.csv") | |
| _tfidf_components: Optional[Dict[str, Any]] = None | |
| _nsfw_tags: Optional[Set[str]] = None | |
| _artist_set: Optional[Set[str]] = None | |
| _fasttext_model: Optional[Any] = None | |
| _tag_counts: Optional[Dict[str, int]] = None | |
| _tfidf_tag_vectors: Optional[Dict[str, Any]] = None | |
| _alias_to_tags: Optional[Dict[str, List[str]]] = None | |
| _tag_to_aliases: Optional[Dict[str, List[str]]] = None | |
| _tag_type_id: Optional[Dict[str, int]] = None | |
| _hnsw_tag_index: Optional["hnswlib.Index"] = None | |
| _hnsw_artist_index: Optional["hnswlib.Index"] = None | |
| _hnsw_tag_count: int = 0 | |
| _hnsw_artist_count: int = 0 | |
| # Tag type names inferred from e621 wiki documentation. | |
| # Numeric IDs come from fluffyrock_3m.csv column 1; mapping is heuristic but | |
| # matches observed usage on e621. | |
| TAG_TYPE_ID_TO_NAME: Dict[int, str] = { | |
| 0: "general", # Default tag type: visible attributes, actions, objects, etc. | |
| 1: "artist", # Artist tags (e.g. by_name, artist_name) | |
| 2: "contributor", # Contributor tags (rare / possibly unused in this dataset) | |
| 3: "copyright", # Series, franchise, or IP (e.g. pokemon, winnie_the_pooh) | |
| 4: "character", # Named characters (e.g. pikachu, pinkie_pie_(mlp)) | |
| 5: "species", # Species tags (e.g. canine, domestic_cat) | |
| 6: "invalid", # Invalid / disallowed / disambiguation-only tags | |
| 7: "meta", # Meta / presentation / file / style-related tags | |
| } | |
| def _l2_normalize_rows(mat: np.ndarray) -> np.ndarray: | |
| mat = np.asarray(mat, dtype=np.float32) | |
| norms = np.linalg.norm(mat, axis=1, keepdims=True) | |
| norms[norms == 0.0] = 1.0 | |
| return mat / norms | |
| def _clean_tag_ascii(tag: str) -> str: | |
| return "".join(char for char in tag if ord(char) < 128) | |
| def clean_tag(tag: str) -> str: | |
| """Normalize tags consistently with legacy alias parsing.""" | |
| return _clean_tag_ascii(tag) | |
| def build_aliases_dict(csv_path: str, reverse: bool = False) -> Dict[str, List[str]]: | |
| """Build tag/alias mappings from the aliases CSV.""" | |
| aliases_dict: Dict[str, List[str]] = {} | |
| with open(csv_path, "r", newline="", encoding="utf-8") as csvfile: | |
| reader = csv.reader(csvfile) | |
| for row in reader: | |
| tag = clean_tag(row[0]) | |
| alias_list = [] if row[3] == "null" else [clean_tag(alias) for alias in row[3].split(",")] | |
| if reverse: | |
| for alias in alias_list: | |
| aliases_dict.setdefault(alias, []).append(tag) | |
| else: | |
| aliases_dict[tag] = alias_list | |
| return aliases_dict | |
| def get_tfidf_components() -> Dict[str, Any]: | |
| global _tfidf_components | |
| if _tfidf_components is not None: | |
| return _tfidf_components | |
| if not TFIDF_PATH.is_file(): | |
| raise FileNotFoundError(f"TF-IDF joblib not found: {TFIDF_PATH}") | |
| model_components = joblib.load(TFIDF_PATH) | |
| if "tag_to_row_index" in model_components and "row_to_tag" not in model_components: | |
| model_components["row_to_tag"] = { | |
| idx: tag for tag, idx in model_components["tag_to_row_index"].items() | |
| } | |
| idf = model_components.get("idf") | |
| if isinstance(idf, dict): | |
| t2c = model_components["tag_to_column_index"] | |
| n_cols = max(t2c.values()) + 1 | |
| idf_by_col = np.ones(n_cols, dtype=np.float32) | |
| for term, col in t2c.items(): | |
| idf_by_col[col] = float(idf.get(term, 1.0)) | |
| model_components["idf"] = idf_by_col | |
| _tfidf_components = model_components | |
| return model_components | |
| def get_nsfw_tags() -> Set[str]: | |
| global _nsfw_tags | |
| if _nsfw_tags is not None: | |
| return _nsfw_tags | |
| if not NSFW_CSV_PATH.is_file(): | |
| raise FileNotFoundError(f"NSFW tag CSV not found: {NSFW_CSV_PATH}") | |
| tags: Set[str] = set() | |
| with NSFW_CSV_PATH.open("r", newline="", encoding="utf-8") as csvfile: | |
| reader = csv.reader(csvfile) | |
| next(reader, None) | |
| for row in reader: | |
| if not row: | |
| continue | |
| word = row[0] | |
| try: | |
| probability_sum = float(row[1]) | |
| except (IndexError, ValueError): | |
| continue | |
| if probability_sum >= NSFW_THRESHOLD: | |
| tags.add(word) | |
| _nsfw_tags = tags | |
| return _nsfw_tags | |
| def get_artist_set() -> Set[str]: | |
| global _artist_set | |
| if _artist_set is not None: | |
| return _artist_set | |
| path = pathlib.Path("fluffyrock_3m.csv") | |
| if not path.is_file(): | |
| _artist_set = set() | |
| return _artist_set | |
| artists: Set[str] = set() | |
| with path.open("r", newline="", encoding="utf-8") as csvfile: | |
| reader = csv.reader(csvfile) | |
| for row in reader: | |
| if not row: | |
| continue | |
| tag_name = row[0] | |
| if tag_name.startswith("by_"): | |
| artists.add(tag_name[3:]) | |
| _artist_set = artists | |
| return _artist_set | |
| def is_artist(name: str) -> bool: | |
| return name in get_artist_set() | |
| def get_fasttext_model() -> Any: | |
| global _fasttext_model | |
| if _fasttext_model is not None: | |
| return _fasttext_model | |
| if not FASTTEXT_MODEL_PATH.is_file(): | |
| raise FileNotFoundError(f"FastText model not found: {FASTTEXT_MODEL_PATH}") | |
| import compress_fasttext | |
| _fasttext_model = compress_fasttext.models.CompressedFastTextKeyedVectors.load( | |
| str(FASTTEXT_MODEL_PATH) | |
| ) | |
| return _fasttext_model | |
| def get_tag_type_ids() -> Dict[str, int]: | |
| """Return canonical tag -> type_id (int) from fluffyrock_3m.csv. | |
| Reads row[1] as int when possible. Missing/invalid values are skipped. | |
| """ | |
| global _tag_type_id | |
| if _tag_type_id is not None: | |
| return _tag_type_id | |
| if not TAG_ALIASES_PATH.is_file(): | |
| raise FileNotFoundError(f"Tag CSV not found: {TAG_ALIASES_PATH}") | |
| m: Dict[str, int] = {} | |
| with TAG_ALIASES_PATH.open("r", newline="", encoding="utf-8") as csvfile: | |
| reader = csv.reader(csvfile) | |
| for row in reader: | |
| if not row: | |
| continue | |
| tag = clean_tag(row[0]) | |
| if len(row) < 2: | |
| continue | |
| try: | |
| type_id = int(row[1]) | |
| except ValueError: | |
| continue | |
| m[tag] = type_id | |
| _tag_type_id = m | |
| return _tag_type_id | |
| def get_tag_type_name(tag: str) -> Optional[str]: | |
| """Return heuristic type name for a tag (e.g. 'artist', 'character'), or None.""" | |
| tid = get_tag_type_ids().get(clean_tag(tag)) | |
| if tid is None: | |
| return None | |
| return TAG_TYPE_ID_TO_NAME.get(tid, f"type_{tid}") | |
| def get_tag_counts() -> Dict[str, int]: | |
| global _tag_counts | |
| if _tag_counts is not None: | |
| return _tag_counts | |
| if not TAG_ALIASES_PATH.is_file(): | |
| raise FileNotFoundError(f"Tag count CSV not found: {TAG_ALIASES_PATH}") | |
| tag_counts: Dict[str, int] = {} | |
| with TAG_ALIASES_PATH.open("r", newline="", encoding="utf-8") as csvfile: | |
| reader = csv.reader(csvfile) | |
| for row in reader: | |
| if not row: | |
| continue | |
| key = row[0] | |
| value = int(row[2]) if row[2].isdigit() else None | |
| if value is not None: | |
| tag_counts[key] = value | |
| _tag_counts = tag_counts | |
| return _tag_counts | |
| def get_alias2tags() -> Dict[str, List[str]]: | |
| """Return alias -> [canonical tags] mapping.""" | |
| global _alias_to_tags | |
| if _alias_to_tags is not None: | |
| return _alias_to_tags | |
| if not TAG_ALIASES_PATH.is_file(): | |
| raise FileNotFoundError(f"Tag alias CSV not found: {TAG_ALIASES_PATH}") | |
| _alias_to_tags = build_aliases_dict(str(TAG_ALIASES_PATH), reverse=True) | |
| return _alias_to_tags | |
| def get_tag2aliases() -> Dict[str, List[str]]: | |
| """Return canonical tag -> [aliases] mapping.""" | |
| global _tag_to_aliases | |
| if _tag_to_aliases is not None: | |
| return _tag_to_aliases | |
| if not TAG_ALIASES_PATH.is_file(): | |
| raise FileNotFoundError(f"Tag alias CSV not found: {TAG_ALIASES_PATH}") | |
| _tag_to_aliases = build_aliases_dict(str(TAG_ALIASES_PATH), reverse=False) | |
| return _tag_to_aliases | |
| def get_tfidf_tag_vectors() -> Dict[str, Any]: | |
| global _tfidf_tag_vectors | |
| if _tfidf_tag_vectors is not None: | |
| return _tfidf_tag_vectors | |
| components = get_tfidf_components() | |
| reduced_matrix = components.get("reduced_matrix") | |
| if reduced_matrix is None: | |
| raise KeyError("TF-IDF components missing reduced_matrix") | |
| row_to_tag = components.get("row_to_tag") | |
| if row_to_tag is None and "tag_to_row_index" in components: | |
| row_to_tag = {idx: tag for tag, idx in components["tag_to_row_index"].items()} | |
| if row_to_tag is None: | |
| raise KeyError("TF-IDF components missing row_to_tag mapping") | |
| tag_to_row_index = components.get("tag_to_row_index") | |
| if tag_to_row_index is None: | |
| tag_to_row_index = {tag: idx for idx, tag in row_to_tag.items()} | |
| reduced_matrix_norm = _l2_normalize_rows(reduced_matrix).astype(np.float32) | |
| _tfidf_tag_vectors = { | |
| "reduced_matrix": reduced_matrix, | |
| "reduced_matrix_norm": reduced_matrix_norm, | |
| "row_to_tag": row_to_tag, | |
| "tag_to_row_index": tag_to_row_index, | |
| } | |
| return _tfidf_tag_vectors | |
| def retrieval_assets_status() -> Dict[str, bool]: | |
| return { | |
| "tfidf": TFIDF_PATH.is_file(), | |
| "nsfw_csv": NSFW_CSV_PATH.is_file(), | |
| "fasttext_model": FASTTEXT_MODEL_PATH.is_file(), | |
| "tag_aliases_csv": TAG_ALIASES_PATH.is_file(), | |
| "hnsw_tags": HNSW_TAG_PATH.is_file(), | |
| "hnsw_artists": HNSW_ART_PATH.is_file(), | |
| } | |
| def _build_or_load_index(path: pathlib.Path, rows: list[int], rm: np.ndarray, dim: int) -> "hnswlib.Index": | |
| idx = hnswlib.Index(space="cosine", dim=dim) | |
| need_build = True | |
| if path.exists(): | |
| try: | |
| idx.load_index(str(path), max_elements=max(1, len(rows))) | |
| if getattr(idx, "get_current_count", None) and idx.get_current_count() == len(rows) and len(rows) > 0: | |
| need_build = False | |
| else: | |
| logging.debug( | |
| "Rebuilding %s: saved_count!=rows_len (%s vs %s)", | |
| path.name, | |
| idx.get_current_count(), | |
| len(rows), | |
| ) | |
| except Exception as e: | |
| logging.debug("Reload %s failed, rebuilding: %s", path.name, e) | |
| if need_build: | |
| try: | |
| if path.exists(): | |
| path.unlink() | |
| except Exception: | |
| pass | |
| idx.init_index(max_elements=max(1, len(rows)), ef_construction=200, M=16) | |
| if rows: | |
| idx.add_items(rm[rows], ids=np.asarray(rows, dtype=np.int32)) | |
| idx.save_index(str(path)) | |
| idx.set_ef(200) | |
| return idx | |
| def _ensure_hnsw_indexes(need_artists: bool) -> None: | |
| global _hnsw_tag_index, _hnsw_artist_index, _hnsw_tag_count, _hnsw_artist_count | |
| if hnswlib is None: | |
| return | |
| if _hnsw_tag_index is not None and (not need_artists or _hnsw_artist_index is not None): | |
| return | |
| components = get_tfidf_components() | |
| reduced_matrix = components["reduced_matrix"] | |
| row_to_tag = components["row_to_tag"] | |
| rm = _l2_normalize_rows(reduced_matrix).astype(np.float32) | |
| n_items, dim = rm.shape | |
| artist_set = get_artist_set() if need_artists else set() | |
| artist_rows: list[int] = [] | |
| tag_rows: list[int] = [] | |
| for i in range(n_items): | |
| tag = row_to_tag.get(i, "") | |
| base = tag[3:] if tag.startswith("by_") else tag | |
| if tag in {"by_unknown_artist", "by_conditional_dnp"}: | |
| tag_rows.append(i) | |
| continue | |
| if artist_set and is_artist(base): | |
| artist_rows.append(i) | |
| else: | |
| tag_rows.append(i) | |
| _hnsw_tag_index = _build_or_load_index(HNSW_TAG_PATH, tag_rows, rm, dim) | |
| _hnsw_tag_count = len(tag_rows) | |
| if need_artists: | |
| _hnsw_artist_index = _build_or_load_index(HNSW_ART_PATH, artist_rows, rm, dim) | |
| _hnsw_artist_count = len(artist_rows) | |
| def get_hnsw_tag_index() -> Tuple[Optional["hnswlib.Index"], int]: | |
| _ensure_hnsw_indexes(need_artists=False) | |
| return _hnsw_tag_index, _hnsw_tag_count | |
| def get_hnsw_artist_index() -> Tuple[Optional["hnswlib.Index"], int]: | |
| _ensure_hnsw_indexes(need_artists=True) | |
| return _hnsw_artist_index, _hnsw_artist_count | |