flax.config package#

Global configuration flags for Flax.

class flax.configurations.Config[source]#
property flax_always_shard_variable#

Whether a nnx.Variable should always automatically be sharded if it contains sharding annotations.

property flax_array_ref#

Whether to use array refs.

property flax_filter_frames#

Whether to hide flax-internal stack frames from tracebacks.

property flax_fix_rng_separator#

Whether to add separator characters when folding in static data into PRNG keys.

property flax_hijax_variable#

Whether to enable HiJAX support for nnx.Variable.

property flax_max_repr_depth#

Maximum depth of reprs for nested flax objects. Default is None (no limit).

property flax_preserve_adopted_names#

When adopting outside modules, don’t clobber existing names.

property flax_profile#

Whether to run Module methods under jax.named_scope for profiles.

property flax_pytree_module#

Whether Modules are pytrees by default or not.

property flax_return_frozendict#

Whether to return FrozenDicts when calling init or apply.

property flax_use_flaxlib#

Whether to use flaxlib for C++ acceleration.

property flax_use_orbax_checkpointing#

Whether to use Orbax to save checkpoints.

property nnx_graph_mode#

Whether NNX APIs default to graph-mode (True) or tree-mode (False).

property nnx_graph_updates#

Whether graph-mode uses dynamic (True) or simple (False) graph traversal.

temp_flip_flag(var_name, var_value)[source]#

Context manager to temporarily flip feature flags for test functions.

Parameters:
  • var_name – the config variable name (without the ‘flax_’ prefix)

  • var_value – the boolean value to set var_name to temporarily

update(name_or_holder, value, /)[source]#

Modify the value of a given flag.

Parameters:
  • name_or_holder – the name of the flag to modify or the corresponding flag holder object.

  • value – new value to set.

flax.configurations.int_flag(name, *, default, help)[source]#

Set up an integer flag.

Example:

num_foo = int_flag(
    name='flax_num_foo',
    default=42,
    help='Number of foo.',
)

Now the FLAX_NUM_FOO shell environment variable can be used to control the process-level value of the flag, in addition to using e.g. config.update("flax_num_foo", 42) directly.

Parameters:
  • name – converted to lowercase to define the name of the flag. It is converted to uppercase to define the corresponding shell environment variable.

  • default – a default value for the flag.

  • help – used to populate the docstring of the returned flag holder object.

Returns:

A flag holder object for accessing the value of the flag.

flax.configurations.static_int_env(varname, default)[source]#

Read an environment variable and interpret it as an integer.

Parameters:
  • varname – the name of the variable

  • default – the default integer value

Returns:

integer return value derived from defaults and environment.

Raises: ValueError if the environment variable is not an integer.