OsamaHamad2023 commited on
Commit
b461543
ยท
verified ยท
1 Parent(s): bfd5fc5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from huggingface_hub import hf_hub_download
4
+ from tiatoolbox.models.arch import resnet50
5
+ from tiatoolbox.models.models_abc import ModelABC
6
+ from torchvision import transforms
7
+ from PIL import Image
8
+ import gradio as gr
9
+
10
+
11
+ # -------------------------------------------------
12
+ # ุชุญู…ูŠู„ ุงู„ู…ูˆุฏูŠู„ ู…ู† Hugging Face
13
+ # -------------------------------------------------
14
+
15
+ MODEL_REPO = "kaczmarj/lymphnodes-tiatoolbox-resnet50.patchcamelyon"
16
+ MODEL_FILE = "resnet50-pcam.pth" # ุบูŠู‘ุฑู‡ ู„ูˆ ุงุณู… ุงู„ู…ู„ู ู…ุฎุชู„ู
17
+
18
+ model_path = hf_hub_download(
19
+ repo_id=MODEL_REPO,
20
+ filename=MODEL_FILE
21
+ )
22
+
23
+ # -------------------------------------------------
24
+ # ุฅุนุฏุงุฏ ู…ูˆุฏูŠู„ ResNet50 (TIAToolbox)
25
+ # -------------------------------------------------
26
+
27
+ class PCamModel(ModelABC):
28
+ def __init__(self):
29
+ super().__init__()
30
+ self.model = resnet50(pretrained=False, num_classes=2)
31
+ self.model.load_state_dict(torch.load(model_path, map_location="cpu"))
32
+ self.model.eval()
33
+
34
+ def forward(self, imgs):
35
+ return self.model(imgs)
36
+
37
+
38
+ model = PCamModel()
39
+
40
+ # -------------------------------------------------
41
+ # ุงู„ุชุญูˆูŠู„ุงุช ุงู„ู…ุทู„ูˆุจุฉ ู„ู„ุตูˆุฑุฉ
42
+ # -------------------------------------------------
43
+
44
+ transform = transforms.Compose([
45
+ transforms.Resize((96, 96)),
46
+ transforms.ToTensor(),
47
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
48
+ ])
49
+
50
+ labels = ["No Tumor", "Tumor"]
51
+
52
+
53
+ # -------------------------------------------------
54
+ # ุฏุงู„ุฉ ุงู„ุชู†ุจุค
55
+ # -------------------------------------------------
56
+
57
+ def predict(image):
58
+ img = transform(image).unsqueeze(0)
59
+
60
+ with torch.no_grad():
61
+ logits = model(img)[0]
62
+ probs = torch.softmax(logits, dim=0).numpy()
63
+
64
+ return {labels[0]: float(probs[0]), labels[1]: float(probs[1])}
65
+
66
+
67
+ # -------------------------------------------------
68
+ # ูˆุงุฌู‡ุฉ Gradio
69
+ # -------------------------------------------------
70
+
71
+ demo = gr.Interface(
72
+ fn=predict,
73
+ inputs=gr.Image(type="pil"),
74
+ outputs=gr.Label(),
75
+ title="Lymph Node Tumor Detection (PatchCamelyon โ€“ ResNet50)",
76
+ description="Model: kaczmarj/lymphnodes-tiatoolbox-resnet50.patchcamelyon"
77
+ )
78
+
79
+ demo.launch()