module#

flax.nnx.iter_children(node, /, *, graph=None)[source]#

Iterates over all immediate child nodes of a given node. This function is similar to iter_graph(), except it only iterates over the immediate children, and does not recurse further down.

Specifically, this function creates a generator that yields the key and the child node instance, where the key is a string representing the attribute name to access the corresponding child.

Example:

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in nnx.iter_children(model):
...  print(path, type(module).__name__)
...
batch_norm BatchNorm
dropout Dropout
linear Linear
submodule SubModule
Parameters:
  • node – A graph node object.

  • graph – If True (default), uses graph-mode which supports the full NNX feature set including shared references. If False, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.

flax.nnx.iter_modules(module, /, *, graph=None)[source]#

Recursively iterates over all nested Module’s of the given Module, including the argument.

Specifically, this function creates a generator that yields the path and the Module instance, where the path is a tuple of strings or integers representing the path to the Module from the root Module.

Example:

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in nnx.iter_modules(model):
...   print(path, type(module).__name__)
...
('batch_norm',) BatchNorm
('dropout',) Dropout
('linear',) Linear
('submodule', 'linear1') Linear
('submodule', 'linear2') Linear
('submodule',) SubModule
() Block
Parameters:
  • module – A Module object.

  • graph – If True (default), uses graph-mode which supports the full NNX feature set including shared references. If False, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.

class flax.nnx.Module(self, /, *args, **kwargs)[source]#

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
eval(**attributes)[source]#

Sets the Module to evaluation mode.

eval uses set_attributes to recursively set attributes deterministic=True and use_running_average=True of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of the Dropout and BatchNorm Modules.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.eval()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
Parameters:

**attributes – additional attributes passed to set_attributes.

iter_children()[source]#

Warning: this method is method is deprecated; use iter_children() instead.

Iterates over all children Module’s of the current Module. This method is similar to iter_modules(), except it only iterates over the immediate children, and does not recurse further down. Alias of iter_children().

iter_modules()[source]#

Warning: this method is method is deprecated; use iter_modules() instead.

Recursively iterates over all nested Module’s of the current Module, including the current Module. Alias of iter_modules().

perturb(name, value, variable_type=<class 'flax.nnx.variablelib.Perturbation'>)[source]#

Extract gradients of intermediate values during training.

Used with nnx.capture() to record intermediate values in the forward pass and their gradients in the backward pass. Returns the value plus whatever perturbation is stored under name in the current capture context, allowing gradient computation via nnx.grad.

The workflow has four steps: 1. Initialize perturbations with nnx.capture(model, nnx.Perturbation) 2. Run model with nnx.capture(model, nnx.Intermediate, init=perturbations) 3. Take gradients with respect to perturbations using nnx.grad 4. Combine results with nnx.merge_state(perturb_grads, intermediates)

Note

This creates extra variables of the same size as value, thus occupies more memory. Use it only to debug gradients in training.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __call__(self, x):
...     x2 = self.perturb('grad_of_x', x)
...     return 3 * x2

>>> model = Model()
>>> x = 1.0

>>> # Step 1: Initialize perturbations
>>> forward = nnx.capture(model, nnx.Perturbation)
>>> _, perturbations = forward(x)

>>> # Steps 2-4: Capture gradients
>>> def train_step(model, perturbations, x):
...   def loss(model, perturbations, x):
...     return nnx.capture(model, nnx.Intermediate, init=perturbations)(x)
...   (grads, perturb_grads), sowed = nnx.grad(loss, argnums=(0, 1), has_aux=True)(model, perturbations, x)
...   return nnx.merge_state(perturb_grads, sowed)

>>> metrics = train_step(model, perturbations, x)
>>> # metrics contains gradients of intermediate values
Parameters:
  • name – A string key for storing the perturbation value.

  • value – The intermediate value to capture gradients for. You must use the returned value (not the original) for gradient capturing to work.

  • variable_type – The Variable type for the stored perturbation. Default is nnx.Perturbation.

set_attributes(*filters, raise_if_not_found=True, graph=None, **attributes)[source]#

Sets the attributes of nested Modules including the current Module. If the attribute is not found in the Module, it is ignored.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, deterministic=False)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.set_attributes(deterministic=True, use_running_average=True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)

Filter’s can be used to set the attributes of specific Modules:

>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)
Parameters:
  • *filters – Filters to select the Modules to set the attributes of.

  • raise_if_not_found – If True (default), raises a ValueError if at least one attribute instance is not found in one of the selected Modules.

  • **attributes – The attributes to set.

sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[source]#

Store intermediate values during module execution for later extraction.

Used with nnx.capture() decorator to collect intermediate values without explicitly passing containers through module calls. Values are stored under the specified name in a collection associated with variable_type.

By default, values are appended to a tuple, allowing multiple values to be tracked when the same module is called multiple times.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'features', x)
...     x = self.linear2(x)
...     return x

>>> # With the capture decorator, sow returns intermediates
>>> model = Model(rngs=nnx.Rngs(0))
>>> @nnx.capture(nnx.Intermediate)
... def forward(model, x):
...   return model(x)
>>> result, intermediates = forward(model, jnp.ones(2))
>>> assert 'features' in intermediates

Custom init/reduce functions can be passed to control accumulation:

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear(x)
...     self.sow(nnx.Intermediate, 'sum', x,
...              init_fn=lambda: 0,
...              reduce_fn=lambda prev, curr: prev+curr)
...     return x
Parameters:
  • variable_type – The Variable type for the stored value. Typically Intermediate or a subclass is used.

  • name – A string key for storing the value in the collection.

  • value – The value to be stored.

  • reduce_fn – Function to combine existing and new values. Default appends to a tuple.

  • init_fn – Function providing initial value for first reduce_fn call. Default is an empty tuple.

train(**attributes)[source]#

Sets the Module to training mode.

train uses set_attributes to recursively set attributes deterministic=False and use_running_average=False of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of the Dropout and BatchNorm Modules.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     # initialize Dropout and BatchNorm in eval mode
...     self.dropout = nnx.Dropout(0.5, deterministic=True)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
>>> block.train()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
Parameters:

**attributes – additional attributes passed to set_attributes.

flax.nnx.view(node, /, *, only=Ellipsis, raise_if_not_found=True, graph=None, **kwargs)[source]#

Creates a new node with static attributes updated according to **kwargs.

The new node contains references to jax arrays in the original node. If a kwarg is not found in any module, this method raises a ValueError. Uses the set_view class method in nnx.Modules. set_view class methods should return any unused kwargs.

Example::
>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, deterministic=False)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> new_block = nnx.view(block, deterministic=True, use_running_average=True)
>>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average
(True, True)
Filter’s can be used to set the attributes of specific Modules::
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> new_block = nnx.view(block, only=nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average
(True, False)
Parameters:
  • node – the object to create a copy of.

  • only – Filters to select the Modules to set the attributes of.

  • graph – If True (default), uses graph-mode which supports the full NNX feature set including shared references. If False, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.

  • **kwargs – The attributes to set.

flax.nnx.view_info(node, /, *, only=Ellipsis, graph=None)[source]#

Provides information about the view arguments for a module and all submodules. If no docstring is provided for a module’s set_view, this function puts the set_view signature below the function.

Example::
>>> from flax import nnx
...
>>> class CustomModel(nnx.Module):
...   def __init__(self, *, rngs):
...       self.mha = nnx.MultiHeadAttention(4, 8, 32, rngs=rngs)
...       self.drop = nnx.Dropout(0.5, rngs=rngs)
...       self.bn = nnx.BatchNorm(32, rngs=rngs)
...
>>> model = CustomModel(rngs=nnx.Rngs(0))
>>> print(nnx.view_info(model))
BatchNorm:
  use_running_average: bool | None = None
    if True, the stored batch statistics will be
    used instead of computing the batch statistics on the input.
Dropout:
  deterministic: bool | None = None
    if True, disables dropout masking.
MultiHeadAttention:
  deterministic: bool | None = None
    if True, the module is set to deterministic mode.
  decode: bool | None = None
    if True, the module is set to decode mode.
  batch_size: int | Shape | None = None
    the batch size to use for the cache.
  max_length: int | None = None
    the max length to use for the cache.
Parameters:
  • node – the object to display view information for.

  • only – Filters to select the Modules to display information for.

  • graph – If True (default), uses graph-mode which supports the full NNX feature set including shared references. If False, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.

flax.nnx.with_attributes(node, /, *, only=Ellipsis, raise_if_not_found=True, graph=None, **attributes)[source]#

Creates a new node with attributes updated according to **attributes.

The new node contains references to jax arrays in the original node. Unlike set_attributes, this function does not modify the original node.

Example::
>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, deterministic=False)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> new_block = nnx.with_attributes(block, deterministic=True, use_running_average=True)
>>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average
(True, True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
Filter’s can be used to set the attributes of specific Modules::
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> new_block = nnx.with_attributes(block, only=nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average
(True, False)
Parameters:
  • node – the object to create a copy of.

  • only – Filters to select the Modules to set the attributes of.

  • raise_if_not_found – If True (default), raises a ValueError if at least one attribute instance is not found in one of the selected Modules.

  • graph – If True (default), uses graph-mode which supports the full NNX feature set including shared references. If False, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.

  • **attributes – The attributes to set.