Skip to content

Commit

Permalink
slice and dask
Browse files Browse the repository at this point in the history
  • Loading branch information
malmans2 committed Jun 16, 2023
1 parent 3cd8c92 commit 2821a3f
Showing 1 changed file with 36 additions and 45 deletions.
81 changes: 36 additions & 45 deletions seaduck/smart_read.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
from itertools import accumulate

import dask.array
import numpy as np


def smart_read(da, indexes_tuple, dask_more_efficient=100, dense=1e7):
def slice_data_and_shift_indexes(da, indexes_tuple):
"""Slice data using min/max indexes, and shift indexes."""
slicers = ()
for indexes, size in zip(indexes_tuple, da.shape):
start = indexes.min() or None
stop = stop if (stop := indexes.max() + 1) < size else None
slicers += (slice(start, stop),)
indexes_tuple = tuple(
indexes.ravel() - slicer.start if slicer.start else indexes.ravel()
for indexes, slicer in zip(indexes_tuple, slicers)
)
return da.data[slicers], indexes_tuple


def smart_read(da, indexes_tuple, dask_more_efficient=100, chunks="auto"):
"""Read from a xarray.DataArray given a tuple indexes.
Try to do it fast and smartly.
Expand All @@ -25,6 +38,8 @@ def smart_read(da, indexes_tuple, dask_more_efficient=100, dense=1e7):
dask_more_efficient: int, default 100
When the number of chunks is larger than this, and the data points are few,
it may make sense to directly use dask's vectorized read.
chunks: int, str, default: "auto"
Chunks for indexes
Returns
-------
Expand All @@ -39,47 +54,26 @@ def smart_read(da, indexes_tuple, dask_more_efficient=100, dense=1e7):

shape = indexes_tuple[0].shape
size = indexes_tuple[0].size
indexes_tuple = tuple(indexes.ravel() for indexes in indexes_tuple)

if not da.chunks:
return da.values[indexes_tuple].reshape(shape)
data = da.data

found_count = 0
reach_first_entry = False
block_dict = {}
if not size:
return np.empty(shape) # TODO: Why shape (0, ...) is allowed and tested?

min_indexes_of_dim = []
for indexes in indexes_tuple:
min_indexes_of_dim.append(np.min(indexes))
data, indexes_tuple = slice_data_and_shift_indexes(da, indexes_tuple)
if isinstance(data, np.ndarray):
return data[indexes_tuple].reshape(shape)

if dense:
max_indexes_of_dim = []
shifted_indexes = []
for idim, indexes in enumerate(indexes_tuple):
max_indexes_of_dim.append(np.max(indexes))
shifted_indexes.append(indexes - min_indexes_of_dim[idim])
minmax = zip(min_indexes_of_dim, max_indexes_of_dim)
slice_tuple = tuple(slice(mn, mx + 1) for mn, mx in minmax)
dense_block_size = np.prod([mx - mn + 1 for mn, mx in minmax])
if dense_block_size <= dense:
dense_block = data[slice_tuple].compute()
values = dense_block[tuple(shifted_indexes)]
return values.reshape(shape)

max_indexes_of_chunk = []
for chunks in data.chunks:
max_indexes_of_chunk.append(list(accumulate(chunks)))
if dask.array.empty(size, chunks=chunks).numblocks[0] > 1:
indexes_tuple = tuple(
dask.array.from_array(indexes, chunks=chunks) for indexes in indexes_tuple
)

block_dict = {}
for block_ids in np.ndindex(*data.numblocks):
if not reach_first_entry:
for block_id, large_index, small_ind in zip(
block_ids, max_indexes_of_chunk, min_indexes_of_dim
):
if large_index[block_id] < small_ind:
break
else:
reach_first_entry = True
if len(block_dict) >= dask_more_efficient:
return (
data.vindex[tuple(map(dask.array.compute, indexes_tuple))]
.compute()
.reshape(shape)
)

shifted_indexes = []
mask = True
Expand All @@ -91,13 +85,10 @@ def smart_read(da, indexes_tuple, dask_more_efficient=100, dense=1e7):
shifted_indexes.append(shifted)
else:
block_dict[block_ids] = (
np.argwhere(mask).squeeze(),
np.nonzero(mask),
tuple(indexes[mask] for indexes in shifted_indexes),
)
if len(block_dict) >= dask_more_efficient:
return data.vindex[indexes_tuple].compute().reshape(shape)

if (found_count := found_count + mask.sum()) == size:
if sum([len(v[0]) for v in block_dict.values()]) == size:
break # all blocks found

values = np.empty(size)
Expand Down

0 comments on commit 2821a3f

Please sign in to comment.