animetimm/danbooru-wdtagger-v4-w640-ws-full
Updated • 281 • 7
How to use Mooshie/caformer_b36.dbv4-full with timm:
import timm
model = timm.create_model("hf_hub:Mooshie/caformer_b36.dbv4-full", pretrained=True)How to use Mooshie/caformer_b36.dbv4-full with Transformers:
# Use a pipeline as a high-level helper
from transformers import pipeline
pipe = pipeline("image-classification", model="Mooshie/caformer_b36.dbv4-full")
pipe("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/hub/parrots.png") # Load model directly
from transformers import AutoModel
model = AutoModel.from_pretrained("Mooshie/caformer_b36.dbv4-full", dtype="auto")| # | Macro@0.40 (F1/MCC/P/R) | Micro@0.40 (F1/MCC/P/R) | Macro@Best (F1/P/R) |
|---|---|---|---|
| Validation | 0.545 / 0.551 / 0.601 / 0.519 | 0.688 / 0.687 / 0.689 / 0.687 | --- |
| Test | 0.546 / 0.552 / 0.602 / 0.519 | 0.689 / 0.688 / 0.690 / 0.688 | 0.581 / 0.585 / 0.599 |
Macro/Micro@0.40 means the metrics on the threshold 0.40.Macro@Best means the mean metrics on the tag-level thresholds on each tags, which should have the best F1 scores.| Category | Name | Alpha | Threshold | Micro@Thr (F1/P/R) | Macro@0.40 (F1/P/R) | Macro@Best (F1/P/R) |
|---|---|---|---|---|---|---|
| 0 | general | 1 | 0.39 | 0.676 / 0.672 / 0.680 | 0.418 / 0.485 / 0.388 | 0.459 / 0.457 / 0.490 |
| 4 | character | 1 | 0.47 | 0.933 / 0.953 / 0.915 | 0.911 / 0.934 / 0.893 | 0.926 / 0.948 / 0.908 |
| 9 | rating | 1 | 0.39 | 0.830 / 0.792 / 0.871 | 0.837 / 0.811 / 0.864 | 0.837 / 0.813 / 0.864 |
Micro@Thr means the metrics on the category-level suggested thresholds, which are listed in the table above.Macro@0.40 means the metrics on the threshold 0.40.Macro@Best means the metrics on the tag-level thresholds on each tags, which should have the best F1 scores.For tag-level thresholds, you can find them in selected_tags.csv.
We provided a sample image for our code samples, you can find it here.
Install dghs-imgutils, timm and other necessary requirements with the following command
pip install 'dghs-imgutils>=0.17.0' torch huggingface_hub timm pillow pandas
After that you can load this model with timm library, and use it for train, validation and test, with the following code
import json
import pandas as pd
import torch
from huggingface_hub import hf_hub_download
from imgutils.data import load_image
from imgutils.preprocess import create_torchvision_transforms
from timm import create_model
repo_id = 'animetimm/caformer_b36.dbv4-full'
model = create_model(f'hf-hub:{repo_id}', pretrained=True)
model.eval()
with open(hf_hub_download(repo_id=repo_id, repo_type='model', filename='preprocess.json'), 'r') as f:
preprocessor = create_torchvision_transforms(json.load(f)['test'])
# Compose(
# PadToSize(size=(512, 512), interpolation=bilinear, background_color=white)
# Resize(size=384, interpolation=bicubic, max_size=None, antialias=True)
# CenterCrop(size=[384, 384])
# MaybeToTensor()
# Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
# )
image = load_image('https://huggingface.co/animetimm/caformer_b36.dbv4-full/resolve/main/sample.webp')
input_ = preprocessor(image).unsqueeze(0)
# input_, shape: torch.Size([1, 3, 384, 384]), dtype: torch.float32
with torch.no_grad():
output = model(input_)
prediction = torch.sigmoid(output)[0]
# output, shape: torch.Size([1, 12476]), dtype: torch.float32
# prediction, shape: torch.Size([12476]), dtype: torch.float32
df_tags = pd.read_csv(
hf_hub_download(repo_id=repo_id, repo_type='model', filename='selected_tags.csv'),
keep_default_na=False
)
tags = df_tags['name']
mask = prediction.numpy() >= df_tags['best_threshold']
print(dict(zip(tags[mask].tolist(), prediction[mask].tolist())))
# {'sensitive': 0.6932118535041809,
# '1girl': 0.9990721940994263,
# 'solo': 0.9785084128379822,
# 'looking_at_viewer': 0.7411327958106995,
# 'blush': 0.8228459358215332,
# 'smile': 0.9370849132537842,
# 'short_hair': 0.8239911198616028,
# 'long_sleeves': 0.5299726724624634,
# 'brown_hair': 0.6389132738113403,
# 'holding': 0.6104577779769897,
# 'dress': 0.6728140115737915,
# 'closed_mouth': 0.33915525674819946,
# 'sitting': 0.7986266016960144,
# 'purple_eyes': 0.7082042694091797,
# 'flower': 0.8504390120506287,
# 'braid': 0.812047004699707,
# 'blunt_bangs': 0.27516067028045654,
# 'tears': 0.8593371510505676,
# 'floral_print': 0.28373879194259644,
# 'crying': 0.31545740365982056,
# 'plant': 0.7968168258666992,
# 'blue_flower': 0.47092026472091675,
# 'tareme': 0.1419680416584015,
# 'crying_with_eyes_open': 0.2293853610754013,
# 'crown_braid': 0.6291146874427795,
# 'potted_plant': 0.7235668897628784,
# 'flower_pot': 0.7853846549987793,
# 'happy_tears': 0.18256734311580658,
# 'pavement': 0.3113744258880615,
# 'wiping_tears': 0.7474808096885681,
# 'morning_glory': 0.865814208984375}
Install dghs-imgutils with the following command
pip install 'dghs-imgutils>=0.17.0'
Use multilabel_timm_predict function with the following code
from imgutils.generic import multilabel_timm_predict
general, character, rating = multilabel_timm_predict(
'https://huggingface.co/animetimm/caformer_b36.dbv4-full/resolve/main/sample.webp',
repo_id='animetimm/caformer_b36.dbv4-full',
fmt=('general', 'character', 'rating'),
)
print(general)
# {'1girl': 0.9990721940994263,
# 'solo': 0.9785083532333374,
# 'smile': 0.9370849132537842,
# 'morning_glory': 0.8658077716827393,
# 'tears': 0.8593354225158691,
# 'flower': 0.8504382371902466,
# 'short_hair': 0.8239905834197998,
# 'blush': 0.8228461742401123,
# 'braid': 0.8120447397232056,
# 'sitting': 0.798625111579895,
# 'plant': 0.7968136072158813,
# 'flower_pot': 0.7853772640228271,
# 'wiping_tears': 0.7474707365036011,
# 'looking_at_viewer': 0.7411322593688965,
# 'potted_plant': 0.7235615253448486,
# 'purple_eyes': 0.7082012295722961,
# 'dress': 0.6728127598762512,
# 'brown_hair': 0.6389119029045105,
# 'crown_braid': 0.6291083693504333,
# 'holding': 0.6104577779769897,
# 'long_sleeves': 0.5299732089042664,
# 'blue_flower': 0.470914363861084,
# 'closed_mouth': 0.33915433287620544,
# 'crying': 0.31545311212539673,
# 'pavement': 0.3113621473312378,
# 'floral_print': 0.2837352156639099,
# 'blunt_bangs': 0.2751583755016327,
# 'crying_with_eyes_open': 0.22938179969787598,
# 'happy_tears': 0.1825598180294037,
# 'tareme': 0.1419658362865448}
print(character)
# {}
print(rating)
# {'sensitive': 0.6932107210159302}
For further information, see documentation of function multilabel_timm_predict.
Base model
timm/caformer_b36.sail_in22k_ft_in1k_384