Attention#
- class flax.nnx.MultiHeadAttention(self, num_heads, in_features, qkv_features=None, out_features=None, num_kv_heads=None, in_kv_features=None, *, dtype=None, param_dtype=<class 'jax.numpy.float32'>, broadcast_dropout=True, dropout_rate=0.0, deterministic=None, precision=None, kernel_init=<function variance_scaling.<locals>.init>, out_kernel_init=None, bias_init=<function zeros>, out_bias_init=None, use_bias=True, attention_fn=<function dot_product_attention>, decode=None, normalize_qk=False, qkv_promote_dtype=<function promote_dtype>, out_promote_dtype=<function promote_dtype>, ln_promote_dtype=<function promote_dtype>, qkv_dot_general=None, out_dot_general=None, qkv_dot_general_cls=None, out_dot_general_cls=None, rngs, keep_rngs=True, kernel_metadata=mappingproxy({}), out_kernel_metadata=mappingproxy({}), bias_metadata=mappingproxy({}), out_bias_metadata=mappingproxy({}), query_ln_scale_metadata=mappingproxy({}), key_ln_scale_metadata=mappingproxy({}))[source]#
Multi-head attention.
Example usage:
>>> from flax import nnx >>> import jax >>> layer = nnx.MultiHeadAttention(num_heads=8, in_features=5, qkv_features=16, ... decode=False, rngs=nnx.Rngs(0)) >>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3) >>> shape = (4, 3, 2, 5) >>> q, k, v = ( ... jax.random.uniform(key1, shape), ... jax.random.uniform(key2, shape), ... jax.random.uniform(key3, shape), ... ) >>> # different inputs for inputs_q, inputs_k and inputs_v >>> out = layer(q, k, v) >>> # equivalent output when inferring v >>> assert (layer(q, k) == layer(q, k, k)).all() >>> # equivalent output when inferring k and v >>> assert (layer(q) == layer(q, q)).all() >>> assert (layer(q) == layer(q, q, q)).all()
- Parameters:
num_heads – number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads.
in_features – int or tuple with number of input features.
qkv_features – dimension of the key, query, and value.
out_features – dimension of the last projection.
in_kv_features – number of input features for computing key and value.
num_kv_heads – number of key and value heads. If None, it defaults to
num_heads. If set to a value smaller thannum_heads, Grouped Query Attention (GQA) is used.num_headsmust be divisible bynum_kv_heads.dtype – the dtype of the computation (default: infer from inputs and params)
param_dtype – the dtype passed to parameter initializers (default: float32)
broadcast_dropout – bool: use a broadcasted dropout along batch dims.
dropout_rate – dropout rate
deterministic – if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.
precision – numerical precision of the computation see jax.lax.Precision for details.
kernel_init – initializer for the kernel of the Dense layers.
out_kernel_init – optional initializer for the kernel of the output Dense layer, if None, the kernel_init is used.
bias_init – initializer for the bias of the Dense layers.
out_bias_init – optional initializer for the bias of the output Dense layer, if None, the bias_init is used.
use_bias – bool: whether pointwise QKVO dense transforms use bias.
attention_fn – dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape [bs, dim1, dim2, …, dimN,, num_heads, value_channels]`
decode – whether to prepare and use an autoregressive cache.
normalize_qk – should QK normalization be applied (arxiv.org/abs/2302.05442).
qkv_promote_dtype – function to promote the dtype of all input array arguments (including Variables accessed through
self) to the desired dtype for the query, key, and value LinearGeneral submodules.out_promote_dtype – function to promote the dtype of all input array arguments (including Variables accessed through
self) to the desired dtype for the output LinearGeneral submodule.ln_promote_dtype – function to promote the dtype of all input array arguments (including Variables accessed through
self) to the desired dtype for the LayerNorm submodules (query_ln and key_ln) when normalize_qk=True.rngs – rng key.
keep_rngs – whether to store the input rngs as attribute (i.e. self.rngs = rngs) (default: True). If rngs is stored, we should split the module as graphdef, params, nondiff = nnx.split(module, nnx.Param, …) where nondiff contains RNG object associated with stored self.rngs.
kernel_metadata – Optional metadata dictionary to set when initializing the Dense layers.
out_kernel_metadata – Optional metadata dictionary to set when initializing the output Dense layers. If None, the kernel_metadata is used.
bias_metadata – Optional metadata dictionary to set when initializing the bias of the Dense layers.
out_bias_metadata – Optional metadata dictionary to set when initializing the bias of the output Dense layers. If None, the bias_metadata is used.
query_ln_scale_metadata – Optional metadata dictionary to set when initializing the scale of the query layer norm layer.
key_ln_scale_metadata – Optional metadata dictionary to set when initializing the scale of the key layer norm layer.
- __call__(inputs_q, inputs_k=None, inputs_v=None, *, mask=None, deterministic=None, rngs=None, sow_weights=False, decode=None)[source]#
Applies multi-head dot product attention on the input data.
Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector.
If both inputs_k and inputs_v are None, they will both copy the value of inputs_q (self attention). If only inputs_v is None, it will copy the value of inputs_k.
- Parameters:
inputs_q – input queries of shape [batch_sizes…, length, features].
inputs_k – key of shape [batch_sizes…, length, features]. If None, inputs_k will copy the value of inputs_q.
inputs_v – values of shape [batch_sizes…, length, features]. If None, inputs_v will copy the value of inputs_k.
mask – attention mask of shape [batch_sizes…, num_heads, query_length, key/value_length]. Attention weights are masked out if their corresponding mask value is False.
deterministic – if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. The
deterministicflag passed into the call method will take precedence over thedeterministicflag passed into the constructor.rngs – rng key. The rng key passed into the call method will take precedence over the rng key passed into the constructor.
sow_weights – if
True, the attention weights are sowed into the ‘intermediates’ collection.decode – whether to prepare and use an autoregressive cache. The
decodeflag passed into the call method will take precedence over thedecodeflag passed into the constructor.
- Returns:
output of shape [batch_sizes…, length, features].
- init_cache(input_shape, dtype=<class 'jax.numpy.float32'>)[source]#
Initializes cache for fast autoregressive decoding. When
decode=True, this method must be called first before performing forward inference. When in decode mode, only one token must be passed at a time.Example usage:
>>> from flax import nnx >>> import jax.numpy as jnp ... >>> batch_size = 5 >>> embed_dim = 3 >>> x = jnp.ones((batch_size, 1, embed_dim)) # single token ... >>> model_nnx = nnx.MultiHeadAttention( ... num_heads=2, ... in_features=3, ... qkv_features=6, ... out_features=6, ... decode=True, ... rngs=nnx.Rngs(42), ... ) ... >>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized ... >>> model_nnx.init_cache(x.shape) >>> out_nnx = model_nnx(x)
- set_view(deterministic=None, decode=None, batch_size=None, max_length=None)[source]#
Class method used by
nnx.view.- Parameters:
train – if True, the module is set to training mode.
deterministic – if True, the module is set to deterministic mode.
decode – if True, the module is set to decode mode.
batch_size – the batch size to use for the cache.
max_length – the max length to use for the cache.
Methods
init_cache(input_shape[, dtype])Initializes cache for fast autoregressive decoding.
set_view([deterministic, decode, ...])Class method used by
nnx.view.
- flax.nnx.combine_masks(*masks, dtype=<class 'jax.numpy.float32'>)[source]#
Combine attention masks.
- Parameters:
*masks – set of attention mask arguments to combine, some can be None.
dtype – dtype for the returned mask.
- Returns:
Combined mask, reduced by logical and, returns None if no masks given.
- flax.nnx.dot_product_attention(query, key, value, bias=None, mask=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0.0, deterministic=False, dtype=None, precision=None, module=None, promote_dtype=<function promote_dtype>, is_causal=False)[source]#
Computes dot-product attention given query, key, and value.
This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights.
Will use the more optimized jax.nn.dot_product_attention if dropout is not activated and module=None.
Note
query,key,valueneedn’t have any batch dimensions.- Parameters:
query – queries for calculating attention with shape of
[batch..., q_length, num_heads, qk_depth_per_head].key – keys for calculating attention with shape of
[batch..., kv_length, num_heads, qk_depth_per_head].value – values to be used in attention with shape of
[batch..., kv_length, num_heads, v_depth_per_head].bias – bias for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks, padding masks, proximity bias, etc.
mask – mask for the attention weights. This should be broadcastable to the shape [batch…, num_heads, q_length, kv_length]. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is False.
broadcast_dropout – bool: use a broadcasted dropout along batch dims.
dropout_rng – JAX PRNGKey: to be used for dropout
dropout_rate – dropout rate
deterministic – bool, deterministic or not (to apply dropout)
dtype – the dtype of the computation (default: infer from inputs)
precision – numerical precision of the computation see jax.lax.Precision for details.
module – the Module that will sow the attention weights into the
nnx.Intermediatecollection. Ifmoduleis None, the attention weights will not be sowed.promote_dtype – function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of
(query, key, value)and adtypekeyword argument, and return a tuple of arrays with the promoted dtype.is_causal – If true, causal attention will be applied. Note, some implementations like xla will generate a mask tensor and apply it to the logits to mask out the non-causal parts of the attention matrix, but other implementations like cudnn will avoid computing the non-causal regions, providing speedups.
- Returns:
Output of shape [batch…, q_length, num_heads, v_depth_per_head].
- flax.nnx.make_attention_mask(query_input, key_input, pairwise_fn=<jnp.ufunc 'multiply'>, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#
Mask-making helper for attention weights.
In case of 1d inputs (i.e., [batch…, len_q], [batch…, len_kv], the attention weights will be [batch…, heads, len_q, len_kv] and this function will produce [batch…, 1, len_q, len_kv].
- Parameters:
query_input – a batched, flat input of query_length size
key_input – a batched, flat input of key_length size
pairwise_fn – broadcasting elementwise comparison function
extra_batch_dims – number of extra batch dims to add singleton axes for, none by default
dtype – mask return dtype
- Returns:
A [batch…, 1, len_q, len_kv] shaped mask for 1d attention.
- flax.nnx.make_causal_mask(x, extra_batch_dims=0, dtype=<class 'jax.numpy.float32'>)[source]#
Make a causal mask for self-attention.
In case of 1d inputs (i.e., [batch…, len], the self-attention weights will be [batch…, heads, len, len] and this function will produce a causal mask of shape [batch…, 1, len, len].
- Parameters:
x – input array of shape [batch…, len]
extra_batch_dims – number of batch dims to add singleton axes for, none by default
dtype – mask return dtype
- Returns:
A [batch…, 1, len, len] shaped causal mask for 1d attention.