Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New BatchSchema class to generate patch selectors and combine into batch selectors #132

Merged
merged 32 commits into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0a5d0cc
Typing for generators and accessors
maxrjones Nov 18, 2022
022ebf7
Add recommendations from code review
maxrjones Nov 18, 2022
60501e5
Merge branch 'main' into generators-typing
maxrjones Nov 30, 2022
d199ee0
Merge branch 'main' into generators-typing
maxrjones Dec 1, 2022
4275828
More informative function name
maxrjones Dec 1, 2022
8787b7b
Use keyword only args for internal functions
maxrjones Dec 1, 2022
7ea67b6
Use dict over OrderedDict
maxrjones Dec 1, 2022
d15d6b1
Add type hint for _gen_slices output
maxrjones Dec 1, 2022
ec399bc
More informative function name
maxrjones Dec 1, 2022
1770d36
More informative variable name
maxrjones Dec 1, 2022
bcdcd72
More informative name for batch selectors
maxrjones Dec 1, 2022
cdd92e2
Type hint for _iterate_over_dimensions output
maxrjones Dec 1, 2022
cc9f04c
Fix batch_dims bug
maxrjones Dec 2, 2022
5ca52da
Merge branch 'main' into batch-dims-bug
maxrjones Dec 2, 2022
463546e
Support preload_batch parameter
maxrjones Dec 2, 2022
8b6e4aa
Merge branch 'main' into batch-dims-bug
maxrjones Dec 5, 2022
b320ba7
Merge branch 'main' into batch-dims-bug
maxrjones Dec 12, 2022
191a644
Try out dataclass for batch selectors
maxrjones Dec 14, 2022
6fdd8f0
Update comment
maxrjones Dec 15, 2022
3b33988
Merge branch 'main' into batch-dims-bug
maxrjones Dec 16, 2022
a0a5614
Remove xfail markers
maxrjones Dec 18, 2022
32f736f
Account for duplicate batch and input dims in testing utils
maxrjones Dec 18, 2022
d7ea564
Support duplicate dims in batch and input
maxrjones Dec 18, 2022
1eca395
Split _combine_patches_grouped_by_input_and_batch_dims()
maxrjones Dec 19, 2022
8a43ed6
Mark test with xfail
maxrjones Dec 19, 2022
98d5d19
Remove dataclass decorator
maxrjones Dec 19, 2022
b781280
Fix warning
maxrjones Dec 19, 2022
29f871b
Compute dask arrays before selecting on patches
maxrjones Dec 20, 2022
0b6c495
Remove NotImplementedError
maxrjones Dec 20, 2022
9d0fcb0
Fix case with more batch_dims than input_dims
maxrjones Dec 20, 2022
95e81b6
Merge branch 'main' into batch-dims-bug
maxrjones Dec 31, 2022
288f0dc
Update xbatcher/generators.py
maxrjones Jan 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xbatcher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from . import testing # noqa: F401
from .accessors import BatchAccessor # noqa: F401
from .generators import BatchGenerator # noqa: F401
from .generators import BatchGenerator, BatchSchema # noqa: F401
from .util.print_versions import show_versions # noqa: F401

try:
Expand Down
343 changes: 302 additions & 41 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,262 @@
"""Classes for iterating through xarray datarrays / datasets in batches."""

import itertools
import warnings
from operator import itemgetter
from typing import Any, Dict, Hashable, Iterator, List, Sequence, Union

import numpy as np
import xarray as xr

PatchGenerator = Iterator[Dict[Hashable, slice]]
BatchSelector = List[Dict[Hashable, slice]]
BatchSelectorSet = Dict[int, BatchSelector]


class BatchSchema:
"""
A representation of the indices and stacking/transposing parameters needed
to generator batches from Xarray Datasets and DataArrays using
xbatcher.BatchGenerator.

Parameters
----------
ds : ``xarray.Dataset`` or ``xarray.DataArray``
The data to iterate over. Unlike for the BatchGenerator, the data is
not retained as a class attribute for the BatchSchema.
input_dims : dict
A dictionary specifying the size of the inputs in each dimension,
e.g. ``{'lat': 30, 'lon': 30}``
These are the dimensions the ML library will see. All other dimensions
will be stacked into one dimension called ``sample``.
input_overlap : dict, optional
A dictionary specifying the overlap along each dimension
e.g. ``{'lat': 3, 'lon': 3}``
batch_dims : dict, optional
A dictionary specifying the size of the batch along each dimension
e.g. ``{'time': 10}``. These will always be iterated over.
concat_input_dims : bool, optional
If ``True``, the dimension chunks specified in ``input_dims`` will be
concatenated and stacked into the ``sample`` dimension. The batch index
will be included as a new level ``input_batch`` in the ``sample``
coordinate.
If ``False``, the dimension chunks specified in ``input_dims`` will be
iterated over.
preload_batch : bool, optional
If ``True``, each batch will be loaded into memory before reshaping /
processing, triggering any dask arrays to be computed.

Notes
-----
The BatchSchema is experimental and subject to change without notice.
"""

def __init__(
self,
ds: Union[xr.Dataset, xr.DataArray],
input_dims: Dict[Hashable, int],
input_overlap: Dict[Hashable, int] = {},
batch_dims: Dict[Hashable, int] = {},
concat_input_bins: bool = True,
preload_batch: bool = True,
):
maxrjones marked this conversation as resolved.
Show resolved Hide resolved
self.input_dims = dict(input_dims)
self.input_overlap = input_overlap
self.batch_dims = dict(batch_dims)
self.concat_input_dims = concat_input_bins
self.preload_batch = preload_batch
# Store helpful information based on arguments
self._duplicate_batch_dims: Dict[Hashable, int] = {
dim: length
for dim, length in self.batch_dims.items()
if self.input_dims.get(dim) is not None
}
self._unique_batch_dims: Dict[Hashable, int] = {
dim: length
for dim, length in self.batch_dims.items()
if self.input_dims.get(dim) is None
}
self._input_stride: Dict[Hashable, int] = {
dim: length - self.input_overlap.get(dim, 0)
for dim, length in self.input_dims.items()
}
self._all_sliced_dims: Dict[Hashable, int] = dict(
**self._unique_batch_dims, **self.input_dims
)
self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds)

def _gen_batch_selectors(
self, ds: Union[xr.DataArray, xr.Dataset]
) -> BatchSelectorSet:
"""
Create batch selectors dict, which can be used to create a batch
from an xarray data object.
"""
# Create an iterator that returns an object usable for .isel in xarray
patch_selectors = self._gen_patch_selectors(ds)
# Create the Dict containing batch selectors
if self.concat_input_dims: # Combine the patches into batches
return self._combine_patches_into_batch(ds, patch_selectors)
else: # Each patch gets its own batch
return {ind: [value] for ind, value in enumerate(patch_selectors)}

def _gen_patch_selectors(
self, ds: Union[xr.DataArray, xr.Dataset]
) -> PatchGenerator:
"""
Create an iterator that can be used to index an Xarray Dataset/DataArray.
"""
if self._duplicate_batch_dims and not self.concat_input_dims:
warnings.warn(
"The following dimensions were included in both ``input_dims`` "
"and ``batch_dims``. Since ``concat_input_dims`` is ``False``, "
f"these dimensions will not impact batch generation: {self._duplicate_batch_dims}"
)
# Generate the slices by iterating over batch_dims and input_dims
all_slices = _iterate_through_dimensions(
ds,
dims=self._all_sliced_dims,
overlap=self.input_overlap,
)
return all_slices

def _combine_patches_into_batch(
self, ds: Union[xr.DataArray, xr.Dataset], patch_selectors: PatchGenerator
) -> BatchSelectorSet:
"""
Combine the patch selectors to form a batch
"""
# Check that patches are only combined with concat_input_dims
if not self.concat_input_dims:
raise AssertionError(
"Patches should only be combined into batches when ``concat_input_dims`` is ``True``"
)
if not self.batch_dims:
return self._combine_patches_into_one_batch(patch_selectors)
elif self._duplicate_batch_dims:
return self._combine_patches_grouped_by_input_and_batch_dims(
ds=ds, patch_selectors=patch_selectors
)
else:
return self._combine_patches_grouped_by_batch_dims(patch_selectors)

def _combine_patches_into_one_batch(
self, patch_selectors: PatchGenerator
) -> BatchSelectorSet:
"""
Group all patches into a single batch
"""
return dict(enumerate([list(patch_selectors)]))

def _combine_patches_grouped_by_batch_dims(
self, patch_selectors: PatchGenerator
) -> BatchSelectorSet:
"""
Group patches based on the unique slices for dimensions in ``batch_dims``
"""
batch_selectors = [
list(value)
for _, value in itertools.groupby(
patch_selectors, key=itemgetter(*self.batch_dims)
)
]
return dict(enumerate(batch_selectors))

def _combine_patches_grouped_by_input_and_batch_dims(
self, ds: Union[xr.DataArray, xr.Dataset], patch_selectors: PatchGenerator
) -> BatchSelectorSet:
"""
Combine patches with multiple slices along ``batch_dims`` grouped into
each patch. Required when a dimension is duplicated between ``batch_dims``
and ``input_dims``.
"""
self._gen_patch_numbers(ds)
self._gen_batch_numbers(ds)
batch_id_per_patch = self._get_batch_multi_index_per_patch()
patch_in_range = self._get_batch_in_range_per_batch(
batch_multi_index=batch_id_per_patch
)
batch_id_per_patch = self._ravel_batch_multi_index(batch_id_per_patch)
batch_selectors = self._gen_empty_batch_selectors()
for i, patch in enumerate(patch_selectors):
if patch_in_range[i]:
batch_selectors[batch_id_per_patch[i]].append(patch)
return batch_selectors

def _gen_empty_batch_selectors(self) -> BatchSelectorSet:
"""
Create an empty batch selector set that can be populated by appending
patches to each batch.
"""
n_batches = np.product(list(self._n_batches_per_dim.values()))
return {k: [] for k in range(n_batches)}

def _gen_patch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]):
"""
Calculate the number of patches per dimension and the number of patches
in each batch per dimension.
"""
self._n_patches_per_batch: Dict[Hashable, int] = {
dim: int(np.ceil(length / self._input_stride.get(dim, length)))
for dim, length in self.batch_dims.items()
}
self._n_patches_per_dim: Dict[Hashable, int] = {
dim: int(
(ds.sizes[dim] - self.input_overlap.get(dim, 0))
// (length - self.input_overlap.get(dim, 0))
)
for dim, length in self._all_sliced_dims.items()
}

def _gen_batch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]):
"""
Calculate the number of batches per dimension
"""
self._n_batches_per_dim: Dict[Hashable, int] = {
dim: int(ds.sizes[dim] // self.batch_dims.get(dim, ds.sizes[dim]))
for dim in self._all_sliced_dims.keys()
}

def _get_batch_multi_index_per_patch(self):
"""
Calculate the batch multi-index for each patch
"""
batch_id_per_dim: Dict[Hashable, Any] = {
dim: np.floor(
np.arange(0, n_patches)
/ self._n_patches_per_batch.get(dim, n_patches + 1)
).astype(np.int64)
for dim, n_patches in self._n_patches_per_dim.items()
}
batch_id_per_patch = np.array(
list(itertools.product(*batch_id_per_dim.values()))
).transpose()
return batch_id_per_patch

def _ravel_batch_multi_index(self, batch_multi_index):
"""
Convert the batch multi-index to a flat index for each patch
"""
return np.ravel_multi_index(
multi_index=batch_multi_index,
dims=tuple(self._n_batches_per_dim.values()),
mode="clip",
)

def _get_batch_in_range_per_batch(self, batch_multi_index):
"""
Determine whether each patch is contained within any of the batches.
"""
batch_id_maximum = np.fromiter(self._n_batches_per_dim.values(), dtype=int)
batch_id_maximum = np.pad(
batch_id_maximum,
(0, (len(self._n_patches_per_dim) - len(self._n_batches_per_dim))),
constant_values=(1),
)
batch_id_maximum = batch_id_maximum[:, np.newaxis]
batch_in_range_per_patch = np.all(batch_multi_index < batch_id_maximum, axis=0)
return batch_in_range_per_patch


def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> List[slice]:
# return a list of slices to chop up a single dimension
Expand Down Expand Up @@ -126,21 +378,41 @@ def __init__(
):

self.ds = ds
self.input_dims = dict(input_dims)
self.input_overlap = input_overlap
self.batch_dims = dict(batch_dims)
self.concat_input_dims = concat_input_dims
self.preload_batch = preload_batch
self._batch_selectors: Dict[
int, Any
] = self._gen_batch_selectors() # dict cache for batches
self._batch_selectors: BatchSchema = BatchSchema(
ds,
input_dims=input_dims,
input_overlap=input_overlap,
batch_dims=batch_dims,
concat_input_bins=concat_input_dims,
preload_batch=preload_batch,
)

@property
def input_dims(self):
return self._batch_selectors.input_dims

@property
def input_overlap(self):
return self._batch_selectors.input_overlap

@property
def batch_dims(self):
return self._batch_selectors.batch_dims

@property
def concat_input_dims(self):
return self._batch_selectors.concat_input_dims

@property
def preload_batch(self):
return self._batch_selectors.preload_batch

def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]:
for idx in self._batch_selectors:
for idx in self._batch_selectors.selectors:
yield self[idx]

def __len__(self) -> int:
return len(self._batch_selectors)
return len(self._batch_selectors.selectors)

def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]:

Expand All @@ -150,17 +422,28 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]:
)

if idx < 0:
idx = list(self._batch_selectors)[idx]
idx = list(self._batch_selectors.selectors)[idx]

if idx in self._batch_selectors:
if idx in self._batch_selectors.selectors:

if self.concat_input_dims:
new_dim_suffix = "_input"
all_dsets: List = []
for ds_input_select in self._batch_selectors[idx]:
batch_selector = {}
for dim in self._batch_selectors.batch_dims.keys():
starts = [
x[dim].start for x in self._batch_selectors.selectors[idx]
]
stops = [x[dim].stop for x in self._batch_selectors.selectors[idx]]
batch_selector[dim] = slice(min(starts), max(stops))
batch_ds = self.ds.isel(batch_selector)
if self.preload_batch:
batch_ds.load()
for selector in self._batch_selectors.selectors[idx]:
patch_ds = self.ds.isel(selector)
all_dsets.append(
_drop_input_dims(
self.ds.isel(**ds_input_select),
patch_ds,
self.input_dims,
suffix=new_dim_suffix,
)
Expand All @@ -169,34 +452,12 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]:
new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims]
return _maybe_stack_batch_dims(dsc, new_input_dims)
else:

batch_ds = self.ds.isel(self._batch_selectors.selectors[idx][0])
if self.preload_batch:
batch_ds.load()
return _maybe_stack_batch_dims(
self.ds.isel(**self._batch_selectors[idx]), list(self.input_dims)
batch_ds,
list(self.input_dims),
)
else:
raise IndexError("list index out of range")

def _gen_batch_selectors(self) -> dict:
# in the future, we will want to do the batch generation lazily
# going the eager route for now is allowing me to fill out the loader api
# but it is likely to perform poorly.
batches = []
for ds_batch_selector in self._iterate_batch_dims():
ds_batch = self.ds.isel(**ds_batch_selector)
if self.preload_batch:
ds_batch.load()
input_generator = self._iterate_input_dims()
if self.concat_input_dims:
batches.append(list(input_generator))
else:
batches += list(input_generator)

return dict(enumerate(batches))

def _iterate_batch_dims(self) -> Any:
return _iterate_through_dimensions(self.ds, dims=self.batch_dims)

def _iterate_input_dims(self) -> Any:
return _iterate_through_dimensions(
self.ds, dims=self.input_dims, overlap=self.input_overlap
)
Loading