diff --git a/xbatcher/__init__.py b/xbatcher/__init__.py index 7282157..6fb8d75 100644 --- a/xbatcher/__init__.py +++ b/xbatcher/__init__.py @@ -3,7 +3,7 @@ from . import testing # noqa: F401 from .accessors import BatchAccessor # noqa: F401 -from .generators import BatchGenerator # noqa: F401 +from .generators import BatchGenerator, BatchSchema # noqa: F401 from .util.print_versions import show_versions # noqa: F401 try: diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 77c9858..3da074c 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -1,10 +1,266 @@ """Classes for iterating through xarray datarrays / datasets in batches.""" import itertools +import warnings +from operator import itemgetter from typing import Any, Dict, Hashable, Iterator, List, Sequence, Union +import numpy as np import xarray as xr +PatchGenerator = Iterator[Dict[Hashable, slice]] +BatchSelector = List[Dict[Hashable, slice]] +BatchSelectorSet = Dict[int, BatchSelector] + + +class BatchSchema: + """ + A representation of the indices and stacking/transposing parameters needed + to generator batches from Xarray Datasets and DataArrays using + xbatcher.BatchGenerator. + + Parameters + ---------- + ds : ``xarray.Dataset`` or ``xarray.DataArray`` + The data to iterate over. Unlike for the BatchGenerator, the data is + not retained as a class attribute for the BatchSchema. + input_dims : dict + A dictionary specifying the size of the inputs in each dimension, + e.g. ``{'lat': 30, 'lon': 30}`` + These are the dimensions the ML library will see. All other dimensions + will be stacked into one dimension called ``sample``. + input_overlap : dict, optional + A dictionary specifying the overlap along each dimension + e.g. ``{'lat': 3, 'lon': 3}`` + batch_dims : dict, optional + A dictionary specifying the size of the batch along each dimension + e.g. ``{'time': 10}``. These will always be iterated over. + concat_input_dims : bool, optional + If ``True``, the dimension chunks specified in ``input_dims`` will be + concatenated and stacked into the ``sample`` dimension. The batch index + will be included as a new level ``input_batch`` in the ``sample`` + coordinate. + If ``False``, the dimension chunks specified in ``input_dims`` will be + iterated over. + preload_batch : bool, optional + If ``True``, each batch will be loaded into memory before reshaping / + processing, triggering any dask arrays to be computed. + + Notes + ----- + The BatchSchema is experimental and subject to change without notice. + """ + + def __init__( + self, + ds: Union[xr.Dataset, xr.DataArray], + input_dims: Dict[Hashable, int], + input_overlap: Dict[Hashable, int] = None, + batch_dims: Dict[Hashable, int] = None, + concat_input_bins: bool = True, + preload_batch: bool = True, + ): + if input_overlap is None: + input_overlap = {} + if batch_dims is None: + batch_dims = {} + self.input_dims = dict(input_dims) + self.input_overlap = input_overlap + self.batch_dims = dict(batch_dims) + self.concat_input_dims = concat_input_bins + self.preload_batch = preload_batch + # Store helpful information based on arguments + self._duplicate_batch_dims: Dict[Hashable, int] = { + dim: length + for dim, length in self.batch_dims.items() + if self.input_dims.get(dim) is not None + } + self._unique_batch_dims: Dict[Hashable, int] = { + dim: length + for dim, length in self.batch_dims.items() + if self.input_dims.get(dim) is None + } + self._input_stride: Dict[Hashable, int] = { + dim: length - self.input_overlap.get(dim, 0) + for dim, length in self.input_dims.items() + } + self._all_sliced_dims: Dict[Hashable, int] = dict( + **self._unique_batch_dims, **self.input_dims + ) + self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds) + + def _gen_batch_selectors( + self, ds: Union[xr.DataArray, xr.Dataset] + ) -> BatchSelectorSet: + """ + Create batch selectors dict, which can be used to create a batch + from an xarray data object. + """ + # Create an iterator that returns an object usable for .isel in xarray + patch_selectors = self._gen_patch_selectors(ds) + # Create the Dict containing batch selectors + if self.concat_input_dims: # Combine the patches into batches + return self._combine_patches_into_batch(ds, patch_selectors) + else: # Each patch gets its own batch + return {ind: [value] for ind, value in enumerate(patch_selectors)} + + def _gen_patch_selectors( + self, ds: Union[xr.DataArray, xr.Dataset] + ) -> PatchGenerator: + """ + Create an iterator that can be used to index an Xarray Dataset/DataArray. + """ + if self._duplicate_batch_dims and not self.concat_input_dims: + warnings.warn( + "The following dimensions were included in both ``input_dims`` " + "and ``batch_dims``. Since ``concat_input_dims`` is ``False``, " + f"these dimensions will not impact batch generation: {self._duplicate_batch_dims}" + ) + # Generate the slices by iterating over batch_dims and input_dims + all_slices = _iterate_through_dimensions( + ds, + dims=self._all_sliced_dims, + overlap=self.input_overlap, + ) + return all_slices + + def _combine_patches_into_batch( + self, ds: Union[xr.DataArray, xr.Dataset], patch_selectors: PatchGenerator + ) -> BatchSelectorSet: + """ + Combine the patch selectors to form a batch + """ + # Check that patches are only combined with concat_input_dims + if not self.concat_input_dims: + raise AssertionError( + "Patches should only be combined into batches when ``concat_input_dims`` is ``True``" + ) + if not self.batch_dims: + return self._combine_patches_into_one_batch(patch_selectors) + elif self._duplicate_batch_dims: + return self._combine_patches_grouped_by_input_and_batch_dims( + ds=ds, patch_selectors=patch_selectors + ) + else: + return self._combine_patches_grouped_by_batch_dims(patch_selectors) + + def _combine_patches_into_one_batch( + self, patch_selectors: PatchGenerator + ) -> BatchSelectorSet: + """ + Group all patches into a single batch + """ + return dict(enumerate([list(patch_selectors)])) + + def _combine_patches_grouped_by_batch_dims( + self, patch_selectors: PatchGenerator + ) -> BatchSelectorSet: + """ + Group patches based on the unique slices for dimensions in ``batch_dims`` + """ + batch_selectors = [ + list(value) + for _, value in itertools.groupby( + patch_selectors, key=itemgetter(*self.batch_dims) + ) + ] + return dict(enumerate(batch_selectors)) + + def _combine_patches_grouped_by_input_and_batch_dims( + self, ds: Union[xr.DataArray, xr.Dataset], patch_selectors: PatchGenerator + ) -> BatchSelectorSet: + """ + Combine patches with multiple slices along ``batch_dims`` grouped into + each patch. Required when a dimension is duplicated between ``batch_dims`` + and ``input_dims``. + """ + self._gen_patch_numbers(ds) + self._gen_batch_numbers(ds) + batch_id_per_patch = self._get_batch_multi_index_per_patch() + patch_in_range = self._get_batch_in_range_per_batch( + batch_multi_index=batch_id_per_patch + ) + batch_id_per_patch = self._ravel_batch_multi_index(batch_id_per_patch) + batch_selectors = self._gen_empty_batch_selectors() + for i, patch in enumerate(patch_selectors): + if patch_in_range[i]: + batch_selectors[batch_id_per_patch[i]].append(patch) + return batch_selectors + + def _gen_empty_batch_selectors(self) -> BatchSelectorSet: + """ + Create an empty batch selector set that can be populated by appending + patches to each batch. + """ + n_batches = np.product(list(self._n_batches_per_dim.values())) + return {k: [] for k in range(n_batches)} + + def _gen_patch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]): + """ + Calculate the number of patches per dimension and the number of patches + in each batch per dimension. + """ + self._n_patches_per_batch: Dict[Hashable, int] = { + dim: int(np.ceil(length / self._input_stride.get(dim, length))) + for dim, length in self.batch_dims.items() + } + self._n_patches_per_dim: Dict[Hashable, int] = { + dim: int( + (ds.sizes[dim] - self.input_overlap.get(dim, 0)) + // (length - self.input_overlap.get(dim, 0)) + ) + for dim, length in self._all_sliced_dims.items() + } + + def _gen_batch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]): + """ + Calculate the number of batches per dimension + """ + self._n_batches_per_dim: Dict[Hashable, int] = { + dim: int(ds.sizes[dim] // self.batch_dims.get(dim, ds.sizes[dim])) + for dim in self._all_sliced_dims.keys() + } + + def _get_batch_multi_index_per_patch(self): + """ + Calculate the batch multi-index for each patch + """ + batch_id_per_dim: Dict[Hashable, Any] = { + dim: np.floor( + np.arange(0, n_patches) + / self._n_patches_per_batch.get(dim, n_patches + 1) + ).astype(np.int64) + for dim, n_patches in self._n_patches_per_dim.items() + } + batch_id_per_patch = np.array( + list(itertools.product(*batch_id_per_dim.values())) + ).transpose() + return batch_id_per_patch + + def _ravel_batch_multi_index(self, batch_multi_index): + """ + Convert the batch multi-index to a flat index for each patch + """ + return np.ravel_multi_index( + multi_index=batch_multi_index, + dims=tuple(self._n_batches_per_dim.values()), + mode="clip", + ) + + def _get_batch_in_range_per_batch(self, batch_multi_index): + """ + Determine whether each patch is contained within any of the batches. + """ + batch_id_maximum = np.fromiter(self._n_batches_per_dim.values(), dtype=int) + batch_id_maximum = np.pad( + batch_id_maximum, + (0, (len(self._n_patches_per_dim) - len(self._n_batches_per_dim))), + constant_values=(1), + ) + batch_id_maximum = batch_id_maximum[:, np.newaxis] + batch_in_range_per_patch = np.all(batch_multi_index < batch_id_maximum, axis=0) + return batch_in_range_per_patch + def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> List[slice]: # return a list of slices to chop up a single dimension @@ -126,21 +382,41 @@ def __init__( ): self.ds = ds - self.input_dims = dict(input_dims) - self.input_overlap = input_overlap - self.batch_dims = dict(batch_dims) - self.concat_input_dims = concat_input_dims - self.preload_batch = preload_batch - self._batch_selectors: Dict[ - int, Any - ] = self._gen_batch_selectors() # dict cache for batches + self._batch_selectors: BatchSchema = BatchSchema( + ds, + input_dims=input_dims, + input_overlap=input_overlap, + batch_dims=batch_dims, + concat_input_bins=concat_input_dims, + preload_batch=preload_batch, + ) + + @property + def input_dims(self): + return self._batch_selectors.input_dims + + @property + def input_overlap(self): + return self._batch_selectors.input_overlap + + @property + def batch_dims(self): + return self._batch_selectors.batch_dims + + @property + def concat_input_dims(self): + return self._batch_selectors.concat_input_dims + + @property + def preload_batch(self): + return self._batch_selectors.preload_batch def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]: - for idx in self._batch_selectors: + for idx in self._batch_selectors.selectors: yield self[idx] def __len__(self) -> int: - return len(self._batch_selectors) + return len(self._batch_selectors.selectors) def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: @@ -150,17 +426,28 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: ) if idx < 0: - idx = list(self._batch_selectors)[idx] + idx = list(self._batch_selectors.selectors)[idx] - if idx in self._batch_selectors: + if idx in self._batch_selectors.selectors: if self.concat_input_dims: new_dim_suffix = "_input" all_dsets: List = [] - for ds_input_select in self._batch_selectors[idx]: + batch_selector = {} + for dim in self._batch_selectors.batch_dims.keys(): + starts = [ + x[dim].start for x in self._batch_selectors.selectors[idx] + ] + stops = [x[dim].stop for x in self._batch_selectors.selectors[idx]] + batch_selector[dim] = slice(min(starts), max(stops)) + batch_ds = self.ds.isel(batch_selector) + if self.preload_batch: + batch_ds.load() + for selector in self._batch_selectors.selectors[idx]: + patch_ds = self.ds.isel(selector) all_dsets.append( _drop_input_dims( - self.ds.isel(**ds_input_select), + patch_ds, self.input_dims, suffix=new_dim_suffix, ) @@ -169,34 +456,12 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims] return _maybe_stack_batch_dims(dsc, new_input_dims) else: - + batch_ds = self.ds.isel(self._batch_selectors.selectors[idx][0]) + if self.preload_batch: + batch_ds.load() return _maybe_stack_batch_dims( - self.ds.isel(**self._batch_selectors[idx]), list(self.input_dims) + batch_ds, + list(self.input_dims), ) else: raise IndexError("list index out of range") - - def _gen_batch_selectors(self) -> dict: - # in the future, we will want to do the batch generation lazily - # going the eager route for now is allowing me to fill out the loader api - # but it is likely to perform poorly. - batches = [] - for ds_batch_selector in self._iterate_batch_dims(): - ds_batch = self.ds.isel(**ds_batch_selector) - if self.preload_batch: - ds_batch.load() - input_generator = self._iterate_input_dims() - if self.concat_input_dims: - batches.append(list(input_generator)) - else: - batches += list(input_generator) - - return dict(enumerate(batches)) - - def _iterate_batch_dims(self) -> Any: - return _iterate_through_dimensions(self.ds, dims=self.batch_dims) - - def _iterate_input_dims(self) -> Any: - return _iterate_through_dimensions( - self.ds, dims=self.input_dims, overlap=self.input_overlap - ) diff --git a/xbatcher/testing.py b/xbatcher/testing.py index 4cb224d..66546ef 100644 --- a/xbatcher/testing.py +++ b/xbatcher/testing.py @@ -101,9 +101,10 @@ def _get_sample_length( """ if generator.concat_input_dims: batch_concat_dims = [ - generator.ds.sizes.get(k) - // np.nanmax([v, generator.batch_dims.get(k, np.nan)]) - for k, v in generator.input_dims.items() + generator.batch_dims.get(dim) // length + if generator.batch_dims.get(dim) + else generator.ds.sizes.get(dim) // length + for dim, length in generator.input_dims.items() ] else: batch_concat_dims = [] diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index bba84dd..3a9f98f 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -115,9 +115,6 @@ def test_batch_1d_concat(sample_ds_1d, input_size): assert "x" in ds_batch.coords -@pytest.mark.xfail( - reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131" -) def test_batch_1d_concat_duplicate_dim(sample_ds_1d): """ Test batch generation for a 1D dataset using ``concat_input_dims`` when @@ -219,10 +216,18 @@ def test_batch_3d_1d_input(sample_ds_3d, input_size): validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) -@pytest.mark.xfail( - reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131" +@pytest.mark.parametrize( + "concat", + [ + True, + pytest.param( + False, + marks=pytest.mark.xfail( + reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/126" + ), + ), + ], ) -@pytest.mark.parametrize("concat", [True, False]) def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat): """ Test batch generation for a 3D dataset using ``input_dims`` and batch_dims``. @@ -239,9 +244,6 @@ def test_batch_3d_1d_input_batch_dims(sample_ds_3d, concat): validate_batch_dimensions(expected_dims=expected_dims, batch=ds_batch) -@pytest.mark.xfail( - reason="Bug described in https://github.com/xarray-contrib/xbatcher/issues/131" -) def test_batch_3d_1d_input_batch_concat_duplicate_dim(sample_ds_3d): """ Test batch generation for a 3D dataset using ``concat_input_dims`` when