flax.traverse_util package

flax.traverse_util package#

A utility for traversing immutable datastructures.

A Traversal can be used to iterate and update complex data structures. Traversals take in an object and return a subset of its contents. For example, a Traversal could select an attribute of an object:

>>> from flax import traverse_util
>>> import dataclasses

>>> @dataclasses.dataclass
... class Foo:
...   foo: int = 0
...   bar: int = 0
...
>>> x = Foo(foo=1)
>>> iterator = traverse_util.TraverseAttr('foo').iterate(x)
>>> list(iterator)
[1]

More complex traversals can be constructed using composition. It is often useful to start from the identity traversal and use a method chain to construct the intended Traversal:

>>> data = [{'foo': 1, 'bar': 2}, {'foo': 3, 'bar': 4}]
>>> traversal = traverse_util.t_identity.each()['foo']
>>> iterator = traversal.iterate(data)
>>> list(iterator)
[1, 3]

Traversals can also be used to make changes using the update method:

>>> data = {'foo': Foo(bar=2)}
>>> traversal = traverse_util.t_identity['foo'].bar
>>> data = traversal.update(lambda x: x + x, data)
>>> data
{'foo': Foo(foo=0, bar=4)}

Traversals never mutate the original data. Therefore, an update essentially returns a copy of the data including the provided updates.

Dict utils#

flax.traverse_util.flatten_dict(xs, keep_empty_nodes=False, is_leaf=None, sep=None)[source]#

Flatten a nested dictionary.

The nested keys are flattened to a tuple. See unflatten_dict on how to restore the nested dictionary structure.

Example:

>>> from flax.traverse_util import flatten_dict

>>> xs = {'foo': 1, 'bar': {'a': 2, 'b': {}}}
>>> flat_xs = flatten_dict(xs)
>>> flat_xs
{('foo',): 1, ('bar', 'a'): 2}

Note that empty dictionaries are ignored and will not be restored by unflatten_dict.

Parameters:
  • xs – a nested dictionary

  • keep_empty_nodes – replaces empty dictionaries with traverse_util.empty_node.

  • is_leaf – an optional function that takes the next nested dictionary and nested keys and returns True if the nested dictionary is a leaf (i.e., should not be flattened further).

  • sep – if specified, then the keys of the returned dictionary will be sep-joined strings (if None, then keys will be tuples).

Returns:

The flattened dictionary.

flax.traverse_util.unflatten_dict(xs, sep=None)[source]#

Unflatten a dictionary.

See flatten_dict

Example:

>>> flat_xs = {
...   ('foo',): 1,
...   ('bar', 'a'): 2,
... }
>>> xs = unflatten_dict(flat_xs)
>>> xs
{'foo': 1, 'bar': {'a': 2}}
Parameters:
  • xs – a flattened dictionary

  • sep – separator (same as used with flatten_dict()).

Returns:

The nested dictionary.

flax.traverse_util.path_aware_map(f, nested_dict)[source]#

A map function that operates over nested dictionary structures while taking the path to each leaf into account.

Example:

>>> import jax.numpy as jnp
>>> from flax import traverse_util

>>> params = {'a': {'x': 10, 'y': 3}, 'b': {'x': 20}}
>>> f = lambda path, x: x + 5 if 'x' in path else -x
>>> traverse_util.path_aware_map(f, params)
{'a': {'x': 15, 'y': -3}, 'b': {'x': 25}}
Parameters:
  • f – A callable that takes in (path, value) arguments and maps them to a new value. Here path is a tuple of strings.

  • nested_dict – A nested dictionary structure.

Returns:

A new nested dictionary structure with the mapped values.