diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index f501de9..03ffea8 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -9,13 +9,6 @@ on: - cron: "0 0 * * *" jobs: - lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2.4.0 - - uses: actions/setup-python@v2.3.0 - - uses: pre-commit/action@v2.0.3 - test: name: ${{ matrix.python-version }}-build runs-on: ubuntu-latest @@ -41,7 +34,13 @@ jobs: python -m pip list - name: Running Tests run: | - python -m pytest --verbose + py.test --verbose --cov=. --cov-report=xml + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v2.1.0 + if: ${{ matrix.python-version }} == 3.9 + with: + file: ./coverage.xml + fail_ci_if_error: false test-upstream: name: ${{ matrix.python-version }}-dev-build @@ -71,4 +70,4 @@ jobs: python -m pip list - name: Running Tests run: | - python -m pytest --verbose + py.test --verbose --cov=. diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4a50297..4786ba5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,7 @@ +ci: + autoupdate_schedule: quarterly + autofix_prs: false + repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.0.1 diff --git a/dev-requirements.txt b/dev-requirements.txt index 21fab1d..0aeec8b 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,4 +1,4 @@ pytest -coverage +pytest-cov adlfs -r requirements.txt diff --git a/xbatcher/features.py b/xbatcher/features.py deleted file mode 100644 index 41c3de0..0000000 --- a/xbatcher/features.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Functions for transforming xarray datasets into features that can -be input to machine learning libraries.""" - - -def dataset_to_feature_dataframe(ds, coords_as_features=False): - df = ds.to_dataframe() - return df diff --git a/xbatcher/tests/test_generators.py b/xbatcher/tests/test_generators.py index 54984b3..38acae9 100644 --- a/xbatcher/tests/test_generators.py +++ b/xbatcher/tests/test_generators.py @@ -18,6 +18,29 @@ def sample_ds_1d(): return ds +@pytest.fixture(scope='module') +def sample_ds_3d(): + shape = (10, 50, 100) + ds = xr.Dataset( + { + 'foo': (['time', 'y', 'x'], np.random.rand(*shape)), + 'bar': (['time', 'y', 'x'], np.random.randint(0, 10, shape)), + }, + { + 'x': (['x'], np.arange(shape[-1])), + 'y': (['y'], np.arange(shape[-2])), + }, + ) + return ds + + +def test_constructor_coerces_to_dataset(): + da = xr.DataArray(np.random.rand(10), dims='x', name='foo') + bg = BatchGenerator(da, input_dims={'x': 2}) + assert isinstance(bg.ds, xr.Dataset) + assert bg.ds.equals(da.to_dataset()) + + # TODO: decide how to handle bsizes like 15 that don't evenly divide the dimension # Should we enforce that each batch size always has to be the same @pytest.mark.parametrize('bsize', [5, 10]) @@ -86,22 +109,6 @@ def test_batch_1d_overlap(sample_ds_1d, olap): assert ds_batch.equals(ds_batch_expected) -@pytest.fixture(scope='module') -def sample_ds_3d(): - shape = (10, 50, 100) - ds = xr.Dataset( - { - 'foo': (['time', 'y', 'x'], np.random.rand(*shape)), - 'bar': (['time', 'y', 'x'], np.random.randint(0, 10, shape)), - }, - { - 'x': (['x'], np.arange(shape[-1])), - 'y': (['y'], np.arange(shape[-2])), - }, - ) - return ds - - @pytest.mark.parametrize('bsize', [5, 10]) def test_batch_3d_1d_input(sample_ds_3d, bsize): @@ -160,3 +167,25 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize): * (sample_ds_3d.dims['y'] // bsize) * sample_ds_3d.dims['time'] ) + + +def test_preload_batch_false(sample_ds_1d): + sample_ds_1d_dask = sample_ds_1d.chunk({'x': 2}) + bg = BatchGenerator( + sample_ds_1d_dask, input_dims={'x': 2}, preload_batch=False + ) + assert bg.preload_batch is False + for ds_batch in bg: + assert isinstance(ds_batch, xr.Dataset) + assert ds_batch.chunks + + +def test_preload_batch_true(sample_ds_1d): + sample_ds_1d_dask = sample_ds_1d.chunk({'x': 2}) + bg = BatchGenerator( + sample_ds_1d_dask, input_dims={'x': 2}, preload_batch=True + ) + assert bg.preload_batch is True + for ds_batch in bg: + assert isinstance(ds_batch, xr.Dataset) + assert not ds_batch.chunks