Skip to content

Commit

Permalink
Density Matrix plotting is being fixed (quantumlib#4805)
Browse files Browse the repository at this point in the history
The plotting positions were dividing by cell size instead of multiplying, fixed that.
Closes quantumlib#4804.
  • Loading branch information
AnimeshSinha1309 authored and rht committed May 1, 2023
1 parent f6447dd commit 5b87a19
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
10 changes: 5 additions & 5 deletions cirq-core/cirq/vis/density_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,18 @@ def _plot_element_of_density_matrix(ax, x, y, r, phase, show_rect=False, show_te
_image_opacity = 0.8 if not show_text else 0.4

circle_out = plt.Circle(
(x + 0.5, y + 0.5), radius=1 / _half_cell_size_after_padding, fill=False, color='#333333'
(x + 0.5, y + 0.5), radius=1 * _half_cell_size_after_padding, fill=False, color='#333333'
)
circle_in = plt.Circle(
(x + 0.5, y + 0.5),
radius=r / _half_cell_size_after_padding,
radius=r * _half_cell_size_after_padding,
fill=True,
color='IndianRed',
alpha=_image_opacity,
)
line = lines.Line2D(
(x + 0.5, x + 0.5 + np.cos(phase) / _half_cell_size_after_padding),
(y + 0.5, y + 0.5 + np.sin(phase) / _half_cell_size_after_padding),
(x + 0.5, x + 0.5 + np.cos(phase) * _half_cell_size_after_padding),
(y + 0.5, y + 0.5 + np.sin(phase) * _half_cell_size_after_padding),
color='#333333',
alpha=_image_opacity,
)
Expand Down Expand Up @@ -128,7 +128,7 @@ def plot_density_matrix(
f"{'0'*(num_qubits - len(f'{i:b}'))}{i:b}" for i in range(matrix.shape[0])
]
ax.set_xticks(ticks)
ax.set_xticklabels(labels)
ax.set_xticklabels(labels, rotation=90)
ax.set_yticks(ticks)
ax.set_yticklabels(reversed(labels))
ax.set_facecolor('#eeeeee')
Expand Down
26 changes: 24 additions & 2 deletions cirq-core/cirq/vis/density_matrix_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@ def test_density_matrix_plotter(size, show_text):

@pytest.mark.parametrize('show_text', [True, False])
@pytest.mark.parametrize('size', [2, 4, 8, 16])
def test_density_matrix_circle_sizes(size, show_text):
def test_density_matrix_circle_rectangle_sizes(size, show_text):
matrix = cirq.testing.random_density_matrix(size)
# Check that the correct title is being shown
ax = plot_density_matrix(matrix, show_text=show_text, title='Test Density Matrix Plot')
# Check that the radius of all the circles in the matrix is correct
circles = list(filter(lambda x: isinstance(x, patches.Circle), ax.get_children()))
Expand Down Expand Up @@ -90,6 +89,29 @@ def test_density_matrix_circle_sizes(size, show_text):
)


@pytest.mark.parametrize('show_text', [True, False])
@pytest.mark.parametrize('size', [2, 4, 8, 16])
def test_density_matrix_sizes_upper_bounds(size, show_text):
matrix = cirq.testing.random_density_matrix(size)
ax = plot_density_matrix(matrix, show_text=show_text, title='Test Density Matrix Plot')

circles = list(filter(lambda x: isinstance(x, patches.Circle), ax.get_children()))
max_radius = np.max([c.radius for c in circles if c.fill])

rects = list(
filter(
lambda x: isinstance(x, patches.Rectangle) and x.get_alpha() is not None,
ax.get_children(),
)
)
max_height = np.max([r.get_height() for r in rects])
max_width = np.max([r.get_width() for r in rects])

assert max_height <= 1.0, "Some rectangle is exceeding out of it's cell size"
assert max_width <= 1.0, "Some rectangle is exceeding out of it's cell size"
assert max_radius * 2 <= 1.0, "Some circle is exceeding out of it's cell size"


@pytest.mark.parametrize('show_rect', [True, False])
@pytest.mark.parametrize('value', [0.0, 1.0, 0.5 + 0.3j, 0.2 + 0.1j, 0.5 + 0.5j])
def test_density_element_plot(value, show_rect):
Expand Down

0 comments on commit 5b87a19

Please sign in to comment.