From 44c2a88674eb6e1e6a52ed1ef7db6cd297810f93 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 30 Oct 2023 10:08:45 +0000 Subject: [PATCH] Forked a subset of JAX configuration APIs These APIs are internal to JAX and should not be used by other projects for managing their configuration. --- docs/api_reference/flax.config.rst | 3 +- flax/configurations.py | 135 ++++++++++++++++++++++------- tests/configurations_test.py | 52 +++++++++++ 3 files changed, 159 insertions(+), 31 deletions(-) create mode 100644 tests/configurations_test.py diff --git a/docs/api_reference/flax.config.rst b/docs/api_reference/flax.config.rst index 86d949ffc2..817f22ff28 100644 --- a/docs/api_reference/flax.config.rst +++ b/docs/api_reference/flax.config.rst @@ -4,4 +4,5 @@ flax.config package .. automodule:: flax.configurations :members: - :exclude-members: temp_flip_flag + :undoc-members: + :exclude-members: FlagHolder, bool_flag, temp_flip_flag, static_bool_env diff --git a/flax/configurations.py b/flax/configurations.py index f11a22b10c..d605e78e00 100644 --- a/flax/configurations.py +++ b/flax/configurations.py @@ -12,40 +12,115 @@ # See the License for the specific language governing permissions and # limitations under the License. -r"""Global configuration options for Flax. - -Now a wrapper over jax.config, in which all config vars have a 'flax\_' prefix. - -To modify a config value on run time, call: -``flax.config.update('flax_', )`` -""" +"""Global configuration flags for Flax.""" import os from contextlib import contextmanager +from typing import Any, Generic, NoReturn, TypeVar, overload + +_T = TypeVar('_T') + + +class Config: + # See https://google.github.io/pytype/faq.html. + _HAS_DYNAMIC_ATTRIBUTES = True + + def __init__(self): + self._values = {} + + def _add_option(self, name, default): + if name in self._values: + raise RuntimeError(f'Config option {name} already defined') + self._values[name] = default + + def _read(self, name): + try: + return self._values[name] + except KeyError: + raise LookupError(f'Unrecognized config option: {name}') + + @overload + def update(self, name: str, value: Any, /) -> None: + ... + + @overload + def update(self, holder: 'FlagHolder[_T]', value: _T, /) -> None: + ... + + def update(self, name_or_holder, value, /): + """Modify the value of a given flag. + + Args: + name_or_holder: the name of the flag to modify or the corresponding + flag holder object. + value: new value to set. + """ + name = name_or_holder + if isinstance(name_or_holder, FlagHolder): + name = name_or_holder.name + if name not in self._values: + raise LookupError(f'Unrecognized config option: {name}') + self._values[name] = value -from jax import config as jax_config -# Keep a wrapper at the flax namespace, in case we make our implementation -# in the future. -config = jax_config +config = Config() # Config parsing utils -def define_bool_state(name, default, help): - """Set up a boolean flag using JAX's config system. +class FlagHolder(Generic[_T]): + def __init__(self, name, help): + self.name = name + self.__name__ = name[4:] if name.startswith('flax_') else name + self.__doc__ = f'Flag holder for `{name}`.\n\n{help}' - The flag will actually be stored as an environment variable of - 'FLAX_'. JAX config ensures that the flag can be overwritten - on runtime with `flax.config.update('flax_', )`. + def __bool__(self) -> NoReturn: + raise TypeError( + "bool() not supported for instances of type '{0}' " + "(did you mean to use '{0}.value' instead?)".format(type(self).__name__) + ) + + @property + def value(self) -> _T: + return config._read(self.name) + + +def bool_flag(name: str, *, default: bool, help: str) -> FlagHolder[bool]: + """Set up a boolean flag. + + Example:: + + enable_foo = bool_flag( + name='flax_enable_foo', + default=False, + help='Enable foo.', + ) + + Now the ``FLAX_ENABLE_FOO`` shell environment variable can be used to + control the process-level value of the flag, in addition to using e.g. + ``config.update("flax_enable_foo", True)`` directly. + + Args: + name: converted to lowercase to define the name of the flag. It is + converted to uppercase to define the corresponding shell environment + variable. + default: a default value for the flag. + help: used to populate the docstring of the returned flag holder object. + + Returns: + A flag holder object for accessing the value of the flag. """ - return jax_config.define_bool_state('flax_' + name, default, help) + name = name.lower() + config._add_option(name, static_bool_env(name.upper(), default)) + fh = FlagHolder[bool](name, help) + setattr(Config, name, property(lambda _: fh.value, doc=help)) + return fh def static_bool_env(varname: str, default: bool) -> bool: """Read an environment variable and interpret it as a boolean. - This is deprecated. Please use define_bool_state() unless your flag + This is deprecated. Please use bool_flag() unless your flag will be used in a static method and does not require runtime updates. True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1'; @@ -90,39 +165,39 @@ def temp_flip_flag(var_name: str, var_value: bool): # Whether to use the lazy rng implementation. flax_lazy_rng = static_bool_env('FLAX_LAZY_RNG', True) -flax_filter_frames = define_bool_state( - name='filter_frames', +flax_filter_frames = bool_flag( + name='flax_filter_frames', default=True, help='Whether to hide flax-internal stack frames from tracebacks.', ) -flax_profile = define_bool_state( - name='profile', +flax_profile = bool_flag( + name='flax_profile', default=True, help='Whether to run Module methods under jax.named_scope for profiles.', ) -flax_use_orbax_checkpointing = define_bool_state( - name='use_orbax_checkpointing', +flax_use_orbax_checkpointing = bool_flag( + name='flax_use_orbax_checkpointing', default=True, help='Whether to use Orbax to save checkpoints.', ) -flax_preserve_adopted_names = define_bool_state( - name='preserve_adopted_names', +flax_preserve_adopted_names = bool_flag( + name='flax_preserve_adopted_names', default=False, help="When adopting outside modules, don't clobber existing names.", ) # TODO(marcuschiam): remove this feature flag once regular dict migration is complete -flax_return_frozendict = define_bool_state( - name='return_frozendict', +flax_return_frozendict = bool_flag( + name='flax_return_frozendict', default=False, help='Whether to return FrozenDicts when calling init or apply.', ) -flax_fix_rng = define_bool_state( - name='fix_rng_separator', +flax_fix_rng = bool_flag( + name='flax_fix_rng_separator', default=False, help=( 'Whether to add separator characters when folding in static data into' diff --git a/tests/configurations_test.py b/tests/configurations_test.py new file mode 100644 index 0000000000..31a93e2d41 --- /dev/null +++ b/tests/configurations_test.py @@ -0,0 +1,52 @@ +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from absl.testing import absltest + +from flax.configurations import bool_flag, config + + +class MyTestCase(absltest.TestCase): + def setUp(self): + super().setUp() + self.enter_context(mock.patch.object(config, '_values', {})) + self._flag = bool_flag('test', default=False, help='Just a test flag.') + + def test_duplicate_flag(self): + with self.assertRaisesRegex(RuntimeError, 'already defined'): + bool_flag(self._flag.name, default=False, help='Another test flag.') + + def test_default(self): + self.assertFalse(self._flag.value) + self.assertFalse(config.test) + + def test_typed_update(self): + config.update(self._flag, True) + self.assertTrue(self._flag.value) + self.assertTrue(config.test) + + def test_untyped_update(self): + config.update(self._flag.name, True) + self.assertTrue(self._flag.value) + self.assertTrue(config.test) + + def test_update_unknown_flag(self): + with self.assertRaisesRegex(LookupError, 'Unrecognized config option'): + config.update('unknown', True) + + +if __name__ == '__main__': + absltest.main()