Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve tests #124

Merged
merged 21 commits into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xbatcher/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from importlib.metadata import PackageNotFoundError as _PackageNotFoundError
from importlib.metadata import version as _version

from . import testing # noqa: F401
from .accessors import BatchAccessor # noqa: F401
from .generators import BatchGenerator # noqa: F401
from .util.print_versions import show_versions # noqa: F401
Expand Down
235 changes: 235 additions & 0 deletions xbatcher/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
from typing import Dict, Hashable, Union
from unittest import TestCase

import numpy as np
import xarray as xr

from .generators import BatchGenerator


def _get_non_specified_dims(generator: BatchGenerator) -> Dict[Hashable, int]:
"""
Return all dimensions that are in the input dataset but not ``input_dims``
or ``batch_dims``.

Parameters
----------
generator : xbatcher.BatchGenerator
The batch generator object.

Returns
-------
d : dict
Dict containing all dimensions in the input dataset that are not
in the input_dims or batch_dims attributes of the batch generator.
"""
return {
k: v
for k, v in generator.ds.sizes.items()
if (generator.input_dims.get(k) is None and generator.batch_dims.get(k) is None)
}


def _get_non_input_batch_dims(generator: BatchGenerator) -> Dict[Hashable, int]:
"""
Return all dimensions that are in batch_dims but not input_dims.

Parameters
----------
generator : xbatcher.BatchGenerator
The batch generator object.

Returns
-------
d : dict
Dict containing all dimensions in specified in batch_dims that are
not also in input_dims
"""
return {
k: v
for k, v in generator.batch_dims.items()
if (generator.input_dims.get(k) is None)
}


def _get_sample_length(
*,
generator: BatchGenerator,
non_specified_ds_dims: Dict[Hashable, int],
non_input_batch_dims: Dict[Hashable, int],
) -> int:
"""
Return the expected length of the sample dimension.

Parameters
----------
generator : xbatcher.BatchGenerator
The batch generator object.
non_specified_ds_dics : dict
Dict containing all dimensions in the input dataset that are not
in the input_dims or batch_dims attributes of the batch generator.
non_input_batch_dims : dict
Dict containing all dimensions in specified in batch_dims that are
not also in input_dims

Returns
-------
s : int
Expected length of the sample dimension
"""
if generator.concat_input_dims:
batch_concat_dims = [
generator.ds.sizes.get(k)
// np.nanmax([v, generator.batch_dims.get(k, np.nan)])
for k, v in generator.input_dims.items()
]
else:
batch_concat_dims = []
return int(
np.product(list(non_specified_ds_dims.values()))
* np.product(list(non_input_batch_dims.values()))
* np.product(batch_concat_dims)
)


def get_batch_dimensions(generator: BatchGenerator) -> Dict[Hashable, int]:
"""
Return the expected batch dimensions based on the ``input_dims``,
``batch_dims``, and ``concat_input_dims`` attributes of the batch
generator.

Parameters
----------
generator : xbatcher.BatchGenerator
The batch generator object.

Returns
-------
d : dict
Dict containing the expected dimensions for batches returned by the
batch generator.
"""
# dimensions that are in the input dataset but not input_dims or batch_dims
non_specified_ds_dims = _get_non_specified_dims(generator)
# dimensions that are in batch_dims but not input_dims
non_input_batch_dims = _get_non_input_batch_dims(generator)
expected_sample_length = _get_sample_length(
generator=generator,
non_specified_ds_dims=non_specified_ds_dims,
non_input_batch_dims=non_input_batch_dims,
)
# input_dims stay the same, possibly with a new suffix
expected_dims = {
f"{k}_input" if generator.concat_input_dims else k: v
for k, v in generator.input_dims.items()
}
# Add a sample dimension if there's anything to get stacked
if (
generator.concat_input_dims
and (len(generator.ds.sizes) - len(generator.input_dims)) == 0
):
expected_dims = {**{"input_batch": expected_sample_length}, **expected_dims}
elif (
generator.concat_input_dims
or (len(generator.ds.sizes) - len(generator.input_dims)) > 1
):
expected_dims = {**{"sample": expected_sample_length}, **expected_dims}
else:
expected_dims = dict(
**non_specified_ds_dims,
**non_input_batch_dims,
**expected_dims,
)
return expected_dims


def validate_batch_dimensions(
*, expected_dims: Dict[Hashable, int], batch: Union[xr.Dataset, xr.DataArray]
) -> None:
"""
Raises an AssertionError if the shape and dimensions of a batch do not
match expected_dims.

Parameters
----------
expected_dims : Dict
Dict containing the expected dimensions for batches.
batch : xarray.Dataset or xarray.DataArray
The xarray data object returned by the batch generator.
"""

# Check the names and lengths of the dimensions are equal
TestCase().assertDictEqual(
expected_dims, batch.sizes.mapping, msg="Dimension names and/or lengths differ"
)
# Check the dimension order is equal
for var in batch.data_vars:
TestCase().assertEqual(
tuple(expected_dims.values()),
batch[var].shape,
msg=f"Order differs for dimensions of: {expected_dims}",
)


def _get_nbatches_from_input_dims(generator: BatchGenerator) -> int:
"""
Calculate the number of batches expected based on ``input_dims`` and
``input_overlap``.

Parameters
----------
generator : xbatcher.BatchGenerator
The batch generator object.

Returns
-------
s : int
Number of batches expected given ``input_dims`` and ``input_overlap``.
"""
nbatches_from_input_dims = np.product(
[
generator.ds.sizes[k] // generator.input_dims[k]
for k in generator.input_dims.keys()
if generator.input_overlap.get(k) is None
]
)
if generator.input_overlap:
nbatches_from_input_overlap = np.product(
[
(generator.ds.sizes[k] - generator.input_overlap[k])
// (generator.input_dims[k] - generator.input_overlap[k])
for k in generator.input_overlap
]
)
return int(nbatches_from_input_overlap * nbatches_from_input_dims)
else:
return int(nbatches_from_input_dims)


def validate_generator_length(generator: BatchGenerator) -> None:
"""
Raises an AssertionError if the generator length does not match
expectations based on the input Dataset and ``input_dims``.

Parameters
----------
generator : xbatcher.BatchGenerator
The batch generator object.
"""
non_input_batch_dims = _get_non_input_batch_dims(generator)
nbatches_from_batch_dims = np.product(
[
generator.ds.sizes[k] // non_input_batch_dims[k]
for k in non_input_batch_dims.keys()
]
)
if generator.concat_input_dims:
expected_length = int(nbatches_from_batch_dims)
else:
nbatches_from_input_dims = _get_nbatches_from_input_dims(generator)
expected_length = int(nbatches_from_batch_dims * nbatches_from_input_dims)
TestCase().assertEqual(
expected_length,
len(generator),
msg="Batch generator length differs",
)
Loading