-
Notifications
You must be signed in to change notification settings - Fork 334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dtype field to RasterDataset #1149
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically this is an API change (we introduce a new class attribute) so it should wait until 0.5.0, but it also fixes an unintended consequence of #992, so it could also be argued that it should go in 0.4.1. Not sure which to choose.
We'll need to take this PR into account when thinking about #985.
Just to complicate things, different bands of the same image may have different dtypes: #1182 Not sure how to help with this one. We can't stack all bands of the image into one tensor unless the tensor has a single shared dtype. |
That needs to be handled by casting in a pre-processing step. It doesn't make sense for RasterDataset to have different dtypes for different bands because the input to a neural network needs to be a float32, 16, 8, .. not a int/float mashup. |
In the case of #1182 it may make more sense for the PIXEL_QA band to be read as a mask instead of an image. |
7e78884
to
ec6fa11
Compare
Related to this, it may be useful to add a similar setting for resampling algorithm. By default, all resampling is done by nearest neighbors, which is perfect for classification masks, but less useful (albeit fast) for images. |
This is supported -- you would make one dataset for the imagery, one dataset for the PIXEL_QA mask then intersection.
Agree, I'll open an issue Is there anything that is holding this PR back? |
The only thing holding this PR back is that |
Often the masks will be longs while the inputs will be float32s. You can try running this:
You should get:
|
Kinda surprised float64 doesn't work. I would be fine with calling it |
Actually, calling it |
Let's call it dtype then. I would either let it apply to both images and masks, or raise an error when someone uses not-float32 for an image. |
How about a UserWarning? Raising an all-out error seems extreme to me when everything should otherwise work. |
This should help with #849 btw |
Needs tests, see https://docs.pytest.org/en/7.1.x/how-to/capture-warnings.html |
d342c59
to
a85a1c2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with this approach. We could also use something like:
@property
def dtype(self) -> torch.dtype:
if self.is_image:
return torch.float32
else:
return torch.long
Then we would only need to override this for the small number of datasets that have a different dtype.
In the great DataModule overhaul of winter 2022/2023 (#992) we made sure that all the dtypes returned by our datasets played nicely with Kornia. As part of this, we made the assumption in RasterDataset that all "mask" layers should be
torch.long
and all images should betorch.float
. As a result, our RasterDatasets can essentially not be used for real valued regression tasks.This PR aims to fix this by introducing a new field to RasterDataset called
dtype
that, if set, forces a cast to that type right before the dataset level transform is called. With this we can fix the dtype expectations on a per dataset level.