flax.struct package#
Utilities for defining custom classes that can be used with jax transformations.
- flax.struct.dataclass(clz=None, **kwargs)[source]#
Create a class which can be passed to functional transformations.
Note
Inherit from
PyTreeNodeinstead to avoid type checking issues when using PyType.Jax transformations such as
jax.jitandjax.gradrequire objects that are immutable and can be mapped over using thejax.tree_utilmethods. Thedataclassdecorator makes it easy to define custom classes that can be passed safely to Jax. Define JAX data as normal attribute fields, and usepytree_node=Falseto define static metadata.See example:
>>> from flax import struct >>> import jax >>> from typing import Any, Callable >>> @struct.dataclass ... class Model: ... params: Any ... # use pytree_node=False to indicate an attribute should not be touched ... # by Jax transformations. ... apply_fn: Callable = struct.field(pytree_node=False) ... def __apply__(self, *args): ... return self.apply_fn(*args) >>> params = {} >>> params_b = {} >>> apply_fn = lambda v, x: x >>> model = Model(params, apply_fn) >>> # model.params = params_b # Model is immutable. This will raise an error. >>> model_b = model.replace(params=params_b) # Use the replace method instead. >>> # This class can now be used safely in Jax to compute gradients w.r.t. the >>> # parameters. >>> model = Model(params, apply_fn) >>> loss_fn = lambda model: 3. >>> model_grad = jax.grad(loss_fn)(model)
Note that dataclasses have an auto-generated
__init__where the arguments of the constructor and the attributes of the created instance match 1:1. If you desire a “smart constructor”, for example to optionally derive some of the attributes from others, make an additional static or class method. Consider the following example:>>> @struct.dataclass ... class DirectionAndScaleKernel: ... direction: jax.Array ... scale: jax.Array ... @classmethod ... def create(cls, kernel): ... scale = jax.numpy.linalg.norm(kernel, axis=0, keepdims=True) ... direction = direction / scale ... return cls(direction, scale)
- Parameters:
clz – the class that will be transformed by the decorator.
**kwargs – arguments to pass to the dataclass constructor.
- Returns:
The new class.
- class flax.struct.PyTreeNode(*args, **kwargs)[source]#
Base class for dataclasses that should act like a JAX pytree node.
See
flax.struct.dataclassfor thejax.tree_utilbehavior. This base class additionally avoids type checking errors when using PyType.Example:
>>> from flax import struct >>> import jax >>> from typing import Any, Callable >>> class Model(struct.PyTreeNode): ... params: Any ... # use pytree_node=False to indicate an attribute should not be touched ... # by Jax transformations. ... apply_fn: Callable = struct.field(pytree_node=False) ... def __apply__(self, *args): ... return self.apply_fn(*args) >>> params = {} >>> params_b = {} >>> apply_fn = lambda v, x: x >>> model = Model(params, apply_fn) >>> # model.params = params_b # Model is immutable. This will raise an error. >>> model_b = model.replace(params=params_b) # Use the replace method instead. >>> # This class can now be used safely in Jax to compute gradients w.r.t. the >>> # parameters. >>> model = Model(params, apply_fn) >>> loss_fn = lambda model: 3. >>> model_grad = jax.grad(loss_fn)(model)