Source code for flax.nnx.rnglib

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import functools
import typing as tp

import jax
from jax import random
import jax.numpy as jnp

from flax import struct
from flax import typing
from flax.nnx import graphlib
from flax.nnx.nn import initializers
from flax.nnx.variablelib import Variable
from flax.nnx import filterlib
from flax.nnx.pytreelib import Pytree
from flax.typing import MISSING, Key, Missing
import warnings

F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
A = tp.TypeVar('A')
Counts = list[int]
AxesValue = tp.Union[int, None]
SplitPattern = tp.Union[AxesValue, tuple[AxesValue, ...]]
OutShardingType: tp.TypeAlias = (
  jax.sharding.NamedSharding | jax.sharding.PartitionSpec | None
)
Fargs = tp.ParamSpec('Fargs')


@tp.runtime_checkable
class KeylessInitializer(tp.Protocol):
  def __call__(
    self,
    shape: typing.Shape,
    dtype: tp.Any | None = None,
    out_sharding: OutShardingType = None,
  ) -> jax.Array:
    raise NotImplementedError


def _to_keyless(
  initializer_constructor: tp.Callable[Fargs, jax.nn.initializers.Initializer],
) -> tp.Callable[Fargs, KeylessInitializer]:
  raise NotImplementedError


def _function_to_method(random_f):
  @functools.wraps(random_f)
  def rngs_random_method(self: Rngs | RngStream, *args, **kwargs) -> jax.Array:
    return random_f(self(), *args, **kwargs)

  return rngs_random_method


def _initializer_to_method(
  initializer_constructor: tp.Callable[Fargs, jax.nn.initializers.Initializer],
):
  def rngs_initializer_method(
    self: Rngs | RngStream, *args: Fargs.args, **kwargs: Fargs.kwargs
  ) -> KeylessInitializer:
    init_fn = initializer_constructor(*args, **kwargs)

    def rngs_keyless_initializer(*init_args, **init_kwargs):
      return init_fn(self(), *init_args, **init_kwargs)

    return rngs_keyless_initializer

  return rngs_initializer_method


class RngState(Variable[jax.Array]):
  tag: str


class RngCount(RngState): ...


class RngKey(RngState): ...


NotKey = filterlib.All(RngState, filterlib.Not(RngKey))


[docs]class RngStream(Pytree): def __init__( self, key: jax.Array | int, *, tag: str, ): if isinstance(key, int): key = random.key(key) elif isinstance(key, jax.Array) and key.dtype == jnp.uint32: key = random.wrap_key_data(key) if not isinstance(key, jax.Array) or not jnp.issubdtype(key.dtype, jax.dtypes.prng_key): raise ValueError(f'Invalid rng value: {key}, expected a ' f'jax.Array of jax.dtypes.prng_key sub-dtype') count = jnp.zeros(key.shape, dtype=jnp.uint32) self.tag = tag self.key = RngKey(key, tag=tag) self.count = RngCount(count, tag=tag) def __call__(self) -> jax.Array: self.count._check_can_update() key = random.fold_in(self.key[...], self.count[...]) self.count[...] += 1 return key def split(self, k: int | tuple[int, ...]): key = random.split(self(), k) return type(self)(key, tag=self.tag) def fork(self, *, split: int | tuple[int, ...] | None = None): if split is not None: warnings.warn( "The 'split' argument of 'fork' is deprecated; use the 'split' method instead.", DeprecationWarning, stacklevel=2, ) key = self() if split is not None: key = random.split(key, split) return type(self)(key, tag=self.tag) # ---------------------------------------------------------- # random functions # ---------------------------------------------------------- if tp.TYPE_CHECKING: bits = staticmethod(functools.partial(random.bits, random.key(0))) uniform = staticmethod( functools.partial(random.uniform, random.key(0)) ) randint = staticmethod( functools.partial(random.randint, random.key(0)) ) permutation = staticmethod( functools.partial(random.permutation, random.key(0)) ) choice = staticmethod(functools.partial(random.choice, random.key(0))) normal = staticmethod(functools.partial(random.normal, random.key(0))) multivariate_normal = staticmethod( functools.partial(random.multivariate_normal, random.key(0)) ) truncated_normal = staticmethod( functools.partial(random.truncated_normal, random.key(0)) ) bernoulli = staticmethod( functools.partial(random.bernoulli, random.key(0)) ) beta = staticmethod(functools.partial(random.beta, random.key(0))) cauchy = staticmethod(functools.partial(random.cauchy, random.key(0))) dirichlet = staticmethod( functools.partial(random.dirichlet, random.key(0)) ) exponential = staticmethod( functools.partial(random.exponential, random.key(0)) ) gamma = staticmethod(functools.partial(random.gamma, random.key(0))) loggamma = staticmethod( functools.partial(random.loggamma, random.key(0)) ) poisson = staticmethod( functools.partial(random.poisson, random.key(0)) ) gumbel = staticmethod(functools.partial(random.gumbel, random.key(0))) categorical = staticmethod( functools.partial(random.categorical, random.key(0)) ) laplace = staticmethod( functools.partial(random.laplace, random.key(0)) ) logistic = staticmethod( functools.partial(random.logistic, random.key(0)) ) pareto = staticmethod(functools.partial(random.pareto, random.key(0))) t = staticmethod(functools.partial(random.t, random.key(0))) chisquare = staticmethod( functools.partial(random.chisquare, random.key(0)) ) f = staticmethod(functools.partial(random.f, random.key(0))) rademacher = staticmethod( functools.partial(random.rademacher, random.key(0)) ) maxwell = staticmethod( functools.partial(random.maxwell, random.key(0)) ) double_sided_maxwell = staticmethod( functools.partial(random.double_sided_maxwell, random.key(0)) ) weibull_min = staticmethod( functools.partial(random.weibull_min, random.key(0)) ) orthogonal = staticmethod( functools.partial(random.orthogonal, random.key(0)) ) generalized_normal = staticmethod( functools.partial(random.generalized_normal, random.key(0)) ) ball = staticmethod(functools.partial(random.ball, random.key(0))) rayleigh = staticmethod( functools.partial(random.rayleigh, random.key(0)) ) wald = staticmethod(functools.partial(random.wald, random.key(0))) geometric = staticmethod( functools.partial(random.geometric, random.key(0)) ) triangular = staticmethod( functools.partial(random.triangular, random.key(0)) ) lognormal = staticmethod( functools.partial(random.lognormal, random.key(0)) ) binomial = staticmethod( functools.partial(random.binomial, random.key(0)) ) multinomial = staticmethod( functools.partial(random.multinomial, random.key(0)) ) else: bits = _function_to_method(random.bits) uniform = _function_to_method(random.uniform) randint = _function_to_method(random.randint) permutation = _function_to_method(random.permutation) choice = _function_to_method(random.choice) normal = _function_to_method(random.normal) multivariate_normal = _function_to_method(random.multivariate_normal) truncated_normal = _function_to_method(random.truncated_normal) bernoulli = _function_to_method(random.bernoulli) beta = _function_to_method(random.beta) cauchy = _function_to_method(random.cauchy) dirichlet = _function_to_method(random.dirichlet) exponential = _function_to_method(random.exponential) gamma = _function_to_method(random.gamma) loggamma = _function_to_method(random.loggamma) poisson = _function_to_method(random.poisson) gumbel = _function_to_method(random.gumbel) categorical = _function_to_method(random.categorical) laplace = _function_to_method(random.laplace) logistic = _function_to_method(random.logistic) pareto = _function_to_method(random.pareto) t = _function_to_method(random.t) chisquare = _function_to_method(random.chisquare) f = _function_to_method(random.f) rademacher = _function_to_method(random.rademacher) maxwell = _function_to_method(random.maxwell) double_sided_maxwell = _function_to_method(random.double_sided_maxwell) weibull_min = _function_to_method(random.weibull_min) orthogonal = _function_to_method(random.orthogonal) generalized_normal = _function_to_method(random.generalized_normal) ball = _function_to_method(random.ball) rayleigh = _function_to_method(random.rayleigh) wald = _function_to_method(random.wald) geometric = _function_to_method(random.geometric) triangular = _function_to_method(random.triangular) lognormal = _function_to_method(random.lognormal) binomial = _function_to_method(random.binomial) multinomial = _function_to_method(random.multinomial) # ---------------------------------------------------------- # initializers # ---------------------------------------------------------- if tp.TYPE_CHECKING: # skip constant delta_orthogonal = staticmethod(_to_keyless(initializers.delta_orthogonal)) glorot_normal = staticmethod(_to_keyless(initializers.glorot_normal)) glorot_uniform = staticmethod(_to_keyless(initializers.glorot_uniform)) he_normal = staticmethod(_to_keyless(initializers.he_normal)) he_uniform = staticmethod(_to_keyless(initializers.he_uniform)) kaiming_normal = staticmethod(_to_keyless(initializers.kaiming_normal)) kaiming_uniform = staticmethod(_to_keyless(initializers.kaiming_uniform)) lecun_normal = staticmethod(_to_keyless(initializers.lecun_normal)) lecun_uniform = staticmethod(_to_keyless(initializers.lecun_uniform)) # skip normal as it conflicts with jax.random.normal # skip ones # skip orthogonal as it conflicts with jax.random.orthogonal # skip truncated_normal as it conflicts with jax.random.truncated_normal # skip uniform as it conflicts with jax.random.uniform variance_scaling = staticmethod(_to_keyless(initializers.variance_scaling)) xavier_normal = staticmethod(_to_keyless(initializers.xavier_normal)) xavier_uniform = staticmethod(_to_keyless(initializers.xavier_uniform)) # skip zeros else: # skip constant delta_orthogonal = _initializer_to_method(initializers.delta_orthogonal) glorot_normal = _initializer_to_method(initializers.glorot_normal) glorot_uniform = _initializer_to_method(initializers.glorot_uniform) he_normal = _initializer_to_method(initializers.he_normal) he_uniform = _initializer_to_method(initializers.he_uniform) kaiming_normal = _initializer_to_method(initializers.kaiming_normal) kaiming_uniform = _initializer_to_method(initializers.kaiming_uniform) lecun_normal = _initializer_to_method(initializers.lecun_normal) lecun_uniform = _initializer_to_method(initializers.lecun_uniform) # skip normal as it conflicts with jax.random.normal # skip ones # skip orthogonal as it conflicts with jax.random.orthogonal # skip truncated_normal as it conflicts with jax.random.truncated_normal # skip uniform as it conflicts with jax.random.uniform variance_scaling = _initializer_to_method(initializers.variance_scaling) xavier_normal = _initializer_to_method(initializers.xavier_normal) xavier_uniform = _initializer_to_method(initializers.xavier_uniform)
# skip zeros RngValue = tp.Union[int, jax.Array]
[docs]class Rngs(Pytree): """A small abstraction to manage RNG state. ``Rngs`` allows the creation of ``RngStream`` which are used to easily generate new unique random keys on demand. An ``RngStream`` is a wrapper around a JAX random ``key``, and a ``counter``. Every time a key is requested, the counter is incremented and the key is generated from the seed key and the counter by using ``jax.random.fold_in``. To create an ``Rngs`` pass in an integer or ``jax.random.key`` to the constructor as a keyword argument with the name of the stream. The key will be used as the starting seed for the stream, and the counter will be initialized to zero. Then call the stream to get a key:: >>> from flax import nnx >>> import jax, jax.numpy as jnp >>> rngs = nnx.Rngs(params=0, dropout=1) >>> param_key1 = rngs.params() >>> param_key2 = rngs.params() >>> dropout_key1 = rngs.dropout() >>> dropout_key2 = rngs.dropout() ... >>> assert param_key1 != dropout_key1 Trying to generate a key for a stream that was not specified during construction will result in an error being raised:: >>> rngs = nnx.Rngs(params=0, dropout=1) >>> try: ... key = rngs.unkown_stream() ... except AttributeError as e: ... print(e) No RngStream named 'unkown_stream' found in Rngs. The ``default`` stream can be created by passing in a key to the constructor without specifying a stream name. When the ``default`` stream is set the ``rngs`` object can be called directly to get a key, and calling streams that were not specified during construction will fallback to ``default``:: >>> rngs = nnx.Rngs(0, params=1) ... >>> key1 = rngs.default() # uses 'default' >>> key2 = rngs() # uses 'default' >>> key3 = rngs.params() # uses 'params' >>> key4 = rngs.dropout() # uses 'default' >>> key5 = rngs.unkown_stream() # uses 'default' """
[docs] def __init__( self, default: RngValue | RngStream | tp.Mapping[str, RngValue | RngStream] | None = None, **rngs: RngValue | RngStream, ): """ Args: default: the starting seed for the ``default`` stream, defaults to None. **rngs: keyword arguments specifying the starting seed for each stream. The key can be an integer or a ``jax.random.key``. """ if default is not None: if isinstance(default, tp.Mapping): rngs = {**default, **rngs} else: rngs['default'] = default for tag, key in rngs.items(): if isinstance(key, RngStream): key = key.key.get_value() stream = RngStream( key=key, tag=tag, ) setattr(self, tag, stream)
def _get_stream(self, name: str, error_type: type[Exception]) -> RngStream: stream_vars = vars(self) if name not in stream_vars: if 'default' not in stream_vars: raise error_type(f"No RngStream named '{name}' found in Rngs.") stream = stream_vars['default'] else: stream = stream_vars[name] return stream def __getitem__(self, name: str): return self._get_stream(name, KeyError) def __getattr__(self, name: str): return self._get_stream(name, AttributeError) def __call__(self): return self.default() def __iter__(self) -> tp.Iterator[str]: for name, stream in vars(self).items(): if isinstance(stream, RngStream): yield name def __len__(self) -> int: return sum( 1 for stream in vars(self).values() if isinstance(stream, RngStream) ) def __contains__(self, name: tp.Any) -> bool: return name in vars(self) def items(self): for name, stream in vars(self).items(): if isinstance(stream, RngStream): yield name, stream def split(self, k: tp.Mapping[filterlib.Filter, int | tuple[int, ...]] | int | tuple[int, ...]): """ Splits the keys of the newly created ``Rngs`` object. Example:: >>> from flax import nnx ... >>> rngs = nnx.Rngs(params=1, dropout=2) >>> new_rngs = rngs.split(5) ... >>> assert new_rngs.params.key.shape == (5,) >>> assert new_rngs.dropout.key.shape == (5,) ``split`` also accepts a mapping of `Filters <https://flax.readthedocs.io/en/latest/guides/filters_guide.html>`__ to split sizes or None to control which streams are split and how they are split:: >>> rngs = nnx.Rngs(params=1, dropout=2, noise=3) >>> new_rngs = rngs.split({ ... 'params': 5, # split params into 5 keys ... 'dropout': None, # don't split dropout ... ...: (2, 5), # split anything else into 2x5 keys ... }) ... >>> assert new_rngs.params.key.shape == (5,) >>> assert new_rngs.dropout.key.shape == () >>> assert new_rngs.noise.key.shape == (2, 5) """ if isinstance(k, int): k = {...: k} elif isinstance(k, tuple): k = {...: k} split_predicates = {filterlib.to_predicate(k): v for k, v in k.items()} keys: dict[str, RngStream] = {} for name, stream in self.items(): for predicate, num_splits in split_predicates.items(): if predicate((), stream): if num_splits is None: keys[name] = stream else: keys[name] = stream.split(num_splits) break else: keys[name] = stream return Rngs(**keys) def fork( self, /, *, split: tp.Mapping[filterlib.Filter, int | tuple[int, ...]] | int | tuple[int, ...] | None = None, ): """Returns a new Rngs object with new unique RNG keys. Example:: >>> from flax import nnx ... >>> rngs = nnx.Rngs(params=1, dropout=2) >>> new_rngs = rngs.fork() ... >>> assert rngs.params() != new_rngs.params() """ if split is not None: warnings.warn( "The 'split' argument of 'fork' is deprecated; use the 'split' method instead.", DeprecationWarning, stacklevel=2, ) if split is None: split = {} elif isinstance(split, int): split = {...: split} elif isinstance(split, tuple): split = {...: split} split_predicates = {filterlib.to_predicate(k): v for k, v in split.items()} keys: dict[str, RngStream] = {} for name, stream in self.items(): for predicate, num_splits in split_predicates.items(): if predicate((), stream): keys[name] = stream.fork(split=num_splits) break else: keys[name] = stream.fork() return Rngs(**keys) # ---------------------------------------------------------- # random functions # ---------------------------------------------------------- if tp.TYPE_CHECKING: bits = staticmethod(functools.partial(random.bits, random.key(0))) uniform = staticmethod( functools.partial(random.uniform, random.key(0)) ) randint = staticmethod( functools.partial(random.randint, random.key(0)) ) permutation = staticmethod( functools.partial(random.permutation, random.key(0)) ) choice = staticmethod(functools.partial(random.choice, random.key(0))) normal = staticmethod(functools.partial(random.normal, random.key(0))) multivariate_normal = staticmethod( functools.partial(random.multivariate_normal, random.key(0)) ) truncated_normal = staticmethod( functools.partial(random.truncated_normal, random.key(0)) ) bernoulli = staticmethod( functools.partial(random.bernoulli, random.key(0)) ) beta = staticmethod(functools.partial(random.beta, random.key(0))) cauchy = staticmethod(functools.partial(random.cauchy, random.key(0))) dirichlet = staticmethod( functools.partial(random.dirichlet, random.key(0)) ) exponential = staticmethod( functools.partial(random.exponential, random.key(0)) ) gamma = staticmethod(functools.partial(random.gamma, random.key(0))) loggamma = staticmethod( functools.partial(random.loggamma, random.key(0)) ) poisson = staticmethod( functools.partial(random.poisson, random.key(0)) ) gumbel = staticmethod(functools.partial(random.gumbel, random.key(0))) categorical = staticmethod( functools.partial(random.categorical, random.key(0)) ) laplace = staticmethod( functools.partial(random.laplace, random.key(0)) ) logistic = staticmethod( functools.partial(random.logistic, random.key(0)) ) pareto = staticmethod(functools.partial(random.pareto, random.key(0))) t = staticmethod(functools.partial(random.t, random.key(0))) chisquare = staticmethod( functools.partial(random.chisquare, random.key(0)) ) f = staticmethod(functools.partial(random.f, random.key(0))) rademacher = staticmethod( functools.partial(random.rademacher, random.key(0)) ) maxwell = staticmethod( functools.partial(random.maxwell, random.key(0)) ) double_sided_maxwell = staticmethod( functools.partial(random.double_sided_maxwell, random.key(0)) ) weibull_min = staticmethod( functools.partial(random.weibull_min, random.key(0)) ) orthogonal = staticmethod( functools.partial(random.orthogonal, random.key(0)) ) generalized_normal = staticmethod( functools.partial(random.generalized_normal, random.key(0)) ) ball = staticmethod(functools.partial(random.ball, random.key(0))) rayleigh = staticmethod( functools.partial(random.rayleigh, random.key(0)) ) wald = staticmethod(functools.partial(random.wald, random.key(0))) geometric = staticmethod( functools.partial(random.geometric, random.key(0)) ) triangular = staticmethod( functools.partial(random.triangular, random.key(0)) ) lognormal = staticmethod( functools.partial(random.lognormal, random.key(0)) ) binomial = staticmethod( functools.partial(random.binomial, random.key(0)) ) multinomial = staticmethod( functools.partial(random.multinomial, random.key(0)) ) else: bits = _function_to_method(random.bits) uniform = _function_to_method(random.uniform) randint = _function_to_method(random.randint) permutation = _function_to_method(random.permutation) choice = _function_to_method(random.choice) normal = _function_to_method(random.normal) multivariate_normal = _function_to_method(random.multivariate_normal) truncated_normal = _function_to_method(random.truncated_normal) bernoulli = _function_to_method(random.bernoulli) beta = _function_to_method(random.beta) cauchy = _function_to_method(random.cauchy) dirichlet = _function_to_method(random.dirichlet) exponential = _function_to_method(random.exponential) gamma = _function_to_method(random.gamma) loggamma = _function_to_method(random.loggamma) poisson = _function_to_method(random.poisson) gumbel = _function_to_method(random.gumbel) categorical = _function_to_method(random.categorical) laplace = _function_to_method(random.laplace) logistic = _function_to_method(random.logistic) pareto = _function_to_method(random.pareto) t = _function_to_method(random.t) chisquare = _function_to_method(random.chisquare) f = _function_to_method(random.f) rademacher = _function_to_method(random.rademacher) maxwell = _function_to_method(random.maxwell) double_sided_maxwell = _function_to_method(random.double_sided_maxwell) weibull_min = _function_to_method(random.weibull_min) orthogonal = _function_to_method(random.orthogonal) generalized_normal = _function_to_method(random.generalized_normal) ball = _function_to_method(random.ball) rayleigh = _function_to_method(random.rayleigh) wald = _function_to_method(random.wald) geometric = _function_to_method(random.geometric) triangular = _function_to_method(random.triangular) lognormal = _function_to_method(random.lognormal) binomial = _function_to_method(random.binomial) multinomial = _function_to_method(random.multinomial) # ---------------------------------------------------------- # initializers # ---------------------------------------------------------- if tp.TYPE_CHECKING: # skip constant delta_orthogonal = staticmethod(_to_keyless(initializers.delta_orthogonal)) glorot_normal = staticmethod(_to_keyless(initializers.glorot_normal)) glorot_uniform = staticmethod(_to_keyless(initializers.glorot_uniform)) he_normal = staticmethod(_to_keyless(initializers.he_normal)) he_uniform = staticmethod(_to_keyless(initializers.he_uniform)) kaiming_normal = staticmethod(_to_keyless(initializers.kaiming_normal)) kaiming_uniform = staticmethod(_to_keyless(initializers.kaiming_uniform)) lecun_normal = staticmethod(_to_keyless(initializers.lecun_normal)) lecun_uniform = staticmethod(_to_keyless(initializers.lecun_uniform)) # skip normal as it conflicts with jax.random.normal # skip ones # skip orthogonal as it conflicts with jax.random.orthogonal # skip truncated_normal as it conflicts with jax.random.truncated_normal # skip uniform as it conflicts with jax.random.uniform variance_scaling = staticmethod(_to_keyless(initializers.variance_scaling)) xavier_normal = staticmethod(_to_keyless(initializers.xavier_normal)) xavier_uniform = staticmethod(_to_keyless(initializers.xavier_uniform)) # skip zeros else: # skip constant delta_orthogonal = _initializer_to_method(initializers.delta_orthogonal) glorot_normal = _initializer_to_method(initializers.glorot_normal) glorot_uniform = _initializer_to_method(initializers.glorot_uniform) he_normal = _initializer_to_method(initializers.he_normal) he_uniform = _initializer_to_method(initializers.he_uniform) kaiming_normal = _initializer_to_method(initializers.kaiming_normal) kaiming_uniform = _initializer_to_method(initializers.kaiming_uniform) lecun_normal = _initializer_to_method(initializers.lecun_normal) lecun_uniform = _initializer_to_method(initializers.lecun_uniform) # skip normal as it conflicts with jax.random.normal # skip ones # skip orthogonal as it conflicts with jax.random.orthogonal # skip truncated_normal as it conflicts with jax.random.truncated_normal # skip uniform as it conflicts with jax.random.uniform variance_scaling = _initializer_to_method(initializers.variance_scaling) xavier_normal = _initializer_to_method(initializers.xavier_normal) xavier_uniform = _initializer_to_method(initializers.xavier_uniform)
# skip zeros StreamBackup = ( tuple[RngStream, jax.Array, jax.Array] | tuple[RngStream, jax.Array] ) class SplitBackups(struct.PyTreeNode, tp.Iterable[StreamBackup]): backups: list[StreamBackup] def __iter__(self) -> tp.Iterator[StreamBackup]: return iter(self.backups) def __enter__(self): return self def __exit__(self, *args): restore_rngs(self) @tp.overload def split_rngs( node: tp.Any, /, *, splits: int | tuple[int, ...], only: filterlib.Filter = ..., squeeze: bool = False, graph: tp.Literal[True] | None = None, ) -> SplitBackups: ... @tp.overload def split_rngs( node: A, /, *, splits: int | tuple[int, ...], only: filterlib.Filter = ..., squeeze: bool = False, graph: tp.Literal[False], ) -> A: ... @tp.overload def split_rngs( *, splits: int | tuple[int, ...], only: filterlib.Filter = ..., squeeze: bool = False, graph: bool | None = None, ) -> tp.Callable[[F], F]: ...
[docs]def split_rngs( node: tp.Any = MISSING, /, *, splits: int | tuple[int, ...], only: filterlib.Filter = ..., squeeze: bool = False, graph: bool | None = None, ) -> SplitBackups | tp.Any | tp.Callable[[F], F]: """Splits the (nested) Rng states of the given node. Args: node: the base node containing the rng states to split. splits: an integer or tuple of integers specifying the shape of the split rng keys. only: a Filter selecting which rng states to split. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Returns: A SplitBackups iterable if ``node`` is provided, otherwise a decorator that splits the rng states of the inputs to the decorated function. Example:: >>> from flax import nnx ... >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.split_rngs(rngs, splits=5) >>> rngs.params.key.shape, rngs.dropout.key.shape ((5,), (5,)) >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.split_rngs(rngs, splits=(2, 5)) >>> rngs.params.key.shape, rngs.dropout.key.shape ((2, 5), (2, 5)) >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.split_rngs(rngs, splits=5, only='params') >>> rngs.params.key.shape, rngs.dropout.key.shape ((5,), ()) Once split, random state can be used with transforms like :func:`nnx.vmap`:: >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.split_rngs(rngs, splits=5, only='params') ... >>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None}) ... >>> @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) ... def create_model(rngs): ... return Model(rngs) ... >>> model = create_model(rngs) >>> model.dropout.rngs.key.shape () ``split_rngs`` returns a SplitBackups object that can be used to restore the original unsplit rng states using :func:`nnx.restore_rngs`, this is useful when you only want to split the rng states temporarily:: >>> rngs = nnx.Rngs(params=0, dropout=1) ... >>> backups = nnx.split_rngs(rngs, splits=5, only='params') >>> model = create_model(rngs) >>> nnx.restore_rngs(backups) ... >>> model.dropout.rngs.key.shape () SplitBackups can also be used as a context manager to automatically restore the rng states when exiting the context:: >>> rngs = nnx.Rngs(params=0, dropout=1) ... >>> with nnx.split_rngs(rngs, splits=5, only='params'): ... model = create_model(rngs) ... >>> model.dropout.rngs.key.shape () >>> state_axes = nnx.StateAxes({(nnx.Param, 'params'): 0, ...: None}) ... >>> @nnx.split_rngs(splits=5, only='params') ... @nnx.vmap(in_axes=(state_axes,), out_axes=state_axes) ... def create_model(rngs): ... return Model(rngs) ... >>> rngs = nnx.Rngs(params=0, dropout=1) >>> model = create_model(rngs) >>> model.dropout.rngs.key.shape () """ if graph is None: graph = graphlib.set_graph_mode.current_value() if isinstance(node, Missing): def split_rngs_decorator(f: F) -> F: @functools.wraps(f) def split_rngs_wrapper(*args, **kwargs): if graph: with split_rngs( (args, kwargs), splits=splits, only=only, squeeze=squeeze, graph=True, ): return f(*args, **kwargs) else: args, kwargs = split_rngs( (args, kwargs), splits=splits, only=only, squeeze=squeeze, graph=False, ) return f(*args, **kwargs) return tp.cast(F, split_rngs_wrapper) return split_rngs_decorator # type: ignore[bad-return-type] if squeeze and splits != 1: raise ValueError('squeeze=True is only supported for splits=1') if graph: return _graph_split_rngs( node, splits=splits, only=only, squeeze=squeeze, ) else: return _tree_split_rngs( node, splits=splits, only=only, squeeze=squeeze, )
def _graph_split_rngs( node: tp.Any, /, *, splits: int | tuple[int, ...], only: filterlib.Filter = ..., squeeze: bool = False, ) -> SplitBackups: predicate = filterlib.to_predicate(only) backups: list[StreamBackup] = [] for path, stream in graphlib.iter_graph(node, graph=True): if ( isinstance(stream, RngStream) and predicate((*path, 'key'), stream.key) and predicate((*path, 'count'), stream.count) ): key = stream() backups.append((stream, stream.key[...], stream.count[...])) key = random.split(key, splits) if squeeze: key = key[0] stream.key.set_value(key) if squeeze: counts_shape = stream.count.shape elif isinstance(splits, int): counts_shape = (splits, *stream.count.shape) else: counts_shape = (*splits, *stream.count.shape) stream.count.set_value(jnp.zeros(counts_shape, dtype=jnp.uint32)) return SplitBackups(backups) def _tree_split_rngs( node: tp.Any, /, *, splits: int | tuple[int, ...], only: filterlib.Filter = ..., squeeze: bool = False, ) -> tp.Any: predicate = filterlib.to_predicate(only) def _split_stream(path, node): if ( isinstance(node, RngStream) and predicate((*path, 'key'), node.key) and predicate((*path, 'count'), node.count) ): key = random.split(node(), splits) if squeeze: key = key[0] if squeeze: counts_shape = node.count.shape elif isinstance(splits, int): counts_shape = (splits, *node.count.shape) else: counts_shape = (*splits, *node.count.shape) node.key = RngKey(key, tag=node.tag) node.count = RngCount( jnp.zeros(counts_shape, dtype=jnp.uint32), tag=node.tag ) return node return graphlib.recursive_map(_split_stream, node, graph=False) @tp.overload def fork_rngs( node: tp.Any, /, *, split: tp.Mapping[filterlib.Filter, int | tuple[int, ...] | None] | int | None = None, graph: bool | None = None, ) -> SplitBackups: ... @tp.overload def fork_rngs( *, split: tp.Mapping[filterlib.Filter, int | tuple[int, ...] | None] | int | None = None, graph: bool | None = None, ) -> tp.Callable[[F], F]: ...
[docs]def fork_rngs( node: tp.Any = MISSING, /, *, split: tp.Mapping[filterlib.Filter, int | tuple[int, ...] | None] | int | None = None, graph: bool | None = None, ) -> SplitBackups | tp.Callable[[F], F]: """Forks the (nested) Rng states of the given node. Args: node: the base node containing the rng states to fork. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. Returns: A SplitBackups iterable if ``node`` is provided, otherwise a decorator that forks the rng states of the inputs to the decorated function. Example:: >>> from flax import nnx ... >>> rngs = nnx.Rngs(params=0, dropout=1) >>> _ = nnx.fork_rngs(rngs) ``fork_rngs`` returns a SplitBackups object that can be used to restore the original unforked rng states using :func:`nnx.restore_rngs`, this is useful when you only want to fork the rng states temporarily:: >>> rngs = nnx.Rngs(params=0, dropout=1) ... >>> backups = nnx.fork_rngs(rngs) >>> model = nnx.Linear(2, 3, rngs=rngs) >>> nnx.restore_rngs(backups) ... SplitBackups can also be used as a context manager to automatically restore the rng states when exiting the context:: >>> rngs = nnx.Rngs(params=0, dropout=1) ... >>> with nnx.fork_rngs(rngs): ... model = nnx.Linear(2, 3, rngs=rngs) """ if isinstance(node, Missing): def fork_rngs_decorator(f: F) -> F: @functools.wraps(f) def fork_rngs_wrapper(*args, **kwargs): with fork_rngs((args, kwargs), split=split): return f(*args, **kwargs) return tp.cast(F, fork_rngs_wrapper) return fork_rngs_decorator # type: ignore[bad-return-type] if split is None: split = {...: None} elif isinstance(split, int | tuple): split = {...: split} predicate_splits = { filterlib.to_predicate(k): v for k, v in split.items() } backups: list[StreamBackup] = [] for path, stream in graphlib.iter_graph(node, graph=graph): for predicate, splits in predicate_splits.items(): if ( isinstance(stream, RngStream) and predicate((*path, 'key'), stream.key) and predicate((*path, 'count'), stream.count) ): forked_stream = stream.fork(split=splits) # backup the original stream state backups.append((stream, stream.key[...], stream.count[...])) # apply the forked key and count to the original stream stream.key.set_value(forked_stream.key.get_value()) stream.count.set_value(forked_stream.count.get_value()) return SplitBackups(backups)
def backup_keys(node: tp.Any, /, *, graph: bool | None = None): backups: list[StreamBackup] = [] for _, stream in graphlib.iter_graph(node, graph=graph): if isinstance(stream, RngStream): backups.append((stream, stream.key[...])) return backups def _scalars_only( path: tuple[Key, ...], scalar_key: jax.Array, target_shape: tuple[int, ...] ) -> jax.Array: if target_shape != (): raise ValueError( f'Cannot reseed stream at path {path!r} becuase it has a non-scalar key, ' f'found key with shape {target_shape}. If all your multi-dimensional ' 'keys have unique values on all dimensions, set policy="match_shape", ' 'else provide a custom reseed policy.' ) return scalar_key def _match_shape( path: tuple[Key, ...], scalar_key: jax.Array, target_shape: tuple[int, ...] ) -> jax.Array: if target_shape == (): return scalar_key return random.split(scalar_key, target_shape)
[docs]def reseed( node, /, *, graph: bool | None = None, policy: tp.Literal['scalars_only', 'match_shape'] | tp.Callable[ [tuple, jax.Array, tuple[int, ...]], jax.Array ] = 'scalars_only', **stream_keys: RngValue, ): """Update the keys of the specified RNG streams with new keys. Args: node: the node to reseed the RNG streams in. graph: If ``True`` (default), uses graph-mode which supports the full NNX feature set including shared references. If ``False``, uses tree-mode which treats Modules as regular JAX pytrees, avoiding the overhead of the graph protocol. policy: defines how the the new scalar key is for each RngStream is used to reseed the stream. If ``'scalars_only'`` is given (the default), an error is raised if the target stream key is not a scalar. If ``'match_shape'`` is given, the new scalar key is split to match the shape of the target stream key. A callable of the form ``(path, scalar_key, target_shape) -> new_key`` can be passed to define a custom reseeding policy. **stream_keys: a mapping of stream names to new keys. The keys can be either integers or ``jax.random.key``. Example:: >>> from flax import nnx >>> import jax.numpy as jnp ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear = nnx.Linear(2, 3, rngs=rngs) ... self.dropout = nnx.Dropout(0.5, rngs=rngs) ... def __call__(self, x): ... return self.dropout(self.linear(x)) ... >>> model = Model(nnx.Rngs(params=0, dropout=42)) >>> x = jnp.ones((1, 2)) ... >>> y1 = model(x) ... >>> # reset the ``dropout`` stream key to 42 >>> nnx.reseed(model, dropout=42) >>> y2 = model(x) ... >>> jnp.allclose(y1, y2) Array(True, dtype=bool) """ if policy == 'scalars_only': policy = _scalars_only elif policy == 'match_shape': policy = _match_shape elif not callable(policy): raise ValueError( f'policy must be "scalars_only", "match_shape" or a callable, ' f'got {policy!r}' ) rngs = Rngs(**stream_keys) for path, stream in graphlib.iter_graph(node, graph=graph): if isinstance(stream, RngStream): if stream.key.tag in stream_keys: key = rngs[stream.key.tag]() key = policy(path, key, stream.key.shape) stream.key.set_value(key) stream.count.set_value(jnp.zeros(key.shape, dtype=jnp.uint32))
def restore_rngs(backups: tp.Iterable[StreamBackup], /): for backup in backups: stream = backup[0] stream.key.set_value(backup[1]) if len(backup) == 3: stream.count.set_value(backup[2]) # count