Source code for flax.nnx.transforms.autodiff

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import deque
import dataclasses
import functools
import typing as tp


from flax import struct
from flax.nnx import (
  extract,
  filterlib,
  graphlib,
  variablelib,
)
from flax.nnx.statelib import State
import jax

from flax.nnx.transforms import general
from flax.nnx.transforms.transforms import (
  resolve_kwargs,
  _resolve_bound_callable,
  _raise_bound_method_error,
)
from flax.typing import MISSING, Missing


A = tp.TypeVar('A')
# C = tp.TypeVar('C')
# B = tp.TypeVar('B')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
# G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any])
# M = tp.TypeVar('M', bound=Module)
# MA = tp.TypeVar('MA', bound=Module)
# N = tp.TypeVar('N', bound=Module)
# StrInt = tp.TypeVar('StrInt', str, int)
AxisName = tp.Hashable
# Leaves = tp.List[Leaf]
# Index = int


# -------------------------------
# grad
# -------------------------------


@dataclasses.dataclass(frozen=True)
class DiffState:
  argnum: int
  filter: filterlib.Filter


@dataclasses.dataclass(eq=False)
class SimpleGradFn:
  f: tp.Callable[..., tp.Any]
  has_aux: bool
  graph: bool

  def __post_init__(self):
    functools.update_wrapper(self, self.f, updated=())

  @extract.treemap_copy_args
  def __call__(self, *args, **kwargs):
    updates, snapshot = extract.updates_and_snapshot((args, kwargs))
    if self.graph:
      args, kwargs = extract.from_tree2((args, kwargs))
    out = self.f(*args, **kwargs)
    if self.graph:
      out = extract.to_tree2(out)
    extract.check_no_aliases('grad', args=updates[0], kwargs=updates[1], out=out)
    updates = extract.mask_variable_updates(updates, snapshot)

    if self.has_aux:
      loss, aux = out
      return loss, (updates, aux)
    else:
      return out, updates


@dataclasses.dataclass(eq=False)
class GradFn:
  f: tp.Callable[..., tp.Any]
  has_aux: bool
  nondiff_states: deque[State | None]

  def __post_init__(self):
    functools.update_wrapper(self, self.f)

  def __call__(self, *pure_args):
    # rebuild diff_state from substates in args

    def _grad_merge_fn(
      ctx: graphlib.MergeContext, path, prefix, value: extract.NodeStates
    ):
      nondiff = self.nondiff_states.popleft()
      if nondiff is None:
        return ctx.merge(value.graphdef, value.state)
      else:
        return ctx.merge(value.graphdef, value.state, nondiff)

    args = extract.from_tree(
      pure_args, merge_fn=_grad_merge_fn, ctxtag='grad', is_inner=True
    )

    out = self.f(*args)

    args_out = extract.clear_non_graph_nodes(args)
    pure_args_out, pure_out = extract.to_tree((args_out, out), ctxtag='grad')

    if self.has_aux:
      loss, pure_aux = pure_out
      fn_out = (loss, (pure_args_out, pure_aux))
    else:
      loss = pure_out
      fn_out = (loss, pure_args_out)

    return fn_out


def _grad_general(
    f: tp.Callable[..., tp.Any],
    argnums: int | DiffState | tp.Sequence[int | DiffState],
    has_aux: bool,
    holomorphic: bool,
    allow_int: bool,
    return_value: bool,
    graph: bool,
    graph_updates: bool,
) -> tp.Callable[..., tp.Any]:

  transform = jax.value_and_grad if return_value else jax.grad

  if not graph or not graph_updates:
    if any(isinstance(x, DiffState) for x in jax.tree.leaves(argnums)):
      raise ValueError(
        '`argnums` cannot contain `DiffState` objects '
        'when `graph=False`. '
        + graphlib._tree_mode_suggestion_transform('grad')
      )

    gradded_fn = transform(
        SimpleGradFn(f, has_aux, graph=graph),
        argnums=argnums,  # type: ignore[arg-type]
        has_aux=True,
        holomorphic=holomorphic,
        allow_int=allow_int,
    )

    def tree_grad_wrapper(*args, **kwargs):
      if graph:
        diff_argnums = (argnums,) if isinstance(argnums, int) else argnums
        args_prefix = tuple(
          i in diff_argnums for i in range(len(args))
        )
        args, kwargs = extract.to_tree2(
          (args, kwargs), prefix=(args_prefix, False),
        )

      extract.check_no_aliases('grad', args=args, kwargs=kwargs)

      fn_out = gradded_fn(*args, **kwargs)

      if return_value:
        if has_aux:
          (loss, (updates, aux)), grads = fn_out
          if graph: grads, aux = extract.from_tree2((grads, aux))
          result = (loss, aux), grads
        else:
          (loss, updates), grads = fn_out
          if graph: grads = extract.from_tree2(grads)
          result = loss, grads
      else:
        if has_aux:
          grads, (updates, aux) = fn_out
          if graph: grads, aux = extract.from_tree2((grads, aux))
          result = grads, aux
        else:
          grads, updates = fn_out
          if graph: grads = extract.from_tree2(grads)
          result = grads

      extract.apply_variable_updates((args, kwargs), updates)
      return result

    return tree_grad_wrapper

  jax_argnums: int | tuple[int, ...]
  if isinstance(argnums, (int, DiffState)):
    jax_argnums = argnums.argnum if isinstance(argnums, DiffState) else argnums
  else:
    jax_argnums = tuple(
      x.argnum if isinstance(x, DiffState) else x for x in argnums
    )

  _argnums = (argnums,) if isinstance(argnums, (int, DiffState)) else argnums
  index_filter: dict[int, DiffState] = {}
  for argnum in _argnums:
    index = argnum.argnum if isinstance(argnum, DiffState) else argnum
    if index in index_filter:
      raise ValueError(f'argnum {index} is repeated in argnums')
    index_filter[index] = (
      dataclasses.replace(argnum, argnum=-1)
      if isinstance(argnum, DiffState)
      else DiffState(-1, variablelib.Param)
    )

  @graphlib.update_context('grad')
  def grad_wrapper(*args, **kwargs):
    args = resolve_kwargs(f, args, kwargs)
    del kwargs
    nondiff_states: deque[State | variablelib.Variable | None] = deque()

    def _grad_split_fn(
      ctx: graphlib.SplitContext, path, prefix: DiffState | None, value
    ):
      if prefix is None or (prefix.argnum == -1 and isinstance(value, variablelib.Variable)):
        nondiff_states.append(None)
        return extract.NodeStates.from_split(*ctx.split(value))
      else:
        graphdef, diff, nondiff = ctx.split(value, prefix.filter, ...)  # type: ignore[misc]
        nondiff_states.append(nondiff)  # type: ignore[container-type-mismatch]
        return extract.NodeStates.from_split(graphdef, diff)

    arg_filters = tuple(index_filter.get(i) for i in range(len(args)))
    pure_args = extract.to_tree(
      args, prefix=arg_filters, split_fn=_grad_split_fn, ctxtag='grad'
    )

    gradded_fn = transform(
      GradFn(f, has_aux, nondiff_states),
      argnums=jax_argnums,
      has_aux=True,
      holomorphic=holomorphic,
      allow_int=allow_int,
    )

    fn_out = gradded_fn(*pure_args)

    def process_grads(grads):
      return jax.tree.map(
        lambda x: x.state if isinstance(x, extract.NodeStates) else x,
        grads,
        is_leaf=lambda x: isinstance(x, extract.NodeStates),
      )

    def process_out(pure_out: A, /) -> A:
      return extract.from_tree(pure_out, ctxtag='grad', is_inner=False)

    if return_value:
      # unpack value_and_grad output
      if has_aux:
        (loss, (pure_args_out, pure_aux)), grads = fn_out
        grads = process_grads(grads)
        _args_out, aux = process_out((pure_args_out, pure_aux))
        return (loss, aux), grads
      else:
        (loss, pure_args_out), grads = fn_out
        grads = process_grads(grads)
        _args_out = process_out(pure_args_out)
        return loss, grads
    else:
      # unpack grad output
      if has_aux:
        grads, (pure_args_out, pure_aux) = fn_out
        grads = process_grads(grads)
        _args_out, aux = process_out((pure_args_out, pure_aux))
        return grads, aux
      else:
        grads, pure_args_out = fn_out
        grads = process_grads(grads)
        _args_out = process_out(pure_args_out)
        return grads

  return grad_wrapper


@tp.overload
def grad(
  f: tp.Callable[..., tp.Any],
  *,
  argnums: int | DiffState | tp.Sequence[int | DiffState] = 0,
  has_aux: bool = False,
  holomorphic: bool = False,
  allow_int: bool = False,
  reduce_axes: tp.Sequence[AxisName] = (),
  graph: bool | None = None,
  graph_updates: bool | None = None,
) -> tp.Callable[..., tp.Any]: ...
@tp.overload
def grad(
  *,
  argnums: int | DiffState | tp.Sequence[int | DiffState] = 0,
  has_aux: bool = False,
  holomorphic: bool = False,
  allow_int: bool = False,
  reduce_axes: tp.Sequence[AxisName] = (),
  graph: bool | None = None,
  graph_updates: bool | None = None,
) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ...
[docs]def grad( f: tp.Callable[..., tp.Any] | Missing = MISSING, *, argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: tp.Sequence[AxisName] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> ( tp.Callable[..., tp.Any] | tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]] ): """Object-aware version of ``jax.grad`` that can handle Modules / graph nodes as arguments. Example:: >>> from flax import nnx >>> import jax.numpy as jnp ... >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 3)) ... >>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) >>> grad_fn = nnx.grad(loss_fn) ... >>> grads = grad_fn(m, x, y) >>> jax.tree.map(jnp.shape, grads) State({ 'bias': Param( value=(3,) ), 'kernel': Param( value=(2, 3) ) }) By default, NNX objects are differentiated with respect to all their ``nnx.Param`` Variables. You can specify which substates are differentiable by passing a ``DiffState`` object to the ``argnums`` argument. For example, if you want to differentiate only the ``kernel`` attribute of the ``Linear`` class, you can use the ``PathContains`` filter:: >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) ... >>> kernel_attribute = nnx.PathContains('kernel') >>> diff_state = nnx.DiffState(0, kernel_attribute) ... >>> loss_fn = lambda m, x, y: jnp.mean((m(x) - y) ** 2) >>> grad_fn = nnx.grad(loss_fn, argnums=diff_state) ... >>> grads = grad_fn(m, x, y) >>> jax.tree.map(jnp.shape, grads) State({ 'kernel': Param( value=(2, 3) ) }) For more information on how to create custom filters, see `Using Filters <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__ guide. Args: fun: Function to be differentiated. Its arguments at positions specified by ``argnums`` should be arrays, scalars, graph nodes or standard Python containers. Argument arrays in the positions specified by ``argnums`` must be of inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape ``()`` but not arrays with shape ``(1,)`` etc.) argnums: Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be holomorphic. If True, inputs and outputs must be complex. Default False. allow_int: Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not support ``DiffState`` or shared ``Variable`` references. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. When ``False``, using ``DiffState`` is not supported. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if reduce_axes: raise NotImplementedError('reduce_axes argument to grad is deprecated') del reduce_axes if isinstance(f, Missing): return functools.partial( grad, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int, graph=graph, graph_updates=graph_updates, ) # Detect bound nnx.Module methods and raise error. f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('grad') return _grad_general( f_unbound, argnums, has_aux, holomorphic, allow_int, return_value=False, graph=graph, graph_updates=graph_updates, )
@tp.overload def value_and_grad( f: tp.Callable[..., tp.Any], *, argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: tp.Sequence[AxisName] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[..., tp.Any]: ... @tp.overload def value_and_grad( *, argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: tp.Sequence[AxisName] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ...
[docs]def value_and_grad( f: tp.Callable[..., tp.Any] | type[Missing] = Missing, *, argnums: int | DiffState | tp.Sequence[int | DiffState] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: tp.Sequence[AxisName] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> ( tp.Callable[..., tp.Any] | tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]] ): """Object-aware version of ``jax.value_and_grad``. Like :func:`grad`, but returns both the value and the gradient of ``f``. Args: f: Function to be differentiated. Its arguments at positions specified by ``argnums`` should be arrays, scalars, graph nodes or standard Python containers. Argument arrays in the positions specified by ``argnums`` must be of inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape ``()`` but not arrays with shape ``(1,)`` etc.) argnums: Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). has_aux: Optional, bool. Indicates whether ``f`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. holomorphic: Optional, bool. Indicates whether ``f`` is promised to be holomorphic. If True, inputs and outputs must be complex. Default False. allow_int: Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vector-space dtype (float0). Default False. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not support ``DiffState`` or shared ``Variable`` references. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. When ``False``, using ``DiffState`` is not supported. Returns: A function with the same arguments as ``f`` that evaluates both ``f`` and the gradient of ``f`` and returns them as a pair (a two-element tuple). If ``argnums`` is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a sequence of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. If ``has_aux`` is True then a tuple of ((value, auxiliary_data), gradient) is returned. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if reduce_axes: raise NotImplementedError( 'reduce_axes argument to value_and_grad is deprecated') del reduce_axes if f is Missing: return functools.partial( value_and_grad, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int, graph=graph, graph_updates=graph_updates, ) # Detect bound nnx.Module methods and raise error. f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('value_and_grad') return _grad_general( f_unbound, argnums, has_aux, holomorphic, allow_int, return_value=True, graph=graph, graph_updates=graph_updates, )
# ----------------------------------------------- # vjp # ----------------------------------------------- @dataclasses.dataclass(eq=False) class SimpleVjpFn: f: tp.Callable[..., tp.Any] has_aux: bool graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args): updates, snapshot = extract.updates_and_snapshot(args) if self.graph: args = extract.from_tree2(args) out = self.f(*args) if self.graph: out = extract.to_tree2(out) extract.check_no_aliases('vjp', args=updates, out=out) updates = extract.mask_variable_updates(updates, snapshot) if self.has_aux: primals_out, aux = out return primals_out, (updates, aux) else: return out, updates @tp.overload def vjp( f: tp.Callable[..., tp.Any], *primals: tp.Any, has_aux: bool = False, reduce_axes: tp.Sequence[AxisName] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> tuple[tp.Any, tp.Callable] | tuple[tp.Any, tp.Callable, tp.Any]: ... @tp.overload def vjp( *, has_aux: bool = False, reduce_axes: tp.Sequence[AxisName] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ...
[docs]def vjp( f: tp.Callable[..., tp.Any] | Missing = MISSING, *primals: tp.Any, has_aux: bool = False, reduce_axes: tp.Sequence[AxisName] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> ( tuple[tp.Any, tp.Callable] | tuple[tp.Any, tp.Callable, tp.Any] | tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]] ): """Stateful version of ``jax.vjp`` that propagates NNX Variable updates. Example:: >>> from flax import nnx >>> import jax.numpy as jnp ... >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) ... >>> def loss_fn(m, x): ... return jnp.sum(m(x)) ... >>> primals_out, vjp_fn = nnx.vjp(loss_fn, m, x, graph=False) >>> m_grad, x_grad = vjp_fn(jnp.ones_like(primals_out)) Can also be used as a decorator:: >>> @nnx.vjp(graph=False) ... def f(m, x): ... return jnp.sum(m(x)) ... >>> primals_out, vjp_fn = f(m, x) Args: f: Function to be differentiated. Its arguments can be arrays, scalars, or pytrees containing arrays and NNX Variables. *primals: A sequence of primal values at which the Jacobian of ``f`` should be evaluated. has_aux: Optional, bool. Indicates whether ``f`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. reduce_axes: Deprecated, do not use. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. Returns: If ``has_aux`` is False, returns a ``(primals_out, vjp_fn)`` pair. ``vjp_fn`` takes a cotangent with the same structure as ``primals_out`` and returns gradients for each primal argument. If ``has_aux`` is True, returns ``(primals_out, vjp_fn, aux)``. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if graph and graph_updates: raise NotImplementedError( 'graph-mode with graph_updates is not supported for nnx.vjp. ' 'Set graph=False or graph_updates=False.' ) if reduce_axes: raise NotImplementedError('reduce_axes argument to vjp is deprecated') del reduce_axes if isinstance(f, Missing): return functools.partial( # type: ignore[return-value] vjp, has_aux=has_aux, graph=graph, graph_updates=graph_updates, ) f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('vjp') if not primals: return functools.partial( # type: ignore[return-value] vjp, f, has_aux=has_aux, graph=graph, graph_updates=graph_updates, ) if graph: primals = extract.to_tree2(primals) extract.check_no_aliases('vjp', primals=primals) primals_out, vjp_fn, aux = jax.vjp( SimpleVjpFn(f_unbound, has_aux=has_aux, graph=graph), *primals, has_aux=True, ) if has_aux: updates, user_aux = aux else: updates = aux user_aux = None if graph: primals_out = extract.from_tree2(primals_out) raw_vjp_fn = vjp_fn def vjp_fn(g): return extract.from_tree2(raw_vjp_fn(g)) extract.apply_variable_updates(primals, updates) if has_aux: return primals_out, vjp_fn, user_aux else: return primals_out, vjp_fn
# ----------------------------------------------- # jvp # ----------------------------------------------- @dataclasses.dataclass(eq=False) class SimpleJvpFn: f: tp.Callable[..., tp.Any] has_aux: bool graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args): updates, snapshot = extract.updates_and_snapshot(args) if self.graph: args = extract.from_tree2(args) out = self.f(*args) if self.graph: out = extract.to_tree2(out) extract.check_no_aliases('jvp', args=updates, out=out) updates = extract.mask_variable_updates(updates, snapshot) if self.has_aux: primals_out, aux = out return (primals_out, updates), aux else: return out, updates @tp.overload def jvp( f: tp.Callable[..., tp.Any], primals: tuple[tp.Any, ...], tangents: tuple[tp.Any, ...], *, has_aux: bool = False, graph: bool | None = None, graph_updates: bool | None = None, ) -> tuple[tp.Any, ...]: ... @tp.overload def jvp( *, has_aux: bool = False, graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]]: ... @tp.overload def jvp( f: tp.Callable[..., tp.Any], *, has_aux: bool = False, graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[..., tp.Any]: ...
[docs]def jvp( f: tp.Callable[..., tp.Any] | Missing = MISSING, primals: tuple[tp.Any, ...] | Missing = MISSING, tangents: tuple[tp.Any, ...] | Missing = MISSING, *, has_aux: bool = False, graph: bool | None = None, graph_updates: bool | None = None, ) -> ( tuple[tp.Any, ...] | tp.Callable[..., tp.Any] | tp.Callable[[tp.Callable[..., tp.Any]], tp.Callable[..., tp.Any]] ): """Stateful version of ``jax.jvp`` that propagates NNX Variable updates. Example:: >>> from flax import nnx >>> import jax.numpy as jnp ... >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) ... >>> def f(m, x): ... return jnp.sum(m(x)) ... >>> m_tangent = jax.tree.map(jnp.zeros_like, m) >>> x_tangent = jnp.ones_like(x) >>> primals_out, tangent_out = nnx.jvp( ... f, (m, x), (m_tangent, x_tangent), graph=False ... ) Can also be used as a decorator:: >>> @nnx.jvp(graph=False) ... def f(m, x): ... return jnp.sum(m(x)) ... >>> primals_out, tangent_out = f((m, x), (m_tangent, x_tangent)) Args: f: Function to be differentiated. Its arguments can be arrays, scalars, or pytrees containing arrays and NNX Variables. primals: A tuple of primal values at which the Jacobian of ``f`` should be evaluated. tangents: A tuple of tangent vectors, with the same structure as ``primals``. has_aux: Optional, bool. Indicates whether ``f`` returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. Returns: If ``has_aux`` is False, returns ``(primals_out, tangent_out)``. If ``has_aux`` is True, returns ``(primals_out, tangent_out, aux)``. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if graph and graph_updates: raise NotImplementedError( 'graph-mode with graph_updates is not supported for nnx.jvp. ' 'Set graph=False or graph_updates=False.' ) if isinstance(f, Missing): return functools.partial( jvp, has_aux=has_aux, graph=graph, graph_updates=graph_updates, ) f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('jvp') if isinstance(primals, Missing) or isinstance(tangents, Missing): return functools.partial( jvp, f, has_aux=has_aux, graph=graph, graph_updates=graph_updates, ) if graph: primals = extract.to_tree2(primals) tangents = extract.to_tree2(tangents) extract.check_no_aliases('jvp', primals=primals) extract.check_no_aliases('jvp', tangents=tangents) if has_aux: (primals_out, updates), (tangent_out, _updates_tangent), aux = jax.jvp( SimpleJvpFn(f_unbound, has_aux=True, graph=graph), primals, tangents, has_aux=True, ) else: (primals_out, updates), (tangent_out, _updates_tangent) = jax.jvp( SimpleJvpFn(f_unbound, has_aux=False, graph=graph), primals, tangents, ) if graph: primals_out = extract.from_tree2(primals_out) tangent_out = extract.from_tree2(tangent_out) extract.apply_variable_updates(primals, updates) if has_aux: return primals_out, tangent_out, aux else: return primals_out, tangent_out
# ----------------------------------------------- # custom_vjp # ----------------------------------------------- @dataclasses.dataclass(eq=False) class SimpleCustomVjpFn: f: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args): updates, snapshot = extract.updates_and_snapshot(args) if self.graph: args = extract.from_tree2(args) out = self.f(*args) if self.graph: out = extract.to_tree2(out) extract.check_no_aliases('custom_vjp', args=updates, out=out) updates = extract.mask_variable_updates(updates, snapshot) return out, updates @dataclasses.dataclass(eq=False) class SimpleFwdFn: fwd: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.fwd, updated=()) @extract.treemap_copy_args def __call__(self, *args): updates, snapshot = extract.updates_and_snapshot(args) if self.graph: args = extract.from_tree2(args) out, residual = self.fwd(*args) if self.graph: out = extract.to_tree2(out) residual = extract.to_tree2(residual) extract.check_no_aliases('custom_vjp', args=updates, out=out) updates = extract.mask_variable_updates(updates, snapshot) return (out, updates), residual @dataclasses.dataclass(eq=False) class SimpleBwdFn: bwd: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.bwd, updated=()) @extract.treemap_copy_args def __call__(self, *args): *nondiff, residual, (out_g, _updates_g) = args if self.graph: nondiff = extract.from_tree2(nondiff) residual = extract.from_tree2(residual) result = self.bwd(*nondiff, residual, out_g) if self.graph: result = extract.to_tree2(result) return result class SimpleCustomVjp(tp.Generic[A]): def __init__( self, fun: tp.Callable[..., A], nondiff_argnums: tuple[int, ...], graph: bool, ): functools.update_wrapper(self, fun) self.fun = fun self.nondiff_argnums = nondiff_argnums self.graph = graph self.custom_vjp_fn = jax.custom_vjp( fun=SimpleCustomVjpFn(fun, graph=graph), nondiff_argnums=nondiff_argnums, ) def __call__( self, *args: tp.Any, **kwargs: tp.Any ) -> A: args = resolve_kwargs(self.fun, args, kwargs) del kwargs if self.graph: prefix = tuple( i not in self.nondiff_argnums for i in range(len(args)) ) args = extract.to_tree2(args, prefix=prefix) extract.check_no_aliases('custom_vjp', args=args) (out, updates) = self.custom_vjp_fn(*args) # check that differentiable arguments were not mutated diff_argnums = tuple( i for i in range(len(args)) if i not in self.nondiff_argnums ) is_var = lambda x: isinstance(x, variablelib.Variable) for i in diff_argnums: changed = [ jax.tree_util.keystr(path) for path, leaf in jax.tree.leaves_with_path( updates[i], is_leaf=is_var ) if leaf is not None ] if changed: paths_str = '\n '.join(changed) raise ValueError( f'Variables in differentiable argument {i} were mutated inside ' f'custom_vjp at:\n\n {paths_str}\n\nThis is not supported when ' f'graph_updates=False because the gradient for the Variable ' f'updates would be silently dropped. Move the Variable mutation ' f'to a non-differentiable argument, or use graph_updates=True.' ) if self.graph: out = extract.from_tree2(out) extract.apply_variable_updates(args, updates) return out def defvjp( self, fwd: tp.Callable[..., tuple[A, tp.Any]], bwd: tp.Callable[..., tuple[tp.Any, ...]], symbolic_zeros: bool = False, ) -> None: self.fwd = fwd self.bwd = bwd self.symbolic_zeros = symbolic_zeros self.custom_vjp_fn.defvjp( fwd=SimpleFwdFn(fwd, graph=self.graph), bwd=SimpleBwdFn(bwd, graph=self.graph), symbolic_zeros=symbolic_zeros, ) # custom_vjp is one of the most complicated transforms as it requires # to handle 4 different functions: # 1. CustomVJP: the main object that runs the outer logic, converts input graph nodes # to pytrees and output pytrees to graph nodes. # 2. CustomVjpFnWrapper: function that wraps the user's function, it converts # its input pytrees to graph nodes and output graph nodes to pytrees. # 3. FwdFn: wraps the user's fwd function, it converts its input pytrees to graph nodes # and output graph nodes to pytrees. Since it might run by itself in a separate context, # it needs to be aware if the update_context is active or not in order to update the outer # referenes. # 4. BwdFn: wraps the user's bwd function, it converts its input pytrees to graph nodes # and output graph nodes to pytrees. It doesn't need to be aware of the outer context # since it will never update the outer references as it runs during the backward pass. def _custom_vjp_merge_fn( ctx: graphlib.MergeContext, path, prefix: bool | DiffState, value: extract.NodeStates, *, nondiff_states: deque[extract.GraphDefState], ): nondiff = nondiff_states.popleft() return ctx.merge(nondiff.graphdef, value.state, nondiff.state) def _custom_vjp_split_fn( ctx: graphlib.SplitContext, path, prefix: bool | DiffState, value, *, nondiff_states: list[extract.GraphDefState], ): broadcast: graphlib.GraphState if prefix is False: # pure non-differentiable arg, not supported raise TypeError( 'Passing integers to nondiff_argnums for graph nodes arguments in custom_vjp is not supported. ' f'Got {prefix} at path {jax.tree_util.keystr(path)} for value {value}' ) elif prefix is True: # pure differentiable arg, we pass all the state through # but we return a TreeNode.from_states which doesn't have a graphdef # in order to keep the gradients clean from any metadata graphdef, passed = ctx.split(value) broadcast = State({}) nondiff_states.append(extract.GraphDefState(graphdef, broadcast)) return extract.NodeStates.from_states(passed) else: # differentiable arg with DiffState filter, we use the filter to split the state # as before we return a TreeNode.from_states to keep the gradients clean # from any metadata, the non-differentiable state is stored in a deque # which is broadcasted during the forward pass graphdef, passed, broadcast = ctx.split(value, prefix.filter, ...) # type: ignore[misc] nondiff_states.append(extract.GraphDefState(graphdef, broadcast)) return extract.NodeStates.from_states(passed) nondiff_argnums: tuple[int, ...] = struct.field(pytree_node=False) tangent_tree_node_args: tuple[tp.Any, ...] = struct.field(pytree_node=False) def _extract_nodedefs(x, *, nodedefs: deque[graphlib.GraphDef]): if isinstance(x, graphlib.GraphDef): nodedefs.append(x) return x.with_no_outer_index() return x @dataclasses.dataclass(eq=False) class CustomVjpFnWrapper: f: tp.Callable[..., tp.Any] jax_nondiff_argnums: tuple[int, ...] ctxtag: str nondiff_states: list[extract.GraphDefState] nodedefs: deque[graphlib.GraphDef] def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args): nondiff_states = deque(self.nondiff_states) args = extract.from_tree( pure_args, merge_fn=functools.partial( _custom_vjp_merge_fn, nondiff_states=nondiff_states ), ctxtag=self.ctxtag, is_inner=True, ) out = self.f(*args) # remove nondiff from pure_args_out_g args_out = tuple( x for i, x in enumerate(args) if i not in self.jax_nondiff_argnums ) args_out = extract.clear_non_graph_nodes(args_out) pure_args_out, pure_out = extract.to_tree( (args_out, out), ctxtag=self.ctxtag ) # remove outer_index from GraphDef's but store them in global context pure_args_out, pure_out = jax.tree.map( functools.partial(_extract_nodedefs, nodedefs=self.nodedefs), (pure_args_out, pure_out), is_leaf=lambda x: isinstance(x, graphlib.GraphDef), ) return pure_args_out, pure_out @dataclasses.dataclass(eq=False) class FwdFn: fwd: tp.Callable[..., tp.Any] nondiff_argnums: tuple[int, ...] ctxtag: str nondiff_states: list[extract.GraphDefState] nodedefs: deque[graphlib.GraphDef] def __post_init__(self): functools.update_wrapper(self, self.fwd) def __call__(self, *pure_args): # here we need to be aware if the update_context is active or not # when its not active, index_mappings will be None # when its active, we will remove the index_mappings from the GraphDef's and store them # in the index_mappings deque created by CustomVjp update_context_active = ( self.ctxtag in graphlib.GRAPH_CONTEXT.update_context_stacks ) nondiff_states = deque(self.nondiff_states) args = extract.from_tree( pure_args, merge_fn=functools.partial( _custom_vjp_merge_fn, nondiff_states=nondiff_states ), ctxtag=self.ctxtag if update_context_active else None, is_inner=True, ) out, residual = self.fwd(*args) # remove nondiff from pure_args_out_g args_out = tuple( x for i, x in enumerate(args) if i not in self.nondiff_argnums ) args_out = extract.clear_non_graph_nodes(args_out) pure_args_out, pure_out = extract.to_tree( (args_out, out), ctxtag=self.ctxtag if update_context_active else None, ) pure_residual = extract.to_tree(residual) if update_context_active: # remove outer_index from GraphDef's but store them in global context pure_args_out, pure_out = jax.tree.map( functools.partial(_extract_nodedefs, nodedefs=self.nodedefs), (pure_args_out, pure_out), is_leaf=lambda x: isinstance(x, graphlib.GraphDef), ) return (pure_args_out, pure_out), pure_residual @dataclasses.dataclass(eq=False) class BwdFn: bwd: tp.Callable[..., tp.Any] tree_node_args: tuple[tp.Any, ...] def __post_init__(self): functools.update_wrapper(self, self.bwd) def __call__(self, *args): *nondiff, pure_residual, (pure_args_out_g, pure_out_g) = args residual = extract.from_tree(pure_residual, is_inner=True) (pure_args_out_g, pure_out_g) = jax.tree.map( lambda x: x.state if isinstance(x, extract.NodeStates) else x, (pure_args_out_g, pure_out_g), is_leaf=lambda x: isinstance(x, extract.NodeStates), ) tangent = self.bwd(*nondiff, residual, (pure_args_out_g, pure_out_g)) def state_to_node_states(is_differentiable: bool, x): if is_differentiable: if isinstance(x, jax.Array): return x elif not isinstance(x, State | variablelib.Variable): raise ValueError(f'Expected State or Variable, got {type(x)}') return extract.NodeStates.from_states(x) return x pure_tangent = jax.tree.map( state_to_node_states, self.tree_node_args, tangent, is_leaf=lambda x: isinstance(x, State | variablelib.Variable), ) return pure_tangent class CustomVjp(tp.Generic[A]): def __init__( self, fun: tp.Callable[..., A], nondiff_argnums: tuple[int | DiffState, ...], ): functools.update_wrapper(self, fun) # first argument is metadata self.jax_nondiff_argnums = tuple( x for x in nondiff_argnums if isinstance(x, int) ) self.ctxtag = f'custom_vjp_{fun.__name__}_{id(fun)}' self.fun = fun self.fwd: tp.Callable | None = None self.bwd: tp.Callable | None = None self.symbolic_zeros: bool | None = None self.nondiff_argnums = nondiff_argnums self.diff_filter: dict[int, tp.Literal[False] | DiffState] = {} for argnum in self.nondiff_argnums: index = argnum.argnum if isinstance(argnum, DiffState) else argnum if index in self.diff_filter: raise ValueError(f'argnum {index} is repeated in nondiff_argnums') self.diff_filter[index] = ( dataclasses.replace(argnum, argnum=-1) if isinstance(argnum, DiffState) else False ) # def __getattr__(self, name: str) -> tp.Any: # if not hasattr(self.custom_vjp_fn, name): # raise AttributeError(f'{type(self).__name__} has no attribute {name}') # return getattr(self.custom_vjp_fn, name) def __call__( self, *args: tp.Any, **kwargs: tp.Any ) -> A: # pytype: disable=invalid-annotation with graphlib.update_context(self.ctxtag): args = resolve_kwargs(self.fun, args, kwargs) del kwargs nondiff_states: list[extract.GraphDefState] = [] arg_filters = tuple( self.diff_filter.get(i, True) for i in range(len(args)) ) pure_args = extract.to_tree( args, prefix=arg_filters, split_fn=functools.partial( _custom_vjp_split_fn, nondiff_states=nondiff_states ), ctxtag=self.ctxtag, ) tree_node_args = jax.tree.map( lambda x: isinstance(x, extract.NodeStates), pure_args, is_leaf=lambda x: isinstance(x, extract.NodeStates), ) tree_node_args = tuple( x for i, x in enumerate(tree_node_args) if i not in self.jax_nondiff_argnums ) nodedefs: deque[graphlib.GraphDef] = deque() if self.fwd is None or self.bwd is None or self.symbolic_zeros is None: raise ValueError() custom_vjp_fn = jax.custom_vjp( fun=CustomVjpFnWrapper( f=self.fun, jax_nondiff_argnums=self.jax_nondiff_argnums, ctxtag=self.ctxtag, nondiff_states=nondiff_states, nodedefs=nodedefs, ), nondiff_argnums=self.jax_nondiff_argnums, ) custom_vjp_fn.defvjp( fwd=FwdFn( fwd=self.fwd, nondiff_argnums=self.jax_nondiff_argnums, ctxtag=self.ctxtag, nondiff_states=nondiff_states, nodedefs=nodedefs, ), bwd=BwdFn( bwd=self.bwd, tree_node_args=tree_node_args, ), symbolic_zeros=self.symbolic_zeros, ) pure_args_out, pure_out = custom_vjp_fn(*pure_args) # insert index_mappings def _insert_index_mappings(x): if isinstance(x, graphlib.GraphDef): nodedef = nodedefs.popleft() return nodedef return x pure_args_out, pure_out = jax.tree_util.tree_map( _insert_index_mappings, (pure_args_out, pure_out), is_leaf=lambda x: isinstance(x, graphlib.GraphDef), ) args_out, out = extract.from_tree( (pure_args_out, pure_out), ctxtag=self.ctxtag, is_inner=False ) return out def defvjp( self, fwd: tp.Callable[..., tuple[A, tp.Any]], bwd: tp.Callable[..., tuple[tp.Any, ...]], symbolic_zeros: bool = False, ) -> None: self.fwd = fwd self.bwd = bwd self.symbolic_zeros = symbolic_zeros @tp.overload def custom_vjp( fun: tp.Callable[..., A], *, nondiff_argnums: tuple[int | DiffState, ...] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> CustomVjp[A] | SimpleCustomVjp[A]: ... @tp.overload def custom_vjp( *, nondiff_argnums: tuple[int | DiffState, ...] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[tp.Callable[..., A]], CustomVjp[A] | SimpleCustomVjp[A]]: ...
[docs]def custom_vjp( fun: tp.Callable[..., A] | Missing = MISSING, *, nondiff_argnums: tuple[int | DiffState, ...] = (), graph: bool | None = None, graph_updates: bool | None = None, ) -> CustomVjp[A] | SimpleCustomVjp[A] | tp.Callable[[tp.Callable[..., A]], CustomVjp[A] | SimpleCustomVjp[A]]: """Reference aware version of `jax.custom_vjp <https://jax.readthedocs.io/en/latest/_autosummary/jax.custom_vjp.html>`__. ``nnx.custom_vjp`` accepts Modules and other Flax NNX objects as arguments. The main difference with the JAX version is that, because Modules follow reference semantics, they propagate the State updates for the inputs as auxiliary outputs. This means that the incoming gradients in the ``bwd`` function will have the form ``(input_updates_g, out_g)`` where ``input_updates_g`` is the gradient updated state of the inputs w.r.t. to the inputs. All Module terms on the inputs will an associated ``State`` term in ``input_updates_g``, while all non-Module terms will appear as None. The shape of the tangent will be expected to have the same shape as the input, with ``State`` terms in place of the corresponding Module terms. Example:: >>> import jax >>> import jax.numpy as jnp >>> from flax import nnx ... >>> class Foo(nnx.Module): ... def __init__(self, x, y): ... self.x = nnx.Param(x) ... self.y = nnx.Param(y) ... >>> @nnx.custom_vjp ... def f(m: Foo): ... return jnp.sin(m.x) * m.y ... >>> def f_fwd(m: Foo): ... return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) ... >>> def f_bwd(res, g): ... input_updates_g, out_g = g ... cos_x, sin_x, m = res ... (m_updates_g,) = input_updates_g ... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy ... ... m_g['x'][...] = cos_x * out_g * m.y ... m_g['y'][...] = sin_x * out_g ... return (m_g,) ... >>> f.defvjp(f_fwd, f_bwd) ... >>> m = Foo(x=jnp.array(1.), y=jnp.array(2.)) >>> grads = nnx.grad(f)(m) ... >>> jax.tree.map(jnp.shape, grads) State({ 'x': Param( value=() ), 'y': Param( value=() ) }) Note that the State objects that represent Module terms on ``input_updates_g`` have the same shape as the State objects expected in the output tanget. This means that you can usually just copy them from ``input_updates_g`` and update them with their corresponding gradient values. You can select which substates are differentiable (have a tangent) for Modules and other graph nodes by passing a ``DiffState`` to ``nondiff_argnums``. For example, if you want to differentiate only the ``x`` attribute of the ``Foo`` class, you can do the following:: >>> x_attribute = nnx.PathContains('x') >>> diff_state = nnx.DiffState(0, x_attribute) ... >>> @nnx.custom_vjp(nondiff_argnums=(diff_state,)) ... def f(m: Foo): ... return jnp.sin(m.x) * m.y # type: ignore >>> def f_fwd(m: Foo): ... y = f(m) ... res = (jnp.cos(m.x), m) # type: ignore ... return y, res ... >>> def f_bwd(res, g): ... input_updates_g, out_g = g ... cos_x, m = res ... (m_updates_g,) = input_updates_g ... m_g = jax.tree.map(lambda x: x, m_updates_g) # create copy ... ... m_g.x[...] = cos_x * out_g * m.y ... del m_g['y'] # y is not differentiable ... return (m_g,) >>> f.defvjp(f_fwd, f_bwd) ... >>> m = Foo(x=jnp.array(1.), y=jnp.array(2.)) >>> grad = nnx.grad(f, argnums=nnx.DiffState(0, x_attribute))(m) ... >>> jax.tree.map(jnp.shape, grad) State({ 'x': Param( value=() ) }) Note that ``grad`` cannot calculate gradients for states that don't have a tangent defined by ``custom_vjp``, in the example above we reuse the same ``x_attribute`` filter to keep ``custom_vjp`` and ``grad`` in sync. **graph_updates=False** When ``graph_updates=False`` or ``graph=False``, the behavior is closer to ``jax.custom_vjp``: the ``bwd`` function receives ``out_g`` directly, and tangent types are the same as the input types, this means the tangent for a Module is a Module instance with gradient values set on its attributes. This mode does not support ``DiffState`` in ``nondiff_argnums``. Additionally, Variables in differentiable arguments cannot be mutated inside ``f``. If mutations are needed, pass the relevant Variables through a non-differentiable argument instead. Example:: >>> @nnx.custom_vjp(graph_updates=False) ... def f(m: Foo): ... return jnp.sin(m.x) * m.y ... >>> def f_fwd(m: Foo): ... return f(m), (jnp.cos(m.x), jnp.sin(m.x), m) ... >>> def f_bwd(res, g): ... cos_x, sin_x, m = res ... m_g = nnx.clone(m) ... m_g.x[...] = cos_x * g * m.y ... m_g.y[...] = sin_x * g ... return (m_g,) ... >>> f.defvjp(f_fwd, f_bwd) Args: fun: Callable base function. nondiff_argnums: Tuple of integers or DiffState objects specifying the argument indices that are not differentiated. By default all arguments are differentiated. Integers cannot be used to mark graph nodes such as Modules as non-differentiable, in this case use a DiffState object. DiffState objects define the set of differentiable substates, contrary to what the name of this argument suggests, this is done for compatibility with ``grad``. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not support ``DiffState`` in ``nondiff_argnums``. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. When ``False``, using ``DiffState`` is not supported. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if isinstance(fun, Missing): return functools.partial( custom_vjp, nondiff_argnums=nondiff_argnums, graph=graph, graph_updates=graph_updates, ) # Detect bound nnx.Module methods and raise error. fun_unbound, _, was_bound = _resolve_bound_callable(fun) if was_bound: _raise_bound_method_error('custom_vjp') if not graph or not graph_updates: if any(isinstance(x, DiffState) for x in nondiff_argnums): raise ValueError( '`nondiff_argnums` cannot contain `DiffState` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('custom_vjp') ) return SimpleCustomVjp(fun_unbound, nondiff_argnums, graph=graph) # type: ignore[arg-type] return CustomVjp(fun_unbound, nondiff_argnums)
# ------------------------------- # remat # ------------------------------- @dataclasses.dataclass(eq=False) class SimpleRematFn: f: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args, **kwargs): updates, snapshot = extract.updates_and_snapshot((args, kwargs)) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out) extract.check_no_aliases('remat', args=updates[0], kwargs=updates[1], out=out) updates = extract.mask_variable_updates(updates, snapshot) return out, updates @tp.overload def remat( *, prevent_cse: bool = True, static_argnums: int | tuple[int, ...] = (), policy: tp.Callable[..., bool] | None = None, graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[F], F]: ... @tp.overload def remat( f: F, *, prevent_cse: bool = True, static_argnums: int | tuple[int, ...] = (), policy: tp.Callable[..., bool] | None = None, graph: bool | None = None, graph_updates: bool | None = None, ) -> F: ...
[docs]def remat( f: F | Missing = MISSING, *, prevent_cse: bool = True, static_argnums: int | tuple[int, ...] = (), policy: tp.Callable[..., bool] | None = None, graph: bool | None = None, graph_updates: bool | None = None, ) -> F | tp.Callable[[F], F]: """A 'lifted' version of the `jax.checkpoint <https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html>`__ (a.k.a. ``jax.remat``). ``flax.nnx.remat``, similar to ``jax.checkpoint`` can provide control over, for example, how ``flax.nnx.grad`` values are computed and saved during the forward pass versus how they are recomputed during the backward pass, trading off memory and FLOPs. Learn more in `Flax NNX vs JAX Transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`_. To learn about ``jax.remat``, go to JAX's `fundamentals of jax.checkpoint <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#fundamentals-of-jax-checkpoint>`_ and `practical notes <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes>`_. Args: f: Function to be rematerialized. prevent_cse: Optional, bool. If True, prevents common subexpression elimination. Default True. static_argnums: Optional, int or tuple of ints. Specifies which positional arguments to treat as static. policy: Optional, callable. A policy for which intermediates to save during the forward pass. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not support shared ``Variable`` references. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if isinstance(f, Missing): return functools.partial( remat, prevent_cse=prevent_cse, static_argnums=static_argnums, policy=policy, graph=graph, graph_updates=graph_updates, ) # type: ignore[return-value] f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('remat') if not graph or not graph_updates: checkpointed_fn = jax.checkpoint( SimpleRematFn(f_unbound, graph=graph), prevent_cse=prevent_cse, static_argnums=static_argnums, policy=policy, ) @functools.wraps(f_unbound) def simple_remat_wrapper(*args, **kwargs): if graph: args, kwargs = extract.to_tree2((args, kwargs)) extract.check_no_aliases('remat', args=args, kwargs=kwargs) out, updates = checkpointed_fn(*args, **kwargs) if graph: out = extract.from_tree2(out) extract.apply_variable_updates((args, kwargs), updates) return out return simple_remat_wrapper # type: ignore[return-value] # Unbound function path: preserve the concise composition used in NNX. return resolve_kwargs()( # type: ignore[return-value] graphlib.update_context('remat')( general.split_inputs( jax.checkpoint( general.merge_inputs(f_unbound, ctxtag='remat'), prevent_cse=prevent_cse, static_argnums=static_argnums, policy=policy, ), ctxtag='remat', ), ) )