Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add direct penzai.treescope support for NNX objects. #3948

Merged
1 commit merged into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ ipython_genutils
sphinx-design
jupytext==1.13.8
dm-haiku
penzai; python_version>='3.10'
penzai>=0.1.2; python_version>='3.10'

# Need to pin docutils to 0.16 to make bulleted lists appear correctly on
# ReadTheDocs: https://stackoverflow.com/a/68008428
Expand Down
31 changes: 30 additions & 1 deletion flax/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,26 @@ def __nnx_repr__(self):
yield reprlib.Attr(
'static_fields', reprlib.PrettyMapping(self.static_fields)
)
yield reprlib.Attr('variables', reprlib.PrettyMapping(self.leaves))
yield reprlib.Attr('leaves', reprlib.PrettyMapping(self.leaves))
yield reprlib.Attr('metadata', self.metadata)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
return pz_repr_lib.render_object_constructor(
object_type=type(self),
attributes={
'type': self.type,
'index': self.index,
'attributes': self.attributes,
'subgraphs': dict(self.subgraphs),
'static_fields': dict(self.static_fields),
'leaves': dict(self.leaves),
'metadata': self.metadata,
},
path=path,
subtree_renderer=subtree_renderer,
)


@dataclasses.dataclass(frozen=True, repr=False)
class GraphDef(tp.Generic[Node], reprlib.Representable):
Expand All @@ -308,6 +325,18 @@ def __nnx_repr__(self):
yield reprlib.Attr('nodedef', self.nodedef)
yield reprlib.Attr('index_mapping', self.index_mapping)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
return pz_repr_lib.render_object_constructor(
object_type=type(self),
attributes={
'nodedef': self.nodedef,
'index_mapping': self.index_mapping,
},
path=path,
subtree_renderer=subtree_renderer,
)

def __deepcopy__(self, memo=None):
nodedef = deepcopy(self.nodedef, memo)
index_mapping = deepcopy(self.index_mapping, memo)
Expand Down
15 changes: 15 additions & 0 deletions flax/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,21 @@ def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
flatten_func=partial(_module_flatten, with_keys=False),
)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
from penzai.treescope import formatting_util # type: ignore[import-not-found,import-untyped]
children = {}
for name, value in vars(self).items():
if name.startswith('_'):
continue
children[name] = value
return pz_repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
color=formatting_util.color_from_string(type(self).__qualname__)
)

# -------------------------
# Pytree Definition
Expand Down
22 changes: 22 additions & 0 deletions flax/nnx/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ def __nnx_repr__(self):
yield reprlib.Object(type(self))
yield reprlib.Attr('trace_state', self._trace_state)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
return pz_repr_lib.render_object_constructor(
object_type=type(self),
attributes={'trace_state': self._trace_state},
path=path,
subtree_renderer=subtree_renderer,
)

class ObjectMeta(ABCMeta):
if not tp.TYPE_CHECKING:
Expand Down Expand Up @@ -162,6 +170,20 @@ def to_shape_dtype(value):
if clear_seen:
CONTEXT.seen_modules_repr = None

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
children = {}
for name, value in vars(self).items():
if name.startswith('_'):
continue
children[name] = value
return pz_repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
)

# Graph Definition
def _graph_node_flatten(self):
nodes = sorted(
Expand Down
23 changes: 23 additions & 0 deletions flax/nnx/nnx/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,14 @@ def __nnx_repr__(self):
continue
yield r

def __penzai_repr__(self, path, subtree_renderer):
children = {}
for k, v in self.state.items():
if isinstance(v, State):
v = NestedStateRepr(v)
children[k] = v
# Render as the dictionary itself at the same path.
return subtree_renderer(children, path=path)

class State(tp.MutableMapping[Key, tp.Any], reprlib.Representable):
def __init__(
Expand Down Expand Up @@ -131,6 +139,21 @@ def __nnx_repr__(self):
v = NestedStateRepr(v)
yield reprlib.Attr(repr(k), v)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]

children = {}
for k, v in self.items():
if isinstance(v, State):
v = NestedStateRepr(v)
children[k] = v
return pz_repr_lib.render_dictionary_wrapper(
object_type=type(self),
wrapped_dict=children,
path=path,
subtree_renderer=subtree_renderer,
)

def flat_state(self) -> FlatState:
return traverse_util.flatten_dict(self._mapping) # type: ignore

Expand Down
9 changes: 9 additions & 0 deletions flax/nnx/nnx/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,14 @@ def __nnx_repr__(self):
yield reprlib.Object(f'{type(self).__name__}')
yield reprlib.Attr('jax_trace', self._jax_trace)

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
return pz_repr_lib.render_object_constructor(
object_type=type(self),
attributes={'jax_trace': self._jax_trace},
path=path,
subtree_renderer=subtree_renderer,
)

def __eq__(self, other):
return isinstance(other, TraceState) and self._jax_trace is other._jax_trace
30 changes: 30 additions & 0 deletions flax/nnx/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,22 @@ def __nnx_repr__(self):
continue
yield reprlib.Attr(name, repr(value))

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
children = {}
for name, value in vars(self).items():
if name == 'raw_value':
name = 'value'
if name.endswith('_hooks') or name == '_trace_state':
continue
children[name] = value
return pz_repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
)

# hooks API
if tp.TYPE_CHECKING:

Expand Down Expand Up @@ -592,6 +608,20 @@ def __nnx_repr__(self):
continue
yield reprlib.Attr(name, repr(value))

def __penzai_repr__(self, path, subtree_renderer):
from penzai.treescope import repr_lib as pz_repr_lib # type: ignore[import-not-found,import-untyped]
children = {'type': self.type}
for name, value in vars(self).items():
if name == 'type' or name.endswith('_hooks'):
continue
children[name] = value
return pz_repr_lib.render_object_constructor(
object_type=type(self),
attributes=children,
path=path,
subtree_renderer=subtree_renderer,
)

def replace(self, value: B) -> 'VariableState[B]':
return VariableState(self.type, value, **self.get_metadata())

Expand Down
89 changes: 1 addition & 88 deletions flax/nnx/nnx/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import importlib.util
import typing as tp

import jax

from flax import nnx

penzai_installed = importlib.util.find_spec('penzai') is not None
try:
Expand All @@ -44,85 +38,4 @@ def display(*args):

with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
for x in args:
value = to_dataclass(x)
pz.ts.display(value, ignore_exceptions=True)


def to_dataclass(node):
seen_nodes = set()
return _treemap_to_dataclass(node, seen_nodes)


def _to_dataclass(x, seen_nodes: set[int]):
if nnx.graph.is_graph_node(x):
if id(x) in seen_nodes:
dc_type = _make_dataclass_obj(
type(x),
{'repeated': True},
)
return dc_type
seen_nodes.add(id(x))
node_impl = nnx.graph.get_node_impl(x)
node_dict = node_impl.node_dict(x)
node_dict = {
str(key): _treemap_to_dataclass(value, seen_nodes)
for key, value in node_dict.items()
}
dc_type = _make_dataclass_obj(
type(x),
{str(key): value for key, value in node_dict.items()},
)
return dc_type
elif isinstance(x, (nnx.Variable, nnx.VariableState)):
obj_vars = vars(x).copy()
if 'raw_value' in obj_vars:
obj_vars['value'] = obj_vars.pop('raw_value')
if '_trace_state' in obj_vars:
del obj_vars['_trace_state']
for name in list(obj_vars):
if name.endswith('_hooks'):
del obj_vars[name]
obj_vars = {
key: _treemap_to_dataclass(value, seen_nodes)
for key, value in obj_vars.items()
}
dc_type = _make_dataclass_obj(
type(x),
obj_vars,
penzai_dataclass=not isinstance(x, nnx.VariableState),
)
return dc_type
elif isinstance(x, nnx.State):
return _treemap_to_dataclass(x._mapping, seen_nodes)
return x


def _treemap_to_dataclass(node, seen_nodes: set[int]):
def _to_dataclass_fn(x):
return _to_dataclass(x, seen_nodes)

return jax.tree.map(
_to_dataclass_fn,
node,
is_leaf=lambda x: isinstance(x, (nnx.VariableState, nnx.State)),
)


def _make_dataclass_obj(
cls, fields: tp.Mapping[str, tp.Any], penzai_dataclass: bool = True
) -> tp.Type:
from penzai import pz # type: ignore[import-error]

dataclass = pz.pytree_dataclass if penzai_dataclass else dataclasses.dataclass
base = pz.Layer if penzai_dataclass else object

attributes = {
'__annotations__': {key: type(value) for key, value in fields.items()},
}

if hasattr(cls, '__call__'):
attributes['__call__'] = cls.__call__

dc_type = type(cls.__name__, (base,), attributes)
dc_type = dataclass(dc_type)
return dc_type(**fields)
pz.ts.display(x, ignore_exceptions=True)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ testing = [
"nbstripout",
"black[jupyter]==23.7.0",
# "pyink==23.5.0", # disabling pyink fow now
"penzai; python_version>='3.10'",
"penzai>=0.1.2; python_version>='3.10'",
]

[project.urls]
Expand Down
Loading