Skip to content

Commit

Permalink
Fix ResetGate to work with MPS, StabilizerSampler (quantumlib#4765)
Browse files Browse the repository at this point in the history
This abstracts the implementation of ResetGate to be a measurment and then a PlusGate with an offset that gets the qubit back to the zero state (for 2-dimensional qubits, this is equivalent to X iff measurement==1).

This allows Reset to work for all simulators, not just the specific ones whose cases were implemented in the existing code.

There is a special consideration for density-matrix-like simulators. For these, we do not want to actually perform the measurement, as a density matrix can represent the mixed state of all measurement results. Performing the measurement would lose that information. Therefore, here we add a `can_represent_mixed_states` property to the ActOnArgs, and if that is true, then we allow the simulator to fall back to its own apply_channel implementation. This new property allows other density-matrix-like state representations (say a superoperator simulator) to adopt the same behavior without having to update `ResetGate._act_on_`.
  • Loading branch information
daxfohl authored and MichaelBroughton committed Jan 22, 2022
1 parent c054e14 commit af86156
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 26 deletions.
11 changes: 11 additions & 0 deletions cirq-core/cirq/contrib/quimb/mps_simulator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,17 @@ def test_measurement_1qubit():
assert sum(result.measurements['1'])[0] > 20


def test_reset():
q = cirq.LineQubit(0)
simulator = ccq.mps_simulator.MPSSimulator()
c = cirq.Circuit(cirq.X(q), cirq.reset(q), cirq.measure(q))
assert simulator.sample(c)['0'][0] == 0
c = cirq.Circuit(cirq.H(q), cirq.reset(q), cirq.measure(q))
assert simulator.sample(c)['0'][0] == 0
c = cirq.Circuit(cirq.reset(q), cirq.measure(q))
assert simulator.sample(c)['0'][0] == 0


def test_measurement_2qubits():
q0, q1, q2 = cirq.LineQubit.range(3)
circuit = cirq.Circuit(cirq.H(q0), cirq.H(q1), cirq.H(q2), cirq.measure(q0, q2))
Expand Down
48 changes: 24 additions & 24 deletions cirq-core/cirq/ops/common_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,33 +717,33 @@ def _qasm_(self, args: 'cirq.QasmArgs', qubits: Tuple['cirq.Qid', ...]) -> Optio
def _qid_shape_(self):
return (self._dimension,)

def _act_on_(self, args: 'cirq.ActOnArgs', qubits: Sequence['cirq.Qid']):
from cirq import sim, ops
def _act_on_(self, args: 'cirq.OperationTarget', qubits: Sequence['cirq.Qid']):
if len(qubits) != 1:
return NotImplemented

if isinstance(args, sim.ActOnStabilizerCHFormArgs):
axe = args.qubit_map[qubits[0]]
if args.state._measure(axe, args.prng):
ops.X._act_on_(args, qubits)
return True
class PlusGate(raw_types.Gate):
"""A qudit gate that increments a qudit state mod its dimension."""

def __init__(self, dimension, increment=1):
self.dimension = dimension
self.increment = increment % dimension

def _qid_shape_(self):
return (self.dimension,)

def _unitary_(self):
inc = (self.increment - 1) % self.dimension + 1
u = np.empty((self.dimension, self.dimension))
u[inc:] = np.eye(self.dimension)[:-inc]
u[:inc] = np.eye(self.dimension)[-inc:]
return u

if isinstance(args, sim.ActOnStateVectorArgs):
# Do a silent measurement.
axes = args.get_axes(qubits)
measurements, _ = sim.measure_state_vector(
args.target_tensor,
axes,
out=args.target_tensor,
qid_shape=args.target_tensor.shape,
)
result = measurements[0]

# Use measurement result to zero the qid.
if result:
zero = args.subspace_index(axes, 0)
other = args.subspace_index(axes, result)
args.target_tensor[zero] = args.target_tensor[other]
args.target_tensor[other] = 0
from cirq.sim import act_on_args

if isinstance(args, act_on_args.ActOnArgs) and not args.can_represent_mixed_states:
result = args._perform_measurement(qubits)[0]
gate = PlusGate(self.dimension, self.dimension - result)
protocols.act_on(gate, args, qubits)
return True

return NotImplemented
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/common_channels_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def test_reset_act_on():
target_tensor=cirq.one_hot(
index=(1, 1, 1, 1, 1), shape=(2, 2, 2, 2, 2), dtype=np.complex64
),
available_buffer=np.empty(shape=(2, 2, 2, 2, 2)),
available_buffer=np.empty(shape=(2, 2, 2, 2, 2), dtype=np.complex64),
qubits=cirq.LineQubit.range(5),
prng=np.random.RandomState(),
log_of_measurement_results={},
Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/sim/act_on_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,10 @@ def __len__(self) -> int:
def __iter__(self) -> Iterator[Optional['cirq.Qid']]:
return iter(self.qubits)

@property
def can_represent_mixed_states(self) -> bool:
return False


def strat_act_on_from_apply_decompose(
val: Any,
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/sim/act_on_args_container_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, qubits, logs):
)

def _perform_measurement(self, qubits: Sequence[cirq.Qid]) -> List[int]:
return []
return [0] * len(qubits)

def copy(self) -> 'EmptyActOnArgs':
return EmptyActOnArgs(
Expand Down
4 changes: 4 additions & 0 deletions cirq-core/cirq/sim/act_on_density_matrix_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ def sample(
seed=seed,
)

@property
def can_represent_mixed_states(self) -> bool:
return True

def __repr__(self) -> str:
return (
'cirq.ActOnDensityMatrixArgs('
Expand Down
11 changes: 11 additions & 0 deletions cirq-core/cirq/sim/clifford/stabilizer_sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,14 @@ def test_produces_samples():
result = cirq.StabilizerSampler().sample(c, repetitions=100)
assert 5 < sum(result['a']) < 95
assert np.all(result['a'] ^ result['b'] == 0)


def test_reset():
q = cirq.LineQubit(0)
sampler = cirq.StabilizerSampler()
c = cirq.Circuit(cirq.X(q), cirq.reset(q), cirq.measure(q))
assert sampler.sample(c)['0'][0] == 0
c = cirq.Circuit(cirq.H(q), cirq.reset(q), cirq.measure(q))
assert sampler.sample(c)['0'][0] == 0
c = cirq.Circuit(cirq.reset(q), cirq.measure(q))
assert sampler.sample(c)['0'][0] == 0

0 comments on commit af86156

Please sign in to comment.