Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let torch accessor and dataloader handle either xarray.DataArray or xarray.Dataset inputs #85

Merged
merged 4 commits into from
Aug 25, 2022
Merged

Conversation

weiji14
Copy link
Member

@weiji14 weiji14 commented Aug 19, 2022

Description of proposed changes

Convert xarray.Dataset to xarray.DataArray first, so that the .data method works to get the underlying array which can be converted to a torch.Tensor. Also added parametrized unit tests, only for the 2D ('x', 'y') case for both xarray.DataArray and xarray.Dataset.

TODO:

  • Preliminary support in .torch accessor to convert xarray.Dataset directly to torch.Tensor
  • Add parametrized unit tests for xarray.Dataset and xarray.DataArray inputs
    • In the DataLoader xbatcher/loaders/torch.py
    • In the xarray accessors xbatcher/accessor.py

Fixes #84, continuation of #71

Convert xarray.Dataset to xarray.DataArray first, so that the `.data` method work to get the underlying array which can be converted to a torch.Tensor.
@codecov
Copy link

codecov bot commented Aug 19, 2022

Codecov Report

Merging #85 (106fff9) into main (0ded974) will not change coverage.
The diff coverage is 100.00%.

@@            Coverage Diff            @@
##              main       #85   +/-   ##
=========================================
  Coverage   100.00%   100.00%           
=========================================
  Files            5         5           
  Lines          179       182    +3     
  Branches        40        37    -3     
=========================================
+ Hits           179       182    +3     
Impacted Files Coverage Δ
xbatcher/accessors.py 100.00% <100.00%> (ø)
xbatcher/generators.py 100.00% <100.00%> (ø)
xbatcher/loaders/keras.py 100.00% <100.00%> (ø)
xbatcher/loaders/torch.py 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

Need to squeeze the extra first dimension in order to preserve the same output shape for xarray.DataArray and xarray.Dataset.
Resolve the strange extra dimension of 1, which is because torch.utils.data.DataLoader adds a batch dimension by default. Setting to `batch_size=None` means no extra batch dimension is prepended.
@weiji14 weiji14 marked this pull request as ready for review August 19, 2022 18:21
@weiji14 weiji14 marked this pull request as draft August 19, 2022 18:22
@weiji14 weiji14 marked this pull request as ready for review August 19, 2022 18:36
@weiji14 weiji14 changed the title Let torch dataloader handle either xarray.DataArray or xarray.Dataset objects Let torch accessor and dataloader handle either xarray.DataArray or xarray.Dataset objects Aug 19, 2022
@weiji14 weiji14 changed the title Let torch accessor and dataloader handle either xarray.DataArray or xarray.Dataset objects Let torch accessor and dataloader handle either xarray.DataArray or xarray.Dataset inputs Aug 19, 2022
Copy link
Member

@maxrjones maxrjones left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @weiji14!

It might be worthwhile to move _as_xarray_dataarray outside of the TorchAccessor class and make the check between the last item in the generator and batch a general testing utility function, but we could do that when they're needed elsewhere in the package.

@weiji14
Copy link
Member Author

weiji14 commented Aug 24, 2022

It might be worthwhile to move _as_xarray_dataarray outside of the TorchAccessor class and make the check between the last item in the generator and batch a general testing utility function, but we could do that when they're needed elsewhere in the package.

Yep, could probably reuse those utility functions for the keras/tensorflow side too!

Btw, I don't have merge permissions, but happy to be added as a maintainer if extra hands are welcome 😃

@maxrjones
Copy link
Member

Btw, I don't have merge permissions, but happy to be added as a maintainer if extra hands are welcome 😃

fantastic! I'll add you as soon as we finish the transfer discussed at #86.

@maxrjones maxrjones merged commit ed45a99 into xarray-contrib:main Aug 25, 2022
@weiji14 weiji14 deleted the accessor/da_ds_support branch August 25, 2022 15:00
@weiji14 weiji14 mentioned this pull request Aug 31, 2022
7 tasks
@maxrjones maxrjones added the enhancement New feature or request label Sep 23, 2022
@maxrjones maxrjones added feature and removed enhancement New feature or request labels Oct 17, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support xarray.DataArray and xarray.Dataset batches in PyTorch dataloader
2 participants