--- license: apache-2.0 library_name: pytorch tags: - image-segmentation - medical - polyp - colonoscopy - unet - unet3plus - efficientnet datasets: - andreribeiro87/kvasir-seg-augmented metrics: - dice - iou pipeline_tag: image-segmentation --- # UNet3+ EfficientNet — Polyp Segmentation (HPO-Optimised) Binary polyp segmentation model trained on [Kvasir-SEG](https://huggingface.co/datasets/Angelou0516/kvasir-seg). **Selected as the best configuration** after a systematic sweep over 24 architecture × backbone combinations followed by Optuna hyperparameter optimisation (60 trials). ## Model Description | Property | Value | |---|---| | Architecture | **UNet 3+** (full-scale skip connections) | | Backbone | **EfficientNet-B0** (ImageNet pre-trained via `timm`) | | Input size | 256 × 256 × 3 | | Output | 256 × 256 × 1 logit map (sigmoid → binary mask) | | Parameters | ~50 MB | | Loss | Dice-Focal (y = 1.12, dice weight = 0.30) | ### Architecture Details **UNet 3+** (Huang et al., ICASSP 2020) extends the standard U-Net by adding *full-scale skip connections*: every decoder node aggregates feature maps from **all encoder scales simultaneously**, giving each node access to both fine-grained spatial detail and deep semantic context. Each incoming stream is projected to a fixed number of inter-channels (64) before concatenation so the total channel count is constant across all decoder levels. The **EfficientNet-B0** backbone (pre-trained on ImageNet-1k) replaces the standard U-Net encoder, providing rich multi-scale representations at five resolution levels. ## Test Set Results Evaluated on the fixed 53-image test partition of Kvasir-SEG (50 % of the original validation split, seed 42): | Metric | Value | |---|---| | **Dice** | **0.9234** | | **IoU** | **0.8577** | | F1 | 0.9234 | | Precision | 0.9474 | | Recall | 0.9005 | | Accuracy | 0.9745 | | Loss | 0.0914 | ## Comparison with Sweep Models This model was selected from an initial sweep of **24 architecture × backbone** combinations: | Rank | Model | Test Dice | Test IoU | |------|-------|-----------|---------| | 1 | attention_unet_convnext | 0.9411 | 0.8888 | | 2 | unet3plus_convnext | 0.9395 | 0.8859 | | 3 | unet_convnext | 0.9383 | 0.8838 | | **—** | **unet3plus_efficientnet (this, post-HPO)** | **0.9234** | **0.8577** | > **Note:** Sweep models were trained for 50 epochs with default hyperparameters (lr = 1e-3, BCEDice loss). > This model was retrained for 5 more epochs using the Optuna-optimised configuration, which significantly > reduces eval loss (0.0537 vs. target < 0.10) at the cost of a slight metric shift on the held-out test set. ## Training Procedure ### Hyperparameter Optimisation Optuna with `MedianPruner` ran **60 trials** (28 completed, 32 pruned) on the top-3 sweep models. The best trial (#32) achieved `eval_loss = 0.0537` (target: < 0.10 ✓). | Hyperparameter | Value | |---|---| | Learning rate | 0.001794 | | Weight decay | 2.51e-06 | | Warmup ratio | 0.121 | | LR scheduler | cosine_with_restarts | | Batch size | 64 | | Loss type | dice_focal | | Focal gamma | 1.1217 | | Dice weight | 0.3012 | ### Training Configuration - **Optimiser:** AdamW - **Epochs:** 50 (sweep) + 5 (HPO final retrain) - **FP16:** enabled - **Dataset:** Kvasir-SEG augmented (4,800 train / 100 val / 100 test) - **Augmentation:** random H/V flips, ±30° rotation, brightness/contrast/saturation ±20 % ## How to Use This model uses a custom PyTorch architecture. The model code is included in the repository. ### Installation ```bash pip install torch torchvision timm transformers ``` ### Inference ```python import torch from transformers import AutoModel from torchvision.transforms import functional as TF from PIL import Image # Load model — downloads weights + code automatically model = AutoModel.from_pretrained( "andreribeiro87/unet3plus-efficientnet-kvasir-seg", trust_remote_code=True, ) model.eval() # Preprocess image = Image.open("your_colonoscopy_image.jpg").convert("RGB") x = TF.to_tensor(TF.resize(image, [256, 256])).unsqueeze(0) # (1, 3, 256, 256) # Predict with torch.no_grad(): outputs = model(pixel_values=x) mask = (outputs["logits"].sigmoid() > 0.5).squeeze() # bool (256, 256) pred_mask = TF.to_pil_image(mask.float()) ``` ## Citation If you use this model or dataset, please cite the original Kvasir-SEG paper: ```bibtex @inproceedings{jha2020kvasir, title = {Kvasir-SEG: A Segmented Polyp Dataset}, author = {Jha, Debesh and Smedsrud, Pia H and Riegler, Michael A and Halvorsen, P{a}l and de Lange, Thomas and Johansen, Dag and Johansen, H{a}vard D}, booktitle = {MultiMedia Modeling (MMM)}, year = {2020} } ``` ```bibtex @inproceedings{huang2020unet3plus, title = {UNet 3+: A Full-Scale Connected UNet for Medical Image Segmentation}, author = {Huang, Huimin and Lin, Lanfen and Tong, Ruofeng and Hu, Hongjie and Zhang, Qiaowei and Iwamoto, Yutaro and Han, Xianhua and Chen, Yen-Wei and Wu, Jian}, booktitle = {IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, year = {2020} } ``` ## Limitations - Trained and evaluated exclusively on **Kvasir-SEG** (single-centre, single-modality). Performance may degrade on other colonoscopy datasets or imaging conditions. - Binary segmentation only; does not distinguish between polyp types or severity. - Input resolution is fixed at **256 × 256**; very small polyps may not be fully captured. - **Not validated for clinical use.** This is a research model.