diff --git a/xbatcher/accessors.py b/xbatcher/accessors.py index a9d19be..af7fed1 100644 --- a/xbatcher/accessors.py +++ b/xbatcher/accessors.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Any, Union import xarray as xr @@ -19,13 +19,13 @@ def _as_xarray_dataarray(xr_obj: Union[xr.Dataset, xr.DataArray]) -> xr.DataArra @xr.register_dataarray_accessor("batch") @xr.register_dataset_accessor("batch") class BatchAccessor: - def __init__(self, xarray_obj): + def __init__(self, xarray_obj: Union[xr.Dataset, xr.DataArray]): """ Batch accessor returning a BatchGenerator object via the `generator method` """ self._obj = xarray_obj - def generator(self, *args, **kwargs): + def generator(self, *args, **kwargs) -> BatchGenerator: """ Return a BatchGenerator via the batch accessor @@ -42,10 +42,10 @@ def generator(self, *args, **kwargs): @xr.register_dataarray_accessor("tf") @xr.register_dataset_accessor("tf") class TFAccessor: - def __init__(self, xarray_obj): + def __init__(self, xarray_obj: Union[xr.Dataset, xr.DataArray]): self._obj = xarray_obj - def to_tensor(self): + def to_tensor(self) -> Any: """Convert this DataArray to a tensorflow.Tensor""" import tensorflow as tf @@ -57,10 +57,10 @@ def to_tensor(self): @xr.register_dataarray_accessor("torch") @xr.register_dataset_accessor("torch") class TorchAccessor: - def __init__(self, xarray_obj): + def __init__(self, xarray_obj: Union[xr.Dataset, xr.DataArray]): self._obj = xarray_obj - def to_tensor(self): + def to_tensor(self) -> Any: """Convert this DataArray to a torch.Tensor""" import torch @@ -68,7 +68,7 @@ def to_tensor(self): return torch.tensor(data=dataarray.data) - def to_named_tensor(self): + def to_named_tensor(self) -> Any: """ Convert this DataArray to a torch.Tensor with named dimensions. diff --git a/xbatcher/generators.py b/xbatcher/generators.py index ec3b2b6..c2c92d9 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -1,13 +1,12 @@ """Classes for iterating through xarray datarrays / datasets in batches.""" import itertools -from collections import OrderedDict -from typing import Any, Dict, Hashable, Iterator +from typing import Any, Dict, Hashable, Iterator, List, OrderedDict, Sequence, Union import xarray as xr -def _slices(dimsize, size, overlap=0): +def _slices(dimsize: int, size: int, overlap: int = 0) -> Any: # return a list of slices to chop up a single dimension if overlap >= size: raise ValueError( @@ -23,7 +22,11 @@ def _slices(dimsize, size, overlap=0): return slices -def _iterate_through_dataset(ds, dims, overlap={}): +def _iterate_through_dataset( + ds: Union[xr.Dataset, xr.DataArray], + dims: OrderedDict[Hashable, int], + overlap: Dict[Hashable, int] = {}, +) -> Any: dim_slices = [] for dim in dims: dimsize = ds.sizes[dim] @@ -39,16 +42,20 @@ def _iterate_through_dataset(ds, dims, overlap={}): dim_slices.append(_slices(dimsize, size, olap)) for slices in itertools.product(*dim_slices): - selector = {key: slice for key, slice in zip(dims, slices)} + selector = dict(zip(dims, slices)) yield selector -def _drop_input_dims(ds, input_dims, suffix="_input"): +def _drop_input_dims( + ds: Union[xr.Dataset, xr.DataArray], + input_dims: OrderedDict[Hashable, int], + suffix: str = "_input", +) -> Union[xr.Dataset, xr.DataArray]: # remove input_dims coordinates from datasets, rename the dimensions # then put intput_dims back in as coordinates out = ds.copy() - for dim in input_dims: - newdim = dim + suffix + for dim in input_dims.keys(): + newdim = f"{dim}{suffix}" out = out.rename({dim: newdim}) # extra steps needed if there is a coordinate if newdim in out: @@ -57,13 +64,16 @@ def _drop_input_dims(ds, input_dims, suffix="_input"): return out -def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name="sample"): +def _maybe_stack_batch_dims( + ds: Union[xr.Dataset, xr.DataArray], + input_dims: Sequence[Hashable], +) -> Union[xr.Dataset, xr.DataArray]: batch_dims = [d for d in ds.sizes if d not in input_dims] if len(batch_dims) < 2: return ds - ds_stack = ds.stack(**{stacked_dim_name: batch_dims}) + ds_stack = ds.stack(sample=batch_dims) # ensure correct order - dim_order = (stacked_dim_name,) + tuple(input_dims) + dim_order = ("sample",) + tuple(input_dims) return ds_stack.transpose(*dim_order) @@ -105,7 +115,7 @@ class BatchGenerator: def __init__( self, - ds: xr.Dataset, + ds: Union[xr.Dataset, xr.DataArray], input_dims: Dict[Hashable, int], input_overlap: Dict[Hashable, int] = {}, batch_dims: Dict[Hashable, int] = {}, @@ -122,14 +132,14 @@ def __init__( self.preload_batch = preload_batch self._batches: Dict[int, Any] = self._gen_batches() # dict cache for batches - def __iter__(self) -> Iterator[xr.Dataset]: + def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]: for idx in self._batches: yield self[idx] def __len__(self) -> int: return len(self._batches) - def __getitem__(self, idx: int) -> xr.Dataset: + def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: if not isinstance(idx, int): raise NotImplementedError( @@ -143,14 +153,15 @@ def __getitem__(self, idx: int) -> xr.Dataset: if self.concat_input_dims: new_dim_suffix = "_input" - all_dsets = [ - _drop_input_dims( - self.ds.isel(**ds_input_select), - list(self.input_dims), - suffix=new_dim_suffix, + all_dsets: List = [] + for ds_input_select in self._batches[idx]: + all_dsets.append( + _drop_input_dims( + self.ds.isel(**ds_input_select), + self.input_dims, + suffix=new_dim_suffix, + ) ) - for ds_input_select in self._batches[idx] - ] dsc = xr.concat(all_dsets, dim="input_batch") new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims] return _maybe_stack_batch_dims(dsc, new_input_dims) @@ -167,22 +178,20 @@ def _gen_batches(self) -> dict: # going the eager route for now is allowing me to fill out the loader api # but it is likely to perform poorly. batches = [] - for ds_batch_selector in self._iterate_batch_dims(self.ds): + for ds_batch_selector in self._iterate_batch_dims(): ds_batch = self.ds.isel(**ds_batch_selector) if self.preload_batch: ds_batch.load() - - input_generator = self._iterate_input_dims(ds_batch) - + input_generator = self._iterate_input_dims() if self.concat_input_dims: batches.append(list(input_generator)) else: batches += list(input_generator) - return dict(zip(range(len(batches)), batches)) + return dict(enumerate(batches)) - def _iterate_batch_dims(self, ds): - return _iterate_through_dataset(ds, self.batch_dims) + def _iterate_batch_dims(self) -> Any: + return _iterate_through_dataset(self.ds, self.batch_dims) - def _iterate_input_dims(self, ds): - return _iterate_through_dataset(ds, self.input_dims, self.input_overlap) + def _iterate_input_dims(self) -> Any: + return _iterate_through_dataset(self.ds, self.input_dims, self.input_overlap)