Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[callback] Allow external callbacks to return 64-bit values in 32-bit mode #20534

Merged
merged 1 commit into from
Apr 3, 2024

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Apr 2, 2024

Previously, prior to #20433, if the Python callback returned a Python literal (which is natively a 64-bit value), and the result_shape_dtypes specified a 32-bit expected returned value, we would just get garbage results. In #20433, I introduced an error in this situation. However, when trying to port the internal code that uses host_callback to io_callback, I am getting many instances of this error. The common scenario is a Python callback function that returns a Python scalar:

def f_host():
  return 42.

io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))

However, if the f_host were called directly JAX would canonicalize the value 42. to a float32 (when jax_enable_x64 is not set). I do not think that it makes sense for io_callback to have stricter behavior that a direct call.

In this PR we add a canonicalization step on the returned values of Python callbacks, which casts the values to 32-bits if JAX is running in 32-bit mode.

Note that the above example should return an error in 64-bit mode, because the actual returned value is a 64-bit value but the declared expected value is np.float32. To avoid the error in both 64-bit and 32-bit mode, the python callback should return np.float32(42.).

In some sense this is replacing the change in #20433 to add a canonicalization step instead of an error.

@gnecula gnecula self-assigned this Apr 2, 2024
@gnecula gnecula requested review from sharadmv and superbobry and removed request for superbobry and sharadmv April 2, 2024 14:12
@gnecula gnecula requested review from superbobry and sharadmv April 2, 2024 14:23
@@ -225,7 +225,7 @@ def _check_shape_dtype(shape_dtype):
dt = np.dtype(shape_dtype.dtype)
if dtypes.canonicalize_dtype(dt) != dt:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check that jax_enable_x64 is disabled as well or does it follow from canonicalize_dtype(dt) != dt?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using the same pattern elsewhere in the code, to check that x64 is not enabler and the type is a 64-bit type.

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Apr 3, 2024
… mode

Previously, prior to jax-ml#20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In jax-ml#20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:

```
def f_host():
  return 42.

io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```

However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.

In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.

In some sense this is replacing the change in  jax-ml#20433 to add a canonicalization
step instead of an error.
@copybara-service copybara-service bot merged commit d89f0d6 into jax-ml:main Apr 3, 2024
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants