From cd41b4fd06a0e685768a86159579c3cb8c555f21 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Wed, 15 May 2024 21:13:59 -0700 Subject: [PATCH] Use `util.cache` instead of `lru_cache` for `create_mesh_pspec_sharding` Its return value depends on `jax.config.enable_memories` due to the memory kind canonicalization, so we should use `util.cache` that uses the trace_context as an additional key. PiperOrigin-RevId: 634192701 --- jax/_src/interpreters/pxla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 9c50908b615b..e72cd07a8323 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -3135,7 +3135,7 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings): return in_shardings, out_shardings, committed, tuple(local_devices) -@lru_cache +@util.cache() def create_mesh_pspec_sharding( mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None, memory_kind: str | None = None) -> sharding_impls.NamedSharding: