File size: 12,994 Bytes
c6be992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
from __future__ import annotations

import csv
import logging
import pathlib
from typing import Any, Dict, List, Optional, Set, Tuple

import joblib
import numpy as np

try:
    import hnswlib
except Exception:
    hnswlib = None  # allow import on environments without hnswlib during partial tests


TFIDF_PATH = pathlib.Path("tf_idf_files_420.joblib")
NSFW_CSV_PATH = pathlib.Path("word_rating_probabilities.csv")
NSFW_THRESHOLD = 0.95

HNSW_ART_PATH = pathlib.Path("tfidf_hnsw_artists.bin")
HNSW_TAG_PATH = pathlib.Path("tfidf_hnsw_tags.bin")
FASTTEXT_MODEL_PATH = pathlib.Path("e621FastTextModel010Replacement_small.bin")
TAG_ALIASES_PATH = pathlib.Path("fluffyrock_3m.csv")

_tfidf_components: Optional[Dict[str, Any]] = None
_nsfw_tags: Optional[Set[str]] = None
_artist_set: Optional[Set[str]] = None
_fasttext_model: Optional[Any] = None
_tag_counts: Optional[Dict[str, int]] = None
_tfidf_tag_vectors: Optional[Dict[str, Any]] = None
_alias_to_tags: Optional[Dict[str, List[str]]] = None
_tag_to_aliases: Optional[Dict[str, List[str]]] = None
_tag_type_id: Optional[Dict[str, int]] = None


_hnsw_tag_index: Optional["hnswlib.Index"] = None
_hnsw_artist_index: Optional["hnswlib.Index"] = None
_hnsw_tag_count: int = 0
_hnsw_artist_count: int = 0

# Tag type names inferred from e621 wiki documentation.
# Numeric IDs come from fluffyrock_3m.csv column 1; mapping is heuristic but
# matches observed usage on e621.
TAG_TYPE_ID_TO_NAME: Dict[int, str] = {
    0: "general",        # Default tag type: visible attributes, actions, objects, etc.
    1: "artist",         # Artist tags (e.g. by_name, artist_name)
    2: "contributor",    # Contributor tags (rare / possibly unused in this dataset)
    3: "copyright",      # Series, franchise, or IP (e.g. pokemon, winnie_the_pooh)
    4: "character",      # Named characters (e.g. pikachu, pinkie_pie_(mlp))
    5: "species",        # Species tags (e.g. canine, domestic_cat)
    6: "invalid",        # Invalid / disallowed / disambiguation-only tags
    7: "meta",           # Meta / presentation / file / style-related tags
}


def _l2_normalize_rows(mat: np.ndarray) -> np.ndarray:
    mat = np.asarray(mat, dtype=np.float32)
    norms = np.linalg.norm(mat, axis=1, keepdims=True)
    norms[norms == 0.0] = 1.0
    return mat / norms


def _clean_tag_ascii(tag: str) -> str:
    return "".join(char for char in tag if ord(char) < 128)


def clean_tag(tag: str) -> str:
    """Normalize tags consistently with legacy alias parsing."""
    return _clean_tag_ascii(tag)


def build_aliases_dict(csv_path: str, reverse: bool = False) -> Dict[str, List[str]]:
    """Build tag/alias mappings from the aliases CSV."""
    aliases_dict: Dict[str, List[str]] = {}
    with open(csv_path, "r", newline="", encoding="utf-8") as csvfile:
        reader = csv.reader(csvfile)
        for row in reader:
            tag = clean_tag(row[0])
            alias_list = [] if row[3] == "null" else [clean_tag(alias) for alias in row[3].split(",")]
            if reverse:
                for alias in alias_list:
                    aliases_dict.setdefault(alias, []).append(tag)
            else:
                aliases_dict[tag] = alias_list
    return aliases_dict


def get_tfidf_components() -> Dict[str, Any]:
    global _tfidf_components
    if _tfidf_components is not None:
        return _tfidf_components

    if not TFIDF_PATH.is_file():
        raise FileNotFoundError(f"TF-IDF joblib not found: {TFIDF_PATH}")

    model_components = joblib.load(TFIDF_PATH)

    if "tag_to_row_index" in model_components and "row_to_tag" not in model_components:
        model_components["row_to_tag"] = {
            idx: tag for tag, idx in model_components["tag_to_row_index"].items()
        }

    idf = model_components.get("idf")
    if isinstance(idf, dict):
        t2c = model_components["tag_to_column_index"]
        n_cols = max(t2c.values()) + 1
        idf_by_col = np.ones(n_cols, dtype=np.float32)
        for term, col in t2c.items():
            idf_by_col[col] = float(idf.get(term, 1.0))
        model_components["idf"] = idf_by_col

    _tfidf_components = model_components
    return model_components


def get_nsfw_tags() -> Set[str]:
    global _nsfw_tags
    if _nsfw_tags is not None:
        return _nsfw_tags

    if not NSFW_CSV_PATH.is_file():
        raise FileNotFoundError(f"NSFW tag CSV not found: {NSFW_CSV_PATH}")

    tags: Set[str] = set()
    with NSFW_CSV_PATH.open("r", newline="", encoding="utf-8") as csvfile:
        reader = csv.reader(csvfile)
        next(reader, None)
        for row in reader:
            if not row:
                continue
            word = row[0]
            try:
                probability_sum = float(row[1])
            except (IndexError, ValueError):
                continue
            if probability_sum >= NSFW_THRESHOLD:
                tags.add(word)

    _nsfw_tags = tags
    return _nsfw_tags


def get_artist_set() -> Set[str]:
    global _artist_set
    if _artist_set is not None:
        return _artist_set

    path = pathlib.Path("fluffyrock_3m.csv")
    if not path.is_file():
        _artist_set = set()
        return _artist_set

    artists: Set[str] = set()
    with path.open("r", newline="", encoding="utf-8") as csvfile:
        reader = csv.reader(csvfile)
        for row in reader:
            if not row:
                continue
            tag_name = row[0]
            if tag_name.startswith("by_"):
                artists.add(tag_name[3:])

    _artist_set = artists
    return _artist_set


def is_artist(name: str) -> bool:
    return name in get_artist_set()


def get_fasttext_model() -> Any:
    global _fasttext_model
    if _fasttext_model is not None:
        return _fasttext_model

    if not FASTTEXT_MODEL_PATH.is_file():
        raise FileNotFoundError(f"FastText model not found: {FASTTEXT_MODEL_PATH}")

    import compress_fasttext

    _fasttext_model = compress_fasttext.models.CompressedFastTextKeyedVectors.load(
        str(FASTTEXT_MODEL_PATH)
    )
    return _fasttext_model


def get_tag_type_ids() -> Dict[str, int]:
    """Return canonical tag -> type_id (int) from fluffyrock_3m.csv.

    Reads row[1] as int when possible. Missing/invalid values are skipped.
    """
    global _tag_type_id
    if _tag_type_id is not None:
        return _tag_type_id

    if not TAG_ALIASES_PATH.is_file():
        raise FileNotFoundError(f"Tag CSV not found: {TAG_ALIASES_PATH}")

    m: Dict[str, int] = {}
    with TAG_ALIASES_PATH.open("r", newline="", encoding="utf-8") as csvfile:
        reader = csv.reader(csvfile)
        for row in reader:
            if not row:
                continue
            tag = clean_tag(row[0])
            if len(row) < 2:
                continue
            try:
                type_id = int(row[1])
            except ValueError:
                continue
            m[tag] = type_id

    _tag_type_id = m
    return _tag_type_id


def get_tag_type_name(tag: str) -> Optional[str]:
    """Return heuristic type name for a tag (e.g. 'artist', 'character'), or None."""
    tid = get_tag_type_ids().get(clean_tag(tag))
    if tid is None:
        return None
    return TAG_TYPE_ID_TO_NAME.get(tid, f"type_{tid}")


def get_tag_counts() -> Dict[str, int]:
    global _tag_counts
    if _tag_counts is not None:
        return _tag_counts

    if not TAG_ALIASES_PATH.is_file():
        raise FileNotFoundError(f"Tag count CSV not found: {TAG_ALIASES_PATH}")

    tag_counts: Dict[str, int] = {}
    with TAG_ALIASES_PATH.open("r", newline="", encoding="utf-8") as csvfile:
        reader = csv.reader(csvfile)
        for row in reader:
            if not row:
                continue
            key = row[0]
            value = int(row[2]) if row[2].isdigit() else None
            if value is not None:
                tag_counts[key] = value

    _tag_counts = tag_counts
    return _tag_counts


def get_alias2tags() -> Dict[str, List[str]]:
    """Return alias -> [canonical tags] mapping."""
    global _alias_to_tags
    if _alias_to_tags is not None:
        return _alias_to_tags

    if not TAG_ALIASES_PATH.is_file():
        raise FileNotFoundError(f"Tag alias CSV not found: {TAG_ALIASES_PATH}")

    _alias_to_tags = build_aliases_dict(str(TAG_ALIASES_PATH), reverse=True)
    return _alias_to_tags


def get_tag2aliases() -> Dict[str, List[str]]:
    """Return canonical tag -> [aliases] mapping."""
    global _tag_to_aliases
    if _tag_to_aliases is not None:
        return _tag_to_aliases

    if not TAG_ALIASES_PATH.is_file():
        raise FileNotFoundError(f"Tag alias CSV not found: {TAG_ALIASES_PATH}")

    _tag_to_aliases = build_aliases_dict(str(TAG_ALIASES_PATH), reverse=False)
    return _tag_to_aliases


def get_tfidf_tag_vectors() -> Dict[str, Any]:
    global _tfidf_tag_vectors
    if _tfidf_tag_vectors is not None:
        return _tfidf_tag_vectors

    components = get_tfidf_components()
    reduced_matrix = components.get("reduced_matrix")
    if reduced_matrix is None:
        raise KeyError("TF-IDF components missing reduced_matrix")

    row_to_tag = components.get("row_to_tag")
    if row_to_tag is None and "tag_to_row_index" in components:
        row_to_tag = {idx: tag for tag, idx in components["tag_to_row_index"].items()}
    if row_to_tag is None:
        raise KeyError("TF-IDF components missing row_to_tag mapping")

    tag_to_row_index = components.get("tag_to_row_index")
    if tag_to_row_index is None:
        tag_to_row_index = {tag: idx for idx, tag in row_to_tag.items()}

    reduced_matrix_norm = _l2_normalize_rows(reduced_matrix).astype(np.float32)

    _tfidf_tag_vectors = {
        "reduced_matrix": reduced_matrix,
        "reduced_matrix_norm": reduced_matrix_norm,
        "row_to_tag": row_to_tag,
        "tag_to_row_index": tag_to_row_index,
    }
    return _tfidf_tag_vectors


def retrieval_assets_status() -> Dict[str, bool]:
    return {
        "tfidf": TFIDF_PATH.is_file(),
        "nsfw_csv": NSFW_CSV_PATH.is_file(),
        "fasttext_model": FASTTEXT_MODEL_PATH.is_file(),
        "tag_aliases_csv": TAG_ALIASES_PATH.is_file(),
        "hnsw_tags": HNSW_TAG_PATH.is_file(),
        "hnsw_artists": HNSW_ART_PATH.is_file(),
    }


def _build_or_load_index(path: pathlib.Path, rows: list[int], rm: np.ndarray, dim: int) -> "hnswlib.Index":
    idx = hnswlib.Index(space="cosine", dim=dim)
    need_build = True
    if path.exists():
        try:
            idx.load_index(str(path), max_elements=max(1, len(rows)))
            if getattr(idx, "get_current_count", None) and idx.get_current_count() == len(rows) and len(rows) > 0:
                need_build = False
            else:
                logging.debug(
                    "Rebuilding %s: saved_count!=rows_len (%s vs %s)",
                    path.name,
                    idx.get_current_count(),
                    len(rows),
                )
        except Exception as e:
            logging.debug("Reload %s failed, rebuilding: %s", path.name, e)

    if need_build:
        try:
            if path.exists():
                path.unlink()
        except Exception:
            pass
        idx.init_index(max_elements=max(1, len(rows)), ef_construction=200, M=16)
        if rows:
            idx.add_items(rm[rows], ids=np.asarray(rows, dtype=np.int32))
        idx.save_index(str(path))

    idx.set_ef(200)
    return idx


def _ensure_hnsw_indexes(need_artists: bool) -> None:
    global _hnsw_tag_index, _hnsw_artist_index, _hnsw_tag_count, _hnsw_artist_count

    if hnswlib is None:
        return

    if _hnsw_tag_index is not None and (not need_artists or _hnsw_artist_index is not None):
        return

    components = get_tfidf_components()
    reduced_matrix = components["reduced_matrix"]
    row_to_tag = components["row_to_tag"]
    rm = _l2_normalize_rows(reduced_matrix).astype(np.float32)
    n_items, dim = rm.shape

    artist_set = get_artist_set() if need_artists else set()
    artist_rows: list[int] = []
    tag_rows: list[int] = []

    for i in range(n_items):
        tag = row_to_tag.get(i, "")
        base = tag[3:] if tag.startswith("by_") else tag

        if tag in {"by_unknown_artist", "by_conditional_dnp"}:
            tag_rows.append(i)
            continue

        if artist_set and is_artist(base):
            artist_rows.append(i)
        else:
            tag_rows.append(i)

    _hnsw_tag_index = _build_or_load_index(HNSW_TAG_PATH, tag_rows, rm, dim)
    _hnsw_tag_count = len(tag_rows)

    if need_artists:
        _hnsw_artist_index = _build_or_load_index(HNSW_ART_PATH, artist_rows, rm, dim)
        _hnsw_artist_count = len(artist_rows)


def get_hnsw_tag_index() -> Tuple[Optional["hnswlib.Index"], int]:
    _ensure_hnsw_indexes(need_artists=False)
    return _hnsw_tag_index, _hnsw_tag_count


def get_hnsw_artist_index() -> Tuple[Optional["hnswlib.Index"], int]:
    _ensure_hnsw_indexes(need_artists=True)
    return _hnsw_artist_index, _hnsw_artist_count