Source code for flax.nnx.variablelib

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
from __future__ import annotations

import dataclasses
import functools
from functools import partial
import itertools as it
import threading
import typing as tp
from typing import Any
import warnings

from flax import config
from flax import errors
from flax.core import spmd as core_spmd
from flax.nnx import reprlib, tracers, visualization
from flax.typing import BaseConfigContext, MISSING, Missing, SizeBytes
import jax
from jax._src.state.types import AbstractRef
import jax.experimental
from jax.experimental import hijax as hjx
import jax.tree_util as jtu
import treescope  # type: ignore[import-untyped]

A = tp.TypeVar('A')
B = tp.TypeVar('B')
C = tp.TypeVar('C')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
P = tp.TypeVar('P', bound=property)
V = tp.TypeVar('V', bound='Variable[Any]')
GetValueHook = tp.Callable[['Variable[A]', A], A]
SetValueHook = tp.Callable[['Variable[A]', A], A]
CreateValueHook = tp.Callable[['Variable[A]', A], A]
AxisName = str
AxisIndex = int
AddAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None]
RemoveAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None]

# JAX array refs were renamed a few times between JAX v0.7.0 and v0.8.0.
# The following ensures we avoid an ImportError or DeprecationWarning.
if hasattr(jax, 'new_ref') and hasattr(jax, 'Ref'):
  # JAX v0.7.2 or newer
  from jax import Ref
elif hasattr(jax, 'array_ref') and hasattr(jax, 'ArrayRef'):
  # JAX v0.7.1
  from jax import ArrayRef as Ref  # type: ignore[import-untyped,no-redef]
else:
  # JAX v0.7.0 or older
  from jax.experimental import MutableArray as Ref  # type: ignore[no-redef]


@dataclasses.dataclass
class VariableContext(threading.local):
  variable_hijax_stack: list[bool] = dataclasses.field(default_factory=list)
  variable_ref_stack: list[bool] = dataclasses.field(default_factory=list)
  eager_shard_stack: list[bool] = dataclasses.field(default_factory=list)


VARIABLE_CONTEXT = VariableContext()


class use_eager_sharding(BaseConfigContext):
  """Sets whether Variables should use eager sharding by default or not.

  Example usage::

    >>> from flax import nnx
    >>> # Use eager sharding by default
    >>> nnx.use_eager_sharding(True)
    <...>
    >>> # Variable will now use eager sharding
    >>> nnx.using_eager_sharding()
    True

  It can also be used as a context manager to temporarily
  change the default behavior for a block of code::

    >>> nnx.use_eager_sharding(False)
    <...>
    >>> with nnx.use_eager_sharding(True):
    ...   nnx.using_eager_sharding()
    True
    >>> # it will reset outside
    >>> v = nnx.Variable(jax.numpy.ones((2, 3)))
    >>> nnx.using_eager_sharding()
    False

  Args:
    value: A boolean indicating if Variables should use eager sharding by default.

  Returns:
    A context manager that resets the context to the previous value.
  """
  get_default = classmethod(lambda cls: config.flax_always_shard_variable)
  get_stack = classmethod(lambda cls: VARIABLE_CONTEXT.eager_shard_stack)


def using_eager_sharding() -> bool:
  """Returns whether Variables are using eager sharding by default.

  Example::

    >>> from flax import nnx
    >>> nnx.use_eager_sharding(True)
    <...>
    >>> nnx.using_eager_sharding()
    True
    >>> nnx.use_eager_sharding(False)
    <...>
    >>> nnx.using_eager_sharding()
    False


  Returns:
    A boolean indicating if Variables are using eager sharding by default.
  """
  return use_eager_sharding.current_value()


@dataclasses.dataclass(frozen=True)
class VarDefaults(tp.Mapping[str, tp.Any]):
  hijax: bool
  ref: bool

  def __getitem__(self, key: str) -> tp.Any:
    return getattr(self, key)

  def __iter__(self) -> tp.Iterator[str]:
    return iter(dataclasses.asdict(self))

  def __len__(self) -> int:
    return len(dataclasses.fields(self))


@tp.overload
def var_defaults() -> VarDefaults: ...


@tp.overload
def var_defaults(
  *, hijax: bool | None = None, ref: bool | None = None
) -> VarDefaultsContext: ...


def var_defaults(
  *, hijax: bool | None = None, ref: bool | None = None
) -> VarDefaultsContext | VarDefaults:
  if hijax is None and ref is None:
    return VarDefaults(
      hijax=VARIABLE_CONTEXT.variable_hijax_stack[-1]
      if VARIABLE_CONTEXT.variable_hijax_stack
      else config.flax_hijax_variable,
      ref=VARIABLE_CONTEXT.variable_ref_stack[-1]
      if VARIABLE_CONTEXT.variable_ref_stack
      else False,
    )

  hijax_prev = None
  if hijax is not None:
    if VARIABLE_CONTEXT.variable_hijax_stack:
      hijax_prev = VARIABLE_CONTEXT.variable_hijax_stack[-1]
      VARIABLE_CONTEXT.variable_hijax_stack[-1] = hijax
    else:
      VARIABLE_CONTEXT.variable_hijax_stack.append(hijax)

  ref_prev = None
  if ref is not None:
    if VARIABLE_CONTEXT.variable_ref_stack:
      ref_prev = VARIABLE_CONTEXT.variable_ref_stack[-1]
      VARIABLE_CONTEXT.variable_ref_stack[-1] = ref
    else:
      VARIABLE_CONTEXT.variable_ref_stack.append(ref)

  return VarDefaultsContext(
    hijax_prev=hijax_prev,
    hijax_new=hijax,
    ref_prev=ref_prev,
    ref_new=ref,
  )


class VarDefaultsContext:
  def __init__(
    self,
    *,
    hijax_prev: bool | None,
    hijax_new: bool | None,
    ref_prev: bool | None,
    ref_new: bool | None,
  ):
    self.hijax_prev = hijax_prev
    self.hijax_new = hijax_new
    self.ref_prev = ref_prev
    self.ref_new = ref_new

  def __enter__(self):
    if self.hijax_new is not None and self.hijax_prev is not None:
      VARIABLE_CONTEXT.variable_hijax_stack.insert(-1, self.hijax_prev)
    if self.ref_new is not None and self.ref_prev is not None:
      VARIABLE_CONTEXT.variable_ref_stack.insert(-1, self.ref_prev)

  def __exit__(self, exc_type, exc_value, traceback):
    if self.hijax_new is not None:
      VARIABLE_CONTEXT.variable_hijax_stack.pop()
    if self.ref_new is not None:
      VARIABLE_CONTEXT.variable_ref_stack.pop()

  def __call__(self, f: F) -> F:
    # undo stack change for decorator usage
    if self.hijax_new is not None:
      VARIABLE_CONTEXT.variable_hijax_stack.pop()
      if self.hijax_prev is not None:
        VARIABLE_CONTEXT.variable_hijax_stack.append(self.hijax_prev)

    if self.ref_new is not None:
      VARIABLE_CONTEXT.variable_ref_stack.pop()
      if self.ref_prev is not None:
        VARIABLE_CONTEXT.variable_ref_stack.append(self.ref_prev)

    @functools.wraps(f)
    def var_defaults_wrapper(*args, **kwargs):
      if self.hijax_new is not None:
        VARIABLE_CONTEXT.variable_hijax_stack.append(self.hijax_new)
      if self.ref_new is not None:
        VARIABLE_CONTEXT.variable_ref_stack.append(self.ref_new)
      try:
        return f(*args, **kwargs)
      finally:
        if self.hijax_new is not None:
          VARIABLE_CONTEXT.variable_hijax_stack.pop()
        if self.ref_new is not None:
          VARIABLE_CONTEXT.variable_ref_stack.pop()

    return var_defaults_wrapper  # type: ignore[return-value]


def is_array_ref(x) -> tp.TypeGuard[Ref]:
  return isinstance(x, jax.Array | AbstractRef | Ref) and isinstance(
    jax.typeof(x), AbstractRef | Ref
  )


[docs]@dataclasses.dataclass class VariableMetadata(tp.Generic[A]): raw_value: A set_value_hooks: tuple[SetValueHook[A], ...] = () get_value_hooks: tuple[GetValueHook[A], ...] = () create_value_hooks: tuple[CreateValueHook[A], ...] = () add_axis_hooks: tuple[AddAxisHook[Variable[A]], ...] = () remove_axis_hooks: tuple[RemoveAxisHook[Variable[A]], ...] = () metadata: tp.Mapping[str, tp.Any] = dataclasses.field(default_factory=dict)
PyTreeDef = tp.Any Leaf = tp.Any # --------------------------------- # hijax # --------------------------------- @dataclasses.dataclass(frozen=True) class VariableQDD: leaf_avals: tuple[hjx.AbstractValue, ...] treedef: PyTreeDef var_type: type[Variable[Any]] def to_tangent_qdd(self): leaf_avals = tuple(a.to_tangent_aval() for a in self.leaf_avals) return VariableQDD(leaf_avals, self.treedef, self.var_type) def normalize(self): leaf_types = tuple(a.normalize() for a in self.leaf_avals) return VariableQDD(leaf_types, self.treedef, self.var_type) class VariableEffect(jax.core.Effect): ... variable_effect = VariableEffect() hjx.control_flow_allowed_effects.add_type(VariableEffect) def _bind_new_variable( *leaves, treedef, var_type, has_qdd, ref ) -> HijaxVariable: """Binds new_variable_p after instantiating any Zero tangents.""" leaves = tuple(hjx.instantiate_zeros(leaf) for leaf in leaves) return new_variable_p.bind( *leaves, treedef=treedef, var_type=var_type, has_qdd=has_qdd, ref=ref, ) def _new_hijax_from_variable(variable: Variable) -> HijaxVariable: has_qdd = not variable.ref leaves, treedef = jax.tree.flatten(variable) var_type = type(variable) hijax_var = _bind_new_variable( *leaves, treedef=treedef, var_type=var_type, has_qdd=has_qdd, ref=variable.ref, ) return hijax_var class NewVariable(hjx.HiPrimitive): def is_high(self, *leaves, treedef, var_type, has_qdd, ref) -> bool: return True # type: ignore def impl(self, *leaves, treedef, var_type, has_qdd, ref): return HijaxVariable._new( leaves, treedef, var_type, has_qdd, ref=ref ) def abstract_eval(self, *leaves, treedef, var_type, has_qdd, ref): aval = AbstractVariable( var_type, treedef, leaves, has_qdd, ref=ref ) if has_qdd: qdd = VariableQDD(tuple(leaves), treedef, var_type) aval_qdd = hjx.AvalQDD(aval, qdd) # type: ignore return aval_qdd, {variable_effect} else: return aval, set() def to_lojax(self, *leaves, treedef, var_type, has_qdd, ref): return HijaxVariable._new(leaves, treedef, var_type, has_qdd, ref=ref) def jvp(_, primals, tangents, *, treedef, var_type, has_qdd, ref): if has_qdd: raise NotImplementedError( "jvp not implemented for 'new_variable' with QDD" ) primal_hijax_var = _bind_new_variable( *primals, treedef=treedef, var_type=var_type, has_qdd=has_qdd, ref=ref ) tangent_hijax_var = _bind_new_variable( *tangents, treedef=treedef, var_type=var_type, has_qdd=has_qdd, ref=ref ) return primal_hijax_var, tangent_hijax_var def transpose( _, out_var: HijaxVariable, *input_leaves, treedef, var_type, has_qdd, ref ): if has_qdd: raise NotImplementedError( "transpose not implemented for 'new_variable' with QDD" ) avals = tuple( map( lambda x: x.aval if hjx.is_undefined_primal(x) else jax.typeof(x), input_leaves, ) ) leaves_dot = get_variable_p.bind( out_var, treedef=treedef, avals=avals, var_type=var_type, has_qdd=has_qdd, ) return leaves_dot new_variable_p = NewVariable(f'new_variable') def _set_hijax_state(hijax_var, variable: Variable): leaves, treedef = jax.tree.flatten(variable) set_variable_p.bind( hijax_var, *leaves, treedef=treedef, var_type=type(variable) ) class SetVariable(hjx.HiPrimitive): multiple_results = True def is_high(_, *leaf_avals, treedef, var_type) -> bool: return True # type: ignore # TODO: upstream this to Box def impl(_, hijax_var: HijaxVariable, *leaves, treedef, var_type): if not hijax_var.has_qdd: raise errors.ImmutableVariableError( "Trying to update Variable with 'has_qdd=False'." ) assert var_type is hijax_var._var_type object.__setattr__(hijax_var, '_leaves', leaves) object.__setattr__(hijax_var, '_treedef', treedef) return [] def abstract_eval( _, aval_mutable_qdd: hjx.AvalMutableQDD, *leaf_avals, treedef, var_type ): hijax_var: AbstractVariable = aval_mutable_qdd.aval # type: ignore assert isinstance(hijax_var, AbstractVariable) if not hijax_var.has_qdd: raise errors.ImmutableVariableError( "Trying to update Variable with 'has_qdd=False'." ) assert var_type is hijax_var._var_type aval_mutable_qdd.mutable_qdd.update( VariableQDD(leaf_avals, treedef, var_type) ) effects = {variable_effect} if hijax_var.has_qdd else set() return [], effects # TODO better typechecking... def to_lojax(_, hijax_var: HijaxVariable, *leaves, treedef, var_type): if not hijax_var.has_qdd: raise errors.ImmutableVariableError( "Trying to update Variable with 'has_qdd=False'." ) assert var_type is hijax_var._var_type object.__setattr__(hijax_var, '_leaves', leaves) object.__setattr__(hijax_var, '_treedef', treedef) return [] def jvp(_, primals, tangents, *, treedef, var_type): variable: Variable variable, *vals = primals variable_dot: Variable variable_dot, *val_dots = tangents if type(variable_dot._raw_value) is hjx.Zero: raise Exception( "can't differentiate Variable._set operation, " 'did you forget jax.lax.stop_gradient?' ) set_variable_p.bind( variable, *vals, treedef=treedef, var_type=type(variable) ) set_variable_p.bind( variable_dot, *val_dots, treedef=treedef, var_type=type(variable_dot) ) return [], [] def transpose(_, *args, treedef, var_type): raise NotImplementedError('transpose not implemented for SetHijaxVariable') set_variable_p = SetVariable(f'set_variable') def _get_hijax_state(hijax_var: HijaxVariable | AbstractVariable) -> Variable: if hijax_var.has_qdd: tys: VariableQDD = jax.experimental.cur_qdd(hijax_var) leaf_vals = get_variable_p.bind( hijax_var, treedef=tys.treedef, avals=tuple(tys.leaf_avals), var_type=hijax_var._var_type, has_qdd=hijax_var.has_qdd, ) variable = jax.tree.unflatten(tys.treedef, leaf_vals) else: assert hijax_var._treedef is not None assert hijax_var._leaves is not None if isinstance(hijax_var, (jax.core.Tracer, AbstractVariable)): leaf_avals = hijax_var._leaves else: leaf_avals = tuple(map(jax.typeof, hijax_var._leaves)) leaf_vals = get_variable_p.bind( hijax_var, treedef=hijax_var._treedef, avals=leaf_avals, var_type=hijax_var._var_type, has_qdd=hijax_var.has_qdd, ) variable = jax.tree.unflatten(hijax_var._treedef, leaf_vals) return variable class GetVariable(hjx.HiPrimitive): multiple_results = True def impl( self, hijax_var: HijaxVariable, *, treedef, avals, var_type, has_qdd ): return hijax_var._leaves def abstract_eval(self, abstract_var, *, treedef, avals, var_type, has_qdd): if has_qdd: return avals, {variable_effect} else: return avals, set() def to_lojax( _, hijax_var: HijaxVariable, *, treedef, avals, var_type, has_qdd ): return hijax_var._leaves def jvp(_, primals, tangents, *, treedef, avals, var_type, has_qdd): if has_qdd: raise NotImplementedError( "jvp not implemented for 'get_variable' with QDD" ) (hijax_var,), (hijax_var_dot,) = primals, tangents return ( get_variable_p.bind( hijax_var, treedef=treedef, avals=avals, var_type=var_type, has_qdd=has_qdd, ), get_variable_p.bind( hijax_var_dot, treedef=treedef, avals=tuple(a.to_tangent_aval() for a in avals), var_type=var_type, has_qdd=has_qdd, ), ) def transpose(_, out, hijax_var, *, treedef, avals, var_type, has_qdd): if has_qdd: raise NotImplementedError( "transpose not implemented for 'get_variable' with QDD" ) abstract_var: AbstractVariable = ( hijax_var.aval if hjx.is_undefined_primal(hijax_var) else jax.typeof(hijax_var) ) hijax_var_dot = _bind_new_variable( *out, treedef=abstract_var._treedef, var_type=var_type, has_qdd=has_qdd, ref=abstract_var.ref, ) return (hijax_var_dot,) get_variable_p = GetVariable(f'get_variable') # --------------------------------- # HijaxVariable # --------------------------------- def _variable_has_changed(old: Variable, new: Variable) -> bool: old_structure = jax.tree.structure(old) new_structure = jax.tree.structure(new) if old_structure != new_structure: # type: ignore[operator] return True old_leaves = jax.tree.leaves(old) new_leaves = jax.tree.leaves(new) return any(o is not n for o, n in zip(old_leaves, new_leaves)) def _as_hijax_property(name: str, *, get: bool, set: bool) -> property: """Creates a property that operates on the hijax type.""" def _getter_wrapper(hijax_var): variable = _get_hijax_state(hijax_var) old_state = jax.tree.map(lambda x: x, variable) out = getattr(variable, name) if _variable_has_changed(old_state, variable): _set_hijax_state(hijax_var, variable) return out def _setter_wrapper(hijax_var, value): variable = _get_hijax_state(hijax_var) setattr(variable, name, value) _set_hijax_state(hijax_var, variable) _hijax_property = property( fget=_getter_wrapper if get else None, fset=_setter_wrapper if set else None, ) return _hijax_property # type: ignore[return] def _as_aval_property(p: property) -> hjx.aval_property: """Wraps a property `p` operate on the aval type.""" _aval_property = hjx.aval_property(fget=p.fget) return _aval_property # type: ignore[return] def _as_hijax_attribute(name: str) -> property: """Creates a property that operates on the hijax type.""" def _getter_wrapper(hijax_var): variable = _get_hijax_state(hijax_var) old_state = jax.tree.map(lambda x: x, variable) out = getattr(variable, name) if _variable_has_changed(old_state, variable): _set_hijax_state(hijax_var, variable) return out _getter_wrapper.__name__ = name _hijax_property = property(fget=_getter_wrapper) return _hijax_property # type: ignore[return] def _as_hijax_method(name: str) -> tp.Any: """Creates a method that operates on the hijax type.""" def hijax_method_wrapper(hijax_var, *args, **kwargs): variable = _get_hijax_state(hijax_var) old_state = jax.tree.map(lambda x: x, variable) method = getattr(variable, name) out = method(*args, **kwargs) if _variable_has_changed(old_state, variable): _set_hijax_state(hijax_var, variable) return out hijax_method_wrapper.__name__ = name return hijax_method_wrapper def _as_tracer_method(name: str): def op(self, hijax_var, *args, **kwargs): variable = _get_hijax_state(hijax_var) old_state = jax.tree.map(lambda x: x, variable) out = getattr(variable, name)(*args, **kwargs) if _variable_has_changed(old_state, variable): _set_hijax_state(hijax_var, variable) return out op.__name__ = name return op def _not_an_attribute_property(name: str): def _op(self): raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) return property(_op) class HijaxVariableMeta(type): def __instancecheck__(self, instance): if super().__instancecheck__(instance): return True if isinstance(instance, jax.core.Tracer): ty = jax.typeof(instance) return isinstance(ty, AbstractVariable) return False class HijaxVariable( tp.Generic[A], reprlib.Representable, metaclass=HijaxVariableMeta ): # type: ignore __slots__ = ('_treedef', '_leaves', '_var_type', 'has_qdd', '_ref') _treedef: PyTreeDef _leaves: tuple[Leaf, ...] _var_type: type[Variable[tp.Any]] has_qdd: bool _ref: bool @classmethod def _new( cls, leaves: tuple[Leaf, ...], treedef: PyTreeDef, var_type: type[Variable[A]], has_qdd: bool, *, ref: bool = False, ): hijax_var = object.__new__(cls) object.__setattr__(hijax_var, '_treedef', treedef) object.__setattr__(hijax_var, '_leaves', leaves) object.__setattr__(hijax_var, '_var_type', var_type) object.__setattr__(hijax_var, 'has_qdd', has_qdd) object.__setattr__(hijax_var, '_ref', ref) return hijax_var __init__ = _as_hijax_method('__init__') @property def value(self) -> A: raise NotImplementedError( 'HijaxVariable.value property is not implemented. For Variable[Array] instances use:\n\n' ' variable[...]\n\n' 'For other Variable types use:\n\n' ' variable.get_value()\n' ) @value.setter def value(self, new_value: A): raise NotImplementedError( 'HijaxVariable.value property is not implemented. For Variable[Array] instances use:\n\n' ' variable[...] = new_value\n\n' 'For other Variable types use:\n\n' ' variable.set_value(new_value)\n' ) @property def var_type(self) -> type[Variable[A]]: return self._var_type _trace_state = _as_hijax_property('_trace_state', get=True, set=False) _can_update = _as_hijax_property('_can_update', get=True, set=False) _check_can_update = _as_hijax_method('_check_can_update') __getattr__ = _as_hijax_method('__getattr__') __setattr__ = _as_hijax_method('__setattr__') __delattr__ = _as_hijax_method('__delattr__') type = _as_hijax_property('type', get=True, set=False) type = _as_hijax_property('type', get=True, set=False) hijax = _as_hijax_property('hijax', get=True, set=False) @property def ref(self) -> bool: return self._ref get_metadata = _as_hijax_method('get_metadata') set_metadata = _as_hijax_method('set_metadata') def copy_from(self, other: Variable[A] | HijaxVariable[A]) -> None: if isinstance(other, HijaxVariable): other = _get_hijax_state(other) variable = _get_hijax_state(self) variable.copy_from(other) # type: ignore[arg-type] _set_hijax_state(self, variable) def update_from_state(self, variable_state: Variable[A] | HijaxVariable[A]): if isinstance(variable_state, HijaxVariable): variable_state = _get_hijax_state(variable_state) variable = _get_hijax_state(self) variable.update_from_state(variable_state) # type: ignore[arg-type] _set_hijax_state(self, variable) get_raw_value = _as_hijax_method('get_raw_value') set_raw_value = _as_hijax_method('set_raw_value') set_value = _as_hijax_method('set_value') get_value = _as_hijax_method('get_value') create_value = _as_hijax_method('create_value') set_raw_value = _as_hijax_method('set_raw_value') add_axis = _as_hijax_method('add_axis') remove_axis = _as_hijax_method('remove_axis') copy = _as_hijax_method('copy') replace = _as_hijax_method('replace') to_state = _as_hijax_method('to_state') @classmethod def from_metadata(cls, value: A, metadata: dict[str, tp.Any]): return cls._var_type.from_metadata(value, metadata) # type: ignore[misc] __nnx_repr__ = _as_hijax_method('__nnx_repr__') __treescope_repr__ = _as_hijax_method('__treescope_repr__') # -------------------------------------------- # proxy methods # -------------------------------------------- __jax_array__ = _as_hijax_method('__jax_array__') __getitem__ = _as_hijax_method('__getitem__') __setitem__ = _as_hijax_method('__setitem__') __delitem__ = _as_hijax_method('__delitem__') __call__ = _as_hijax_method('__call__') __len__ = _as_hijax_method('__len__') __iter__ = _as_hijax_method('__iter__') __contains__ = _as_hijax_method('__contains__') __add__ = _as_hijax_method('__add__') __sub__ = _as_hijax_method('__sub__') __mul__ = _as_hijax_method('__mul__') __matmul__ = _as_hijax_method('__matmul__') __truediv__ = _as_hijax_method('__truediv__') __floordiv__ = _as_hijax_method('__floordiv__') __mod__ = _as_hijax_method('__mod__') __divmod__ = _as_hijax_method('__divmod__') __pow__ = _as_hijax_method('__pow__') __lshift__ = _as_hijax_method('__lshift__') __rshift__ = _as_hijax_method('__rshift__') __and__ = _as_hijax_method('__and__') __xor__ = _as_hijax_method('__xor__') __or__ = _as_hijax_method('__or__') __radd__ = _as_hijax_method('__radd__') __rsub__ = _as_hijax_method('__rsub__') __rmul__ = _as_hijax_method('__rmul__') __rmatmul__ = _as_hijax_method('__rmatmul__') __rtruediv__ = _as_hijax_method('__rtruediv__') __rfloordiv__ = _as_hijax_method('__rfloordiv__') __rmod__ = _as_hijax_method('__rmod__') __rdivmod__ = _as_hijax_method('__rdivmod__') __rpow__ = _as_hijax_method('__rpow__') __rlshift__ = _as_hijax_method('__rlshift__') __rrshift__ = _as_hijax_method('__rrshift__') __rand__ = _as_hijax_method('__rand__') __rxor__ = _as_hijax_method('__rxor__') __ror__ = _as_hijax_method('__ror__') __iadd__ = _as_hijax_method('__iadd__') __isub__ = _as_hijax_method('__isub__') __imul__ = _as_hijax_method('__imul__') __imatmul__ = _as_hijax_method('__imatmul__') __itruediv__ = _as_hijax_method('__itruediv__') __ifloordiv__ = _as_hijax_method('__ifloordiv__') __imod__ = _as_hijax_method('__imod__') __ipow__ = _as_hijax_method('__ipow__') __ilshift__ = _as_hijax_method('__ilshift__') __irshift__ = _as_hijax_method('__irshift__') __iand__ = _as_hijax_method('__iand__') __ixor__ = _as_hijax_method('__ixor__') __ior__ = _as_hijax_method('__ior__') __neg__ = _as_hijax_method('__neg__') __pos__ = _as_hijax_method('__pos__') __abs__ = _as_hijax_method('__abs__') __invert__ = _as_hijax_method('__invert__') __complex__ = _as_hijax_method('__complex__') __int__ = _as_hijax_method('__int__') __float__ = _as_hijax_method('__float__') __index__ = _as_hijax_method('__index__') __round__ = _as_hijax_method('__round__') __trunc__ = _as_hijax_method('__trunc__') __floor__ = _as_hijax_method('__floor__') __ceil__ = _as_hijax_method('__ceil__') # -------------------------------------------- # hijax interface # -------------------------------------------- def cur_qdd(self): return self.type_state() def type_state(self): leaf_avals = tuple(map(jax.typeof, self._leaves)) return VariableQDD(leaf_avals, self._treedef, self._var_type) def _to_abstract_variable(hijax_var: HijaxVariable): if hijax_var.has_qdd: treedef = None leaves = None else: leaves = tuple(map(jax.typeof, hijax_var._leaves)) treedef = hijax_var._treedef return AbstractVariable( hijax_var._var_type, treedef, leaves, hijax_var.has_qdd, ref=hijax_var.ref, ) hjx.register_hitype(HijaxVariable, _to_abstract_variable) # --------------------------------- # AbstractVariable # --------------------------------- class AbstractVariable(tp.Generic[A], hjx.MutableHiType): __slots__ = ['_var_type', '_treedef', '_leaves', 'has_qdd', '_ref'] _var_type: type[Variable[A]] _treedef: PyTreeDef | None _leaves: tuple[hjx.AbstractValue, ...] | None has_qdd: bool _ref: bool @property def ref(self) -> bool: return self._ref @property def hijax(self): return True _check_can_update = hjx.aval_method(HijaxVariable._check_can_update) def __init__( self, var_type: type[Variable[A]], treedef: PyTreeDef | None, leaves: tuple[hjx.AbstractValue, ...] | None, has_qdd: bool, *, ref: bool = False, ): if (treedef is None) ^ (leaves is None): raise ValueError('treedef and leaves must be both provided or both None') object.__setattr__(self, '_treedef', treedef) object.__setattr__(self, '_leaves', leaves) object.__setattr__(self, '_var_type', var_type) object.__setattr__(self, 'has_qdd', has_qdd) object.__setattr__(self, '_ref', ref) @property def dtype(self): raise AttributeError @property def ndim(self): raise AttributeError @property def size(self): raise AttributeError @property def shape(self): raise AttributeError def __getattr__(self, name: str): # Forward unknown attributes to the value if hasattr(AbstractVariable, name): raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) if name.startswith('_'): raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) return _as_aval_property(_as_hijax_attribute(name)) # __setattr__ supported via __getattr__ # __delattr__ CURRENTLY NOT SUPPORTED type = _as_aval_property(HijaxVariable.type) get_metadata = hjx.aval_method(HijaxVariable.get_metadata) set_metadata = hjx.aval_method(HijaxVariable.set_metadata) copy_from = hjx.aval_method(HijaxVariable.copy_from) update_from_state = hjx.aval_method(HijaxVariable.update_from_state) get_raw_value = hjx.aval_method(HijaxVariable.get_raw_value) set_raw_value = hjx.aval_method(HijaxVariable.set_raw_value) set_value = hjx.aval_method(HijaxVariable.set_value) get_value = hjx.aval_method(HijaxVariable.get_value) create_value = hjx.aval_method(HijaxVariable.create_value) set_raw_value = hjx.aval_method(HijaxVariable.set_raw_value) add_axis = hjx.aval_method(HijaxVariable.add_axis) remove_axis = hjx.aval_method(HijaxVariable.remove_axis) replace = hjx.aval_method(HijaxVariable.replace) @hjx.aval_method def from_metadata(self, value, metadata: dict[str, tp.Any]): aval: AbstractVariable = self.aval # type: ignore variable = aval._var_type.from_metadata(value, metadata) return variable copy = hjx.aval_method(HijaxVariable.copy) replace = hjx.aval_method(HijaxVariable.replace) to_state = hjx.aval_method(HijaxVariable.to_state) def __str__(self): return f'{self._var_type.__name__}()' def __repr__(self): return f'{self._var_type.__name__}()' @hjx.aval_method def __treescope_repr__(self, path, subtree_renderer): raise NotImplementedError # --------------------------------- # proxy methods # --------------------------------- __jax_array__ = hjx.aval_method(HijaxVariable.__jax_array__) _getitem = _as_tracer_method('__getitem__') _setitem = _as_tracer_method('__setitem__') # __delitem__ CURRENTLY NOT SUPPORTED # __call__ CURRENTLY NOT SUPPORTED _len = _as_tracer_method('__len__') _iter = _as_tracer_method('__iter__') # __contains__ CURRENTLY NOT SUPPORTED _add = _as_tracer_method('__add__') _sub = _as_tracer_method('__sub__') _mul = _as_tracer_method('__mul__') _matmul = _as_tracer_method('__matmul__') _truediv = _as_tracer_method('__truediv__') _floordiv = _as_tracer_method('__floordiv__') _mod = _as_tracer_method('__mod__') _divmod = _as_tracer_method('__divmod__') _pow = _as_tracer_method('__pow__') _lshift = _as_tracer_method('__lshift__') _rshift = _as_tracer_method('__rshift__') _and = _as_tracer_method('__and__') _xor = _as_tracer_method('__xor__') _or = _as_tracer_method('__or__') _radd = _as_tracer_method('__radd__') _rsub = _as_tracer_method('__rsub__') _rmul = _as_tracer_method('__rmul__') _rmatmul = _as_tracer_method('__rmatmul__') _rtruediv = _as_tracer_method('__rtruediv__') _rfloordiv = _as_tracer_method('__rfloordiv__') _rmod = _as_tracer_method('__rmod__') _rdivmod = _as_tracer_method('__rdivmod__') _rpow = _as_tracer_method('__rpow__') _rlshift = _as_tracer_method('__rlshift__') _rrshift = _as_tracer_method('__rrshift__') _rand = _as_tracer_method('__rand__') _rxor = _as_tracer_method('__rxor__') _ror = _as_tracer_method('__ror__') # _iadd CURRENTLY NOT SUPPORTED # _isub CURRENTLY NOT SUPPORTED # _imul CURRENTLY NOT SUPPORTED # _imatmul CURRENTLY NOT SUPPORTED # _itruediv CURRENTLY NOT SUPPORTED # _ifloordiv CURRENTLY NOT SUPPORTED # _imod CURRENTLY NOT SUPPORTED # _ipow CURRENTLY NOT SUPPORTED # _ilshift CURRENTLY NOT SUPPORTED # _irshift CURRENTLY NOT SUPPORTED # _iand CURRENTLY NOT SUPPORTED # _ixor CURRENTLY NOT SUPPORTED # _ior CURRENTLY NOT SUPPORTED _neg = _as_tracer_method('__neg__') _pos = _as_tracer_method('__pos__') _abs = _as_tracer_method('__abs__') _invert = _as_tracer_method('__invert__') _complex = _as_tracer_method('__complex__') _int = _as_tracer_method('__int__') _float = _as_tracer_method('__float__') _index = _as_tracer_method('__index__') _round = _as_tracer_method('__round__') _trunc = _as_tracer_method('__trunc__') _floor = _as_tracer_method('__floor__') _ceil = _as_tracer_method('__ceil__') # -------------------------------- # hijax interface # -------------------------------- cur_qdd = _not_an_attribute_property('cur_qdd') def __hash__(self): if self._leaves is not None and self._treedef is not None: return hash( (AbstractVariable, self._var_type, self._treedef, self._leaves) ) else: assert self._leaves is None and self._treedef is None return hash((AbstractVariable, self._var_type)) def __eq__(self, other): return ( isinstance(other, AbstractVariable) and self._var_type == other._var_type ) def str_short(self, short_dtypes=False, **_) -> str: # type: ignore return f'{self._var_type.__name__}()' # mutable interface def lo_ty_qdd(self, variable_state: VariableQDD) -> list: # type: ignore return [lo_ty for t in variable_state.leaf_avals for lo_ty in t.lo_ty()] def new_from_loval( # type: ignore[override] self, variable_state: VariableQDD, *lo_vals ) -> HijaxVariable: lo_vals_ = iter(lo_vals) hi_vals = [ hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) # type: ignore for hi_ty in variable_state.leaf_avals ] assert next(lo_vals_, None) is None variable: Variable = jax.tree.unflatten(variable_state.treedef, hi_vals) return HijaxVariable._new( hi_vals, variable_state.treedef, self._var_type, has_qdd=self.has_qdd, ref=self.ref, ) # will be mutated def read_loval(self, variable_state: VariableQDD, variable) -> list: # type: ignore leaf_vals, treedef = jax.tree.flatten(_get_hijax_state(variable)) assert treedef == variable_state.treedef return [ lo_val for hi_ty, hi_val in zip(variable_state.leaf_avals, leaf_vals) for lo_val in hi_ty.lower_val(hi_val) ] # type: ignore def update_from_loval( # type: ignore[override] self, box_state: VariableQDD, variable, *lo_vals ) -> None: lo_vals_ = iter(lo_vals) hi_vals = [ hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) # type: ignore for hi_ty in box_state.leaf_avals ] assert next(lo_vals_, None) is None _set_hijax_state(variable, jax.tree.unflatten(box_state.treedef, hi_vals)) def to_tangent_aval(self): return AbstractVariable( self._var_type, self._treedef, self._leaves, self.has_qdd, ref=self.ref, ) # -------------------------------------------- # Variable # -------------------------------------------- def _remap_sharding_metadata(metadata: dict[str, tp.Any]) -> None: if 'sharding' in metadata: warnings.warn( "'sharding' is deprecated, use 'out_sharding' instead.", DeprecationWarning, stacklevel=3, ) metadata['out_sharding'] = metadata.pop('sharding') if 'sharding_names' in metadata: warnings.warn( "'sharding_names' is deprecated, use 'out_sharding' instead.", DeprecationWarning, stacklevel=3, ) metadata['out_sharding'] = metadata.pop('sharding_names') def _variable_operator(name: str) -> tp.Callable[[Variable[A], tp.Any], A]: def variable_operator_method(self, other): value = self.get_value() if isinstance(other, Variable): other = other.get_value() return getattr(value, name)(other) variable_operator_method.__name__ = name return variable_operator_method def _variable_unary_operator(name: str) -> tp.Callable[[Variable[A]], A]: def variable_unary_operator_method(self): value = self.get_value() return getattr(value, name)() variable_unary_operator_method.__name__ = name return variable_unary_operator_method class VariableMeta(type): def __new__(cls, cls_name, bases, attrs): if '__slots__' not in attrs: attrs['__slots__'] = () return super().__new__(cls, cls_name, bases, attrs) def __instancecheck__(self, instance): if super().__instancecheck__(instance): return True if isinstance(instance, jax.core.Tracer): ty = jax.typeof(instance) if isinstance(ty, AbstractVariable): return issubclass(ty._var_type, self) if isinstance(instance, HijaxVariable): return issubclass(instance._var_type, self) return False if not tp.TYPE_CHECKING: def __call__(cls, *args, **kwargs): return cls._variable_meta_call(*args, **kwargs) def _variable_meta_call(cls, *args, **kwargs): variable = super().__call__(*args, **kwargs) if variable.hijax: return _new_hijax_from_variable(variable) return variable
[docs]class Variable(tp.Generic[A], reprlib.Representable, metaclass=VariableMeta): """The base class for all ``Variable`` types. Create custom ``Variable`` types by subclassing this class. Numerous NNX graph functions can filter for specific ``Variable`` types, for example, :func:`split`, :func:`state`, :func:`pop`, and :func:`State.filter`. Example usage:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> class CustomVariable(nnx.Variable): ... pass >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.custom_variable = CustomVariable(jnp.ones((1, 3))) ... def __call__(self, x): ... return self.linear(x) + self.custom_variable >>> model = Model(rngs=nnx.Rngs(0)) >>> linear_variables = nnx.state(model, nnx.Param) >>> jax.tree.map(jnp.shape, linear_variables) State({ 'linear': { 'bias': Param( value=(3,) ), 'kernel': Param( value=(2, 3) ) } }) >>> custom_variable = nnx.state(model, CustomVariable) >>> jax.tree.map(jnp.shape, custom_variable) State({ 'custom_variable': CustomVariable( value=(1, 3) ) }) >>> variables = nnx.state(model) >>> jax.tree.map(jnp.shape, variables) State({ 'custom_variable': CustomVariable( value=(1, 3) ), 'linear': { 'bias': Param( value=(3,) ), 'kernel': Param( value=(2, 3) ) } }) """ __slots__ = ('_raw_value', '_trace_state', '_var_metadata') _raw_value: A _trace_state: tracers.TraceState _var_metadata: dict[str, tp.Any] required_metadata = frozenset( ['hijax', 'ref', 'eager_sharding'] ) @property def var_type(self): return type(self) @property def hijax(self) -> bool: return self._var_metadata['hijax'] @property def ref(self) -> bool: return self._var_metadata['ref'] @property def shape(self: Variable[jax.Array]) -> tuple[int, ...]: return self.get_value().shape @property def sharding_names(self): warnings.warn( "'sharding_names' is deprecated, use 'out_sharding' instead.", DeprecationWarning, stacklevel=2, ) return self.get_metadata('out_sharding', None) def __init__( self, value: A | VariableMetadata[A], *, hijax: bool | None = None, ref: bool | None = None, eager_sharding: bool | None = None, **metadata: tp.Any, ): var_t = type(self) if isinstance(value, VariableMetadata): aux_metadata = dict(value.metadata) if 'hijax' in aux_metadata: if hijax is not None and hijax != aux_metadata['hijax']: raise ValueError( 'Cannot specify hijax both in VariableMetadata and as an ' 'argument to Variable constructor.' ) hijax = aux_metadata.pop('hijax') if 'ref' in aux_metadata: if ref is not None and ref != aux_metadata['ref']: raise ValueError( 'Cannot specify ref both in VariableMetadata and as an ' 'argument to Variable constructor.' ) ref = aux_metadata.pop('ref') if 'eager_sharding' in aux_metadata: if ( eager_sharding is not None and eager_sharding != aux_metadata['eager_sharding'] ): raise ValueError( 'Cannot specify eager_sharding both in VariableMetadata and as ' 'an argument to Variable constructor.' ) eager_sharding = aux_metadata['eager_sharding'] metadata.update(aux_metadata) value = tp.cast(A, value.raw_value) if hijax is None: hijax = var_defaults().hijax if ref is None: ref = var_defaults().ref if eager_sharding is None: eager_sharding = using_eager_sharding() if any(is_array_ref(v) for v in jax.tree.leaves(value)): raise ValueError('Cannot pass a Ref directly into Variable constructor.') metadata['hijax'] = hijax metadata['ref'] = ref metadata['eager_sharding'] = eager_sharding object.__setattr__(self, '_trace_state', tracers.TraceState()) object.__setattr__(self, '_var_metadata', metadata) object.__setattr__(self, '_raw_value', value) if hasattr(var_t, 'on_get_value') and 'on_get_value' not in metadata: metadata['on_get_value'] = var_t.on_get_value if hasattr(var_t, 'on_set_value') and 'on_set_value' not in metadata: metadata['on_set_value'] = var_t.on_set_value if hasattr(var_t, 'on_create_value') and 'on_create_value' not in metadata: metadata['on_create_value'] = var_t.on_create_value if hasattr(var_t, 'on_add_axis') and 'on_add_axis' not in metadata: metadata['on_add_axis'] = var_t.on_add_axis if hasattr(var_t, 'on_remove_axis') and 'on_remove_axis' not in metadata: metadata['on_remove_axis'] = var_t.on_remove_axis _remap_sharding_metadata(metadata) # run create_value hooks if 'on_create_value' in metadata: value = metadata['on_create_value'](self, value) object.__setattr__(self, '_raw_value', value) # run create_value hook value = self.create_value(value) # type: ignore # shard the _value if applicable if eager_sharding and 'out_sharding' in metadata: value = core_spmd.shard_value( value, metadata['out_sharding'], metadata.get('sharding_rules', None), metadata.get('mesh', None), ) if ref: value = jax.new_ref(value) # type: ignore object.__setattr__(self, '_raw_value', value) @property def _can_update(self) -> bool: """Whether the Variable can be updated in-place in the current trace context.""" if self.hijax: return True else: return self._trace_state.is_valid() def _check_can_update(self): if not self.hijax and not self._trace_state.is_valid(): raise errors.TraceContextError( f'Cannot mutate {type(self).__name__} from a different trace level' ) def __getattr__(self, name: str) -> tp.Any: if name in object.__getattribute__(self, '_var_metadata'): return self._var_metadata[name] return getattr(object.__getattribute__(self, '_raw_value'), name) def __setattr__(self, name: str, value: tp.Any): self._check_can_update() try: object.__setattr__(self, name, value) except AttributeError as e: raise AttributeError( f'Cannot set attribute {name}. ' f'To set Variable metadata use either:\n\n' f' variable.set_metadata({name}=value)\n\nor\n\n' f" variable.set_metadata('{name}', value)" ) from e def __delattr__(self, name: str): self._check_can_update() try: object.__delattr__(self, name) except AttributeError as e: raise AttributeError( f'Cannot delete attribute {name}. ' f'To delete Variable metadata use:\n\n' f" variable.del_metadata('{name}')" ) from e # NOTE(cgarciae): adding this for backward compatibility with VariableState @property def type(self): """The type of the variable.""" return type(self) @tp.overload def get_metadata( self, *, exclude_required: bool = False ) -> dict[str, tp.Any]: ... @tp.overload def get_metadata(self, name: str, default: tp.Any = MISSING) -> tp.Any: ...
[docs] def get_metadata( self, name: str | None = None, default: tp.Any = MISSING, *, exclude_required: bool | None = None, ) -> tp.Any: """Get metadata for the Variable. Args: name: The key of the metadata element to get. If not provided, returns the full metadata dictionary. default: The default value to return if the metadata key is not found. If not provided and the key is not found, raises a KeyError. """ if name is not None and exclude_required is not None: raise TypeError( "Cannot specify both 'name' and 'exclude_required' arguments." ) metadata = self._var_metadata.copy() if name is None: if not isinstance(default, Missing): raise TypeError( "Cannot provide a default value when 'name' is not provided. " f'Got default={default}' ) if exclude_required: for key in self.required_metadata: metadata.pop(key, None) return metadata if name not in metadata and not isinstance(default, Missing): return default return metadata[name]
@tp.overload def set_metadata(self, metadata: dict[str, tp.Any], /) -> None: ... @tp.overload def set_metadata(self, name: str, value: tp.Any, /) -> None: ... @tp.overload def set_metadata(self, **metadata: tp.Any) -> None: ...
[docs] def set_metadata(self, *args, **kwargs) -> None: """Set metadata for the Variable. `set_metadata` can be called in 3 ways: 1. By passing a dictionary of metadata as the first argument, this will replace the entire Variable's metadata. 2. By passing a name and value as the first two arguments, this will set the metadata entry for the given name to the given value. 3. By using keyword arguments, this will update the Variable's metadata with the provided key-value pairs. """ self._check_can_update() if args and kwargs: raise TypeError( 'Cannot mix positional and keyword arguments in set_metadata' ) if len(args) == 1: metadata = dict(args[0]) _remap_sharding_metadata(metadata) if 'hijax' not in metadata: metadata['hijax'] = self.hijax if metadata['hijax'] != self.hijax: raise ValueError( f'Cannot change `hijax` metadata, expected {self.hijax}, ' f'got {metadata["hijax"]}' ) if 'ref' not in metadata: metadata['ref'] = self.ref if metadata['ref'] != self.ref: raise ValueError( f'Cannot change `ref` metadata, expected {self.ref}, ' f'got {metadata["ref"]}' ) if 'eager_sharding' not in metadata: metadata['eager_sharding'] = self.eager_sharding if metadata['eager_sharding'] != self.eager_sharding: raise ValueError( f'Cannot change `eager_sharding` metadata, expected ' f'{self.eager_sharding}, got {metadata["eager_sharding"]}' ) self._var_metadata = metadata elif len(args) == 2: name, value = args if name == 'sharding_names': warnings.warn( "'sharding_names' is deprecated, use 'out_sharding' instead.", DeprecationWarning, stacklevel=2, ) name = 'out_sharding' elif name == 'sharding': warnings.warn( "'sharding' is deprecated, use 'out_sharding' instead.", DeprecationWarning, stacklevel=2, ) name = 'out_sharding' if name == 'hijax' and value != self.hijax: raise ValueError( f'Cannot change `hijax` metadata, expected {self.hijax}, got {value}' ) if name == 'ref' and value != self.ref: raise ValueError( f'Cannot change `ref` metadata, expected {self.ref}, got {value}' ) self._var_metadata[name] = value elif kwargs: _remap_sharding_metadata(kwargs) if 'hijax' in kwargs and kwargs['hijax'] != self.hijax: raise ValueError( f'Cannot change `hijax` metadata, expected {self.hijax}, ' f'got {kwargs["hijax"]}' ) if 'ref' in kwargs and kwargs['ref'] != self.ref: raise ValueError( f'Cannot change `ref` metadata, expected {self.ref}, ' f'got {kwargs["ref"]}' ) self._var_metadata.update(kwargs) else: raise TypeError( f'set_metadata takes either 1 or 2 arguments, or at least 1 keyword argument, ' f'got args={args}, kwargs={kwargs}' )
[docs] def has_metadata(self, name: str) -> bool: """Check if the Variable has a metadata entry for the given name. Args: name: The key of the metadata element to check. Returns: True if the metadata entry exists, False otherwise. """ return name in self._var_metadata
[docs] def del_metadata(self, name: str) -> None: """Delete a metadata entry for the Variable. Args: name: The key of the metadata element to delete. """ self._check_can_update() if name in ('hijax', 'ref'): raise ValueError(f'Cannot delete `{name}` metadata') del self._var_metadata[name]
def copy_from(self, other: Variable[A]) -> None: if type(self) is not type(other): raise ValueError( f'Cannot copy from incompatible container, ' f'expected {type(self).__name__}, got {type(other).__name__}' ) if self is other: return self._raw_value = other._raw_value self._var_metadata.clear() self._var_metadata.update(other.get_metadata()) def update_from_state(self, variable_state: Variable[A]): self._raw_value = variable_state._raw_value if self._var_metadata != variable_state._var_metadata: metadata = variable_state.get_metadata() metadata['hijax'] = self.hijax metadata['ref'] = self.ref self._var_metadata = metadata @tp.final def get_raw_value(self) -> A: return self._raw_value # @tp.final def set_raw_value(self, value: A, *, _unsafe_bypass_check: bool = False): if not _unsafe_bypass_check: self._check_can_update() self._raw_value = value @property def raw_value(self) -> A: warnings.warn( "'.raw_value' access is now deprecated. Use:\n\n" ' variable.get_raw_value()\n', DeprecationWarning, stacklevel=2, ) return self.get_raw_value() @raw_value.setter def raw_value(self, value: A): warnings.warn( "'.raw_value' setter is now deprecated. Use:\n\n" ' variable.set_raw_value(value)\n', DeprecationWarning, stacklevel=2, ) self.set_raw_value(value) @property def value(self) -> A: warnings.warn( "'.value' access is now deprecated. For Variable[Array] instances use:\n\n" ' variable[...]\n\n' 'For other Variable types use:\n\n' ' variable.get_value()\n', DeprecationWarning, stacklevel=2, ) return self.get_value() @value.setter def value(self, value: A): warnings.warn( "'.value' setter is now deprecated. For Variable[Array] instances use:\n\n" ' variable[...] = value\n\n' 'For other Variable types use:\n\n' ' variable.set_value(value)\n', DeprecationWarning, stacklevel=2, ) self.set_value(value) def create_value(self, value: A): return value def get_value(self, *, index: tp.Any = MISSING) -> A: value = jax.tree.map(lambda x: x, self._raw_value) # make a copy if not isinstance(index, Missing): if is_array_ref(value): value = value[index] elif isinstance(value, jax.Array) and index is ...: pass # skip trivial access else: value = value[index] elif is_array_ref(value): value = value[...] if 'on_get_value' in self._var_metadata: value = self._var_metadata['on_get_value'](self, value) return value # type: ignore def set_value(self, value: A, *, index: tp.Any = MISSING): value = jax.tree.map(lambda x: x, value) # make a copy if isinstance(value, Variable): raise ValueError( 'Cannot set value to a Variable, use `copy_from` method instead' ) if 'on_set_value' in self._var_metadata: value = self._var_metadata['on_set_value'](self, value) # update _raw_value if is_array_ref(self._raw_value): if isinstance(index, Missing): self._raw_value[...] = value else: self._raw_value[index] = value elif isinstance(self._raw_value, jax.Array) and ( not isinstance(index, Missing) ): # check if its a full replace to av if ( index == ... and isinstance(value, jax.Array) and value.shape == self._raw_value[index].shape and value.dtype == self._raw_value.dtype and ( getattr(value, 'sharding', None) == getattr(self._raw_value, 'sharding', None) ) ): self._raw_value = value else: self._raw_value = self._raw_value.at[index].set(value) # type: ignore else: if isinstance(index, Missing): self._raw_value = value else: self._raw_value[index] = value # type: ignore def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): if 'on_add_axis' in self._var_metadata: self._var_metadata['on_add_axis'](self, axis_index, axis_name) def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None): if 'on_remove_axis' in self._var_metadata: self._var_metadata['on_remove_axis'](self, axis_index, axis_name) @tp.overload def copy(self, value: B, **kwargs) -> Variable[B]: ... @tp.overload def copy(self, **kwargs) -> Variable[A]: ... def copy( self, value: tp.Any = MISSING, *, _copy_ref: bool = True, **updates, ) -> Variable[tp.Any]: assert 'raw_value' not in updates new_metadata = self.get_metadata() | updates if not isinstance(value, Missing): pass elif 'value' in updates: value = updates.pop('value') else: value = self.get_raw_value() if _copy_ref and is_array_ref(value): value = value[...] if _copy_ref and new_metadata['ref']: value = jax.new_ref(value) new_metadata['ref'] = True obj = self.from_metadata(value, new_metadata) return obj @classmethod def _new( cls, value: A, metadata: dict[str, tp.Any], ) -> Variable[A]: obj = object.__new__(cls) # skip __setattr__ for trace_state initialization object.__setattr__(obj, '_trace_state', tracers.TraceState()) object.__setattr__(obj, '_var_metadata', metadata) object.__setattr__(obj, '_raw_value', value) return obj @classmethod def from_metadata( cls, value: A, attributes: dict[str, tp.Any], ) -> Variable[A]: variable = cls._new(value, dict(attributes)) if attributes['hijax']: variable = _new_hijax_from_variable(variable) # type: ignore[assignment] return variable # type: ignore[return-value] replace = copy to_state = copy def __nnx_repr__(self): stats = SizeBytes.from_any(self._raw_value) if stats: comment = f' # {stats}' else: comment = '' yield reprlib.Object(type=type(self).__name__, comment=comment) yield reprlib.Attr('value', self.get_value()) for name, value in self._var_metadata.items(): if name == 'hijax' and value == config.flax_hijax_variable: continue if name == 'ref' and not value: continue if name == 'eager_sharding' and value == config.flax_always_shard_variable: continue yield reprlib.Attr(name, value) def __treescope_repr__(self, path, subtree_renderer): size_bytes = SizeBytes.from_any(self.get_value()) if size_bytes: stats_repr = f' # {size_bytes}' first_line_annotation = treescope.rendering_parts.comment_color( treescope.rendering_parts.text(f'{stats_repr}') ) else: first_line_annotation = None metadata = { name: value for name, value in self._var_metadata.items() if not (name == 'hijax' and value == config.flax_hijax_variable) and not (name == 'ref' and not value) and not (name == 'eager_sharding' and value == config.flax_always_shard_variable) } children = {'value': self.get_value(), **metadata} return visualization.render_object_constructor( object_type=type(self), attributes=children, path=path, subtree_renderer=subtree_renderer, first_line_annotation=first_line_annotation, ) # hooks API if tp.TYPE_CHECKING: def on_get_value(self, value: A) -> A: ... def on_set_value(self, value: A) -> A: ... def on_create_value(self, value: A) -> A: ... def on_add_axis( self: V, axis_index: AxisIndex, axis_name: AxisName | None ) -> V: ... def on_remove_axis( self: V, axis_index: AxisIndex, axis_name: AxisName | None ) -> V: ... def __jax_array__(self): return self.get_value() # pickle support def __getstate__(self): return { '_raw_value': self._raw_value, '_trace_state': self._trace_state, '_var_metadata': self._var_metadata, } def __setstate__(self, state): # skip __setattr__ for trace_state initialization object.__setattr__(self, '_trace_state', state['_trace_state']) object.__setattr__(self, '_var_metadata', state['_var_metadata']) object.__setattr__(self, '_raw_value', state['_raw_value']) # -------------------------------------------- # proxy methods # -------------------------------------------- @tp.overload def __getitem__(self: Variable[jax.Array], key) -> jax.Array: ... @tp.overload def __getitem__(self: Variable[dict[tp.Any, B]], key) -> B: ... @tp.overload def __getitem__(self: Variable[list[B]], key: int) -> B: ... @tp.overload def __getitem__(self: Variable[tuple[B, ...]], key: int) -> B: ... @tp.overload def __getitem__(self, key) -> tp.Any: ... def __getitem__(self, key): return self.get_value(index=key) def __setitem__(self, key, value) -> None: self.set_value(value, index=key) def __delitem__(self, key) -> None: value = self.get_value() del value[key] # type: ignore self.set_value(value) # type: ignore def __call__(self, *args, **kwargs) -> tp.Any: return self.get_value()(*args, **kwargs) # type: ignore def __len__(self) -> int: return len(self.get_value()) # type: ignore def __iter__(self) -> tp.Iterator: return iter(self.get_value()) # type: ignore def __contains__(self, item) -> bool: return item in self.get_value() # type: ignore __add__ = _variable_operator('__add__') __sub__ = _variable_operator('__sub__') __mul__ = _variable_operator('__mul__') __matmul__ = _variable_operator('__matmul__') __truediv__ = _variable_operator('__truediv__') __floordiv__ = _variable_operator('__floordiv__') __mod__ = _variable_operator('__mod__') __pow__ = _variable_operator('__pow__') __lshift__ = _variable_operator('__lshift__') __rshift__ = _variable_operator('__rshift__') __and__ = _variable_operator('__and__') __xor__ = _variable_operator('__xor__') __or__ = _variable_operator('__or__') __radd__ = _variable_operator('__radd__') __rsub__ = _variable_operator('__rsub__') __rmul__ = _variable_operator('__rmul__') __rmatmul__ = _variable_operator('__rmatmul__') __rtruediv__ = _variable_operator('__rtruediv__') __rfloordiv__ = _variable_operator('__rfloordiv__') __rmod__ = _variable_operator('__rmod__') __rpow__ = _variable_operator('__rpow__') __rlshift__ = _variable_operator('__rlshift__') __rrshift__ = _variable_operator('__rrshift__') __rand__ = _variable_operator('__rand__') __rxor__ = _variable_operator('__rxor__') __ror__ = _variable_operator('__ror__') def __eq__(self, other) -> bool: if isinstance(other, Variable): other = other.get_value() return self.get_value().__eq__(other) # type: ignore def __iadd__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable[...] += x` instead.' ) def __isub__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable[...] -= x` instead.' ) def __imul__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable[...] *= x` instead.' ) def __imatmul__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value @= x` instead.' ) def __itruediv__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value /= x` instead.' ) def __ifloordiv__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value //= x`` instead.' ) def __imod__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value %= x` instead.' ) def __ipow__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value **= x`` instead.' ) def __ilshift__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value <<= x`` instead.' ) def __irshift__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value >>= x`` instead.' ) def __iand__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value &= x` instead.' ) def __ixor__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value ^= x` instead.' ) def __ior__(self: V, other) -> V: raise NotImplementedError( 'In-place operations are no longer supported for Variable.\n' 'Use `variable.value |= x` instead.' ) __neg__ = _variable_unary_operator('__neg__') __pos__ = _variable_unary_operator('__pos__') __abs__ = _variable_unary_operator('__abs__') __invert__ = _variable_unary_operator('__invert__') __complex__ = _variable_unary_operator('__complex__') __int__ = _variable_unary_operator('__int__') __float__ = _variable_unary_operator('__float__') __index__ = _variable_unary_operator('__index__') __trunc__ = _variable_unary_operator('__trunc__') __floor__ = _variable_unary_operator('__floor__') __ceil__ = _variable_unary_operator('__ceil__') def __round__(self, ndigits: int = 0) -> A: return self.get_value().__round__(ndigits) # type: ignore # -------------------------------------------- def __init_subclass__(cls) -> None: if '__slots__' not in vars(cls): cls.__slots__ = () # type: ignore[assignment] super().__init_subclass__() jax.tree_util.register_pytree_with_keys( cls, flatten_with_keys=_variable_flatten_with_keys, unflatten_func=partial(_variable_unflatten, cls), # type: ignore flatten_func=_variable_flatten, )
def _variable_flatten_with_keys(x: Variable[tp.Any]): metadata = tuple(sorted(x._var_metadata.items())) node = (jtu.GetAttrKey('value'), x._raw_value) return (node,), metadata def _variable_flatten(x: Variable[tp.Any]): metadata = tuple(sorted(x._var_metadata.items())) return (x._raw_value,), metadata def _variable_unflatten( cls: type[Variable[tp.Any]], static: tuple[tuple[str, tp.Any], ...], children: tuple[tp.Any], ): return cls._new(children[0], dict(static)) jax.tree_util.register_pytree_with_keys( Variable, flatten_with_keys=_variable_flatten_with_keys, unflatten_func=partial(_variable_unflatten, Variable), # type: ignore flatten_func=_variable_flatten, ) VariableState = Variable
[docs]class Param(Variable[A]): """The canonical learnable parameter. All learnable parameters in NNX layer modules will have the ``Param`` :class:`Variable` type:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': Param( value=(3,) ), 'kernel': Param( value=(2, 3) ) }) """ pass
[docs]class BatchStat(Variable[A]): """The mean and variance batch statistics stored in the :class:`BatchNorm` layer. Note, these are not the learnable scale and bias parameters, but rather the running average statistics that are typically used during post-training inference:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.BatchNorm(3, rngs=nnx.Rngs(0)) >>> jax.tree.map(jnp.shape, nnx.state(layer)) State({ 'bias': Param( value=(3,) ), 'mean': BatchStat( value=(3,) ), 'scale': Param( value=(3,) ), 'var': BatchStat( value=(3,) ) }) """ pass
[docs]class Cache(Variable[A]): """Autoregressive cache in :class:`MultiHeadAttention`:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> layer = nnx.MultiHeadAttention( ... num_heads=2, ... in_features=3, ... qkv_features=6, ... out_features=6, ... decode=True, ... rngs=nnx.Rngs(0), ... ) >>> layer.init_cache((1, 3)) >>> jax.tree.map(jnp.shape, nnx.state(layer, nnx.Cache)) State({ 'cache_index': Cache( value=() ), 'cached_key': Cache( value=(1, 2, 3) ), 'cached_value': Cache( value=(1, 2, 3) ) }) """ pass
[docs]class Intermediate(Variable[A]): """:class:`Variable` type that is typically used for :func:`Module.sow`:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... self.sow(nnx.Intermediate, 'i', x) ... x = self.linear2(x) ... return x >>> model = Model(rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) >>> y = model(x) >>> jax.tree.map(jnp.shape, nnx.state(model, nnx.Intermediate)) State({ 'i': Intermediate( value=((1, 3),) ) }) """ pass
class Perturbation(Intermediate[A]): """:class:`Variable` type that is typically used for :func:`Module.perturb`:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = self.perturb('i', x) ... x = self.linear2(x) ... return x >>> model = Model(rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) >>> y = model(x) >>> jax.tree.map(jnp.shape, nnx.state(model, nnx.Perturbation)) State({ 'i': Perturbation( value=(1, 3) ) }) """ pass
[docs]def with_metadata( initializer: F, set_value_hooks: tp.Union[SetValueHook[A], tp.Sequence[SetValueHook[A]]] = (), get_value_hooks: tp.Union[SetValueHook[A], tp.Sequence[SetValueHook[A]]] = (), create_value_hooks: tp.Union[ CreateValueHook[A], tp.Sequence[CreateValueHook[A]] ] = (), add_axis_hooks: tp.Union[ AddAxisHook[Variable[A]], tp.Sequence[AddAxisHook[Variable[A]]] ] = (), remove_axis_hooks: tp.Union[ RemoveAxisHook[Variable[A]], tp.Sequence[RemoveAxisHook[Variable[A]]], ] = (), **metadata: tp.Any, ) -> F: if set_value_hooks: if callable(set_value_hooks): set_value_hooks = (set_value_hooks,) else: set_value_hooks = tuple(set_value_hooks) else: set_value_hooks = () if get_value_hooks: if callable(get_value_hooks): get_value_hooks = (get_value_hooks,) else: get_value_hooks = tuple(get_value_hooks) else: get_value_hooks = () if create_value_hooks: if callable(create_value_hooks): create_value_hooks = (create_value_hooks,) else: create_value_hooks = tuple(create_value_hooks) else: create_value_hooks = () if add_axis_hooks: if callable(add_axis_hooks): add_axis_hooks = (add_axis_hooks,) else: add_axis_hooks = tuple(add_axis_hooks) else: add_axis_hooks = () if remove_axis_hooks: if callable(remove_axis_hooks): remove_axis_hooks = (remove_axis_hooks,) else: remove_axis_hooks = tuple(remove_axis_hooks) else: remove_axis_hooks = () @functools.wraps(initializer) def wrapper(*args): return VariableMetadata( initializer(*args), set_value_hooks=set_value_hooks, get_value_hooks=get_value_hooks, create_value_hooks=create_value_hooks, add_axis_hooks=add_axis_hooks, remove_axis_hooks=remove_axis_hooks, metadata=metadata, ) return wrapper # type: ignore
################################################### ### Variable type/class <-> string name mapping ### ################################################### # Assumption: the mapping is 1-1 and unique. VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {}
[docs]def variable_type_from_name( name: str, /, *, base: type[Variable[tp.Any]] = Variable, allow_register: bool = False, ) -> tp.Type[Variable[tp.Any]]: """Given a Linen-style collection name, get or create its NNX Variable class.""" if name not in VariableTypeCache: if not allow_register: raise ValueError( f'Name {name} is not registered in the registry. ' 'To register a new name, use register_variable_name() ' 'or set allow_register=True.' ) VariableTypeCache[name] = type(name, (base,), {}) return VariableTypeCache[name]
[docs]def variable_name_from_type( typ: tp.Type[Variable[tp.Any]], /, *, allow_register: bool = False ) -> str: """Given an NNX Variable type, get its Linen-style collection name. Should output the exact inversed result of `variable_type_from_name()`.""" for name, t in VariableTypeCache.items(): if typ == t: return name if not allow_register: raise ValueError( f'Type {typ} is not registered in the registry. ' 'To register a new type, use register_variable_name() ' 'or set allow_register=True.' ) name = typ.__name__ if name in VariableTypeCache: raise ValueError( 'Name {name} is already registered in the registry as {VariableTypeCache[name]}. ' 'It cannot be linked with this type {typ}.' ) register_variable_name(name, typ) return name
@tp.overload def register_variable_name( name: str, typ: type[Variable[tp.Any]], *, overwrite: bool = False, ) -> type[Variable[tp.Any]]: ... @tp.overload def register_variable_name( name: str, *, overwrite: bool = False, ) -> tp.Callable[[type[Variable[tp.Any]]], type[Variable[tp.Any]]]: ...
[docs]def register_variable_name( name: str, typ: type[Variable[A]] | Missing = MISSING, *, overwrite=False, ) -> type[Variable[A]] | tp.Callable[[type[Variable[A]]], type[Variable[A]]]: """Register a pair of Linen collection name and its NNX type.""" if isinstance(typ, Missing): return partial(register_variable_name, name, overwrite=overwrite) typ = tp.cast(type[Variable[A]], typ) if not overwrite and name in VariableTypeCache: raise ValueError( f'Name {name} already mapped to type {VariableTypeCache[name]}. ' 'To overwrite, call register_variable_name() with `overwrite=True`.' ) VariableTypeCache[name] = typ return typ
# add known variable type names register_variable_name('params', Param) register_variable_name('batch_stats', BatchStat) register_variable_name('cache', Cache) register_variable_name('intermediates', Intermediate) register_variable_name('perturbations', Perturbation)