Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forked a subset of JAX configuration APIs #3448

Merged
merged 1 commit into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/api_reference/flax.config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ flax.config package

.. automodule:: flax.configurations
:members:
:exclude-members: temp_flip_flag
:undoc-members:
:exclude-members: FlagHolder, bool_flag, temp_flip_flag, static_bool_env
135 changes: 105 additions & 30 deletions flax/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,115 @@
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Global configuration options for Flax.

Now a wrapper over jax.config, in which all config vars have a 'flax\_' prefix.

To modify a config value on run time, call:
``flax.config.update('flax_<config_name>', <value>)``
"""
"""Global configuration flags for Flax."""

import os
from contextlib import contextmanager
from typing import Any, Generic, NoReturn, TypeVar, overload

_T = TypeVar('_T')


class Config:
# See https://google.github.io/pytype/faq.html.
_HAS_DYNAMIC_ATTRIBUTES = True

def __init__(self):
self._values = {}

def _add_option(self, name, default):
if name in self._values:
raise RuntimeError(f'Config option {name} already defined')
self._values[name] = default

def _read(self, name):
try:
return self._values[name]
except KeyError:
raise LookupError(f'Unrecognized config option: {name}')

@overload
def update(self, name: str, value: Any, /) -> None:
...

@overload
def update(self, holder: 'FlagHolder[_T]', value: _T, /) -> None:
...

def update(self, name_or_holder, value, /):
"""Modify the value of a given flag.

Args:
name_or_holder: the name of the flag to modify or the corresponding
flag holder object.
value: new value to set.
"""
name = name_or_holder
if isinstance(name_or_holder, FlagHolder):
name = name_or_holder.name
if name not in self._values:
raise LookupError(f'Unrecognized config option: {name}')
self._values[name] = value

from jax import config as jax_config

# Keep a wrapper at the flax namespace, in case we make our implementation
# in the future.
config = jax_config
config = Config()

# Config parsing utils


def define_bool_state(name, default, help):
"""Set up a boolean flag using JAX's config system.
class FlagHolder(Generic[_T]):
def __init__(self, name, help):
self.name = name
self.__name__ = name[4:] if name.startswith('flax_') else name
self.__doc__ = f'Flag holder for `{name}`.\n\n{help}'

The flag will actually be stored as an environment variable of
'FLAX_<UPPERCASE_NAME>'. JAX config ensures that the flag can be overwritten
on runtime with `flax.config.update('flax_<config_name>', <value>)`.
def __bool__(self) -> NoReturn:
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(type(self).__name__)
)

@property
def value(self) -> _T:
return config._read(self.name)


def bool_flag(name: str, *, default: bool, help: str) -> FlagHolder[bool]:
"""Set up a boolean flag.

Example::

enable_foo = bool_flag(
name='flax_enable_foo',
default=False,
help='Enable foo.',
)

Now the ``FLAX_ENABLE_FOO`` shell environment variable can be used to
control the process-level value of the flag, in addition to using e.g.
``config.update("flax_enable_foo", True)`` directly.

Args:
name: converted to lowercase to define the name of the flag. It is
converted to uppercase to define the corresponding shell environment
variable.
default: a default value for the flag.
help: used to populate the docstring of the returned flag holder object.

Returns:
A flag holder object for accessing the value of the flag.
"""
return jax_config.define_bool_state('flax_' + name, default, help)
name = name.lower()
config._add_option(name, static_bool_env(name.upper(), default))
fh = FlagHolder[bool](name, help)
setattr(Config, name, property(lambda _: fh.value, doc=help))
return fh


def static_bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean.

This is deprecated. Please use define_bool_state() unless your flag
This is deprecated. Please use bool_flag() unless your flag
will be used in a static method and does not require runtime updates.

True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1';
Expand Down Expand Up @@ -90,39 +165,39 @@ def temp_flip_flag(var_name: str, var_value: bool):
# Whether to use the lazy rng implementation.
flax_lazy_rng = static_bool_env('FLAX_LAZY_RNG', True)

flax_filter_frames = define_bool_state(
name='filter_frames',
flax_filter_frames = bool_flag(
name='flax_filter_frames',
default=True,
help='Whether to hide flax-internal stack frames from tracebacks.',
)

flax_profile = define_bool_state(
name='profile',
flax_profile = bool_flag(
name='flax_profile',
default=True,
help='Whether to run Module methods under jax.named_scope for profiles.',
)

flax_use_orbax_checkpointing = define_bool_state(
name='use_orbax_checkpointing',
flax_use_orbax_checkpointing = bool_flag(
name='flax_use_orbax_checkpointing',
default=True,
help='Whether to use Orbax to save checkpoints.',
)

flax_preserve_adopted_names = define_bool_state(
name='preserve_adopted_names',
flax_preserve_adopted_names = bool_flag(
name='flax_preserve_adopted_names',
default=False,
help="When adopting outside modules, don't clobber existing names.",
)

# TODO(marcuschiam): remove this feature flag once regular dict migration is complete
flax_return_frozendict = define_bool_state(
name='return_frozendict',
flax_return_frozendict = bool_flag(
name='flax_return_frozendict',
default=False,
help='Whether to return FrozenDicts when calling init or apply.',
)

flax_fix_rng = define_bool_state(
name='fix_rng_separator',
flax_fix_rng = bool_flag(
name='flax_fix_rng_separator',
default=False,
help=(
'Whether to add separator characters when folding in static data into'
Expand Down
52 changes: 52 additions & 0 deletions tests/configurations_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

from absl.testing import absltest

from flax.configurations import bool_flag, config


class MyTestCase(absltest.TestCase):
def setUp(self):
super().setUp()
self.enter_context(mock.patch.object(config, '_values', {}))
self._flag = bool_flag('test', default=False, help='Just a test flag.')

def test_duplicate_flag(self):
with self.assertRaisesRegex(RuntimeError, 'already defined'):
bool_flag(self._flag.name, default=False, help='Another test flag.')

def test_default(self):
self.assertFalse(self._flag.value)
self.assertFalse(config.test)

def test_typed_update(self):
config.update(self._flag, True)
self.assertTrue(self._flag.value)
self.assertTrue(config.test)

def test_untyped_update(self):
config.update(self._flag.name, True)
self.assertTrue(self._flag.value)
self.assertTrue(config.test)

def test_update_unknown_flag(self):
with self.assertRaisesRegex(LookupError, 'Unrecognized config option'):
config.update('unknown', True)


if __name__ == '__main__':
absltest.main()
Loading