diff --git a/releasenotes/notes/fix-mpl-draw-digraph-plots-aecf86738ab9b0db.yaml b/releasenotes/notes/fix-mpl-draw-digraph-plots-aecf86738ab9b0db.yaml new file mode 100644 index 000000000..926e36044 --- /dev/null +++ b/releasenotes/notes/fix-mpl-draw-digraph-plots-aecf86738ab9b0db.yaml @@ -0,0 +1,29 @@ +--- +fixes: + - | + Fixed the plots of multigraphs using :func:`.mpl_draw`. Previously, parallel edges of + multigraphs were plotted on top of each other, with overlapping arrows and labels. + The radius of parallel edges of the multigraph was fixed to be `0.25` for + `connectionstyle` supporting this argument in :func:`.draw_edges`. The edge lables + were offset to `0.25` in :func:`.draw_edge_labels` to align with their respective + edges. This fix can be tested using the following code: + + .. jupyter-execute:: + + import rustworkx + from rustworkx.visualization import mpl_draw + + graph = rustworkx.PyDiGraph() + graph.add_node('A') + graph.add_node('B') + graph.add_node('C') + + graph.add_edge(1, 0, 2) + graph.add_edge(0, 1, 3) + graph.add_edge(1, 2, 4) + + mpl_draw(graph, with_labels=True, labels=str, edge_labels=str, alpha=0.5) + + - | + Refer to `#774 ` for more + details. diff --git a/rustworkx/visualization/matplotlib.py b/rustworkx/visualization/matplotlib.py index 7f1437112..559756a13 100644 --- a/rustworkx/visualization/matplotlib.py +++ b/rustworkx/visualization/matplotlib.py @@ -636,7 +636,9 @@ def draw_edges( edge_color = "k" # set edge positions - edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edge_list]) + edge_pos = set() + for e in edge_list: + edge_pos.add((tuple(pos[e[0]]), tuple(pos[e[1]]))) # Check if edge_color is an array of floats and map to edge_cmap. # This is the only case handled differently from matplotlib @@ -670,58 +672,17 @@ def to_marker_edge(marker_size, marker): arrow_collection = [] mutation_scale = arrow_size # scale factor of arrow head - # compute view - mirustworkx = np.amin(np.ravel(edge_pos[:, :, 0])) - maxx = np.amax(np.ravel(edge_pos[:, :, 0])) - miny = np.amin(np.ravel(edge_pos[:, :, 1])) - maxy = np.amax(np.ravel(edge_pos[:, :, 1])) - w = maxx - mirustworkx - h = maxy - miny - base_connectionstyle = mpl.patches.ConnectionStyle(connectionstyle) # Fallback for self-loop scale. Left outside of _connectionstyle so it is # only computed once max_nodesize = np.array(node_size).max() - def _connectionstyle(posA, posB, *args, **kwargs): - # check if we need to do a self-loop - if np.all(posA == posB): - # Self-loops are scaled by view extent, except in cases the extent - # is 0, e.g. for a single node. In this case, fall back to scaling - # by the maximum node size - selfloop_ht = 0.005 * max_nodesize if h == 0 else h - # this is called with _screen space_ values so covert back - # to data space - data_loc = ax.transData.inverted().transform(posA) - v_shift = 0.1 * selfloop_ht - h_shift = v_shift * 0.5 - # put the top of the loop first so arrow is not hidden by node - path = [ - # 1 - data_loc + np.asarray([0, v_shift]), - # 4 4 4 - data_loc + np.asarray([h_shift, v_shift]), - data_loc + np.asarray([h_shift, 0]), - data_loc, - # 4 4 4 - data_loc + np.asarray([-h_shift, 0]), - data_loc + np.asarray([-h_shift, v_shift]), - data_loc + np.asarray([0, v_shift]), - ] - - ret = mpl.path.Path(ax.transData.transform(path), [1, 4, 4, 4, 4, 4, 4]) - # if not, fall back to the user specified behavior - else: - ret = base_connectionstyle(posA, posB, *args, **kwargs) - - return ret - # FancyArrowPatch doesn't handle color strings arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha) - for i, (src, dst) in enumerate(edge_pos): - x1, y1 = src - x2, y2 = dst + for i, edge in enumerate(edge_pos): + x1, y1 = edge[0][0], edge[0][1] + x2, y2 = edge[1][0], edge[1][1] shrink_source = 0 # space from source to tail shrink_target = 0 # space from head to target if np.iterable(node_size): # many node sizes @@ -754,6 +715,12 @@ def _connectionstyle(posA, posB, *args, **kwargs): else: line_width = width + # radius of edges + if tuple(reversed(edge)) in edge_pos: + rad = 0.25 + else: + rad = 0.0 + arrow = mpl.patches.FancyArrowPatch( (x1, y1), (x2, y2), @@ -763,7 +730,7 @@ def _connectionstyle(posA, posB, *args, **kwargs): mutation_scale=mutation_scale, color=arrow_color, linewidth=line_width, - connectionstyle=_connectionstyle, + connectionstyle=connectionstyle + f", rad = {rad}", linestyle=style, zorder=1, ) # arrows go behind nodes @@ -771,6 +738,49 @@ def _connectionstyle(posA, posB, *args, **kwargs): arrow_collection.append(arrow) ax.add_patch(arrow) + edge_pos = np.asarray(tuple(edge_pos)) + + # compute view + mirustworkx = np.amin(np.ravel(edge_pos[:, :, 0])) + maxx = np.amax(np.ravel(edge_pos[:, :, 0])) + miny = np.amin(np.ravel(edge_pos[:, :, 1])) + maxy = np.amax(np.ravel(edge_pos[:, :, 1])) + w = maxx - mirustworkx + h = maxy - miny + + def _connectionstyle(posA, posB, *args, **kwargs): + # check if we need to do a self-loop + if np.all(posA == posB): + # Self-loops are scaled by view extent, except in cases the extent + # is 0, e.g. for a single node. In this case, fall back to scaling + # by the maximum node size + selfloop_ht = 0.005 * max_nodesize if h == 0 else h + # this is called with _screen space_ values so covert back + # to data space + data_loc = ax.transData.inverted().transform(posA) + v_shift = 0.1 * selfloop_ht + h_shift = v_shift * 0.5 + # put the top of the loop first so arrow is not hidden by node + path = [ + # 1 + data_loc + np.asarray([0, v_shift]), + # 4 4 4 + data_loc + np.asarray([h_shift, v_shift]), + data_loc + np.asarray([h_shift, 0]), + data_loc, + # 4 4 4 + data_loc + np.asarray([-h_shift, 0]), + data_loc + np.asarray([-h_shift, v_shift]), + data_loc + np.asarray([0, v_shift]), + ] + + ret = mpl.path.Path(ax.transData.transform(path), [1, 4, 4, 4, 4, 4, 4]) + # if not, fall back to the user specified behavior + else: + ret = base_connectionstyle(posA, posB, *args, **kwargs) + + return ret + # update view padx, pady = 0.05 * w, 0.05 * h corners = (mirustworkx - padx, miny - pady), (maxx + padx, maxy + pady) @@ -1001,6 +1011,12 @@ def draw_edge_labels( x1 * label_pos + x2 * (1.0 - label_pos), y1 * label_pos + y2 * (1.0 - label_pos), ) + if (n2, n1) in labels.keys(): # loop + dy = np.abs(y2 - y1) + if n2 > n1: + y -= 0.25 * dy + else: + y += 0.25 * dy if rotate: # in degrees