Skip to content

Commit

Permalink
Forked a subset of JAX configuration APIs
Browse files Browse the repository at this point in the history
These APIs are internal to JAX and should not be used by other projects for
managing their configuration.
  • Loading branch information
superbobry committed Oct 30, 2023
1 parent c008753 commit 927cf87
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 31 deletions.
3 changes: 2 additions & 1 deletion docs/api_reference/flax.config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ flax.config package

.. automodule:: flax.configurations
:members:
:exclude-members: temp_flip_flag
:undoc-members:
:exclude-members: FlagHolder, bool_flag, temp_flip_flag, static_bool_env
134 changes: 104 additions & 30 deletions flax/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,114 @@
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Global configuration options for Flax.
r"""Global configuration flags for Flax."""

Now a wrapper over jax.config, in which all config vars have a 'flax\_' prefix.
import os
from contextlib import contextmanager
from typing import Generic, NoReturn, TypeVar, overload, Any

To modify a config value on run time, call:
``flax.config.update('flax_<config_name>', <value>)``
"""
_T = TypeVar("_T")

import os
from jax import config as jax_config

from contextlib import contextmanager
class Config:
# See https://google.github.io/pytype/faq.html.
_HAS_DYNAMIC_ATTRIBUTES = True

def __init__(self):
self._values = {}

def _add_option(self, name, default):
if name in self._values:
raise RuntimeError(f"Config option {name} already defined")
self._values[name] = default

def _read(self, name):
try:
return self._values[name]
except KeyError:
raise LookupError(f"Unrecognized config option: {name}")

@overload
def update(self, name: str, value: Any, /) -> None:
...
@overload
def update(self, holder: "FlagHolder[_T]", value: _T, /) -> None:
...
def update(self, name_or_holder, value, /):
"""Modify the value of a given flag.
Args:
name_or_holder: the name of the flag to modify or the corresponding
flag holder object.
value: new value to set.
"""
name = name_or_holder
if isinstance(name_or_holder, FlagHolder):
name = name_or_holder.name
if name not in self._values:
raise LookupError(f"Unrecognized config option: {name}")
self._values[name] = value


config = Config()

# Keep a wrapper at the flax namespace, in case we make our implementation
# in the future.
config = jax_config

# Config parsing utils


def define_bool_state(name, default, help):
"""Set up a boolean flag using JAX's config system.
class FlagHolder(Generic[_T]):
def __init__(self, name, help):
self.name = name
self.__name__ = name[4:] if name.startswith("flax_") else name
self.__doc__ = f"Flag holder for `{name}`.\n\n{help}"

The flag will actually be stored as an environment variable of
'FLAX_<UPPERCASE_NAME>'. JAX config ensures that the flag can be overwritten
on runtime with `flax.config.update('flax_<config_name>', <value>)`.
def __bool__(self) -> NoReturn:
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(
type(self).__name__))

@property
def value(self) -> _T:
return config._read(self.name)


def bool_flag(name: str, default: bool, help: str) -> FlagHolder[bool]:
"""Set up a boolean flag.
Example::
enable_foo = bool_flag(
name='flax_enable_foo',
default=False,
help='Enable foo.',
)
Now the ``FLAX_ENABLE_FOO`` shell environment variable can be used to
control the process-level value of the flag, in addition to using e.g.
``config.update("flax_enable_foo", True)`` directly.
Args:
name: converted to lowercase to define the name of the flag. It is
converted to uppercase to define the corresponding shell environment
variable.
default: a default value for the flag.
help: used to populate the docstring of the returned flag holder object.
Returns:
A flag holder object for accessing the value of the flag.
"""
return jax_config.define_bool_state('flax_' + name, default, help)
name = name.lower()
config._add_option(name, static_bool_env(name.upper(), default))
fh = FlagHolder[bool](name, help)
setattr(Config, name, property(lambda _: fh.value, doc=help))
return fh


def static_bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean.
This is deprecated. Please use define_bool_state() unless your flag
This is deprecated. Please use bool_flag() unless your flag
will be used in a static method and does not require runtime updates.
True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1';
Expand Down Expand Up @@ -90,39 +164,39 @@ def temp_flip_flag(var_name: str, var_value: bool):
# Whether to use the lazy rng implementation.
flax_lazy_rng = static_bool_env('FLAX_LAZY_RNG', True)

flax_filter_frames = define_bool_state(
name='filter_frames',
flax_filter_frames = bool_flag(
name='flax_filter_frames',
default=True,
help='Whether to hide flax-internal stack frames from tracebacks.',
)

flax_profile = define_bool_state(
name='profile',
flax_profile = bool_flag(
name='flax_profile',
default=True,
help='Whether to run Module methods under jax.named_scope for profiles.',
)

flax_use_orbax_checkpointing = define_bool_state(
name='use_orbax_checkpointing',
flax_use_orbax_checkpointing = bool_flag(
name='flax_use_orbax_checkpointing',
default=True,
help='Whether to use Orbax to save checkpoints.',
)

flax_preserve_adopted_names = define_bool_state(
name='preserve_adopted_names',
flax_preserve_adopted_names = bool_flag(
name='flax_preserve_adopted_names',
default=False,
help="When adopting outside modules, don't clobber existing names.",
)

# TODO(marcuschiam): remove this feature flag once regular dict migration is complete
flax_return_frozendict = define_bool_state(
name='return_frozendict',
flax_return_frozendict = bool_flag(
name='flax_return_frozendict',
default=False,
help='Whether to return FrozenDicts when calling init or apply.',
)

flax_fix_rng = define_bool_state(
name='fix_rng_separator',
flax_fix_rng = bool_flag(
name='flax_fix_rng_separator',
default=False,
help=(
'Whether to add separator characters when folding in static data into'
Expand Down

0 comments on commit 927cf87

Please sign in to comment.