Skip to content

Commit

Permalink
Use util.cache instead of lru_cache for create_mesh_pspec_sharding
Browse files Browse the repository at this point in the history
Its return value depends on `jax.config.enable_memories` due to the memory kind canonicalization, so we should use `util.cache` that uses the trace_context as an additional key.

PiperOrigin-RevId: 634192701
  • Loading branch information
junwhanahn authored and jax authors committed May 16, 2024
1 parent 6fe313c commit cd41b4f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -3135,7 +3135,7 @@ def _get_metadata_jit_pmap(local_devices, num_in_shardings, num_out_shardings):
return in_shardings, out_shardings, committed, tuple(local_devices)


@lru_cache
@util.cache()
def create_mesh_pspec_sharding(
mesh: Mesh, pspec: PartitionSpec | None, parsed_pspec=None,
memory_kind: str | None = None) -> sharding_impls.NamedSharding:
Expand Down

0 comments on commit cd41b4f

Please sign in to comment.