variables#

class flax.nnx.BatchStat(*args, **kwargs)[source]#

The mean and variance batch statistics stored in the BatchNorm layer. Note, these are not the learnable scale and bias parameters, but rather the running average statistics that are typically used during post-training inference:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> layer = nnx.BatchNorm(3, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': Param(
    value=(3,)
  ),
  'mean': BatchStat(
    value=(3,)
  ),
  'scale': Param(
    value=(3,)
  ),
  'var': BatchStat(
    value=(3,)
  )
})
class flax.nnx.Cache(*args, **kwargs)[source]#

Autoregressive cache in MultiHeadAttention:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.MultiHeadAttention(
...   num_heads=2,
...   in_features=3,
...   qkv_features=6,
...   out_features=6,
...   decode=True,
...   rngs=nnx.Rngs(0),
... )
>>> layer.init_cache((1, 3))
>>> jax.tree.map(jnp.shape, nnx.state(layer, nnx.Cache))
State({
  'cache_index': Cache(
    value=()
  ),
  'cached_key': Cache(
    value=(1, 2, 3)
  ),
  'cached_value': Cache(
    value=(1, 2, 3)
  )
})
class flax.nnx.Intermediate(*args, **kwargs)[source]#

Variable type that is typically used for Module.sow():

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     self.sow(nnx.Intermediate, 'i', x)
...     x = self.linear2(x)
...     return x
>>> model = Model(rngs=nnx.Rngs(0))

>>> x = jnp.ones((1, 2))
>>> y = model(x)
>>> jax.tree.map(jnp.shape, nnx.state(model, nnx.Intermediate))
State({
  'i': Intermediate(
    value=((1, 3),)
  )
})
class flax.nnx.Param(*args, **kwargs)[source]#

The canonical learnable parameter. All learnable parameters in NNX layer modules will have the Param Variable type:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> layer = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
  'bias': Param(
    value=(3,)
  ),
  'kernel': Param(
    value=(2, 3)
  )
})
class flax.nnx.Variable(*args, **kwargs)[source]#

The base class for all Variable types. Create custom Variable types by subclassing this class. Numerous NNX graph functions can filter for specific Variable types, for example, split(), state(), pop(), and State.filter().

Example usage:

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> class CustomVariable(nnx.Variable):
...   pass

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...     self.custom_variable = CustomVariable(jnp.ones((1, 3)))
...   def __call__(self, x):
...     return self.linear(x) + self.custom_variable
>>> model = Model(rngs=nnx.Rngs(0))

>>> linear_variables = nnx.state(model, nnx.Param)
>>> jax.tree.map(jnp.shape, linear_variables)
State({
  'linear': {
    'bias': Param(
      value=(3,)
    ),
    'kernel': Param(
      value=(2, 3)
    )
  }
})

>>> custom_variable = nnx.state(model, CustomVariable)
>>> jax.tree.map(jnp.shape, custom_variable)
State({
  'custom_variable': CustomVariable(
    value=(1, 3)
  )
})

>>> variables = nnx.state(model)
>>> jax.tree.map(jnp.shape, variables)
State({
  'custom_variable': CustomVariable(
    value=(1, 3)
  ),
  'linear': {
    'bias': Param(
      value=(3,)
    ),
    'kernel': Param(
      value=(2, 3)
    )
  }
})
del_metadata(name)[source]#

Delete a metadata entry for the Variable.

Parameters:

name – The key of the metadata element to delete.

get_metadata(name=None, default=<flax.typing.Missing object>, *, exclude_required=None)[source]#

Get metadata for the Variable.

Parameters:
  • name – The key of the metadata element to get. If not provided, returns the full metadata dictionary.

  • default – The default value to return if the metadata key is not found. If not provided and the key is not found, raises a KeyError.

has_metadata(name)[source]#

Check if the Variable has a metadata entry for the given name.

Parameters:

name – The key of the metadata element to check.

Returns:

True if the metadata entry exists, False otherwise.

set_metadata(*args, **kwargs)[source]#

Set metadata for the Variable.

set_metadata can be called in 3 ways:

  1. By passing a dictionary of metadata as the first argument, this will replace

the entire Variable’s metadata.

  1. By passing a name and value as the first two arguments, this will set

the metadata entry for the given name to the given value.

  1. By using keyword arguments, this will update the Variable’s metadata

with the provided key-value pairs.

property type#

The type of the variable.

class flax.nnx.VariableMetadata(raw_value: 'A', set_value_hooks: 'tuple[SetValueHook[A], ...]' = (), get_value_hooks: 'tuple[GetValueHook[A], ...]' = (), create_value_hooks: 'tuple[CreateValueHook[A], ...]' = (), add_axis_hooks: 'tuple[AddAxisHook[Variable[A]], ...]' = (), remove_axis_hooks: 'tuple[RemoveAxisHook[Variable[A]], ...]' = (), metadata: 'tp.Mapping[str, tp.Any]' = <factory>)[source]#
flax.nnx.with_metadata(initializer, set_value_hooks=(), get_value_hooks=(), create_value_hooks=(), add_axis_hooks=(), remove_axis_hooks=(), **metadata)[source]#
flax.nnx.variable_name_from_type(typ, /, *, allow_register=False)[source]#

Given an NNX Variable type, get its Linen-style collection name.

Should output the exact inversed result of variable_type_from_name().

flax.nnx.variable_type_from_name(name, /, *, base=<class 'flax.nnx.variablelib.Variable'>, allow_register=False)[source]#

Given a Linen-style collection name, get or create its NNX Variable class.

flax.nnx.register_variable_name(name, typ=<flax.typing.Missing object>, *, overwrite=False)[source]#

Register a pair of Linen collection name and its NNX type.