Skip to content

Commit

Permalink
Update mpl_draw() to fix multigraph plots (#1204)
Browse files Browse the repository at this point in the history
* Update #1: Fixing  mpl_draw() for multigraphs

* Update matplotlib.py for formatting

* Update rustworkx/visualization/matplotlib.py

Co-authored-by: Ivan Carvalho <[email protected]>

* Update matplotlib.py to remove the loop

* Add releasenotes

* Reformat connectionstyle string in rustworkx/visualization/matplotlib.py

Co-authored-by: Ivan Carvalho <[email protected]>

* Fixes #774

* Optimize edge search by using sets

---------

Co-authored-by: Ivan Carvalho <[email protected]>
  • Loading branch information
maxwell04-wq and IvanIsCoding authored Jun 10, 2024
1 parent c7a7d53 commit 646057a
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
---
fixes:
- |
Fixed the plots of multigraphs using :func:`.mpl_draw`. Previously, parallel edges of
multigraphs were plotted on top of each other, with overlapping arrows and labels.
The radius of parallel edges of the multigraph was fixed to be `0.25` for
`connectionstyle` supporting this argument in :func:`.draw_edges`. The edge lables
were offset to `0.25` in :func:`.draw_edge_labels` to align with their respective
edges. This fix can be tested using the following code:
.. jupyter-execute::
import rustworkx
from rustworkx.visualization import mpl_draw
graph = rustworkx.PyDiGraph()
graph.add_node('A')
graph.add_node('B')
graph.add_node('C')
graph.add_edge(1, 0, 2)
graph.add_edge(0, 1, 3)
graph.add_edge(1, 2, 4)
mpl_draw(graph, with_labels=True, labels=str, edge_labels=str, alpha=0.5)
- |
Refer to `#774 <https://github.com/Qiskit/rustworkx/issues/774>` for more
details.
108 changes: 62 additions & 46 deletions rustworkx/visualization/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,9 @@ def draw_edges(
edge_color = "k"

# set edge positions
edge_pos = np.asarray([(pos[e[0]], pos[e[1]]) for e in edge_list])
edge_pos = set()
for e in edge_list:
edge_pos.add((tuple(pos[e[0]]), tuple(pos[e[1]])))

# Check if edge_color is an array of floats and map to edge_cmap.
# This is the only case handled differently from matplotlib
Expand Down Expand Up @@ -670,58 +672,17 @@ def to_marker_edge(marker_size, marker):
arrow_collection = []
mutation_scale = arrow_size # scale factor of arrow head

# compute view
mirustworkx = np.amin(np.ravel(edge_pos[:, :, 0]))
maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
miny = np.amin(np.ravel(edge_pos[:, :, 1]))
maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
w = maxx - mirustworkx
h = maxy - miny

base_connectionstyle = mpl.patches.ConnectionStyle(connectionstyle)

# Fallback for self-loop scale. Left outside of _connectionstyle so it is
# only computed once
max_nodesize = np.array(node_size).max()

def _connectionstyle(posA, posB, *args, **kwargs):
# check if we need to do a self-loop
if np.all(posA == posB):
# Self-loops are scaled by view extent, except in cases the extent
# is 0, e.g. for a single node. In this case, fall back to scaling
# by the maximum node size
selfloop_ht = 0.005 * max_nodesize if h == 0 else h
# this is called with _screen space_ values so covert back
# to data space
data_loc = ax.transData.inverted().transform(posA)
v_shift = 0.1 * selfloop_ht
h_shift = v_shift * 0.5
# put the top of the loop first so arrow is not hidden by node
path = [
# 1
data_loc + np.asarray([0, v_shift]),
# 4 4 4
data_loc + np.asarray([h_shift, v_shift]),
data_loc + np.asarray([h_shift, 0]),
data_loc,
# 4 4 4
data_loc + np.asarray([-h_shift, 0]),
data_loc + np.asarray([-h_shift, v_shift]),
data_loc + np.asarray([0, v_shift]),
]

ret = mpl.path.Path(ax.transData.transform(path), [1, 4, 4, 4, 4, 4, 4])
# if not, fall back to the user specified behavior
else:
ret = base_connectionstyle(posA, posB, *args, **kwargs)

return ret

# FancyArrowPatch doesn't handle color strings
arrow_colors = mpl.colors.colorConverter.to_rgba_array(edge_color, alpha)
for i, (src, dst) in enumerate(edge_pos):
x1, y1 = src
x2, y2 = dst
for i, edge in enumerate(edge_pos):
x1, y1 = edge[0][0], edge[0][1]
x2, y2 = edge[1][0], edge[1][1]
shrink_source = 0 # space from source to tail
shrink_target = 0 # space from head to target
if np.iterable(node_size): # many node sizes
Expand Down Expand Up @@ -754,6 +715,12 @@ def _connectionstyle(posA, posB, *args, **kwargs):
else:
line_width = width

# radius of edges
if tuple(reversed(edge)) in edge_pos:
rad = 0.25
else:
rad = 0.0

arrow = mpl.patches.FancyArrowPatch(
(x1, y1),
(x2, y2),
Expand All @@ -763,14 +730,57 @@ def _connectionstyle(posA, posB, *args, **kwargs):
mutation_scale=mutation_scale,
color=arrow_color,
linewidth=line_width,
connectionstyle=_connectionstyle,
connectionstyle=connectionstyle + f", rad = {rad}",
linestyle=style,
zorder=1,
) # arrows go behind nodes

arrow_collection.append(arrow)
ax.add_patch(arrow)

edge_pos = np.asarray(tuple(edge_pos))

# compute view
mirustworkx = np.amin(np.ravel(edge_pos[:, :, 0]))
maxx = np.amax(np.ravel(edge_pos[:, :, 0]))
miny = np.amin(np.ravel(edge_pos[:, :, 1]))
maxy = np.amax(np.ravel(edge_pos[:, :, 1]))
w = maxx - mirustworkx
h = maxy - miny

def _connectionstyle(posA, posB, *args, **kwargs):
# check if we need to do a self-loop
if np.all(posA == posB):
# Self-loops are scaled by view extent, except in cases the extent
# is 0, e.g. for a single node. In this case, fall back to scaling
# by the maximum node size
selfloop_ht = 0.005 * max_nodesize if h == 0 else h
# this is called with _screen space_ values so covert back
# to data space
data_loc = ax.transData.inverted().transform(posA)
v_shift = 0.1 * selfloop_ht
h_shift = v_shift * 0.5
# put the top of the loop first so arrow is not hidden by node
path = [
# 1
data_loc + np.asarray([0, v_shift]),
# 4 4 4
data_loc + np.asarray([h_shift, v_shift]),
data_loc + np.asarray([h_shift, 0]),
data_loc,
# 4 4 4
data_loc + np.asarray([-h_shift, 0]),
data_loc + np.asarray([-h_shift, v_shift]),
data_loc + np.asarray([0, v_shift]),
]

ret = mpl.path.Path(ax.transData.transform(path), [1, 4, 4, 4, 4, 4, 4])
# if not, fall back to the user specified behavior
else:
ret = base_connectionstyle(posA, posB, *args, **kwargs)

return ret

# update view
padx, pady = 0.05 * w, 0.05 * h
corners = (mirustworkx - padx, miny - pady), (maxx + padx, maxy + pady)
Expand Down Expand Up @@ -1001,6 +1011,12 @@ def draw_edge_labels(
x1 * label_pos + x2 * (1.0 - label_pos),
y1 * label_pos + y2 * (1.0 - label_pos),
)
if (n2, n1) in labels.keys(): # loop
dy = np.abs(y2 - y1)
if n2 > n1:
y -= 0.25 * dy
else:
y += 0.25 * dy

if rotate:
# in degrees
Expand Down

0 comments on commit 646057a

Please sign in to comment.