summary#
- flax.nnx.tabulate(obj, *input_args, depth=None, method='__call__', row_filter=<function filter_rng_streams>, table_kwargs=mappingproxy({}), column_kwargs=mappingproxy({}), console_kwargs=mappingproxy({}), compute_flops=False, compute_vjp_flops=False, **input_kwargs)[source]#
Creates a summary of the graph object represented as a table.
The table summarizes the object’s state and metadata. The table is structured as follows:
The first column represents the path of the object in the graph.
The second column represents the type of the object.
The third column represents the input arguments passed to the object’s method.
The fourth column represents the output of the object’s method.
The following columns provide information about the object’s state, grouped by Variable types.
Example:
>>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.bn = nnx.BatchNorm(dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.2, rngs=rngs) ... ... def __call__(self, x): ... return nnx.relu(self.dropout(self.bn(self.linear(x)))) ... >>> class Foo(nnx.Module): ... def __init__(self, rngs: nnx.Rngs): ... self.block1 = Block(32, 128, rngs=rngs) ... self.block2 = Block(128, 10, rngs=rngs) ... ... def __call__(self, x): ... return self.block2(self.block1(x)) ... >>> foo = Foo(nnx.Rngs(0)) >>> # print(nnx.tabulate(foo, jnp.ones((1, 32)))) Foo Summary ┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓ ┃ path ┃ type ┃ inputs ┃ outputs ┃ BatchStat ┃ Param ┃ RngState ┃ ┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩ │ │ Foo │ float32[1,32] │ float32[1,10] │ 276 (1.1 KB) │ 5,790 (23.2 KB) │ 2 (12 B) │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block1 │ Block │ float32[1,32] │ float32[1,128] │ 256 (1.0 KB) │ 4,480 (17.9 KB) │ 2 (12 B) │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block1/linear │ Linear │ float32[1,32] │ float32[1,128] │ │ bias: float32[128] │ │ │ │ │ │ │ │ kernel: float32[32,128] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 4,224 (16.9 KB) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block1/bn │ BatchNorm │ float32[1,128] │ float32[1,128] │ mean: float32[128] │ bias: float32[128] │ │ │ │ │ │ │ var: float32[128] │ scale: float32[128] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 256 (1.0 KB) │ 256 (1.0 KB) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block1/dropout │ Dropout │ float32[1,128] │ float32[1,128] │ │ │ 2 (12 B) │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block2 │ Block │ float32[1,128] │ float32[1,10] │ 20 (80 B) │ 1,310 (5.2 KB) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block2/linear │ Linear │ float32[1,128] │ float32[1,10] │ │ bias: float32[10] │ │ │ │ │ │ │ │ kernel: float32[128,10] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 1,290 (5.2 KB) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block2/bn │ BatchNorm │ float32[1,10] │ float32[1,10] │ mean: float32[10] │ bias: float32[10] │ │ │ │ │ │ │ var: float32[10] │ scale: float32[10] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 20 (80 B) │ 20 (80 B) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block2/dropout │ Dropout │ float32[1,10] │ float32[1,10] │ │ │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ │ │ │ Total │ 276 (1.1 KB) │ 5,790 (23.2 KB) │ 2 (12 B) │ └────────────────┴───────────┴────────────────┴────────────────┴────────────────────┴─────────────────────────┴──────────┘ Total Parameters: 6,068 (24.3 KB)
Note that
block2/dropoutis not shown in the table because it shares the sameRngStatewithblock1/dropout.- Parameters:
obj – A object to summarize. It can a pytree or a graph objects such as nnx.Module or nnx.Optimizer.
*input_args – Positional arguments passed to the object’s method.
**input_kwargs – Keyword arguments passed to the object’s method.
depth – The depth of the table.
method – The method to call on the object. Default is
'__call__'.row_filter – A callable that filters the rows to be displayed in the table. By default, it filters out rows with type
nnx.RngStream.table_kwargs – An optional dictionary with additional keyword arguments that are passed to
rich.table.Tableconstructor.column_kwargs – An optional dictionary with additional keyword arguments that are passed to
rich.table.Table.add_columnwhen adding columns to the table.console_kwargs – An optional dictionary with additional keyword arguments that are passed to rich.console.Console when rendering the table. Default arguments are
'force_terminal': True, and'force_jupyter'is set toTrueif the code is running in a Jupyter notebook, otherwise it is set toFalse.compute_flops – whether to include a flops column in the table listing the estimated FLOPs cost of each module forward pass. Does incur actual on-device computation / compilation / memory allocation, but still introduces overhead for large modules (e.g. extra 20 seconds for a Stable Diffusion’s UNet, whereas otherwise tabulation would finish in 5 seconds).
compute_vjp_flops – whether to include a vjp_flops column in the table listing the estimated FLOPs cost of each module backward pass. Introduces a compute overhead of about 2-3X of compute_flops.
- Returns:
A string summarizing the object.