Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[callback] Fix io_callback for callbacks that return Python literals. #20433

Merged
merged 1 commit into from
Mar 26, 2024

Conversation

copybara-service[bot]
Copy link

@copybara-service copybara-service bot commented Mar 26, 2024

[callback] Fix io_callback for callbacks that return Python literals.

The internal implementation of io_callback and friends currently use .shape and .dtype on the result of the callback. This fails if the callback returns a Python literal.

Fixed the checks that the callback returns values of expected shape and dtype,
and added tests.

Reverts 19e6156

@copybara-service copybara-service bot force-pushed the test_619100631 branch 4 times, most recently from e07671d to 5187651 Compare March 26, 2024 12:18
The internal implementation of io_callback and friends currently use .shape and .dtype on the result of the callback. This fails if the callback returns a Python literal.

Fixed the checks that the callback returns values of expected shape and dtype,
and added tests.

Reverts 19e6156

PiperOrigin-RevId: 619156176
@copybara-service copybara-service bot merged commit 75db481 into main Mar 26, 2024
@copybara-service copybara-service bot deleted the test_619100631 branch March 26, 2024 12:32
gnecula added a commit to gnecula/jax that referenced this pull request Apr 2, 2024
… mode

Previously, prior to jax-ml#20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In jax-ml#20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:

```
def f_host():
  return 42.

io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```

However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.

In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.

In some sense this is replacing the change in  jax-ml#20433 to add a canonicalization
step instead of an error.
gnecula added a commit to gnecula/jax that referenced this pull request Apr 2, 2024
… mode

Previously, prior to jax-ml#20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In jax-ml#20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:

```
def f_host():
  return 42.

io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```

However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.

In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.

In some sense this is replacing the change in  jax-ml#20433 to add a canonicalization
step instead of an error.
gnecula added a commit to gnecula/jax that referenced this pull request Apr 3, 2024
… mode

Previously, prior to jax-ml#20433, if the Python callback returned a Python literal
(which is natively a 64-bit value), and the `result_shape_dtypes` specified
a 32-bit expected returned value, we would just get garbage results. In jax-ml#20433, I introduced
an error in this situation. However, when trying to port the internal code that
uses host_callback to `io_callback`, I am getting many instances of this error.
The common scenario is a Python callback function that returns a Python scalar:

```
def f_host():
  return 42.

io_callback(f_host, jax.ShapeDtypeStruct((), np.float32))
```

However, if the `f_host` were called directly JAX would canonicalize
the value `42.` to a float32 (when `jax_enable_x64` is not set). I do not
think that it makes sense for `io_callback` to have stricter behaviour
that a direct call.

In this PR we add a canonicalization step on the returned values of
Python callbacks, which would cast the values to 32-bits.

In some sense this is replacing the change in  jax-ml#20433 to add a canonicalization
step instead of an error.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant