Source code for flax.nnx.transforms.iteration

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file

from collections import deque
import dataclasses
import functools
import typing as tp


from flax import struct
from flax import typing
from flax.core.frozen_dict import FrozenDict
from flax.nnx import extract, filterlib, graphlib, spmd, variablelib
from flax.nnx import statelib
from flax.nnx.module import Module
from flax.nnx.statelib import State
from flax.nnx.transforms.transforms import (
  resolve_kwargs,
  _resolve_bound_callable,
  _raise_bound_method_error,
)
from flax.typing import Leaf, Missing, PytreeDeque
import jax
import jax.core
import jax.numpy as jnp
import jax.stages
import numpy as np

A = tp.TypeVar('A')
C = tp.TypeVar('C')
B = tp.TypeVar('B')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any])
M = tp.TypeVar('M', bound=Module)
MA = tp.TypeVar('MA', bound=Module)
N = tp.TypeVar('N', bound=Module)
T = tp.TypeVar('T')
StrInt = tp.TypeVar('StrInt', str, int)
AxisName = tp.Hashable
Leaves = list[Leaf]
Index = int


[docs]class Carry: """Helper class for :func:`flax.nnx.scan` function to mark input and output axis as carry. """ pass
# ------------------------------- # transform_metadata # ------------------------------- def _apply_axis_fn( tree: tp.Any, axes: tp.Any, metadata: tp.Mapping[str, tp.Any], axis_fn: tp.Callable[..., tp.Any], ) -> None: is_leaf = lambda x: x is None or isinstance(x, variablelib.Variable) _, per_leaf_axes = extract.broadcast_prefix2(axes, tree, is_leaf=is_leaf) leaves = jax.tree_util.tree_leaves(tree, is_leaf=is_leaf) for leaf, axis in zip(leaves, per_leaf_axes): if isinstance(axis, int) and isinstance(leaf, variablelib.Variable): axis_fn(leaf, axis, metadata) @tp.overload def transform_metadata( *, in_axes: tp.Any = 0, out_axes: tp.Any = 0, partition: str, graph: bool | None = None, ) -> tp.Callable[[F], F]: ... @tp.overload def transform_metadata( f: F, *, in_axes: tp.Any = 0, out_axes: tp.Any = 0, graph: bool | None = None, partition: str, ) -> F: ... def transform_metadata( f: F | type[Missing] = Missing, *, in_axes: tp.Any = 0, out_axes: tp.Any = 0, graph: bool | None = None, partition: str, ) -> F | tp.Callable[[F], F]: if f is Missing: return functools.partial( transform_metadata, in_axes=in_axes, out_axes=out_axes, partition=partition, graph=graph, ) # type: ignore[return-value] if graph is None: graph = graphlib.set_graph_mode.current_value() metadata: tp.Mapping[str, tp.Any] = { spmd.PARTITION_NAME: partition, } @functools.wraps(f) def wrapper(*in_args, **in_kwargs): in_args = resolve_kwargs(f, in_args, in_kwargs) if graph: in_args = extract.to_tree2(in_args, prefix=in_axes) extract.check_no_aliases('transform_metadata', args=in_args) args = graphlib.clone(in_args, graph=graph) _apply_axis_fn(args, in_axes, metadata, spmd.remove_axis) updates, snapshot = extract.updates_and_snapshot(args) if graph: args = extract.from_tree2(args) out = f(*args) if graph: out = extract.to_tree2(out, prefix=out_axes) extract.check_no_aliases('transform_metadata', args=updates, out=out) _apply_axis_fn(args, in_axes, metadata, spmd.add_axis) _apply_axis_fn(out, out_axes, metadata, spmd.add_axis) updates = extract.mask_variable_updates(updates, snapshot) extract.apply_variable_updates(in_args, updates) if graph: out = extract.from_tree2(out) return out return wrapper # type: ignore[return-value] # ------------------------------- # vmap # ------------------------------- class StateAxes(extract.PrefixMapping, tp.Mapping): def __init__( self, filter_axes: ( statelib.State | tp.Mapping[filterlib.Filter, Index | type[Carry] | None] | tp.Iterable[tuple[filterlib.Filter, Index | type[Carry] | None]] ), /, ): if isinstance(filter_axes, statelib.State): filter_axes = statelib.create_path_filters(filter_axes) # type: ignore iterable = tuple( filter_axes.items() if isinstance(filter_axes, tp.Mapping) else filter_axes ) self._filters = tuple(filter for filter, _ in iterable) self._axes = tuple(axis for _, axis in iterable) @property def filters(self) -> tuple[filterlib.Filter, ...]: return self._filters @property def axes(self) -> tuple[Index | type[Carry] | None, ...]: return self._axes def map_prefix( self, path: typing.PathParts, variable: variablelib.Variable ) -> tp.Any: for filter, axis in zip(self.filters, self.axes): predicate = filterlib.to_predicate(filter) if predicate(path, variable): return axis raise ValueError(f'No axis found for {path=}, {variable=}') def __repr__(self): return f'StateAxes({dict(self.items())})' def items(self): return zip(self.filters, self.axes) def __getitem__(self, key): return self.axes[self.filters.index(key)] def __iter__(self): return iter(self.filters) def __len__(self): return len(self.filters) def __eq__(self, other): return ( isinstance(other, StateAxes) and self.filters == other.filters and self.axes == other.axes ) def __hash__(self): return hash((self.filters, self.axes)) AxisFn = tp.Callable[ [graphlib.GraphState | variablelib.Variable, int, tp.Mapping], graphlib.GraphState | variablelib.Variable, ] def _update_variable_sharding_metadata( tree, transform_metadata, axis_fn: AxisFn ): def _update_axes_fn(node_states): if isinstance(node_states, extract.NodeStates) and isinstance( node_states.metadata, (StateAxes, int) ): if isinstance(node_states.metadata, int): state = node_states.state assert isinstance(state, State | variablelib.Variable) state = axis_fn(state, node_states.metadata, transform_metadata) return node_states.replace(states=(state,)) else: states_out: list[graphlib.GraphState | variablelib.Variable] = [] for state, axis in zip(node_states.states, node_states.metadata.axes): assert isinstance(state, graphlib.State | variablelib.Variable) if isinstance(axis, int): state = axis_fn(state, axis, transform_metadata) states_out.append(state) return node_states.replace(states=tuple(states_out)) return node_states return jax.tree.map( _update_axes_fn, tree, is_leaf=lambda x: isinstance(x, extract.NodeStates) ) def _vmap_split_fn(ctx: graphlib.SplitContext, path, prefix, x): if isinstance(prefix, StateAxes): return extract.NodeStates.from_split( *ctx.split(x, *prefix.filters), metadata=prefix ) return extract.NodeStates.from_split(*ctx.split(x), metadata=prefix) @dataclasses.dataclass(eq=False) class SimpleVmapFn: f: tp.Callable[..., tp.Any] graph: bool out_axes: tp.Any def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args, **kwargs): updates, snapshot = extract.updates_and_snapshot((args, kwargs)) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_axes) extract.check_no_aliases('vmap', args=updates[0], kwargs=updates[1], out=out) updates = extract.mask_variable_updates(updates, snapshot) return out, updates @dataclasses.dataclass(eq=False) class SimplePmapFn: f: tp.Callable[..., tp.Any] graph: bool out_axes: tp.Any def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args, **kwargs): updates, snapshot = extract.updates_and_snapshot((args, kwargs)) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_axes) extract.check_no_aliases('pmap', args=updates[0], kwargs=updates[1], out=out) updates = extract.mask_variable_updates(updates, snapshot) return out, updates @dataclasses.dataclass(eq=False) class VmapFn: f: tp.Callable[..., tp.Any] transform_metadata: tp.Mapping[str, tp.Any] in_axes: tp.Any out_axes: tp.Any def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args: tuple[tp.Any, ...]): if spmd.PARTITION_NAME in self.transform_metadata: pure_args = _update_variable_sharding_metadata( pure_args, self.transform_metadata, spmd.remove_axis ) args = extract.from_tree(pure_args, ctxtag='vmap', is_inner=True) out = self.f(*args) args_out = extract.clear_non_graph_nodes(args) pure_args_out, pure_out = extract.to_tree( (args_out, out), prefix=(self.in_axes, self.out_axes), split_fn=_vmap_split_fn, ctxtag='vmap', ) if spmd.PARTITION_NAME in self.transform_metadata: pure_args_out, pure_out = _update_variable_sharding_metadata( (pure_args_out, pure_out), self.transform_metadata, spmd.add_axis ) return pure_args_out, pure_out @tp.overload def vmap( *, in_axes: int | None | tp.Sequence[tp.Any] = 0, out_axes: tp.Any = 0, axis_name: AxisName | None = None, axis_size: int | None = None, spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[F], F]: ... @tp.overload def vmap( f: F, *, in_axes: int | None | tp.Sequence[tp.Any] = 0, out_axes: tp.Any = 0, axis_name: AxisName | None = None, axis_size: int | None = None, spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> F: ...
[docs]def vmap( f: F | type[Missing] = Missing, *, in_axes: int | None | tp.Sequence[tp.Any] = 0, out_axes: tp.Any = 0, axis_name: AxisName | None = None, axis_size: int | None = None, spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> F | tp.Callable[[F], F]: """Reference-aware version of `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__. Args: f: Function to be mapped over additional axes. in_axes: An integer, None, or sequence of values specifying which input array axes to map over (see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__). In addition to integers and None, :class:`StateAxes` can be used to control how graph nodes like Modules are vectorized by specifying the axes to be applied to substates of the graph node given a `Filter <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__. out_axes: An integer, None, or pytree indicating where the mapped axis should appear in the output (see `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__). axis_name: Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied. axis_size: Optional, an integer indicating the size of the axis to be mapped. If not provided, the mapped axis size is inferred from arguments. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not support ``StateAxes`` or shared ``Variable`` references. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. When ``False``, using ``StateAxes`` is not supported. Returns: Batched/vectorized version of ``f`` with arguments that correspond to those of ``f``, but with extra array axes at positions indicated by ``in_axes``, and a return value that corresponds to that of ``f``, but with extra array axes at positions indicated by ``out_axes``. Example:: >>> from flax import nnx >>> from jax import random, numpy as jnp ... >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((5, 2)) ... >>> @nnx.vmap(in_axes=(None, 0), out_axes=0) ... def forward(model, x): ... return model(x) ... >>> y = forward(model, x) >>> y.shape (5, 3) >>> class LinearEnsemble(nnx.Module): ... def __init__(self, num, rngs): ... self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3))) ... >>> model = LinearEnsemble(5, rngs=nnx.Rngs(0)) >>> x = jnp.ones((2,)) ... >>> @nnx.vmap(in_axes=(0, None), out_axes=0) ... def forward(model, x): ... return x @ model.w ... >>> y = forward(model, x) >>> y.shape (5, 3) To control control how graph node substates are vectorized, ``StateAxes`` can be passed to ``in_axes`` and ``out_axes`` specifying the axes to be applied to each substate given a filter. The following example shows how to share the parameters between the ensemble members which keeping different batch statistics and dropout random state:: >>> class Foo(nnx.Module): ... def __init__(self): ... self.a = nnx.Param(jnp.arange(4)) ... self.b = nnx.BatchStat(jnp.arange(4)) ... >>> state_axes = nnx.StateAxes({nnx.Param: 0, nnx.BatchStat: None}) >>> @nnx.vmap(in_axes=(state_axes,), out_axes=0) ... def mul(foo): ... return foo.a * foo.b ... >>> foo = Foo() >>> y = mul(foo) >>> y Array([[0, 0, 0, 0], [0, 1, 2, 3], [0, 2, 4, 6], [0, 3, 6, 9]], dtype=int32) """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if f is Missing: return functools.partial( vmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, axis_size=axis_size, spmd_axis_name=spmd_axis_name, transform_metadata=transform_metadata, graph=graph, graph_updates=graph_updates, ) # type: ignore[return-value] f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('vmap') if not graph or not graph_updates: if any(isinstance(x, StateAxes) for x in jax.tree.leaves(in_axes)): raise ValueError( '`in_axes` cannot contain `StateAxes` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('vmap') ) if any(isinstance(x, StateAxes) for x in jax.tree.leaves(out_axes)): raise ValueError( '`out_axes` cannot contain `StateAxes` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('vmap') ) vmapped_fn = jax.vmap( SimpleVmapFn(f_unbound, graph=graph, out_axes=out_axes), in_axes=in_axes, out_axes=(out_axes, (in_axes, 0)), axis_name=axis_name, axis_size=axis_size, spmd_axis_name=spmd_axis_name, ) @functools.wraps(f_unbound) def simple_vmap_wrapper(*args, **kwargs): if graph: args, kwargs = extract.to_tree2( (args, kwargs), prefix=(in_axes, None) if in_axes is not None else None, check_aliasing=in_axes is not None, ) extract.check_no_aliases('vmap', args=args, kwargs=kwargs) out, updates = vmapped_fn(*args, **kwargs) extract.apply_variable_updates((args, kwargs), updates) if graph: out = extract.from_tree2(out) return out return simple_vmap_wrapper # type: ignore[return-value] jax_in_axes = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x, in_axes, ) jax_out_axes = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x, out_axes, ) vmapped_fn = jax.vmap( # type: ignore[assignment] VmapFn(f_unbound, transform_metadata, in_axes, out_axes), in_axes=jax_in_axes, out_axes=(jax_in_axes, jax_out_axes), axis_name=axis_name, axis_size=axis_size, spmd_axis_name=spmd_axis_name, ) @functools.wraps(f) @graphlib.update_context('vmap') def vmap_wrapper(*args, **kwargs): args = resolve_kwargs(f, args, kwargs) pure_args = extract.to_tree( args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='vmap' ) pure_args_out, pure_out = vmapped_fn(*pure_args) _args_out, out = extract.from_tree( (pure_args_out, pure_out), ctxtag='vmap', is_inner=False ) return out return vmap_wrapper # type: ignore
# ------------------------------- # pmap # ------------------------------- @dataclasses.dataclass(eq=False) class PmapFn: f: tp.Callable[..., tp.Any] transform_metadata: tp.Mapping[str, tp.Any] in_axes: tp.Any out_axes: tp.Any def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args: tuple[tp.Any, ...]): if spmd.PARTITION_NAME in self.transform_metadata: pure_args = _update_variable_sharding_metadata( pure_args, self.transform_metadata, spmd.remove_axis ) args = extract.from_tree(pure_args, ctxtag='pmap', is_inner=True) out = self.f(*args) args_out = extract.clear_non_graph_nodes(args) pure_args_out, pure_out = extract.to_tree( (args_out, out), prefix=(self.in_axes, self.out_axes), split_fn=_vmap_split_fn, ctxtag='pmap', ) if spmd.PARTITION_NAME in self.transform_metadata: pure_args_out, pure_out = _update_variable_sharding_metadata( (pure_args_out, pure_out), self.transform_metadata, spmd.add_axis ) return pure_args_out, pure_out @tp.overload def pmap( *, axis_name: AxisName | None = None, in_axes: tp.Any = 0, out_axes: tp.Any = 0, static_broadcasted_argnums: int | tp.Iterable[int] = (), devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 backend: str | None = None, axis_size: int | None = None, donate_argnums: int | tp.Iterable[int] = (), global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[F], F]: ... @tp.overload def pmap( f: F, *, axis_name: AxisName | None = None, in_axes: tp.Any = 0, out_axes: tp.Any = 0, static_broadcasted_argnums: int | tp.Iterable[int] = (), devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 backend: str | None = None, axis_size: int | None = None, donate_argnums: int | tp.Iterable[int] = (), global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> F: ... def pmap( f: F | type[Missing] = Missing, *, axis_name: AxisName | None = None, in_axes: tp.Any = 0, out_axes: tp.Any = 0, static_broadcasted_argnums: int | tp.Iterable[int] = (), devices: tp.Sequence[jax.Device] | None = None, # noqa: F811 backend: str | None = None, axis_size: int | None = None, donate_argnums: int | tp.Iterable[int] = (), global_arg_shapes: tuple[tuple[int, ...], ...] | None = None, # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> F | tp.Callable[[F], F]: """Reference-aware version of `jax.pmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html>`__. Args: f: Function to be mapped over argument axes. Its arguments and return value should be arrays, scalars, graph nodes, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by ``static_broadcasted_argnums`` can be anything at all, provided they are hashable and have an equality operation defined. axis_name: Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied. in_axes: A non-negative integer, None, or nested Python container thereof that specifies which axes of positional arguments to map over. Arguments passed as keywords are always mapped over their leading axis (i.e. axis index 0). In addition to integers and None, :class:`StateAxes` can be used to control how graph nodes like Modules are vectorized by specifying the axes to be applied to substates of the graph node given a `Filter <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__. out_axes: A non-negative integer, None, or nested Python container thereof indicating where the mapped axis should appear in the output. All outputs with a mapped axis must have a non-None ``out_axes`` specification. static_broadcasted_argnums: An int or collection of ints specifying which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded. Calling the pmapped function with different values for these constants will trigger recompilation. If the pmapped function is called with fewer positional arguments than indicated by ``static_broadcasted_argnums`` then an error is raised. Each of the static arguments will be broadcasted to all devices. Arguments that are not arrays or containers thereof must be marked as static. Defaults to (). Static arguments must be hashable, meaning both ``__hash__`` and ``__eq__`` are implemented, and should be immutable. devices: This is an experimental feature and the API is likely to change. Optional, a sequence of Devices to map over. (Available devices can be retrieved via jax.devices()). Must be given identically for each process in multi-process settings (and will therefore include devices across processes). If specified, the size of the mapped axis must be equal to the number of devices in the sequence local to the given process. Nested ``pmap`` s with ``devices`` specified in either the inner or outer ``pmap`` are not yet supported. backend: This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend. 'cpu', 'gpu', or 'tpu'. axis_size: Optional; the size of the mapped axis. donate_argnums: Specify which positional argument buffers are "donated" to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. Note that donate_argnums only work for positional arguments, and keyword arguments will not be donated. transform_metadata: Optional mapping of metadata for the transform. graph: if True, use graph-mode (default). If False, use tree-mode. If None, uses the value of ``nnx_graph_mode`` config. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. When ``False``, using ``StateAxes`` is not supported. Returns: A parallelized version of ``f`` with arguments that correspond to those of ``f`` but with extra array axes at positions indicated by ``in_axes`` and with output that has an additional leading array axis (with the same size). """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if f is Missing: return functools.partial( pmap, axis_name=axis_name, in_axes=in_axes, out_axes=out_axes, static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, global_arg_shapes=global_arg_shapes, transform_metadata=transform_metadata, graph=graph, graph_updates=graph_updates, ) # type: ignore[return-value] f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('pmap') if not graph or not graph_updates: if any(isinstance(x, StateAxes) for x in jax.tree.leaves(in_axes)): raise ValueError( '`in_axes` cannot contain `StateAxes` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('pmap') ) if any(isinstance(x, StateAxes) for x in jax.tree.leaves(out_axes)): raise ValueError( '`out_axes` cannot contain `StateAxes` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('pmap') ) pmapped_fn = jax.pmap( SimplePmapFn(f_unbound, graph=graph, out_axes=out_axes), axis_name=axis_name, in_axes=in_axes, out_axes=(out_axes, (in_axes, 0)), static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, global_arg_shapes=global_arg_shapes, ) @functools.wraps(f_unbound) def simple_pmap_wrapper(*args, **kwargs): if graph: args, kwargs = extract.to_tree2( (args, kwargs), prefix=(in_axes, None) if in_axes is not None else None, check_aliasing=in_axes is not None, ) extract.check_no_aliases('pmap', args=args, kwargs=kwargs) out, updates = pmapped_fn(*args, **kwargs) extract.apply_variable_updates((args, kwargs), updates) if graph: out = extract.from_tree2(out) return out return simple_pmap_wrapper # type: ignore[return-value] jax_in_axes = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x, in_axes, ) jax_out_axes = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x) if isinstance(x, StateAxes) else x, out_axes, ) pmapped_fn = jax.pmap( PmapFn(f_unbound, transform_metadata, in_axes, out_axes), axis_name=axis_name, in_axes=jax_in_axes, out_axes=(jax_in_axes, jax_out_axes), static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, global_arg_shapes=global_arg_shapes, ) @functools.wraps(f) @graphlib.update_context('pmap') def vmap_wrapper(*args): pure_args = extract.to_tree( args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='pmap' ) pure_args_out, pure_out = pmapped_fn(*pure_args) _args_out, out = extract.from_tree( (pure_args_out, pure_out), ctxtag='pmap', is_inner=False ) return out return vmap_wrapper # type: ignore # ------------------------------- # scan # ------------------------------- class Broadcasted(struct.PyTreeNode): data: tp.Any def _get_carry_argnum(axes, is_in_axes: bool): if axes is Carry: return 'all' elif isinstance(axes, int) or axes is None: return None obj_repr = 'in_axes' if is_in_axes else 'out_axes' carry_argnum: int | None = None prev_key: tp.Any = None for key, x in jax.tree_util.tree_leaves_with_path(axes): if x is not Carry: continue assert isinstance(key[0], jax.tree_util.SequenceKey) i = key[0].idx if len(key) >= 2: raise ValueError( f'Carry must at the top-level, it cannot be nested. Found {axes=}' ) if carry_argnum is not None: raise ValueError( f'Found multiple Carry definitions at ' f'{obj_repr}{jax.tree_util.keystr(prev_key)} and ' f'{obj_repr}{jax.tree_util.keystr(key)}' ) carry_argnum = i prev_key = key return carry_argnum def _check_out_axes(out_axes): for key, x in jax.tree_util.tree_leaves_with_path( out_axes, is_leaf=lambda x: x is None ): if x is None: raise ValueError( f'Cannot broadcast output state. ' f'Got out_axes=None at: out_axes{jax.tree_util.keystr(key)}' ) elif isinstance(x, StateAxes): for filter, value in x.items(): if value is None: raise ValueError( f'Cannot broadcast output state. ' f'Got StateAxes({{{filter}: None}}) at: out_axes' f'{jax.tree_util.keystr(key)}' ) elif value is Carry: raise ValueError( f'Cannot carry output state. ' f'Got StateAxes({{{filter}: Carry}}) at: out_axes' f'{jax.tree_util.keystr(key)}' ) def _check_carry_same_references(carry_arg, carry_arg_out): def check_carry_same_references(key_path, arg, out): if ( not isinstance(arg, jax.Array) or not isinstance(out, jax.Array) ) and arg is not out: raise ValueError( 'Carry references must be the same between iterations. ' f'Got {arg=} with id={id(arg)} and {out=} with id={id(out)} ' f'at carry{jax.tree_util.keystr(key_path)}' ) jax.tree_util.tree_map_with_path( check_carry_same_references, carry_arg, carry_arg_out, is_leaf=lambda x: graphlib.is_graph_node(x) and not isinstance(x, variablelib.Variable), ) def _scan_split_in( carry_deque: PytreeDeque[list[State | variablelib.Variable]], graphdefs_deque: PytreeDeque[graphlib.GraphDef], broadcast_deque: PytreeDeque[list[State | variablelib.Variable]], broadcast_arrays: PytreeDeque[Broadcasted], /, ctx: graphlib.SplitContext, path, prefix, x, ): if graphlib.is_graph_node(x) or isinstance(x, variablelib.Variable): vectorized_states: list[State | variablelib.Variable] = [] carry_states: list[State | variablelib.Variable] = [] broadcast_states: list[State | variablelib.Variable] = [] if isinstance(prefix, StateAxes): graphdef, *states = ctx.split(x, *prefix.filters) for state, axis in zip(states, prefix.axes): if axis is None: broadcast_states.append(state) elif isinstance(axis, int): if axis != 0: state = jax.tree.map(lambda x: jnp.moveaxis(x, axis, 0), state) vectorized_states.append(state) else: # axis is Carry carry_states.append(state) if not vectorized_states: vectorized_states.append(State({})) carry_deque.append(carry_states) graphdefs_deque.append(graphdef) broadcast_deque.append(broadcast_states) return extract.NodeStates.from_split( None, *vectorized_states, metadata=prefix ) elif isinstance(prefix, int): graphdef, state = ctx.split(x) if prefix != 0: state = jax.tree.map(lambda x: jnp.moveaxis(x, prefix, 0), state) vectorized_states.append(state) elif prefix is None: graphdef, state = ctx.split(x) broadcast_states.append(state) vectorized_states.append(State({})) elif prefix is Carry: graphdef, state = ctx.split(x) carry_states.append(state) vectorized_states.append(State({})) else: raise ValueError( f'Invalid axes {prefix} args{jax.tree_util.keystr(path)}' ) if not vectorized_states: vectorized_states.append(State({})) carry_deque.append(carry_states) graphdefs_deque.append(graphdef) broadcast_deque.append(broadcast_states) return extract.NodeStates.from_split( None, *vectorized_states, metadata=prefix ) else: if isinstance(prefix, StateAxes): raise ValueError( 'Cannot use StateAxes on non-graph nodes, ' f'found {prefix} args{jax.tree_util.keystr(path)}' ) elif prefix is Carry: return x elif prefix is None: broadcast_arrays.append(Broadcasted(x)) return Broadcasted(None) elif isinstance(prefix, int): if not isinstance(x, (jax.Array, np.ndarray)): raise ValueError( f'Expected an array, got {type(x).__name__} args' f'{jax.tree_util.keystr(path)}' ) if prefix != 0: x = jnp.moveaxis(x, prefix, 0) return x else: raise ValueError( f'Invalid axes {prefix} args{jax.tree_util.keystr(path)}' ) def _scan_split_out( carry_deque: PytreeDeque[list[State | variablelib.Variable]], graphdefs_deque: PytreeDeque[graphlib.GraphDef], broadcast_deque: PytreeDeque[list[State | variablelib.Variable]], /, ctx: graphlib.SplitContext, path: extract.KeyPath, prefix, x, ): assert isinstance(path[0], jax.tree_util.SequenceKey) is_input_arg = path[0].idx == 0 if graphlib.is_graph_node(x) or isinstance(x, variablelib.Variable): vectorized_states: list[State | variablelib.Variable] = [] carry_states: list[State | variablelib.Variable] = [] broadcast_states: list[State | variablelib.Variable] = [] if isinstance(prefix, StateAxes): graphdef, *states = ctx.split(x, *prefix.filters) for state, filter, axis in zip(states, prefix.filters, prefix.axes): if axis is None: assert is_input_arg # validated by _check_out_axes broadcast_states.append(state) elif isinstance(axis, int): vectorized_states.append(state) elif axis is Carry: assert is_input_arg # validated by _check_out_axes carry_states.append(state) else: obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Invalid axes {axis} for filter {filter} at ' f'{obj_repr}{jax.tree_util.keystr(path)}' ) if not vectorized_states: vectorized_states.append(State({})) if is_input_arg: carry_deque.append(carry_states) broadcast_deque.append(broadcast_states) graphdefs_deque.append(graphdef) return extract.NodeStates.from_split( None, *vectorized_states, metadata=prefix ) elif isinstance(prefix, int): graphdef, state = ctx.split(x) vectorized_states.append(state) elif prefix is None: assert is_input_arg # validated by _check_out_axes graphdef, state = ctx.split(x) broadcast_states.append(state) vectorized_states.append(State({})) elif prefix is Carry: assert is_input_arg # validated by _check_out_axes graphdef, state = ctx.split(x) carry_states.append(state) vectorized_states.append(State({})) else: obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Invalid axes {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) if not vectorized_states: vectorized_states.append(State({})) if is_input_arg: carry_deque.append(carry_states) broadcast_deque.append(broadcast_states) graphdefs_deque.append(graphdef) return extract.NodeStates.from_split( None, *vectorized_states, metadata=prefix ) else: if isinstance(prefix, StateAxes): obj_repr = 'args' if is_input_arg else 'out' raise ValueError( 'Cannot use StateAxes on non-graph nodes, ' f'found {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) elif prefix is Carry: return x elif prefix is None: assert not is_input_arg # validated by _check_out_axes return Broadcasted(None) elif isinstance(prefix, int): return x else: obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Invalid axes {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) def _scan_merge_in( carry_deque: PytreeDeque[list[State]], graphdefs_deque: PytreeDeque[graphlib.GraphDef], broadcast_deque: PytreeDeque[list[State]], broadcast_arrays: PytreeDeque[Broadcasted], /, ctx: graphlib.MergeContext, path, prefix, x, ): if isinstance(x, extract.NodeStates): carry_states = carry_deque.popleft() broadcast_states = broadcast_deque.popleft() graphdef = graphdefs_deque.popleft() return ctx.merge(graphdef, *x.states, *carry_states, *broadcast_states) elif isinstance(x, Broadcasted): assert x.data is None return broadcast_arrays.popleft().data else: return x def _scan_merge_out( carry_deque: PytreeDeque[list[State]], graphdefs_deque: PytreeDeque[graphlib.GraphDef], broadcast_deque: PytreeDeque[list[State]], /, ctx: graphlib.MergeContext, path, prefix, x, ): assert isinstance(path[0], jax.tree_util.SequenceKey) is_input_arg = path[0].idx == 0 if isinstance(x, extract.NodeStates): states: list[State] = [] graphdef = graphdefs_deque.popleft() if is_input_arg: carry_states = deque(carry_deque.popleft()) broadcast_states = deque(broadcast_deque.popleft()) else: carry_states = deque[State]() broadcast_states = deque[State]() if isinstance(prefix, StateAxes): vectorized_states = deque(x.states) for axis in prefix.axes: if isinstance(axis, int): state = vectorized_states.popleft() state = jax.tree.map( lambda x: jnp.moveaxis(x, 0, axis) if axis != 0 else x, state, ) states.append(state) elif axis is None: states.append(broadcast_states.popleft()) else: # axis is Carry states.append(carry_states.popleft()) assert not carry_states and not broadcast_states assert not vectorized_states or ( len(vectorized_states) == 1 and not vectorized_states[0] ) elif isinstance(prefix, int): state = jax.tree.map( lambda x: jnp.moveaxis(x, 0, prefix) if prefix != 0 else x, x.state ) states.extend((state, *carry_states, *broadcast_states)) elif prefix is None: assert is_input_arg states.extend(broadcast_states) elif prefix is Carry: assert is_input_arg states.extend(carry_states) else: obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Invalid axes {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) return ctx.merge(graphdef, *states) else: if isinstance(prefix, StateAxes): obj_repr = 'args' if is_input_arg else 'out' raise ValueError( 'Cannot use StateAxes on non-graph nodes, ' f'found {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) elif prefix is Carry: return x elif prefix is None: return x elif isinstance(prefix, int): if not isinstance(x, (jax.Array, np.ndarray)): obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Expected an array, got {type(x).__name__} at ' f'{obj_repr}{jax.tree_util.keystr(path)}' ) if prefix != 0: x = jnp.moveaxis(x, 0, prefix) return x else: obj_repr = 'args' if is_input_arg else 'out' raise ValueError( f'Invalid axes {prefix} at {obj_repr}{jax.tree_util.keystr(path)}' ) @dataclasses.dataclass(eq=False) class ScanFn: f: tp.Callable[..., tp.Any] input_carry_argnum: int | None | tp.Literal['all'] output_carry_argnum: int | None | tp.Literal['all'] in_axes: tp.Any out_axes: tp.Any transform_metadata: tp.Mapping[str, tp.Any] def __post_init__(self): functools.update_wrapper(self, self.f) def __call__( self, carry: tuple[ tp.Any, # carry_arg PytreeDeque[list[State]], # carry_deque PytreeDeque[list[State]], # broadcast_deque PytreeDeque[Broadcasted], # broadcast_arrays ], scan_in: tuple[tp.Any, ...], ): pure_carry_arg, carry_deque, broadcast_deque, broadcast_arrays = carry broadcast_deque_out = PytreeDeque(broadcast_deque) broadcast_arrays_out = PytreeDeque(broadcast_arrays) graphdefs_deque, pure_args = scan_in if self.input_carry_argnum == 'all': assert pure_args == () pure_args = (pure_carry_arg,) elif isinstance(self.input_carry_argnum, int): assert pure_args[self.input_carry_argnum] is None _pure_args = list(pure_args) _pure_args[self.input_carry_argnum] = pure_carry_arg pure_args = tuple(_pure_args) else: assert self.input_carry_argnum is None assert pure_carry_arg is None if spmd.PARTITION_NAME in self.transform_metadata: pure_args = _update_variable_sharding_metadata( pure_args, self.transform_metadata, spmd.remove_axis ) args: tuple = extract.from_tree( pure_args, prefix=self.in_axes, merge_fn=functools.partial( _scan_merge_in, carry_deque, graphdefs_deque, broadcast_deque, broadcast_arrays ), is_leaf=lambda x: isinstance(x, (extract.NodeStates, Broadcasted)), map_non_graph_nodes=True, ctxtag='scan', is_inner=True, ) assert not carry_deque and not broadcast_deque and not broadcast_arrays out = self.f(*args) # extract the carry from the args if self.input_carry_argnum == 'all': carry_arg = args[0] elif isinstance(self.input_carry_argnum, int): carry_arg = args[self.input_carry_argnum] else: assert self.input_carry_argnum is None carry_arg = None # extract the carry from the output if self.output_carry_argnum == 'all': carry_arg_out = out out = None elif isinstance(self.output_carry_argnum, int): assert isinstance(out, tuple) carry_arg_out = out[self.output_carry_argnum] _out = list(out) _out[self.output_carry_argnum] = None out = tuple(_out) else: assert self.output_carry_argnum is None carry_arg_out = None # TODO(cgarciae): allowing new references might lead to inconsistencies with # scan's looping semantics and we would also need to propagate the input _check_carry_same_references(carry_arg, carry_arg_out) args_out: tuple = extract.clear_non_graph_nodes(args) # replace the carry from the input args with the carry from the output if self.input_carry_argnum == 'all': args_out = (carry_arg_out,) elif isinstance(self.input_carry_argnum, int): _args_out = list(args_out) _args_out[self.input_carry_argnum] = carry_arg_out args_out = tuple(_args_out) else: assert self.input_carry_argnum is None assert carry_arg_out is None carry_deque_out = PytreeDeque[list[State | variablelib.Variable]]() graphdefs_out = PytreeDeque[graphlib.GraphDef]() _broadcast_deque_out_tmp = PytreeDeque[ list[State | variablelib.Variable] ]() # discarded pure_args_out: tuple pure_args_out, pure_out = extract.to_tree( (args_out, out), prefix=(self.in_axes, self.out_axes), split_fn=functools.partial( _scan_split_out, carry_deque_out, graphdefs_out, _broadcast_deque_out_tmp ), map_non_graph_nodes=True, ctxtag='scan', ) if spmd.PARTITION_NAME in self.transform_metadata: pure_args_out, pure_out = _update_variable_sharding_metadata( (pure_args_out, pure_out), self.transform_metadata, spmd.add_axis, ) # extract the pure carry from the pure args if self.input_carry_argnum == 'all': pure_carry_arg_out = pure_args_out[0] pure_args_out = () elif isinstance(self.input_carry_argnum, int): pure_carry_arg_out = pure_args_out[self.input_carry_argnum] _pure_args_out = list(pure_args_out) _pure_args_out[self.input_carry_argnum] = None pure_args_out = tuple(_pure_args_out) else: assert self.input_carry_argnum is None pure_carry_arg_out = None carry_arg_out = ( pure_carry_arg_out, carry_deque_out, broadcast_deque_out, broadcast_arrays_out, ) scan_out = ( graphdefs_out, pure_args_out, pure_out, ) return carry_arg_out, scan_out @dataclasses.dataclass(eq=False) class SimpleScanFn: f: tp.Callable[..., tp.Any] graph: bool in_axes: tp.Any out_axes: tp.Any out_is_tuple: bool carry_arg_index: int | None carry_out_index: int | None def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args): updates, snapshot = extract.updates_and_snapshot(args) if self.carry_arg_index is not None: carry_in = args[self.carry_arg_index] else: carry_in = None if self.graph: args = extract.from_tree2(args) out = self.f(*args) if self.graph: out = extract.to_tree2(out, prefix=self.out_axes) if self.carry_out_index is not None: carry_out = out[self.carry_out_index] if self.out_is_tuple else out extract.check_same_variables(carry_in, carry_out, 'scan') # Mask variable updates for non-carry args: identify broadcast (None axis) # Variables and drop their updates since they should not change across # scan iterations. masked_carry_updates = extract.mask_at(updates, self.carry_arg_index) masked_carry_snapshot = extract.mask_at(snapshot, self.carry_arg_index) if isinstance(self.in_axes, tuple): masked_carry_in_axes = extract.mask_at(self.in_axes, self.carry_arg_index) else: masked_carry_in_axes = self.in_axes is_leaf = lambda x: isinstance(x, variablelib.Variable) or x is None _, per_leaf_axes = extract.broadcast_prefix2( masked_carry_in_axes, masked_carry_updates, is_leaf=is_leaf, ) broadcast_var_ids = set() carry_leaves = jax.tree.leaves(masked_carry_updates, is_leaf=is_leaf) for axis, leaf in zip(per_leaf_axes, carry_leaves, strict=True): if axis is None and isinstance(leaf, variablelib.Variable): broadcast_var_ids.add(id(leaf)) def keep_fn(path, cur, snap): changed = extract.variable_changed(cur, snap) if id(cur) in broadcast_var_ids and changed: raise ValueError( f'Broadcast (None axis) Variable at {jax.tree_util.keystr(path)} ' 'was mutated during scan. Only Carry and scanned Variables can be ' 'updated.' ) return changed extract.check_no_aliases('scan', args=masked_carry_updates, out=out) masked_carry_updates = extract.mask_variable_updates( masked_carry_updates, masked_carry_snapshot, keep_fn=keep_fn, ) if self.out_is_tuple: return (*out, masked_carry_updates) return (out, masked_carry_updates) @tp.overload def scan( *, length: int | None = None, reverse: bool = False, unroll: int | bool = 1, _split_transpose: bool = False, # extended api in_axes: int | None | type[Carry] | tuple[tp.Any, ...] = (Carry, 0), out_axes: tp.Any = (Carry, 0), # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[F], F]: ... @tp.overload def scan( f: F, *, length: int | None = None, reverse: bool = False, unroll: int | bool = 1, _split_transpose: bool = False, # extended api in_axes: int | None | type[Carry] | tuple[tp.Any, ...] = (Carry, 0), out_axes: tp.Any = (Carry, 0), # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> F: ...
[docs]def scan( f: F | type[Missing] = Missing, *, length: int | None = None, reverse: bool = False, unroll: int | bool = 1, _split_transpose: bool = False, # extended api in_axes: int | None | type[Carry] | tuple[tp.Any, ...] = (Carry, 0), out_axes: tp.Any = (Carry, 0), # nnx specific transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), graph: bool | None = None, graph_updates: bool | None = None, ) -> F | tp.Callable[[F], F]: """A Flax NNX transformation of `jax.lax.scan`_. Example:: import jax from flax import nnx class Block(nnx.Module): def __init__(self, input_dim, features, *, rngs): self.linear = nnx.Linear(input_dim, features, rngs=rngs) self.dropout = nnx.Dropout(0.1, rngs=rngs) def __call__(self, x: jax.Array): x = self.linear(x) x = self.dropout(x) x = jax.nn.relu(x) return x class Model(nnx.Module): def __init__(self, num_layers, features, *, rngs): # In this model implementation we create # multiple blocks using vmap # As Block contains dropout op, we prefer # to split RNG into num_layers of RNGs # using @nnx.split_rngs decorator. # Next, nnx.vmap creates a vectorized version of Block. # in_axes and out_axes define vectorization axis # of the input splitted rngs and the output Block instance. # Both axes should be 0. @nnx.split_rngs(splits=num_layers) @nnx.vmap(in_axes=(0,), out_axes=0) def create_block(rngs: nnx.Rngs): return Block(features, features, rngs=rngs) self.blocks = create_block(rngs) self.num_layers = num_layers def __call__(self, x): # Forward pass method implementation # We use nnx.scan to apply sequentially the blocks # on the input, for example with num_layers=3 # output = block[0](x) # output = block[1](output) # output = block[2](output) # # In `forward` function defined below: # - x represents the loop carry value # - model is the data to scan along the leading axis # nnx.scan args: # - in_axes marks the inputs: x is marked as carry # and the model is to scan along the axis 0 # - out_axes marks the output as carry @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry) def forward(x, model): x = model(x) return x return forward(x, self.blocks) # Alternatively, we can also decorate `self.__call__` method # @nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry) # def __call__(self, x): # return self.blocks(x) model = Model(2, 4, rngs=nnx.Rngs(0)) _, params, _ = nnx.split(model, nnx.Param, ...) print(params) # kernel of shape: (2, 4, 4) x = jnp.arange(5 * 4, dtype="float32").reshape((5, 4)) y = model(x) print(y.shape) # shape: (5, 4) Args: f: a Python function to be scanned length: optional integer specifying the number of loop iterations reverse: optional boolean specifying whether to run the scan iteration forward (the default) or in reverse unroll: optional positive int or bool specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop. in_axes: integer, None, :class:`flax.nnx.Carry` or sequence of values specifying the kind of input args. Integer value would specify the axis of corresponding input data to scan along. :class:`flax.nnx.Carry` marks the input data as loop carry value. None marks the input data as auxiliary input. out_axes: integer, None, :class:`flax.nnx.Carry` or sequence of values specifying the kind of output args. See ``in_axes`` for details. Note that If ``in_axes`` contains :class:`flax.nnx.Carry` then ``out_axes`` must also contain :class:`flax.nnx.Carry`. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. When ``False``, using ``StateAxes`` is not supported. .. _jax.lax.scan: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html> """ if f is Missing: return functools.partial( scan, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, in_axes=in_axes, out_axes=out_axes, transform_metadata=transform_metadata, graph=graph, graph_updates=graph_updates, ) # type: ignore[return-value] f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('scan') if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if not graph or not graph_updates: return _simple_scan( f, f_unbound, graph=graph, in_axes=in_axes, out_axes=out_axes, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, ) return _graph_updates_scan( f, f_unbound, in_axes=in_axes, out_axes=out_axes, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, transform_metadata=transform_metadata, )
def _simple_scan( f, f_unbound, *, graph, in_axes, out_axes, length, reverse, unroll, _split_transpose, ): if any(isinstance(x, StateAxes) for x in jax.tree.leaves(in_axes)): raise ValueError( '`in_axes` cannot contain `StateAxes` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('scan') ) if any(isinstance(x, StateAxes) for x in jax.tree.leaves(out_axes)): raise ValueError( '`out_axes` cannot contain `StateAxes` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('scan') ) out_is_tuple = isinstance(out_axes, tuple) if in_axes is Carry: in_axes = (Carry,) if isinstance(in_axes, tuple): carry_arg_index = next( (i for i, ax in enumerate(in_axes) if ax is Carry), None ) updates_out_axes = extract.mask_at(in_axes, carry_arg_index) else: carry_arg_index = None updates_out_axes = in_axes if isinstance(out_axes, tuple): carry_out_index = next( (i for i, ax in enumerate(out_axes) if ax is Carry), None ) else: carry_out_index = None simple_scan_fn = SimpleScanFn( f_unbound, graph=graph, in_axes=in_axes, out_axes=out_axes, out_is_tuple=out_is_tuple, carry_arg_index=carry_arg_index, carry_out_index=carry_out_index, ) if out_is_tuple: augmented_out_axes = (*out_axes, updates_out_axes) else: augmented_out_axes = (out_axes, updates_out_axes) @functools.wraps(f) def simple_scan_wrapper(*args): args = resolve_kwargs(f, args, {}) if graph: args = extract.to_tree2(args, prefix=in_axes) extract.check_no_aliases('scan', args=args) result = pure_jax_fancy_scan( simple_scan_fn, *args, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, in_axes=in_axes, out_axes=augmented_out_axes, ) if out_is_tuple: n = len(out_axes) out = result[:n] updates = result[n] else: out, updates = result masked_args = extract.mask_at(args, carry_arg_index) extract.apply_variable_updates(masked_args, updates) if carry_arg_index is not None: carry_in = args[carry_arg_index] carry_out = ( out[carry_out_index] if out_is_tuple else out ) extract.update_carry_variables(carry_in, carry_out) if out_is_tuple: out_list = list(out) out_list[carry_out_index] = carry_in out = tuple(out_list) else: out = carry_in if graph: out = extract.from_tree2(out) return out return simple_scan_wrapper def _graph_updates_scan( f, f_unbound, *, in_axes, out_axes, length, reverse, unroll, _split_transpose, transform_metadata, ): _check_out_axes(out_axes) input_carry_argnum = _get_carry_argnum(in_axes, is_in_axes=True) output_carry_argnum = _get_carry_argnum(out_axes, is_in_axes=False) if (input_carry_argnum is None and output_carry_argnum is not None) or ( input_carry_argnum is not None and output_carry_argnum is None ): raise ValueError( 'If one of in_axes or out_axes has Carry, the other must also have Carry. ' f'Got {in_axes=!r} and {out_axes=!r}' ) scan_fn = ScanFn( f_unbound, input_carry_argnum, output_carry_argnum, in_axes, out_axes, transform_metadata, ) @functools.wraps(f) @graphlib.update_context('scan') def scan_wrapper(*args, **kwargs): args = resolve_kwargs(f, args, kwargs) if in_axes is Carry and len(args) != 1: raise ValueError( f'When in_axes=Carry, the function must take exactly one argument, ' f'got {len(args)} arguments.' ) graphdefs_deque = PytreeDeque() carry_deque = PytreeDeque() broadcast_deque = PytreeDeque() broadcast_arrays = PytreeDeque() pure_args: tuple = extract.to_tree( args, prefix=in_axes, split_fn=functools.partial( _scan_split_in, carry_deque, graphdefs_deque, broadcast_deque, broadcast_arrays ), map_non_graph_nodes=True, ctxtag='scan', ) if isinstance(input_carry_argnum, int): pure_carry_arg = pure_args[input_carry_argnum] _pure_args = list(pure_args) _pure_args[input_carry_argnum] = None pure_args = tuple(_pure_args) elif input_carry_argnum == 'all': pure_carry_arg = pure_args[0] pure_args = () else: assert input_carry_argnum is None pure_carry_arg = None carry = (pure_carry_arg, carry_deque, broadcast_deque, broadcast_arrays) scan_in = (graphdefs_deque, pure_args) carry_out, scan_out = jax.lax.scan( scan_fn, carry, scan_in, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, ) ( pure_carry_arg_out, carry_deque_out, broadcast_deque_out, broadcast_arrays_out, ) = carry_out ( graphdefs_out, pure_args_out, pure_out, ) = scan_out if input_carry_argnum == 'all': pure_args_out = (pure_carry_arg_out,) elif isinstance(input_carry_argnum, int): _pure_args_out = list(pure_args_out) _pure_args_out[input_carry_argnum] = pure_carry_arg_out pure_args_out = tuple(_pure_args_out) else: assert input_carry_argnum is None assert pure_carry_arg_out is None args_out, out = extract.from_tree( (pure_args_out, pure_out), prefix=(in_axes, out_axes), merge_fn=functools.partial( _scan_merge_out, carry_deque_out, graphdefs_out, broadcast_deque_out ), is_leaf=lambda x: isinstance(x, (extract.NodeStates, Broadcasted)), map_non_graph_nodes=True, ctxtag='scan', is_inner=False, ) if input_carry_argnum == 'all': carry_arg = args_out[0] elif isinstance(input_carry_argnum, int): carry_arg = args_out[input_carry_argnum] else: assert input_carry_argnum is None carry_arg = None if output_carry_argnum == 'all': out = carry_arg elif isinstance(output_carry_argnum, int): _out = list(out) _out[output_carry_argnum] = carry_arg out = tuple(_out) else: assert output_carry_argnum is None assert carry_arg is None return out return scan_wrapper def pure_jax_fancy_scan( f, *args, length: int | None = None, reverse: bool = False, unroll: int | bool = 1, _split_transpose: bool = False, in_axes: tp.Any = (Carry, 0), out_axes: tp.Any = (Carry, 0), ): if in_axes is Carry: in_axes = (Carry,) is_axis_leaf = lambda x: x is None or x is Carry if isinstance(in_axes, tuple): for i, ax in enumerate(in_axes): if ax is Carry or ax is None or isinstance(ax, int): continue for leaf in jax.tree.leaves(ax, is_leaf=is_axis_leaf): if leaf is Carry: raise ValueError( 'Carry must be a top-level argument, it cannot be nested. ' f'Found Carry inside in_axes[{i}]={ax}' ) if isinstance(out_axes, tuple): for i, ax in enumerate(out_axes): if ax is Carry or ax is None or isinstance(ax, int): continue for path, leaf in jax.tree_util.tree_leaves_with_path( ax, is_leaf=is_axis_leaf, ): if leaf is Carry: raise ValueError( 'Carry must be a top-level argument, it cannot be nested. ' f'Found Carry at out_axes[{i}]{jax.tree_util.keystr(path)}' ) in_has_carry = in_axes is Carry or ( isinstance(in_axes, tuple) and Carry in in_axes ) out_has_carry = out_axes is Carry or ( isinstance(out_axes, tuple) and Carry in out_axes ) if in_has_carry != out_has_carry: raise ValueError( 'If one of in_axes or out_axes has Carry, the other must also ' f'have Carry. Got {in_axes=}, {out_axes=}' ) args_flat, args_treedef = jax.tree.flatten(args) _, in_axes_flat = extract.broadcast_prefix2( in_axes, args, is_leaf=is_axis_leaf, ) carry_indices: list[int] = [] broadcast_indices: list[int] = [] scan_indices: list[int] = [] scan_in_axes: list[int] = [] carry_leaves: list[tp.Any] = [] broadcast_leaves: list[tp.Any] = [] scan_leaves: list[tp.Any] = [] for i, (leaf, ax) in enumerate(zip(args_flat, in_axes_flat, strict=True)): if ax is Carry: carry_indices.append(i) carry_leaves.append(leaf) elif ax is None: broadcast_indices.append(i) broadcast_leaves.append(leaf) elif isinstance(ax, int): scan_indices.append(i) scan_in_axes.append(ax) if ax != 0: leaf = jnp.moveaxis(leaf, ax, 0) scan_leaves.append(leaf) else: raise ValueError(f'Invalid in_axes leaf value: {ax}') n_in = len(args_flat) out_info: list[tuple[ jax.tree_util.PyTreeDef, list[int], list[int], list[int], ]] = [] in_broadcast = jax.tree.map(lambda x: x, broadcast_leaves) def body_fn(carry_state, scan_x): flat = [None] * n_in for idx, j in enumerate(carry_indices): flat[j] = carry_state[idx] for idx, j in enumerate(broadcast_indices): flat[j] = in_broadcast[idx] if scan_x is not None: for idx, j in enumerate(scan_indices): flat[j] = scan_x[idx] reconstructed = args_treedef.unflatten(flat) out = f(*reconstructed) out_flat, out_treedef = jax.tree.flatten(out) out_axes_paths, out_axes_flat = extract.broadcast_prefix2( out_axes, out, is_leaf=is_axis_leaf, ) if not out_info: out_carry_idx = [] out_scan_idx = [] out_scan_axes = [] out_broadcast_idx = [] for j, oax in enumerate(out_axes_flat): if oax is Carry: out_carry_idx.append(j) elif oax is None: out_broadcast_idx.append(j) elif isinstance(oax, int): out_scan_idx.append(j) out_scan_axes.append(oax) else: raise ValueError(f'Invalid out_axes leaf value: {oax}') if out_broadcast_idx: broadcast_paths = [ jax.tree_util.keystr(out_axes_paths[j]) for j in out_broadcast_idx ] broadcast_str = "\n\n ".join(broadcast_paths) raise ValueError( 'Scan does not support broadcast outputs (None axis). The following ' f'output leaves are broadcast:\n\n {broadcast_str}\n' ) out_info.append( (out_treedef, out_carry_idx, out_scan_idx, out_scan_axes), ) oci = out_info[0][1] osi = out_info[0][2] new_carry = [out_flat[j] for j in oci] new_ys = [out_flat[j] for j in osi] return new_carry, new_ys final_carry, stacked_ys = jax.lax.scan( body_fn, carry_leaves, scan_leaves if scan_leaves else None, length=length, reverse=reverse, unroll=unroll, _split_transpose=_split_transpose, ) out_treedef, out_carry_idx, out_scan_idx, out_scan_axes = ( out_info[0] ) n_out = out_treedef.num_leaves out_flat: list[tp.Any] = [None] * n_out for idx, j in enumerate(out_carry_idx): out_flat[j] = final_carry[idx] for idx, j in enumerate(out_scan_idx): y = stacked_ys[idx] ax = out_scan_axes[idx] if ax != 0: y = jnp.moveaxis(y, 0, ax) out_flat[j] = y return out_treedef.unflatten(out_flat) # ------------------------------- # while_loop # ------------------------------- @dataclasses.dataclass(eq=False) class SimpleWhileLoopBodyFn: f: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, val): val_variables, _ = extract.updates_and_snapshot(val) if self.graph: val = extract.from_tree2(val) out = self.f(val) if self.graph: out = extract.to_tree2(out) extract.check_same_variables(val_variables, out, 'while_loop') return out @dataclasses.dataclass(eq=False) class SimpleWhileLoopCondFn: f: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, val): if self.graph: val = extract.from_tree2(val) return self.f(val) @dataclasses.dataclass(eq=False) class WhileLoopCondFn: f: tp.Callable[..., tp.Any] def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, pure_val): val = extract.from_tree(pure_val) out = self.f(val) return out def _reconsile_index_mapping(tree_to_fix, example_tree): def f(a, b): if not isinstance(a, extract.NodeStates) or not isinstance( a._graphdef, graphlib.GraphDef ): return a return dataclasses.replace( a, _graphdef=a._graphdef.with_matching_outer_index(b._graphdef) ) return jax.tree.map(f, tree_to_fix, example_tree, is_leaf=lambda x: isinstance(x, extract.NodeStates)) def _add_fake_index_mapping(tree: tp.Any): def per_node_state(node_state: extract.NodeStates | tp.Any): if not isinstance(node_state, extract.NodeStates) or not isinstance( node_state._graphdef, graphlib.GraphDef ): return node_state return dataclasses.replace( node_state, _graphdef=node_state._graphdef.with_same_outer_index() ) return jax.tree.map(per_node_state, tree, is_leaf=lambda x: isinstance(x, extract.NodeStates)) def _remove_index_mapping(tree: tp.Any): """Remove a fake outer_index for the input to match that of the output.""" def per_node_state(node_state: extract.NodeStates | tp.Any): if not isinstance(node_state, extract.NodeStates) or not isinstance( node_state._graphdef, graphlib.GraphDef ): return node_state assert isinstance(node_state._graphdef, graphlib.GraphDef) node_state = dataclasses.replace( node_state, _graphdef=node_state._graphdef.with_no_outer_index() ) return node_state return jax.tree.map(per_node_state, tree, is_leaf=lambda x: isinstance(x, extract.NodeStates)) @dataclasses.dataclass(eq=False) class WhileLoopBodyFn: f: tp.Callable[..., tp.Any] def __post_init__(self): functools.update_wrapper(self, self.f) @graphlib.update_context('while_loop_body') def __call__(self, pure_val): # Removing the dummy index mapping being added outside of body function. pure_val_in = _remove_index_mapping(pure_val) val = extract.from_tree( pure_val_in, ctxtag='while_loop_body', is_inner=True ) out = self.f(val) pure_out = extract.to_tree(out, ctxtag='while_loop_body') try: jax.tree.map(lambda a, b: None, pure_val, pure_out) except ValueError as e: msg = ( "nnx.while_loop requires body function's input and output to " 'have the same reference and pytree structure, but they differ. ' 'If the mismatch comes from `outer_index` field, you might ' 'have modified reference structure within the body function, ' 'which is not allowed.' f'Detail of the mismatch: \n {str(e)}' ) raise ValueError(msg) return pure_out
[docs]@graphlib.update_context('while_loop') def while_loop(cond_fun: tp.Callable[[T], tp.Any], body_fun: tp.Callable[[T], T], init_val: T, *, graph: bool | None = None, graph_updates: bool | None = None) -> T: """A Flax NNX transformation of `jax.lax.while_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html>`_. Caution: for the NNX internal reference tracing mechanism to work, you cannot change the variable reference structure of ``init_val`` inside ``body_fun``. Example:: >>> import jax >>> from flax import nnx >>> def fwd_fn(input): ... module, x, count = input ... return module, module(x), count - 1.0 >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) >>> x = jax.random.normal(jax.random.key(0), (10,)) >>> # `module` will be called three times >>> _, y, _ = nnx.while_loop( ... lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0)) Args: cond_fun: A function for the continue condition of the while loop, taking a single input of type ``T`` and outputting a boolean. body_fun: A function that takes an input of type ``T`` and outputs an ``T``. Note that both data and modules of ``T`` must have the same reference structure between inputs and outputs. init_val: The initial input for ``cond_fun`` and ``body_fun``. Must be of type ``T``. graph: if True, use graph-mode (default). If False, use tree-mode. If None, uses the value of ``nnx_graph_mode`` config. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if not graph or not graph_updates: simple_body_fn = SimpleWhileLoopBodyFn(body_fun, graph=graph) simple_cond_fn = SimpleWhileLoopCondFn(cond_fun, graph=graph) if graph: init_val = extract.to_tree2(init_val) val_out = jax.lax.while_loop(simple_cond_fn, simple_body_fn, init_val) val_out = extract.update_carry_variables(init_val, val_out) if graph: val_out = extract.from_tree2(val_out) return val_out pure_init_val = extract.to_tree(init_val, ctxtag='while_loop') pure_init_val = _add_fake_index_mapping(pure_init_val) pure_out = jax.lax.while_loop( WhileLoopCondFn(cond_fun), WhileLoopBodyFn(body_fun), pure_init_val, ) out = extract.from_tree(pure_out, ctxtag='while_loop', is_inner=False) return out
@dataclasses.dataclass(eq=False) class SimpleForiLoopBodyFn: f: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, i, val): val_variables, _ = extract.updates_and_snapshot(val) if self.graph: val = extract.from_tree2(val) out = self.f(i, val) if self.graph: out = extract.to_tree2(out) extract.check_same_variables(val_variables, out, 'fori_loop') return out @dataclasses.dataclass(eq=False) class ForiLoopBodyFn: f: tp.Callable[..., tp.Any] def __post_init__(self): functools.update_wrapper(self, self.f) @graphlib.update_context('fori_loop_body') def __call__(self, i, pure_val_in): val = extract.from_tree(pure_val_in, ctxtag='fori_loop_body', is_inner=True) out = self.f(i, val) pure_out = extract.to_tree(out, ctxtag='fori_loop_body') return pure_out
[docs]@graphlib.update_context('fori_loop') def fori_loop(lower: int, upper: int, body_fun: tp.Callable[[int, T], T], init_val: T, *, unroll: int | bool | None = None, graph: bool | None = None, graph_updates: bool | None = None) -> T: """A Flax NNX transformation of `jax.lax.fori_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html>`_. Caution: for the NNX internal reference tracing mechanism to work, you cannot change the variable reference structure of `init_val` inside `body_fun`. Example:: >>> import jax >>> from flax import nnx >>> def fwd_fn(i, input): ... m, x = input ... m.kernel[...] = jnp.identity(10) * i ... return m, m(x) >>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0)) >>> x = jax.random.normal(jax.random.key(0), (10,)) >>> _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x)) >>> np.testing.assert_array_equal(y, x * 2 * 3) Args: lower: An integer representing the loop index lower bound (inclusive). upper: An integer representing the loop index upper bound (exclusive). body_fun: a function that takes an input of type ``T`` and outputs an ``T``. Note that both data and modules of ``T`` must have the same reference structure between inputs and outputs. init_val: the initial input for body_fun. Must be of type ``T``. unroll: An optional integer or boolean that determines how much to unroll the loop. If an integer is provided, it determines how many unrolled loop iterations to run within a single rolled iteration of the loop. If a boolean is provided, it will determine if the loop is competely unrolled (i.e. ``unroll=True``) or left completely unrolled (i.e. ``unroll=False``). This argument is only applicable if the loop bounds are statically known. graph: if True, use graph-mode (default). If False, use tree-mode. If None, uses the value of ``nnx_graph_mode`` config. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. Returns: A loop value from the final iteration, of type ``T``. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if not graph or not graph_updates: simple_body_fn = SimpleForiLoopBodyFn(body_fun, graph=graph) if graph: init_val = extract.to_tree2(init_val) val_out = jax.lax.fori_loop( lower, upper, simple_body_fn, init_val, unroll=unroll, ) val_out = extract.update_carry_variables(init_val, val_out) if graph: val_out = extract.from_tree2(val_out) return val_out pure_init_val = extract.to_tree(init_val, ctxtag='fori_loop') body = ForiLoopBodyFn(body_fun) pure_out = jax.eval_shape(body, lower, pure_init_val) pure_init_val = _reconsile_index_mapping(pure_init_val, pure_out) pure_out = jax.lax.fori_loop(lower, upper, body, pure_init_val, unroll=unroll) out = extract.from_tree(pure_out, ctxtag='fori_loop', is_inner=False) return out