diff --git a/ci/environment-ci.yml b/ci/environment-ci.yml index f7f1dde3..6ebd985c 100644 --- a/ci/environment-ci.yml +++ b/ci/environment-ci.yml @@ -3,7 +3,6 @@ channels: - conda-forge - nodefaults dependencies: -- dask - jupyter-book - make - matplotlib diff --git a/environment.yml b/environment.yml index 7263e241..9c6da8d5 100644 --- a/environment.yml +++ b/environment.yml @@ -7,3 +7,4 @@ dependencies: - xarray - numpy - pandas +- dask diff --git a/pyproject.toml b/pyproject.toml index b27bc28a..3fb299ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "numpy", "pandas", "scipy", + "dask", "xarray" ] description = "A python package that interpolates data from ocean dataset from both Eulerian and Lagrangian perspective. " diff --git a/seaduck/eulerian.py b/seaduck/eulerian.py index efb060f3..3848f9f3 100644 --- a/seaduck/eulerian.py +++ b/seaduck/eulerian.py @@ -7,7 +7,7 @@ from seaduck.ocedata import HRel, OceData, RelCoord, TRel, VlRel, VRel # from OceInterp.kernel_and_weight import _translate_to_tendency,find_pk_4d -from seaduck.smart_read import smart_read as sread +from seaduck.smart_read import smart_read from seaduck.utils import ( _general_len, find_px_py, @@ -495,7 +495,7 @@ def _get_needed( "have all the dimensions needed" ) if prefetched is None: - return sread(self.ocedata[varName], ind) + return smart_read(self.ocedata[varName], ind) else: return prefetched[ind] @@ -994,7 +994,7 @@ def _read_data_and_register( temp_ind = _subtract_i_min(ind, i_min) needed = np.nan_to_num(prefetched[temp_ind]) else: - needed = np.nan_to_num(sread(self.ocedata[varName], ind)) + needed = np.nan_to_num(smart_read(self.ocedata[varName], ind)) data_lookup[hs] = needed elif isinstance(varName, tuple): uname, vname = varName @@ -1013,10 +1013,10 @@ def _read_data_and_register( vfromu = np.nan_to_num(upre[_subtract_i_min(indvfromu, i_min)]) vfromv = np.nan_to_num(vpre[_subtract_i_min(indvfromv, i_min)]) else: - ufromu = np.nan_to_num(sread(self.ocedata[uname], indufromu)) - ufromv = np.nan_to_num(sread(self.ocedata[vname], indufromv)) - vfromu = np.nan_to_num(sread(self.ocedata[uname], indvfromu)) - vfromv = np.nan_to_num(sread(self.ocedata[vname], indvfromv)) + ufromu = np.nan_to_num(smart_read(self.ocedata[uname], indufromu)) + ufromv = np.nan_to_num(smart_read(self.ocedata[vname], indufromv)) + vfromu = np.nan_to_num(smart_read(self.ocedata[uname], indvfromu)) + vfromv = np.nan_to_num(smart_read(self.ocedata[vname], indvfromv)) temp_n_u[bool_ufromu] = ufromu # 0# temp_n_u[bool_ufromv] = ufromv # 1# temp_n_v[bool_vfromu] = vfromu # 0# diff --git a/seaduck/smart_read.py b/seaduck/smart_read.py index f626b0d7..4d79652d 100644 --- a/seaduck/smart_read.py +++ b/seaduck/smart_read.py @@ -1,84 +1,98 @@ -from collections import OrderedDict as orderdic - +import dask.array import numpy as np -import xarray as xr -def smart_read(da, ind, memory_chunk=3, xarray_more_efficient=100): - """Read from a xarray.DataArray given tuple indexes. +def slice_data_and_shift_indexes(da, indexes_tuple): + """Slice data using min/max indexes, and shift indexes.""" + slicers = () + for indexes, size in zip(indexes_tuple, da.shape): + start = indexes.min() or None + stop = stop if (stop := indexes.max() + 1) < size else None + slicers += (slice(start, stop),) + indexes_tuple = tuple( + indexes.ravel() - slicer.start if slicer.start else indexes.ravel() + for indexes, slicer in zip(indexes_tuple, slicers) + ) + return da.data[slicers], indexes_tuple + + +def smart_read(da, indexes_tuple, dask_more_efficient=10, chunks="auto"): + """Read from a xarray.DataArray given a tuple indexes. Try to do it fast and smartly. + There is a lot of improvement to be made here, + but this is how it is currently done. + + The data we read is going to be unstructured but they tend to be + rather localized. For example, the lagrangian particles read data + from the same time step. + This function figures out which chunks stores the data, convert them + into numpy arrays, and then read the data from the converted ones. Parameters ---------- da: xarray.DataArray DataArray to read from - ind: tuple of numpy.ndarray + indexes_tuple: tuple of numpy.ndarray The indexes of points of interest, each element does not need to be 1D - memory_chunk: int, default 3 - If the number of chunks needed is smaller than this, read all of them at once. - xarray_more_efficient: int, default 100 + dask_more_efficient: int, default 100 When the number of chunks is larger than this, and the data points are few, - it may make sense to directly use xarray's vectorized read. + it may make sense to directly use dask's vectorized read. + chunks: int, str, default: "auto" + Chunks for indexes Returns ------- + values: numpy.ndarray - The values of the points of interest. Has the same shape as the elements in ind. + The values of the points of interest. Has the same shape as the elements in indexes_tuple. """ - the_shape = ind[0].shape - ind = tuple(i.ravel() for i in ind) - if len(da.dims) != len(ind): - raise ValueError("index does not match the number of dimensions") - if da.chunks is None or da.chunks == {}: - npck = np.array(da) - return npck[ind].reshape(the_shape) - if ( - np.prod([len(i) for i in da.chunks]) <= memory_chunk - ): # if the number of chunks is small don't bother - npck = np.array(da) - return npck[ind].reshape(the_shape) - cksz = orderdic(zip(da.dims, da.chunks)) - keys = list(cksz.keys()) - n = len(ind[0]) - result = np.zeros(n) + if len(indexes_tuple) != da.ndim: + raise ValueError( + "indexes_tuple does not match the number of dimensions: " + f"{len(indexes_tuple)} vs {da.ndim}" + ) + + shape = indexes_tuple[0].shape + size = indexes_tuple[0].size + if not size: + return np.empty(shape) # TODO: Why shape (0, ...) is allowed and tested? + + data, indexes_tuple = slice_data_and_shift_indexes(da, indexes_tuple) + if isinstance(data, np.ndarray): + return data[indexes_tuple].reshape(shape) + + if dask.array.empty(size, chunks=chunks).numblocks[0] > 1: + indexes_tuple = tuple( + dask.array.from_array(indexes, chunks=chunks) for indexes in indexes_tuple + ) + + block_dict = {} + for block_ids in np.ndindex(*data.numblocks): + if len(block_dict) >= dask_more_efficient: + return ( + data.vindex[tuple(map(dask.array.compute, indexes_tuple))] + .compute() + .reshape(shape) + ) - new_dic = {} - # typically what happens is that the first a few indexes are chunked - # here we figure out what is the last dimension chunked. - for i in range(len(cksz) - 1, -1, -1): - if len(cksz[keys[i]]) > 1: - last = i - break + shifted_indexes = [] + mask = None + for block_id, indexes, chunks in zip(block_ids, indexes_tuple, data.chunks): + shifted = indexes - sum(chunks[:block_id]) + block_mask = (shifted >= 0) & (shifted < chunks[block_id]) + if not (mask := block_mask if mask is None else mask & block_mask).any(): + break # empty block + shifted_indexes.append(shifted) + else: + block_dict[block_ids] = ( + np.nonzero(mask), + tuple(indexes[mask] for indexes in shifted_indexes), + ) + if sum([len(v[0]) for v in block_dict.values()]) == size: + break # all blocks found - ckbl = np.zeros((n, i + 1)).astype(int) - # register each each dimension and the chunk they are in - for i, k in enumerate(keys[: i + 1]): - ix = ind[i] - suffix = np.cumsum(cksz[k]) - new_dic[i] = suffix - ckbl[:, i] = np.searchsorted(suffix, ix, side="right") - # this is the time limiting step for localized long query. - ckus, inverse = np.unique(ckbl, axis=0, return_inverse=True) - # ckus is the individual chunks used - if len(ckus) <= xarray_more_efficient: - # logging.debug('use smart') - for i, k in enumerate(ckus): - ind_str = [] - pre = [] - which = inverse == i - for j, p in enumerate(k): - sf = new_dic[j][p] # the upperbound of index - pr = sf - cksz[keys[j]][p] # the lower bound of index - ind_str.append(slice(pr, sf)) - pre.append(pr) - prs = np.zeros(len(keys)).astype(int) - prs[: last + 1] = pre - npck = np.array(da[tuple(ind_str)]) - subind = tuple(ind[dim][which] - prs[dim] for dim in range(len(ind))) - result[which] = npck[subind] - return result.reshape(the_shape) - else: - # logging.debug('use xarray') - xrind = tuple(xr.DataArray(dim, dims=["x"]) for dim in ind) - return np.array(da[xrind]).reshape(the_shape) + values = np.empty(size) + for block_ids, (values_indexes, block_indexes) in block_dict.items(): + block_values = data.blocks[block_ids].compute() + values[values_indexes] = block_values[block_indexes] + return values.reshape(shape) diff --git a/tests/test_smart_read.py b/tests/test_smart_read.py index 2658743d..4481194c 100644 --- a/tests/test_smart_read.py +++ b/tests/test_smart_read.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from seaduck.smart_read import smart_read as srd +from seaduck.smart_read import smart_read @pytest.fixture @@ -19,16 +19,16 @@ def ind(): @pytest.mark.parametrize("ds", ["ecco"], indirect=True) def test_just_read(ind, ds, chunk): ds["SALT"] = ds["SALT"].chunk(chunk) - srd(ds["SALT"], ind) + smart_read(ds["SALT"], ind) @pytest.mark.parametrize("ds", ["ecco"], indirect=True) def test_read_xarray(ind, ds): ds["SALT"] = ds["SALT"].chunk({"time": 1}) - srd(ds["SALT"], ind, xarray_more_efficient=1) + smart_read(ds["SALT"], ind, dask_more_efficient=1) @pytest.mark.parametrize("ds", ["ecco"], indirect=True) def test_mismatch_read(ind, ds): with pytest.raises(ValueError): - srd(ds["XC"], ind) + smart_read(ds["XC"], ind)