From fc35a4f7bd2ec828658bd21fc7fb5f4ec13b19e3 Mon Sep 17 00:00:00 2001 From: Tanuj Khattar Date: Thu, 17 Mar 2022 02:35:43 +0530 Subject: [PATCH] Add assert_decompose_ends_at_default_gateset consistency test (#5079) * Add assert_decompose_ends_at_default_gateset consistency test * Refactor assert_decompose_ends_at_default_gateset to provide hook to ignore gates without known decompositions --- cirq/protocols/decompose_protocol.py | 9 +++ cirq/testing/__init__.py | 1 + cirq/testing/consistent_decomposition.py | 17 ++++++ cirq/testing/consistent_decomposition_test.py | 61 +++++++++++++++++++ 4 files changed, 88 insertions(+) diff --git a/cirq/protocols/decompose_protocol.py b/cirq/protocols/decompose_protocol.py index dda931d1772..47aaa839414 100644 --- a/cirq/protocols/decompose_protocol.py +++ b/cirq/protocols/decompose_protocol.py @@ -47,6 +47,15 @@ DecomposeResult = Union[None, NotImplementedType, 'cirq.OP_TREE'] OpDecomposer = Callable[['cirq.Operation'], DecomposeResult] +DECOMPOSE_TARGET_GATESET = ops.Gateset( + ops.XPowGate, + ops.YPowGate, + ops.ZPowGate, + ops.CZPowGate, + ops.MeasurementGate, + ops.GlobalPhaseGate, +) + def _value_error_describing_bad_operation(op: 'cirq.Operation') -> ValueError: return ValueError(f"Operation doesn't satisfy the given `keep` but can't be decomposed: {op!r}") diff --git a/cirq/testing/__init__.py b/cirq/testing/__init__.py index f08cffc00d5..7124f91173e 100644 --- a/cirq/testing/__init__.py +++ b/cirq/testing/__init__.py @@ -33,6 +33,7 @@ ) from cirq.testing.consistent_decomposition import ( + assert_decompose_ends_at_default_gateset, assert_decompose_is_consistent_with_unitary, ) diff --git a/cirq/testing/consistent_decomposition.py b/cirq/testing/consistent_decomposition.py index 1050a48ffdd..978763c461e 100644 --- a/cirq/testing/consistent_decomposition.py +++ b/cirq/testing/consistent_decomposition.py @@ -47,3 +47,20 @@ def assert_decompose_is_consistent_with_unitary(val: Any, ignoring_global_phase: else: # coverage: ignore np.testing.assert_allclose(actual, expected, atol=1e-8) + + +def _known_gate_with_no_decomposition(val: Any): + """Checks whether `val` is a known gate with no default decomposition to default gateset.""" + return False + + +def assert_decompose_ends_at_default_gateset(val: Any): + """Asserts that cirq.decompose(val) ends at default cirq gateset or a known gate.""" + if _known_gate_with_no_decomposition(val): + return # coverage: ignore + args = () if isinstance(val, ops.Operation) else (tuple(devices.LineQid.for_gate(val)),) + dec_once = protocols.decompose_once(val, [val(*args[0]) if args else val], *args) + for op in [*ops.flatten_to_ops(protocols.decompose(d) for d in dec_once)]: + assert _known_gate_with_no_decomposition(op.gate) or ( + op in protocols.decompose_protocol.DECOMPOSE_TARGET_GATESET + ), f'{val} decomposed to {op}, which is not part of default cirq target gateset.' diff --git a/cirq/testing/consistent_decomposition_test.py b/cirq/testing/consistent_decomposition_test.py index e573084254a..002ed0668fd 100644 --- a/cirq/testing/consistent_decomposition_test.py +++ b/cirq/testing/consistent_decomposition_test.py @@ -15,6 +15,7 @@ import pytest import numpy as np +import sympy import cirq @@ -49,3 +50,63 @@ def test_assert_decompose_is_consistent_with_unitary(): cirq.testing.assert_decompose_is_consistent_with_unitary( BadGateDecompose().on(cirq.NamedQubit('q')) ) + + +class GateDecomposesToDefaultGateset(cirq.Gate): + def _num_qubits_(self): + return 2 + + def _decompose_(self, qubits): + return [GoodGateDecompose().on(qubits[0]), BadGateDecompose().on(qubits[1])] + + +class GateDecomposeDoesNotEndInDefaultGateset(cirq.Gate): + def _num_qubits_(self): + return 4 + + def _decompose_(self, qubits): + yield GateDecomposeNotImplemented().on_each(*qubits) + + +class GateDecomposeNotImplemented(cirq.SingleQubitGate): + def _decompose_(self, qubits): + return NotImplemented + + +class ParameterizedGate(cirq.SingleQubitGate): + def _num_qubits_(self): + return 2 + + def _decompose_(self, qubits): + yield cirq.X(qubits[0]) ** sympy.Symbol("x") + yield cirq.Y(qubits[1]) ** sympy.Symbol("y") + + +def test_assert_decompose_ends_at_default_gateset(): + + cirq.testing.assert_decompose_ends_at_default_gateset(GateDecomposesToDefaultGateset()) + cirq.testing.assert_decompose_ends_at_default_gateset( + GateDecomposesToDefaultGateset().on(*cirq.LineQubit.range(2)) + ) + + cirq.testing.assert_decompose_ends_at_default_gateset(ParameterizedGate()) + cirq.testing.assert_decompose_ends_at_default_gateset( + ParameterizedGate().on(*cirq.LineQubit.range(2)) + ) + + with pytest.raises(AssertionError): + cirq.testing.assert_decompose_ends_at_default_gateset(GateDecomposeNotImplemented()) + + with pytest.raises(AssertionError): + cirq.testing.assert_decompose_ends_at_default_gateset( + GateDecomposeNotImplemented().on(cirq.NamedQubit('q')) + ) + with pytest.raises(AssertionError): + cirq.testing.assert_decompose_ends_at_default_gateset( + GateDecomposeDoesNotEndInDefaultGateset() + ) + + with pytest.raises(AssertionError): + cirq.testing.assert_decompose_ends_at_default_gateset( + GateDecomposeDoesNotEndInDefaultGateset().on(*cirq.LineQubit.range(4)) + )