From eb13c053e9de9ce3338db31ca5edbce3cc47c014 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 7 Feb 2023 15:14:53 -0800 Subject: [PATCH] Add option to run tests with persistent compilation cache enabled. This can help us get a lot more coverage of the compilation cache, since all compiles will trigger it, instead of having to write explicit compilation cache tests. PiperOrigin-RevId: 507898535 --- jax/_src/test_util.py | 36 +++++++++++++++++-- .../compilation_cache/compilation_cache.py | 6 ++++ tests/compilation_cache_test.py | 9 +++-- 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 445ffff9b48b..447539c8ed4a 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import contextmanager +from contextlib import contextmanager, ExitStack import inspect import io import functools from functools import partial import re import os +import tempfile import textwrap -from typing import Callable, List, Generator, Sequence, Tuple, Union +from typing import Callable, List, Generator, Optional, Sequence, Tuple, Union import unittest import warnings import zlib @@ -33,6 +34,7 @@ import jax from jax import lax +from jax.experimental.compilation_cache import compilation_cache from jax.interpreters import mlir from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten from jax._src import api @@ -40,7 +42,9 @@ from jax._src import core from jax._src import dispatch from jax._src import dtypes as _dtypes -from jax._src.config import flags, bool_env, config +from jax._src.config import (flags, bool_env, config, + raise_persistent_cache_errors, + persistent_cache_min_compile_time_secs) from jax._src.numpy.lax_numpy import _promote_dtypes, _promote_dtypes_inexact from jax._src.util import prod, unzip2 from jax._src.lib import xla_bridge @@ -90,6 +94,12 @@ 'on the test name. If empty or unspecified, run all tests.' ) +flags.DEFINE_bool( + 'jax_test_with_persistent_compilation_cache', + bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False), + help='If enabled, the persistent compilation cache will be enabled for all ' + 'test cases. This can be used to increase compilation cache coverage.') + def num_float_bits(dtype): return _dtypes.finfo(_dtypes.canonicalize_dtype(dtype)).bits @@ -822,6 +832,8 @@ class JaxTestCase(parameterized.TestCase): 'jax_traceback_filtering': 'off', } + _compilation_cache_exit_stack: Optional[ExitStack] = None + # TODO(mattjj): this obscures the error messages from failures, figure out how # to re-enable it # def tearDown(self) -> None: @@ -844,6 +856,24 @@ def tearDown(self): config.update(key, value) super().tearDown() + @classmethod + def setUpClass(cls): + if FLAGS.jax_test_with_persistent_compilation_cache: + cls._compilation_cache_exit_stack = ExitStack() + stack = cls._compilation_cache_exit_stack + stack.enter_context(raise_persistent_cache_errors(True)) + stack.enter_context(persistent_cache_min_compile_time_secs(0)) + + tmp_dir = stack.enter_context(tempfile.TemporaryDirectory()) + compilation_cache.initialize_cache(tmp_dir) + stack.callback(lambda: compilation_cache.reset_cache() + if compilation_cache.is_initialized() else None) + + @classmethod + def tearDownClass(cls): + if FLAGS.jax_test_with_persistent_compilation_cache: + cls._compilation_cache_exit_stack.close() + def rng(self): return self._rng diff --git a/jax/experimental/compilation_cache/compilation_cache.py b/jax/experimental/compilation_cache/compilation_cache.py index 5e82a2397c33..f9f30bb6dd00 100644 --- a/jax/experimental/compilation_cache/compilation_cache.py +++ b/jax/experimental/compilation_cache/compilation_cache.py @@ -254,3 +254,9 @@ def _hash_int_list(hash_obj, int_list): def is_initialized(): return _cache is not None + +def reset_cache(): + global _cache + assert is_initialized() + logger.info("Resetting cache at %s.", _cache._path) + _cache = None diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index b347d0cd0f74..5beba41e7633 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -61,9 +61,14 @@ def setUp(self): raise SkipTest("serialize executable only works on " + ",".join(supported_platforms)) + # Reset cache if already initialized by JaxTestCase + if cc.is_initialized(): + cc.reset_cache() + def tearDown(self): - super().tearDown() - cc._cache = None + if cc.is_initialized(): + cc.reset_cache() + super().tearDown() def test_compile_options(self): compile_options_not_filled = xla_bridge.get_compile_options(