# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
from __future__ import annotations
import dataclasses
import functools
from functools import partial
import itertools as it
import threading
import typing as tp
from typing import Any
import warnings
from flax import config
from flax import errors
from flax.core import spmd as core_spmd
from flax.nnx import reprlib, tracers, visualization
from flax.typing import BaseConfigContext, MISSING, Missing, SizeBytes
import jax
from jax._src.state.types import AbstractRef
import jax.experimental
from jax.experimental import hijax as hjx
import jax.tree_util as jtu
import treescope # type: ignore[import-untyped]
A = tp.TypeVar('A')
B = tp.TypeVar('B')
C = tp.TypeVar('C')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
P = tp.TypeVar('P', bound=property)
V = tp.TypeVar('V', bound='Variable[Any]')
GetValueHook = tp.Callable[['Variable[A]', A], A]
SetValueHook = tp.Callable[['Variable[A]', A], A]
CreateValueHook = tp.Callable[['Variable[A]', A], A]
AxisName = str
AxisIndex = int
AddAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None]
RemoveAxisHook = tp.Callable[[V, AxisIndex, AxisName | None], None]
# JAX array refs were renamed a few times between JAX v0.7.0 and v0.8.0.
# The following ensures we avoid an ImportError or DeprecationWarning.
if hasattr(jax, 'new_ref') and hasattr(jax, 'Ref'):
# JAX v0.7.2 or newer
from jax import Ref
elif hasattr(jax, 'array_ref') and hasattr(jax, 'ArrayRef'):
# JAX v0.7.1
from jax import ArrayRef as Ref # type: ignore[import-untyped,no-redef]
else:
# JAX v0.7.0 or older
from jax.experimental import MutableArray as Ref # type: ignore[no-redef]
@dataclasses.dataclass
class VariableContext(threading.local):
variable_hijax_stack: list[bool] = dataclasses.field(default_factory=list)
variable_ref_stack: list[bool] = dataclasses.field(default_factory=list)
eager_shard_stack: list[bool] = dataclasses.field(default_factory=list)
VARIABLE_CONTEXT = VariableContext()
class use_eager_sharding(BaseConfigContext):
"""Sets whether Variables should use eager sharding by default or not.
Example usage::
>>> from flax import nnx
>>> # Use eager sharding by default
>>> nnx.use_eager_sharding(True)
<...>
>>> # Variable will now use eager sharding
>>> nnx.using_eager_sharding()
True
It can also be used as a context manager to temporarily
change the default behavior for a block of code::
>>> nnx.use_eager_sharding(False)
<...>
>>> with nnx.use_eager_sharding(True):
... nnx.using_eager_sharding()
True
>>> # it will reset outside
>>> v = nnx.Variable(jax.numpy.ones((2, 3)))
>>> nnx.using_eager_sharding()
False
Args:
value: A boolean indicating if Variables should use eager sharding by default.
Returns:
A context manager that resets the context to the previous value.
"""
get_default = classmethod(lambda cls: config.flax_always_shard_variable)
get_stack = classmethod(lambda cls: VARIABLE_CONTEXT.eager_shard_stack)
def using_eager_sharding() -> bool:
"""Returns whether Variables are using eager sharding by default.
Example::
>>> from flax import nnx
>>> nnx.use_eager_sharding(True)
<...>
>>> nnx.using_eager_sharding()
True
>>> nnx.use_eager_sharding(False)
<...>
>>> nnx.using_eager_sharding()
False
Returns:
A boolean indicating if Variables are using eager sharding by default.
"""
return use_eager_sharding.current_value()
@dataclasses.dataclass(frozen=True)
class VarDefaults(tp.Mapping[str, tp.Any]):
hijax: bool
ref: bool
def __getitem__(self, key: str) -> tp.Any:
return getattr(self, key)
def __iter__(self) -> tp.Iterator[str]:
return iter(dataclasses.asdict(self))
def __len__(self) -> int:
return len(dataclasses.fields(self))
@tp.overload
def var_defaults() -> VarDefaults: ...
@tp.overload
def var_defaults(
*, hijax: bool | None = None, ref: bool | None = None
) -> VarDefaultsContext: ...
def var_defaults(
*, hijax: bool | None = None, ref: bool | None = None
) -> VarDefaultsContext | VarDefaults:
if hijax is None and ref is None:
return VarDefaults(
hijax=VARIABLE_CONTEXT.variable_hijax_stack[-1]
if VARIABLE_CONTEXT.variable_hijax_stack
else config.flax_hijax_variable,
ref=VARIABLE_CONTEXT.variable_ref_stack[-1]
if VARIABLE_CONTEXT.variable_ref_stack
else False,
)
hijax_prev = None
if hijax is not None:
if VARIABLE_CONTEXT.variable_hijax_stack:
hijax_prev = VARIABLE_CONTEXT.variable_hijax_stack[-1]
VARIABLE_CONTEXT.variable_hijax_stack[-1] = hijax
else:
VARIABLE_CONTEXT.variable_hijax_stack.append(hijax)
ref_prev = None
if ref is not None:
if VARIABLE_CONTEXT.variable_ref_stack:
ref_prev = VARIABLE_CONTEXT.variable_ref_stack[-1]
VARIABLE_CONTEXT.variable_ref_stack[-1] = ref
else:
VARIABLE_CONTEXT.variable_ref_stack.append(ref)
return VarDefaultsContext(
hijax_prev=hijax_prev,
hijax_new=hijax,
ref_prev=ref_prev,
ref_new=ref,
)
class VarDefaultsContext:
def __init__(
self,
*,
hijax_prev: bool | None,
hijax_new: bool | None,
ref_prev: bool | None,
ref_new: bool | None,
):
self.hijax_prev = hijax_prev
self.hijax_new = hijax_new
self.ref_prev = ref_prev
self.ref_new = ref_new
def __enter__(self):
if self.hijax_new is not None and self.hijax_prev is not None:
VARIABLE_CONTEXT.variable_hijax_stack.insert(-1, self.hijax_prev)
if self.ref_new is not None and self.ref_prev is not None:
VARIABLE_CONTEXT.variable_ref_stack.insert(-1, self.ref_prev)
def __exit__(self, exc_type, exc_value, traceback):
if self.hijax_new is not None:
VARIABLE_CONTEXT.variable_hijax_stack.pop()
if self.ref_new is not None:
VARIABLE_CONTEXT.variable_ref_stack.pop()
def __call__(self, f: F) -> F:
# undo stack change for decorator usage
if self.hijax_new is not None:
VARIABLE_CONTEXT.variable_hijax_stack.pop()
if self.hijax_prev is not None:
VARIABLE_CONTEXT.variable_hijax_stack.append(self.hijax_prev)
if self.ref_new is not None:
VARIABLE_CONTEXT.variable_ref_stack.pop()
if self.ref_prev is not None:
VARIABLE_CONTEXT.variable_ref_stack.append(self.ref_prev)
@functools.wraps(f)
def var_defaults_wrapper(*args, **kwargs):
if self.hijax_new is not None:
VARIABLE_CONTEXT.variable_hijax_stack.append(self.hijax_new)
if self.ref_new is not None:
VARIABLE_CONTEXT.variable_ref_stack.append(self.ref_new)
try:
return f(*args, **kwargs)
finally:
if self.hijax_new is not None:
VARIABLE_CONTEXT.variable_hijax_stack.pop()
if self.ref_new is not None:
VARIABLE_CONTEXT.variable_ref_stack.pop()
return var_defaults_wrapper # type: ignore[return-value]
def is_array_ref(x) -> tp.TypeGuard[Ref]:
return isinstance(x, jax.Array | AbstractRef | Ref) and isinstance(
jax.typeof(x), AbstractRef | Ref
)
PyTreeDef = tp.Any
Leaf = tp.Any
# ---------------------------------
# hijax
# ---------------------------------
@dataclasses.dataclass(frozen=True)
class VariableQDD:
leaf_avals: tuple[hjx.AbstractValue, ...]
treedef: PyTreeDef
var_type: type[Variable[Any]]
def to_tangent_qdd(self):
leaf_avals = tuple(a.to_tangent_aval() for a in self.leaf_avals)
return VariableQDD(leaf_avals, self.treedef, self.var_type)
def normalize(self):
leaf_types = tuple(a.normalize() for a in self.leaf_avals)
return VariableQDD(leaf_types, self.treedef, self.var_type)
class VariableEffect(jax.core.Effect): ...
variable_effect = VariableEffect()
hjx.control_flow_allowed_effects.add_type(VariableEffect)
def _bind_new_variable(
*leaves, treedef, var_type, has_qdd, ref
) -> HijaxVariable:
"""Binds new_variable_p after instantiating any Zero tangents."""
leaves = tuple(hjx.instantiate_zeros(leaf) for leaf in leaves)
return new_variable_p.bind(
*leaves,
treedef=treedef,
var_type=var_type,
has_qdd=has_qdd,
ref=ref,
)
def _new_hijax_from_variable(variable: Variable) -> HijaxVariable:
has_qdd = not variable.ref
leaves, treedef = jax.tree.flatten(variable)
var_type = type(variable)
hijax_var = _bind_new_variable(
*leaves,
treedef=treedef,
var_type=var_type,
has_qdd=has_qdd,
ref=variable.ref,
)
return hijax_var
class NewVariable(hjx.HiPrimitive):
def is_high(self, *leaves, treedef, var_type, has_qdd, ref) -> bool:
return True # type: ignore
def impl(self, *leaves, treedef, var_type, has_qdd, ref):
return HijaxVariable._new(
leaves, treedef, var_type, has_qdd, ref=ref
)
def abstract_eval(self, *leaves, treedef, var_type, has_qdd, ref):
aval = AbstractVariable(
var_type, treedef, leaves, has_qdd, ref=ref
)
if has_qdd:
qdd = VariableQDD(tuple(leaves), treedef, var_type)
aval_qdd = hjx.AvalQDD(aval, qdd) # type: ignore
return aval_qdd, {variable_effect}
else:
return aval, set()
def to_lojax(self, *leaves, treedef, var_type, has_qdd, ref):
return HijaxVariable._new(leaves, treedef, var_type, has_qdd, ref=ref)
def jvp(_, primals, tangents, *, treedef, var_type, has_qdd, ref):
if has_qdd:
raise NotImplementedError(
"jvp not implemented for 'new_variable' with QDD"
)
primal_hijax_var = _bind_new_variable(
*primals, treedef=treedef, var_type=var_type, has_qdd=has_qdd, ref=ref
)
tangent_hijax_var = _bind_new_variable(
*tangents, treedef=treedef, var_type=var_type, has_qdd=has_qdd, ref=ref
)
return primal_hijax_var, tangent_hijax_var
def transpose(
_, out_var: HijaxVariable, *input_leaves, treedef, var_type, has_qdd, ref
):
if has_qdd:
raise NotImplementedError(
"transpose not implemented for 'new_variable' with QDD"
)
avals = tuple(
map(
lambda x: x.aval if hjx.is_undefined_primal(x) else jax.typeof(x),
input_leaves,
)
)
leaves_dot = get_variable_p.bind(
out_var,
treedef=treedef,
avals=avals,
var_type=var_type,
has_qdd=has_qdd,
)
return leaves_dot
new_variable_p = NewVariable(f'new_variable')
def _set_hijax_state(hijax_var, variable: Variable):
leaves, treedef = jax.tree.flatten(variable)
set_variable_p.bind(
hijax_var, *leaves, treedef=treedef, var_type=type(variable)
)
class SetVariable(hjx.HiPrimitive):
multiple_results = True
def is_high(_, *leaf_avals, treedef, var_type) -> bool:
return True # type: ignore
# TODO: upstream this to Box
def impl(_, hijax_var: HijaxVariable, *leaves, treedef, var_type):
if not hijax_var.has_qdd:
raise errors.ImmutableVariableError(
"Trying to update Variable with 'has_qdd=False'."
)
assert var_type is hijax_var._var_type
object.__setattr__(hijax_var, '_leaves', leaves)
object.__setattr__(hijax_var, '_treedef', treedef)
return []
def abstract_eval(
_, aval_mutable_qdd: hjx.AvalMutableQDD, *leaf_avals, treedef, var_type
):
hijax_var: AbstractVariable = aval_mutable_qdd.aval # type: ignore
assert isinstance(hijax_var, AbstractVariable)
if not hijax_var.has_qdd:
raise errors.ImmutableVariableError(
"Trying to update Variable with 'has_qdd=False'."
)
assert var_type is hijax_var._var_type
aval_mutable_qdd.mutable_qdd.update(
VariableQDD(leaf_avals, treedef, var_type)
)
effects = {variable_effect} if hijax_var.has_qdd else set()
return [], effects # TODO better typechecking...
def to_lojax(_, hijax_var: HijaxVariable, *leaves, treedef, var_type):
if not hijax_var.has_qdd:
raise errors.ImmutableVariableError(
"Trying to update Variable with 'has_qdd=False'."
)
assert var_type is hijax_var._var_type
object.__setattr__(hijax_var, '_leaves', leaves)
object.__setattr__(hijax_var, '_treedef', treedef)
return []
def jvp(_, primals, tangents, *, treedef, var_type):
variable: Variable
variable, *vals = primals
variable_dot: Variable
variable_dot, *val_dots = tangents
if type(variable_dot._raw_value) is hjx.Zero:
raise Exception(
"can't differentiate Variable._set operation, "
'did you forget jax.lax.stop_gradient?'
)
set_variable_p.bind(
variable, *vals, treedef=treedef, var_type=type(variable)
)
set_variable_p.bind(
variable_dot, *val_dots, treedef=treedef, var_type=type(variable_dot)
)
return [], []
def transpose(_, *args, treedef, var_type):
raise NotImplementedError('transpose not implemented for SetHijaxVariable')
set_variable_p = SetVariable(f'set_variable')
def _get_hijax_state(hijax_var: HijaxVariable | AbstractVariable) -> Variable:
if hijax_var.has_qdd:
tys: VariableQDD = jax.experimental.cur_qdd(hijax_var)
leaf_vals = get_variable_p.bind(
hijax_var,
treedef=tys.treedef,
avals=tuple(tys.leaf_avals),
var_type=hijax_var._var_type,
has_qdd=hijax_var.has_qdd,
)
variable = jax.tree.unflatten(tys.treedef, leaf_vals)
else:
assert hijax_var._treedef is not None
assert hijax_var._leaves is not None
if isinstance(hijax_var, (jax.core.Tracer, AbstractVariable)):
leaf_avals = hijax_var._leaves
else:
leaf_avals = tuple(map(jax.typeof, hijax_var._leaves))
leaf_vals = get_variable_p.bind(
hijax_var,
treedef=hijax_var._treedef,
avals=leaf_avals,
var_type=hijax_var._var_type,
has_qdd=hijax_var.has_qdd,
)
variable = jax.tree.unflatten(hijax_var._treedef, leaf_vals)
return variable
class GetVariable(hjx.HiPrimitive):
multiple_results = True
def impl(
self, hijax_var: HijaxVariable, *, treedef, avals, var_type, has_qdd
):
return hijax_var._leaves
def abstract_eval(self, abstract_var, *, treedef, avals, var_type, has_qdd):
if has_qdd:
return avals, {variable_effect}
else:
return avals, set()
def to_lojax(
_, hijax_var: HijaxVariable, *, treedef, avals, var_type, has_qdd
):
return hijax_var._leaves
def jvp(_, primals, tangents, *, treedef, avals, var_type, has_qdd):
if has_qdd:
raise NotImplementedError(
"jvp not implemented for 'get_variable' with QDD"
)
(hijax_var,), (hijax_var_dot,) = primals, tangents
return (
get_variable_p.bind(
hijax_var,
treedef=treedef,
avals=avals,
var_type=var_type,
has_qdd=has_qdd,
),
get_variable_p.bind(
hijax_var_dot,
treedef=treedef,
avals=tuple(a.to_tangent_aval() for a in avals),
var_type=var_type,
has_qdd=has_qdd,
),
)
def transpose(_, out, hijax_var, *, treedef, avals, var_type, has_qdd):
if has_qdd:
raise NotImplementedError(
"transpose not implemented for 'get_variable' with QDD"
)
abstract_var: AbstractVariable = (
hijax_var.aval
if hjx.is_undefined_primal(hijax_var)
else jax.typeof(hijax_var)
)
hijax_var_dot = _bind_new_variable(
*out,
treedef=abstract_var._treedef,
var_type=var_type,
has_qdd=has_qdd,
ref=abstract_var.ref,
)
return (hijax_var_dot,)
get_variable_p = GetVariable(f'get_variable')
# ---------------------------------
# HijaxVariable
# ---------------------------------
def _variable_has_changed(old: Variable, new: Variable) -> bool:
old_structure = jax.tree.structure(old)
new_structure = jax.tree.structure(new)
if old_structure != new_structure: # type: ignore[operator]
return True
old_leaves = jax.tree.leaves(old)
new_leaves = jax.tree.leaves(new)
return any(o is not n for o, n in zip(old_leaves, new_leaves))
def _as_hijax_property(name: str, *, get: bool, set: bool) -> property:
"""Creates a property that operates on the hijax type."""
def _getter_wrapper(hijax_var):
variable = _get_hijax_state(hijax_var)
old_state = jax.tree.map(lambda x: x, variable)
out = getattr(variable, name)
if _variable_has_changed(old_state, variable):
_set_hijax_state(hijax_var, variable)
return out
def _setter_wrapper(hijax_var, value):
variable = _get_hijax_state(hijax_var)
setattr(variable, name, value)
_set_hijax_state(hijax_var, variable)
_hijax_property = property(
fget=_getter_wrapper if get else None,
fset=_setter_wrapper if set else None,
)
return _hijax_property # type: ignore[return]
def _as_aval_property(p: property) -> hjx.aval_property:
"""Wraps a property `p` operate on the aval type."""
_aval_property = hjx.aval_property(fget=p.fget)
return _aval_property # type: ignore[return]
def _as_hijax_attribute(name: str) -> property:
"""Creates a property that operates on the hijax type."""
def _getter_wrapper(hijax_var):
variable = _get_hijax_state(hijax_var)
old_state = jax.tree.map(lambda x: x, variable)
out = getattr(variable, name)
if _variable_has_changed(old_state, variable):
_set_hijax_state(hijax_var, variable)
return out
_getter_wrapper.__name__ = name
_hijax_property = property(fget=_getter_wrapper)
return _hijax_property # type: ignore[return]
def _as_hijax_method(name: str) -> tp.Any:
"""Creates a method that operates on the hijax type."""
def hijax_method_wrapper(hijax_var, *args, **kwargs):
variable = _get_hijax_state(hijax_var)
old_state = jax.tree.map(lambda x: x, variable)
method = getattr(variable, name)
out = method(*args, **kwargs)
if _variable_has_changed(old_state, variable):
_set_hijax_state(hijax_var, variable)
return out
hijax_method_wrapper.__name__ = name
return hijax_method_wrapper
def _as_tracer_method(name: str):
def op(self, hijax_var, *args, **kwargs):
variable = _get_hijax_state(hijax_var)
old_state = jax.tree.map(lambda x: x, variable)
out = getattr(variable, name)(*args, **kwargs)
if _variable_has_changed(old_state, variable):
_set_hijax_state(hijax_var, variable)
return out
op.__name__ = name
return op
def _not_an_attribute_property(name: str):
def _op(self):
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
return property(_op)
class HijaxVariableMeta(type):
def __instancecheck__(self, instance):
if super().__instancecheck__(instance):
return True
if isinstance(instance, jax.core.Tracer):
ty = jax.typeof(instance)
return isinstance(ty, AbstractVariable)
return False
class HijaxVariable(
tp.Generic[A], reprlib.Representable, metaclass=HijaxVariableMeta
): # type: ignore
__slots__ = ('_treedef', '_leaves', '_var_type', 'has_qdd', '_ref')
_treedef: PyTreeDef
_leaves: tuple[Leaf, ...]
_var_type: type[Variable[tp.Any]]
has_qdd: bool
_ref: bool
@classmethod
def _new(
cls,
leaves: tuple[Leaf, ...],
treedef: PyTreeDef,
var_type: type[Variable[A]],
has_qdd: bool,
*,
ref: bool = False,
):
hijax_var = object.__new__(cls)
object.__setattr__(hijax_var, '_treedef', treedef)
object.__setattr__(hijax_var, '_leaves', leaves)
object.__setattr__(hijax_var, '_var_type', var_type)
object.__setattr__(hijax_var, 'has_qdd', has_qdd)
object.__setattr__(hijax_var, '_ref', ref)
return hijax_var
__init__ = _as_hijax_method('__init__')
@property
def value(self) -> A:
raise NotImplementedError(
'HijaxVariable.value property is not implemented. For Variable[Array] instances use:\n\n'
' variable[...]\n\n'
'For other Variable types use:\n\n'
' variable.get_value()\n'
)
@value.setter
def value(self, new_value: A):
raise NotImplementedError(
'HijaxVariable.value property is not implemented. For Variable[Array] instances use:\n\n'
' variable[...] = new_value\n\n'
'For other Variable types use:\n\n'
' variable.set_value(new_value)\n'
)
@property
def var_type(self) -> type[Variable[A]]:
return self._var_type
_trace_state = _as_hijax_property('_trace_state', get=True, set=False)
_can_update = _as_hijax_property('_can_update', get=True, set=False)
_check_can_update = _as_hijax_method('_check_can_update')
__getattr__ = _as_hijax_method('__getattr__')
__setattr__ = _as_hijax_method('__setattr__')
__delattr__ = _as_hijax_method('__delattr__')
type = _as_hijax_property('type', get=True, set=False)
type = _as_hijax_property('type', get=True, set=False)
hijax = _as_hijax_property('hijax', get=True, set=False)
@property
def ref(self) -> bool:
return self._ref
get_metadata = _as_hijax_method('get_metadata')
set_metadata = _as_hijax_method('set_metadata')
def copy_from(self, other: Variable[A] | HijaxVariable[A]) -> None:
if isinstance(other, HijaxVariable):
other = _get_hijax_state(other)
variable = _get_hijax_state(self)
variable.copy_from(other) # type: ignore[arg-type]
_set_hijax_state(self, variable)
def update_from_state(self, variable_state: Variable[A] | HijaxVariable[A]):
if isinstance(variable_state, HijaxVariable):
variable_state = _get_hijax_state(variable_state)
variable = _get_hijax_state(self)
variable.update_from_state(variable_state) # type: ignore[arg-type]
_set_hijax_state(self, variable)
get_raw_value = _as_hijax_method('get_raw_value')
set_raw_value = _as_hijax_method('set_raw_value')
set_value = _as_hijax_method('set_value')
get_value = _as_hijax_method('get_value')
create_value = _as_hijax_method('create_value')
set_raw_value = _as_hijax_method('set_raw_value')
add_axis = _as_hijax_method('add_axis')
remove_axis = _as_hijax_method('remove_axis')
copy = _as_hijax_method('copy')
replace = _as_hijax_method('replace')
to_state = _as_hijax_method('to_state')
@classmethod
def from_metadata(cls, value: A, metadata: dict[str, tp.Any]):
return cls._var_type.from_metadata(value, metadata) # type: ignore[misc]
__nnx_repr__ = _as_hijax_method('__nnx_repr__')
__treescope_repr__ = _as_hijax_method('__treescope_repr__')
# --------------------------------------------
# proxy methods
# --------------------------------------------
__jax_array__ = _as_hijax_method('__jax_array__')
__getitem__ = _as_hijax_method('__getitem__')
__setitem__ = _as_hijax_method('__setitem__')
__delitem__ = _as_hijax_method('__delitem__')
__call__ = _as_hijax_method('__call__')
__len__ = _as_hijax_method('__len__')
__iter__ = _as_hijax_method('__iter__')
__contains__ = _as_hijax_method('__contains__')
__add__ = _as_hijax_method('__add__')
__sub__ = _as_hijax_method('__sub__')
__mul__ = _as_hijax_method('__mul__')
__matmul__ = _as_hijax_method('__matmul__')
__truediv__ = _as_hijax_method('__truediv__')
__floordiv__ = _as_hijax_method('__floordiv__')
__mod__ = _as_hijax_method('__mod__')
__divmod__ = _as_hijax_method('__divmod__')
__pow__ = _as_hijax_method('__pow__')
__lshift__ = _as_hijax_method('__lshift__')
__rshift__ = _as_hijax_method('__rshift__')
__and__ = _as_hijax_method('__and__')
__xor__ = _as_hijax_method('__xor__')
__or__ = _as_hijax_method('__or__')
__radd__ = _as_hijax_method('__radd__')
__rsub__ = _as_hijax_method('__rsub__')
__rmul__ = _as_hijax_method('__rmul__')
__rmatmul__ = _as_hijax_method('__rmatmul__')
__rtruediv__ = _as_hijax_method('__rtruediv__')
__rfloordiv__ = _as_hijax_method('__rfloordiv__')
__rmod__ = _as_hijax_method('__rmod__')
__rdivmod__ = _as_hijax_method('__rdivmod__')
__rpow__ = _as_hijax_method('__rpow__')
__rlshift__ = _as_hijax_method('__rlshift__')
__rrshift__ = _as_hijax_method('__rrshift__')
__rand__ = _as_hijax_method('__rand__')
__rxor__ = _as_hijax_method('__rxor__')
__ror__ = _as_hijax_method('__ror__')
__iadd__ = _as_hijax_method('__iadd__')
__isub__ = _as_hijax_method('__isub__')
__imul__ = _as_hijax_method('__imul__')
__imatmul__ = _as_hijax_method('__imatmul__')
__itruediv__ = _as_hijax_method('__itruediv__')
__ifloordiv__ = _as_hijax_method('__ifloordiv__')
__imod__ = _as_hijax_method('__imod__')
__ipow__ = _as_hijax_method('__ipow__')
__ilshift__ = _as_hijax_method('__ilshift__')
__irshift__ = _as_hijax_method('__irshift__')
__iand__ = _as_hijax_method('__iand__')
__ixor__ = _as_hijax_method('__ixor__')
__ior__ = _as_hijax_method('__ior__')
__neg__ = _as_hijax_method('__neg__')
__pos__ = _as_hijax_method('__pos__')
__abs__ = _as_hijax_method('__abs__')
__invert__ = _as_hijax_method('__invert__')
__complex__ = _as_hijax_method('__complex__')
__int__ = _as_hijax_method('__int__')
__float__ = _as_hijax_method('__float__')
__index__ = _as_hijax_method('__index__')
__round__ = _as_hijax_method('__round__')
__trunc__ = _as_hijax_method('__trunc__')
__floor__ = _as_hijax_method('__floor__')
__ceil__ = _as_hijax_method('__ceil__')
# --------------------------------------------
# hijax interface
# --------------------------------------------
def cur_qdd(self):
return self.type_state()
def type_state(self):
leaf_avals = tuple(map(jax.typeof, self._leaves))
return VariableQDD(leaf_avals, self._treedef, self._var_type)
def _to_abstract_variable(hijax_var: HijaxVariable):
if hijax_var.has_qdd:
treedef = None
leaves = None
else:
leaves = tuple(map(jax.typeof, hijax_var._leaves))
treedef = hijax_var._treedef
return AbstractVariable(
hijax_var._var_type,
treedef,
leaves,
hijax_var.has_qdd,
ref=hijax_var.ref,
)
hjx.register_hitype(HijaxVariable, _to_abstract_variable)
# ---------------------------------
# AbstractVariable
# ---------------------------------
class AbstractVariable(tp.Generic[A], hjx.MutableHiType):
__slots__ = ['_var_type', '_treedef', '_leaves', 'has_qdd', '_ref']
_var_type: type[Variable[A]]
_treedef: PyTreeDef | None
_leaves: tuple[hjx.AbstractValue, ...] | None
has_qdd: bool
_ref: bool
@property
def ref(self) -> bool:
return self._ref
@property
def hijax(self):
return True
_check_can_update = hjx.aval_method(HijaxVariable._check_can_update)
def __init__(
self,
var_type: type[Variable[A]],
treedef: PyTreeDef | None,
leaves: tuple[hjx.AbstractValue, ...] | None,
has_qdd: bool,
*,
ref: bool = False,
):
if (treedef is None) ^ (leaves is None):
raise ValueError('treedef and leaves must be both provided or both None')
object.__setattr__(self, '_treedef', treedef)
object.__setattr__(self, '_leaves', leaves)
object.__setattr__(self, '_var_type', var_type)
object.__setattr__(self, 'has_qdd', has_qdd)
object.__setattr__(self, '_ref', ref)
@property
def dtype(self):
raise AttributeError
@property
def ndim(self):
raise AttributeError
@property
def size(self):
raise AttributeError
@property
def shape(self):
raise AttributeError
def __getattr__(self, name: str):
# Forward unknown attributes to the value
if hasattr(AbstractVariable, name):
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
if name.startswith('_'):
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
return _as_aval_property(_as_hijax_attribute(name))
# __setattr__ supported via __getattr__
# __delattr__ CURRENTLY NOT SUPPORTED
type = _as_aval_property(HijaxVariable.type)
get_metadata = hjx.aval_method(HijaxVariable.get_metadata)
set_metadata = hjx.aval_method(HijaxVariable.set_metadata)
copy_from = hjx.aval_method(HijaxVariable.copy_from)
update_from_state = hjx.aval_method(HijaxVariable.update_from_state)
get_raw_value = hjx.aval_method(HijaxVariable.get_raw_value)
set_raw_value = hjx.aval_method(HijaxVariable.set_raw_value)
set_value = hjx.aval_method(HijaxVariable.set_value)
get_value = hjx.aval_method(HijaxVariable.get_value)
create_value = hjx.aval_method(HijaxVariable.create_value)
set_raw_value = hjx.aval_method(HijaxVariable.set_raw_value)
add_axis = hjx.aval_method(HijaxVariable.add_axis)
remove_axis = hjx.aval_method(HijaxVariable.remove_axis)
replace = hjx.aval_method(HijaxVariable.replace)
@hjx.aval_method
def from_metadata(self, value, metadata: dict[str, tp.Any]):
aval: AbstractVariable = self.aval # type: ignore
variable = aval._var_type.from_metadata(value, metadata)
return variable
copy = hjx.aval_method(HijaxVariable.copy)
replace = hjx.aval_method(HijaxVariable.replace)
to_state = hjx.aval_method(HijaxVariable.to_state)
def __str__(self):
return f'{self._var_type.__name__}()'
def __repr__(self):
return f'{self._var_type.__name__}()'
@hjx.aval_method
def __treescope_repr__(self, path, subtree_renderer):
raise NotImplementedError
# ---------------------------------
# proxy methods
# ---------------------------------
__jax_array__ = hjx.aval_method(HijaxVariable.__jax_array__)
_getitem = _as_tracer_method('__getitem__')
_setitem = _as_tracer_method('__setitem__')
# __delitem__ CURRENTLY NOT SUPPORTED
# __call__ CURRENTLY NOT SUPPORTED
_len = _as_tracer_method('__len__')
_iter = _as_tracer_method('__iter__')
# __contains__ CURRENTLY NOT SUPPORTED
_add = _as_tracer_method('__add__')
_sub = _as_tracer_method('__sub__')
_mul = _as_tracer_method('__mul__')
_matmul = _as_tracer_method('__matmul__')
_truediv = _as_tracer_method('__truediv__')
_floordiv = _as_tracer_method('__floordiv__')
_mod = _as_tracer_method('__mod__')
_divmod = _as_tracer_method('__divmod__')
_pow = _as_tracer_method('__pow__')
_lshift = _as_tracer_method('__lshift__')
_rshift = _as_tracer_method('__rshift__')
_and = _as_tracer_method('__and__')
_xor = _as_tracer_method('__xor__')
_or = _as_tracer_method('__or__')
_radd = _as_tracer_method('__radd__')
_rsub = _as_tracer_method('__rsub__')
_rmul = _as_tracer_method('__rmul__')
_rmatmul = _as_tracer_method('__rmatmul__')
_rtruediv = _as_tracer_method('__rtruediv__')
_rfloordiv = _as_tracer_method('__rfloordiv__')
_rmod = _as_tracer_method('__rmod__')
_rdivmod = _as_tracer_method('__rdivmod__')
_rpow = _as_tracer_method('__rpow__')
_rlshift = _as_tracer_method('__rlshift__')
_rrshift = _as_tracer_method('__rrshift__')
_rand = _as_tracer_method('__rand__')
_rxor = _as_tracer_method('__rxor__')
_ror = _as_tracer_method('__ror__')
# _iadd CURRENTLY NOT SUPPORTED
# _isub CURRENTLY NOT SUPPORTED
# _imul CURRENTLY NOT SUPPORTED
# _imatmul CURRENTLY NOT SUPPORTED
# _itruediv CURRENTLY NOT SUPPORTED
# _ifloordiv CURRENTLY NOT SUPPORTED
# _imod CURRENTLY NOT SUPPORTED
# _ipow CURRENTLY NOT SUPPORTED
# _ilshift CURRENTLY NOT SUPPORTED
# _irshift CURRENTLY NOT SUPPORTED
# _iand CURRENTLY NOT SUPPORTED
# _ixor CURRENTLY NOT SUPPORTED
# _ior CURRENTLY NOT SUPPORTED
_neg = _as_tracer_method('__neg__')
_pos = _as_tracer_method('__pos__')
_abs = _as_tracer_method('__abs__')
_invert = _as_tracer_method('__invert__')
_complex = _as_tracer_method('__complex__')
_int = _as_tracer_method('__int__')
_float = _as_tracer_method('__float__')
_index = _as_tracer_method('__index__')
_round = _as_tracer_method('__round__')
_trunc = _as_tracer_method('__trunc__')
_floor = _as_tracer_method('__floor__')
_ceil = _as_tracer_method('__ceil__')
# --------------------------------
# hijax interface
# --------------------------------
cur_qdd = _not_an_attribute_property('cur_qdd')
def __hash__(self):
if self._leaves is not None and self._treedef is not None:
return hash(
(AbstractVariable, self._var_type, self._treedef, self._leaves)
)
else:
assert self._leaves is None and self._treedef is None
return hash((AbstractVariable, self._var_type))
def __eq__(self, other):
return (
isinstance(other, AbstractVariable) and self._var_type == other._var_type
)
def str_short(self, short_dtypes=False, **_) -> str: # type: ignore
return f'{self._var_type.__name__}()'
# mutable interface
def lo_ty_qdd(self, variable_state: VariableQDD) -> list: # type: ignore
return [lo_ty for t in variable_state.leaf_avals for lo_ty in t.lo_ty()]
def new_from_loval( # type: ignore[override]
self, variable_state: VariableQDD, *lo_vals
) -> HijaxVariable:
lo_vals_ = iter(lo_vals)
hi_vals = [
hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) # type: ignore
for hi_ty in variable_state.leaf_avals
]
assert next(lo_vals_, None) is None
variable: Variable = jax.tree.unflatten(variable_state.treedef, hi_vals)
return HijaxVariable._new(
hi_vals,
variable_state.treedef,
self._var_type,
has_qdd=self.has_qdd,
ref=self.ref,
) # will be mutated
def read_loval(self, variable_state: VariableQDD, variable) -> list: # type: ignore
leaf_vals, treedef = jax.tree.flatten(_get_hijax_state(variable))
assert treedef == variable_state.treedef
return [
lo_val
for hi_ty, hi_val in zip(variable_state.leaf_avals, leaf_vals)
for lo_val in hi_ty.lower_val(hi_val)
] # type: ignore
def update_from_loval( # type: ignore[override]
self, box_state: VariableQDD, variable, *lo_vals
) -> None:
lo_vals_ = iter(lo_vals)
hi_vals = [
hi_ty.raise_val(*it.islice(lo_vals_, len(hi_ty.lo_ty()))) # type: ignore
for hi_ty in box_state.leaf_avals
]
assert next(lo_vals_, None) is None
_set_hijax_state(variable, jax.tree.unflatten(box_state.treedef, hi_vals))
def to_tangent_aval(self):
return AbstractVariable(
self._var_type,
self._treedef,
self._leaves,
self.has_qdd,
ref=self.ref,
)
# --------------------------------------------
# Variable
# --------------------------------------------
def _remap_sharding_metadata(metadata: dict[str, tp.Any]) -> None:
if 'sharding' in metadata:
warnings.warn(
"'sharding' is deprecated, use 'out_sharding' instead.",
DeprecationWarning,
stacklevel=3,
)
metadata['out_sharding'] = metadata.pop('sharding')
if 'sharding_names' in metadata:
warnings.warn(
"'sharding_names' is deprecated, use 'out_sharding' instead.",
DeprecationWarning,
stacklevel=3,
)
metadata['out_sharding'] = metadata.pop('sharding_names')
def _variable_operator(name: str) -> tp.Callable[[Variable[A], tp.Any], A]:
def variable_operator_method(self, other):
value = self.get_value()
if isinstance(other, Variable):
other = other.get_value()
return getattr(value, name)(other)
variable_operator_method.__name__ = name
return variable_operator_method
def _variable_unary_operator(name: str) -> tp.Callable[[Variable[A]], A]:
def variable_unary_operator_method(self):
value = self.get_value()
return getattr(value, name)()
variable_unary_operator_method.__name__ = name
return variable_unary_operator_method
class VariableMeta(type):
def __new__(cls, cls_name, bases, attrs):
if '__slots__' not in attrs:
attrs['__slots__'] = ()
return super().__new__(cls, cls_name, bases, attrs)
def __instancecheck__(self, instance):
if super().__instancecheck__(instance):
return True
if isinstance(instance, jax.core.Tracer):
ty = jax.typeof(instance)
if isinstance(ty, AbstractVariable):
return issubclass(ty._var_type, self)
if isinstance(instance, HijaxVariable):
return issubclass(instance._var_type, self)
return False
if not tp.TYPE_CHECKING:
def __call__(cls, *args, **kwargs):
return cls._variable_meta_call(*args, **kwargs)
def _variable_meta_call(cls, *args, **kwargs):
variable = super().__call__(*args, **kwargs)
if variable.hijax:
return _new_hijax_from_variable(variable)
return variable
[docs]class Variable(tp.Generic[A], reprlib.Representable, metaclass=VariableMeta):
"""The base class for all ``Variable`` types. Create custom ``Variable``
types by subclassing this class. Numerous NNX graph functions can filter
for specific ``Variable`` types, for example, :func:`split`, :func:`state`,
:func:`pop`, and :func:`State.filter`.
Example usage::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> class CustomVariable(nnx.Variable):
... pass
>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.linear = nnx.Linear(2, 3, rngs=rngs)
... self.custom_variable = CustomVariable(jnp.ones((1, 3)))
... def __call__(self, x):
... return self.linear(x) + self.custom_variable
>>> model = Model(rngs=nnx.Rngs(0))
>>> linear_variables = nnx.state(model, nnx.Param)
>>> jax.tree.map(jnp.shape, linear_variables)
State({
'linear': {
'bias': Param(
value=(3,)
),
'kernel': Param(
value=(2, 3)
)
}
})
>>> custom_variable = nnx.state(model, CustomVariable)
>>> jax.tree.map(jnp.shape, custom_variable)
State({
'custom_variable': CustomVariable(
value=(1, 3)
)
})
>>> variables = nnx.state(model)
>>> jax.tree.map(jnp.shape, variables)
State({
'custom_variable': CustomVariable(
value=(1, 3)
),
'linear': {
'bias': Param(
value=(3,)
),
'kernel': Param(
value=(2, 3)
)
}
})
"""
__slots__ = ('_raw_value', '_trace_state', '_var_metadata')
_raw_value: A
_trace_state: tracers.TraceState
_var_metadata: dict[str, tp.Any]
required_metadata = frozenset(
['hijax', 'ref', 'eager_sharding']
)
@property
def var_type(self):
return type(self)
@property
def hijax(self) -> bool:
return self._var_metadata['hijax']
@property
def ref(self) -> bool:
return self._var_metadata['ref']
@property
def shape(self: Variable[jax.Array]) -> tuple[int, ...]:
return self.get_value().shape
@property
def sharding_names(self):
warnings.warn(
"'sharding_names' is deprecated, use 'out_sharding' instead.",
DeprecationWarning,
stacklevel=2,
)
return self.get_metadata('out_sharding', None)
def __init__(
self,
value: A | VariableMetadata[A],
*,
hijax: bool | None = None,
ref: bool | None = None,
eager_sharding: bool | None = None,
**metadata: tp.Any,
):
var_t = type(self)
if isinstance(value, VariableMetadata):
aux_metadata = dict(value.metadata)
if 'hijax' in aux_metadata:
if hijax is not None and hijax != aux_metadata['hijax']:
raise ValueError(
'Cannot specify hijax both in VariableMetadata and as an '
'argument to Variable constructor.'
)
hijax = aux_metadata.pop('hijax')
if 'ref' in aux_metadata:
if ref is not None and ref != aux_metadata['ref']:
raise ValueError(
'Cannot specify ref both in VariableMetadata and as an '
'argument to Variable constructor.'
)
ref = aux_metadata.pop('ref')
if 'eager_sharding' in aux_metadata:
if (
eager_sharding is not None
and eager_sharding != aux_metadata['eager_sharding']
):
raise ValueError(
'Cannot specify eager_sharding both in VariableMetadata and as '
'an argument to Variable constructor.'
)
eager_sharding = aux_metadata['eager_sharding']
metadata.update(aux_metadata)
value = tp.cast(A, value.raw_value)
if hijax is None:
hijax = var_defaults().hijax
if ref is None:
ref = var_defaults().ref
if eager_sharding is None:
eager_sharding = using_eager_sharding()
if any(is_array_ref(v) for v in jax.tree.leaves(value)):
raise ValueError('Cannot pass a Ref directly into Variable constructor.')
metadata['hijax'] = hijax
metadata['ref'] = ref
metadata['eager_sharding'] = eager_sharding
object.__setattr__(self, '_trace_state', tracers.TraceState())
object.__setattr__(self, '_var_metadata', metadata)
object.__setattr__(self, '_raw_value', value)
if hasattr(var_t, 'on_get_value') and 'on_get_value' not in metadata:
metadata['on_get_value'] = var_t.on_get_value
if hasattr(var_t, 'on_set_value') and 'on_set_value' not in metadata:
metadata['on_set_value'] = var_t.on_set_value
if hasattr(var_t, 'on_create_value') and 'on_create_value' not in metadata:
metadata['on_create_value'] = var_t.on_create_value
if hasattr(var_t, 'on_add_axis') and 'on_add_axis' not in metadata:
metadata['on_add_axis'] = var_t.on_add_axis
if hasattr(var_t, 'on_remove_axis') and 'on_remove_axis' not in metadata:
metadata['on_remove_axis'] = var_t.on_remove_axis
_remap_sharding_metadata(metadata)
# run create_value hooks
if 'on_create_value' in metadata:
value = metadata['on_create_value'](self, value)
object.__setattr__(self, '_raw_value', value)
# run create_value hook
value = self.create_value(value) # type: ignore
# shard the _value if applicable
if eager_sharding and 'out_sharding' in metadata:
value = core_spmd.shard_value(
value,
metadata['out_sharding'],
metadata.get('sharding_rules', None),
metadata.get('mesh', None),
)
if ref:
value = jax.new_ref(value) # type: ignore
object.__setattr__(self, '_raw_value', value)
@property
def _can_update(self) -> bool:
"""Whether the Variable can be updated in-place in the current trace context."""
if self.hijax:
return True
else:
return self._trace_state.is_valid()
def _check_can_update(self):
if not self.hijax and not self._trace_state.is_valid():
raise errors.TraceContextError(
f'Cannot mutate {type(self).__name__} from a different trace level'
)
def __getattr__(self, name: str) -> tp.Any:
if name in object.__getattribute__(self, '_var_metadata'):
return self._var_metadata[name]
return getattr(object.__getattribute__(self, '_raw_value'), name)
def __setattr__(self, name: str, value: tp.Any):
self._check_can_update()
try:
object.__setattr__(self, name, value)
except AttributeError as e:
raise AttributeError(
f'Cannot set attribute {name}. '
f'To set Variable metadata use either:\n\n'
f' variable.set_metadata({name}=value)\n\nor\n\n'
f" variable.set_metadata('{name}', value)"
) from e
def __delattr__(self, name: str):
self._check_can_update()
try:
object.__delattr__(self, name)
except AttributeError as e:
raise AttributeError(
f'Cannot delete attribute {name}. '
f'To delete Variable metadata use:\n\n'
f" variable.del_metadata('{name}')"
) from e
# NOTE(cgarciae): adding this for backward compatibility with VariableState
@property
def type(self):
"""The type of the variable."""
return type(self)
@tp.overload
def get_metadata(
self, *, exclude_required: bool = False
) -> dict[str, tp.Any]: ...
@tp.overload
def get_metadata(self, name: str, default: tp.Any = MISSING) -> tp.Any: ...
@tp.overload
def set_metadata(self, metadata: dict[str, tp.Any], /) -> None: ...
@tp.overload
def set_metadata(self, name: str, value: tp.Any, /) -> None: ...
@tp.overload
def set_metadata(self, **metadata: tp.Any) -> None: ...
def copy_from(self, other: Variable[A]) -> None:
if type(self) is not type(other):
raise ValueError(
f'Cannot copy from incompatible container, '
f'expected {type(self).__name__}, got {type(other).__name__}'
)
if self is other:
return
self._raw_value = other._raw_value
self._var_metadata.clear()
self._var_metadata.update(other.get_metadata())
def update_from_state(self, variable_state: Variable[A]):
self._raw_value = variable_state._raw_value
if self._var_metadata != variable_state._var_metadata:
metadata = variable_state.get_metadata()
metadata['hijax'] = self.hijax
metadata['ref'] = self.ref
self._var_metadata = metadata
@tp.final
def get_raw_value(self) -> A:
return self._raw_value
# @tp.final
def set_raw_value(self, value: A, *, _unsafe_bypass_check: bool = False):
if not _unsafe_bypass_check:
self._check_can_update()
self._raw_value = value
@property
def raw_value(self) -> A:
warnings.warn(
"'.raw_value' access is now deprecated. Use:\n\n"
' variable.get_raw_value()\n',
DeprecationWarning,
stacklevel=2,
)
return self.get_raw_value()
@raw_value.setter
def raw_value(self, value: A):
warnings.warn(
"'.raw_value' setter is now deprecated. Use:\n\n"
' variable.set_raw_value(value)\n',
DeprecationWarning,
stacklevel=2,
)
self.set_raw_value(value)
@property
def value(self) -> A:
warnings.warn(
"'.value' access is now deprecated. For Variable[Array] instances use:\n\n"
' variable[...]\n\n'
'For other Variable types use:\n\n'
' variable.get_value()\n',
DeprecationWarning,
stacklevel=2,
)
return self.get_value()
@value.setter
def value(self, value: A):
warnings.warn(
"'.value' setter is now deprecated. For Variable[Array] instances use:\n\n"
' variable[...] = value\n\n'
'For other Variable types use:\n\n'
' variable.set_value(value)\n',
DeprecationWarning,
stacklevel=2,
)
self.set_value(value)
def create_value(self, value: A):
return value
def get_value(self, *, index: tp.Any = MISSING) -> A:
value = jax.tree.map(lambda x: x, self._raw_value) # make a copy
if not isinstance(index, Missing):
if is_array_ref(value):
value = value[index]
elif isinstance(value, jax.Array) and index is ...:
pass # skip trivial access
else:
value = value[index]
elif is_array_ref(value):
value = value[...]
if 'on_get_value' in self._var_metadata:
value = self._var_metadata['on_get_value'](self, value)
return value # type: ignore
def set_value(self, value: A, *, index: tp.Any = MISSING):
value = jax.tree.map(lambda x: x, value) # make a copy
if isinstance(value, Variable):
raise ValueError(
'Cannot set value to a Variable, use `copy_from` method instead'
)
if 'on_set_value' in self._var_metadata:
value = self._var_metadata['on_set_value'](self, value)
# update _raw_value
if is_array_ref(self._raw_value):
if isinstance(index, Missing):
self._raw_value[...] = value
else:
self._raw_value[index] = value
elif isinstance(self._raw_value, jax.Array) and (
not isinstance(index, Missing)
):
# check if its a full replace to av
if (
index == ...
and isinstance(value, jax.Array)
and value.shape == self._raw_value[index].shape
and value.dtype == self._raw_value.dtype
and (
getattr(value, 'sharding', None)
== getattr(self._raw_value, 'sharding', None)
)
):
self._raw_value = value
else:
self._raw_value = self._raw_value.at[index].set(value) # type: ignore
else:
if isinstance(index, Missing):
self._raw_value = value
else:
self._raw_value[index] = value # type: ignore
def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_add_axis' in self._var_metadata:
self._var_metadata['on_add_axis'](self, axis_index, axis_name)
def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None):
if 'on_remove_axis' in self._var_metadata:
self._var_metadata['on_remove_axis'](self, axis_index, axis_name)
@tp.overload
def copy(self, value: B, **kwargs) -> Variable[B]: ...
@tp.overload
def copy(self, **kwargs) -> Variable[A]: ...
def copy(
self,
value: tp.Any = MISSING,
*,
_copy_ref: bool = True,
**updates,
) -> Variable[tp.Any]:
assert 'raw_value' not in updates
new_metadata = self.get_metadata() | updates
if not isinstance(value, Missing):
pass
elif 'value' in updates:
value = updates.pop('value')
else:
value = self.get_raw_value()
if _copy_ref and is_array_ref(value):
value = value[...]
if _copy_ref and new_metadata['ref']:
value = jax.new_ref(value)
new_metadata['ref'] = True
obj = self.from_metadata(value, new_metadata)
return obj
@classmethod
def _new(
cls,
value: A,
metadata: dict[str, tp.Any],
) -> Variable[A]:
obj = object.__new__(cls)
# skip __setattr__ for trace_state initialization
object.__setattr__(obj, '_trace_state', tracers.TraceState())
object.__setattr__(obj, '_var_metadata', metadata)
object.__setattr__(obj, '_raw_value', value)
return obj
@classmethod
def from_metadata(
cls,
value: A,
attributes: dict[str, tp.Any],
) -> Variable[A]:
variable = cls._new(value, dict(attributes))
if attributes['hijax']:
variable = _new_hijax_from_variable(variable) # type: ignore[assignment]
return variable # type: ignore[return-value]
replace = copy
to_state = copy
def __nnx_repr__(self):
stats = SizeBytes.from_any(self._raw_value)
if stats:
comment = f' # {stats}'
else:
comment = ''
yield reprlib.Object(type=type(self).__name__, comment=comment)
yield reprlib.Attr('value', self.get_value())
for name, value in self._var_metadata.items():
if name == 'hijax' and value == config.flax_hijax_variable:
continue
if name == 'ref' and not value:
continue
if name == 'eager_sharding' and value == config.flax_always_shard_variable:
continue
yield reprlib.Attr(name, value)
def __treescope_repr__(self, path, subtree_renderer):
size_bytes = SizeBytes.from_any(self.get_value())
if size_bytes:
stats_repr = f' # {size_bytes}'
first_line_annotation = treescope.rendering_parts.comment_color(
treescope.rendering_parts.text(f'{stats_repr}')
)
else:
first_line_annotation = None
metadata = {
name: value
for name, value in self._var_metadata.items()
if not (name == 'hijax' and value == config.flax_hijax_variable)
and not (name == 'ref' and not value)
and not (name == 'eager_sharding' and value == config.flax_always_shard_variable)
}
children = {'value': self.get_value(), **metadata}
return visualization.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
first_line_annotation=first_line_annotation,
)
# hooks API
if tp.TYPE_CHECKING:
def on_get_value(self, value: A) -> A: ...
def on_set_value(self, value: A) -> A: ...
def on_create_value(self, value: A) -> A: ...
def on_add_axis(
self: V, axis_index: AxisIndex, axis_name: AxisName | None
) -> V: ...
def on_remove_axis(
self: V, axis_index: AxisIndex, axis_name: AxisName | None
) -> V: ...
def __jax_array__(self):
return self.get_value()
# pickle support
def __getstate__(self):
return {
'_raw_value': self._raw_value,
'_trace_state': self._trace_state,
'_var_metadata': self._var_metadata,
}
def __setstate__(self, state):
# skip __setattr__ for trace_state initialization
object.__setattr__(self, '_trace_state', state['_trace_state'])
object.__setattr__(self, '_var_metadata', state['_var_metadata'])
object.__setattr__(self, '_raw_value', state['_raw_value'])
# --------------------------------------------
# proxy methods
# --------------------------------------------
@tp.overload
def __getitem__(self: Variable[jax.Array], key) -> jax.Array:
...
@tp.overload
def __getitem__(self: Variable[dict[tp.Any, B]], key) -> B:
...
@tp.overload
def __getitem__(self: Variable[list[B]], key: int) -> B:
...
@tp.overload
def __getitem__(self: Variable[tuple[B, ...]], key: int) -> B:
...
@tp.overload
def __getitem__(self, key) -> tp.Any:
...
def __getitem__(self, key):
return self.get_value(index=key)
def __setitem__(self, key, value) -> None:
self.set_value(value, index=key)
def __delitem__(self, key) -> None:
value = self.get_value()
del value[key] # type: ignore
self.set_value(value) # type: ignore
def __call__(self, *args, **kwargs) -> tp.Any:
return self.get_value()(*args, **kwargs) # type: ignore
def __len__(self) -> int:
return len(self.get_value()) # type: ignore
def __iter__(self) -> tp.Iterator:
return iter(self.get_value()) # type: ignore
def __contains__(self, item) -> bool:
return item in self.get_value() # type: ignore
__add__ = _variable_operator('__add__')
__sub__ = _variable_operator('__sub__')
__mul__ = _variable_operator('__mul__')
__matmul__ = _variable_operator('__matmul__')
__truediv__ = _variable_operator('__truediv__')
__floordiv__ = _variable_operator('__floordiv__')
__mod__ = _variable_operator('__mod__')
__pow__ = _variable_operator('__pow__')
__lshift__ = _variable_operator('__lshift__')
__rshift__ = _variable_operator('__rshift__')
__and__ = _variable_operator('__and__')
__xor__ = _variable_operator('__xor__')
__or__ = _variable_operator('__or__')
__radd__ = _variable_operator('__radd__')
__rsub__ = _variable_operator('__rsub__')
__rmul__ = _variable_operator('__rmul__')
__rmatmul__ = _variable_operator('__rmatmul__')
__rtruediv__ = _variable_operator('__rtruediv__')
__rfloordiv__ = _variable_operator('__rfloordiv__')
__rmod__ = _variable_operator('__rmod__')
__rpow__ = _variable_operator('__rpow__')
__rlshift__ = _variable_operator('__rlshift__')
__rrshift__ = _variable_operator('__rrshift__')
__rand__ = _variable_operator('__rand__')
__rxor__ = _variable_operator('__rxor__')
__ror__ = _variable_operator('__ror__')
def __eq__(self, other) -> bool:
if isinstance(other, Variable):
other = other.get_value()
return self.get_value().__eq__(other) # type: ignore
def __iadd__(self: V, other) -> V:
raise NotImplementedError(
'In-place operations are no longer supported for Variable.\n'
'Use `variable[...] += x` instead.'
)
def __isub__(self: V, other) -> V:
raise NotImplementedError(
'In-place operations are no longer supported for Variable.\n'
'Use `variable[...] -= x` instead.'
)
def __imul__(self: V, other) -> V:
raise NotImplementedError(
'In-place operations are no longer supported for Variable.\n'
'Use `variable[...] *= x` instead.'
)
def __imatmul__(self: V, other) -> V:
raise NotImplementedError(
'In-place operations are no longer supported for Variable.\n'
'Use `variable.value @= x` instead.'
)
def __itruediv__(self: V, other) -> V:
raise NotImplementedError(
'In-place operations are no longer supported for Variable.\n'
'Use `variable.value /= x` instead.'
)
def __ifloordiv__(self: V, other) -> V:
raise NotImplementedError(
'In-place operations are no longer supported for Variable.\n'
'Use `variable.value //= x`` instead.'
)
def __imod__(self: V, other) -> V:
raise NotImplementedError(
'In-place operations are no longer supported for Variable.\n'
'Use `variable.value %= x` instead.'
)
def __ipow__(self: V, other) -> V:
raise NotImplementedError(
'In-place operations are no longer supported for Variable.\n'
'Use `variable.value **= x`` instead.'
)
def __ilshift__(self: V, other) -> V:
raise NotImplementedError(
'In-place operations are no longer supported for Variable.\n'
'Use `variable.value <<= x`` instead.'
)
def __irshift__(self: V, other) -> V:
raise NotImplementedError(
'In-place operations are no longer supported for Variable.\n'
'Use `variable.value >>= x`` instead.'
)
def __iand__(self: V, other) -> V:
raise NotImplementedError(
'In-place operations are no longer supported for Variable.\n'
'Use `variable.value &= x` instead.'
)
def __ixor__(self: V, other) -> V:
raise NotImplementedError(
'In-place operations are no longer supported for Variable.\n'
'Use `variable.value ^= x` instead.'
)
def __ior__(self: V, other) -> V:
raise NotImplementedError(
'In-place operations are no longer supported for Variable.\n'
'Use `variable.value |= x` instead.'
)
__neg__ = _variable_unary_operator('__neg__')
__pos__ = _variable_unary_operator('__pos__')
__abs__ = _variable_unary_operator('__abs__')
__invert__ = _variable_unary_operator('__invert__')
__complex__ = _variable_unary_operator('__complex__')
__int__ = _variable_unary_operator('__int__')
__float__ = _variable_unary_operator('__float__')
__index__ = _variable_unary_operator('__index__')
__trunc__ = _variable_unary_operator('__trunc__')
__floor__ = _variable_unary_operator('__floor__')
__ceil__ = _variable_unary_operator('__ceil__')
def __round__(self, ndigits: int = 0) -> A:
return self.get_value().__round__(ndigits) # type: ignore
# --------------------------------------------
def __init_subclass__(cls) -> None:
if '__slots__' not in vars(cls):
cls.__slots__ = () # type: ignore[assignment]
super().__init_subclass__()
jax.tree_util.register_pytree_with_keys(
cls,
flatten_with_keys=_variable_flatten_with_keys,
unflatten_func=partial(_variable_unflatten, cls), # type: ignore
flatten_func=_variable_flatten,
)
def _variable_flatten_with_keys(x: Variable[tp.Any]):
metadata = tuple(sorted(x._var_metadata.items()))
node = (jtu.GetAttrKey('value'), x._raw_value)
return (node,), metadata
def _variable_flatten(x: Variable[tp.Any]):
metadata = tuple(sorted(x._var_metadata.items()))
return (x._raw_value,), metadata
def _variable_unflatten(
cls: type[Variable[tp.Any]],
static: tuple[tuple[str, tp.Any], ...],
children: tuple[tp.Any],
):
return cls._new(children[0], dict(static))
jax.tree_util.register_pytree_with_keys(
Variable,
flatten_with_keys=_variable_flatten_with_keys,
unflatten_func=partial(_variable_unflatten, Variable), # type: ignore
flatten_func=_variable_flatten,
)
VariableState = Variable
[docs]class Param(Variable[A]):
"""The canonical learnable parameter. All learnable parameters
in NNX layer modules will have the ``Param`` :class:`Variable`
type::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
'bias': Param(
value=(3,)
),
'kernel': Param(
value=(2, 3)
)
})
"""
pass
[docs]class BatchStat(Variable[A]):
"""The mean and variance batch statistics stored in
the :class:`BatchNorm` layer. Note, these are not the
learnable scale and bias parameters, but rather the
running average statistics that are typically used
during post-training inference::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.BatchNorm(3, rngs=nnx.Rngs(0))
>>> jax.tree.map(jnp.shape, nnx.state(layer))
State({
'bias': Param(
value=(3,)
),
'mean': BatchStat(
value=(3,)
),
'scale': Param(
value=(3,)
),
'var': BatchStat(
value=(3,)
)
})
"""
pass
[docs]class Cache(Variable[A]):
"""Autoregressive cache in :class:`MultiHeadAttention`::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> layer = nnx.MultiHeadAttention(
... num_heads=2,
... in_features=3,
... qkv_features=6,
... out_features=6,
... decode=True,
... rngs=nnx.Rngs(0),
... )
>>> layer.init_cache((1, 3))
>>> jax.tree.map(jnp.shape, nnx.state(layer, nnx.Cache))
State({
'cache_index': Cache(
value=()
),
'cached_key': Cache(
value=(1, 2, 3)
),
'cached_value': Cache(
value=(1, 2, 3)
)
})
"""
pass
class Perturbation(Intermediate[A]):
""":class:`Variable` type that is typically used for
:func:`Module.perturb`::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.linear1 = nnx.Linear(2, 3, rngs=rngs)
... self.linear2 = nnx.Linear(3, 4, rngs=rngs)
... def __call__(self, x):
... x = self.linear1(x)
... x = self.perturb('i', x)
... x = self.linear2(x)
... return x
>>> model = Model(rngs=nnx.Rngs(0))
>>> x = jnp.ones((1, 2))
>>> y = model(x)
>>> jax.tree.map(jnp.shape, nnx.state(model, nnx.Perturbation))
State({
'i': Perturbation(
value=(1, 3)
)
})
"""
pass
###################################################
### Variable type/class <-> string name mapping ###
###################################################
# Assumption: the mapping is 1-1 and unique.
VariableTypeCache: dict[str, tp.Type[Variable[tp.Any]]] = {}
[docs]def variable_type_from_name(
name: str,
/,
*,
base: type[Variable[tp.Any]] = Variable,
allow_register: bool = False,
) -> tp.Type[Variable[tp.Any]]:
"""Given a Linen-style collection name, get or create its NNX Variable class."""
if name not in VariableTypeCache:
if not allow_register:
raise ValueError(
f'Name {name} is not registered in the registry. '
'To register a new name, use register_variable_name() '
'or set allow_register=True.'
)
VariableTypeCache[name] = type(name, (base,), {})
return VariableTypeCache[name]
[docs]def variable_name_from_type(
typ: tp.Type[Variable[tp.Any]], /, *, allow_register: bool = False
) -> str:
"""Given an NNX Variable type, get its Linen-style collection name.
Should output the exact inversed result of `variable_type_from_name()`."""
for name, t in VariableTypeCache.items():
if typ == t:
return name
if not allow_register:
raise ValueError(
f'Type {typ} is not registered in the registry. '
'To register a new type, use register_variable_name() '
'or set allow_register=True.'
)
name = typ.__name__
if name in VariableTypeCache:
raise ValueError(
'Name {name} is already registered in the registry as {VariableTypeCache[name]}. '
'It cannot be linked with this type {typ}.'
)
register_variable_name(name, typ)
return name
@tp.overload
def register_variable_name(
name: str,
typ: type[Variable[tp.Any]],
*,
overwrite: bool = False,
) -> type[Variable[tp.Any]]: ...
@tp.overload
def register_variable_name(
name: str,
*,
overwrite: bool = False,
) -> tp.Callable[[type[Variable[tp.Any]]], type[Variable[tp.Any]]]: ...
[docs]def register_variable_name(
name: str,
typ: type[Variable[A]] | Missing = MISSING,
*,
overwrite=False,
) -> type[Variable[A]] | tp.Callable[[type[Variable[A]]], type[Variable[A]]]:
"""Register a pair of Linen collection name and its NNX type."""
if isinstance(typ, Missing):
return partial(register_variable_name, name, overwrite=overwrite)
typ = tp.cast(type[Variable[A]], typ)
if not overwrite and name in VariableTypeCache:
raise ValueError(
f'Name {name} already mapped to type {VariableTypeCache[name]}. '
'To overwrite, call register_variable_name() with `overwrite=True`.'
)
VariableTypeCache[name] = typ
return typ
# add known variable type names
register_variable_name('params', Param)
register_variable_name('batch_stats', BatchStat)
register_variable_name('cache', Cache)
register_variable_name('intermediates', Intermediate)
register_variable_name('perturbations', Perturbation)