Skip to content

Commit

Permalink
feat(graphviz): node- and edge-specific custom attributes (#8527)
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-gy authored Mar 5, 2024
1 parent 5c121df commit 98c52aa
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
14 changes: 14 additions & 0 deletions ibis/expr/types/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import ibis.expr.types as ir
from ibis.backends import BaseBackend
from ibis.expr.visualize import EdgeAttributeGetter, NodeAttributeGetter

TimeContext = tuple[pd.Timestamp, pd.Timestamp]

Expand Down Expand Up @@ -164,7 +165,9 @@ def visualize(
label_edges: bool = False,
verbose: bool = False,
node_attr: Mapping[str, str] | None = None,
node_attr_getter: NodeAttributeGetter | None = None,
edge_attr: Mapping[str, str] | None = None,
edge_attr_getter: EdgeAttributeGetter | None = None,
) -> None:
"""Visualize an expression as a GraphViz graph in the browser.
Expand All @@ -180,23 +183,32 @@ def visualize(
node_attr
Mapping of ``(attribute, value)`` pairs set for all nodes.
Options are specified by the ``graphviz`` Python library.
node_attr_getter
Callback taking a node and returning a mapping of ``(attribute, value)`` pairs
for that node. Options are specified by the ``graphviz`` Python library.
edge_attr
Mapping of ``(attribute, value)`` pairs set for all edges.
Options are specified by the ``graphviz`` Python library.
edge_attr_getter
Callback taking two adjacent nodes and returning a mapping of ``(attribute, value)`` pairs
for the edge between those nodes. Options are specified by the ``graphviz`` Python library.
Examples
--------
Open the visualization of an expression in default browser:
>>> import ibis
>>> import ibis.expr.operations as ops
>>> left = ibis.table(dict(a="int64", b="string"), name="left")
>>> right = ibis.table(dict(b="string", c="int64", d="string"), name="right")
>>> expr = left.inner_join(right, "b").select(left.a, b=right.c, c=right.d)
>>> expr.visualize(
... format="svg",
... label_edges=True,
... node_attr={"fontname": "Roboto Mono", "fontsize": "10"},
... node_attr_getter=lambda node: isinstance(node, ops.Field) and {"shape": "oval"},
... edge_attr={"fontsize": "8"},
... edge_attr_getter=lambda u, v: isinstance(u, ops.Field) and {"color": "red"},
... ) # quartodoc: +SKIP # doctest: +SKIP
Raises
Expand All @@ -210,7 +222,9 @@ def visualize(
viz.to_graph(
self,
node_attr=node_attr,
node_attr_getter=node_attr_getter,
edge_attr=edge_attr,
edge_attr_getter=edge_attr_getter,
label_edges=label_edges,
),
format=format,
Expand Down
32 changes: 28 additions & 4 deletions ibis/expr/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import tempfile
from html import escape
from typing import Callable, Optional

import graphviz as gv

Expand Down Expand Up @@ -88,8 +89,18 @@ def get_label(node):
DEFAULT_NODE_ATTRS = {"shape": "box", "fontname": "Deja Vu Sans Mono"}
DEFAULT_EDGE_ATTRS = {"fontname": "Deja Vu Sans Mono"}

NodeAttributeGetter = Callable[[ops.Node], Optional[dict[str, str]]]
EdgeAttributeGetter = Callable[[ops.Node, ops.Node], Optional[dict[str, str]]]

def to_graph(expr, node_attr=None, edge_attr=None, label_edges: bool = False):

def to_graph(
expr,
node_attr=None,
node_attr_getter: NodeAttributeGetter | None = None,
edge_attr=None,
edge_attr_getter: EdgeAttributeGetter | None = None,
label_edges: bool = False,
):
graph = Graph.from_bfs(expr.op(), filter=ops.Node)

g = gv.Digraph(
Expand All @@ -105,13 +116,21 @@ def to_graph(expr, node_attr=None, edge_attr=None, label_edges: bool = False):
for v, us in graph.items():
vhash = str(hash(v))
if v not in seen:
g.node(vhash, label=get_label(v))
g.node(
vhash,
label=get_label(v),
_attributes=node_attr_getter(v) if node_attr_getter else {},
)
seen.add(v)

for u in us:
uhash = str(hash(u))
if u not in seen:
g.node(uhash, label=get_label(u))
g.node(
uhash,
label=get_label(u),
_attributes=node_attr_getter(u) if node_attr_getter else {},
)
seen.add(u)
if (edge := (u, v)) not in edges:
if not label_edges:
Expand All @@ -128,7 +147,12 @@ def to_graph(expr, node_attr=None, edge_attr=None, label_edges: bool = False):
name = None
label = f"<.{name}>"

g.edge(uhash, vhash, label=label)
g.edge(
uhash,
vhash,
label=label,
_attributes=edge_attr_getter(u, v) if edge_attr_getter else {},
)
edges.add(edge)
return g

Expand Down

0 comments on commit 98c52aa

Please sign in to comment.