diff --git a/README.md b/README.md index 68a1c04..3eac62d 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ import earthnet_minicuber as emc specs = { "lon_lat": (43.598946, 3.087414), # center pixel "xy_shape": (256, 256), # width, height of cutout around center pixel - "resolution": 20, # in meters.. will use this on a local UTM grid.. + "resolution": 10, # in meters.. will use this on a local UTM grid.. "time_interval": "2021-07-01/2021-07-31", "providers": [ { @@ -94,6 +94,7 @@ Kwargs: - `five_daily_filter`: If `True` returns a regular 5-daily cycle starting with the first date in `full_time_interval`. It has no effect, if `best_orbit_filter` is used. - `brdf_correction`: If `True`, does BRDF correction based on the Sentinel 2 Metadata (illumination angles). - `cloud_mask`: If `True`, creates a cloud and cloud shadow mask based on deep learning. It automatically finds the best available cloud mask for the requested `bands`. +- `cloud_mask_rescale_factor`: If using cloud mask and a lower resolution than 10m, set this rescaling factor to the multiple of 10m that you are requesting. E.g. if `resolution = 20`, set `cloud_mask_rescale_factor = 2`. - `correct_processing_baseline`: If `True` (default): corrects the shift of +1000 that exists in Sentinel 2 data with processing baseline >= 4.0 diff --git a/earthnet_minicuber/__init__.py b/earthnet_minicuber/__init__.py index 2fdf639..edaae8c 100644 --- a/earthnet_minicuber/__init__.py +++ b/earthnet_minicuber/__init__.py @@ -1,6 +1,6 @@ """EarthNet Minicuber""" -__version__ = "0.1.2" +__version__ = "0.1.3" __author__ = "Vitus Benson" diff --git a/earthnet_minicuber/provider/s2/cloudmask.py b/earthnet_minicuber/provider/s2/cloudmask.py index 7552141..5dffde0 100644 --- a/earthnet_minicuber/provider/s2/cloudmask.py +++ b/earthnet_minicuber/provider/s2/cloudmask.py @@ -26,8 +26,9 @@ def get_checkpoint(bands_avail): class CloudMask: - def __init__(self, bands = ["B02", "B03", "B04", "B8A"]): + def __init__(self, bands = ["B02", "B03", "B04", "B8A"], cloud_mask_rescale_factor = None): + self.cloud_mask_rescale_factor = cloud_mask_rescale_factor self.bands = bands ckpt, self.ckpt_bands = get_checkpoint(bands) @@ -63,12 +64,21 @@ def __call__(self, stack): x = torch.nn.functional.pad(x, (w_pad_left, w_pad_right, h_pad_left, h_pad_right), mode = "reflect") + if self.cloud_mask_rescale_factor: + #orig_size = (x.shape[-2], x.shape[-1]) + x = torch.nn.functional.interpolate(x, scale_factor = self.cloud_mask_rescale_factor, mode = 'bilinear') + with torch.no_grad(): y_hat = self.model(x) - y_hat = y_hat[:, :, h_pad_left:-h_pad_right, w_pad_left:-w_pad_right] + y_hat = torch.argmax(y_hat, dim = 1).float() + + if self.cloud_mask_rescale_factor: + y_hat = torch.nn.functional.max_pool2d(y_hat[:,None,...], kernel_size = self.cloud_mask_rescale_factor)[:,0,...]#torch.nn.functional.interpolate(y_hat, size = orig_size, mode = "bilinear") + + y_hat = y_hat[:, h_pad_left:-h_pad_right, w_pad_left:-w_pad_right] - ds["mask"] = (("time", "y", "x"),torch.argmax(y_hat, dim = 1).float().cpu().numpy()) + ds["mask"] = (("time", "y", "x"),y_hat.cpu().numpy()) return ds.to_array("band") diff --git a/earthnet_minicuber/provider/s2/nbar.py b/earthnet_minicuber/provider/s2/nbar.py index 437b204..491bd60 100644 --- a/earthnet_minicuber/provider/s2/nbar.py +++ b/earthnet_minicuber/provider/s2/nbar.py @@ -53,13 +53,16 @@ def call_sen2nbar(stack, items, epsg): # Compute the c-factor per item and extract the processing baseline c_array = [] for item in ordered_items:#tqdm(ordered_items, disable=quiet): - c = c_factor_from_item(item, f"epsg:{epsg}") - c = c.interp( - y=stack.y.values, - x=stack.x.values, - method="linear", - kwargs={"fill_value": "extrapolate"}, - ) + try: + c = c_factor_from_item(item, f"epsg:{epsg}") + c = c.interp( + y=stack.y.values, + x=stack.x.values, + method="linear", + kwargs={"fill_value": "extrapolate"}, + ) + except ValueError: + c = xr.DataArray(np.full((9,len(stack.y), len(stack.x)), np.NaN), coords = {"band": ['B02','B03','B04','B05','B06','B07','B08','B11','B12'], "y": stack.y, "x": stack.x}, dims = ("band", "y", "x")) c_array.append(c) orig_bands = stack.band.values.tolist() diff --git a/earthnet_minicuber/provider/s2/sentinel2.py b/earthnet_minicuber/provider/s2/sentinel2.py index 4df7898..6521d1d 100644 --- a/earthnet_minicuber/provider/s2/sentinel2.py +++ b/earthnet_minicuber/provider/s2/sentinel2.py @@ -39,11 +39,11 @@ class Sentinel2(provider_base.Provider): - def __init__(self, bands = ["AOT", "B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B09", "B11", "B12", "WVP"], best_orbit_filter = True, five_daily_filter = False, brdf_correction = True, cloud_mask = True, aws_bucket = "planetary_computer", s2_avail_var = True, correct_processing_baseline = True): + def __init__(self, bands = ["AOT", "B01", "B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B09", "B11", "B12", "WVP"], best_orbit_filter = True, five_daily_filter = False, brdf_correction = True, cloud_mask = True, cloud_mask_rescale_factor = None, aws_bucket = "planetary_computer", s2_avail_var = True, correct_processing_baseline = True): self.is_temporal = True - self.cloud_mask = CloudMask(bands=bands) if cloud_mask else None + self.cloud_mask = CloudMask(bands=bands, cloud_mask_rescale_factor = cloud_mask_rescale_factor) if cloud_mask else None if self.cloud_mask and "SCL" not in bands: bands += ["SCL"] @@ -201,6 +201,9 @@ def load_data(self, bbox, time_interval, **kwargs): stack = stack.sel(time = stack.time.dt.date.isin(dates)) + if len(stack.time) == 0: + return None + if self.correct_processing_baseline: stack = correct_processing_baseline(stack, items_s2) diff --git a/setup.py b/setup.py index 6b43870..7d45e42 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ setup(name='earthnet-minicuber', - version='0.1.2', + version='0.1.3', description="EarthNet Minicuber", author="Vitus Benson, Christian Requena-Mesa", author_email="vbenson@bgc-jena.mpg.de",