-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Let torch accessor and dataloader handle either xarray.DataArray or xarray.Dataset inputs #85
Let torch accessor and dataloader handle either xarray.DataArray or xarray.Dataset inputs #85
Conversation
Convert xarray.Dataset to xarray.DataArray first, so that the `.data` method work to get the underlying array which can be converted to a torch.Tensor.
Codecov Report
@@ Coverage Diff @@
## main #85 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 5 5
Lines 179 182 +3
Branches 40 37 -3
=========================================
+ Hits 179 182 +3
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
Need to squeeze the extra first dimension in order to preserve the same output shape for xarray.DataArray and xarray.Dataset.
Resolve the strange extra dimension of 1, which is because torch.utils.data.DataLoader adds a batch dimension by default. Setting to `batch_size=None` means no extra batch dimension is prepended.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @weiji14!
It might be worthwhile to move _as_xarray_dataarray
outside of the TorchAccessor class and make the check between the last item in the generator and batch a general testing utility function, but we could do that when they're needed elsewhere in the package.
Yep, could probably reuse those utility functions for the keras/tensorflow side too! Btw, I don't have merge permissions, but happy to be added as a maintainer if extra hands are welcome 😃 |
fantastic! I'll add you as soon as we finish the transfer discussed at #86. |
Description of proposed changes
Convert xarray.Dataset to xarray.DataArray first, so that the
.data
method works to get the underlying array which can be converted to a torch.Tensor. Also added parametrized unit tests, only for the 2D ('x', 'y') case for bothxarray.DataArray
andxarray.Dataset
.TODO:
.torch
accessor to convertxarray.Dataset
directly totorch.Tensor
xarray.Dataset
andxarray.DataArray
inputsxbatcher/loaders/torch.py
xbatcher/accessor.py
Fixes #84, continuation of #71