diff --git a/ibis/expr/types/core.py b/ibis/expr/types/core.py index f259e98d3258..e5ebd2939047 100644 --- a/ibis/expr/types/core.py +++ b/ibis/expr/types/core.py @@ -28,6 +28,7 @@ import ibis.expr.types as ir from ibis.backends import BaseBackend + from ibis.expr.visualize import EdgeAttributeGetter, NodeAttributeGetter TimeContext = tuple[pd.Timestamp, pd.Timestamp] @@ -164,7 +165,9 @@ def visualize( label_edges: bool = False, verbose: bool = False, node_attr: Mapping[str, str] | None = None, + node_attr_getter: NodeAttributeGetter | None = None, edge_attr: Mapping[str, str] | None = None, + edge_attr_getter: EdgeAttributeGetter | None = None, ) -> None: """Visualize an expression as a GraphViz graph in the browser. @@ -180,15 +183,22 @@ def visualize( node_attr Mapping of ``(attribute, value)`` pairs set for all nodes. Options are specified by the ``graphviz`` Python library. + node_attr_getter + Callback taking a node and returning a mapping of ``(attribute, value)`` pairs + for that node. Options are specified by the ``graphviz`` Python library. edge_attr Mapping of ``(attribute, value)`` pairs set for all edges. Options are specified by the ``graphviz`` Python library. + edge_attr_getter + Callback taking two adjacent nodes and returning a mapping of ``(attribute, value)`` pairs + for the edge between those nodes. Options are specified by the ``graphviz`` Python library. Examples -------- Open the visualization of an expression in default browser: >>> import ibis + >>> import ibis.expr.operations as ops >>> left = ibis.table(dict(a="int64", b="string"), name="left") >>> right = ibis.table(dict(b="string", c="int64", d="string"), name="right") >>> expr = left.inner_join(right, "b").select(left.a, b=right.c, c=right.d) @@ -196,7 +206,9 @@ def visualize( ... format="svg", ... label_edges=True, ... node_attr={"fontname": "Roboto Mono", "fontsize": "10"}, + ... node_attr_getter=lambda node: isinstance(node, ops.Field) and {"shape": "oval"}, ... edge_attr={"fontsize": "8"}, + ... edge_attr_getter=lambda u, v: isinstance(u, ops.Field) and {"color": "red"}, ... ) # quartodoc: +SKIP # doctest: +SKIP Raises @@ -210,7 +222,9 @@ def visualize( viz.to_graph( self, node_attr=node_attr, + node_attr_getter=node_attr_getter, edge_attr=edge_attr, + edge_attr_getter=edge_attr_getter, label_edges=label_edges, ), format=format, diff --git a/ibis/expr/visualize.py b/ibis/expr/visualize.py index bcaa41a53e38..34ffa999f575 100644 --- a/ibis/expr/visualize.py +++ b/ibis/expr/visualize.py @@ -4,6 +4,7 @@ import sys import tempfile from html import escape +from typing import Callable, Optional import graphviz as gv @@ -88,8 +89,18 @@ def get_label(node): DEFAULT_NODE_ATTRS = {"shape": "box", "fontname": "Deja Vu Sans Mono"} DEFAULT_EDGE_ATTRS = {"fontname": "Deja Vu Sans Mono"} +NodeAttributeGetter = Callable[[ops.Node], Optional[dict[str, str]]] +EdgeAttributeGetter = Callable[[ops.Node, ops.Node], Optional[dict[str, str]]] -def to_graph(expr, node_attr=None, edge_attr=None, label_edges: bool = False): + +def to_graph( + expr, + node_attr=None, + node_attr_getter: NodeAttributeGetter | None = None, + edge_attr=None, + edge_attr_getter: EdgeAttributeGetter | None = None, + label_edges: bool = False, +): graph = Graph.from_bfs(expr.op(), filter=ops.Node) g = gv.Digraph( @@ -105,13 +116,21 @@ def to_graph(expr, node_attr=None, edge_attr=None, label_edges: bool = False): for v, us in graph.items(): vhash = str(hash(v)) if v not in seen: - g.node(vhash, label=get_label(v)) + g.node( + vhash, + label=get_label(v), + _attributes=node_attr_getter(v) if node_attr_getter else {}, + ) seen.add(v) for u in us: uhash = str(hash(u)) if u not in seen: - g.node(uhash, label=get_label(u)) + g.node( + uhash, + label=get_label(u), + _attributes=node_attr_getter(u) if node_attr_getter else {}, + ) seen.add(u) if (edge := (u, v)) not in edges: if not label_edges: @@ -128,7 +147,12 @@ def to_graph(expr, node_attr=None, edge_attr=None, label_edges: bool = False): name = None label = f"<.{name}>" - g.edge(uhash, vhash, label=label) + g.edge( + uhash, + vhash, + label=label, + _attributes=edge_attr_getter(u, v) if edge_attr_getter else {}, + ) edges.add(edge) return g