Skip to content

Commit

Permalink
Deferred measurements transformer (quantumlib#4849)
Browse files Browse the repository at this point in the history
Closes quantumlib#4818, Also reimplements `mux` simulation based on this, in preparation to deprecate `ignore_measurement_results`.

Needs a follow-up after quantumlib#4512 to support classical controls on multi-qubit measurements, as we need some way of defining the condition "at least one qubit is not zero" to match the classical interpretation of a multi-qubit measurement.
  • Loading branch information
daxfohl authored and 95-martin-orion committed Mar 2, 2022
1 parent 7e0d25f commit 488ec7b
Show file tree
Hide file tree
Showing 8 changed files with 569 additions and 4 deletions.
2 changes: 2 additions & 0 deletions cirq-core/cirq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,8 @@
decompose_multi_controlled_x,
decompose_multi_controlled_rotation,
decompose_two_qubit_interaction_into_four_fsim_gates,
defer_measurements,
dephase_measurements,
drop_empty_moments,
drop_negligible_operations,
eject_phased_paulis,
Expand Down
2 changes: 1 addition & 1 deletion cirq-core/cirq/ops/kraus_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
self._key = key

@staticmethod
def from_channel(channel: 'KrausChannel', key: Union[str, 'cirq.MeasurementKey', None] = None):
def from_channel(channel: 'cirq.Gate', key: Union[str, 'cirq.MeasurementKey', None] = None):
"""Creates a copy of a channel with the given measurement key."""
return KrausChannel(kraus_ops=list(protocols.kraus(channel)), key=key)

Expand Down
3 changes: 3 additions & 0 deletions cirq-core/cirq/ops/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@ def __repr__(self) -> str:

def _json_dict_(self) -> Dict[str, str]:
return {}

def __hash__(self):
return hash(VirtualTag)
6 changes: 4 additions & 2 deletions cirq-core/cirq/sim/mux.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from cirq._doc import document
from cirq.sim import sparse_simulator, density_matrix_simulator
from cirq.sim.clifford import clifford_simulator
from cirq.transformers import measurement_transformers

if TYPE_CHECKING:
import cirq
Expand Down Expand Up @@ -281,9 +282,10 @@ def final_density_matrix(
dtype=dtype,
noise=noise,
seed=seed,
ignore_measurement_results=(ignore_measurement_results),
).simulate(
program=circuit_like,
program=measurement_transformers.dephase_measurements(circuit_like)
if ignore_measurement_results
else circuit_like,
initial_state=initial_state,
qubit_order=qubit_order,
param_resolver=param_resolver,
Expand Down
5 changes: 5 additions & 0 deletions cirq-core/cirq/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@

from cirq.transformers.eject_z import eject_z

from cirq.transformers.measurement_transformers import (
defer_measurements,
dephase_measurements,
)

from cirq.transformers.synchronize_terminal_measurements import synchronize_terminal_measurements

from cirq.transformers.transformer_api import (
Expand Down
177 changes: 177 additions & 0 deletions cirq-core/cirq/transformers/measurement_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright 2022 The Cirq Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union

from cirq import ops, protocols, value
from cirq.transformers import (
transformer_api,
transformer_primitives,
)
from cirq.transformers.synchronize_terminal_measurements import find_terminal_measurements

if TYPE_CHECKING:
import cirq


class _MeasurementQid(ops.Qid):
"""A qubit that substitutes in for a deferred measurement.
Exactly one qubit will be created per qubit in the measurement gate.
"""

def __init__(self, key: Union[str, 'cirq.MeasurementKey'], qid: 'cirq.Qid'):
"""Initializes the qubit.
Args:
key: The key of the measurement gate being deferred.
qid: One qubit that is being measured. Each deferred measurement
should create one new _MeasurementQid per qubit being measured
by that gate.
"""
self._key = value.MeasurementKey.parse_serialized(key) if isinstance(key, str) else key
self._qid = qid

@property
def dimension(self) -> int:
return self._qid.dimension

def _comparison_key(self) -> Any:
return (str(self._key), self._qid._comparison_key())

def __str__(self) -> str:
return f"M('{self._key}', q={self._qid})"

def __repr__(self) -> str:
return f'_MeasurementQid({self._key!r}, {self._qid!r})'


@transformer_api.transformer
def defer_measurements(
circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None
) -> 'cirq.Circuit':
"""Implements the Deferred Measurement Principle.
Uses the Deferred Measurement Principle to move all measurements to the
end of the circuit. All non-terminal measurements are changed to
conditional quantum gates onto ancilla qubits, and classically controlled
operations are transformed to quantum controls from those ancilla qubits.
Finally, measurements of all ancilla qubits are appended to the end of the
circuit.
Optimizing deferred measurements is an area of active research, and future
iterations may contain optimizations that reduce the number of ancilla
qubits, so one should not depend on the exact shape of the output from this
function. Only the logical equivalence is guaranteed to remain unchanged.
Moment and subcircuit structure is not preserved.
Args:
circuit: The circuit to transform. It will not be modified.
context: `cirq.TransformerContext` storing common configurable options
for transformers.
Returns:
A circuit with equivalent logic, but all measurements at the end of the
circuit.
Raises:
ValueError: If sympy-based classical conditions are used, or if
conditions based on multi-qubit measurements exist. (The latter of
these is planned to be implemented soon).
"""

circuit = transformer_primitives.unroll_circuit_op(circuit, deep=True, tags_to_check=None)
terminal_measurements = {op for _, op in find_terminal_measurements(circuit)}
measurement_qubits: Dict['cirq.MeasurementKey', List['_MeasurementQid']] = {}

def defer(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
if op in terminal_measurements:
return op
gate = op.gate
if isinstance(gate, ops.MeasurementGate):
key = value.MeasurementKey.parse_serialized(gate.key)
targets = [_MeasurementQid(key, q) for q in op.qubits]
measurement_qubits[key] = targets
cxs = [ops.CX(q, target) for q, target in zip(op.qubits, targets)]
xs = [ops.X(targets[i]) for i, b in enumerate(gate.full_invert_mask()) if b]
return cxs + xs
elif protocols.is_measurement(op):
return [defer(op, None) for op in protocols.decompose_once(op)]
elif op.classical_controls:
controls = []
for c in op.classical_controls:
if isinstance(c, value.KeyCondition):
if c.key not in measurement_qubits:
raise ValueError(f'Deferred measurement for key={c.key} not found.')
qubits = measurement_qubits[c.key]
if len(qubits) != 1:
# TODO: Multi-qubit conditions require
# https://github.com/quantumlib/Cirq/issues/4512
# Remember to update docstring above once this works.
raise ValueError('Only single qubit conditions are allowed.')
controls.extend(qubits)
else:
raise ValueError('Only KeyConditions are allowed.')
return op.without_classical_controls().controlled_by(
*controls, control_values=[tuple(range(1, q.dimension)) for q in controls]
)
return op

circuit = transformer_primitives.map_operations_and_unroll(
circuit=circuit,
map_func=defer,
tags_to_ignore=context.tags_to_ignore if context else (),
raise_if_add_qubits=False,
).unfreeze()
for k, qubits in measurement_qubits.items():
circuit.append(ops.measure(*qubits, key=k))
return circuit


@transformer_api.transformer
def dephase_measurements(
circuit: 'cirq.AbstractCircuit', *, context: Optional['cirq.TransformerContext'] = None
) -> 'cirq.Circuit':
"""Changes all measurements to a dephase operation.
This transformer is useful when using a density matrix simulator, when
wishing to calculate the final density matrix of a circuit and not simulate
the measurements themselves.
Args:
circuit: The circuit to transform. It will not be modified.
context: `cirq.TransformerContext` storing common configurable options
for transformers.
Returns:
A copy of the circuit, with dephase operations in place of all
measurements.
Raises:
ValueError: If the circuit contains classical controls. In this case,
it is required to change these to quantum controls via
`cirq.defer_measurements` first. Since deferral adds ancilla qubits
to the circuit, this is not done automatically, to prevent
surprises.
"""

def dephase(op: 'cirq.Operation', _) -> 'cirq.OP_TREE':
gate = op.gate
if isinstance(gate, ops.MeasurementGate):
key = value.MeasurementKey.parse_serialized(gate.key)
return ops.KrausChannel.from_channel(ops.phase_damp(1), key=key).on_each(op.qubits)
elif isinstance(op, ops.ClassicallyControlledOperation):
raise ValueError('Use cirq.defer_measurements first to remove classical controls.')
return op

ignored = () if context is None else context.tags_to_ignore
return transformer_primitives.map_operations(
circuit, dephase, deep=True, tags_to_ignore=ignored
).unfreeze()
Loading

0 comments on commit 488ec7b

Please sign in to comment.