Source code for flax.nnx.summary

# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from collections import defaultdict
import dataclasses
import inspect
import io
import typing as tp
from types import MappingProxyType
import functools
import itertools

import jax
import numpy as np
import rich.console
import rich.table
import rich.text
import yaml
import jax.numpy as jnp

from flax import nnx
from flax import typing
from flax.nnx import graphlib, statelib, variablelib

from functools import wraps


try:
  from IPython import get_ipython

  in_ipython = get_ipython() is not None
except ImportError:
  in_ipython = False

# Custom YAML dumper to represent None as 'None' string (not YAML 'null') for clarity
class NoneDumper(yaml.SafeDumper):
  pass

NoneDumper.add_representer(
  type(None),
  lambda dumper, data: dumper.represent_scalar('tag:yaml.org,2002:str', 'None'),
)

class SizeBytes(typing.SizeBytes):
  def __repr__(self) -> str:
    bytes_repr = _bytes_repr(self.bytes)
    return f'{self.size:,} [dim]({bytes_repr})[/dim]'

class ObjectInfo(tp.NamedTuple):
  path: statelib.PathParts
  stats: dict[type[variablelib.Variable], SizeBytes]
  variable_groups: defaultdict[
    type[variablelib.Variable], defaultdict[typing.Key, variablelib.Variable]
  ]

NodeStats = dict[int, ObjectInfo | None]

def _collect_stats(
  path: statelib.PathParts,
  node: tp.Any,
  node_stats: NodeStats,
  object_types: set[type],
):
  if not graphlib.is_node(node) and not isinstance(node, variablelib.Variable):
    raise ValueError(f'Expected a graph node or Variable, got {type(node)!r}.')

  if id(node) in node_stats:
    return

  stats: dict[type[variablelib.Variable], SizeBytes] = {}
  variable_groups: defaultdict[
    type[variablelib.Variable], defaultdict[typing.Key, variablelib.Variable]
  ] = defaultdict(lambda: defaultdict())
  node_stats[id(node)] = ObjectInfo(path, stats, variable_groups)

  if isinstance(node, nnx.Pytree):
    node._nnx_tabulate_id = id(node)  # type: ignore
    object_types.add(type(node))

  node_impl = graphlib.get_node_impl(node)
  assert node_impl is not None
  node_dict = node_impl.node_dict(node)
  for key, value in node_dict.items():
    if id(value) in node_stats:
      continue
    elif isinstance(value, variablelib.Variable):
      var_type = type(value)
      if issubclass(var_type, nnx.RngState):
        var_type = nnx.RngState
      size_bytes = SizeBytes.from_any(value.get_value())
      if var_type in stats:
        stats[var_type] += size_bytes
      else:
        stats[var_type] = size_bytes
      variable_groups[var_type][key] = value
      node_stats[id(value)] = None
    elif graphlib.is_node(value):
      _collect_stats((*path, key), value, node_stats, object_types)
      # accumulate stats from children
      child_info = node_stats[id(value)]
      assert child_info is not None
      for var_type, size_bytes in child_info.stats.items():
        if var_type in stats:
          stats[var_type] += size_bytes
        else:
          stats[var_type] = size_bytes

@dataclasses.dataclass(frozen=True, repr=False)
class ArrayRepr:
  shape: tuple[int, ...]
  dtype: tp.Any

  @classmethod
  def from_array(cls, x: jax.Array | np.ndarray):
    return cls(jnp.shape(x), jnp.result_type(x))

  def __str__(self):
    shape_repr = ','.join(str(x) for x in self.shape)
    return f'[dim]{self.dtype}[/dim][{shape_repr}]'


@dataclasses.dataclass
class CallInfo:
  call_order: int
  object_id: int
  type: type
  path: statelib.PathParts
  inputs_repr: str
  outputs: tp.Any
  flops: int | None
  vjp_flops: int | None

class SimpleObjectRepr:
  def __init__(self, obj: tp.Any):
    self.type = type(obj)

  def __str__(self):
    return f'{self.type.__name__}(...)'

  def __repr__(self):
    return f'{self.type.__name__}(...)'


def _to_dummy_array(x):
  if isinstance(x,jax.ShapeDtypeStruct):
    return ArrayRepr(x.shape, x.dtype)
  elif isinstance(x, jax.Array | np.ndarray):
    return ArrayRepr.from_array(x)
  elif graphlib.is_graph_node(x):
    return SimpleObjectRepr(x)
  else:
    return x

def _pure_nnx_vjp(f, model, *args, **kwargs):
  "Wrap nnx functional api around jax.vjp. Only handles pure method calls."
  graphdef, state = nnx.split(model)
  def inner(state, *args, **kwargs):
    model = nnx.merge(graphdef, state)
    return f(model, *args, **kwargs)
  return jax.vjp(inner, state, *args, **kwargs)

def filter_rng_streams(row: CallInfo):
  return not issubclass(row.type, nnx.RngStream)

def _create_obj_env(object_types):
  "Turn a set of object types into a dictionary mapping (type, method name) pairs to methods"
  result = {}
  for obj_type in object_types:
    for name, top_method in inspect.getmembers(obj_type, inspect.isfunction):
      if not name.startswith('_') or name == '__call__':
        result[(obj_type, name)] = top_method
  return result

def _get_inputs_repr(args, kwargs):
  input_args, input_kwargs = jax.tree.map(
    _to_dummy_array, (args, kwargs)
  )
  inputs_repr = ''
  if input_args:
    if len(input_args) == 1 and not input_kwargs:
      inputs_repr += _as_yaml_str(input_args[0])
    else:
      inputs_repr += _as_yaml_str(input_args)
    if input_kwargs:
      inputs_repr += '\n'
  if input_kwargs:
    inputs_repr += _as_yaml_str(input_kwargs)
  return inputs_repr

def _save_call_info(counter, tracer_args, f, node_stats, compute_flops, compute_vjp_flops, seen):
  "Wrap a function to save its arguments"

  # Used when computing vjp flops
  def do_vjp(*args, **kwargs):
    primals, f_vjp = jax.vjp(f, *args, **kwargs)
    return f_vjp(primals)

  method_name = f.__name__

  @functools.partial(jax.jit)
  def jit_f(graphdef, state):
    args, kwargs = nnx.merge(graphdef, state)
    return f(*args, **kwargs)

  @wraps(f)
  def wrapper(obj, *args, **kwargs):
    inputs_repr = _get_inputs_repr(args, kwargs)
    object_id = getattr(obj, '_nnx_tabulate_id')
    node_info = node_stats[object_id]
    path = node_info.path
    if method_name != '__call__':
      path = (*path, method_name)
    identifier = (inputs_repr, object_id)
    counter_val = next(counter)
    graphdef, state = nnx.split(((obj, *args), kwargs))
    if compute_flops:
      lowered = jit_f.lower(graphdef, state)
      flops = _get_flops(lowered)
      outputs = lowered.out_info
    else:
      flops = None
      outputs = jit_f(graphdef, state)
    if identifier not in seen:
      seen.add(identifier)
      output_repr = jax.tree.map(_to_dummy_array, outputs)
      vjp_flops = _get_flops(jax.jit(do_vjp).lower(
        obj, *args, **kwargs)) if compute_vjp_flops else None
      tracer_args.append(
        CallInfo(counter_val, object_id, type(obj), path, inputs_repr,
          output_repr, flops, vjp_flops))
    return jit_f(graphdef, state)
  return wrapper

def _overwrite_methods(env):
  "Overwrite methods with functions from an environment"
  for (obj_type, name), f in env.items():
    setattr(obj_type, name, f)

def _get_flops(e) -> int:
  cost = e.cost_analysis() or e.compile().cost_analysis()
  return 0 if cost is None or 'flops' not in cost else int(cost['flops'])

[docs]def tabulate( obj, *input_args, depth: int | None = None, method: str = '__call__', row_filter: tp.Callable[[CallInfo], bool] = filter_rng_streams, table_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), column_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), console_kwargs: tp.Mapping[str, tp.Any] = MappingProxyType({}), compute_flops: bool = False, compute_vjp_flops: bool = False, **input_kwargs, ) -> str: """Creates a summary of the graph object represented as a table. The table summarizes the object's state and metadata. The table is structured as follows: - The first column represents the path of the object in the graph. - The second column represents the type of the object. - The third column represents the input arguments passed to the object's method. - The fourth column represents the output of the object's method. - The following columns provide information about the object's state, grouped by Variable types. Example:: >>> from flax import nnx ... >>> class Block(nnx.Module): ... def __init__(self, din, dout, rngs: nnx.Rngs): ... self.linear = nnx.Linear(din, dout, rngs=rngs) ... self.bn = nnx.BatchNorm(dout, rngs=rngs) ... self.dropout = nnx.Dropout(0.2, rngs=rngs) ... ... def __call__(self, x): ... return nnx.relu(self.dropout(self.bn(self.linear(x)))) ... >>> class Foo(nnx.Module): ... def __init__(self, rngs: nnx.Rngs): ... self.block1 = Block(32, 128, rngs=rngs) ... self.block2 = Block(128, 10, rngs=rngs) ... ... def __call__(self, x): ... return self.block2(self.block1(x)) ... >>> foo = Foo(nnx.Rngs(0)) >>> # print(nnx.tabulate(foo, jnp.ones((1, 32)))) Foo Summary ┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓ ┃ path ┃ type ┃ inputs ┃ outputs ┃ BatchStat ┃ Param ┃ RngState ┃ ┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩ │ │ Foo │ float32[1,32] │ float32[1,10] │ 276 (1.1 KB) │ 5,790 (23.2 KB) │ 2 (12 B) │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block1 │ Block │ float32[1,32] │ float32[1,128] │ 256 (1.0 KB) │ 4,480 (17.9 KB) │ 2 (12 B) │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block1/linear │ Linear │ float32[1,32] │ float32[1,128] │ │ bias: float32[128] │ │ │ │ │ │ │ │ kernel: float32[32,128] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 4,224 (16.9 KB) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block1/bn │ BatchNorm │ float32[1,128] │ float32[1,128] │ mean: float32[128] │ bias: float32[128] │ │ │ │ │ │ │ var: float32[128] │ scale: float32[128] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 256 (1.0 KB) │ 256 (1.0 KB) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block1/dropout │ Dropout │ float32[1,128] │ float32[1,128] │ │ │ 2 (12 B) │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block2 │ Block │ float32[1,128] │ float32[1,10] │ 20 (80 B) │ 1,310 (5.2 KB) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block2/linear │ Linear │ float32[1,128] │ float32[1,10] │ │ bias: float32[10] │ │ │ │ │ │ │ │ kernel: float32[128,10] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 1,290 (5.2 KB) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block2/bn │ BatchNorm │ float32[1,10] │ float32[1,10] │ mean: float32[10] │ bias: float32[10] │ │ │ │ │ │ │ var: float32[10] │ scale: float32[10] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 20 (80 B) │ 20 (80 B) │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ block2/dropout │ Dropout │ float32[1,10] │ float32[1,10] │ │ │ │ ├────────────────┼───────────┼────────────────┼────────────────┼────────────────────┼─────────────────────────┼──────────┤ │ │ │ │ Total │ 276 (1.1 KB) │ 5,790 (23.2 KB) │ 2 (12 B) │ └────────────────┴───────────┴────────────────┴────────────────┴────────────────────┴─────────────────────────┴──────────┘ Total Parameters: 6,068 (24.3 KB) Note that ``block2/dropout`` is not shown in the table because it shares the same ``RngState`` with ``block1/dropout``. Args: obj: A object to summarize. It can a pytree or a graph objects such as nnx.Module or nnx.Optimizer. *input_args: Positional arguments passed to the object's method. **input_kwargs: Keyword arguments passed to the object's method. depth: The depth of the table. method: The method to call on the object. Default is ``'__call__'``. row_filter: A callable that filters the rows to be displayed in the table. By default, it filters out rows with type ``nnx.RngStream``. table_kwargs: An optional dictionary with additional keyword arguments that are passed to ``rich.table.Table`` constructor. column_kwargs: An optional dictionary with additional keyword arguments that are passed to ``rich.table.Table.add_column`` when adding columns to the table. console_kwargs: An optional dictionary with additional keyword arguments that are passed to `rich.console.Console` when rendering the table. Default arguments are ``'force_terminal': True``, and ``'force_jupyter'`` is set to ``True`` if the code is running in a Jupyter notebook, otherwise it is set to ``False``. compute_flops: whether to include a `flops` column in the table listing the estimated FLOPs cost of each module forward pass. Does incur actual on-device computation / compilation / memory allocation, but still introduces overhead for large modules (e.g. extra 20 seconds for a Stable Diffusion's UNet, whereas otherwise tabulation would finish in 5 seconds). compute_vjp_flops: whether to include a `vjp_flops` column in the table listing the estimated FLOPs cost of each module backward pass. Introduces a compute overhead of about 2-3X of `compute_flops`. Returns: A string summarizing the object. """ _console_kwargs = {'force_terminal': True, 'force_jupyter': in_ipython} _console_kwargs.update(console_kwargs) obj = graphlib.clone(obj) # create copy to avoid side effects node_stats: NodeStats = {} object_types: set[type] = set() _collect_stats((), obj, node_stats, object_types) _variable_types: set[type] = { nnx.RngState # type: ignore[misc] if isinstance(leaf, nnx.RngState) else type(leaf) for _, leaf in nnx.to_flat_state(nnx.state(obj)) } variable_types: list[type] = sorted(_variable_types, key=lambda t: t.__name__) # Create a dictionary-version of the object's class. This makes # iteration over methods easier. env = _create_obj_env(object_types) # Information is recorded in post-order, but should be presented as a pre-order traversal. # This keeps track of the order of calls. counter = itertools.count(0) # Modify all the object's methods to save their lowered JIT representations. rows : list[CallInfo] = [] seen : set = set() jits = {k: _save_call_info(counter, rows, v, node_stats, compute_flops, compute_vjp_flops, seen) for k, v in env.items()} _overwrite_methods(jits) # Trace the top function (which indirectly traces all the others) jits[(type(obj), method)](obj, *input_args, **input_kwargs) # Sort call info in pre-order traversal order rows.sort(key=lambda x: x.call_order) # Restore the object's original methods _overwrite_methods(env) if depth is not None: rows = [row for row in rows if len(row.path) <= depth and row_filter(row)] else: rows = [row for row in rows if row_filter(row)] rich_table = rich.table.Table( show_header=True, show_lines=True, show_footer=True, title=f'{type(obj).__name__} Summary', **table_kwargs, ) rich_table.add_column('path', **column_kwargs) rich_table.add_column('type', **column_kwargs) rich_table.add_column('inputs', **column_kwargs) rich_table.add_column('outputs', **column_kwargs) if compute_flops: rich_table.add_column('flops', **column_kwargs) if compute_vjp_flops: rich_table.add_column('vjp_flops', **column_kwargs) for var_type in variable_types: rich_table.add_column(var_type.__name__, **column_kwargs) for row in rows: node_info = node_stats[row.object_id] assert node_info is not None col_reprs: list[str] = [] path_str = '/'.join(map(str, row.path)) col_reprs.append(path_str) col_reprs.append(row.type.__name__) col_reprs.append(row.inputs_repr) col_reprs.append(_as_yaml_str(row.outputs)) if compute_flops: col_reprs.append(str(row.flops)) if compute_vjp_flops: col_reprs.append(str(row.vjp_flops)) for var_type in variable_types: attributes = {} variable: variablelib.Variable for name, variable in node_info.variable_groups[var_type].items(): value = variable.get_value() value_repr = _render_array(value) if _has_shape_dtype(value) else '' metadata = variable.get_metadata() for required_key in var_type.required_metadata: metadata.pop(required_key, None) if metadata: attributes[name] = { 'value': value_repr, **metadata, } elif value_repr: attributes[name] = value_repr # type: ignore[assignment] if attributes: col_repr = _as_yaml_str(attributes) + '\n\n' else: col_repr = '' size_bytes = node_info.stats.get(var_type) # type: ignore[call-overload] if size_bytes: col_repr += f'[bold]{size_bytes}[/bold]' col_reprs.append(col_repr) rich_table.add_row(*col_reprs) total_offset = 3 + int(compute_flops) + int(compute_vjp_flops) rich_table.columns[total_offset].footer = rich.text.Text.from_markup( 'Total', justify='right' ) node_info = node_stats[id(obj)] assert node_info is not None for i, var_type in enumerate(variable_types): size_bytes = node_info.stats[var_type] rich_table.columns[i + total_offset + 1].footer = str(size_bytes) rich_table.caption_style = 'bold' total_size = sum(node_info.stats.values(), SizeBytes(0, 0)) rich_table.caption = f'\nTotal Parameters: {total_size}' return _get_rich_repr(rich_table, _console_kwargs)
def _get_rich_repr(obj, console_kwargs): f = io.StringIO() console = rich.console.Console(file=f, **console_kwargs) console.print(obj) return f.getvalue() def _size_and_bytes(pytree: tp.Any) -> tuple[int, int]: leaves = jax.tree.leaves(pytree) size = sum(x.size for x in leaves if hasattr(x, 'size')) num_bytes = sum( x.size * x.dtype.itemsize for x in leaves if hasattr(x, 'size') ) return size, num_bytes def _size_and_bytes_repr(size: int, num_bytes: int) -> str: if not size: return '' bytes_repr = _bytes_repr(num_bytes) return f'{size:,} [dim]({bytes_repr})[/dim]' def _bytes_repr(num_bytes): count, units = ( (f'{num_bytes / 1e9:,.1f}', 'GB') if num_bytes > 1e9 else (f'{num_bytes / 1e6:,.1f}', 'MB') if num_bytes > 1e6 else (f'{num_bytes / 1e3:,.1f}', 'KB') if num_bytes > 1e3 else (f'{num_bytes:,}', 'B') ) return f'{count} {units}' def _has_shape_dtype(value): return hasattr(value, 'shape') and hasattr(value, 'dtype') def _normalize_values(x): if isinstance(x, type): return f'type[{x.__name__}]' elif isinstance(x, ArrayRepr | SimpleObjectRepr): return str(x) else: return repr(x) def _maybe_pytree_to_dict(pytree: tp.Any): path_leaves = jax.tree_util.tree_flatten_with_path(pytree)[0] path_leaves = [ (tuple(map(graphlib._key_path_to_key, path)), value) for path, value in path_leaves ] if len(path_leaves) < 1: return pytree elif len(path_leaves) == 1 and path_leaves[0][0] == (): return pytree else: return _unflatten_to_simple_structure(path_leaves, original=pytree) def _unflatten_to_simple_structure( xs: list[tuple[tuple[tp.Any, ...], tp.Any]], *, original: tp.Any ): """Rebuild a simple Python structure from path/value leaves. This variant is aware of the original object so it can: - Preserve empty containers that were elided by JAX flattening. - Pad trailing missing list/tuple items using the original length. - Distinguish placeholders for empty dict/list vs None. """ def _get_by_path(x, path: tuple[tp.Any, ...]): cur = x for k in path: cur = cur[k] return cur def _to_simple(x): # Convert to display-friendly simple structures if isinstance(x, (list, tuple)): return [_to_simple(e) for e in x] if isinstance(x, dict): return {k: _to_simple(v) for k, v in x.items()} return x result: list | dict = ( [] if len(xs) > 0 and isinstance(xs[0][0][0], int) else {} ) for path, value in xs: cursor = result for i, key in enumerate(path[:-1]): if isinstance(cursor, list): # Ensure list has slot for current key; infer placeholder from original while len(cursor) <= key: # path to the slot we are about to create slot_path = path[:i] + (len(cursor),) try: orig_slot = _get_by_path(original, slot_path) except Exception: orig_slot = None if isinstance(orig_slot, (list, tuple)): cursor.append([]) elif isinstance(orig_slot, dict): cursor.append({}) else: cursor.append(None) else: if key not in cursor: next_key = path[i + 1] if isinstance(next_key, int): cursor[key] = [] else: cursor[key] = {} cursor = cursor[key] if isinstance(cursor, list): # Handle gaps in indices caused by JAX flattening eliding empty containers while len(cursor) <= path[-1]: slot_path = path[:-1] + (len(cursor),) try: orig_slot = _get_by_path(original, slot_path) except Exception: orig_slot = None if isinstance(orig_slot, (list, tuple)): cursor.append([]) elif isinstance(orig_slot, dict): cursor.append({}) else: cursor.append(None) cursor[path[-1]] = value else: assert isinstance(cursor, dict) cursor[path[-1]] = value # If original is a sequence and result is a list, pad trailing items if isinstance(original, (list, tuple)) and isinstance(result, list): for i in range(len(result), len(original)): slot = original[i] result.append(_to_simple(slot)) return result def _as_yaml_str(value) -> str: if (hasattr(value, '__len__') and len(value) == 0) or value is None: return '' value = jax.tree.map(_normalize_values, value) value = _maybe_pytree_to_dict(value) file = io.StringIO() yaml.dump( value, file, Dumper=NoneDumper, default_flow_style=False, indent=2, sort_keys=False, explicit_end=False, ) return file.getvalue().replace('\n...', '').replace("'", '').strip() def _render_array(x): shape, dtype = jnp.shape(x), jnp.result_type(x) shape_repr = ','.join(str(x) for x in shape) return f'[dim]{dtype}[/dim][{shape_repr}]' def _sort_variable_types(types: tp.Iterable[type]) -> list[type]: def _variable_parents_count(t: type): return sum(1 for p in t.mro() if issubclass(p, variablelib.Variable)) type_sort_key = {t: (-_variable_parents_count(t), t.__name__) for t in types} return sorted(types, key=lambda t: type_sort_key[t])