flax.traverse_util package#
A utility for traversing immutable datastructures.
A Traversal can be used to iterate and update complex data structures. Traversals take in an object and return a subset of its contents. For example, a Traversal could select an attribute of an object:
>>> from flax import traverse_util
>>> import dataclasses
>>> @dataclasses.dataclass
... class Foo:
... foo: int = 0
... bar: int = 0
...
>>> x = Foo(foo=1)
>>> iterator = traverse_util.TraverseAttr('foo').iterate(x)
>>> list(iterator)
[1]
More complex traversals can be constructed using composition. It is often useful to start from the identity traversal and use a method chain to construct the intended Traversal:
>>> data = [{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 4}]
>>> traversal = traverse_util.t_identity.each()['foo']
>>> iterator = traversal.iterate(data)
>>> list(iterator)
[1, 3]
Traversals can also be used to make changes using the update method:
>>> data = {'foo': Foo(bar=2)}
>>> traversal = traverse_util.t_identity['foo'].bar
>>> data = traversal.update(lambda x: x + x, data)
>>> data
{'foo': Foo(foo=0, bar=4)}
Traversals never mutate the original data. Therefore, an update essentially returns a copy of the data including the provided updates.
Dict utils#
- flax.traverse_util.flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None)[source]#
Flatten a nested dictionary.
The nested keys are flattened to a tuple. See
unflatten_dicton how to restore the nested dictionary structure.Example:
>>> from flax.traverse_util import flatten_dict >>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}} >>> flat_xs = flatten_dict(xs) >>> flat_xs {('foo',): 1, ('bar', 'a'): 2}
Note that empty dictionaries are ignored and will not be restored by
unflatten_dict.- Parameters:
xs – a nested dictionary
keep_empty_nodes – replaces empty dictionaries with
traverse_util.empty_node.is_leaf – an optional function that takes the next nested dictionary and nested keys and returns True if the nested dictionary is a leaf (i.e., should not be flattened further).
sep – if specified, then the keys of the returned dictionary will be
sep-joined strings (ifNone, then keys will be tuples).
- Returns:
The flattened dictionary.
- flax.traverse_util.unflatten_dict(xs, sep=None)[source]#
Unflatten a dictionary.
See
flatten_dictExample:
>>> flat_xs = { ... ('foo',): 1, ... ('bar', 'a'): 2, ... } >>> xs = unflatten_dict(flat_xs) >>> xs {'foo': 1, 'bar': {'a': 2}}
- Parameters:
xs – a flattened dictionary
sep – separator (same as used with
flatten_dict()).
- Returns:
The nested dictionary.
- flax.traverse_util.path_aware_map(f, nested_dict)[source]#
A map function that operates over nested dictionary structures while taking the path to each leaf into account.
Example:
>>> import jax.numpy as jnp >>> from flax import traverse_util >>> params = {'a': {'x': 10, 'y': 3}, 'b': {'x': 20}} >>> f = lambda path, x: x + 5 if 'x' in path else -x >>> traverse_util.path_aware_map(f, params) {'a': {'x': 15, 'y': -3}, 'b': {'x': 25}}
- Parameters:
f – A callable that takes in
(path, value)arguments and maps them to a new value. Herepathis a tuple of strings.nested_dict – A nested dictionary structure.
- Returns:
A new nested dictionary structure with the mapped values.