Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use
util.cache
instead of lru_cache
for create_mesh_pspec_sharding
Its return value depends on `jax.config.enable_memories` due to the memory kind canonicalization, so we should use `util.cache` that uses the trace_context as an additional key. PiperOrigin-RevId: 634192701
- Loading branch information