Skip to content

Commit

Permalink
Add direct penzai.treescope support for NNX objects.
Browse files Browse the repository at this point in the history
This change implements the `__penzai_repr__` protocol on most NNX
objects, making it possible to directly visualize them using the
standard `penzai.treescope` configuration without an extra conversion
step. Modules, GraphDefs, and States are all visualizable.

The `nnx.display` function is no longer needed if Penzai is installed,
since `pz.ts.basic_interactive_setup()` followed by
`IPython.display.display` or `pz.show` (or just returning an object
from an IPython cell) is now sufficient to visualize NNX objects.

Also fixes GraphDef __repr__ to use "leaves" instead of "variables".

PiperOrigin-RevId: 637319547
  • Loading branch information
danieldjohnson authored and Flax Authors committed May 29, 2024
1 parent 2cfd174 commit 4c20880
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 91 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ ipython_genutils
sphinx-design
jupytext==1.13.8
dm-haiku
penzai; python_version>='3.10'
penzai>=0.1.2; python_version>='3.10'

# Need to pin docutils to 0.16 to make bulleted lists appear correctly on
# ReadTheDocs: https://stackoverflow.com/a/68008428
Expand Down
31 changes: 30 additions & 1 deletion flax/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,26 @@ def __nnx_repr__(self):
yield reprlib.Attr(
'static_fields', reprlib.PrettyMapping(self.static_fields)
)
yield reprlib.Attr('variables', reprlib.PrettyMapping(self.leaves))
yield reprlib.Attr('leaves', reprlib.PrettyMapping(self.leaves))
yield reprlib.Attr('metadata', self.metadata)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
return pz_repr_lib.render_object_constructor(
object_type=type(self),
attributes={
'type': self.type,
'index': self.index,
'attributes': self.attributes,
'subgraphs': dict(self.subgraphs),
'static_fields': dict(self.static_fields),
'leaves': dict(self.leaves),
'metadata': self.metadata,
},
path=path,
subtree_renderer=subtree_renderer,
)


@dataclasses.dataclass(frozen=True, repr=False)
class GraphDef(tp.Generic[Node], reprlib.Representable):
Expand All @@ -308,6 +325,18 @@ def __nnx_repr__(self):
yield reprlib.Attr('nodedef', self.nodedef)
yield reprlib.Attr('index_mapping', self.index_mapping)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
return pz_repr_lib.render_object_constructor(
object_type=type(self),
attributes={
'nodedef': self.nodedef,
'index_mapping': self.index_mapping,
},
path=path,
subtree_renderer=subtree_renderer,
)

def __deepcopy__(self, memo=None):
nodedef = deepcopy(self.nodedef, memo)
index_mapping = deepcopy(self.index_mapping, memo)
Expand Down
15 changes: 15 additions & 0 deletions flax/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,21 @@ def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
flatten_func=partial(_module_flatten, with_keys=False),
)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
from penzai.treescope import formatting_util # type: ignore[import-not-found,import-untyped]
children = {}
for name, value in vars(self).items():
if name.startswith('_'):
continue
children[name] = value
return pz_repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
color=formatting_util.color_from_string(type(self).__qualname__)
)

# -------------------------
# Pytree Definition
Expand Down
22 changes: 22 additions & 0 deletions flax/nnx/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ def __nnx_repr__(self):
yield reprlib.Object(type(self))
yield reprlib.Attr('trace_state', self._trace_state)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
return pz_repr_lib.render_object_constructor(
object_type=type(self),
attributes={'trace_state': self._trace_state},
path=path,
subtree_renderer=subtree_renderer,
)

class ObjectMeta(ABCMeta):
if not tp.TYPE_CHECKING:
Expand Down Expand Up @@ -162,6 +170,20 @@ def to_shape_dtype(value):
if clear_seen:
CONTEXT.seen_modules_repr = None

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
children = {}
for name, value in vars(self).items():
if name.startswith('_'):
continue
children[name] = value
return pz_repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
)

# Graph Definition
def _graph_node_flatten(self):
nodes = sorted(
Expand Down
23 changes: 23 additions & 0 deletions flax/nnx/nnx/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ def __nnx_repr__(self):
continue
yield r

def __penzai_repr__(self, path, subtree_renderer):
children = {}
for k, v in self.state.items():
if isinstance(v, State):
v = NestedStateRepr(v)
children[k] = v
# Render as the dictionary itself at the same path.
return subtree_renderer(children, path=path)

class State(tp.MutableMapping[Key, tp.Any], reprlib.Representable):
def __init__(
Expand Down Expand Up @@ -131,6 +139,21 @@ def __nnx_repr__(self):
v = NestedStateRepr(v)
yield reprlib.Attr(repr(k), v)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]

children = {}
for k, v in self.items():
if isinstance(v, State):
v = NestedStateRepr(v)
children[k] = v
return pz_repr_lib.render_dictionary_wrapper(
object_type=type(self),
wrapped_dict=children,
path=path,
subtree_renderer=subtree_renderer,
)

def flat_state(self) -> FlatState:
return traverse_util.flatten_dict(self._mapping) # type: ignore

Expand Down
9 changes: 9 additions & 0 deletions flax/nnx/nnx/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,14 @@ def __nnx_repr__(self):
yield reprlib.Object(f'{type(self).__name__}')
yield reprlib.Attr('jax_trace', self._jax_trace)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
return pz_repr_lib.render_object_constructor(
object_type=type(self),
attributes={'jax_trace': self._jax_trace},
path=path,
subtree_renderer=subtree_renderer,
)

def __eq__(self, other):
return isinstance(other, TraceState) and self._jax_trace is other._jax_trace
30 changes: 30 additions & 0 deletions flax/nnx/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,22 @@ def __nnx_repr__(self):
continue
yield reprlib.Attr(name, repr(value))

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
children = {}
for name, value in vars(self).items():
if name == 'raw_value':
name = 'value'
if name.endswith('_hooks') or name == '_trace_state':
continue
children[name] = value
return pz_repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
)

# hooks API
if tp.TYPE_CHECKING:

Expand Down Expand Up @@ -592,6 +608,20 @@ def __nnx_repr__(self):
continue
yield reprlib.Attr(name, repr(value))

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
children = {'type': self.type}
for name, value in vars(self).items():
if name == 'type' or name.endswith('_hooks'):
continue
children[name] = value
return pz_repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
)

def replace(self, value: B) -> 'VariableState[B]':
return VariableState(self.type, value, **self.get_metadata())

Expand Down
89 changes: 1 addition & 88 deletions flax/nnx/nnx/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import importlib.util
import typing as tp

import jax

from flax import nnx

penzai_installed = importlib.util.find_spec('penzai') is not None
try:
Expand All @@ -44,85 +38,4 @@ def display(*args):

with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
for x in args:
value = to_dataclass(x)
pz.ts.display(value, ignore_exceptions=True)


def to_dataclass(node):
seen_nodes = set()
return _treemap_to_dataclass(node, seen_nodes)


def _to_dataclass(x, seen_nodes: set[int]):
if nnx.graph.is_graph_node(x):
if id(x) in seen_nodes:
dc_type = _make_dataclass_obj(
type(x),
{'repeated': True},
)
return dc_type
seen_nodes.add(id(x))
node_impl = nnx.graph.get_node_impl(x)
node_dict = node_impl.node_dict(x)
node_dict = {
str(key): _treemap_to_dataclass(value, seen_nodes)
for key, value in node_dict.items()
}
dc_type = _make_dataclass_obj(
type(x),
{str(key): value for key, value in node_dict.items()},
)
return dc_type
elif isinstance(x, (nnx.Variable, nnx.VariableState)):
obj_vars = vars(x).copy()
if 'raw_value' in obj_vars:
obj_vars['value'] = obj_vars.pop('raw_value')
if '_trace_state' in obj_vars:
del obj_vars['_trace_state']
for name in list(obj_vars):
if name.endswith('_hooks'):
del obj_vars[name]
obj_vars = {
key: _treemap_to_dataclass(value, seen_nodes)
for key, value in obj_vars.items()
}
dc_type = _make_dataclass_obj(
type(x),
obj_vars,
penzai_dataclass=not isinstance(x, nnx.VariableState),
)
return dc_type
elif isinstance(x, nnx.State):
return _treemap_to_dataclass(x._mapping, seen_nodes)
return x


def _treemap_to_dataclass(node, seen_nodes: set[int]):
def _to_dataclass_fn(x):
return _to_dataclass(x, seen_nodes)

return jax.tree.map(
_to_dataclass_fn,
node,
is_leaf=lambda x: isinstance(x, (nnx.VariableState, nnx.State)),
)


def _make_dataclass_obj(
cls, fields: tp.Mapping[str, tp.Any], penzai_dataclass: bool = True
) -> tp.Type:
from penzai import pz # type: ignore[import-error]

dataclass = pz.pytree_dataclass if penzai_dataclass else dataclasses.dataclass
base = pz.Layer if penzai_dataclass else object

attributes = {
'__annotations__': {key: type(value) for key, value in fields.items()},
}

if hasattr(cls, '__call__'):
attributes['__call__'] = cls.__call__

dc_type = type(cls.__name__, (base,), attributes)
dc_type = dataclass(dc_type)
return dc_type(**fields)
pz.ts.display(x, ignore_exceptions=True)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ testing = [
"nbstripout",
"black[jupyter]==23.7.0",
# "pyink==23.5.0", # disabling pyink fow now
"penzai; python_version>='3.10'",
"penzai>=0.1.2; python_version>='3.10'",
]

[project.urls]
Expand Down

0 comments on commit 4c20880

Please sign in to comment.