# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
import dataclasses
import functools
import typing as tp
from flax import struct
from flax.nnx import (
extract,
filterlib,
graphlib,
variablelib,
)
from flax.nnx.statelib import State
import jax
from flax.nnx.transforms import general
from flax.nnx.transforms.transforms import (
resolve_kwargs,
_resolve_bound_callable,
_raise_bound_method_error,
)
from flax.typing import MISSING, Missing
A = tp.TypeVar('A')
# C = tp.TypeVar('C')
# B = tp.TypeVar('B')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
# G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any])
# M = tp.TypeVar('M', bound=Module)
# MA = tp.TypeVar('MA', bound=Module)
# N = tp.TypeVar('N', bound=Module)
# StrInt = tp.TypeVar('StrInt', str, int)
AxisName = tp.Hashable
# Leaves = tp.List[Leaf]
# Index = int
# -------------------------------
# grad
# -------------------------------
@dataclasses.dataclass(frozen=True)
class DiffState:
argnum: int
filter: filterlib.Filter
@dataclasses.dataclass(eq=False)
class SimpleGradFn:
f: tp.Callable[..., tp.Any]
has_aux: bool
graph: bool
def __post_init__(self):
functools.update_wrapper(self, self.f, updated=())
@extract.treemap_copy_args
def __call__(self, *args, **kwargs):
updates, snapshot = extract.updates_and_snapshot((args, kwargs))
if self.graph:
args, kwargs = extract.from_tree2((args, kwargs))
out = self.f(*args, **kwargs)
if self.graph:
out = extract.to_tree2(out)
extract.check_no_aliases('grad', args=updates[0], kwargs=updates[1], out=out)
updates = extract.mask_variable_updates(updates, snapshot)
if self.has_aux:
loss, aux = out
return loss, (updates, aux)
else:
return out, updates
@dataclasses.dataclass(eq=False)
class GradFn:
f: tp.Callable[..., tp.Any]
has_aux: bool
nondiff_states: deque[State | None]
def __post_init__(self):
functools.update_wrapper(self, self.f)
def __call__(self, *pure_args):
# rebuild diff_state from substates in args
def _grad_merge_fn(
ctx: graphlib.MergeContext, path, prefix, value: extract.NodeStates
):
nondiff = self.nondiff_states.popleft()
if nondiff is None:
return ctx.merge(value.graphdef, value.state)
else:
return ctx.merge(value.graphdef, value.state, nondiff)
args = extract.from_tree(
pure_args, merge_fn=_grad_merge_fn, ctxtag='grad', is_inner=True
)
out = self.f(*args)
args_out = extract.clear_non_graph_nodes(args)
pure_args_out, pure_out = extract.to_tree((args_out, out), ctxtag='grad')
if self.has_aux:
loss, pure_aux = pure_out
fn_out = (loss, (pure_args_out, pure_aux))
else:
loss = pure_out
fn_out = (loss, pure_args_out)
return fn_out
def _grad_general(
f: tp.Callable[..., tp.Any],
argnums: int | DiffState | tp.Sequence[int | DiffState],
has_aux: bool,
holomorphic: bool,
allow_int: bool,
return_value: bool,
graph: bool,
graph_updates: bool,
) -> tp.Callable[..., tp.Any]:
transform = jax.value_and_grad if return_value else jax.grad
if not graph or not graph_updates:
if any(isinstance(x, DiffState) for x in jax.tree.leaves(argnums)):
raise ValueError(
'`argnums` cannot contain `DiffState` objects '
'when `graph=False`. '
+ graphlib._tree_mode_suggestion_transform('grad')
)
gradded_fn = transform(
SimpleGradFn(f, has_aux, graph=graph),
argnums=argnums, # type: ignore[arg-type]
has_aux=True,
holomorphic=holomorphic,
allow_int=allow_int,
)
def tree_grad_wrapper(*args, **kwargs):
if graph:
diff_argnums = (argnums,) if isinstance(argnums, int) else argnums
args_prefix = tuple(
i in diff_argnums for i in range(len(args))
)
args, kwargs = extract.to_tree2(
(args, kwargs), prefix=(args_prefix, False),
)
extract.check_no_aliases('grad', args=args, kwargs=kwargs)
fn_out = gradded_fn(*args, **kwargs)
if return_value:
if has_aux:
(loss, (updates, aux)), grads = fn_out
if graph: grads, aux = extract.from_tree2((grads, aux))
result = (loss, aux), grads
else:
(loss, updates), grads = fn_out
if graph: grads = extract.from_tree2(grads)
result = loss, grads
else:
if has_aux:
grads, (updates, aux) = fn_out
if graph: grads, aux = extract.from_tree2((grads, aux))
result = grads, aux
else:
grads, updates = fn_out
if graph: grads = extract.from_tree2(grads)
result = grads
extract.apply_variable_updates((args, kwargs), updates)
return result
return tree_grad_wrapper
jax_argnums: int | tuple[int, ...]
if isinstance(argnums, (int, DiffState)):
jax_argnums = argnums.argnum if isinstance(argnums, DiffState) else argnums
else:
jax_argnums = tuple(
x.argnum if isinstance(x, DiffState) else x for x in argnums
)
_argnums = (argnums,) if isinstance(argnums, (int, DiffState)) else argnums
index_filter: dict[int, DiffState] = {}
for argnum in _argnums:
index = argnum.argnum if isinstance(argnum, DiffState) else argnum
if index in index_filter:
raise ValueError(f'argnum {index} is repeated in argnums')
index_filter[index] = (
dataclasses.replace(argnum, argnum=-1)
if isinstance(argnum, DiffState)
else DiffState(-1, variablelib.Param)
)
@graphlib.update_context('grad')
def grad_wrapper(*args, **kwargs):
args = resolve_kwargs(f, args, kwargs)
del kwargs
nondiff_states: deque[State | variablelib.Variable | None] = deque()
def _grad_split_fn(
ctx: graphlib.SplitContext, path, prefix: DiffState | None, value
):
if prefix is None or (prefix.argnum == -1 and isinstance(value, variablelib.Variable)):
nondiff_states.append(None)
return extract.NodeStates.from_split(*ctx.split(value))
else:
graphdef, diff, nondiff = ctx.split(value, prefix.filter, ...) # type: ignore[misc]
nondiff_states.append(nondiff) # type: ignore[container-type-mismatch]
return extract.NodeStates.from_split(graphdef, diff)
arg_filters = tuple(index_filter.get(i) for i in range(len(args)))
pure_args = extract.to_tree(
args, prefix=arg_filters, split_fn=_grad_split_fn, ctxtag='grad'
)
gradded_fn = transform(
GradFn(f, has_aux, nondiff_states),
argnums=jax_argnums,
has_aux=True,
holomorphic=holomorphic,
allow_int=allow_int,
)
fn_out = gradded_fn(*pure_args)
def process_grads(grads):
return jax.tree.map(
lambda x: x.state if isinstance(x, extract.NodeStates) else x,
grads,
is_leaf=lambda x: isinstance(x, extract.NodeStates),
)
def process_out(pure_out: A, /) -> A:
return extract.from_tree(pure_out, ctxtag='grad', is_inner=False)
if return_value:
# unpack value_and_grad output
if has_aux:
(loss, (pure_args_out, pure_aux)), grads = fn_out
grads = process_grads(grads)
_args_out, aux = process_out((pure_args_out, pure_aux))
return (loss, aux), grads
else:
(loss, pure_args_out), grads = fn_out
grads = process_grads(grads)
_args_out = process_out(pure_args_out)
return loss, grads
else:
# unpack grad output
if has_aux:
grads, (pure_args_out, pure_aux) = fn_out
grads = process_grads(grads)
_args_out, aux = process_out((pure_args_out, pure_aux))
return grads, aux
else:
grads, pure_args_out = fn_out
grads = process_grads(grads)
_args_out = process_out(pure_args_out)
return grads
return grad_wrapper
@tp.overload
def grad(
f: tp.Callable[..., tp.Any],
*,
argnums: int | DiffState | tp.Sequence[int | DiffState] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> tp.Callable[..., tp.Any]: ...
@tp.overload
def grad(
*,
argnums: int | DiffState | tp.Sequence[int | DiffState] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ...
[docs]def grad(
f: tp.Callable[..., tp.Any] | Missing = MISSING,
*,
argnums: int | DiffState | tp.Sequence[int | DiffState] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> (
tp.Callable[..., tp.Any]
| tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]
):
"""Object-aware version of ``jax.grad`` that can handle Modules / graph nodes as
arguments.
Example::
>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 3))
...
>>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2)
>>> grad_fn = nnx.grad(loss_fn)
...
>>> grads = grad_fn(m, x, y)
>>> jax.tree.map(jnp.shape, grads)
State({
'bias': Param(
value=(3,)
),
'kernel': Param(
value=(2, 3)
)
})
By default, NNX objects are differentiated with respect to all their ``nnx.Param``
Variables. You can specify which substates are differentiable by passing a ``DiffState``
object to the ``argnums`` argument. For example, if you want to differentiate only the
``kernel`` attribute of the ``Linear`` class, you can use the ``PathContains`` filter::
>>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
...
>>> kernel_attribute = nnx.PathContains('kernel')
>>> diff_state = nnx.DiffState(0, kernel_attribute)
...
>>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2)
>>> grad_fn = nnx.grad(loss_fn, argnums=diff_state)
...
>>> grads = grad_fn(m, x, y)
>>> jax.tree.map(jnp.shape, grads)
State({
'kernel': Param(
value=(2, 3)
)
})
For more information on how to create custom filters, see
`Using Filters <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__
guide.
Args:
fun: Function to be differentiated. Its arguments at positions specified by
``argnums`` should be arrays, scalars, graph nodes or standard Python
containers. Argument arrays in the positions specified by ``argnums`` must
be of inexact (i.e., floating-point or complex) type. It should return a
scalar (which includes arrays with shape ``()`` but not arrays with shape
``(1,)`` etc.)
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default 0).
has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be
holomorphic. If True, inputs and outputs must be complex. Default False.
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
graph: If ``True`` (default), uses graph-mode which supports the full
NNX feature set including shared references and reference semantics.
If ``False``, uses tree-mode which treats Modules as regular JAX
pytrees, avoiding the overhead of the graph protocol. Tree-mode does
not support ``DiffState`` or shared ``Variable`` references.
graph_updates: If ``True``, propagates updates on graph structure
that happen inside the transform to the input graphs, has no
effect when ``graph=False``. When ``False``, using ``DiffState``
is not supported.
"""
if graph is None:
graph = graphlib.set_graph_mode.current_value()
if graph_updates is None:
graph_updates = graphlib.set_graph_updates.current_value()
if reduce_axes:
raise NotImplementedError('reduce_axes argument to grad is deprecated')
del reduce_axes
if isinstance(f, Missing):
return functools.partial(
grad,
argnums=argnums,
has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int,
graph=graph,
graph_updates=graph_updates,
)
# Detect bound nnx.Module methods and raise error.
f_unbound, _, was_bound = _resolve_bound_callable(f)
if was_bound:
_raise_bound_method_error('grad')
return _grad_general(
f_unbound,
argnums,
has_aux,
holomorphic,
allow_int,
return_value=False,
graph=graph,
graph_updates=graph_updates,
)
@tp.overload
def value_and_grad(
f: tp.Callable[..., tp.Any],
*,
argnums: int | DiffState | tp.Sequence[int | DiffState] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> tp.Callable[..., tp.Any]: ...
@tp.overload
def value_and_grad(
*,
argnums: int | DiffState | tp.Sequence[int | DiffState] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ...
[docs]def value_and_grad(
f: tp.Callable[..., tp.Any] | type[Missing] = Missing,
*,
argnums: int | DiffState | tp.Sequence[int | DiffState] = 0,
has_aux: bool = False,
holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> (
tp.Callable[..., tp.Any]
| tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]
):
"""Object-aware version of ``jax.value_and_grad``.
Like :func:`grad`, but returns both the value and the gradient of ``f``.
Args:
f: Function to be differentiated. Its arguments at positions specified by
``argnums`` should be arrays, scalars, graph nodes or standard Python
containers. Argument arrays in the positions specified by ``argnums`` must
be of inexact (i.e., floating-point or complex) type. It should return a
scalar (which includes arrays with shape ``()`` but not arrays with shape
``(1,)`` etc.)
argnums: Optional, integer or sequence of integers. Specifies which
positional argument(s) to differentiate with respect to (default 0).
has_aux: Optional, bool. Indicates whether ``f`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
holomorphic: Optional, bool. Indicates whether ``f`` is promised to be
holomorphic. If True, inputs and outputs must be complex. Default False.
allow_int: Optional, bool. Whether to allow differentiating with
respect to integer valued inputs. The gradient of an integer input will
have a trivial vector-space dtype (float0). Default False.
graph: If ``True`` (default), uses graph-mode which supports the full
NNX feature set including shared references and reference semantics.
If ``False``, uses tree-mode which treats Modules as regular JAX
pytrees, avoiding the overhead of the graph protocol. Tree-mode does
not support ``DiffState`` or shared ``Variable`` references.
graph_updates: If ``True``, propagates updates on graph structure
that happen inside the transform to the input graphs, has no
effect when ``graph=False``. When ``False``, using ``DiffState``
is not supported.
Returns:
A function with the same arguments as ``f`` that evaluates both ``f``
and the gradient of ``f`` and returns them as a pair (a two-element
tuple). If ``argnums`` is an integer then the gradient has the same shape
and type as the positional argument indicated by that integer. If argnums is
a sequence of integers, the gradient is a tuple of values with the same
shapes and types as the corresponding arguments. If ``has_aux`` is True
then a tuple of ((value, auxiliary_data), gradient) is returned.
"""
if graph is None:
graph = graphlib.set_graph_mode.current_value()
if graph_updates is None:
graph_updates = graphlib.set_graph_updates.current_value()
if reduce_axes:
raise NotImplementedError(
'reduce_axes argument to value_and_grad is deprecated')
del reduce_axes
if f is Missing:
return functools.partial(
value_and_grad,
argnums=argnums,
has_aux=has_aux,
holomorphic=holomorphic,
allow_int=allow_int,
graph=graph,
graph_updates=graph_updates,
)
# Detect bound nnx.Module methods and raise error.
f_unbound, _, was_bound = _resolve_bound_callable(f)
if was_bound:
_raise_bound_method_error('value_and_grad')
return _grad_general(
f_unbound,
argnums,
has_aux,
holomorphic,
allow_int,
return_value=True,
graph=graph,
graph_updates=graph_updates,
)
# -----------------------------------------------
# vjp
# -----------------------------------------------
@dataclasses.dataclass(eq=False)
class SimpleVjpFn:
f: tp.Callable[..., tp.Any]
has_aux: bool
graph: bool
def __post_init__(self):
functools.update_wrapper(self, self.f, updated=())
@extract.treemap_copy_args
def __call__(self, *args):
updates, snapshot = extract.updates_and_snapshot(args)
if self.graph:
args = extract.from_tree2(args)
out = self.f(*args)
if self.graph:
out = extract.to_tree2(out)
extract.check_no_aliases('vjp', args=updates, out=out)
updates = extract.mask_variable_updates(updates, snapshot)
if self.has_aux:
primals_out, aux = out
return primals_out, (updates, aux)
else:
return out, updates
@tp.overload
def vjp(
f: tp.Callable[..., tp.Any],
*primals: tp.Any,
has_aux: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> tuple[tp.Any, tp.Callable] | tuple[tp.Any, tp.Callable, tp.Any]: ...
@tp.overload
def vjp(
*,
has_aux: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ...
[docs]def vjp(
f: tp.Callable[..., tp.Any] | Missing = MISSING,
*primals: tp.Any,
has_aux: bool = False,
reduce_axes: tp.Sequence[AxisName] = (),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> (
tuple[tp.Any, tp.Callable]
| tuple[tp.Any, tp.Callable, tp.Any]
| tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]
):
"""Stateful version of ``jax.vjp`` that propagates NNX Variable updates.
Example::
>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 2))
...
>>> def loss_fn(m, x):
... return jnp.sum(m(x))
...
>>> primals_out, vjp_fn = nnx.vjp(loss_fn, m, x, graph=False)
>>> m_grad, x_grad = vjp_fn(jnp.ones_like(primals_out))
Can also be used as a decorator::
>>> @nnx.vjp(graph=False)
... def f(m, x):
... return jnp.sum(m(x))
...
>>> primals_out, vjp_fn = f(m, x)
Args:
f: Function to be differentiated. Its arguments can be arrays, scalars,
or pytrees containing arrays and NNX Variables.
*primals: A sequence of primal values at which the Jacobian of ``f``
should be evaluated.
has_aux: Optional, bool. Indicates whether ``f`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
reduce_axes: Deprecated, do not use.
graph: If ``True`` (default), uses graph-mode which supports the full
NNX feature set including shared references and reference semantics.
If ``False``, uses tree-mode which treats Modules as regular JAX
pytrees, avoiding the overhead of the graph protocol.
graph_updates: If ``True``, propagates updates on graph structure
that happen inside the transform to the input graphs, has no
effect when ``graph=False``.
Returns:
If ``has_aux`` is False, returns a ``(primals_out, vjp_fn)`` pair.
``vjp_fn`` takes a cotangent with the same structure as ``primals_out``
and returns gradients for each primal argument.
If ``has_aux`` is True, returns ``(primals_out, vjp_fn, aux)``.
"""
if graph is None:
graph = graphlib.set_graph_mode.current_value()
if graph_updates is None:
graph_updates = graphlib.set_graph_updates.current_value()
if graph and graph_updates:
raise NotImplementedError(
'graph-mode with graph_updates is not supported for nnx.vjp. '
'Set graph=False or graph_updates=False.'
)
if reduce_axes:
raise NotImplementedError('reduce_axes argument to vjp is deprecated')
del reduce_axes
if isinstance(f, Missing):
return functools.partial( # type: ignore[return-value]
vjp,
has_aux=has_aux,
graph=graph,
graph_updates=graph_updates,
)
f_unbound, _, was_bound = _resolve_bound_callable(f)
if was_bound:
_raise_bound_method_error('vjp')
if not primals:
return functools.partial( # type: ignore[return-value]
vjp,
f,
has_aux=has_aux,
graph=graph,
graph_updates=graph_updates,
)
if graph:
primals = extract.to_tree2(primals)
extract.check_no_aliases('vjp', primals=primals)
primals_out, vjp_fn, aux = jax.vjp(
SimpleVjpFn(f_unbound, has_aux=has_aux, graph=graph),
*primals,
has_aux=True,
)
if has_aux:
updates, user_aux = aux
else:
updates = aux
user_aux = None
if graph:
primals_out = extract.from_tree2(primals_out)
raw_vjp_fn = vjp_fn
def vjp_fn(g):
return extract.from_tree2(raw_vjp_fn(g))
extract.apply_variable_updates(primals, updates)
if has_aux:
return primals_out, vjp_fn, user_aux
else:
return primals_out, vjp_fn
# -----------------------------------------------
# jvp
# -----------------------------------------------
@dataclasses.dataclass(eq=False)
class SimpleJvpFn:
f: tp.Callable[..., tp.Any]
has_aux: bool
graph: bool
def __post_init__(self):
functools.update_wrapper(self, self.f, updated=())
@extract.treemap_copy_args
def __call__(self, *args):
updates, snapshot = extract.updates_and_snapshot(args)
if self.graph:
args = extract.from_tree2(args)
out = self.f(*args)
if self.graph:
out = extract.to_tree2(out)
extract.check_no_aliases('jvp', args=updates, out=out)
updates = extract.mask_variable_updates(updates, snapshot)
if self.has_aux:
primals_out, aux = out
return (primals_out, updates), aux
else:
return out, updates
@tp.overload
def jvp(
f: tp.Callable[..., tp.Any],
primals: tuple[tp.Any, ...],
tangents: tuple[tp.Any, ...],
*,
has_aux: bool = False,
graph: bool | None = None,
graph_updates: bool | None = None,
) -> tuple[tp.Any, ...]: ...
@tp.overload
def jvp(
*,
has_aux: bool = False,
graph: bool | None = None,
graph_updates: bool | None = None,
) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ...
@tp.overload
def jvp(
f: tp.Callable[..., tp.Any],
*,
has_aux: bool = False,
graph: bool | None = None,
graph_updates: bool | None = None,
) -> tp.Callable[..., tp.Any]: ...
[docs]def jvp(
f: tp.Callable[..., tp.Any] | Missing = MISSING,
primals: tuple[tp.Any, ...] | Missing = MISSING,
tangents: tuple[tp.Any, ...] | Missing = MISSING,
*,
has_aux: bool = False,
graph: bool | None = None,
graph_updates: bool | None = None,
) -> (
tuple[tp.Any, ...]
| tp.Callable[..., tp.Any]
| tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]
):
"""Stateful version of ``jax.jvp`` that propagates NNX Variable updates.
Example::
>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 2))
...
>>> def f(m, x):
... return jnp.sum(m(x))
...
>>> m_tangent = jax.tree.map(jnp.zeros_like, m)
>>> x_tangent = jnp.ones_like(x)
>>> primals_out, tangent_out = nnx.jvp(
... f, (m, x), (m_tangent, x_tangent), graph=False
... )
Can also be used as a decorator::
>>> @nnx.jvp(graph=False)
... def f(m, x):
... return jnp.sum(m(x))
...
>>> primals_out, tangent_out = f((m, x), (m_tangent, x_tangent))
Args:
f: Function to be differentiated. Its arguments can be arrays, scalars,
or pytrees containing arrays and NNX Variables.
primals: A tuple of primal values at which the Jacobian of ``f``
should be evaluated.
tangents: A tuple of tangent vectors, with the same structure as
``primals``.
has_aux: Optional, bool. Indicates whether ``f`` returns a pair where the
first element is considered the output of the mathematical function to be
differentiated and the second element is auxiliary data. Default False.
graph: If ``True`` (default), uses graph-mode which supports the full
NNX feature set including shared references and reference semantics.
If ``False``, uses tree-mode which treats Modules as regular JAX
pytrees, avoiding the overhead of the graph protocol.
graph_updates: If ``True``, propagates updates on graph structure
that happen inside the transform to the input graphs, has no
effect when ``graph=False``.
Returns:
If ``has_aux`` is False, returns ``(primals_out, tangent_out)``.
If ``has_aux`` is True, returns ``(primals_out, tangent_out, aux)``.
"""
if graph is None:
graph = graphlib.set_graph_mode.current_value()
if graph_updates is None:
graph_updates = graphlib.set_graph_updates.current_value()
if graph and graph_updates:
raise NotImplementedError(
'graph-mode with graph_updates is not supported for nnx.jvp. '
'Set graph=False or graph_updates=False.'
)
if isinstance(f, Missing):
return functools.partial(
jvp,
has_aux=has_aux,
graph=graph,
graph_updates=graph_updates,
)
f_unbound, _, was_bound = _resolve_bound_callable(f)
if was_bound:
_raise_bound_method_error('jvp')
if isinstance(primals, Missing) or isinstance(tangents, Missing):
return functools.partial(
jvp,
f,
has_aux=has_aux,
graph=graph,
graph_updates=graph_updates,
)
if graph:
primals = extract.to_tree2(primals)
tangents = extract.to_tree2(tangents)
extract.check_no_aliases('jvp', primals=primals)
extract.check_no_aliases('jvp', tangents=tangents)
if has_aux:
(primals_out, updates), (tangent_out, _updates_tangent), aux = jax.jvp(
SimpleJvpFn(f_unbound, has_aux=True, graph=graph),
primals,
tangents,
has_aux=True,
)
else:
(primals_out, updates), (tangent_out, _updates_tangent) = jax.jvp(
SimpleJvpFn(f_unbound, has_aux=False, graph=graph),
primals,
tangents,
)
if graph:
primals_out = extract.from_tree2(primals_out)
tangent_out = extract.from_tree2(tangent_out)
extract.apply_variable_updates(primals, updates)
if has_aux:
return primals_out, tangent_out, aux
else:
return primals_out, tangent_out
# -----------------------------------------------
# custom_vjp
# -----------------------------------------------
@dataclasses.dataclass(eq=False)
class SimpleCustomVjpFn:
f: tp.Callable[..., tp.Any]
graph: bool
def __post_init__(self):
functools.update_wrapper(self, self.f, updated=())
@extract.treemap_copy_args
def __call__(self, *args):
updates, snapshot = extract.updates_and_snapshot(args)
if self.graph:
args = extract.from_tree2(args)
out = self.f(*args)
if self.graph:
out = extract.to_tree2(out)
extract.check_no_aliases('custom_vjp', args=updates, out=out)
updates = extract.mask_variable_updates(updates, snapshot)
return out, updates
@dataclasses.dataclass(eq=False)
class SimpleFwdFn:
fwd: tp.Callable[..., tp.Any]
graph: bool
def __post_init__(self):
functools.update_wrapper(self, self.fwd, updated=())
@extract.treemap_copy_args
def __call__(self, *args):
updates, snapshot = extract.updates_and_snapshot(args)
if self.graph:
args = extract.from_tree2(args)
out, residual = self.fwd(*args)
if self.graph:
out = extract.to_tree2(out)
residual = extract.to_tree2(residual)
extract.check_no_aliases('custom_vjp', args=updates, out=out)
updates = extract.mask_variable_updates(updates, snapshot)
return (out, updates), residual
@dataclasses.dataclass(eq=False)
class SimpleBwdFn:
bwd: tp.Callable[..., tp.Any]
graph: bool
def __post_init__(self):
functools.update_wrapper(self, self.bwd, updated=())
@extract.treemap_copy_args
def __call__(self, *args):
*nondiff, residual, (out_g, _updates_g) = args
if self.graph:
nondiff = extract.from_tree2(nondiff)
residual = extract.from_tree2(residual)
result = self.bwd(*nondiff, residual, out_g)
if self.graph:
result = extract.to_tree2(result)
return result
class SimpleCustomVjp(tp.Generic[A]):
def __init__(
self,
fun: tp.Callable[..., A],
nondiff_argnums: tuple[int, ...],
graph: bool,
):
functools.update_wrapper(self, fun)
self.fun = fun
self.nondiff_argnums = nondiff_argnums
self.graph = graph
self.custom_vjp_fn = jax.custom_vjp(
fun=SimpleCustomVjpFn(fun, graph=graph),
nondiff_argnums=nondiff_argnums,
)
def __call__(
self, *args: tp.Any, **kwargs: tp.Any
) -> A:
args = resolve_kwargs(self.fun, args, kwargs)
del kwargs
if self.graph:
prefix = tuple(
i not in self.nondiff_argnums for i in range(len(args))
)
args = extract.to_tree2(args, prefix=prefix)
extract.check_no_aliases('custom_vjp', args=args)
(out, updates) = self.custom_vjp_fn(*args)
# check that differentiable arguments were not mutated
diff_argnums = tuple(
i for i in range(len(args)) if i not in self.nondiff_argnums
)
is_var = lambda x: isinstance(x, variablelib.Variable)
for i in diff_argnums:
changed = [
jax.tree_util.keystr(path)
for path, leaf in jax.tree.leaves_with_path(
updates[i], is_leaf=is_var
)
if leaf is not None
]
if changed:
paths_str = '\n '.join(changed)
raise ValueError(
f'Variables in differentiable argument {i} were mutated inside '
f'custom_vjp at:\n\n {paths_str}\n\nThis is not supported when '
f'graph_updates=False because the gradient for the Variable '
f'updates would be silently dropped. Move the Variable mutation '
f'to a non-differentiable argument, or use graph_updates=True.'
)
if self.graph:
out = extract.from_tree2(out)
extract.apply_variable_updates(args, updates)
return out
def defvjp(
self,
fwd: tp.Callable[..., tuple[A, tp.Any]],
bwd: tp.Callable[..., tuple[tp.Any, ...]],
symbolic_zeros: bool = False,
) -> None:
self.fwd = fwd
self.bwd = bwd
self.symbolic_zeros = symbolic_zeros
self.custom_vjp_fn.defvjp(
fwd=SimpleFwdFn(fwd, graph=self.graph),
bwd=SimpleBwdFn(bwd, graph=self.graph),
symbolic_zeros=symbolic_zeros,
)
# custom_vjp is one of the most complicated transforms as it requires
# to handle 4 different functions:
# 1. CustomVJP: the main object that runs the outer logic, converts input graph nodes
# to pytrees and output pytrees to graph nodes.
# 2. CustomVjpFnWrapper: function that wraps the user's function, it converts
# its input pytrees to graph nodes and output graph nodes to pytrees.
# 3. FwdFn: wraps the user's fwd function, it converts its input pytrees to graph nodes
# and output graph nodes to pytrees. Since it might run by itself in a separate context,
# it needs to be aware if the update_context is active or not in order to update the outer
# referenes.
# 4. BwdFn: wraps the user's bwd function, it converts its input pytrees to graph nodes
# and output graph nodes to pytrees. It doesn't need to be aware of the outer context
# since it will never update the outer references as it runs during the backward pass.
def _custom_vjp_merge_fn(
ctx: graphlib.MergeContext,
path,
prefix: bool | DiffState,
value: extract.NodeStates,
*,
nondiff_states: deque[extract.GraphDefState],
):
nondiff = nondiff_states.popleft()
return ctx.merge(nondiff.graphdef, value.state, nondiff.state)
def _custom_vjp_split_fn(
ctx: graphlib.SplitContext,
path,
prefix: bool | DiffState,
value,
*,
nondiff_states: list[extract.GraphDefState],
):
broadcast: graphlib.GraphState
if prefix is False:
# pure non-differentiable arg, not supported
raise TypeError(
'Passing integers to nondiff_argnums for graph nodes arguments in custom_vjp is not supported. '
f'Got {prefix} at path {jax.tree_util.keystr(path)} for value {value}'
)
elif prefix is True:
# pure differentiable arg, we pass all the state through
# but we return a TreeNode.from_states which doesn't have a graphdef
# in order to keep the gradients clean from any metadata
graphdef, passed = ctx.split(value)
broadcast = State({})
nondiff_states.append(extract.GraphDefState(graphdef, broadcast))
return extract.NodeStates.from_states(passed)
else:
# differentiable arg with DiffState filter, we use the filter to split the state
# as before we return a TreeNode.from_states to keep the gradients clean
# from any metadata, the non-differentiable state is stored in a deque
# which is broadcasted during the forward pass
graphdef, passed, broadcast = ctx.split(value, prefix.filter, ...) # type: ignore[misc]
nondiff_states.append(extract.GraphDefState(graphdef, broadcast))
return extract.NodeStates.from_states(passed)
nondiff_argnums: tuple[int, ...] = struct.field(pytree_node=False)
tangent_tree_node_args: tuple[tp.Any, ...] = struct.field(pytree_node=False)
def _extract_nodedefs(x, *, nodedefs: deque[graphlib.GraphDef]):
if isinstance(x, graphlib.GraphDef):
nodedefs.append(x)
return x.with_no_outer_index()
return x
@dataclasses.dataclass(eq=False)
class CustomVjpFnWrapper:
f: tp.Callable[..., tp.Any]
jax_nondiff_argnums: tuple[int, ...]
ctxtag: str
nondiff_states: list[extract.GraphDefState]
nodedefs: deque[graphlib.GraphDef]
def __post_init__(self):
functools.update_wrapper(self, self.f)
def __call__(self, *pure_args):
nondiff_states = deque(self.nondiff_states)
args = extract.from_tree(
pure_args,
merge_fn=functools.partial(
_custom_vjp_merge_fn, nondiff_states=nondiff_states
),
ctxtag=self.ctxtag,
is_inner=True,
)
out = self.f(*args)
# remove nondiff from pure_args_out_g
args_out = tuple(
x for i, x in enumerate(args) if i not in self.jax_nondiff_argnums
)
args_out = extract.clear_non_graph_nodes(args_out)
pure_args_out, pure_out = extract.to_tree(
(args_out, out), ctxtag=self.ctxtag
)
# remove outer_index from GraphDef's but store them in global context
pure_args_out, pure_out = jax.tree.map(
functools.partial(_extract_nodedefs, nodedefs=self.nodedefs),
(pure_args_out, pure_out),
is_leaf=lambda x: isinstance(x, graphlib.GraphDef),
)
return pure_args_out, pure_out
@dataclasses.dataclass(eq=False)
class FwdFn:
fwd: tp.Callable[..., tp.Any]
nondiff_argnums: tuple[int, ...]
ctxtag: str
nondiff_states: list[extract.GraphDefState]
nodedefs: deque[graphlib.GraphDef]
def __post_init__(self):
functools.update_wrapper(self, self.fwd)
def __call__(self, *pure_args):
# here we need to be aware if the update_context is active or not
# when its not active, index_mappings will be None
# when its active, we will remove the index_mappings from the GraphDef's and store them
# in the index_mappings deque created by CustomVjp
update_context_active = (
self.ctxtag in graphlib.GRAPH_CONTEXT.update_context_stacks
)
nondiff_states = deque(self.nondiff_states)
args = extract.from_tree(
pure_args,
merge_fn=functools.partial(
_custom_vjp_merge_fn, nondiff_states=nondiff_states
),
ctxtag=self.ctxtag if update_context_active else None,
is_inner=True,
)
out, residual = self.fwd(*args)
# remove nondiff from pure_args_out_g
args_out = tuple(
x for i, x in enumerate(args) if i not in self.nondiff_argnums
)
args_out = extract.clear_non_graph_nodes(args_out)
pure_args_out, pure_out = extract.to_tree(
(args_out, out),
ctxtag=self.ctxtag if update_context_active else None,
)
pure_residual = extract.to_tree(residual)
if update_context_active:
# remove outer_index from GraphDef's but store them in global context
pure_args_out, pure_out = jax.tree.map(
functools.partial(_extract_nodedefs, nodedefs=self.nodedefs),
(pure_args_out, pure_out),
is_leaf=lambda x: isinstance(x, graphlib.GraphDef),
)
return (pure_args_out, pure_out), pure_residual
@dataclasses.dataclass(eq=False)
class BwdFn:
bwd: tp.Callable[..., tp.Any]
tree_node_args: tuple[tp.Any, ...]
def __post_init__(self):
functools.update_wrapper(self, self.bwd)
def __call__(self, *args):
*nondiff, pure_residual, (pure_args_out_g, pure_out_g) = args
residual = extract.from_tree(pure_residual, is_inner=True)
(pure_args_out_g, pure_out_g) = jax.tree.map(
lambda x: x.state if isinstance(x, extract.NodeStates) else x,
(pure_args_out_g, pure_out_g),
is_leaf=lambda x: isinstance(x, extract.NodeStates),
)
tangent = self.bwd(*nondiff, residual, (pure_args_out_g, pure_out_g))
def state_to_node_states(is_differentiable: bool, x):
if is_differentiable:
if isinstance(x, jax.Array):
return x
elif not isinstance(x, State | variablelib.Variable):
raise ValueError(f'Expected State or Variable, got {type(x)}')
return extract.NodeStates.from_states(x)
return x
pure_tangent = jax.tree.map(
state_to_node_states,
self.tree_node_args,
tangent,
is_leaf=lambda x: isinstance(x, State | variablelib.Variable),
)
return pure_tangent
class CustomVjp(tp.Generic[A]):
def __init__(
self,
fun: tp.Callable[..., A],
nondiff_argnums: tuple[int | DiffState, ...],
):
functools.update_wrapper(self, fun)
# first argument is metadata
self.jax_nondiff_argnums = tuple(
x for x in nondiff_argnums if isinstance(x, int)
)
self.ctxtag = f'custom_vjp_{fun.__name__}_{id(fun)}'
self.fun = fun
self.fwd: tp.Callable | None = None
self.bwd: tp.Callable | None = None
self.symbolic_zeros: bool | None = None
self.nondiff_argnums = nondiff_argnums
self.diff_filter: dict[int, tp.Literal[False] | DiffState] = {}
for argnum in self.nondiff_argnums:
index = argnum.argnum if isinstance(argnum, DiffState) else argnum
if index in self.diff_filter:
raise ValueError(f'argnum {index} is repeated in nondiff_argnums')
self.diff_filter[index] = (
dataclasses.replace(argnum, argnum=-1)
if isinstance(argnum, DiffState)
else False
)
# def __getattr__(self, name: str) -> tp.Any:
# if not hasattr(self.custom_vjp_fn, name):
# raise AttributeError(f'{type(self).__name__} has no attribute {name}')
# return getattr(self.custom_vjp_fn, name)
def __call__(
self, *args: tp.Any, **kwargs: tp.Any
) -> A: # pytype: disable=invalid-annotation
with graphlib.update_context(self.ctxtag):
args = resolve_kwargs(self.fun, args, kwargs)
del kwargs
nondiff_states: list[extract.GraphDefState] = []
arg_filters = tuple(
self.diff_filter.get(i, True) for i in range(len(args))
)
pure_args = extract.to_tree(
args,
prefix=arg_filters,
split_fn=functools.partial(
_custom_vjp_split_fn, nondiff_states=nondiff_states
),
ctxtag=self.ctxtag,
)
tree_node_args = jax.tree.map(
lambda x: isinstance(x, extract.NodeStates),
pure_args,
is_leaf=lambda x: isinstance(x, extract.NodeStates),
)
tree_node_args = tuple(
x
for i, x in enumerate(tree_node_args)
if i not in self.jax_nondiff_argnums
)
nodedefs: deque[graphlib.GraphDef] = deque()
if self.fwd is None or self.bwd is None or self.symbolic_zeros is None:
raise ValueError()
custom_vjp_fn = jax.custom_vjp(
fun=CustomVjpFnWrapper(
f=self.fun,
jax_nondiff_argnums=self.jax_nondiff_argnums,
ctxtag=self.ctxtag,
nondiff_states=nondiff_states,
nodedefs=nodedefs,
),
nondiff_argnums=self.jax_nondiff_argnums,
)
custom_vjp_fn.defvjp(
fwd=FwdFn(
fwd=self.fwd,
nondiff_argnums=self.jax_nondiff_argnums,
ctxtag=self.ctxtag,
nondiff_states=nondiff_states,
nodedefs=nodedefs,
),
bwd=BwdFn(
bwd=self.bwd,
tree_node_args=tree_node_args,
),
symbolic_zeros=self.symbolic_zeros,
)
pure_args_out, pure_out = custom_vjp_fn(*pure_args)
# insert index_mappings
def _insert_index_mappings(x):
if isinstance(x, graphlib.GraphDef):
nodedef = nodedefs.popleft()
return nodedef
return x
pure_args_out, pure_out = jax.tree_util.tree_map(
_insert_index_mappings,
(pure_args_out, pure_out),
is_leaf=lambda x: isinstance(x, graphlib.GraphDef),
)
args_out, out = extract.from_tree(
(pure_args_out, pure_out), ctxtag=self.ctxtag, is_inner=False
)
return out
def defvjp(
self,
fwd: tp.Callable[..., tuple[A, tp.Any]],
bwd: tp.Callable[..., tuple[tp.Any, ...]],
symbolic_zeros: bool = False,
) -> None:
self.fwd = fwd
self.bwd = bwd
self.symbolic_zeros = symbolic_zeros
@tp.overload
def custom_vjp(
fun: tp.Callable[..., A],
*,
nondiff_argnums: tuple[int | DiffState, ...] = (),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> CustomVjp[A] | SimpleCustomVjp[A]: ...
@tp.overload
def custom_vjp(
*,
nondiff_argnums: tuple[int | DiffState, ...] = (),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> tp.Callable[[tp.Callable[..., A]], CustomVjp[A] | SimpleCustomVjp[A]]: ...
[docs]def custom_vjp(
fun: tp.Callable[..., A] | Missing = MISSING,
*,
nondiff_argnums: tuple[int | DiffState, ...] = (),
graph: bool | None = None,
graph_updates: bool | None = None,
) -> CustomVjp[A] | SimpleCustomVjp[A] | tp.Callable[[tp.Callable[..., A]], CustomVjp[A] | SimpleCustomVjp[A]]:
"""Reference aware version of
`jax.custom_vjp <https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_vjp.html>`__.
``nnx.custom_vjp`` accepts Modules and other Flax NNX objects as arguments. The main difference
with the JAX version is that, because Modules follow reference semantics, they propagate the State
updates for the inputs as auxiliary outputs. This means that the incoming gradients in the ``bwd`` function
will have the form ``(input_updates_g, out_g)`` where ``input_updates_g`` is the gradient updated state of
the inputs w.r.t. to the inputs. All Module terms on the inputs will an associated ``State`` term in
``input_updates_g``, while all non-Module terms will appear as None. The shape of the tangent will be
expected to have the same shape as the input, with ``State`` terms in place of the corresponding Module terms.
Example::
>>> import jax
>>> import jax.numpy as jnp
>>> from flax import nnx
...
>>> class Foo(nnx.Module):
... def __init__(self, x, y):
... self.x = nnx.Param(x)
... self.y = nnx.Param(y)
...
>>> @nnx.custom_vjp
... def f(m: Foo):
... return jnp.sin(m.x) * m.y
...
>>> def f_fwd(m: Foo):
... return f(m), (jnp.cos(m.x), jnp.sin(m.x), m)
...
>>> def f_bwd(res, g):
... input_updates_g, out_g = g
... cos_x, sin_x, m = res
... (m_updates_g,) = input_updates_g
... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy
...
... m_g['x'][...] = cos_x * out_g * m.y
... m_g['y'][...] = sin_x * out_g
... return (m_g,)
...
>>> f.defvjp(f_fwd, f_bwd)
...
>>> m = Foo(x=jnp.array(1.), y=jnp.array(2.))
>>> grads = nnx.grad(f)(m)
...
>>> jax.tree.map(jnp.shape, grads)
State({
'x': Param(
value=()
),
'y': Param(
value=()
)
})
Note that the State objects that represent Module terms on ``input_updates_g`` have the
same shape as the State objects expected in the output tanget. This means that you can
usually just copy them from ``input_updates_g`` and update them with their corresponding
gradient values.
You can select which substates are differentiable (have a tangent) for Modules and other
graph nodes by passing a ``DiffState`` to ``nondiff_argnums``. For example, if you want to
differentiate only the ``x`` attribute of the ``Foo`` class, you can do the following::
>>> x_attribute = nnx.PathContains('x')
>>> diff_state = nnx.DiffState(0, x_attribute)
...
>>> @nnx.custom_vjp(nondiff_argnums=(diff_state,))
... def f(m: Foo):
... return jnp.sin(m.x) * m.y # type: ignore
>>> def f_fwd(m: Foo):
... y = f(m)
... res = (jnp.cos(m.x), m) # type: ignore
... return y, res
...
>>> def f_bwd(res, g):
... input_updates_g, out_g = g
... cos_x, m = res
... (m_updates_g,) = input_updates_g
... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy
...
... m_g.x[...] = cos_x * out_g * m.y
... del m_g['y'] # y is not differentiable
... return (m_g,)
>>> f.defvjp(f_fwd, f_bwd)
...
>>> m = Foo(x=jnp.array(1.), y=jnp.array(2.))
>>> grad = nnx.grad(f, argnums=nnx.DiffState(0, x_attribute))(m)
...
>>> jax.tree.map(jnp.shape, grad)
State({
'x': Param(
value=()
)
})
Note that ``grad`` cannot calculate gradients for states that don't have a tangent
defined by ``custom_vjp``, in the example above we reuse the same ``x_attribute``
filter to keep ``custom_vjp`` and ``grad`` in sync.
**graph_updates=False**
When ``graph_updates=False`` or ``graph=False``, the behavior is closer to
``jax.custom_vjp``: the ``bwd`` function receives ``out_g`` directly, and
tangent types are the same as the input types, this means the tangent for a
Module is a Module instance with gradient values set on its attributes.
This mode does not support ``DiffState`` in ``nondiff_argnums``. Additionally,
Variables in differentiable arguments cannot be mutated inside ``f``. If
mutations are needed, pass the relevant Variables through a non-differentiable
argument instead.
Example::
>>> @nnx.custom_vjp(graph_updates=False)
... def f(m: Foo):
... return jnp.sin(m.x) * m.y
...
>>> def f_fwd(m: Foo):
... return f(m), (jnp.cos(m.x), jnp.sin(m.x), m)
...
>>> def f_bwd(res, g):
... cos_x, sin_x, m = res
... m_g = nnx.clone(m)
... m_g.x[...] = cos_x * g * m.y
... m_g.y[...] = sin_x * g
... return (m_g,)
...
>>> f.defvjp(f_fwd, f_bwd)
Args:
fun: Callable base function.
nondiff_argnums: Tuple of integers or DiffState objects specifying the
argument indices that are not differentiated. By default all arguments are
differentiated. Integers cannot be used to mark graph nodes such as Modules
as non-differentiable, in this case use a DiffState object. DiffState objects
define the set of differentiable substates, contrary to what the name of this
argument suggests, this is done for compatibility with ``grad``.
graph: If ``True`` (default), uses graph-mode which supports the full
NNX feature set including shared references and reference semantics.
If ``False``, uses tree-mode which treats Modules as regular JAX
pytrees, avoiding the overhead of the graph protocol. Tree-mode does
not support ``DiffState`` in ``nondiff_argnums``.
graph_updates: If ``True``, propagates updates on graph structure
that happen inside the transform to the input graphs, has no
effect when ``graph=False``. When ``False``, using ``DiffState``
is not supported.
"""
if graph is None:
graph = graphlib.set_graph_mode.current_value()
if graph_updates is None:
graph_updates = graphlib.set_graph_updates.current_value()
if isinstance(fun, Missing):
return functools.partial(
custom_vjp, nondiff_argnums=nondiff_argnums, graph=graph,
graph_updates=graph_updates,
)
# Detect bound nnx.Module methods and raise error.
fun_unbound, _, was_bound = _resolve_bound_callable(fun)
if was_bound:
_raise_bound_method_error('custom_vjp')
if not graph or not graph_updates:
if any(isinstance(x, DiffState) for x in nondiff_argnums):
raise ValueError(
'`nondiff_argnums` cannot contain `DiffState` objects '
'when `graph=False`. '
+ graphlib._tree_mode_suggestion_transform('custom_vjp')
)
return SimpleCustomVjp(fun_unbound, nondiff_argnums, graph=graph) # type: ignore[arg-type]
return CustomVjp(fun_unbound, nondiff_argnums)
# -------------------------------
# remat
# -------------------------------
@dataclasses.dataclass(eq=False)
class SimpleRematFn:
f: tp.Callable[..., tp.Any]
graph: bool
def __post_init__(self):
functools.update_wrapper(self, self.f, updated=())
@extract.treemap_copy_args
def __call__(self, *args, **kwargs):
updates, snapshot = extract.updates_and_snapshot((args, kwargs))
if self.graph:
args, kwargs = extract.from_tree2((args, kwargs))
out = self.f(*args, **kwargs)
if self.graph:
out = extract.to_tree2(out)
extract.check_no_aliases('remat', args=updates[0], kwargs=updates[1], out=out)
updates = extract.mask_variable_updates(updates, snapshot)
return out, updates
@tp.overload
def remat(
*,
prevent_cse: bool = True,
static_argnums: int | tuple[int, ...] = (),
policy: tp.Callable[..., bool] | None = None,
graph: bool | None = None,
graph_updates: bool | None = None,
) -> tp.Callable[[F], F]: ...
@tp.overload
def remat(
f: F,
*,
prevent_cse: bool = True,
static_argnums: int | tuple[int, ...] = (),
policy: tp.Callable[..., bool] | None = None,
graph: bool | None = None,
graph_updates: bool | None = None,
) -> F: ...
[docs]def remat(
f: F | Missing = MISSING,
*,
prevent_cse: bool = True,
static_argnums: int | tuple[int, ...] = (),
policy: tp.Callable[..., bool] | None = None,
graph: bool | None = None,
graph_updates: bool | None = None,
) -> F | tp.Callable[[F], F]:
"""A 'lifted' version of the
`jax.checkpoint <https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html>`__
(a.k.a. ``jax.remat``).
``flax.nnx.remat``, similar to ``jax.checkpoint`` can provide control over, for
example, how ``flax.nnx.grad`` values are computed and saved during the forward pass versus
how they are recomputed during the backward pass, trading off memory and FLOPs.
Learn more in `Flax NNX vs JAX Transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`_.
To learn about ``jax.remat``, go to JAX's
`fundamentals of jax.checkpoint <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#fundamentals-of-jax-checkpoint>`_
and `practical notes <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes>`_.
Args:
f: Function to be rematerialized.
prevent_cse: Optional, bool. If True, prevents common subexpression
elimination. Default True.
static_argnums: Optional, int or tuple of ints. Specifies which
positional arguments to treat as static.
policy: Optional, callable. A policy for which intermediates to save
during the forward pass.
graph: If ``True`` (default), uses graph-mode which supports the full
NNX feature set including shared references and reference semantics.
If ``False``, uses tree-mode which treats Modules as regular JAX
pytrees, avoiding the overhead of the graph protocol. Tree-mode does
not support shared ``Variable`` references.
graph_updates: If ``True``, propagates updates on graph structure
that happen inside the transform to the input graphs, has no
effect when ``graph=False``.
"""
if graph is None:
graph = graphlib.set_graph_mode.current_value()
if graph_updates is None:
graph_updates = graphlib.set_graph_updates.current_value()
if isinstance(f, Missing):
return functools.partial(
remat,
prevent_cse=prevent_cse,
static_argnums=static_argnums,
policy=policy,
graph=graph,
graph_updates=graph_updates,
) # type: ignore[return-value]
f_unbound, _, was_bound = _resolve_bound_callable(f)
if was_bound:
_raise_bound_method_error('remat')
if not graph or not graph_updates:
checkpointed_fn = jax.checkpoint(
SimpleRematFn(f_unbound, graph=graph),
prevent_cse=prevent_cse,
static_argnums=static_argnums,
policy=policy,
)
@functools.wraps(f_unbound)
def simple_remat_wrapper(*args, **kwargs):
if graph:
args, kwargs = extract.to_tree2((args, kwargs))
extract.check_no_aliases('remat', args=args, kwargs=kwargs)
out, updates = checkpointed_fn(*args, **kwargs)
if graph:
out = extract.from_tree2(out)
extract.apply_variable_updates((args, kwargs), updates)
return out
return simple_remat_wrapper # type: ignore[return-value]
# Unbound function path: preserve the concise composition used in NNX.
return resolve_kwargs()( # type: ignore[return-value]
graphlib.update_context('remat')(
general.split_inputs(
jax.checkpoint(
general.merge_inputs(f_unbound, ctxtag='remat'),
prevent_cse=prevent_cse,
static_argnums=static_argnums,
policy=policy,
),
ctxtag='remat',
),
)
)