# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
from collections import deque
import dataclasses
import functools
import typing as tp
from flax import struct
from flax import typing
from flax.core.frozen_dict import FrozenDict
from flax.nnx import extract, filterlib, graphlib, spmd, variablelib
from flax.nnx import statelib
from flax.nnx.module import Module
from flax.nnx.statelib import State
from flax.nnx.transforms.transforms import (
resolve_kwargs,
_resolve_bound_callable,
_raise_bound_method_error,
)
from flax.typing import Leaf, Missing, PytreeDeque
import jax
import jax.core
import jax.numpy as jnp
import jax.stages
import numpy as np
A = tp.TypeVar('A')
C = tp.TypeVar('C')
B = tp.TypeVar('B')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any])
M = tp.TypeVar('M', bound=Module)
MA = tp.TypeVar('MA', bound=Module)
N = tp.TypeVar('N', bound=Module)
T = tp.TypeVar('T')
StrInt = tp.TypeVar('StrInt', str, int)
AxisName = tp.Hashable
Leaves = list[Leaf]
Index = int
[docs]class Carry:
"""Helper class for :func:`flax.nnx.scan` function to mark input and output axis as carry.
"""
pass
# -------------------------------
# transform_metadata
# -------------------------------
def _apply_axis_fn(
tree: tp.Any,
axes: tp.Any,
metadata: tp.Mapping[str, tp.Any],
axis_fn: tp.Callable[..., tp.Any],
) -> None:
is_leaf = lambda x: x is None or isinstance(x, variablelib.Variable)
_, per_leaf_axes = extract.broadcast_prefix2(axes, tree, is_leaf=is_leaf)
leaves = jax.tree_util.tree_leaves(tree, is_leaf=is_leaf)
for leaf, axis in zip(leaves, per_leaf_axes):
if isinstance(axis, int) and isinstance(leaf, variablelib.Variable):
axis_fn(leaf, axis, metadata)
@tp.overload
def transform_metadata(
*,
in_axes: tp.Any = 0,
out_axes: tp.Any = 0,
partition: str,
graph: bool | None = None,
) -> tp.Callable[[F], F]:
...
@tp.overload
def transform_metadata(
f: F,
*,
in_axes: tp.Any = 0,
out_axes: tp.Any = 0,
graph: bool | None = None,
partition: str,
) -> F:
...
def transform_metadata(
f: F | type[Missing] = Missing,
*,
in_axes: tp.Any = 0,
out_axes: tp.Any = 0,
graph: bool | None = None,
partition: str,
) -> F | tp.Callable[[F], F]:
if f is Missing:
return functools.partial(
transform_metadata,
in_axes=in_axes,
out_axes=out_axes,
partition=partition,
graph=graph,
) # type: ignore[return-value]
if graph is None:
graph = graphlib.set_graph_mode.current_value()
metadata: tp.Mapping[str, tp.Any] = {
spmd.PARTITION_NAME: partition,
}
@functools.wraps(f)
def wrapper(*in_args, **in_kwargs):
in_args = resolve_kwargs(f, in_args, in_kwargs)
if graph:
in_args = extract.to_tree2(in_args, prefix=in_axes)
extract.check_no_aliases('transform_metadata', args=in_args)
args = graphlib.clone(in_args, graph=graph)
_apply_axis_fn(args, in_axes, metadata, spmd.remove_axis)
updates, snapshot = extract.updates_and_snapshot(args)
if graph:
args = extract.from_tree2(args)
out = f(*args)
if graph:
out = extract.to_tree2(out, prefix=out_axes)
extract.check_no_aliases('transform_metadata', args=updates, out=out)
_apply_axis_fn(args, in_axes, metadata, spmd.add_axis)
_apply_axis_fn(out, out_axes, metadata, spmd.add_axis)
updates = extract.mask_variable_updates(updates, snapshot)
extract.apply_variable_updates(in_args, updates)
if graph:
out = extract.from_tree2(out)
return out
return wrapper # type: ignore[return-value]
# -------------------------------
# vmap
# -------------------------------
class StateAxes(extract.PrefixMapping, tp.Mapping):
def __init__(
self,
filter_axes: (
statelib.State
| tp.Mapping[filterlib.Filter, Index | type[Carry] | None]
| tp.Iterable[tuple[filterlib.Filter, Index | type[Carry] | None]]
),
/,
):
if isinstance(filter_axes, statelib.State):
filter_axes = statelib.create_path_filters(filter_axes) # type: ignore
iterable = tuple(
filter_axes.items()
if isinstance(filter_axes, tp.Mapping)
else filter_axes
)
self._filters = tuple(filter for filter, _ in iterable)
self._axes = tuple(axis for _, axis in iterable)
@property
def filters(self) -> tuple[filterlib.Filter, ...]:
return self._filters
@property
def axes(self) -> tuple[Index | type[Carry] | None, ...]:
return self._axes
def map_prefix(
self, path: typing.PathParts, variable: variablelib.Variable
) -> tp.Any:
for filter, axis in zip(self.filters, self.axes):
predicate = filterlib.to_predicate(filter)
if predicate(path, variable):
return axis
raise ValueError(f'No axis found for {path=}, {variable=}')
def __repr__(self):
return f'StateAxes({dict(self.items())})'
def items(self):
return zip(self.filters, self.axes)
def __getitem__(self, key):
return self.axes[self.filters.index(key)]
def __iter__(self):
return iter(self.filters)
def __len__(self):
return len(self.filters)
def __eq__(self, other):
return (
isinstance(other, StateAxes)
and self.filters == other.filters
and self.axes == other.axes
)
def __hash__(self):
return hash((self.filters, self.axes))
AxisFn = tp.Callable[
[graphlib.GraphState | variablelib.Variable, int, tp.Mapping],
graphlib.GraphState | variablelib.Variable,
]
def _update_variable_sharding_metadata(
tree, transform_metadata, axis_fn: AxisFn
):
def _update_axes_fn(node_states):
if isinstance(node_states, extract.NodeStates) and isinstance(
node_states.metadata, (StateAxes, int)
):
if isinstance(node_states.metadata, int):
state = node_states.state
assert isinstance(state, State | variablelib.Variable)
state = axis_fn(state, node_states.metadata, transform_metadata)
return node_states.replace(states=(state,))
else:
states_out: list[graphlib.GraphState | variablelib.Variable] = []
for state, axis in zip(node_states.states, node_states.metadata.axes):
assert isinstance(state, graphlib.State | variablelib.Variable)
if isinstance(axis, int):
state = axis_fn(state, axis, transform_metadata)
states_out.append(state)
return node_states.replace(states=tuple(states_out))
return node_states
return jax.tree.map(
_update_axes_fn, tree, is_leaf=lambda x: isinstance(x, extract.NodeStates)
)
def _vmap_split_fn(ctx: graphlib.SplitContext, path, prefix, x):
if isinstance(prefix, StateAxes):
return extract.NodeStates.from_split(
*ctx.split(x, *prefix.filters), metadata=prefix
)
return extract.NodeStates.from_split(*ctx.split(x), metadata=prefix)
@dataclasses.dataclass(eq=False)
class SimpleVmapFn:
f: tp.Callable[..., tp.Any]
graph: bool
out_axes: tp.Any
def __post_init__(self):
functools.update_wrapper(self, self.f, updated=())
@extract.treemap_copy_args
def __call__(self, *args, **kwargs):
updates, snapshot = extract.updates_and_snapshot((args, kwargs))
if self.graph:
args, kwargs = extract.from_tree2((args, kwargs))
out = self.f(*args, **kwargs)
if self.graph:
out = extract.to_tree2(out, prefix=self.out_axes)
extract.check_no_aliases('vmap', args=updates[0], kwargs=updates[1], out=out)
updates = extract.mask_variable_updates(updates, snapshot)
return out, updates
@dataclasses.dataclass(eq=False)
class SimplePmapFn:
f: tp.Callable[..., tp.Any]
graph: bool
out_axes: tp.Any
def __post_init__(self):
functools.update_wrapper(self, self.f, updated=())
@extract.treemap_copy_args
def __call__(self, *args, **kwargs):
updates, snapshot = extract.updates_and_snapshot((args, kwargs))
if self.graph:
args, kwargs = extract.from_tree2((args, kwargs))
out = self.f(*args, **kwargs)
if self.graph:
out = extract.to_tree2(out, prefix=self.out_axes)
extract.check_no_aliases('pmap', args=updates[0], kwargs=updates[1], out=out)
updates = extract.mask_variable_updates(updates, snapshot)
return out, updates
@dataclasses.dataclass(eq=False)
class VmapFn:
f: tp.Callable[..., tp.Any]
transform_metadata: tp.Mapping[str, tp.Any]
in_axes: tp.Any
out_axes: tp.Any
def __post_init__(self):
functools.update_wrapper(self, self.f)
def __call__(self, *pure_args: tuple[tp.Any, ...]):
if spmd.PARTITION_NAME in self.transform_metadata:
pure_args = _update_variable_sharding_metadata(
pure_args, self.transform_metadata, spmd.remove_axis
)
args = extract.from_tree(pure_args, ctxtag='vmap', is_inner=True)
out = self.f(*args)
args_out = extract.clear_non_graph_nodes(args)
pure_args_out, pure_out = extract.to_tree(
(args_out, out),
prefix=(self.in_axes, self.out_axes),
split_fn=_vmap_split_fn,
ctxtag='vmap',
)
if spmd.PARTITION_NAME in self.transform_metadata:
pure_args_out, pure_out = _update_variable_sharding_metadata(
(pure_args_out, pure_out), self.transform_metadata, spmd.add_axis
)
return pure_args_out, pure_out
@tp.overload
def vmap(
*,
in_axes: int | None | tp.Sequence[tp.Any] = 0,
out_axes: tp.Any = 0,
axis_name: AxisName | None = None,
axis_size: int | None = None,
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
# nnx specific
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> tp.Callable[[F], F]:
...
@tp.overload
def vmap(
f: F,
*,
in_axes: int | None | tp.Sequence[tp.Any] = 0,
out_axes: tp.Any = 0,
axis_name: AxisName | None = None,
axis_size: int | None = None,
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
# nnx specific
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> F:
...
[docs]def vmap(
f: F | type[Missing] = Missing,
*,
in_axes: int | None | tp.Sequence[tp.Any] = 0,
out_axes: tp.Any = 0,
axis_name: AxisName | None = None,
axis_size: int | None = None,
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
# nnx specific
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> F | tp.Callable[[F], F]:
"""Reference-aware version of `jax.vmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__.
Args:
f: Function to be mapped over additional axes.
in_axes: An integer, None, or sequence of values specifying which input
array axes to map over (see `jax.vmap
<https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__). In
addition to integers and None, :class:`StateAxes` can be used to control
how graph nodes like Modules are vectorized by specifying the axes to be
applied to substates of the graph node given a `Filter
<https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__.
out_axes: An integer, None, or pytree indicating where the mapped axis
should appear in the output (see `jax.vmap
<https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html>`__).
axis_name: Optional, a hashable Python object used to identify the mapped
axis so that parallel collectives can be applied.
axis_size: Optional, an integer indicating the size of the axis to be
mapped. If not provided, the mapped axis size is inferred from arguments.
graph: If ``True`` (default), uses graph-mode which supports the full
NNX feature set including shared references and reference semantics.
If ``False``, uses tree-mode which treats Modules as regular JAX
pytrees, avoiding the overhead of the graph protocol. Tree-mode does
not support ``StateAxes`` or shared ``Variable`` references.
graph_updates: If ``True``, propagates updates on graph structure
that happen inside the transform to the input graphs, has no
effect when ``graph=False``. When ``False``, using ``StateAxes``
is not supported.
Returns:
Batched/vectorized version of ``f`` with arguments that correspond to
those of ``f``, but with extra array axes at positions indicated by
``in_axes``, and a return value that corresponds to that of ``f``, but
with extra array axes at positions indicated by ``out_axes``.
Example::
>>> from flax import nnx
>>> from jax import random, numpy as jnp
...
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((5, 2))
...
>>> @nnx.vmap(in_axes=(None, 0), out_axes=0)
... def forward(model, x):
... return model(x)
...
>>> y = forward(model, x)
>>> y.shape
(5, 3)
>>> class LinearEnsemble(nnx.Module):
... def __init__(self, num, rngs):
... self.w = nnx.Param(jax.random.uniform(rngs(), (num, 2, 3)))
...
>>> model = LinearEnsemble(5, rngs=nnx.Rngs(0))
>>> x = jnp.ones((2,))
...
>>> @nnx.vmap(in_axes=(0, None), out_axes=0)
... def forward(model, x):
... return x @ model.w
...
>>> y = forward(model, x)
>>> y.shape
(5, 3)
To control control how graph node substates are vectorized, ``StateAxes``
can be passed to ``in_axes`` and ``out_axes`` specifying the axes to be
applied to each substate given a filter. The following example shows how to
share the parameters between the ensemble members which keeping different
batch statistics and dropout random state::
>>> class Foo(nnx.Module):
... def __init__(self):
... self.a = nnx.Param(jnp.arange(4))
... self.b = nnx.BatchStat(jnp.arange(4))
...
>>> state_axes = nnx.StateAxes({nnx.Param: 0, nnx.BatchStat: None})
>>> @nnx.vmap(in_axes=(state_axes,), out_axes=0)
... def mul(foo):
... return foo.a * foo.b
...
>>> foo = Foo()
>>> y = mul(foo)
>>> y
Array([[0, 0, 0, 0],
[0, 1, 2, 3],
[0, 2, 4, 6],
[0, 3, 6, 9]], dtype=int32)
"""
if graph is None:
graph = graphlib.set_graph_mode.current_value()
if graph_updates is None:
graph_updates = graphlib.set_graph_updates.current_value()
if f is Missing:
return functools.partial(
vmap,
in_axes=in_axes,
out_axes=out_axes,
axis_name=axis_name,
axis_size=axis_size,
spmd_axis_name=spmd_axis_name,
transform_metadata=transform_metadata,
graph=graph,
graph_updates=graph_updates,
) # type: ignore[return-value]
f_unbound, _, was_bound = _resolve_bound_callable(f)
if was_bound:
_raise_bound_method_error('vmap')
if not graph or not graph_updates:
if any(isinstance(x, StateAxes) for x in jax.tree.leaves(in_axes)):
raise ValueError(
'`in_axes` cannot contain `StateAxes` objects '
'when `graph=False`. '
+ graphlib._tree_mode_suggestion_transform('vmap')
)
if any(isinstance(x, StateAxes) for x in jax.tree.leaves(out_axes)):
raise ValueError(
'`out_axes` cannot contain `StateAxes` objects '
'when `graph=False`. '
+ graphlib._tree_mode_suggestion_transform('vmap')
)
vmapped_fn = jax.vmap(
SimpleVmapFn(f_unbound, graph=graph, out_axes=out_axes),
in_axes=in_axes,
out_axes=(out_axes, (in_axes, 0)),
axis_name=axis_name,
axis_size=axis_size,
spmd_axis_name=spmd_axis_name,
)
@functools.wraps(f_unbound)
def simple_vmap_wrapper(*args, **kwargs):
if graph:
args, kwargs = extract.to_tree2(
(args, kwargs),
prefix=(in_axes, None)
if in_axes is not None
else None,
check_aliasing=in_axes is not None,
)
extract.check_no_aliases('vmap', args=args, kwargs=kwargs)
out, updates = vmapped_fn(*args, **kwargs)
extract.apply_variable_updates((args, kwargs), updates)
if graph:
out = extract.from_tree2(out)
return out
return simple_vmap_wrapper # type: ignore[return-value]
jax_in_axes = jax.tree.map(
lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x)
if isinstance(x, StateAxes)
else x,
in_axes,
)
jax_out_axes = jax.tree.map(
lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x)
if isinstance(x, StateAxes)
else x,
out_axes,
)
vmapped_fn = jax.vmap( # type: ignore[assignment]
VmapFn(f_unbound, transform_metadata, in_axes, out_axes),
in_axes=jax_in_axes,
out_axes=(jax_in_axes, jax_out_axes),
axis_name=axis_name,
axis_size=axis_size,
spmd_axis_name=spmd_axis_name,
)
@functools.wraps(f)
@graphlib.update_context('vmap')
def vmap_wrapper(*args, **kwargs):
args = resolve_kwargs(f, args, kwargs)
pure_args = extract.to_tree(
args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='vmap'
)
pure_args_out, pure_out = vmapped_fn(*pure_args)
_args_out, out = extract.from_tree(
(pure_args_out, pure_out), ctxtag='vmap', is_inner=False
)
return out
return vmap_wrapper # type: ignore
# -------------------------------
# pmap
# -------------------------------
@dataclasses.dataclass(eq=False)
class PmapFn:
f: tp.Callable[..., tp.Any]
transform_metadata: tp.Mapping[str, tp.Any]
in_axes: tp.Any
out_axes: tp.Any
def __post_init__(self):
functools.update_wrapper(self, self.f)
def __call__(self, *pure_args: tuple[tp.Any, ...]):
if spmd.PARTITION_NAME in self.transform_metadata:
pure_args = _update_variable_sharding_metadata(
pure_args, self.transform_metadata, spmd.remove_axis
)
args = extract.from_tree(pure_args, ctxtag='pmap', is_inner=True)
out = self.f(*args)
args_out = extract.clear_non_graph_nodes(args)
pure_args_out, pure_out = extract.to_tree(
(args_out, out),
prefix=(self.in_axes, self.out_axes),
split_fn=_vmap_split_fn,
ctxtag='pmap',
)
if spmd.PARTITION_NAME in self.transform_metadata:
pure_args_out, pure_out = _update_variable_sharding_metadata(
(pure_args_out, pure_out), self.transform_metadata, spmd.add_axis
)
return pure_args_out, pure_out
@tp.overload
def pmap(
*,
axis_name: AxisName | None = None,
in_axes: tp.Any = 0,
out_axes: tp.Any = 0,
static_broadcasted_argnums: int | tp.Iterable[int] = (),
devices: tp.Sequence[jax.Device] | None = None, # noqa: F811
backend: str | None = None,
axis_size: int | None = None,
donate_argnums: int | tp.Iterable[int] = (),
global_arg_shapes: tuple[tuple[int, ...], ...] | None = None,
# nnx specific
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> tp.Callable[[F], F]:
...
@tp.overload
def pmap(
f: F,
*,
axis_name: AxisName | None = None,
in_axes: tp.Any = 0,
out_axes: tp.Any = 0,
static_broadcasted_argnums: int | tp.Iterable[int] = (),
devices: tp.Sequence[jax.Device] | None = None, # noqa: F811
backend: str | None = None,
axis_size: int | None = None,
donate_argnums: int | tp.Iterable[int] = (),
global_arg_shapes: tuple[tuple[int, ...], ...] | None = None,
# nnx specific
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> F:
...
def pmap(
f: F | type[Missing] = Missing,
*,
axis_name: AxisName | None = None,
in_axes: tp.Any = 0,
out_axes: tp.Any = 0,
static_broadcasted_argnums: int | tp.Iterable[int] = (),
devices: tp.Sequence[jax.Device] | None = None, # noqa: F811
backend: str | None = None,
axis_size: int | None = None,
donate_argnums: int | tp.Iterable[int] = (),
global_arg_shapes: tuple[tuple[int, ...], ...] | None = None,
# nnx specific
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> F | tp.Callable[[F], F]:
"""Reference-aware version of `jax.pmap <https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html>`__.
Args:
f: Function to be mapped over argument axes. Its arguments and return
value should be arrays, scalars, graph nodes, or (nested) standard
Python containers (tuple/list/dict) thereof. Positional arguments
indicated by ``static_broadcasted_argnums`` can be anything at all,
provided they are hashable and have an equality operation defined.
axis_name: Optional, a hashable Python object used to identify the mapped
axis so that parallel collectives can be applied.
in_axes: A non-negative integer, None, or nested Python container thereof
that specifies which axes of positional arguments to map over. Arguments
passed as keywords are always mapped over their leading axis (i.e. axis
index 0). In addition to integers and None, :class:`StateAxes` can be
used to control how graph nodes like Modules are vectorized by specifying
the axes to be applied to substates of the graph node given a `Filter
<https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__.
out_axes: A non-negative integer, None, or nested Python container thereof
indicating where the mapped axis should appear in the output. All outputs
with a mapped axis must have a non-None ``out_axes`` specification.
static_broadcasted_argnums: An int or collection of ints specifying which
positional arguments to treat as static (compile-time constant).
Operations that only depend on static arguments will be constant-folded.
Calling the pmapped function with different values for these constants
will trigger recompilation. If the pmapped function is called with fewer
positional arguments than indicated by ``static_broadcasted_argnums`` then
an error is raised. Each of the static arguments will be broadcasted to
all devices. Arguments that are not arrays or containers thereof must be
marked as static. Defaults to ().
Static arguments must be hashable, meaning both ``__hash__`` and
``__eq__`` are implemented, and should be immutable.
devices: This is an experimental feature and the API is likely to change.
Optional, a sequence of Devices to map over. (Available devices can be
retrieved via jax.devices()). Must be given identically for each process
in multi-process settings (and will therefore include devices across
processes). If specified, the size of the mapped axis must be equal to
the number of devices in the sequence local to the given process. Nested
``pmap`` s with ``devices`` specified in either the inner or outer
``pmap`` are not yet supported.
backend: This is an experimental feature and the API is likely to change.
Optional, a string representing the XLA backend. 'cpu', 'gpu', or 'tpu'.
axis_size: Optional; the size of the mapped axis.
donate_argnums: Specify which positional argument buffers are "donated" to
the computation. It is safe to donate argument buffers if you no longer
need them once the computation has finished. In some cases XLA can make
use of donated buffers to reduce the amount of memory needed to perform a
computation, for example recycling one of your input buffers to store a
result. You should not reuse buffers that you donate to a computation,
JAX will raise an error if you try to. Note that donate_argnums only
work for positional arguments, and keyword arguments will not be donated.
transform_metadata: Optional mapping of metadata for the transform.
graph: if True, use graph-mode (default). If False, use tree-mode.
If None, uses the value of ``nnx_graph_mode`` config.
graph_updates: If ``True``, propagates updates on graph structure
that happen inside the transform to the input graphs, has no
effect when ``graph=False``. When ``False``, using ``StateAxes``
is not supported.
Returns:
A parallelized version of ``f`` with arguments that correspond to those of
``f`` but with extra array axes at positions indicated by ``in_axes`` and
with output that has an additional leading array axis (with the same size).
"""
if graph is None:
graph = graphlib.set_graph_mode.current_value()
if graph_updates is None:
graph_updates = graphlib.set_graph_updates.current_value()
if f is Missing:
return functools.partial(
pmap,
axis_name=axis_name,
in_axes=in_axes,
out_axes=out_axes,
static_broadcasted_argnums=static_broadcasted_argnums,
devices=devices,
backend=backend,
axis_size=axis_size,
donate_argnums=donate_argnums,
global_arg_shapes=global_arg_shapes,
transform_metadata=transform_metadata,
graph=graph,
graph_updates=graph_updates,
) # type: ignore[return-value]
f_unbound, _, was_bound = _resolve_bound_callable(f)
if was_bound:
_raise_bound_method_error('pmap')
if not graph or not graph_updates:
if any(isinstance(x, StateAxes) for x in jax.tree.leaves(in_axes)):
raise ValueError(
'`in_axes` cannot contain `StateAxes` objects '
'when `graph=False`. '
+ graphlib._tree_mode_suggestion_transform('pmap')
)
if any(isinstance(x, StateAxes) for x in jax.tree.leaves(out_axes)):
raise ValueError(
'`out_axes` cannot contain `StateAxes` objects '
'when `graph=False`. '
+ graphlib._tree_mode_suggestion_transform('pmap')
)
pmapped_fn = jax.pmap(
SimplePmapFn(f_unbound, graph=graph, out_axes=out_axes),
axis_name=axis_name,
in_axes=in_axes,
out_axes=(out_axes, (in_axes, 0)),
static_broadcasted_argnums=static_broadcasted_argnums,
devices=devices,
backend=backend,
axis_size=axis_size,
donate_argnums=donate_argnums,
global_arg_shapes=global_arg_shapes,
)
@functools.wraps(f_unbound)
def simple_pmap_wrapper(*args, **kwargs):
if graph:
args, kwargs = extract.to_tree2(
(args, kwargs),
prefix=(in_axes, None)
if in_axes is not None
else None,
check_aliasing=in_axes is not None,
)
extract.check_no_aliases('pmap', args=args, kwargs=kwargs)
out, updates = pmapped_fn(*args, **kwargs)
extract.apply_variable_updates((args, kwargs), updates)
if graph:
out = extract.from_tree2(out)
return out
return simple_pmap_wrapper # type: ignore[return-value]
jax_in_axes = jax.tree.map(
lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x)
if isinstance(x, StateAxes)
else x,
in_axes,
)
jax_out_axes = jax.tree.map(
lambda x: extract.NodeStates.from_prefixes(x.axes, metadata=x)
if isinstance(x, StateAxes)
else x,
out_axes,
)
pmapped_fn = jax.pmap(
PmapFn(f_unbound, transform_metadata, in_axes, out_axes),
axis_name=axis_name,
in_axes=jax_in_axes,
out_axes=(jax_in_axes, jax_out_axes),
static_broadcasted_argnums=static_broadcasted_argnums,
devices=devices,
backend=backend,
axis_size=axis_size,
donate_argnums=donate_argnums,
global_arg_shapes=global_arg_shapes,
)
@functools.wraps(f)
@graphlib.update_context('pmap')
def vmap_wrapper(*args):
pure_args = extract.to_tree(
args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='pmap'
)
pure_args_out, pure_out = pmapped_fn(*pure_args)
_args_out, out = extract.from_tree(
(pure_args_out, pure_out), ctxtag='pmap', is_inner=False
)
return out
return vmap_wrapper # type: ignore
# -------------------------------
# scan
# -------------------------------
class Broadcasted(struct.PyTreeNode):
data: tp.Any
def _get_carry_argnum(axes, is_in_axes: bool):
if axes is Carry:
return 'all'
elif isinstance(axes, int) or axes is None:
return None
obj_repr = 'in_axes' if is_in_axes else 'out_axes'
carry_argnum: int | None = None
prev_key: tp.Any = None
for key, x in jax.tree_util.tree_leaves_with_path(axes):
if x is not Carry:
continue
assert isinstance(key[0], jax.tree_util.SequenceKey)
i = key[0].idx
if len(key) >= 2:
raise ValueError(
f'Carry must at the top-level, it cannot be nested. Found {axes=}'
)
if carry_argnum is not None:
raise ValueError(
f'Found multiple Carry definitions at '
f'{obj_repr}{jax.tree_util.keystr(prev_key)} and '
f'{obj_repr}{jax.tree_util.keystr(key)}'
)
carry_argnum = i
prev_key = key
return carry_argnum
def _check_out_axes(out_axes):
for key, x in jax.tree_util.tree_leaves_with_path(
out_axes, is_leaf=lambda x: x is None
):
if x is None:
raise ValueError(
f'Cannot broadcast output state. '
f'Got out_axes=None at: out_axes{jax.tree_util.keystr(key)}'
)
elif isinstance(x, StateAxes):
for filter, value in x.items():
if value is None:
raise ValueError(
f'Cannot broadcast output state. '
f'Got StateAxes({{{filter}: None}}) at: out_axes'
f'{jax.tree_util.keystr(key)}'
)
elif value is Carry:
raise ValueError(
f'Cannot carry output state. '
f'Got StateAxes({{{filter}: Carry}}) at: out_axes'
f'{jax.tree_util.keystr(key)}'
)
def _check_carry_same_references(carry_arg, carry_arg_out):
def check_carry_same_references(key_path, arg, out):
if (
not isinstance(arg, jax.Array) or not isinstance(out, jax.Array)
) and arg is not out:
raise ValueError(
'Carry references must be the same between iterations. '
f'Got {arg=} with id={id(arg)} and {out=} with id={id(out)} '
f'at carry{jax.tree_util.keystr(key_path)}'
)
jax.tree_util.tree_map_with_path(
check_carry_same_references,
carry_arg,
carry_arg_out,
is_leaf=lambda x: graphlib.is_graph_node(x)
and not isinstance(x, variablelib.Variable),
)
def _scan_split_in(
carry_deque: PytreeDeque[list[State | variablelib.Variable]],
graphdefs_deque: PytreeDeque[graphlib.GraphDef],
broadcast_deque: PytreeDeque[list[State | variablelib.Variable]],
broadcast_arrays: PytreeDeque[Broadcasted],
/,
ctx: graphlib.SplitContext,
path,
prefix,
x,
):
if graphlib.is_graph_node(x) or isinstance(x, variablelib.Variable):
vectorized_states: list[State | variablelib.Variable] = []
carry_states: list[State | variablelib.Variable] = []
broadcast_states: list[State | variablelib.Variable] = []
if isinstance(prefix, StateAxes):
graphdef, *states = ctx.split(x, *prefix.filters)
for state, axis in zip(states, prefix.axes):
if axis is None:
broadcast_states.append(state)
elif isinstance(axis, int):
if axis != 0:
state = jax.tree.map(lambda x: jnp.moveaxis(x, axis, 0), state)
vectorized_states.append(state)
else: # axis is Carry
carry_states.append(state)
if not vectorized_states:
vectorized_states.append(State({}))
carry_deque.append(carry_states)
graphdefs_deque.append(graphdef)
broadcast_deque.append(broadcast_states)
return extract.NodeStates.from_split(
None, *vectorized_states, metadata=prefix
)
elif isinstance(prefix, int):
graphdef, state = ctx.split(x)
if prefix != 0:
state = jax.tree.map(lambda x: jnp.moveaxis(x, prefix, 0), state)
vectorized_states.append(state)
elif prefix is None:
graphdef, state = ctx.split(x)
broadcast_states.append(state)
vectorized_states.append(State({}))
elif prefix is Carry:
graphdef, state = ctx.split(x)
carry_states.append(state)
vectorized_states.append(State({}))
else:
raise ValueError(
f'Invalid axes {prefix} args{jax.tree_util.keystr(path)}'
)
if not vectorized_states:
vectorized_states.append(State({}))
carry_deque.append(carry_states)
graphdefs_deque.append(graphdef)
broadcast_deque.append(broadcast_states)
return extract.NodeStates.from_split(
None, *vectorized_states, metadata=prefix
)
else:
if isinstance(prefix, StateAxes):
raise ValueError(
'Cannot use StateAxes on non-graph nodes, '
f'found {prefix} args{jax.tree_util.keystr(path)}'
)
elif prefix is Carry:
return x
elif prefix is None:
broadcast_arrays.append(Broadcasted(x))
return Broadcasted(None)
elif isinstance(prefix, int):
if not isinstance(x, (jax.Array, np.ndarray)):
raise ValueError(
f'Expected an array, got {type(x).__name__} args'
f'{jax.tree_util.keystr(path)}'
)
if prefix != 0:
x = jnp.moveaxis(x, prefix, 0)
return x
else:
raise ValueError(
f'Invalid axes {prefix} args{jax.tree_util.keystr(path)}'
)
def _scan_split_out(
carry_deque: PytreeDeque[list[State | variablelib.Variable]],
graphdefs_deque: PytreeDeque[graphlib.GraphDef],
broadcast_deque: PytreeDeque[list[State | variablelib.Variable]],
/,
ctx: graphlib.SplitContext,
path: extract.KeyPath,
prefix,
x,
):
assert isinstance(path[0], jax.tree_util.SequenceKey)
is_input_arg = path[0].idx == 0
if graphlib.is_graph_node(x) or isinstance(x, variablelib.Variable):
vectorized_states: list[State | variablelib.Variable] = []
carry_states: list[State | variablelib.Variable] = []
broadcast_states: list[State | variablelib.Variable] = []
if isinstance(prefix, StateAxes):
graphdef, *states = ctx.split(x, *prefix.filters)
for state, filter, axis in zip(states, prefix.filters, prefix.axes):
if axis is None:
assert is_input_arg # validated by _check_out_axes
broadcast_states.append(state)
elif isinstance(axis, int):
vectorized_states.append(state)
elif axis is Carry:
assert is_input_arg # validated by _check_out_axes
carry_states.append(state)
else:
obj_repr = 'args' if is_input_arg else 'out'
raise ValueError(
f'Invalid axes {axis} for filter {filter} at '
f'{obj_repr}{jax.tree_util.keystr(path)}'
)
if not vectorized_states:
vectorized_states.append(State({}))
if is_input_arg:
carry_deque.append(carry_states)
broadcast_deque.append(broadcast_states)
graphdefs_deque.append(graphdef)
return extract.NodeStates.from_split(
None, *vectorized_states, metadata=prefix
)
elif isinstance(prefix, int):
graphdef, state = ctx.split(x)
vectorized_states.append(state)
elif prefix is None:
assert is_input_arg # validated by _check_out_axes
graphdef, state = ctx.split(x)
broadcast_states.append(state)
vectorized_states.append(State({}))
elif prefix is Carry:
assert is_input_arg # validated by _check_out_axes
graphdef, state = ctx.split(x)
carry_states.append(state)
vectorized_states.append(State({}))
else:
obj_repr = 'args' if is_input_arg else 'out'
raise ValueError(
f'Invalid axes {prefix} at {obj_repr}{jax.tree_util.keystr(path)}'
)
if not vectorized_states:
vectorized_states.append(State({}))
if is_input_arg:
carry_deque.append(carry_states)
broadcast_deque.append(broadcast_states)
graphdefs_deque.append(graphdef)
return extract.NodeStates.from_split(
None, *vectorized_states, metadata=prefix
)
else:
if isinstance(prefix, StateAxes):
obj_repr = 'args' if is_input_arg else 'out'
raise ValueError(
'Cannot use StateAxes on non-graph nodes, '
f'found {prefix} at {obj_repr}{jax.tree_util.keystr(path)}'
)
elif prefix is Carry:
return x
elif prefix is None:
assert not is_input_arg # validated by _check_out_axes
return Broadcasted(None)
elif isinstance(prefix, int):
return x
else:
obj_repr = 'args' if is_input_arg else 'out'
raise ValueError(
f'Invalid axes {prefix} at {obj_repr}{jax.tree_util.keystr(path)}'
)
def _scan_merge_in(
carry_deque: PytreeDeque[list[State]],
graphdefs_deque: PytreeDeque[graphlib.GraphDef],
broadcast_deque: PytreeDeque[list[State]],
broadcast_arrays: PytreeDeque[Broadcasted],
/,
ctx: graphlib.MergeContext,
path,
prefix,
x,
):
if isinstance(x, extract.NodeStates):
carry_states = carry_deque.popleft()
broadcast_states = broadcast_deque.popleft()
graphdef = graphdefs_deque.popleft()
return ctx.merge(graphdef, *x.states, *carry_states, *broadcast_states)
elif isinstance(x, Broadcasted):
assert x.data is None
return broadcast_arrays.popleft().data
else:
return x
def _scan_merge_out(
carry_deque: PytreeDeque[list[State]],
graphdefs_deque: PytreeDeque[graphlib.GraphDef],
broadcast_deque: PytreeDeque[list[State]],
/,
ctx: graphlib.MergeContext,
path,
prefix,
x,
):
assert isinstance(path[0], jax.tree_util.SequenceKey)
is_input_arg = path[0].idx == 0
if isinstance(x, extract.NodeStates):
states: list[State] = []
graphdef = graphdefs_deque.popleft()
if is_input_arg:
carry_states = deque(carry_deque.popleft())
broadcast_states = deque(broadcast_deque.popleft())
else:
carry_states = deque[State]()
broadcast_states = deque[State]()
if isinstance(prefix, StateAxes):
vectorized_states = deque(x.states)
for axis in prefix.axes:
if isinstance(axis, int):
state = vectorized_states.popleft()
state = jax.tree.map(
lambda x: jnp.moveaxis(x, 0, axis) if axis != 0 else x,
state,
)
states.append(state)
elif axis is None:
states.append(broadcast_states.popleft())
else: # axis is Carry
states.append(carry_states.popleft())
assert not carry_states and not broadcast_states
assert not vectorized_states or (
len(vectorized_states) == 1 and not vectorized_states[0]
)
elif isinstance(prefix, int):
state = jax.tree.map(
lambda x: jnp.moveaxis(x, 0, prefix) if prefix != 0 else x, x.state
)
states.extend((state, *carry_states, *broadcast_states))
elif prefix is None:
assert is_input_arg
states.extend(broadcast_states)
elif prefix is Carry:
assert is_input_arg
states.extend(carry_states)
else:
obj_repr = 'args' if is_input_arg else 'out'
raise ValueError(
f'Invalid axes {prefix} at {obj_repr}{jax.tree_util.keystr(path)}'
)
return ctx.merge(graphdef, *states)
else:
if isinstance(prefix, StateAxes):
obj_repr = 'args' if is_input_arg else 'out'
raise ValueError(
'Cannot use StateAxes on non-graph nodes, '
f'found {prefix} at {obj_repr}{jax.tree_util.keystr(path)}'
)
elif prefix is Carry:
return x
elif prefix is None:
return x
elif isinstance(prefix, int):
if not isinstance(x, (jax.Array, np.ndarray)):
obj_repr = 'args' if is_input_arg else 'out'
raise ValueError(
f'Expected an array, got {type(x).__name__} at '
f'{obj_repr}{jax.tree_util.keystr(path)}'
)
if prefix != 0:
x = jnp.moveaxis(x, 0, prefix)
return x
else:
obj_repr = 'args' if is_input_arg else 'out'
raise ValueError(
f'Invalid axes {prefix} at {obj_repr}{jax.tree_util.keystr(path)}'
)
@dataclasses.dataclass(eq=False)
class ScanFn:
f: tp.Callable[..., tp.Any]
input_carry_argnum: int | None | tp.Literal['all']
output_carry_argnum: int | None | tp.Literal['all']
in_axes: tp.Any
out_axes: tp.Any
transform_metadata: tp.Mapping[str, tp.Any]
def __post_init__(self):
functools.update_wrapper(self, self.f)
def __call__(
self,
carry: tuple[
tp.Any, # carry_arg
PytreeDeque[list[State]], # carry_deque
PytreeDeque[list[State]], # broadcast_deque
PytreeDeque[Broadcasted], # broadcast_arrays
],
scan_in: tuple[tp.Any, ...],
):
pure_carry_arg, carry_deque, broadcast_deque, broadcast_arrays = carry
broadcast_deque_out = PytreeDeque(broadcast_deque)
broadcast_arrays_out = PytreeDeque(broadcast_arrays)
graphdefs_deque, pure_args = scan_in
if self.input_carry_argnum == 'all':
assert pure_args == ()
pure_args = (pure_carry_arg,)
elif isinstance(self.input_carry_argnum, int):
assert pure_args[self.input_carry_argnum] is None
_pure_args = list(pure_args)
_pure_args[self.input_carry_argnum] = pure_carry_arg
pure_args = tuple(_pure_args)
else:
assert self.input_carry_argnum is None
assert pure_carry_arg is None
if spmd.PARTITION_NAME in self.transform_metadata:
pure_args = _update_variable_sharding_metadata(
pure_args, self.transform_metadata, spmd.remove_axis
)
args: tuple = extract.from_tree(
pure_args,
prefix=self.in_axes,
merge_fn=functools.partial(
_scan_merge_in, carry_deque, graphdefs_deque, broadcast_deque, broadcast_arrays
),
is_leaf=lambda x: isinstance(x, (extract.NodeStates, Broadcasted)),
map_non_graph_nodes=True,
ctxtag='scan',
is_inner=True,
)
assert not carry_deque and not broadcast_deque and not broadcast_arrays
out = self.f(*args)
# extract the carry from the args
if self.input_carry_argnum == 'all':
carry_arg = args[0]
elif isinstance(self.input_carry_argnum, int):
carry_arg = args[self.input_carry_argnum]
else:
assert self.input_carry_argnum is None
carry_arg = None
# extract the carry from the output
if self.output_carry_argnum == 'all':
carry_arg_out = out
out = None
elif isinstance(self.output_carry_argnum, int):
assert isinstance(out, tuple)
carry_arg_out = out[self.output_carry_argnum]
_out = list(out)
_out[self.output_carry_argnum] = None
out = tuple(_out)
else:
assert self.output_carry_argnum is None
carry_arg_out = None
# TODO(cgarciae): allowing new references might lead to inconsistencies with
# scan's looping semantics and we would also need to propagate the input
_check_carry_same_references(carry_arg, carry_arg_out)
args_out: tuple = extract.clear_non_graph_nodes(args)
# replace the carry from the input args with the carry from the output
if self.input_carry_argnum == 'all':
args_out = (carry_arg_out,)
elif isinstance(self.input_carry_argnum, int):
_args_out = list(args_out)
_args_out[self.input_carry_argnum] = carry_arg_out
args_out = tuple(_args_out)
else:
assert self.input_carry_argnum is None
assert carry_arg_out is None
carry_deque_out = PytreeDeque[list[State | variablelib.Variable]]()
graphdefs_out = PytreeDeque[graphlib.GraphDef]()
_broadcast_deque_out_tmp = PytreeDeque[
list[State | variablelib.Variable]
]() # discarded
pure_args_out: tuple
pure_args_out, pure_out = extract.to_tree(
(args_out, out),
prefix=(self.in_axes, self.out_axes),
split_fn=functools.partial(
_scan_split_out, carry_deque_out, graphdefs_out, _broadcast_deque_out_tmp
),
map_non_graph_nodes=True,
ctxtag='scan',
)
if spmd.PARTITION_NAME in self.transform_metadata:
pure_args_out, pure_out = _update_variable_sharding_metadata(
(pure_args_out, pure_out),
self.transform_metadata,
spmd.add_axis,
)
# extract the pure carry from the pure args
if self.input_carry_argnum == 'all':
pure_carry_arg_out = pure_args_out[0]
pure_args_out = ()
elif isinstance(self.input_carry_argnum, int):
pure_carry_arg_out = pure_args_out[self.input_carry_argnum]
_pure_args_out = list(pure_args_out)
_pure_args_out[self.input_carry_argnum] = None
pure_args_out = tuple(_pure_args_out)
else:
assert self.input_carry_argnum is None
pure_carry_arg_out = None
carry_arg_out = (
pure_carry_arg_out,
carry_deque_out,
broadcast_deque_out,
broadcast_arrays_out,
)
scan_out = (
graphdefs_out,
pure_args_out,
pure_out,
)
return carry_arg_out, scan_out
@dataclasses.dataclass(eq=False)
class SimpleScanFn:
f: tp.Callable[..., tp.Any]
graph: bool
in_axes: tp.Any
out_axes: tp.Any
out_is_tuple: bool
carry_arg_index: int | None
carry_out_index: int | None
def __post_init__(self):
functools.update_wrapper(self, self.f, updated=())
@extract.treemap_copy_args
def __call__(self, *args):
updates, snapshot = extract.updates_and_snapshot(args)
if self.carry_arg_index is not None:
carry_in = args[self.carry_arg_index]
else:
carry_in = None
if self.graph:
args = extract.from_tree2(args)
out = self.f(*args)
if self.graph:
out = extract.to_tree2(out, prefix=self.out_axes)
if self.carry_out_index is not None:
carry_out = out[self.carry_out_index] if self.out_is_tuple else out
extract.check_same_variables(carry_in, carry_out, 'scan')
# Mask variable updates for non-carry args: identify broadcast (None axis)
# Variables and drop their updates since they should not change across
# scan iterations.
masked_carry_updates = extract.mask_at(updates, self.carry_arg_index)
masked_carry_snapshot = extract.mask_at(snapshot, self.carry_arg_index)
if isinstance(self.in_axes, tuple):
masked_carry_in_axes = extract.mask_at(self.in_axes, self.carry_arg_index)
else:
masked_carry_in_axes = self.in_axes
is_leaf = lambda x: isinstance(x, variablelib.Variable) or x is None
_, per_leaf_axes = extract.broadcast_prefix2(
masked_carry_in_axes, masked_carry_updates, is_leaf=is_leaf,
)
broadcast_var_ids = set()
carry_leaves = jax.tree.leaves(masked_carry_updates, is_leaf=is_leaf)
for axis, leaf in zip(per_leaf_axes, carry_leaves, strict=True):
if axis is None and isinstance(leaf, variablelib.Variable):
broadcast_var_ids.add(id(leaf))
def keep_fn(path, cur, snap):
changed = extract.variable_changed(cur, snap)
if id(cur) in broadcast_var_ids and changed:
raise ValueError(
f'Broadcast (None axis) Variable at {jax.tree_util.keystr(path)} '
'was mutated during scan. Only Carry and scanned Variables can be '
'updated.'
)
return changed
extract.check_no_aliases('scan', args=masked_carry_updates, out=out)
masked_carry_updates = extract.mask_variable_updates(
masked_carry_updates, masked_carry_snapshot, keep_fn=keep_fn,
)
if self.out_is_tuple:
return (*out, masked_carry_updates)
return (out, masked_carry_updates)
@tp.overload
def scan(
*,
length: int | None = None,
reverse: bool = False,
unroll: int | bool = 1,
_split_transpose: bool = False,
# extended api
in_axes: int | None | type[Carry] | tuple[tp.Any, ...] = (Carry, 0),
out_axes: tp.Any = (Carry, 0),
# nnx specific
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> tp.Callable[[F], F]:
...
@tp.overload
def scan(
f: F,
*,
length: int | None = None,
reverse: bool = False,
unroll: int | bool = 1,
_split_transpose: bool = False,
# extended api
in_axes: int | None | type[Carry] | tuple[tp.Any, ...] = (Carry, 0),
out_axes: tp.Any = (Carry, 0),
# nnx specific
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> F:
...
[docs]def scan(
f: F | type[Missing] = Missing,
*,
length: int | None = None,
reverse: bool = False,
unroll: int | bool = 1,
_split_transpose: bool = False,
# extended api
in_axes: int | None | type[Carry] | tuple[tp.Any, ...] = (Carry, 0),
out_axes: tp.Any = (Carry, 0),
# nnx specific
transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> F | tp.Callable[[F], F]:
"""A Flax NNX transformation of `jax.lax.scan`_.
Example::
import jax
from flax import nnx
class Block(nnx.Module):
def __init__(self, input_dim, features, *, rngs):
self.linear = nnx.Linear(input_dim, features, rngs=rngs)
self.dropout = nnx.Dropout(0.1, rngs=rngs)
def __call__(self, x: jax.Array):
x = self.linear(x)
x = self.dropout(x)
x = jax.nn.relu(x)
return x
class Model(nnx.Module):
def __init__(self, num_layers, features, *, rngs):
# In this model implementation we create
# multiple blocks using vmap
# As Block contains dropout op, we prefer
# to split RNG into num_layers of RNGs
# using @nnx.split_rngs decorator.
# Next, nnx.vmap creates a vectorized version of Block.
# in_axes and out_axes define vectorization axis
# of the input splitted rngs and the output Block instance.
# Both axes should be 0.
@nnx.split_rngs(splits=num_layers)
@nnx.vmap(in_axes=(0,), out_axes=0)
def create_block(rngs: nnx.Rngs):
return Block(features, features, rngs=rngs)
self.blocks = create_block(rngs)
self.num_layers = num_layers
def __call__(self, x):
# Forward pass method implementation
# We use nnx.scan to apply sequentially the blocks
# on the input, for example with num_layers=3
# output = block[0](x)
# output = block[1](output)
# output = block[2](output)
#
# In `forward` function defined below:
# - x represents the loop carry value
# - model is the data to scan along the leading axis
# nnx.scan args:
# - in_axes marks the inputs: x is marked as carry
# and the model is to scan along the axis 0
# - out_axes marks the output as carry
@nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
def forward(x, model):
x = model(x)
return x
return forward(x, self.blocks)
# Alternatively, we can also decorate `self.__call__` method
# @nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
# def __call__(self, x):
# return self.blocks(x)
model = Model(2, 4, rngs=nnx.Rngs(0))
_, params, _ = nnx.split(model, nnx.Param, ...)
print(params) # kernel of shape: (2, 4, 4)
x = jnp.arange(5 * 4, dtype="float32").reshape((5, 4))
y = model(x)
print(y.shape) # shape: (5, 4)
Args:
f: a Python function to be scanned
length: optional integer specifying the number of loop iterations
reverse: optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse
unroll: optional positive int or bool specifying, in the underlying
operation of the scan primitive, how many scan iterations to unroll
within a single iteration of a loop.
in_axes: integer, None, :class:`flax.nnx.Carry` or sequence of values specifying
the kind of input args. Integer value would specify the axis of corresponding
input data to scan along. :class:`flax.nnx.Carry` marks the input data as
loop carry value. None marks the input data as auxiliary input.
out_axes: integer, None, :class:`flax.nnx.Carry` or sequence of values specifying
the kind of output args. See ``in_axes`` for details. Note that If ``in_axes``
contains :class:`flax.nnx.Carry` then ``out_axes`` must also contain :class:`flax.nnx.Carry`.
graph_updates: If ``True``, propagates updates on graph structure
that happen inside the transform to the input graphs, has no
effect when ``graph=False``. When ``False``, using ``StateAxes``
is not supported.
.. _jax.lax.scan: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html>
"""
if f is Missing:
return functools.partial(
scan,
length=length,
reverse=reverse,
unroll=unroll,
_split_transpose=_split_transpose,
in_axes=in_axes,
out_axes=out_axes,
transform_metadata=transform_metadata,
graph=graph,
graph_updates=graph_updates,
) # type: ignore[return-value]
f_unbound, _, was_bound = _resolve_bound_callable(f)
if was_bound:
_raise_bound_method_error('scan')
if graph is None:
graph = graphlib.set_graph_mode.current_value()
if graph_updates is None:
graph_updates = graphlib.set_graph_updates.current_value()
if not graph or not graph_updates:
return _simple_scan(
f, f_unbound, graph=graph,
in_axes=in_axes, out_axes=out_axes,
length=length, reverse=reverse, unroll=unroll,
_split_transpose=_split_transpose,
)
return _graph_updates_scan(
f, f_unbound,
in_axes=in_axes, out_axes=out_axes,
length=length, reverse=reverse, unroll=unroll,
_split_transpose=_split_transpose,
transform_metadata=transform_metadata,
)
def _simple_scan(
f, f_unbound, *,
graph, in_axes, out_axes,
length, reverse, unroll, _split_transpose,
):
if any(isinstance(x, StateAxes) for x in jax.tree.leaves(in_axes)):
raise ValueError(
'`in_axes` cannot contain `StateAxes` objects '
'when `graph=False`. '
+ graphlib._tree_mode_suggestion_transform('scan')
)
if any(isinstance(x, StateAxes) for x in jax.tree.leaves(out_axes)):
raise ValueError(
'`out_axes` cannot contain `StateAxes` objects '
'when `graph=False`. '
+ graphlib._tree_mode_suggestion_transform('scan')
)
out_is_tuple = isinstance(out_axes, tuple)
if in_axes is Carry:
in_axes = (Carry,)
if isinstance(in_axes, tuple):
carry_arg_index = next(
(i for i, ax in enumerate(in_axes) if ax is Carry), None
)
updates_out_axes = extract.mask_at(in_axes, carry_arg_index)
else:
carry_arg_index = None
updates_out_axes = in_axes
if isinstance(out_axes, tuple):
carry_out_index = next(
(i for i, ax in enumerate(out_axes) if ax is Carry), None
)
else:
carry_out_index = None
simple_scan_fn = SimpleScanFn(
f_unbound, graph=graph,
in_axes=in_axes, out_axes=out_axes,
out_is_tuple=out_is_tuple,
carry_arg_index=carry_arg_index,
carry_out_index=carry_out_index,
)
if out_is_tuple:
augmented_out_axes = (*out_axes, updates_out_axes)
else:
augmented_out_axes = (out_axes, updates_out_axes)
@functools.wraps(f)
def simple_scan_wrapper(*args):
args = resolve_kwargs(f, args, {})
if graph:
args = extract.to_tree2(args, prefix=in_axes)
extract.check_no_aliases('scan', args=args)
result = pure_jax_fancy_scan(
simple_scan_fn,
*args,
length=length,
reverse=reverse,
unroll=unroll,
_split_transpose=_split_transpose,
in_axes=in_axes,
out_axes=augmented_out_axes,
)
if out_is_tuple:
n = len(out_axes)
out = result[:n]
updates = result[n]
else:
out, updates = result
masked_args = extract.mask_at(args, carry_arg_index)
extract.apply_variable_updates(masked_args, updates)
if carry_arg_index is not None:
carry_in = args[carry_arg_index]
carry_out = (
out[carry_out_index] if out_is_tuple else out
)
extract.update_carry_variables(carry_in, carry_out)
if out_is_tuple:
out_list = list(out)
out_list[carry_out_index] = carry_in
out = tuple(out_list)
else:
out = carry_in
if graph:
out = extract.from_tree2(out)
return out
return simple_scan_wrapper
def _graph_updates_scan(
f, f_unbound, *,
in_axes, out_axes,
length, reverse, unroll, _split_transpose,
transform_metadata,
):
_check_out_axes(out_axes)
input_carry_argnum = _get_carry_argnum(in_axes, is_in_axes=True)
output_carry_argnum = _get_carry_argnum(out_axes, is_in_axes=False)
if (input_carry_argnum is None and output_carry_argnum is not None) or (
input_carry_argnum is not None and output_carry_argnum is None
):
raise ValueError(
'If one of in_axes or out_axes has Carry, the other must also have Carry. '
f'Got {in_axes=!r} and {out_axes=!r}'
)
scan_fn = ScanFn(
f_unbound,
input_carry_argnum,
output_carry_argnum,
in_axes,
out_axes,
transform_metadata,
)
@functools.wraps(f)
@graphlib.update_context('scan')
def scan_wrapper(*args, **kwargs):
args = resolve_kwargs(f, args, kwargs)
if in_axes is Carry and len(args) != 1:
raise ValueError(
f'When in_axes=Carry, the function must take exactly one argument, '
f'got {len(args)} arguments.'
)
graphdefs_deque = PytreeDeque()
carry_deque = PytreeDeque()
broadcast_deque = PytreeDeque()
broadcast_arrays = PytreeDeque()
pure_args: tuple = extract.to_tree(
args,
prefix=in_axes,
split_fn=functools.partial(
_scan_split_in, carry_deque, graphdefs_deque, broadcast_deque, broadcast_arrays
),
map_non_graph_nodes=True,
ctxtag='scan',
)
if isinstance(input_carry_argnum, int):
pure_carry_arg = pure_args[input_carry_argnum]
_pure_args = list(pure_args)
_pure_args[input_carry_argnum] = None
pure_args = tuple(_pure_args)
elif input_carry_argnum == 'all':
pure_carry_arg = pure_args[0]
pure_args = ()
else:
assert input_carry_argnum is None
pure_carry_arg = None
carry = (pure_carry_arg, carry_deque, broadcast_deque, broadcast_arrays)
scan_in = (graphdefs_deque, pure_args)
carry_out, scan_out = jax.lax.scan(
scan_fn,
carry,
scan_in,
length=length,
reverse=reverse,
unroll=unroll,
_split_transpose=_split_transpose,
)
(
pure_carry_arg_out,
carry_deque_out,
broadcast_deque_out,
broadcast_arrays_out,
) = carry_out
(
graphdefs_out,
pure_args_out,
pure_out,
) = scan_out
if input_carry_argnum == 'all':
pure_args_out = (pure_carry_arg_out,)
elif isinstance(input_carry_argnum, int):
_pure_args_out = list(pure_args_out)
_pure_args_out[input_carry_argnum] = pure_carry_arg_out
pure_args_out = tuple(_pure_args_out)
else:
assert input_carry_argnum is None
assert pure_carry_arg_out is None
args_out, out = extract.from_tree(
(pure_args_out, pure_out),
prefix=(in_axes, out_axes),
merge_fn=functools.partial(
_scan_merge_out, carry_deque_out, graphdefs_out, broadcast_deque_out
),
is_leaf=lambda x: isinstance(x, (extract.NodeStates, Broadcasted)),
map_non_graph_nodes=True,
ctxtag='scan',
is_inner=False,
)
if input_carry_argnum == 'all':
carry_arg = args_out[0]
elif isinstance(input_carry_argnum, int):
carry_arg = args_out[input_carry_argnum]
else:
assert input_carry_argnum is None
carry_arg = None
if output_carry_argnum == 'all':
out = carry_arg
elif isinstance(output_carry_argnum, int):
_out = list(out)
_out[output_carry_argnum] = carry_arg
out = tuple(_out)
else:
assert output_carry_argnum is None
assert carry_arg is None
return out
return scan_wrapper
def pure_jax_fancy_scan(
f,
*args,
length: int | None = None,
reverse: bool = False,
unroll: int | bool = 1,
_split_transpose: bool = False,
in_axes: tp.Any = (Carry, 0),
out_axes: tp.Any = (Carry, 0),
):
if in_axes is Carry:
in_axes = (Carry,)
is_axis_leaf = lambda x: x is None or x is Carry
if isinstance(in_axes, tuple):
for i, ax in enumerate(in_axes):
if ax is Carry or ax is None or isinstance(ax, int):
continue
for leaf in jax.tree.leaves(ax, is_leaf=is_axis_leaf):
if leaf is Carry:
raise ValueError(
'Carry must be a top-level argument, it cannot be nested. '
f'Found Carry inside in_axes[{i}]={ax}'
)
if isinstance(out_axes, tuple):
for i, ax in enumerate(out_axes):
if ax is Carry or ax is None or isinstance(ax, int):
continue
for path, leaf in jax.tree_util.tree_leaves_with_path(
ax, is_leaf=is_axis_leaf,
):
if leaf is Carry:
raise ValueError(
'Carry must be a top-level argument, it cannot be nested. '
f'Found Carry at out_axes[{i}]{jax.tree_util.keystr(path)}'
)
in_has_carry = in_axes is Carry or (
isinstance(in_axes, tuple) and Carry in in_axes
)
out_has_carry = out_axes is Carry or (
isinstance(out_axes, tuple) and Carry in out_axes
)
if in_has_carry != out_has_carry:
raise ValueError(
'If one of in_axes or out_axes has Carry, the other must also '
f'have Carry. Got {in_axes=}, {out_axes=}'
)
args_flat, args_treedef = jax.tree.flatten(args)
_, in_axes_flat = extract.broadcast_prefix2(
in_axes, args, is_leaf=is_axis_leaf,
)
carry_indices: list[int] = []
broadcast_indices: list[int] = []
scan_indices: list[int] = []
scan_in_axes: list[int] = []
carry_leaves: list[tp.Any] = []
broadcast_leaves: list[tp.Any] = []
scan_leaves: list[tp.Any] = []
for i, (leaf, ax) in enumerate(zip(args_flat, in_axes_flat, strict=True)):
if ax is Carry:
carry_indices.append(i)
carry_leaves.append(leaf)
elif ax is None:
broadcast_indices.append(i)
broadcast_leaves.append(leaf)
elif isinstance(ax, int):
scan_indices.append(i)
scan_in_axes.append(ax)
if ax != 0:
leaf = jnp.moveaxis(leaf, ax, 0)
scan_leaves.append(leaf)
else:
raise ValueError(f'Invalid in_axes leaf value: {ax}')
n_in = len(args_flat)
out_info: list[tuple[
jax.tree_util.PyTreeDef, list[int], list[int], list[int],
]] = []
in_broadcast = jax.tree.map(lambda x: x, broadcast_leaves)
def body_fn(carry_state, scan_x):
flat = [None] * n_in
for idx, j in enumerate(carry_indices):
flat[j] = carry_state[idx]
for idx, j in enumerate(broadcast_indices):
flat[j] = in_broadcast[idx]
if scan_x is not None:
for idx, j in enumerate(scan_indices):
flat[j] = scan_x[idx]
reconstructed = args_treedef.unflatten(flat)
out = f(*reconstructed)
out_flat, out_treedef = jax.tree.flatten(out)
out_axes_paths, out_axes_flat = extract.broadcast_prefix2(
out_axes, out, is_leaf=is_axis_leaf,
)
if not out_info:
out_carry_idx = []
out_scan_idx = []
out_scan_axes = []
out_broadcast_idx = []
for j, oax in enumerate(out_axes_flat):
if oax is Carry:
out_carry_idx.append(j)
elif oax is None:
out_broadcast_idx.append(j)
elif isinstance(oax, int):
out_scan_idx.append(j)
out_scan_axes.append(oax)
else:
raise ValueError(f'Invalid out_axes leaf value: {oax}')
if out_broadcast_idx:
broadcast_paths = [
jax.tree_util.keystr(out_axes_paths[j]) for j in out_broadcast_idx
]
broadcast_str = "\n\n ".join(broadcast_paths)
raise ValueError(
'Scan does not support broadcast outputs (None axis). The following '
f'output leaves are broadcast:\n\n {broadcast_str}\n'
)
out_info.append(
(out_treedef, out_carry_idx, out_scan_idx, out_scan_axes),
)
oci = out_info[0][1]
osi = out_info[0][2]
new_carry = [out_flat[j] for j in oci]
new_ys = [out_flat[j] for j in osi]
return new_carry, new_ys
final_carry, stacked_ys = jax.lax.scan(
body_fn,
carry_leaves,
scan_leaves if scan_leaves else None,
length=length,
reverse=reverse,
unroll=unroll,
_split_transpose=_split_transpose,
)
out_treedef, out_carry_idx, out_scan_idx, out_scan_axes = (
out_info[0]
)
n_out = out_treedef.num_leaves
out_flat: list[tp.Any] = [None] * n_out
for idx, j in enumerate(out_carry_idx):
out_flat[j] = final_carry[idx]
for idx, j in enumerate(out_scan_idx):
y = stacked_ys[idx]
ax = out_scan_axes[idx]
if ax != 0:
y = jnp.moveaxis(y, 0, ax)
out_flat[j] = y
return out_treedef.unflatten(out_flat)
# -------------------------------
# while_loop
# -------------------------------
@dataclasses.dataclass(eq=False)
class SimpleWhileLoopBodyFn:
f: tp.Callable[..., tp.Any]
graph: bool
def __post_init__(self):
functools.update_wrapper(self, self.f, updated=())
@extract.treemap_copy_args
def __call__(self, val):
val_variables, _ = extract.updates_and_snapshot(val)
if self.graph:
val = extract.from_tree2(val)
out = self.f(val)
if self.graph:
out = extract.to_tree2(out)
extract.check_same_variables(val_variables, out, 'while_loop')
return out
@dataclasses.dataclass(eq=False)
class SimpleWhileLoopCondFn:
f: tp.Callable[..., tp.Any]
graph: bool
def __post_init__(self):
functools.update_wrapper(self, self.f)
def __call__(self, val):
if self.graph:
val = extract.from_tree2(val)
return self.f(val)
@dataclasses.dataclass(eq=False)
class WhileLoopCondFn:
f: tp.Callable[..., tp.Any]
def __post_init__(self):
functools.update_wrapper(self, self.f)
def __call__(self, pure_val):
val = extract.from_tree(pure_val)
out = self.f(val)
return out
def _reconsile_index_mapping(tree_to_fix, example_tree):
def f(a, b):
if not isinstance(a, extract.NodeStates) or not isinstance(
a._graphdef, graphlib.GraphDef
):
return a
return dataclasses.replace(
a, _graphdef=a._graphdef.with_matching_outer_index(b._graphdef)
)
return jax.tree.map(f, tree_to_fix, example_tree,
is_leaf=lambda x: isinstance(x, extract.NodeStates))
def _add_fake_index_mapping(tree: tp.Any):
def per_node_state(node_state: extract.NodeStates | tp.Any):
if not isinstance(node_state, extract.NodeStates) or not isinstance(
node_state._graphdef, graphlib.GraphDef
):
return node_state
return dataclasses.replace(
node_state, _graphdef=node_state._graphdef.with_same_outer_index()
)
return jax.tree.map(per_node_state, tree,
is_leaf=lambda x: isinstance(x, extract.NodeStates))
def _remove_index_mapping(tree: tp.Any):
"""Remove a fake outer_index for the input to match that of the output."""
def per_node_state(node_state: extract.NodeStates | tp.Any):
if not isinstance(node_state, extract.NodeStates) or not isinstance(
node_state._graphdef, graphlib.GraphDef
):
return node_state
assert isinstance(node_state._graphdef, graphlib.GraphDef)
node_state = dataclasses.replace(
node_state, _graphdef=node_state._graphdef.with_no_outer_index()
)
return node_state
return jax.tree.map(per_node_state, tree,
is_leaf=lambda x: isinstance(x, extract.NodeStates))
@dataclasses.dataclass(eq=False)
class WhileLoopBodyFn:
f: tp.Callable[..., tp.Any]
def __post_init__(self):
functools.update_wrapper(self, self.f)
@graphlib.update_context('while_loop_body')
def __call__(self, pure_val):
# Removing the dummy index mapping being added outside of body function.
pure_val_in = _remove_index_mapping(pure_val)
val = extract.from_tree(
pure_val_in, ctxtag='while_loop_body', is_inner=True
)
out = self.f(val)
pure_out = extract.to_tree(out, ctxtag='while_loop_body')
try:
jax.tree.map(lambda a, b: None, pure_val, pure_out)
except ValueError as e:
msg = (
"nnx.while_loop requires body function's input and output to "
'have the same reference and pytree structure, but they differ. '
'If the mismatch comes from `outer_index` field, you might '
'have modified reference structure within the body function, '
'which is not allowed.'
f'Detail of the mismatch: \n {str(e)}'
)
raise ValueError(msg)
return pure_out
[docs]@graphlib.update_context('while_loop')
def while_loop(cond_fun: tp.Callable[[T], tp.Any],
body_fun: tp.Callable[[T], T],
init_val: T,
*,
graph: bool | None = None,
graph_updates: bool | None = None) -> T:
"""A Flax NNX transformation of `jax.lax.while_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html>`_.
Caution: for the NNX internal reference tracing mechanism to work, you cannot
change the variable reference structure of ``init_val`` inside ``body_fun``.
Example::
>>> import jax
>>> from flax import nnx
>>> def fwd_fn(input):
... module, x, count = input
... return module, module(x), count - 1.0
>>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
>>> x = jax.random.normal(jax.random.key(0), (10,))
>>> # `module` will be called three times
>>> _, y, _ = nnx.while_loop(
... lambda input: input[-1] > 0, fwd_fn, (module, x, 3.0))
Args:
cond_fun: A function for the continue condition of the while loop, taking a
single input of type ``T`` and outputting a boolean.
body_fun: A function that takes an input of type ``T`` and outputs an ``T``.
Note that both data and modules of ``T`` must have the same reference
structure between inputs and outputs.
init_val: The initial input for ``cond_fun`` and ``body_fun``. Must be of type ``T``.
graph: if True, use graph-mode (default). If False, use tree-mode.
If None, uses the value of ``nnx_graph_mode`` config.
graph_updates: If ``True``, propagates updates on graph structure
that happen inside the transform to the input graphs, has no
effect when ``graph=False``.
"""
if graph is None:
graph = graphlib.set_graph_mode.current_value()
if graph_updates is None:
graph_updates = graphlib.set_graph_updates.current_value()
if not graph or not graph_updates:
simple_body_fn = SimpleWhileLoopBodyFn(body_fun, graph=graph)
simple_cond_fn = SimpleWhileLoopCondFn(cond_fun, graph=graph)
if graph:
init_val = extract.to_tree2(init_val)
val_out = jax.lax.while_loop(simple_cond_fn, simple_body_fn, init_val)
val_out = extract.update_carry_variables(init_val, val_out)
if graph:
val_out = extract.from_tree2(val_out)
return val_out
pure_init_val = extract.to_tree(init_val, ctxtag='while_loop')
pure_init_val = _add_fake_index_mapping(pure_init_val)
pure_out = jax.lax.while_loop(
WhileLoopCondFn(cond_fun),
WhileLoopBodyFn(body_fun),
pure_init_val,
)
out = extract.from_tree(pure_out, ctxtag='while_loop', is_inner=False)
return out
@dataclasses.dataclass(eq=False)
class SimpleForiLoopBodyFn:
f: tp.Callable[..., tp.Any]
graph: bool
def __post_init__(self):
functools.update_wrapper(self, self.f, updated=())
@extract.treemap_copy_args
def __call__(self, i, val):
val_variables, _ = extract.updates_and_snapshot(val)
if self.graph:
val = extract.from_tree2(val)
out = self.f(i, val)
if self.graph:
out = extract.to_tree2(out)
extract.check_same_variables(val_variables, out, 'fori_loop')
return out
@dataclasses.dataclass(eq=False)
class ForiLoopBodyFn:
f: tp.Callable[..., tp.Any]
def __post_init__(self):
functools.update_wrapper(self, self.f)
@graphlib.update_context('fori_loop_body')
def __call__(self, i, pure_val_in):
val = extract.from_tree(pure_val_in, ctxtag='fori_loop_body', is_inner=True)
out = self.f(i, val)
pure_out = extract.to_tree(out, ctxtag='fori_loop_body')
return pure_out
[docs]@graphlib.update_context('fori_loop')
def fori_loop(lower: int, upper: int,
body_fun: tp.Callable[[int, T], T],
init_val: T,
*,
unroll: int | bool | None = None,
graph: bool | None = None,
graph_updates: bool | None = None) -> T:
"""A Flax NNX transformation of `jax.lax.fori_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html>`_.
Caution: for the NNX internal reference tracing mechanism to work, you cannot
change the variable reference structure of `init_val` inside `body_fun`.
Example::
>>> import jax
>>> from flax import nnx
>>> def fwd_fn(i, input):
... m, x = input
... m.kernel[...] = jnp.identity(10) * i
... return m, m(x)
>>> module = nnx.Linear(10, 10, rngs=nnx.Rngs(0))
>>> x = jax.random.normal(jax.random.key(0), (10,))
>>> _, y = nnx.fori_loop(2, 4, fwd_fn, (module, x))
>>> np.testing.assert_array_equal(y, x * 2 * 3)
Args:
lower: An integer representing the loop index lower bound (inclusive).
upper: An integer representing the loop index upper bound (exclusive).
body_fun: a function that takes an input of type ``T`` and outputs an ``T``.
Note that both data and modules of ``T`` must have the same reference
structure between inputs and outputs.
init_val: the initial input for body_fun. Must be of type ``T``.
unroll: An optional integer or boolean that determines how much to unroll
the loop. If an integer is provided, it determines how many unrolled
loop iterations to run within a single rolled iteration of the loop. If a
boolean is provided, it will determine if the loop is competely unrolled
(i.e. ``unroll=True``) or left completely unrolled (i.e. ``unroll=False``).
This argument is only applicable if the loop bounds are statically known.
graph: if True, use graph-mode (default). If False, use tree-mode.
If None, uses the value of ``nnx_graph_mode`` config.
graph_updates: If ``True``, propagates updates on graph structure
that happen inside the transform to the input graphs, has no
effect when ``graph=False``.
Returns:
A loop value from the final iteration, of type ``T``.
"""
if graph is None:
graph = graphlib.set_graph_mode.current_value()
if graph_updates is None:
graph_updates = graphlib.set_graph_updates.current_value()
if not graph or not graph_updates:
simple_body_fn = SimpleForiLoopBodyFn(body_fun, graph=graph)
if graph:
init_val = extract.to_tree2(init_val)
val_out = jax.lax.fori_loop(
lower, upper,
simple_body_fn,
init_val,
unroll=unroll,
)
val_out = extract.update_carry_variables(init_val, val_out)
if graph:
val_out = extract.from_tree2(val_out)
return val_out
pure_init_val = extract.to_tree(init_val, ctxtag='fori_loop')
body = ForiLoopBodyFn(body_fun)
pure_out = jax.eval_shape(body, lower, pure_init_val)
pure_init_val = _reconsile_index_mapping(pure_init_val, pure_out)
pure_out = jax.lax.fori_loop(lower, upper,
body, pure_init_val,
unroll=unroll)
out = extract.from_tree(pure_out, ctxtag='fori_loop', is_inner=False)
return out