Source code for flax.nnx.transforms.transforms

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
from __future__ import annotations

from abc import abstractmethod
import dataclasses
import functools
import inspect
import typing as tp

from jax._src import checkify as checkify_lib

from flax.nnx import (
  extract,
  graphlib,
  variablelib,
)
from flax.nnx.module import Module
from flax.nnx.proxy_caller import (
  CallableProxy,
  DelayedAccessor,
)
from flax.nnx.transforms import general
from flax.typing import MISSING, Leaf, Missing
import jax
import jax.core
import jax.stages

A = tp.TypeVar('A')
C = tp.TypeVar('C')
B = tp.TypeVar('B')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any])
M = tp.TypeVar('M', bound=Module)
MA = tp.TypeVar('MA', bound=Module)
N = tp.TypeVar('N', bound=Module)
StrInt = tp.TypeVar('StrInt', str, int)
AxisName = tp.Hashable
Leaves = list[Leaf]
Index = int


@tp.overload
def resolve_kwargs(
  fun: tp.Callable[..., tp.Any],
  args: tuple,
  kwargs: dict[str, tp.Any],
) -> tuple: ...
@tp.overload
def resolve_kwargs() -> tp.Callable[[F], F]: ...
def resolve_kwargs(
  fun: tp.Callable[..., tp.Any] | Missing = MISSING,
  args: tuple | Missing = MISSING,
  kwargs: dict[str, tp.Any] | Missing = MISSING,
) -> tuple | tp.Callable[[F], F]:
  if isinstance(fun, Missing):

    def resolve_kwargs_decorator(f):
      @functools.wraps(f)
      def resolve_kwargs_wrapper(*args, **kwargs):
        args = resolve_kwargs(f, args, kwargs)
        return f(*args)

      return resolve_kwargs_wrapper

    return resolve_kwargs_decorator  # type: ignore

  if isinstance(args, Missing):
    raise ValueError('args must be provided')
  if isinstance(kwargs, Missing):
    raise ValueError('kwargs must be provided')

  if isinstance(fun, functools.partial):
    # functools.partial should have an opaque signature.
    fun = lambda *args, **kwargs: None
  ba = inspect.signature(fun).bind(*args, **kwargs)
  ba.apply_defaults()
  if ba.kwargs:
    raise TypeError('keyword arguments could not be resolved to positions')
  else:
    return ba.args



# -------------------------------
# helper utilities for bound methods & indices
# -------------------------------

def _resolve_bound_callable(
  f: tp.Callable[..., tp.Any],
) -> tuple[tp.Callable[..., tp.Any], tp.Any | None, bool]:
  """Detects and extracts bound methods from NNX Module callables.

  This function unwraps functools.partial layers to reach the underlying
  callable before checking if it's a bound method of an NNX Module.

  Args:
    f: A callable that may be a bound method of an NNX Module, potentially
       wrapped in functools.partial.

  Returns:
    A tuple of (unbound_fn, bound_self, was_bound) where:
    - unbound_fn: The unbound function (or original if not bound)
    - bound_self: The Module instance if f was bound, None otherwise
    - was_bound: True if f was a bound method, False otherwise

  Note:
    Preserves functools.partial wrappers around the callable and follows
    the same detection pattern as _get_unbound_fn in bridge/module.py.
    Detection occurs before any argnum shifting or index normalization.
  """
  # Unwrap functools.partial layers to reach the underlying callable.
  partials: list[tuple[tuple[tp.Any, ...], dict[str, tp.Any] | None]] = []
  g = f
  while isinstance(g, functools.partial):  # type: ignore[arg-type]
    partials.append((g.args or (), g.keywords))  # type: ignore[attr-defined]
    g = g.func  # type: ignore[attr-defined]

  bound_self = getattr(g, "__self__", None)
  was_bound = bool(inspect.ismethod(g) and isinstance(bound_self, Module))
  if was_bound:
    g = g.__func__  # type: ignore[attr-defined]

  # Reapply partials in reverse unwrap order.
  for args, kwargs in reversed(partials):
    kwargs = {} if kwargs is None else kwargs
    g = functools.partial(g, *args, **kwargs)

  return g, (bound_self if was_bound else None), was_bound


def _raise_bound_method_error(transform_name: str):
  """Raises a standardized error for bound method usage with NNX transforms.

  Args:
    transform_name: Name of the transform (e.g., 'grad', 'jit', 'remat').
  """
  raise ValueError(
    f"nnx.{transform_name} does not support bound methods. "
    f"Use the decorator form @nnx.{transform_name} or call "
    f"nnx.{transform_name}(MyClass.method)(instance, ...) with the unbound method."
  )


class LiftedModule(tp.Generic[M], Module):  # type: ignore[ignored-abstractmethod]
  @abstractmethod
  def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> tp.Any:
    pass

  @property
  @abstractmethod
  def _submodule(self) -> M:
    pass  # type: ignore[bad-return-type] # why pytype?

  def __call__(self, *args, **kwargs) -> tp.Any:
    return self.call(*args, **kwargs)  # type: ignore

  @property
  def call(self) -> tp.Any:
    module = self

    def check_and_call(accessor: DelayedAccessor, *args, **kwargs):
      return self._call(accessor, *args, **kwargs)

    proxy = CallableProxy(check_and_call)  # type: ignore[arg-type]

    while isinstance(module._submodule, LiftedModule):
      module = module._submodule
      proxy = proxy.call

    return proxy  # type: ignore


# -------------------------------
# simple transforms
# -------------------------------
@dataclasses.dataclass(frozen=True)
class ValueMetadata:
  var_type: type[variablelib.Variable]
  value: tp.Any
  metadata: dict[str, tp.Any]


def _flatten_value_metadata(
  value_metadata: tp.Union[tp.Any, ValueMetadata],
):
  metadata = tuple(sorted(value_metadata.metadata.items()))
  return (value_metadata.value,), (value_metadata.var_type, metadata)


def _unflatten_value_metadata(aux_data, children):
  var_type, metadata_items = aux_data
  metadata = dict(metadata_items)
  return ValueMetadata(var_type=var_type, value=children[0], metadata=metadata)


jax.tree_util.register_pytree_node(
  ValueMetadata,
  _flatten_value_metadata,
  _unflatten_value_metadata,
)


def _to_value_metadata(node):
  def to_value_metadata(x):
    if isinstance(x, variablelib.Variable):
      value = x.get_raw_value()
      if variablelib.is_array_ref(value):
        value = value[...]
      metadata = x.get_metadata()
      return ValueMetadata(var_type=x.var_type, value=value, metadata=metadata)
    return x

  return jax.tree.map(
    to_value_metadata,
    node,
    is_leaf=lambda x: isinstance(x, variablelib.Variable),
  )


def _to_variable(node):
  # import here to avoid circular imports
  from flax.nnx.spmd import get_var_pspec

  def to_variable(x):
    if isinstance(x, ValueMetadata):
      var = x.var_type._new(x.value, x.metadata)

      global_mesh = jax.sharding.get_abstract_mesh()
      if global_mesh.axis_sizes == ():
        global_mesh = None
      mesh = var.get_metadata("mesh", None) or global_mesh
      if mesh is not None and (not hasattr(var, 'sharding') or var.sharding is None):
        pspec = get_var_pspec(var)
        sharding = jax.sharding.NamedSharding(mesh=mesh, spec=pspec)
        var.set_value(jax.ShapeDtypeStruct(shape=var.shape, dtype=var.dtype, sharding=sharding))
      return var
    return x

  return jax.tree.map(
    to_variable, node, is_leaf=lambda x: isinstance(x, ValueMetadata)
  )


@dataclasses.dataclass(eq=False)
class SimpleEvalShapeFn:
  f: tp.Callable[..., tp.Any]
  graph: bool

  def __post_init__(self):
    functools.update_wrapper(self, self.f, updated=())

  @extract.treemap_copy_args
  def __call__(self, *args, **kwargs):
    if self.graph:
      args, kwargs = extract.from_tree2((args, kwargs))
    out = self.f(*args, **kwargs)
    if self.graph:
      out = extract.to_tree2(out)
    extract.check_no_aliases('eval_shape', args=args, kwargs=kwargs, out=out)
    return out


[docs]def eval_shape( f: tp.Callable[..., A], *args: tp.Any, graph: bool | None = None, graph_updates: bool | None = None, **kwargs: tp.Any, ) -> A: """A \"lifted\" version of `jax.eval_shape <https://jax.readthedocs.io/en/latest/_autosummary/jax.eval_shape.html#jax.eval_shape>`_ that can handle `flax.nnx.Module <https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module>`_ / graph nodes as arguments. Similar to ``jax.eval_shape``, it computes the shape/dtype of a function `f` without performing any floating point operations (FLOPs) which can be expensive. This can be useful for performing shape inference, for example. Unlike `jax.eval_shape`, `nnx.eval_shape` will automatically compute the expected sharding based on Flax sharding metadata for all Variables not using explicit sharding. Args: f: the function to evaluate. *args: positional arguments to ``f``. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. **kwargs: keyword arguments to ``f``. """ f_call, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('eval_shape') if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if not graph or not graph_updates: if graph: args, kwargs = extract.to_tree2((args, kwargs)) extract.check_no_aliases('eval_shape', args=args, kwargs=kwargs) out = jax.eval_shape( SimpleEvalShapeFn(f_call, graph=graph), *args, **kwargs ) if graph: out = extract.from_tree2(out) return out args, kwargs = extract.to_tree((args, kwargs)) @functools.wraps(f) def _eval_shape_fn(*args, **kwargs): args, kwargs = extract.from_tree((args, kwargs)) out = f_call(*args, **kwargs) return _to_value_metadata(extract.to_tree(out)) out = jax.eval_shape(_eval_shape_fn, *args, **kwargs) return extract.from_tree(_to_variable(out))
@dataclasses.dataclass(eq=False) class CheckifyFn: f: tp.Callable[..., tp.Any] def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args, **pure_kwargs): args, kwargs = extract.from_tree( (pure_args, pure_kwargs), ctxtag='checkify', is_inner=True ) out = self.f(*args, **kwargs) args_out, kwargs_out = extract.clear_non_graph_nodes((args, kwargs)) pure_args_out, pure_kwargs_out, pure_out = extract.to_tree( (args, kwargs, out), ctxtag='checkify' ) return pure_args_out, pure_kwargs_out, pure_out @dataclasses.dataclass(eq=False) class SimpleCheckifyFn: f: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args): updates, snapshot = extract.updates_and_snapshot(args) if self.graph: args = extract.from_tree2(args) out = self.f(*args) if self.graph: out = extract.to_tree2(out) extract.check_no_aliases('checkify', args=updates, out=out) updates = extract.mask_variable_updates(updates, snapshot) return out, updates def checkify( f: tp.Callable[..., checkify_lib.Out], errors: frozenset[type[checkify_lib.JaxException]] = checkify_lib.user_checks, # type: ignore graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[..., tuple[checkify_lib.Error, checkify_lib.Out]]: """Reference-aware version of `jax.experimental.checkify <https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-functional-api>`_. Example:: >>> import jax >>> import jax.numpy as jnp >>> from jax.experimental import checkify >>> import dataclasses >>> from flax import nnx ... >>> class Foo(nnx.Module): ... def __init__(self, a): ... self.a = nnx.Param(a) ... >>> @nnx.jit ... def f(m): ... y = jnp.sin(m.a) # error ... return m.a + y ... >>> m = Foo(a=jnp.inf) >>> err, out = nnx.checkify(f, errors=checkify.float_checks)(m) >>> # err.throw() >>> print(err) Error(nan generated by primitive: sin.) Args: f: the function to checkify. errors: the set of error checks to enable. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. """ f_call, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('checkify') if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if not graph or not graph_updates: checkify_fn = checkify_lib.checkify( SimpleCheckifyFn(f_call, graph=graph), errors, ) @functools.wraps(f) def simple_checkify_wrapper(*args): if graph: args = extract.to_tree2(args) extract.check_no_aliases('checkify', args=args) error, (out, updates) = checkify_fn(*args) if graph: out = extract.from_tree2(out) extract.apply_variable_updates(args, updates) return error, out return simple_checkify_wrapper # type: ignore checkify_fn = checkify_lib.checkify(CheckifyFn(f_call), errors) @functools.wraps(f) @graphlib.update_context('checkify') def checkify_wrapper(*args, **kwargs): pure_args, pure_kwargs = extract.to_tree( (args, kwargs), ctxtag='checkify', ) error, (pure_args_out, pure_kwargs_out, pure_out) = checkify_fn( *pure_args, **pure_kwargs ) args_out, kwargs_out, out = extract.from_tree( (pure_args_out, pure_kwargs_out, pure_out), ctxtag='checkify', is_inner=False, ) return error, out return checkify_wrapper # type: ignore @dataclasses.dataclass(eq=False) class SimpleCondFn: f: tp.Callable[..., tp.Any] graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args): updates, _snapshot = extract.updates_and_snapshot(args) if self.graph: args = extract.from_tree2(args) out = self.f(*args) if self.graph: out = extract.to_tree2(out) extract.check_no_aliases('switch', args=updates, out=out) return out, updates
[docs]def cond( pred, true_fun: tp.Callable[..., A], false_fun: tp.Callable[..., A], *operands, graph: bool | None = None, graph_updates: bool | None = None, ) -> A: """Conditionally apply ``true_fun`` or ``false_fun``. Wraps `jax.lax.cond <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.cond.html>`__ to support Flax NNX modules and variables. Args: pred: boolean scalar. If True, ``true_fun`` is applied, otherwise ``false_fun``. true_fun: function to apply if ``pred`` is True. false_fun: function to apply if ``pred`` is False. *operands: operands passed to whichever branch is selected. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if not graph or not graph_updates: if graph: operands = extract.to_tree2(operands) extract.check_no_aliases('cond', operands=operands) out, updates = jax.lax.cond( pred, SimpleCondFn(true_fun, graph=graph), SimpleCondFn(false_fun, graph=graph), *operands, ) if graph: out = extract.from_tree2(out) extract.apply_variable_updates(operands, updates) return out @general.split_inputs(ctxtag='cond') def _cond(pred, true_fun, false_fun, *operands): return jax.lax.cond( pred, general.merge_inputs(true_fun, ctxtag='cond'), general.merge_inputs(false_fun, ctxtag='cond'), *operands, ) return _cond(pred, true_fun, false_fun, *operands)
[docs]def switch( index, branches: tp.Sequence[tp.Callable[..., A]], *operands, graph: bool | None = None, graph_updates: bool | None = None, ) -> A: """Select and apply one of ``branches`` based on ``index``. Wraps `jax.lax.switch <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.switch.html>`__ to support Flax NNX modules and variables. Args: index: integer scalar indicating which branch to apply. branches: sequence of functions to select from. *operands: operands passed to the selected branch. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if not graph or not graph_updates: if graph: operands = extract.to_tree2(operands) extract.check_no_aliases('switch', operands=operands) out, updates = jax.lax.switch( index, [SimpleCondFn(f, graph=graph) for f in branches], *operands, ) if graph: out = extract.from_tree2(out) extract.apply_variable_updates(operands, updates) return out @general.split_inputs(ctxtag='switch') def _switch(index, branches, *operands): return jax.lax.switch( index, [general.merge_inputs(f, ctxtag='switch') for f in branches], *operands, ) return _switch(index, branches, *operands)