From 3f5f3e1c47c230cc5d44841b08c3db9598442d13 Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 5 Dec 2024 08:39:48 -0800 Subject: [PATCH] [export] Removed __gpu$xla.gpu.triton (Pallas GPU) from the list of custom calls with guaranteed compatibility. This is because the underlying Triton IR does not guarantee compatibility. PiperOrigin-RevId: 703127711 --- CHANGELOG.md | 5 +++++ jax/_src/export/_export.py | 3 ++- tests/pallas/export_back_compat_pallas_test.py | 5 +++-- tests/pallas/export_pallas_test.py | 4 ++++ 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 258fad49b5f4..b6d0f97f439d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -70,6 +70,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. return NaN for negative integer inputs, to match the behavior of SciPy from https://github.com/scipy/scipy/pull/21827. * `jax.clear_backends` was removed after being deprecated in v0.4.26. + * We removed the custom call "__gpu$xla.gpu.triton" from the list of custom + call that we guarantee export stability. This is because this custom call + relies on Triton IR, which is not guaranteed to be stable. If you need + to export code that uses this custom call, you can use the `disabled_checks` + parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls). * New Features * {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index ad2c7fdac2dc..e3508639fe15 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -1005,7 +1005,8 @@ def _check_lowering(lowering) -> None: *_CPU_FFI_KERNELS, "Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape", "cu_threefry2x32", "cu_threefry2x32_ffi", - "__gpu$xla.gpu.triton", # Pallas call on GPU + # Triton IR does not guarantee stability. + # "__gpu$xla.gpu.triton", # cholesky on CPU "lapack_spotrf", "lapack_dpotrf", "lapack_cpotrf", "lapack_zpotrf", # eigh on TPU diff --git a/tests/pallas/export_back_compat_pallas_test.py b/tests/pallas/export_back_compat_pallas_test.py index 1b810bcb6f26..462597e567f2 100644 --- a/tests/pallas/export_back_compat_pallas_test.py +++ b/tests/pallas/export_back_compat_pallas_test.py @@ -48,9 +48,10 @@ def setUp(self): self.skipTest("Only works on GPUs with capability >= sm80") super().setUp() - @unittest.skip("TODO(necula): This test is checking backwards compatibility " + @unittest.skip("This test is checking backwards compatibility " "of Triton IR, but Triton doesn't promise backwards " - "compatibility for its IR.") + "compatibility for its IR, and we have since removed " + "the corresponding custom call from the guaranteed stable list.") def test_triton_add_one(self): def func(x): def add_one(x_ref, o_ref): diff --git a/tests/pallas/export_pallas_test.py b/tests/pallas/export_pallas_test.py index 70e40e1f2801..8b18f706a1d0 100644 --- a/tests/pallas/export_pallas_test.py +++ b/tests/pallas/export_pallas_test.py @@ -50,6 +50,10 @@ def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: exp = export.export( add_vectors, platforms=["tpu", "cuda"], + # The Pallas GPU custom call is not enabled for export by default. + disabled_checks=[ + export.DisabledSafetyCheck.custom_call("__gpu$xla.gpu.triton") + ] )(a, a) if (jtu.device_under_test() == "tpu" or