andreribeiro87 commited on
Commit
9c57d57
·
verified ·
1 Parent(s): e664e1c

Add AutoModel.from_pretrained support (unet3plus+efficientnet)

Browse files
Files changed (2) hide show
  1. config.json +1 -1
  2. modeling_unet3plus.py +21 -2
config.json CHANGED
@@ -6,7 +6,7 @@
6
  "out_channels": 1,
7
  "inter_ch": 64,
8
  "auto_map": {
9
- "AutoConfig": "configuration_unet3plus.UNet3PlusConfig",
10
  "AutoModel": "modeling_unet3plus.UNet3PlusForSegmentation"
11
  },
12
  "training": {
 
6
  "out_channels": 1,
7
  "inter_ch": 64,
8
  "auto_map": {
9
+ "AutoConfig": "modeling_unet3plus.UNet3PlusConfig",
10
  "AutoModel": "modeling_unet3plus.UNet3PlusForSegmentation"
11
  },
12
  "training": {
modeling_unet3plus.py CHANGED
@@ -7,12 +7,31 @@ if _here not in sys.path:
7
 
8
  import torch
9
  import torch.nn.functional as F
10
- from transformers import PreTrainedModel
11
 
12
- from configuration_unet3plus import UNet3PlusConfig
13
  from models import create_model
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class UNet3PlusForSegmentation(PreTrainedModel):
17
  """UNet 3+ segmentation model with EfficientNet-B0 backbone.
18
 
 
7
 
8
  import torch
9
  import torch.nn.functional as F
10
+ from transformers import PreTrainedModel, PretrainedConfig
11
 
 
12
  from models import create_model
13
 
14
 
15
+ class UNet3PlusConfig(PretrainedConfig):
16
+ model_type = "unet3plus"
17
+
18
+ def __init__(
19
+ self,
20
+ backbone: str = "efficientnet",
21
+ in_channels: int = 3,
22
+ out_channels: int = 1,
23
+ inter_ch: int = 64,
24
+ img_size: int = 256,
25
+ **kwargs,
26
+ ):
27
+ super().__init__(**kwargs)
28
+ self.backbone = backbone
29
+ self.in_channels = in_channels
30
+ self.out_channels = out_channels
31
+ self.inter_ch = inter_ch
32
+ self.img_size = img_size
33
+
34
+
35
  class UNet3PlusForSegmentation(PreTrainedModel):
36
  """UNet 3+ segmentation model with EfficientNet-B0 backbone.
37