Normalization#
- class flax.nnx.BatchNorm(self, num_features, *, use_running_average=False, axis=-1, momentum=0.99, epsilon=1e-05, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, axis_name=None, axis_index_groups=None, use_fast_variance=True, promote_dtype=<function promote_dtype>, rngs, bias_metadata=mappingproxy({}), scale_metadata=mappingproxy({}))[source]#
BatchNorm Module.
To calculate the batch norm on the input and update the batch statistics, call the
train()method (or pass inuse_running_average=Falsein the constructor or during call time).To use the stored batch statistics’ running average, call the
eval()method (or pass inuse_running_average=Truein the constructor or during call time).Example usage:
>>> from flax import nnx >>> import jax, jax.numpy as jnp >>> x = jax.random.normal(jax.random.key(0), (5, 6)) >>> layer = nnx.BatchNorm(num_features=6, momentum=0.9, epsilon=1e-5, ... dtype=jnp.float32, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': Param( value=(6,) ), 'mean': BatchStat( value=(6,) ), 'scale': Param( value=(6,) ), 'var': BatchStat( value=(6,) ) }) >>> # calculate batch norm on input and update batch statistics >>> layer.train() >>> y = layer(x) >>> batch_stats1 = nnx.clone(nnx.state(layer, nnx.BatchStat)) # keep a copy >>> y = layer(x) >>> batch_stats2 = nnx.state(layer, nnx.BatchStat) >>> assert (batch_stats1['mean'][...] != batch_stats2['mean'][...]).all() >>> assert (batch_stats1['var'][...] != batch_stats2['var'][...]).all() >>> # use stored batch statistics' running average >>> layer.eval() >>> y = layer(x) >>> batch_stats3 = nnx.state(layer, nnx.BatchStat) >>> assert (batch_stats2['mean'][...] == batch_stats3['mean'][...]).all() >>> assert (batch_stats2['var'][...] == batch_stats3['var'][...]).all()
- Parameters:
num_features – the number of input features.
use_running_average – if True, the stored batch statistics will be used instead of computing the batch statistics on the input.
axis – the feature or non-batch axis of the input.
momentum – decay rate for the exponential moving average of the batch statistics.
epsilon – a small float added to variance to avoid dividing by zero.
dtype – the dtype of the result (default: infer from input and params).
param_dtype – the dtype passed to parameter initializers (default: float32).
use_bias – if True, bias (beta) is added.
use_scale – if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
bias_init – initializer for bias, by default, zero.
scale_init – initializer for scale, by default, one.
axis_name – the axis name used to combine batch statistics from multiple devices. See
jax.pmapfor a description of axis names (default: None).axis_index_groups – groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example,
[[0, 1], [2, 3]]would independently batch-normalize over the examples on the first two and last two devices. Seejax.lax.psumfor more details. This argument is currently not supported for SPMD jit.use_fast_variance – If true, use a faster, but less numerically stable, calculation for the variance.
promote_dtype – function to promote the dtype of all input array arguments (including Variables accessed through
self) to the desired dtype. The function should accept a tuple of(inputs, mean, var, scale, bias)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.rngs – rng key.
bias_metadata – Optional metadata dictionary to set when initializing the bias.
scale_metadata – Optional metadata dictionary to set when initializing the scale.
- __call__(x, use_running_average=None, *, mask=None)[source]#
Normalizes the input using batch statistics.
- Parameters:
x – the input to be normalized.
use_running_average – if true, the stored batch statistics will be used instead of computing the batch statistics on the input. The
use_running_averageflag passed into the call method will take precedence over theuse_running_averageflag passed into the constructor.
- Returns:
Normalized inputs (the same shape as inputs).
- set_view(use_running_average=None)[source]#
Class method used by
nnx.view.- Parameters:
use_running_average – if True, the stored batch statistics will be used instead of computing the batch statistics on the input.
Methods
set_view([use_running_average])Class method used by
nnx.view.
- class flax.nnx.LayerNorm(self, num_features, *, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, promote_dtype=<function promote_dtype>, rngs, bias_metadata=mappingproxy({}), scale_metadata=mappingproxy({}))[source]#
Layer normalization (https://arxiv.org/abs/1607.06450).
LayerNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.
Example usage:
>>> from flax import nnx >>> import jax >>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6)) >>> layer = nnx.LayerNorm(num_features=6, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'bias': Param( # 6 (24 B) value=Array([0., 0., 0., 0., 0., 0.], dtype=float32) ), 'scale': Param( # 6 (24 B) value=Array([1., 1., 1., 1., 1., 1.], dtype=float32) ) }) >>> y = layer(x)
- Parameters:
num_features – the number of input features.
epsilon – A small float added to variance to avoid dividing by zero.
dtype – the dtype of the result (default: infer from input and params).
param_dtype – the dtype passed to parameter initializers (default: float32).
use_bias – If True, bias (beta) is added.
use_scale – If True, multiply by scale (gamma). When the next layer is linear (also e.g. nnx.relu), this can be disabled since the scaling will be done by the next layer.
bias_init – Initializer for bias, by default, zero.
scale_init – Initializer for scale, by default, one.
reduction_axes – Axes for computing normalization statistics.
feature_axes – Feature axes for learned bias and scaling.
axis_name – the axis name used to combine batch statistics from multiple devices. See
jax.pmapfor a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap.axis_index_groups – groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example,
[[0, 1], [2, 3]]would independently batch-normalize over the examples on the first two and last two devices. Seejax.lax.psumfor more details. This argument is currently not supported for SPMD jit.use_fast_variance – If true, use a faster, but less numerically stable, calculation for the variance.
promote_dtype – function to promote the dtype of all input array arguments (including Variables accessed through
self) to the desired dtype. The function should accept a tuple of(inputs, scale, bias)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.rngs – rng key.
bias_metadata – Optional metadata dictionary to set when initializing the bias.
scale_metadata – Optional metadata dictionary to set when initializing the scale.
- __call__(x, *, mask=None)[source]#
Applies layer normalization on the input.
- Parameters:
x – the inputs
- Returns:
Normalized inputs (the same shape as inputs).
Methods
- class flax.nnx.RMSNorm(self, num_features, *, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_scale=True, scale_init=<function ones>, reduction_axes=-1, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, promote_dtype=<function promote_dtype>, rngs, scale_metadata=mappingproxy({}))[source]#
RMS Layer normalization (https://arxiv.org/abs/1910.07467).
RMSNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the standard deviation of the activations, RMSNorm does not re-center at all and instead normalizes by the root mean square of the activations.
Example usage:
>>> from flax import nnx >>> import jax >>> x = jax.random.normal(jax.random.key(0), (5, 6)) >>> layer = nnx.RMSNorm(num_features=6, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'scale': Param( # 6 (24 B) value=Array([1., 1., 1., 1., 1., 1.], dtype=float32) ) }) >>> y = layer(x)
- Parameters:
num_features – the number of input features.
epsilon – A small float added to variance to avoid dividing by zero.
dtype – the dtype of the result (default: infer from input and params).
param_dtype – the dtype passed to parameter initializers (default: float32).
use_scale – If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
scale_init – Initializer for scale, by default, one.
reduction_axes – Axes for computing normalization statistics.
feature_axes – Feature axes for learned bias and scaling.
axis_name – the axis name used to combine batch statistics from multiple devices. See
jax.pmapfor a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap.axis_index_groups – groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example,
[[0, 1], [2, 3]]would independently batch-normalize over the examples on the first two and last two devices. Seejax.lax.psumfor more details. This argument is currently not supported for SPMD jit.use_fast_variance – If true, use a faster, but less numerically stable, calculation for the variance.
promote_dtype – function to promote the dtype of all input array arguments (including Variables accessed through
self) to the desired dtype. The function should accept a tuple of(inputs, scale)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.rngs – rng key.
scale_metadata – Optional metadata dictionary to set when initializing the scale.
- __call__(x, mask=None)[source]#
Applies layer normalization on the input.
- Parameters:
x – the inputs
- Returns:
Normalized inputs (the same shape as inputs).
Methods
- class flax.nnx.GroupNorm(self, num_features, num_groups=32, group_size=None, *, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, reduction_axes=None, axis_name=None, axis_index_groups=None, use_fast_variance=True, promote_dtype=<function promote_dtype>, rngs, bias_metadata=mappingproxy({}), scale_metadata=mappingproxy({}))[source]#
Group normalization (arxiv.org/abs/1803.08494).
This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group.
Note
LayerNorm is a special case of GroupNorm where
num_groups=1.Example usage:
>>> from flax import nnx >>> import jax >>> import numpy as np ... >>> x = jax.random.normal(jax.random.key(0), (3, 4, 5, 6)) >>> layer = nnx.GroupNorm(num_features=6, num_groups=3, rngs=nnx.Rngs(0)) >>> nnx.state(layer) State({ 'bias': Param( # 6 (24 B) value=Array([0., 0., 0., 0., 0., 0.], dtype=float32) ), 'scale': Param( # 6 (24 B) value=Array([1., 1., 1., 1., 1., 1.], dtype=float32) ) }) >>> y = layer(x) ... >>> y = nnx.GroupNorm(num_features=6, num_groups=1, rngs=nnx.Rngs(0))(x) >>> y2 = nnx.LayerNorm(num_features=6, reduction_axes=(1, 2, 3), rngs=nnx.Rngs(0))(x) >>> np.testing.assert_allclose(y, y2)
- Parameters:
num_features – the number of input features/channels.
num_groups – the total number of channel groups. The default value of 32 is proposed by the original group normalization paper.
group_size – the number of channels in a group.
epsilon – A small float added to variance to avoid dividing by zero.
dtype – the dtype of the result (default: infer from input and params).
param_dtype – the dtype passed to parameter initializers (default: float32).
use_bias – If True, bias (beta) is added.
use_scale – If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
bias_init – Initializer for bias, by default, zero.
scale_init – Initializer for scale, by default, one.
reduction_axes – List of axes used for computing normalization statistics. This list must include the final dimension, which is assumed to be the feature axis. Furthermore, if the input used at call time has additional leading axes compared to the data used for initialisation, for example due to batching, then the reduction axes need to be defined explicitly.
axis_name – the axis name used to combine batch statistics from multiple devices. See
jax.pmapfor a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.axis_index_groups – groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example,
[[0, 1], [2, 3]]would independently batch-normalize over the examples on the first two and last two devices. Seejax.lax.psumfor more details. This argument is currently not supported for SPMD jit.use_fast_variance – If true, use a faster, but less numerically stable, calculation for the variance.
promote_dtype – function to promote the dtype of all input array arguments (including Variables accessed through
self) to the desired dtype. The function should accept a tuple of(inputs, scale, bias)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.rngs – rng key.
bias_metadata – Optional metadata dictionary to set when initializing the bias.
scale_metadata – Optional metadata dictionary to set when initializing the scale.
- __call__(x, *, mask=None)[source]#
Applies group normalization to the input (arxiv.org/abs/1803.08494).
- Parameters:
x – the input of shape
...self.num_featureswhereself.num_featuresis a channels dimension and...represents an arbitrary number of extra dimensions that can be used to accumulate statistics over. If no reduction axes have been specified then all additional dimensions...will be used to accumulate statistics apart from the leading dimension which is assumed to represent the batch.mask – Binary array of shape broadcastable to
inputstensor, indicating the positions for which the mean and variance should be computed.
- Returns:
Normalized inputs (the same shape as inputs).
Methods
- class flax.nnx.InstanceNorm(self, num_features, *, epsilon=1e-06, dtype=None, param_dtype=<class 'jax.numpy.float32'>, use_bias=True, use_scale=True, bias_init=<function zeros>, scale_init=<function ones>, feature_axes=-1, axis_name=None, axis_index_groups=None, use_fast_variance=True, promote_dtype=<function promote_dtype>, rngs, bias_metadata=mappingproxy({}), scale_metadata=mappingproxy({}))[source]#
Instance normalization (https://arxiv.org/abs/1607.08022v3). InstanceNorm normalizes the activations of the layer for each channel (rather than across all channels like Layer Normalization), and for each given example in a batch independently (rather than across an entire batch like Batch Normalization). i.e. applies a transformation that maintains the mean activation within each channel within each example close to 0 and the activation standard deviation close to 1.
Note
This normalization operation is identical to LayerNorm and GroupNorm; the difference is simply which axes are reduced and the shape of the feature axes (i.e. the shape of the learnable scale and bias parameters).
Example usage:
>>> from flax import nnx >>> import jax >>> import numpy as np >>> # dimensions: (batch, height, width, channel) >>> x = jax.random.normal(jax.random.key(0), (2, 3, 4, 5)) >>> layer = nnx.InstanceNorm(5, rngs=nnx.Rngs(0)) >>> nnx.state(layer, nnx.Param) State({ 'bias': Param( # 5 (20 B) value=Array([0., 0., 0., 0., 0.], dtype=float32) ), 'scale': Param( # 5 (20 B) value=Array([1., 1., 1., 1., 1.], dtype=float32) ) }) >>> y = layer(x) >>> # having a channel_axis of -1 in InstanceNorm is identical to reducing all non-batch, >>> # non-channel axes and using the feature_axes as the feature_axes in LayerNorm >>> y2 = nnx.LayerNorm(5, reduction_axes=[1, 2], feature_axes=-1, rngs=nnx.Rngs(0))(x) >>> np.testing.assert_allclose(y, y2, atol=1e-7) >>> y3 = nnx.GroupNorm(5, num_groups=x.shape[-1], rngs=nnx.Rngs(0))(x) >>> np.testing.assert_allclose(y, y3, atol=1e-7)
- Parameters:
num_features – the number of input features/channels.
epsilon – A small float added to variance to avoid dividing by zero.
dtype – the dtype of the result (default: infer from input and params).
param_dtype – the dtype passed to parameter initializers (default: float32).
use_bias – If True, bias (beta) is added.
use_scale – If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.
bias_init – Initializer for bias, by default, zero.
scale_init – Initializer for scale, by default, one.
feature_axes – Axes for features. The learned bias and scaling parameters will be in the shape defined by the feature axes. All other axes except the batch axes (which is assumed to be the leading axis) will be reduced.
axis_name – the axis name used to combine batch statistics from multiple devices. See
jax.pmapfor a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.axis_index_groups – groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example,
[[0, 1], [2, 3]]would independently batch-normalize over the examples on the first two and last two devices. Seejax.lax.psumfor more details. This argument is currently not supported for SPMD jit.use_fast_variance – If true, use a faster, but less numerically stable, calculation for the variance.
promote_dtype – function to promote the dtype of all input array arguments (including Variables accessed through
self) to the desired dtype. The function should accept a tuple of(inputs, scale, bias)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.rngs – The rng key.
bias_metadata – Optional metadata dictionary to set when initializing the bias.
scale_metadata – Optional metadata dictionary to set when initializing the scale.
- __call__(x, *, mask=None)[source]#
Applies instance normalization on the input.
- Parameters:
x – the inputs
mask – Binary array of shape broadcastable to
inputsarray, indicating the positions for which the mean and variance should be computed.
- Returns:
Normalized inputs (the same shape as inputs).
Methods
- class flax.nnx.SpectralNorm(self, layer_instance, *, n_steps=1, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, error_on_non_matrix=False, update_stats=True, rngs)[source]#
Spectral normalization.
See:
Spectral normalization normalizes the weight params so that the spectral norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params spectral normalized before computing its
__call__output.Note
The initialized variables dict will contain, in addition to a ‘params’ collection, a separate ‘batch_stats’ collection that will contain a
uvector andsigmavalue, which are intermediate values used when performing spectral normalization. During training, we pass inupdate_stats=Trueso thatuandsigmaare updated with the most recently computed values using power iteration. This will help the power iteration method approximate the true singular value more accurately over time. During eval, we pass inupdate_stats=Falseto ensure we get deterministic behavior from the model.Example usage:
>>> from flax import nnx >>> import jax >>> rngs = nnx.Rngs(0) >>> x = jax.random.normal(jax.random.key(0), (3, 4)) >>> layer = nnx.SpectralNorm(nnx.Linear(4, 5, rngs=rngs), rngs=rngs) >>> jax.tree.map(jax.numpy.shape, nnx.state(layer, nnx.Param)) State({ 'layer_instance': { 'bias': Param( value=(5,) ), 'kernel': Param( value=(4, 5) ) } }) >>> y = layer(x, update_stats=True)
- Parameters:
layer_instance – Module instance that is wrapped with SpectralNorm
n_steps – How many steps of power iteration to perform to approximate the singular value of the weight params.
epsilon – A small float added to l2-normalization to avoid dividing by zero.
dtype – the dtype of the result (default: infer from input and params).
param_dtype – the dtype passed to parameter initializers (default: float32).
error_on_non_matrix – Spectral normalization is only defined on matrices. By default, this module will return scalars unchanged and flatten higher-order tensors in their leading dimensions. Setting this flag to True will instead throw an error if a weight tensor with dimension greater than 2 is used by the layer.
update_stats – if True, the stored batch statistics will be used instead of computing the batch statistics on the input.
rngs – rng key.
- __call__(x, update_stats=None)[source]#
Compute the largest singular value of the weights in
self.layer_instanceusing power iteration and normalize the weights using this value before computing the__call__output.- Parameters:
x – the input array of the nested layer
update_stats – if True, update the internal
uvector andsigmavalue after computing their updated values using power iteration. This will help the power iteration method approximate the true singular value more accurately over time.
- Returns:
Output of the layer using spectral normalized weights.
Methods
- class flax.nnx.WeightNorm(self, layer_instance, *, feature_axes=-1, use_scale=True, scale_init=<function ones>, epsilon=1e-12, dtype=None, param_dtype=<class 'jax.numpy.float32'>, variable_filter=PathContains('kernel', exact=True), promote_dtype=<function promote_dtype>, rngs)[source]#
L2 weight normalization (https://arxiv.org/abs/1602.07868).
Weight normalization normalizes the weight params so that the l2-norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params l2-normalized before computing its
__call__output.Example usage:
>>> import jax >>> import numpy as np >>> from flax import nnx >>> class Foo(nnx.Module): ... def __init__(self, rngs: nnx.Rngs): ... self.normed_linear = nnx.WeightNorm( ... nnx.Linear(8, 4, rngs=rngs), ... variable_filter=nnx.PathContains('kernel'), ... rngs=rngs, ... ) ... ... def __call__(self, x: jax.Array) -> jax.Array: ... return self.normed_linear(x) >>> rng = jax.random.key(42) >>> model = Foo(rngs=nnx.Rngs(rng)) >>> x = jax.random.normal(rng, (5, 8)) >>> y = model(x) >>> y.shape (5, 4) >>> w = model.normed_linear.layer_instance.kernel[...] >>> col_norms = np.linalg.norm(np.array(w), axis=0) >>> np.testing.assert_allclose(col_norms, np.ones(4))
- Parameters:
layer_instance – The layer instance to wrap.
feature_axes – The axes to normalize.
use_scale – Whether to use a scale parameter.
scale_init – The initializer for the scale parameter, by default ones.
epsilon – The epsilon value for the normalization, by default 1e-12.
dtype – The dtype of the result, by default infer from input and params.
param_dtype – The dtype of the parameters, by default float32.
variable_filter – The variable filter, by default
nnx.PathContains('kernel').promote_dtype – function to promote the dtype of all input array arguments (including Variables accessed through
self) to the desired dtype. This is used internally by WeightNorm when normalizing weights.rngs – The rng key.
- __call__(x, *args, **kwargs)[source]#
Compute the l2-norm of the weights in
self.layer_instanceand normalize the weights using this value before computing the__call__output.- Parameters:
*args – positional arguments to be passed into the call method of the underlying layer instance in
self.layer_instance.**kwargs – keyword arguments to be passed into the call method of the underlying layer instance in
self.layer_instance.
- Returns:
Output of the layer using l2-normalized weights.
Methods