transforms#
- flax.nnx.grad(f=<flax.typing.Missing object>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=(), graph=None, graph_updates=None)[source]#
Object-aware version of
jax.gradthat can handle Modules / graph nodes as arguments.Example:
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 3)) ... >>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) >>> grad_fn = nnx.grad(loss_fn) ... >>> grads = grad_fn(m, x, y) >>> jax.tree.map(jnp.shape, grads) State({ 'bias': Param( value=(3,) ), 'kernel': Param( value=(2, 3) ) })
By default, NNX objects are differentiated with respect to all their
nnx.ParamVariables. You can specify which substates are differentiable by passing aDiffStateobject to theargnumsargument. For example, if you want to differentiate only thekernelattribute of theLinearclass, you can use thePathContainsfilter:>>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) ... >>> kernel_attribute = nnx.PathContains('kernel') >>> diff_state = nnx.DiffState(0, kernel_attribute) ... >>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) >>> grad_fn = nnx.grad(loss_fn, argnums=diff_state) ... >>> grads = grad_fn(m, x, y) >>> jax.tree.map(jnp.shape, grads) State({ 'kernel': Param( value=(2, 3) ) })
For more information on how to create custom filters, see Using Filters guide.
- Parameters:
fun – Function to be differentiated. Its arguments at positions specified by
argnumsshould be arrays, scalars, graph nodes or standard Python containers. Argument arrays in the positions specified byargnumsmust be of inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape()but not arrays with shape(1,)etc.)argnums – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).
has_aux – Optional, bool. Indicates whether
funreturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic – Optional, bool. Indicates whether
funis promised to be holomorphic. If True, inputs and outputs must be complex. Default False.allow_int – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.
graph – If
True(default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. IfFalse, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not supportDiffStateor sharedVariablereferences.graph_updates – If
True, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect whengraph=False. WhenFalse, usingDiffStateis not supported.
- flax.nnx.jit(fun=<flax.typing.Missing object>, *, in_shardings=None, out_shardings=None, static_argnums=None, static_argnames=None, donate_argnums=None, donate_argnames=None, keep_unused=False, device=None, backend=None, inline=False, graph=None, graph_updates=None)[source]#
Lifted version of
jax.jitthat can handle Modules / graph nodes as arguments.Note
If jitted function has a model and an optimizer as inputs, we can reduce accelerator’s memory usage if we specify them in
donate_argnumsordonate_argnames:>>> from flax import nnx >>> >>> @nnx.jit(donate_argnames=("model", "optimizer")) ... def func(model: nnx.Module, optimizer: nnx.Optimizer, other_args): ... pass
For details please see this discussion.
- Parameters:
fun –
Function to be jitted.
funshould be a pure function, as side-effects may only be executed once.The arguments and return value of
funshould be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated bystatic_argnumscan be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.JAX keeps a weak reference to
funfor use as a compilation cache key, so the objectfunmust be weakly-referenceable. MostCallableobjects will already satisfy this requirement.Note
Bound methods (e.g.,
module.method) are not supported. Use the decorator form@nnx.jiton the method definition or callnnx.jit(MyClass.method)(instance, ...)with the unbound method.in_shardings –
Pytree of structure matching that of arguments to
fun, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in that subtree.The
in_shardingsargument is optional. JAX will infer the shardings from the inputjax.Array’s and defaults to replicating the input if the sharding cannot be inferred.- The valid resource assignment specifications are:
Sharding, which will decide how the valuewill be partitioned. With this, using a mesh context manager is not required.
None, will give JAX the freedom to choose whatever sharding it wants. For in_shardings, JAX will mark is as replicated but this behavior can change in the future. For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
The size of every dimension has to be a multiple of the total number of resources assigned to it. This is similar to pjit’s in_shardings.
out_shardings –
Like
in_shardings, but specifies resource assignment for function outputs. This is similar to pjit’s out_shardings.The
out_shardingsargument is optional. If not specified,jax.jit()will use GSPMD’s sharding propagation to figure out what the sharding of the output(s) should be.static_argnums –
An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object.
Static arguments should be hashable, meaning both
__hash__and__eq__are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static.If neither
static_argnumsnorstatic_argnamesis provided, no arguments are treated as static. Ifstatic_argnumsis not provided butstatic_argnamesis, or vice versa, JAX usesinspect.signature(fun)to find any positional arguments that correspond tostatic_argnames(or vice versa). If bothstatic_argnumsandstatic_argnamesare provided,inspect.signatureis not used, and only actual parameters listed in eitherstatic_argnumsorstatic_argnameswill be treated as static.static_argnames – An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on
static_argnumsfor details. If not provided butstatic_argnumsis set, the default is based on callinginspect.signature(fun)to find corresponding named arguments.donate_argnums –
Specify which positional argument buffers are “donated” to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. By default, no argument buffers are donated.
If neither
donate_argnumsnordonate_argnamesis provided, no arguments are donated. Ifdonate_argnumsis not provided butdonate_argnamesis, or vice versa, JAX usesinspect.signature(fun)to find any positional arguments that correspond todonate_argnames(or vice versa). If bothdonate_argnumsanddonate_argnamesare provided,inspect.signatureis not used, and only actual parameters listed in eitherdonate_argnumsordonate_argnameswill be donated.For more details on buffer donation see the FAQ.
donate_argnames – An optional string or collection of strings specifying which named arguments are donated to the computation. See the comment on
donate_argnumsfor details. If not provided butdonate_argnumsis set, the default is based on callinginspect.signature(fun)to find corresponding named arguments.keep_unused – If False (the default), arguments that JAX determines to be unused by fun may be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If True, unused arguments will not be pruned.
device – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via
jax.devices().) The default is inherited from XLA’s DeviceAssignment logic and is usually to usejax.devices()[0].backend – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend:
'cpu','gpu', or'tpu'.inline – Specify whether this function should be inlined into enclosing jaxprs (rather than being represented as an application of the xla_call primitive with its own subjaxpr). Default False.
graph – If
True(default), uses graph-mode which supports the full NNX feature set including shared references, reference semantics, and structural changes to Modules inside the jitted function. IfFalse, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode is faster but does not support sharedVariablereferences or returning mutable array references from the jitted function.
- Returns:
A wrapped version of
fun, set up for just-in-time compilation.
- flax.nnx.shard_map(f=<class 'flax.typing.Missing'>, *, mesh, in_specs, out_specs, axis_names=frozenset({}), check_vma=True, graph=None, graph_updates=None)[source]#
Lifted version of jax.shard_map that can handle Modules / graph nodes as arguments.
Simple data parallel example:
import jax import jax.numpy as jnp from flax import nnx from jax.sharding import PartitionSpec as P mesh = jax.sharding.Mesh(jax.local_devices(), ('data',)) m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) x = jnp.ones((32, 2)) @nnx.shard_map( mesh=mesh, in_specs=(P(None), P('data')), out_specs=P('data') ) def f(m, x): return m(x) y = f(m, x) jax.debug.visualize_array_sharding(y)
Notice that here we simply used some
PartitionSpecto define the spec the the whole model and data. This works for simple cases but if we need to assign differentPartitionSpecto different parts of the model we need to useStateShardingand create some filters that allow us to target specific parts of the model. Here’s an example of how to do tensor parallelism for a simple MLP block usingStateShardingand filters:mesh = jax.sharding.Mesh(jax.local_devices(), ('model',)) class MLP(nnx.Module): def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs) self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs) def __call__(self, x): return self.linear2(jax.nn.relu(self.linear1(x))) m = MLP(2, 64, 3, rngs=nnx.Rngs(0)) x = jnp.ones((32, 2)) def path_ends_with(*path_suffix): # custom filter return lambda path, value: path[-len(path_suffix):] == path_suffix model_spec = nnx.StateSharding({ path_ends_with('linear1', 'kernel'): P(None, 'model'), path_ends_with('linear2', 'kernel'): P('model', None), }) @nnx.shard_map(mesh=mesh, in_specs=(model_spec, P(None)), out_specs=P(None)) def f(m, x): y = m(x) return jax.lax.psum(y, 'model') y = f(m, x) jax.debug.visualize_array_sharding(m.linear1.kernel[...]) jax.debug.visualize_array_sharding(m.linear2.kernel[...])
Alternatively, a
Stateobject with the exact PartitionSpec for each state then you can be passed toStateSharding:mesh = jax.sharding.Mesh(jax.local_devices(), ('model',)) class MLP(nnx.Module): def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs) self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs) def __call__(self, x): return self.linear2(jax.nn.relu(self.linear1(x))) m = MLP(2, 64, 3, rngs=nnx.Rngs(0)) x = jnp.ones((32, 2)) model_spec = nnx.State( { 'linear1': {'kernel': P(None, 'model')}, 'linear2': {'kernel': P('model', None)}, } ) @nnx.shard_map( mesh=mesh, in_specs=(nnx.StateSharding(model_spec), P(None)), out_specs=P(None), ) def f(m, x): y = m(x) return jax.lax.psum(y, 'model') y = f(m, x) jax.debug.visualize_array_sharding(m.linear1.kernel[...]) jax.debug.visualize_array_sharding(m.linear2.kernel[...])
Here
model_specwas created manually but you can also automate this process by usingnnx.get_partition_specto automatically create it for you (see Scale up on multiple devices ).- Parameters:
f – callable to be mapped. Each application of
f, or “instance” off, takes as input a shard of the mapped-over arguments and produces a shard of the output.mesh – a
jax.sharding.Meshrepresenting the array of devices over which to shard the data and on which to execute instances off. The names of theMeshcan be used in collective communication operations inf. This is typically created by a utility function likejax.experimental.mesh_utils.create_device_mesh().in_specs – a pytree with
jax.sharding.PartitionSpec``or ``nnx.StateSharding(mapping substates toPartitionSpec``s) instances as leaves, with a tree structure that is a tree prefix of the args tuple to be mapped over. Similar to ``jax.sharding.NamedSharding, eachPartitionSpecrepresents how the corresponding argument (or subtree of arguments) should be sharded along the named axes ofmesh. In eachPartitionSpec, mentioning ameshaxis name at a position expresses sharding the corresponding argument array axis along that positional axis; not mentioning an axis name expresses replication. If an argument, or argument subtree, has a corresponding spec of None, that argument is not sharded.out_specs – a pytree with
jax.sharding.PartitionSpecornnx.StateSharding(mapping substates toPartitionSpec``s) instances as leaves, with a tree structure that is a tree prefix of the output of ``f. EachPartitionSpecrepresents how the corresponding output shards should be concatenated. In eachPartitionSpec, metioning ameshaxis name at a position expresses concatenation of that mesh axis’s shards along the corresponding positional axis. Not mentioning ameshaxis name expresses a promise that the output values are equal along that mesh axis, and that rather than concatenating only a single value should be produced.axis_names – optional set of axis names from
meshover which the functionfis manual. If empty,f, is manual over all mesh axes.check_vma – optional boolean representing whether to enable additional validity checks and automatic differentiation optimizations. The validity checks concern whether any mesh axis names not mentioned in
out_specsare consistent with how the outputs offare replicated.graph – If
True(default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. IfFalse, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not supportStateShardingor sharedVariablereferences.graph_updates – If
True, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect whengraph=False. WhenFalse, usingStateShardingis not supported.
- Returns:
A callable that applies the input function
facross data sharded according to themeshandin_specs.
- flax.nnx.remat(f=<flax.typing.Missing object>, *, prevent_cse=True, static_argnums=(), policy=None, graph=None, graph_updates=None)[source]#
A ‘lifted’ version of the jax.checkpoint (a.k.a.
jax.remat).flax.nnx.remat, similar tojax.checkpointcan provide control over, forexample, how
flax.nnx.gradvalues are computed and saved during the forward pass versus how they are recomputed during the backward pass, trading off memory and FLOPs.
Learn more in Flax NNX vs JAX Transformations.
- To learn about
jax.remat, go to JAX’s
- Parameters:
f – Function to be rematerialized.
prevent_cse – Optional, bool. If True, prevents common subexpression elimination. Default True.
static_argnums – Optional, int or tuple of ints. Specifies which positional arguments to treat as static.
policy – Optional, callable. A policy for which intermediates to save during the forward pass.
graph – If
True(default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. IfFalse, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not support sharedVariablereferences.graph_updates – If
True, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect whengraph=False.
- flax.nnx.scan(f=<class 'flax.typing.Missing'>, *, length=None, reverse=False, unroll=1, _split_transpose=False, in_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), out_axes=(<class 'flax.nnx.transforms.iteration.Carry'>, 0), transform_metadata=FrozenDict({}), graph=None, graph_updates=None)[source]#
A Flax NNX transformation of jax.lax.scan.
Example:
import jax from flax import nnx class Block(nnx.Module): def __init__(self, input_dim, features, *, rngs): self.linear = nnx.Linear(input_dim, features, rngs=rngs) self.dropout = nnx.Dropout(0.1, rngs=rngs) def __call__(self, x: jax.Array): x = self.linear(x) x = self.dropout(x) x = jax.nn.relu(x) return x class Model(nnx.Module): def __init__(self, num_layers, features, *, rngs): # In this model implementation we create # multiple blocks using vmap # As Block contains dropout op, we prefer # to split RNG into num_layers of RNGs # using @nnx.split_rngs decorator. # Next, nnx.vmap creates a vectorized version of Block. # in_axes and out_axes define vectorization axis # of the input splitted rngs and the output Block instance. # Both axes should be 0. @nnx.split_rngs(splits=num_layers) @nnx.vmap(in_axes=(0,), out_axes=0) def create_block(rngs: nnx.Rngs): return Block(features, features, rngs=rngs) self.blocks = create_block(rngs) self.num_layers = num_layers def __call__(self, x): # Forward pass method implementation # We use nnx.scan to apply sequentially the blocks # on the input, for example with num_layers=3 # output = block[0](x) # output = block[1](output) # output = block[2](output) # # In `forward` function defined below: # - x represents the loop carry value # - model is the data to scan along the leading axis # nnx.scan args: # - in_axes marks the inputs: x is marked as carry # and the model is to scan along the axis 0 # - out_axes marks the output as carry @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry) def forward(x, model): x = model(x) return x return forward(x, self.blocks) # Alternatively, we can also decorate `self.__call__` method # @nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry) # def __call__(self, x): # return self.blocks(x) model = Model(2, 4, rngs=nnx.Rngs(0)) _, params, _ = nnx.split(model, nnx.Param, ...) print(params) # kernel of shape: (2, 4, 4) x = jnp.arange(5 * 4, dtype="float32").reshape((5, 4)) y = model(x) print(y.shape) # shape: (5, 4)
- Parameters:
f – a Python function to be scanned
length – optional integer specifying the number of loop iterations
reverse – optional boolean specifying whether to run the scan iteration forward (the default) or in reverse
unroll – optional positive int or bool specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop.
in_axes – integer, None,
flax.nnx.Carryor sequence of values specifying the kind of input args. Integer value would specify the axis of corresponding input data to scan along.flax.nnx.Carrymarks the input data as loop carry value. None marks the input data as auxiliary input.out_axes – integer, None,
flax.nnx.Carryor sequence of values specifying the kind of output args. Seein_axesfor details. Note that Ifin_axescontainsflax.nnx.Carrythenout_axesmust also containflax.nnx.Carry.graph_updates – If
True, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect whengraph=False. WhenFalse, usingStateAxesis not supported.
- class flax.nnx.Carry[source]#
Helper class for
flax.nnx.scan()function to mark input and output axis as carry.
- flax.nnx.value_and_grad(f=<class 'flax.typing.Missing'>, *, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=(), graph=None, graph_updates=None)[source]#
Object-aware version of
jax.value_and_grad.Like
grad(), but returns both the value and the gradient off.- Parameters:
f – Function to be differentiated. Its arguments at positions specified by
argnumsshould be arrays, scalars, graph nodes or standard Python containers. Argument arrays in the positions specified byargnumsmust be of inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape()but not arrays with shape(1,)etc.)argnums – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).
has_aux – Optional, bool. Indicates whether
freturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic – Optional, bool. Indicates whether
fis promised to be holomorphic. If True, inputs and outputs must be complex. Default False.allow_int – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False.
graph – If
True(default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. IfFalse, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not supportDiffStateor sharedVariablereferences.graph_updates – If
True, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect whengraph=False. WhenFalse, usingDiffStateis not supported.
- Returns:
A function with the same arguments as
fthat evaluates bothfand the gradient offand returns them as a pair (a two-element tuple). Ifargnumsis an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a sequence of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. Ifhas_auxis True then a tuple of ((value, auxiliary_data), gradient) is returned.
- flax.nnx.vmap(f=<class 'flax.typing.Missing'>, *, in_axes=0, out_axes=0, axis_name=None, axis_size=None, spmd_axis_name=None, transform_metadata=FrozenDict({}), graph=None, graph_updates=None)[source]#
Reference-aware version of jax.vmap.
- Parameters:
f – Function to be mapped over additional axes.
in_axes – An integer, None, or sequence of values specifying which input array axes to map over (see jax.vmap). In addition to integers and None,
StateAxescan be used to control how graph nodes like Modules are vectorized by specifying the axes to be applied to substates of the graph node given a Filter.out_axes – An integer, None, or pytree indicating where the mapped axis should appear in the output (see jax.vmap).
axis_name – Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied.
axis_size – Optional, an integer indicating the size of the axis to be mapped. If not provided, the mapped axis size is inferred from arguments.
graph – If
True(default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. IfFalse, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not supportStateAxesor sharedVariablereferences.graph_updates – If
True, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect whengraph=False. WhenFalse, usingStateAxesis not supported.
- Returns:
Batched/vectorized version of
fwith arguments that correspond to those off, but with extra array axes at positions indicated byin_axes, and a return value that corresponds to that off, but with extra array axes at positions indicated byout_axes.
Example:
>>> from flax import nnx >>> from jax import random, numpy as jnp ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((5, 2)) ... >>> @nnx.vmap(in_axes=(None, 0), out_axes=0) ... def forward(model, x): ... return model(x) ... >>> y = forward(model, x) >>> y.shape (5, 3)
>>> class LinearEnsemble(nnx.Module): ... def __init__(self, num, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3))) ... >>> model = LinearEnsemble(5, rngs=nnx.Rngs(0)) >>> x = jnp.ones((2,)) ... >>> @nnx.vmap(in_axes=(0, None), out_axes=0) ... def forward(model, x): ... return x @ model.w ... >>> y = forward(model, x) >>> y.shape (5, 3)
To control control how graph node substates are vectorized,
StateAxescan be passed toin_axesandout_axesspecifying the axes to be applied to each substate given a filter. The following example shows how to share the parameters between the ensemble members which keeping different batch statistics and dropout random state:>>> class Foo(nnx.Module): ... def __init__(self): ... self.a = nnx.Param(jnp.arange(4)) ... self.b = nnx.BatchStat(jnp.arange(4)) ... >>> state_axes = nnx.StateAxes({nnx.Param: 0, nnx.BatchStat: None}) >>> @nnx.vmap(in_axes=(state_axes,), out_axes=0) ... def mul(foo): ... return foo.a * foo.b ... >>> foo = Foo() >>> y = mul(foo) >>> y Array([[0, 0, 0, 0], [0, 1, 2, 3], [0, 2, 4, 6], [0, 3, 6, 9]], dtype=int32)
- flax.nnx.eval_shape(f, *args, graph=None, graph_updates=None, **kwargs)[source]#
- A “lifted” version of jax.eval_shape
that can handle flax.nnx.Module / graph nodes as arguments.
- Similar to
jax.eval_shape, it computes the shape/dtype of a function f without performing any floating point operations (FLOPs) which can be expensive. This can be useful for performing shape inference, for example. Unlike jax.eval_shape, nnx.eval_shape will automatically compute the expected sharding based on Flax sharding metadata for all Variables not using explicit sharding.
- Parameters:
f – the function to evaluate.
*args – positional arguments to
f.graph – If
True(default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. IfFalse, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.graph_updates – If
True, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect whengraph=False.**kwargs – keyword arguments to
f.
- flax.nnx.custom_vjp(fun=<flax.typing.Missing object>, *, nondiff_argnums=(), graph=None, graph_updates=None)[source]#
Reference aware version of jax.custom_vjp.
nnx.custom_vjpaccepts Modules and other Flax NNX objects as arguments. The main difference with the JAX version is that, because Modules follow reference semantics, they propagate the State updates for the inputs as auxiliary outputs. This means that the incoming gradients in thebwdfunction will have the form(input_updates_g, out_g)whereinput_updates_gis the gradient updated state of the inputs w.r.t. to the inputs. All Module terms on the inputs will an associatedStateterm ininput_updates_g, while all non-Module terms will appear as None. The shape of the tangent will be expected to have the same shape as the input, withStateterms in place of the corresponding Module terms.Example:
>>> import jax >>> import jax.numpy as jnp >>> from flax import nnx ... >>> class Foo(nnx.Module): ... def __init__(self, x, y): ... self.x = nnx.Param(x) ... self.y = nnx.Param(y) ... >>> @nnx.custom_vjp ... def f(m: Foo): ... return jnp.sin(m.x) * m.y ... >>> def f_fwd(m: Foo): ... return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) ... >>> def f_bwd(res, g): ... input_updates_g, out_g = g ... cos_x, sin_x, m = res ... (m_updates_g,) = input_updates_g ... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy ... ... m_g['x'][...] = cos_x * out_g * m.y ... m_g['y'][...] = sin_x * out_g ... return (m_g,) ... >>> f.defvjp(f_fwd, f_bwd) ... >>> m = Foo(x=jnp.array(1.), y=jnp.array(2.)) >>> grads = nnx.grad(f)(m) ... >>> jax.tree.map(jnp.shape, grads) State({ 'x': Param( value=() ), 'y': Param( value=() ) })
Note that the State objects that represent Module terms on
input_updates_ghave the same shape as the State objects expected in the output tanget. This means that you can usually just copy them frominput_updates_gand update them with their corresponding gradient values.You can select which substates are differentiable (have a tangent) for Modules and other graph nodes by passing a
DiffStatetonondiff_argnums. For example, if you want to differentiate only thexattribute of theFooclass, you can do the following:>>> x_attribute = nnx.PathContains('x') >>> diff_state = nnx.DiffState(0, x_attribute) ... >>> @nnx.custom_vjp(nondiff_argnums=(diff_state,)) ... def f(m: Foo): ... return jnp.sin(m.x) * m.y # type: ignore >>> def f_fwd(m: Foo): ... y = f(m) ... res = (jnp.cos(m.x), m) # type: ignore ... return y, res ... >>> def f_bwd(res, g): ... input_updates_g, out_g = g ... cos_x, m = res ... (m_updates_g,) = input_updates_g ... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy ... ... m_g.x[...] = cos_x * out_g * m.y ... del m_g['y'] # y is not differentiable ... return (m_g,) >>> f.defvjp(f_fwd, f_bwd) ... >>> m = Foo(x=jnp.array(1.), y=jnp.array(2.)) >>> grad = nnx.grad(f, argnums=nnx.DiffState(0, x_attribute))(m) ... >>> jax.tree.map(jnp.shape, grad) State({ 'x': Param( value=() ) })
Note that
gradcannot calculate gradients for states that don’t have a tangent defined bycustom_vjp, in the example above we reuse the samex_attributefilter to keepcustom_vjpandgradin sync.graph_updates=False
When
graph_updates=Falseorgraph=False, the behavior is closer tojax.custom_vjp: thebwdfunction receivesout_gdirectly, and tangent types are the same as the input types, this means the tangent for a Module is a Module instance with gradient values set on its attributes. This mode does not supportDiffStateinnondiff_argnums. Additionally, Variables in differentiable arguments cannot be mutated insidef. If mutations are needed, pass the relevant Variables through a non-differentiable argument instead.Example:
>>> @nnx.custom_vjp(graph_updates=False) ... def f(m: Foo): ... return jnp.sin(m.x) * m.y ... >>> def f_fwd(m: Foo): ... return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) ... >>> def f_bwd(res, g): ... cos_x, sin_x, m = res ... m_g = nnx.clone(m) ... m_g.x[...] = cos_x * g * m.y ... m_g.y[...] = sin_x * g ... return (m_g,) ... >>> f.defvjp(f_fwd, f_bwd)
- Parameters:
fun – Callable base function.
nondiff_argnums – Tuple of integers or DiffState objects specifying the argument indices that are not differentiated. By default all arguments are differentiated. Integers cannot be used to mark graph nodes such as Modules as non-differentiable, in this case use a DiffState object. DiffState objects define the set of differentiable substates, contrary to what the name of this argument suggests, this is done for compatibility with
grad.graph – If
True(default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. IfFalse, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not supportDiffStateinnondiff_argnums.graph_updates – If
True, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect whengraph=False. WhenFalse, usingDiffStateis not supported.
- flax.nnx.vjp(f=<flax.typing.Missing object>, *primals, has_aux=False, reduce_axes=(), graph=None, graph_updates=None)[source]#
Stateful version of
jax.vjpthat propagates NNX Variable updates.Example:
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) ... >>> def loss_fn(m, x): ... return jnp.sum(m(x)) ... >>> primals_out, vjp_fn = nnx.vjp(loss_fn, m, x, graph=False) >>> m_grad, x_grad = vjp_fn(jnp.ones_like(primals_out))
Can also be used as a decorator:
>>> @nnx.vjp(graph=False) ... def f(m, x): ... return jnp.sum(m(x)) ... >>> primals_out, vjp_fn = f(m, x)
- Parameters:
f – Function to be differentiated. Its arguments can be arrays, scalars, or pytrees containing arrays and NNX Variables.
*primals – A sequence of primal values at which the Jacobian of
fshould be evaluated.has_aux – Optional, bool. Indicates whether
freturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.reduce_axes – Deprecated, do not use.
graph – If
True(default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. IfFalse, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.graph_updates – If
True, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect whengraph=False.
- Returns:
If
has_auxis False, returns a(primals_out, vjp_fn)pair.vjp_fntakes a cotangent with the same structure asprimals_outand returns gradients for each primal argument. Ifhas_auxis True, returns(primals_out, vjp_fn, aux).
- flax.nnx.jvp(f=<flax.typing.Missing object>, primals=<flax.typing.Missing object>, tangents=<flax.typing.Missing object>, *, has_aux=False, graph=None, graph_updates=None)[source]#
Stateful version of
jax.jvpthat propagates NNX Variable updates.Example:
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) ... >>> def f(m, x): ... return jnp.sum(m(x)) ... >>> m_tangent = jax.tree.map(jnp.zeros_like, m) >>> x_tangent = jnp.ones_like(x) >>> primals_out, tangent_out = nnx.jvp( ... f, (m, x), (m_tangent, x_tangent), graph=False ... )
Can also be used as a decorator:
>>> @nnx.jvp(graph=False) ... def f(m, x): ... return jnp.sum(m(x)) ... >>> primals_out, tangent_out = f((m, x), (m_tangent, x_tangent))
- Parameters:
f – Function to be differentiated. Its arguments can be arrays, scalars, or pytrees containing arrays and NNX Variables.
primals – A tuple of primal values at which the Jacobian of
fshould be evaluated.tangents – A tuple of tangent vectors, with the same structure as
primals.has_aux – Optional, bool. Indicates whether
freturns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.graph – If
True(default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. IfFalse, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.graph_updates – If
True, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect whengraph=False.
- Returns:
If
has_auxis False, returns(primals_out, tangent_out). Ifhas_auxis True, returns(primals_out, tangent_out, aux).
- flax.nnx.cond(pred, true_fun, false_fun, *operands, graph=None, graph_updates=None)[source]#
Conditionally apply
true_funorfalse_fun.Wraps jax.lax.cond to support Flax NNX modules and variables.
- Parameters:
pred – boolean scalar. If True,
true_funis applied, otherwisefalse_fun.true_fun – function to apply if
predis True.false_fun – function to apply if
predis False.*operands – operands passed to whichever branch is selected.
graph – If
True(default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. IfFalse, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.graph_updates – If
True, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect whengraph=False.
- flax.nnx.switch(index, branches, *operands, graph=None, graph_updates=None)[source]#
Select and apply one of
branchesbased onindex.Wraps jax.lax.switch to support Flax NNX modules and variables.
- Parameters:
index – integer scalar indicating which branch to apply.
branches – sequence of functions to select from.
*operands – operands passed to the selected branch.
graph – If
True(default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. IfFalse, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol.graph_updates – If
True, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect whengraph=False.
- flax.nnx.while_loop(cond_fun, body_fun, init_val, *, graph=None, graph_updates=None)[source]#
A Flax NNX transformation of jax.lax.while_loop.
Caution: for the NNX internal reference tracing mechanism to work, you cannot change the variable reference structure of
init_valinsidebody_fun.Example:
>>> import jax >>> from flax import nnx >>> def fwd_fn(input): ... module, x, count = input ... return module, module(x), count - 1.0 >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) >>> x = jax.random.normal(jax.random.key(0), (10,)) >>> # `module` will be called three times >>> _, y, _ = nnx.while_loop( ... lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
- Parameters:
cond_fun – A function for the continue condition of the while loop, taking a single input of type
Tand outputting a boolean.body_fun – A function that takes an input of type
Tand outputs anT. Note that both data and modules ofTmust have the same reference structure between inputs and outputs.init_val – The initial input for
cond_funandbody_fun. Must be of typeT.graph – if True, use graph-mode (default). If False, use tree-mode. If None, uses the value of
nnx_graph_modeconfig.graph_updates – If
True, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect whengraph=False.
- flax.nnx.fori_loop(lower, upper, body_fun, init_val, *, unroll=None, graph=None, graph_updates=None)[source]#
A Flax NNX transformation of jax.lax.fori_loop.
Caution: for the NNX internal reference tracing mechanism to work, you cannot change the variable reference structure of init_val inside body_fun.
Example:
>>> import jax >>> from flax import nnx >>> def fwd_fn(i, input): ... m, x = input ... m.kernel[...] = jnp.identity(10) * i ... return m, m(x) >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) >>> x = jax.random.normal(jax.random.key(0), (10,)) >>> _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x)) >>> np.testing.assert_array_equal(y, x * 2 * 3)
- Parameters:
lower – An integer representing the loop index lower bound (inclusive).
upper – An integer representing the loop index upper bound (exclusive).
body_fun – a function that takes an input of type
Tand outputs anT. Note that both data and modules ofTmust have the same reference structure between inputs and outputs.init_val – the initial input for body_fun. Must be of type
T.unroll – An optional integer or boolean that determines how much to unroll the loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a boolean is provided, it will determine if the loop is competely unrolled (i.e.
unroll=True) or left completely unrolled (i.e.unroll=False). This argument is only applicable if the loop bounds are statically known.graph – if True, use graph-mode (default). If False, use tree-mode. If None, uses the value of
nnx_graph_modeconfig.graph_updates – If
True, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect whengraph=False.
- Returns:
A loop value from the final iteration, of type
T.