Skip to content

Commit

Permalink
Adds default decompositions for cirq.MatrixGate into X/Y/Z/CZ targe…
Browse files Browse the repository at this point in the history
…t gateset. (quantumlib#5088)

* Add default decompositions for cirq.MatrixGate

* Add special case to handle MatrixGate as a sub gate in ControlledGate
  • Loading branch information
tanujkhattar authored Mar 17, 2022
1 parent f02cf18 commit 78c2b79
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 21 deletions.
40 changes: 30 additions & 10 deletions cirq/ops/controlled_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import AbstractSet, Any, cast, Collection, Dict, Optional, Sequence, Tuple, Union
from typing import (
AbstractSet,
Any,
cast,
Collection,
Dict,
List,
Optional,
Sequence,
Tuple,
Union,
TYPE_CHECKING,
)

import numpy as np

import cirq
from cirq import protocols, value
from cirq import protocols, value, _import
from cirq._compat import deprecated
from cirq.ops import raw_types, controlled_operation as cop
from cirq.ops import raw_types, controlled_operation as cop, matrix_gates
from cirq.type_workarounds import NotImplementedType

if TYPE_CHECKING:
import cirq

line_qubit = _import.LazyLoader('line_qubit', globals(), 'cirq.devices')


@value.value_equality
class ControlledGate(raw_types.Gate):
Expand Down Expand Up @@ -137,17 +153,21 @@ def num_controls(self) -> int:
return len(self.control_qid_shape)

def _qid_shape_(self) -> Tuple[int, ...]:
return self.control_qid_shape + cirq.qid_shape(self.sub_gate)
return self.control_qid_shape + protocols.qid_shape(self.sub_gate)

def _decompose_(self, qubits):
if isinstance(self.sub_gate, matrix_gates.MatrixGate):
# Default decompositions of 2/3 qubit `cirq.MatrixGate` ignores global phase, which is
# local phase in the controlled variant and hence cannot be ignored.
return NotImplemented

result = protocols.decompose_once_with_qubits(
self.sub_gate, qubits[self.num_controls() :], NotImplemented
)

if result is NotImplemented:
return NotImplemented

decomposed = []
decomposed: List['cirq.Operation'] = []
for op in result:
decomposed.append(
cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values)
Expand All @@ -172,7 +192,7 @@ def _value_equality_values_(self):
)

def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> np.ndarray:
qubits = cirq.LineQid.for_gate(self)
qubits = line_qubit.LineQid.for_gate(self)
op = self.sub_gate.on(*qubits[self.num_controls() :])
c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values)
return protocols.apply_unitary(c_op, args, default=NotImplemented)
Expand All @@ -181,7 +201,7 @@ def _has_unitary_(self) -> bool:
return protocols.has_unitary(self.sub_gate)

def _unitary_(self) -> Union[np.ndarray, NotImplementedType]:
qubits = cirq.LineQid.for_gate(self)
qubits = line_qubit.LineQid.for_gate(self)
op = self.sub_gate.on(*qubits[self.num_controls() :])
c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values)

Expand All @@ -191,7 +211,7 @@ def _has_mixture_(self) -> bool:
return protocols.has_mixture(self.sub_gate)

def _mixture_(self) -> Union[np.ndarray, NotImplementedType]:
qubits = cirq.LineQid.for_gate(self)
qubits = line_qubit.LineQid.for_gate(self)
op = self.sub_gate.on(*qubits[self.num_controls() :])
c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values)
return protocols.mixture(c_op, default=NotImplemented)
Expand Down
26 changes: 25 additions & 1 deletion cirq/ops/matrix_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,23 @@

import numpy as np

from cirq import linalg, protocols
from cirq import linalg, protocols, _import
from cirq._compat import proper_repr
from cirq.ops import raw_types

if TYPE_CHECKING:
import cirq

single_qubit_decompositions = _import.LazyLoader(
'single_qubit_decompositions', globals(), 'cirq.transformers.analytical_decompositions'
)
two_qubit_to_cz = _import.LazyLoader(
'two_qubit_to_cz', globals(), 'cirq.transformers.analytical_decompositions'
)
three_qubit_decomposition = _import.LazyLoader(
'three_qubit_decomposition', globals(), 'cirq.transformers.analytical_decompositions'
)


class MatrixGate(raw_types.Gate):
"""A unitary qubit or qudit gate defined entirely by its matrix."""
Expand Down Expand Up @@ -116,6 +126,20 @@ def _phase_by_(self, phase_turns: float, qubit_index: int) -> 'MatrixGate':
result[linalg.slice_for_qubits_equal_to([j], 1)] *= np.conj(p)
return MatrixGate(matrix=result.reshape(self._matrix.shape), qid_shape=self._qid_shape)

def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> 'cirq.OP_TREE':
if self._qid_shape == (2,):
return [
g.on(qubits[0])
for g in single_qubit_decompositions.single_qubit_matrix_to_gates(self._matrix)
]
if self._qid_shape == (2,) * 2:
return two_qubit_to_cz.two_qubit_matrix_to_cz_operations(
*qubits, self._matrix, allow_partial_czs=True
)
if self._qid_shape == (2,) * 3:
return three_qubit_decomposition.three_qubit_matrix_to_operations(*qubits, self._matrix)
return NotImplemented

def _has_unitary_(self) -> bool:
return True

Expand Down
23 changes: 13 additions & 10 deletions cirq/ops/matrix_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,16 +276,19 @@ def test_str_executes():
assert '0' in str(cirq.MatrixGate(np.eye(4)))


def test_one_qubit_consistent():
u = cirq.testing.random_unitary(2)
g = cirq.MatrixGate(u)
cirq.testing.assert_implements_consistent_protocols(g)


def test_two_qubit_consistent():
u = cirq.testing.random_unitary(4)
g = cirq.MatrixGate(u)
cirq.testing.assert_implements_consistent_protocols(g)
@pytest.mark.parametrize('n', [1, 2, 3, 4, 5])
def test_implements_consistent_protocols(n):
u = cirq.testing.random_unitary(2 ** n)
g1 = cirq.MatrixGate(u)
cirq.testing.assert_implements_consistent_protocols(g1, ignoring_global_phase=True)
cirq.testing.assert_decompose_ends_at_default_gateset(g1)

if n == 1:
return

g2 = cirq.MatrixGate(u, qid_shape=(4,) + (2,) * (n - 2))
cirq.testing.assert_implements_consistent_protocols(g2, ignoring_global_phase=True)
cirq.testing.assert_decompose_ends_at_default_gateset(g2)


def test_repr():
Expand Down
2 changes: 2 additions & 0 deletions cirq/testing/consistent_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ def assert_decompose_is_consistent_with_unitary(val: Any, ignoring_global_phase:

def _known_gate_with_no_decomposition(val: Any):
"""Checks whether `val` is a known gate with no default decomposition to default gateset."""
if isinstance(val, ops.MatrixGate):
return protocols.qid_shape(val) not in [(2,), (2,) * 2, (2,) * 3]
return False


Expand Down

0 comments on commit 78c2b79

Please sign in to comment.