diff --git a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py index 5d85735eac0..1025f05e944 100644 --- a/cirq-core/cirq/contrib/quimb/mps_simulator_test.py +++ b/cirq-core/cirq/contrib/quimb/mps_simulator_test.py @@ -223,6 +223,17 @@ def test_measurement_1qubit(): assert sum(result.measurements['1'])[0] > 20 +def test_reset(): + q = cirq.LineQubit(0) + simulator = ccq.mps_simulator.MPSSimulator() + c = cirq.Circuit(cirq.X(q), cirq.reset(q), cirq.measure(q)) + assert simulator.sample(c)['0'][0] == 0 + c = cirq.Circuit(cirq.H(q), cirq.reset(q), cirq.measure(q)) + assert simulator.sample(c)['0'][0] == 0 + c = cirq.Circuit(cirq.reset(q), cirq.measure(q)) + assert simulator.sample(c)['0'][0] == 0 + + def test_measurement_2qubits(): q0, q1, q2 = cirq.LineQubit.range(3) circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1), cirq.H(q2), cirq.measure(q0, q2)) diff --git a/cirq-core/cirq/ops/common_channels.py b/cirq-core/cirq/ops/common_channels.py index c282d362d9e..936e19f4faf 100644 --- a/cirq-core/cirq/ops/common_channels.py +++ b/cirq-core/cirq/ops/common_channels.py @@ -717,33 +717,33 @@ def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optio def _qid_shape_(self): return (self._dimension,) - def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']): - from cirq import sim, ops + def _act_on_(self, args: 'cirq.OperationTarget', qubits: Sequence['cirq.Qid']): + if len(qubits) != 1: + return NotImplemented - if isinstance(args, sim.ActOnStabilizerCHFormArgs): - axe = args.qubit_map[qubits[0]] - if args.state._measure(axe, args.prng): - ops.X._act_on_(args, qubits) - return True + class PlusGate(raw_types.Gate): + """A qudit gate that increments a qudit state mod its dimension.""" + + def __init__(self, dimension, increment=1): + self.dimension = dimension + self.increment = increment % dimension + + def _qid_shape_(self): + return (self.dimension,) + + def _unitary_(self): + inc = (self.increment - 1) % self.dimension + 1 + u = np.empty((self.dimension, self.dimension)) + u[inc:] = np.eye(self.dimension)[:-inc] + u[:inc] = np.eye(self.dimension)[-inc:] + return u - if isinstance(args, sim.ActOnStateVectorArgs): - # Do a silent measurement. - axes = args.get_axes(qubits) - measurements, _ = sim.measure_state_vector( - args.target_tensor, - axes, - out=args.target_tensor, - qid_shape=args.target_tensor.shape, - ) - result = measurements[0] - - # Use measurement result to zero the qid. - if result: - zero = args.subspace_index(axes, 0) - other = args.subspace_index(axes, result) - args.target_tensor[zero] = args.target_tensor[other] - args.target_tensor[other] = 0 + from cirq.sim import act_on_args + if isinstance(args, act_on_args.ActOnArgs) and not args.can_represent_mixed_states: + result = args._perform_measurement(qubits)[0] + gate = PlusGate(self.dimension, self.dimension - result) + protocols.act_on(gate, args, qubits) return True return NotImplemented diff --git a/cirq-core/cirq/ops/common_channels_test.py b/cirq-core/cirq/ops/common_channels_test.py index ee89b9b82fb..ffc0084d8a5 100644 --- a/cirq-core/cirq/ops/common_channels_test.py +++ b/cirq-core/cirq/ops/common_channels_test.py @@ -498,7 +498,7 @@ def test_reset_act_on(): target_tensor=cirq.one_hot( index=(1, 1, 1, 1, 1), shape=(2, 2, 2, 2, 2), dtype=np.complex64 ), - available_buffer=np.empty(shape=(2, 2, 2, 2, 2)), + available_buffer=np.empty(shape=(2, 2, 2, 2, 2), dtype=np.complex64), qubits=cirq.LineQubit.range(5), prng=np.random.RandomState(), log_of_measurement_results={}, diff --git a/cirq-core/cirq/sim/act_on_args.py b/cirq-core/cirq/sim/act_on_args.py index 07c5e7666eb..dd670dbdd1d 100644 --- a/cirq-core/cirq/sim/act_on_args.py +++ b/cirq-core/cirq/sim/act_on_args.py @@ -276,6 +276,10 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[Optional['cirq.Qid']]: return iter(self.qubits) + @property + def can_represent_mixed_states(self) -> bool: + return False + def strat_act_on_from_apply_decompose( val: Any, diff --git a/cirq-core/cirq/sim/act_on_args_container_test.py b/cirq-core/cirq/sim/act_on_args_container_test.py index 16380bc592c..e72887aea60 100644 --- a/cirq-core/cirq/sim/act_on_args_container_test.py +++ b/cirq-core/cirq/sim/act_on_args_container_test.py @@ -24,7 +24,7 @@ def __init__(self, qubits, logs): ) def _perform_measurement(self, qubits: Sequence[cirq.Qid]) -> List[int]: - return [] + return [0] * len(qubits) def copy(self) -> 'EmptyActOnArgs': return EmptyActOnArgs( diff --git a/cirq-core/cirq/sim/act_on_density_matrix_args.py b/cirq-core/cirq/sim/act_on_density_matrix_args.py index fe4792553a7..a1c5ff1df25 100644 --- a/cirq-core/cirq/sim/act_on_density_matrix_args.py +++ b/cirq-core/cirq/sim/act_on_density_matrix_args.py @@ -174,6 +174,10 @@ def sample( seed=seed, ) + @property + def can_represent_mixed_states(self) -> bool: + return True + def __repr__(self) -> str: return ( 'cirq.ActOnDensityMatrixArgs(' diff --git a/cirq-core/cirq/sim/clifford/stabilizer_sampler_test.py b/cirq-core/cirq/sim/clifford/stabilizer_sampler_test.py index b85a0c08fc1..08bedfc2f15 100644 --- a/cirq-core/cirq/sim/clifford/stabilizer_sampler_test.py +++ b/cirq-core/cirq/sim/clifford/stabilizer_sampler_test.py @@ -29,3 +29,14 @@ def test_produces_samples(): result = cirq.StabilizerSampler().sample(c, repetitions=100) assert 5 < sum(result['a']) < 95 assert np.all(result['a'] ^ result['b'] == 0) + + +def test_reset(): + q = cirq.LineQubit(0) + sampler = cirq.StabilizerSampler() + c = cirq.Circuit(cirq.X(q), cirq.reset(q), cirq.measure(q)) + assert sampler.sample(c)['0'][0] == 0 + c = cirq.Circuit(cirq.H(q), cirq.reset(q), cirq.measure(q)) + assert sampler.sample(c)['0'][0] == 0 + c = cirq.Circuit(cirq.reset(q), cirq.measure(q)) + assert sampler.sample(c)['0'][0] == 0