Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| from tiatoolbox.models.arch import resnet50 | |
| from tiatoolbox.models.models_abc import ModelABC | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| # ------------------------------------------------- | |
| # ุชุญู ูู ุงูู ูุฏูู ู ู Hugging Face | |
| # ------------------------------------------------- | |
| MODEL_REPO = "kaczmarj/lymphnodes-tiatoolbox-resnet50.patchcamelyon" | |
| MODEL_FILE = "resnet50-pcam.pth" # ุบููุฑู ูู ุงุณู ุงูู ูู ู ุฎุชูู | |
| model_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename=MODEL_FILE | |
| ) | |
| # ------------------------------------------------- | |
| # ุฅุนุฏุงุฏ ู ูุฏูู ResNet50 (TIAToolbox) | |
| # ------------------------------------------------- | |
| class PCamModel(ModelABC): | |
| def __init__(self): | |
| super().__init__() | |
| self.model = resnet50(pretrained=False, num_classes=2) | |
| self.model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
| self.model.eval() | |
| def forward(self, imgs): | |
| return self.model(imgs) | |
| model = PCamModel() | |
| # ------------------------------------------------- | |
| # ุงูุชุญูููุงุช ุงูู ุทููุจุฉ ููุตูุฑุฉ | |
| # ------------------------------------------------- | |
| transform = transforms.Compose([ | |
| transforms.Resize((96, 96)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| ]) | |
| labels = ["No Tumor", "Tumor"] | |
| # ------------------------------------------------- | |
| # ุฏุงูุฉ ุงูุชูุจุค | |
| # ------------------------------------------------- | |
| def predict(image): | |
| img = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| logits = model(img)[0] | |
| probs = torch.softmax(logits, dim=0).numpy() | |
| return {labels[0]: float(probs[0]), labels[1]: float(probs[1])} | |
| # ------------------------------------------------- | |
| # ูุงุฌูุฉ Gradio | |
| # ------------------------------------------------- | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(), | |
| title="Lymph Node Tumor Detection (PatchCamelyon โ ResNet50)", | |
| description="Model: kaczmarj/lymphnodes-tiatoolbox-resnet50.patchcamelyon" | |
| ) | |
| demo.launch() |