Filters#
Flax NNX uses Filters extensively as a way to create nnx.State groups in APIs, such as nnx.split, nnx.state(), and many of the Flax NNX transformations (transforms).
In this guide you will learn how to:
Use
Filters to group Flax NNX variables and states into subgroups;Understand relationships between types, such as
nnx.Paramornnx.BatchStat, andFilters;Express your
Filters flexibly withnnx.filterlib.Filterlanguage.
In the following example nnx.Param and nnx.BatchStat are used as Filters to split the model into two groups: one with the parameters and the other with the batch statistics:
from flax import nnx
class Foo(nnx.Module):
def __init__(self):
self.a = nnx.Param(0)
self.b = nnx.BatchStat(True)
foo = Foo()
graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat)
print(f'{params = }')
print(f'{batch_stats = }')
params = State({
'a': Param(
value=0
)
})
batch_stats = State({
'b': BatchStat(
value=True
)
})
Let’s dive deeper into Filters.
The Filter Protocol#
In general, Flax Filters are predicate functions of the form:
(path: tuple[Key, ...], value: Any) -> bool
where:
Keyis a hashable and comparable type;pathis a tuple ofKeys representing the path to the value in a nested structure; andvalueis the value at the path.
The function returns True if the value should be included in the group, and False otherwise.
Types are not functions of this form. They are treated as Filters because, as you will learn in the next section, types and some other literals are converted to predicates. For example, nnx.Param is roughly converted to a predicate like this:
def is_param(path, value) -> bool:
return isinstance(value, nnx.Param)
print(f'{is_param((), nnx.Param(0)) = }')
is_param((), nnx.Param(0)) = True
Such function matches any value that is an instance of nnx.Param. Internally Flax NNX uses OfType which defines a callable of this form for a given type:
is_param = nnx.OfType(nnx.Param)
print(f'{is_param((), nnx.Param(0)) = }')
is_param((), nnx.Param(0)) = True
The Filter DSL#
Flax NNX exposes a small domain specific language (DSL), formalized as the nnx.filterlib.Filter type. This means users don’t have to create functions like in the previous section.
Here is a list of all the callable Filters included in Flax NNX, and their corresponding DSL literals (when available):
Literal |
Callable |
Description |
|---|---|---|
|
|
Matches all values |
|
|
Matches no values |
|
|
Matches values that are instances of |
|
Matches values that have an associated |
|
|
|
Matches values that have string |
|
|
Matches values that match any of the inner |
|
Matches values that match all of the inner |
|
|
Matches values that do not match the inner |
Let’s check out the DSL in action by using nnx.vmap as an example. Consider the following:
You want to vectorize all parameters;
Apply
'dropout'Rng(Keys|Counts)on the0th axis; andBroadcast the rest.
To do this, you can use the following Filters to define a nnx.StateAxes object that you can pass to nnx.vmap’s in_axes to specify how the model’s various sub-states should be vectorized:
state_axes = nnx.StateAxes({(nnx.Param, 'dropout'): 0, ...: None})
@nnx.vmap(in_axes=(state_axes, 0))
def forward(model, x):
...
Here (nnx.Param, 'dropout') expands to Any(OfType(nnx.Param), WithTag('dropout')) and ... expands to Everything().
If you wish to manually convert literal into a predicate, you can use nnx.filterlib.to_predicate:
is_param = nnx.filterlib.to_predicate(nnx.Param)
everything = nnx.filterlib.to_predicate(...)
nothing = nnx.filterlib.to_predicate(False)
params_or_dropout = nnx.filterlib.to_predicate((nnx.Param, 'dropout'))
print(f'{is_param = }')
print(f'{everything = }')
print(f'{nothing = }')
print(f'{params_or_dropout = }')
is_param = OfType(<class 'flax.nnx.variablelib.Param'>)
everything = Everything()
nothing = Nothing()
params_or_dropout = Any(OfType(<class 'flax.nnx.variablelib.Param'>), WithTag('dropout'))
Grouping States#
With the knowledge of Filters from previous sections at hand, let’s learn how to roughly implement nnx.split. Here are the key ideas:
Use
nnx.graph.flattento get theGraphDefandnnx.Staterepresentation of the node.Convert all the
Filters to predicates.Use
State.flat_stateto get the flat representation of the state.Traverse all the
(path, value)pairs in the flat state and group them according to the predicates.Use
State.from_flat_stateto convert the flat states to nestednnx.States.
from typing import Any
KeyPath = tuple[nnx.graph.Key, ...]
def split(node, *filters):
graphdef, state = nnx.graph.flatten(node)
predicates = [nnx.filterlib.to_predicate(f) for f in filters]
flat_states: list[dict[KeyPath, Any]] = [{} for p in predicates]
for path, value in state:
for i, predicate in enumerate(predicates):
if predicate(path, value):
flat_states[i][path] = value
break
else:
raise ValueError(f'No filter matched {path = } {value = }')
states: tuple[nnx.GraphState, ...] = tuple(
nnx.State.from_flat_path(flat_state) for flat_state in flat_states
)
return graphdef, *states
# Let's test it.
foo = Foo()
graphdef, params, batch_stats = split(foo, nnx.Param, nnx.BatchStat)
print(f'{params = }')
print(f'{batch_stats = }')
params = State({
'a': Param(
value=0
)
})
batch_stats = State({
'b': BatchStat(
value=True
)
})
Note:* It’s very important to know that filtering is order-dependent. The first Filter that matches a value will keep it, and therefore you should place more specific Filters before more general Filters.
For example, as demonstrated below, if you:
Create a
SpecialParamtype that is a subclass ofnnx.Param, and aBarobject (subclassingnnx.Module) that contains both types of parameters; andTry to split the
nnx.Params before theSpecialParams
then all the values will be placed in the nnx.Param group, and the SpecialParam group will be empty because all SpecialParams are also nnx.Params:
class SpecialParam(nnx.Param):
pass
class Bar(nnx.Module):
def __init__(self):
self.a = nnx.Param(0)
self.b = SpecialParam(0)
bar = Bar()
graphdef, params, special_params = split(bar, nnx.Param, SpecialParam) # wrong!
print(f'{params = }')
print(f'{special_params = }')
params = State({
'a': Param(
value=0
),
'b': SpecialParam(
value=0
)
})
special_params = State({})
And reversing the order will ensure that the SpecialParam are captured first:
graphdef, special_params, params = split(bar, SpecialParam, nnx.Param) # correct!
print(f'{params = }')
print(f'{special_params = }')
params = State({
'a': Param(
value=0
)
})
special_params = State({
'b': SpecialParam(
value=0
)
})