From af8615655d51693294f4c36206bfb6d827c1474c Mon Sep 17 00:00:00 2001 From: Dax Fohl Date: Tue, 28 Dec 2021 10:11:19 -0800 Subject: [PATCH] Fix ResetGate to work with MPS, StabilizerSampler (#4765) This abstracts the implementation of ResetGate to be a measurment and then a PlusGate with an offset that gets the qubit back to the zero state (for 2-dimensional qubits, this is equivalent to X iff measurement==1). This allows Reset to work for all simulators, not just the specific ones whose cases were implemented in the existing code. There is a special consideration for density-matrix-like simulators. For these, we do not want to actually perform the measurement, as a density matrix can represent the mixed state of all measurement results. Performing the measurement would lose that information. Therefore, here we add a `can_represent_mixed_states` property to the ActOnArgs, and if that is true, then we allow the simulator to fall back to its own apply_channel implementation. This new property allows other density-matrix-like state representations (say a superoperator simulator) to adopt the same behavior without having to update `ResetGate._act_on_`. --- .../cirq/contrib/quimb/mps_simulator_test.py | 11 +++++ cirq-core/cirq/ops/common_channels.py | 48 +++++++++---------- cirq-core/cirq/ops/common_channels_test.py | 2 +- cirq-core/cirq/sim/act_on_args.py | 4 ++ .../cirq/sim/act_on_args_container_test.py | 2 +- .../cirq/sim/act_on_density_matrix_args.py | 4 ++ .../sim/clifford/stabilizer_sampler_test.py | 11 +++++ 7 files changed, 56 insertions(+), 26 deletions(-) diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py index 5d85735eac0..1025f05e944 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py @@ -223,6 +223,17 @@ def test_measurement_1qubit(): assert sum(result.measurements['1'])[0] > 20 +def test_reset(): + q = cirq.LineQubit(0) + simulator = ccq.mps_simulator.MPSSimulator() + c = cirq.Circuit(cirq.X(q), cirq.reset(q), cirq.measure(q)) + assert simulator.sample(c)['0'][0] == 0 + c = cirq.Circuit(cirq.H(q), cirq.reset(q), cirq.measure(q)) + assert simulator.sample(c)['0'][0] == 0 + c = cirq.Circuit(cirq.reset(q), cirq.measure(q)) + assert simulator.sample(c)['0'][0] == 0 + + def test_measurement_2qubits(): q0, q1, q2 = cirq.LineQubit.range(3) circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1), cirq.H(q2), cirq.measure(q0, q2)) diff --git a/cirq-core/cirq/ops/common_channels.py b/cirq-core/cirq/ops/common_channels.py index c282d362d9e..936e19f4faf 100644 --- a/cirq-core/cirq/ops/common_channels.py +++ b/cirq-core/cirq/ops/common_channels.py @@ -717,33 +717,33 @@ def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optio def _qid_shape_(self): return (self._dimension,) - def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']): - from cirq import sim, ops + def _act_on_(self, args: 'cirq.OperationTarget', qubits: Sequence['cirq.Qid']): + if len(qubits) != 1: + return NotImplemented - if isinstance(args, sim.ActOnStabilizerCHFormArgs): - axe = args.qubit_map[qubits[0]] - if args.state._measure(axe, args.prng): - ops.X._act_on_(args, qubits) - return True + class PlusGate(raw_types.Gate): + """A qudit gate that increments a qudit state mod its dimension.""" + + def __init__(self, dimension, increment=1): + self.dimension = dimension + self.increment = increment % dimension + + def _qid_shape_(self): + return (self.dimension,) + + def _unitary_(self): + inc = (self.increment - 1) % self.dimension + 1 + u = np.empty((self.dimension, self.dimension)) + u[inc:] = np.eye(self.dimension)[:-inc] + u[:inc] = np.eye(self.dimension)[-inc:] + return u - if isinstance(args, sim.ActOnStateVectorArgs): - # Do a silent measurement. - axes = args.get_axes(qubits) - measurements, _ = sim.measure_state_vector( - args.target_tensor, - axes, - out=args.target_tensor, - qid_shape=args.target_tensor.shape, - ) - result = measurements[0] - - # Use measurement result to zero the qid. - if result: - zero = args.subspace_index(axes, 0) - other = args.subspace_index(axes, result) - args.target_tensor[zero] = args.target_tensor[other] - args.target_tensor[other] = 0 + from cirq.sim import act_on_args + if isinstance(args, act_on_args.ActOnArgs) and not args.can_represent_mixed_states: + result = args._perform_measurement(qubits)[0] + gate = PlusGate(self.dimension, self.dimension - result) + protocols.act_on(gate, args, qubits) return True return NotImplemented diff --git a/cirq-core/cirq/ops/common_channels_test.py b/cirq-core/cirq/ops/common_channels_test.py index ee89b9b82fb..ffc0084d8a5 100644 --- a/cirq-core/cirq/ops/common_channels_test.py +++ b/cirq-core/cirq/ops/common_channels_test.py @@ -498,7 +498,7 @@ def test_reset_act_on(): target_tensor=cirq.one_hot( index=(1, 1, 1, 1, 1), shape=(2, 2, 2, 2, 2), dtype=np.complex64 ), - available_buffer=np.empty(shape=(2, 2, 2, 2, 2)), + available_buffer=np.empty(shape=(2, 2, 2, 2, 2), dtype=np.complex64), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), log_of_measurement_results={}, diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index 07c5e7666eb..dd670dbdd1d 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -276,6 +276,10 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[Optional['cirq.Qid']]: return iter(self.qubits) + @property + def can_represent_mixed_states(self) -> bool: + return False + def strat_act_on_from_apply_decompose( val: Any, diff --git a/cirq-core/cirq/sim/act_on_args_container_test.py b/cirq-core/cirq/sim/act_on_args_container_test.py index 16380bc592c..e72887aea60 100644 --- a/cirq-core/cirq/sim/act_on_args_container_test.py +++ b/cirq-core/cirq/sim/act_on_args_container_test.py @@ -24,7 +24,7 @@ def __init__(self, qubits, logs): ) def _perform_measurement(self, qubits: Sequence[cirq.Qid]) -> List[int]: - return [] + return [0] * len(qubits) def copy(self) -> 'EmptyActOnArgs': return EmptyActOnArgs( diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index fe4792553a7..a1c5ff1df25 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -174,6 +174,10 @@ def sample( seed=seed, ) + @property + def can_represent_mixed_states(self) -> bool: + return True + def __repr__(self) -> str: return ( 'cirq.ActOnDensityMatrixArgs(' diff --git a/cirq-core/cirq/sim/clifford/stabilizer_sampler_test.py b/cirq-core/cirq/sim/clifford/stabilizer_sampler_test.py index b85a0c08fc1..08bedfc2f15 100644 --- a/cirq-core/cirq/sim/clifford/stabilizer_sampler_test.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_sampler_test.py @@ -29,3 +29,14 @@ def test_produces_samples(): result = cirq.StabilizerSampler().sample(c, repetitions=100) assert 5 < sum(result['a']) < 95 assert np.all(result['a'] ^ result['b'] == 0) + + +def test_reset(): + q = cirq.LineQubit(0) + sampler = cirq.StabilizerSampler() + c = cirq.Circuit(cirq.X(q), cirq.reset(q), cirq.measure(q)) + assert sampler.sample(c)['0'][0] == 0 + c = cirq.Circuit(cirq.H(q), cirq.reset(q), cirq.measure(q)) + assert sampler.sample(c)['0'][0] == 0 + c = cirq.Circuit(cirq.reset(q), cirq.measure(q)) + assert sampler.sample(c)['0'][0] == 0