-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve the flexibility of standardize_dtype
and fix pad
in torch backend
#828
Improve the flexibility of standardize_dtype
and fix pad
in torch backend
#828
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR
@@ -408,7 +408,7 @@ def standardize_dtype(dtype): | |||
dtype = "int32" | |||
if hasattr(dtype, "name"): | |||
dtype = dtype.name | |||
elif config.backend() == "torch": | |||
if hasattr(dtype, "__str__") and "torch" in str(dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use elif for better performance (and above too)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Hi @fchollet This change catched subtle bugs in:
It is surprising that >>> import numpy as np
>>> x = np.array([0.0, 1.0, 3.0, 1.6])
>>> bins = np.array([0.0, 3.0, 4.5, 7.0])
>>> np.digitize(x, bins).dtype
dtype('int64')
>>> np.digitize(x, bins).dtype == "int"
True In torch backend:
|
keras_core/backend/torch/numpy.py
Outdated
@@ -676,7 +677,8 @@ def pad(x, pad_width, mode="constant"): | |||
mode = "replicate" | |||
if mode != "constant" and x.ndim < 3: | |||
new_dims = [1] * (3 - x.ndim) | |||
x = cast(x, torch.float32) if x.dtype == torch.int else x | |||
if x.dtype not in (torch.float32, torch.float64): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about float16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tried float16 with mode=reflect
and it is not supported by torch
>>> x = torch.randn(3, 10, 10, dtype=torch.float16)
>>> torch.nn.functional.pad(x, [1, 1], mode="reflect")
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: "reflection_pad1d" not implemented for 'Half'
I believe the solution in this PR should be the same as the official one (torchvision)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fchollet
Should we restore the original dtype after the op? The current code cast to float32 for non-constant padding mode.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we restore the original dtype after the op?
In fact, yes -- that is the behavior that other backends follow. Does the unit test only check float32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact, yes -- that is the behavior that other backends follow. Does the unit test only check float32?
Actually, the unit test only checked int64.
keras-core/keras_core/ops/numpy_test.py
Lines 2936 to 2945 in fa547ec
def test_pad(self): | |
x = np.array([[1, 2], [3, 4]]) | |
self.assertAllClose( | |
knp.pad(x, ((1, 1), (1, 1))), | |
np.pad(x, ((1, 1), (1, 1))), | |
) | |
self.assertAllClose( | |
knp.pad(x, ((1, 1), (1, 1))), | |
np.pad(x, ((1, 1), (1, 1))), | |
) |
Please see the new comment below
I have refactored
In the example below, we can find that >>> x = torch.ones((2, 3, 4, 5, 6))
>>> torch.nn.functional.pad(x, [2, 3, 1, 1, 1, 1, 0, 0, 0, 0], mode="reflect").shape
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
NotImplementedError: Only 2D, 3D, 4D, 5D padding with non-constant padding are supported for now
>>> torch.nn.functional.pad(x, [2, 3, 1, 1, 1, 1], mode="reflect").shape
torch.Size([2, 3, 6, 7, 11])
>>> I have also updated the unit test to improve the coverage for various shapes, dtypes and modes. |
torch.Tensor.dtype
standardize_dtype
and fix pad
in torch backend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
One of the advantages of Keras Core is that we can integrate the workflow with different backends.
For example, we can train a tensorflow model using a torch dataloader.
However, operations containing
standardize_dtype
might fail when the dtype istorch.Tensor.dtype
and the backend is NOT torch.This PR has addressed the issue by implementing a better check for
torch.Tensor.dtype
.A unit test for this behavior has been included.