Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update mpl_draw() to fix multigraph plots #1204

Merged
merged 13 commits into from
Jun 10, 2024
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
---
fixes:
- |
Fixed the plots of multigraphs using :func:`.mpl_draw`. Previously, parallel edges of
multigraphs were plotted on top of each other, with overlapping arrows and labels.
The radius of parallel edges of the multigraph was fixed to be `0.25` for
`connectionstyle` supporting this argument in :func:`.draw_edges`. The edge lables
were offset to `0.25` in :func:`.draw_edge_labels` to align with their respective
edges. This fix can be tested using the following code:

.. jupyter-execute::

import rustworkx
from rustworkx.visualization import mpl_draw

graph = rustworkx.PyDiGraph()
graph.add_node('A')
graph.add_node('B')
graph.add_node('C')

graph.add_edge(1, 0, 2)
graph.add_edge(0, 1, 3)
graph.add_edge(1, 2, 4)

mpl_draw(graph, with_labels=True, labels=str, edge_labels=str, alpha=0.5)

- |
Refer to `#774 <https://github.com/Qiskit/rustworkx/issues/774>` for more
details.
108 changes: 62 additions & 46 deletions rustworkx/visualization/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,9 @@ def draw_edges(
edge_color = "k"

# set edge positions
edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edge_list])
edge_pos = set()
for e in edge_list:
edge_pos.add((tuple(pos[e[0]]), tuple(pos[e[1]])))

# Check if edge_color is an array of floats and map to edge_cmap.
# This is the only case handled differently from matplotlib
Expand Down Expand Up @@ -670,58 +672,17 @@ def to_marker_edge(marker_size, marker):
arrow_collection = []
mutation_scale = arrow_size # scale factor of arrow head

# compute view
mirustworkx = np.amin(np.ravel(edge_pos[:, :, 0]))
maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
miny = np.amin(np.ravel(edge_pos[:, :, 1]))
maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
w = maxx - mirustworkx
h = maxy - miny

base_connectionstyle = mpl.patches.ConnectionStyle(connectionstyle)

# Fallback for self-loop scale. Left outside of _connectionstyle so it is
# only computed once
max_nodesize = np.array(node_size).max()

def _connectionstyle(posA, posB, *args, **kwargs):
# check if we need to do a self-loop
if np.all(posA == posB):
# Self-loops are scaled by view extent, except in cases the extent
# is 0, e.g. for a single node. In this case, fall back to scaling
# by the maximum node size
selfloop_ht = 0.005 * max_nodesize if h == 0 else h
# this is called with _screen space_ values so covert back
# to data space
data_loc = ax.transData.inverted().transform(posA)
v_shift = 0.1 * selfloop_ht
h_shift = v_shift * 0.5
# put the top of the loop first so arrow is not hidden by node
path = [
# 1
data_loc + np.asarray([0, v_shift]),
# 4 4 4
data_loc + np.asarray([h_shift, v_shift]),
data_loc + np.asarray([h_shift, 0]),
data_loc,
# 4 4 4
data_loc + np.asarray([-h_shift, 0]),
data_loc + np.asarray([-h_shift, v_shift]),
data_loc + np.asarray([0, v_shift]),
]

ret = mpl.path.Path(ax.transData.transform(path), [1, 4, 4, 4, 4, 4, 4])
# if not, fall back to the user specified behavior
else:
ret = base_connectionstyle(posA, posB, *args, **kwargs)

return ret

# FancyArrowPatch doesn't handle color strings
arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha)
for i, (src, dst) in enumerate(edge_pos):
x1, y1 = src
x2, y2 = dst
for i, edge in enumerate(edge_pos):
x1, y1 = edge[0][0], edge[0][1]
x2, y2 = edge[1][0], edge[1][1]
shrink_source = 0 # space from source to tail
shrink_target = 0 # space from head to target
if np.iterable(node_size): # many node sizes
Expand Down Expand Up @@ -754,6 +715,12 @@ def _connectionstyle(posA, posB, *args, **kwargs):
else:
line_width = width

# radius of edges
if tuple(reversed(edge)) in edge_pos:
rad = 0.25
else:
rad = 0.0

arrow = mpl.patches.FancyArrowPatch(
(x1, y1),
(x2, y2),
Expand All @@ -763,14 +730,57 @@ def _connectionstyle(posA, posB, *args, **kwargs):
mutation_scale=mutation_scale,
color=arrow_color,
linewidth=line_width,
connectionstyle=_connectionstyle,
connectionstyle=connectionstyle + f", rad = {rad}",
maxwell04-wq marked this conversation as resolved.
Show resolved Hide resolved
linestyle=style,
zorder=1,
) # arrows go behind nodes

arrow_collection.append(arrow)
ax.add_patch(arrow)

edge_pos = np.asarray(tuple(edge_pos))

# compute view
mirustworkx = np.amin(np.ravel(edge_pos[:, :, 0]))
maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
miny = np.amin(np.ravel(edge_pos[:, :, 1]))
maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
w = maxx - mirustworkx
h = maxy - miny

def _connectionstyle(posA, posB, *args, **kwargs):
# check if we need to do a self-loop
if np.all(posA == posB):
# Self-loops are scaled by view extent, except in cases the extent
# is 0, e.g. for a single node. In this case, fall back to scaling
# by the maximum node size
selfloop_ht = 0.005 * max_nodesize if h == 0 else h
# this is called with _screen space_ values so covert back
# to data space
data_loc = ax.transData.inverted().transform(posA)
v_shift = 0.1 * selfloop_ht
h_shift = v_shift * 0.5
# put the top of the loop first so arrow is not hidden by node
path = [
# 1
data_loc + np.asarray([0, v_shift]),
# 4 4 4
data_loc + np.asarray([h_shift, v_shift]),
data_loc + np.asarray([h_shift, 0]),
data_loc,
# 4 4 4
data_loc + np.asarray([-h_shift, 0]),
data_loc + np.asarray([-h_shift, v_shift]),
data_loc + np.asarray([0, v_shift]),
]

ret = mpl.path.Path(ax.transData.transform(path), [1, 4, 4, 4, 4, 4, 4])
# if not, fall back to the user specified behavior
else:
ret = base_connectionstyle(posA, posB, *args, **kwargs)

return ret

# update view
padx, pady = 0.05 * w, 0.05 * h
corners = (mirustworkx - padx, miny - pady), (maxx + padx, maxy + pady)
Expand Down Expand Up @@ -1001,6 +1011,12 @@ def draw_edge_labels(
x1 * label_pos + x2 * (1.0 - label_pos),
y1 * label_pos + y2 * (1.0 - label_pos),
)
if (n2, n1) in labels.keys(): # loop
dy = np.abs(y2 - y1)
if n2 > n1:
y -= 0.25 * dy
else:
y += 0.25 * dy

if rotate:
# in degrees
Expand Down