diff --git a/safe_s1/metadata.py b/safe_s1/metadata.py index be567a3..684ccbb 100644 --- a/safe_s1/metadata.py +++ b/safe_s1/metadata.py @@ -1,22 +1,136 @@ import os +import pdb import re import dask import fsspec import numpy as np -import rasterio +# import rasterio import yaml from affine import Affine -from rioxarray import rioxarray - +import tifffile +import zarr from . import sentinel1_xml_mappings from .xml_parser import XmlParser import xarray as xr +import dask.array as da import datatree import pandas as pd import warnings +def compute_low_res_tiles(tile, spacing, posting, tile_width, resolution=None, window='GAUSSIAN'): + """ + Compute low resolution tiles on defined ground spacing based on full resolution SLC tile. + Code example: + tile = dn.sel(pol='vv') + mytile = tiles[0] + spacing = {'sample':mytile['sampleSpacing']/np.sin(np.radians(mytile['incidence'])), 'line':mytile['lineSpacing']} + posting = {'sample':400,'line':400} + tile_width = {'sample':17600.,'line':17600.} + low_res_tile = compute_low_res_tiles(mytile, spacing = spacing, posting = posting, tile_width = tile_width) + + Args: + tile (xarray.DataArray) : dataArray containing digital number (2D matrix) for a given channel (ie polarization) + spacing (dict): GROUND spacing of provided tile. {name of dimension (str): spacing in [m] (float)}. + posting (dict): Desired output posting. {name of dimension (str): spacing in [m] (float)}. + tile_width (dict): form {name of dimension (str): width in [m] (float)}. Desired width of the output tile (should be smaller or equal than provided data) + resolution (dict, optional): resolution for filter. default is twice the posting (Nyquist) + window (str, optional): Name of window used to smooth out the data. 'GAUSSIAN' and 'RECT' are valid entries + Returns: + (xarray.Dataset) : dataset of filtered/resampled sigma0 + """ + from scipy.signal import fftconvolve + # from xsarslc.tools import gaussian_kernel, rect_kernel + + if resolution is None: + resolution = {d: 2 * v for d, v in posting.items()} + + # sigma0 = tile['sigma0'] + mask = np.isfinite(tile) + if window.upper() == 'GAUSSIAN': + kernel_filter = gaussian_kernel(width=resolution, spacing=spacing) + elif window.upper() == 'RECT': + kernel_filter = rect_kernel(width=resolution, spacing=spacing) + else: + raise ValueError('Unknown window: {}'.format(window)) + swap_dims = {d: d + '_' for d in resolution.keys()} + kernel_filter = kernel_filter.rename(swap_dims) + low_pass = xr.apply_ufunc(fftconvolve, tile.where(mask, 0.), kernel_filter, + input_core_dims=[resolution.keys(), swap_dims.values()], vectorize=True, + output_core_dims=[resolution.keys()], kwargs={'mode': 'same'}, dask='allowed') + + normal = xr.apply_ufunc(fftconvolve, mask, kernel_filter, + input_core_dims=[resolution.keys(), swap_dims.values()], + vectorize=True, output_core_dims=[resolution.keys()], kwargs={'mode': 'same'}, + dask='allowed') + + low_pass = low_pass / normal + + # ------- decimate ------- + Np = {d: np.rint(tile_width[d] / posting[d]).astype(int) for d in tile_width.keys()} + new_line = xr.DataArray( + int(low_pass['line'].isel(line=low_pass.sizes['line'] // 2)) + np.arange(-Np['line'] // 2, + Np['line'] // 2) * float( + posting['line'] / spacing['line']), dims='line') # previous azimuth + new_sample = xr.DataArray( + int(low_pass['sample'].isel(sample=low_pass.sizes['sample'] // 2)) + np.arange(-Np['sample'] // 2, + Np['sample'] // 2) * float( + posting['sample'] / spacing['sample']), dims='sample') # previous range + decimated = low_pass.interp(sample=new_sample, line=new_line, assume_sorted=True).rename('digital_number') + + # decimated = decimated.drop_vars(['line', 'sample']) + # decimated = decimated.swap_dims({'azimuth':'line','range':'sample'}) + decimated.attrs.update(tile.attrs) + + range_spacing = xr.DataArray(posting['sample'], attrs={'units': 'm', 'long_name': 'ground range spacing'}, + name='range_spacing') + azimuth_spacing = xr.DataArray(posting['line'], attrs={'units': 'm', 'long_name': 'azimuth spacing'}, + name='azimuth_spacing') + # added_variables = [tile[v].to_dataset() for v in + # ['incidence', 'ground_heading', 'land_flag']] # add variables from L1B to output + # decimated = xr.merge( + # [decimated.to_dataset(), range_spacing.to_dataset(), + # azimuth_spacing.to_dataset()]) + # decimated = decimated.transpose('azimuth', 'range', ...) + decimated = decimated.transpose('line', 'sample', ...) + print('decimated',decimated) + return decimated,range_spacing.to_dataset(),azimuth_spacing.to_dataset() +def gaussian_kernel(width, spacing, truncate=3.): + """ + Compute a Gaussian kernel for filtering. The width correspond to the wavelength that is needed to be kept. The standard deviation of the gaussian has to be width/(2 pi) + + Args: + width (dict): form {name of dimension (str): width in [m] (float)} + spacing (dict): form {name of dimension (str): spacing in [m] (float)} + truncate (float): gaussian shape is truncate at +/- (truncate x width) value + """ + gk = 1. + width = {d: w / (2 * np.pi) for d, w in width.items()} # frequency cut off has a 2 pi factor + for d in width.keys(): + coord = np.arange(-truncate * width[d], truncate * width[d], spacing[d]) + coord = xr.DataArray(coord, dims=d, coords={d: coord}) + gk = gk * np.exp(-coord ** 2 / (2 * width[d] ** 2)) + gk /= gk.sum() + return gk + + +def rect_kernel(width, spacing): + """ + Compute a rectangular window kernel for filtering + + Args: + width (dict): form {name of dimension (str): width in [m] (float)} + spacing (dict): form {name of dimension (str): spacing in [m] (float)} + """ + wk = 1. + for d in width.keys(): + coord = np.arange(-width[d] / 2, width[d] / 2, spacing[d]) + win = xr.DataArray(np.ones_like(coord), dims=d, coords={d: coord}) + wk = wk * win + wk /= wk.sum() + return wk + class Sentinel1Reader: def __init__(self, name, backend_kwargs=None): @@ -104,14 +218,37 @@ def __init__(self, name, backend_kwargs=None): self.dt = datatree.DataTree.from_dict(self._dict) assert self.dt==self.datatree - def load_digital_number(self, resolution=None, chunks=None, resampling=rasterio.enums.Resampling.rms): + def basic_open_tiff(self,files_measurement,map_dims): + tmplist = [] + for f, pol in zip(files_measurement, self.manifest_attrs['polarizations']): + aa = tifffile.imread(f, aszarr=True) + bb = zarr.open(aa, mode='r') + cc = da.from_array(bb) + cc2 = cc.reshape((1, cc.shape[0], cc.shape[1])) + dimss = tuple(map_dims.keys()) + dd = xr.DataArray(cc2, dims=dimss, coords={'pol': [pol]}) + tmplist.append(dd) + + # dn = xr.concat(tmplist,dim='pol').assign_coords(band=np.arange(len(self.manifest_attrs['polarizations'])) + 1) + # dn = xr.merge(tmplist) + # dn = xr.combine_by_coords(tmplist) + dn = xr.combine_nested(tmplist, concat_dim='pol') + # set dimensions names + # dn = dn.rename(dict(zip(map_dims.values(), map_dims.keys()))) + + # create coordinates from dimension index (because of parse_coordinates=False) + dn = dn.assign_coords({'line': dn.line, 'sample': dn.sample}) + dn = dn.drop_vars('spatial_ref', errors='ignore') + return dn + + + def load_digital_number(self, resolution=None, chunks=None): """ load digital_number from self.sar_meta.files['measurement'], as an `xarray.Dataset`. Parameters ---------- resolution: None, numbers.Number, str or dict - resampling: rasterio.enums.Resampling Returns ------- @@ -164,8 +301,8 @@ def _get_glob(st): } if resolution is not None: - comment = 'resampled at "%s" with %s.%s.%s' % ( - resolution, resampling.__module__, resampling.__class__.__name__, resampling.name) + comment = 'resampled at "%s" with %s' % ( + resolution,'compute_low_res_tiles') else: comment = 'read at full resolution' @@ -174,31 +311,17 @@ def _get_glob(st): files_measurement = [os.path.join(self.path, f) for f in files_measurement] # arbitrary rio object, to get shape, etc ... (will not be used to read data) - rio = rasterio.open(files_measurement[0]) - + #rio = rasterio.open(files_measurement[0]) + metaobjzarr = zarr.open(tifffile.imread(files_measurement[0], aszarr=True), mode='r') # lazy load to get shape chunks['pol'] = 1 # sort chunks keys like map_dims chunks = dict(sorted(chunks.items(), key=lambda pair: list(map_dims.keys()).index(pair[0]))) - chunks_rio = {map_dims[d]: chunks[d] for d in map_dims.keys()} + # chunks_rio = {map_dims[d]: chunks[d] for d in map_dims.keys()} res = None if resolution is None: # using tiff driver: need to read individual tiff and concat them # riofiles['rio'] is ordered like self.sar_meta.manifest_attrs['polarizations'] - - dn = xr.concat( - [ - rioxarray.open_rasterio( - f, chunks=chunks_rio, parse_coordinates=False - ) for f in files_measurement - ], 'band' - ).assign_coords(band=np.arange(len(self.manifest_attrs['polarizations'])) + 1) - - # set dimensions names - dn = dn.rename(dict(zip(map_dims.values(), map_dims.keys()))) - - # create coordinates from dimension index (because of parse_coordinates=False) - dn = dn.assign_coords({'line': dn.line, 'sample': dn.sample}) - dn = dn.drop_vars('spatial_ref', errors='ignore') + dn = self.basic_open_tiff(files_measurement,map_dims).chunk(chunks) else: if not isinstance(resolution, dict): if isinstance(resolution, str) and resolution.endswith('m'): @@ -211,8 +334,8 @@ def _get_glob(st): # resample the DN at gdal level, before feeding it to the dataset out_shape = ( - int(rio.height / resolution['line']), - int(rio.width / resolution['sample']) + int(metaobjzarr.shape[0] / resolution['line']), + int(metaobjzarr.shape[1] / resolution['sample']) ) out_shape_pol = (1,) + out_shape # read resampled array in one chunk, and rechunk @@ -220,38 +343,28 @@ def _get_glob(st): if isinstance(resolution['line'], int): # legacy behaviour: winsize is the maximum full image size that can be divided by resolution (int) - winsize = (0, 0, rio.width // resolution['sample'] * resolution['sample'], - rio.height // resolution['line'] * resolution['line']) - window = rasterio.windows.Window(*winsize) + winsize = (0, 0, metaobjzarr.shape[1] // resolution['sample'] * resolution['sample'], + metaobjzarr.shape[0] // resolution['line'] * resolution['line']) + #window = rasterio.windows.Window(*winsize) + window_az = slice(winsize[1],winsize[3]) + window_ra = slice(winsize[0],winsize[2]) else: - window = None - - dn = xr.concat( - [ - xr.DataArray( - dask.array.from_array( - rasterio.open(f).read( - out_shape=out_shape_pol, - resampling=resampling, - window=window - ), - chunks=chunks_rio - ), - dims=tuple(map_dims.keys()), coords={'pol': [pol]} - ) for f, pol in - zip(files_measurement, self.manifest_attrs['polarizations']) - ], - 'pol' - ).chunk(chunks) - - # create coordinates at box center - translate = Affine.translation((resolution['sample'] - 1) / 2, (resolution['line'] - 1) / 2) - scale = Affine.scale( - rio.width // resolution['sample'] * resolution['sample'] / out_shape[1], - rio.height // resolution['line'] * resolution['line'] / out_shape[0]) - sample, _ = translate * scale * (dn.sample, 0) - _, line = translate * scale * (0, dn.line) - dn = dn.assign_coords({'line': line, 'sample': sample}) + window_az = None + window_ra = None + dn = self.basic_open_tiff(files_measurement, map_dims).chunk(chunks) + spacing = {'line':float(self.datatree['image'].ds['azimuthPixelSpacing']), + 'sample':float(self.datatree['image'].ds['groundRangePixelSpacing'])} + posting = {'sample':resolution['sample'], + 'line':resolution['line']} # here we considere that posting and resolution are equal, to keep interfaces, it can be revised in the future + # info: in xsarslc resolution = 2*posting, may be future evolution will introduce posting as optinal argument + tile_width = {'sample':metaobjzarr.shape[1],'line':metaobjzarr.shape[0]} + + channels = [] + for pol in dn.pol: + dn_channel_sep,range_spacing_ds,az_spacing_ds = compute_low_res_tiles(dn.sel(pol=pol), spacing, + posting, tile_width, resolution=None, window='GAUSSIAN') + channels.append(dn_channel_sep) + dn = xr.combine_nested(channels, concat_dim='pol') # for GTiff driver, pols are already ordered. just rename them dn = dn.assign_coords(pol=self.manifest_attrs['polarizations'])