Source code for flax.nnx.nn.attention

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Attention core modules for Flax."""

from __future__ import annotations

import functools
from typing import Any
from collections.abc import Mapping
from types import MappingProxyType
from collections.abc import Callable
import math

import jax
import jax.numpy as jnp
from jax import lax, random

from flax import nnx
from flax.nnx import rnglib
from flax.nnx.module import Module, first_from
from flax.nnx.nn import initializers
from flax.nnx.nn import dtypes
from flax.nnx.nn.linear import (
  LinearGeneral,
  default_kernel_init,
)
from flax.nnx.nn.normalization import LayerNorm
from flax.typing import (
  Dtype,
  PromoteDtypeFn,
  Shape,
  Initializer,
  PrecisionLike,
  DotGeneralT,
)

Array = jax.Array


def dot_product_attention_weights(
  query: Array,
  key: Array,
  bias: Array | None = None,
  mask: Array | None = None,
  broadcast_dropout: bool = True,
  dropout_rng: Array | None = None,
  dropout_rate: float = 0.0,
  deterministic: bool = False,
  dtype: Dtype | None = None,
  precision: PrecisionLike = None,
  module: Module | None = None,
  promote_dtype: PromoteDtypeFn = dtypes.promote_dtype,
  is_causal: bool = False,
):
  """Computes dot-product attention weights given query and key.

  Used by :func:`dot_product_attention`, which is what you'll most likely use.
  But if you want access to the attention weights for introspection, then
  you can directly call this function and call einsum yourself.

  Args:
    query: queries for calculating attention with shape of `[batch..., q_length,
      num_heads, qk_depth_per_head]`.
    key: keys for calculating attention with shape of `[batch..., kv_length,
      num_heads, qk_depth_per_head]`.
    bias: bias for the attention weights. This should be broadcastable to the
      shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
      incorporating causal masks, padding masks, proximity bias, etc.
    mask: mask for the attention weights. This should be broadcastable to the
      shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
      incorporating causal masks. Attention weights are masked out if their
      corresponding mask value is `False`.
    broadcast_dropout: bool: use a broadcasted dropout along batch dims.
    dropout_rng: JAX PRNGKey: to be used for dropout
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    dtype: the dtype of the computation (default: infer from inputs and params)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.
    module: the Module that will sow the attention weights into the
      ``nnx.Intermediate`` collection. If ``module`` is None, the attention
      weights will not be sowed.
    promote_dtype: function to promote the dtype of the arrays to the desired
      dtype. The function should accept a tuple of ``(query, key)`` and a ``dtype``
      keyword argument, and return a tuple of arrays with the promoted dtype.
    is_causal: If true, causal attention will be applied. Note, some
      implementations like xla will generate a mask tensor and apply it to
      the logits to mask out the non-causal parts of the attention matrix,
      but other implementations like cudnn will avoid computing the
      non-causal regions, providing speedups.

  Returns:
    Output of shape `[batch..., num_heads, q_length, kv_length]`.
  """
  query, key = promote_dtype((query, key), dtype=dtype)  # type: ignore[bad-unpacking]
  dtype = query.dtype

  assert query.ndim == key.ndim, 'q, k must have same rank.'
  assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
  assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'

  # check if we need to broadcast Key heads to match Query heads
  is_gqa = False
  if query.shape[-2] != key.shape[-2]:
    q_heads = query.shape[-2]
    k_heads = key.shape[-2]

    if q_heads % k_heads != 0:
      raise ValueError(
        f"Query heads ({q_heads}) must be multiple of "
        f"Key heads ({k_heads}) for Grouped Query Attention."
      )

    n_rep = q_heads // k_heads
    is_gqa = True
    # Reshape Query: [..., Q, H_k * n_rep, D] -> [..., Q, H_k, n_rep, D]
    query = query.reshape(query.shape[:-2] + (k_heads, n_rep, query.shape[-1]))
    # Expand Key: [..., K, H_k, D] -> [..., K, H_k, 1, D]
    key = jnp.expand_dims(key, axis=-2)

    # Contract: q(h)gd, k(h)1d -> hgqk (h=H_k, g=n_rep)
    einsum_str = '...qhgd,...kh1d->...hgqk'
  else:
    q_heads = query.shape[-2]
    einsum_str = '...qhd,...khd->...hqk'
    assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'

  # calculate attention matrix
  depth = query.shape[-1]
  query = query / jnp.sqrt(depth).astype(dtype)

  # attn weight shape is (batch..., num_heads, q_length, kv_length)
  attn_weights = jnp.einsum(einsum_str, query, key, precision=precision)

  if is_gqa:
      attn_weights = attn_weights.reshape(attn_weights.shape[:-4] + (q_heads, attn_weights.shape[-2], attn_weights.shape[-1]))

  # apply attention bias: masking, dropout, proximity bias, etc.
  if bias is not None:
    attn_weights = attn_weights + bias
  # apply attention mask
  if mask is not None or is_causal:
    big_neg = jnp.finfo(dtype).min
    masks = [m for m in [mask] if m is not None]
    if is_causal:
      T, S = attn_weights.shape[-2:]
      causal_mask = jnp.tril(jnp.ones((T, S), dtype=dtype))
      target_shape = mask.shape if mask is not None else attn_weights.shape
      masks.append(jnp.broadcast_to(causal_mask, target_shape))
    combined_mask = combine_masks(*masks, dtype=dtype)
    assert combined_mask is not None
    attn_weights = jnp.where(combined_mask, attn_weights, big_neg)

  # normalize the attention weights
  attn_weights = jax.nn.softmax(attn_weights).astype(dtype)

  if module:
    module.sow(nnx.Intermediate, 'attention_weights', attn_weights)

  # apply attention dropout
  if not deterministic and dropout_rate > 0.0:
    keep_prob = 1.0 - dropout_rate
    # use original key.ndim because we might have expanded key dim
    ndim_base = key.ndim - 1 if is_gqa else key.ndim

    if broadcast_dropout:
      # dropout is broadcast across the batch + head dimensions
      dropout_shape = tuple([1] * (ndim_base - 2)) + attn_weights.shape[-2:]
      keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)  # type: ignore
    else:
      keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)  # type: ignore
    multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
    attn_weights = attn_weights * multiplier

  return attn_weights


[docs]def dot_product_attention( query: Array, key: Array, value: Array, bias: Array | None = None, mask: Array | None = None, broadcast_dropout: bool = True, dropout_rng: Array | None = None, dropout_rate: float = 0.0, deterministic: bool = False, dtype: Dtype | None = None, precision: PrecisionLike = None, module: Module | None = None, promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, is_causal: bool = False, ): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights. Will use the more optimized `jax.nn.dot_product_attention` if dropout is not activated and `module=None`. .. note:: ``query``, ``key``, ``value`` needn't have any batch dimensions. Args: query: queries for calculating attention with shape of ``[batch..., q_length, num_heads, qk_depth_per_head]``. key: keys for calculating attention with shape of ``[batch..., kv_length, num_heads, qk_depth_per_head]``. value: values to be used in attention with shape of ``[batch..., kv_length, num_heads, v_depth_per_head]``. bias: bias for the attention weights. This should be broadcastable to the shape `[batch..., num_heads, q_length, kv_length]`. This can be used for incorporating causal masks, padding masks, proximity bias, etc. mask: mask for the attention weights. This should be broadcastable to the shape `[batch..., num_heads, q_length, kv_length]`. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is `False`. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) dtype: the dtype of the computation (default: infer from inputs) precision: numerical precision of the computation see `jax.lax.Precision` for details. module: the Module that will sow the attention weights into the ``nnx.Intermediate`` collection. If ``module`` is None, the attention weights will not be sowed. promote_dtype: function to promote the dtype of the arrays to the desired dtype. The function should accept a tuple of ``(query, key, value)`` and a ``dtype`` keyword argument, and return a tuple of arrays with the promoted dtype. is_causal: If true, causal attention will be applied. Note, some implementations like xla will generate a mask tensor and apply it to the logits to mask out the non-causal parts of the attention matrix, but other implementations like cudnn will avoid computing the non-causal regions, providing speedups. Returns: Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`. """ query, key, value = promote_dtype((query, key, value), dtype=dtype) # type: ignore[bad-unpacking] dtype = query.dtype assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' assert ( query.shape[:-3] == key.shape[:-3] == value.shape[:-3] ), 'q, k, v batch dims must match.' assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' # Criteria that invoke the more optimized dot product attention if dropout_rate == 0.0 and module is None: # make sure qkv batch are compressed to one dim query_shape = query.shape if len(query_shape) > 4: def reshape_4d(x): return jnp.reshape(x, (math.prod(x.shape[:-3]), *x.shape[-3:])) query, key, value, bias, mask = jax.tree.map( reshape_4d, (query, key, value, bias, mask)) if mask is not None: mask = mask.astype(jnp.bool) out = jax.nn.dot_product_attention(query, key, value, bias, mask, is_causal=is_causal) if len(query_shape) > 4: out = jnp.reshape(out, query_shape) return out # compute attention weights attn_weights = dot_product_attention_weights( query, key, bias, mask, broadcast_dropout, dropout_rng, dropout_rate, deterministic, dtype, precision, module, promote_dtype, is_causal, ) # return weighted sum over values for each query position # check if need to broadcast Value heads to match Query heads (GQA) if attn_weights.shape[-3] != value.shape[-2]: q_heads = attn_weights.shape[-3] v_heads = value.shape[-2] if q_heads % v_heads != 0: raise ValueError(f"Query heads ({q_heads}) must be multiple of Value heads ({v_heads})") n_rep = q_heads // v_heads # Reshape weights: [..., H_v, n_rep, Q, K] attn_weights = attn_weights.reshape(attn_weights.shape[:-3] + (v_heads, n_rep) + attn_weights.shape[-2:]) # Expand Value: [..., K, H_v, 1, D] value = jnp.expand_dims(value, axis=-2) # Contract: hgqk, kh1d -> qhgd (h=H_v, g=n_rep) out = jnp.einsum('...hgqk,...kh1d->...qhgd', attn_weights, value, precision=precision) # Flatten: [..., Q, H_q, D] out = out.reshape(out.shape[:-3] + (q_heads, out.shape[-1])) else: out = jnp.einsum( '...hqk,...khd->...qhd', attn_weights, value, precision=precision ) return out
[docs]class MultiHeadAttention(Module): """Multi-head attention. Example usage:: >>> from flax import nnx >>> import jax >>> layer = nnx.MultiHeadAttention(num_heads=8, in_features=5, qkv_features=16, ... decode=False, rngs=nnx.Rngs(0)) >>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3) >>> shape = (4, 3, 2, 5) >>> q, k, v = ( ... jax.random.uniform(key1, shape), ... jax.random.uniform(key2, shape), ... jax.random.uniform(key3, shape), ... ) >>> # different inputs for inputs_q, inputs_k and inputs_v >>> out = layer(q, k, v) >>> # equivalent output when inferring v >>> assert (layer(q, k) == layer(q, k, k)).all() >>> # equivalent output when inferring k and v >>> assert (layer(q) == layer(q, q)).all() >>> assert (layer(q) == layer(q, q, q)).all() Args: num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. in_features: int or tuple with number of input features. qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection. in_kv_features: number of input features for computing key and value. num_kv_heads: number of key and value heads. If None, it defaults to ``num_heads``. If set to a value smaller than ``num_heads``, Grouped Query Attention (GQA) is used. ``num_heads`` must be divisible by ``num_kv_heads``. dtype: the dtype of the computation (default: infer from inputs and params) param_dtype: the dtype passed to parameter initializers (default: float32) broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rate: dropout rate deterministic: if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. out_kernel_init: optional initializer for the kernel of the output Dense layer, if None, the kernel_init is used. bias_init: initializer for the bias of the Dense layers. out_bias_init: optional initializer for the bias of the output Dense layer, if None, the bias_init is used. use_bias: bool: whether pointwise QKVO dense transforms use bias. attention_fn: dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` decode: whether to prepare and use an autoregressive cache. normalize_qk: should QK normalization be applied (arxiv.org/abs/2302.05442). qkv_promote_dtype: function to promote the dtype of all input array arguments (including Variables accessed through ``self``) to the desired dtype for the query, key, and value LinearGeneral submodules. out_promote_dtype: function to promote the dtype of all input array arguments (including Variables accessed through ``self``) to the desired dtype for the output LinearGeneral submodule. ln_promote_dtype: function to promote the dtype of all input array arguments (including Variables accessed through ``self``) to the desired dtype for the LayerNorm submodules (query_ln and key_ln) when normalize_qk=True. rngs: rng key. keep_rngs: whether to store the input rngs as attribute (i.e. `self.rngs = rngs`) (default: True). If rngs is stored, we should split the module as `graphdef, params, nondiff = nnx.split(module, nnx.Param, ...)` where `nondiff` contains RNG object associated with stored `self.rngs`. kernel_metadata: Optional metadata dictionary to set when initializing the Dense layers. out_kernel_metadata: Optional metadata dictionary to set when initializing the output Dense layers. If None, the kernel_metadata is used. bias_metadata: Optional metadata dictionary to set when initializing the bias of the Dense layers. out_bias_metadata: Optional metadata dictionary to set when initializing the bias of the output Dense layers. If None, the bias_metadata is used. query_ln_scale_metadata: Optional metadata dictionary to set when initializing the scale of the query layer norm layer. key_ln_scale_metadata: Optional metadata dictionary to set when initializing the scale of the key layer norm layer. """ def __init__( self, num_heads: int, in_features: int, qkv_features: int | None = None, out_features: int | None = None, num_kv_heads: int | None = None, in_kv_features: int | None = None, *, dtype: Dtype | None = None, param_dtype: Dtype = jnp.float32, broadcast_dropout: bool = True, dropout_rate: float = 0.0, deterministic: bool | None = None, precision: PrecisionLike = None, kernel_init: Initializer = default_kernel_init, out_kernel_init: Initializer | None = None, bias_init: Initializer = initializers.zeros_init(), out_bias_init: Initializer | None = None, use_bias: bool = True, attention_fn: Callable[..., Array] = dot_product_attention, decode: bool | None = None, normalize_qk: bool = False, qkv_promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, out_promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, ln_promote_dtype: PromoteDtypeFn = dtypes.promote_dtype, # Deprecated, will be removed. qkv_dot_general: DotGeneralT | None = None, out_dot_general: DotGeneralT | None = None, qkv_dot_general_cls: Any = None, out_dot_general_cls: Any = None, rngs: rnglib.Rngs, keep_rngs: bool = True, kernel_metadata: Mapping[str, Any] = MappingProxyType({}), out_kernel_metadata: Mapping[str, Any] = MappingProxyType({}), bias_metadata: Mapping[str, Any] = MappingProxyType({}), out_bias_metadata: Mapping[str, Any] = MappingProxyType({}), query_ln_scale_metadata: Mapping[str, Any] = MappingProxyType({}), key_ln_scale_metadata: Mapping[str, Any] = MappingProxyType({}), ): self.num_heads = num_heads self.in_features = in_features self.qkv_features = ( qkv_features if qkv_features is not None else in_features ) self.out_features = ( out_features if out_features is not None else in_features ) self.in_kv_features = ( in_kv_features if in_kv_features is not None else in_features ) self.num_kv_heads = ( num_kv_heads if num_kv_heads is not None else num_heads ) if self.num_heads % self.num_kv_heads != 0: raise ValueError( f"num_heads ({self.num_heads}) must be divisible by " f"num_kv_heads ({self.num_kv_heads})." ) self.dtype = dtype self.param_dtype = param_dtype self.broadcast_dropout = broadcast_dropout self.dropout_rate = dropout_rate self.deterministic = deterministic self.precision = precision self.use_bias = use_bias self.attention_fn = attention_fn self.decode = decode self.normalize_qk = normalize_qk self.qkv_promote_dtype = qkv_promote_dtype self.out_promote_dtype = out_promote_dtype self.ln_promote_dtype = ln_promote_dtype self.qkv_dot_general = qkv_dot_general self.out_dot_general = out_dot_general self.qkv_dot_general_cls = qkv_dot_general_cls self.out_dot_general_cls = out_dot_general_cls if self.qkv_features % self.num_heads != 0: raise ValueError( f'Memory dimension ({self.qkv_features}) must be divisible by ' f"'num_heads' heads ({self.num_heads})." ) self.head_dim = self.qkv_features // self.num_heads linear_general = functools.partial( LinearGeneral, dtype=self.dtype, param_dtype=self.param_dtype, kernel_init=kernel_init, bias_init=bias_init, use_bias=self.use_bias, precision=self.precision, promote_dtype=self.qkv_promote_dtype, dot_general=self.qkv_dot_general, dot_general_cls=self.qkv_dot_general_cls, kernel_metadata=kernel_metadata, bias_metadata=bias_metadata, ) # project inputs_q to multi-headed q/k/v # dimensions are then [batch..., length, n_heads, n_features_per_head] self.query = linear_general( self.in_features, out_features=(self.num_heads, self.head_dim), rngs=rngs ) self.key = linear_general( self.in_kv_features, out_features=(self.num_kv_heads, self.head_dim), rngs=rngs ) self.value = linear_general( self.in_kv_features, out_features=(self.num_kv_heads, self.head_dim), rngs=rngs ) self.query_ln: LayerNorm | None self.key_ln: LayerNorm | None if self.normalize_qk: # Normalizing query and key projections stabilizes training with higher # LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis. self.query_ln = LayerNorm( self.head_dim, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, promote_dtype=self.ln_promote_dtype, rngs=rngs, scale_metadata=query_ln_scale_metadata, ) self.key_ln = LayerNorm( self.head_dim, use_bias=False, dtype=self.dtype, param_dtype=self.param_dtype, promote_dtype=self.ln_promote_dtype, rngs=rngs, scale_metadata=key_ln_scale_metadata, ) else: self.query_ln = nnx.data(None) self.key_ln = nnx.data(None) self.out = LinearGeneral( in_features=(self.num_heads, self.head_dim), out_features=self.out_features, axis=(-2, -1), kernel_init=out_kernel_init or kernel_init, bias_init=out_bias_init or bias_init, use_bias=self.use_bias, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision, promote_dtype=self.out_promote_dtype, dot_general=self.out_dot_general, dot_general_cls=self.out_dot_general_cls, rngs=rngs, kernel_metadata=out_kernel_metadata or kernel_metadata, bias_metadata=out_bias_metadata or bias_metadata, ) self.rngs = rngs.dropout.fork() if keep_rngs and dropout_rate > 0 else None self.cached_key: nnx.Cache[Array] | None = nnx.data(None) self.cached_value: nnx.Cache[Array] | None = nnx.data(None) self.cache_index: nnx.Cache[Array] | None = nnx.data(None)
[docs] def __call__( self, inputs_q: Array, inputs_k: Array | None = None, inputs_v: Array | None = None, *, mask: Array | None = None, deterministic: bool | None = None, rngs: rnglib.Rngs | rnglib.RngStream | None = None, sow_weights: bool = False, decode: bool | None = None, ): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. If both inputs_k and inputs_v are None, they will both copy the value of inputs_q (self attention). If only inputs_v is None, it will copy the value of inputs_k. Args: inputs_q: input queries of shape `[batch_sizes..., length, features]`. inputs_k: key of shape `[batch_sizes..., length, features]`. If None, inputs_k will copy the value of inputs_q. inputs_v: values of shape `[batch_sizes..., length, features]`. If None, inputs_v will copy the value of inputs_k. mask: attention mask of shape `[batch_sizes..., num_heads, query_length, key/value_length]`. Attention weights are masked out if their corresponding mask value is `False`. deterministic: if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. The ``deterministic`` flag passed into the call method will take precedence over the ``deterministic`` flag passed into the constructor. rngs: rng key. The rng key passed into the call method will take precedence over the rng key passed into the constructor. sow_weights: if ``True``, the attention weights are sowed into the 'intermediates' collection. decode: whether to prepare and use an autoregressive cache. The ``decode`` flag passed into the call method will take precedence over the ``decode`` flag passed into the constructor. Returns: output of shape `[batch_sizes..., length, features]`. """ if rngs is None: rngs = self.rngs elif isinstance(rngs, rnglib.Rngs): rngs = rngs.dropout if inputs_k is None: if inputs_v is not None: raise ValueError( '`inputs_k` cannot be None if `inputs_v` is not None. ' 'To have both `inputs_k` and `inputs_v` be the same value, pass in the ' 'value to `inputs_k` and leave `inputs_v` as None.' ) inputs_k = inputs_q if inputs_v is None: inputs_v = inputs_k if inputs_q.shape[-1] != self.in_features: raise ValueError( f'Incompatible input dimension, got {inputs_q.shape[-1]} ' f'but module expects {self.in_features}.' ) query = self.query(inputs_q) key = self.key(inputs_k) value = self.value(inputs_v) if self.normalize_qk: assert self.query_ln is not None and self.key_ln is not None # Normalizing query and key projections stabilizes training with higher # LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis. query = self.query_ln(query) key = self.key_ln(key) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. decode = first_from( decode, self.decode, error_msg="""No `decode` argument was provided to MultiHeadAttention as either a __call__ argument, class attribute, or nnx.flag.""", ) if decode: if ( self.cached_key is None or self.cached_value is None or self.cache_index is None ): raise ValueError( 'Autoregressive cache not initialized, call ``init_cache`` first.' ) ( *batch_dims, max_length, num_kv_heads, depth_per_head, ) = self.cached_key.shape # shape check of cached keys against key input expected_shape = tuple(batch_dims) + (1, num_kv_heads, depth_per_head) if expected_shape != key.shape: raise ValueError( 'Autoregressive cache shape error, ' f'expected key shape {expected_shape} instead got {key.shape}.' ) # update key, value caches with our new 1d spatial slices cur_index = self.cache_index[...] zero = jnp.array(0, dtype=lax.dtype(cur_index.dtype)) indices = (zero,) * len(batch_dims) + (cur_index, zero, zero) key = lax.dynamic_update_slice(self.cached_key[...], key, indices) value = lax.dynamic_update_slice(self.cached_value[...], value, indices) self.cached_key[...] = key self.cached_value[...] = value self.cache_index[...] += 1 # causal mask for cached decoder self-attention: # our single query position should only attend to those key # positions that have already been generated and cached, # not the remaining zero elements. mask = combine_masks( mask, jnp.broadcast_to( jnp.arange(max_length) <= cur_index, tuple(batch_dims) + (1, 1, max_length), ), ) if ( self.dropout_rate > 0.0 ): # Require `deterministic` only if using dropout. deterministic = first_from( deterministic, self.deterministic, error_msg="""No `deterministic` argument was provided to MultiHeadAttention as either a __call__ argument, class attribute, or nnx.flag.""", ) if not deterministic: if rngs is None: raise ValueError( "'rngs' must be provided to __call__ method if " "MultiHeadAttention instance is defined with keep_rngs=False." ) dropout_rng = rngs() else: dropout_rng = None else: deterministic = True dropout_rng = None # apply attention x = self.attention_fn( query, key, value, mask=mask, dropout_rng=dropout_rng, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, deterministic=deterministic, dtype=self.dtype, precision=self.precision, module=self if sow_weights else None, ) # back to the original inputs dimensions out = self.out(x) return out
[docs] def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): """Initializes cache for fast autoregressive decoding. When ``decode=True``, this method must be called first before performing forward inference. When in decode mode, only one token must be passed at a time. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp ... >>> batch_size = 5 >>> embed_dim = 3 >>> x = jnp.ones((batch_size, 1, embed_dim)) # single token ... >>> model_nnx = nnx.MultiHeadAttention( ... num_heads=2, ... in_features=3, ... qkv_features=6, ... out_features=6, ... decode=True, ... rngs=nnx.Rngs(42), ... ) ... >>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized ... >>> model_nnx.init_cache(x.shape) >>> out_nnx = model_nnx(x) """ cache_shape = (*input_shape[:-1], self.num_kv_heads, self.head_dim) self.cached_key = nnx.Cache(jnp.zeros(cache_shape, dtype)) self.cached_value = nnx.Cache(jnp.zeros(cache_shape, dtype)) self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32))
[docs] def set_view( self, deterministic: bool | None = None, decode: bool | None = None, batch_size: int | Shape | None = None, max_length: int | None = None, ): """Class method used by ``nnx.view``. Args: train: if True, the module is set to training mode. deterministic: if True, the module is set to deterministic mode. decode: if True, the module is set to decode mode. batch_size: the batch size to use for the cache. max_length: the max length to use for the cache. """ if deterministic is not None: self.deterministic = deterministic if decode is not None: self.decode = decode if ( not hasattr(self, 'cached_key') or not hasattr(self, 'cached_value') or not hasattr(self, 'cache_index') ): if batch_size is None: raise TypeError( "'batch_size' must be provided when initializing cache." ) if max_length is None: raise TypeError( "'max_length' must be provided when initializing cache." ) if isinstance(batch_size, int): batch_size = (batch_size,) # initialize cache cache_shape = (*batch_size, max_length, self.num_kv_heads, self.head_dim) self.cached_key = nnx.Cache(jnp.zeros(cache_shape, self.dtype)) self.cached_value = nnx.Cache(jnp.zeros(cache_shape, self.dtype)) self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32))
# mask-making utility functions
[docs]def make_attention_mask( query_input: Array, key_input: Array, pairwise_fn: Callable[..., Any] = jnp.multiply, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32, ): """Mask-making helper for attention weights. In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the attention weights will be `[batch..., heads, len_q, len_kv]` and this function will produce `[batch..., 1, len_q, len_kv]`. Args: query_input: a batched, flat input of query_length size key_input: a batched, flat input of key_length size pairwise_fn: broadcasting elementwise comparison function extra_batch_dims: number of extra batch dims to add singleton axes for, none by default dtype: mask return dtype Returns: A `[batch..., 1, len_q, len_kv]` shaped mask for 1d attention. """ mask = pairwise_fn( jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2) ) mask = jnp.expand_dims(mask, axis=-3) mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) return mask.astype(dtype)
[docs]def make_causal_mask( x: Array, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32 ) -> Array: """Make a causal mask for self-attention. In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights will be `[batch..., heads, len, len]` and this function will produce a causal mask of shape `[batch..., 1, len, len]`. Args: x: input array of shape `[batch..., len]` extra_batch_dims: number of batch dims to add singleton axes for, none by default dtype: mask return dtype Returns: A `[batch..., 1, len, len]` shaped causal mask for 1d attention. """ idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) return make_attention_mask( idxs, idxs, jnp.greater_equal, extra_batch_dims=extra_batch_dims, dtype=dtype, )
[docs]def combine_masks( *masks: Array | None, dtype: Dtype = jnp.float32 ) -> Array | None: """Combine attention masks. Args: *masks: set of attention mask arguments to combine, some can be None. dtype: dtype for the returned mask. Returns: Combined mask, reduced by logical and, returns None if no masks given. """ masks_list = [m for m in masks if m is not None] if not masks_list: return None assert all( map(lambda x: x.ndim == masks_list[0].ndim, masks_list) ), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}' mask, *other_masks = masks_list for other_mask in other_masks: mask = jnp.logical_and(mask, other_mask) return mask.astype(dtype)