Recurrent#
RNN modules for Flax.
- class flax.nnx.nn.recurrent.LSTMCell(self, in_features, hidden_features, *, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function modified_orthogonal>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=None, promote_dtype=<function promote_dtype>, keep_rngs=False, rngs, kernel_metadata=mappingproxy({}), recurrent_kernel_metadata=mappingproxy({}), bias_metadata=mappingproxy({}))[source]#
LSTM cell.
The mathematical definition of the cell is as follows
\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]where x is the input, h is the output of the previous time step, and c is the memory.
- __call__(carry, inputs)[source]#
A long short-term memory (LSTM) cell.
- Parameters:
carry – the hidden state of the LSTM cell, initialized using
LSTMCell.initialize_carry.inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.
- Returns:
A tuple with the new carry and the output.
- initialize_carry(input_shape, rngs=None, carry_init=None)[source]#
Initialize the RNN cell carry.
- Parameters:
rng – random number generator passed to the init_fn.
input_shape – a tuple providing the shape of the input to the cell.
- Returns:
An initialized carry for the given RNN cell.
Methods
initialize_carry(input_shape[, rngs, carry_init])Initialize the RNN cell carry.
- class flax.nnx.nn.recurrent.OptimizedLSTMCell(self, in_features, hidden_features, *, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=None, promote_dtype=<function promote_dtype>, keep_rngs=False, rngs, kernel_metadata=mappingproxy({}), recurrent_kernel_metadata=mappingproxy({}), bias_metadata=mappingproxy({}))[source]#
More efficient LSTM Cell that concatenates state components before matmul.
The parameters are compatible with
LSTMCell. Note that this cell is often faster thanLSTMCellas long as the hidden size is roughly <= 2048 units.The mathematical definition of the cell is the same as
LSTMCelland as follows:\[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]where x is the input, h is the output of the previous time step, and c is the memory.
- Parameters:
gate_fn – activation function used for gates (default: sigmoid).
activation_fn – activation function used for output and memory update (default: tanh).
kernel_init – initializer function for the kernels that transform the input (default: lecun_normal).
recurrent_kernel_init – initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).
bias_init – initializer for the bias parameters (default: initializers.zeros_init()).
dtype – the dtype of the computation (default: infer from inputs and params).
param_dtype – the dtype passed to parameter initializers (default: float32).
keep_rngs – whether to store the input rngs as attribute (i.e. self.rngs = rngs) (default: True). If rngs is stored, we should split the module as graphdef, params, nondiff = nnx.split(module, nnx.Param, …) where nondiff contains RNG object associated with stored self.rngs.
rngs – rng key.
kernel_metadata – Optional metadata dictionary to set when initializing the kernels that transform the input.
recurrent_kernel_metadata – Optional metadata dictionary to set when initializing the kernels that transform the hidden state.
bias_metadata – Optional metadata dictionary to set when initializing the bias of layers that transform the hidden state.
- __call__(carry, inputs)[source]#
An optimized long short-term memory (LSTM) cell.
- Parameters:
carry – the hidden state of the LSTM cell, initialized using
LSTMCell.initialize_carry.inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.
- Returns:
A tuple with the new carry and the output.
- initialize_carry(input_shape, rngs=None, carry_init=None)[source]#
Initialize the RNN cell carry.
- Parameters:
rngs – random number generator passed to the init_fn.
input_shape – a tuple providing the shape of the input to the cell.
- Returns:
An initialized carry for the given RNN cell.
Methods
initialize_carry(input_shape[, rngs, carry_init])Initialize the RNN cell carry.
- class flax.nnx.nn.recurrent.SimpleCell(self, in_features, hidden_features, *, dtype=<class 'jax.numpy.float32'>, param_dtype=<class 'jax.numpy.float32'>, carry_init=None, residual=False, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, promote_dtype=<function promote_dtype>, keep_rngs=False, rngs, kernel_metadata=mappingproxy({}), recurrent_kernel_metadata=mappingproxy({}), bias_metadata=mappingproxy({}))[source]#
Simple cell.
The mathematical definition of the cell is as follows
\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h) \end{array}\]where x is the input and h is the output of the previous time step.
If residual is True,
\[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h + h) \end{array}\]- __call__(carry, inputs)[source]#
Run the RNN cell.
- Parameters:
carry – the hidden state of the RNN cell.
inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.
- Returns:
A tuple with the new carry and the output.
- initialize_carry(input_shape, rngs=None, carry_init=None)[source]#
Initialize the RNN cell carry.
- Parameters:
rng – random number generator passed to the init_fn.
input_shape – a tuple providing the shape of the input to the cell.
- Returns:
An initialized carry for the given RNN cell.
Methods
initialize_carry(input_shape[, rngs, carry_init])Initialize the RNN cell carry.
- class flax.nnx.nn.recurrent.GRUCell(self, in_features, hidden_features, *, gate_fn=<PjitFunction of <function sigmoid>>, activation_fn=<PjitFunction of <function tanh>>, kernel_init=<function variance_scaling.<locals>.init>, recurrent_kernel_init=<function orthogonal.<locals>.init>, bias_init=<function zeros>, dtype=None, param_dtype=<class 'jax.numpy.float32'>, carry_init=None, promote_dtype=<function promote_dtype>, keep_rngs=False, rngs, kernel_metadata=mappingproxy({}), recurrent_kernel_metadata=mappingproxy({}), bias_metadata=mappingproxy({}))[source]#
GRU cell.
The mathematical definition of the cell is as follows
\[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array}\end{split}\]where x is the input and h is the output of the previous time step.
- Parameters:
in_features – number of input features.
hidden_features – number of output features.
gate_fn – activation function used for gates (default: sigmoid).
activation_fn – activation function used for output and memory update (default: tanh).
kernel_init – initializer function for the kernels that transform the input (default: lecun_normal).
recurrent_kernel_init – initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).
bias_init – initializer for the bias parameters (default: initializers.zeros_init()).
dtype – the dtype of the computation (default: None).
param_dtype – the dtype passed to parameter initializers (default: float32).
keep_rngs – whether to store the input rngs as attribute (i.e. self.rngs = rngs) (default: True). If rngs is stored, we should split the module as graphdef, params, nondiff = nnx.split(module, nnx.Param, …) where nondiff contains RNG object associated with stored self.rngs.
rngs – rng key.
kernel_metadata – Optional metadata dictionary to set when initializing the kernels that transform the input.
recurrent_kernel_metadata – Optional metadata dictionary to set when initializing the kernels that transform the hidden state.
bias_metadata – Optional metadata dictionary to set when initializing the bias of layers that transform the input.
- __call__(carry, inputs)[source]#
Gated recurrent unit (GRU) cell.
- Parameters:
carry – the hidden state of the GRU cell, initialized using
GRUCell.initialize_carry.inputs – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.
- Returns:
A tuple with the new carry and the output.
- initialize_carry(input_shape, rngs=None, carry_init=None)[source]#
Initialize the RNN cell carry.
- Parameters:
rngs – random number generator passed to the init_fn.
input_shape – a tuple providing the shape of the input to the cell.
- Returns:
An initialized carry for the given RNN cell.
Methods
initialize_carry(input_shape[, rngs, carry_init])Initialize the RNN cell carry.
- class flax.nnx.nn.recurrent.RNN(self, cell, *, time_major=False, return_carry=False, reverse=False, keep_order=False, unroll=1, state_axes=None, broadcast_rngs=None, rngs=True)[source]#
The
RNNmodule takes anyRNNCellBaseinstance and applies it over a sequenceusing
flax.nnx.scan().- __call__(inputs, *, initial_carry=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None, rngs=None)[source]#
Call self as a function.
Methods
- class flax.nnx.nn.recurrent.Bidirectional(self, forward_rnn, backward_rnn, *, merge_fn=<function _concatenate>, time_major=False, return_carry=False, rngs=True)[source]#
Processes the input in both directions and merges the results.
Example usage:
>>> from flax import nnx >>> import jax >>> import jax.numpy as jnp >>> # Define forward and backward RNNs >>> forward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) >>> backward_rnn = RNN(GRUCell(in_features=3, hidden_features=4, rngs=nnx.Rngs(0))) >>> # Create Bidirectional layer >>> layer = Bidirectional(forward_rnn=forward_rnn, backward_rnn=backward_rnn) >>> # Input data >>> x = jnp.ones((2, 3, 3)) >>> # Apply the layer >>> out = layer(x) >>> print(out.shape) (2, 3, 8)
- __call__(inputs, *, initial_carry=None, rngs=None, seq_lengths=None, return_carry=None, time_major=None, reverse=None, keep_order=None)[source]#
Call self as a function.
Methods
- flax.nnx.nn.recurrent.flip_sequences(inputs, seq_lengths, num_batch_dims, time_major)[source]#
Flips a sequence of inputs along the time axis.
This function can be used to prepare inputs for the reverse direction of a bidirectional LSTM. It solves the issue that, when naively flipping multiple padded sequences stored in a matrix, the first elements would be padding values for those sequences that were padded. This function keeps the padding at the end, while flipping the rest of the elements.
Example:
>>> from flax.nnx.nn.recurrent import flip_sequences >>> from jax import numpy as jnp >>> inputs = jnp.array([[1, 0, 0], [2, 3, 0], [4, 5, 6]]) >>> lengths = jnp.array([1, 2, 3]) >>> flip_sequences(inputs, lengths, 1, False) Array([[1, 0, 0], [3, 2, 0], [6, 5, 4]], dtype=int32)
- Parameters:
inputs – An array of input IDs <int>[batch_size, seq_length].
lengths – The length of each sequence <int>[batch_size].
- Returns:
An ndarray with the flipped inputs.