Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No GDAL lib (tifffile replaces) #10

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 171 additions & 58 deletions safe_s1/metadata.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,136 @@
import os
import pdb
import re

import dask
import fsspec
import numpy as np
import rasterio
# import rasterio
import yaml
from affine import Affine
from rioxarray import rioxarray

import tifffile
import zarr
from . import sentinel1_xml_mappings
from .xml_parser import XmlParser
import xarray as xr
import dask.array as da
import datatree
import pandas as pd
import warnings


def compute_low_res_tiles(tile, spacing, posting, tile_width, resolution=None, window='GAUSSIAN'):
"""
Compute low resolution tiles on defined ground spacing based on full resolution SLC tile.
Code example:
tile = dn.sel(pol='vv')
mytile = tiles[0]
spacing = {'sample':mytile['sampleSpacing']/np.sin(np.radians(mytile['incidence'])), 'line':mytile['lineSpacing']}
posting = {'sample':400,'line':400}
tile_width = {'sample':17600.,'line':17600.}
low_res_tile = compute_low_res_tiles(mytile, spacing = spacing, posting = posting, tile_width = tile_width)

Args:
tile (xarray.DataArray) : dataArray containing digital number (2D matrix) for a given channel (ie polarization)
spacing (dict): GROUND spacing of provided tile. {name of dimension (str): spacing in [m] (float)}.
posting (dict): Desired output posting. {name of dimension (str): spacing in [m] (float)}.
tile_width (dict): form {name of dimension (str): width in [m] (float)}. Desired width of the output tile (should be smaller or equal than provided data)
resolution (dict, optional): resolution for filter. default is twice the posting (Nyquist)
window (str, optional): Name of window used to smooth out the data. 'GAUSSIAN' and 'RECT' are valid entries
Returns:
(xarray.Dataset) : dataset of filtered/resampled sigma0
"""
from scipy.signal import fftconvolve
# from xsarslc.tools import gaussian_kernel, rect_kernel

if resolution is None:
resolution = {d: 2 * v for d, v in posting.items()}

# sigma0 = tile['sigma0']
mask = np.isfinite(tile)
if window.upper() == 'GAUSSIAN':
kernel_filter = gaussian_kernel(width=resolution, spacing=spacing)
elif window.upper() == 'RECT':
kernel_filter = rect_kernel(width=resolution, spacing=spacing)
else:
raise ValueError('Unknown window: {}'.format(window))
swap_dims = {d: d + '_' for d in resolution.keys()}
kernel_filter = kernel_filter.rename(swap_dims)
low_pass = xr.apply_ufunc(fftconvolve, tile.where(mask, 0.), kernel_filter,
input_core_dims=[resolution.keys(), swap_dims.values()], vectorize=True,
output_core_dims=[resolution.keys()], kwargs={'mode': 'same'}, dask='allowed')

normal = xr.apply_ufunc(fftconvolve, mask, kernel_filter,
input_core_dims=[resolution.keys(), swap_dims.values()],
vectorize=True, output_core_dims=[resolution.keys()], kwargs={'mode': 'same'},
dask='allowed')

low_pass = low_pass / normal

# ------- decimate -------
Np = {d: np.rint(tile_width[d] / posting[d]).astype(int) for d in tile_width.keys()}
new_line = xr.DataArray(
int(low_pass['line'].isel(line=low_pass.sizes['line'] // 2)) + np.arange(-Np['line'] // 2,
Np['line'] // 2) * float(
posting['line'] / spacing['line']), dims='line') # previous azimuth
new_sample = xr.DataArray(
int(low_pass['sample'].isel(sample=low_pass.sizes['sample'] // 2)) + np.arange(-Np['sample'] // 2,
Np['sample'] // 2) * float(
posting['sample'] / spacing['sample']), dims='sample') # previous range
decimated = low_pass.interp(sample=new_sample, line=new_line, assume_sorted=True).rename('digital_number')

# decimated = decimated.drop_vars(['line', 'sample'])
# decimated = decimated.swap_dims({'azimuth':'line','range':'sample'})
decimated.attrs.update(tile.attrs)

range_spacing = xr.DataArray(posting['sample'], attrs={'units': 'm', 'long_name': 'ground range spacing'},
name='range_spacing')
azimuth_spacing = xr.DataArray(posting['line'], attrs={'units': 'm', 'long_name': 'azimuth spacing'},
name='azimuth_spacing')
# added_variables = [tile[v].to_dataset() for v in
# ['incidence', 'ground_heading', 'land_flag']] # add variables from L1B to output
# decimated = xr.merge(
# [decimated.to_dataset(), range_spacing.to_dataset(),
# azimuth_spacing.to_dataset()])
# decimated = decimated.transpose('azimuth', 'range', ...)
decimated = decimated.transpose('line', 'sample', ...)
print('decimated',decimated)
return decimated,range_spacing.to_dataset(),azimuth_spacing.to_dataset()
def gaussian_kernel(width, spacing, truncate=3.):
"""
Compute a Gaussian kernel for filtering. The width correspond to the wavelength that is needed to be kept. The standard deviation of the gaussian has to be width/(2 pi)

Args:
width (dict): form {name of dimension (str): width in [m] (float)}
spacing (dict): form {name of dimension (str): spacing in [m] (float)}
truncate (float): gaussian shape is truncate at +/- (truncate x width) value
"""
gk = 1.
width = {d: w / (2 * np.pi) for d, w in width.items()} # frequency cut off has a 2 pi factor
for d in width.keys():
coord = np.arange(-truncate * width[d], truncate * width[d], spacing[d])
coord = xr.DataArray(coord, dims=d, coords={d: coord})
gk = gk * np.exp(-coord ** 2 / (2 * width[d] ** 2))
gk /= gk.sum()
return gk


def rect_kernel(width, spacing):
"""
Compute a rectangular window kernel for filtering

Args:
width (dict): form {name of dimension (str): width in [m] (float)}
spacing (dict): form {name of dimension (str): spacing in [m] (float)}
"""
wk = 1.
for d in width.keys():
coord = np.arange(-width[d] / 2, width[d] / 2, spacing[d])
win = xr.DataArray(np.ones_like(coord), dims=d, coords={d: coord})
wk = wk * win
wk /= wk.sum()
return wk

class Sentinel1Reader:

def __init__(self, name, backend_kwargs=None):
Expand Down Expand Up @@ -104,14 +218,37 @@ def __init__(self, name, backend_kwargs=None):
self.dt = datatree.DataTree.from_dict(self._dict)
assert self.dt==self.datatree

def load_digital_number(self, resolution=None, chunks=None, resampling=rasterio.enums.Resampling.rms):
def basic_open_tiff(self,files_measurement,map_dims):
tmplist = []
for f, pol in zip(files_measurement, self.manifest_attrs['polarizations']):
aa = tifffile.imread(f, aszarr=True)
bb = zarr.open(aa, mode='r')
cc = da.from_array(bb)
cc2 = cc.reshape((1, cc.shape[0], cc.shape[1]))
dimss = tuple(map_dims.keys())
dd = xr.DataArray(cc2, dims=dimss, coords={'pol': [pol]})
tmplist.append(dd)

# dn = xr.concat(tmplist,dim='pol').assign_coords(band=np.arange(len(self.manifest_attrs['polarizations'])) + 1)
# dn = xr.merge(tmplist)
# dn = xr.combine_by_coords(tmplist)
dn = xr.combine_nested(tmplist, concat_dim='pol')
# set dimensions names
# dn = dn.rename(dict(zip(map_dims.values(), map_dims.keys())))

# create coordinates from dimension index (because of parse_coordinates=False)
dn = dn.assign_coords({'line': dn.line, 'sample': dn.sample})
dn = dn.drop_vars('spatial_ref', errors='ignore')
return dn


def load_digital_number(self, resolution=None, chunks=None):
"""
load digital_number from self.sar_meta.files['measurement'], as an `xarray.Dataset`.

Parameters
----------
resolution: None, numbers.Number, str or dict
resampling: rasterio.enums.Resampling

Returns
-------
Expand Down Expand Up @@ -164,8 +301,8 @@ def _get_glob(st):
}

if resolution is not None:
comment = 'resampled at "%s" with %s.%s.%s' % (
resolution, resampling.__module__, resampling.__class__.__name__, resampling.name)
comment = 'resampled at "%s" with %s' % (
resolution,'compute_low_res_tiles')
else:
comment = 'read at full resolution'

Expand All @@ -174,31 +311,17 @@ def _get_glob(st):
files_measurement = [os.path.join(self.path, f) for f in files_measurement]

# arbitrary rio object, to get shape, etc ... (will not be used to read data)
rio = rasterio.open(files_measurement[0])

#rio = rasterio.open(files_measurement[0])
metaobjzarr = zarr.open(tifffile.imread(files_measurement[0], aszarr=True), mode='r') # lazy load to get shape
chunks['pol'] = 1
# sort chunks keys like map_dims
chunks = dict(sorted(chunks.items(), key=lambda pair: list(map_dims.keys()).index(pair[0])))
chunks_rio = {map_dims[d]: chunks[d] for d in map_dims.keys()}
# chunks_rio = {map_dims[d]: chunks[d] for d in map_dims.keys()}
res = None
if resolution is None:
# using tiff driver: need to read individual tiff and concat them
# riofiles['rio'] is ordered like self.sar_meta.manifest_attrs['polarizations']

dn = xr.concat(
[
rioxarray.open_rasterio(
f, chunks=chunks_rio, parse_coordinates=False
) for f in files_measurement
], 'band'
).assign_coords(band=np.arange(len(self.manifest_attrs['polarizations'])) + 1)

# set dimensions names
dn = dn.rename(dict(zip(map_dims.values(), map_dims.keys())))

# create coordinates from dimension index (because of parse_coordinates=False)
dn = dn.assign_coords({'line': dn.line, 'sample': dn.sample})
dn = dn.drop_vars('spatial_ref', errors='ignore')
dn = self.basic_open_tiff(files_measurement,map_dims).chunk(chunks)
else:
if not isinstance(resolution, dict):
if isinstance(resolution, str) and resolution.endswith('m'):
Expand All @@ -211,47 +334,37 @@ def _get_glob(st):

# resample the DN at gdal level, before feeding it to the dataset
out_shape = (
int(rio.height / resolution['line']),
int(rio.width / resolution['sample'])
int(metaobjzarr.shape[0] / resolution['line']),
int(metaobjzarr.shape[1] / resolution['sample'])
)
out_shape_pol = (1,) + out_shape
# read resampled array in one chunk, and rechunk
# this doesn't optimize memory, but total size remain quite small

if isinstance(resolution['line'], int):
# legacy behaviour: winsize is the maximum full image size that can be divided by resolution (int)
winsize = (0, 0, rio.width // resolution['sample'] * resolution['sample'],
rio.height // resolution['line'] * resolution['line'])
window = rasterio.windows.Window(*winsize)
winsize = (0, 0, metaobjzarr.shape[1] // resolution['sample'] * resolution['sample'],
metaobjzarr.shape[0] // resolution['line'] * resolution['line'])
#window = rasterio.windows.Window(*winsize)
window_az = slice(winsize[1],winsize[3])
window_ra = slice(winsize[0],winsize[2])
else:
window = None

dn = xr.concat(
[
xr.DataArray(
dask.array.from_array(
rasterio.open(f).read(
out_shape=out_shape_pol,
resampling=resampling,
window=window
),
chunks=chunks_rio
),
dims=tuple(map_dims.keys()), coords={'pol': [pol]}
) for f, pol in
zip(files_measurement, self.manifest_attrs['polarizations'])
],
'pol'
).chunk(chunks)

# create coordinates at box center
translate = Affine.translation((resolution['sample'] - 1) / 2, (resolution['line'] - 1) / 2)
scale = Affine.scale(
rio.width // resolution['sample'] * resolution['sample'] / out_shape[1],
rio.height // resolution['line'] * resolution['line'] / out_shape[0])
sample, _ = translate * scale * (dn.sample, 0)
_, line = translate * scale * (0, dn.line)
dn = dn.assign_coords({'line': line, 'sample': sample})
window_az = None
window_ra = None
dn = self.basic_open_tiff(files_measurement, map_dims).chunk(chunks)
spacing = {'line':float(self.datatree['image'].ds['azimuthPixelSpacing']),
'sample':float(self.datatree['image'].ds['groundRangePixelSpacing'])}
posting = {'sample':resolution['sample'],
'line':resolution['line']} # here we considere that posting and resolution are equal, to keep interfaces, it can be revised in the future
# info: in xsarslc resolution = 2*posting, may be future evolution will introduce posting as optinal argument
tile_width = {'sample':metaobjzarr.shape[1],'line':metaobjzarr.shape[0]}

channels = []
for pol in dn.pol:
dn_channel_sep,range_spacing_ds,az_spacing_ds = compute_low_res_tiles(dn.sel(pol=pol), spacing,
posting, tile_width, resolution=None, window='GAUSSIAN')
channels.append(dn_channel_sep)
dn = xr.combine_nested(channels, concat_dim='pol')

# for GTiff driver, pols are already ordered. just rename them
dn = dn.assign_coords(pol=self.manifest_attrs['polarizations'])
Expand Down