-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for 16 bits png images #4657
Conversation
{int64_t(height), int64_t(width), channels}, | ||
bit_depth <= 8 ? torch::kU8 : torch::kI32); | ||
|
||
if (bit_depth <= 8) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this if
block is unchanged and corresponds to the original code. I just renamed ptr
into t_ptr
, because the other block uses too many pointers for ptr
to be explicit enough
@@ -11,6 +11,11 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { | |||
} | |||
#else | |||
|
|||
bool is_little_endian() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// We're reading a 16bits png, but pytorch doesn't support uint16. | ||
// So we read each row in a 16bits tmp_buffer which we then cast into | ||
// a int32 tensor instead. | ||
if (is_little_endian()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fmassa I eventually realized that this was a much cleaner and simpler way to handle the endianness. The rest takes care of itself when we cast the uint16 value into a int32_t a few lines below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a ton for adding support for 16-bit PNGs!
I have one concern about current implementation, otherwise the rest LGTM!
uint16_t* tmp_buffer = | ||
(uint16_t*)malloc(num_pixels_per_row * sizeof(uint16_t)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This leads to a memory leak in the end of the function.
If you malloc
, you need to free
after it's used. But you'll need to handle a few corner cases in the freeing size (what if png_read_row
fails?).
I think it would be easier to just allocate the buffer via PyTorch torch::empty
(or raw data via at::DataPtr
via at::getCPUAllocator()->allocate(length);
, but I think torch::empty
is easier to use, up to you)
torchvision/io/image.py
Outdated
@@ -61,7 +61,12 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE | |||
""" | |||
Decodes a PNG image into a 3 dimensional RGB Tensor. | |||
Optionally converts the image to the desired format. | |||
The values of the output tensor are uint8 between 0 and 255. | |||
The values of the output tensor are uint8 between 0 and 255, except for | |||
16-bits pngs which are int32 tensors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also mention the range for 16-bit pngs, which is from 0-65k?
// We're reading a 16bits png, but pytorch doesn't support uint16. | ||
// So we read each row in a 16bits tmp_buffer which we then cast into | ||
// a int32 tensor instead. | ||
if (is_little_endian()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
auto tmp_buffer_tensor = torch::empty( | ||
{int64_t(num_pixels_per_row * sizeof(uint16_t))}, torch::kU8); | ||
uint16_t* tmp_buffer = | ||
(uint16_t*)tmp_buffer_tensor.accessor<uint8_t, 1>().data(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit because it was already like this before: you can just do tmp_buffer_tensor.data_ptr<uint8_t, 1>()
Summary: * WIP * cleaner code * Add tests * Add docs * Assert dtype * put back check * Address comments Reviewed By: NicolasHug Differential Revision: D31916334 fbshipit-source-id: 8877266f6e533e8c45c5f202e535944a9a939376 Co-authored-by: Francisco Massa <[email protected]>
* WIP * cleaner code * Add tests * Add docs * Assert dtype * put back check * Address comments Co-authored-by: Francisco Massa <[email protected]>
Closes #4107
Closes #2218
This PR adds support for 16 bits pngs. Since pytorch doesn't support the uint16 dtype, we return int32 tensors instead (we indicate in the doc that we will be returning uint16 tensors in the future, if pytorch start supporting those).
Among other things, this will enable training RAFT on the Kitti dataset, which currently can only be done by relying on openCV.
PIL support for 16 bits png is a bit limited and buggy, especially for grayscale images (python-pillow/Pillow#3011). PIL also automatically converts the 16bits values to uint8, loosing tons of precision. This makes it hard to test. For this reason I only added test for one RGB image and one RGBA image. According to a few ad-hoc tests, grayscale images are decoded properly (unlike for PIL).
Also, for all 200 Kitti-Flow ground-truth flow images, this code returns the exact same values as the cv2 version.
This code takes about the same time as cv2 to decode a 1567 x 1965 RGBA image. PIL is a lot faster but I assume that this is because they downcast everything to uint8:
Note: I observe the same relative performance on 8 bits images: torchvision == cv2 >> PIL