Food Desert commited on
Commit
e2ed0c1
·
1 Parent(s): 06a3c46

Add n30 caption-evident set and per-group display blend overrides

Browse files
data/analysis/per_group_weight_tuning_caption_evident_n30.csv ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ group_name,support_samples,best_w_tfidf,best_w_fasttext,best_ndcg,ndcg_fasttext_only,ndcg_tfidf_only,delta_best_minus_fasttext,delta_tfidf_minus_fasttext
2
+ clothing,1,1.0,0.0,1.0,0.386853,1.0,0.613147,0.613147
3
+ expression_detail,11,0.2,0.8,0.818522,0.797766,0.677285,0.020756,-0.120481
4
+ body_type,18,1.0,0.0,0.979496,0.958992,0.979496,0.020504,0.020504
5
+ pose_action_detail,7,0.8,0.2,0.973933,0.959446,0.971103,0.014486,0.011657
6
+ gender,17,0.7,0.3,0.93487,0.927168,0.905458,0.007702,-0.02171
7
+ gaze_detail,4,0.1,0.9,0.8125,0.811163,0.633429,0.001337,-0.177734
8
+ hair,11,0.0,1.0,1.0,1.0,0.966448,0.0,-0.033552
9
+ general_activity_if_any,3,1.0,0.0,1.0,1.0,1.0,0.0,0.0
10
+ gaze,3,0.5,0.5,1.0,1.0,0.833333,0.0,-0.166667
11
+ fur_style,1,0.9,0.1,1.0,1.0,0.63093,0.0,-0.36907
12
+ limbs,1,1.0,0.0,1.0,1.0,1.0,0.0,0.0
13
+ posture,3,1.0,0.0,1.0,1.0,1.0,0.0,0.0
14
+ body_decor,3,0.9,0.1,1.0,1.0,0.876977,0.0,-0.123023
15
+ franchise_series,2,1.0,0.0,1.0,1.0,1.0,0.0,0.0
16
+ count,27,1.0,0.0,0.986331,0.986331,0.986331,0.0,0.0
17
+ species,16,0.0,1.0,0.957772,0.957772,0.896342,0.0,-0.06143
18
+ expression,7,0.5,0.5,0.894551,0.894551,0.723087,0.0,-0.171465
19
+ clothing_detail,12,0.0,1.0,0.856749,0.856749,0.658325,0.0,-0.198423
20
+ color_markings,12,0.0,1.0,0.838306,0.838306,0.772345,0.0,-0.065961
21
+ anatomy_features,12,0.0,1.0,0.803319,0.803319,0.728669,0.0,-0.074649
22
+ objects_props,6,0.0,1.0,0.645258,0.645258,0.570952,0.0,-0.074306
23
+ text,1,0.6,0.4,0.63093,0.63093,0.5,0.0,-0.13093
24
+ background_composition,4,0.0,1.0,0.593497,0.593497,0.409213,0.0,-0.184284
data/eval_results/eval_caption_evident_n30_k1_seed42.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000_caption_evident_n30.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
psq_rag/ui/group_ranked_display.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
  import csv
 
4
  from dataclasses import dataclass
5
  from functools import lru_cache
6
  from pathlib import Path
@@ -9,7 +10,7 @@ from typing import Dict, List, Sequence, Tuple
9
  import numpy as np
10
 
11
  from psq_rag.retrieval.psq_retrieval import construct_pseudo_vector, _norm_tag_for_lookup
12
- from psq_rag.retrieval.state import get_tfidf_components, get_tfidf_tag_vectors
13
 
14
 
15
  @dataclass
@@ -19,6 +20,92 @@ class GroupRankingRow:
19
  tags: List[Tuple[str, float]]
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  @lru_cache(maxsize=1)
23
  def _load_enabled_groups() -> Dict[str, List[str]]:
24
  csv_path = Path("data/analysis/category_registry.csv")
@@ -83,6 +170,22 @@ def _calibrate_probabilities(scores: Dict[str, float]) -> Dict[str, float]:
83
  return probs
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def rank_groups_from_tfidf(
87
  seed_terms: Sequence[str],
88
  *,
@@ -94,55 +197,87 @@ def rank_groups_from_tfidf(
94
  if not groups:
95
  return []
96
 
97
- components = get_tfidf_components()
98
- tag_vectors = get_tfidf_tag_vectors()
99
- idf = components["idf"]
100
- term_to_col = components["tag_to_column_index"]
101
- svd = components["svd_model"]
102
- tag_to_row = tag_vectors["tag_to_row_index"]
103
- mat_norm = tag_vectors["reduced_matrix_norm"]
104
-
105
- pseudo_doc: Dict[str, float] = {}
106
- for term in seed_terms:
107
- key = _norm_tag_for_lookup(str(term))
108
- if key in term_to_col:
109
- pseudo_doc[key] = pseudo_doc.get(key, 0.0) + 1.0
110
- if not pseudo_doc:
111
- return []
112
-
113
- pseudo_vec = construct_pseudo_vector(pseudo_doc, idf, term_to_col)
114
- q = svd.transform(pseudo_vec).reshape(-1).astype(np.float32)
115
- qn = float(np.linalg.norm(q))
116
- if qn <= 0.0:
117
- return []
118
- q = q / qn
119
 
120
  all_tags: List[str] = []
121
  for tags in groups.values():
122
  all_tags.extend(tags)
123
  all_tags = list(dict.fromkeys(all_tags))
124
 
125
- scored_tags: List[str] = []
126
- rows: List[int] = []
127
- for tag in all_tags:
128
- idx = tag_to_row.get(tag)
129
- if idx is None:
130
- continue
131
- scored_tags.append(tag)
132
- rows.append(int(idx))
133
- if not rows:
134
  return []
135
 
136
- sims = (mat_norm[np.asarray(rows, dtype=np.int32)] @ q).astype(np.float32)
137
- score_by_tag: Dict[str, float] = {t: float(s) for t, s in zip(scored_tags, sims)}
138
- prob_by_tag = _calibrate_probabilities(score_by_tag)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  rows_out: List[GroupRankingRow] = []
141
  rank_k = max(1, int(group_rank_top_k))
142
  display_k = max(1, int(top_tags_per_group))
143
 
144
  for group_name, tags in groups.items():
145
- scored = [(t, prob_by_tag[t]) for t in tags if t in prob_by_tag]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  if not scored:
147
  continue
148
  scored.sort(key=lambda x: x[1], reverse=True)
 
1
  from __future__ import annotations
2
 
3
  import csv
4
+ import os
5
  from dataclasses import dataclass
6
  from functools import lru_cache
7
  from pathlib import Path
 
10
  import numpy as np
11
 
12
  from psq_rag.retrieval.psq_retrieval import construct_pseudo_vector, _norm_tag_for_lookup
13
+ from psq_rag.retrieval.state import get_fasttext_model, get_tfidf_components, get_tfidf_tag_vectors
14
 
15
 
16
  @dataclass
 
20
  tags: List[Tuple[str, float]]
21
 
22
 
23
+ def _resolve_display_weights() -> Tuple[float, float]:
24
+ tfidf_w = float(os.environ.get("PSQ_DISPLAY_TFIDF_WEIGHT", "0.0"))
25
+ fasttext_w = float(os.environ.get("PSQ_DISPLAY_FASTTEXT_WEIGHT", "1.0"))
26
+ tfidf_w, fasttext_w = _normalize_weights(tfidf_w, fasttext_w)
27
+ return tfidf_w, fasttext_w
28
+
29
+
30
+ def _normalize_weights(tfidf_w: float, fasttext_w: float) -> Tuple[float, float]:
31
+ tfidf_w = max(0.0, float(tfidf_w))
32
+ fasttext_w = max(0.0, float(fasttext_w))
33
+ total = tfidf_w + fasttext_w
34
+ if total <= 1e-8:
35
+ return 1.0, 0.0
36
+ return tfidf_w / total, fasttext_w / total
37
+
38
+
39
+ @lru_cache(maxsize=1)
40
+ def _load_group_weight_overrides() -> Dict[str, Tuple[float, float]]:
41
+ csv_path = Path(
42
+ os.environ.get(
43
+ "PSQ_DISPLAY_GROUP_WEIGHT_PATH",
44
+ "data/analysis/per_group_weight_tuning_caption_evident_n30.csv",
45
+ )
46
+ )
47
+ if not csv_path.exists():
48
+ return {}
49
+
50
+ min_support = int(os.environ.get("PSQ_DISPLAY_GROUP_WEIGHT_MIN_SUPPORT", "5"))
51
+ min_delta = float(os.environ.get("PSQ_DISPLAY_GROUP_WEIGHT_MIN_DELTA", "0.005"))
52
+ out: Dict[str, Tuple[float, float]] = {}
53
+
54
+ with csv_path.open("r", encoding="utf-8", newline="") as f:
55
+ reader = csv.DictReader(f)
56
+ for row in reader:
57
+ group = (row.get("group_name") or "").strip()
58
+ if not group:
59
+ continue
60
+ try:
61
+ support = int(float(row.get("support_samples") or "0"))
62
+ delta = float(row.get("delta_best_minus_fasttext") or "0")
63
+ tfidf_w = float(row.get("best_w_tfidf") or "0")
64
+ fasttext_w = float(row.get("best_w_fasttext") or "0")
65
+ except Exception:
66
+ continue
67
+ if support < min_support or delta < min_delta:
68
+ continue
69
+ out[group] = _normalize_weights(tfidf_w, fasttext_w)
70
+ return out
71
+
72
+
73
+ def _safe_unit_vector(vec: np.ndarray) -> np.ndarray:
74
+ v = np.asarray(vec, dtype=np.float32).reshape(-1)
75
+ n = float(np.linalg.norm(v))
76
+ if n <= 1e-12:
77
+ return np.zeros_like(v, dtype=np.float32)
78
+ return (v / n).astype(np.float32)
79
+
80
+
81
+ @lru_cache(maxsize=2)
82
+ def _fasttext_tag_matrix(tags: Tuple[str, ...]) -> np.ndarray:
83
+ ft = get_fasttext_model()
84
+ rows = [_safe_unit_vector(ft.get_vector(tag)) for tag in tags]
85
+ if not rows:
86
+ return np.zeros((0, 0), dtype=np.float32)
87
+ return np.vstack(rows).astype(np.float32)
88
+
89
+
90
+ def _build_fasttext_score_by_tag(tags: Sequence[str], query_terms: Sequence[str]) -> Dict[str, float]:
91
+ if not tags or not query_terms:
92
+ return {}
93
+
94
+ ft = get_fasttext_model()
95
+ query_rows = [_safe_unit_vector(ft.get_vector(term)) for term in query_terms]
96
+ query_rows = [r for r in query_rows if np.any(r)]
97
+ if not query_rows:
98
+ return {}
99
+ query_matrix = np.vstack(query_rows).astype(np.float32)
100
+
101
+ tag_matrix = _fasttext_tag_matrix(tuple(tags))
102
+ if tag_matrix.size == 0:
103
+ return {}
104
+ sims = (tag_matrix @ query_matrix.T).astype(np.float32)
105
+ best = np.max(sims, axis=1)
106
+ return {tag: float(score) for tag, score in zip(tags, best)}
107
+
108
+
109
  @lru_cache(maxsize=1)
110
  def _load_enabled_groups() -> Dict[str, List[str]]:
111
  csv_path = Path("data/analysis/category_registry.csv")
 
170
  return probs
171
 
172
 
173
+ def _blend_prob(
174
+ tag: str,
175
+ tfidf_w: float,
176
+ fasttext_w: float,
177
+ prob_by_tag_tfidf: Dict[str, float],
178
+ prob_by_tag_fasttext: Dict[str, float],
179
+ ) -> float:
180
+ p_tfidf = prob_by_tag_tfidf.get(tag)
181
+ p_fasttext = prob_by_tag_fasttext.get(tag)
182
+ if tfidf_w <= 0.0 or p_tfidf is None:
183
+ return float(p_fasttext or 0.0)
184
+ if fasttext_w <= 0.0 or p_fasttext is None:
185
+ return float(p_tfidf or 0.0)
186
+ return float(tfidf_w * p_tfidf + fasttext_w * p_fasttext)
187
+
188
+
189
  def rank_groups_from_tfidf(
190
  seed_terms: Sequence[str],
191
  *,
 
197
  if not groups:
198
  return []
199
 
200
+ default_tfidf_w, default_fasttext_w = _resolve_display_weights()
201
+ per_group_weights = _load_group_weight_overrides()
202
+ use_tfidf = default_tfidf_w > 0.0 or any(w[0] > 0.0 for w in per_group_weights.values())
203
+ use_fasttext = default_fasttext_w > 0.0 or any(w[1] > 0.0 for w in per_group_weights.values())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  all_tags: List[str] = []
206
  for tags in groups.values():
207
  all_tags.extend(tags)
208
  all_tags = list(dict.fromkeys(all_tags))
209
 
210
+ query_terms = [_norm_tag_for_lookup(str(t)) for t in seed_terms if str(t).strip()]
211
+ query_terms = list(dict.fromkeys(query_terms))
212
+ if not query_terms:
 
 
 
 
 
 
213
  return []
214
 
215
+ scored_tags = list(all_tags)
216
+
217
+ prob_by_tag_tfidf: Dict[str, float] = {}
218
+ if use_tfidf:
219
+ components = get_tfidf_components()
220
+ tag_vectors = get_tfidf_tag_vectors()
221
+ idf = components["idf"]
222
+ term_to_col = components["tag_to_column_index"]
223
+ svd = components["svd_model"]
224
+ tag_to_row = tag_vectors["tag_to_row_index"]
225
+ mat_norm = tag_vectors["reduced_matrix_norm"]
226
+
227
+ pseudo_doc: Dict[str, float] = {}
228
+ for term in query_terms:
229
+ if term in term_to_col:
230
+ pseudo_doc[term] = pseudo_doc.get(term, 0.0) + 1.0
231
+
232
+ if pseudo_doc:
233
+ tfidf_tags: List[str] = []
234
+ tfidf_rows: List[int] = []
235
+ for tag in all_tags:
236
+ idx = tag_to_row.get(tag)
237
+ if idx is None:
238
+ continue
239
+ tfidf_tags.append(tag)
240
+ tfidf_rows.append(int(idx))
241
+
242
+ if tfidf_rows:
243
+ pseudo_vec = construct_pseudo_vector(pseudo_doc, idf, term_to_col)
244
+ q = svd.transform(pseudo_vec).reshape(-1).astype(np.float32)
245
+ qn = float(np.linalg.norm(q))
246
+ if qn > 0.0:
247
+ q = q / qn
248
+ sims = (mat_norm[np.asarray(tfidf_rows, dtype=np.int32)] @ q).astype(np.float32)
249
+ tfidf_score_by_tag: Dict[str, float] = {t: float(s) for t, s in zip(tfidf_tags, sims)}
250
+ prob_by_tag_tfidf = _calibrate_probabilities(tfidf_score_by_tag)
251
+
252
+ prob_by_tag_fasttext: Dict[str, float] = {}
253
+ if use_fasttext:
254
+ fasttext_score_by_tag = _build_fasttext_score_by_tag(scored_tags, query_terms)
255
+ if fasttext_score_by_tag:
256
+ prob_by_tag_fasttext = _calibrate_probabilities(fasttext_score_by_tag)
257
+
258
+ if not prob_by_tag_tfidf and not prob_by_tag_fasttext:
259
+ return []
260
 
261
  rows_out: List[GroupRankingRow] = []
262
  rank_k = max(1, int(group_rank_top_k))
263
  display_k = max(1, int(top_tags_per_group))
264
 
265
  for group_name, tags in groups.items():
266
+ group_tfidf_w, group_fasttext_w = per_group_weights.get(
267
+ group_name,
268
+ (default_tfidf_w, default_fasttext_w),
269
+ )
270
+ scored: List[Tuple[str, float]] = []
271
+ for t in tags:
272
+ p = _blend_prob(
273
+ tag=t,
274
+ tfidf_w=group_tfidf_w,
275
+ fasttext_w=group_fasttext_w,
276
+ prob_by_tag_tfidf=prob_by_tag_tfidf,
277
+ prob_by_tag_fasttext=prob_by_tag_fasttext,
278
+ )
279
+ if p > 0.0:
280
+ scored.append((t, p))
281
  if not scored:
282
  continue
283
  scored.sort(key=lambda x: x[1], reverse=True)