Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 'crop' and 'save_datasets' to MultiScene #613

Merged
merged 14 commits into from
Feb 13, 2019
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions doc/source/dev_guide/xarray_migration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,14 @@ than creating a delayed function. Similar to delayed functions the inputs to
the function are fully computed DataArrays or numpy arrays, but only the
individual chunks of the dask array at a time. Note that ``map_blocks`` must
be provided dask arrays and won't function properly on XArray DataArrays.
It is recommended that the function object passed to ``map_blocks`` **not**
be an internal function (a function defined inside another function) or it
may be unserializable and can cause issues in some environments.

.. code-block:: python

my_new_arr = da.map_blocks(_complex_operation, my_dask_arr1, my_dask_arr2, dtype=my_dask_arr1.dtype)

http://dask.pydata.org/en/latest/array-api.html#dask.array.core.map_blocks

Helpful functions
*****************

Expand Down
52 changes: 26 additions & 26 deletions satpy/composites/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,32 @@ def __call__(self, datasets, optional_datasets=None, **info):
return super(RatioSharpenedRGB, self).__call__((r, g, b), **info)


def _mean4(data, offset=(0, 0), block_id=None):
rows, cols = data.shape
# we assume that the chunks except the first ones are aligned
if block_id[0] == 0:
row_offset = offset[0] % 2
else:
row_offset = 0
if block_id[1] == 0:
col_offset = offset[1] % 2
else:
col_offset = 0
row_after = (row_offset + rows) % 2
col_after = (col_offset + cols) % 2
pad = ((row_offset, row_after), (col_offset, col_after))

rows2 = rows + row_offset + row_after
cols2 = cols + col_offset + col_after

av_data = np.pad(data, pad, 'edge')
new_shape = (int(rows2 / 2.), 2, int(cols2 / 2.), 2)
data_mean = np.nanmean(av_data.reshape(new_shape), axis=(1, 3))
data_mean = np.repeat(np.repeat(data_mean, 2, axis=0), 2, axis=1)
data_mean = data_mean[row_offset:row_offset + rows, col_offset:col_offset + cols]
return data_mean


class SelfSharpenedRGB(RatioSharpenedRGB):
"""Sharpen RGB with ratio of a band with a strided-version of itself.

Expand All @@ -1289,32 +1315,6 @@ class SelfSharpenedRGB(RatioSharpenedRGB):
@staticmethod
def four_element_average_dask(d):
"""Average every 4 elements (2x2) in a 2D array"""
def _mean4(data, offset=(0, 0), block_id=None):
rows, cols = data.shape
# we assume that the chunks except the first ones are aligned
if block_id[0] == 0:
row_offset = offset[0] % 2
else:
row_offset = 0
if block_id[1] == 0:
col_offset = offset[1] % 2
else:
col_offset = 0
row_after = (row_offset + rows) % 2
col_after = (col_offset + cols) % 2
pad = ((row_offset, row_after), (col_offset, col_after))

rows2 = rows + row_offset + row_after
cols2 = cols + col_offset + col_after

av_data = np.pad(data, pad, 'edge')
new_shape = (int(rows2 / 2.), 2, int(cols2 / 2.), 2)
data_mean = np.nanmean(av_data.reshape(new_shape), axis=(1, 3))
data_mean = np.repeat(np.repeat(data_mean, 2, axis=0), 2, axis=1)
data_mean = data_mean[row_offset:row_offset + rows,
col_offset:col_offset + cols]
return data_mean

try:
offset = d.attrs['area'].crop_offset
except (KeyError, AttributeError):
Expand Down
14 changes: 9 additions & 5 deletions satpy/composites/crefl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,16 +327,18 @@ def chand(phi, muv, mus, taur):
return rhoray, trdown, trup


def _sphalb_index(index_arr, sphalb0):
# FIXME: if/when dask can support lazy index arrays then remove this
return sphalb0[index_arr]


def atm_variables_finder(mus, muv, phi, height, tau, tO3, tH2O, taustep4sphalb, tO2=1.0):
tau_step = da.linspace(taustep4sphalb, MAXNUMSPHALBVALUES * taustep4sphalb, MAXNUMSPHALBVALUES,
chunks=int(MAXNUMSPHALBVALUES / 2))
sphalb0 = csalbr(tau_step)
taur = tau * da.exp(-height / SCALEHEIGHT)
rhoray, trdown, trup = chand(phi, muv, mus, taur)
if isinstance(height, xr.DataArray):
def _sphalb_index(index_arr, sphalb0):
# FIXME: if/when dask can support lazy index arrays then remove this
return sphalb0[index_arr]
sphalb = da.map_blocks(_sphalb_index, (taur / taustep4sphalb + 0.5).astype(np.int32).data, sphalb0.compute(),
dtype=sphalb0.dtype)
else:
Expand Down Expand Up @@ -380,6 +382,10 @@ def G_calc(zenith, a_coeff):
return (da.cos(da.deg2rad(zenith))+(a_coeff[0]*(zenith**a_coeff[1])*(a_coeff[2]-zenith)**a_coeff[3]))**-1


def _avg_elevation_index(avg_elevation, row, col):
return avg_elevation[row, col]


def run_crefl(refl, coeffs,
lon,
lat,
Expand Down Expand Up @@ -423,8 +429,6 @@ def run_crefl(refl, coeffs,
row[space_mask] = 0
col[space_mask] = 0

def _avg_elevation_index(avg_elevation, row, col):
return avg_elevation[row, col]
height = da.map_blocks(_avg_elevation_index, avg_elevation, row, col, dtype=avg_elevation.dtype)
height = xr.DataArray(height, dims=['y', 'x'])
# negative heights aren't allowed, clip to 0
Expand Down
139 changes: 110 additions & 29 deletions satpy/multiscene.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
from satpy.writers import get_enhanced_image
from satpy.dataset import combine_metadata, DatasetID
from itertools import chain
from threading import Thread

try:
# python 3
from queue import Queue
except ImportError:
# python 2
from Queue import Queue

try:
# new API
Expand Down Expand Up @@ -249,6 +257,10 @@ def load(self, *args, **kwargs):
"""Load the required datasets from the multiple scenes."""
self._generate_scene_func(self._scenes, 'load', False, *args, **kwargs)

def crop(self, *args, **kwargs):
"""Crop the multiscene and return a new cropped multiscene."""
return self._generate_scene_func(self._scenes, 'crop', True, *args, **kwargs)

def resample(self, destination=None, **kwargs):
"""Resample the multiscene."""
return self._generate_scene_func(self._scenes, 'resample', True, destination=destination, **kwargs)
Expand All @@ -270,6 +282,85 @@ def blend(self, blend_function=stack):

return new_scn

def _distribute_save_datasets(self, scenes_iter, client, batch_size=1, **kwargs):
"""Distribute save_datasets across a cluster."""
def load_data(q):
idx = 0
while True:
future_list = q.get()
if future_list is None:
break

# save_datasets shouldn't be returning anything
for future in future_list:
future.result()
log.info("Finished saving %d scenes", idx)
idx += 1
q.task_done()

input_q = Queue(batch_size if batch_size is not None else 1)
load_thread = Thread(target=load_data, args=(input_q,))
load_thread.start()

for scene in scenes_iter:
delayed = scene.save_datasets(compute=False, **kwargs)
if isinstance(delayed, (list, tuple)) and len(delayed) == 2:
# TODO Make this work for (source, target) datasets
# given a target, source combination
raise NotImplementedError("Distributed save_datasets does not support writers "
"that return (source, target) combinations at this time. Use "
"the non-distributed save_datasets instead.")
future = client.compute(delayed)
input_q.put(future)
input_q.put(None)

log.debug("Waiting for child thread to get saved results...")
load_thread.join()
log.debug("Child thread died successfully")

def _simple_save_datasets(self, scenes_iter, **kwargs):
"""Helper to simple run save_datasets on each Scene."""
for scn in scenes_iter:
scn.save_datasets(**kwargs)

def save_datasets(self, client=True, batch_size=1, **kwargs):
"""Run save_datasets on each Scene.

Note that some writers may not be multi-process friendly and may
produce unexpected results or fail by raising an exception. In
these cases ``client`` should be set to ``False``.
This is currently a known issue for basic 'geotiff' writer work loads.

Args:
batch_size (int): Number of scenes to compute at the same time.
This only has effect if the `dask.distributed` package is
installed. This will default to 1. Setting this to 0 or less
will attempt to process all scenes at once. This option should
be used with care to avoid memory issues when trying to
improve performance.
client (bool or dask.distributed.Client): Dask distributed client
to use for computation. If this is ``True`` (default) then
any existing clients will be used.
If this is ``False`` or ``None`` then a client will not be
created and ``dask.distributed`` will not be used. If this
is a dask ``Client`` object then it will be used for
distributed computation.
kwargs: Additional keyword arguments to pass to
:meth:`~satpy.scene.Scene.save_datasets`.
Note ``compute`` can not be provided.

"""
if 'compute' in kwargs:
raise ValueError("The 'compute' keyword argument can not be provided.")

client = self._get_client(client=client)

scenes = iter(self._scenes)
if client is not None:
self._distribute_save_datasets(scenes, client, batch_size=batch_size, **kwargs)
else:
self._simple_save_datasets(scenes, **kwargs)

def _get_animation_info(self, all_datasets, filename, fill_value=None):
"""Determine filename and shape of animation to be created."""
valid_datasets = [ds for ds in all_datasets if ds is not None]
Expand Down Expand Up @@ -310,16 +401,24 @@ def _get_animation_frames(self, all_datasets, shape, fill_value=None,
data = data.transpose('y', 'x', 'bands')
yield data.data

def _get_client(self, client=True):
"""Determine what dask distributed client to use."""
client = client or None # convert False/None to None
if client is True:
try:
# get existing client
from dask.distributed import get_client
client = get_client()
except ImportError:
log.debug("'dask.distributed' library was not found, will "
"use simple serial processing.")
except ValueError:
log.warning("No dask distributed client was provided or found, "
"but distributed features were requested. Will use simple serial processing.")
return client

def _distribute_frame_compute(self, writers, frame_keys, frames_to_write, client, batch_size=1):
"""Use ``dask.distributed`` to compute multiple frames at a time."""
try:
# python 3
from queue import Queue
except ImportError:
# python 2
from Queue import Queue
from threading import Thread

def load_data(frame_gen, q):
for frame_arrays in frame_gen:
future_list = client.compute(frame_arrays)
Expand Down Expand Up @@ -391,7 +490,7 @@ def save_animation(self, filename, datasets=None, fps=10, fill_value=None,
fps (int): Frames per second for produced animation
fill_value (int): Value to use instead creating an alpha band.
batch_size (int): Number of frames to compute at the same time.
This only has affect if the `dask.distributed` package is
This only has effect if the `dask.distributed` package is
installed. This will default to 1. Setting this to 0 or less
will attempt to process all frames at once. This option should
be used with care to avoid memory issues when trying to
Expand All @@ -403,7 +502,7 @@ def save_animation(self, filename, datasets=None, fps=10, fill_value=None,
is missing from a child scene.
client (bool or dask.distributed.Client): Dask distributed client
to use for computation. If this is ``True`` (default) then
any existing clients will be used or a new one created.
any existing clients will be used.
If this is ``False`` or ``None`` then a client will not be
created and ``dask.distributed`` will not be used. If this
is a dask ``Client`` object then it will be used for
Expand Down Expand Up @@ -449,22 +548,7 @@ def save_animation(self, filename, datasets=None, fps=10, fill_value=None,
frames[dataset_id] = data_to_write
writers[dataset_id] = writer

client = client or None # convert False/None to None
close_client = False
if client is True:
try:
# get existing client
from dask.distributed import get_client
client = get_client()
except ImportError:
log.debug("'dask.distributed' library was not found, will "
"use simple frame processing.")
except ValueError:
# create new client
from dask.distributed import Client
client = Client()
close_client = True

client = self._get_client(client=client)
# get an ordered list of frames
frame_keys, frames_to_write = list(zip(*frames.items()))
frames_to_write = zip(*frames_to_write)
Expand All @@ -475,6 +559,3 @@ def save_animation(self, filename, datasets=None, fps=10, fill_value=None,

for writer in writers.values():
writer.close()
if close_client:
log.debug("Closing dask client...")
client.close()
6 changes: 5 additions & 1 deletion satpy/readers/electrol_hrit.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,10 @@ def calibrate(self, data, calibration):
logger.debug("Calibration time " + str(datetime.now() - tic))
return res

@staticmethod
def _getitem(block, lut):
return lut[block]

def _calibrate(self, data):
"""Visible/IR channel calibration."""
lut = self.prologue['ImageCalibration'][self.chid]
Expand All @@ -315,7 +319,7 @@ def _calibrate(self, data):
lut /= 1000
lut[0] = np.nan
# Dask/XArray don't support indexing in 2D (yet).
res = data.data.map_blocks(lambda block: lut[block], dtype=lut.dtype)
res = data.data.map_blocks(self._getitem, dtype=lut.dtype)
res = xr.DataArray(res, dims=data.dims,
attrs=data.attrs, coords=data.coords)
res = res.where(data > 0)
Expand Down
2 changes: 1 addition & 1 deletion satpy/readers/goes_imager_hrit.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def _calibrate(self, data):
idx = self.mda['calibration_parameters']['indices']
val = self.mda['calibration_parameters']['values']
data.data = da.where(data.data == 0, np.nan, data.data)
ddata = data.data.map_blocks(lambda block: np.interp(block, idx, val), dtype=val.dtype)
ddata = data.data.map_blocks(np.interp, idx, val, dtype=val.dtype)
res = xr.DataArray(ddata,
dims=data.dims, attrs=data.attrs,
coords=data.coords)
Expand Down
12 changes: 5 additions & 7 deletions satpy/readers/hrit_jma.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,10 @@ def _mask_space(self, data):
geomask = get_geostationary_mask(area=self.area)
return data.where(geomask)

@staticmethod
def _interp(arr, cal):
return np.interp(arr.ravel(), cal[:, 0], cal[:, 1]).reshape(arr.shape)

def calibrate(self, data, calibration):
"""Calibrate the data."""
tic = datetime.now()
Expand All @@ -310,13 +314,7 @@ def calibrate(self, data, calibration):
raise NotImplementedError("Can't calibrate to radiance.")
else:
cal = self.calibration_table

def interp(arr):
return np.interp(arr.ravel(),
cal[:, 0], cal[:, 1]).reshape(arr.shape)

res = data.data.map_blocks(interp, dtype=cal[:, 0].dtype)

res = data.data.map_blocks(self._interp, cal, dtype=cal[:, 0].dtype)
res = xr.DataArray(res,
dims=data.dims, attrs=data.attrs,
coords=data.coords)
Expand Down
Loading