Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve smart read #51

Merged
merged 13 commits into from
Jun 16, 2023
111 changes: 44 additions & 67 deletions seaduck/smart_read.py
Original file line number Diff line number Diff line change
@@ -1,84 +1,61 @@
from collections import OrderedDict as orderdic

import numpy as np
import xarray as xr


def smart_read(da, ind, memory_chunk=3, xarray_more_efficient=100):
"""Read from a xarray.DataArray given tuple indexes.
def smart_read(da, indexes_tuple, dask_more_efficient=100):
"""Read from a xarray.DataArray given a tuple indexes.

Try to do it fast and smartly.

Parameters
----------
da: xarray.DataArray
DataArray to read from
ind: tuple of numpy.ndarray
indexes_tuple: tuple of numpy.ndarray
The indexes of points of interest, each element does not need to be 1D
memory_chunk: int, default 3
If the number of chunks needed is smaller than this, read all of them at once.
xarray_more_efficient: int, default 100
dask_more_efficient: int, default 100
When the number of chunks is larger than this, and the data points are few,
it may make sense to directly use xarray's vectorized read.
it may make sense to directly use dask's vectorized read.

Returns
-------
+ values: numpy.ndarray
The values of the points of interest. Has the same shape as the elements in ind.
The values of the points of interest. Has the same shape as the elements in indexes_tuple.
"""
the_shape = ind[0].shape
ind = tuple(i.ravel() for i in ind)
if len(da.dims) != len(ind):
raise ValueError("index does not match the number of dimensions")
if da.chunks is None or da.chunks == {}:
npck = np.array(da)
return npck[ind].reshape(the_shape)
if (
np.prod([len(i) for i in da.chunks]) <= memory_chunk
): # if the number of chunks is small don't bother
npck = np.array(da)
return npck[ind].reshape(the_shape)
cksz = orderdic(zip(da.dims, da.chunks))
keys = list(cksz.keys())
n = len(ind[0])
result = np.zeros(n)

new_dic = {}
# typically what happens is that the first a few indexes are chunked
# here we figure out what is the last dimension chunked.
for i in range(len(cksz) - 1, -1, -1):
if len(cksz[keys[i]]) > 1:
last = i
break

ckbl = np.zeros((n, i + 1)).astype(int)
# register each each dimension and the chunk they are in
for i, k in enumerate(keys[: i + 1]):
ix = ind[i]
suffix = np.cumsum(cksz[k])
new_dic[i] = suffix
ckbl[:, i] = np.searchsorted(suffix, ix, side="right")
# this is the time limiting step for localized long query.
ckus, inverse = np.unique(ckbl, axis=0, return_inverse=True)
# ckus is the individual chunks used
if len(ckus) <= xarray_more_efficient:
# logging.debug('use smart')
for i, k in enumerate(ckus):
ind_str = []
pre = []
which = inverse == i
for j, p in enumerate(k):
sf = new_dic[j][p] # the upperbound of index
pr = sf - cksz[keys[j]][p] # the lower bound of index
ind_str.append(slice(pr, sf))
pre.append(pr)
prs = np.zeros(len(keys)).astype(int)
prs[: last + 1] = pre
npck = np.array(da[tuple(ind_str)])
subind = tuple(ind[dim][which] - prs[dim] for dim in range(len(ind)))
result[which] = npck[subind]
return result.reshape(the_shape)
else:
# logging.debug('use xarray')
xrind = tuple(xr.DataArray(dim, dims=["x"]) for dim in ind)
return np.array(da[xrind]).reshape(the_shape)
if len(indexes_tuple) != da.ndim:
raise ValueError(
"indexes_tuple does not match the number of dimensions: "
f"{len(indexes_tuple)} vs {da.ndim}"
)

shape = indexes_tuple[0].shape
size = indexes_tuple[0].size
indexes_tuple = tuple(indexes.ravel() for indexes in indexes_tuple)

if not da.chunks:
return da.values[indexes_tuple].reshape(shape)
data = da.data

found_count = 0
block_dict = {}
for block_ids in np.ndindex(*data.numblocks):
shifted_indexes = []
mask = True
for block_id, indexes, chunks in zip(block_ids, indexes_tuple, data.chunks):
shifted = indexes - sum(chunks[:block_id])
block_mask = (shifted >= 0) & (shifted < chunks[block_id])
if not block_mask.any() or not (mask := mask & block_mask).any():
break # empty block
shifted_indexes.append(shifted)
else:
block_dict[block_ids] = (mask, shifted_indexes)
if len(block_dict) >= dask_more_efficient:
return data.vindex[indexes_tuple].compute().reshape(shape)

if (found_count := found_count + mask.sum()) == size:
break # all blocks found

values = np.empty(size)
for block_ids, (mask, shifted_indexes) in block_dict.items():
block_values = data.blocks[block_ids].compute()
values[mask] = block_values[tuple(indexes[mask] for indexes in shifted_indexes)]
return values.reshape(shape)
2 changes: 1 addition & 1 deletion tests/test_smart_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_just_read(ind, ds, chunk):
@pytest.mark.parametrize("ds", ["ecco"], indirect=True)
def test_read_xarray(ind, ds):
ds["SALT"] = ds["SALT"].chunk({"time": 1})
srd(ds["SALT"], ind, xarray_more_efficient=1)
srd(ds["SALT"], ind, dask_more_efficient=1)


@pytest.mark.parametrize("ds", ["ecco"], indirect=True)
Expand Down