From 2a85f8506ce5110a9c1ba69efa4dbfd9d2a12544 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 19 Mar 2021 13:49:38 -0700 Subject: [PATCH] unify configuration state handling --- jax/__init__.py | 3 +- jax/_src/lax/control_flow.py | 4 +- jax/_src/util.py | 3 +- jax/api.py | 6 +- jax/config.py | 182 ++++++++++++++++--- jax/core.py | 45 +---- jax/custom_derivatives.py | 2 +- jax/dtypes.py | 8 - jax/experimental/jax2tf/jax2tf.py | 14 +- jax/experimental/jax2tf/tests/jax2tf_test.py | 2 +- jax/experimental/x64_context.py | 37 +--- jax/interpreters/ad.py | 4 +- jax/interpreters/batching.py | 2 +- jax/interpreters/partial_eval.py | 12 +- jax/interpreters/pxla.py | 6 +- jax/interpreters/sharded_jit.py | 2 +- jax/interpreters/xla.py | 29 +-- jax/linear_util.py | 2 +- jax/test_util.py | 4 +- mypy.ini | 3 +- tests/api_test.py | 66 +++---- tests/debug_nans_test.py | 4 +- tests/djax_test.py | 2 +- tests/lax_autodiff_test.py | 3 +- tests/lax_control_flow_test.py | 1 - tests/lax_scipy_sparse_test.py | 2 +- tests/lax_test.py | 2 +- tests/nn_test.py | 4 +- tests/random_test.py | 2 +- tests/x64_context_test.py | 2 +- 30 files changed, 265 insertions(+), 193 deletions(-) diff --git a/jax/__init__.py b/jax/__init__.py index cb9644791257..1cf4eb412cd8 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -30,7 +30,8 @@ del _cloud_tpu_init # flake8: noqa: F401 -from .config import config +from .config import (config, enable_checks, check_tracer_leaks, checking_leaks, + debug_nans, debug_infs, log_compiles) from .api import ( ad, # TODO(phawkins): update users to avoid this. argnums_partial, # TODO(phawkins): update Haiku to not use this. diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index 4ec2ab8498ad..472663c4516f 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -1113,7 +1113,7 @@ def _cond_typecheck(*avals, branches, linear): f'called with operands of type {_avals_short(op_avals)}') def cond_bind(*args, branches, linear): - if not core.skip_checks: + if config.jax_enable_checks: avals = _map(core.get_aval, args) _cond_typecheck(*avals, branches=branches, linear=linear) for jaxpr in branches: @@ -1876,7 +1876,7 @@ def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry, f'called with sequence of type\n{_avals_short(x_avals)}') def scan_bind(*args, **params): - if not core.skip_checks: + if config.jax_enable_checks: avals = _map(core.get_aval, args) _scan_typecheck(True, *avals, **params) core.check_jaxpr(params['jaxpr'].jaxpr) diff --git a/jax/_src/util.py b/jax/_src/util.py index 4467ee15dee3..3a681e03e824 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -21,7 +21,6 @@ import numpy as np -import jax from jax.config import config partial = functools.partial @@ -192,7 +191,7 @@ def cached(_, *args, **kwargs): @functools.wraps(f) def wrapper(*args, **kwargs): - if jax.core.debug_state.check_leaks: + if config.jax_check_tracer_leaks: return f(*args, **kwargs) else: return cached(bool(config.x64_enabled), *args, **kwargs) diff --git a/jax/api.py b/jax/api.py index 0beeb1fbcb00..6b6dfc465cef 100644 --- a/jax/api.py +++ b/jax/api.py @@ -41,7 +41,7 @@ from . import linear_util as lu from . import ad_util from . import dtypes -from .core import eval_jaxpr, checking_leaks +from .core import eval_jaxpr from .api_util import (flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial, argnums_partial_except, flatten_axes, donation_vector, @@ -362,7 +362,7 @@ def f_jitted(*args, **kwargs): context = (getattr(core.thread_local_state.trace_state.trace_stack, "dynamic", None), config.x64_enabled) # TODO(jblespiau): Move this to C++. - if (FLAGS.jax_debug_nans or FLAGS.jax_debug_infs) and not _jit_is_disabled(): + if (config.jax_debug_nans or config.jax_debug_infs) and not _jit_is_disabled(): device_arrays = cpp_jitted_f(context, *args, **kwargs) try: xla.check_special(xla.xla_call_p, [ @@ -372,7 +372,7 @@ def f_jitted(*args, **kwargs): ]) return device_arrays except FloatingPointError: - assert FLAGS.jax_debug_nans or FLAGS.jax_debug_infs # compiled_fun can only raise in this case + assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case print("Invalid nan value encountered in the output of a C++-jit " "function. Calling the de-optimized version.") return cache_miss(*args, **kwargs)[0] # probably won't return diff --git a/jax/config.py b/jax/config.py index a64acf34fbdc..352b96900fb8 100644 --- a/jax/config.py +++ b/jax/config.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import functools import os import sys +import threading + from jax import lib def bool_env(varname: str, default: bool) -> bool: @@ -42,11 +46,16 @@ def int_env(varname: str, default: int) -> int: class Config: + _HAS_DYNAMIC_ATTRIBUTES = True + def __init__(self): self.values = {} self.meta = {} self.FLAGS = NameSpace(self.read) self.use_absl = False + self._contextmanager_flags = set() + + # TODO(mattjj): delete these when only omnistaging is available self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True) self._omnistaging_disablers = [] @@ -65,6 +74,13 @@ def update(self, name, val): lib.jax_jit.global_state().enable_x64 = val def read(self, name): + if name in self._contextmanager_flags: + raise AttributeError( + "For flags with a corresponding contextmanager, read their value " + f"via e.g. `config.{name}` rather than `config.FLAGS.{name}`.") + return self._read(name) + + def _read(self, name): if self.use_absl: return getattr(self.absl_flags.FLAGS, name) else: @@ -143,14 +159,82 @@ def disable_omnistaging(self): disabler() self.omnistaging_enabled = False - @property - def x64_enabled(self): - return lib.jax_jit.get_enable_x64() - - # TODO(jakevdp): make this public when thread-local x64 is fully implemented. - def _set_x64_enabled(self, state): - lib.jax_jit.thread_local_state().enable_x64 = bool(state) - +# # TODO(jakevdp, mattjj): unify this with `define_bool_state` stuff below +# @property +# def x64_enabled(self): +# return lib.jax_jit.get_enable_x64() + +# def _set_x64_enabled(self, state): +# lib.jax_jit.thread_local_state().enable_x64 = bool(state) + + def define_bool_state(self, name: str, default: bool, help: str): + """Set up thread-local state and return a contextmanager for managing it. + + This function is a convenience wrapper. It defines a flag and corresponding + thread-local state, which can be managed via the contextmanager it returns. + + The thread-local state value can be read via the ``config.`` + attribute, where ``config`` is the singleton ``Config`` instance. + + Args: + name: string, converted to lowercase to define the name of the config + option (and absl flag). It is converted to uppercase to define the + corresponding shell environment variable. + default: boolean, a default value for the option. + help: string, used to populate the flag help information as well as the + docstring of the returned context manager. + + Returns: + A contextmanager to control the thread-local state value. + + Example: + + enable_foo = config.define_bool_state( + name='jax_enable_foo', + default=False, + help='Enable foo.') + + # Now the JAX_ENABLE_FOO shell environment variable and --jax_enable_foo + # command-line flag can be used to control the process-level value of + # the configuration option, in addition to using e.g. + # ``config.update("jax_enable_foo", True)`` directly. We can also use a + # context manager: + + with enable_foo(True): + ... + + The value of the thread-local state or flag can be accessed via + ``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is + an error. + """ + name = name.lower() + self.DEFINE_bool(name, bool_env(name.upper(), default), help) + self._contextmanager_flags.add(name) + + def get_state(self): + val = getattr(_thread_local_state, name, unset) + return val if val is not unset else self._read(name) + setattr(Config, name, property(get_state)) + + @contextlib.contextmanager + def set_state(new_val: bool): + prev_val = getattr(_thread_local_state, name, unset) + setattr(_thread_local_state, name, new_val) + try: + yield + finally: + if prev_val is unset: + delattr(_thread_local_state, name) + else: + setattr(_thread_local_state, name, prev_val) + set_state.__name__ = name[4:] if name.startswith('jax_') else name + set_state.__doc__ = f"Context manager for `{name}` config option.\n\n{help}" + return set_state + +_thread_local_state = threading.local() + +class Unset: pass +unset = Unset() class NameSpace(object): def __init__(self, getter): @@ -166,11 +250,6 @@ def __getattr__(self, name): already_configured_with_absl = False -flags.DEFINE_bool( - 'jax_enable_checks', - bool_env('JAX_ENABLE_CHECKS', False), - help='Turn on invariant checking (core.skip_checks = False)' -) flags.DEFINE_bool( 'jax_omnistaging', @@ -184,14 +263,6 @@ def __getattr__(self, name): help='Set the number of stack frames in JAX tracer error messages.' ) -flags.DEFINE_bool( - 'jax_check_tracer_leaks', - bool_env('JAX_CHECK_TRACER_LEAKS', False), - help=('Turn on checking for leaked tracers as soon as a trace completes. ' - 'Enabling leak checking may have performance impacts: some caching ' - 'is disabled, and other overheads may be added.'), -) - flags.DEFINE_bool( 'jax_host_callback_inline', bool_env('JAX_HOST_CALLBACK_INLINE', False), @@ -206,3 +277,72 @@ def __getattr__(self, name): 'until the Python callback consume more outfeeds.'), lower_bound=int(16 * 1e6) ) + + +enable_checks = config.define_bool_state( + name='jax_enable_checks', + default=False, + help='Turn on invariant checking for JAX internals. Makes things slower.') + +check_tracer_leaks = config.define_bool_state( + name='jax_check_tracer_leaks', + default=False, + help=('Turn on checking for leaked tracers as soon as a trace completes. ' + 'Enabling leak checking may have performance impacts: some caching ' + 'is disabled, and other overheads may be added.')) +checking_leaks = functools.partial(check_tracer_leaks, True) + +debug_nans = config.define_bool_state( + name='jax_debug_nans', + default=False, + help=('Add nan checks to every operation. When a nan is detected on the ' + 'output of a jit-compiled computation, call into the un-compiled ' + 'version in an attempt to more precisely identify the operation ' + 'which produced the nan.')) + +debug_infs = config.define_bool_state( + name='jax_debug_infs', + default=False, + help=('Add inf checks to every operation. When an inf is detected on the ' + 'output of a jit-compiled computation, call into the un-compiled ' + 'version in an attempt to more precisely identify the operation ' + 'which produced the inf.')) + +log_compiles = config.define_bool_state( + name='jax_log_compiles', + default=False, + help=('Log a message each time every time `jit` or `pmap` compiles an XLA ' + 'computation. Logging is performed with `absl.logging`. When this ' + 'option is set, the log level is WARNING; otherwise the level is ' + 'DEBUG.')) + +# Because jax_enable_x64 is managed by C++ code, we don't reuse the +# config.define_bool_state mechanism, though conceptually it is the same. +config.DEFINE_bool('jax_enable_x64', bool_env('JAX_ENABLE_X64', False), + help='Enable 64-bit types to be used') +lib.jax_jit.global_state().enable_x64 = bool_env('JAX_ENABLE_X64', False) + +@contextlib.contextmanager +def enable_x64(new_val: bool = True): + """Experimental context manager to temporarily enable X64 mode. + + Usage:: + + >>> import jax.numpy as jnp + >>> with enable_x64(True): + ... print(jnp.arange(10.0).dtype) + ... + float64 + """ + prev_val = config.jax_enable_x64 + lib.jax_jit.thread_local_state().enable_x64 = bool(new_val) + try: + yield + finally: + lib.jax_jit.thread_local_state().enable_x64 = prev_val +Config.jax_enable_x64 = property(lambda self: lib.jax_jit.get_enable_x64()) +config._contextmanager_flags.add('jax_enable_x64') + +# The `x64_enabled` property doesn't fit the naming scheme, but we use it for +# backward compatibility. +Config.x64_enabled = Config.jax_enable_x64 diff --git a/jax/core.py b/jax/core.py index 99f17a0b681e..14e0d42e8403 100644 --- a/jax/core.py +++ b/jax/core.py @@ -45,33 +45,6 @@ from ._src import traceback_util traceback_util.register_exclusion(__file__) -# TODO(mattjj): move this into debug_state -skip_checks = not FLAGS.jax_enable_checks - -@contextmanager -def skipping_checks(): - """Context manager for temporarily disabling internal checks.""" - global skip_checks - old_value, skip_checks = skip_checks, True - try: - yield - finally: - skip_checks = old_value - -@contextmanager -def checking_leaks(): - """Context manager for temporarily enabling tracer leak checks.""" - old_value, debug_state.check_leaks = debug_state.check_leaks, True - try: - yield - finally: - debug_state.check_leaks = old_value - -class DebugState(threading.local): - def __init__(self): - self.check_leaks = FLAGS.jax_check_tracer_leaks -debug_state = DebugState() - zip = safe_zip map = safe_map @@ -279,8 +252,8 @@ def __repr__(self): def bind(self, *args, **params): - assert skip_checks or all(isinstance(arg, Tracer) - or valid_jaxtype(arg) for arg in args), args + assert (not config.jax_enable_checks or + all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args top_trace = find_top_trace(args) tracers = map(top_trace.full_raise, args) out = top_trace.process_primitive(self, tracers, params) @@ -569,7 +542,7 @@ def __array_module__(self, types): return self.aval._array_module(self, types) def __getattr__(self, name): # if the aval property raises an AttributeError, gets caught here - assert skip_checks or name != "aval" + assert not config.jax_enable_checks or name != "aval" try: attr = getattr(self.aval, name) @@ -783,7 +756,7 @@ def new_main(trace_type: Type[Trace], if lib._xla_extension_version >= 11: jit_tls.extra_jit_context = extra_jit_context(stack) - if debug_state.check_leaks: + if config.jax_check_tracer_leaks: t = ref(main) del main if t() is not None: @@ -807,7 +780,7 @@ def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]: if lib._xla_extension_version >= 11: jit_tls.extra_jit_context = extra_jit_context(stack) - if debug_state.check_leaks: + if config.jax_check_tracer_leaks: t = ref(main) del main if t() is not None: @@ -827,7 +800,7 @@ def new_sublevel() -> Generator[None, None, None]: finally: thread_local_state.trace_state.substack.pop() - if debug_state.check_leaks: + if config.jax_check_tracer_leaks: t = ref(sublevel) del sublevel if t() is not None: @@ -899,7 +872,7 @@ class AbstractUnit(AbstractValue): # _num_buffers = 0 def at_least_vspace(self): return self def join(self, other): - if not skip_checks: + if config.jax_enable_checks: assert other is abstract_unit, other return self def _eq(self, self_traced, other): return get_aval(other) is self @@ -1932,7 +1905,7 @@ def new_main(trace_type: Type[Trace], bottom=False, **payload) -> Generator[Main finally: thread_local_state.trace_state.trace_stack.pop(bottom) - if debug_state.check_leaks: + if config.jax_check_tracer_leaks: t = ref(main) del main if t() is not None: @@ -1949,7 +1922,7 @@ def eval_context(): yield # dummy implementation for forward compatibility def bind(self, *args, **kwargs): - assert skip_checks or all(isinstance(arg, Tracer) + assert not config.jax_enable_checks or all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args), args top_trace = find_top_trace(args) if top_trace is None: diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index fe6c9a9d6390..955992293ed7 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -914,7 +914,7 @@ def rev(objective_fn, res, g): """ flat_args, in_tree = tree_flatten(example_args) in_avals = tuple(map(abstractify, flat_args)) - if core.debug_state.check_leaks: + if config.jax_check_tracer_leaks: return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals) else: return _closure_convert_for_avals(fun, in_tree, in_avals) diff --git a/jax/dtypes.py b/jax/dtypes.py index ff93da1bfae8..71e28331eecb 100644 --- a/jax/dtypes.py +++ b/jax/dtypes.py @@ -20,27 +20,19 @@ # so we need our own implementation that deviates from NumPy in places. -from distutils.util import strtobool import functools -import os from typing import Dict import numpy as np from ._src import util from .config import flags, config -from . import lib from .lib import xla_client from ._src import traceback_util traceback_util.register_exclusion(__file__) FLAGS = flags.FLAGS -flags.DEFINE_bool('jax_enable_x64', - strtobool(os.getenv('JAX_ENABLE_X64', 'False')), - 'Enable 64-bit types to be used.') -lib.jax_jit.global_state().enable_x64 = strtobool( - os.getenv('JAX_ENABLE_X64', 'False')) # bfloat16 support bfloat16: type = xla_client.bfloat16 diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 0ccc9938f1f0..259b96b09ce3 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -70,9 +70,9 @@ def _is_tfval(v: TfVal) -> bool: if isinstance(v, (tf.Tensor, tf.Variable)): return True try: - # Note: this conversion is overkill and just intended as a type check; this code - # is in principle only run if core.skip_checks is False. - # TODO: it is not true that this code is run only without skip_checks + # Note: this conversion is overkill and just intended as a type check; this + # code is in principle only run if config.jax_enable_checks is True. + # TODO: it is not true that this code is run only with jax_enable_checks. _safe_convert_to_tensor(v) return True except ValueError: @@ -353,7 +353,7 @@ def _tfval_shape_dtype(val: TfVal) -> Tuple[Sequence[Optional[int]], DType]: # May be partially known return tuple(val.shape), to_jax_dtype(val.dtype) else: # Must be a numeric value - assert core.skip_checks or _is_tfval(val), f"Non TfVal: {val}" + assert not config.jax_enable_checks or _is_tfval(val), f"Non TfVal: {val}" raw_aval = xla.abstractify(val) return raw_aval.shape, raw_aval.dtype # type: ignore[attr-defined] @@ -605,7 +605,7 @@ def __init__(self, trace: 'TensorFlowTrace', val: TfVal, val = tf.cast(val, dtype=aval_dtype) val_dtype = aval_dtype - if not core.skip_checks: + if config.jax_enable_checks: assert aval_dtype == val_dtype, f"expected {aval_dtype} == {val_dtype}" for aval_dim, val_dim in util.safe_zip(self._aval.shape, val_shape): # type: ignore[attr-defined] if val_dim is None: @@ -703,7 +703,7 @@ def process_primitive(self, primitive: core.Primitive, # Check that the impl rule returned a value of expected shape and dtype # TODO: adapt this to match polymorphic shapes - if not core.skip_checks: + if config.jax_enable_checks: if primitive.multiple_results: for o, expected_aval in zip(out, out_aval): # type: ignore assert o.aval.strip_weak_type() == expected_aval.strip_weak_type(), ( @@ -1530,7 +1530,7 @@ def _common_reduce_window(operand, init_val, reducer, window_dimensions, reducer_fn = tf.function(reducer, autograph=False).get_concrete_function(o_spec, o_spec) if not isinstance(init_val, tf.Tensor): - assert core.skip_checks or _is_tfval(init_val), f"Non TfVal: {init_val}" + assert not config.jax_enable_checks or _is_tfval(init_val), f"Non TfVal: {init_val}" init_val = tf.constant(init_val, operand.dtype) out = tfxla.reduce_window(operand, init_val, reducer_fn, window_dimensions, diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index af2d76fbdc54..080119d89d19 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -135,7 +135,7 @@ def test_bfloat16_returned_by_jax(self): dtype=dtype) for dtype in [np.int64, np.float64])) def test_converts_64bit(self, dtype=np.int64, with_function=False): - if not config.FLAGS.jax_enable_x64: + if not config.jax_enable_x64: self.skipTest("requires x64 mode") big_const = np.full((5,), 2 ** 33, dtype=dtype) self.ConvertAndCompare(jnp.sin, big_const) diff --git a/jax/experimental/x64_context.py b/jax/experimental/x64_context.py index 089f3a3690dc..1938adf1afe1 100644 --- a/jax/experimental/x64_context.py +++ b/jax/experimental/x64_context.py @@ -17,31 +17,14 @@ **Experimental: please give feedback, and expect changes.** """ -from contextlib import contextmanager -from jax import config - -@contextmanager -def enable_x64(): - """Experimental context manager to temporarily enable X64 mode. - - Usage:: - - >>> import jax.numpy as jnp - >>> with enable_x64(): - ... print(jnp.arange(10.0).dtype) - ... - float64 +# This file provides +# 1. a jax.experimental API endpoint; +# 2. the `disable_x64` wrapper. +# TODO(jakevdp): remove this file, and consider removing `disable_x64` for +# uniformity - See Also - -------- - jax.experimental.disable_x64 : temporarily disable X64 mode. - """ - _x64_state = config.x64_enabled - config._set_x64_enabled(True) - try: - yield - finally: - config._set_x64_enabled(_x64_state) +from contextlib import contextmanager +from jax.config import enable_x64 @contextmanager def disable_x64(): @@ -59,9 +42,5 @@ def disable_x64(): -------- jax.experimental.enable_x64 : temporarily enable X64 mode. """ - _x64_state = config.x64_enabled - config._set_x64_enabled(False) - try: + with enable_x64(False): yield - finally: - config._set_x64_enabled(_x64_state) diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 2dbf30e00006..0f139b391365 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -174,7 +174,7 @@ def write_cotangent(prim, v, ct): # assert v.aval == ct.aval, (prim, v.aval, ct.aval) return ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct - if not core.skip_checks: + if config.jax_enable_checks: ct_aval = core.get_aval(ct_env[v]) joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type() assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval) @@ -389,7 +389,7 @@ class JVPTracer(Tracer): __slots__ = ['primal', 'tangent'] def __init__(self, trace, primal, tangent): - if not core.skip_checks: + if config.jax_enable_checks: _primal_tangent_shapes_match(primal, tangent) self._trace = trace self.primal = primal diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index c81edef7948f..5925067baa8f 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -108,7 +108,7 @@ class BatchTracer(Tracer): __slots__ = ['val', 'batch_dim'] def __init__(self, trace, val, batch_dim: Optional[int]): - assert core.skip_checks or type(batch_dim) in (int, NotMapped) # type: ignore + assert not config.jax_enable_checks or type(batch_dim) in (int, NotMapped) # type: ignore self._trace = trace self.val = val self.batch_dim = batch_dim diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 4aa73c6c4bb0..fb4bd9127307 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -49,7 +49,7 @@ class PartialVal(tuple): """ def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]): pv, const = xs - if not core.skip_checks: + if config.jax_enable_checks: # type checks assert isinstance(pv, (AbstractValue, type(None))), xs assert isinstance(const, core.Tracer) or type(const) is Zero or core.valid_jaxtype(const), xs @@ -648,25 +648,25 @@ def getconstvar(c): const_vars, const_vals = unzip2(consts.items()) # The env_vars are pre-pended to the invars jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers), eqns) - core.skip_checks or core.check_jaxpr(jaxpr) + config.jax_enable_checks and core.check_jaxpr(jaxpr) return jaxpr, const_vals, env_vals @cache() def convert_constvars_jaxpr(jaxpr: Jaxpr): """Moves the constvars to the start of invars.""" - core.skip_checks or core.check_jaxpr(jaxpr) + config.jax_enable_checks and core.check_jaxpr(jaxpr) lifted_jaxpr = Jaxpr(constvars=(), invars=jaxpr.constvars + jaxpr.invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns) - core.skip_checks or core.check_jaxpr(lifted_jaxpr) + config.jax_enable_checks and core.check_jaxpr(lifted_jaxpr) return lifted_jaxpr def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int): - core.skip_checks or core.check_jaxpr(jaxpr) + config.jax_enable_checks and core.check_jaxpr(jaxpr) env_vars, invars = split_list(jaxpr.invars, [num_env_vars]) converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars, invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns) - core.skip_checks or core.check_jaxpr(converted_jaxpr) + config.jax_enable_checks and core.check_jaxpr(converted_jaxpr) return converted_jaxpr diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 70fec7dab933..99b58cd47526 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -486,7 +486,7 @@ def __init__(self, self.indices = indices self._npy_value = None self._one_replica_buffer_indices = None - if not core.skip_checks: + if config.jax_enable_checks: assert type(aval) is ShapedArray @property @@ -792,7 +792,7 @@ def dynamic_fun(dummy, *args): f"`axis_size` (or remove the `devices` argument). Got nested_replicas=" f"{jaxpr_replicas} and nested_partitions={num_partitions}") - log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG + log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, f"Compiling {fun.__name__} ({id(fun)}) for {num_global_shards} " f"devices with args {avals}. (num_replicas={num_global_replicas}" @@ -1387,7 +1387,7 @@ def mesh_callable(fun: lu.WrappedFun, global_axis_sizes = mesh.shape local_axis_sizes = local_mesh.shape - log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG + log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, f"Compiling {fun.__name__} ({id(fun)}) for {tuple(global_axis_sizes.items())} " f"mesh with args {local_in_untiled_avals}. Argument mapping: {in_axes}.") diff --git a/jax/interpreters/sharded_jit.py b/jax/interpreters/sharded_jit.py index 01448c05eeef..c6108a36b316 100644 --- a/jax/interpreters/sharded_jit.py +++ b/jax/interpreters/sharded_jit.py @@ -144,7 +144,7 @@ def _sharded_callable( for out, parts, lparts in safe_zip(global_out_avals, out_parts, local_out_parts)] - log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG + log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, f"Compiling {fun.__name__} for {nparts} devices with " f"args {global_abstract_args}.") diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 68ff07459de1..2ffa4c3db6c8 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -23,7 +23,7 @@ from absl import logging import numpy as np -from ..config import flags, bool_env, config +from ..config import config from .. import core from .. import ad_util from .. import dtypes @@ -58,17 +58,6 @@ XlaComputationBuilder = Any # xla_bridge._JaxComputationBuilder XlaExecutable = Any # xla_extension.LocalExecutable -FLAGS = flags.FLAGS -flags.DEFINE_bool('jax_debug_nans', - bool_env('JAX_DEBUG_NANS', False), - 'Add nan checks to every operation.') -flags.DEFINE_bool('jax_debug_infs', - bool_env('JAX_DEBUG_INFS', False), - 'Add inf checks to every operation.') -flags.DEFINE_bool('jax_log_compiles', - bool_env('JAX_LOG_COMPILES', False), - 'Print a message each time a `jit` computation is compiled.') - # This flag is set on exit; no logging should be attempted _on_exit = False @@ -244,7 +233,7 @@ def apply_primitive(prim, *args, **params): def _partition_outputs(avals, outs): nouts = [aval._num_buffers for aval in avals] - if not core.skip_checks: + if config.jax_enable_checks: assert sum(nouts) == len(outs), f"Internal error: sum(nouts)={sum(nouts)} should equal len(outs)={len(outs)}." outs = iter(outs) return [[next(outs) for _ in range(nout)] for nout in nouts] @@ -372,7 +361,7 @@ def _execute_replicated_primitive(prim, compiled, result_handler, *args): return result_handler(*out_bufs) def needs_check_special(): - return FLAGS.jax_debug_infs or FLAGS.jax_debug_nans + return config.jax_debug_infs or config.jax_debug_nans def check_special(name, bufs): if needs_check_special(): @@ -382,9 +371,9 @@ def check_special(name, bufs): def _check_special(name, xla_shape, buf): assert not xla_shape.is_tuple() if dtypes.issubdtype(xla_shape.element_type(), np.inexact): - if FLAGS.jax_debug_nans and np.any(np.isnan(buf.to_py())): + if config.jax_debug_nans and np.any(np.isnan(buf.to_py())): raise FloatingPointError(f"invalid value (nan) encountered in {name}") - if FLAGS.jax_debug_infs and np.any(np.isinf(buf.to_py())): + if config.jax_debug_infs and np.any(np.isinf(buf.to_py())): raise FloatingPointError(f"invalid value (inf) encountered in {name}") ### compiling jaxprs @@ -590,13 +579,13 @@ def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_inv try: return compiled_fun(*args) except FloatingPointError: - assert FLAGS.jax_debug_nans or FLAGS.jax_debug_infs # compiled_fun can only raise in this case + assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case print("Invalid value encountered in the output of a jit function. " "Calling the de-optimized version.") # We want to run the wrapped function again (after _xla_callable already ran # it), but linear_util.WrappedFun instances are meant to be run only once. # In addition to re-executing the Python code, which is usually undesirable - # but which FLAGS.jax_debug_nans is meant to opt into, we'll be re-executing + # but which config.jax_debug_nans is meant to opt into, we'll be re-executing # any linear_util.py-style side effects, i.e. re-populating Stores created # by any transformation_with_aux's applied to fun. Since this is # intentional here, to avoid "Store occupied" errors we reset the stores to @@ -688,7 +677,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar return partial(_execute_trivial, jaxpr, device, consts, out_avals, result_handlers) if not _on_exit: - log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG + log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG logging.log(log_priority, "Compiling %s (%s) for args %s.", fun.__name__, id(fun), abstract_args) @@ -1096,7 +1085,7 @@ def __init__(self, aval: core.ShapedArray, device: Optional[Device], self._device = device self._npy_value = None - if not core.skip_checks: + if config.jax_enable_checks: assert type(aval) is ShapedArray npy_value = self._value assert npy_value.dtype == aval.dtype and npy_value.shape == aval.shape diff --git a/jax/linear_util.py b/jax/linear_util.py index 3a8f156a0deb..5a2caed887e3 100644 --- a/jax/linear_util.py +++ b/jax/linear_util.py @@ -248,7 +248,7 @@ def cache(call: Callable): def memoized_fun(fun: WrappedFun, *args): cache = fun_caches.setdefault(fun.f, {}) - if core.debug_state.check_leaks: + if config.jax_check_tracer_leaks: key = (_copy_main_traces(fun.transforms), fun.params, args, config.x64_enabled) else: key = (fun.transforms, fun.params, args, config.x64_enabled) diff --git a/jax/test_util.py b/jax/test_util.py index fba09e7c4051..732435323e41 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -445,7 +445,7 @@ def skip_on_flag(flag_name, skip_value): def skip(test_method): # pylint: disable=missing-docstring @functools.wraps(test_method) def test_method_wrapper(self, *args, **kwargs): - flag_value = getattr(FLAGS, flag_name) + flag_value = config._read(flag_name) if flag_value == skip_value: test_name = getattr(test_method, '__name__', '[unknown test]') raise unittest.SkipTest( @@ -819,7 +819,7 @@ class JaxTestCase(parameterized.TestCase): def setUp(self): super(JaxTestCase, self).setUp() - core.skip_checks = False + config.update('jax_enable_checks', True) # We use the adler32 hash for two reasons. # a) it is deterministic run to run, unlike hash() which is randomized. # b) it returns values in int32 range, which RandomState requires. diff --git a/mypy.ini b/mypy.ini index a7a5225c689b..f55bb6e50e91 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,5 +1,6 @@ [mypy] -show_error_codes=True +show_error_codes = True +disable_error_code = attr-defined [mypy-absl.*] ignore_missing_imports = True diff --git a/tests/api_test.py b/tests/api_test.py index 3bf5a9c31852..cad8776aaf4a 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2217,7 +2217,7 @@ def test_leak_checker_catches_a_jit_leak(self): if not config.omnistaging_enabled: raise unittest.SkipTest("test only works with omnistaging") - with core.checking_leaks(): + with jax.checking_leaks(): lst = [] @jit @@ -2232,7 +2232,7 @@ def test_leak_checker_catches_a_pmap_leak(self): if not config.omnistaging_enabled: raise unittest.SkipTest("test only works with omnistaging") - with core.checking_leaks(): + with jax.checking_leaks(): lst = [] @api.pmap @@ -2247,7 +2247,7 @@ def test_leak_checker_catches_a_grad_leak(self): if not config.omnistaging_enabled: raise unittest.SkipTest("test only works with omnistaging") - with core.checking_leaks(): + with jax.checking_leaks(): lst = [] def f(x): @@ -2261,7 +2261,7 @@ def test_leak_checker_avoids_false_positives(self): if not config.omnistaging_enabled: raise unittest.SkipTest("test only works with omnistaging") - with core.checking_leaks(): + with jax.checking_leaks(): @jit def f(x): return x @@ -2279,7 +2279,7 @@ def test_leak_checker_catches_a_scan_leak(self): if not config.omnistaging_enabled: raise unittest.SkipTest("test only works with omnistaging") - with core.checking_leaks(): + with jax.checking_leaks(): lst = [] to_scan = lambda c, x: (lst.append(c) or jnp.sin(c), None) @@ -2291,7 +2291,7 @@ def test_leak_checker_avoids_false_positives_scan(self): if not config.omnistaging_enabled: raise unittest.SkipTest("test only works with omnistaging") - with core.checking_leaks(): + with jax.checking_leaks(): to_scan = lambda c, x: (jnp.sin(c), None) lax.scan(to_scan, 1., np.arange(3.)) # doesn't crash @@ -2299,7 +2299,7 @@ def test_leak_checker_avoids_false_positives_scan_jvp(self): if not config.omnistaging_enabled: raise unittest.SkipTest("test only works with omnistaging") - with core.checking_leaks(): + with jax.checking_leaks(): to_scan = lambda c, x: (c, None) def f(x): @@ -2310,7 +2310,7 @@ def test_leak_checker_avoids_false_positives_scan_vmap(self): if not config.omnistaging_enabled: raise unittest.SkipTest("test only works with omnistaging") - with core.checking_leaks(): + with jax.checking_leaks(): to_scan = lambda c, _: (1., None) @api.vmap @@ -2322,7 +2322,7 @@ def test_leak_checker_avoids_false_positives_scan_vmap_2(self): if not config.omnistaging_enabled: raise unittest.SkipTest("test only works with omnistaging") - with core.checking_leaks(): + with jax.checking_leaks(): to_scan = lambda c, _: (c, None) @api.vmap @@ -2334,7 +2334,7 @@ def test_leak_checker_catches_a_sublevel_leak(self): if not config.omnistaging_enabled: raise unittest.SkipTest("test only works with omnistaging") - with core.checking_leaks(): + with jax.checking_leaks(): @jit def f(x): lst = [] @@ -4957,29 +4957,29 @@ def f_vjp(x, y): api.grad(lambda x, y: f(x, y)[0])(1., 2.) # doesn't crash def test_custom_transforms_vjp_nones(self): - core.skip_checks = True # Fails with checks - # issue raised by jsnoek@ and jumper@ - @jax.custom_transforms - def solve(a, b): - return jnp.dot(jnp.linalg.inv(a), b) - # print(solve(a, b)) - - def solve_vjp(a, b): - x = solve(a, b) - def vjp(x_tangent): - dx = jnp.dot(solve(a, x_tangent), x.T) - out = (dx, b * 0.) - return out - return x, vjp - jax.defvjp_all(solve, solve_vjp) - gf = grad(lambda a,b: jnp.sum(solve(a, b))) - - n = 3 - a_in = jnp.linspace(0, 1, n)[:, None] - a = jnp.dot(a_in, a_in.T) + jnp.eye(n) * 0.1 - real_x = np.random.RandomState(0).randn(n) - b = jnp.dot(a + jnp.eye(a.shape[0]), real_x) - print(gf(a, b)) # doesn't crash + with jax.enable_checks(False): # fails with checks + # issue raised by jsnoek@ and jumper@ + @jax.custom_transforms + def solve(a, b): + return jnp.dot(jnp.linalg.inv(a), b) + # print(solve(a, b)) + + def solve_vjp(a, b): + x = solve(a, b) + def vjp(x_tangent): + dx = jnp.dot(solve(a, x_tangent), x.T) + out = (dx, b * 0.) + return out + return x, vjp + jax.defvjp_all(solve, solve_vjp) + gf = grad(lambda a,b: jnp.sum(solve(a, b))) + + n = 3 + a_in = jnp.linspace(0, 1, n)[:, None] + a = jnp.dot(a_in, a_in.T) + jnp.eye(n) * 0.1 + real_x = np.random.RandomState(0).randn(n) + b = jnp.dot(a + jnp.eye(a.shape[0]), real_x) + print(gf(a, b)) # doesn't crash class BufferDonationTest(jtu.BufferDonationTestCase): diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index 99785ff0f869..dba67072aa88 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -30,7 +30,7 @@ class DebugNaNsTest(jtu.JaxTestCase): def setUp(self): - self.cfg = config.read("jax_debug_nans") + self.cfg = config._read("jax_debug_nans") config.update("jax_debug_nans", True) def tearDown(self): @@ -144,7 +144,7 @@ def testPjit(self): class DebugInfsTest(jtu.JaxTestCase): def setUp(self): - self.cfg = config.read("jax_debug_infs") + self.cfg = config._read("jax_debug_infs") config.update("jax_debug_infs", True) def tearDown(self): diff --git a/tests/djax_test.py b/tests/djax_test.py index 42629b5e12be..bd8fc8b40c0e 100644 --- a/tests/djax_test.py +++ b/tests/djax_test.py @@ -148,7 +148,7 @@ def f(x): y = sin(x) return reduce_sum(y, axes=(0,)) x = bbarray((5,), jnp.arange(2.)) - with jax.core.skipping_checks(): # TODO implement dxla_call abs eval rule + with jax.enable_checks(False): # TODO implement dxla_call abs eval rule z, f_lin = jax.linearize(f, x) z_dot = f_lin(ones_like(x)) diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index 2f97b0b28430..92d496b123c4 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -25,7 +25,6 @@ import jax from jax import api -from jax import core from jax import dtypes from jax import lax from jax import test_util as jtu @@ -992,7 +991,7 @@ def f2(x, y): expected = np.array(0.0) self.assertAllClose(ans, expected, check_dtypes=False) - with core.skipping_checks(): + with jax.enable_checks(False): with self.assertRaises(TypeError): lax.stop_gradient(lambda x: x) diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index 2a8f23ef6cbd..f23409ba8325 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -1709,7 +1709,6 @@ def scan_body(c, x): self.assertAllClose(carry_out[1], carry_init, check_dtypes=False) self.assertAllClose(carry_out[0], jnp.array([2., 2., 2.]), check_dtypes = False) - # TODO(mattjj, dougalm): fix this test when skip_checks is False def testIssue757(self): # code from https://github.com/google/jax/issues/757 def fn(a): diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index 18c8fd960667..0216be48f6d4 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -208,7 +208,7 @@ def tree_unflatten(cls, aux_data, children): )) def test_bicgstab_against_scipy( self, shape, dtype, preconditioner): - if not config.FLAGS.jax_enable_x64: + if not config.jax_enable_x64: raise unittest.SkipTest("requires x64 mode") rng = jtu.rand_default(self.rng()) diff --git a/tests/lax_test.py b/tests/lax_test.py index 2eaf6db35a3b..3733f07c954f 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -2190,7 +2190,7 @@ def test_tie_in_error(self): # api.make_jaxpr(lambda x: lax.tie_in((x, x), 1))(1.) def test_primitive_jaxtype_error(self): - with core.skipping_checks(): + with jax.enable_checks(False): with self.assertRaisesRegex( TypeError, "Argument .* of type .* is not a valid JAX type"): lax.add(1, 'hi') diff --git a/tests/nn_test.py b/tests/nn_test.py index 42cca92fca3f..edcb20035aba 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -117,12 +117,12 @@ def testDtypeMatchesInput(self, dtype, fn): def testEluMemory(self): # see https://github.com/google/jax/pull/1640 - with core.skipping_checks(): # With checks we materialize the array + with jax.enable_checks(False): # With checks we materialize the array jax.make_jaxpr(lambda: nn.elu(jnp.ones((10 ** 12,)))) # don't oom def testHardTanhMemory(self): # see https://github.com/google/jax/pull/1640 - with core.skipping_checks(): # With checks we materialize the array + with jax.enable_checks(False): # With checks we materialize the array jax.make_jaxpr(lambda: nn.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom def testOneHot(self): diff --git a/tests/random_test.py b/tests/random_test.py index 3eb3e6b54bf5..f370cd4a5535 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -914,7 +914,7 @@ def test_eval_shape_big_random_array(self): raise SkipTest("after deleting lazy constants, requires omnistaging") def f(x): return random.normal(random.PRNGKey(x), (int(1e12),)) - with core.skipping_checks(): # check_jaxpr will materialize array + with jax.enable_checks(False): # check_jaxpr will materialize array api.eval_shape(f, 0) # doesn't error @parameterized.named_parameters(jtu.cases_from_list( diff --git a/tests/x64_context_test.py b/tests/x64_context_test.py index 84b37013929d..9e809d9a805b 100644 --- a/tests/x64_context_test.py +++ b/tests/x64_context_test.py @@ -73,7 +73,7 @@ def test_correctly_capture_default(self, jit, enable_or_disable): func = _maybe_jit(jit, lambda: jnp.arange(10.0)) func() - expected_dtype = "float64" if config.read("jax_enable_x64") else "float32" + expected_dtype = "float64" if config._read("jax_enable_x64") else "float32" self.assertEqual(func().dtype, expected_dtype) with enable_x64():