diff --git a/cirq/study/visualize.py b/cirq/study/visualize.py index 65aaa0ab189..5cc94192988 100644 --- a/cirq/study/visualize.py +++ b/cirq/study/visualize.py @@ -24,8 +24,6 @@ def plot_state_histogram(result: trial_result.TrialResult) -> np.ndarray: States is a bitstring representation of all the qubit states in a single result. - Currently this function assumes each measurement gate applies to only - a single qubit. Args: result: The trial results to plot. @@ -39,17 +37,15 @@ def plot_state_histogram(result: trial_result.TrialResult) -> np.ndarray: # This allows cirq to be usable without python3-tk. import matplotlib.pyplot as plt - num_qubits = len(result.measurements.keys()) + num_qubits = sum([value.shape[1] for value in result.measurements.values()]) states = 2**num_qubits values = np.zeros(states) - # measurements is a dict of {measurement gate key: # array(repetitions, boolean result)} # Convert this to an array of repetitions, each with an array of booleans. # e.g. {q1: array([[True, True]]), q2: array([[False, False]])} # --> array([[True, False], [True, False]]) - measurement_by_result = np.array([ - v.transpose()[0] for k, v in result.measurements.items()]).transpose() + measurement_by_result = np.hstack(list(result.measurements.values())) for meas in measurement_by_result: # Convert each array of booleans to a string representation. diff --git a/cirq/study/visualize_test.py b/cirq/study/visualize_test.py index 8cd83e9caf9..88f7920d4b9 100644 --- a/cirq/study/visualize_test.py +++ b/cirq/study/visualize_test.py @@ -38,3 +38,30 @@ def test_plot_state_histogram(): expected_values = [0., 0., 0., 5.] np.testing.assert_equal(values_plotted, expected_values) + + +def test_plot_state_histogram_multi_1(): + pl.switch_backend('PDF') + qubits = cirq.LineQubit.range(4) + c = cirq.Circuit( + cirq.X.on_each(*qubits[1:]), + cirq.measure(*qubits), # One multi-qubit measurement + ) + r = cirq.sample(c, repetitions=5) + values_plotted = visualize.plot_state_histogram(r) + expected_values = [0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0] + np.testing.assert_equal(values_plotted, expected_values) + + +def test_plot_state_histogram_multi_2(): + pl.switch_backend('PDF') + qubits = cirq.LineQubit.range(4) + c = cirq.Circuit( + cirq.X.on_each(*qubits[1:]), + cirq.measure(*qubits[:2]), # One multi-qubit measurement + cirq.measure_each(*qubits[2:]), # Multiple single-qubit measurement + ) + r = cirq.sample(c, repetitions=5) + values_plotted = visualize.plot_state_histogram(r) + expected_values = [0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0] + np.testing.assert_equal(values_plotted, expected_values)