Optimizer#
- class flax.nnx.optimizer.Optimizer(self, model, tx, *, wrt)#
Simple train state for the common case with a single Optax optimizer.
Example usage:
>>> import jax, jax.numpy as jnp >>> from flax import nnx >>> import optax ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... return self.linear2(self.linear1(x)) ... >>> x = jax.random.normal(jax.random.key(0), (1, 2)) >>> y = jnp.ones((1, 4)) ... >>> model = Model(nnx.Rngs(0)) >>> tx = optax.adam(1e-3) >>> optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param) ... >>> loss_fn = lambda model: ((model(x) - y) ** 2).mean() >>> loss_fn(model) Array(2.3359997, dtype=float32) >>> grads = nnx.grad(loss_fn)(model) >>> optimizer.update(model, grads) >>> loss_fn(model) Array(2.310461, dtype=float32)
- step#
An
OptStateVariablethat tracks the step count.
- tx#
An Optax gradient transformation.
- opt_state#
The Optax optimizer state.
- __init__(model, tx, *, wrt)#
Instantiate the class and wrap the
Moduleand Optax gradient transformation. Instantiate the optimizer state to keep track ofVariabletypes specified inwrt. Set the step count to 0.- Parameters:
model – An NNX Module.
tx – An Optax gradient transformation.
wrt – filter to specify for which
Variable’s to keep track of in the optimizer state. These should be theVariable’s that you plan on updating; i.e. this argument value should match thewrtargument passed to thennx.gradcall that will generate the gradients that will be passed into thegradsargument of theupdate()method. The filter should match the filter used in nnx.grad.
- update(model, grads, /, **kwargs)#
Updates the optimizer state and model parameters given the gradients.
Example:
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> import optax ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.count = nnx.Variable(jnp.array(0)) ... ... def __call__(self, x): ... self.count[...] += 1 ... return self.linear(x) ... >>> model = Model(rngs=nnx.Rngs(0)) ... >>> loss_fn = lambda model, x, y: ((model(x) - y) ** 2).mean() >>> optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) >>> grads = nnx.grad(loss_fn, argnums=nnx.DiffState(0, nnx.Param))( ... model, jnp.ones((1, 2)), jnp.ones((1, 3)) ... ) >>> optimizer.update(model, grads)
Note that internally this function calls
.tx.update()followed by a call tooptax.apply_updates()to updateparamsandopt_state.- Parameters:
grads – the gradients derived from
nnx.grad.**kwargs – additional keyword arguments passed to the tx.update, to support
GradientTransformationExtraArgs –
optax.scale_by_backtracking_linesearch. (such as) –