File size: 544 Bytes
e664e1c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | from transformers import PretrainedConfig
class UNet3PlusConfig(PretrainedConfig):
model_type = "unet3plus"
def __init__(
self,
backbone: str = "efficientnet",
in_channels: int = 3,
out_channels: int = 1,
inter_ch: int = 64,
img_size: int = 256,
**kwargs,
):
super().__init__(**kwargs)
self.backbone = backbone
self.in_channels = in_channels
self.out_channels = out_channels
self.inter_ch = inter_ch
self.img_size = img_size
|