flax.config package#
Global configuration flags for Flax.
- class flax.configurations.Config[source]#
- property flax_always_shard_variable#
Whether a nnx.Variable should always automatically be sharded if it contains sharding annotations.
- property flax_array_ref#
Whether to use array refs.
- property flax_filter_frames#
Whether to hide flax-internal stack frames from tracebacks.
- property flax_fix_rng_separator#
Whether to add separator characters when folding in static data into PRNG keys.
- property flax_hijax_variable#
Whether to enable HiJAX support for nnx.Variable.
- property flax_max_repr_depth#
Maximum depth of reprs for nested flax objects. Default is None (no limit).
- property flax_preserve_adopted_names#
When adopting outside modules, don’t clobber existing names.
- property flax_profile#
Whether to run Module methods under jax.named_scope for profiles.
- property flax_pytree_module#
Whether Modules are pytrees by default or not.
- property flax_return_frozendict#
Whether to return FrozenDicts when calling init or apply.
- property flax_use_flaxlib#
Whether to use flaxlib for C++ acceleration.
- property flax_use_orbax_checkpointing#
Whether to use Orbax to save checkpoints.
- property nnx_graph_mode#
Whether NNX APIs default to graph-mode (True) or tree-mode (False).
- property nnx_graph_updates#
Whether graph-mode uses dynamic (True) or simple (False) graph traversal.
- flax.configurations.int_flag(name, *, default, help)[source]#
Set up an integer flag.
Example:
num_foo = int_flag( name='flax_num_foo', default=42, help='Number of foo.', )
Now the
FLAX_NUM_FOOshell environment variable can be used to control the process-level value of the flag, in addition to using e.g.config.update("flax_num_foo", 42)directly.- Parameters:
name – converted to lowercase to define the name of the flag. It is converted to uppercase to define the corresponding shell environment variable.
default – a default value for the flag.
help – used to populate the docstring of the returned flag holder object.
- Returns:
A flag holder object for accessing the value of the flag.
- flax.configurations.static_int_env(varname, default)[source]#
Read an environment variable and interpret it as an integer.
- Parameters:
varname – the name of the variable
default – the default integer value
- Returns:
integer return value derived from defaults and environment.
Raises: ValueError if the environment variable is not an integer.