Food Desert commited on
Commit
c6be992
·
1 Parent(s): 4fdda86

Add alias-based character tag filtering for Stage 3

Browse files
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ *.pyc
4
+ *.log
5
+ *.tmp
6
+ .DS_Store
7
+ .env
8
+ zout.txt
9
+ tf_idf_files_420.joblib
10
+ e621FastTextModel010Replacement_small.bin
11
+ tfidf_hnsw_artists.bin
12
+ tfidf_hnsw_tags.bin
.vscode/launch.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.2.0",
3
+ "configurations": [
4
+ {
5
+ "name": "Run app.py",
6
+ "type": "python",
7
+ "request": "launch",
8
+ "program": "${workspaceFolder}/app.py",
9
+ "console": "integratedTerminal",
10
+ "envFile": "${workspaceFolder}/.env"
11
+ }
12
+ ]
13
+ }
.vscode/settings.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "python.defaultInterpreterPath": ".venv/Scripts/python.exe",
3
+ "python.analysis.typeCheckingMode": "basic",
4
+ "python.analysis.autoImportCompletions": true,
5
+ "python.analysis.extraPaths": ["."]
6
+ }
AGENTS.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Codex Instructions (Prompt_Squirrel_RAG)
2
+
3
+ ## Environment (Windows / PowerShell)
4
+ - Always run from the repo root.
5
+ - Never run `python` or `pip` directly.
6
+ - Always use the venv interpreter:
7
+ - `.venv\Scripts\python.exe`
8
+ - Install deps with:
9
+ - `.venv\Scripts\python.exe -m pip install -r requirements.txt`
10
+
11
+ ## Change discipline
12
+ - Keep diffs small: fix one issue or implement one focused step per patch.
13
+ - Do not rewrite large files.
14
+ - Do not move logic across modules unless the contract requires it.
15
+ - Preserve stage boundaries: rewriting (LLM) vs retrieval (candidate generation) vs selection (index-only).
16
+
17
+ ## Project contracts
18
+ - Follow the retrieval grounding / candidate generation contract:
19
+ - `docs/retrieval_contract.md`
20
+ - If behavior conflicts with existing code, update code to match the contract (not the other way around).
ConvertSampleImagesToJpeg.ipynb ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 7,
6
+ "id": "4aa04654",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": []
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 1,
14
+ "id": "098e115f",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import glob\n",
19
+ "import os\n",
20
+ "import json\n",
21
+ "from PIL import Image\n",
22
+ "from sd_parsers import ParserManager\n",
23
+ "\n",
24
+ "# Directory with PNG images\n",
25
+ "image_directory = 'E:/image/holder/Tagset_Completer/sampleimages/02landscape'\n",
26
+ "\n",
27
+ "# Initialize the ParserManager\n",
28
+ "parser_manager = ParserManager()\n",
29
+ "\n",
30
+ "# Dictionary for artist names to corresponding JPG file names\n",
31
+ "artist_to_file_map = {}\n",
32
+ "\n",
33
+ "# Iterate through PNG files in the directory\n",
34
+ "for png_file in glob.glob(os.path.join(image_directory, '*.png')):\n",
35
+ " with Image.open(png_file) as img:\n",
36
+ " # Extract metadata using ParserManager\n",
37
+ " prompt_info = parser_manager.parse(img)\n",
38
+ " if prompt_info and prompt_info.prompts:\n",
39
+ " first_prompt_text = list(prompt_info.prompts)[0].value.split(',')[0].strip()\n",
40
+ " if first_prompt_text.startswith(\"by \"):\n",
41
+ " first_prompt_text = first_prompt_text[3:] # Remove \"by \" prefix\n",
42
+ " artist_to_file_map[first_prompt_text] = os.path.basename(png_file).replace('.png', '.jpg')\n",
43
+ " else:\n",
44
+ " artist_to_file_map[\"\"] = os.path.basename(png_file).replace('.png', '.jpg')\n",
45
+ "\n",
46
+ "# Save the mapping to a JSON file in the same directory\n",
47
+ "json_path = os.path.join(image_directory, 'artist_to_file_map.json')\n",
48
+ "with open(json_path, 'w') as json_file:\n",
49
+ " json.dump(artist_to_file_map, json_file, indent=4)\n"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 2,
55
+ "id": "ac5cba7f",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "# Iterate through PNG files in the directory\n",
60
+ "for png_file in glob.glob(os.path.join(image_directory, '*.png')):\n",
61
+ " # Open the image\n",
62
+ " with Image.open(png_file) as img:\n",
63
+ " # Convert the image to RGB mode in case it's RGBA or P mode\n",
64
+ " img = img.convert('RGB')\n",
65
+ " # Define the output filename replacing .png with .jpg\n",
66
+ " jpg_file = png_file.rsplit('.', 1)[0] + '.jpg'\n",
67
+ " # Save the image in JPG format\n",
68
+ " img.save(jpg_file, 'JPEG')\n",
69
+ " # Optionally, remove the original PNG file\n",
70
+ " os.remove(png_file)\n"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "32bfb9cc",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": []
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "id": "3648a9fc",
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": []
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": null,
92
+ "id": "09f74cbd",
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "\n"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 4,
102
+ "id": "d2e18c17",
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "\n"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "id": "354fda37",
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": []
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "id": "ac4e5911",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": []
124
+ }
125
+ ],
126
+ "metadata": {
127
+ "kernelspec": {
128
+ "display_name": "Python 3 (ipykernel)",
129
+ "language": "python",
130
+ "name": "python3"
131
+ },
132
+ "language_info": {
133
+ "codemirror_mode": {
134
+ "name": "ipython",
135
+ "version": 3
136
+ },
137
+ "file_extension": ".py",
138
+ "mimetype": "text/x-python",
139
+ "name": "python",
140
+ "nbconvert_exporter": "python",
141
+ "pygments_lexer": "ipython3",
142
+ "version": "3.10.9"
143
+ }
144
+ },
145
+ "nbformat": 4,
146
+ "nbformat_minor": 5
147
+ }
Prompt_Squirrel_RAG.code-workspace ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "folders": [
3
+ {
4
+ "path": "."
5
+ }
6
+ ],
7
+ "settings": {}
8
+ }
README.md CHANGED
@@ -1,14 +1,14 @@
1
- ---
2
- title: Prompt Squirrel RAG
3
- emoji: 📚
4
- colorFrom: pink
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 6.5.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: RAG interface for Prompt Squirrel
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: Prompt Squirrel
3
+ emoji: 🐿️
4
+ colorFrom: gray
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.43.1
8
+ python_version: 3.10.12
9
+ app_file: app.py
10
+ pinned: false
11
+ license: apache-2.0
12
+ tags:
13
+ - not-for-all-audience
14
+ ---
SamplePrompts.csv ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Prompts:,,,,
2
+ name,source,description,prompt,negative
3
+ soyjak,drhead,simple prompt shows styles more intensely,"by artist, soyjak, anthro, male, bust portrait, meme, grin",
4
+ landscape,https://e621.net/posts/320878,tags from a landscape featuring no characters,"by artist, amazing background, cliff, cloud, crystal, detailed background, fantasy, forest, grass, high-angle view, horizon, landscape, monument, mountain, nature, not furry, outside, plant, plateau, river, rock, scenery, scenery porn, sculpture, sky, spikes, statue, tower, tree, water, waterfall, wood, zero pictured",nsfw
5
+ goat,https://e621.net/posts/2741820,tags from a high-scoring image featuring a male furry,"by artist, bovid, caprine, goat, mammal, angry, anthro, bar emanata, bell, bell collar, blush, child, collar, cowbell, daww, emanata, fur, hair, horizontal pupils, horn, male, nude, open mouth, orange eyes, pupils, red collar, simple background, solo, square pupils, tongue, unusual pupils, white body, white fur, white hair, young, young anthro",nsfw
6
+ ,,,,
7
+ ,,,,
8
+ ,,,,
9
+ Artists,,,,
10
+ {by grypwolf|by evilymasterful|by domovoi lazaroth|by krazyelf|by bzeh|by harmarist|by doxy|by cervina7 \(artist\)|by nastycalamari|by tyroo|by secretly saucy|by siroc|by jrjresq|by stylusknight|by raccoondouglas|by furlana|by slimefur|by aycee|by ncs|by areye \(artist\)|by devo87|by youjomodoki|by qupostuv35|by seraziel|by juiceps|by dezz|by sligarthetiger|by scafen \(artist\)|by brolaren|by ro|by 0r0ch1|by zeta-haru|by glacierclear|by kluclew|by feretta|by the gentle giant|by pata|by raikissu|by f-r95|by wolfy-nail|by darkenstardragon|by tokifuji|by flamespitter|by twinkle-sez|by aennor|by dangpa|by twistedscarlett60|by neelix|by scruffythedeer|by frenky hw|by hladilnik|by quotefox|by w4g4|by ancesra|by tzarvolver|by wolflong|by katahane3|by saurian \(artist\)|by ittybittykittytittys|by km-15|by nawka|by utopianvee|by anchee|None|by darkgem|by joaoppereiraus|by kittydee|by monkeyspirit|by tailzkim|by sidnithefox|by killioma|by cyancapsule|by asnnonaka|by skidoo|by iwbitu|by shadman|by luccatoasty|by re-sublimity-kun|by hyilpi|by sepulte|by cumbread|by sususuigi|by r3drunner|by jailbird|by agitype01|by chikaretsu|by lonbluewolf|by rick griffin|by euyoshi89|by cold-blooded-twilight|by domasarts|by katarhein|by fivel|by nextel|by negger|by mcfli|by gekasso|by anglo|by securipun|by zeiro|by cocoline \(artist\)|by lizardlars|by sabrotiger|by dripponi|by krokobyaka|by type|by bastionshadowpaw|by amberpendant|by chromapan|by buta99|by demicoeur|by alfa995|by spuydjeks|by spirale|by shaolin bones|by seth-iova|by complextree|by freckles \(artist\)|by angiewolf|by glopossum|by aoizuri|by inuzu|by zourik|by manmosu marimo|by sijimmy456|by zummeng|by mleonheart|by macaronneko|by pache riggs|by kanashiipanda|by smileeeeeee|by sicklyhypnos|by diacordst|by haychel|by zawmg|by orionsmaniac \(artist\)|by vhkansfweer|by tsampikos|by johnfoxart|by zp92|by gammainks|by gerrkk|by aomori|by kionant|by kanel|by tattoorexy|by mcfan|by sepiruth|by clockhands|by carpetwurm|by capaoculta|by miles df|by sana!rpg|by carrot \(artist\)|by inno-sjoa|by raptoral|by thericegoat|by iriedono|by acstlu|by rov|by glitter trap boy|by redrusker|by ldr|by frumples|by nikraccoom|by mystikfox61|by haaru|by ketei|by somik|by zinfyu|by jinu|by zoyler|by rotten robbie|by nurinaki|by sincrescent|by bonnie bovine|by cooliehigh|by s1m|by dash ravo|by jakethegoat|by claweddrip|by 007delta|by jizoku|by personalami|by marblesoda|by dagasi|by chrysalisdraws|by marik azemus34|by nnecgrau|by atrolux|by slugbox|by imgonnaloveyou|by snowskau|by drmax|by lazysnout|by xennos|by oro97|by dark violet|by eternity-zinogre|by nepentz|by rysonanthrodog|by sigma x|by omega56|by letodoesart|by skully|by delki|by ratatooey|by codyblue-731|by honeycalamari|by saltyxodium|by fleet-foot|by ashraely|by cobaltsynapse|by edjit|by twang|by etheross|by chelodoy|by shinodage|by dlw|by twiren|by ssssnowy|by nikkibunn|by backsash|by syuro|by zaush|by skeleion|by chunie|by butterchalk|by loimu|by seibear|by r-mk|by cobalt snow|by braeburned|by eldiman|by einshelm|by trigaroo|by eto ya|by gewitter|by wizzikt|by hyattlen|by coffeesoda|by photonoko|by woolrool|by jarnqk|by nuzzo|by inu-sama|by ruaidri|by jishinu|by merrunz|by hioshiru|by thousandfoldfeathers|by desertkaiju|by kakhao|by xeono|by b-epon|by nexivian|by smiju|by captainzepto|by meesh|by catcouch|by sorc|by ajin|by rajii|by tofu froth|by sagaris uwu|by burgerkiss|by black-kitten|by kawfee|by lizet|by berseepon09|by sssonic2|by backlash91|by doomthewolf|by arbuzbudesh|by k 98|by picturd|by rayka|by soulcentinel|by adelaherz|by babywife|by stargazer|by elicitie|by rakisha|by kuroodod|by discordthege|by the-minuscule-task|by rainbowscreen|by skygracer|by lynncore|by itsunknownanon|by goonie-san|by kekitopu|by ultrabondagefairy|by mawmain|by hoodie \(artist\)|by truegrave9|by modca|by stoopix|by fumiko|by patto|by iskra|by the crab mage|by narse|by zero-sum|by digitoxici|by abesdrawings|by yuio|by zhanbow|by avante92|by hinar miler|by kikurage|by raaz|by romarom|by iztli|by unknown artist|by foxovh|by dimwitdog|by miso souperstar|by totesfleisch8|by keadonger|by piporete|by valkoinen|by jay-r|by thesecretcave|by smitty g|by pixelsketcher|by youwannaslap|by seff|by sicmop|by dragonfu|by magnetus|by chloe-dog|by alibi-cami|by bonifasko|by dankflank|by pakwan008|by deymos|by viejillox|by lysergide|by metal \(artist\)|by vader-san|by lockworkorange|by prsmrti|by halbean|by naive tabby|by shoutingisfun|by kiyosan|by daftpatriot|by gothbunnyboy|by anonymous artist|by hark|by phenyanyanya|by tsudamaku|by koorinezumi|by natoli|by jackaloo|by boo3|by tfancred|by nana gel|by reddragonkan|by flinters|by amegared|by markie|by nishi oxnard|by chrisandcompany|by triadfox|by dlrowdog|by hentai boy|by lizheru|by buzzer \(artist\)|by satsumalord|by pasaran|by foxfoxplz|by blpanda|by babystar|by yantaro keno|by renee-moonveil|by 9x9|by tombola1993|by raptor007|by chaostone|by cooner|by mt tg|by ficficponyfic|by sarcolopter|by azumaril|by dreadwolfclaw1990|by bigshow|by fierglief|by bobert|by zeriara|by mac-daddy|by dragmon|by jbond|by trevor-fox|by parclytaxel|by kusosensei|by gyrotech|by itoruna|by a.b. lust|by superbunnygt|by doneru|by box xod|by lefthighkick|by uniparasite|by malicekira|by mizzyam|by vrabo|by sacrificabominat|by zer0rebel4|by rikitoka|by karabiner|by fredryk phox|by mot|by rairai-no26-chu|by citrinelle|by jrvanesbroek|by makarimorph|by torakuta|by 1boshi|by skyelegs|by kanada|by darkdoomer|by smudge proof|by riorix|by kitchiki|by bristol|by fuze|by dirtyscoundrel|by foxball|by badumsquish|by ken sugimori|by lovelesskiax|by ricky hoffman|by buta5kawa|by roobin|by grumpy griffin creations|by mastergodai|by imperatorcaesar|by lagotrope|by ichthy0stega|by dark-moltres|by smutbooru|by deanwolfwood|by kamui shirow|by koraru-san|by foxenawolf|by caramelcraze|by date natsuku|by cotton \(artist\)|by catmonkshiro|by julius zimmerman|by hitec|by snow utamaru|by ottahz|by ryuko rose|by takagi kyou|by ka-samy|by ittybittyshark|by dynoex|by hatake|by kraken \(artist\)|by ruthredmane|by cybercat|by honesty \(artist\)|by freeze-pop88|by kinoshita-jiroh|by sobieniak|by viroveteruscy|by kelly hamilton|by pembrokewkorgi|by hinami|by kick \(artist\)|by train \(artist\)|by mind drive|by ayaka|by harpseal|by ukisudori|by inunoshippo|by sikai|by jamminbison|by artsy-theo|by marco fanjul|by wolfmalro|by positive wishes \(artist\)|by schwartzgeist|by utsuki maito|by bunnie love|by mulefoot|by chris goodwin|by poge jirushi|by thegreatmatsutzu|by sachiel 666|by inkyfrog|by dtalvi|by rorr|by fab3716|by rex equinox|by navitaserussirus|by rousemouse|by bitterplaguerat|by dannyg|by sbshouseofpancakes|by slb|by edgar rice burroughs|by doug winger|by maxime-jeanne|by rocket grunt \(artist\)|by usuario2 \(artist\)|by mauroz|by sailoranna|by tatwuyan|by tkc2021|by misterdonn|by tanutronik753 k|by namagakiokami|by emufu|by suishou0602|by macop|by bakukurara|by oogamikennta|by tigerlilylucky|by mike sherman|by snowfyre|by mylafox|by kitfox-crimson|by arania|by selinatc|by toshi \(artist\)|by mofuaki|by pokefound|by delirost|by galacticmichi|by doost|by trixythespiderfox|by darkmirage|by aogami|by meraence|by isolatedartest|by nottrevbe|by nsfwzhenya|by fourball|by manene|by trinity-fate62|by kilinah|by ingi|by latchk3y|by pochincoff|by welost|by skipsy|by bunnybits|by lunalei|by yousan|by kaynine|by honovy|by dream and nightmare|by wugi|by viskasunya|by faejunkie|by v-tal|by sabuky|by faeki|by kammi-lu|by foxes in love|by nightfaux|by virtyalfobo|by peculiart|by rika|by marsminer|by discreet user|by marshmallow-ears|by aeonspassed|by dreiker|by lyme-slyme|by punkypanda|by ponporio \(artist\)|by sonsasu|by kame 3|by pururing|by wbnsfwfactory|by bikupan|by bigdon1992|by lichfang|by bakemonoy|by b-ern|by merunyaa|by redishdragie|by lightsource|by enigi09|by hanuvo|by justmegabenewell|by thefuckingdevil|by minnosimmins|by qwertydragon|by fakeryway|by cotora|by ark warrior|by danomil|by avoid posting|by kostos art|by ratcha|by atryl|by fuf|by lvlirror|by theboogie|by nitani|by roly|by aer0 zer0|by hardyboy|by nozomyarts|by sinsquared|by cherrikissu|by asaneman|by tfzn|by hooves-art|by catsudon|by bigcozyorca|by mr.smile|by sinensian|by nukochi|by felino|by toto draw|by mytigertail|by arrwulf|by oselotti|by gorsha pendragon|by laser \(artist\)|by doesnotexist|by nekowuwu|by alanscampos|by el-loko|by compfive|by komdog|by magenta7|by milachu92|by serex|by bigdad|by aaron \(artist\)|by diadorin|by pig \(artist\)|by slickerwolf|by angstrom|by kihu|by ike marshall|by chalo|by furball \(artist\)|by lavenderpandy|by hunterramirez|by kloudmutt|by jerseydevil|by zi ran|by moreuselesssource|by ocaritna|by rukifox|by tggeko|by kiseff|by e254e|by princelykaden|by artdecade|by inuki|by prrrrrrmine|by chewycuticle|by haps|by senz|by argento|by daigaijin|by falcrus|by omari|by risenpaw|by satsukii|by lollipopcon|by ralek|by kyrosh|by tush|by reccand|by sindoll|by zerofox1000|by kaboozey|by somescrub|by yurusa|by limebreaker|by keffotin|by matemi|by uromatsu|by roadiesky|by saku1saya|by knightmoonlight98|by zerolativity|by winick-lim|by harnny|by girlsay|by sukebepanda|by sparrow \(artist\)|by amazinggwen|by slug \(artist\)|by smoothlabs|by eleacat|by replica \(artist\)|by thewill|by kevinsano|by feliscede|by james howard|by moki|by skylardoodles|by hyucaze|by lumineko|by conditional dnp},,,,
TagDocumentation.txt ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ tag what you see (locked)
3
+
4
+ Regardless of what you know from outside sources, only tag what you can see in the image.
5
+
6
+ Also, make sure you check out e621:Tag What You See (Explained) for the reasoning behind the TWYS policy.
7
+
8
+ Unlike many other art sites, e621 has a tagging policy called "Tag What You See", or TWYS for short.
9
+ TWYS states that all General category tags on a post must be directly evident from within the post itself. TWYS applies only to visual elements within a post, such as objects, characters, and the actions taken by characters that are visible. Audio content is not tagged, except in the Meta category.
10
+
11
+ For example, a solo picture of a character who appears male must be tagged male.
12
+ That remains true even if the artist or the character owner themselves state that the character is not male, or if text within the image states that the character is not male. These tags refer strictly to a character's outward appearance and nothing more.
13
+
14
+ This policy exists to make search results more predictable and objective.
15
+ Note that you can use Lore tags to describe the stated genders of the characters, rather than the visible ones.
16
+
17
+ Tags in other categories are not entirely subject to the Tag What You See principle:
18
+
19
+ Tags in the Lore category are meant to convey the artist's intentions or other background information that cannot be reliably determined via TWYS, such as gender identity or familial relations.
20
+ Tags in the Character and Species categories are partially dependent upon TWYS: that is, external information can be used to help identify what character or species is supposed to be depicted in the post in cases where it isn't obvious, but it cannot actively conflict with what is seen in the post. For example, you can tag character a if the artist claims that a disembodied hand in the post belongs to character a, unless the hand looks nothing like character a and instead looks exactly like it belongs to character b. In that case, TWYS overrides the artist's word.
21
+ Tags in other categories are valid if the information that they convey is objectively true, such as the artist's name, the image's aspect ratio, or the IP holder of the characters in the post. For MP4, WebM, and Flash posts, audio-related tags may be included in the Meta category, but only to the extent of describing the presence and type of audio in the post (see the sound article for more information).
22
+
23
+ There will be times when it's still not clear what tags should be applied to an image. An administrator should be contacted to help resolve such cases.
24
+
25
+ Leeway may be given to hybrid characters, as the components of the species by which they are comprised are not always obvious.
26
+
27
+ Note: tag_what_you_see is not a tag to be used. If a post is contains this tag, please remove it.
28
+ See also
29
+
30
+ Help with tags
31
+ How to tag genders
32
+ Overly Specific
33
+ Tag What You See (Explained)
34
+ Tagging Checklist
35
+
36
+ Posts (view all)
37
+
38
+ Nobody here but us chickens!
39
+
40
+
41
+ ###
42
+
43
+
44
+ e621:tag what you see (explained) (locked)
45
+
46
+ [Back: e621:index]
47
+
48
+ The text below is intended to be a sort of "introduction" to e621's Tag What You See policy. The text below is NOT the policy itself, which you can view here: Tag What You See
49
+
50
+ Reading and understanding the TWYS policy is extremely important if you intend on editing tags on posts at all, so please make sure you read the policy itself as well as this introduction.
51
+ The Policy
52
+
53
+ A brief summary of what the TWYS policy is:
54
+
55
+ Unlike many other art sites, e621.net has a tagging policy called "Tag What You See" (aka: "TWYS"). With very few exceptions, TWYS says that all tags on a post must be directly verifiable within the post itself. Example: a solo picture of what APPEARS to be a male character will be tagged "male". Even if the character was defined as "female" on other sites by the artist or character owner themselves, the picture would still need to be tagged "male" on e621, because of the TWYS policy.
56
+
57
+ This may seem unusual and even insensitive, but please read on to understand why the site functions this way.
58
+ The Debate
59
+
60
+ The dispute between "Tag What You See" and "Tag What You Know"
61
+
62
+ The Reasons
63
+
64
+ There are several reasons for the necessity of the TWYS policy.
65
+
66
+ The Problems
67
+
68
+ Of course, no method of tagging is perfect, and there are a few problems that tend to arise as a result of using TWYS:
69
+
70
+ Sometimes users are just going to disagree over what is "seen" in a post or not. This is simply an expected consequence of having a TWYS policy. These situations will often need intervention from an administrator in order to resolve.
71
+
72
+ Gender tags (male, female, herm, etc) are typically at the heart of most TWYS debates. The reasons for this are numerous, but it boils down to A) artists drawing characters in ways that make it difficult to determine gender, and B) characters designed in such a way that they can easily appear to be either one gender or another (e.g. a herm wearing clothes typically looks just female). Again, there's nothing "wrong" with doing this, but it undoubtedly leads to confusion and people getting the wrong ideas if the artwork is ever viewed by itself. Again, e621 currently is interested only in a character's APPARENT gender, not their DEFINED gender. But sometimes even the apparent gender isn't obvious; in these cases, an administrator will need to make the final decision.
73
+
74
+ Tip:
75
+
76
+ To quickly link other e621 users to this page, simply type [[twys]] in your message. Example: "Check out twys for an explanation of the TWYS rule."
77
+
78
+ [Back: e621:index]
79
+
80
+
81
+ ###
82
+
83
+
84
+ e621:tagging checklist (locked)
85
+
86
+ [Back: e621:index]
87
+
88
+ This is an informal and unofficial supplement to the tagging rules and guidelines, meant to encourage better and more complete tagging.
89
+
90
+ Make sure you're also familiar with our Tag What You See policy before editing tags: tag_what_you_see for the policy itself, and e621:Tag What You See (Explained) for a more in-depth explanation why we use TWYS.
91
+
92
+ Each entry below poses a general question about a post, with some example tags that answer it. A good post will probably have most of these answered (but not necessarily all).
93
+ Basics
94
+
95
+ Tags that all posts should have, to maintain minimal searchability.
96
+
97
+ Artist(s)? Use their best known alias. If a picture has more than one artist, tag them all, along with collaboration. If you're not sure who the artist is, tag unknown_artist. If the artist wishes to remain anonymous, use anonymous_artist instead.
98
+ Rating?
99
+ Explicit for fully or partially exposed genitalia (penis, pussy, cloaca, sheath, balls, or anus), various sex acts even if no genitalia are visible, high amounts of violence/gore, sexual fluids such as cum or pussy_juice, and extreme sexual fetishes such as scat, watersports, or BDSM.
100
+ Safe for anything that can be viewed in public without much uproar: no genitals, no sexual overtones or poses, no realistic violence, or any questionable activity.
101
+ Questionable for everything in between, such as topless females and suggestive poses.
102
+ For more help on ratings please see e621: Ratings
103
+ Copyright? The original series or company a character or game is owned by.
104
+ Character? Tag the character's best known name. If not that, their full name.
105
+ Body type? anthro, feral, humanoid, taur, anthrofied (pokemorph, digimorph), ponified, feralized
106
+ Species? human, canine, feline, bovine, cervine, equine, lagomorph, rodent, avian, insect, marine (cetacean, shark), scalie (click for detailed lists)
107
+ Sex/gender? male, female, intersex (herm, maleherm, gynomorph, andromorph), ambiguous_gender
108
+ See How To: Tag Genders for a detailed guide
109
+ How many? solo, duo, trio, group, zero_pictured
110
+ Clothing? fully_clothed, partially_clothed, skimpy, nude, bottomless, topless, underwear, open_shirt
111
+ Location? inside, outside, bedroom, kitchen, forest
112
+ Perspective? front_view, rear_view, side_view, three-quarter_view, low-angle_view, high-angle_view, worm's-eye_view, bird's-eye_view, first_person_view
113
+
114
+ Sexually explicit
115
+
116
+ Male bits? penis, balls, sheath, knot, erection, half-erect, flaccid, humanoid_penis, equine_penis, tapering_penis, veiny_penis, uncut, circumcised
117
+ Female bits? pussy, clitoris, plump_labia, equine_pussy, canine_pussy
118
+ Other? butt, anus, puffy_anus, gaping_anus, urethra, genital_slit
119
+ Sex act? sex (male/female, female/female, male/male, bisexual), masturbation, handjob, footjob, fellatio, cunnilingus, vaginal_penetration, anal_penetration, threesome, foursome, orgy, gangbang, frottage, tribadism, orgasm, cum_inside
120
+ Position? Common ones: missionary_position, cowgirl_position, reverse_cowgirl_position, from_behind, 69_position, stand_and_carry_position.
121
+ See also: tag group:sex positions
122
+ Sexual themes? bondage, domination, rape, rough_sex, happy_sex, presenting, internal, impregnation, bestiality, interspecies, public, exhibitionism
123
+ Fluids? cum, cumshot, precum, pussy_juice, pussy_ejaculation, saliva
124
+ Toys? dildo, vibrator, buttplug, egg_vibrator, strapon, feeldoe
125
+
126
+ Pose / Activity / Appearance
127
+
128
+ General activity (if any)? walking, running, fighting, sleeping, dancing, eating, kissing, licking
129
+ Posture? standing, bent_over, sitting, crouching, kneeling, all_fours, on_front, on_side, on_back, ass_up (see tag group:pose for full list)
130
+ Body decor? glasses, ring, necklace, bracelet, anklet, tattoo, piercing, collar, hat
131
+ Fur style? mane, chest_tuft, pubes
132
+ Hair? hair, long hair, short hair
133
+ Breasts? breasts (small_breasts, big_breasts, huge_breasts), nipples, under_boob, side_boob, teats
134
+ Limbs? crossed_arms, raised_arms, arms_behind_head, spread_legs, crossed_legs, raised_leg, legs_up, raised_tail, tailwag
135
+ Gaze? looking_at_viewer, looking_back, eye_contact, eyes_closed
136
+ Expression? blush, wink, smile, grin, tongue_out, naughty_face, embarrassed, happy, sad
137
+
138
+ Information and Requests
139
+
140
+ Quality/medium? sketch, line_art, monochrome, shaded, pencil_(artwork), watercolor, 3D, digital_media_(artwork)
141
+ Picture organization? comic, multiple_scenes, sequence, close-up, portrait, pinup, solo_focus, wallpaper
142
+ Style? toony, detailed, realistic
143
+ Text and languages? english_text, japanese_text, spanish_text, runes, dialogue, speech_bubble, symbol
144
+ Information? translated, partially_translated, unknown_artist_signature, not_furry, bigger version at the source
145
+ Requests? translation_request, source_request, tagme
146
+ Image size? low_res, hi_res, absurd_res, superabsurd_res
147
+ Year of creation? 2016, 2015, and so on
148
+
149
+ Heavily vetted tags.
150
+
151
+ Tags that can be found on our global blacklist, and heavily vetted tags MUST be added upon upload.
152
+
153
+ young, gore, scat, watersports, diaper, my little pony, vore, not furry, rape, hyper, feral, nazi, politics, zoophile iconography.
154
+ Everything pedophilia
155
+
156
+ Do NOT tag
157
+
158
+ Subjective tags that express opinions. Common examples include beautiful, sexy, hot, good, crappy and most other adjectives. Subjective themes can be collected into a set instead. (See https://e621.net/help/sets )
159
+ Generic tags such as legs, eyes, big, image and organism.
160
+
161
+
162
+ ###
163
+
164
+
165
+ Help: Tags
166
+
167
+
168
+ ← E621 Wiki – Tags
169
+ Table of Contents
170
+
171
+ Guidelines
172
+ Categories
173
+ Artist
174
+ Contributor
175
+ Character
176
+ Copyright
177
+ Species
178
+ General
179
+ Meta
180
+ Lore
181
+ Invalid
182
+ Changing Tag Category
183
+
184
+ Read More: Aliases | Implications | Bulk Update Requests
185
+       Search Cheatsheet
186
+ Tags
187
+
188
+ Tags are keywords that you can use to describe posts.
189
+ They serve a dual purpose: they allow you to both find the content that you like, and to filter out stuff that you dislike.
190
+
191
+ Tags may belong to various categories, and may interact with each other via relationships.
192
+
193
+ See the cheatsheet for examples of the search syntax.
194
+
195
+
196
+ ↑ Guidelines
197
+
198
+ When tagging a post, you must follow the following guidelines.
199
+ Tag What You See
200
+
201
+ Full article: Tag What You See.
202
+
203
+ Unlike many other art sites, e621 has a tagging policy called "Tag What You See", or TWYS for short.
204
+ TWYS states that all General category tags on a post must be directly evident from within the post itself.
205
+
206
+ For example, a solo picture of a character who appears male must be tagged male.
207
+ That remains true even if the artist or the character owner themselves state that the character is not male, or if text within the image states that the character is not male. These tags refer strictly to a character's outward appearance and nothing more.
208
+
209
+ This policy exists to make search results more predictable and objective.
210
+ Note that you can use Lore tags to describe the stated genders of the characters, rather than the visible ones.
211
+
212
+ Tags in other categories are not entirely subject to the Tag What You See principle:
213
+
214
+ Tags in the Lore category are meant to convey the artist's intentions or other background information that cannot be reliably determined via TWYS, such as gender identity or familial relations.
215
+ Tags in the Character and Species categories are partially dependent upon TWYS: that is, external information can be used to help identify what character or species is supposed to be depicted in the post in cases where it isn't obvious, but it cannot actively conflict with what is seen in the post. For example, you can tag character a if the artist claims that a disembodied hand in the post belongs to character a, unless the hand looks nothing like character a and instead looks exactly like it belongs to character b. In that case, TWYS overrides the artist's word.
216
+ Tags in other categories are valid if the information that they convey is objectively true, such as the artist's name, the name of a voice actor, the image's aspect ratio, or the IP holder of the characters in the post.
217
+
218
+ Minimum tag requirements
219
+
220
+ Code of Conduct 2.2 - Tagging, Rating, and Sourcing Abuse
221
+ All posts are expected to have at least ten general, non-implied tags upon upload. This refers to tags in the General tag category: Artist, Character, Species, Copyright, Lore, Meta, and Invalid tags do not count towards this requirement. "Non-implied" means that a tag which is added by implication from another tag does not count. For example, forest implies tree which implies plant. If you add the forest tag, both tree and plant will be added automatically. However, only the first tag counts towards the minimum tag requirement.
222
+
223
+ This restriction will be eased if the post does not have ten distinct tags that are reasonably applicable to it. For example, extremely simplistic posts such as some zero pictured images may not depict enough to create ten tags.
224
+
225
+ Contentious or objectionable content must always be tagged upon upload. This includes any strange, unusual, or extreme fetishes depicted within the post.
226
+ Forbidden characters
227
+
228
+ Tags may only contain English letters, numbers, and some symbols.
229
+ No unicode characters, or characters belonging to languages other than English, may be used.
230
+
231
+ The following characters are reserved for potential future uses.
232
+ No new tags containing them can be created.
233
+
234
+ %,#\\*: anywhere in the tag
235
+ -~: as the first character
236
+
237
+ Note that some existing tags already contain such characters.
238
+ These tags predate the rule change, and will likely be phased out at some point in the future.
239
+ ↑ Categories
240
+
241
+ There are eight categories (or "types") of tags on e621. They help to organize the many tags listed on this site and its many, many posts.
242
+ This page will provide a quick rundown of what they are for and how to change the categories of tags from one to another.
243
+ artist
244
+
245
+ Arguably the most important tag on any post is the one that identifies the person who made the post itself.
246
+ This (usually) isn't the e621 member who uploaded the post, a person who edited the post, and certainly not anyone who merely commissioned or requested the post.
247
+
248
+ Artist tags are essential, as we maintain and respect an Avoid Posting List.
249
+ If you are unable to identify the artist, then unknown_artist should be used. If the artist does not want to be identified, then anonymous_artist should be used instead.
250
+
251
+ There are a few non-artist tags that are deliberately typed as "artists" in order to bring attention to them.
252
+
253
+ avoid_posting and its variant conditional_dnp tags identify artists with DNP or conditional DNP status
254
+ epilepsy_warning is used for flashing lights in animated, Flash, and video posts that could trigger epileptic seizures
255
+ sound_warning is for any loud sound playing in Flash and video posts
256
+ jumpscare_warning is for posts featuring loud sounds (typically screams) accompanied by unsettling or scary visuals.
257
+ unknown_artist_signature is for posts where there is an artist's signature on it, but the artist who made it could not be immediately identified
258
+
259
+ contributor
260
+
261
+ People who did not create the specific artwork in the post but who did provide creative contributions that are considered significant and essential to the artwork itself. (See topic #54179 for the discussion thread about this new category.)
262
+
263
+ Currently, only two types of contributors are recognized for this category.
264
+
265
+ Voice actors, whose tags are suffixed with the disambiguation _(va).
266
+ Character modelers, whose tags are suffixed with the disambiguation _(modeler).
267
+
268
+ Note that the primary artist(s) of a post are not to be tagged as contributor; they are still tagged as artists as normal. If the artist is also the modeler, they are to be tagged as just an artist; modeler tags are to be used if they created or provided a character model but did not provide the composition of the post. Likewise, if the artist of a video post voice acted for their own video, they still don't get a separate contributor tag.
269
+ character
270
+
271
+ Any identifiable fictional or real world individual who can be seen in a post, even if they're not actually "there".
272
+ A statue or a kigurumi modeled after a character, or the cover of a solo music artist's album, would still be tagged as their corresponding characters
273
+ Characters can range from mere fursonas to globally famous copyrighted characters like Mickey Mouse, Bugs Bunny, and Mario. Fan characters are also covered here.
274
+
275
+ If you cannot identify a character, but you do know that they either are owned by someone or come from the real world, then unknown_character should be used.
276
+ copyright
277
+
278
+ Any recognizable brands and franchises (as well as the companies who own them) that can be identified through the use of their characters, settings, or other recognizable elements.
279
+ Parodies of copyrights are also tagged with the copyrights that a post is parodying. Specific holidays like Christmas, Easter, and Halloween are also given copyright status.
280
+
281
+ The real world is also a copyright tag, for what it's worth.
282
+ species
283
+
284
+ The bread and butter tags of this curated furry image archive, covering many real and fictional creatures.
285
+ Cats, dogs, horses, fish, scalies, aliens, robots, spirits, Pocket Monsters, Digital Monsters, regular monsters, and the dreaded but mostly harmless humans are among the many kinds of creatures that you can find here.
286
+
287
+ If you can't properly identify a species, then there are two tags you can use: unknown_species for creatures with identifying features, and ambiguous_species for creatures that cannot be determinable at all.
288
+ general
289
+
290
+ These plain-colored tags are for anything else that don't fit with any of the aforementioned four categories. Genders, objects, distinguishing features, locations, fetishes, sexual positions, sexual acts, and so on.
291
+ New tags are automatically categorized as general tags. Artists, contributors, characters, copyrights, and species that haven't been properly re-typed to such yet are most likely typed as general tags as well.
292
+ meta
293
+
294
+ Tags that describe facts about the image itself, rather than what's in it, are placed in the meta category.
295
+ Some of these tags are added automatically, like hi_res. Others, like 16:9 or 1:1 are added by dedicated bots.
296
+ Tags describing what year the image was made also belong in this category, from 2025 all the way back to 6th_century_bc.
297
+ lore
298
+
299
+ Unlike other categories, lore tags are entirely outside the realm of TWYS. Instead, lore tags provide information that is either incorrect when following TWYS, or simply cannot be confirmed visually in the image itself, yet still relevant to the post.
300
+ Keep in mind that standards TWYS tags should still be used where applicable. Lore tags do not replace them.
301
+
302
+ Whenever a submission must be tagged as something that is "wrong", a lore tag should be added to provide the correct information.
303
+ The most common use for lore tags is to correct gender tags – for example, a post that is tagged gynomorph might also need a herm_(lore) tag if that's what the character is, despite there not being any evidence of that in the image itself.
304
+
305
+ Conversely, some fetish tags (like incest) cannot always be definitively confirmed through the image itself, and thus belong in the lore category.
306
+
307
+ New lore tags can be requested on the forums.
308
+ invalid
309
+
310
+ Some tags are too ambiguous or broad to be useful, so they are placed in the invalid category.
311
+ They should be replaced with better-fitting or more specific tags.
312
+
313
+ Please, do not simply remove invalid tags without fixing the issue.
314
+
315
+
316
+ ###
317
+
318
+
319
+
app.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import logging
4
+ from PIL import Image
5
+ from pathlib import Path
6
+ from typing import List
7
+
8
+ from psq_rag.pipeline.preproc import extract_user_provided_tags_upto_3_words
9
+ from psq_rag.llm.rewrite import llm_rewrite_prompt
10
+ from psq_rag.retrieval.psq_retrieval import psq_candidates_from_rewrite_phrases, _norm_tag_for_lookup
11
+ from psq_rag.llm.select import llm_select_indices
12
+
13
+
14
+ def _split_prompt_commas(s: str) -> List[str]:
15
+ return [p.strip() for p in (s or "").split(",") if p.strip()]
16
+
17
+ def _norm_for_dedupe(tag: str) -> str:
18
+ # your canonical form for lookup/dedupe
19
+ return _norm_tag_for_lookup(tag.lower())
20
+
21
+ def compose_final_prompt(rewritten_prompt: str, selected_tags: List[str]) -> str:
22
+ parts = _split_prompt_commas(rewritten_prompt)
23
+ parts.extend(selected_tags)
24
+
25
+ seen = set()
26
+ out = []
27
+ for p in parts:
28
+ key = _norm_for_dedupe(p)
29
+ if key in seen:
30
+ continue
31
+ seen.add(key)
32
+ out.append(p)
33
+
34
+ return ", ".join(out)
35
+
36
+
37
+ # Set up logging
38
+ # Minimal prod logging: warnings+ to stderr, no file by default
39
+ import os, logging
40
+
41
+ LOG_LEVEL = os.environ.get("PSQ_LOG_LEVEL", "WARNING").upper()
42
+ logging.basicConfig(
43
+ level=getattr(logging, LOG_LEVEL, logging.WARNING),
44
+ format="%(asctime)s %(levelname)s:%(message)s",
45
+ handlers=[logging.StreamHandler()] # no file -> avoids huge logs on Spaces
46
+ )
47
+
48
+ # Quiet down common noisy libs (optional)
49
+ for _name in ("gensim", "gradio", "hnswlib", "httpx", "uvicorn"):
50
+ logging.getLogger(_name).setLevel(logging.ERROR)
51
+
52
+ # Turn off Gradio analytics phone-home to avoid those background thread errors (optional)
53
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "0"
54
+
55
+
56
+ MASCOT_DIR = Path(__file__).parent / "mascotimages"
57
+ MASCOT_FILE = MASCOT_DIR / "transparentsquirrel.png"
58
+
59
+ try:
60
+ from gradio_client import utils as _gc_utils
61
+
62
+ _orig_get_type = _gc_utils.get_type
63
+ _orig_j2p = _gc_utils._json_schema_to_python_type
64
+ _orig_pub = _gc_utils.json_schema_to_python_type
65
+
66
+ def _get_type_safe(schema):
67
+ # Sometimes schema is a bare True/False (JSON Schema boolean form)
68
+ if not isinstance(schema, dict):
69
+ return "any"
70
+ return _orig_get_type(schema)
71
+
72
+ def _j2p_safe(schema, defs=None):
73
+ # Accept non-dict schemas (True/False/None) and treat as "any"
74
+ if not isinstance(schema, dict):
75
+ return "any"
76
+ return _orig_j2p(schema, defs or schema.get("$defs"))
77
+
78
+ def _pub_safe(schema):
79
+ # Public wrapper used by Gradio; keep it resilient too
80
+ if not isinstance(schema, dict):
81
+ return "any"
82
+ return _j2p_safe(schema, schema.get("$defs"))
83
+
84
+ _gc_utils.get_type = _get_type_safe
85
+ _gc_utils._json_schema_to_python_type = _j2p_safe
86
+ _gc_utils.json_schema_to_python_type = _pub_safe
87
+
88
+ except Exception as e:
89
+ print("gradio_client hotfix not applied:", e)
90
+ # -------------------------------------------------------------------------------
91
+
92
+
93
+ allow_nsfw_tags = False
94
+ verbose_retrieval = True
95
+ verbose_retrieval_all = False
96
+ verbose_retrieval_limit = 20
97
+
98
+ css = """
99
+ .scrollable-content{
100
+ max-height: 420px;
101
+ overflow-y: scroll; /* always show scrollbar */
102
+ overflow-x: hidden;
103
+ padding-right: 8px;
104
+ padding-bottom: 14px; /* <— add this */
105
+ scrollbar-gutter: stable; /* prevent layout shift as it fills */
106
+
107
+ /* Firefox */
108
+ scrollbar-width: auto;
109
+ scrollbar-color: rgba(180,180,180,.9) rgba(0,0,0,.15);
110
+ }
111
+
112
+ /* WebKit/Chromium (Chrome/Edge/Safari) */
113
+ .scrollable-content::-webkit-scrollbar{ width: 10px; }
114
+ .scrollable-content::-webkit-scrollbar-thumb{ background: rgba(180,180,180,.9); border-radius: 8px; }
115
+ .scrollable-content::-webkit-scrollbar-track{ background: rgba(0,0,0,.15); }
116
+
117
+ /* (Optional) make both scroll panes taller so they fill more of the column */
118
+ .pane-left .scrollable-content,
119
+ .pane-right .scrollable-content {
120
+ max-height: 610px; /* was 420px; tweak to taste */
121
+ }
122
+ """
123
+
124
+
125
+ def rag_pipeline_ui(user_prompt: str):
126
+ logs = []
127
+ def log(s): logs.append(s)
128
+
129
+ try:
130
+ log("Start: received prompt")
131
+ prompt_in = (user_prompt or "").strip()
132
+ if not prompt_in:
133
+ return "Error: empty prompt", ""
134
+
135
+ log("Input:")
136
+ log(prompt_in)
137
+ log("")
138
+
139
+ user_tags = extract_user_provided_tags_upto_3_words(prompt_in)
140
+ log("Heuristically extracted user tags:")
141
+ if user_tags:
142
+ log(", ".join(user_tags))
143
+ else:
144
+ log("(none)")
145
+ log("")
146
+
147
+ log("Step 1: LLM rewrite")
148
+ rewritten = llm_rewrite_prompt(prompt_in, log)
149
+ log("Rewrite:")
150
+ log(rewritten if rewritten else "(empty)")
151
+ log("")
152
+
153
+ rewrite_for_retrieval = rewritten
154
+ if user_tags:
155
+ # keep them separate in logs, but allow them to help retrieval
156
+ rewrite_for_retrieval = (rewrite_for_retrieval + ", " + ", ".join(user_tags)).strip(", ").strip()
157
+
158
+
159
+ log("Step 2: Prompt Squirrel retrieval (hidden)")
160
+ try:
161
+ rewrite_phrases = [p.strip() for p in (rewrite_for_retrieval or "").split(",") if p.strip()]
162
+ retrieval_result = psq_candidates_from_rewrite_phrases(
163
+ rewrite_phrases=rewrite_phrases,
164
+ allow_nsfw_tags=allow_nsfw_tags,
165
+ global_k=300,
166
+ verbose=verbose_retrieval,
167
+ )
168
+ if isinstance(retrieval_result, tuple):
169
+ candidates, phrase_reports = retrieval_result
170
+ else:
171
+ candidates, phrase_reports = retrieval_result, []
172
+ log(f"Retrieved {len(candidates)} candidate tags")
173
+ if verbose_retrieval:
174
+ log(f"Total unique candidates: {len(candidates)}")
175
+ limit = None if verbose_retrieval_all else max(1, int(verbose_retrieval_limit))
176
+ for report in phrase_reports:
177
+ phrase = report.get("normalized") or report.get("phrase") or ""
178
+ lookup = report.get("lookup") or ""
179
+ tfidf_vocab = report.get("tfidf_vocab")
180
+ log(f"Phrase: {phrase} (lookup={lookup}) tfidf_vocab={tfidf_vocab}")
181
+ rows = report.get("candidates", [])
182
+ shown = rows if limit is None else rows[:limit]
183
+ for row in shown:
184
+ tag = row.get("tag")
185
+ alias_token = row.get("alias_token")
186
+ score_fasttext = row.get("score_fasttext")
187
+ score_context = row.get("score_context")
188
+ score_combined = row.get("score_combined")
189
+ count = row.get("count")
190
+ alias_part = ""
191
+ if alias_token and alias_token != tag:
192
+ alias_part = f" [alias_token={alias_token}]"
193
+ fasttext_str = (
194
+ f"{score_fasttext:.3f}" if isinstance(score_fasttext, (int, float)) else score_fasttext
195
+ )
196
+ if score_context is None:
197
+ context_str = "None"
198
+ else:
199
+ context_str = (
200
+ f"{score_context:.3f}" if isinstance(score_context, (int, float)) else score_context
201
+ )
202
+ combined_str = (
203
+ f"{score_combined:.3f}" if isinstance(score_combined, (int, float)) else score_combined
204
+ )
205
+ log(
206
+ f" {tag}{alias_part} | fasttext={fasttext_str} context={context_str} "
207
+ f"combined={combined_str} count={count}"
208
+ )
209
+ if limit is not None and len(rows) > limit:
210
+ log(f" ... ({len(rows) - limit} more)")
211
+ except Exception as e:
212
+ log(f"Retrieval fallback: {type(e).__name__}: {e}")
213
+ candidates = []
214
+
215
+ log("Step 3: LLM index selection")
216
+ # We pass the original 'prompt_in' as the description for the LLM to match against
217
+ picked_indices = llm_select_indices(
218
+ query_text=prompt_in,
219
+ candidates=candidates,
220
+ max_pick=0,
221
+ log=log
222
+ )
223
+
224
+ selected_tags = [candidates[i].tag for i in picked_indices] if picked_indices else []
225
+
226
+ log("Step 4: Compose final prompt")
227
+ final_prompt = compose_final_prompt(rewritten, selected_tags)
228
+
229
+ log("Done: final prompt ready")
230
+ return "\n".join(logs), final_prompt
231
+
232
+ except Exception as e:
233
+ log(f"Error: {type(e).__name__}: {e}")
234
+ return "\n".join(logs), ""
235
+
236
+
237
+
238
+ with gr.Blocks(css=css) as app:
239
+ with gr.Row():
240
+ with gr.Column(scale=3, elem_classes=["prompt-col"]):
241
+ image_tags = gr.Textbox(
242
+ label="Enter Prompt",
243
+ placeholder="e.g. fox, outside, detailed background, .",
244
+ lines=1
245
+ )
246
+ with gr.Column(scale=1):
247
+ _mascot_pil = Image.open(MASCOT_FILE).convert("RGBA")
248
+ mascot_img = gr.Image(
249
+ value=_mascot_pil,
250
+ show_label=False,
251
+ interactive=False,
252
+ height=220,
253
+ elem_id="mascot"
254
+ )
255
+ submit_button = gr.Button("Run", variant="primary")
256
+
257
+ gr.Markdown(
258
+ """
259
+ ### Prompt Squirrel RAG (pipeline version)
260
+
261
+ Type a rough prompt. This tool rewrites it and aligns it to an e621-style tag vocabulary using Prompt Squirrel internally,
262
+ then returns a cleaned, model-friendly prompt.
263
+ """.strip()
264
+ )
265
+
266
+ console = gr.Textbox(
267
+ label="Console",
268
+ lines=10,
269
+ interactive=False,
270
+ placeholder="Progress logs will appear here."
271
+ )
272
+
273
+ final_prompt = gr.Textbox(
274
+ label="Final Prompt",
275
+ lines=3,
276
+ interactive=False,
277
+ placeholder="Your optimized prompt will appear here."
278
+ )
279
+
280
+ submit_button.click(
281
+ rag_pipeline_ui,
282
+ inputs=[image_tags],
283
+ outputs=[console, final_prompt]
284
+ )
285
+
286
+ image_tags.submit(
287
+ rag_pipeline_ui,
288
+ inputs=[image_tags],
289
+ outputs=[console, final_prompt]
290
+ )
291
+
292
+ if __name__ == "__main__":
293
+ app.queue().launch(allowed_paths=[str(MASCOT_DIR)])
data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
docs/retrieval_contract.md ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Retrieval Contract -- Stage 2 (Retrieval Grounding / Candidate Generation)
2
+
3
+ Stage 2 performs **retrieval grounding** over a **closed vocabulary** of canonical e621-style tags.
4
+ It does not "tag images", and it does not do free-form generation. Its job is to produce a high-recall,
5
+ inspectable candidate pool for downstream **closed-set selection**.
6
+
7
+ ---
8
+
9
+ ## Inputs
10
+
11
+ - `rewrite_phrases: list[str]`
12
+ - Output of Stage 1 query rewriting (comma-separated "tag-shaped" phrases).
13
+ - Not canonical tags. Not underscored. High recall is preferred.
14
+
15
+ - `allow_nsfw_tags: bool`
16
+ - If false, filter out tags in the project's `nsfw_tags` set.
17
+
18
+ - `verbose: bool`
19
+ - If true, return per-phrase debug reports.
20
+
21
+ ---
22
+
23
+ ## Normalization and phrase expansion
24
+
25
+ 1) Normalize rewrite phrases for internal processing:
26
+ - lowercase
27
+ - strip leading/trailing whitespace
28
+ - collapse internal whitespace to a single space
29
+
30
+ 2) Treat the phrase list as a **set** (dedupe after normalization).
31
+
32
+ 3) **Head-noun expansion**:
33
+ - For each multi-token phrase, add its head noun (last token) as an additional phrase.
34
+ - Apply the same set semantics so duplicates are processed once.
35
+
36
+ Example:
37
+ - Input phrases: `["big shirt", "grey shirt"]`
38
+ - Final phrase set: `{"big shirt", "grey shirt", "shirt"}`
39
+
40
+ ---
41
+
42
+ ## Candidate generation per phrase (FastText neighbors + canonicalization)
43
+
44
+ For each phrase `p` in the final phrase set:
45
+
46
+ 1) Convert to lookup form:
47
+ - `lookup = p.replace(" ", "_")`
48
+
49
+ 2) Retrieve neighbors using FastText:
50
+ - `neighbors = fasttext.most_similar(lookup, topn=per_phrase_k)`
51
+ - Note: FastText neighbors may include alias tokens and other non-canonical strings.
52
+
53
+ 3) **Project neighbor tokens to canonical tags** (alias -> canonical):
54
+ - If a neighbor token is already a canonical tag (token is in `tag_counts` OR token has a TF-IDF row in `tag_to_row_index`), it maps to itself.
55
+ - Else if it is an alias, map it via `alias2tags[token]` (may map to multiple canonical tags).
56
+ - Else, drop it (not in closed vocabulary).
57
+
58
+ 4) **Deduplicate by canonical tag** within this phrase:
59
+ - Keep the canonical tag with the highest FastText similarity among all tokens that mapped to it.
60
+ - Record the token that achieved that max similarity as `alias_token` for verbose reporting ("best token wins").
61
+
62
+ 5) **Exact-match injection**:
63
+ - Project the phrase's own `lookup` through the same projection logic.
64
+ - For each canonical tag produced by that projection, inject it into the candidate set with:
65
+ - `score_fasttext = 1.0`
66
+ - `alias_token = lookup`
67
+ - This ensures the phrase canonical appears even though `most_similar()` often does not return the query token itself.
68
+
69
+ 6) Apply NSFW filtering (if `allow_nsfw_tags=False`):
70
+ - Drop candidate canonical tags that are present in `nsfw_tags`.
71
+
72
+ Result: for each phrase, we have a set of canonical candidate tags with:
73
+ - `score_fasttext`
74
+ - `alias_token` (token that produced the best FastText score for that canonical tag)
75
+
76
+ ---
77
+
78
+ ## Context similarity (TF-IDF -> SVD cosine)
79
+
80
+ Stage 2 computes one **query context vector** for the entire request:
81
+
82
+ 1) Build a pseudo TF-IDF vector from the **final phrase set** (deduped + head nouns):
83
+ - Convert each phrase to underscore form (same `lookup` rule).
84
+ - Terms that exist in the TF-IDF vocabulary (underscore lookups) contribute `(term_count * idf(term))`.
85
+ - OOV terms contribute nothing (but may be reported in verbose mode).
86
+
87
+ 2) Project to SVD space and L2-normalize:
88
+ - `query_vec = normalize(svd.transform(tfidf_vec))`
89
+
90
+ If the query vector has zero norm (no recognized TF-IDF terms), then `query_has_context = False` and:
91
+ - `score_context = None` for all candidates
92
+ - `score_combined = score_fasttext` (FastText-only)
93
+
94
+ If `query_has_context = True`, compute per-candidate cosine similarity when possible:
95
+ - For tags that have a TF-IDF/SVD row: `score_context_by_tag[tag] = dot(query_vec, reduced_matrix_norm[row])`
96
+ - For tags that lack a TF-IDF/SVD row: initial `score_context = None` (may be imputed per-phrase)
97
+
98
+ ### Missing context policy (per-phrase, q=0.10)
99
+
100
+ If `query_has_context = True` and a candidate tag has `score_context = None`:
101
+ - For that phrase, compute `default_context_for_phrase` as the 10th percentile (q=0.10) of the available (non-None) context scores among that phrase's candidates.
102
+ - If there are no available context scores for that phrase, `default_context_for_phrase = 0.0`.
103
+ - Impute missing context scores using `default_context_for_phrase` and mark:
104
+ - `context_imputed = True`
105
+ Otherwise:
106
+ - `context_imputed = False`
107
+
108
+ ---
109
+
110
+ ## Score fusion (FastText + Context)
111
+
112
+ Compute a fused score per phrase candidate:
113
+
114
+ - If `query_has_context = False`:
115
+ - `score_combined = score_fasttext`
116
+
117
+ - Else:
118
+ - `score_combined = (1 - context_weight) * score_fasttext + context_weight * score_context`
119
+ - (`score_context` may be imputed as described above)
120
+
121
+ ---
122
+
123
+ ## Per-phrase truncation and must-include rule
124
+
125
+ After scoring candidates for a phrase:
126
+ - Sort by `score_combined` descending.
127
+ - Keep top `per_phrase_final_k` (typically 10).
128
+
129
+ **Must-include rule (pinned exact phrase tags)**:
130
+ - Let `required_tags` be the canonical tag(s) produced by projecting the phrase's own `lookup` (`projected_lookup`).
131
+ - Each required tag must appear in that phrase's final top `per_phrase_final_k` list, even if its fused score would otherwise place it below the cutoff.
132
+ - If the list is full, evict the lowest-ranked tag that is *not* required.
133
+ - Note: `required_tags` may contain multiple canonicals if `alias2tags` maps a token to multiple tags.
134
+
135
+ This rule applies **only to the phrase's own required tags**. It does not inject tags into other phrases' lists.
136
+
137
+ ---
138
+
139
+ ## Merge across phrases (global candidate pool)
140
+
141
+ A canonical tag may appear in multiple per-phrase top-K lists. Stage 2 deduplicates tags into a single global record.
142
+
143
+ - `sources` is the union of phrases whose per-phrase lists contained the tag.
144
+ - `score_fasttext` is the maximum FastText score observed for the tag across those phrases.
145
+ - `score_context` is the maximum context cosine observed for the tag across those phrases (with `None` treated as missing).
146
+ - `score_combined` is the maximum fused score observed for the tag across those phrases.
147
+
148
+ Note:
149
+ - These maxima may come from different phrases; the global candidate row does not necessarily correspond to any single phrase's row.
150
+ - For tags with a TF-IDF row, `score_context` is phrase-invariant. Differences across phrases only arise for tags whose context score was imputed.
151
+
152
+ Finally:
153
+ - Sort global candidates by `score_combined` descending.
154
+ - Return top `global_k` candidates (and optionally all candidates if the app needs them).
155
+
156
+ ---
157
+
158
+ ## Output schema
159
+
160
+ ### Stage 2 return (non-verbose)
161
+ - `candidates: list[Candidate]` (ordered)
162
+ - `tag: str` (canonical)
163
+ - `score_combined: float`
164
+ - `score_fasttext: float | None`
165
+ - `score_context: float | None` (None only when `query_has_context=False` or when missing)
166
+ - `count: int | None`
167
+ - `sources: list[str]`
168
+
169
+ ### Optional per-phrase debug report (verbose)
170
+ For each phrase:
171
+ - `phrase: str`
172
+ - `normalized: str`
173
+ - `lookup: str`
174
+ - `tfidf_vocab: bool` (lookup is in TF-IDF vocabulary)
175
+ - `oov_terms: list[str]`
176
+ - `candidates: list[CandidateRow]` (top per-phrase list)
177
+ - `tag: str`
178
+ - `alias_token: str`
179
+ - `score_fasttext: float`
180
+ - `score_context: float | None`
181
+ - `score_combined: float`
182
+ - `context_imputed: bool`
183
+ - `count: int | None`
184
+
185
+ ---
186
+
187
+ ## Determinism and performance constraints
188
+
189
+ - Artifact loading is **lazy** (load-on-first-use, cached thereafter).
190
+ - No feature flags for old/new behavior: delete old code paths.
191
+ - Logging must be read-only and must not affect results.
192
+
193
+ ---
194
+
195
+ ## NSFW tag source
196
+
197
+ - `nsfw_tags` is sourced from `word_rating_probabilities.csv` with `NSFW_THRESHOLD=0.95` as implemented in `psq_rag.retrieval.state`.
docs/rewrite_contract.md ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stage 1 — Query Rewriting Contract
2
+
3
+ ## Purpose
4
+
5
+ Stage 1 (“Query Rewriting”) converts a free-form natural-language prompt into a
6
+ comma-separated list of short, tag-shaped phrases suitable for downstream
7
+ retrieval over a closed image-tag vocabulary.
8
+
9
+ This stage is not tagging, not normalization, and not validation.
10
+ Its sole role is to rewrite user intent into a retrieval-friendly surface form
11
+ with high recall.
12
+
13
+ ---
14
+
15
+ ## Inputs
16
+
17
+ - User prompt: an arbitrary string entered by the user.
18
+ - The input may include:
19
+ - natural language
20
+ - comma-separated phrases
21
+ - Stable-Diffusion-style parentheses and weights
22
+ - punctuation and spacing artifacts
23
+
24
+ No structural guarantees are assumed about the input.
25
+
26
+ ---
27
+
28
+ ## Pre-Rewrite Heuristics (Non-LLM)
29
+
30
+ Before the LLM rewrite is invoked, the system performs a lightweight heuristic
31
+ extraction:
32
+
33
+ - The prompt is split on "." and ","
34
+ - Segments with three or fewer whitespace-separated tokens are retained
35
+ - Case-insensitive deduplication is applied
36
+
37
+ This produces a small list of user-provided phrases that may later be appended
38
+ to the rewrite output for retrieval support.
39
+
40
+ This heuristic:
41
+ - is lossy
42
+ - is not authoritative
43
+ - exists only to preserve short explicit phrases if the rewrite fails or omits them
44
+
45
+ ---
46
+
47
+ ## Rewrite Mechanism
48
+
49
+ Stage 1 uses a single deterministic LLM call with:
50
+
51
+ - temperature = 0.0
52
+ - no retries
53
+ - no streaming
54
+ - no structured output enforcement
55
+
56
+ The system prompt instructs the model to:
57
+
58
+ - output a comma-separated list
59
+ - use short, literal, tag-shaped phrases
60
+ - preserve coherent multi-word visual concepts
61
+ - avoid inventing details
62
+ - avoid demographic inference
63
+ - avoid guessing identities
64
+
65
+ The LLM output is treated as plain text.
66
+
67
+ ---
68
+
69
+ ## Output Format
70
+
71
+ On success, Stage 1 returns:
72
+
73
+ - a single string
74
+ - containing comma-separated phrases
75
+ - with arbitrary spacing normalized
76
+ - truncated to a maximum of approximately 800 characters
77
+
78
+ No further parsing, validation, or canonicalization is applied at this stage.
79
+
80
+ The rewrite may:
81
+ - reorder concepts
82
+ - merge or split phrasing
83
+ - introduce additional generic visual concepts (e.g. "white background")
84
+
85
+ ---
86
+
87
+ ## Failure and Fallback Behavior
88
+
89
+ If the LLM call:
90
+
91
+ - errors
92
+ - produces a refusal-like response
93
+ - returns empty output
94
+
95
+ then Stage 1 returns an empty string.
96
+
97
+ In downstream stages, this empty rewrite may be supplemented by the heuristic
98
+ phrases extracted earlier, but Stage 1 itself does not attempt recovery.
99
+
100
+ ---
101
+
102
+ ## Explicit Non-Guarantees
103
+
104
+ Stage 1 does not guarantee that:
105
+
106
+ - output phrases correspond to known vocabulary tags
107
+ - phrases are unique
108
+ - phrases are canonicalized
109
+ - phrases are mutually exclusive
110
+ - all user concepts are preserved
111
+ - added concepts reflect ground truth
112
+
113
+ Stage 2 must not assume any of the above.
114
+
115
+ ---
116
+
117
+ ## Contract Boundary with Stage 2
118
+
119
+ Stage 1 guarantees only that:
120
+
121
+ - output is a comma-separated list of short phrases
122
+ - phrases are intended to be retrieval queries, not canonical tags
123
+ - output is deterministic for a given input
124
+
125
+ Stage 2 is responsible for:
126
+
127
+ - normalization
128
+ - deduplication
129
+ - head-noun expansion
130
+ - vocabulary grounding
131
+ - alias handling
132
+ - scoring and ranking
133
+
134
+ ---
135
+
136
+ ## Summary (Interview-Safe)
137
+
138
+ Stage 1 is a deterministic query-rewriting step that reshapes free-form text into
139
+ retrieval-friendly phrase queries. It intentionally favors recall and
140
+ surface-form alignment over correctness or canonicalization, delegating all
141
+ grounding and validation to later stages.
docs/stage3_contract.md ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ STAGE 3 CONTRACT: CLOSED-SET SELECTION
2
+
3
+ Purpose
4
+ -------
5
+ Stage 3 performs closed-set selection over the candidate set produced by Stage 2.
6
+ It must output only canonical tags drawn from the provided candidates.
7
+ No hallucinated or novel tags are permitted.
8
+
9
+ Stage 3 is not retrieval. Stage 2 already performs candidate generation and
10
+ retrieval grounding. Stage 3 is selection / reranking only.
11
+
12
+ Inputs
13
+ ------
14
+
15
+ 1) User prompt
16
+ - original_prompt: str
17
+ The user's original text prompt. This is the primary semantic signal used by
18
+ Stage 3.
19
+
20
+ 2) Candidate set (from Stage 2)
21
+ - candidates: List[Candidate]
22
+
23
+ Each Candidate corresponds to one canonical tag.
24
+
25
+ Required fields:
26
+ - tag: str
27
+ Canonical tag name (e621-style snake_case). Unique within this list.
28
+
29
+ - count: Optional[int]
30
+ Frequency/count from the tag corpus. Used only as a hint or ordering signal.
31
+
32
+ Optional fields (may be present but must not be required by Stage 3):
33
+ - score_fasttext: Optional[float]
34
+ - score_context: Optional[float]
35
+ - score_combined: Optional[float]
36
+ - alias_token: Optional[str] (debug / evidence only)
37
+ - sources: Optional[List[str]] (debug / evidence only)
38
+
39
+ Contract note:
40
+ Stage 3 must not rely on optional fields to function correctly.
41
+
42
+ 3) Selection mode parameters (system-controlled)
43
+ These are not user-facing.
44
+
45
+ - mode: "single_shot" | "chunked_map_union"
46
+
47
+ If mode == "chunked_map_union":
48
+ - chunk_size: int (e.g., 50–80)
49
+ - per_chunk_budget: int (soft cap, e.g., 10–20)
50
+
51
+ Optional:
52
+ - debug_rationale: bool (default false in production)
53
+
54
+ LLM-Facing Representation
55
+ -------------------------
56
+ Candidates are presented to the LLM as an indexed list per call.
57
+
58
+ For each call:
59
+ - Indices are local to that call: 1..N_local
60
+ - A mapping idx -> canonical tag is maintained by the system
61
+
62
+ Each candidate line should include:
63
+ - local index
64
+ - canonical tag
65
+ - optionally count
66
+
67
+ Example:
68
+ 27. blue_fur (count=12034)
69
+
70
+ Indices are not required to be stable across calls and must be mapped back
71
+ immediately after parsing.
72
+
73
+ Outputs
74
+ -------
75
+
76
+ Primary output:
77
+ - selected_tags: List[str]
78
+ Canonical tag names. Must be a subset of the provided candidate tags.
79
+
80
+ Optional outputs (recommended for development and smoke tests):
81
+ - why_by_tag: Dict[str, str]
82
+ Compact rationale code per selected tag (only if debug_rationale == true).
83
+
84
+ - stage3_diagnostics: Dict[str, Any]
85
+ Parse and validation statistics (for testing and analysis).
86
+
87
+ Per-Call LLM Output Schema
88
+ -------------------------
89
+ Each LLM call must return valid JSON of the following form:
90
+
91
+ {
92
+ "selections": [
93
+ { "i": 27, "why": "explicit" },
94
+ { "i": 6, "why": "strong_implied" }
95
+ ]
96
+ }
97
+
98
+ Fields:
99
+ - i: int
100
+ Local index within that call.
101
+
102
+ - why: str
103
+ Rationale code (required only if debug_rationale == true).
104
+
105
+ Allowed rationale codes:
106
+ - explicit
107
+ - strong_implied
108
+ - weak_implied
109
+ - style_or_meta
110
+ - other
111
+
112
+ Validation Rules
113
+ ----------------
114
+
115
+ Per-call validation:
116
+ - selections is a list
117
+ - i is an integer
118
+ - 1 <= i <= N_local
119
+ - indices are unique within the call
120
+ - if debug_rationale == true, why must be one of the allowed codes
121
+
122
+ Global validation (after mapping indices to tags):
123
+ - every selected tag must exist in the Stage 2 candidate set
124
+ - duplicates removed by canonical tag identity
125
+ - final selected_tags is the deterministic result of mapping and union
126
+
127
+ Policy note:
128
+ If NSFW tags are disallowed, Stage 2 must remove them. Stage 3 does not require
129
+ policy flags as input. Defense-in-depth checks are allowed but not required.
130
+
131
+ Chunking and Aggregation Behavior
132
+ --------------------------------
133
+
134
+ Single-shot mode:
135
+ - One LLM call over all candidates
136
+ - Output parsed, validated, and mapped
137
+
138
+ Chunked Map + Union mode (no LLM reduce):
139
+ - Split candidate list into chunks of size chunk_size
140
+ - For each chunk:
141
+ - enumerate locally 1..N_local
142
+ - run one LLM call
143
+ - parse and validate
144
+ - map indices to canonical tags immediately
145
+
146
+ - Aggregate across chunks by union on canonical tag:
147
+ - why_by_tag[tag] chosen by majority vote or first occurrence
148
+
149
+ No second LLM consolidation or pruning call is implied or required.
150
+
151
+ Ordering of Final Output
152
+ ------------------------
153
+ The final selected_tags list must be ordered deterministically using:
154
+ 1) descending why score (as defined by the system)
155
+ 2) tie-break by descending count
156
+
157
+ Smoke Test Requirements
158
+ -----------------------
159
+ Stage 3 smoke tests should report:
160
+ - JSON parse success rate
161
+ - invalid index rate
162
+ - duplicate index rate
163
+ - selection size distribution
164
+ - union size distribution (chunked mode)
165
+ - stability across repeated runs on identical input
166
+ - quality metrics vs ground truth where available:
167
+ precision / recall / F1 over tag sets
168
+
169
+ Smoke test results are used to empirically choose between single_shot and
170
+ chunked_map_union for typical candidate set sizes.
e621naturallanguagedataset.txt ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Dataset Card for furry-e621-sfw-7m-hq
2
+ Dataset Summary
3
+
4
+ This is 6.92 M captions of the images from the safe-for-work (SFW) split of e621 ("e926"). It extends to January 2023, before the widespread advent of machine learning images. It includes captions created by LLMs and a custom multilabel classifier along with CogVLM. There are 8 LLM (mistralai/Mistral-7B-v0.1) and 1 CogVLM (THUDM/CogVLM) captions per image.
5
+
6
+ Most captions are substantially larger than 77 tokens and are unsuitable for discrimination using current CLIP-based approaches.
7
+ Languages
8
+
9
+ The captions are in English.
10
+ Original Categorized Tags For LLM Captions
11
+
12
+ The tags were selected for safe-for-work attributes and were filtered down to approximately 7,000 tags. A multilabel classifier was created using DINOv2 giant (facebook/dinov2-giant) with the pooled output of the visual encoder. The classifier was trained with APL loss (gamma -4, -6, and -8) for 1000 epochs and the best model achieved an AP of 0.342 and F1 of 0.5576.
13
+
14
+ Of these tags, they were categorized manually to the follow labels:
15
+
16
+ animals_and_anthropomorphic_features
17
+ clothing_and_accessories
18
+ characters_and_gender
19
+ hairstyle
20
+ background_and_setting
21
+ number_of_characters
22
+ miscellaneous
23
+ actions_and_poses
24
+ colors
25
+ furniture_and_objects
26
+ body_and_body_parts
27
+ emotions_and_expressions
28
+
29
+ Data Instances
30
+
31
+ An example of a row:
32
+
33
+ {
34
+ "id": 3556547,
35
+ "md5": "1ae8668745b8fefb83e79d0c77e31a4e",
36
+ "caption_cogvlm": "The image depicts an anthropomorphic creature, possibly a possum, sitting at a desk in front of a computer. The creature has a somewhat disgruntled expression, with fur that is a mix of black and white. The background is muted, with a grayish tone, and the desk has a yellowish hue. The creature is wearing a black shirt and is seen typing on the keyboard. The overall mood of the image is somber and introspective.",
37
+ "caption_llm_0": "a solo female marsupial, specifically a possum, sitting on a chair in front of a simple background. She has black hair and is wearing clothing accessories like a laptop and keyboard. The possum has anthropomorphic features such as bipedalism, snout, and whiskers. She is holding an object while painting or typing at the computer desk with her fingers. The background consists of sky and clouds with countershading present.",
38
+ "caption_llm_1": "A solo female marsupial, specifically a possum, sitting at a desk with her black-furred body and white-furred face. she is wearing black clothing and a shirt, as well as having pink markings on her nose. the possum is holding an object while painting or typing on the computer. her hair is black, and she has grey fur on her body with brown fur accents. she has breasts and tufts of hair on her head. the background consists of furniture such as a chair, table, laptop, monitor, keyboard and container in various shades of white or grey colors.",
39
+ "caption_llm_2": "a female possum, sitting on a chair and looking at the viewer. She has black hair and is wearing a black shirt over her white body with two-tone fur. The background is simple, with a sky visible through the window. She's holding an object while painting at her desk, which also contains a laptop, monitor, keyboard, and computer mouse. Her nose is pink and she has narrowed eyes as she reacts to something in the scene.",
40
+ "caption_llm_3": "A solo female marsupial, specifically a possum, sitting on a chair in front of a computer. the possum has white fur on its body and black fur on its face, creating two-tone fur. it is wearing black clothing and has pink markings on its nose. the background shows furniture such as a desk, table, and container with various objects like laptop, monitor, computer mouse and keyboard. the possum's hair is black while it looks at the viewer with narrowed eyes and half-closed eyes. it also holds an object in one hand while painting with the other hand or typing at the computer.",
41
+ "caption_llm_4": "a solo, female marsupial possum sitting at a desk, typing on a laptop while wearing a black shirt and white fur. The background is detailed and set indoors. The animal has humanoid hands, bipedal stance, chest tufts, and pink nose. It's clothed in black clothing with black ears and topwear. The possum's body has two-tone fur - one color being white and the other being black. A computer or table can be seen in the background as part of the furniture setting.",
42
+ "caption_llm_5": "A solo, female marsupial, specifically a possum, sitting and looking at an object while typing. the animal has humanoid hands and is clothed in black clothing with black ears and a black shirt. its fur is two-toned with white body and white fur, as well as pink nose. the background is detailed and set indoors.",
43
+ "caption_llm_6": "a solo, female marsupial possum sitting at a desk, typing on a laptop while wearing a shirt and displaying humanoid hands. The animal has bags under its eyes, eye bags, and tufts of fur on its chest. It also shows teeth when it frowns or displays anger with clenched teeth. The possum's hair is visible in the artwork.",
44
+ "caption_llm_7": "A solo, female marsupial possum sitting at a desk, typing on a laptop while wearing black clothing and black topwear. the possum has bipedal humanoid features, chest tufts, bags under its eyes, and breasts. its fur is two-toned with black ears and white body/fur. it also has visible teeth and pink nose. the background includes furniture such as a computer and table.",
45
+ "tags_synthetic_categorized": "{\"animals_and_anthropomorphic_features\":[\"anthro\",\"biped\",\"feral\",\"snout\",\"whiskers\"],\"number_of_characters\":[\"solo\"],\"clothing_and_accessories\":[\"clothing\",\"fur\",\"topwear\",\"clothed\",\"shirt\"],\"characters_and_gender\":[\"female\"],\"furniture_and_objects\":[\"computer\",\"furniture\",\"table\",\"laptop\",\"chair\",\"desk\",\"container\",\"computer_mouse\",\"keyboard\",\"monitor\"],\"colors\":[\"white_body\",\"white_fur\",\"two_tone_fur\",\"black_body\",\"black_fur\",\"two_tone_body\",\"black_clothing\",\"black_topwear\",\"black_shirt\",\"pink_nose\",\"grey_body\",\"grey_fur\",\"brown_body\",\"brown_fur\",\"black_nose\"],\"actions_and_poses\":[\"sitting\",\"looking_at_viewer\",\"holding_object\",\"painting\",\"typing\",\"standing\",\"on_chair\",\"looking_at_object\"],\"hairstyle\":[\"hair\",\"black_hair\"],\"background_and_setting\":[\"inside\",\"simple_background\",\"outside\",\"detailed_background\",\"sky\",\"cloud\",\"countershading\"],\"body_and_body_parts\":[\"breasts\",\"tuft\",\"teeth\",\"fingers\",\"markings\",\"5_fingers\",\"eyebrows\",\"arm_support\",\"eye_bags\"],\"miscellaneous\":[\"text\"],\"emotions_and_expressions\":[\"open_mouth\",\"smile\",\"narrowed_eyes\",\"reaction_image\",\"half-closed_eyes\"],\"species_or_animal_type\":[\"didelphid\",\"mammal\",\"virginia_opossum\",\"marsupial\"]}\r\n",
46
+ "tags_ground_truth_categorized": "{\"emotions_and_expressions\":[\"angry\",\"clenched_teeth\",\"frown\",\"reaction_image\",\"teeth_showing\"],\"animals_and_anthropomorphic_features\":[\"anthro\",\"biped\",\"chest_tuft\",\"humanoid_hands\"],\"body_and_body_parts\":[\"bags_under_eyes\",\"breasts\",\"eye_bags\",\"teeth\",\"teeth_visible\",\"tuft\"],\"colors\":[\"black_clothing\",\"black_ears\",\"black_shirt\",\"black_topwear\",\"pink_nose\",\"two_tone_body\",\"two_tone_fur\",\"white_body\",\"white_fur\"],\"clothing_and_accessories\":[\"clothed\",\"clothing\",\"fur\",\"shirt\",\"t-shirt\",\"topwear\"],\"furniture_and_objects\":[\"computer\",\"desk\",\"furniture\",\"laptop\",\"table\"],\"background_and_setting\":[\"detailed_background\",\"inside\"],\"characters_and_gender\":[\"female\"],\"hairstyle\":[\"hair\"],\"actions_and_poses\":[\"looking_at_object\",\"sitting\",\"typing\"],\"number_of_characters\":[\"solo\"],\"species_or_animal_type\":[\"mammal\",\"marsupial\",\"possum\"]}\r\n",
47
+ }
48
+
49
+ LLM-derived Captions
50
+
51
+ The caption_llm_x field was produced with the following prompt using the mistralai/Mistral-7B-v0.1 weights:
52
+
53
+ Please make a detailed description, one paragraph long, of the image using this JSON of categorized tags:
54
+
55
+ {{ tags }}
56
+
57
+ For every nth image where n was odd, the text "artwork of {{ characters }}." was appended for all characters with >= 10 images.
58
+
59
+ For the first 1-4 captions, synthetic tags were used. For the last 5-8 captions, ground truth tags were used.
60
+
61
+ For every caption, two categories were dropped out of the categorized tags each time (excluding species) to force the LLM to focus on different aspects of the image.
62
+
63
+ For a small number of images, LLM captions were not computed. These are left as empty strings for these images.
64
+ CogVLM-derived Captions
65
+
66
+ The caption_cogvlm field was produced with the following prompt using the THUDM/CogVLM weights:
67
+
68
+ Please make a detailed description, one paragraph long, of the image using this JSON of categorized tags:
69
+
70
+ {{ tags }}
71
+
72
+ The tags provided were the ground truth, categorized tags.
73
+
74
+ CogVLM captions often display repetitive prefixes. You can remove them with:
75
+
76
+ REPEATED_OPENINGS = [
77
+ ('The image showcases ', ''),
78
+ ('The image portrays ', ''),
79
+ ('The image appears to be ', ''),
80
+ ('The image is ', ''),
81
+ ('The image depicts ', ''),
82
+ ('The image features ', ''),
83
+ ('This image showcases ', ''),
84
+ ('This image portrays ', ''),
85
+ ('This image appears to be ', ''),
86
+ ('This image is ', ''),
87
+ ('This image depicts ', ''),
88
+ ('This image features ', ''),
89
+ ('In this picture, ', ''),
90
+ ('In this artwork, ', 'Artwork of '),
91
+ ('In this illustration, ', 'Illustration of '),
92
+ ('In this depiction, ', ''),
93
+ ('In this piece, ', ''),
94
+ ('In this image, ', ''),
95
+ ('In this art piece, ', 'Art of '),
96
+ ('In this scene, ', ''),
97
+ ]
98
+ def postprocess_caption(caption: str):
99
+ for often_repeated, replacer in REPEATED_OPENINGS:
100
+ if often_repeated in caption:
101
+ caption = caption.replace(often_repeated, replacer, 1).capitalize()
102
+ return caption
103
+
104
+ Data Splits
105
+ train
106
+ furry-e621-sfw-7m-hq 768859
107
+ Dataset Creation
108
+ Source Data
109
+
110
+ Collected from e621 according to their rate-limiting instructions on archiving content.
111
+ Discussion of Biases
112
+
113
+ The captions are biased to the results of the multilabel classifier and the CogVLM model.
114
+ Known Limitations
115
+
116
+ The LLM derived captions commonly hallucinate text and may contain a small amount of captions that are corrupted by repeating tokens or tag lists. The CogVLM derived captions have more correct OCR but may also occasionally hallucinate text or small details.
117
+
118
+ For a small number of images, LLM captions were not computed. These are left as empty strings for these images.
119
+
120
+ While the images are labeled as "safe", they were not inspected for safety and may contain inappropriate subject matter.
121
+ Additional Information
122
+ Dataset Curators
123
+
124
+ Caption Emporium
125
+ Downloading the Images
126
+
127
+ Please refer to this issue.
128
+ Licensing Information
129
+
130
+ The dataset is available under the Creative Commons ShareAlike (CC BY-SA 4.0).
131
+ Citation Information
132
+
133
+ @misc{furry-e621-sfw-7m-hq,
134
+ author = { Caption Emporium },
135
+ title = {furry-e621-sfw-7m-hq},
136
+ year = {2024},
137
+ publisher = {Huggingface},
138
+ journal = {Huggingface repository},
139
+ howpublished = {\url{https://huggingface.co/datasets/CaptionEmporium/furry-e621-sfw-7m-hq}},
140
+ }
fluffyrock_3m.csv ADDED
The diff for this file is too large to render. See raw diff
 
mascotimages/transparentsquirrel.png ADDED

Git LFS Details

  • SHA256: 8e18321c9051b82ab18932ef9ed4052915659b83ef2065050600d0c06bddb9e7
  • Pointer size: 131 Bytes
  • Size of remote file: 257 kB
predict_all_tags_from_dump.ipynb ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "55c95870",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import csv\n",
11
+ "import gzip\n",
12
+ "from math import log\n",
13
+ "from collections import Counter\n",
14
+ "from sys import maxsize\n",
15
+ "import numpy as np\n",
16
+ "import joblib\n",
17
+ "from collections import OrderedDict\n",
18
+ "from sklearn.metrics.pairwise import cosine_similarity\n",
19
+ "from collections import defaultdict\n",
20
+ "import sys\n",
21
+ "from scipy.sparse import dok_matrix\n",
22
+ "from sklearn.preprocessing import normalize\n",
23
+ "from sklearn.decomposition import TruncatedSVD\n",
24
+ "\n",
25
+ "\n",
26
+ "\n",
27
+ "posts_file = 'posts-2024-04-14.csv.gz'\n",
28
+ "fluffyrock_tags_list_file = 'fluffyrock_3m.csv'\n",
29
+ "\n",
30
+ "\n",
31
+ "def extract_artist_names(file_path):\n",
32
+ " \"\"\"\n",
33
+ " Extract artist names from a CSV file where each row contains tag information,\n",
34
+ " and the first column contains the tag's name. Artist tags start with 'by_'.\n",
35
+ "\n",
36
+ " :param file_path: Path to the CSV file\n",
37
+ " :return: A set containing artist names without the 'by_' prefix\n",
38
+ " \"\"\"\n",
39
+ " artists = set()\n",
40
+ "\n",
41
+ " # Open the CSV file and read it\n",
42
+ " with open(file_path, newline='', encoding='utf-8') as csvfile:\n",
43
+ " reader = csv.reader(csvfile)\n",
44
+ " \n",
45
+ " # Iterate over each row in the CSV file\n",
46
+ " for row in reader:\n",
47
+ " tag_name = row[0] # Assuming the first column contains the tag names\n",
48
+ " if tag_name.startswith('by_'):\n",
49
+ " # Strip 'by_' from the start of the tag name and add it to the set\n",
50
+ " artist_name = tag_name[3:] # Remove the first three characters 'by_'\n",
51
+ " artists.add(tag_name)\n",
52
+ "\n",
53
+ " return artists\n",
54
+ "\n",
55
+ "\n",
56
+ "def build_tag_list(tags, e621_rating_character, fav_count, artist_names):\n",
57
+ " results = []\n",
58
+ " \n",
59
+ " #score\n",
60
+ " score_value = min(1.0, (log(int(fav_count)+1) / 10))\n",
61
+ " rounded_score_value = round(score_value * 10)\n",
62
+ " results.append(f\"score: {rounded_score_value}\")\n",
63
+ " \n",
64
+ " #rating\n",
65
+ " results.append(\"rating:\" + e621_rating_character)\n",
66
+ " \n",
67
+ " #regular tags and artists\n",
68
+ " for tag in tags:\n",
69
+ " if tag in artist_names:\n",
70
+ " results.append(\"by_\" + tag)\n",
71
+ " else:\n",
72
+ " results.append(tag)\n",
73
+ " return results\n",
74
+ "\n",
75
+ "\n",
76
+ "def read_csv_as_dict(file_path):\n",
77
+ " \"\"\"\n",
78
+ " Generator function to read a gzipped CSV file and yield each row as a dictionary\n",
79
+ " where keys are the column names and values are the data in each column.\n",
80
+ "\n",
81
+ " :param file_path: Path to the .csv.gz file\n",
82
+ " \"\"\"\n",
83
+ " \n",
84
+ " #counter=0\n",
85
+ " with gzip.open(file_path, 'rt', newline='', encoding='utf-8') as gz_file:\n",
86
+ " csv.field_size_limit(1000000)\n",
87
+ " reader = csv.DictReader(gz_file)\n",
88
+ " for row in reader:\n",
89
+ " #counter += 1\n",
90
+ " #if counter % 100 == 0:\n",
91
+ " yield row\n",
92
+ " \n",
93
+ " \n",
94
+ "def process_tags_from_csv(file_path, artist_names):\n",
95
+ " \"\"\"\n",
96
+ " Generator function that reads rows from a CSV file, processes each row to extract and\n",
97
+ " build tag lists, and yields these lists one at a time.\n",
98
+ "\n",
99
+ " :param file_path: The path to the gzipped CSV file.\n",
100
+ " :param artist_names: A set containing all artist names for tag processing.\n",
101
+ " :return: Yields lists of tags for each row.\n",
102
+ " \"\"\"\n",
103
+ " for row in read_csv_as_dict(file_path):\n",
104
+ " base_tags = row['tag_string'].split(' ')\n",
105
+ " rating_character = row['rating']\n",
106
+ " fav_count = row['fav_count']\n",
107
+ " all_tags = build_tag_list(base_tags, rating_character, fav_count, artist_names)\n",
108
+ " yield all_tags\n",
109
+ " \n",
110
+ " \n",
111
+ "def construct_pseudo_vector(pseudo_doc_terms, idf_loaded, tag_to_column_loaded):\n",
112
+ " # Initialize a vector of zeros with the length of the term_to_index mapping\n",
113
+ " pseudo_vector = np.zeros(len(tag_to_column_loaded))\n",
114
+ " \n",
115
+ " # Fill in the vector for terms in the pseudo document\n",
116
+ " for term in pseudo_doc_terms:\n",
117
+ " if term in tag_to_column_loaded:\n",
118
+ " index = tag_to_column_loaded[term]\n",
119
+ " pseudo_vector[index] = idf_loaded.get(term, 0)\n",
120
+ " \n",
121
+ " # Return the vector as a 2D array for compatibility with SVD transform\n",
122
+ " return pseudo_vector.reshape(1, -1)"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": null,
128
+ "id": "0a9becfd",
129
+ "metadata": {},
130
+ "outputs": [],
131
+ "source": [
132
+ "all_artist_names = extract_artist_names(fluffyrock_tags_list_file)\n",
133
+ "\n",
134
+ "tag_count = Counter()\n",
135
+ "min_occurrences = 200\n",
136
+ " \n",
137
+ "for all_tags in process_tags_from_csv(posts_file, all_artist_names):\n",
138
+ " tag_count.update(all_tags)\n",
139
+ " \n",
140
+ "\n",
141
+ "# Apply the counting logic from the first code snippet\n",
142
+ "sorted_tags = tag_count.most_common()\n",
143
+ "filtered_tags = [tag for tag, count in sorted_tags if count >= min_occurrences]\n",
144
+ "\n",
145
+ "# Print tag counts before and after filtering\n",
146
+ "print(\"Tag count before filtering: \", len(tag_count))\n",
147
+ "print(\"Tag count after filtering: \", len(filtered_tags))"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": null,
153
+ "id": "56f8d7cd",
154
+ "metadata": {},
155
+ "outputs": [],
156
+ "source": [
157
+ "# Initialize a dictionary to hold the co-occurrences for each tag in filtered_tags\n",
158
+ "# Using a nested defaultdict for automatic handling of missing keys\n",
159
+ "pseudo_docs = defaultdict(lambda: defaultdict(int))\n",
160
+ "\n",
161
+ "# Number of tags processed\n",
162
+ "total_rows_processed = 0\n",
163
+ "\n",
164
+ "# Read each row and process the tags\n",
165
+ "for all_tags in process_tags_from_csv(posts_file, all_artist_names):\n",
166
+ " # Filter the tags in the current list to include only those in filtered_tags\n",
167
+ " filtered_tag_list = [tag for tag in all_tags if tag in filtered_tags]\n",
168
+ " \n",
169
+ " # For each tag in the filtered list\n",
170
+ " for tag in filtered_tag_list:\n",
171
+ " # For each co-occurring tag in the same list\n",
172
+ " for co_occur_tag in filtered_tag_list:\n",
173
+ " if co_occur_tag != tag:\n",
174
+ " pseudo_docs[tag][co_occur_tag] += 1\n",
175
+ "\n",
176
+ " # Counting total tags processed for progress monitoring\n",
177
+ " total_rows_processed += 1\n",
178
+ " if total_rows_processed % 10000 == 0:\n",
179
+ " print(f\"Processed {total_rows_processed} rows\", file=sys.stderr)\n",
180
+ "\n",
181
+ "print(\"Processing complete.\")\n"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": null,
187
+ "id": "b1d011a5",
188
+ "metadata": {},
189
+ "outputs": [],
190
+ "source": [
191
+ "# Number of pseudo-documents\n",
192
+ "N = len(pseudo_docs)\n",
193
+ "\n",
194
+ "# Calculate TF and DF\n",
195
+ "tf = {}\n",
196
+ "df = {}\n",
197
+ "for doc, terms in pseudo_docs.items():\n",
198
+ " tf[doc] = {}\n",
199
+ " total_terms = sum(terms.values())\n",
200
+ " for term, count in terms.items():\n",
201
+ " tf[doc][term] = count / total_terms # Term Frequency\n",
202
+ " df[term] = df.get(term, 0) + 1 # Document Frequency\n",
203
+ " \n",
204
+ "# Ensure all terms are indexed\n",
205
+ "all_terms = set(df.keys())\n",
206
+ "term_to_column_index = {term: idx for idx, term in enumerate(all_terms)}\n",
207
+ "\n",
208
+ "# Calculate IDF\n",
209
+ "idf = {term: log((N + 1) / (df_val + 1)) for term, df_val in df.items()} # Adding 1 to prevent division by zero\n",
210
+ "\n",
211
+ "# Initialize the TF-IDF matrix\n",
212
+ "tfidf_matrix = dok_matrix((N, len(df)), dtype=float)\n",
213
+ "\n",
214
+ "# Mapping of tags to matrix rows\n",
215
+ "tag_to_row = {tag: idx for idx, tag in enumerate(pseudo_docs)}\n",
216
+ "\n",
217
+ "# Compute TF-IDF and fill the matrix\n",
218
+ "for doc, terms in tf.items():\n",
219
+ " row_idx = tag_to_row[doc]\n",
220
+ " for term, tf_val in terms.items():\n",
221
+ " col_idx = term_to_column_index[term] # Use term_to_index for column indexing\n",
222
+ " tfidf_matrix[row_idx, col_idx] = tf_val * idf[term]\n",
223
+ "\n",
224
+ "# Convert to CSR format for efficient row slicing\n",
225
+ "tfidf_matrix = tfidf_matrix.tocsr()\n",
226
+ "\n",
227
+ "print(\"TF-IDF matrix shape:\", tfidf_matrix.shape)\n"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": null,
233
+ "id": "b098a5fb",
234
+ "metadata": {},
235
+ "outputs": [],
236
+ "source": [
237
+ "# Choose the number of components for the reduced dimensionality\n",
238
+ "n_components = 300 # For example, reducing to 300 dimensions\n",
239
+ "\n",
240
+ "# Initialize the TruncatedSVD object\n",
241
+ "svd = TruncatedSVD(n_components=n_components, random_state=42)\n",
242
+ "\n",
243
+ "# Fit and transform the TF-IDF matrix\n",
244
+ "reduced_matrix = svd.fit_transform(tfidf_matrix)\n",
245
+ "\n",
246
+ "# 'reduced_matrix' now has a shape of (8500, n_components), e.g., (8500, 300)"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "code",
251
+ "execution_count": null,
252
+ "id": "023ae26f",
253
+ "metadata": {},
254
+ "outputs": [],
255
+ "source": []
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "id": "06ec21c4",
261
+ "metadata": {},
262
+ "outputs": [],
263
+ "source": [
264
+ "# Step 1: Construct TF vector for the pseudo-document\n",
265
+ "pseudo_doc_terms = [\"female\"]\n",
266
+ "pseudo_tfidf_vector = construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index)\n",
267
+ "\n",
268
+ "# Assuming 'tfidf_matrix' is your original TF-IDF matrix and 'reduced_matrix' is obtained from Truncated SVD\n",
269
+ "# 'pseudo_tfidf_vector' is the TF-IDF vector for your pseudo-document, constructed as previously discussed\n",
270
+ "\n",
271
+ "# For the original TF-IDF matrix\n",
272
+ "# Compute cosine similarities\n",
273
+ "cosine_similarities_full = cosine_similarity(pseudo_tfidf_vector, tfidf_matrix).flatten()\n",
274
+ "print(\"Cosine similarities (full matrix):\", cosine_similarities_full)\n",
275
+ "# Identify the indices of the top 10 most similar tags\n",
276
+ "top_indices_full = np.argsort(cosine_similarities_full)[-10:][::-1]\n",
277
+ "\n",
278
+ "# For the reduced matrix\n",
279
+ "# Reduce the dimensionality of the pseudo-document vector\n",
280
+ "# Before calculating similarities, print the TF-IDF vectors\n",
281
+ "print(\"Pseudo TF-IDF vector:\", pseudo_tfidf_vector)\n",
282
+ "reduced_pseudo_vector = svd.transform(pseudo_tfidf_vector)\n",
283
+ "print(\"Reduced pseudo-document vector:\", reduced_pseudo_vector)\n",
284
+ "\n",
285
+ "# Compute cosine similarities in the reduced space\n",
286
+ "cosine_similarities_reduced = cosine_similarity(reduced_pseudo_vector, reduced_matrix).flatten()\n",
287
+ "print(\"Cosine similarities (reduced matrix):\", cosine_similarities_reduced)\n",
288
+ "\n",
289
+ "\n",
290
+ "# Identify the indices of the top 10 most similar tags in the reduced space, sorted from most to least similar\n",
291
+ "top_indices_reduced = np.argsort(cosine_similarities_reduced)[-10:][::-1]\n",
292
+ "\n",
293
+ "\n",
294
+ "# Convert indices to tag names using the inverse of your 'tag_to_row' mapping\n",
295
+ "# Printing the tag to index and index to tag mappings\n",
296
+ "print(\"tag_to_row mapping (partial):\", dict(list(tag_to_row.items())[:12])) # Print only first 10 for brevity\n",
297
+ "row_to_tag = {idx: tag for tag, idx in tag_to_row.items()}\n",
298
+ "print(\"row_to_tag mapping (partial):\", dict(list(row_to_tag.items())[:12]))\n",
299
+ "\n",
300
+ "# Generate lists of tags with their corresponding similarity scores\n",
301
+ "top_tags_full = [(row_to_tag[idx], cosine_similarities_full[idx]) for idx in top_indices_full]\n",
302
+ "top_tags_reduced = [(row_to_tag[idx], cosine_similarities_reduced[idx]) for idx in top_indices_reduced]\n",
303
+ "\n",
304
+ "# Output the results with scores\n",
305
+ "print(\"Most similar tags (Full Matrix):\")\n",
306
+ "for tag, score in top_tags_full:\n",
307
+ " print(f\"{tag}: {score:.4f}\")\n",
308
+ "\n",
309
+ "print(\"Most similar tags (Reduced Matrix):\")\n",
310
+ "for tag, score in top_tags_reduced:\n",
311
+ " print(f\"{tag}: {score:.4f}\")\n"
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "code",
316
+ "execution_count": null,
317
+ "id": "91753fa3",
318
+ "metadata": {},
319
+ "outputs": [],
320
+ "source": [
321
+ "#Save the model to a file\n",
322
+ "\n",
323
+ "# Package necessary components\n",
324
+ "components_to_save = {\n",
325
+ " 'idf': idf,\n",
326
+ " 'tag_to_column_index': term_to_column_index,\n",
327
+ " 'row_to_tag': row_to_tag, \n",
328
+ " 'reduced_matrix': reduced_matrix,\n",
329
+ " 'svd_model': svd\n",
330
+ "}\n",
331
+ "\n",
332
+ "# Save the components into a file\n",
333
+ "joblib.dump(components_to_save, 'components_file418.joblib')"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": null,
339
+ "id": "2e08dc1a",
340
+ "metadata": {},
341
+ "outputs": [],
342
+ "source": []
343
+ },
344
+ {
345
+ "cell_type": "code",
346
+ "execution_count": 3,
347
+ "id": "d066db2f",
348
+ "metadata": {},
349
+ "outputs": [
350
+ {
351
+ "name": "stdout",
352
+ "output_type": "stream",
353
+ "text": [
354
+ "Most similar tags (Reduced Matrix):\n",
355
+ "nameless_(arbuzbudesh): 0.0000\n",
356
+ "knotted_dildo: 0.0000\n",
357
+ "black_legs: 0.0000\n",
358
+ "disguise: 0.0000\n",
359
+ "lineup: 0.0000\n",
360
+ "olympics: 0.0000\n",
361
+ "burping: 0.0000\n",
362
+ "pink_collar: 0.0000\n",
363
+ "team_rocket: 0.0000\n",
364
+ "studded_bracelet: 0.0000\n"
365
+ ]
366
+ }
367
+ ],
368
+ "source": [
369
+ "#Reload and test file\n",
370
+ "\n",
371
+ "# Load the saved components from the joblib file\n",
372
+ "components = joblib.load('tf_idf_files_418_updated.joblib')\n",
373
+ "\n",
374
+ "# Extract necessary components\n",
375
+ "idf = components['idf']\n",
376
+ "term_to_column_index = components['tag_to_column_index']\n",
377
+ "row_to_tag = components['row_to_tag']\n",
378
+ "reduced_matrix = components['reduced_matrix']\n",
379
+ "svd = components['svd_model']\n",
380
+ "\n",
381
+ "# Construct the TF-IDF vector for \"domestic_dog\"\n",
382
+ "pseudo_tfidf_vector = construct_pseudo_vector(\"blue_(jurassic_world)\", idf, term_to_column_index)\n",
383
+ "\n",
384
+ "# Reduce the dimensionality of the pseudo-document vector for the reduced matrix\n",
385
+ "reduced_pseudo_vector = svd.transform(pseudo_tfidf_vector)\n",
386
+ "\n",
387
+ "# Compute cosine similarities in the reduced space\n",
388
+ "cosine_similarities_reduced = cosine_similarity(reduced_pseudo_vector, reduced_matrix).flatten()\n",
389
+ "\n",
390
+ "# Sort the indices by descending cosine similarity\n",
391
+ "top_indices_reduced = np.argsort(cosine_similarities_reduced)[::-1][:10]\n",
392
+ "\n",
393
+ "# Display the most similar tags in the reduced matrix with their scores\n",
394
+ "print(\"Most similar tags (Reduced Matrix):\")\n",
395
+ "for idx in top_indices_reduced:\n",
396
+ " tag = row_to_tag[idx]\n",
397
+ " score = cosine_similarities_reduced[idx]\n",
398
+ " print(f\"{tag}: {score:.4f}\")\n"
399
+ ]
400
+ },
401
+ {
402
+ "cell_type": "code",
403
+ "execution_count": null,
404
+ "id": "ddea5f32",
405
+ "metadata": {},
406
+ "outputs": [],
407
+ "source": []
408
+ },
409
+ {
410
+ "cell_type": "code",
411
+ "execution_count": null,
412
+ "id": "74897a5c",
413
+ "metadata": {},
414
+ "outputs": [],
415
+ "source": []
416
+ },
417
+ {
418
+ "cell_type": "code",
419
+ "execution_count": null,
420
+ "id": "c0c5b32d",
421
+ "metadata": {},
422
+ "outputs": [],
423
+ "source": []
424
+ },
425
+ {
426
+ "cell_type": "code",
427
+ "execution_count": null,
428
+ "id": "9ff9a331",
429
+ "metadata": {},
430
+ "outputs": [],
431
+ "source": []
432
+ },
433
+ {
434
+ "cell_type": "code",
435
+ "execution_count": null,
436
+ "id": "91c66b57",
437
+ "metadata": {},
438
+ "outputs": [],
439
+ "source": []
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "execution_count": null,
444
+ "id": "a830c6cf",
445
+ "metadata": {},
446
+ "outputs": [],
447
+ "source": []
448
+ },
449
+ {
450
+ "cell_type": "code",
451
+ "execution_count": null,
452
+ "id": "4cdc98f0",
453
+ "metadata": {},
454
+ "outputs": [],
455
+ "source": []
456
+ },
457
+ {
458
+ "cell_type": "code",
459
+ "execution_count": null,
460
+ "id": "150d66f3",
461
+ "metadata": {},
462
+ "outputs": [],
463
+ "source": []
464
+ },
465
+ {
466
+ "cell_type": "code",
467
+ "execution_count": null,
468
+ "id": "337b1f65",
469
+ "metadata": {},
470
+ "outputs": [],
471
+ "source": []
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "execution_count": null,
476
+ "id": "34d2fde1",
477
+ "metadata": {},
478
+ "outputs": [],
479
+ "source": []
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": null,
484
+ "id": "9fc197d8",
485
+ "metadata": {},
486
+ "outputs": [],
487
+ "source": []
488
+ },
489
+ {
490
+ "cell_type": "code",
491
+ "execution_count": null,
492
+ "id": "bfa9c299",
493
+ "metadata": {},
494
+ "outputs": [],
495
+ "source": []
496
+ },
497
+ {
498
+ "cell_type": "code",
499
+ "execution_count": null,
500
+ "id": "551a8453",
501
+ "metadata": {},
502
+ "outputs": [],
503
+ "source": []
504
+ },
505
+ {
506
+ "cell_type": "code",
507
+ "execution_count": null,
508
+ "id": "0dcdeb9e",
509
+ "metadata": {},
510
+ "outputs": [],
511
+ "source": []
512
+ },
513
+ {
514
+ "cell_type": "code",
515
+ "execution_count": null,
516
+ "id": "537c9e26",
517
+ "metadata": {},
518
+ "outputs": [],
519
+ "source": []
520
+ },
521
+ {
522
+ "cell_type": "code",
523
+ "execution_count": null,
524
+ "id": "aa873abf",
525
+ "metadata": {},
526
+ "outputs": [],
527
+ "source": []
528
+ },
529
+ {
530
+ "cell_type": "code",
531
+ "execution_count": null,
532
+ "id": "41aca76f",
533
+ "metadata": {},
534
+ "outputs": [],
535
+ "source": []
536
+ },
537
+ {
538
+ "cell_type": "code",
539
+ "execution_count": null,
540
+ "id": "36a3ae96",
541
+ "metadata": {},
542
+ "outputs": [],
543
+ "source": []
544
+ },
545
+ {
546
+ "cell_type": "code",
547
+ "execution_count": null,
548
+ "id": "fb59bac3",
549
+ "metadata": {},
550
+ "outputs": [],
551
+ "source": []
552
+ },
553
+ {
554
+ "cell_type": "code",
555
+ "execution_count": null,
556
+ "id": "39c87db9",
557
+ "metadata": {},
558
+ "outputs": [],
559
+ "source": []
560
+ },
561
+ {
562
+ "cell_type": "code",
563
+ "execution_count": null,
564
+ "id": "1646e731",
565
+ "metadata": {},
566
+ "outputs": [],
567
+ "source": []
568
+ },
569
+ {
570
+ "cell_type": "code",
571
+ "execution_count": null,
572
+ "id": "99f95d09",
573
+ "metadata": {},
574
+ "outputs": [],
575
+ "source": []
576
+ },
577
+ {
578
+ "cell_type": "code",
579
+ "execution_count": null,
580
+ "id": "9d6a67c2",
581
+ "metadata": {},
582
+ "outputs": [],
583
+ "source": []
584
+ },
585
+ {
586
+ "cell_type": "code",
587
+ "execution_count": null,
588
+ "id": "32acbfd7",
589
+ "metadata": {},
590
+ "outputs": [],
591
+ "source": []
592
+ },
593
+ {
594
+ "cell_type": "code",
595
+ "execution_count": null,
596
+ "id": "3c17cd42",
597
+ "metadata": {},
598
+ "outputs": [],
599
+ "source": []
600
+ },
601
+ {
602
+ "cell_type": "code",
603
+ "execution_count": null,
604
+ "id": "d333776c",
605
+ "metadata": {},
606
+ "outputs": [],
607
+ "source": []
608
+ },
609
+ {
610
+ "cell_type": "code",
611
+ "execution_count": null,
612
+ "id": "1e8c7511",
613
+ "metadata": {},
614
+ "outputs": [],
615
+ "source": []
616
+ },
617
+ {
618
+ "cell_type": "code",
619
+ "execution_count": null,
620
+ "id": "acf35591",
621
+ "metadata": {},
622
+ "outputs": [],
623
+ "source": []
624
+ },
625
+ {
626
+ "cell_type": "code",
627
+ "execution_count": null,
628
+ "id": "101fb083",
629
+ "metadata": {},
630
+ "outputs": [],
631
+ "source": []
632
+ },
633
+ {
634
+ "cell_type": "code",
635
+ "execution_count": null,
636
+ "id": "f8bd8551",
637
+ "metadata": {},
638
+ "outputs": [],
639
+ "source": []
640
+ },
641
+ {
642
+ "cell_type": "code",
643
+ "execution_count": null,
644
+ "id": "271b9c12",
645
+ "metadata": {},
646
+ "outputs": [],
647
+ "source": []
648
+ },
649
+ {
650
+ "cell_type": "code",
651
+ "execution_count": null,
652
+ "id": "a232e088",
653
+ "metadata": {},
654
+ "outputs": [],
655
+ "source": []
656
+ },
657
+ {
658
+ "cell_type": "code",
659
+ "execution_count": null,
660
+ "id": "43df0240",
661
+ "metadata": {},
662
+ "outputs": [],
663
+ "source": []
664
+ },
665
+ {
666
+ "cell_type": "code",
667
+ "execution_count": null,
668
+ "id": "8dbb05e8",
669
+ "metadata": {},
670
+ "outputs": [],
671
+ "source": [
672
+ "\n"
673
+ ]
674
+ },
675
+ {
676
+ "cell_type": "code",
677
+ "execution_count": null,
678
+ "id": "9730cb16",
679
+ "metadata": {},
680
+ "outputs": [],
681
+ "source": []
682
+ },
683
+ {
684
+ "cell_type": "code",
685
+ "execution_count": null,
686
+ "id": "d38f92b2",
687
+ "metadata": {},
688
+ "outputs": [],
689
+ "source": []
690
+ },
691
+ {
692
+ "cell_type": "code",
693
+ "execution_count": null,
694
+ "id": "879f5463",
695
+ "metadata": {},
696
+ "outputs": [],
697
+ "source": []
698
+ }
699
+ ],
700
+ "metadata": {
701
+ "kernelspec": {
702
+ "display_name": "Python 3 (ipykernel)",
703
+ "language": "python",
704
+ "name": "python3"
705
+ },
706
+ "language_info": {
707
+ "codemirror_mode": {
708
+ "name": "ipython",
709
+ "version": 3
710
+ },
711
+ "file_extension": ".py",
712
+ "mimetype": "text/x-python",
713
+ "name": "python",
714
+ "nbconvert_exporter": "python",
715
+ "pygments_lexer": "ipython3",
716
+ "version": "3.10.9"
717
+ }
718
+ },
719
+ "nbformat": 4,
720
+ "nbformat_minor": 5
721
+ }
psq_rag/__init__.py ADDED
File without changes
psq_rag/llm/__init__.py ADDED
File without changes
psq_rag/llm/openrouter_client.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+ import httpx
4
+
5
+
6
+ OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY", "")
7
+ OPENROUTER_MODEL = os.environ.get("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
8
+
9
+ def _extract_json_object(text: str) -> Optional[dict]:
10
+ """
11
+ Best-effort: find the first top-level JSON object in a response.
12
+ Works even if the model wraps JSON with prose or code fences.
13
+ """
14
+ if not text:
15
+ return None
16
+
17
+ # Strip common fences
18
+ t = text.strip()
19
+ t = t.removeprefix("```json").removeprefix("```").removesuffix("```").strip()
20
+
21
+ # Find first {...} span
22
+ start = t.find("{")
23
+ if start == -1:
24
+ return None
25
+
26
+ depth = 0
27
+ for i in range(start, len(t)):
28
+ if t[i] == "{":
29
+ depth += 1
30
+ elif t[i] == "}":
31
+ depth -= 1
32
+ if depth == 0:
33
+ chunk = t[start:i+1]
34
+ try:
35
+ return json.loads(chunk)
36
+ except Exception:
37
+ return None
38
+ return None
39
+
40
+
41
+
42
+ def openrouter_chat(
43
+ messages: List[Dict[str, str]],
44
+ response_format: Optional[Dict[str, Any]] = None,
45
+ temperature: float = 0.2,
46
+ max_tokens: int = 512,
47
+ timeout_s: float = 30.0,
48
+ ) -> Tuple[Optional[str], Optional[dict], Optional[str]]:
49
+ """
50
+ Returns (raw_text, parsed_json, error_str).
51
+ Never raises.
52
+
53
+ Instrumented to help diagnose moderation/routing variance:
54
+ - includes HTTP status
55
+ - includes OpenRouter error message/code if provided
56
+ """
57
+ if not OPENROUTER_API_KEY:
58
+ return None, None, "OPENROUTER_API_KEY missing"
59
+
60
+ payload: Dict[str, Any] = {
61
+ "model": OPENROUTER_MODEL,
62
+ "messages": messages,
63
+ "temperature": temperature,
64
+ "max_tokens": max_tokens,
65
+ }
66
+ if response_format is not None:
67
+ payload["response_format"] = response_format
68
+
69
+ headers = {
70
+ "Authorization": f"Bearer {OPENROUTER_API_KEY}",
71
+ "Content-Type": "application/json",
72
+ "HTTP-Referer": "https://huggingface.co/spaces",
73
+ "X-Title": "Prompt_Squirrel_RAG",
74
+ }
75
+
76
+ try:
77
+ with httpx.Client(timeout=timeout_s) as client:
78
+ r = client.post(
79
+ "https://openrouter.ai/api/v1/chat/completions",
80
+ headers=headers,
81
+ json=payload,
82
+ )
83
+
84
+ data = r.json()
85
+ choice0 = data["choices"][0]
86
+ content = (choice0["message"].get("content", "") or "").strip()
87
+
88
+ finish_reason = choice0.get("finish_reason")
89
+ native_finish_reason = choice0.get("native_finish_reason")
90
+
91
+ # (optional) expose these as part of error_str for logging
92
+ meta = []
93
+ if data.get("model"):
94
+ meta.append(f"model={data['model']}")
95
+ if finish_reason:
96
+ meta.append(f"finish={finish_reason}")
97
+ if native_finish_reason:
98
+ meta.append(f"native_finish={native_finish_reason}")
99
+ if isinstance(data.get("usage"), dict):
100
+ u = data["usage"]
101
+ if "prompt_tokens" in u and "completion_tokens" in u:
102
+ meta.append(f"tokens={u['prompt_tokens']}+{u['completion_tokens']}")
103
+
104
+ parsed = _extract_json_object(content)
105
+
106
+ # If it looks filtered, flag it
107
+ if finish_reason == "content_filter":
108
+ return content, parsed, f"Filtered (content_filter; {'; '.join(meta)})"
109
+
110
+ # If it looks refusal-like but not content_filter, still flag it
111
+ if content.lower().startswith(("i can't", "i can’t", "i cannot", "can't", "cannot")):
112
+ return content, parsed, f"Refusal-like ({'; '.join(meta)})"
113
+
114
+ return content, parsed, None
115
+
116
+ except Exception as e:
117
+ return None, None, f"{type(e).__name__}: {e}"
118
+
119
+
120
+ if __name__ == "__main__":
121
+ print("openrouter_client.py imports ok")
psq_rag/llm/rewrite.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .openrouter_client import openrouter_chat
2
+
3
+
4
+ REWRITE_SYSTEM = """Rewrite the input into a concise, comma-separated list of short phrases
5
+ that resemble image tags.
6
+
7
+ Use short, literal phrases that reflect how visual concepts are commonly
8
+ written in image tag vocabularies.
9
+
10
+ Multi-word phrases are appropriate when they represent one coherent
11
+ visual idea.
12
+
13
+ Examples of tag-shaped phrases:
14
+ - wolf, angry
15
+ - blue jacket, striped tail
16
+ - long hair, raised ears
17
+ - holding object, hand on shoulder
18
+ - looking at viewer, looking down
19
+ - simple background, outdoor scene
20
+ - wooden table, plant
21
+ - running, sleeping
22
+ - smiling, angry expression
23
+ - bedroom, forest
24
+ - sonic the hedgehog, princess peach
25
+
26
+ Do not invent details or guess identities.
27
+ Do not infer demographic attributes (e.g., gender/age) unless explicitly stated.
28
+
29
+ Output ONLY the rewritten list.
30
+ """
31
+
32
+
33
+ def llm_rewrite_prompt(prompt_in: str, log) -> str:
34
+ messages = [
35
+ {"role": "system", "content": REWRITE_SYSTEM},
36
+ {"role": "user", "content": prompt_in},
37
+ ]
38
+
39
+ raw, _parsed_unused, err = openrouter_chat(
40
+ messages,
41
+ response_format=None,
42
+ temperature=0.0,
43
+ max_tokens=256,
44
+ )
45
+
46
+ if err:
47
+ log(f"LLM rewrite: fallback (error: {err})")
48
+ # NEW: if we got a refusal-like completion, log the refusal text for debugging
49
+ if raw and err.lower().startswith("refusal-like"):
50
+ log(f"LLM rewrite refusal text: {raw.strip()[:300]}")
51
+ return ""
52
+
53
+ out = (raw or "").strip()
54
+ if not out:
55
+ log("LLM rewrite: fallback (empty response)")
56
+ return ""
57
+
58
+ out = " ".join(out.split())
59
+ if len(out) > 800:
60
+ out = out[:800].rstrip()
61
+
62
+ log("LLM rewrite: ok")
63
+ return out
64
+
65
+
66
+ if __name__ == "__main__":
67
+ print("rewrite.py imports ok")
psq_rag/llm/select.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # psq_rag/llm/select.py
2
+ # Stage 3: Closed-Set Selection (LangChain-only implementation)
3
+ #
4
+ # This module intentionally uses LangChain for:
5
+ # - prompt templating (including {N})
6
+ # - LLM call orchestration
7
+ # - JSON parsing
8
+ #
9
+ # There is NO fallback path. If LangChain dependencies are missing, this module
10
+ # should fail loudly so you install them.
11
+
12
+ import os
13
+ import re
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast, Literal
16
+
17
+ from langchain_openai import ChatOpenAI
18
+ from langchain_core.prompts import ChatPromptTemplate
19
+ from langchain_core.output_parsers import PydanticOutputParser
20
+ from pydantic import BaseModel, Field, SecretStr
21
+ from rapidfuzz import fuzz
22
+
23
+ from psq_rag.retrieval.psq_retrieval import Candidate # Candidate(tag, score_*, count, sources)
24
+ from psq_rag.retrieval.state import get_tag_type_name, get_tag2aliases
25
+
26
+
27
+ WHY_ENUM = ["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
28
+
29
+ # Deterministic mapping: ordinal "why" -> numeric score for ordering/debug.
30
+ WHY_TO_SCORE: Dict[str, float] = {
31
+ "explicit": 0.90,
32
+ "strong_implied": 0.70,
33
+ "weak_implied": 0.45,
34
+ "style_or_meta": 0.35,
35
+ "other": 0.25,
36
+ }
37
+
38
+
39
+ # IMPORTANT ABOUT TEMPLATING:
40
+ # - This string is rendered by LangChain's f-string template engine.
41
+ # - Literal JSON braces must be escaped as {{ and }}.
42
+ # - {N} is a real template variable and MUST be provided.
43
+ SELECT_SYSTEM_TEMPLATE = """You are given a description of an image and a list of imageboard tags.
44
+
45
+ Select the tags that correspond to content that would be visible or depicted in the described image.
46
+
47
+ The list contains only valid tags; many of them are irrelevant to the image.
48
+
49
+ Return JSON ONLY matching this schema:
50
+
51
+ {{
52
+ \"selections\": [
53
+ {{\"i\": <int>, \"why\": \"<one of: explicit|strong_implied|weak_implied|style_or_meta|other>\"}},
54
+ ...
55
+ ]
56
+ }}
57
+
58
+ Rules:
59
+ - Choose ONLY from indices 1..{N}.
60
+ - Do NOT output tag text.
61
+ - Do NOT output any keys other than \"selections\", and inside each item only the item index \"i\" and \"why\".
62
+ - Do select both a general tag and a more specific tag when both apply (for example, \"shirt\" and \"grey shirt\").
63
+
64
+ Define \"why\" as:
65
+ - explicit: directly stated in the image description
66
+ - strong_implied: very likely given the description, even if not literally stated
67
+ - weak_implied: plausible but not strongly supported by the description
68
+ - style_or_meta: stylistic or presentation-related tags only if clearly indicated
69
+ - other: fallback category; use sparingly
70
+ """
71
+
72
+
73
+ ENTITY_SYSTEM_TEMPLATE = """You are given a description of an image and a list of CHARACTER tags.
74
+
75
+ These character tags have already been pre-filtered to only include characters whose names
76
+ (or known aliases) appear in the image description. Your job is to confirm which of these
77
+ pre-filtered candidates are the correct match for the character mentioned by the user.
78
+
79
+ Return JSON ONLY matching this schema:
80
+
81
+ {{
82
+ \"selections\": [
83
+ {{\"i\": <int>, \"why\": \"explicit\"}},
84
+ ...
85
+ ]
86
+ }}
87
+
88
+ Rules for character selection:
89
+ - Choose ONLY from indices 1..{N}.
90
+ - Do NOT output tag text.
91
+ - Always use \"why\": \"explicit\" for all selections.
92
+ - Select the tag that best represents the character as described.
93
+ - If the user described a specific variant (e.g. \"pikachu libre\", \"detective pikachu\"),
94
+ select that specific variant tag.
95
+ - If the user described only the base character (e.g. just \"pikachu\"), select only the
96
+ base/default tag, NOT costume or variant tags.
97
+ - When uncertain between variants, prefer the simplest/most general tag.
98
+ """
99
+
100
+
101
+ USER_TEMPLATE = """IMAGE DESCRIPTION:
102
+ {image_description}
103
+
104
+ CANDIDATES (choose by index only):
105
+ {candidate_lines}
106
+
107
+ Select up to {per_call_budget} indices. Output fewer if uncertain.
108
+ """
109
+
110
+
111
+ @dataclass(frozen=True)
112
+ class Selected:
113
+ i: int
114
+ tag: str # canonical tag (underscore form)
115
+ why: str
116
+ score: float
117
+
118
+
119
+ WhyLiteral = Literal["explicit", "strong_implied", "weak_implied", "style_or_meta", "other"]
120
+
121
+
122
+ class Stage3SelectionItem(BaseModel):
123
+ i: int = Field(..., description="1-based index into the candidate list.")
124
+ why: WhyLiteral = Field(..., description="Rationale code from the allowed set.")
125
+
126
+
127
+ class Stage3SelectionResponse(BaseModel):
128
+ selections: List[Stage3SelectionItem] = Field(default_factory=list)
129
+
130
+
131
+ def _build_response_format() -> Dict[str, Any]:
132
+ # Strict JSON Schema structured output.
133
+ schema = {
134
+ "type": "object",
135
+ "properties": {
136
+ "selections": {
137
+ "type": "array",
138
+ "items": {
139
+ "type": "object",
140
+ "properties": {
141
+ "i": {"type": "integer"},
142
+ "why": {"type": "string", "enum": WHY_ENUM},
143
+ },
144
+ "required": ["i", "why"],
145
+ "additionalProperties": False,
146
+ },
147
+ }
148
+ },
149
+ "required": ["selections"],
150
+ "additionalProperties": False,
151
+ }
152
+
153
+ return {
154
+ "type": "json_schema",
155
+ "json_schema": {
156
+ "name": "stage3_selection",
157
+ "strict": True,
158
+ "schema": schema,
159
+ },
160
+ }
161
+
162
+
163
+ def _get_llm(*, temperature: float, max_tokens: int, response_format: Dict[str, Any]) -> ChatOpenAI:
164
+ api_key = os.getenv("OPENROUTER_API_KEY")
165
+ if not api_key:
166
+ raise RuntimeError(
167
+ "OPENROUTER_API_KEY is not set.\n"
168
+ "Set it in your environment before running Stage 3."
169
+ )
170
+ api_key = SecretStr(cast(str, api_key))
171
+
172
+ model = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
173
+ headers: Dict[str, str] = {}
174
+ if referer := os.getenv("OPENROUTER_HTTP_REFERER"):
175
+ headers["HTTP-Referer"] = referer
176
+ if title := os.getenv("OPENROUTER_X_TITLE"):
177
+ headers["X-Title"] = title
178
+
179
+ # OpenRouter OpenAI-compatible endpoint.
180
+ return ChatOpenAI(
181
+ model=model,
182
+ base_url="https://openrouter.ai/api/v1",
183
+ api_key=api_key,
184
+ temperature=temperature,
185
+ max_completion_tokens=max_tokens,
186
+ default_headers=headers,
187
+ # Provider-specific request body fields (OpenAI-compatible).
188
+ # Response Healing plugin reduces malformed-JSON failures (syntax only).
189
+ extra_body={
190
+ "response_format": response_format,
191
+ "plugins": [{"id": "response-healing"}],
192
+ },
193
+ )
194
+
195
+
196
+ def _phrase_key_for_candidate(c: Candidate) -> str:
197
+ # Deterministic "primary phrase" for grouping.
198
+ if c.sources:
199
+ return sorted(c.sources)[0]
200
+ return ""
201
+
202
+
203
+ def _interleave_round_robin(cands: Sequence[Candidate]) -> List[Candidate]:
204
+ """Round-robin interleave by primary source phrase.
205
+
206
+ NOTE: counts are used only for ordering; they are NOT shown to the LLM.
207
+ """
208
+ groups: Dict[str, List[Candidate]] = {}
209
+ for c in cands:
210
+ k = _phrase_key_for_candidate(c)
211
+ groups.setdefault(k, []).append(c)
212
+
213
+ for k in groups:
214
+ groups[k].sort(key=lambda x: (x.score_combined, (x.count or -1)), reverse=True)
215
+
216
+ keys = sorted(groups.keys())
217
+
218
+ out: List[Candidate] = []
219
+ idx = 0
220
+ while True:
221
+ progressed = False
222
+ for k in keys:
223
+ if idx < len(groups[k]):
224
+ out.append(groups[k][idx])
225
+ progressed = True
226
+ if not progressed:
227
+ break
228
+ idx += 1
229
+
230
+ return out
231
+
232
+
233
+ def _display_tag(tag: str) -> str:
234
+ # Display tags with spaces for the LLM, but keep canonical underscores internally.
235
+ return tag.replace("_", " ")
236
+
237
+
238
+ def _format_candidates_local(
239
+ cands: Sequence[Candidate],
240
+ ) -> Tuple[str, Dict[int, str], Dict[int, Candidate]]:
241
+ lines: List[str] = []
242
+ idx_to_tag: Dict[int, str] = {}
243
+ idx_to_candidate: Dict[int, Candidate] = {}
244
+ for j, c in enumerate(cands, start=1):
245
+ idx_to_tag[j] = c.tag
246
+ idx_to_candidate[j] = c
247
+ lines.append(f"{j}. {_display_tag(c.tag)}")
248
+ return "\n".join(lines), idx_to_tag, idx_to_candidate
249
+
250
+
251
+ def _phrases_in_call(cands: Sequence[Candidate]) -> int:
252
+ s = set()
253
+ for c in cands:
254
+ for src in c.sources:
255
+ s.add(src)
256
+ return len(s)
257
+
258
+
259
+ def _parse_validate_map(
260
+ parsed: Any,
261
+ idx_to_tag: Dict[int, str],
262
+ per_call_budget: int,
263
+ ) -> Tuple[List[Selected], Dict[str, Any]]:
264
+ diag = {
265
+ "parse_ok": isinstance(parsed, dict),
266
+ "invalid_items": 0,
267
+ "oob_indices": 0,
268
+ "dupe_indices": 0,
269
+ "kept": 0,
270
+ }
271
+
272
+ if isinstance(parsed, BaseModel):
273
+ parsed = parsed.model_dump() if hasattr(parsed, "model_dump") else parsed.dict()
274
+ diag["parse_ok"] = isinstance(parsed, dict)
275
+
276
+ if not isinstance(parsed, dict):
277
+ return [], diag
278
+
279
+ selections = parsed.get("selections", [])
280
+ if not isinstance(selections, list):
281
+ diag["parse_ok"] = False
282
+ return [], diag
283
+
284
+ out: List[Selected] = []
285
+ seen_i = set()
286
+
287
+ for item in selections:
288
+ if len(out) >= per_call_budget:
289
+ break
290
+ if not isinstance(item, dict):
291
+ diag["invalid_items"] += 1
292
+ continue
293
+
294
+ i = item.get("i")
295
+ why = item.get("why")
296
+
297
+ if isinstance(i, bool) or not isinstance(i, int):
298
+ diag["invalid_items"] += 1
299
+ continue
300
+ if i in seen_i:
301
+ diag["dupe_indices"] += 1
302
+ continue
303
+ if i not in idx_to_tag:
304
+ diag["oob_indices"] += 1
305
+ continue
306
+ if not isinstance(why, str) or why not in WHY_ENUM:
307
+ diag["invalid_items"] += 1
308
+ continue
309
+ seen_i.add(i)
310
+ tag = idx_to_tag[i]
311
+ out.append(Selected(i=i, tag=tag, why=why, score=WHY_TO_SCORE[why]))
312
+
313
+ diag["kept"] = len(out)
314
+ return out, diag
315
+
316
+
317
+ def _split_candidates_by_type(
318
+ candidates: List[Candidate],
319
+ log,
320
+ ) -> Tuple[List[Tuple[int, Candidate]], List[Tuple[int, Candidate]]]:
321
+ """Split candidates into general vs entity (character only) lists.
322
+
323
+ Returns:
324
+ (general_list, entity_list) where each item is (original_index, candidate)
325
+
326
+ Tag types:
327
+ - General: 0 (general), 1 (artist), 5 (species), 7 (meta)
328
+ - Entity: 4 (character) only
329
+ - Filtered: 3 (copyright) - too broad for image generation
330
+ """
331
+ general_with_idx: List[Tuple[int, Candidate]] = []
332
+ entity_with_idx: List[Tuple[int, Candidate]] = []
333
+
334
+ unknown_count = 0
335
+ copyright_count = 0
336
+
337
+ for idx, cand in enumerate(candidates):
338
+ type_name = get_tag_type_name(cand.tag)
339
+
340
+ if type_name == "character":
341
+ entity_with_idx.append((idx, cand))
342
+ elif type_name == "copyright":
343
+ # Filter out copyright/series tags - too broad for image generation
344
+ copyright_count += 1
345
+ elif type_name in ("general", "artist", "species", "meta"):
346
+ general_with_idx.append((idx, cand))
347
+ else:
348
+ # Unknown or None - treat as general by default
349
+ general_with_idx.append((idx, cand))
350
+ unknown_count += 1
351
+
352
+ if log:
353
+ log(
354
+ f"Stage3 split: "
355
+ f"general={len(general_with_idx)} "
356
+ f"entity={len(entity_with_idx)} "
357
+ f"copyright_filtered={copyright_count} "
358
+ f"unknown_type={unknown_count}"
359
+ )
360
+
361
+ return general_with_idx, entity_with_idx
362
+
363
+
364
+ # Regex to strip series/franchise suffixes from aliases, e.g. _(sonic), _(mlp), _(character)
365
+ _SERIES_SUFFIX_RE = re.compile(r"_\([^)]+\)$")
366
+
367
+
368
+ def _normalize_for_matching(text: str) -> str:
369
+ """Lowercase, replace underscores with spaces, strip series suffixes."""
370
+ text = text.lower().strip()
371
+ text = _SERIES_SUFFIX_RE.sub("", text)
372
+ text = text.replace("_", " ")
373
+ return text
374
+
375
+
376
+ def _query_words(query: str) -> Set[str]:
377
+ """Extract individual words from the user query for matching."""
378
+ return set(_normalize_for_matching(query).split())
379
+
380
+
381
+ def _alias_matches_query(alias_norm: str, query_words: Set[str], query_norm: str,
382
+ fuzzy_threshold: int = 85) -> bool:
383
+ """Check if an alias matches the user query.
384
+
385
+ Matching logic:
386
+ 1. Exact substring: alias appears as a substring of the query
387
+ 2. Word subset: all words in the alias appear in the query words
388
+ 3. Fuzzy: alias is close to a word in the query (handles typos)
389
+ """
390
+ # Exact substring match
391
+ if alias_norm in query_norm:
392
+ return True
393
+
394
+ alias_words = alias_norm.split()
395
+ if not alias_words:
396
+ return False
397
+
398
+ # Word subset match: all alias words must appear in query
399
+ if all(w in query_words for w in alias_words):
400
+ return True
401
+
402
+ # For single-word aliases, try fuzzy matching against each query word
403
+ if len(alias_words) == 1:
404
+ for qw in query_words:
405
+ if fuzz.ratio(alias_words[0], qw) >= fuzzy_threshold:
406
+ return True
407
+
408
+ # For multi-word aliases, try fuzzy partial ratio against whole query
409
+ if len(alias_words) > 1:
410
+ if fuzz.partial_ratio(alias_norm, query_norm) >= fuzzy_threshold:
411
+ return True
412
+
413
+ return False
414
+
415
+
416
+ def _character_matches_via_aliases(
417
+ tag: str,
418
+ query: str,
419
+ tag2aliases: Dict[str, List[str]],
420
+ query_words: Set[str],
421
+ query_norm: str,
422
+ fuzzy_threshold: int = 85,
423
+ ) -> bool:
424
+ """Check if a character tag matches the user query via its aliases.
425
+
426
+ For a character tag to match:
427
+ - The tag name itself (normalized) must match, OR
428
+ - At least one of its registered aliases must match.
429
+
430
+ Empty aliases list means no known aliases; still check the tag name itself.
431
+ """
432
+ # Check the tag name itself
433
+ tag_norm = _normalize_for_matching(tag)
434
+ if _alias_matches_query(tag_norm, query_words, query_norm, fuzzy_threshold):
435
+ return True
436
+
437
+ # Check all registered aliases
438
+ aliases = tag2aliases.get(tag, [])
439
+ for alias in aliases:
440
+ alias_norm = _normalize_for_matching(alias)
441
+ if not alias_norm:
442
+ continue
443
+ if _alias_matches_query(alias_norm, query_words, query_norm, fuzzy_threshold):
444
+ return True
445
+
446
+ return False
447
+
448
+
449
+ def llm_select_indices(
450
+ query_text: str, # kept for compatibility; treated as IMAGE DESCRIPTION
451
+ candidates: Union[
452
+ Sequence[Candidate],
453
+ Sequence[str],
454
+ Sequence[Tuple[str, float]],
455
+ ],
456
+ max_pick: int, # legacy param; applied after union + ordering (optional)
457
+ log,
458
+ retries: int = 2,
459
+ *,
460
+ mode: str = "chunked_map_union", # "single_shot" or "chunked_map_union"
461
+ chunk_size: int = 60,
462
+ per_phrase_k: int = 2, # per-call budget = per_phrase_k * phrases_in_call
463
+ temperature: float = 0.0,
464
+ max_tokens: int = 512,
465
+ ) -> List[int]:
466
+ """Return indices into the ORIGINAL candidates list (legacy interface).
467
+
468
+ This implementation uses LangChain ONLY.
469
+
470
+ NOTE: query_text is treated as the image description (original prompt).
471
+ """
472
+
473
+ image_description = query_text
474
+
475
+ # Normalize candidates:
476
+ # - preferred: List[Candidate]
477
+ # - legacy: List[(tag, sim)] (count/sources unavailable)
478
+ norm: List[Candidate] = []
479
+ tag_to_first_index: Dict[str, int] = {}
480
+
481
+ branch = "empty"
482
+ cand0_type = type(candidates[0]).__name__ if candidates else "none"
483
+
484
+ if candidates and isinstance(candidates[0], Candidate):
485
+ branch = "candidate"
486
+ typed_candidates = cast(Sequence[Candidate], candidates)
487
+ for idx, c in enumerate(typed_candidates):
488
+ if c.tag not in tag_to_first_index:
489
+ tag_to_first_index[c.tag] = idx
490
+ norm.append(c)
491
+ elif candidates and isinstance(candidates[0], str):
492
+ branch = "string"
493
+ typed_candidates = cast(Sequence[str], candidates)
494
+ for idx, tag in enumerate(typed_candidates):
495
+ if tag not in tag_to_first_index:
496
+ tag_to_first_index[tag] = idx
497
+ norm.append(
498
+ Candidate(
499
+ tag=tag,
500
+ score_combined=0.0,
501
+ score_fasttext=None,
502
+ score_context=None,
503
+ count=None,
504
+ sources=[],
505
+ )
506
+ )
507
+ else:
508
+ if candidates:
509
+ branch = "tuple"
510
+ typed_candidates = cast(Sequence[Tuple[str, float]], candidates)
511
+ for idx, row in enumerate(typed_candidates):
512
+ if not isinstance(row, (list, tuple)) or len(row) < 2:
513
+ raise ValueError("Stage 3 candidates must be Candidate, tag strings, or (tag, score) tuples.")
514
+ tag, sim = row[0], row[1]
515
+ if tag not in tag_to_first_index:
516
+ tag_to_first_index[tag] = idx
517
+ norm.append(
518
+ Candidate(
519
+ tag=tag,
520
+ score_combined=float(sim),
521
+ score_fasttext=None,
522
+ score_context=None,
523
+ count=None,
524
+ sources=[],
525
+ )
526
+ )
527
+
528
+ if log:
529
+ if norm:
530
+ log(
531
+ "Stage3 input: "
532
+ f"type0={cand0_type} "
533
+ f"branch={branch} "
534
+ f"norm0_score={norm[0].score_combined!r} "
535
+ f"norm0_sources_empty={not bool(norm[0].sources)}"
536
+ )
537
+ else:
538
+ log(f"Stage3 input: type0={cand0_type} branch={branch} (no candidates)")
539
+
540
+ if mode not in ("single_shot", "chunked_map_union"):
541
+ raise ValueError(f"Invalid mode: {mode}")
542
+
543
+ response_format = _build_response_format()
544
+ llm = _get_llm(temperature=temperature, max_tokens=max_tokens, response_format=response_format)
545
+ model_name = os.getenv("OPENROUTER_MODEL", "meta-llama/llama-3.1-8b-instruct")
546
+
547
+ parser = PydanticOutputParser(pydantic_object=Stage3SelectionResponse)
548
+
549
+ # Global union: tag -> best (score, why)
550
+ best: Dict[str, Tuple[float, str]] = {}
551
+
552
+ def run_call(call_cands: Sequence[Candidate], label: str, system_template: str) -> None:
553
+ # Create chain with the provided system template
554
+ prompt = ChatPromptTemplate.from_messages(
555
+ [
556
+ ("system", system_template),
557
+ ("human", USER_TEMPLATE),
558
+ ],
559
+ template_format="f-string",
560
+ )
561
+ chain = prompt | llm | parser
562
+
563
+ ordered = _interleave_round_robin(call_cands)
564
+ candidate_lines, idx_to_tag, idx_to_candidate = _format_candidates_local(ordered)
565
+ N_local = len(idx_to_tag)
566
+
567
+ phrases = _phrases_in_call(call_cands)
568
+ per_call_budget = max(1, per_phrase_k * phrases) if phrases > 0 else per_phrase_k
569
+ summary_logged = False
570
+
571
+ if log:
572
+ log(f"Stage3 {label}: candidates (local indices):\n{candidate_lines}")
573
+ if phrases > 0:
574
+ distinct_phrases = sorted({src for c in call_cands for src in c.sources})
575
+ log(
576
+ f"Stage3 {label}: distinct_phrases={len(distinct_phrases)} "
577
+ f"phrases={', '.join(distinct_phrases)}"
578
+ )
579
+
580
+ # Invoke LangChain chain (templating fills {N} and other vars)
581
+ for att in range(retries + 1):
582
+ try:
583
+ if log:
584
+ log(
585
+ f"Stage3 {label}: "
586
+ f"model={model_name} "
587
+ f"N={N_local} "
588
+ f"phrases={phrases} "
589
+ f"per_call_budget={per_call_budget} "
590
+ f"response_healing=on"
591
+ )
592
+
593
+ parsed = chain.invoke(
594
+ {
595
+ "N": N_local,
596
+ "image_description": image_description,
597
+ "candidate_lines": candidate_lines,
598
+ "per_call_budget": per_call_budget,
599
+ }
600
+ )
601
+ selected, diag = _parse_validate_map(parsed, idx_to_tag, per_call_budget=per_call_budget)
602
+ if log:
603
+ log(f"Stage3 {label}: attempt {att+1} diag={diag}")
604
+ if not summary_logged and (selected or att == retries):
605
+ log(
606
+ f"Stage3 {label}: summary "
607
+ f"N={N_local} selected={len(selected)} per_call_budget={per_call_budget}"
608
+ )
609
+ summary_logged = True
610
+ if selected:
611
+ lines = [
612
+ f"Stage3 {label} selections:",
613
+ *[
614
+ (
615
+ f' - i={s.i} tag="{s.tag}" '
616
+ f"why={s.why} score={s.score:.2f} "
617
+ f"sources={idx_to_candidate.get(s.i).sources if idx_to_candidate.get(s.i) else []}"
618
+ )
619
+ for s in selected
620
+ ],
621
+ ]
622
+ log("\n".join(lines))
623
+ else:
624
+ log(f"Stage3 {label} selections: (none)")
625
+
626
+ if selected:
627
+ for s in selected:
628
+ prev = best.get(s.tag)
629
+ if prev is None or s.score > prev[0]:
630
+ best[s.tag] = (s.score, s.why)
631
+ return
632
+
633
+ except Exception as e:
634
+ if log:
635
+ log(f"Stage3 {label}: attempt {att+1} error: {e}")
636
+
637
+ if log:
638
+ log(f"Stage3 {label}: gave up after {retries+1} attempts")
639
+
640
+ # Split candidates by type (general vs entity)
641
+ general_with_idx, entity_with_idx = _split_candidates_by_type(norm, log)
642
+
643
+ # Extract just the candidates for LLM calls
644
+ general_cands = [cand for _, cand in general_with_idx]
645
+ entity_cands = [cand for _, cand in entity_with_idx]
646
+
647
+ # Process general candidates (attributes, actions, species, etc.)
648
+ if general_cands:
649
+ if mode == "single_shot":
650
+ run_call(general_cands, "general_single_shot", SELECT_SYSTEM_TEMPLATE)
651
+ else:
652
+ for start in range(0, len(general_cands), chunk_size):
653
+ run_call(
654
+ general_cands[start:start + chunk_size],
655
+ f"general_chunk_{start//chunk_size}",
656
+ SELECT_SYSTEM_TEMPLATE
657
+ )
658
+
659
+ # Process entity candidates (characters only) with alias-based pre-filtering
660
+ if entity_cands:
661
+ tag2aliases = get_tag2aliases()
662
+ qwords = _query_words(image_description)
663
+ qnorm = _normalize_for_matching(image_description)
664
+
665
+ filtered_entity_cands: List[Candidate] = []
666
+ filtered_out: List[str] = []
667
+
668
+ for cand in entity_cands:
669
+ if _character_matches_via_aliases(
670
+ cand.tag, image_description, tag2aliases, qwords, qnorm
671
+ ):
672
+ filtered_entity_cands.append(cand)
673
+ else:
674
+ filtered_out.append(cand.tag)
675
+
676
+ if log:
677
+ log(
678
+ f"Stage3 entity alias filter: "
679
+ f"before={len(entity_cands)} "
680
+ f"after={len(filtered_entity_cands)} "
681
+ f"removed={len(filtered_out)}"
682
+ )
683
+ if filtered_out:
684
+ log(f"Stage3 entity alias filter removed: {filtered_out[:20]}")
685
+
686
+ if filtered_entity_cands:
687
+ if mode == "single_shot":
688
+ run_call(filtered_entity_cands, "entity_single_shot", ENTITY_SYSTEM_TEMPLATE)
689
+ else:
690
+ for start in range(0, len(filtered_entity_cands), chunk_size):
691
+ run_call(
692
+ filtered_entity_cands[start:start + chunk_size],
693
+ f"entity_chunk_{start//chunk_size}",
694
+ ENTITY_SYSTEM_TEMPLATE
695
+ )
696
+
697
+ # Deterministic ordering: derived score desc, tie-break by count desc (count not shown to LLM).
698
+ count_by_tag = {c.tag: (c.count if c.count is not None else -1) for c in norm}
699
+ ordered_tags = sorted(best.keys(), key=lambda t: (best[t][0], count_by_tag.get(t, -1)), reverse=True)
700
+
701
+ # Legacy cap: apply AFTER union + ordering.
702
+ if isinstance(max_pick, int) and max_pick > 0:
703
+ ordered_tags = ordered_tags[:max_pick]
704
+
705
+ # Map back to original indices
706
+ out_idx: List[int] = []
707
+ for t in ordered_tags:
708
+ if t in tag_to_first_index:
709
+ out_idx.append(tag_to_first_index[t])
710
+
711
+ return out_idx
psq_rag/parsing/__init__.py ADDED
File without changes
psq_rag/parsing/prompt_grammar.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from lark import Lark, Token
3
+
4
+
5
+ #Parser
6
+ grammar=r"""
7
+ !start: (prompt | /[][():]/+)*
8
+ prompt: (emphasized | plain | comma | WHITESPACE)*
9
+ !emphasized: "(" prompt ")"
10
+ | "(" prompt ":" [WHITESPACE] NUMBER [WHITESPACE] ")"
11
+ comma: ","
12
+ WHITESPACE: /\s+/
13
+ plain: /([^,\\\[\]():|]|\\.)+/
14
+ %import common.SIGNED_NUMBER -> NUMBER
15
+ """
16
+
17
+ # Initialize the parser
18
+ parser = Lark(grammar, start='start')
19
+
20
+ # Function to extract tags
21
+ def extract_tags(tree):
22
+ tags_with_positions = []
23
+ def _traverse(node):
24
+ if isinstance(node, Token) and node.type == '__ANON_1':
25
+ tag_position = node.start_pos
26
+ tag_text = node.value
27
+ tags_with_positions.append((tag_text, tag_position, "tag"))
28
+ elif not isinstance(node, Token):
29
+ for child in node.children:
30
+ _traverse(child)
31
+ _traverse(tree)
32
+ return tags_with_positions
33
+
34
+
35
+
36
+ def build_tag_offsets_dicts(new_image_tags_with_positions):
37
+ # Structure the data for HighlightedText
38
+ tag_data = []
39
+ for tag_text, start_pos, nodetype in new_image_tags_with_positions:
40
+ # Modify the tag
41
+ modified_tag = tag_text.replace('_', ' ').replace('\\(', '(').replace('\\)', ')').strip()
42
+ artist_matrix_tag = tag_text.replace('_', ' ').replace('\\(', '\(').replace('\\)', '\)').strip()
43
+ tf_idf_matrix_tag = re.sub(r'\\([()])', r'\1', re.sub(r' ', '_', tag_text.strip().removeprefix('by ').removeprefix('by_')))
44
+ # Calculate the end position based on the original tag length
45
+ end_pos = start_pos + len(tag_text)
46
+ # Append the structured data for each tag
47
+ tag_data.append({
48
+ "original_tag": tag_text,
49
+ "start_pos": start_pos,
50
+ "end_pos": end_pos,
51
+ "modified_tag": modified_tag,
52
+ "artist_matrix_tag": artist_matrix_tag,
53
+ "tf_idf_matrix_tag": tf_idf_matrix_tag,
54
+ "node_type": nodetype
55
+ })
56
+ return tag_data
57
+
58
+
59
+ if __name__ == "__main__":
60
+ print("prompt_grammar.py imports ok")
psq_rag/pipeline/__init__.py ADDED
File without changes
psq_rag/pipeline/preproc.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ def extract_user_provided_tags_upto_3_words(prompt_in: str) -> list[str]:
4
+ """
5
+ Heuristic:
6
+ - split on '.' and ','
7
+ - strip leading/trailing whitespace
8
+ - split on whitespace
9
+ - keep items with <= 3 tokens
10
+ """
11
+ if not prompt_in:
12
+ return []
13
+
14
+ parts = re.split(r"[.,]+", prompt_in)
15
+
16
+ out: list[str] = []
17
+ seen = set()
18
+
19
+ for raw in parts:
20
+ item = raw.strip()
21
+ if not item:
22
+ continue
23
+
24
+ tokens = item.split()
25
+ if len(tokens) <= 3:
26
+ key = item.lower()
27
+ if key not in seen:
28
+ seen.add(key)
29
+ out.append(item)
30
+
31
+ return out
32
+
33
+
34
+ if __name__ == "__main__":
35
+ print("preproc.py imports ok")
36
+
psq_rag/retrieval/__init__.py ADDED
File without changes
psq_rag/retrieval/psq_retrieval.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import logging
5
+ import math
6
+ import os
7
+ import pathlib
8
+ import re
9
+ from collections import Counter, OrderedDict
10
+ from dataclasses import dataclass
11
+ from itertools import islice
12
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union
13
+
14
+ import numpy as np
15
+ import joblib
16
+ from scipy.sparse import csr_matrix
17
+
18
+ from .state import (
19
+ get_fasttext_model,
20
+ get_tag_counts,
21
+ get_hnsw_artist_index,
22
+ get_hnsw_tag_index,
23
+ get_nsfw_tags,
24
+ get_tfidf_components,
25
+ get_tfidf_tag_vectors,
26
+ get_alias2tags,
27
+ )
28
+
29
+ @dataclass(frozen=True)
30
+ class Candidate:
31
+ tag: str
32
+ score_combined: float
33
+ score_fasttext: Optional[float]
34
+ score_context: Optional[float]
35
+ count: Optional[int]
36
+ sources: List[str]
37
+
38
+
39
+
40
+
41
+ def _norm_tag_for_lookup(s: str) -> str:
42
+ # convert "name with spaces" -> "name_with_spaces" and unescape parens
43
+ return s.replace(' ', '_').replace('\\(', '(').replace('\\)', ')')
44
+
45
+
46
+ special_tags = ["score:0", "score:1", "score:2", "score:3", "score:4", "score:5", "score:6", "score:7", "score:8", "score:9", "rating:s", "rating:q", "rating:e"]
47
+ def remove_special_tags(original_string):
48
+ tags = [tag.strip() for tag in original_string.split(",")]
49
+ remaining_tags = [tag for tag in tags if tag not in special_tags]
50
+ removed_tags = [tag for tag in tags if tag in special_tags]
51
+ return ", ".join(remaining_tags), removed_tags
52
+
53
+
54
+ def construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index):
55
+ cols, data = [], []
56
+ for term, w in pseudo_doc_terms.items():
57
+ j = term_to_column_index.get(term)
58
+ if j is None:
59
+ continue
60
+ cols.append(j)
61
+ data.append(w * idf[j])
62
+ n_cols = len(idf)
63
+ indptr = [0, len(cols)]
64
+ return csr_matrix((data, cols, indptr), shape=(1, n_cols), dtype=np.float32)
65
+
66
+
67
+ def _ensure_dual_hnsw_indexes():
68
+ """
69
+ Build/load two HNSW indexes over the SVD-reduced TF-IDF matrix.
70
+ """
71
+ get_hnsw_tag_index()
72
+ get_hnsw_artist_index()
73
+ return
74
+
75
+
76
+ def _hnsw_query(idx, vec: np.ndarray, k: int):
77
+ """
78
+ Query a given HNSW index with a (1, D) or (D,) vector in SVD space.
79
+ Returns (indices, sims) with cosine similarity scores.
80
+ """
81
+ q = np.asarray(vec, dtype=np.float32).reshape(-1)
82
+ q_norm = np.linalg.norm(q)
83
+ if q_norm > 0:
84
+ q = q / q_norm
85
+ labels, dists = idx.knn_query(q, k=k)
86
+ inds = labels[0]
87
+ sims = 1.0 - dists[0] # cosine distance -> similarity
88
+ return inds, sims
89
+
90
+
91
+ def _ann_tags_topk(vec: np.ndarray, k: int):
92
+ idx, n_items = get_hnsw_tag_index()
93
+ if idx is None:
94
+ return (np.array([], dtype=int), np.array([], dtype=float))
95
+ k = min(k, n_items if n_items else 0)
96
+ return _hnsw_query(idx, vec, k) if k else (np.array([], dtype=int), np.array([], dtype=float))
97
+
98
+
99
+ def _ann_artists_topk(vec: np.ndarray, k: int):
100
+ idx, n_items = get_hnsw_artist_index()
101
+ if idx is None:
102
+ return (np.array([], dtype=int), np.array([], dtype=float))
103
+ k = min(k, n_items if n_items else 0)
104
+ return _hnsw_query(idx, vec, k) if k else (np.array([], dtype=int), np.array([], dtype=float))
105
+
106
+
107
+ def get_tfidf_reduced_similar_tags(pseudo_doc_terms, allow_nsfw_tags):
108
+ tf_idf_components = get_tfidf_components()
109
+ idf = tf_idf_components["idf"]
110
+ term_to_column_index = tf_idf_components["tag_to_column_index"]
111
+ row_to_tag = tf_idf_components["row_to_tag"]
112
+ svd = tf_idf_components["svd_model"]
113
+
114
+ # 1) Build the pseudo TF-IDF, reduce to SVD space (unchanged)
115
+ pseudo_tfidf_vector = construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index)
116
+ reduced_pseudo_vector = svd.transform(pseudo_tfidf_vector) # shape (1, D)
117
+
118
+ # 2) ANN: only fetch nearest non-artist candidates (no full-matrix cosine)
119
+ K = 2000 # tune for speed/recall
120
+ top_inds, top_sims = _ann_tags_topk(reduced_pseudo_vector, k=K)
121
+
122
+ # 3) Build similarity dict from those candidates
123
+ tag_similarity_dict = {}
124
+ for i, sim in zip(top_inds, top_sims):
125
+ tag = row_to_tag.get(int(i))
126
+ if tag is not None:
127
+ tag_similarity_dict[tag] = float(sim)
128
+
129
+ if not allow_nsfw_tags:
130
+ nsfw_tags = get_nsfw_tags()
131
+ tag_similarity_dict = {t: s for t, s in tag_similarity_dict.items() if t not in nsfw_tags}
132
+
133
+ # 4) Sort & escape like before
134
+ sorted_tag_similarity_dict = OrderedDict(sorted(tag_similarity_dict.items(), key=lambda x: x[1], reverse=True))
135
+ transformed_sorted_tag_similarity_dict = OrderedDict(
136
+ (key.replace('_', ' ').replace('(', '\\(').replace(')', '\\)'), val)
137
+ for key, val in sorted_tag_similarity_dict.items()
138
+ )
139
+ return transformed_sorted_tag_similarity_dict
140
+
141
+
142
+ def psq_candidates_from_terms(terms: Sequence[str], *, allow_nsfw_tags: bool, k: int = 300):
143
+ cand_dict = get_tfidf_reduced_similar_tags(dict(Counter(terms)), allow_nsfw_tags)
144
+ candidates = list(islice(cand_dict.items(), k))
145
+ tag_counts = get_tag_counts()
146
+ return [
147
+ Candidate(
148
+ tag=tag,
149
+ score_combined=float(score),
150
+ score_fasttext=None,
151
+ score_context=None,
152
+ count=tag_counts.get(tag),
153
+ sources=[],
154
+ )
155
+ for tag, score in candidates
156
+ ]
157
+
158
+
159
+ def psq_candidates_from_rewrite_phrases(
160
+ rewrite_phrases: Sequence[str],
161
+ *,
162
+ allow_nsfw_tags: bool,
163
+ context_weight: float = 0.5,
164
+ per_phrase_k: int = 50,
165
+ per_phrase_final_k: int = 10,
166
+ global_k: int = 300,
167
+ verbose: bool = False,
168
+ ) -> Union[List[Candidate], Tuple[List[Candidate], List[Dict[str, Any]]]]:
169
+ head_stopwords = {
170
+ "and",
171
+ "or",
172
+ "the",
173
+ "a",
174
+ "an",
175
+ "of",
176
+ "to",
177
+ "in",
178
+ "on",
179
+ "at",
180
+ "with",
181
+ "for",
182
+ "from",
183
+ "by",
184
+ "as",
185
+ "is",
186
+ "are",
187
+ "was",
188
+ "were",
189
+ "be",
190
+ "been",
191
+ "being",
192
+ "down",
193
+ "up",
194
+ "over",
195
+ "under",
196
+ }
197
+
198
+ def _normalize_phrase(phrase: str) -> str:
199
+ lowered = (phrase or "").lower().strip().replace("_", " ")
200
+ return " ".join(lowered.split())
201
+
202
+ norm_phrases = [_normalize_phrase(p) for p in rewrite_phrases]
203
+ deduped_phrases = list(dict.fromkeys(p for p in norm_phrases if p))
204
+ if not deduped_phrases:
205
+ return ([], []) if verbose else []
206
+
207
+ head_phrases: List[str] = []
208
+ for phrase in deduped_phrases:
209
+ parts = phrase.split()
210
+ if len(parts) >= 2:
211
+ head = parts[-1]
212
+ if len(head) >= 3 and head.lower() not in head_stopwords:
213
+ head_phrases.append(head)
214
+
215
+ final_phrases = list(dict.fromkeys(deduped_phrases + head_phrases))
216
+
217
+ fasttext_model = get_fasttext_model()
218
+ tag_counts = get_tag_counts()
219
+ nsfw_tags = get_nsfw_tags() if not allow_nsfw_tags else set()
220
+ alias2tags = get_alias2tags()
221
+
222
+ tfidf_components = get_tfidf_components()
223
+ tfidf_vocab = tfidf_components.get("tag_to_column_index", {})
224
+ idf = tfidf_components["idf"]
225
+ term_to_column_index = tfidf_components["tag_to_column_index"]
226
+ svd = tfidf_components["svd_model"]
227
+
228
+ pseudo_doc_terms = Counter()
229
+ oov_terms: List[str] = []
230
+ for phrase in final_phrases:
231
+ lookup = phrase.replace(" ", "_")
232
+ if lookup in term_to_column_index:
233
+ pseudo_doc_terms[lookup] += 1
234
+ elif verbose:
235
+ oov_terms.append(lookup)
236
+ pseudo_tfidf_vector = construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index)
237
+ reduced_query_vector = svd.transform(pseudo_tfidf_vector).reshape(-1)
238
+ query_norm = np.linalg.norm(reduced_query_vector)
239
+ if query_norm > 0:
240
+ reduced_query_vector = reduced_query_vector / query_norm
241
+ query_has_context = True
242
+ else:
243
+ query_has_context = False
244
+ tag_vectors = get_tfidf_tag_vectors() if query_has_context else None
245
+ tag_to_row_index = tag_vectors["tag_to_row_index"] if tag_vectors else {}
246
+
247
+ phrase_candidate_maps: List[Tuple[str, Dict[str, float]]] = []
248
+ phrase_required_tags: Dict[str, Set[str]] = {}
249
+ phrase_best_tokens: Dict[str, Dict[str, str]] = {}
250
+ phrase_context_imputed: Dict[str, Dict[str, bool]] = {}
251
+ phrase_reports: List[Dict[str, Any]] = []
252
+
253
+ for phrase in final_phrases:
254
+ lookup = phrase.replace(" ", "_")
255
+
256
+ def _project_to_canonicals(token: str) -> List[str]:
257
+ if token in tag_counts or token in tag_to_row_index:
258
+ return [token]
259
+ if token in alias2tags:
260
+ return alias2tags[token]
261
+ return []
262
+
263
+ try:
264
+ neighbors = fasttext_model.most_similar(lookup, topn=per_phrase_k)
265
+ except KeyError:
266
+ neighbors = []
267
+
268
+ per_phrase_candidates: Dict[str, float] = {}
269
+ per_phrase_best_token: Dict[str, str] = {}
270
+ for token, sim in neighbors:
271
+ for canonical_tag in _project_to_canonicals(token):
272
+ if not allow_nsfw_tags and canonical_tag in nsfw_tags:
273
+ continue
274
+ prev = per_phrase_candidates.get(canonical_tag)
275
+ if prev is None or sim > prev:
276
+ per_phrase_candidates[canonical_tag] = float(sim)
277
+ per_phrase_best_token[canonical_tag] = token
278
+ projected_lookup = _project_to_canonicals(lookup)
279
+ required_tags = set(projected_lookup)
280
+ if not allow_nsfw_tags:
281
+ required_tags = {tag for tag in required_tags if tag not in nsfw_tags}
282
+ for canonical_tag in projected_lookup:
283
+ if not allow_nsfw_tags and canonical_tag in nsfw_tags:
284
+ continue
285
+ prev = per_phrase_candidates.get(canonical_tag)
286
+ if prev is None or 1.0 > prev:
287
+ per_phrase_candidates[canonical_tag] = 1.0
288
+ per_phrase_best_token[canonical_tag] = lookup
289
+
290
+ phrase_candidate_maps.append((phrase, per_phrase_candidates))
291
+ phrase_required_tags[phrase] = required_tags
292
+ phrase_best_tokens[phrase] = per_phrase_best_token
293
+
294
+ if verbose:
295
+ in_vocab = bool(tfidf_vocab and lookup in tfidf_vocab)
296
+ rows = []
297
+ for canonical_tag, sim in sorted(per_phrase_candidates.items(), key=lambda x: x[1], reverse=True):
298
+ if not allow_nsfw_tags and canonical_tag in nsfw_tags:
299
+ continue
300
+ alias_token = per_phrase_best_token.get(canonical_tag, canonical_tag)
301
+ rows.append(
302
+ {
303
+ "tag": canonical_tag,
304
+ "alias_token": alias_token,
305
+ "score_fasttext": float(sim),
306
+ "score_context": None,
307
+ "score_combined": float(sim),
308
+ "context_imputed": False,
309
+ "count": tag_counts.get(canonical_tag),
310
+ }
311
+ )
312
+ phrase_reports.append(
313
+ {
314
+ "phrase": phrase,
315
+ "normalized": phrase,
316
+ "lookup": lookup,
317
+ "tfidf_vocab": in_vocab,
318
+ "oov_terms": oov_terms,
319
+ "candidates": rows,
320
+ }
321
+ )
322
+
323
+ all_candidate_tags: Set[str] = set()
324
+ for _, per_phrase_candidates in phrase_candidate_maps:
325
+ all_candidate_tags.update(per_phrase_candidates.keys())
326
+
327
+ score_context_by_tag: Dict[str, Optional[float]] = {}
328
+ if query_has_context:
329
+ reduced_matrix_norm = tag_vectors["reduced_matrix_norm"]
330
+ for tag in all_candidate_tags:
331
+ row = tag_to_row_index.get(tag)
332
+ if row is None:
333
+ score_context_by_tag[tag] = None
334
+ continue
335
+ score_context_by_tag[tag] = float(np.dot(reduced_query_vector, reduced_matrix_norm[row]))
336
+ else:
337
+ for tag in all_candidate_tags:
338
+ score_context_by_tag[tag] = None
339
+
340
+ merged_by_tag: Dict[str, Candidate] = {}
341
+ per_phrase_scored: Dict[str, List[Tuple[str, float, Optional[float], float]]] = {}
342
+ for phrase, per_phrase_candidates in phrase_candidate_maps:
343
+ context_imputed_by_tag: Dict[str, bool] = {}
344
+ default_context_for_phrase = None
345
+ if query_has_context:
346
+ context_scores = [
347
+ score_context_by_tag.get(tag)
348
+ for tag in per_phrase_candidates.keys()
349
+ ]
350
+ context_scores = [score for score in context_scores if score is not None]
351
+ if context_scores:
352
+ context_scores.sort()
353
+ index = int(math.floor(0.10 * (len(context_scores) - 1)))
354
+ default_context_for_phrase = float(context_scores[index])
355
+ else:
356
+ default_context_for_phrase = 0.0
357
+ scored_rows: List[Tuple[str, float, Optional[float], float]] = []
358
+ for tag, score_fasttext in per_phrase_candidates.items():
359
+ if not allow_nsfw_tags and tag in nsfw_tags:
360
+ continue
361
+ score_context = score_context_by_tag.get(tag)
362
+ context_imputed = False
363
+ if score_context is None and query_has_context:
364
+ # Impute missing context with the per-phrase 10th percentile.
365
+ score_context = default_context_for_phrase
366
+ context_imputed = True
367
+ if score_context is None:
368
+ score_combined = float(score_fasttext)
369
+ else:
370
+ score_combined = (1.0 - context_weight) * float(score_fasttext) + context_weight * score_context
371
+ scored_rows.append((tag, float(score_fasttext), score_context, score_combined))
372
+ context_imputed_by_tag[tag] = context_imputed
373
+
374
+ scored_rows.sort(key=lambda x: x[3], reverse=True)
375
+ required_tags = phrase_required_tags.get(phrase, set())
376
+ if required_tags:
377
+ scored_by_tag = {row[0]: row for row in scored_rows}
378
+ top_rows = scored_rows[:per_phrase_final_k]
379
+ top_tags = {row[0] for row in top_rows}
380
+ for required_tag in required_tags:
381
+ if required_tag in top_tags:
382
+ continue
383
+ required_row = scored_by_tag.get(required_tag)
384
+ if required_row is None:
385
+ score_fasttext = per_phrase_candidates.get(required_tag)
386
+ score_context = score_context_by_tag.get(required_tag)
387
+ if score_fasttext is None:
388
+ score_fasttext = 1.0
389
+ context_imputed = False
390
+ if score_context is None and query_has_context:
391
+ score_context = default_context_for_phrase
392
+ context_imputed = True
393
+ if score_context is None:
394
+ score_combined = float(score_fasttext)
395
+ else:
396
+ score_combined = (1.0 - context_weight) * float(score_fasttext) + context_weight * score_context
397
+ required_row = (required_tag, float(score_fasttext), score_context, score_combined)
398
+ context_imputed_by_tag[required_tag] = context_imputed
399
+ if len(top_rows) >= per_phrase_final_k:
400
+ drop_index = None
401
+ for idx in range(len(top_rows) - 1, -1, -1):
402
+ if top_rows[idx][0] not in required_tags:
403
+ drop_index = idx
404
+ break
405
+ if drop_index is None:
406
+ drop_index = -1
407
+ top_rows.pop(drop_index)
408
+ top_rows.append(required_row)
409
+ top_tags.add(required_tag)
410
+ # Deterministic must-include for exact phrase matches; re-sort top-N by combined score.
411
+ top_rows.sort(key=lambda x: x[3], reverse=True)
412
+ scored_rows = top_rows
413
+ else:
414
+ scored_rows = scored_rows[:per_phrase_final_k]
415
+ per_phrase_scored[phrase] = scored_rows
416
+ phrase_context_imputed[phrase] = context_imputed_by_tag
417
+
418
+ for tag, score_fasttext, score_context, score_combined in scored_rows:
419
+ existing = merged_by_tag.get(tag)
420
+ if existing is None:
421
+ merged_by_tag[tag] = Candidate(
422
+ tag=tag,
423
+ score_combined=score_combined,
424
+ score_fasttext=score_fasttext,
425
+ score_context=score_context,
426
+ count=tag_counts.get(tag),
427
+ sources=[phrase],
428
+ )
429
+ else:
430
+ if phrase not in existing.sources:
431
+ existing.sources.append(phrase)
432
+ existing_fasttext = (
433
+ existing.score_fasttext if existing.score_fasttext is not None else float("-inf")
434
+ )
435
+ incoming_fasttext = score_fasttext if score_fasttext is not None else float("-inf")
436
+ max_fasttext = max(existing_fasttext, incoming_fasttext)
437
+ existing_context = existing.score_context
438
+ if existing_context is None:
439
+ max_context = score_context
440
+ elif score_context is None:
441
+ max_context = existing_context
442
+ else:
443
+ max_context = max(existing_context, score_context)
444
+ max_combined = max(existing.score_combined, score_combined)
445
+ merged_by_tag[tag] = Candidate(
446
+ tag=tag,
447
+ score_combined=max_combined,
448
+ score_fasttext=max_fasttext if max_fasttext != float("-inf") else None,
449
+ score_context=max_context,
450
+ count=existing.count,
451
+ sources=existing.sources,
452
+ )
453
+
454
+ if verbose:
455
+ for report in phrase_reports:
456
+ phrase = report["phrase"]
457
+ rows = []
458
+ for tag, score_fasttext, score_context, score_combined in per_phrase_scored.get(phrase, []):
459
+ alias_token = phrase_best_tokens.get(phrase, {}).get(tag, tag)
460
+ context_imputed = phrase_context_imputed.get(phrase, {}).get(tag, False)
461
+ rows.append(
462
+ {
463
+ "tag": tag,
464
+ "alias_token": alias_token,
465
+ "score_fasttext": score_fasttext,
466
+ "score_context": score_context,
467
+ "score_combined": score_combined,
468
+ "context_imputed": context_imputed,
469
+ "count": tag_counts.get(tag),
470
+ }
471
+ )
472
+ report["candidates"] = rows
473
+
474
+ merged_candidates = list(merged_by_tag.values())
475
+ merged_candidates.sort(key=lambda c: c.score_combined, reverse=True)
476
+ merged_candidates = merged_candidates[:global_k]
477
+
478
+ return (merged_candidates, phrase_reports) if verbose else merged_candidates
479
+
480
+
481
+ def psq_candidates_from_prompt(prompt: str, *, allow_nsfw_tags: bool, k: int = 300):
482
+ """Return Stage 2 candidates from a raw prompt."""
483
+ from ..parsing.prompt_grammar import build_tag_offsets_dicts, extract_tags, parser
484
+
485
+ p = (prompt or "").lower()
486
+ p, removed_special = remove_special_tags(p)
487
+
488
+ parsed = parser.parse(p)
489
+ tags_with_pos = extract_tags(parsed)
490
+ tag_data = build_tag_offsets_dicts(tags_with_pos)
491
+
492
+ # These are TF-IDF terms as your pipeline already expects
493
+ terms = [item["tf_idf_matrix_tag"] for item in tag_data] + removed_special
494
+
495
+ return psq_candidates_from_terms(terms, allow_nsfw_tags=allow_nsfw_tags, k=k)
496
+
497
+
498
+ if __name__ == "__main__":
499
+ print("psq_retrieval.py imports ok")
500
+
psq_rag/retrieval/state.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import logging
5
+ import pathlib
6
+ from typing import Any, Dict, List, Optional, Set, Tuple
7
+
8
+ import joblib
9
+ import numpy as np
10
+
11
+ try:
12
+ import hnswlib
13
+ except Exception:
14
+ hnswlib = None # allow import on environments without hnswlib during partial tests
15
+
16
+
17
+ TFIDF_PATH = pathlib.Path("tf_idf_files_420.joblib")
18
+ NSFW_CSV_PATH = pathlib.Path("word_rating_probabilities.csv")
19
+ NSFW_THRESHOLD = 0.95
20
+
21
+ HNSW_ART_PATH = pathlib.Path("tfidf_hnsw_artists.bin")
22
+ HNSW_TAG_PATH = pathlib.Path("tfidf_hnsw_tags.bin")
23
+ FASTTEXT_MODEL_PATH = pathlib.Path("e621FastTextModel010Replacement_small.bin")
24
+ TAG_ALIASES_PATH = pathlib.Path("fluffyrock_3m.csv")
25
+
26
+ _tfidf_components: Optional[Dict[str, Any]] = None
27
+ _nsfw_tags: Optional[Set[str]] = None
28
+ _artist_set: Optional[Set[str]] = None
29
+ _fasttext_model: Optional[Any] = None
30
+ _tag_counts: Optional[Dict[str, int]] = None
31
+ _tfidf_tag_vectors: Optional[Dict[str, Any]] = None
32
+ _alias_to_tags: Optional[Dict[str, List[str]]] = None
33
+ _tag_to_aliases: Optional[Dict[str, List[str]]] = None
34
+ _tag_type_id: Optional[Dict[str, int]] = None
35
+
36
+
37
+ _hnsw_tag_index: Optional["hnswlib.Index"] = None
38
+ _hnsw_artist_index: Optional["hnswlib.Index"] = None
39
+ _hnsw_tag_count: int = 0
40
+ _hnsw_artist_count: int = 0
41
+
42
+ # Tag type names inferred from e621 wiki documentation.
43
+ # Numeric IDs come from fluffyrock_3m.csv column 1; mapping is heuristic but
44
+ # matches observed usage on e621.
45
+ TAG_TYPE_ID_TO_NAME: Dict[int, str] = {
46
+ 0: "general", # Default tag type: visible attributes, actions, objects, etc.
47
+ 1: "artist", # Artist tags (e.g. by_name, artist_name)
48
+ 2: "contributor", # Contributor tags (rare / possibly unused in this dataset)
49
+ 3: "copyright", # Series, franchise, or IP (e.g. pokemon, winnie_the_pooh)
50
+ 4: "character", # Named characters (e.g. pikachu, pinkie_pie_(mlp))
51
+ 5: "species", # Species tags (e.g. canine, domestic_cat)
52
+ 6: "invalid", # Invalid / disallowed / disambiguation-only tags
53
+ 7: "meta", # Meta / presentation / file / style-related tags
54
+ }
55
+
56
+
57
+ def _l2_normalize_rows(mat: np.ndarray) -> np.ndarray:
58
+ mat = np.asarray(mat, dtype=np.float32)
59
+ norms = np.linalg.norm(mat, axis=1, keepdims=True)
60
+ norms[norms == 0.0] = 1.0
61
+ return mat / norms
62
+
63
+
64
+ def _clean_tag_ascii(tag: str) -> str:
65
+ return "".join(char for char in tag if ord(char) < 128)
66
+
67
+
68
+ def clean_tag(tag: str) -> str:
69
+ """Normalize tags consistently with legacy alias parsing."""
70
+ return _clean_tag_ascii(tag)
71
+
72
+
73
+ def build_aliases_dict(csv_path: str, reverse: bool = False) -> Dict[str, List[str]]:
74
+ """Build tag/alias mappings from the aliases CSV."""
75
+ aliases_dict: Dict[str, List[str]] = {}
76
+ with open(csv_path, "r", newline="", encoding="utf-8") as csvfile:
77
+ reader = csv.reader(csvfile)
78
+ for row in reader:
79
+ tag = clean_tag(row[0])
80
+ alias_list = [] if row[3] == "null" else [clean_tag(alias) for alias in row[3].split(",")]
81
+ if reverse:
82
+ for alias in alias_list:
83
+ aliases_dict.setdefault(alias, []).append(tag)
84
+ else:
85
+ aliases_dict[tag] = alias_list
86
+ return aliases_dict
87
+
88
+
89
+ def get_tfidf_components() -> Dict[str, Any]:
90
+ global _tfidf_components
91
+ if _tfidf_components is not None:
92
+ return _tfidf_components
93
+
94
+ if not TFIDF_PATH.is_file():
95
+ raise FileNotFoundError(f"TF-IDF joblib not found: {TFIDF_PATH}")
96
+
97
+ model_components = joblib.load(TFIDF_PATH)
98
+
99
+ if "tag_to_row_index" in model_components and "row_to_tag" not in model_components:
100
+ model_components["row_to_tag"] = {
101
+ idx: tag for tag, idx in model_components["tag_to_row_index"].items()
102
+ }
103
+
104
+ idf = model_components.get("idf")
105
+ if isinstance(idf, dict):
106
+ t2c = model_components["tag_to_column_index"]
107
+ n_cols = max(t2c.values()) + 1
108
+ idf_by_col = np.ones(n_cols, dtype=np.float32)
109
+ for term, col in t2c.items():
110
+ idf_by_col[col] = float(idf.get(term, 1.0))
111
+ model_components["idf"] = idf_by_col
112
+
113
+ _tfidf_components = model_components
114
+ return model_components
115
+
116
+
117
+ def get_nsfw_tags() -> Set[str]:
118
+ global _nsfw_tags
119
+ if _nsfw_tags is not None:
120
+ return _nsfw_tags
121
+
122
+ if not NSFW_CSV_PATH.is_file():
123
+ raise FileNotFoundError(f"NSFW tag CSV not found: {NSFW_CSV_PATH}")
124
+
125
+ tags: Set[str] = set()
126
+ with NSFW_CSV_PATH.open("r", newline="", encoding="utf-8") as csvfile:
127
+ reader = csv.reader(csvfile)
128
+ next(reader, None)
129
+ for row in reader:
130
+ if not row:
131
+ continue
132
+ word = row[0]
133
+ try:
134
+ probability_sum = float(row[1])
135
+ except (IndexError, ValueError):
136
+ continue
137
+ if probability_sum >= NSFW_THRESHOLD:
138
+ tags.add(word)
139
+
140
+ _nsfw_tags = tags
141
+ return _nsfw_tags
142
+
143
+
144
+ def get_artist_set() -> Set[str]:
145
+ global _artist_set
146
+ if _artist_set is not None:
147
+ return _artist_set
148
+
149
+ path = pathlib.Path("fluffyrock_3m.csv")
150
+ if not path.is_file():
151
+ _artist_set = set()
152
+ return _artist_set
153
+
154
+ artists: Set[str] = set()
155
+ with path.open("r", newline="", encoding="utf-8") as csvfile:
156
+ reader = csv.reader(csvfile)
157
+ for row in reader:
158
+ if not row:
159
+ continue
160
+ tag_name = row[0]
161
+ if tag_name.startswith("by_"):
162
+ artists.add(tag_name[3:])
163
+
164
+ _artist_set = artists
165
+ return _artist_set
166
+
167
+
168
+ def is_artist(name: str) -> bool:
169
+ return name in get_artist_set()
170
+
171
+
172
+ def get_fasttext_model() -> Any:
173
+ global _fasttext_model
174
+ if _fasttext_model is not None:
175
+ return _fasttext_model
176
+
177
+ if not FASTTEXT_MODEL_PATH.is_file():
178
+ raise FileNotFoundError(f"FastText model not found: {FASTTEXT_MODEL_PATH}")
179
+
180
+ import compress_fasttext
181
+
182
+ _fasttext_model = compress_fasttext.models.CompressedFastTextKeyedVectors.load(
183
+ str(FASTTEXT_MODEL_PATH)
184
+ )
185
+ return _fasttext_model
186
+
187
+
188
+ def get_tag_type_ids() -> Dict[str, int]:
189
+ """Return canonical tag -> type_id (int) from fluffyrock_3m.csv.
190
+
191
+ Reads row[1] as int when possible. Missing/invalid values are skipped.
192
+ """
193
+ global _tag_type_id
194
+ if _tag_type_id is not None:
195
+ return _tag_type_id
196
+
197
+ if not TAG_ALIASES_PATH.is_file():
198
+ raise FileNotFoundError(f"Tag CSV not found: {TAG_ALIASES_PATH}")
199
+
200
+ m: Dict[str, int] = {}
201
+ with TAG_ALIASES_PATH.open("r", newline="", encoding="utf-8") as csvfile:
202
+ reader = csv.reader(csvfile)
203
+ for row in reader:
204
+ if not row:
205
+ continue
206
+ tag = clean_tag(row[0])
207
+ if len(row) < 2:
208
+ continue
209
+ try:
210
+ type_id = int(row[1])
211
+ except ValueError:
212
+ continue
213
+ m[tag] = type_id
214
+
215
+ _tag_type_id = m
216
+ return _tag_type_id
217
+
218
+
219
+ def get_tag_type_name(tag: str) -> Optional[str]:
220
+ """Return heuristic type name for a tag (e.g. 'artist', 'character'), or None."""
221
+ tid = get_tag_type_ids().get(clean_tag(tag))
222
+ if tid is None:
223
+ return None
224
+ return TAG_TYPE_ID_TO_NAME.get(tid, f"type_{tid}")
225
+
226
+
227
+ def get_tag_counts() -> Dict[str, int]:
228
+ global _tag_counts
229
+ if _tag_counts is not None:
230
+ return _tag_counts
231
+
232
+ if not TAG_ALIASES_PATH.is_file():
233
+ raise FileNotFoundError(f"Tag count CSV not found: {TAG_ALIASES_PATH}")
234
+
235
+ tag_counts: Dict[str, int] = {}
236
+ with TAG_ALIASES_PATH.open("r", newline="", encoding="utf-8") as csvfile:
237
+ reader = csv.reader(csvfile)
238
+ for row in reader:
239
+ if not row:
240
+ continue
241
+ key = row[0]
242
+ value = int(row[2]) if row[2].isdigit() else None
243
+ if value is not None:
244
+ tag_counts[key] = value
245
+
246
+ _tag_counts = tag_counts
247
+ return _tag_counts
248
+
249
+
250
+ def get_alias2tags() -> Dict[str, List[str]]:
251
+ """Return alias -> [canonical tags] mapping."""
252
+ global _alias_to_tags
253
+ if _alias_to_tags is not None:
254
+ return _alias_to_tags
255
+
256
+ if not TAG_ALIASES_PATH.is_file():
257
+ raise FileNotFoundError(f"Tag alias CSV not found: {TAG_ALIASES_PATH}")
258
+
259
+ _alias_to_tags = build_aliases_dict(str(TAG_ALIASES_PATH), reverse=True)
260
+ return _alias_to_tags
261
+
262
+
263
+ def get_tag2aliases() -> Dict[str, List[str]]:
264
+ """Return canonical tag -> [aliases] mapping."""
265
+ global _tag_to_aliases
266
+ if _tag_to_aliases is not None:
267
+ return _tag_to_aliases
268
+
269
+ if not TAG_ALIASES_PATH.is_file():
270
+ raise FileNotFoundError(f"Tag alias CSV not found: {TAG_ALIASES_PATH}")
271
+
272
+ _tag_to_aliases = build_aliases_dict(str(TAG_ALIASES_PATH), reverse=False)
273
+ return _tag_to_aliases
274
+
275
+
276
+ def get_tfidf_tag_vectors() -> Dict[str, Any]:
277
+ global _tfidf_tag_vectors
278
+ if _tfidf_tag_vectors is not None:
279
+ return _tfidf_tag_vectors
280
+
281
+ components = get_tfidf_components()
282
+ reduced_matrix = components.get("reduced_matrix")
283
+ if reduced_matrix is None:
284
+ raise KeyError("TF-IDF components missing reduced_matrix")
285
+
286
+ row_to_tag = components.get("row_to_tag")
287
+ if row_to_tag is None and "tag_to_row_index" in components:
288
+ row_to_tag = {idx: tag for tag, idx in components["tag_to_row_index"].items()}
289
+ if row_to_tag is None:
290
+ raise KeyError("TF-IDF components missing row_to_tag mapping")
291
+
292
+ tag_to_row_index = components.get("tag_to_row_index")
293
+ if tag_to_row_index is None:
294
+ tag_to_row_index = {tag: idx for idx, tag in row_to_tag.items()}
295
+
296
+ reduced_matrix_norm = _l2_normalize_rows(reduced_matrix).astype(np.float32)
297
+
298
+ _tfidf_tag_vectors = {
299
+ "reduced_matrix": reduced_matrix,
300
+ "reduced_matrix_norm": reduced_matrix_norm,
301
+ "row_to_tag": row_to_tag,
302
+ "tag_to_row_index": tag_to_row_index,
303
+ }
304
+ return _tfidf_tag_vectors
305
+
306
+
307
+ def retrieval_assets_status() -> Dict[str, bool]:
308
+ return {
309
+ "tfidf": TFIDF_PATH.is_file(),
310
+ "nsfw_csv": NSFW_CSV_PATH.is_file(),
311
+ "fasttext_model": FASTTEXT_MODEL_PATH.is_file(),
312
+ "tag_aliases_csv": TAG_ALIASES_PATH.is_file(),
313
+ "hnsw_tags": HNSW_TAG_PATH.is_file(),
314
+ "hnsw_artists": HNSW_ART_PATH.is_file(),
315
+ }
316
+
317
+
318
+ def _build_or_load_index(path: pathlib.Path, rows: list[int], rm: np.ndarray, dim: int) -> "hnswlib.Index":
319
+ idx = hnswlib.Index(space="cosine", dim=dim)
320
+ need_build = True
321
+ if path.exists():
322
+ try:
323
+ idx.load_index(str(path), max_elements=max(1, len(rows)))
324
+ if getattr(idx, "get_current_count", None) and idx.get_current_count() == len(rows) and len(rows) > 0:
325
+ need_build = False
326
+ else:
327
+ logging.debug(
328
+ "Rebuilding %s: saved_count!=rows_len (%s vs %s)",
329
+ path.name,
330
+ idx.get_current_count(),
331
+ len(rows),
332
+ )
333
+ except Exception as e:
334
+ logging.debug("Reload %s failed, rebuilding: %s", path.name, e)
335
+
336
+ if need_build:
337
+ try:
338
+ if path.exists():
339
+ path.unlink()
340
+ except Exception:
341
+ pass
342
+ idx.init_index(max_elements=max(1, len(rows)), ef_construction=200, M=16)
343
+ if rows:
344
+ idx.add_items(rm[rows], ids=np.asarray(rows, dtype=np.int32))
345
+ idx.save_index(str(path))
346
+
347
+ idx.set_ef(200)
348
+ return idx
349
+
350
+
351
+ def _ensure_hnsw_indexes(need_artists: bool) -> None:
352
+ global _hnsw_tag_index, _hnsw_artist_index, _hnsw_tag_count, _hnsw_artist_count
353
+
354
+ if hnswlib is None:
355
+ return
356
+
357
+ if _hnsw_tag_index is not None and (not need_artists or _hnsw_artist_index is not None):
358
+ return
359
+
360
+ components = get_tfidf_components()
361
+ reduced_matrix = components["reduced_matrix"]
362
+ row_to_tag = components["row_to_tag"]
363
+ rm = _l2_normalize_rows(reduced_matrix).astype(np.float32)
364
+ n_items, dim = rm.shape
365
+
366
+ artist_set = get_artist_set() if need_artists else set()
367
+ artist_rows: list[int] = []
368
+ tag_rows: list[int] = []
369
+
370
+ for i in range(n_items):
371
+ tag = row_to_tag.get(i, "")
372
+ base = tag[3:] if tag.startswith("by_") else tag
373
+
374
+ if tag in {"by_unknown_artist", "by_conditional_dnp"}:
375
+ tag_rows.append(i)
376
+ continue
377
+
378
+ if artist_set and is_artist(base):
379
+ artist_rows.append(i)
380
+ else:
381
+ tag_rows.append(i)
382
+
383
+ _hnsw_tag_index = _build_or_load_index(HNSW_TAG_PATH, tag_rows, rm, dim)
384
+ _hnsw_tag_count = len(tag_rows)
385
+
386
+ if need_artists:
387
+ _hnsw_artist_index = _build_or_load_index(HNSW_ART_PATH, artist_rows, rm, dim)
388
+ _hnsw_artist_count = len(artist_rows)
389
+
390
+
391
+ def get_hnsw_tag_index() -> Tuple[Optional["hnswlib.Index"], int]:
392
+ _ensure_hnsw_indexes(need_artists=False)
393
+ return _hnsw_tag_index, _hnsw_tag_count
394
+
395
+
396
+ def get_hnsw_artist_index() -> Tuple[Optional["hnswlib.Index"], int]:
397
+ _ensure_hnsw_indexes(need_artists=True)
398
+ return _hnsw_artist_index, _hnsw_artist_count
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio==4.44.1
2
+ gradio-client==1.3.0
3
+ hnswlib==0.8.0
4
+ numpy==1.25.1
5
+ scikit-learn==1.4.1.post1
6
+ h5py==3.8.0
7
+ joblib==1.2.0
8
+ compress-fasttext
9
+ lark-parser
10
+ scipy==1.12.0
11
+ gensim==4.3.2
12
+ huggingface_hub<1.0
13
+ rapidfuzz>=3.0
scripts/extract_tag_patterns.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Extract syntactic / compositional patterns from an e621-style tag CSV.
4
+
5
+ Assumptions:
6
+ - Input is a CSV (or TSV) where the FIRST column contains the tag string.
7
+ - Tags are typically underscore-delimited, e.g. "blue_shirt", "looking_at_viewer".
8
+ - We want PATTERN statistics, not just top tags.
9
+
10
+ Outputs:
11
+ - Top underscore-shape templates (e.g., "<w>_<w>", "<w>_<w>_<w>")
12
+ - Top "suffix patterns" (e.g., "<w>_shirt", "<w>_fur") that catch color+object style combos
13
+ - Top "prefix patterns" (e.g., "looking_<w>_<w>")
14
+ - Heuristic slot-typed templates (e.g., "<color>_<clothing>") based on small dictionaries
15
+
16
+ No third-party deps.
17
+ """
18
+
19
+ import argparse
20
+ import csv
21
+ import os
22
+ import random
23
+ import re
24
+ import sys
25
+ from collections import Counter, defaultdict
26
+ from typing import Dict, Iterable, List, Tuple
27
+
28
+
29
+ # ---- small heuristic lexicons (edit freely) ----
30
+ # Keep these small: they are for *pattern discovery*, not canonicalization.
31
+ COLORS = {
32
+ "black","white","grey","gray","red","blue","green","yellow","orange","purple","pink","brown","tan",
33
+ "silver","gold","blonde","blond","aqua","teal","cyan","magenta","violet","indigo","maroon","navy",
34
+ "beige","cream","ivory","turquoise","lavender",
35
+ "multicolored","two_tone","two_toned",
36
+ }
37
+
38
+ # Clothing nouns seen frequently in tag vocabularies (add as you notice them).
39
+ CLOTHING = {
40
+ "shirt","pants","shorts","dress","skirt","underwear","panties","bra","bikini","swimwear",
41
+ "topwear","bottomwear","legwear","handwear","armwear","footwear","stockings","socks","shoes","boots",
42
+ "gloves","hat","headwear","headgear","collar","armor","mask",
43
+ }
44
+
45
+ BODY = {
46
+ "fur","hair","eyes","tail","ears","horn","wings","paws","toes","fingers","nipples","breasts",
47
+ "belly","butt","penis","pussy","anus","clitoris","hooves","teeth","fangs","tongue","nose",
48
+ }
49
+
50
+ VIEW_WORDS = {
51
+ "front","rear","side","back","from","first","third"
52
+ }
53
+
54
+ # ---- helpers ----
55
+
56
+ def detect_dialect(path: str) -> csv.Dialect:
57
+ # Very small sniff. If it fails, fall back to comma.
58
+ with open(path, "r", encoding="utf-8", newline="") as f:
59
+ sample = f.read(4096)
60
+ try:
61
+ return csv.Sniffer().sniff(sample, delimiters=[",", "\t", ";", "|"])
62
+ except Exception:
63
+ return csv.get_dialect("excel")
64
+
65
+
66
+ def iter_tags_from_first_col(path: str, sample_n: int | None, seed: int) -> Iterable[str]:
67
+ dialect = detect_dialect(path)
68
+ rng = random.Random(seed)
69
+
70
+ # If sampling, do reservoir sampling so we don't load the whole file.
71
+ reservoir: List[str] = []
72
+ seen = 0
73
+
74
+ with open(path, "r", encoding="utf-8", newline="") as f:
75
+ reader = csv.reader(f, dialect=dialect)
76
+ for row in reader:
77
+ if not row:
78
+ continue
79
+ tag = row[0].strip()
80
+ if not tag:
81
+ continue
82
+
83
+ # skip a header row if it looks like one
84
+ if seen == 0 and tag.lower() in {"tag", "tags", "name"}:
85
+ seen += 1
86
+ continue
87
+
88
+ seen += 1
89
+ if sample_n is None:
90
+ yield tag
91
+ else:
92
+ # reservoir sampling
93
+ if len(reservoir) < sample_n:
94
+ reservoir.append(tag)
95
+ else:
96
+ j = rng.randrange(seen)
97
+ if j < sample_n:
98
+ reservoir[j] = tag
99
+
100
+ if sample_n is not None:
101
+ for t in reservoir:
102
+ yield t
103
+
104
+
105
+ _word_re = re.compile(r"^[a-z0-9]+(?:[/-][a-z0-9]+)*$")
106
+
107
+ def normalize_tag(tag: str) -> str:
108
+ # keep underscores; strip whitespace; lowercase.
109
+ return tag.strip().lower()
110
+
111
+ def split_parts(tag: str) -> List[str]:
112
+ return [p for p in tag.split("_") if p]
113
+
114
+ def underscore_shape(parts: List[str]) -> str:
115
+ # e.g. 1 part -> "<w>"
116
+ return "_".join(["<w>"] * len(parts))
117
+
118
+ def suffix_pattern(parts: List[str], k: int = 1) -> str | None:
119
+ # "<w>_<w>_shirt" style: wildcard prefix + suffix tokens.
120
+ if len(parts) <= k:
121
+ return None
122
+ suf = "_".join(parts[-k:])
123
+ return "<w>_" + suf if k == 1 else "<w>..._" + suf
124
+
125
+ def prefix_pattern(parts: List[str], k: int = 1) -> str | None:
126
+ # "looking_<w>_<w>" style: prefix token(s) + wildcard suffix
127
+ if len(parts) <= k:
128
+ return None
129
+ pre = "_".join(parts[:k])
130
+ return pre + "_<w>" if k == 1 else pre + "_<w>..."
131
+
132
+ def typed_token(tok: str) -> str:
133
+ # heuristic slot typing
134
+ if tok.isdigit():
135
+ return "<num>"
136
+ if tok in COLORS:
137
+ return "<color>"
138
+ if tok in CLOTHING:
139
+ return "<clothing>"
140
+ if tok in BODY:
141
+ return "<body>"
142
+ if tok in {"male","female","intersex","gynomorph","ambiguous_gender"}:
143
+ return "<gender>"
144
+ return "<w>"
145
+
146
+ def typed_template(parts: List[str]) -> str:
147
+ return "_".join(typed_token(p) for p in parts)
148
+
149
+ def bigram_templates(parts: List[str]) -> List[str]:
150
+ # adjacent pair templates: useful for color+thing detection even if full tag is longer
151
+ out = []
152
+ for a, b in zip(parts, parts[1:]):
153
+ out.append(f"{typed_token(a)}_{typed_token(b)}")
154
+ return out
155
+
156
+
157
+ def print_counter(title: str, c: Counter, top: int, min_count: int) -> None:
158
+ print("\n" + title)
159
+ print("-" * len(title))
160
+ shown = 0
161
+ for key, val in c.most_common():
162
+ if val < min_count:
163
+ break
164
+ print(f"{val:>8} {key}")
165
+ shown += 1
166
+ if shown >= top:
167
+ break
168
+ if shown == 0:
169
+ print("(no entries above min_count)")
170
+
171
+ def main() -> None:
172
+ ap = argparse.ArgumentParser()
173
+ ap.add_argument("csv_path", help="Path to CSV/TSV; first column is tag")
174
+ ap.add_argument("--top", type=int, default=100, help="Top N patterns to print per section")
175
+ ap.add_argument("--min-count", type=int, default=25, help="Minimum count to show")
176
+ ap.add_argument("--sample", type=int, default=None, help="Reservoir sample N tags instead of full file")
177
+ ap.add_argument("--seed", type=int, default=0, help="RNG seed for sampling")
178
+ ap.add_argument("--max-rows", type=int, default=None, help="Hard stop after reading this many rows (debug)")
179
+ args = ap.parse_args()
180
+
181
+ path = args.csv_path
182
+ if not os.path.exists(path):
183
+ print(f"ERROR: file not found: {path}", file=sys.stderr)
184
+ sys.exit(1)
185
+
186
+ shape_counts = Counter()
187
+ typed_counts = Counter()
188
+ suffix1_counts = Counter()
189
+ suffix2_counts = Counter()
190
+ prefix1_counts = Counter()
191
+ prefix2_counts = Counter()
192
+ bigram_typed_counts = Counter()
193
+ token_counts = Counter()
194
+ length_counts = Counter()
195
+ prefix_head_counts = defaultdict(Counter)
196
+
197
+ read = 0
198
+ for raw_tag in iter_tags_from_first_col(path, args.sample, args.seed):
199
+ tag = normalize_tag(raw_tag)
200
+ parts = split_parts(tag)
201
+ if not parts:
202
+ continue
203
+ if len(parts) >= 2:
204
+ prefix = parts[0]
205
+ head = parts[1]
206
+ prefix_head_counts[prefix][head] += 1
207
+
208
+ read += 1
209
+ if args.max_rows is not None and read > args.max_rows:
210
+ break
211
+
212
+ length_counts[len(parts)] += 1
213
+ token_counts.update(parts)
214
+
215
+ shape_counts[underscore_shape(parts)] += 1
216
+ typed_counts[typed_template(parts)] += 1
217
+ bigram_typed_counts.update(bigram_templates(parts))
218
+
219
+ sp1 = suffix_pattern(parts, k=1)
220
+ if sp1: suffix1_counts[sp1] += 1
221
+ sp2 = suffix_pattern(parts, k=2)
222
+ if sp2: suffix2_counts[sp2] += 1
223
+
224
+ pp1 = prefix_pattern(parts, k=1)
225
+ if pp1: prefix1_counts[pp1] += 1
226
+ pp2 = prefix_pattern(parts, k=2)
227
+ if pp2: prefix2_counts[pp2] += 1
228
+
229
+ print(f"Read {read} tags from {path}")
230
+ print_counter("Tag length distribution (#parts)", length_counts, top=50, min_count=1)
231
+ print_counter("Top underscore-shapes", shape_counts, top=args.top, min_count=args.min_count)
232
+
233
+ # Typed templates will be sparse if lexicons are small; still useful.
234
+ print_counter("Top typed templates (heuristic)", typed_counts, top=args.top, min_count=args.min_count)
235
+
236
+ # Bigrams are the best way to surface “collectively important” schemas.
237
+ print_counter("Top typed bigrams (heuristic, adjacent parts)", bigram_typed_counts, top=args.top, min_count=args.min_count)
238
+
239
+ # Suffix patterns show color+THING and modifier+THING tendencies.
240
+ print_counter("Top suffix patterns (last token)", suffix1_counts, top=args.top, min_count=args.min_count)
241
+ print_counter("Top suffix patterns (last 2 tokens)", suffix2_counts, top=args.top, min_count=args.min_count)
242
+
243
+ # Prefix patterns show looking_* and similar families.
244
+ print_counter("Top prefix patterns (first token)", prefix1_counts, top=args.top, min_count=args.min_count)
245
+ print_counter("Top prefix patterns (first 2 tokens)", prefix2_counts, top=args.top, min_count=args.min_count)
246
+
247
+ # Show the most common tokens too (useful for expanding lexicons).
248
+ print_counter("Top tokens (raw parts)", token_counts, top=200, min_count=max(args.min_count, 100))
249
+
250
+ def print_prefix_families(prefix_head_counts, top_prefixes=50, top_heads=15, min_prefix_count=100):
251
+ print("\nPrefix -> Head Families")
252
+ print("----------------------")
253
+
254
+ # rank prefixes by total usage
255
+ prefix_totals = {
256
+ p: sum(heads.values())
257
+ for p, heads in prefix_head_counts.items()
258
+ }
259
+
260
+ for prefix, total in sorted(prefix_totals.items(), key=lambda x: -x[1]):
261
+ if total < min_prefix_count:
262
+ break
263
+
264
+ print(f"\nPREFIX: {prefix} (total={total})")
265
+ for head, cnt in prefix_head_counts[prefix].most_common(top_heads):
266
+ print(f" {head:<20} {cnt}")
267
+
268
+ print_prefix_families(prefix_head_counts)
269
+
270
+
271
+ if __name__ == "__main__":
272
+ main()
scripts/rewrite_playground.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from pathlib import Path
5
+
6
+ import requests
7
+
8
+ CAPTION_FIELDS = ["caption_llm_4", "caption_llm_6", "caption_cogvlm"]
9
+
10
+ # Start with something minimal. You will iterate this.
11
+ REWRITE_SYSTEM = """Rewrite the input into a concise, comma-separated list of short phrases
12
+ that resemble image tags.
13
+
14
+ Use short, literal phrases that reflect how visual concepts are commonly
15
+ written in image tag vocabularies.
16
+
17
+ Multi-word phrases are appropriate when they represent one coherent
18
+ visual idea.
19
+
20
+ Examples of tag-shaped phrases:
21
+ - wolf, angry
22
+ - blue jacket, striped tail
23
+ - long hair, raised ears
24
+ - holding object, hand on shoulder
25
+ - looking at viewer, looking down
26
+ - simple background, outdoor scene
27
+ - wooden table, plant
28
+ - running, sleeping
29
+ - smiling, angry expression
30
+ - bedroom, forest
31
+ - sonic the hedgehog, princess peach
32
+
33
+ Do not invent details or guess identities.
34
+ Do not infer demographic attributes (e.g., gender/age) unless explicitly stated.
35
+
36
+ Output ONLY the rewritten list."""
37
+
38
+
39
+ def load_jsonl(path: Path):
40
+ with path.open("r", encoding="utf-8") as f:
41
+ for line in f:
42
+ yield json.loads(line)
43
+
44
+
45
+ def openrouter_chat(model: str, system: str, user: str, temperature: float = 0.0, max_tokens: int = 200) -> str:
46
+ api_key = os.environ.get("OPENROUTER_API_KEY")
47
+ if not api_key:
48
+ raise RuntimeError("Set OPENROUTER_API_KEY in your environment.")
49
+
50
+ url = "https://openrouter.ai/api/v1/chat/completions"
51
+ headers = {
52
+ "Authorization": f"Bearer {api_key}",
53
+ "Content-Type": "application/json",
54
+ }
55
+ payload = {
56
+ "model": model,
57
+ "temperature": temperature,
58
+ "max_tokens": max_tokens,
59
+ "messages": [
60
+ {"role": "system", "content": system},
61
+ {"role": "user", "content": user},
62
+ ],
63
+ }
64
+
65
+ r = requests.post(url, headers=headers, json=payload, timeout=60)
66
+ r.raise_for_status()
67
+ data = r.json()
68
+ return data["choices"][0]["message"]["content"].strip()
69
+
70
+
71
+ def main() -> None:
72
+ ap = argparse.ArgumentParser(description="Interactive prompt **query rewriting** playground.")
73
+ ap.add_argument("--sample", type=str, required=True, help="Path to the trimmed JSONL sample.")
74
+ ap.add_argument("--field", type=str, default="caption_llm_6", choices=CAPTION_FIELDS)
75
+ ap.add_argument("--model", type=str, default="meta-llama/llama-3.1-8b-instruct")
76
+ ap.add_argument("--temperature", type=float, default=0.0)
77
+ ap.add_argument("--max-tokens", type=int, default=200)
78
+ ap.add_argument("--start", type=int, default=0, help="Index to start from within the loaded examples.")
79
+ args = ap.parse_args()
80
+
81
+ rows = []
82
+ for row in load_jsonl(Path(args.sample)):
83
+ text = (row.get(args.field) or "").strip()
84
+ if text:
85
+ gt = row.get("tags_ground_truth_categorized")
86
+ rows.append((str(row["id"]), text, gt))
87
+
88
+ if not rows:
89
+ raise RuntimeError(f"No non-empty rows found for field={args.field}")
90
+
91
+ print(f"Loaded {len(rows)} examples from {args.sample} using {args.field}.")
92
+ print("Commands: [Enter]=next | r=rerun current (same input) | q=quit\n")
93
+
94
+ if args.start < 0 or args.start >= len(rows):
95
+ raise ValueError(f"--start must be in [0, {len(rows)-1}] but got {args.start}")
96
+ idx = args.start
97
+ while True:
98
+ row_id, prompt, gt = rows[idx]
99
+ print("=" * 80)
100
+ print(f"row_id: {row_id}")
101
+ print(f"ORIGINAL:\n{prompt}\n")
102
+
103
+ rewritten = openrouter_chat(
104
+ model=args.model,
105
+ system=REWRITE_SYSTEM,
106
+ user=prompt,
107
+ temperature=args.temperature,
108
+ max_tokens=args.max_tokens,
109
+ )
110
+ print(f"REWRITE:\n{rewritten}\n")
111
+
112
+ if gt:
113
+ gt_dict = json.loads(gt)
114
+ flat_gt = sorted({tag for tags in gt_dict.values() for tag in tags})
115
+ print(f"GROUND TRUTH TAGS:\n{', '.join(flat_gt)}\n")
116
+
117
+
118
+ cmd = input("> ").strip().lower()
119
+ if cmd == "q":
120
+ break
121
+ if cmd == "r":
122
+ continue
123
+ idx += 1
124
+ if idx >= len(rows):
125
+ print("End of samples.")
126
+ break
127
+
128
+
129
+ if __name__ == "__main__":
130
+ main()
scripts/sample_dataset_streaming.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+
5
+ from datasets import load_dataset
6
+
7
+ DATASET_ID = "CaptionEmporium/furry-e621-sfw-7m-hq"
8
+ SPLIT = "train"
9
+
10
+ # Adjust these names if your actual columns differ.
11
+ CAPTION_FIELDS = ["caption_llm_6", "caption_llm_8", "caption_cogvlm"]
12
+ KEEP_FIELDS = ["tags_ground_truth_categorized"] + CAPTION_FIELDS
13
+
14
+
15
+ def pick_id(row: dict) -> str:
16
+ # Try a few common id keys; fall back to a hash-like stable string.
17
+ for k in ("id", "post_id", "e621_id", "image_id"):
18
+ if k in row and row[k] not in (None, ""):
19
+ return str(row[k])
20
+ # As a fallback, derive a stable-ish id from caption text.
21
+ base = (row.get("caption_llm_6") or row.get("caption_llm_8") or row.get("caption_cogvlm") or "")
22
+ return f"no_id:{hash(base)}"
23
+
24
+
25
+ def main() -> None:
26
+ ap = argparse.ArgumentParser(description="Stream+shuffle sample and save a trimmed JSONL for prompt experiments.")
27
+ ap.add_argument("--n", type=int, default=1000)
28
+ ap.add_argument("--seed", type=int, default=123)
29
+ ap.add_argument("--buffer-size", type=int, default=10_000)
30
+ ap.add_argument(
31
+ "--out",
32
+ type=str,
33
+ default="data/eval_samples/e621_sfw_sample_1000_seed123_buffer10000_trimmed.jsonl",
34
+ )
35
+ ap.add_argument(
36
+ "--require-any-caption",
37
+ action="store_true",
38
+ help="If set, only keep rows where at least one of the caption fields is non-empty.",
39
+ )
40
+ args = ap.parse_args()
41
+
42
+ out_path = Path(args.out)
43
+ out_path.parent.mkdir(parents=True, exist_ok=True)
44
+
45
+ ds = load_dataset(DATASET_ID, split=SPLIT, streaming=True)
46
+ ds = ds.shuffle(seed=args.seed, buffer_size=args.buffer_size)
47
+
48
+ wrote = 0
49
+ with out_path.open("w", encoding="utf-8") as f:
50
+ for row in ds:
51
+ out = {"row_id": pick_id(row)}
52
+ for k in KEEP_FIELDS:
53
+ out[k] = row.get(k, "")
54
+
55
+ if args.require_any_caption:
56
+ if not any((out.get(c) or "").strip() for c in CAPTION_FIELDS):
57
+ continue
58
+
59
+ f.write(json.dumps(out, ensure_ascii=False) + "\n")
60
+ wrote += 1
61
+ if wrote >= args.n:
62
+ break
63
+
64
+ print(f"Wrote {wrote} rows to: {out_path}")
65
+
66
+
67
+ if __name__ == "__main__":
68
+ main()
scripts/smoke_test.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import sys
3
+ import traceback
4
+ import os
5
+
6
+ # Add repo root (parent of /scripts) to sys.path
7
+ repo_root = Path(__file__).resolve().parents[1]
8
+ sys.path.insert(0, str(repo_root))
9
+ os.chdir(repo_root)
10
+
11
+
12
+ def main():
13
+ from psq_rag.llm.rewrite import llm_rewrite_prompt
14
+ from psq_rag.retrieval.psq_retrieval import (
15
+ psq_candidates_from_rewrite_phrases,
16
+ )
17
+ from psq_rag.retrieval.state import (
18
+ get_artist_set,
19
+ get_nsfw_tags,
20
+ )
21
+
22
+ def log(x=""):
23
+ print(x)
24
+
25
+ def assert_true(condition, message):
26
+ if not condition:
27
+ raise AssertionError(message)
28
+
29
+ def print_failure(message, exc):
30
+ log(f"FAIL: {message}")
31
+ if exc is not None:
32
+ for line in traceback.format_exception_only(type(exc), exc):
33
+ log(line.rstrip())
34
+
35
+ def import_sanity():
36
+ try:
37
+ __import__("psq_rag.retrieval.state")
38
+ __import__("psq_rag.retrieval.psq_retrieval")
39
+ __import__("psq_rag.parsing.prompt_grammar")
40
+ __import__("psq_rag.llm.rewrite")
41
+ import app
42
+ log("import sanity: ok")
43
+ except Exception as e:
44
+ log(f"import sanity: {type(e).__name__}: {e}")
45
+
46
+ import_sanity()
47
+
48
+ stage2_only = "--stage2-only" in sys.argv
49
+
50
+ if not stage2_only:
51
+ prompt = "ape, raised arms, looking at viewer"
52
+ rewrite = llm_rewrite_prompt(prompt, log)
53
+ if rewrite:
54
+ print("rewrite:", rewrite)
55
+ else:
56
+ log("LLM rewrite: no result (continuing)")
57
+
58
+ def run_stage2_test_a():
59
+ phrases = ["big shirt", "grey shirt"]
60
+ cands, per_phrase = psq_candidates_from_rewrite_phrases(
61
+ rewrite_phrases=phrases,
62
+ allow_nsfw_tags=True,
63
+ verbose=True,
64
+ global_k=300,
65
+ per_phrase_k=50,
66
+ per_phrase_final_k=10,
67
+ )
68
+ print("cands:", len(cands))
69
+
70
+ assert_true(isinstance(per_phrase, list), "per_phrase must be a list")
71
+ phrase_set = {report.get("phrase") for report in per_phrase}
72
+ assert_true("big shirt" in phrase_set, "per_phrase missing entry for 'big shirt'")
73
+ assert_true("grey shirt" in phrase_set, "per_phrase missing entry for 'grey shirt'")
74
+ assert_true("shirt" in phrase_set, "per_phrase missing head-noun expansion for 'shirt'")
75
+
76
+ required_report_keys = {"phrase", "normalized", "lookup", "tfidf_vocab", "oov_terms", "candidates"}
77
+ required_row_keys = {
78
+ "tag",
79
+ "alias_token",
80
+ "score_fasttext",
81
+ "score_context",
82
+ "score_combined",
83
+ "context_imputed",
84
+ "count",
85
+ }
86
+ for report in per_phrase:
87
+ assert_true(required_report_keys.issubset(report.keys()), "per_phrase missing required keys")
88
+ rows = report.get("candidates", [])
89
+ assert_true(isinstance(rows, list), "per_phrase candidates must be a list")
90
+ for row in rows:
91
+ assert_true(required_row_keys.issubset(row.keys()), "candidate row missing required keys")
92
+
93
+ big_report = None
94
+ for report in per_phrase:
95
+ if report.get("phrase") == "big shirt":
96
+ big_report = report
97
+ break
98
+ assert_true(big_report is not None, "no per_phrase report found for 'big shirt'")
99
+ big_tags = {row.get("tag") for row in big_report.get("candidates", [])}
100
+ assert_true("big_shirt" in big_tags, "big_shirt missing from per_phrase_final_k for 'big shirt'")
101
+
102
+ log("stage2-only test A: PASS")
103
+
104
+ def run_stage2_test_b():
105
+ phrases = ["anuss"]
106
+ result_unfiltered = psq_candidates_from_rewrite_phrases(
107
+ rewrite_phrases=phrases,
108
+ allow_nsfw_tags=True,
109
+ verbose=False,
110
+ global_k=300,
111
+ per_phrase_k=50,
112
+ per_phrase_final_k=10,
113
+ )
114
+ result_filtered = psq_candidates_from_rewrite_phrases(
115
+ rewrite_phrases=phrases,
116
+ allow_nsfw_tags=False,
117
+ verbose=False,
118
+ global_k=300,
119
+ per_phrase_k=50,
120
+ per_phrase_final_k=10,
121
+ )
122
+ cands_unfiltered = result_unfiltered[0] if isinstance(result_unfiltered, tuple) else result_unfiltered
123
+ cands_filtered = result_filtered[0] if isinstance(result_filtered, tuple) else result_filtered
124
+
125
+ def extract_tag(row):
126
+ if hasattr(row, "get"):
127
+ return row.get("tag")
128
+ return getattr(row, "tag", None)
129
+
130
+ unfiltered_tags = {extract_tag(row) for row in cands_unfiltered}
131
+ filtered_tags = {extract_tag(row) for row in cands_filtered}
132
+ assert_true("anus" in unfiltered_tags, "anus missing from unfiltered candidates")
133
+ assert_true("anus" not in filtered_tags, "anus unexpectedly present in filtered candidates")
134
+ log(f"stage2-only test B: PASS (anus in unfiltered={ 'anus' in unfiltered_tags }, in filtered={ 'anus' in filtered_tags })")
135
+
136
+ if stage2_only:
137
+ try:
138
+ run_stage2_test_a()
139
+ run_stage2_test_b()
140
+ except AssertionError as exc:
141
+ print_failure("stage2 contract assertion failed", exc)
142
+ sys.exit(1)
143
+ return
144
+
145
+ # Artist set check (optional in RAG mode)
146
+ try:
147
+ artists = get_artist_set()
148
+ log(f"artist set size: {len(artists)}")
149
+ except Exception as e:
150
+ log(f"artist set: {type(e).__name__}: {e}")
151
+
152
+ try:
153
+ nsfw_tags = get_nsfw_tags()
154
+ log(f"nsfw tag count: {len(nsfw_tags)}")
155
+ except Exception as e:
156
+ log(f"nsfw tags: {type(e).__name__}: {e}")
157
+
158
+ if __name__ == "__main__":
159
+ main()
scripts/stage3_debug.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stage 3 debug harness.
2
+
3
+ Goal: Run a realistic Stage 3 selection loop without manually typing hundreds of candidates.
4
+
5
+ Typical usage (bypass Stage 1):
6
+ python scripts/stage3_debug.py --prompt "..." --phrases "a, b, c" \
7
+ --no-allow-nsfw --mode chunked_map_union --chunk-size 60 --per-phrase-k 2
8
+
9
+ If --phrases is omitted, the script uses a simple fallback: it treats the prompt as a
10
+ comma-separated list of phrases (useful for quick tests).
11
+
12
+ Outputs:
13
+ - Stage 2 candidate stats
14
+ - Stage 3 per-call config + validation diagnostics (via the selector's log hook)
15
+ - Final selected tags
16
+
17
+ NOTE: This script expects your project package imports to work (run from repo root).
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import argparse
23
+ import sys
24
+ from pathlib import Path
25
+ from typing import Any, List, Sequence, cast
26
+
27
+ # Ensure repo root is on sys.path when running as a script:
28
+ # python scripts/stage3_debug.py ...
29
+ # This makes `import psq_rag...` work without requiring editable installs.
30
+ _REPO_ROOT = Path(__file__).resolve().parents[1]
31
+ if str(_REPO_ROOT) not in sys.path:
32
+ sys.path.insert(0, str(_REPO_ROOT))
33
+
34
+
35
+ def _split_csv_phrases(s: str) -> List[str]:
36
+ # Very small helper: split on commas, trim, drop empties.
37
+ return [p.strip() for p in s.split(",") if p.strip()]
38
+
39
+
40
+ def _import_stage2_entrypoint():
41
+ """Import Stage 2 entrypoints.
42
+
43
+ This harness supports two paths:
44
+
45
+ A) End-to-end-ish (Stage 1 + Stage 2):
46
+ - psq_candidates_from_prompt(prompt: str, allow_nsfw_tags: bool, ...)
47
+
48
+ B) Bypass Stage 1 (Stage 2 only):
49
+ - psq_candidates_from_rewrite_phrases(rewrite_phrases: List[str], allow_nsfw_tags: bool, ...)
50
+
51
+ We import both and choose at runtime depending on whether --phrases was provided.
52
+ """
53
+
54
+ import psq_rag.retrieval.psq_retrieval as m
55
+
56
+ fn_prompt = getattr(m, "psq_candidates_from_prompt", None)
57
+ fn_phrases = getattr(m, "psq_candidates_from_rewrite_phrases", None)
58
+
59
+ if callable(fn_prompt) or callable(fn_phrases):
60
+ return fn_prompt, fn_phrases
61
+
62
+ # Older naming (very old code paths):
63
+ fn = getattr(m, "psq_candidates", None)
64
+ if callable(fn):
65
+ # Use the same function for both paths
66
+ return fn, fn
67
+
68
+ # Fail loudly with guidance.
69
+ public = [
70
+ name
71
+ for name, obj in vars(m).items()
72
+ if callable(obj) and not name.startswith("_")
73
+ ]
74
+ raise RuntimeError(
75
+ "Expected Stage 2 function psq_candidates(...) in psq_rag.retrieval.psq_retrieval, "
76
+ "but it was not found.\n"
77
+ "Public callables in that module:\n - "
78
+ + "\n - ".join(sorted(public))
79
+ + "\n\nUpdate _import_stage2_entrypoint() in scripts/stage3_debug.py to use the correct function."
80
+ )
81
+
82
+
83
+ def _import_stage3_selector():
84
+ from psq_rag.llm.select import llm_select_indices, _split_candidates_by_type
85
+
86
+ return llm_select_indices, _split_candidates_by_type
87
+
88
+
89
+ def _import_stage1_rewrite():
90
+ from psq_rag.llm.rewrite import llm_rewrite_prompt
91
+
92
+ return llm_rewrite_prompt
93
+
94
+
95
+ def _as_list_candidates(stage2_result: Any):
96
+ """Normalize various plausible Stage 2 return shapes into (candidates, aux).
97
+
98
+ Common patterns we've used in this repo:
99
+ - candidates
100
+ - (candidates, verbose_rows)
101
+ - (candidates, anything_else, ...)
102
+ - {"candidates": [...], ...}
103
+ """
104
+
105
+ if isinstance(stage2_result, dict) and "candidates" in stage2_result:
106
+ return stage2_result["candidates"], stage2_result
107
+
108
+ if isinstance(stage2_result, tuple) and len(stage2_result) >= 1:
109
+ return stage2_result[0], stage2_result
110
+
111
+ return stage2_result, None
112
+
113
+
114
+ def _safe_tag_display(tag: str) -> str:
115
+ return tag.replace("_", " ")
116
+
117
+
118
+ def _print_top_candidates(cands: Sequence[Any], n: int) -> None:
119
+ # Candidate is a dataclass-like object with fields: tag, score_combined, count, sources.
120
+ # We do NOT print sources/count too noisily; this is just a quick glance.
121
+ print(f"\nTop {min(n, len(cands))} candidates (by score_combined, then count):")
122
+
123
+ def key(c):
124
+ sc = getattr(c, "score_combined", 0.0)
125
+ ct = getattr(c, "count", None)
126
+ return (sc, ct if ct is not None else -1)
127
+
128
+ for i, c in enumerate(sorted(cands, key=key, reverse=True)[:n], start=1):
129
+ tag = getattr(c, "tag", str(c))
130
+ sc = getattr(c, "score_combined", None)
131
+ ct = getattr(c, "count", None)
132
+ sc_s = f"{sc:.4f}" if isinstance(sc, (float, int)) else "?"
133
+ ct_s = str(ct) if ct is not None else "?"
134
+ print(f" {i:>2}. {_safe_tag_display(tag)} score={sc_s} count={ct_s}")
135
+
136
+
137
+ def _describe_candidate_sample(cands: Sequence[Any], n: int = 5) -> None:
138
+ print(f"\nCandidate contract sample (first {min(n, len(cands))}):")
139
+ for i, c in enumerate(cands[:n], start=1):
140
+ if hasattr(c, "tag"):
141
+ tag = getattr(c, "tag", None)
142
+ sc = getattr(c, "score_combined", None)
143
+ sf = getattr(c, "score_fasttext", None)
144
+ sx = getattr(c, "score_context", None)
145
+ ct = getattr(c, "count", None)
146
+ src = getattr(c, "sources", None)
147
+ print(
148
+ " "
149
+ f"{i}. type={type(c).__name__} "
150
+ f"tag={tag!r} "
151
+ f"score_combined={sc!r}({type(sc).__name__}) "
152
+ f"score_fasttext={sf!r}({type(sf).__name__}) "
153
+ f"score_context={sx!r}({type(sx).__name__}) "
154
+ f"count={ct!r}({type(ct).__name__}) "
155
+ f"sources={src!r}"
156
+ )
157
+ elif isinstance(c, (list, tuple)):
158
+ parts = list(c)[:3]
159
+ parts_t = [type(p).__name__ for p in parts]
160
+ print(f" {i}. type={type(c).__name__} head={parts!r} types={parts_t}")
161
+ else:
162
+ print(f" {i}. type={type(c).__name__} value={c!r}")
163
+
164
+
165
+ def main(argv: Sequence[str] | None = None) -> int:
166
+ ap = argparse.ArgumentParser(description="Stage 3 debug harness (Stage2 -> Stage3)")
167
+
168
+ ap.add_argument(
169
+ "--prompt",
170
+ required=True,
171
+ help="Image description (original prompt).",
172
+ )
173
+ ap.add_argument(
174
+ "--phrases",
175
+ default="",
176
+ help=(
177
+ "Comma-separated Stage 1 rewrite phrases (bypass Stage 1). "
178
+ "If omitted, Stage 1 will run on --prompt via psq_candidates_from_prompt."
179
+ ),
180
+ )
181
+
182
+ ap.add_argument(
183
+ "--allow-nsfw",
184
+ dest="allow_nsfw",
185
+ action="store_true",
186
+ help="Allow NSFW tags to appear in Stage 2 candidates.",
187
+ )
188
+ ap.add_argument(
189
+ "--no-allow-nsfw",
190
+ dest="allow_nsfw",
191
+ action="store_false",
192
+ help="Disallow NSFW tags in Stage 2 candidates.",
193
+ )
194
+ ap.set_defaults(allow_nsfw=True)
195
+
196
+ ap.add_argument(
197
+ "--max-cands",
198
+ type=int,
199
+ default=0,
200
+ help="Optional: truncate Stage 2 candidate list to this many candidates (0 = no truncation).",
201
+ )
202
+
203
+ # Stage 3 knobs
204
+ ap.add_argument("--mode", choices=["single_shot", "chunked_map_union"], default="chunked_map_union")
205
+ ap.add_argument("--chunk-size", type=int, default=60)
206
+ ap.add_argument("--per-phrase-k", type=int, default=2)
207
+ ap.add_argument("--temperature", type=float, default=0.1)
208
+ ap.add_argument("--max-tokens", type=int, default=512)
209
+
210
+ ap.add_argument(
211
+ "--show-top",
212
+ type=int,
213
+ default=25,
214
+ help="Print the top-N Stage 2 candidates for a quick glance.",
215
+ )
216
+
217
+ args = ap.parse_args(list(argv) if argv is not None else None)
218
+
219
+ prompt = args.prompt.strip()
220
+ if not prompt:
221
+ print("--prompt must be non-empty", file=sys.stderr)
222
+ return 2
223
+
224
+ phrases = _split_csv_phrases(args.phrases) if args.phrases.strip() else []
225
+
226
+ print("Stage3 Debug")
227
+ print("-----------")
228
+ print(f"Prompt: {prompt}")
229
+ print(f"Phrases ({len(phrases)}): {', '.join(phrases)}")
230
+ print(f"allow_nsfw_tags: {args.allow_nsfw}")
231
+
232
+ # Stage 2
233
+ stage2_from_prompt, stage2_from_phrases = _import_stage2_entrypoint()
234
+ stage1_rewrite = _import_stage1_rewrite()
235
+
236
+ print("Running Stage 2 (retrieval grounding / candidate generation)...")
237
+
238
+ # Choose Stage 2 path based on whether --phrases was provided.
239
+ if phrases:
240
+ print("Stage 2 path: rewrite_phrases (Stage 1 bypassed)")
241
+ if stage2_from_phrases is None:
242
+ raise RuntimeError("psq_candidates_from_rewrite_phrases is not available in psq_rag.retrieval.psq_retrieval")
243
+ stage2_out = stage2_from_phrases(rewrite_phrases=phrases, allow_nsfw_tags=args.allow_nsfw)
244
+ else:
245
+ print("Stage 1 path: rewrite (LLM)")
246
+ rewritten = stage1_rewrite(prompt, log=print)
247
+ phrases = _split_csv_phrases(rewritten)
248
+ print(f"Rewrite phrases ({len(phrases)}): {', '.join(phrases)}")
249
+ print("Stage 2 path: rewrite_phrases (from Stage 1 output)")
250
+ if stage2_from_phrases is None:
251
+ raise RuntimeError("psq_candidates_from_rewrite_phrases is not available in psq_rag.retrieval.psq_retrieval")
252
+ stage2_out = stage2_from_phrases(rewrite_phrases=phrases, allow_nsfw_tags=args.allow_nsfw)
253
+
254
+ candidates, aux = _as_list_candidates(stage2_out)
255
+
256
+ if not isinstance(candidates, list):
257
+ candidates = list(candidates)
258
+
259
+ if isinstance(stage2_out, tuple):
260
+ print(f"Stage 2 return type: tuple len={len(stage2_out)}")
261
+ elif isinstance(stage2_out, list):
262
+ print(f"Stage 2 return type: list len={len(stage2_out)}")
263
+ elif isinstance(stage2_out, dict):
264
+ print(f"Stage 2 return type: dict keys={sorted(stage2_out.keys())}")
265
+ else:
266
+ print(f"Stage 2 return type: {type(stage2_out).__name__}")
267
+
268
+ print(f"Stage 2 candidates type: {type(candidates).__name__} len={len(candidates)}")
269
+
270
+ print(f"Stage 2 returned {len(candidates)} candidates")
271
+
272
+ if args.max_cands and args.max_cands > 0:
273
+ candidates = candidates[: args.max_cands]
274
+ print(f"Truncated to {len(candidates)} candidates due to --max-cands")
275
+
276
+ _describe_candidate_sample(candidates, n=5)
277
+
278
+ num_candidates_with_sources = sum(
279
+ 1
280
+ for c in candidates
281
+ if hasattr(c, "sources") and bool(getattr(c, "sources", []))
282
+ )
283
+ distinct_sources = len(
284
+ {
285
+ src
286
+ for c in candidates
287
+ if hasattr(c, "sources")
288
+ for src in getattr(c, "sources", [])
289
+ }
290
+ )
291
+ print(
292
+ "Stage 2 sources: "
293
+ f"with_sources={num_candidates_with_sources} "
294
+ f"distinct_sources={distinct_sources}"
295
+ )
296
+
297
+ if args.show_top and args.show_top > 0:
298
+ _print_top_candidates(candidates, args.show_top)
299
+
300
+ # Stage 3
301
+ llm_select_indices, split_candidates_by_type = _import_stage3_selector()
302
+
303
+ # Show candidate bucket assignments (general vs entity) if candidates are Candidate objects
304
+ if candidates and hasattr(candidates[0], 'tag'):
305
+ from psq_rag.retrieval.psq_retrieval import Candidate
306
+ print("\nCandidate type split (general vs entity):")
307
+ general_with_idx, entity_with_idx = split_candidates_by_type(cast(List[Candidate], candidates), log=None)
308
+ print(f" General candidates (attributes, species, meta, artists): {len(general_with_idx)}")
309
+ print(f" Entity candidates (characters only, copyrights filtered): {len(entity_with_idx)}")
310
+
311
+ if entity_with_idx:
312
+ print(f"\n Character candidates preview (first {min(10, len(entity_with_idx))}):")
313
+ for _, cand in entity_with_idx[:10]:
314
+ print(f" - {_safe_tag_display(cand.tag)}")
315
+ else:
316
+ print("\nSkipping candidate type split (candidates not in Candidate format)")
317
+
318
+ def log(msg: str) -> None:
319
+ print(msg)
320
+
321
+ print("\nRunning Stage 3 (closed-set selection)...")
322
+
323
+ # NOTE: llm_select_indices returns indices into the ORIGINAL candidates list you pass in.
324
+ picked = llm_select_indices(
325
+ query_text=prompt, # treated as image description in Stage 3
326
+ candidates=candidates,
327
+ max_pick=0,
328
+ log=log,
329
+ mode=args.mode,
330
+ chunk_size=args.chunk_size,
331
+ per_phrase_k=args.per_phrase_k,
332
+ temperature=args.temperature,
333
+ max_tokens=args.max_tokens,
334
+ )
335
+
336
+ print("\nStage 3 selected:")
337
+ if not picked:
338
+ print(" (no selections)")
339
+ return 0
340
+
341
+ # Deduplicate while preserving order
342
+ seen = set()
343
+ tags = []
344
+ for idx in picked:
345
+ if idx in seen:
346
+ continue
347
+ seen.add(idx)
348
+ c = candidates[idx]
349
+ tags.append(getattr(c, "tag", str(c)))
350
+
351
+ for t in tags:
352
+ print(f" - {t} ({_safe_tag_display(t)})")
353
+
354
+ print(f"\nTotal selected tags: {len(tags)}")
355
+ return 0
356
+
357
+
358
+ if __name__ == "__main__":
359
+ raise SystemExit(main())
scripts/test_alias_filter.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test harness for Stage 3 alias-based character tag filtering.
2
+
3
+ Tests _character_matches_via_aliases() and related helper functions to ensure:
4
+ - Character tags only match when the user mentions the character name (or alias)
5
+ - Variant tags (e.g. pikachu_libre) do NOT match when only the base name is used
6
+ - Aliases with series suffixes (e.g. tails_(sonic)) correctly match after normalization
7
+ - Fuzzy matching handles common typos
8
+ - Generic descriptions (e.g. "orange cat") do NOT match character tags
9
+
10
+ Usage:
11
+ python scripts/test_alias_filter.py
12
+
13
+ Requires: rapidfuzz (no CSV data files needed - uses mock alias data)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ # Ensure repo root is on sys.path
22
+ _REPO_ROOT = Path(__file__).resolve().parents[1]
23
+ if str(_REPO_ROOT) not in sys.path:
24
+ sys.path.insert(0, str(_REPO_ROOT))
25
+
26
+ from psq_rag.llm.select import (
27
+ _normalize_for_matching,
28
+ _query_words,
29
+ _alias_matches_query,
30
+ _character_matches_via_aliases,
31
+ )
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Mock alias data matching real e621 patterns
35
+ # ---------------------------------------------------------------------------
36
+ MOCK_TAG2ALIASES = {
37
+ # Garfield: "garfield" is an alias for garfield_the_cat
38
+ "garfield_the_cat": ["garfield", "garfield_(character)", "garfield_cat"],
39
+ # Tails / Miles Prower: aliases include tails_(sonic)
40
+ "miles_prower": ["tails_(sonic)", "tails_the_fox", "tailsko", "miles_tails_prower"],
41
+ # Pikachu base (species tag type 5, but testing if it were type 4)
42
+ "pikachu": ["pikachu_(pokemon)"],
43
+ # Pikachu variants - distinct aliases that should NOT match base "pikachu"
44
+ "pikachu_libre": ["pikachu_libre_(pokemon)", "libre_pikachu"],
45
+ "detective_pikachu": ["detective_pikachu_(pokemon)", "detective_pikachu_(movie)"],
46
+ "cosplay_pikachu_(character)": ["cosplay_pikachu"],
47
+ # Sonic
48
+ "sonic_the_hedgehog": ["sonic", "sonic_(character)", "sonic_(sth)"],
49
+ # Character with no aliases
50
+ "cat_busters": [],
51
+ # Mickey Mouse
52
+ "mickey_mouse": ["mickey", "mickey_(disney)"],
53
+ # A character whose name is a common word
54
+ "shadow_the_hedgehog": ["shadow_(sonic)", "shadow"],
55
+ }
56
+
57
+
58
+ def log(msg: str) -> None:
59
+ print(f" {msg}")
60
+
61
+
62
+ def run_tests() -> int:
63
+ passed = 0
64
+ failed = 0
65
+
66
+ def check(description: str, result: bool, expected: bool) -> None:
67
+ nonlocal passed, failed
68
+ status = "PASS" if result == expected else "FAIL"
69
+ if result != expected:
70
+ failed += 1
71
+ print(f" {status}: {description} (got={result}, expected={expected})")
72
+ else:
73
+ passed += 1
74
+ print(f" {status}: {description}")
75
+
76
+ # -----------------------------------------------------------------------
77
+ print("\n=== _normalize_for_matching ===")
78
+ # -----------------------------------------------------------------------
79
+ check(
80
+ "strips series suffix _(sonic)",
81
+ _normalize_for_matching("tails_(sonic)") == "tails",
82
+ True,
83
+ )
84
+ check(
85
+ "strips _(character) suffix",
86
+ _normalize_for_matching("garfield_(character)") == "garfield",
87
+ True,
88
+ )
89
+ check(
90
+ "replaces underscores with spaces",
91
+ _normalize_for_matching("garfield_the_cat") == "garfield the cat",
92
+ True,
93
+ )
94
+ check(
95
+ "lowercases",
96
+ _normalize_for_matching("Pikachu_Libre") == "pikachu libre",
97
+ True,
98
+ )
99
+ check(
100
+ "no suffix stays intact",
101
+ _normalize_for_matching("pikachu_libre") == "pikachu libre",
102
+ True,
103
+ )
104
+
105
+ # -----------------------------------------------------------------------
106
+ print("\n=== Core matching: tails vs miles_prower ===")
107
+ # -----------------------------------------------------------------------
108
+ query = "tails flying"
109
+ qwords = _query_words(query)
110
+ qnorm = _normalize_for_matching(query)
111
+
112
+ check(
113
+ "'tails flying' matches miles_prower (via alias tails_(sonic))",
114
+ _character_matches_via_aliases("miles_prower", query, MOCK_TAG2ALIASES, qwords, qnorm),
115
+ True,
116
+ )
117
+
118
+ # -----------------------------------------------------------------------
119
+ print("\n=== Core matching: pikachu vs variants ===")
120
+ # -----------------------------------------------------------------------
121
+ query = "pikachu with red cheeks"
122
+ qwords = _query_words(query)
123
+ qnorm = _normalize_for_matching(query)
124
+
125
+ check(
126
+ "'pikachu with red cheeks' matches pikachu (base tag)",
127
+ _character_matches_via_aliases("pikachu", query, MOCK_TAG2ALIASES, qwords, qnorm),
128
+ True,
129
+ )
130
+ check(
131
+ "'pikachu with red cheeks' does NOT match pikachu_libre",
132
+ _character_matches_via_aliases("pikachu_libre", query, MOCK_TAG2ALIASES, qwords, qnorm),
133
+ False,
134
+ )
135
+ check(
136
+ "'pikachu with red cheeks' does NOT match detective_pikachu",
137
+ _character_matches_via_aliases("detective_pikachu", query, MOCK_TAG2ALIASES, qwords, qnorm),
138
+ False,
139
+ )
140
+ check(
141
+ "'pikachu with red cheeks' does NOT match cosplay_pikachu_(character)",
142
+ _character_matches_via_aliases("cosplay_pikachu_(character)", query, MOCK_TAG2ALIASES, qwords, qnorm),
143
+ False,
144
+ )
145
+
146
+ # -----------------------------------------------------------------------
147
+ print("\n=== Variant explicitly mentioned ===")
148
+ # -----------------------------------------------------------------------
149
+ query = "pikachu libre wrestling"
150
+ qwords = _query_words(query)
151
+ qnorm = _normalize_for_matching(query)
152
+
153
+ check(
154
+ "'pikachu libre wrestling' matches pikachu_libre",
155
+ _character_matches_via_aliases("pikachu_libre", query, MOCK_TAG2ALIASES, qwords, qnorm),
156
+ True,
157
+ )
158
+ check(
159
+ "'pikachu libre wrestling' also matches base pikachu (substring)",
160
+ _character_matches_via_aliases("pikachu", query, MOCK_TAG2ALIASES, qwords, qnorm),
161
+ True,
162
+ )
163
+
164
+ query = "detective pikachu in the rain"
165
+ qwords = _query_words(query)
166
+ qnorm = _normalize_for_matching(query)
167
+
168
+ check(
169
+ "'detective pikachu in the rain' matches detective_pikachu",
170
+ _character_matches_via_aliases("detective_pikachu", query, MOCK_TAG2ALIASES, qwords, qnorm),
171
+ True,
172
+ )
173
+
174
+ # -----------------------------------------------------------------------
175
+ print("\n=== Garfield via alias ===")
176
+ # -----------------------------------------------------------------------
177
+ query = "garfield sleeping on a table"
178
+ qwords = _query_words(query)
179
+ qnorm = _normalize_for_matching(query)
180
+
181
+ check(
182
+ "'garfield sleeping' matches garfield_the_cat (via alias 'garfield')",
183
+ _character_matches_via_aliases("garfield_the_cat", query, MOCK_TAG2ALIASES, qwords, qnorm),
184
+ True,
185
+ )
186
+
187
+ # -----------------------------------------------------------------------
188
+ print("\n=== Generic description should NOT match characters ===")
189
+ # -----------------------------------------------------------------------
190
+ query = "orange cat sitting outside"
191
+ qwords = _query_words(query)
192
+ qnorm = _normalize_for_matching(query)
193
+
194
+ check(
195
+ "'orange cat sitting outside' does NOT match garfield_the_cat",
196
+ _character_matches_via_aliases("garfield_the_cat", query, MOCK_TAG2ALIASES, qwords, qnorm),
197
+ False,
198
+ )
199
+ check(
200
+ "'orange cat sitting outside' does NOT match cat_busters",
201
+ _character_matches_via_aliases("cat_busters", query, MOCK_TAG2ALIASES, qwords, qnorm),
202
+ False,
203
+ )
204
+
205
+ query = "mouse character running"
206
+ qwords = _query_words(query)
207
+ qnorm = _normalize_for_matching(query)
208
+
209
+ check(
210
+ "'mouse character running' does NOT match mickey_mouse",
211
+ _character_matches_via_aliases("mickey_mouse", query, MOCK_TAG2ALIASES, qwords, qnorm),
212
+ False,
213
+ )
214
+
215
+ # -----------------------------------------------------------------------
216
+ print("\n=== Sonic via alias ===")
217
+ # -----------------------------------------------------------------------
218
+ query = "sonic running fast"
219
+ qwords = _query_words(query)
220
+ qnorm = _normalize_for_matching(query)
221
+
222
+ check(
223
+ "'sonic running fast' matches sonic_the_hedgehog (via alias 'sonic')",
224
+ _character_matches_via_aliases("sonic_the_hedgehog", query, MOCK_TAG2ALIASES, qwords, qnorm),
225
+ True,
226
+ )
227
+
228
+ # -----------------------------------------------------------------------
229
+ print("\n=== Fuzzy matching: typos ===")
230
+ # -----------------------------------------------------------------------
231
+ query = "garfeild sleeping"
232
+ qwords = _query_words(query)
233
+ qnorm = _normalize_for_matching(query)
234
+
235
+ check(
236
+ "'garfeild' (typo) matches garfield_the_cat via fuzzy",
237
+ _character_matches_via_aliases("garfield_the_cat", query, MOCK_TAG2ALIASES, qwords, qnorm),
238
+ True,
239
+ )
240
+
241
+ query = "pikachuu battling"
242
+ qwords = _query_words(query)
243
+ qnorm = _normalize_for_matching(query)
244
+
245
+ check(
246
+ "'pikachuu' (typo) matches pikachu via fuzzy",
247
+ _character_matches_via_aliases("pikachu", query, MOCK_TAG2ALIASES, qwords, qnorm),
248
+ True,
249
+ )
250
+
251
+ # -----------------------------------------------------------------------
252
+ print("\n=== Shadow: common word that is also a character alias ===")
253
+ # -----------------------------------------------------------------------
254
+ query = "shadow the hedgehog posing"
255
+ qwords = _query_words(query)
256
+ qnorm = _normalize_for_matching(query)
257
+
258
+ check(
259
+ "'shadow the hedgehog posing' matches shadow_the_hedgehog",
260
+ _character_matches_via_aliases("shadow_the_hedgehog", query, MOCK_TAG2ALIASES, qwords, qnorm),
261
+ True,
262
+ )
263
+
264
+ # "shadow" alone is an alias - this WILL match because the user said "shadow"
265
+ query = "shadow lurking in darkness"
266
+ qwords = _query_words(query)
267
+ qnorm = _normalize_for_matching(query)
268
+
269
+ check(
270
+ "'shadow lurking in darkness' matches shadow_the_hedgehog (alias 'shadow')",
271
+ _character_matches_via_aliases("shadow_the_hedgehog", query, MOCK_TAG2ALIASES, qwords, qnorm),
272
+ True,
273
+ )
274
+
275
+ # -----------------------------------------------------------------------
276
+ print("\n=== Tag with no aliases and no name match ===")
277
+ # -----------------------------------------------------------------------
278
+ query = "a dog playing fetch"
279
+ qwords = _query_words(query)
280
+ qnorm = _normalize_for_matching(query)
281
+
282
+ check(
283
+ "'a dog playing fetch' does NOT match cat_busters (no aliases, no name match)",
284
+ _character_matches_via_aliases("cat_busters", query, MOCK_TAG2ALIASES, qwords, qnorm),
285
+ False,
286
+ )
287
+
288
+ # -----------------------------------------------------------------------
289
+ # Summary
290
+ # -----------------------------------------------------------------------
291
+ total = passed + failed
292
+ print(f"\n{'=' * 50}")
293
+ print(f"Results: {passed}/{total} passed, {failed}/{total} failed")
294
+ if failed == 0:
295
+ print("ALL TESTS PASSED")
296
+ else:
297
+ print("SOME TESTS FAILED")
298
+ print(f"{'=' * 50}")
299
+
300
+ return 1 if failed > 0 else 0
301
+
302
+
303
+ if __name__ == "__main__":
304
+ sys.exit(run_tests())
transparentsquirrel.png ADDED

Git LFS Details

  • SHA256: 090a20f6afc0879333afb01ee491df994ed549c543aac861d76ab1fa05978a90
  • Pointer size: 131 Bytes
  • Size of remote file: 257 kB
wiki_pages-2023-08-08.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d453c0cc8ae09c548e554ceb77b1c1578c277eb2c5a6278a85f89c73566a7b27
3
+ size 30986436
word_rating_probabilities.csv ADDED
The diff for this file is too large to render. See raw diff