Skip to content

Commit

Permalink
Sort xla_flags so that the order of flags (whether added in ENV or co…
Browse files Browse the repository at this point in the history
…mmand line) does not affect the generated cache key.

PiperOrigin-RevId: 609773236
  • Loading branch information
Jieying Luo authored and jax authors committed Feb 23, 2024
1 parent f5c0021 commit bb5997b
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/cache_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _hash_xla_flags(hash_obj, extra_flag_prefixes: list[str]):

# N.B. all XLA flags that take an argument must use '=' and not a space
# (e.g. --xla_force_host_platform_device_count=8) (I think).
for flag in xla_flags:
for flag in sorted(xla_flags):
if flag.split("=")[0] in xla_flags_to_exclude_from_cache_key:
logger.debug("Not including XLA flag in cache key: %s", flag)
continue
Expand Down

0 comments on commit bb5997b

Please sign in to comment.