Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the flexibility of standardize_dtype and fix pad in torch backend #828

Merged

Conversation

james77777778
Copy link
Contributor

@james77777778 james77777778 commented Sep 1, 2023

One of the advantages of Keras Core is that we can integrate the workflow with different backends.
For example, we can train a tensorflow model using a torch dataloader.

However, operations containing standardize_dtype might fail when the dtype is torch.Tensor.dtype and the backend is NOT torch.

import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import torch

from keras_core import ops

x = torch.randn(4, 16, 16, 3)
y = ops.convert_to_tensor(x, dtype=x.dtype)  # failed w/o this PR
print(y.dtype)

This PR has addressed the issue by implementing a better check for torch.Tensor.dtype.
A unit test for this behavior has been included.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR

@@ -408,7 +408,7 @@ def standardize_dtype(dtype):
dtype = "int32"
if hasattr(dtype, "name"):
dtype = dtype.name
elif config.backend() == "torch":
if hasattr(dtype, "__str__") and "torch" in str(dtype):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use elif for better performance (and above too)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@james77777778
Copy link
Contributor Author

Hi @fchollet
I have updated the standardize_dtype to give the best performance as far as I know.

This change catched subtle bugs in:

  • digitize (numpy backend)
  • pad and isclose (torch backend)

It is surprising that x.dtype == "int" is True when the dtype is np.int64. This results in strange behavior in standardize_dtype.

>>> import numpy as np
>>> x = np.array([0.0, 1.0, 3.0, 1.6])
>>> bins = np.array([0.0, 3.0, 4.5, 7.0])
>>> np.digitize(x, bins).dtype
dtype('int64')
>>> np.digitize(x, bins).dtype == "int"
True

In torch backend:

@@ -676,7 +677,8 @@ def pad(x, pad_width, mode="constant"):
mode = "replicate"
if mode != "constant" and x.ndim < 3:
new_dims = [1] * (3 - x.ndim)
x = cast(x, torch.float32) if x.dtype == torch.int else x
if x.dtype not in (torch.float32, torch.float64):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about float16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tried float16 with mode=reflect and it is not supported by torch

>>> x = torch.randn(3, 10, 10, dtype=torch.float16)
>>> torch.nn.functional.pad(x, [1, 1], mode="reflect")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: "reflection_pad1d" not implemented for 'Half'

I believe the solution in this PR should be the same as the official one (torchvision)

https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py#L418-L423

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet
Should we restore the original dtype after the op? The current code cast to float32 for non-constant padding mode.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we restore the original dtype after the op?

In fact, yes -- that is the behavior that other backends follow. Does the unit test only check float32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, yes -- that is the behavior that other backends follow. Does the unit test only check float32?

Actually, the unit test only checked int64.

def test_pad(self):
x = np.array([[1, 2], [3, 4]])
self.assertAllClose(
knp.pad(x, ((1, 1), (1, 1))),
np.pad(x, ((1, 1), (1, 1))),
)
self.assertAllClose(
knp.pad(x, ((1, 1), (1, 1))),
np.pad(x, ((1, 1), (1, 1))),
)

Please see the new comment below

@james77777778
Copy link
Contributor Author

james77777778 commented Sep 2, 2023

@fchollet

I have refactored pad in torch backend to accommodate the restriction of torch.nn.functional.pad.
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html

Replicate and reflection padding are implemented for padding the last 3 dimensions of a 4D or 5D input tensor, the last 2 dimensions of a 3D or 4D input tensor, or the last dimension of a 2D or 3D input tensor.

In the example below, we can find that reflect padding is not working when pad_width is a 5D list even it is a 3D padding. However, it works if we remove the redundant 0.

>>> x = torch.ones((2, 3, 4, 5, 6))
>>> torch.nn.functional.pad(x, [2, 3, 1, 1, 1, 1, 0, 0, 0, 0], mode="reflect").shape
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now
>>> torch.nn.functional.pad(x, [2, 3, 1, 1, 1, 1], mode="reflect").shape
torch.Size([2, 3, 6, 7, 11])
>>>

I have also updated the unit test to improve the coverage for various shapes, dtypes and modes.

@james77777778 james77777778 changed the title Improve the flexibility in dtype check for torch.Tensor.dtype Improve the flexibility of standardize_dtype and fix pad in torch backend Sep 2, 2023
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@fchollet fchollet merged commit de510e9 into keras-team:main Sep 2, 2023
6 checks passed
@james77777778 james77777778 deleted the improve-flexibility-in-dtype branch September 4, 2023 01:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants