state#
- class flax.nnx.State(mapping, /, *, _copy=True)[source]#
A pytree-like
Mappingwith hashable and comparable keys.
- flax.nnx.filter_state(state, first, /, *filters)[source]#
Filter a
Stateinto one or moreState’s. The user must pass at least oneFilter(i.e.Variable). This method is similar tosplit(), except the filters can be non-exhaustive.Example usage:
>>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> state = nnx.state(model) >>> param = nnx.filter_state(state, nnx.Param) >>> batch_stats = nnx.filter_state(state, nnx.BatchStat) >>> param, batch_stats = nnx.filter_state(state, nnx.Param, nnx.BatchStat)
- Parameters:
first – The first filter
*filters – The optional, additional filters to group the state into mutually exclusive substates.
- Returns:
One or more
Statesequal to the number of filters passed.
- flax.nnx.from_flat_state(flat_state, *, cls=<class 'flax.nnx.statelib.State'>)[source]#
Convert flat state object into
Stateobject.
- flax.nnx.merge_state(state, /, *states, cls=<class 'flax.nnx.statelib.State'>)[source]#
The inverse of
split().mergetakes one or moreState’s and creates a newState.Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat) >>> params['linear']['bias'][...] += 1 >>> state = nnx.merge_state(params, batch_stats) >>> nnx.update(model, state) >>> assert (model.linear.bias[...] == jnp.array([1, 1, 1])).all()
- flax.nnx.replace_by_pure_dict(state, pure_dict, replace_fn=None)[source]#
Replace input
statevalues withpure_dictvalues.- Parameters:
state – A
Stateobject.pure_dict – pure dictionary with values to be used for replacement.
replace_fn – optional replace function.
- flax.nnx.restore_int_paths(pure_dict)[source]#
Restore integer paths from string value in the dict. This method can be helpful when restoring the state from a checkpoint as pure dictionary:
Example:
>>> from flax import nnx >>> import orbax.checkpoint as ocp >>> import tempfile ... >>> model = nnx.List([nnx.Linear(10, 10, rngs=nnx.Rngs(0)) for _ in range(2)]) >>> pure_dict_state = nnx.to_pure_dict(nnx.state(model)) >>> list(pure_dict_state.keys()) [0, 1] >>> checkpointer = ocp.StandardCheckpointer() >>> with tempfile.TemporaryDirectory() as tmpdir: ... checkpointer.save(f'{tmpdir}/ckpt', pure_dict_state) ... restored_pure_dict = checkpointer.restore(f'{tmpdir}/ckpt') ... list(restored_pure_dict.keys()) ['0', '1'] >>> restored_pure_dict = nnx.restore_int_paths(restored_pure_dict) >>> list(restored_pure_dict.keys()) [0, 1]
- Parameters:
pure_dict – state as pure dictionary
- Returns:
state as pure dictionary with restored integers paths
- flax.nnx.to_pure_dict(state, extract_fn=None)[source]#
Convert
Stateobject into pure dictionary state.- Parameters:
state – A
Stateobject.extract_fn – optional extraction function.
- Returns:
Pure dictionary.
- flax.nnx.split_state(state, first, /, *filters)[source]#
Split a
Stateinto one or moreState’s. The user must pass at least oneFilter(i.e.Variable), and the filters must be exhaustive (i.e. they must cover allVariabletypes in theState).Example usage:
>>> from flax import nnx >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs) ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... def __call__(self, x): ... return self.linear(self.batchnorm(x)) >>> model = Model(rngs=nnx.Rngs(0)) >>> state = nnx.state(model) >>> param, batch_stats = nnx.split_state(state, nnx.Param, nnx.BatchStat)
- Parameters:
first – The first filter
*filters – The optional, additional filters to group the state into mutually exclusive substates.
- Returns:
One or more
Statesequal to the number of filters passed.