module#

flax.nnx.iter_children(node, /, *, graph=None)[source]#

Iterates over all immediate child nodes of a given node. This function is similar to iter_graph(), except it only iterates over the immediate children, and does not recurse further down.

Specifically, this function creates a generator that yields the key and the child node instance, where the key is a string representing the attribute name to access the corresponding child.

Example:

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in nnx.iter_children(model):
...  print(path, type(module).__name__)
...
batch_norm BatchNorm
dropout Dropout
linear Linear
submodule SubModule
Parameters:
  • node – A graph node object.

  • graph – If True (default), uses graph-mode which supports the full NNX feature set including shared references. If False, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.

flax.nnx.iter_modules(module, /, *, graph=None)[source]#

Recursively iterates over all nested Module’s of the given Module, including the argument.

Specifically, this function creates a generator that yields the path and the Module instance, where the path is a tuple of strings or integers representing the path to the Module from the root Module.

Example:

>>> from flax import nnx
...
>>> class SubModule(nnx.Module):
...   def __init__(self, din, dout, rngs):
...     self.linear1 = nnx.Linear(din, dout, rngs=rngs)
...     self.linear2 = nnx.Linear(din, dout, rngs=rngs)
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.submodule = SubModule(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> model = Block(2, 5, rngs=nnx.Rngs(0))
>>> for path, module in nnx.iter_modules(model):
...   print(path, type(module).__name__)
...
('batch_norm',) BatchNorm
('dropout',) Dropout
('linear',) Linear
('submodule', 'linear1') Linear
('submodule', 'linear2') Linear
('submodule',) SubModule
() Block
Parameters:
  • module – A Module object.

  • graph – If True (default), uses graph-mode which supports the full NNX feature set including shared references. If False, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.

class flax.nnx.Module(self, /, *args, **kwargs)[source]#

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
eval(**attributes)[source]#

Sets the Module to evaluation mode.

eval uses set_attributes to recursively set attributes deterministic=True and use_running_average=True of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of the Dropout and BatchNorm Modules.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5)
...     self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.eval()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
Parameters:

**attributes – additional attributes passed to set_attributes.

iter_children()[source]#

Warning: this method is method is deprecated; use iter_children() instead.

Iterates over all children Module’s of the current Module. This method is similar to iter_modules(), except it only iterates over the immediate children, and does not recurse further down. Alias of iter_children().

iter_modules()[source]#

Warning: this method is method is deprecated; use iter_modules() instead.

Recursively iterates over all nested Module’s of the current Module, including the current Module. Alias of iter_modules().

perturb(name, value, variable_type=<class 'flax.nnx.variablelib.Perturbation'>)[source]#

Add an zero-value variable (“perturbation”) to the intermediate value.

The gradient of value would be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients of value by running jax.grad on the perturbation variable.

Since the shape of the perturbation value depends on the shape of the input, a perturbation variable is only created after you run a sample input through the model once.

Note

This creates extra dummy variables of the same size as value, thus occupies more memory. Use it only to debug gradients in training.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = self.perturb('xgrad', x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 4))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'xgrad')  # perturbation requires a sample input run
>>> _ = model(x)
>>> assert model.xgrad.shape == (1, 3)   # same as the intermediate value
>>> graphdef, params, perturbations = nnx.split(model, nnx.Param, nnx.Perturbation)

>>> # Take gradients on the Param and Perturbation variables
>>> @nnx.grad(argnums=(0, 1))
... def grad_loss(params, perturbations, inputs, targets):
...   model = nnx.merge(graphdef, params, perturbations)
...   return jnp.mean((model(inputs) - targets) ** 2)

>>> (grads, perturbations) = grad_loss(params, perturbations, x, y)
>>> # `perturbations.xgrad[...]` is the intermediate gradient
>>> assert not jnp.array_equal(perturbations.xgrad[...], jnp.zeros((1, 3)))
Parameters:
  • name – A string denoting the Module attribute name for the perturbation value.

  • value – The value to take intermediate gradient.

  • variable_type – The Variable type for the stored perturbation. Defaulted at nnx.Perturbation.

set_attributes(*filters, raise_if_not_found=True, graph=None, **attributes)[source]#

Sets the attributes of nested Modules including the current Module. If the attribute is not found in the Module, it is ignored.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, deterministic=False)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.set_attributes(deterministic=True, use_running_average=True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)

Filter’s can be used to set the attributes of specific Modules:

>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)
Parameters:
  • *filters – Filters to select the Modules to set the attributes of.

  • raise_if_not_found – If True (default), raises a ValueError if at least one attribute instance is not found in one of the selected Modules.

  • **attributes – The attributes to set.

sow(variable_type, name, value, reduce_fn=<function <lambda>>, init_fn=<function <lambda>>)[source]#

sow() can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call. sow() stores a value in a new Module attribute, denoted by name. The value will be wrapped by a Variable of type variable_type, which can be useful to filter for in split(), state() and pop().

By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x, add=0):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x+add)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'i')

>>> y = model(x)
>>> assert hasattr(model, 'i')
>>> assert len(model.i) == 1 # tuple of length 1
>>> assert model.i[0].shape == (1, 3)

>>> y = model(x, add=1)
>>> assert len(model.i) == 2 # tuple of length 2
>>> assert (model.i[0] + 1 == model.i[1]).all()

Alternatively, a custom init/reduce function can be passed:

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'sum', x,
...              init_fn=lambda: 0,
...              reduce_fn=lambda prev, curr: prev+curr)
...     self.sow(nnx.Intermediate, 'product', x,
...              init_fn=lambda: 1,
...              reduce_fn=lambda prev, curr: prev*curr)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))

>>> y = model(x)
>>> assert (model.sum[...] == model.product[...]).all()
>>> intermediate = model.sum[...]

>>> y = model(x)
>>> assert (model.sum[...] == intermediate*2).all()
>>> assert (model.product[...] == intermediate**2).all()
Parameters:
  • variable_type – The Variable type for the stored value. Typically Intermediate is used to indicate an intermediate value.

  • name – A string denoting the Module attribute name, where the sowed value is stored.

  • value – The value to be stored.

  • reduce_fn – The function used to combine the existing value with the new value. The default is to append the value to a tuple.

  • init_fn – For the first value stored, reduce_fn will be passed the result of init_fn together with the value to be stored. The default is an empty tuple.

train(**attributes)[source]#

Sets the Module to training mode.

train uses set_attributes to recursively set attributes deterministic=False and use_running_average=False of all nested Modules that have these attributes. Its primarily used to control the runtime behavior of the Dropout and BatchNorm Modules.

Example:

>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     # initialize Dropout and BatchNorm in eval mode
...     self.dropout = nnx.Dropout(0.5, deterministic=True)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)
>>> block.train()
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
Parameters:

**attributes – additional attributes passed to set_attributes.

flax.nnx.view(node, /, *, only=Ellipsis, raise_if_not_found=True, graph=None, **kwargs)[source]#

Creates a new node with static attributes updated according to **kwargs.

The new node contains references to jax arrays in the original node. If a kwarg is not found in any module, this method raises a ValueError. Uses the set_view class method in nnx.Modules. set_view class methods should return any unused kwargs.

Example::
>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, deterministic=False)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> new_block = nnx.view(block, deterministic=True, use_running_average=True)
>>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average
(True, True)
Filter’s can be used to set the attributes of specific Modules::
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> new_block = nnx.view(block, only=nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average
(True, False)
Parameters:
  • node – the object to create a copy of.

  • only – Filters to select the Modules to set the attributes of.

  • graph – If True (default), uses graph-mode which supports the full NNX feature set including shared references. If False, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.

  • **kwargs – The attributes to set.

flax.nnx.view_info(node, /, *, only=Ellipsis, graph=None)[source]#

Provides information about the view arguments for a module and all submodules. If no docstring is provided for a module’s set_view, this function puts the set_view signature below the function.

Example::
>>> from flax import nnx
...
>>> class CustomModel(nnx.Module):
...   def __init__(self, *, rngs):
...       self.mha = nnx.MultiHeadAttention(4, 8, 32, rngs=rngs)
...       self.drop = nnx.Dropout(0.5, rngs=rngs)
...       self.bn = nnx.BatchNorm(32, rngs=rngs)
...
>>> model = CustomModel(rngs=nnx.Rngs(0))
>>> print(nnx.view_info(model))
BatchNorm:
  use_running_average: bool | None = None
    if True, the stored batch statistics will be
    used instead of computing the batch statistics on the input.
Dropout:
  deterministic: bool | None = None
    if True, disables dropout masking.
MultiHeadAttention:
  deterministic: bool | None = None
    if True, the module is set to deterministic mode.
  decode: bool | None = None
    if True, the module is set to decode mode.
  batch_size: int | Shape | None = None
    the batch size to use for the cache.
  max_length: int | None = None
    the max length to use for the cache.
Parameters:
  • node – the object to display view information for.

  • only – Filters to select the Modules to display information for.

  • graph – If True (default), uses graph-mode which supports the full NNX feature set including shared references. If False, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.

flax.nnx.with_attributes(node, /, *, only=Ellipsis, raise_if_not_found=True, graph=None, **attributes)[source]#

Creates a new node with attributes updated according to **attributes.

The new node contains references to jax arrays in the original node. Unlike set_attributes, this function does not modify the original node.

Example::
>>> from flax import nnx
...
>>> class Block(nnx.Module):
...   def __init__(self, din, dout, *, rngs: nnx.Rngs):
...     self.linear = nnx.Linear(din, dout, rngs=rngs)
...     self.dropout = nnx.Dropout(0.5, deterministic=False)
...     self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> new_block = nnx.with_attributes(block, deterministic=True, use_running_average=True)
>>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average
(True, True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
Filter’s can be used to set the attributes of specific Modules::
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> new_block = nnx.with_attributes(block, only=nnx.Dropout, deterministic=True)
>>> # Only the dropout will be modified
>>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average
(True, False)
Parameters:
  • node – the object to create a copy of.

  • only – Filters to select the Modules to set the attributes of.

  • raise_if_not_found – If True (default), raises a ValueError if at least one attribute instance is not found in one of the selected Modules.

  • graph – If True (default), uses graph-mode which supports the full NNX feature set including shared references. If False, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.

  • **attributes – The attributes to set.