Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints for generators and accessors #128

Merged
merged 4 commits into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions xbatcher/accessors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Any, Union

import xarray as xr

Expand All @@ -19,13 +19,13 @@ def _as_xarray_dataarray(xr_obj: Union[xr.Dataset, xr.DataArray]) -> xr.DataArra
@xr.register_dataarray_accessor("batch")
@xr.register_dataset_accessor("batch")
class BatchAccessor:
def __init__(self, xarray_obj):
def __init__(self, xarray_obj: Union[xr.Dataset, xr.DataArray]):
"""
Batch accessor returning a BatchGenerator object via the `generator method`
"""
self._obj = xarray_obj

def generator(self, *args, **kwargs):
def generator(self, *args, **kwargs) -> BatchGenerator:
"""
Return a BatchGenerator via the batch accessor

Expand All @@ -42,10 +42,10 @@ def generator(self, *args, **kwargs):
@xr.register_dataarray_accessor("tf")
@xr.register_dataset_accessor("tf")
class TFAccessor:
def __init__(self, xarray_obj):
def __init__(self, xarray_obj: Union[xr.Dataset, xr.DataArray]):
self._obj = xarray_obj

def to_tensor(self):
def to_tensor(self) -> Any:
"""Convert this DataArray to a tensorflow.Tensor"""
import tensorflow as tf

Expand All @@ -57,18 +57,18 @@ def to_tensor(self):
@xr.register_dataarray_accessor("torch")
@xr.register_dataset_accessor("torch")
class TorchAccessor:
def __init__(self, xarray_obj):
def __init__(self, xarray_obj: Union[xr.Dataset, xr.DataArray]):
self._obj = xarray_obj

def to_tensor(self):
def to_tensor(self) -> Any:
"""Convert this DataArray to a torch.Tensor"""
import torch

dataarray = _as_xarray_dataarray(xr_obj=self._obj)

return torch.tensor(data=dataarray.data)

def to_named_tensor(self):
def to_named_tensor(self) -> Any:
"""
Convert this DataArray to a torch.Tensor with named dimensions.

Expand Down
69 changes: 39 additions & 30 deletions xbatcher/generators.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
"""Classes for iterating through xarray datarrays / datasets in batches."""

import itertools
from collections import OrderedDict
from typing import Any, Dict, Hashable, Iterator
from typing import Any, Dict, Hashable, Iterator, List, OrderedDict, Sequence, Union

import xarray as xr


def _slices(dimsize, size, overlap=0):
def _slices(dimsize: int, size: int, overlap: int = 0) -> Any:
# return a list of slices to chop up a single dimension
if overlap >= size:
raise ValueError(
Expand All @@ -23,7 +22,11 @@ def _slices(dimsize, size, overlap=0):
return slices


def _iterate_through_dataset(ds, dims, overlap={}):
def _iterate_through_dataset(
ds: Union[xr.Dataset, xr.DataArray],
dims: OrderedDict[Hashable, int],
overlap: Dict[Hashable, int] = {},
) -> Any:
dim_slices = []
for dim in dims:
dimsize = ds.sizes[dim]
Expand All @@ -39,16 +42,20 @@ def _iterate_through_dataset(ds, dims, overlap={}):
dim_slices.append(_slices(dimsize, size, olap))

for slices in itertools.product(*dim_slices):
selector = {key: slice for key, slice in zip(dims, slices)}
selector = dict(zip(dims, slices))
yield selector


def _drop_input_dims(ds, input_dims, suffix="_input"):
def _drop_input_dims(
ds: Union[xr.Dataset, xr.DataArray],
input_dims: OrderedDict[Hashable, int],
suffix: str = "_input",
) -> Union[xr.Dataset, xr.DataArray]:
# remove input_dims coordinates from datasets, rename the dimensions
# then put intput_dims back in as coordinates
out = ds.copy()
for dim in input_dims:
newdim = dim + suffix
for dim in input_dims.keys():
newdim = f"{dim}{suffix}"
out = out.rename({dim: newdim})
# extra steps needed if there is a coordinate
if newdim in out:
Expand All @@ -57,13 +64,16 @@ def _drop_input_dims(ds, input_dims, suffix="_input"):
return out


def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name="sample"):
def _maybe_stack_batch_dims(
ds: Union[xr.Dataset, xr.DataArray],
input_dims: Sequence[Hashable],
) -> Union[xr.Dataset, xr.DataArray]:
batch_dims = [d for d in ds.sizes if d not in input_dims]
if len(batch_dims) < 2:
return ds
ds_stack = ds.stack(**{stacked_dim_name: batch_dims})
ds_stack = ds.stack(sample=batch_dims)
# ensure correct order
dim_order = (stacked_dim_name,) + tuple(input_dims)
dim_order = ("sample",) + tuple(input_dims)
return ds_stack.transpose(*dim_order)


Expand Down Expand Up @@ -105,7 +115,7 @@ class BatchGenerator:

def __init__(
self,
ds: xr.Dataset,
ds: Union[xr.Dataset, xr.DataArray],
input_dims: Dict[Hashable, int],
input_overlap: Dict[Hashable, int] = {},
batch_dims: Dict[Hashable, int] = {},
Expand All @@ -122,14 +132,14 @@ def __init__(
self.preload_batch = preload_batch
self._batches: Dict[int, Any] = self._gen_batches() # dict cache for batches

def __iter__(self) -> Iterator[xr.Dataset]:
def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]:
for idx in self._batches:
yield self[idx]

def __len__(self) -> int:
return len(self._batches)

def __getitem__(self, idx: int) -> xr.Dataset:
def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]:

if not isinstance(idx, int):
raise NotImplementedError(
Expand All @@ -143,14 +153,15 @@ def __getitem__(self, idx: int) -> xr.Dataset:

if self.concat_input_dims:
new_dim_suffix = "_input"
all_dsets = [
_drop_input_dims(
self.ds.isel(**ds_input_select),
list(self.input_dims),
suffix=new_dim_suffix,
all_dsets: List = []
for ds_input_select in self._batches[idx]:
all_dsets.append(
_drop_input_dims(
self.ds.isel(**ds_input_select),
self.input_dims,
suffix=new_dim_suffix,
)
)
for ds_input_select in self._batches[idx]
]
dsc = xr.concat(all_dsets, dim="input_batch")
new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims]
return _maybe_stack_batch_dims(dsc, new_input_dims)
Expand All @@ -167,22 +178,20 @@ def _gen_batches(self) -> dict:
# going the eager route for now is allowing me to fill out the loader api
# but it is likely to perform poorly.
batches = []
for ds_batch_selector in self._iterate_batch_dims(self.ds):
for ds_batch_selector in self._iterate_batch_dims():
ds_batch = self.ds.isel(**ds_batch_selector)
if self.preload_batch:
ds_batch.load()

input_generator = self._iterate_input_dims(ds_batch)

input_generator = self._iterate_input_dims()
if self.concat_input_dims:
batches.append(list(input_generator))
else:
batches += list(input_generator)

return dict(zip(range(len(batches)), batches))
return dict(enumerate(batches))

def _iterate_batch_dims(self, ds):
return _iterate_through_dataset(ds, self.batch_dims)
def _iterate_batch_dims(self) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _iterate_batch_dims(self) -> Any:
def _iterate_batch_dims(self) -> Generator[Dict[str, int], None, None]:

I didn't test this but I think it is basically right. If it works, you can propagate this to the _iterate_through_dataset function as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, you can use the Iterator type.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of errors using def _iterate_batch_dims(self) -> Iterator[Dict[Hashable, slice]]:.

xbatcher/generators.py:182: error: Keywords must be strings
xbatcher/generators.py:182: error: Argument 1 to "isel" of "Dataset" has incompatible type "**Dict[Hashable, slice]"; expected "Optional[Mapping[Any, Any]]"
xbatcher/generators.py:182: error: Argument 1 to "isel" of "Dataset" has incompatible type "**Dict[Hashable, slice]"; expected "bool"
xbatcher/generators.py:182: error: Argument 1 to "isel" of "Dataset" has incompatible type "**Dict[Hashable, slice]"; expected "Literal['raise', 'warn', 'ignore']"
xbatcher/generators.py:182: error: Argument 1 to "isel" of "DataArray" has incompatible type "**Dict[Hashable, slice]"; expected "Optional[Mapping[Any, Any]]"
xbatcher/generators.py:182: error: Argument 1 to "isel" of "DataArray" has incompatible type "**Dict[Hashable, slice]"; expected "bool"
xbatcher/generators.py:182: error: Argument 1 to "isel" of "DataArray" has incompatible type "**Dict[Hashable, slice]"; expected "Literal['raise', 'warn', 'ignore']"
xbatcher/generators.py:189: error: Argument 1 to "list" has incompatible type "Iterator[Dict[Hashable, slice]]"; expected "Iterable[List[Dict[Hashable, slice]]]"
Found 8 errors in 1 file (checked 20 source files)

I have a separate refactor almost finished that would fix #131 and make this function obsolete. Do view it as necessary to type all functions before merging in the existing type hints? My preference would be to not spend more time on this.

return _iterate_through_dataset(self.ds, self.batch_dims)

def _iterate_input_dims(self, ds):
return _iterate_through_dataset(ds, self.input_dims, self.input_overlap)
def _iterate_input_dims(self) -> Any:
return _iterate_through_dataset(self.ds, self.input_dims, self.input_overlap)