Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve smart read #51

Merged
merged 13 commits into from
Jun 16, 2023
14 changes: 7 additions & 7 deletions seaduck/eulerian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from seaduck.ocedata import HRel, OceData, RelCoord, TRel, VlRel, VRel

# from OceInterp.kernel_and_weight import _translate_to_tendency,find_pk_4d
from seaduck.smart_read import smart_read as sread
from seaduck.smart_read import smart_read
from seaduck.utils import (
_general_len,
find_px_py,
Expand Down Expand Up @@ -495,7 +495,7 @@ def _get_needed(
"have all the dimensions needed"
)
if prefetched is None:
return sread(self.ocedata[varName], ind)
return smart_read(self.ocedata[varName], ind)
else:
return prefetched[ind]

Expand Down Expand Up @@ -994,7 +994,7 @@ def _read_data_and_register(
temp_ind = _subtract_i_min(ind, i_min)
needed = np.nan_to_num(prefetched[temp_ind])
else:
needed = np.nan_to_num(sread(self.ocedata[varName], ind))
needed = np.nan_to_num(smart_read(self.ocedata[varName], ind))
data_lookup[hs] = needed
elif isinstance(varName, tuple):
uname, vname = varName
Expand All @@ -1013,10 +1013,10 @@ def _read_data_and_register(
vfromu = np.nan_to_num(upre[_subtract_i_min(indvfromu, i_min)])
vfromv = np.nan_to_num(vpre[_subtract_i_min(indvfromv, i_min)])
else:
ufromu = np.nan_to_num(sread(self.ocedata[uname], indufromu))
ufromv = np.nan_to_num(sread(self.ocedata[vname], indufromv))
vfromu = np.nan_to_num(sread(self.ocedata[uname], indvfromu))
vfromv = np.nan_to_num(sread(self.ocedata[vname], indvfromv))
ufromu = np.nan_to_num(smart_read(self.ocedata[uname], indufromu))
ufromv = np.nan_to_num(smart_read(self.ocedata[vname], indufromv))
vfromu = np.nan_to_num(smart_read(self.ocedata[uname], indvfromu))
vfromv = np.nan_to_num(smart_read(self.ocedata[vname], indvfromv))
temp_n_u[bool_ufromu] = ufromu # 0#
temp_n_u[bool_ufromv] = ufromv # 1#
temp_n_v[bool_vfromu] = vfromu # 0#
Expand Down
144 changes: 79 additions & 65 deletions seaduck/smart_read.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,98 @@
from collections import OrderedDict as orderdic

import dask.array
import numpy as np
import xarray as xr


def smart_read(da, ind, memory_chunk=3, xarray_more_efficient=100):
"""Read from a xarray.DataArray given tuple indexes.
def slice_data_and_shift_indexes(da, indexes_tuple):
"""Slice data using min/max indexes, and shift indexes."""
slicers = ()
for indexes, size in zip(indexes_tuple, da.shape):
start = indexes.min() or None
stop = stop if (stop := indexes.max() + 1) < size else None
slicers += (slice(start, stop),)
indexes_tuple = tuple(
indexes.ravel() - slicer.start if slicer.start else indexes.ravel()
for indexes, slicer in zip(indexes_tuple, slicers)
)
return da.data[slicers], indexes_tuple


def smart_read(da, indexes_tuple, dask_more_efficient=100, chunks="auto"):
"""Read from a xarray.DataArray given a tuple indexes.

Try to do it fast and smartly.
There is a lot of improvement to be made here,
but this is how it is currently done.

The data we read is going to be unstructured but they tend to be
rather localized. For example, the lagrangian particles read data
from the same time step.
This function figures out which chunks stores the data, convert them
into numpy arrays, and then read the data from the converted ones.

Parameters
----------
da: xarray.DataArray
DataArray to read from
ind: tuple of numpy.ndarray
indexes_tuple: tuple of numpy.ndarray
The indexes of points of interest, each element does not need to be 1D
memory_chunk: int, default 3
If the number of chunks needed is smaller than this, read all of them at once.
xarray_more_efficient: int, default 100
dask_more_efficient: int, default 100
When the number of chunks is larger than this, and the data points are few,
it may make sense to directly use xarray's vectorized read.
it may make sense to directly use dask's vectorized read.
chunks: int, str, default: "auto"
Chunks for indexes

Returns
-------
+ values: numpy.ndarray
The values of the points of interest. Has the same shape as the elements in ind.
The values of the points of interest. Has the same shape as the elements in indexes_tuple.
"""
the_shape = ind[0].shape
ind = tuple(i.ravel() for i in ind)
if len(da.dims) != len(ind):
raise ValueError("index does not match the number of dimensions")
if da.chunks is None or da.chunks == {}:
npck = np.array(da)
return npck[ind].reshape(the_shape)
if (
np.prod([len(i) for i in da.chunks]) <= memory_chunk
): # if the number of chunks is small don't bother
npck = np.array(da)
return npck[ind].reshape(the_shape)
cksz = orderdic(zip(da.dims, da.chunks))
keys = list(cksz.keys())
n = len(ind[0])
result = np.zeros(n)
if len(indexes_tuple) != da.ndim:
raise ValueError(
"indexes_tuple does not match the number of dimensions: "
f"{len(indexes_tuple)} vs {da.ndim}"
)

shape = indexes_tuple[0].shape
size = indexes_tuple[0].size
if not size:
return np.empty(shape) # TODO: Why shape (0, ...) is allowed and tested?
MaceKuailv marked this conversation as resolved.
Show resolved Hide resolved

data, indexes_tuple = slice_data_and_shift_indexes(da, indexes_tuple)
if isinstance(data, np.ndarray):
return data[indexes_tuple].reshape(shape)

if dask.array.empty(size, chunks=chunks).numblocks[0] > 1:
indexes_tuple = tuple(
dask.array.from_array(indexes, chunks=chunks) for indexes in indexes_tuple
)

block_dict = {}
for block_ids in np.ndindex(*data.numblocks):
if len(block_dict) >= dask_more_efficient:
return (
data.vindex[tuple(map(dask.array.compute, indexes_tuple))]
.compute()
.reshape(shape)
)

new_dic = {}
# typically what happens is that the first a few indexes are chunked
# here we figure out what is the last dimension chunked.
for i in range(len(cksz) - 1, -1, -1):
if len(cksz[keys[i]]) > 1:
last = i
break
shifted_indexes = []
mask = True
for block_id, indexes, chunks in zip(block_ids, indexes_tuple, data.chunks):
shifted = indexes - sum(chunks[:block_id])
block_mask = (shifted >= 0) & (shifted < chunks[block_id])
if not block_mask.any() or not (mask := mask & block_mask).any():
break # empty block
shifted_indexes.append(shifted)
else:
block_dict[block_ids] = (
np.nonzero(mask),
tuple(indexes[mask] for indexes in shifted_indexes),
)
if sum([len(v[0]) for v in block_dict.values()]) == size:
break # all blocks found

ckbl = np.zeros((n, i + 1)).astype(int)
# register each each dimension and the chunk they are in
for i, k in enumerate(keys[: i + 1]):
ix = ind[i]
suffix = np.cumsum(cksz[k])
new_dic[i] = suffix
ckbl[:, i] = np.searchsorted(suffix, ix, side="right")
# this is the time limiting step for localized long query.
ckus, inverse = np.unique(ckbl, axis=0, return_inverse=True)
# ckus is the individual chunks used
if len(ckus) <= xarray_more_efficient:
# logging.debug('use smart')
for i, k in enumerate(ckus):
ind_str = []
pre = []
which = inverse == i
for j, p in enumerate(k):
sf = new_dic[j][p] # the upperbound of index
pr = sf - cksz[keys[j]][p] # the lower bound of index
ind_str.append(slice(pr, sf))
pre.append(pr)
prs = np.zeros(len(keys)).astype(int)
prs[: last + 1] = pre
npck = np.array(da[tuple(ind_str)])
subind = tuple(ind[dim][which] - prs[dim] for dim in range(len(ind)))
result[which] = npck[subind]
return result.reshape(the_shape)
else:
# logging.debug('use xarray')
xrind = tuple(xr.DataArray(dim, dims=["x"]) for dim in ind)
return np.array(da[xrind]).reshape(the_shape)
values = np.empty(size)
for block_ids, (values_indexes, block_indexes) in block_dict.items():
block_values = data.blocks[block_ids].compute()
values[values_indexes] = block_values[block_indexes]
return values.reshape(shape)
8 changes: 4 additions & 4 deletions tests/test_smart_read.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest

from seaduck.smart_read import smart_read as srd
from seaduck.smart_read import smart_read


@pytest.fixture
Expand All @@ -19,16 +19,16 @@ def ind():
@pytest.mark.parametrize("ds", ["ecco"], indirect=True)
def test_just_read(ind, ds, chunk):
ds["SALT"] = ds["SALT"].chunk(chunk)
srd(ds["SALT"], ind)
smart_read(ds["SALT"], ind)


@pytest.mark.parametrize("ds", ["ecco"], indirect=True)
def test_read_xarray(ind, ds):
ds["SALT"] = ds["SALT"].chunk({"time": 1})
srd(ds["SALT"], ind, xarray_more_efficient=1)
smart_read(ds["SALT"], ind, dask_more_efficient=1)


@pytest.mark.parametrize("ds", ["ecco"], indirect=True)
def test_mismatch_read(ind, ds):
with pytest.raises(ValueError):
srd(ds["XC"], ind)
smart_read(ds["XC"], ind)