Skip to content

Commit

Permalink
[export] Removed __gpu$xla.gpu.triton (Pallas GPU) from the list of c…
Browse files Browse the repository at this point in the history
…ustom calls with guaranteed compatibility.

This is because the underlying Triton IR does not guarantee compatibility.

PiperOrigin-RevId: 703127711
  • Loading branch information
gnecula authored and Google-ML-Automation committed Dec 5, 2024
1 parent 4a41aa0 commit 3f5f3e1
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 3 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
return NaN for negative integer inputs, to match the behavior of SciPy from
https://github.com/scipy/scipy/pull/21827.
* `jax.clear_backends` was removed after being deprecated in v0.4.26.
* We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
call that we guarantee export stability. This is because this custom call
relies on Triton IR, which is not guaranteed to be stable. If you need
to export code that uses this custom call, you can use the `disabled_checks`
parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls).

* New Features
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
Expand Down
3 changes: 2 additions & 1 deletion jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,8 @@ def _check_lowering(lowering) -> None:
*_CPU_FFI_KERNELS,
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape",
"cu_threefry2x32", "cu_threefry2x32_ffi",
"__gpu$xla.gpu.triton", # Pallas call on GPU
# Triton IR does not guarantee stability.
# "__gpu$xla.gpu.triton",
# cholesky on CPU
"lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf",
# eigh on TPU
Expand Down
5 changes: 3 additions & 2 deletions tests/pallas/export_back_compat_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ def setUp(self):
self.skipTest("Only works on GPUs with capability >= sm80")
super().setUp()

@unittest.skip("TODO(necula): This test is checking backwards compatibility "
@unittest.skip("This test is checking backwards compatibility "
"of Triton IR, but Triton doesn't promise backwards "
"compatibility for its IR.")
"compatibility for its IR, and we have since removed "
"the corresponding custom call from the guaranteed stable list.")
def test_triton_add_one(self):
def func(x):
def add_one(x_ref, o_ref):
Expand Down
4 changes: 4 additions & 0 deletions tests/pallas/export_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
exp = export.export(
add_vectors,
platforms=["tpu", "cuda"],
# The Pallas GPU custom call is not enabled for export by default.
disabled_checks=[
export.DisabledSafetyCheck.custom_call("__gpu$xla.gpu.triton")
]
)(a, a)

if (jtu.device_under_test() == "tpu" or
Expand Down

0 comments on commit 3f5f3e1

Please sign in to comment.