state#

class flax.nnx.State(mapping, /, *, _copy=True)[source]#

A pytree-like Mapping with hashable and comparable keys.

class flax.nnx.FlatState(items, /, *, sort)[source]#
flax.nnx.filter_state(state, first, /, *filters)[source]#

Filter a State into one or more State’s. The user must pass at least one Filter (i.e. Variable). This method is similar to split(), except the filters can be non-exhaustive.

Example usage:

>>> from flax import nnx

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batchnorm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> state = nnx.state(model)
>>> param = nnx.filter_state(state, nnx.Param)
>>> batch_stats = nnx.filter_state(state, nnx.BatchStat)
>>> param, batch_stats = nnx.filter_state(state, nnx.Param, nnx.BatchStat)
Parameters:
  • first – The first filter

  • *filters – The optional, additional filters to group the state into mutually exclusive substates.

Returns:

One or more States equal to the number of filters passed.

flax.nnx.from_flat_state(flat_state, *, cls=<class 'flax.nnx.statelib.State'>)[source]#

Convert flat state object into State object.

Parameters:

flat_state – A FlatState object.

Returns:

State State object.

flax.nnx.map_state(f, state)[source]#

Map f over State object.

Parameters:
  • f – A function to be mapped

  • state – A State object.

Returns:

New state State.

flax.nnx.merge_state(state, /, *states, cls=<class 'flax.nnx.statelib.State'>)[source]#

The inverse of split().

merge takes one or more State’s and creates a new State.

Example usage:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batchnorm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat)
>>> params['linear']['bias'][...] += 1

>>> state = nnx.merge_state(params, batch_stats)
>>> nnx.update(model, state)
>>> assert (model.linear.bias[...] == jnp.array([1, 1, 1])).all()
Parameters:
  • state – A State object.

  • *states – Additional State objects.

Returns:

The merged State.

flax.nnx.replace_by_pure_dict(state, pure_dict, replace_fn=None)[source]#

Replace input state values with pure_dict values.

Parameters:
  • state – A State object.

  • pure_dict – pure dictionary with values to be used for replacement.

  • replace_fn – optional replace function.

flax.nnx.restore_int_paths(pure_dict)[source]#

Restore integer paths from string value in the dict. This method can be helpful when restoring the state from a checkpoint as pure dictionary:

Example:

>>> from flax import nnx
>>> import orbax.checkpoint as ocp
>>> import tempfile
...
>>> model = nnx.List([nnx.Linear(10, 10, rngs=nnx.Rngs(0)) for _ in range(2)])
>>> pure_dict_state = nnx.to_pure_dict(nnx.state(model))
>>> list(pure_dict_state.keys())
[0, 1]
>>> checkpointer = ocp.StandardCheckpointer()
>>> with tempfile.TemporaryDirectory() as tmpdir:
...   checkpointer.save(f'{tmpdir}/ckpt', pure_dict_state)
...   restored_pure_dict = checkpointer.restore(f'{tmpdir}/ckpt')
...   list(restored_pure_dict.keys())
['0', '1']
>>> restored_pure_dict = nnx.restore_int_paths(restored_pure_dict)
>>> list(restored_pure_dict.keys())
[0, 1]
Parameters:

pure_dict – state as pure dictionary

Returns:

state as pure dictionary with restored integers paths

flax.nnx.to_flat_state(state)[source]#

Convert state into flat state

Parameters:

state – A State object.

Returns:

Flat state FlatState

flax.nnx.to_pure_dict(state, extract_fn=None)[source]#

Convert State object into pure dictionary state.

Parameters:
  • state – A State object.

  • extract_fn – optional extraction function.

Returns:

Pure dictionary.

flax.nnx.split_state(state, first, /, *filters)[source]#

Split a State into one or more State’s. The user must pass at least one Filter (i.e. Variable), and the filters must be exhaustive (i.e. they must cover all Variable types in the State).

Example usage:

>>> from flax import nnx

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
...     self.linear = nnx.Linear(2, 3, rngs=rngs)
...   def __call__(self, x):
...     return self.linear(self.batchnorm(x))

>>> model = Model(rngs=nnx.Rngs(0))
>>> state = nnx.state(model)
>>> param, batch_stats = nnx.split_state(state, nnx.Param, nnx.BatchStat)
Parameters:
  • first – The first filter

  • *filters – The optional, additional filters to group the state into mutually exclusive substates.

Returns:

One or more States equal to the number of filters passed.