Source code for flax.nnx.transforms.compilation

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
from __future__ import annotations

import dataclasses
import functools
import inspect
import operator
import typing as tp

import jax
from jax.sharding import AbstractMesh, Mesh, PartitionSpec

from flax.nnx import (
  extract,
  filterlib,
  graphlib,
  statelib,
  variablelib,
)
from flax.nnx.transforms.transforms import (
  _resolve_bound_callable,
  _raise_bound_method_error,
)
from flax.typing import MISSING, Missing, PathParts

F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
P = tp.ParamSpec('P')
R = tp.TypeVar('R')
Specs = tp.Any
AxisName = tp.Hashable

# -------------------------------
# jit
# -------------------------------


class StateSharding(extract.PrefixMapping):
  def __init__(
    self,
    filter_sharding: statelib.State
    | tp.Mapping[filterlib.Filter, tp.Any]
    | tp.Iterable[tuple[filterlib.Filter, tp.Any]],
    /,
  ):
    if isinstance(filter_sharding, statelib.State):
      filter_sharding = statelib.create_path_filters(filter_sharding)  # type: ignore

    iterable = tuple(
      filter_sharding.items()
      if isinstance(filter_sharding, tp.Mapping)
      else filter_sharding
    )
    self._filters = tuple(filter for filter, _ in iterable)
    self._shardings = tuple(axis for _, axis in iterable)

  @property
  def filters(self) -> tuple[filterlib.Filter, ...]:
    return self._filters

  @property
  def shardings(self) -> tuple[tp.Any, ...]:
    return self._shardings

  def map_prefix(
    self, path: PathParts, variable: variablelib.Variable
  ) -> tp.Any:
    for filter, sharding in zip(self.filters, self.shardings):
      predicate = filterlib.to_predicate(filter)
      if predicate(path, variable):
        return sharding
    raise ValueError(f'No axis found for {path=}, {variable=}')

  def __repr__(self):
    return f'StateSharding({dict(zip(self.filters, self.shardings))})'

  def __eq__(self, other):
    return (
      isinstance(other, StateSharding)
      and self.filters == other.filters
      and self.shardings == other.shardings
    )

  def __hash__(self):
    return hash((self.filters, self.shardings))


def _jit_split_fn(ctx: graphlib.SplitContext, path, prefix, x):
  if isinstance(prefix, StateSharding):
    graphdef, *states = ctx.flatten(x, *prefix.filters)
    return extract.NodeStates.from_split(graphdef, *states, metadata=prefix)
  return extract.NodeStates.from_split(*ctx.flatten(x, with_paths=False))


def _jit_merge_fn(ctx: graphlib.MergeContext, path, prefix, leaf) -> tp.Any:
  if not isinstance(leaf, extract.NodeStates):
    raise ValueError(f'Expected TreeNode, got {type(leaf)} at path {path}')
  return ctx.unflatten(leaf.graphdef, *leaf.states)


@dataclasses.dataclass(eq=False)
class JitFn:
  f: tp.Callable[..., tp.Any]
  in_shardings: tp.Any
  out_shardings: tp.Any
  kwarg_shardings: tp.Any
  ctxtag: tp.Hashable

  def __post_init__(self):
    # Prevent overwriting our ctxtag info with the child function's
    orig_ctxtag = self.ctxtag
    functools.update_wrapper(self, self.f, updated=())
    self.ctxtag = orig_ctxtag

  def __call__(self, *pure_args, **pure_kwargs):
    args, kwargs = extract.from_tree(
      (pure_args, pure_kwargs),
      merge_fn=_jit_merge_fn,
      ctxtag=self.ctxtag,
      is_inner=True,
    )

    out = self.f(*args, **kwargs)

    args_out, kwargs_out = extract.clear_non_graph_nodes((args, kwargs))
    pure_args_out, pure_kwargs_out, pure_out = extract.to_tree(
      (args_out, kwargs_out, out),
      prefix=(self.in_shardings, self.kwarg_shardings, self.out_shardings),
      ctxtag=self.ctxtag,
      split_fn=_jit_split_fn,
    )

    return pure_args_out, pure_kwargs_out, pure_out


@tp.overload
def jit(
  *,
  in_shardings: tp.Any = None,
  out_shardings: tp.Any = None,
  static_argnums: int | tp.Sequence[int] | None = None,
  static_argnames: str | tp.Iterable[str] | None = None,
  donate_argnums: int | tp.Sequence[int] | None = None,
  donate_argnames: str | tp.Iterable[str] | None = None,
  keep_unused: bool = False,
  device: tp.Optional[jax.Device] = None,
  backend: tp.Optional[str] = None,
  inline: bool = False,
  graph: bool | None = None,
  graph_updates: bool | None = None,
) -> tp.Callable[[tp.Callable[P, R]], JitWrapped[P, R]]: ...
@tp.overload
def jit(
  fun: tp.Callable[P, R],
  *,
  in_shardings: tp.Any = None,
  out_shardings: tp.Any = None,
  static_argnums: int | tp.Sequence[int] | None = None,
  static_argnames: str | tp.Iterable[str] | None = None,
  donate_argnums: int | tp.Sequence[int] | None = None,
  donate_argnames: str | tp.Iterable[str] | None = None,
  keep_unused: bool = False,
  device: tp.Optional[jax.Device] = None,
  backend: tp.Optional[str] = None,
  inline: bool = False,
  graph: bool | None = None,
  graph_updates: bool | None = None,
) -> JitWrapped[P, R]: ...
[docs]def jit( fun: tp.Callable[P, R] | Missing = MISSING, *, in_shardings: tp.Any = None, out_shardings: tp.Any = None, static_argnums: int | tp.Sequence[int] | None = None, static_argnames: str | tp.Iterable[str] | None = None, donate_argnums: int | tp.Sequence[int] | None = None, donate_argnames: str | tp.Iterable[str] | None = None, keep_unused: bool = False, device: tp.Optional[jax.Device] = None, backend: tp.Optional[str] = None, inline: bool = False, graph: bool | None = None, graph_updates: bool | None = None, ) -> JitWrapped[P, R] | tp.Callable[[tp.Callable[P, R]], JitWrapped[P, R]]: """ Lifted version of ``jax.jit`` that can handle Modules / graph nodes as arguments. .. note:: If jitted function has a model and an optimizer as inputs, we can reduce accelerator's memory usage if we specify them in ``donate_argnums`` or ``donate_argnames``: >>> from flax import nnx >>> >>> @nnx.jit(donate_argnames=("model", "optimizer")) ... def func(model: nnx.Module, optimizer: nnx.Optimizer, other_args): ... pass For details please see `this discussion <https://github.com/google/flax/issues/5026>`_. Args: fun: Function to be jitted. ``fun`` should be a pure function, as side-effects may only be executed once. The arguments and return value of ``fun`` should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by ``static_argnums`` can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined. JAX keeps a weak reference to ``fun`` for use as a compilation cache key, so the object ``fun`` must be weakly-referenceable. Most :class:`Callable` objects will already satisfy this requirement. .. note:: Bound methods (e.g., ``module.method``) are not supported. Use the decorator form ``@nnx.jit`` on the method definition or call ``nnx.jit(MyClass.method)(instance, ...)`` with the unbound method. in_shardings: Pytree of structure matching that of arguments to ``fun``, with all actual arguments replaced by resource assignment specifications. It is also valid to specify a pytree prefix (e.g. one value in place of a whole subtree), in which case the leaves get broadcast to all values in that subtree. The ``in_shardings`` argument is optional. JAX will infer the shardings from the input :py:class:`jax.Array`'s and defaults to replicating the input if the sharding cannot be inferred. The valid resource assignment specifications are: - :py:class:`Sharding`, which will decide how the value will be partitioned. With this, using a mesh context manager is not required. - :py:obj:`None`, will give JAX the freedom to choose whatever sharding it wants. For in_shardings, JAX will mark is as replicated but this behavior can change in the future. For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings. The size of every dimension has to be a multiple of the total number of resources assigned to it. This is similar to pjit's in_shardings. out_shardings: Like ``in_shardings``, but specifies resource assignment for function outputs. This is similar to pjit's out_shardings. The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit` will use GSPMD's sharding propagation to figure out what the sharding of the output(s) should be. static_argnums: An optional int or collection of ints that specify which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corresponding argument values can be any Python object. Static arguments should be hashable, meaning both ``__hash__`` and ``__eq__`` are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static. If neither ``static_argnums`` nor ``static_argnames`` is provided, no arguments are treated as static. If ``static_argnums`` is not provided but ``static_argnames`` is, or vice versa, JAX uses :code:`inspect.signature(fun)` to find any positional arguments that correspond to ``static_argnames`` (or vice versa). If both ``static_argnums`` and ``static_argnames`` are provided, ``inspect.signature`` is not used, and only actual parameters listed in either ``static_argnums`` or ``static_argnames`` will be treated as static. static_argnames: An optional string or collection of strings specifying which named arguments to treat as static (compile-time constant). See the comment on ``static_argnums`` for details. If not provided but ``static_argnums`` is set, the default is based on calling ``inspect.signature(fun)`` to find corresponding named arguments. donate_argnums: Specify which positional argument buffers are "donated" to the computation. It is safe to donate argument buffers if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. By default, no argument buffers are donated. If neither ``donate_argnums`` nor ``donate_argnames`` is provided, no arguments are donated. If ``donate_argnums`` is not provided but ``donate_argnames`` is, or vice versa, JAX uses :code:`inspect.signature(fun)` to find any positional arguments that correspond to ``donate_argnames`` (or vice versa). If both ``donate_argnums`` and ``donate_argnames`` are provided, ``inspect.signature`` is not used, and only actual parameters listed in either ``donate_argnums`` or ``donate_argnames`` will be donated. For more details on buffer donation see the `FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_. donate_argnames: An optional string or collection of strings specifying which named arguments are donated to the computation. See the comment on ``donate_argnums`` for details. If not provided but ``donate_argnums`` is set, the default is based on calling ``inspect.signature(fun)`` to find corresponding named arguments. keep_unused: If `False` (the default), arguments that JAX determines to be unused by `fun` *may* be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If `True`, unused arguments will not be pruned. device: This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via :py:func:`jax.devices`.) The default is inherited from XLA's DeviceAssignment logic and is usually to use ``jax.devices()[0]``. backend: This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or ``'tpu'``. inline: Specify whether this function should be inlined into enclosing jaxprs (rather than being represented as an application of the xla_call primitive with its own subjaxpr). Default False. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references, reference semantics, and structural changes to Modules inside the jitted function. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode is faster but does not support shared ``Variable`` references or returning mutable array references from the jitted function. Returns: A wrapped version of ``fun``, set up for just-in-time compilation. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if isinstance(fun, Missing): return functools.partial( jit, in_shardings=in_shardings, out_shardings=out_shardings, static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, keep_unused=keep_unused, device=device, backend=backend, inline=inline, graph=graph, graph_updates=graph_updates, ) # type: ignore[return-value] fun_unbound, _, was_bound = _resolve_bound_callable(fun) if was_bound: _raise_bound_method_error('jit') if not graph: if any(isinstance(x, StateSharding) for x in jax.tree.leaves(in_shardings)): raise ValueError( '`in_shardings` cannot contain `StateSharding` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('jit') ) if any(isinstance(x, StateSharding) for x in jax.tree.leaves(out_shardings)): raise ValueError( '`out_shardings` cannot contain `StateSharding` objects ' 'when `graph=False`. ' + graphlib._tree_mode_suggestion_transform('jit') ) wrapped_cls: tp.Any if graph and graph_updates: wrapped_cls = JitWrapped else: wrapped_cls = functools.partial(SimpleJitWrapped, graph=graph) return wrapped_cls( fun_unbound, in_shardings=in_shardings, out_shardings=out_shardings, static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, keep_unused=keep_unused, device=device, backend=backend, inline=inline, )
@dataclasses.dataclass(frozen=True, slots=True) class PartialState: """Container for a pre-flattened partial argument. Stores the pytree structure (``treedef``) as static metadata and the flattened leaves as dynamic data. Variables within the original argument are kept as leaves so their values can change between calls without triggering recompilation. """ treedef: jax.tree_util.PyTreeDef leaves: list[tp.Any] jax.tree_util.register_dataclass( PartialState, data_fields=['leaves'], meta_fields=['treedef'], ) def _flatten_to_partial_state( arg: tp.Any, ref_index: graphlib.RefMap | None, ) -> PartialState: if ref_index is not None: graphdef, flat_state = graphlib.flatten(arg, ref_index=ref_index, graph=True) return PartialState(treedef=graphdef, leaves=flat_state.leaves) is_leaf = lambda x: isinstance(x, variablelib.Variable) leaves, treedef = jax.tree.flatten(arg, is_leaf=is_leaf) return PartialState(treedef=treedef, leaves=leaves) def _unflatten_partial_state( state: PartialState, index_ref: graphlib.IndexMap | None, ) -> tp.Any: if index_ref is not None: return graphlib.unflatten( state.treedef, state.leaves, index_ref=index_ref, copy_variables=False) return jax.tree.unflatten(state.treedef, state.leaves) @dataclasses.dataclass(eq=False) class SimpleJitFn: f: tp.Callable[..., tp.Any] out_shardings: tp.Any donate_argnums: frozenset[int] donate_argnames: frozenset[str] graph: bool def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args, **kwargs): updates, snapshot = extract.updates_and_snapshot((args, kwargs)) args_updates, kwargs_updates = updates args_snapshot, kwargs_snapshot = snapshot index_ref = graphlib.IndexMap() if self.graph else None args = tuple( _unflatten_partial_state(a, index_ref=index_ref) if isinstance(a, PartialState) else a for a in args ) if self.graph: args, kwargs = extract.from_tree2((args, kwargs)) out = self.f(*args, **kwargs) if self.graph: out = extract.to_tree2(out, prefix=self.out_shardings) extract.check_no_aliases('jit', args=args_updates, kwargs=kwargs_updates, out=out) def donated_arg(jax_path, c, s): path = graphlib.jax_to_nnx_path(jax_path) return path[0] in self.donate_argnums or extract.variable_changed(c, s) args_updates = extract.mask_variable_updates( args_updates, args_snapshot, keep_fn=donated_arg) def donated_kwarg(jax_path, c, s): path = graphlib.jax_to_nnx_path(jax_path) return path[0] in self.donate_argnames or extract.variable_changed(c, s) kwargs_updates = extract.mask_variable_updates( kwargs_updates, kwargs_snapshot, keep_fn=donated_kwarg) return out, (args_updates, kwargs_updates) class SimpleJitWrapped(tp.Generic[P, R]): def __init__( self, fun: tp.Callable[P, R], in_shardings: tp.Any, out_shardings: tp.Any, static_argnums: int | tp.Sequence[int] | None = None, static_argnames: str | tp.Iterable[str] | None = None, donate_argnums: int | tp.Sequence[int] | None = None, donate_argnames: str | tp.Iterable[str] | None = None, keep_unused: bool = False, device: tp.Optional[jax.Device] = None, backend: tp.Optional[str] = None, inline: bool = False, partial_args: tuple[PartialState, ...] = (), graph: bool = True, ): functools.update_wrapper(self, fun) self.fun: tp.Callable[P, R] = fun self.out_shardings = out_shardings self.partial_args = partial_args self.graph = graph if in_shardings is not None and isinstance(in_shardings, (tuple, list)) and ( static_argnums or static_argnames ): resolved = _resolve_argnums(fun, static_argnums, static_argnames) expanded = list(in_shardings) for i in sorted(resolved): expanded.insert(i, None) self.in_shardings = tuple(expanded) else: self.in_shardings = in_shardings jit_out_shardings: tp.Any if in_shardings is not None or out_shardings is not None: if isinstance(in_shardings, (tuple, list)) and ( static_argnums or static_argnames ): resolved = _resolve_argnums(fun, static_argnums, static_argnames) expanded = list(in_shardings) for i in sorted(resolved): expanded.insert(i, None) out_in_shardings = tuple(expanded) else: out_in_shardings = in_shardings jit_out_shardings = (out_shardings, (out_in_shardings, None)) else: jit_out_shardings = None donate_argnums_set = frozenset( (donate_argnums,) if isinstance(donate_argnums, int) else donate_argnums or () ) donate_argnames_set = frozenset( (donate_argnames,) if isinstance(donate_argnames, str) else donate_argnames or () ) self.jitted_fn = jax.jit( SimpleJitFn(fun, out_shardings, donate_argnums_set, donate_argnames_set, graph), in_shardings=in_shardings, out_shardings=jit_out_shardings, static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, keep_unused=keep_unused, device=device, backend=backend, inline=inline, ) def _maybe_to_tree(self, args, kwargs): if self.graph: args, kwargs = extract.to_tree2( (args, kwargs), prefix=(self.in_shardings, None) if self.in_shardings is not None else None, check_aliasing=self.in_shardings is not None, ) return args, kwargs def _maybe_from_tree(self, out): if self.graph: out = extract.from_tree2(out) return out def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: args, kwargs = self._maybe_to_tree(args, kwargs) extract.check_no_aliases('jit', args=args, kwargs=kwargs) out, updates = self.jitted_fn(*self.partial_args, *args, **kwargs) extract.apply_variable_updates( ((*self.partial_args, *args), kwargs), updates) return self._maybe_from_tree(out) def __get__(self, obj, objtype=None): if obj is None: return self return functools.partial(self, obj) def eval_shape(self, *args, **kwargs): args, kwargs = self._maybe_to_tree(args, kwargs) out, updates = self.jitted_fn.eval_shape( *self.partial_args, *args, **kwargs) return self._maybe_from_tree(out) def trace(self, *args, **kwargs): args, kwargs = self._maybe_to_tree(args, kwargs) traced = self.jitted_fn.trace(*self.partial_args, *args, **kwargs) return SimpleTraced(traced, self) def lower(self, *args, **kwargs): args, kwargs = self._maybe_to_tree(args, kwargs) lowered = self.jitted_fn.lower(*self.partial_args, *args, **kwargs) return SimpleLowered(lowered, self) def jit_partial( fun: tp.Callable[..., R], *partial_args: tp.Any, in_shardings: tp.Any = None, out_shardings: tp.Any = None, donate_argnums: int | tp.Sequence[int] | None = None, donate_argnames: str | tp.Iterable[str] | None = None, keep_unused: bool = False, device: tp.Optional[jax.Device] = None, backend: tp.Optional[str] = None, inline: bool = False, graph: bool | None = None, graph_updates: bool | None = None, ) -> SimpleJitWrapped[..., R]: """JIT-compile ``fun`` with pre-flattened partial arguments. Similar to ``nnx.cached_partial`` but designed for tree-mode (``graph=False``). Each ``partial_arg`` is flattened into a ``PartialState`` whose pytree structure is fixed at construction time. Variable values inside partial arguments can still change between calls without triggering recompilation, and any mutations to Variables are propagated back to the originals after each call. Example usage:: >>> from flax import nnx >>> import jax.numpy as jnp >>> import optax ... >>> x, y = jnp.ones((4, 2)), jnp.ones((4, 3)) >>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> optimizer = nnx.Optimizer(model, optax.adamw(1e-3), wrt=nnx.Param) ... >>> def train_step(model, optimizer, x, y): ... def loss_fn(model): ... return jnp.mean((model(x) - y) ** 2) ... loss, grads = nnx.value_and_grad(loss_fn)(model) ... optimizer.update(model, grads) ... return loss ... >>> train_step_fn = nnx.jit_partial(train_step, model, optimizer, graph=False) ... >>> loss = train_step_fn(x, y) Args: fun: The function to JIT-compile. *partial_args: Arguments to be pre-flattened and bound. These must appear as the first positional arguments of ``fun``. in_shardings: Sharding specification for inputs. When a tuple/list, the first ``len(partial_args)`` entries correspond to partial arguments and are broadcast against their original pytree structure. A non-tuple value (e.g. a single ``PartitionSpec``) is passed through directly to ``jax.jit`` and broadcast across all arguments uniformly. out_shardings: Like ``in_shardings``, but for function outputs. donate_argnums: Positional argument indices whose buffers may be donated to the computation. donate_argnames: Named arguments whose buffers may be donated. keep_unused: If ``True``, unused arguments are not pruned. device: Optional device to run on. backend: Optional backend to use. inline: If ``True``, inline the function. graph: If ``None``, uses the ``nnx_graph_mode`` config value. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. When ``False``, using ``StateSharding`` is not supported. Returns: A callable expecting the remaining (runtime) arguments. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if graph_updates and graph: raise ValueError( '`graph_updates` not supported by `jit_partial`' ) if any(isinstance(x, StateSharding) for x in jax.tree.leaves(in_shardings)): raise ValueError( '`in_shardings` cannot contain `StateSharding` objects ' 'in `jit_partial`' ) if any(isinstance(x, StateSharding) for x in jax.tree.leaves(out_shardings)): raise ValueError( '`out_shardings` cannot contain `StateSharding` objects ' 'in `jit_partial`' ) is_variable = lambda x: isinstance(x, variablelib.Variable) ref_index = graphlib.RefMap() if graph else None flat_partial_args = tuple( _flatten_to_partial_state(arg, ref_index=ref_index) for arg in partial_args ) jit_in_shardings: tp.Any = None if in_shardings is not None and isinstance(in_shardings, (tuple, list)) and not graph: num_partial = len(partial_args) partial_shardings = in_shardings[:num_partial] runtime_shardings = in_shardings[num_partial:] flat_partial_shardings = [] for flat_arg, orig_arg, sharding in zip( flat_partial_args, partial_args, partial_shardings): broadcasted = extract.broadcast_prefix( sharding, orig_arg, prefix_is_leaf=lambda x: x is None or isinstance(x, variablelib.Variable), tree_is_leaf=is_variable, ) flat_partial_shardings.append( PartialState(treedef=flat_arg.treedef, leaves=broadcasted) ) jit_in_shardings = (*flat_partial_shardings, *runtime_shardings) else: jit_in_shardings = in_shardings return SimpleJitWrapped( fun, in_shardings=jit_in_shardings, out_shardings=out_shardings, donate_argnums=donate_argnums, donate_argnames=donate_argnames, keep_unused=keep_unused, device=device, backend=backend, inline=inline, partial_args=flat_partial_args, graph=graph, ) class JitWrapped(tp.Generic[P, R]): """A function ready to be traced, lowered, and compiled. This protocol reflects the output of functions such as ``jax.jit``. Calling it results in JIT (just-in-time) lowering, compilation, and execution. It can also be explicitly lowered prior to compilation, and the result compiled prior to execution. """ def __init__( self, fun: tp.Callable[P, R], in_shardings: tp.Any, out_shardings: tp.Any, static_argnums: int | tp.Sequence[int] | None = None, static_argnames: str | tp.Iterable[str] | None = None, donate_argnums: int | tp.Sequence[int] | None = None, donate_argnames: str | tp.Iterable[str] | None = None, keep_unused: bool = False, device: tp.Optional[jax.Device] = None, backend: tp.Optional[str] = None, inline: bool = False, ): functools.update_wrapper(self, fun) self.fun: tp.Callable[P, R] = fun kwarg_shardings = None self.jax_in_shardings = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.shardings, metadata=x) if isinstance(x, StateSharding) else x, in_shardings, ) self.jax_out_shardings = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.shardings, metadata=x) if isinstance(x, StateSharding) else x, out_shardings, ) if isinstance(in_shardings, (tuple, list)) and (static_argnums or static_argnames): # We should reintroduce None values into in_shardings corresponding to static arguments static_argnums = _resolve_argnums(fun, static_argnums, static_argnames) in_shardings = list(in_shardings) for static_arg_index in sorted(static_argnums): in_shardings.insert(static_arg_index, None) in_shardings = tuple(in_shardings) jax_out_in_shardings = jax.tree.map( lambda x: extract.NodeStates.from_prefixes(x.shardings, metadata=x) if isinstance(x, StateSharding) else x, in_shardings, ) self.jitted_fn = jax.jit( JitFn(fun, in_shardings, out_shardings, kwarg_shardings, self), in_shardings=self.jax_in_shardings, out_shardings=(jax_out_in_shardings, kwarg_shardings, self.jax_out_shardings), static_argnums=static_argnums, static_argnames=static_argnames, donate_argnums=donate_argnums, donate_argnames=donate_argnames, keep_unused=keep_unused, device=device, backend=backend, inline=inline, ) self.in_shardings = in_shardings self.out_shardings = out_shardings self.kwarg_shardings = kwarg_shardings self.static_argnums = static_argnums # implement descriptor protocol so that we can use this as a method def __get__(self, obj, objtype=None): if obj is None: return self return functools.partial(self, obj) def _get_pure_args_kwargs(self, args, kwargs): pure_args, pure_kwargs = extract.to_tree( (args, kwargs), prefix=(self.in_shardings, self.kwarg_shardings) if self.in_shardings is not None or self.kwarg_shardings is not None else None, split_fn=_jit_split_fn, check_aliasing=self.in_shardings is not None or self.kwarg_shardings is not None, ctxtag=self, ) return pure_args, pure_kwargs def _get_non_pure_out(self, pure_args_out, pure_kwargs_out, pure_out, /): _args_out, _kwargs_out, out = extract.from_tree( (pure_args_out, pure_kwargs_out, pure_out), merge_fn=_jit_merge_fn, is_inner=False, ctxtag=self, ) return out def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: # run dynamic_cache_context before update_context with graphlib.update_context(self): pure_args, pure_kwargs = self._get_pure_args_kwargs(args, kwargs) pure_args_out, pure_kwargs_out, pure_out = self.jitted_fn( *pure_args, **pure_kwargs ) out = self._get_non_pure_out(pure_args_out, pure_kwargs_out, pure_out) return out def eval_shape(self, *args, **kwargs): """See ``jax.eval_shape``.""" args, kwargs = graphlib.clone((args, kwargs)) with graphlib.update_context(self): pure_args, pure_kwargs = self._get_pure_args_kwargs(args, kwargs) pure_args_out, pure_kwargs_out, pure_out = self.jitted_fn.eval_shape( *pure_args, **pure_kwargs ) out = self._get_non_pure_out(pure_args_out, pure_kwargs_out, pure_out) return out def trace(self, *args, **kwargs) -> Traced: """Trace this function explicitly for the given arguments. A traced function is staged out of Python and translated to a jaxpr. It is ready for lowering but not yet lowered. Returns: A ``Traced`` instance representing the tracing. """ with graphlib.update_context(self): pure_args, pure_kwargs = self._get_pure_args_kwargs(args, kwargs) traced = self.jitted_fn.trace(*pure_args, **pure_kwargs) return Traced(traced, self) def lower(self, *args, **kwargs) -> Lowered: """Lower this function explicitly for the given arguments. This is a shortcut for ``self.trace(*args, **kwargs).lower()``. A lowered function is staged out of Python and translated to a compiler's input language, possibly in a backend-dependent manner. It is ready for compilation but not yet compiled. Returns: A ``Lowered`` instance representing the lowering. """ with graphlib.update_context(self): pure_args, pure_kwargs = self._get_pure_args_kwargs(args, kwargs) lowered = self.jitted_fn.lower(*pure_args, **pure_kwargs) return Lowered(lowered, self) class Stage: args_info: tp.Any # PyTree of ArgInfo @property def _inner_obj(self) -> tp.Any: raise NotImplementedError @property def in_tree(self) -> jax.tree_util.PyTreeDef: return self._inner_obj.in_tree @property def in_avals(self): return self._inner_obj.in_avals @property def donate_argnums(self): return self._inner_obj.donate_argnums @dataclasses.dataclass(frozen=True, slots=True) class Compiled(Stage): """Compiled representation of a function specialized to types/values. A compiled computation is associated with an executable and the remaining information needed to execute it. It also provides a common API for querying properties of compiled computations across JAX's various compilation paths and backends. """ compiled: jax.stages.Compiled jit_wrapped: JitWrapped @property def _inner_obj(self): return self.compiled @property def args_info(self) -> tp.Any: # PyTree of ArgInfo raise self.compiled.args_info @staticmethod def call(*args, **kwargs): raise NotImplementedError def __call__(self, *args, **kwargs): with graphlib.update_context(self.jit_wrapped): pure_args, pure_kwargs = self.jit_wrapped._get_pure_args_kwargs( args, kwargs ) pure_args_out, pure_kwargs_out, pure_out = self.compiled( *pure_args, **pure_kwargs ) out = self.jit_wrapped._get_non_pure_out( pure_args_out, pure_kwargs_out, pure_out ) return out @property def out_tree(self) -> jax.tree_util.PyTreeDef: return self.compiled.out_tree def as_text(self) -> str | None: """A human-readable text representation of this executable. Intended for visualization and debugging purposes. This is not a valid nor reliable serialization. Returns ``None`` if unavailable, e.g. based on backend, compiler, or runtime. """ return self.compiled.as_text() def cost_analysis(self) -> tp.Any | None: """A summary of execution cost estimates. Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations. Returns ``None`` if unavailable, e.g. based on backend, compiler, or runtime. """ return self.compiled.cost_analysis() def memory_analysis(self) -> tp.Any | None: """A summary of estimated memory requirements. Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations. Returns ``None`` if unavailable, e.g. based on backend, compiler, or runtime. """ return self.compiled.memory_analysis() def runtime_executable(self) -> tp.Any | None: """An arbitrary object representation of this executable. Intended for debugging purposes. This is not valid nor reliable serialization. The output has no guarantee of consistency across invocations. Returns ``None`` if unavailable, e.g. based on backend, compiler, or runtime. """ return self.compiled.runtime_executable() @property def input_shardings(self): # PyTree[sharding.Sharding] return self.compiled.input_shardings @property def output_shardings(self): # PyTree[sharding.Sharding] return self.compiled.output_shardings @property def input_layouts(self): return self.compiled.input_formats @dataclasses.dataclass(frozen=True, slots=True) class Lowered(Stage): """Lowering of a function specialized to argument types and values. A lowering is a computation ready for compilation. This class carries a lowering together with the remaining information needed to later compile and execute it. It also provides a common API for querying properties of lowered computations across JAX's various lowering paths (:func:`~jax.jit`, :func:`~jax.pmap`, etc.). """ lowered: jax.stages.Lowered jit_wrapped: JitWrapped @property def _inner_obj(self): return self.lowered @property def args_info(self) -> tp.Any: # PyTree of ArgInfo return self.lowered.args_info @property def out_tree(self): return self.lowered.out_tree @classmethod def from_flat_info( cls, lowering: tp.Any, # type: ignore[name-defined] in_tree: jax.tree_util.PyTreeDef, in_avals, donate_argnums: tuple[int, ...], out_tree: jax.tree_util.PyTreeDef, no_kwargs: bool = False, ): raise NotImplementedError def compile( self, compiler_options: jax.stages.CompilerOptions | None = None ) -> Compiled: """Compile, returning a corresponding ``Compiled`` instance.""" compiled = self.lowered.compile(compiler_options) return Compiled(compiled, self.jit_wrapped) def as_text( self, dialect: str | None = None, *, debug_info: bool = False ) -> str: """A human-readable text representation of this lowering. Intended for visualization and debugging purposes. This need not be a valid nor reliable serialization. Use `jax.export` if you want reliable and portable serialization. Args: dialect: Optional string specifying a lowering dialect (e.g. "stablehlo", or "hlo"). debug_info: Whether to include debugging information, e.g., source location. """ return self.lowered.as_text(dialect=dialect, debug_info=debug_info) def compiler_ir(self, dialect: str | None = None) -> tp.Any | None: """An arbitrary object representation of this lowering. Intended for debugging purposes. This is not a valid nor reliable serialization. The output has no guarantee of consistency across invocations. Use `jax.export` if you want reliable and portable serialization. Returns ``None`` if unavailable, e.g. based on backend, compiler, or runtime. Args: dialect: Optional string specifying a lowering dialect (e.g. "stablehlo", or "hlo"). """ return self.lowered.compiler_ir(dialect=dialect) def cost_analysis(self) -> tp.Any | None: """A summary of execution cost estimates. Intended for visualization and debugging purposes. The object output by this is some simple data structure that can easily be printed or serialized (e.g. nested dicts, lists, and tuples with numeric leaves). However, its structure can be arbitrary: it may be inconsistent across versions of JAX and jaxlib, or even across invocations. Returns ``None`` if unavailable, e.g. based on backend, compiler, or runtime. """ return self.lowered.cost_analysis() @dataclasses.dataclass(frozen=True, slots=True) class Traced(Stage): """Traced form of a function specialized to argument types and values. A traced computation is ready for lowering. This class carries the traced representation with the remaining information needed to later lower, compile, and execute it. """ traced: jax.stages.Traced jit_wrapped: JitWrapped @property def _inner_obj(self): return self.traced @property def out_info(self): return self.traced.out_info def lower( self, *, lowering_platforms: tuple[str, ...] | None = None ) -> Lowered: """Lower to compiler input, returning a ``Lowered`` instance.""" lowered = self.traced.lower(lowering_platforms=lowering_platforms) return Lowered(lowered, self.jit_wrapped) @dataclasses.dataclass(frozen=True, slots=True) class SimpleCompiled(Stage): compiled: jax.stages.Compiled jit_wrapped: SimpleJitWrapped @property def _inner_obj(self): return self.compiled @property def args_info(self) -> tp.Any: raise self.compiled.args_info @staticmethod def call(*args, **kwargs): raise NotImplementedError def __call__(self, *args, **kwargs): args, kwargs = self.jit_wrapped._maybe_to_tree(args, kwargs) extract.check_no_aliases('jit', args=args, kwargs=kwargs) out, updates = self.compiled(*args, **kwargs) extract.apply_variable_updates((args, kwargs), updates) return self.jit_wrapped._maybe_from_tree(out) @property def out_tree(self) -> jax.tree_util.PyTreeDef: return self.compiled.out_tree def as_text(self) -> str | None: return self.compiled.as_text() def cost_analysis(self) -> tp.Any | None: return self.compiled.cost_analysis() def memory_analysis(self) -> tp.Any | None: return self.compiled.memory_analysis() def runtime_executable(self) -> tp.Any | None: return self.compiled.runtime_executable() @property def input_shardings(self): return self.compiled.input_shardings @property def output_shardings(self): return self.compiled.output_shardings @property def input_layouts(self): return self.compiled.input_formats @dataclasses.dataclass(frozen=True, slots=True) class SimpleLowered(Stage): lowered: jax.stages.Lowered jit_wrapped: SimpleJitWrapped @property def _inner_obj(self): return self.lowered @property def args_info(self) -> tp.Any: return self.lowered.args_info @property def out_tree(self): return self.lowered.out_tree def compile( self, compiler_options: jax.stages.CompilerOptions | None = None ) -> SimpleCompiled: compiled = self.lowered.compile(compiler_options) return SimpleCompiled(compiled, self.jit_wrapped) def as_text( self, dialect: str | None = None, *, debug_info: bool = False ) -> str: return self.lowered.as_text(dialect=dialect, debug_info=debug_info) def compiler_ir(self, dialect: str | None = None) -> tp.Any | None: return self.lowered.compiler_ir(dialect=dialect) def cost_analysis(self) -> tp.Any | None: return self.lowered.cost_analysis() @dataclasses.dataclass(frozen=True, slots=True) class SimpleTraced(Stage): traced: jax.stages.Traced jit_wrapped: SimpleJitWrapped @property def _inner_obj(self): return self.traced @property def out_info(self): return self.traced.out_info def lower( self, *, lowering_platforms: tuple[str, ...] | None = None ) -> SimpleLowered: lowered = self.traced.lower(lowering_platforms=lowering_platforms) return SimpleLowered(lowered, self.jit_wrapped) # ------------------------------- # shard_map # ------------------------------- # TODO: create StateSpec and consider enabling a mode that does # not use filters during split for performance. Overall there might # be performance limitations for using shard_map at a top-level @dataclasses.dataclass(eq=False) class SimpleShardMapFn: f: tp.Callable[..., tp.Any] graph: bool out_specs: tp.Any def __post_init__(self): functools.update_wrapper(self, self.f, updated=()) @extract.treemap_copy_args def __call__(self, *args): updates, snapshot = extract.updates_and_snapshot(args) if self.graph: args = extract.from_tree2(args) out = self.f(*args) if self.graph: out = extract.to_tree2(out, prefix=self.out_specs) extract.check_no_aliases('shard_map', args=updates, out=out) updates = extract.mask_variable_updates(updates, snapshot) return out, updates @dataclasses.dataclass(eq=False) class ShardMapFn: f: tp.Callable[..., tp.Any] in_specs: tp.Any out_specs: tp.Any kwarg_specs: tp.Any ctxtag: tp.Hashable def __post_init__(self): functools.update_wrapper(self, self.f) def __call__(self, *pure_args, **pure_kwargs): args, kwargs = extract.from_tree( (pure_args, pure_kwargs), merge_fn=_jit_merge_fn, ctxtag=self.ctxtag, is_inner=True, ) out = self.f(*args, **kwargs) args_out, kwargs_out = extract.clear_non_graph_nodes((args, kwargs)) pure_args_out, pure_kwargs_out, pure_out = extract.to_tree( (args_out, kwargs_out, out), prefix=(self.in_specs, self.kwarg_specs, self.out_specs), ctxtag=self.ctxtag, split_fn=_jit_split_fn, ) return pure_args_out, pure_kwargs_out, pure_out @tp.overload def shard_map( f: F, *, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs, axis_names: tp.AbstractSet[AxisName] = frozenset(), check_vma: bool = True, graph: bool | None = None, graph_updates: bool | None = None, ) -> F: ... @tp.overload def shard_map( *, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs, axis_names: tp.AbstractSet[AxisName] = frozenset(), check_vma: bool = True, graph: bool | None = None, graph_updates: bool | None = None, ) -> tp.Callable[[F], F]: ...
[docs]def shard_map( f: F | type[Missing] = Missing, *, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs, axis_names: tp.AbstractSet[AxisName] = frozenset(), check_vma: bool = True, graph: bool | None = None, graph_updates: bool | None = None, ) -> F | tp.Callable[[F], F]: """ Lifted version of `jax.shard_map <https://docs.jax.dev/en/latest/_autosummary/jax.shard_map.html>`_ that can handle Modules / graph nodes as arguments. Simple data parallel example:: import jax import jax.numpy as jnp from flax import nnx from jax.sharding import PartitionSpec as P mesh = jax.sharding.Mesh(jax.local_devices(), ('data',)) m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) x = jnp.ones((32, 2)) @nnx.shard_map( mesh=mesh, in_specs=(P(None), P('data')), out_specs=P('data') ) def f(m, x): return m(x) y = f(m, x) jax.debug.visualize_array_sharding(y) Notice that here we simply used some ``PartitionSpec`` to define the spec the the whole model and data. This works for simple cases but if we need to assign different ``PartitionSpec`` to different parts of the model we need to use ``StateSharding`` and create some filters that allow us to target specific parts of the model. Here's an example of how to do tensor parallelism for a simple MLP block using ``StateSharding`` and filters:: mesh = jax.sharding.Mesh(jax.local_devices(), ('model',)) class MLP(nnx.Module): def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs) self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs) def __call__(self, x): return self.linear2(jax.nn.relu(self.linear1(x))) m = MLP(2, 64, 3, rngs=nnx.Rngs(0)) x = jnp.ones((32, 2)) def path_ends_with(*path_suffix): # custom filter return lambda path, value: path[-len(path_suffix):] == path_suffix model_spec = nnx.StateSharding({ path_ends_with('linear1', 'kernel'): P(None, 'model'), path_ends_with('linear2', 'kernel'): P('model', None), }) @nnx.shard_map(mesh=mesh, in_specs=(model_spec, P(None)), out_specs=P(None)) def f(m, x): y = m(x) return jax.lax.psum(y, 'model') y = f(m, x) jax.debug.visualize_array_sharding(m.linear1.kernel[...]) jax.debug.visualize_array_sharding(m.linear2.kernel[...]) Alternatively, a ``State`` object with the exact PartitionSpec for each state then you can be passed to ``StateSharding``:: mesh = jax.sharding.Mesh(jax.local_devices(), ('model',)) class MLP(nnx.Module): def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): self.linear1 = nnx.Linear(din, dhidden, use_bias=False, rngs=rngs) self.linear2 = nnx.Linear(dhidden, dout, use_bias=False, rngs=rngs) def __call__(self, x): return self.linear2(jax.nn.relu(self.linear1(x))) m = MLP(2, 64, 3, rngs=nnx.Rngs(0)) x = jnp.ones((32, 2)) model_spec = nnx.State( { 'linear1': {'kernel': P(None, 'model')}, 'linear2': {'kernel': P('model', None)}, } ) @nnx.shard_map( mesh=mesh, in_specs=(nnx.StateSharding(model_spec), P(None)), out_specs=P(None), ) def f(m, x): y = m(x) return jax.lax.psum(y, 'model') y = f(m, x) jax.debug.visualize_array_sharding(m.linear1.kernel[...]) jax.debug.visualize_array_sharding(m.linear2.kernel[...]) Here ``model_spec`` was created manually but you can also automate this process by using ``nnx.get_partition_spec`` to automatically create it for you (see `Scale up on multiple devices <https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html>`_ ). Args: f: callable to be mapped. Each application of ``f``, or "instance" of ``f``, takes as input a shard of the mapped-over arguments and produces a shard of the output. mesh: a ``jax.sharding.Mesh`` representing the array of devices over which to shard the data and on which to execute instances of ``f``. The names of the ``Mesh`` can be used in collective communication operations in ``f``. This is typically created by a utility function like :func:`jax.experimental.mesh_utils.create_device_mesh`. in_specs: a pytree with ``jax.sharding.PartitionSpec``or ``nnx.StateSharding`` (mapping substates to ``PartitionSpec``s) instances as leaves, with a tree structure that is a tree prefix of the args tuple to be mapped over. Similar to ``jax.sharding.NamedSharding``, each ``PartitionSpec`` represents how the corresponding argument (or subtree of arguments) should be sharded along the named axes of ``mesh``. In each ``PartitionSpec``, mentioning a ``mesh`` axis name at a position expresses sharding the corresponding argument array axis along that positional axis; not mentioning an axis name expresses replication. If an argument, or argument subtree, has a corresponding spec of None, that argument is not sharded. out_specs: a pytree with ``jax.sharding.PartitionSpec`` or ``nnx.StateSharding`` (mapping substates to ``PartitionSpec``s) instances as leaves, with a tree structure that is a tree prefix of the output of ``f``. Each ``PartitionSpec`` represents how the corresponding output shards should be concatenated. In each ``PartitionSpec``, metioning a ``mesh`` axis name at a position expresses concatenation of that mesh axis's shards along the corresponding positional axis. Not mentioning a ``mesh`` axis name expresses a promise that the output values are equal along that mesh axis, and that rather than concatenating only a single value should be produced. axis_names: optional set of axis names from ``mesh`` over which the function ``f`` is manual. If empty, ``f``, is manual over all mesh axes. check_vma: optional boolean representing whether to enable additional validity checks and automatic differentiation optimizations. The validity checks concern whether any mesh axis names not mentioned in ``out_specs`` are consistent with how the outputs of ``f`` are replicated. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references and reference semantics. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Tree-mode does not support ``StateSharding`` or shared ``Variable`` references. graph_updates: If ``True``, propagates updates on graph structure that happen inside the transform to the input graphs, has no effect when ``graph=False``. When ``False``, using ``StateSharding`` is not supported. Returns: A callable that applies the input function ``f`` across data sharded according to the ``mesh`` and ``in_specs``. """ if graph is None: graph = graphlib.set_graph_mode.current_value() if graph_updates is None: graph_updates = graphlib.set_graph_updates.current_value() if f is Missing: return functools.partial( shard_map, mesh=mesh, in_specs=in_specs, out_specs=out_specs, axis_names=axis_names, check_vma=check_vma, graph=graph, graph_updates=graph_updates, ) # type: ignore[return-value] assert not isinstance(f, type) f_unbound, _, was_bound = _resolve_bound_callable(f) if was_bound: _raise_bound_method_error('shard_map') if not graph or not graph_updates: if any(isinstance(x, StateSharding) for x in jax.tree.leaves(in_specs)): raise ValueError( '`in_specs` cannot contain `StateSharding` objects ' 'when `graph=False`' ) if any(isinstance(x, StateSharding) for x in jax.tree.leaves(out_specs)): raise ValueError( '`out_specs` cannot contain `StateSharding` objects ' 'when `graph=False`' ) shard_map_fn = jax.shard_map( SimpleShardMapFn(f_unbound, graph=graph, out_specs=out_specs), mesh=mesh, in_specs=in_specs, out_specs=(out_specs, in_specs), axis_names=axis_names, check_vma=check_vma, ) @functools.wraps(f) def shard_map_wrapper(*args, **kwargs): if graph: args = extract.to_tree2( args, prefix=in_specs, check_aliasing=in_specs is not None, ) extract.check_no_aliases('shard_map', args=args) out, updates = shard_map_fn(*args, **kwargs) extract.apply_variable_updates(args, updates) if graph: out = extract.from_tree2(out) return out shard_map_wrapper.inner = shard_map_fn # type: ignore return shard_map_wrapper # type: ignore kwarg_specs = PartitionSpec() jax_in_specs = jax.tree.map( lambda x: extract.NodeStates( _graphdef=PartitionSpec(), # type: ignore[arg-type] states=x.shardings, metadata=x, ) if isinstance(x, StateSharding) else x, in_specs, ) jax_out_specs = jax.tree.map( lambda x: extract.NodeStates( _graphdef=PartitionSpec(), # type: ignore[arg-type] states=x.shardings, metadata=x, ) if isinstance(x, StateSharding) else x, out_specs, ) @functools.wraps(f) # type: ignore[no-redef] def shard_map_wrapper(*args, **kwargs): with graphlib.update_context(shard_map_wrapper): pure_args, pure_kwargs = extract.to_tree( (args, kwargs), prefix=(in_specs, kwarg_specs) if in_specs is not None or kwarg_specs is not None else None, split_fn=_jit_split_fn, check_aliasing=in_specs is not None or kwarg_specs is not None, ctxtag=shard_map_wrapper, ) pure_args_out, pure_kwargs_out, pure_out = shard_map_fn( *pure_args, **pure_kwargs ) _args_out, _kwargs_out, out = extract.from_tree( (pure_args_out, pure_kwargs_out, pure_out), merge_fn=_jit_merge_fn, is_inner=False, ctxtag=shard_map_wrapper, ) return out shard_map_fn = jax.shard_map( ShardMapFn(f_unbound, in_specs, out_specs, kwarg_specs, shard_map_wrapper), mesh=mesh, in_specs=jax_in_specs, out_specs=(jax_in_specs, kwarg_specs, jax_out_specs), # type: ignore axis_names=axis_names, check_vma=check_vma, ) shard_map_wrapper.inner = shard_map_fn # type: ignore return shard_map_wrapper # type: ignore
# We can't use private methods from jax._src.api_util # We copy the function: api_util.fun_signature def _fun_signature(fun: tp.Callable) -> inspect.Signature | None: try: return inspect.signature(fun) except (ValueError, TypeError): return None # Adapted copy of private jax function from api_util: fun_signature def _resolve_argnums( fun: tp.Callable, static_argnums: int | tp.Sequence[int] | None, static_argnames: str | tp.Iterable[str] | None, ) -> tuple[int, ...]: def _ensure_index_tuple(x: tp.Any) -> tuple[int, ...]: """Convert x to a tuple of indices.""" try: return (operator.index(x),) except TypeError: return tuple(map(operator.index, x)) def _ensure_str(x: str) -> str: if not isinstance(x, str): raise TypeError(f"argument is not a string: {x}") return x def _ensure_str_tuple(x: str | tp.Iterable[str]) -> tuple[str, ...]: """Convert x to a tuple of strings.""" if isinstance(x, str): return (x,) else: return tuple(map(_ensure_str, x)) signature = _fun_signature(fun) if signature is None: # Some built-in functions don't support signature. # See: https://github.com/python/cpython/issues/73485 # In this case no validation is done static_argnums = () if static_argnums is None else _ensure_index_tuple( static_argnums) else: # Infer argnums and argnames according to docstring # If nums is None and names is not None, then nums are inferred from the # names and vice-versa. _POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD _POSITIONAL_ARGUMENTS = ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD ) def infer_argnums_and_argnames( sig: inspect.Signature, argnums: int | tp.Iterable[int] | None, argnames: str | tp.Iterable[str] | None, ) -> tuple[tuple[int, ...], tuple[str, ...]]: """Infer missing argnums and argnames for a function with inspect.""" if argnums is None and argnames is None: return (), () if argnums is not None and argnames is not None: argnums = _ensure_index_tuple(argnums) argnames = _ensure_str_tuple(argnames) return argnums, argnames parameters = sig.parameters if argnums is None: assert argnames is not None argnames = _ensure_str_tuple(argnames) argnums = tuple( i for i, (k, param) in enumerate(parameters.items()) if param.kind == _POSITIONAL_OR_KEYWORD and k in argnames ) else: argnums = _ensure_index_tuple(argnums) argnames = tuple( k for i, (k, param) in enumerate(parameters.items()) if param.kind == _POSITIONAL_OR_KEYWORD and i in argnums ) return argnums, argnames def _validate_argnums(sig: inspect.Signature, argnums: tuple[int, ...], argnums_name: str) -> None: n_pos_args = 0 for param in sig.parameters.values(): if param.kind in _POSITIONAL_ARGUMENTS: n_pos_args += 1 elif param.kind is inspect.Parameter.VAR_POSITIONAL: # We can have any number of positional arguments return if argnums and (-min(argnums) > n_pos_args or max(argnums) >= n_pos_args): raise ValueError(f"Jitted function has {argnums_name}={argnums}, " f"but only accepts {n_pos_args} positional arguments.") static_argnums, static_argnames = infer_argnums_and_argnames( signature, static_argnums, static_argnames) # Validation _validate_argnums(signature, static_argnums, "static_argnums") return static_argnums