From 221dfb076ecdd51b469a95993efa428127793a55 Mon Sep 17 00:00:00 2001 From: Max Jones <14077947+maxrjones@users.noreply.github.com> Date: Tue, 3 Jan 2023 16:13:33 -0500 Subject: [PATCH] New BatchSchema class to generate patch selectors and combine into batch selectors (#132) In v0.2.0, information about the batch_selectors was stored as a dict in the ._batch_selectors attribute of the BatchGenerator class. This PR introduces BatchSchema which contains all necessary information about the batch selectors. BatchGenerator now creates an instance of BatchSchema, which handles the creation of selectors object. The purpose of BatchGenerator is to create batches (xarray Datasets/DataArrays) from those selector objects. This opens up possibilities including serializing/deserializing the BatchSchema (e.g., BatchSchema.to_json() and BatchSchema.from_json()), caching BatchSchema objects separate from caching batches, and relatedly applying one BatchSchema instance to multiple xarray datasets. Co-authored-by: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> Co-authored-by: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> --- xbatcher/__init__.py | 2 +- xbatcher/generators.py | 347 ++++++++++++++++++++++++++---- xbatcher/testing.py | 7 +- xbatcher/tests/test_generators.py | 20 +- 4 files changed, 322 insertions(+), 54 deletions(-) diff --git a/xbatcher/__init__.py b/xbatcher/__init__.py index 7282157..6fb8d75 100644 --- a/xbatcher/__init__.py +++ b/xbatcher/__init__.py @@ -3,7 +3,7 @@ from . import testing # noqa: F401 from .accessors import BatchAccessor # noqa: F401 -from .generators import BatchGenerator # noqa: F401 +from .generators import BatchGenerator, BatchSchema # noqa: F401 from .util.print_versions import show_versions # noqa: F401 try: diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 77c9858..3da074c 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -1,10 +1,266 @@ """Classes for iterating through xarray datarrays / datasets in batches.""" import itertools +import warnings +from operator import itemgetter from typing import Any, Dict, Hashable, Iterator, List, Sequence, Union +import numpy as np import xarray as xr +PatchGenerator = Iterator[Dict[Hashable, slice]] +BatchSelector = List[Dict[Hashable, slice]] +BatchSelectorSet = Dict[int, BatchSelector] + + +class BatchSchema: + """ + A representation of the indices and stacking/transposing parameters needed + to generator batches from Xarray Datasets and DataArrays using + xbatcher.BatchGenerator. + + Parameters + ---------- + ds : ``xarray.Dataset`` or ``xarray.DataArray`` + The data to iterate over. Unlike for the BatchGenerator, the data is + not retained as a class attribute for the BatchSchema. + input_dims : dict + A dictionary specifying the size of the inputs in each dimension, + e.g. ``{'lat': 30, 'lon': 30}`` + These are the dimensions the ML library will see. All other dimensions + will be stacked into one dimension called ``sample``. + input_overlap : dict, optional + A dictionary specifying the overlap along each dimension + e.g. ``{'lat': 3, 'lon': 3}`` + batch_dims : dict, optional + A dictionary specifying the size of the batch along each dimension + e.g. ``{'time': 10}``. These will always be iterated over. + concat_input_dims : bool, optional + If ``True``, the dimension chunks specified in ``input_dims`` will be + concatenated and stacked into the ``sample`` dimension. The batch index + will be included as a new level ``input_batch`` in the ``sample`` + coordinate. + If ``False``, the dimension chunks specified in ``input_dims`` will be + iterated over. + preload_batch : bool, optional + If ``True``, each batch will be loaded into memory before reshaping / + processing, triggering any dask arrays to be computed. + + Notes + ----- + The BatchSchema is experimental and subject to change without notice. + """ + + def __init__( + self, + ds: Union[xr.Dataset, xr.DataArray], + input_dims: Dict[Hashable, int], + input_overlap: Dict[Hashable, int] = None, + batch_dims: Dict[Hashable, int] = None, + concat_input_bins: bool = True, + preload_batch: bool = True, + ): + if input_overlap is None: + input_overlap = {} + if batch_dims is None: + batch_dims = {} + self.input_dims = dict(input_dims) + self.input_overlap = input_overlap + self.batch_dims = dict(batch_dims) + self.concat_input_dims = concat_input_bins + self.preload_batch = preload_batch + # Store helpful information based on arguments + self._duplicate_batch_dims: Dict[Hashable, int] = { + dim: length + for dim, length in self.batch_dims.items() + if self.input_dims.get(dim) is not None + } + self._unique_batch_dims: Dict[Hashable, int] = { + dim: length + for dim, length in self.batch_dims.items() + if self.input_dims.get(dim) is None + } + self._input_stride: Dict[Hashable, int] = { + dim: length - self.input_overlap.get(dim, 0) + for dim, length in self.input_dims.items() + } + self._all_sliced_dims: Dict[Hashable, int] = dict( + **self._unique_batch_dims, **self.input_dims + ) + self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds) + + def _gen_batch_selectors( + self, ds: Union[xr.DataArray, xr.Dataset] + ) -> BatchSelectorSet: + """ + Create batch selectors dict, which can be used to create a batch + from an xarray data object. + """ + # Create an iterator that returns an object usable for .isel in xarray + patch_selectors = self._gen_patch_selectors(ds) + # Create the Dict containing batch selectors + if self.concat_input_dims: # Combine the patches into batches + return self._combine_patches_into_batch(ds, patch_selectors) + else: # Each patch gets its own batch + return {ind: [value] for ind, value in enumerate(patch_selectors)} + + def _gen_patch_selectors( + self, ds: Union[xr.DataArray, xr.Dataset] + ) -> PatchGenerator: + """ + Create an iterator that can be used to index an Xarray Dataset/DataArray. + """ + if self._duplicate_batch_dims and not self.concat_input_dims: + warnings.warn( + "The following dimensions were included in both ``input_dims`` " + "and ``batch_dims``. Since ``concat_input_dims`` is ``False``, " + f"these dimensions will not impact batch generation: {self._duplicate_batch_dims}" + ) + # Generate the slices by iterating over batch_dims and input_dims + all_slices = _iterate_through_dimensions( + ds, + dims=self._all_sliced_dims, + overlap=self.input_overlap, + ) + return all_slices + + def _combine_patches_into_batch( + self, ds: Union[xr.DataArray, xr.Dataset], patch_selectors: PatchGenerator + ) -> BatchSelectorSet: + """ + Combine the patch selectors to form a batch + """ + # Check that patches are only combined with concat_input_dims + if not self.concat_input_dims: + raise AssertionError( + "Patches should only be combined into batches when ``concat_input_dims`` is ``True``" + ) + if not self.batch_dims: + return self._combine_patches_into_one_batch(patch_selectors) + elif self._duplicate_batch_dims: + return self._combine_patches_grouped_by_input_and_batch_dims( + ds=ds, patch_selectors=patch_selectors + ) + else: + return self._combine_patches_grouped_by_batch_dims(patch_selectors) + + def _combine_patches_into_one_batch( + self, patch_selectors: PatchGenerator + ) -> BatchSelectorSet: + """ + Group all patches into a single batch + """ + return dict(enumerate([list(patch_selectors)])) + + def _combine_patches_grouped_by_batch_dims( + self, patch_selectors: PatchGenerator + ) -> BatchSelectorSet: + """ + Group patches based on the unique slices for dimensions in ``batch_dims`` + """ + batch_selectors = [ + list(value) + for _, value in itertools.groupby( + patch_selectors, key=itemgetter(*self.batch_dims) + ) + ] + return dict(enumerate(batch_selectors)) + + def _combine_patches_grouped_by_input_and_batch_dims( + self, ds: Union[xr.DataArray, xr.Dataset], patch_selectors: PatchGenerator + ) -> BatchSelectorSet: + """ + Combine patches with multiple slices along ``batch_dims`` grouped into + each patch. Required when a dimension is duplicated between ``batch_dims`` + and ``input_dims``. + """ + self._gen_patch_numbers(ds) + self._gen_batch_numbers(ds) + batch_id_per_patch = self._get_batch_multi_index_per_patch() + patch_in_range = self._get_batch_in_range_per_batch( + batch_multi_index=batch_id_per_patch + ) + batch_id_per_patch = self._ravel_batch_multi_index(batch_id_per_patch) + batch_selectors = self._gen_empty_batch_selectors() + for i, patch in enumerate(patch_selectors): + if patch_in_range[i]: + batch_selectors[batch_id_per_patch[i]].append(patch) + return batch_selectors + + def _gen_empty_batch_selectors(self) -> BatchSelectorSet: + """ + Create an empty batch selector set that can be populated by appending + patches to each batch. + """ + n_batches = np.product(list(self._n_batches_per_dim.values())) + return {k: [] for k in range(n_batches)} + + def _gen_patch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]): + """ + Calculate the number of patches per dimension and the number of patches + in each batch per dimension. + """ + self._n_patches_per_batch: Dict[Hashable, int] = { + dim: int(np.ceil(length / self._input_stride.get(dim, length))) + for dim, length in self.batch_dims.items() + } + self._n_patches_per_dim: Dict[Hashable, int] = { + dim: int( + (ds.sizes[dim] - self.input_overlap.get(dim, 0)) + // (length - self.input_overlap.get(dim, 0)) + ) + for dim, length in self._all_sliced_dims.items() + } + + def _gen_batch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]): + """ + Calculate the number of batches per dimension + """ + self._n_batches_per_dim: Dict[Hashable, int] = { + dim: int(ds.sizes[dim] // self.batch_dims.get(dim, ds.sizes[dim])) + for dim in self._all_sliced_dims.keys() + } + + def _get_batch_multi_index_per_patch(self): + """ + Calculate the batch multi-index for each patch + """ + batch_id_per_dim: Dict[Hashable, Any] = { + dim: np.floor( + np.arange(0, n_patches) + / self._n_patches_per_batch.get(dim, n_patches + 1) + ).astype(np.int64) + for dim, n_patches in self._n_patches_per_dim.items() + } + batch_id_per_patch = np.array( + list(itertools.product(*batch_id_per_dim.values())) + ).transpose() + return batch_id_per_patch + + def _ravel_batch_multi_index(self, batch_multi_index): + """ + Convert the batch multi-index to a flat index for each patch + """ + return np.ravel_multi_index( + multi_index=batch_multi_index, + dims=tuple(self._n_batches_per_dim.values()), + mode="clip", + ) + + def _get_batch_in_range_per_batch(self, batch_multi_index): + """ + Determine whether each patch is contained within any of the batches. + """ + batch_id_maximum = np.fromiter(self._n_batches_per_dim.values(), dtype=int) + batch_id_maximum = np.pad( + batch_id_maximum, + (0, (len(self._n_patches_per_dim) - len(self._n_batches_per_dim))), + constant_values=(1), + ) + batch_id_maximum = batch_id_maximum[:, np.newaxis] + batch_in_range_per_patch = np.all(batch_multi_index < batch_id_maximum, axis=0) + return batch_in_range_per_patch + def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> List[slice]: # return a list of slices to chop up a single dimension @@ -126,21 +382,41 @@ def __init__( ): self.ds = ds - self.input_dims = dict(input_dims) - self.input_overlap = input_overlap - self.batch_dims = dict(batch_dims) - self.concat_input_dims = concat_input_dims - self.preload_batch = preload_batch - self._batch_selectors: Dict[ - int, Any - ] = self._gen_batch_selectors() # dict cache for batches + self._batch_selectors: BatchSchema = BatchSchema( + ds, + input_dims=input_dims, + input_overlap=input_overlap, + batch_dims=batch_dims, + concat_input_bins=concat_input_dims, + preload_batch=preload_batch, + ) + + @property + def input_dims(self): + return self._batch_selectors.input_dims + + @property + def input_overlap(self): + return self._batch_selectors.input_overlap + + @property + def batch_dims(self): + return self._batch_selectors.batch_dims + + @property + def concat_input_dims(self): + return self._batch_selectors.concat_input_dims + + @property + def preload_batch(self): + return self._batch_selectors.preload_batch def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]: - for idx in self._batch_selectors: + for idx in self._batch_selectors.selectors: yield self[idx] def __len__(self) -> int: - return len(self._batch_selectors) + return len(self._batch_selectors.selectors) def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: @@ -150,17 +426,28 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: ) if idx < 0: - idx = list(self._batch_selectors)[idx] + idx = list(self._batch_selectors.selectors)[idx] - if idx in self._batch_selectors: + if idx in self._batch_selectors.selectors: if self.concat_input_dims: new_dim_suffix = "_input" all_dsets: List = [] - for ds_input_select in self._batch_selectors[idx]: + batch_selector = {} + for dim in self._batch_selectors.batch_dims.keys(): + starts = [ + x[dim].start for x in self._batch_selectors.selectors[idx] + ] + stops = [x[dim].stop for x in self._batch_selectors.selectors[idx]] + batch_selector[dim] = slice(min(starts), max(stops)) + batch_ds = self.ds.isel(batch_selector) + if self.preload_batch: + batch_ds.load() + for selector in self._batch_selectors.selectors[idx]: + patch_ds = self.ds.isel(selector) all_dsets.append( _drop_input_dims( - self.ds.isel(**ds_input_select), + patch_ds, self.input_dims, suffix=new_dim_suffix, ) @@ -169,34 +456,12 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims] return _maybe_stack_batch_dims(dsc, new_input_dims) else: - + batch_ds = self.ds.isel(self._batch_selectors.selectors[idx][0]) + if self.preload_batch: + batch_ds.load() return _maybe_stack_batch_dims( - self.ds.isel(**self._batch_selectors[idx]), list(self.input_dims) + batch_ds, + list(self.input_dims), ) else: raise IndexError("list index out of range") - - def _gen_batch_selectors(self) -> dict: - # in the future, we will want to do the batch generation lazily - # going the eager route for now is allowing me to fill out the loader api - # but it is likely to perform poorly. - batches = [] - for ds_batch_selector in self._iterate_batch_dims(): - ds_batch = self.ds.isel(**ds_batch_selector) - if self.preload_batch: - ds_batch.load() - input_generator = self._iterate_input_dims() - if self.concat_input_dims: - batches.append(list(input_generator)) - else: - batches += list(input_generator) - - return dict(enumerate(batches)) - - def _iterate_batch_dims(self) -> Any: - return _iterate_through_dimensions(self.ds, dims=self.batch_dims) - - def _iterate_input_dims(self) -> Any: - return _iterate_through_dimensions( - self.ds, dims=self.input_dims, overlap=self.input_overlap - ) diff --git a/xbatcher/testing.py b/xbatcher/testing.py index 4cb224d..66546ef 100644 --- a/xbatcher/testing.py +++ b/xbatcher/testing.py @@ -101,9 +101,10 @@ def _get_sample_length( """ if generator.concat_input_dims: batch_concat_dims = [ - generator.ds.sizes.get(k) - // np.nanmax([v, generator.batch_dims.get(k, np.nan)]) - for k, v in generator.input_dims.items() + generator.batch_dims.get(dim) // length + if generator.batch_dims.get(dim) + else generator.ds.sizes.get(dim) // length + for dim, length in generator.input_dims.items() ] else: batch_concat_dims = [] diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index bba84dd..3a9f98f 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -115,9 +115,6 @@ def test_batch_1d_concat(sample_ds_1d, input_size): assert "x" in ds_batch.coords -@pytest.mark.xfail( - reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131" -) def test_batch_1d_concat_duplicate_dim(sample_ds_1d): """ Test batch generation for a 1D dataset using ``concat_input_dims`` when @@ -219,10 +216,18 @@ def test_batch_3d_1d_input(sample_ds_3d, input_size): validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) -@pytest.mark.xfail( - reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131" +@pytest.mark.parametrize( + "concat", + [ + True, + pytest.param( + False, + marks=pytest.mark.xfail( + reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/126" + ), + ), + ], ) -@pytest.mark.parametrize("concat", [True, False]) def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat): """ Test batch generation for a 3D dataset using ``input_dims`` and batch_dims``. @@ -239,9 +244,6 @@ def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat): validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) -@pytest.mark.xfail( - reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131" -) def test_batch_3d_1d_input_batch_concat_duplicate_dim(sample_ds_3d): """ Test batch generation for a 3D dataset using ``concat_input_dims`` when