diff --git a/cirq-core/cirq/vis/density_matrix.py b/cirq-core/cirq/vis/density_matrix.py index d0c959a05e5..4cf726f77f4 100644 --- a/cirq-core/cirq/vis/density_matrix.py +++ b/cirq-core/cirq/vis/density_matrix.py @@ -41,18 +41,18 @@ def _plot_element_of_density_matrix(ax, x, y, r, phase, show_rect=False, show_te _image_opacity = 0.8 if not show_text else 0.4 circle_out = plt.Circle( - (x + 0.5, y + 0.5), radius=1 / _half_cell_size_after_padding, fill=False, color='#333333' + (x + 0.5, y + 0.5), radius=1 * _half_cell_size_after_padding, fill=False, color='#333333' ) circle_in = plt.Circle( (x + 0.5, y + 0.5), - radius=r / _half_cell_size_after_padding, + radius=r * _half_cell_size_after_padding, fill=True, color='IndianRed', alpha=_image_opacity, ) line = lines.Line2D( - (x + 0.5, x + 0.5 + np.cos(phase) / _half_cell_size_after_padding), - (y + 0.5, y + 0.5 + np.sin(phase) / _half_cell_size_after_padding), + (x + 0.5, x + 0.5 + np.cos(phase) * _half_cell_size_after_padding), + (y + 0.5, y + 0.5 + np.sin(phase) * _half_cell_size_after_padding), color='#333333', alpha=_image_opacity, ) @@ -128,7 +128,7 @@ def plot_density_matrix( f"{'0'*(num_qubits - len(f'{i:b}'))}{i:b}" for i in range(matrix.shape[0]) ] ax.set_xticks(ticks) - ax.set_xticklabels(labels) + ax.set_xticklabels(labels, rotation=90) ax.set_yticks(ticks) ax.set_yticklabels(reversed(labels)) ax.set_facecolor('#eeeeee') diff --git a/cirq-core/cirq/vis/density_matrix_test.py b/cirq-core/cirq/vis/density_matrix_test.py index 0f078db26ed..511392258c4 100644 --- a/cirq-core/cirq/vis/density_matrix_test.py +++ b/cirq-core/cirq/vis/density_matrix_test.py @@ -49,9 +49,8 @@ def test_density_matrix_plotter(size, show_text): @pytest.mark.parametrize('show_text', [True, False]) @pytest.mark.parametrize('size', [2, 4, 8, 16]) -def test_density_matrix_circle_sizes(size, show_text): +def test_density_matrix_circle_rectangle_sizes(size, show_text): matrix = cirq.testing.random_density_matrix(size) - # Check that the correct title is being shown ax = plot_density_matrix(matrix, show_text=show_text, title='Test Density Matrix Plot') # Check that the radius of all the circles in the matrix is correct circles = list(filter(lambda x: isinstance(x, patches.Circle), ax.get_children())) @@ -90,6 +89,29 @@ def test_density_matrix_circle_sizes(size, show_text): ) +@pytest.mark.parametrize('show_text', [True, False]) +@pytest.mark.parametrize('size', [2, 4, 8, 16]) +def test_density_matrix_sizes_upper_bounds(size, show_text): + matrix = cirq.testing.random_density_matrix(size) + ax = plot_density_matrix(matrix, show_text=show_text, title='Test Density Matrix Plot') + + circles = list(filter(lambda x: isinstance(x, patches.Circle), ax.get_children())) + max_radius = np.max([c.radius for c in circles if c.fill]) + + rects = list( + filter( + lambda x: isinstance(x, patches.Rectangle) and x.get_alpha() is not None, + ax.get_children(), + ) + ) + max_height = np.max([r.get_height() for r in rects]) + max_width = np.max([r.get_width() for r in rects]) + + assert max_height <= 1.0, "Some rectangle is exceeding out of it's cell size" + assert max_width <= 1.0, "Some rectangle is exceeding out of it's cell size" + assert max_radius * 2 <= 1.0, "Some circle is exceeding out of it's cell size" + + @pytest.mark.parametrize('show_rect', [True, False]) @pytest.mark.parametrize('value', [0.0, 1.0, 0.5 + 0.3j, 0.2 + 0.1j, 0.5 + 0.5j]) def test_density_element_plot(value, show_rect):