File size: 3,452 Bytes
5738620
3471015
a786b7a
 
3471015
 
5738620
 
 
 
 
 
 
3471015
5738620
 
 
3471015
5738620
3471015
5738620
 
 
 
3471015
5738620
 
 
 
 
 
 
 
3471015
5738620
 
 
 
3471015
5738620
 
 
 
 
3471015
5738620
 
 
 
 
 
3471015
5738620
 
 
 
 
 
3471015
5738620
 
 
 
3471015
5738620
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# 📦 RADIOCAP13 — HuggingFace Space

#Below is a complete multi-file project layout for deploying your image-captioning model as a HuggingFace Space.
#You can copy/paste these into your repository.


## **app.py**
import gradio as gr
import torch
from transformers import ViTModel
from PIL import Image
from torchvision import transforms
import json

IMG_SIZE = 224
SEQ_LEN = 32
VOCAB_SIZE = 75460

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])

def preprocess_image(img):
    if img is None:
        raise ValueError("Image is None")
    if not isinstance(img, Image.Image):
        img = Image.fromarray(img)
    if img.mode != "RGB":
        img = img.convert("RGB")
    return transform(img)

class SimpleTokenizer:
    def __init__(self, word2idx=None):
        self.word2idx = word2idx or {}
        self.idx2word = {v: k for k, v in self.word2idx.items()}

    @classmethod
    def load(cls, path):
        with open(f"{path}/vocab.json", "r") as f:
            word2idx = json.load(f)
        return cls(word2idx)

class BiasDecoder(torch.nn.Module):
    def __init__(self, feature_dim=768, vocab_size=VOCAB_SIZE):
        super().__init__()
        self.token_emb = torch.nn.Embedding(vocab_size, feature_dim)
        self.pos_emb = torch.nn.Embedding(SEQ_LEN-1, feature_dim)
        self.final_layer = torch.nn.Linear(feature_dim, vocab_size)

    def forward(self, img_feat, target_seq):
        x = self.token_emb(target_seq)
        pos = torch.arange(x.size(1), device=x.device).clamp(max=self.pos_emb.num_embeddings - 1)
        x = x + self.pos_emb(pos)
        x = x + img_feat.unsqueeze(1)
        return self.final_layer(x)

# Load models
decoder = BiasDecoder().to(device)
decoder.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
decoder.eval()

vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k").to(device)
vit.eval()

tokenizer = SimpleTokenizer.load("./")
pad_idx = tokenizer.word2idx["<PAD>"]

@torch.no_grad()
def generate_caption(img):
    img_tensor = preprocess_image(img).unsqueeze(0).to(device)
    img_feat = vit(pixel_values=img_tensor).pooler_output

    beams = [([tokenizer.word2idx["<SOS>"]], 0.0)]
    beam_size = 3

    for _ in range(SEQ_LEN - 1):
        candidates = []
        for seq, score in beams:
            inp = torch.tensor(seq + [pad_idx] * (SEQ_LEN - len(seq)), device=device).unsqueeze(0)
            logits = decoder(img_feat, inp)
            probs = torch.nn.functional.log_softmax(logits[0, len(seq)-1], dim=-1)
            top_p, top_i = torch.topk(probs, beam_size)
            for i in range(beam_size):
                candidates.append((seq + [top_i[i].item()], score + top_p[i].item()))
        beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_size]
        if all(s[-1] == tokenizer.word2idx["<EOS>"] for s, _ in beams):
            break

    words = [tokenizer.idx2word.get(i, "<UNK>") for i in beams[0][0][1:] if i != pad_idx]
    return " ".join(words)

with gr.Blocks() as demo:
    gr.Markdown("# RADIOCAP13 — Image Captioning Demo")
    img_in = gr.Image(type="pil", label="Upload an Image")
    out = gr.Textbox(label="Generated Caption")
    btn = gr.Button("Generate Caption")
    btn.click(generate_caption, inputs=img_in, outputs=out)

if __name__ == "__main__":
    demo.launch()