| import onnxruntime as ort |
| from typing import List, Tuple, Any, Dict |
| from pathlib import Path |
| import numpy as np |
| from croplands.io import read_zarr, read_zarr_profile |
| from croplands.utils import impute_nan, normalize_s2 |
| from croplands.polygonize import polygonize_raster |
| import json |
| from skimage import measure |
|
|
| class CroplandHandler(): |
|
|
| def __init__(self, input_dir: str, output_dir: str, device: str = "cpu") -> None: |
|
|
| self.input_dir = Path(input_dir) |
| self.output_dir = Path(output_dir) |
|
|
| |
| assert self.input_dir.exists(), "Input directory doesn't exist" |
| assert self.output_dir.exists(), "Output directory doesn't exist" |
| assert device == "cpu" or device.startswith("cuda"), f"{device} is not a valid device." |
|
|
| |
| mdoel_path = "model_repository/utae.onnx" |
| provider = "CUDAExecutionProvider" if device.startswith("cuda") else "CPUExecutionProvider" |
| self.session = ort.InferenceSession(str(mdoel_path), providers=[provider]) |
|
|
| with open("months_per_patch.json") as dates: |
| self.dates = json.load(dates) |
| |
| def preprocess(self, file: str) -> Tuple[np.array, Dict, np.array]: |
|
|
| assert file is not None, "Missing input file for inference" |
|
|
| file_path = self.input_dir / file |
| data = read_zarr(file_path) |
| data = impute_nan(data) |
| data = normalize_s2(data) |
| profile = read_zarr_profile(file_path) |
| dates = self.dates[file_path.stem] |
| batch = np.expand_dims(data,axis=0) |
| dates = np.expand_dims(np.array(dates),axis=0) |
| return batch, profile, dates |
| |
| def postprocess(self, outputs: Any, file: str, profile: Dict, save_raster: bool = False) -> np.array: |
| outputs = np.array(outputs) |
|
|
|
|
| if save_raster: |
| out_class = np.argmax(outputs[0][0], axis=0) |
| out_bin = (out_class!=0).astype(np.uint8) |
| components = measure.label(out_bin, connectivity=1) |
| gdf = polygonize_raster(out_class, components, tolerance = 0.0001, transform= profile["transform"], |
| crs=profile["crs"]) |
| data_path = self.input_dir / file |
| save_path = self.output_dir / (data_path.stem + ".parquet") |
| gdf.to_parquet(save_path) |
|
|
| return outputs |
|
|
| def predict(self, files: List[str], save_raster: bool = False) -> np.array: |
| |
| |
| batch, profiles, dates = self.preprocess(files) |
| |
| outputs = self.session.run(None, {"input": batch, "batch_positions": dates}) |
| |
| outputs = self.postprocess(outputs, files, profiles, save_raster) |
|
|
| return outputs |
|
|
|
|
|
|
|
|
| |