Skip to content

Commit

Permalink
Add option to run tests with persistent compilation cache enabled.
Browse files Browse the repository at this point in the history
This can help us get a lot more coverage of the compilation cache, since all compiles will trigger it, instead of having to write explicit compilation cache tests.

PiperOrigin-RevId: 507898535
  • Loading branch information
skye authored and jax authors committed Feb 7, 2023
1 parent 6860cb8 commit eb13c05
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 5 deletions.
36 changes: 33 additions & 3 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
from contextlib import contextmanager, ExitStack
import inspect
import io
import functools
from functools import partial
import re
import os
import tempfile
import textwrap
from typing import Callable, List, Generator, Sequence, Tuple, Union
from typing import Callable, List, Generator, Optional, Sequence, Tuple, Union
import unittest
import warnings
import zlib
Expand All @@ -33,14 +34,17 @@

import jax
from jax import lax
from jax.experimental.compilation_cache import compilation_cache
from jax.interpreters import mlir
from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten
from jax._src import api
from jax._src import pjit as pjit_lib
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes as _dtypes
from jax._src.config import flags, bool_env, config
from jax._src.config import (flags, bool_env, config,
raise_persistent_cache_errors,
persistent_cache_min_compile_time_secs)
from jax._src.numpy.lax_numpy import _promote_dtypes, _promote_dtypes_inexact
from jax._src.util import prod, unzip2
from jax._src.lib import xla_bridge
Expand Down Expand Up @@ -90,6 +94,12 @@
'on the test name. If empty or unspecified, run all tests.'
)

flags.DEFINE_bool(
'jax_test_with_persistent_compilation_cache',
bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False),
help='If enabled, the persistent compilation cache will be enabled for all '
'test cases. This can be used to increase compilation cache coverage.')

def num_float_bits(dtype):
return _dtypes.finfo(_dtypes.canonicalize_dtype(dtype)).bits

Expand Down Expand Up @@ -822,6 +832,8 @@ class JaxTestCase(parameterized.TestCase):
'jax_traceback_filtering': 'off',
}

_compilation_cache_exit_stack: Optional[ExitStack] = None

# TODO(mattjj): this obscures the error messages from failures, figure out how
# to re-enable it
# def tearDown(self) -> None:
Expand All @@ -844,6 +856,24 @@ def tearDown(self):
config.update(key, value)
super().tearDown()

@classmethod
def setUpClass(cls):
if FLAGS.jax_test_with_persistent_compilation_cache:
cls._compilation_cache_exit_stack = ExitStack()
stack = cls._compilation_cache_exit_stack
stack.enter_context(raise_persistent_cache_errors(True))
stack.enter_context(persistent_cache_min_compile_time_secs(0))

tmp_dir = stack.enter_context(tempfile.TemporaryDirectory())
compilation_cache.initialize_cache(tmp_dir)
stack.callback(lambda: compilation_cache.reset_cache()
if compilation_cache.is_initialized() else None)

@classmethod
def tearDownClass(cls):
if FLAGS.jax_test_with_persistent_compilation_cache:
cls._compilation_cache_exit_stack.close()

def rng(self):
return self._rng

Expand Down
6 changes: 6 additions & 0 deletions jax/experimental/compilation_cache/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,9 @@ def _hash_int_list(hash_obj, int_list):

def is_initialized():
return _cache is not None

def reset_cache():
global _cache
assert is_initialized()
logger.info("Resetting cache at %s.", _cache._path)
_cache = None
9 changes: 7 additions & 2 deletions tests/compilation_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,14 @@ def setUp(self):
raise SkipTest("serialize executable only works on " +
",".join(supported_platforms))

# Reset cache if already initialized by JaxTestCase
if cc.is_initialized():
cc.reset_cache()

def tearDown(self):
super().tearDown()
cc._cache = None
if cc.is_initialized():
cc.reset_cache()
super().tearDown()

def test_compile_options(self):
compile_options_not_filled = xla_bridge.get_compile_options(
Expand Down

0 comments on commit eb13c05

Please sign in to comment.