diff --git a/seaduck/smart_read.py b/seaduck/smart_read.py index afb68829..ad61dc8f 100644 --- a/seaduck/smart_read.py +++ b/seaduck/smart_read.py @@ -3,7 +3,7 @@ import numpy as np -def smart_read(da, indexes_tuple, dask_more_efficient=100): +def smart_read(da, indexes_tuple, dask_more_efficient=100, dense=1e7): """Read from a xarray.DataArray given a tuple indexes. Try to do it fast and smartly. @@ -49,18 +49,32 @@ def smart_read(da, indexes_tuple, dask_more_efficient=100): reach_first_entry = False block_dict = {} - largest_indexes_of_chunk = [] - for chunks in data.chunks: - largest_indexes_of_chunk.append(list(accumulate(chunks))) - - smallest_indexes_of_dimension = [] + min_indexes_of_dim = [] for indexes in indexes_tuple: - smallest_indexes_of_dimension.append(np.min(indexes)) + min_indexes_of_dim.append(np.min(indexes)) + + if dense: + max_indexes_of_dim = [] + shifted_indexes = [] + for idim, indexes in enumerate(indexes_tuple): + max_indexes_of_dim.append(np.max(indexes)) + shifted_indexes.append(indexes - min_indexes_of_dim[idim]) + minmax = zip(min_indexes_of_dim, max_indexes_of_dim) + slice_tuple = tuple(slice(mn, mx + 1) for mn, mx in minmax) + dense_block_size = np.prod([mx - mn + 1 for mn, mx in minmax]) + if dense_block_size <= dense: + dense_block = data[slice_tuple].compute() + values = dense_block[tuple(shifted_indexes)] + return values.reshape(shape) + + max_indexes_of_chunk = [] + for chunks in data.chunks: + max_indexes_of_chunk.append(list(accumulate(chunks))) for block_ids in np.ndindex(*data.numblocks): if not reach_first_entry: for block_id, large_index, small_ind in zip( - block_ids, largest_indexes_of_chunk, smallest_indexes_of_dimension + block_ids, max_indexes_of_chunk, min_indexes_of_dim ): if large_index[block_id] < small_ind: break