diff --git a/cirq/ops/controlled_gate.py b/cirq/ops/controlled_gate.py index 615cf94abae..75e2ae03ccb 100644 --- a/cirq/ops/controlled_gate.py +++ b/cirq/ops/controlled_gate.py @@ -12,16 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import AbstractSet, Any, cast, Collection, Dict, Optional, Sequence, Tuple, Union +from typing import ( + AbstractSet, + Any, + cast, + Collection, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, + TYPE_CHECKING, +) import numpy as np -import cirq -from cirq import protocols, value +from cirq import protocols, value, _import from cirq._compat import deprecated -from cirq.ops import raw_types, controlled_operation as cop +from cirq.ops import raw_types, controlled_operation as cop, matrix_gates from cirq.type_workarounds import NotImplementedType +if TYPE_CHECKING: + import cirq + +line_qubit = _import.LazyLoader('line_qubit', globals(), 'cirq.devices') + @value.value_equality class ControlledGate(raw_types.Gate): @@ -137,17 +153,21 @@ def num_controls(self) -> int: return len(self.control_qid_shape) def _qid_shape_(self) -> Tuple[int, ...]: - return self.control_qid_shape + cirq.qid_shape(self.sub_gate) + return self.control_qid_shape + protocols.qid_shape(self.sub_gate) def _decompose_(self, qubits): + if isinstance(self.sub_gate, matrix_gates.MatrixGate): + # Default decompositions of 2/3 qubit `cirq.MatrixGate` ignores global phase, which is + # local phase in the controlled variant and hence cannot be ignored. + return NotImplemented + result = protocols.decompose_once_with_qubits( self.sub_gate, qubits[self.num_controls() :], NotImplemented ) - if result is NotImplemented: return NotImplemented - decomposed = [] + decomposed: List['cirq.Operation'] = [] for op in result: decomposed.append( cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values) @@ -172,7 +192,7 @@ def _value_equality_values_(self): ) def _apply_unitary_(self, args: 'protocols.ApplyUnitaryArgs') -> np.ndarray: - qubits = cirq.LineQid.for_gate(self) + qubits = line_qubit.LineQid.for_gate(self) op = self.sub_gate.on(*qubits[self.num_controls() :]) c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values) return protocols.apply_unitary(c_op, args, default=NotImplemented) @@ -181,7 +201,7 @@ def _has_unitary_(self) -> bool: return protocols.has_unitary(self.sub_gate) def _unitary_(self) -> Union[np.ndarray, NotImplementedType]: - qubits = cirq.LineQid.for_gate(self) + qubits = line_qubit.LineQid.for_gate(self) op = self.sub_gate.on(*qubits[self.num_controls() :]) c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values) @@ -191,7 +211,7 @@ def _has_mixture_(self) -> bool: return protocols.has_mixture(self.sub_gate) def _mixture_(self) -> Union[np.ndarray, NotImplementedType]: - qubits = cirq.LineQid.for_gate(self) + qubits = line_qubit.LineQid.for_gate(self) op = self.sub_gate.on(*qubits[self.num_controls() :]) c_op = cop.ControlledOperation(qubits[: self.num_controls()], op, self.control_values) return protocols.mixture(c_op, default=NotImplemented) diff --git a/cirq/ops/matrix_gates.py b/cirq/ops/matrix_gates.py index 4513d644566..998df1948ba 100644 --- a/cirq/ops/matrix_gates.py +++ b/cirq/ops/matrix_gates.py @@ -18,13 +18,23 @@ import numpy as np -from cirq import linalg, protocols +from cirq import linalg, protocols, _import from cirq._compat import proper_repr from cirq.ops import raw_types if TYPE_CHECKING: import cirq +single_qubit_decompositions = _import.LazyLoader( + 'single_qubit_decompositions', globals(), 'cirq.transformers.analytical_decompositions' +) +two_qubit_to_cz = _import.LazyLoader( + 'two_qubit_to_cz', globals(), 'cirq.transformers.analytical_decompositions' +) +three_qubit_decomposition = _import.LazyLoader( + 'three_qubit_decomposition', globals(), 'cirq.transformers.analytical_decompositions' +) + class MatrixGate(raw_types.Gate): """A unitary qubit or qudit gate defined entirely by its matrix.""" @@ -116,6 +126,20 @@ def _phase_by_(self, phase_turns: float, qubit_index: int) -> 'MatrixGate': result[linalg.slice_for_qubits_equal_to([j], 1)] *= np.conj(p) return MatrixGate(matrix=result.reshape(self._matrix.shape), qid_shape=self._qid_shape) + def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> 'cirq.OP_TREE': + if self._qid_shape == (2,): + return [ + g.on(qubits[0]) + for g in single_qubit_decompositions.single_qubit_matrix_to_gates(self._matrix) + ] + if self._qid_shape == (2,) * 2: + return two_qubit_to_cz.two_qubit_matrix_to_cz_operations( + *qubits, self._matrix, allow_partial_czs=True + ) + if self._qid_shape == (2,) * 3: + return three_qubit_decomposition.three_qubit_matrix_to_operations(*qubits, self._matrix) + return NotImplemented + def _has_unitary_(self) -> bool: return True diff --git a/cirq/ops/matrix_gates_test.py b/cirq/ops/matrix_gates_test.py index 16d21a45d18..848e685022d 100644 --- a/cirq/ops/matrix_gates_test.py +++ b/cirq/ops/matrix_gates_test.py @@ -276,16 +276,19 @@ def test_str_executes(): assert '0' in str(cirq.MatrixGate(np.eye(4))) -def test_one_qubit_consistent(): - u = cirq.testing.random_unitary(2) - g = cirq.MatrixGate(u) - cirq.testing.assert_implements_consistent_protocols(g) - - -def test_two_qubit_consistent(): - u = cirq.testing.random_unitary(4) - g = cirq.MatrixGate(u) - cirq.testing.assert_implements_consistent_protocols(g) +@pytest.mark.parametrize('n', [1, 2, 3, 4, 5]) +def test_implements_consistent_protocols(n): + u = cirq.testing.random_unitary(2 ** n) + g1 = cirq.MatrixGate(u) + cirq.testing.assert_implements_consistent_protocols(g1, ignoring_global_phase=True) + cirq.testing.assert_decompose_ends_at_default_gateset(g1) + + if n == 1: + return + + g2 = cirq.MatrixGate(u, qid_shape=(4,) + (2,) * (n - 2)) + cirq.testing.assert_implements_consistent_protocols(g2, ignoring_global_phase=True) + cirq.testing.assert_decompose_ends_at_default_gateset(g2) def test_repr(): diff --git a/cirq/testing/consistent_decomposition.py b/cirq/testing/consistent_decomposition.py index 978763c461e..313c074f14a 100644 --- a/cirq/testing/consistent_decomposition.py +++ b/cirq/testing/consistent_decomposition.py @@ -51,6 +51,8 @@ def assert_decompose_is_consistent_with_unitary(val: Any, ignoring_global_phase: def _known_gate_with_no_decomposition(val: Any): """Checks whether `val` is a known gate with no default decomposition to default gateset.""" + if isinstance(val, ops.MatrixGate): + return protocols.qid_shape(val) not in [(2,), (2,) * 2, (2,) * 3] return False