File size: 2,223 Bytes
b461543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import numpy as np
from huggingface_hub import hf_hub_download
from tiatoolbox.models.arch import resnet50
from tiatoolbox.models.models_abc import ModelABC
from torchvision import transforms
from PIL import Image
import gradio as gr


# -------------------------------------------------
# ุชุญู…ูŠู„ ุงู„ู…ูˆุฏูŠู„ ู…ู† Hugging Face
# -------------------------------------------------

MODEL_REPO = "kaczmarj/lymphnodes-tiatoolbox-resnet50.patchcamelyon"
MODEL_FILE = "resnet50-pcam.pth"   # ุบูŠู‘ุฑู‡ ู„ูˆ ุงุณู… ุงู„ู…ู„ู ู…ุฎุชู„ู

model_path = hf_hub_download(
    repo_id=MODEL_REPO,
    filename=MODEL_FILE
)

# -------------------------------------------------
# ุฅุนุฏุงุฏ ู…ูˆุฏูŠู„ ResNet50 (TIAToolbox)
# -------------------------------------------------

class PCamModel(ModelABC):
    def __init__(self):
        super().__init__()
        self.model = resnet50(pretrained=False, num_classes=2)
        self.model.load_state_dict(torch.load(model_path, map_location="cpu"))
        self.model.eval()

    def forward(self, imgs):
        return self.model(imgs)


model = PCamModel()

# -------------------------------------------------
# ุงู„ุชุญูˆูŠู„ุงุช ุงู„ู…ุทู„ูˆุจุฉ ู„ู„ุตูˆุฑุฉ
# -------------------------------------------------

transform = transforms.Compose([
    transforms.Resize((96, 96)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

labels = ["No Tumor", "Tumor"]


# -------------------------------------------------
# ุฏุงู„ุฉ ุงู„ุชู†ุจุค
# -------------------------------------------------

def predict(image):
    img = transform(image).unsqueeze(0)

    with torch.no_grad():
        logits = model(img)[0]
        probs = torch.softmax(logits, dim=0).numpy()

    return {labels[0]: float(probs[0]), labels[1]: float(probs[1])}


# -------------------------------------------------
# ูˆุงุฌู‡ุฉ Gradio
# -------------------------------------------------

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil"),
    outputs=gr.Label(),
    title="Lymph Node Tumor Detection (PatchCamelyon โ€“ ResNet50)",
    description="Model: kaczmarj/lymphnodes-tiatoolbox-resnet50.patchcamelyon"
)

demo.launch()