# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Global configuration flags for Flax."""
import os
from contextlib import contextmanager
from typing import Any, Generic, NoReturn, TypeVar, overload
_T = TypeVar('_T')
[docs]class Config:
flax_use_flaxlib: bool
flax_array_ref: bool
flax_pytree_module: bool
flax_max_repr_depth: int | None
flax_always_shard_variable: bool
flax_hijax_variable: bool
nnx_graph_mode: bool
nnx_graph_updates: bool
# See https://google.github.io/pytype/faq.html.
_HAS_DYNAMIC_ATTRIBUTES = True
def __init__(self):
self._values = {}
def _add_option(self, name, default):
if name in self._values:
raise RuntimeError(f'Config option {name} already defined')
self._values[name] = default
def _read(self, name):
try:
return self._values[name]
except KeyError:
raise LookupError(f'Unrecognized config option: {name}')
@overload
def update(self, name: str, value: Any, /) -> None:
...
@overload
def update(self, holder: 'FlagHolder[_T]', value: _T, /) -> None:
...
[docs] def update(self, name_or_holder, value, /):
"""Modify the value of a given flag.
Args:
name_or_holder: the name of the flag to modify or the corresponding
flag holder object.
value: new value to set.
"""
name = name_or_holder
if isinstance(name_or_holder, FlagHolder):
name = name_or_holder.name
if name not in self._values:
raise LookupError(f'Unrecognized config option: {name}')
self._values[name] = value
def __repr__(self):
values_repr = ', '.join(f'\n {k}={v!r}' for k, v in self._values.items())
return f'Config({values_repr}\n)'
[docs] @contextmanager
def temp_flip_flag(self, var_name: str, var_value: bool):
"""Context manager to temporarily flip feature flags for test functions.
Args:
var_name: the config variable name (without the 'flax_' prefix)
var_value: the boolean value to set var_name to temporarily
"""
old_value = getattr(self, f'flax_{var_name}')
try:
self.update(f'flax_{var_name}', var_value)
yield
finally:
self.update(f'flax_{var_name}', old_value)
config = Config()
# Config parsing utils
class FlagHolder(Generic[_T]):
def __init__(self, name, help):
self.name = name
self.__name__ = name[4:] if name.startswith('flax_') else name
self.__doc__ = f'Flag holder for `{name}`.\n\n{help}'
def __bool__(self) -> NoReturn:
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(type(self).__name__)
)
@property
def value(self) -> _T:
return config._read(self.name)
def bool_flag(name: str, *, default: bool, help: str) -> FlagHolder[bool]:
"""Set up a boolean flag.
Example::
enable_foo = bool_flag(
name='flax_enable_foo',
default=False,
help='Enable foo.',
)
Now the ``FLAX_ENABLE_FOO`` shell environment variable can be used to
control the process-level value of the flag, in addition to using e.g.
``config.update("flax_enable_foo", True)`` directly.
Args:
name: converted to lowercase to define the name of the flag. It is
converted to uppercase to define the corresponding shell environment
variable.
default: a default value for the flag.
help: used to populate the docstring of the returned flag holder object.
Returns:
A flag holder object for accessing the value of the flag.
"""
name = name.lower()
config._add_option(name, static_bool_env(name.upper(), default))
fh = FlagHolder[bool](name, help)
setattr(Config, name, property(lambda _: fh.value, doc=help))
return fh
[docs]def int_flag(name: str, *, default: int | None, help: str) -> FlagHolder[int]:
"""Set up an integer flag.
Example::
num_foo = int_flag(
name='flax_num_foo',
default=42,
help='Number of foo.',
)
Now the ``FLAX_NUM_FOO`` shell environment variable can be used to
control the process-level value of the flag, in addition to using e.g.
``config.update("flax_num_foo", 42)`` directly.
Args:
name: converted to lowercase to define the name of the flag. It is
converted to uppercase to define the corresponding shell environment
variable.
default: a default value for the flag.
help: used to populate the docstring of the returned flag holder object.
Returns:
A flag holder object for accessing the value of the flag.
"""
name = name.lower()
config._add_option(name, static_int_env(name.upper(), default))
fh = FlagHolder[int](name, help)
setattr(Config, name, property(lambda _: fh.value, doc=help))
return fh
def static_bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean.
This is deprecated. Please use bool_flag() unless your flag
will be used in a static method and does not require runtime updates.
True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1';
false values are 'n', 'no', 'f', 'false', 'off', and '0'.
Args:
varname: the name of the variable
default: the default boolean value
Returns:
boolean return value derived from defaults and environment.
Raises: ValueError if the environment variable is anything else.
"""
val = os.getenv(varname, str(default))
val = val.lower()
if val in ('y', 'yes', 't', 'true', 'on', '1'):
return True
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
return False
else:
raise ValueError(
f'invalid truth value {val!r} for environment {varname!r}'
)
[docs]def static_int_env(varname: str, default: int | None) -> int | None:
"""Read an environment variable and interpret it as an integer.
Args:
varname: the name of the variable
default: the default integer value
Returns:
integer return value derived from defaults and environment.
Raises: ValueError if the environment variable is not an integer.
"""
val = os.getenv(varname)
if val is None:
return default
try:
return int(val)
except ValueError:
raise ValueError(
f'invalid integer value {val!r} for environment {varname!r}'
) from None
# Flax Global Configuration Variables:
flax_filter_frames = bool_flag(
name='flax_filter_frames',
default=True,
help='Whether to hide flax-internal stack frames from tracebacks.',
)
flax_profile = bool_flag(
name='flax_profile',
default=True,
help='Whether to run Module methods under jax.named_scope for profiles.',
)
flax_use_orbax_checkpointing = bool_flag(
name='flax_use_orbax_checkpointing',
default=True,
help='Whether to use Orbax to save checkpoints.',
)
flax_preserve_adopted_names = bool_flag(
name='flax_preserve_adopted_names',
default=False,
help="When adopting outside modules, don't clobber existing names.",
)
# TODO(marcuschiam): remove this feature flag once regular dict migration is complete
flax_return_frozendict = bool_flag(
name='flax_return_frozendict',
default=False,
help='Whether to return FrozenDicts when calling init or apply.',
)
flax_fix_rng = bool_flag(
name='flax_fix_rng_separator',
default=False,
help=(
'Whether to add separator characters when folding in static data into'
' PRNG keys.'
),
)
flax_use_flaxlib = bool_flag(
name='flax_use_flaxlib',
default=False,
help='Whether to use flaxlib for C++ acceleration.',
)
flax_array_ref = bool_flag(
name='flax_array_ref',
default=False,
help='Whether to use array refs.',
)
flax_pytree_module = bool_flag(
name='flax_pytree_module',
default=True,
help='Whether Modules are pytrees by default or not.',
)
flax_max_repr_depth = int_flag(
name='flax_max_repr_depth',
default=None,
help='Maximum depth of reprs for nested flax objects. Default is None (no limit).',
)
flax_always_shard_variable = bool_flag(
name='flax_always_shard_variable',
default=True,
help='Whether a `nnx.Variable` should always automatically be sharded if it contains sharding annotations.',
)
flax_hijax_variable = bool_flag(
name='flax_hijax_variable',
default=False,
help='Whether to enable HiJAX support for `nnx.Variable`.',
)
nnx_graph_mode = bool_flag(
name='nnx_graph_mode',
default=True,
help='Whether NNX APIs default to graph-mode (True) or tree-mode (False).',
)
nnx_graph_updates = bool_flag(
name='nnx_graph_updates',
default=True,
help='Whether graph-mode uses dynamic (True) or simple (False) graph traversal.',
)