From bbe34990fb7f6b564912e6074d2b2b4ff1d40e61 Mon Sep 17 00:00:00 2001 From: Anderson Banihirwe <13301940+andersy005@users.noreply.github.com> Date: Thu, 15 Aug 2024 16:12:27 -0700 Subject: [PATCH] Drop support for Python 3.9 (#232) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/main.yaml | 4 ++-- pyproject.toml | 5 ++-- xbatcher/accessors.py | 10 ++++---- xbatcher/generators.py | 48 +++++++++++++++++-------------------- xbatcher/loaders/keras.py | 7 +++--- xbatcher/loaders/torch.py | 7 +++--- xbatcher/testing.py | 3 +-- 7 files changed, 40 insertions(+), 44 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index a6d51ee..3aacd5e 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -22,7 +22,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] fail-fast: false steps: - uses: actions/checkout@v4 @@ -55,7 +55,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11", "3.12"] fail-fast: false steps: - uses: actions/checkout@v4 diff --git a/pyproject.toml b/pyproject.toml index fa8b2af..4927be2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ description = "Batch generation from Xarray objects" readme = "README.rst" license = {text = "Apache"} authors = [{name = "xbatcher Developers", email = "rpa@ldeo.columbia.edu"}] -requires-python = ">=3.9" +requires-python = ">=3.10" classifiers = [ "Development Status :: 4 - Beta", "License :: OSI Approved :: Apache Software License", @@ -19,7 +19,6 @@ classifiers = [ "Intended Audience :: Science/Research", "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -62,7 +61,7 @@ fallback_version = "999" [tool.ruff] -target-version = "py39" +target-version = "py310" extend-include = ["*.ipynb"] diff --git a/xbatcher/accessors.py b/xbatcher/accessors.py index d43d606..f05cac1 100644 --- a/xbatcher/accessors.py +++ b/xbatcher/accessors.py @@ -1,11 +1,11 @@ -from typing import Any, Union +from typing import Any import xarray as xr from .generators import BatchGenerator -def _as_xarray_dataarray(xr_obj: Union[xr.Dataset, xr.DataArray]) -> xr.DataArray: +def _as_xarray_dataarray(xr_obj: xr.Dataset | xr.DataArray) -> xr.DataArray: """ Convert xarray.Dataset to xarray.DataArray if needed, so that it can be converted into a Tensor object. @@ -19,7 +19,7 @@ def _as_xarray_dataarray(xr_obj: Union[xr.Dataset, xr.DataArray]) -> xr.DataArra @xr.register_dataarray_accessor('batch') @xr.register_dataset_accessor('batch') class BatchAccessor: - def __init__(self, xarray_obj: Union[xr.Dataset, xr.DataArray]): + def __init__(self, xarray_obj: xr.Dataset | xr.DataArray): """ Batch accessor returning a BatchGenerator object via the `generator method` """ @@ -42,7 +42,7 @@ def generator(self, *args, **kwargs) -> BatchGenerator: @xr.register_dataarray_accessor('tf') @xr.register_dataset_accessor('tf') class TFAccessor: - def __init__(self, xarray_obj: Union[xr.Dataset, xr.DataArray]): + def __init__(self, xarray_obj: xr.Dataset | xr.DataArray): self._obj = xarray_obj def to_tensor(self) -> Any: @@ -57,7 +57,7 @@ def to_tensor(self) -> Any: @xr.register_dataarray_accessor('torch') @xr.register_dataset_accessor('torch') class TorchAccessor: - def __init__(self, xarray_obj: Union[xr.Dataset, xr.DataArray]): + def __init__(self, xarray_obj: xr.Dataset | xr.DataArray): self._obj = xarray_obj def to_tensor(self) -> Any: diff --git a/xbatcher/generators.py b/xbatcher/generators.py index 5640f32..412d09a 100644 --- a/xbatcher/generators.py +++ b/xbatcher/generators.py @@ -3,9 +3,9 @@ import itertools import json import warnings -from collections.abc import Hashable, Iterator, Sequence +from collections.abc import Callable, Hashable, Iterator, Sequence from operator import itemgetter -from typing import Any, Callable, Optional, Union +from typing import Any import numpy as np import xarray as xr @@ -55,10 +55,10 @@ class BatchSchema: def __init__( self, - ds: Union[xr.Dataset, xr.DataArray], + ds: xr.Dataset | xr.DataArray, input_dims: dict[Hashable, int], - input_overlap: Optional[dict[Hashable, int]] = None, - batch_dims: Optional[dict[Hashable, int]] = None, + input_overlap: dict[Hashable, int] | None = None, + batch_dims: dict[Hashable, int] | None = None, concat_input_bins: bool = True, preload_batch: bool = True, ): @@ -91,9 +91,7 @@ def __init__( ) self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds) - def _gen_batch_selectors( - self, ds: Union[xr.DataArray, xr.Dataset] - ) -> BatchSelectorSet: + def _gen_batch_selectors(self, ds: xr.DataArray | xr.Dataset) -> BatchSelectorSet: """ Create batch selectors dict, which can be used to create a batch from an Xarray data object. @@ -106,9 +104,7 @@ def _gen_batch_selectors( else: # Each patch gets its own batch return {ind: [value] for ind, value in enumerate(patch_selectors)} - def _gen_patch_selectors( - self, ds: Union[xr.DataArray, xr.Dataset] - ) -> PatchGenerator: + def _gen_patch_selectors(self, ds: xr.DataArray | xr.Dataset) -> PatchGenerator: """ Create an iterator that can be used to index an Xarray Dataset/DataArray. """ @@ -127,7 +123,7 @@ def _gen_patch_selectors( return all_slices def _combine_patches_into_batch( - self, ds: Union[xr.DataArray, xr.Dataset], patch_selectors: PatchGenerator + self, ds: xr.DataArray | xr.Dataset, patch_selectors: PatchGenerator ) -> BatchSelectorSet: """ Combine the patch selectors to form a batch @@ -169,7 +165,7 @@ def _combine_patches_grouped_by_batch_dims( return dict(enumerate(batch_selectors)) def _combine_patches_grouped_by_input_and_batch_dims( - self, ds: Union[xr.DataArray, xr.Dataset], patch_selectors: PatchGenerator + self, ds: xr.DataArray | xr.Dataset, patch_selectors: PatchGenerator ) -> BatchSelectorSet: """ Combine patches with multiple slices along ``batch_dims`` grouped into @@ -197,7 +193,7 @@ def _gen_empty_batch_selectors(self) -> BatchSelectorSet: n_batches = np.prod(list(self._n_batches_per_dim.values())) return {k: [] for k in range(n_batches)} - def _gen_patch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]): + def _gen_patch_numbers(self, ds: xr.DataArray | xr.Dataset): """ Calculate the number of patches per dimension and the number of patches in each batch per dimension. @@ -214,7 +210,7 @@ def _gen_patch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]): for dim, length in self._all_sliced_dims.items() } - def _gen_batch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]): + def _gen_batch_numbers(self, ds: xr.DataArray | xr.Dataset): """ Calculate the number of batches per dimension """ @@ -324,7 +320,7 @@ def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> list[sli def _iterate_through_dimensions( - ds: Union[xr.Dataset, xr.DataArray], + ds: xr.Dataset | xr.DataArray, *, dims: dict[Hashable, int], overlap: dict[Hashable, int] = {}, @@ -350,10 +346,10 @@ def _iterate_through_dimensions( def _drop_input_dims( - ds: Union[xr.Dataset, xr.DataArray], + ds: xr.Dataset | xr.DataArray, input_dims: dict[Hashable, int], suffix: str = '_input', -) -> Union[xr.Dataset, xr.DataArray]: +) -> xr.Dataset | xr.DataArray: # remove input_dims coordinates from datasets, rename the dimensions # then put intput_dims back in as coordinates out = ds.copy() @@ -368,9 +364,9 @@ def _drop_input_dims( def _maybe_stack_batch_dims( - ds: Union[xr.Dataset, xr.DataArray], + ds: xr.Dataset | xr.DataArray, input_dims: Sequence[Hashable], -) -> Union[xr.Dataset, xr.DataArray]: +) -> xr.Dataset | xr.DataArray: batch_dims = [d for d in ds.sizes if d not in input_dims] if len(batch_dims) < 2: return ds @@ -424,14 +420,14 @@ class BatchGenerator: def __init__( self, - ds: Union[xr.Dataset, xr.DataArray], + ds: xr.Dataset | xr.DataArray, input_dims: dict[Hashable, int], input_overlap: dict[Hashable, int] = {}, batch_dims: dict[Hashable, int] = {}, concat_input_dims: bool = False, preload_batch: bool = True, - cache: Optional[dict[str, Any]] = None, - cache_preprocess: Optional[Callable] = None, + cache: dict[str, Any] | None = None, + cache_preprocess: Callable | None = None, ): self.ds = ds self.cache = cache @@ -466,14 +462,14 @@ def concat_input_dims(self): def preload_batch(self): return self._batch_selectors.preload_batch - def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]: + def __iter__(self) -> Iterator[xr.DataArray | xr.Dataset]: for idx in self._batch_selectors.selectors: yield self[idx] def __len__(self) -> int: return len(self._batch_selectors.selectors) - def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: + def __getitem__(self, idx: int) -> xr.Dataset | xr.DataArray: if not isinstance(idx, int): raise NotImplementedError( f'{type(self).__name__}.__getitem__ currently requires a single integer key' @@ -532,7 +528,7 @@ def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: def _batch_in_cache(self, idx: int) -> bool: return self.cache is not None and f'{idx}/.zgroup' in self.cache - def _cache_batch(self, idx: int, batch: Union[xr.Dataset, xr.DataArray]) -> None: + def _cache_batch(self, idx: int, batch: xr.Dataset | xr.DataArray) -> None: batch.to_zarr(self.cache, group=str(idx), mode='a') def _get_cached_batch(self, idx: int) -> xr.Dataset: diff --git a/xbatcher/loaders/keras.py b/xbatcher/loaders/keras.py index b3ad023..86dfb40 100644 --- a/xbatcher/loaders/keras.py +++ b/xbatcher/loaders/keras.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any try: import tensorflow as tf @@ -21,8 +22,8 @@ def __init__( X_generator, y_generator, *, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + transform: Callable | None = None, + target_transform: Callable | None = None, ) -> None: """ Keras Dataset adapter for Xbatcher diff --git a/xbatcher/loaders/torch.py b/xbatcher/loaders/torch.py index 9bfe9f1..77ebcb8 100644 --- a/xbatcher/loaders/torch.py +++ b/xbatcher/loaders/torch.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any try: import torch @@ -24,8 +25,8 @@ def __init__( self, X_generator, y_generator, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + transform: Callable | None = None, + target_transform: Callable | None = None, ) -> None: """ PyTorch Dataset adapter for Xbatcher diff --git a/xbatcher/testing.py b/xbatcher/testing.py index 354a458..219d592 100644 --- a/xbatcher/testing.py +++ b/xbatcher/testing.py @@ -1,5 +1,4 @@ from collections.abc import Hashable -from typing import Union from unittest import TestCase import numpy as np @@ -170,7 +169,7 @@ def get_batch_dimensions(generator: BatchGenerator) -> dict[Hashable, int]: def validate_batch_dimensions( - *, expected_dims: dict[Hashable, int], batch: Union[xr.Dataset, xr.DataArray] + *, expected_dims: dict[Hashable, int], batch: xr.Dataset | xr.DataArray ) -> None: """ Raises an AssertionError if the shape and dimensions of a batch do not