blumenstiel commited on
Commit
e54b97f
·
1 Parent(s): 70c9404

Fix rename weight keys and fix inference

Browse files
Prithvi-EO-V2-300M-TL-Sen1Floods11.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:eb7db01b161d8919c7e61fed209d8633dabd49f5f81c785600790e43832344ba
3
- size 1276881720
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3675e9c2b52547de8ff8a19f4881c28573e6d4d2f0805d866f4fc48c1e517d60
3
+ size 1276843350
config.yaml CHANGED
@@ -101,13 +101,6 @@ data:
101
  - SWIR_1
102
  - SWIR_2
103
  train_transform:
104
- - class_path: albumentations.Resize
105
- init_args:
106
- height: 448
107
- width: 448
108
- interpolation: 1
109
- always_apply: false
110
- p: 1
111
  - class_path: albumentations.RandomCrop
112
  init_args:
113
  height: 224
@@ -128,26 +121,12 @@ data:
128
  always_apply: true
129
  p: 1.0
130
  val_transform:
131
- - class_path: albumentations.Resize
132
- init_args:
133
- height: 448
134
- width: 448
135
- interpolation: 1
136
- always_apply: false
137
- p: 1
138
  - class_path: albumentations.pytorch.ToTensorV2
139
  init_args:
140
  transpose_mask: false
141
  always_apply: true
142
  p: 1.0
143
  test_transform:
144
- - class_path: albumentations.Resize
145
- init_args:
146
- height: 448
147
- width: 448
148
- interpolation: 1
149
- always_apply: false
150
- p: 1
151
  - class_path: albumentations.pytorch.ToTensorV2
152
  init_args:
153
  transpose_mask: false
 
101
  - SWIR_1
102
  - SWIR_2
103
  train_transform:
 
 
 
 
 
 
 
104
  - class_path: albumentations.RandomCrop
105
  init_args:
106
  height: 224
 
121
  always_apply: true
122
  p: 1.0
123
  val_transform:
 
 
 
 
 
 
 
124
  - class_path: albumentations.pytorch.ToTensorV2
125
  init_args:
126
  transpose_mask: false
127
  always_apply: true
128
  p: 1.0
129
  test_transform:
 
 
 
 
 
 
 
130
  - class_path: albumentations.pytorch.ToTensorV2
131
  init_args:
132
  transpose_mask: false
inference.py CHANGED
@@ -185,7 +185,7 @@ def run_model(input_data, temporal_coords, location_coords, model, datamodule, i
185
  for x in windows:
186
  # Apply standardization
187
  x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1,2,0))
188
- x = datamodule.aug(x['image'])
189
 
190
  with torch.no_grad():
191
  x = x.to(model.device)
@@ -237,7 +237,7 @@ def main(
237
  # Load model ---------------------------------------------------------------------------------
238
 
239
  lightning_model = LightningInferenceModel.from_config(config, checkpoint)
240
- img_size = 512 # Size of Sen1Floods11
241
 
242
  # Loading data ---------------------------------------------------------------------------------
243
 
 
185
  for x in windows:
186
  # Apply standardization
187
  x = datamodule.test_transform(image=x.squeeze().numpy().transpose(1,2,0))
188
+ x = datamodule.aug(x)['image']
189
 
190
  with torch.no_grad():
191
  x = x.to(model.device)
 
237
  # Load model ---------------------------------------------------------------------------------
238
 
239
  lightning_model = LightningInferenceModel.from_config(config, checkpoint)
240
+ img_size = 512 # Size of Sen1Floods11 training
241
 
242
  # Loading data ---------------------------------------------------------------------------------
243
 
requirements.txt CHANGED
@@ -3,4 +3,4 @@ torchvision
3
  timm
4
  einops
5
  rasterio
6
- git+https://github.com/IBM/terratorch.git
 
3
  timm
4
  einops
5
  rasterio
6
+ terratorch==0.99.8