Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support multi-qubit measurement in plot_state_histogram #3054

Merged
merged 12 commits into from
Jun 9, 2020
8 changes: 2 additions & 6 deletions cirq/study/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def plot_state_histogram(result: trial_result.TrialResult) -> np.ndarray:

States is a bitstring representation of all the qubit states in a single
result.
Currently this function assumes each measurement gate applies to only
a single qubit.

Args:
result: The trial results to plot.
Expand All @@ -39,17 +37,15 @@ def plot_state_histogram(result: trial_result.TrialResult) -> np.ndarray:
# This allows cirq to be usable without python3-tk.
import matplotlib.pyplot as plt

num_qubits = len(result.measurements.keys())
num_qubits = sum([value.shape[1] for value in result.measurements.values()])
states = 2**num_qubits
values = np.zeros(states)

# measurements is a dict of {measurement gate key:
# array(repetitions, boolean result)}
# Convert this to an array of repetitions, each with an array of booleans.
# e.g. {q1: array([[True, True]]), q2: array([[False, False]])}
# --> array([[True, False], [True, False]])
measurement_by_result = np.array([
v.transpose()[0] for k, v in result.measurements.items()]).transpose()
measurement_by_result = np.hstack(list(result.measurements.values()))
cduck marked this conversation as resolved.
Show resolved Hide resolved

for meas in measurement_by_result:
# Convert each array of booleans to a string representation.
Expand Down
37 changes: 37 additions & 0 deletions cirq/study/visualize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,40 @@ def test_plot_state_histogram():
expected_values = [0., 0., 0., 5.]

np.testing.assert_equal(values_plotted, expected_values)


def test_plot_state_histogram_multi_1():
cduck marked this conversation as resolved.
Show resolved Hide resolved
qubits = cirq.LineQubit.range(3)
c1 = cirq.Circuit(
cirq.X.on_each(*qubits),
cirq.measure(*qubits), # One multi-qubit measurement
)
r1 = cirq.sample(c1, repetitions=5)
values_plotted = visualize.plot_state_histogram(r1)
expected_values = [0, 0, 0, 0, 0, 0, 0, 5]
np.testing.assert_equal(values_plotted, expected_values)


def test_plot_state_histogram_multi_2():
cduck marked this conversation as resolved.
Show resolved Hide resolved
qubits = cirq.LineQubit.range(3)
c2 = cirq.Circuit(
cirq.X.on_each(*qubits),
cirq.measure_each(*qubits), # One multi-qubit measurement
cduck marked this conversation as resolved.
Show resolved Hide resolved
)
r2 = cirq.sample(c2, repetitions=5)
values_plotted = visualize.plot_state_histogram(r2)
expected_values = [0, 0, 0, 0, 0, 0, 0, 5]
np.testing.assert_equal(values_plotted, expected_values)


def test_plot_state_histogram_multi_3():
cduck marked this conversation as resolved.
Show resolved Hide resolved
qubits = cirq.LineQubit.range(3)
c3 = cirq.Circuit(
cirq.X.on_each(*qubits),
cirq.measure(*qubits[:2]), # One multi-qubit measurement
cirq.measure_each(*qubits[2:]), # One multi-qubit measurement
cduck marked this conversation as resolved.
Show resolved Hide resolved
)
r3 = cirq.sample(c3, repetitions=5)
values_plotted = visualize.plot_state_histogram(r3)
expected_values = [0, 0, 0, 0, 0, 0, 0, 5]
np.testing.assert_equal(values_plotted, expected_values)